继续操作前请注册或者登录。
Fangjun Kuang
Committed by GitHub

Add APIs for Online NeMo CTC models (#2454)

... ... @@ -9,6 +9,49 @@ git status
ls -lh
ls -lh node_modules
# online asr
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
node ./test-online-paraformer.js
rm -rf sherpa-onnx-streaming-paraformer-bilingual-zh-en
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
rm -f itn*
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
node ./test-online-transducer-itn.js
node ./test-online-transducer.js
rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
node ./test-online-zipformer2-ctc-hlg.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
echo "----------keyword spotting----------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
node ./test-keyword-spotter-transducer.js
rm -rf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
# asr with offline nemo canary
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
... ... @@ -145,15 +188,6 @@ rm Obama.wav
rm silero_vad.onnx
rm -rf sherpa-onnx-whisper-tiny.en
echo "----------keyword spotting----------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
node ./test-keyword-spotter-transducer.js
rm -rf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
# offline asr
#
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
... ... @@ -218,37 +252,3 @@ rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
node ./test-offline-moonshine.js
rm -rf sherpa-onnx-moonshine-*
# online asr
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
node ./test-online-paraformer.js
rm -rf sherpa-onnx-streaming-paraformer-bilingual-zh-en
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
rm -f itn*
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
node ./test-online-transducer-itn.js
node ./test-online-transducer.js
rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
node ./test-online-zipformer2-ctc-hlg.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
... ...
... ... @@ -148,7 +148,7 @@ to download pre-trained non-streaming zipformer models.
dotnet run \
--tokens=./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.onnx \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \
--files ./sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/1.wav \
... ...
... ... @@ -18,7 +18,7 @@ fi
dotnet run \
--tokens=./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.onnx \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \
--rule-fsts=./itn_zh_number.fst \
--num-threads=2 \
--files ./itn-zh-number.wav
... ...
... ... @@ -10,7 +10,7 @@ fi
dotnet run \
--tokens=./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.onnx \
--paraformer=./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \
--num-threads=2 \
--files ./sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/1.wav \
... ...
... ... @@ -13,6 +13,4 @@ dotnet run \
--tokens=./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/tokens.txt \
--zipformer-ctc=./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/model.int8.onnx \
--num-threads=1 \
--files ./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/test_wavs/0.wav \
./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/test_wavs/1.wav \
./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/test_wavs/8k.wav
--files ./sherpa-onnx-zipformer-ctc-zh-int8-2025-07-03/test_wavs/0.wav
... ...
... ... @@ -121,6 +121,9 @@ class KeywordSpotter {
c.ref.model.zipformer2Ctc.model =
config.model.zipformer2Ctc.model.toNativeUtf8();
// nemoCtc
c.ref.model.nemoCtc.model = config.model.nemoCtc.model.toNativeUtf8();
c.ref.model.tokens = config.model.tokens.toNativeUtf8();
c.ref.model.numThreads = config.model.numThreads;
c.ref.model.provider = config.model.provider.toNativeUtf8();
... ... @@ -146,6 +149,7 @@ class KeywordSpotter {
calloc.free(c.ref.model.modelType);
calloc.free(c.ref.model.provider);
calloc.free(c.ref.model.tokens);
calloc.free(c.ref.model.nemoCtc.model);
calloc.free(c.ref.model.zipformer2Ctc.model);
calloc.free(c.ref.model.paraformer.encoder);
calloc.free(c.ref.model.paraformer.decoder);
... ...
... ... @@ -86,11 +86,33 @@ class OnlineZipformer2CtcModelConfig {
final String model;
}
class OnlineNemoCtcModelConfig {
const OnlineNemoCtcModelConfig({this.model = ''});
factory OnlineNemoCtcModelConfig.fromJson(Map<String, dynamic> json) {
return OnlineNemoCtcModelConfig(
model: json['model'] as String? ?? '',
);
}
@override
String toString() {
return 'OnlineNemoCtcModelConfig(model: $model)';
}
Map<String, dynamic> toJson() => {
'model': model,
};
final String model;
}
class OnlineModelConfig {
const OnlineModelConfig({
this.transducer = const OnlineTransducerModelConfig(),
this.paraformer = const OnlineParaformerModelConfig(),
this.zipformer2Ctc = const OnlineZipformer2CtcModelConfig(),
this.nemoCtc = const OnlineNemoCtcModelConfig(),
required this.tokens,
this.numThreads = 1,
this.provider = 'cpu',
... ... @@ -108,6 +130,8 @@ class OnlineModelConfig {
json['paraformer'] as Map<String, dynamic>? ?? const {}),
zipformer2Ctc: OnlineZipformer2CtcModelConfig.fromJson(
json['zipformer2Ctc'] as Map<String, dynamic>? ?? const {}),
nemoCtc: OnlineNemoCtcModelConfig.fromJson(
json['nemoCtc'] as Map<String, dynamic>? ?? const {}),
tokens: json['tokens'] as String,
numThreads: json['numThreads'] as int? ?? 1,
provider: json['provider'] as String? ?? 'cpu',
... ... @@ -120,13 +144,14 @@ class OnlineModelConfig {
@override
String toString() {
return 'OnlineModelConfig(transducer: $transducer, paraformer: $paraformer, zipformer2Ctc: $zipformer2Ctc, tokens: $tokens, numThreads: $numThreads, provider: $provider, debug: $debug, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab)';
return 'OnlineModelConfig(transducer: $transducer, paraformer: $paraformer, zipformer2Ctc: $zipformer2Ctc, nemoCtc: $nemoCtc, tokens: $tokens, numThreads: $numThreads, provider: $provider, debug: $debug, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab)';
}
Map<String, dynamic> toJson() => {
'transducer': transducer.toJson(),
'paraformer': paraformer.toJson(),
'zipformer2Ctc': zipformer2Ctc.toJson(),
'nemoCtc': nemoCtc.toJson(),
'tokens': tokens,
'numThreads': numThreads,
'provider': provider,
... ... @@ -139,6 +164,7 @@ class OnlineModelConfig {
final OnlineTransducerModelConfig transducer;
final OnlineParaformerModelConfig paraformer;
final OnlineZipformer2CtcModelConfig zipformer2Ctc;
final OnlineNemoCtcModelConfig nemoCtc;
final String tokens;
... ... @@ -333,6 +359,9 @@ class OnlineRecognizer {
c.ref.model.zipformer2Ctc.model =
config.model.zipformer2Ctc.model.toNativeUtf8();
// nemoCtc
c.ref.model.nemoCtc.model = config.model.nemoCtc.model.toNativeUtf8();
c.ref.model.tokens = config.model.tokens.toNativeUtf8();
c.ref.model.numThreads = config.model.numThreads;
c.ref.model.provider = config.model.provider.toNativeUtf8();
... ... @@ -377,6 +406,7 @@ class OnlineRecognizer {
calloc.free(c.ref.model.modelType);
calloc.free(c.ref.model.provider);
calloc.free(c.ref.model.tokens);
calloc.free(c.ref.model.nemoCtc.model);
calloc.free(c.ref.model.zipformer2Ctc.model);
calloc.free(c.ref.model.paraformer.encoder);
calloc.free(c.ref.model.paraformer.decoder);
... ...
... ... @@ -388,6 +388,10 @@ final class SherpaOnnxOnlineZipformer2CtcModelConfig extends Struct {
external Pointer<Utf8> model;
}
final class SherpaOnnxOnlineNemoCtcModelConfig extends Struct {
external Pointer<Utf8> model;
}
final class SherpaOnnxOnlineModelConfig extends Struct {
external SherpaOnnxOnlineTransducerModelConfig transducer;
external SherpaOnnxOnlineParaformerModelConfig paraformer;
... ... @@ -413,6 +417,8 @@ final class SherpaOnnxOnlineModelConfig extends Struct {
@Int32()
external int tokensBufSize;
external SherpaOnnxOnlineNemoCtcModelConfig nemoCtc;
}
final class SherpaOnnxOnlineCtcFstDecoderConfig extends Struct {
... ...
module non-streaming-canary-decode-files
go 1.17
require (
github.com/k2-fsa/sherpa-onnx-go v1.12.4
github.com/spf13/pflag v1.0.6
github.com/youpy/go-wav v0.3.2
)
require (
github.com/k2-fsa/sherpa-onnx-go-linux v1.12.4 // indirect
github.com/k2-fsa/sherpa-onnx-go-macos v1.12.4 // indirect
github.com/k2-fsa/sherpa-onnx-go-windows v1.12.4 // indirect
github.com/youpy/go-riff v0.1.0 // indirect
github.com/zaf/g711 v0.0.0-20190814101024-76a4a538f52b // indirect
)
... ...
... ... @@ -25,6 +25,7 @@ export { Samples,
} from './src/main/ets/components/NonStreamingAsr';
export { OnlineStream,
OnlineNemoCtcModelConfig,
OnlineTransducerModelConfig,
OnlineParaformerModelConfig,
OnlineZipformer2CtcModelConfig,
... ...
... ... @@ -73,6 +73,22 @@ GetOnlineZipformer2CtcModelConfig(Napi::Object obj) {
return c;
}
static SherpaOnnxOnlineNemoCtcModelConfig GetOnlineNemoCtcModelConfig(
Napi::Object obj) {
SherpaOnnxOnlineNemoCtcModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("nemoCtc") || !obj.Get("nemoCtc").IsObject()) {
return c;
}
Napi::Object o = obj.Get("nemoCtc").As<Napi::Object>();
SHERPA_ONNX_ASSIGN_ATTR_STR(model, model);
return c;
}
static SherpaOnnxOnlineParaformerModelConfig GetOnlineParaformerModelConfig(
Napi::Object obj) {
SherpaOnnxOnlineParaformerModelConfig c;
... ... @@ -103,6 +119,7 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) {
c.transducer = GetOnlineTransducerModelConfig(o);
c.paraformer = GetOnlineParaformerModelConfig(o);
c.zipformer2_ctc = GetOnlineZipformer2CtcModelConfig(o);
c.nemo_ctc = GetOnlineNemoCtcModelConfig(o);
SHERPA_ONNX_ASSIGN_ATTR_STR(tokens, tokens);
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
... ... @@ -248,6 +265,7 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
SHERPA_ONNX_DELETE_C_STR(c.model_config.paraformer.encoder);
SHERPA_ONNX_DELETE_C_STR(c.model_config.paraformer.decoder);
SHERPA_ONNX_DELETE_C_STR(c.model_config.nemo_ctc.model);
SHERPA_ONNX_DELETE_C_STR(c.model_config.zipformer2_ctc.model);
SHERPA_ONNX_DELETE_C_STR(c.model_config.tokens);
SHERPA_ONNX_DELETE_C_STR(c.model_config.provider);
... ...
... ... @@ -46,10 +46,15 @@ export class OnlineZipformer2CtcModelConfig {
public model: string = '';
}
export class OnlineNemoCtcModelConfig {
public model: string = '';
}
export class OnlineModelConfig {
public transducer: OnlineTransducerModelConfig = new OnlineTransducerModelConfig();
public paraformer: OnlineParaformerModelConfig = new OnlineParaformerModelConfig();
public zipformer2_ctc: OnlineZipformer2CtcModelConfig = new OnlineZipformer2CtcModelConfig();
public nemo_ctc: OnlineNemoCtcModelConfig = new OnlineNemoCtcModelConfig();
public tokens: string = '';
public numThreads: number = 1;
public provider: string = 'cpu';
... ...
... ... @@ -338,7 +338,7 @@ void CNonStreamingSpeechRecognitionDlg::ShowInitRecognizerHelpMessage() {
msg +=
"wget "
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-paraformer-zh-2023-09-14/resolve/main/model.onnx\r\n";
"sherpa-onnx-paraformer-zh-2023-09-14/resolve/main/model.int8.onnx\r\n";
msg +=
"wget "
"https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14/"
... ...
... ... @@ -24,6 +24,7 @@ namespace SherpaOnnx
BpeVocab = "";
TokensBuf = "";
TokensBufSize = 0;
NemoCtc = new OnlineNemoCtcModelConfig();
}
public OnlineTransducerModelConfig Transducer;
... ... @@ -55,6 +56,8 @@ namespace SherpaOnnx
public string TokensBuf;
public int TokensBufSize;
public OnlineNemoCtcModelConfig NemoCtc;
}
}
... ...
/// Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
using System.Runtime.InteropServices;
namespace SherpaOnnx
{
[StructLayout(LayoutKind.Sequential)]
public struct OnlineNemoCtcModelConfig
{
public OnlineNemoCtcModelConfig()
{
Model = "";
}
[MarshalAs(UnmanagedType.LPStr)]
public string Model;
}
}
... ...
... ... @@ -77,6 +77,10 @@ type OnlineZipformer2CtcModelConfig struct {
Model string // Path to the onnx model
}
type OnlineNemoCtcModelConfig struct {
Model string // Path to the onnx model
}
// Configuration for online/streaming models
//
// Please refer to
... ... @@ -87,6 +91,7 @@ type OnlineModelConfig struct {
Transducer OnlineTransducerModelConfig
Paraformer OnlineParaformerModelConfig
Zipformer2Ctc OnlineZipformer2CtcModelConfig
NemoCtc OnlineNemoCtcModelConfig
Tokens string // Path to tokens.txt
NumThreads int // Number of threads to use for neural network computation
Provider string // Optional. Valid values are: cpu, cuda, coreml
... ... @@ -197,6 +202,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer {
c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model)
defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model))
c.model_config.nemo_ctc.model = C.CString(config.ModelConfig.NemoCtc.Model)
defer C.free(unsafe.Pointer(c.model_config.nemo_ctc.model))
c.model_config.tokens = C.CString(config.ModelConfig.Tokens)
defer C.free(unsafe.Pointer(c.model_config.tokens))
... ... @@ -1814,6 +1822,9 @@ func NewKeywordSpotter(config *KeywordSpotterConfig) *KeywordSpotter {
c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model)
defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model))
c.model_config.nemo_ctc.model = C.CString(config.ModelConfig.NemoCtc.Model)
defer C.free(unsafe.Pointer(c.model_config.nemo_ctc.model))
c.model_config.tokens = C.CString(config.ModelConfig.Tokens)
defer C.free(unsafe.Pointer(c.model_config.tokens))
... ...
... ... @@ -97,6 +97,9 @@ static sherpa_onnx::OnlineRecognizerConfig GetOnlineRecognizerConfig(
config->model_config.tokens_buf, config->model_config.tokens_buf_size);
}
recognizer_config.model_config.nemo_ctc.model =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider_config.provider =
... ... @@ -108,8 +111,7 @@ static sherpa_onnx::OnlineRecognizerConfig GetOnlineRecognizerConfig(
recognizer_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.model_config.debug = config->model_config.debug;
recognizer_config.model_config.modeling_unit =
SHERPA_ONNX_OR(config->model_config.modeling_unit, "cjkchar");
... ... @@ -431,8 +433,7 @@ static sherpa_onnx::OfflineRecognizerConfig GetOfflineRecognizerConfig(
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.model_config.debug = config->model_config.debug;
recognizer_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
if (recognizer_config.model_config.provider.empty()) {
... ... @@ -759,6 +760,9 @@ static sherpa_onnx::KeywordSpotterConfig GetKeywordSpotterConfig(
spotter_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
spotter_config.model_config.nemo_ctc.model =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
spotter_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
if (config->model_config.tokens_buf &&
... ... @@ -777,8 +781,7 @@ static sherpa_onnx::KeywordSpotterConfig GetKeywordSpotterConfig(
spotter_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
spotter_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
spotter_config.model_config.debug = config->model_config.debug;
spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
... ... @@ -1055,7 +1058,7 @@ sherpa_onnx::VadModelConfig GetVadModelConfig(
vad_config.provider = "cpu";
}
vad_config.debug = SHERPA_ONNX_OR(config->debug, false);
vad_config.debug = config->debug;
if (vad_config.debug) {
#if __OHOS__
... ... @@ -1542,7 +1545,7 @@ GetSpeakerEmbeddingExtractorConfig(
c.model = SHERPA_ONNX_OR(config->model, "");
c.num_threads = SHERPA_ONNX_OR(config->num_threads, 1);
c.debug = SHERPA_ONNX_OR(config->debug, 0);
c.debug = config->debug;
c.provider = SHERPA_ONNX_OR(config->provider, "cpu");
if (c.provider.empty()) {
c.provider = "cpu";
... ...
... ... @@ -100,6 +100,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig {
const char *model;
} SherpaOnnxOnlineZipformer2CtcModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineNemoCtcModelConfig {
const char *model;
} SherpaOnnxOnlineNemoCtcModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
SherpaOnnxOnlineTransducerModelConfig transducer;
SherpaOnnxOnlineParaformerModelConfig paraformer;
... ... @@ -120,6 +124,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
const char *tokens_buf;
/// byte size excluding the trailing '\0'
int32_t tokens_buf_size;
SherpaOnnxOnlineNemoCtcModelConfig nemo_ctc;
} SherpaOnnxOnlineModelConfig;
/// It expects 16 kHz 16-bit single channel wave format.
... ...
... ... @@ -69,6 +69,8 @@ OnlineRecognizer OnlineRecognizer::Create(
c.model_config.zipformer2_ctc.model =
config.model_config.zipformer2_ctc.model.c_str();
c.model_config.nemo_ctc.model = config.model_config.nemo_ctc.model.c_str();
c.model_config.tokens = config.model_config.tokens.c_str();
c.model_config.num_threads = config.model_config.num_threads;
c.model_config.provider = config.model_config.provider.c_str();
... ... @@ -473,6 +475,8 @@ KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
c.model_config.zipformer2_ctc.model =
config.model_config.zipformer2_ctc.model.c_str();
c.model_config.nemo_ctc.model = config.model_config.nemo_ctc.model.c_str();
c.model_config.tokens = config.model_config.tokens.c_str();
c.model_config.num_threads = config.model_config.num_threads;
c.model_config.provider = config.model_config.provider.c_str();
... ...
... ... @@ -32,10 +32,15 @@ struct OnlineZipformer2CtcModelConfig {
std::string model;
};
struct OnlineNemoCtcModelConfig {
std::string model;
};
struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
OnlineParaformerModelConfig paraformer;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNemoCtcModelConfig nemo_ctc;
std::string tokens;
int32_t num_threads = 1;
std::string provider = "cpu";
... ...
... ... @@ -175,6 +175,77 @@ class SileroVadModelRknn::Impl {
config_.silero_vad.threshold = threshold;
}
float Run(const float *samples, int32_t n) {
std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
auto &input = inputs[i];
auto &attr = input_attrs_[i];
input.index = attr.index;
if (attr.type == RKNN_TENSOR_FLOAT16) {
input.type = RKNN_TENSOR_FLOAT32;
} else if (attr.type == RKNN_TENSOR_INT64) {
input.type = RKNN_TENSOR_INT64;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
input.fmt = attr.fmt;
if (i == 0) {
input.buf = reinterpret_cast<void *>(const_cast<float *>(samples));
input.size = n * sizeof(float);
} else {
input.buf = reinterpret_cast<void *>(states_[i - 1].data());
input.size = states_[i - 1].size() * sizeof(float);
}
}
std::vector<float> out(output_attrs_[0].n_elems);
auto &next_states = states_;
std::vector<rknn_output> outputs(output_attrs_.size());
for (int32_t i = 0; i < outputs.size(); ++i) {
auto &output = outputs[i];
auto &attr = output_attrs_[i];
output.index = attr.index;
output.is_prealloc = 1;
if (attr.type == RKNN_TENSOR_FLOAT16) {
output.want_float = 1;
} else if (attr.type == RKNN_TENSOR_INT64) {
output.want_float = 0;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
if (i == 0) {
output.size = out.size() * sizeof(float);
output.buf = reinterpret_cast<void *>(out.data());
} else {
output.size = next_states[i - 1].size() * sizeof(float);
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
}
}
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
return out[0];
}
private:
void Init(void *model_data, size_t model_data_length) {
InitContext(model_data, model_data_length, config_.debug, &ctx_);
... ... @@ -267,77 +338,6 @@ class SileroVadModelRknn::Impl {
Reset();
}
float Run(const float *samples, int32_t n) {
std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
auto &input = inputs[i];
auto &attr = input_attrs_[i];
input.index = attr.index;
if (attr.type == RKNN_TENSOR_FLOAT16) {
input.type = RKNN_TENSOR_FLOAT32;
} else if (attr.type == RKNN_TENSOR_INT64) {
input.type = RKNN_TENSOR_INT64;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
input.fmt = attr.fmt;
if (i == 0) {
input.buf = reinterpret_cast<void *>(const_cast<float *>(samples));
input.size = n * sizeof(float);
} else {
input.buf = reinterpret_cast<void *>(states_[i - 1].data());
input.size = states_[i - 1].size() * sizeof(float);
}
}
std::vector<float> out(output_attrs_[0].n_elems);
auto &next_states = states_;
std::vector<rknn_output> outputs(output_attrs_.size());
for (int32_t i = 0; i < outputs.size(); ++i) {
auto &output = outputs[i];
auto &attr = output_attrs_[i];
output.index = attr.index;
output.is_prealloc = 1;
if (attr.type == RKNN_TENSOR_FLOAT16) {
output.want_float = 1;
} else if (attr.type == RKNN_TENSOR_INT64) {
output.want_float = 0;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
if (i == 0) {
output.size = out.size() * sizeof(float);
output.buf = reinterpret_cast<void *>(out.data());
} else {
output.size = next_states[i - 1].size() * sizeof(float);
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
}
}
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
return out[0];
}
private:
VadModelConfig config_;
rknn_context ctx_ = 0;
... ... @@ -395,6 +395,10 @@ void SileroVadModelRknn::SetThreshold(float threshold) {
impl_->SetThreshold(threshold);
}
float SileroVadModelRknn::Compute(const float *samples, int32_t n) {
return impl_->Run(samples, n);
}
#if __ANDROID_API__ >= 9
template SileroVadModelRknn::SileroVadModelRknn(AAssetManager *mgr,
const VadModelConfig &config);
... ...
... ... @@ -32,6 +32,7 @@ class SileroVadModelRknn : public VadModel {
* @return Return true if speech is detected. Return false otherwise.
*/
bool IsSpeech(const float *samples, int32_t n) override;
float Compute(const float *samples, int32_t n) override;
// For silero vad V4, it is WindowShift().
int32_t WindowSize() const override;
... ...
... ... @@ -89,8 +89,8 @@ void SafeJNI(JNIEnv *env, const char *functionName, Func func) {
}
// Helper function to validate JNI pointers
inline bool ValidatePointer(JNIEnv *env, jlong ptr,
const char *functionName, const char *message) {
inline bool ValidatePointer(JNIEnv *env, jlong ptr, const char *functionName,
const char *message) {
if (ptr == 0) {
jclass exClass = env->FindClass("java/lang/NullPointerException");
if (exClass != nullptr) {
... ...
... ... @@ -9,6 +9,9 @@
namespace sherpa_onnx {
OnlineModelConfig GetOnlineModelConfig(JNIEnv *env, jclass model_config_cls,
jobject model_config);
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
KeywordSpotterConfig ans;
... ... @@ -57,54 +60,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
ans.model_config = GetOnlineModelConfig(env, model_config_cls, model_config);
return ans;
}
... ...
... ... @@ -10,6 +10,117 @@
namespace sherpa_onnx {
OnlineModelConfig GetOnlineModelConfig(JNIEnv *env, jclass model_config_cls,
jobject model_config) {
OnlineModelConfig ans;
// transducer
auto fid =
env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
auto s = (jstring)env->GetObjectField(transducer_config, fid);
auto p = env->GetStringUTFChars(s, nullptr);
ans.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
// streaming NeMo CTC
fid = env->GetFieldID(model_config_cls, "neMoCtc",
"Lcom/k2fsa/sherpa/onnx/OnlineNeMoCtcModelConfig;");
jobject nemo_ctc_config = env->GetObjectField(model_config, fid);
jclass nemo_ctc_config_cls = env->GetObjectClass(nemo_ctc_config);
fid = env->GetFieldID(nemo_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(nemo_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.nemo_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_type = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.modeling_unit = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
... ... @@ -122,109 +233,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
// streaming NeMo CTC
fid = env->GetFieldID(model_config_cls, "neMoCtc",
"Lcom/k2fsa/sherpa/onnx/OnlineNeMoCtcModelConfig;");
jobject nemo_ctc_config = env->GetObjectField(model_config, fid);
jclass nemo_ctc_config_cls = env->GetObjectClass(nemo_ctc_config);
fid = env->GetFieldID(nemo_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(nemo_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.nemo_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.modeling_unit = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
ans.model_config = GetOnlineModelConfig(env, model_config_cls, model_config);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
... ...
... ... @@ -165,6 +165,11 @@ type
function ToString: AnsiString;
end;
TSherpaOnnxOnlineNemoCtcModelConfig = record
Model: AnsiString;
function ToString: AnsiString;
end;
TSherpaOnnxOnlineModelConfig = record
Transducer: TSherpaOnnxOnlineTransducerModelConfig;
Paraformer: TSherpaOnnxOnlineParaformerModelConfig;
... ... @@ -178,6 +183,7 @@ type
BpeVocab: AnsiString;
TokensBuf: AnsiString;
TokensBufSize: Integer;
NemoCtc: TSherpaOnnxOnlineNemoCtcModelConfig;
function ToString: AnsiString;
class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineModelConfig);
end;
... ... @@ -691,6 +697,10 @@ type
Model: PAnsiChar;
end;
SherpaOnnxOnlineNemoCtcModelConfig = record
Model: PAnsiChar;
end;
SherpaOnnxOnlineModelConfig= record
Transducer: SherpaOnnxOnlineTransducerModelConfig;
Paraformer: SherpaOnnxOnlineParaformerModelConfig;
... ... @@ -704,6 +714,7 @@ type
BpeVocab: PAnsiChar;
TokensBuf: PAnsiChar;
TokensBufSize: cint32;
NemoCtc: SherpaOnnxOnlineNemoCtcModelConfig;
end;
SherpaOnnxFeatureConfig = record
SampleRate: cint32;
... ... @@ -1311,6 +1322,12 @@ begin
[Self.Model]);
end;
function TSherpaOnnxOnlineNemoCtcModelConfig.ToString: AnsiString;
begin
Result := Format('TSherpaOnnxOnlineNemoCtcModelConfig(Model := %s)',
[Self.Model]);
end;
function TSherpaOnnxOnlineModelConfig.ToString: AnsiString;
begin
Result := Format('TSherpaOnnxOnlineModelConfig(Transducer := %s, ' +
... ... @@ -1322,12 +1339,13 @@ begin
'Debug := %s, ' +
'ModelType := %s, ' +
'ModelingUnit := %s, ' +
'BpeVocab := %s)'
,
'BpeVocab := %s, ' +
'NemoCtc := %s',
[Self.Transducer.ToString, Self.Paraformer.ToString,
Self.Zipformer2Ctc.ToString, Self.Tokens,
Self.NumThreads, Self.Provider, Self.Debug.ToString,
Self.ModelType, Self.ModelingUnit, Self.BpeVocab
Self.ModelType, Self.ModelingUnit, Self.BpeVocab,
Self.NemoCtc.ToString
]);
end;
... ... @@ -1426,6 +1444,7 @@ begin
C.ModelConfig.Paraformer.Decoder := PAnsiChar(Config.ModelConfig.Paraformer.Decoder);
C.ModelConfig.Zipformer2Ctc.Model := PAnsiChar(Config.ModelConfig.Zipformer2Ctc.Model);
C.ModelConfig.NemoCtc.Model := PAnsiChar(Config.ModelConfig.NemoCtc.Model);
C.ModelConfig.Tokens := PAnsiChar(Config.ModelConfig.Tokens);
C.ModelConfig.NumThreads := Config.ModelConfig.NumThreads;
... ...
... ... @@ -128,77 +128,69 @@ class TestOfflineRecognizer(unittest.TestCase):
print(s2.result.text)
def test_paraformer_single_file(self):
for use_int8 in [True, False]:
if use_int8:
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
else:
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.onnx"
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav"
tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav"
if not Path(model).is_file():
print("skipping test_paraformer_single_file()")
return
if not Path(model).is_file():
print("skipping test_paraformer_single_file()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave0)
s.accept_waveform(sample_rate, samples)
recognizer.decode_stream(s)
print(s.result.text)
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave0)
s.accept_waveform(sample_rate, samples)
recognizer.decode_stream(s)
print(s.result.text)
def test_paraformer_multiple_files(self):
for use_int8 in [True, False]:
if use_int8:
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
else:
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.onnx"
tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/1.wav"
wave2 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/2.wav"
wave3 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_paraformer_multiple_files()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
samples0, sample_rate0 = read_wave(wave0)
s0.accept_waveform(sample_rate0, samples0)
s1 = recognizer.create_stream()
samples1, sample_rate1 = read_wave(wave1)
s1.accept_waveform(sample_rate1, samples1)
s2 = recognizer.create_stream()
samples2, sample_rate2 = read_wave(wave2)
s2.accept_waveform(sample_rate2, samples2)
s3 = recognizer.create_stream()
samples3, sample_rate3 = read_wave(wave3)
s3.accept_waveform(sample_rate3, samples3)
recognizer.decode_streams([s0, s1, s2, s3])
print(s0.result.text)
print(s1.result.text)
print(s2.result.text)
print(s3.result.text)
model = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/1.wav"
wave2 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/2.wav"
wave3 = f"{d}/sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_paraformer_multiple_files()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
samples0, sample_rate0 = read_wave(wave0)
s0.accept_waveform(sample_rate0, samples0)
s1 = recognizer.create_stream()
samples1, sample_rate1 = read_wave(wave1)
s1.accept_waveform(sample_rate1, samples1)
s2 = recognizer.create_stream()
samples2, sample_rate2 = read_wave(wave2)
s2.accept_waveform(sample_rate2, samples2)
s3 = recognizer.create_stream()
samples3, sample_rate3 = read_wave(wave3)
s3.accept_waveform(sample_rate3, samples3)
recognizer.decode_streams([s0, s1, s2, s3])
print(s0.result.text)
print(s1.result.text)
print(s2.result.text)
print(s3.result.text)
def test_nemo_ctc_single_file(self):
for use_int8 in [True, False]:
... ...
... ... @@ -68,6 +68,14 @@ func sherpaOnnxOnlineZipformer2CtcModelConfig(
)
}
func sherpaOnnxOnlineNemoCtcModelConfig(
model: String = ""
) -> SherpaOnnxOnlineNemoCtcModelConfig {
return SherpaOnnxOnlineNemoCtcModelConfig(
model: toCPointer(model)
)
}
/// Return an instance of SherpaOnnxOnlineModelConfig.
///
/// Please refer to
... ... @@ -92,7 +100,8 @@ func sherpaOnnxOnlineModelConfig(
modelingUnit: String = "cjkchar",
bpeVocab: String = "",
tokensBuf: String = "",
tokensBufSize: Int = 0
tokensBufSize: Int = 0,
nemoCtc: SherpaOnnxOnlineNemoCtcModelConfig = sherpaOnnxOnlineNemoCtcModelConfig()
) -> SherpaOnnxOnlineModelConfig {
return SherpaOnnxOnlineModelConfig(
transducer: transducer,
... ... @@ -106,7 +115,8 @@ func sherpaOnnxOnlineModelConfig(
modeling_unit: toCPointer(modelingUnit),
bpe_vocab: toCPointer(bpeVocab),
tokens_buf: toCPointer(tokensBuf),
tokens_buf_size: Int32(tokensBufSize)
tokens_buf_size: Int32(tokensBufSize),
nemo_ctc: nemoCtc
)
}
... ...
... ... @@ -15,8 +15,8 @@ function freeConfig(config, Module) {
freeConfig(config.paraformer, Module)
}
if ('ctc' in config) {
freeConfig(config.ctc, Module)
if ('zipformer2Ctc' in config) {
freeConfig(config.zipformer2Ctc, Module)
}
if ('feat' in config) {
... ... @@ -157,6 +157,22 @@ function initSherpaOnnxOnlineZipformer2CtcModelConfig(config, Module) {
}
}
function initSherpaOnnxOnlineNemoCtcModelConfig(config, Module) {
const n = Module.lengthBytesUTF8(config.model || '') + 1;
const buffer = Module._malloc(n);
const len = 1 * 4; // 1 pointer
const ptr = Module._malloc(len);
Module.stringToUTF8(config.model || '', buffer, n);
Module.setValue(ptr, buffer, 'i8*');
return {
buffer: buffer, ptr: ptr, len: len,
}
}
function initSherpaOnnxOnlineModelConfig(config, Module) {
if (!('transducer' in config)) {
config.transducer = {
... ... @@ -179,6 +195,12 @@ function initSherpaOnnxOnlineModelConfig(config, Module) {
};
}
if (!('nemoCtc' in config)) {
config.nemoCtc = {
model: '',
};
}
if (!('tokensBuf' in config)) {
config.tokensBuf = '';
}
... ... @@ -193,10 +215,15 @@ function initSherpaOnnxOnlineModelConfig(config, Module) {
const paraformer =
initSherpaOnnxOnlineParaformerModelConfig(config.paraformer, Module);
const ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig(
const zipformer2Ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig(
config.zipformer2Ctc, Module);
const len = transducer.len + paraformer.len + ctc.len + 9 * 4;
const nemoCtc =
initSherpaOnnxOnlineNemoCtcModelConfig(config.nemoCtc, Module);
const len =
transducer.len + paraformer.len + zipformer2Ctc.len + 9 * 4 + nemoCtc.len;
const ptr = Module._malloc(len);
let offset = 0;
... ... @@ -206,8 +233,8 @@ function initSherpaOnnxOnlineModelConfig(config, Module) {
Module._CopyHeap(paraformer.ptr, paraformer.len, ptr + offset);
offset += paraformer.len;
Module._CopyHeap(ctc.ptr, ctc.len, ptr + offset);
offset += ctc.len;
Module._CopyHeap(zipformer2Ctc.ptr, zipformer2Ctc.len, ptr + offset);
offset += zipformer2Ctc.len;
const tokensLen = Module.lengthBytesUTF8(config.tokens || '') + 1;
const providerLen = Module.lengthBytesUTF8(config.provider || 'cpu') + 1;
... ... @@ -240,7 +267,7 @@ function initSherpaOnnxOnlineModelConfig(config, Module) {
Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen);
offset += tokensBufLen;
offset = transducer.len + paraformer.len + ctc.len;
offset = transducer.len + paraformer.len + zipformer2Ctc.len;
Module.setValue(ptr + offset, buffer, 'i8*'); // tokens
offset += 4;
... ... @@ -278,9 +305,12 @@ function initSherpaOnnxOnlineModelConfig(config, Module) {
Module.setValue(ptr + offset, config.tokensBufSize || 0, 'i32');
offset += 4;
Module._CopyHeap(nemoCtc.ptr, nemoCtc.len, ptr + offset);
offset += nemoCtc.len;
return {
buffer: buffer, ptr: ptr, len: len, transducer: transducer,
paraformer: paraformer, ctc: ctc
paraformer: paraformer, zipformer2Ctc: zipformer2Ctc, nemoCtc: nemoCtc
}
}
... ... @@ -485,6 +515,10 @@ function createOnlineRecognizer(Module, myConfig) {
model: '',
};
const onlineNemoCtcModelConfig = {
model: '',
};
let type = 0;
switch (type) {
... ... @@ -500,9 +534,13 @@ function createOnlineRecognizer(Module, myConfig) {
onlineParaformerModelConfig.decoder = './decoder.onnx';
break;
case 2:
// ctc
// zipformer2Ctc
onlineZipformer2CtcModelConfig.model = './encoder.onnx';
break;
case 3:
// nemoCtc
onlineNemoCtcModelConfig.model = './nemo-ctc.onnx';
break;
}
... ... @@ -510,6 +548,7 @@ function createOnlineRecognizer(Module, myConfig) {
transducer: onlineTransducerModelConfig,
paraformer: onlineParaformerModelConfig,
zipformer2Ctc: onlineZipformer2CtcModelConfig,
nemoCtc: onlineNemoCtcModelConfig,
tokens: './tokens.txt',
numThreads: 1,
provider: 'cpu',
... ...
... ... @@ -16,10 +16,12 @@ extern "C" {
static_assert(sizeof(SherpaOnnxOnlineTransducerModelConfig) == 3 * 4, "");
static_assert(sizeof(SherpaOnnxOnlineParaformerModelConfig) == 2 * 4, "");
static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, "");
static_assert(sizeof(SherpaOnnxOnlineNemoCtcModelConfig) == 1 * 4, "");
static_assert(sizeof(SherpaOnnxOnlineModelConfig) ==
sizeof(SherpaOnnxOnlineTransducerModelConfig) +
sizeof(SherpaOnnxOnlineParaformerModelConfig) +
sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4,
sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4 +
sizeof(SherpaOnnxOnlineNemoCtcModelConfig),
"");
static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, "");
static_assert(sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) == 2 * 4, "");
... ... @@ -36,6 +38,7 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) {
auto transducer_model_config = &model_config->transducer;
auto paraformer_model_config = &model_config->paraformer;
auto ctc_model_config = &model_config->zipformer2_ctc;
auto nemo_ctc = &model_config->nemo_ctc;
fprintf(stdout, "----------online transducer model config----------\n");
fprintf(stdout, "encoder: %s\n", transducer_model_config->encoder);
... ... @@ -46,8 +49,12 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) {
fprintf(stdout, "encoder: %s\n", paraformer_model_config->encoder);
fprintf(stdout, "decoder: %s\n", paraformer_model_config->decoder);
fprintf(stdout, "----------online ctc model config----------\n");
fprintf(stdout, "----------online zipformer2 ctc model config----------\n");
fprintf(stdout, "model: %s\n", ctc_model_config->model);
fprintf(stdout, "----------online nemo ctc model config----------\n");
fprintf(stdout, "model: %s\n", nemo_ctc->model);
fprintf(stdout, "tokens: %s\n", model_config->tokens);
fprintf(stdout, "num_threads: %d\n", model_config->num_threads);
fprintf(stdout, "provider: %s\n", model_config->provider);
... ...
... ... @@ -73,9 +73,12 @@ function initModelConfig(config, Module) {
const transducer =
initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module);
const paraformer_len = 2 * 4
const ctc_len = 1 * 4
const zipfomer2_ctc_len = 1 * 4
const nemo_ctc_len = 1 * 4
const len = transducer.len + paraformer_len + zipfomer2_ctc_len + 9 * 4 +
nemo_ctc_len;
const len = transducer.len + paraformer_len + ctc_len + 9 * 4;
const ptr = Module._malloc(len);
Module.HEAPU8.fill(0, ptr, ptr + len);
... ... @@ -112,7 +115,7 @@ function initModelConfig(config, Module) {
Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen);
offset += tokensBufLen;
offset = transducer.len + paraformer_len + ctc_len;
offset = transducer.len + paraformer_len + zipfomer2_ctc_len;
Module.setValue(ptr + offset, buffer, 'i8*'); // tokens
offset += 4;
... ...