Fangjun Kuang
Committed by GitHub

Add Python API for punctuation models. (#762)

@@ -14,7 +14,7 @@ echo "PATH: $PATH" @@ -14,7 +14,7 @@ echo "PATH: $PATH"
14 which $EXE 14 which $EXE
15 15
16 log "------------------------------------------------------------" 16 log "------------------------------------------------------------"
17 -log "Download model " 17 +log "Download the punctuation model "
18 log "------------------------------------------------------------" 18 log "------------------------------------------------------------"
19 19
20 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 20 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
@@ -8,6 +8,18 @@ log() { @@ -8,6 +8,18 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test offline punctuation"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  14 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  15 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  16 +repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
  17 +ls -lh $repo
  18 +
  19 +python3 ./python-api-examples/add-punctuation.py
  20 +
  21 +rm -rf $repo
  22 +
11 log "test audio tagging" 23 log "test audio tagging"
12 24
13 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 25 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
@@ -91,3 +91,4 @@ sr-data @@ -91,3 +91,4 @@ sr-data
91 *xcworkspace/xcuserdata/* 91 *xcworkspace/xcuserdata/*
92 92
93 vits-icefall-* 93 vits-icefall-*
  94 +sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
2 2
3 3
4 if [ ! -f ./silero_vad.onnx ]; then 4 if [ ! -f ./silero_vad.onnx ]; then
5 - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx 5 + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
6 fi 6 fi
7 7
8 if [ ! -f ./sherpa-onnx-paraformer-trilingual-zh-cantonese-en/model.int8.onnx ]; then 8 if [ ! -f ./sherpa-onnx-paraformer-trilingual-zh-cantonese-en/model.int8.onnx ]; then
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
2 2
3 3
4 if [ ! -f ./silero_vad.onnx ]; then 4 if [ ! -f ./silero_vad.onnx ]; then
5 - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx 5 + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
6 fi 6 fi
7 7
8 if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then 8 if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
@@ -9,7 +9,7 @@ if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then @@ -9,7 +9,7 @@ if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then
9 fi 9 fi
10 10
11 if [ ! -f ./silero_vad.onnx ]; then 11 if [ ! -f ./silero_vad.onnx ]; then
12 - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx 12 + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
13 fi 13 fi
14 14
15 go mod tidy 15 go mod tidy
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
2 2
3 3
4 if [ ! -f ./silero_vad.onnx ]; then 4 if [ ! -f ./silero_vad.onnx ]; then
5 - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx 5 + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
6 fi 6 fi
7 7
8 if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then 8 if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
2 2
3 3
4 if [ ! -f ./silero_vad.onnx ]; then 4 if [ ! -f ./silero_vad.onnx ]; then
5 - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx 5 + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
6 fi 6 fi
7 7
8 go mod tidy 8 go mod tidy
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to add punctuations to text using sherpa-onnx Python API.
  5 +
  6 +Please download the model from
  7 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
  8 +
  9 +The following is an example
  10 +
  11 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  12 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  13 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  14 +"""
  15 +
  16 +from pathlib import Path
  17 +
  18 +import sherpa_onnx
  19 +
  20 +
  21 +def main():
  22 + model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"
  23 + if not Path(model).is_file():
  24 + raise ValueError(f"{model} does not exist")
  25 + config = sherpa_onnx.OfflinePunctuationConfig(
  26 + model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
  27 + )
  28 +
  29 + punct = sherpa_onnx.OfflinePunctuation(config)
  30 +
  31 + text_list = [
  32 + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你",
  33 + "我们都是木头人不会说话不会动",
  34 + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry",
  35 + ]
  36 + for text in text_list:
  37 + text_with_punct = punct.add_punctuation(text)
  38 + print("----------")
  39 + print(f"input: {text}")
  40 + print(f"output: {text_with_punct}")
  41 +
  42 + print("----------")
  43 +
  44 +
  45 +if __name__ == "__main__":
  46 + main()
@@ -12,6 +12,7 @@ set(srcs @@ -12,6 +12,7 @@ set(srcs
12 offline-model-config.cc 12 offline-model-config.cc
13 offline-nemo-enc-dec-ctc-model-config.cc 13 offline-nemo-enc-dec-ctc-model-config.cc
14 offline-paraformer-model-config.cc 14 offline-paraformer-model-config.cc
  15 + offline-punctuation.cc
15 offline-recognizer.cc 16 offline-recognizer.cc
16 offline-stream.cc 17 offline-stream.cc
17 offline-tdnn-model-config.cc 18 offline-tdnn-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-punctuation.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-punctuation.h"
  6 +
  7 +#include "sherpa-onnx/csrc/offline-punctuation.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static void PybindOfflinePunctuationModelConfig(py::module *m) {
  12 + using PyClass = OfflinePunctuationModelConfig;
  13 + py::class_<PyClass>(*m, "OfflinePunctuationModelConfig")
  14 + .def(py::init<>())
  15 + .def(py::init<const std::string &, int32_t, bool, const std::string &>(),
  16 + py::arg("ct_transformer"), py::arg("num_threads") = 1,
  17 + py::arg("debug") = false, py::arg("provider") = "cpu")
  18 + .def_readwrite("ct_transformer", &PyClass::ct_transformer)
  19 + .def_readwrite("num_threads", &PyClass::num_threads)
  20 + .def_readwrite("debug", &PyClass::debug)
  21 + .def_readwrite("provider", &PyClass::provider)
  22 + .def("validate", &PyClass::Validate)
  23 + .def("__str__", &PyClass::ToString);
  24 +}
  25 +
  26 +static void PybindOfflinePunctuationConfig(py::module *m) {
  27 + PybindOfflinePunctuationModelConfig(m);
  28 + using PyClass = OfflinePunctuationConfig;
  29 +
  30 + py::class_<PyClass>(*m, "OfflinePunctuationConfig")
  31 + .def(py::init<>())
  32 + .def(py::init<const OfflinePunctuationModelConfig &>(), py::arg("model"))
  33 + .def_readwrite("model", &PyClass::model)
  34 + .def("validate", &PyClass::Validate)
  35 + .def("__str__", &PyClass::ToString);
  36 +}
  37 +
  38 +void PybindOfflinePunctuation(py::module *m) {
  39 + PybindOfflinePunctuationConfig(m);
  40 + using PyClass = OfflinePunctuation;
  41 +
  42 + py::class_<PyClass>(*m, "OfflinePunctuation")
  43 + .def(py::init<const OfflinePunctuationConfig &>(), py::arg("config"),
  44 + py::call_guard<py::gil_scoped_release>())
  45 + .def("add_punctuation", &PyClass::AddPunctuation, py::arg("text"),
  46 + py::call_guard<py::gil_scoped_release>());
  47 +}
  48 +
  49 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-punctuation.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflinePunctuation(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" 14 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
15 #include "sherpa-onnx/python/csrc/offline-lm-config.h" 15 #include "sherpa-onnx/python/csrc/offline-lm-config.h"
16 #include "sherpa-onnx/python/csrc/offline-model-config.h" 16 #include "sherpa-onnx/python/csrc/offline-model-config.h"
  17 +#include "sherpa-onnx/python/csrc/offline-punctuation.h"
17 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 18 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
18 #include "sherpa-onnx/python/csrc/offline-stream.h" 19 #include "sherpa-onnx/python/csrc/offline-stream.h"
19 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" 20 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
@@ -40,6 +41,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -40,6 +41,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
40 41
41 PybindWaveWriter(&m); 42 PybindWaveWriter(&m);
42 PybindAudioTagging(&m); 43 PybindAudioTagging(&m);
  44 + PybindOfflinePunctuation(&m);
43 45
44 PybindFeatures(&m); 46 PybindFeatures(&m);
45 PybindOnlineCtcFstDecoderConfig(&m); 47 PybindOnlineCtcFstDecoderConfig(&m);
@@ -6,6 +6,9 @@ from _sherpa_onnx import ( @@ -6,6 +6,9 @@ from _sherpa_onnx import (
6 AudioTaggingModelConfig, 6 AudioTaggingModelConfig,
7 CircularBuffer, 7 CircularBuffer,
8 Display, 8 Display,
  9 + OfflinePunctuation,
  10 + OfflinePunctuationConfig,
  11 + OfflinePunctuationModelConfig,
9 OfflineStream, 12 OfflineStream,
10 OfflineTts, 13 OfflineTts,
11 OfflineTtsConfig, 14 OfflineTtsConfig,