Fangjun Kuang
Committed by GitHub

Support adding puncutations to text for node-addon-api (#876)

@@ -18,7 +18,7 @@ fi @@ -18,7 +18,7 @@ fi
18 SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) 18 SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
19 echo "SHERPA_ONNX_VERSION $SHERPA_ONNX_VERSION" 19 echo "SHERPA_ONNX_VERSION $SHERPA_ONNX_VERSION"
20 20
21 -# SHERPA_ONNX_VERSION=1.0.22 21 +# SHERPA_ONNX_VERSION=1.0.23
22 22
23 if [ -z $owner ]; then 23 if [ -z $owner ]; then
24 owner=k2-fsa 24 owner=k2-fsa
@@ -6,6 +6,15 @@ d=nodejs-addon-examples @@ -6,6 +6,15 @@ d=nodejs-addon-examples
6 echo "dir: $d" 6 echo "dir: $d"
7 cd $d 7 cd $d
8 8
  9 +echo "----------add punctuations----------"
  10 +
  11 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  12 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  13 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  14 +
  15 +node ./test_punctuation.js
  16 +rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
  17 +
9 echo "----------audio tagging----------" 18 echo "----------audio tagging----------"
10 19
11 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 20 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
@@ -55,7 +55,7 @@ jobs: @@ -55,7 +55,7 @@ jobs:
55 55
56 SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) 56 SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
57 echo "SHERPA_ONNX_VERSION $SHERPA_ONNX_VERSION" 57 echo "SHERPA_ONNX_VERSION $SHERPA_ONNX_VERSION"
58 - # SHERPA_ONNX_VERSION=1.0.22 58 + # SHERPA_ONNX_VERSION=1.0.23
59 59
60 src_dir=.github/scripts/node-addon 60 src_dir=.github/scripts/node-addon
61 sed -i.bak s/SHERPA_ONNX_VERSION/$SHERPA_ONNX_VERSION/g $src_dir/package.json 61 sed -i.bak s/SHERPA_ONNX_VERSION/$SHERPA_ONNX_VERSION/g $src_dir/package.json
@@ -31,6 +31,12 @@ export LD_LIBRARY_PATH=$PWD/node_modules/sherpa-onnx-linux-arm64:$LD_LIBRARY_PAT @@ -31,6 +31,12 @@ export LD_LIBRARY_PATH=$PWD/node_modules/sherpa-onnx-linux-arm64:$LD_LIBRARY_PAT
31 31
32 The following tables list the examples in this folder. 32 The following tables list the examples in this folder.
33 33
  34 +## Add punctuations to text
  35 +
  36 +|File| Description|
  37 +|---|---|
  38 +|[./test_punctuation.js](./test_punctuation.js)| Add punctuations to input text using [CT transformer](https://modelscope.cn/models/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary). It supports both Chinese and English.|
  39 +
34 ## Voice activity detection (VAD) 40 ## Voice activity detection (VAD)
35 41
36 |File| Description| 42 |File| Description|
@@ -309,3 +315,13 @@ git clone https://github.com/csukuangfj/sr-data @@ -309,3 +315,13 @@ git clone https://github.com/csukuangfj/sr-data
309 315
310 node ./test_speaker_identification.js 316 node ./test_speaker_identification.js
311 ``` 317 ```
  318 +
  319 +### Add punctuations
  320 +
  321 +```bash
  322 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  323 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  324 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  325 +
  326 +node ./test_punctuation.js
  327 +```
  1 +// Copyright (c) 2023-2024 Xiaomi Corporation (authors: Fangjun Kuang)
  2 +
  3 +const sherpa_onnx = require('sherpa-onnx-node');
  4 +
  5 +// Please download test files from
  6 +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
  7 +function createPunctuation() {
  8 + const config = {
  9 + model: {
  10 + ctTransformer:
  11 + './sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx',
  12 + debug: true,
  13 + numThreads: 1,
  14 + provider: 'cpu',
  15 + },
  16 + };
  17 + return new sherpa_onnx.Punctuation(config);
  18 +}
  19 +
  20 +const punct = createPunctuation();
  21 +const sentences = [
  22 + '这是一个测试你好吗How are you我很好thank you are you ok谢谢你',
  23 + '我们都是木头人不会说话不会动',
  24 + 'The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry',
  25 +];
  26 +console.log('---');
  27 +for (let sentence of sentences) {
  28 + const punct_text = punct.addPunct(sentence);
  29 + console.log(`Input: ${sentence}`);
  30 + console.log(`Output: ${punct_text}`);
  31 + console.log('---');
  32 +}
@@ -2,6 +2,8 @@ @@ -2,6 +2,8 @@
2 2
3 const sherpa_onnx = require('sherpa-onnx-node'); 3 const sherpa_onnx = require('sherpa-onnx-node');
4 4
  5 +// Please download whisper multi-lingual models from
  6 +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
5 function createSpokenLanguageID() { 7 function createSpokenLanguageID() {
6 const config = { 8 const config = {
7 whisper: { 9 whisper: {
@@ -21,6 +21,7 @@ set(srcs @@ -21,6 +21,7 @@ set(srcs
21 src/audio-tagging.cc 21 src/audio-tagging.cc
22 src/non-streaming-asr.cc 22 src/non-streaming-asr.cc
23 src/non-streaming-tts.cc 23 src/non-streaming-tts.cc
  24 + src/punctuation.cc
24 src/sherpa-onnx-node-addon-api.cc 25 src/sherpa-onnx-node-addon-api.cc
25 src/speaker-identification.cc 26 src/speaker-identification.cc
26 src/spoken-language-identification.cc 27 src/spoken-language-identification.cc
  1 +const addon = require('./addon.js');
  2 +
  3 +class Punctuation {
  4 + constructor(config) {
  5 + this.handle = addon.createOfflinePunctuation(config);
  6 + this.config = config;
  7 + }
  8 + addPunct(text) {
  9 + return addon.offlinePunctuationAddPunct(this.handle, text);
  10 + }
  11 +}
  12 +
  13 +module.exports = {
  14 + Punctuation,
  15 +}
@@ -6,6 +6,7 @@ const vad = require('./vad.js'); @@ -6,6 +6,7 @@ const vad = require('./vad.js');
6 const slid = require('./spoken-language-identification.js'); 6 const slid = require('./spoken-language-identification.js');
7 const sid = require('./speaker-identification.js'); 7 const sid = require('./speaker-identification.js');
8 const at = require('./audio-tagg.js'); 8 const at = require('./audio-tagg.js');
  9 +const punct = require('./punctuation.js');
9 10
10 module.exports = { 11 module.exports = {
11 OnlineRecognizer: streaming_asr.OnlineRecognizer, 12 OnlineRecognizer: streaming_asr.OnlineRecognizer,
@@ -20,4 +21,5 @@ module.exports = { @@ -20,4 +21,5 @@ module.exports = {
20 SpeakerEmbeddingExtractor: sid.SpeakerEmbeddingExtractor, 21 SpeakerEmbeddingExtractor: sid.SpeakerEmbeddingExtractor,
21 SpeakerEmbeddingManager: sid.SpeakerEmbeddingManager, 22 SpeakerEmbeddingManager: sid.SpeakerEmbeddingManager,
22 AudioTagging: at.AudioTagging, 23 AudioTagging: at.AudioTagging,
  24 + Punctuation: punct.Punctuation,
23 } 25 }
@@ -166,7 +166,7 @@ static Napi::Object AudioTaggingComputeWrapper(const Napi::CallbackInfo &info) { @@ -166,7 +166,7 @@ static Napi::Object AudioTaggingComputeWrapper(const Napi::CallbackInfo &info) {
166 166
167 if (!info[1].IsExternal()) { 167 if (!info[1].IsExternal()) {
168 Napi::TypeError::New( 168 Napi::TypeError::New(
169 - env, "You should pass a offline stream pointer as the second argument") 169 + env, "You should pass an offline stream pointer as the second argument")
170 .ThrowAsJavaScriptException(); 170 .ThrowAsJavaScriptException();
171 171
172 return {}; 172 return {};
  1 +// scripts/node-addon-api/src/punctuation.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include <sstream>
  5 +
  6 +#include "macros.h" // NOLINT
  7 +#include "napi.h" // NOLINT
  8 +#include "sherpa-onnx/c-api/c-api.h"
  9 +
  10 +static SherpaOnnxOfflinePunctuationModelConfig GetOfflinePunctuationModelConfig(
  11 + Napi::Object obj) {
  12 + SherpaOnnxOfflinePunctuationModelConfig c;
  13 + memset(&c, 0, sizeof(c));
  14 +
  15 + if (!obj.Has("model") || !obj.Get("model").IsObject()) {
  16 + return c;
  17 + }
  18 +
  19 + Napi::Object o = obj.Get("model").As<Napi::Object>();
  20 +
  21 + SHERPA_ONNX_ASSIGN_ATTR_STR(ct_transformer, ctTransformer);
  22 +
  23 + SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
  24 +
  25 + if (o.Has("debug") &&
  26 + (o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) {
  27 + if (o.Get("debug").IsBoolean()) {
  28 + c.debug = o.Get("debug").As<Napi::Boolean>().Value();
  29 + } else {
  30 + c.debug = o.Get("debug").As<Napi::Number>().Int32Value();
  31 + }
  32 + }
  33 + SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);
  34 +
  35 + return c;
  36 +}
  37 +
  38 +static Napi::External<SherpaOnnxOfflinePunctuation>
  39 +CreateOfflinePunctuationWrapper(const Napi::CallbackInfo &info) {
  40 + Napi::Env env = info.Env();
  41 + if (info.Length() != 1) {
  42 + std::ostringstream os;
  43 + os << "Expect only 1 argument. Given: " << info.Length();
  44 +
  45 + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
  46 +
  47 + return {};
  48 + }
  49 +
  50 + if (!info[0].IsObject()) {
  51 + Napi::TypeError::New(env, "You should pass an object as the only argument.")
  52 + .ThrowAsJavaScriptException();
  53 +
  54 + return {};
  55 + }
  56 +
  57 + Napi::Object o = info[0].As<Napi::Object>();
  58 +
  59 + SherpaOnnxOfflinePunctuationConfig c;
  60 + memset(&c, 0, sizeof(c));
  61 + c.model = GetOfflinePunctuationModelConfig(o);
  62 +
  63 + const SherpaOnnxOfflinePunctuation *punct =
  64 + SherpaOnnxCreateOfflinePunctuation(&c);
  65 +
  66 + if (c.model.ct_transformer) {
  67 + delete[] c.model.ct_transformer;
  68 + }
  69 +
  70 + if (c.model.provider) {
  71 + delete[] c.model.provider;
  72 + }
  73 +
  74 + if (!punct) {
  75 + Napi::TypeError::New(env, "Please check your config!")
  76 + .ThrowAsJavaScriptException();
  77 +
  78 + return {};
  79 + }
  80 +
  81 + return Napi::External<SherpaOnnxOfflinePunctuation>::New(
  82 + env, const_cast<SherpaOnnxOfflinePunctuation *>(punct),
  83 + [](Napi::Env env, SherpaOnnxOfflinePunctuation *punct) {
  84 + SherpaOnnxDestroyOfflinePunctuation(punct);
  85 + });
  86 +}
  87 +
  88 +static Napi::String OfflinePunctuationAddPunctWraper(
  89 + const Napi::CallbackInfo &info) {
  90 + Napi::Env env = info.Env();
  91 + if (info.Length() != 2) {
  92 + std::ostringstream os;
  93 + os << "Expect only 2 arguments. Given: " << info.Length();
  94 +
  95 + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
  96 +
  97 + return {};
  98 + }
  99 +
  100 + if (!info[0].IsExternal()) {
  101 + Napi::TypeError::New(
  102 + env,
  103 + "You should pass an offline punctuation pointer as the first argument")
  104 + .ThrowAsJavaScriptException();
  105 +
  106 + return {};
  107 + }
  108 +
  109 + if (!info[1].IsString()) {
  110 + Napi::TypeError::New(env, "You should pass a string as the second argument")
  111 + .ThrowAsJavaScriptException();
  112 +
  113 + return {};
  114 + }
  115 +
  116 + SherpaOnnxOfflinePunctuation *punct =
  117 + info[0].As<Napi::External<SherpaOnnxOfflinePunctuation>>().Data();
  118 + Napi::String js_text = info[1].As<Napi::String>();
  119 + std::string text = js_text.Utf8Value();
  120 +
  121 + const char *punct_text =
  122 + SherpaOfflinePunctuationAddPunct(punct, text.c_str());
  123 +
  124 + Napi::String ans = Napi::String::New(env, punct_text);
  125 + SherpaOfflinePunctuationFreeText(punct_text);
  126 + return ans;
  127 +}
  128 +
  129 +void InitPunctuation(Napi::Env env, Napi::Object exports) {
  130 + exports.Set(Napi::String::New(env, "createOfflinePunctuation"),
  131 + Napi::Function::New(env, CreateOfflinePunctuationWrapper));
  132 +
  133 + exports.Set(Napi::String::New(env, "offlinePunctuationAddPunct"),
  134 + Napi::Function::New(env, OfflinePunctuationAddPunctWraper));
  135 +}
@@ -21,6 +21,8 @@ void InitSpeakerID(Napi::Env env, Napi::Object exports); @@ -21,6 +21,8 @@ void InitSpeakerID(Napi::Env env, Napi::Object exports);
21 21
22 void InitAudioTagging(Napi::Env env, Napi::Object exports); 22 void InitAudioTagging(Napi::Env env, Napi::Object exports);
23 23
  24 +void InitPunctuation(Napi::Env env, Napi::Object exports);
  25 +
24 Napi::Object Init(Napi::Env env, Napi::Object exports) { 26 Napi::Object Init(Napi::Env env, Napi::Object exports) {
25 InitStreamingAsr(env, exports); 27 InitStreamingAsr(env, exports);
26 InitNonStreamingAsr(env, exports); 28 InitNonStreamingAsr(env, exports);
@@ -31,6 +33,7 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) { @@ -31,6 +33,7 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
31 InitSpokenLanguageID(env, exports); 33 InitSpokenLanguageID(env, exports);
32 InitSpeakerID(env, exports); 34 InitSpeakerID(env, exports);
33 InitAudioTagging(env, exports); 35 InitAudioTagging(env, exports);
  36 + InitPunctuation(env, exports);
34 37
35 return exports; 38 return exports;
36 } 39 }