Committed by
GitHub
Add Python binding for online punctuation models (#1312)
正在显示
8 个修改的文件
包含
133 行增加
和
0 行删除
| @@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py | @@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py | ||
| 91 | 91 | ||
| 92 | rm -rf $repo | 92 | rm -rf $repo |
| 93 | 93 | ||
| 94 | +log "test online punctuation" | ||
| 95 | + | ||
| 96 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 97 | +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 98 | +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 99 | +repo=sherpa-onnx-online-punct-en-2024-08-06 | ||
| 100 | +ls -lh $repo | ||
| 101 | + | ||
| 102 | +python3 ./python-api-examples/add-punctuation-online.py | ||
| 103 | + | ||
| 104 | +rm -rf $repo | ||
| 105 | + | ||
| 94 | log "test audio tagging" | 106 | log "test audio tagging" |
| 95 | 107 | ||
| 96 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | 108 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This script shows how to add punctuations to text using sherpa-onnx Python API. | ||
| 5 | + | ||
| 6 | +Please download the model from | ||
| 7 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models | ||
| 8 | + | ||
| 9 | +The following is an example | ||
| 10 | + | ||
| 11 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 12 | +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 13 | +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 | ||
| 14 | +""" | ||
| 15 | + | ||
| 16 | +from pathlib import Path | ||
| 17 | + | ||
| 18 | +import sherpa_onnx | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def main(): | ||
| 22 | + model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx" | ||
| 23 | + bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab" | ||
| 24 | + if not Path(model).is_file(): | ||
| 25 | + raise ValueError(f"{model} does not exist") | ||
| 26 | + if not Path(bpe).is_file(): | ||
| 27 | + raise ValueError(f"{bpe} does not exist") | ||
| 28 | + | ||
| 29 | + model_config = sherpa_onnx.OnlinePunctuationModelConfig( | ||
| 30 | + cnn_bilstm=model, bpe_vocab=bpe | ||
| 31 | + ) | ||
| 32 | + config = sherpa_onnx.OnlinePunctuationConfig(model_config=model_config) | ||
| 33 | + punct = sherpa_onnx.OnlinePunctuation(config) | ||
| 34 | + | ||
| 35 | + texts = [ | ||
| 36 | + "how are you i am fine thank you", | ||
| 37 | + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", | ||
| 38 | + ] | ||
| 39 | + for text in texts: | ||
| 40 | + text_with_punct = punct.add_punctuation_with_case(text) | ||
| 41 | + print("----------") | ||
| 42 | + print(f"input : {text}") | ||
| 43 | + print(f"output: {text_with_punct}") | ||
| 44 | + print("----------") | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +if __name__ == "__main__": | ||
| 48 | + main() |
| @@ -27,6 +27,7 @@ set(srcs | @@ -27,6 +27,7 @@ set(srcs | ||
| 27 | online-model-config.cc | 27 | online-model-config.cc |
| 28 | online-nemo-ctc-model-config.cc | 28 | online-nemo-ctc-model-config.cc |
| 29 | online-paraformer-model-config.cc | 29 | online-paraformer-model-config.cc |
| 30 | + online-punctuation.cc | ||
| 30 | online-recognizer.cc | 31 | online-recognizer.cc |
| 31 | online-stream.cc | 32 | online-stream.cc |
| 32 | online-transducer-model-config.cc | 33 | online-transducer-model-config.cc |
| 1 | +// sherpa-onnx/python/csrc/online-punctuation.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-punctuation.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/online-punctuation.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +static void PybindOnlinePunctuationModelConfig(py::module *m) { | ||
| 12 | + using PyClass = OnlinePunctuationModelConfig; | ||
| 13 | + py::class_<PyClass>(*m, "OnlinePunctuationModelConfig") | ||
| 14 | + .def(py::init<>()) | ||
| 15 | + .def(py::init<const std::string &, const std::string &, int32_t, bool, const std::string &>(), | ||
| 16 | + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1, | ||
| 17 | + py::arg("debug") = false, py::arg("provider") = "cpu") | ||
| 18 | + .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) | ||
| 19 | + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) | ||
| 20 | + .def_readwrite("num_threads", &PyClass::num_threads) | ||
| 21 | + .def_readwrite("debug", &PyClass::debug) | ||
| 22 | + .def_readwrite("provider", &PyClass::provider) | ||
| 23 | + .def("validate", &PyClass::Validate) | ||
| 24 | + .def("__str__", &PyClass::ToString); | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +static void PybindOnlinePunctuationConfig(py::module *m) { | ||
| 28 | + PybindOnlinePunctuationModelConfig(m); | ||
| 29 | + using PyClass = OnlinePunctuationConfig; | ||
| 30 | + | ||
| 31 | + py::class_<PyClass>(*m, "OnlinePunctuationConfig") | ||
| 32 | + .def(py::init<>()) | ||
| 33 | + .def(py::init<const OnlinePunctuationModelConfig &>(), py::arg("model_config")) | ||
| 34 | + .def_readwrite("model_config", &PyClass::model) | ||
| 35 | + .def("validate", &PyClass::Validate) | ||
| 36 | + .def("__str__", &PyClass::ToString); | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +void PybindOnlinePunctuation(py::module *m) { | ||
| 40 | + PybindOnlinePunctuationConfig(m); | ||
| 41 | + using PyClass = OnlinePunctuation; | ||
| 42 | + | ||
| 43 | + py::class_<PyClass>(*m, "OnlinePunctuation") | ||
| 44 | + .def(py::init<const OnlinePunctuationConfig &>(), py::arg("config"), | ||
| 45 | + py::call_guard<py::gil_scoped_release>()) | ||
| 46 | + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"), | ||
| 47 | + py::call_guard<py::gil_scoped_release>()); | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/online-punctuation.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/online-punctuation.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlinePunctuation(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ |
| @@ -20,6 +20,7 @@ | @@ -20,6 +20,7 @@ | ||
| 20 | #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" | 20 | #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" |
| 21 | #include "sherpa-onnx/python/csrc/online-lm-config.h" | 21 | #include "sherpa-onnx/python/csrc/online-lm-config.h" |
| 22 | #include "sherpa-onnx/python/csrc/online-model-config.h" | 22 | #include "sherpa-onnx/python/csrc/online-model-config.h" |
| 23 | +#include "sherpa-onnx/python/csrc/online-punctuation.h" | ||
| 23 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 24 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| 24 | #include "sherpa-onnx/python/csrc/online-stream.h" | 25 | #include "sherpa-onnx/python/csrc/online-stream.h" |
| 25 | #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" | 26 | #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" |
| @@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 42 | PybindWaveWriter(&m); | 43 | PybindWaveWriter(&m); |
| 43 | PybindAudioTagging(&m); | 44 | PybindAudioTagging(&m); |
| 44 | PybindOfflinePunctuation(&m); | 45 | PybindOfflinePunctuation(&m); |
| 46 | + PybindOnlinePunctuation(&m); | ||
| 45 | 47 | ||
| 46 | PybindFeatures(&m); | 48 | PybindFeatures(&m); |
| 47 | PybindOnlineCtcFstDecoderConfig(&m); | 49 | PybindOnlineCtcFstDecoderConfig(&m); |
| @@ -15,6 +15,9 @@ from _sherpa_onnx import ( | @@ -15,6 +15,9 @@ from _sherpa_onnx import ( | ||
| 15 | OfflineTtsModelConfig, | 15 | OfflineTtsModelConfig, |
| 16 | OfflineTtsVitsModelConfig, | 16 | OfflineTtsVitsModelConfig, |
| 17 | OfflineZipformerAudioTaggingModelConfig, | 17 | OfflineZipformerAudioTaggingModelConfig, |
| 18 | + OnlinePunctuation, | ||
| 19 | + OnlinePunctuationConfig, | ||
| 20 | + OnlinePunctuationModelConfig, | ||
| 18 | OnlineStream, | 21 | OnlineStream, |
| 19 | SileroVadModelConfig, | 22 | SileroVadModelConfig, |
| 20 | SpeakerEmbeddingExtractor, | 23 | SpeakerEmbeddingExtractor, |
-
请 注册 或 登录 后发表评论