Lim Yao Chong
Committed by GitHub

Add Python binding for online punctuation models (#1312)

@@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py @@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py
91 91
92 rm -rf $repo 92 rm -rf $repo
93 93
  94 +log "test online punctuation"
  95 +
  96 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  97 +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  98 +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  99 +repo=sherpa-onnx-online-punct-en-2024-08-06
  100 +ls -lh $repo
  101 +
  102 +python3 ./python-api-examples/add-punctuation-online.py
  103 +
  104 +rm -rf $repo
  105 +
94 log "test audio tagging" 106 log "test audio tagging"
95 107
96 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 108 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
@@ -117,3 +117,4 @@ sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 @@ -117,3 +117,4 @@ sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
117 vits-melo-tts-zh_en 117 vits-melo-tts-zh_en
118 *.o 118 *.o
119 *.ppu 119 *.ppu
  120 +sherpa-onnx-online-punct-en-2024-08-06
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to add punctuations to text using sherpa-onnx Python API.
  5 +
  6 +Please download the model from
  7 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
  8 +
  9 +The following is an example
  10 +
  11 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  12 +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  13 +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  14 +"""
  15 +
  16 +from pathlib import Path
  17 +
  18 +import sherpa_onnx
  19 +
  20 +
  21 +def main():
  22 + model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx"
  23 + bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab"
  24 + if not Path(model).is_file():
  25 + raise ValueError(f"{model} does not exist")
  26 + if not Path(bpe).is_file():
  27 + raise ValueError(f"{bpe} does not exist")
  28 +
  29 + model_config = sherpa_onnx.OnlinePunctuationModelConfig(
  30 + cnn_bilstm=model, bpe_vocab=bpe
  31 + )
  32 + config = sherpa_onnx.OnlinePunctuationConfig(model_config=model_config)
  33 + punct = sherpa_onnx.OnlinePunctuation(config)
  34 +
  35 + texts = [
  36 + "how are you i am fine thank you",
  37 + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry",
  38 + ]
  39 + for text in texts:
  40 + text_with_punct = punct.add_punctuation_with_case(text)
  41 + print("----------")
  42 + print(f"input : {text}")
  43 + print(f"output: {text_with_punct}")
  44 + print("----------")
  45 +
  46 +
  47 +if __name__ == "__main__":
  48 + main()
@@ -27,6 +27,7 @@ set(srcs @@ -27,6 +27,7 @@ set(srcs
27 online-model-config.cc 27 online-model-config.cc
28 online-nemo-ctc-model-config.cc 28 online-nemo-ctc-model-config.cc
29 online-paraformer-model-config.cc 29 online-paraformer-model-config.cc
  30 + online-punctuation.cc
30 online-recognizer.cc 31 online-recognizer.cc
31 online-stream.cc 32 online-stream.cc
32 online-transducer-model-config.cc 33 online-transducer-model-config.cc
  1 +// sherpa-onnx/python/csrc/online-punctuation.cc
  2 +//
  3 +// Copyright (c) 2024
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-punctuation.h"
  6 +
  7 +#include "sherpa-onnx/csrc/online-punctuation.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static void PybindOnlinePunctuationModelConfig(py::module *m) {
  12 + using PyClass = OnlinePunctuationModelConfig;
  13 + py::class_<PyClass>(*m, "OnlinePunctuationModelConfig")
  14 + .def(py::init<>())
  15 + .def(py::init<const std::string &, const std::string &, int32_t, bool, const std::string &>(),
  16 + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1,
  17 + py::arg("debug") = false, py::arg("provider") = "cpu")
  18 + .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm)
  19 + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
  20 + .def_readwrite("num_threads", &PyClass::num_threads)
  21 + .def_readwrite("debug", &PyClass::debug)
  22 + .def_readwrite("provider", &PyClass::provider)
  23 + .def("validate", &PyClass::Validate)
  24 + .def("__str__", &PyClass::ToString);
  25 +}
  26 +
  27 +static void PybindOnlinePunctuationConfig(py::module *m) {
  28 + PybindOnlinePunctuationModelConfig(m);
  29 + using PyClass = OnlinePunctuationConfig;
  30 +
  31 + py::class_<PyClass>(*m, "OnlinePunctuationConfig")
  32 + .def(py::init<>())
  33 + .def(py::init<const OnlinePunctuationModelConfig &>(), py::arg("model_config"))
  34 + .def_readwrite("model_config", &PyClass::model)
  35 + .def("validate", &PyClass::Validate)
  36 + .def("__str__", &PyClass::ToString);
  37 +}
  38 +
  39 +void PybindOnlinePunctuation(py::module *m) {
  40 + PybindOnlinePunctuationConfig(m);
  41 + using PyClass = OnlinePunctuation;
  42 +
  43 + py::class_<PyClass>(*m, "OnlinePunctuation")
  44 + .def(py::init<const OnlinePunctuationConfig &>(), py::arg("config"),
  45 + py::call_guard<py::gil_scoped_release>())
  46 + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"),
  47 + py::call_guard<py::gil_scoped_release>());
  48 +}
  49 +
  50 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-punctuation.h
  2 +//
  3 +// Copyright (c) 2024
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlinePunctuation(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_
@@ -20,6 +20,7 @@ @@ -20,6 +20,7 @@
20 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" 20 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
21 #include "sherpa-onnx/python/csrc/online-lm-config.h" 21 #include "sherpa-onnx/python/csrc/online-lm-config.h"
22 #include "sherpa-onnx/python/csrc/online-model-config.h" 22 #include "sherpa-onnx/python/csrc/online-model-config.h"
  23 +#include "sherpa-onnx/python/csrc/online-punctuation.h"
23 #include "sherpa-onnx/python/csrc/online-recognizer.h" 24 #include "sherpa-onnx/python/csrc/online-recognizer.h"
24 #include "sherpa-onnx/python/csrc/online-stream.h" 25 #include "sherpa-onnx/python/csrc/online-stream.h"
25 #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" 26 #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
@@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
42 PybindWaveWriter(&m); 43 PybindWaveWriter(&m);
43 PybindAudioTagging(&m); 44 PybindAudioTagging(&m);
44 PybindOfflinePunctuation(&m); 45 PybindOfflinePunctuation(&m);
  46 + PybindOnlinePunctuation(&m);
45 47
46 PybindFeatures(&m); 48 PybindFeatures(&m);
47 PybindOnlineCtcFstDecoderConfig(&m); 49 PybindOnlineCtcFstDecoderConfig(&m);
@@ -15,6 +15,9 @@ from _sherpa_onnx import ( @@ -15,6 +15,9 @@ from _sherpa_onnx import (
15 OfflineTtsModelConfig, 15 OfflineTtsModelConfig,
16 OfflineTtsVitsModelConfig, 16 OfflineTtsVitsModelConfig,
17 OfflineZipformerAudioTaggingModelConfig, 17 OfflineZipformerAudioTaggingModelConfig,
  18 + OnlinePunctuation,
  19 + OnlinePunctuationConfig,
  20 + OnlinePunctuationModelConfig,
18 OnlineStream, 21 OnlineStream,
19 SileroVadModelConfig, 22 SileroVadModelConfig,
20 SpeakerEmbeddingExtractor, 23 SpeakerEmbeddingExtractor,