Fangjun Kuang
Committed by GitHub

Add C++ runtime for SenseVoice models (#1148)

正在显示 34 个修改的文件 包含 1159 行增加38 行删除
@@ -15,7 +15,30 @@ echo "PATH: $PATH" @@ -15,7 +15,30 @@ echo "PATH: $PATH"
15 15
16 which $EXE 16 which $EXE
17 17
18 -if false; then 18 +log "------------------------------------------------------------"
  19 +log "Run SenseVoice models"
  20 +log "------------------------------------------------------------"
  21 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  22 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  23 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  24 +repo=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
  25 +
  26 +for m in model.onnx model.int8.onnx; do
  27 + for w in zh en yue ja ko; do
  28 + for use_itn in 0 1; do
  29 + echo "$m $w $use_itn"
  30 + time $EXE \
  31 + --tokens=$repo/tokens.txt \
  32 + --sense-voice-model=$repo/$m \
  33 + --sense-voice-use-itn=$use_itn \
  34 + $repo/test_wavs/$w.wav
  35 + done
  36 + done
  37 +done
  38 +
  39 +rm -rf $repo
  40 +
  41 +if true; then
19 # It has problems with onnxruntime 1.18 42 # It has problems with onnxruntime 1.18
20 log "------------------------------------------------------------" 43 log "------------------------------------------------------------"
21 log "Run Wenet models" 44 log "Run Wenet models"
@@ -10,6 +10,18 @@ log() { @@ -10,6 +10,18 @@ log() {
10 10
11 export GIT_CLONE_PROTECTION_ACTIVE=false 11 export GIT_CLONE_PROTECTION_ACTIVE=false
12 12
  13 +log "test offline SenseVoice CTC"
  14 +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  15 +name=$(basename $url)
  16 +repo=$(basename -s .tar.bz2 $name)
  17 +
  18 +curl -SL -O $url
  19 +tar xvf $name
  20 +rm $name
  21 +ls -lh $repo
  22 +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
  23 +rm -rf $repo
  24 +
13 log "test offline TeleSpeech CTC" 25 log "test offline TeleSpeech CTC"
14 url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 26 url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
15 name=$(basename $url) 27 name=$(basename $url)
@@ -73,7 +73,7 @@ jobs: @@ -73,7 +73,7 @@ jobs:
73 echo "pwd: $PWD" 73 echo "pwd: $PWD"
74 ls -lh ../scripts/sense-voice 74 ls -lh ../scripts/sense-voice
75 75
76 - rm -rf ./ 76 + rm -rf ./*
77 77
78 cp -v ../scripts/sense-voice/*.onnx . 78 cp -v ../scripts/sense-voice/*.onnx .
79 cp -v ../scripts/sense-voice/tokens.txt . 79 cp -v ../scripts/sense-voice/tokens.txt .
@@ -111,3 +111,4 @@ sherpa-onnx-telespeech-ctc-* @@ -111,3 +111,4 @@ sherpa-onnx-telespeech-ctc-*
111 *.fst 111 *.fst
112 .ccache 112 .ccache
113 lib*.a 113 lib*.a
  114 +sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
  1 +## 1.10.17
  2 +
  3 +* Support SenseVoice CTC models.
  4 +
1 ## 1.10.16 5 ## 1.10.16
2 6
3 * Support zh-en TTS model from MeloTTS. 7 * Support zh-en TTS model from MeloTTS.
@@ -11,7 +11,7 @@ project(sherpa-onnx) @@ -11,7 +11,7 @@ project(sherpa-onnx)
11 # ./nodejs-addon-examples 11 # ./nodejs-addon-examples
12 # ./dart-api-examples/ 12 # ./dart-api-examples/
13 # ./CHANGELOG.md 13 # ./CHANGELOG.md
14 -set(SHERPA_ONNX_VERSION "1.10.16") 14 +set(SHERPA_ONNX_VERSION "1.10.17")
15 15
16 # Disable warning about 16 # Disable warning about
17 # 17 #
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming SenseVoice CTC model from
  5 +https://github.com/FunAudioLLM/SenseVoice
  6 +to decode files.
  7 +
  8 +Please download model files from
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  10 +
  11 +For instance,
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  14 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  15 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  16 +"""
  17 +
  18 +from pathlib import Path
  19 +
  20 +import sherpa_onnx
  21 +import soundfile as sf
  22 +
  23 +
  24 +def create_recognizer():
  25 + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
  26 + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
  27 + test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
  28 + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"
  29 + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav"
  30 + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav"
  31 + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav"
  32 +
  33 + if not Path(model).is_file() or not Path(test_wav).is_file():
  34 + raise ValueError(
  35 + """Please download model files from
  36 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  37 + """
  38 + )
  39 + return (
  40 + sherpa_onnx.OfflineRecognizer.from_sense_voice(
  41 + model=model,
  42 + tokens=tokens,
  43 + use_itn=True,
  44 + debug=True,
  45 + ),
  46 + test_wav,
  47 + )
  48 +
  49 +
  50 +def main():
  51 + recognizer, wave_filename = create_recognizer()
  52 +
  53 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  54 + audio = audio[:, 0] # only use the first channel
  55 +
  56 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  57 + # sample_rate does not need to be 16000 Hz
  58 +
  59 + stream = recognizer.create_stream()
  60 + stream.accept_waveform(sample_rate, audio)
  61 + recognizer.decode_stream(stream)
  62 + print(wave_filename)
  63 + print(stream.result)
  64 +
  65 +
  66 +if __name__ == "__main__":
  67 + main()
@@ -162,7 +162,9 @@ def main(): @@ -162,7 +162,9 @@ def main():
162 "neg_mean": neg_mean, 162 "neg_mean": neg_mean,
163 "inv_stddev": inv_stddev, 163 "inv_stddev": inv_stddev,
164 "model_type": "sense_voice_ctc", 164 "model_type": "sense_voice_ctc",
165 - "version": "1", 165 + # version 1: Use QInt8
  166 + # version 2: Use QUInt8
  167 + "version": "2",
166 "model_author": "iic", 168 "model_author": "iic",
167 "maintainer": "k2-fsa", 169 "maintainer": "k2-fsa",
168 "vocab_size": vocab_size, 170 "vocab_size": vocab_size,
@@ -185,7 +187,10 @@ def main(): @@ -185,7 +187,10 @@ def main():
185 model_input=filename, 187 model_input=filename,
186 model_output=filename_int8, 188 model_output=filename_int8,
187 op_types_to_quantize=["MatMul"], 189 op_types_to_quantize=["MatMul"],
188 - weight_type=QuantType.QInt8, 190 + # Note that we have to use QUInt8 here.
  191 + #
  192 + # When QInt8 is used, C++ onnxruntime produces incorrect results
  193 + weight_type=QuantType.QUInt8,
189 ) 194 )
190 195
191 196
@@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream { @@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream {
310 310
311 static sherpa_onnx::OfflineRecognizerConfig convertConfig( 311 static sherpa_onnx::OfflineRecognizerConfig convertConfig(
312 const SherpaOnnxOfflineRecognizerConfig *config); 312 const SherpaOnnxOfflineRecognizerConfig *config);
  313 +
313 SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( 314 SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
314 const SherpaOnnxOfflineRecognizerConfig *config) { 315 const SherpaOnnxOfflineRecognizerConfig *config) {
315 sherpa_onnx::OfflineRecognizerConfig recognizer_config = 316 sherpa_onnx::OfflineRecognizerConfig recognizer_config =
@@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( @@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
391 recognizer_config.model_config.telespeech_ctc = 392 recognizer_config.model_config.telespeech_ctc =
392 SHERPA_ONNX_OR(config->model_config.telespeech_ctc, ""); 393 SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
393 394
  395 + recognizer_config.model_config.sense_voice.model =
  396 + SHERPA_ONNX_OR(config->model_config.sense_voice.model, "");
  397 +
  398 + recognizer_config.model_config.sense_voice.language =
  399 + SHERPA_ONNX_OR(config->model_config.sense_voice.language, "");
  400 +
  401 + recognizer_config.model_config.sense_voice.use_itn =
  402 + config->model_config.sense_voice.use_itn;
  403 +
394 recognizer_config.lm_config.model = 404 recognizer_config.lm_config.model =
395 SHERPA_ONNX_OR(config->lm_config.model, ""); 405 SHERPA_ONNX_OR(config->lm_config.model, "");
396 recognizer_config.lm_config.scale = 406 recognizer_config.lm_config.scale =
@@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { @@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig {
379 float scale; 379 float scale;
380 } SherpaOnnxOfflineLMConfig; 380 } SherpaOnnxOfflineLMConfig;
381 381
  382 +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSenseVoiceModelConfig {
  383 + const char *model;
  384 + const char *language;
  385 + int32_t use_itn;
  386 +} SherpaOnnxOfflineSenseVoiceModelConfig;
  387 +
382 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { 388 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
383 SherpaOnnxOfflineTransducerModelConfig transducer; 389 SherpaOnnxOfflineTransducerModelConfig transducer;
384 SherpaOnnxOfflineParaformerModelConfig paraformer; 390 SherpaOnnxOfflineParaformerModelConfig paraformer;
@@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { @@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
398 const char *modeling_unit; 404 const char *modeling_unit;
399 const char *bpe_vocab; 405 const char *bpe_vocab;
400 const char *telespeech_ctc; 406 const char *telespeech_ctc;
  407 + SherpaOnnxOfflineSenseVoiceModelConfig sense_voice;
401 } SherpaOnnxOfflineModelConfig; 408 } SherpaOnnxOfflineModelConfig;
402 409
403 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { 410 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
@@ -36,6 +36,8 @@ set(sources @@ -36,6 +36,8 @@ set(sources
36 offline-recognizer-impl.cc 36 offline-recognizer-impl.cc
37 offline-recognizer.cc 37 offline-recognizer.cc
38 offline-rnn-lm.cc 38 offline-rnn-lm.cc
  39 + offline-sense-voice-model-config.cc
  40 + offline-sense-voice-model.cc
39 offline-stream.cc 41 offline-stream.cc
40 offline-tdnn-ctc-model.cc 42 offline-tdnn-ctc-model.cc
41 offline-tdnn-model-config.cc 43 offline-tdnn-model-config.cc
1 -// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h 1 +// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
2 // 2 //
3 // Copyright (c) 2024 Xiaomi Corporation 3 // Copyright (c) 2024 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ 4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
@@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
93 93
94 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( 94 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
95 const OfflineModelConfig &config) { 95 const OfflineModelConfig &config) {
  96 + // TODO(fangjun): Refactor it. We don't need to use model_type here
96 ModelType model_type = ModelType::kUnknown; 97 ModelType model_type = ModelType::kUnknown;
97 98
98 std::string filename; 99 std::string filename;
@@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
148 149
149 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( 150 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
150 AAssetManager *mgr, const OfflineModelConfig &config) { 151 AAssetManager *mgr, const OfflineModelConfig &config) {
  152 + // TODO(fangjun): Refactor it. We don't need to use model_type here
151 ModelType model_type = ModelType::kUnknown; 153 ModelType model_type = ModelType::kUnknown;
152 154
153 std::string filename; 155 std::string filename;
@@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
18 tdnn.Register(po); 18 tdnn.Register(po);
19 zipformer_ctc.Register(po); 19 zipformer_ctc.Register(po);
20 wenet_ctc.Register(po); 20 wenet_ctc.Register(po);
  21 + sense_voice.Register(po);
21 22
22 po->Register("telespeech-ctc", &telespeech_ctc, 23 po->Register("telespeech-ctc", &telespeech_ctc,
23 "Path to model.onnx for telespeech ctc"); 24 "Path to model.onnx for telespeech ctc");
@@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const { @@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const {
94 return wenet_ctc.Validate(); 95 return wenet_ctc.Validate();
95 } 96 }
96 97
  98 + if (!sense_voice.model.empty()) {
  99 + return sense_voice.Validate();
  100 + }
  101 +
97 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { 102 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
98 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", 103 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
99 telespeech_ctc.c_str()); 104 telespeech_ctc.c_str());
100 return false; 105 return false;
101 - } else {  
102 - return true;  
103 } 106 }
104 107
  108 + if (!transducer.encoder_filename.empty()) {
105 return transducer.Validate(); 109 return transducer.Validate();
  110 + }
  111 +
  112 + return true;
106 } 113 }
107 114
108 std::string OfflineModelConfig::ToString() const { 115 std::string OfflineModelConfig::ToString() const {
@@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const { @@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const {
116 os << "tdnn=" << tdnn.ToString() << ", "; 123 os << "tdnn=" << tdnn.ToString() << ", ";
117 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; 124 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
118 os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; 125 os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
  126 + os << "sense_voice=" << sense_voice.ToString() << ", ";
119 os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; 127 os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
120 os << "tokens=\"" << tokens << "\", "; 128 os << "tokens=\"" << tokens << "\", ";
121 os << "num_threads=" << num_threads << ", "; 129 os << "num_threads=" << num_threads << ", ";
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" 9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
  11 +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
11 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" 12 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 13 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
13 #include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" 14 #include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
@@ -24,6 +25,7 @@ struct OfflineModelConfig { @@ -24,6 +25,7 @@ struct OfflineModelConfig {
24 OfflineTdnnModelConfig tdnn; 25 OfflineTdnnModelConfig tdnn;
25 OfflineZipformerCtcModelConfig zipformer_ctc; 26 OfflineZipformerCtcModelConfig zipformer_ctc;
26 OfflineWenetCtcModelConfig wenet_ctc; 27 OfflineWenetCtcModelConfig wenet_ctc;
  28 + OfflineSenseVoiceModelConfig sense_voice;
27 std::string telespeech_ctc; 29 std::string telespeech_ctc;
28 30
29 std::string tokens; 31 std::string tokens;
@@ -53,6 +55,7 @@ struct OfflineModelConfig { @@ -53,6 +55,7 @@ struct OfflineModelConfig {
53 const OfflineTdnnModelConfig &tdnn, 55 const OfflineTdnnModelConfig &tdnn,
54 const OfflineZipformerCtcModelConfig &zipformer_ctc, 56 const OfflineZipformerCtcModelConfig &zipformer_ctc,
55 const OfflineWenetCtcModelConfig &wenet_ctc, 57 const OfflineWenetCtcModelConfig &wenet_ctc,
  58 + const OfflineSenseVoiceModelConfig &sense_voice,
56 const std::string &telespeech_ctc, 59 const std::string &telespeech_ctc,
57 const std::string &tokens, int32_t num_threads, bool debug, 60 const std::string &tokens, int32_t num_threads, bool debug,
58 const std::string &provider, const std::string &model_type, 61 const std::string &provider, const std::string &model_type,
@@ -65,6 +68,7 @@ struct OfflineModelConfig { @@ -65,6 +68,7 @@ struct OfflineModelConfig {
65 tdnn(tdnn), 68 tdnn(tdnn),
66 zipformer_ctc(zipformer_ctc), 69 zipformer_ctc(zipformer_ctc),
67 wenet_ctc(wenet_ctc), 70 wenet_ctc(wenet_ctc),
  71 + sense_voice(sense_voice),
68 telespeech_ctc(telespeech_ctc), 72 telespeech_ctc(telespeech_ctc),
69 tokens(tokens), 73 tokens(tokens),
70 num_threads(num_threads), 74 num_threads(num_threads),
@@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
212 } 212 }
213 } 213 }
214 214
215 - OfflineRecognizerConfig GetConfig() const override {  
216 - return config_;  
217 - }  
218 - 215 + OfflineRecognizerConfig GetConfig() const override { return config_; }
219 216
220 private: 217 private:
221 // Decode a single stream. 218 // Decode a single stream.
@@ -21,6 +21,7 @@ @@ -21,6 +21,7 @@
21 #include "sherpa-onnx/csrc/macros.h" 21 #include "sherpa-onnx/csrc/macros.h"
22 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 22 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
23 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 23 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
  24 +#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
24 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 25 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
25 #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" 26 #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
26 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" 27 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
@@ -31,6 +32,28 @@ namespace sherpa_onnx { @@ -31,6 +32,28 @@ namespace sherpa_onnx {
31 32
32 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 33 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
33 const OfflineRecognizerConfig &config) { 34 const OfflineRecognizerConfig &config) {
  35 + if (!config.model_config.sense_voice.model.empty()) {
  36 + return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
  37 + }
  38 +
  39 + if (!config.model_config.paraformer.model.empty()) {
  40 + return std::make_unique<OfflineRecognizerParaformerImpl>(config);
  41 + }
  42 +
  43 + if (!config.model_config.nemo_ctc.model.empty() ||
  44 + !config.model_config.zipformer_ctc.model.empty() ||
  45 + !config.model_config.tdnn.model.empty() ||
  46 + !config.model_config.wenet_ctc.model.empty()) {
  47 + return std::make_unique<OfflineRecognizerCtcImpl>(config);
  48 + }
  49 +
  50 + if (!config.model_config.whisper.encoder.empty()) {
  51 + return std::make_unique<OfflineRecognizerWhisperImpl>(config);
  52 + }
  53 +
  54 + // TODO(fangjun): Refactor it. We only need to use model type for the
  55 + // following models:
  56 + // 1. transducer and nemo_transducer
34 if (!config.model_config.model_type.empty()) { 57 if (!config.model_config.model_type.empty()) {
35 const auto &model_type = config.model_config.model_type; 58 const auto &model_type = config.model_config.model_type;
36 if (model_type == "transducer") { 59 if (model_type == "transducer") {
@@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
180 #if __ANDROID_API__ >= 9 203 #if __ANDROID_API__ >= 9
181 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 204 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
182 AAssetManager *mgr, const OfflineRecognizerConfig &config) { 205 AAssetManager *mgr, const OfflineRecognizerConfig &config) {
  206 + if (!config.model_config.sense_voice.model.empty()) {
  207 + return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
  208 + }
  209 +
  210 + if (!config.model_config.paraformer.model.empty()) {
  211 + return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
  212 + }
  213 +
  214 + if (!config.model_config.nemo_ctc.model.empty() ||
  215 + !config.model_config.zipformer_ctc.model.empty() ||
  216 + !config.model_config.tdnn.model.empty() ||
  217 + !config.model_type.wenet_ctc.model.empty()) {
  218 + return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
  219 + }
  220 +
  221 + if (!config.model_config.whisper.encoder.empty()) {
  222 + return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
  223 + }
  224 +
  225 + // TODO(fangjun): Refactor it. We only need to use model type for the
  226 + // following models:
  227 + // 1. transducer and nemo_transducer
183 if (!config.model_config.model_type.empty()) { 228 if (!config.model_config.model_type.empty()) {
184 const auto &model_type = config.model_config.model_type; 229 const auto &model_type = config.model_config.model_type;
185 if (model_type == "transducer") { 230 if (model_type == "transducer") {
@@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
102 exit(-1); 102 exit(-1);
103 } 103 }
104 104
105 - // Paraformer models assume input samples are in the range  
106 - // [-32768, 32767], so we set normalize_samples to false  
107 - config_.feat_config.normalize_samples = false; 105 + InitFeatConfig();
108 } 106 }
109 107
110 #if __ANDROID_API__ >= 9 108 #if __ANDROID_API__ >= 9
@@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
124 exit(-1); 122 exit(-1);
125 } 123 }
126 124
127 - // Paraformer models assume input samples are in the range  
128 - // [-32768, 32767], so we set normalize_samples to false  
129 - config_.feat_config.normalize_samples = false; 125 + InitFeatConfig();
130 } 126 }
131 #endif 127 #endif
132 128
@@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
211 } 207 }
212 } 208 }
213 209
214 - OfflineRecognizerConfig GetConfig() const override {  
215 - return config_;  
216 - } 210 + OfflineRecognizerConfig GetConfig() const override { return config_; }
217 211
218 private: 212 private:
  213 + void InitFeatConfig() {
  214 + // Paraformer models assume input samples are in the range
  215 + // [-32768, 32767], so we set normalize_samples to false
  216 + config_.feat_config.normalize_samples = false;
  217 + config_.feat_config.window_type = "hamming";
  218 + config_.feat_config.high_freq = 0;
  219 + config_.feat_config.snip_edges = true;
  220 + }
  221 +
219 std::vector<float> ApplyLFR(const std::vector<float> &in) const { 222 std::vector<float> ApplyLFR(const std::vector<float> &in) const {
220 int32_t lfr_window_size = model_->LfrWindowSize(); 223 int32_t lfr_window_size = model_->LfrWindowSize();
221 int32_t lfr_window_shift = model_->LfrWindowShift(); 224 int32_t lfr_window_shift = model_->LfrWindowShift();
  1 +// sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
  18 +
  19 +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
  20 +#include "sherpa-onnx/csrc/offline-model-config.h"
  21 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  22 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  23 +#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
  24 +#include "sherpa-onnx/csrc/pad-sequence.h"
  25 +#include "sherpa-onnx/csrc/symbol-table.h"
  26 +
  27 +namespace sherpa_onnx {
  28 +
  29 +static OfflineRecognitionResult ConvertSenseVoiceResult(
  30 + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
  31 + int32_t frame_shift_ms, int32_t subsampling_factor) {
  32 + OfflineRecognitionResult r;
  33 + r.tokens.reserve(src.tokens.size());
  34 + r.timestamps.reserve(src.timestamps.size());
  35 +
  36 + std::string text;
  37 +
  38 + for (int32_t i = 4; i < src.tokens.size(); ++i) {
  39 + auto sym = sym_table[src.tokens[i]];
  40 + text.append(sym);
  41 +
  42 + r.tokens.push_back(std::move(sym));
  43 + }
  44 + r.text = std::move(text);
  45 +
  46 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  47 +
  48 + for (int32_t i = 4; i < src.timestamps.size(); ++i) {
  49 + float time = frame_shift_s * (src.timestamps[i] - 4);
  50 + r.timestamps.push_back(time);
  51 + }
  52 +
  53 + r.words = std::move(src.words);
  54 +
  55 + return r;
  56 +}
  57 +
  58 +class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
  59 + public:
  60 + explicit OfflineRecognizerSenseVoiceImpl(
  61 + const OfflineRecognizerConfig &config)
  62 + : OfflineRecognizerImpl(config),
  63 + config_(config),
  64 + symbol_table_(config_.model_config.tokens),
  65 + model_(std::make_unique<OfflineSenseVoiceModel>(config.model_config)) {
  66 + const auto &meta_data = model_->GetModelMetadata();
  67 + if (config.decoding_method == "greedy_search") {
  68 + decoder_ =
  69 + std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
  70 + } else {
  71 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  72 + config.decoding_method.c_str());
  73 + exit(-1);
  74 + }
  75 +
  76 + InitFeatConfig();
  77 + }
  78 +
  79 +#if __ANDROID_API__ >= 9
  80 + OfflineRecognizerSenseVoiceImpl(AAssetManager *mgr,
  81 + const OfflineRecognizerConfig &config)
  82 + : OfflineRecognizerImpl(mgr, config),
  83 + config_(config),
  84 + symbol_table_(mgr, config_.model_config.tokens),
  85 + model_(std::make_unique<OfflineSenseVoiceModel>(mgr,
  86 + config.model_config)) {
  87 + const auto &meta_data = model_->GetModelMetadata();
  88 + if (config.decoding_method == "greedy_search") {
  89 + decoder_ =
  90 + std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
  91 + } else {
  92 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  93 + config.decoding_method.c_str());
  94 + exit(-1);
  95 + }
  96 +
  97 + InitFeatConfig();
  98 + }
  99 +#endif
  100 +
  101 + std::unique_ptr<OfflineStream> CreateStream() const override {
  102 + return std::make_unique<OfflineStream>(config_.feat_config);
  103 + }
  104 +
  105 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  106 + if (n == 1) {
  107 + DecodeOneStream(ss[0]);
  108 + return;
  109 + }
  110 +
  111 + const auto &meta_data = model_->GetModelMetadata();
  112 + // 1. Apply LFR
  113 + // 2. Apply CMVN
  114 + //
  115 + // Please refer to
  116 + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf
  117 + // for what LFR means
  118 + //
  119 + // "Lower Frame Rate Neural Network Acoustic Models"
  120 + auto memory_info =
  121 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  122 +
  123 + std::vector<Ort::Value> features;
  124 + features.reserve(n);
  125 +
  126 + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
  127 +
  128 + std::vector<std::vector<float>> features_vec(n);
  129 + std::vector<int32_t> features_length_vec(n);
  130 + for (int32_t i = 0; i != n; ++i) {
  131 + std::vector<float> f = ss[i]->GetFrames();
  132 +
  133 + f = ApplyLFR(f);
  134 + ApplyCMVN(&f);
  135 +
  136 + int32_t num_frames = f.size() / feat_dim;
  137 + features_vec[i] = std::move(f);
  138 +
  139 + features_length_vec[i] = num_frames;
  140 +
  141 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  142 +
  143 + Ort::Value x = Ort::Value::CreateTensor(
  144 + memory_info, features_vec[i].data(), features_vec[i].size(),
  145 + shape.data(), shape.size());
  146 + features.push_back(std::move(x));
  147 + }
  148 +
  149 + std::vector<const Ort::Value *> features_pointer(n);
  150 + for (int32_t i = 0; i != n; ++i) {
  151 + features_pointer[i] = &features[i];
  152 + }
  153 +
  154 + std::array<int64_t, 1> features_length_shape = {n};
  155 + Ort::Value x_length = Ort::Value::CreateTensor(
  156 + memory_info, features_length_vec.data(), n,
  157 + features_length_shape.data(), features_length_shape.size());
  158 +
  159 + // Caution(fangjun): We cannot pad it with log(eps),
  160 + // i.e., -23.025850929940457f
  161 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
  162 +
  163 + int32_t language = 0;
  164 + if (config_.model_config.sense_voice.language.empty()) {
  165 + language = 0;
  166 + } else if (meta_data.lang2id.count(
  167 + config_.model_config.sense_voice.language)) {
  168 + language =
  169 + meta_data.lang2id.at(config_.model_config.sense_voice.language);
  170 + } else {
  171 + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
  172 + config_.model_config.sense_voice.language.c_str());
  173 + }
  174 +
  175 + std::vector<int32_t> language_array(n);
  176 + std::fill(language_array.begin(), language_array.end(), language);
  177 +
  178 + std::vector<int32_t> text_norm_array(n);
  179 + std::fill(text_norm_array.begin(), text_norm_array.end(),
  180 + config_.model_config.sense_voice.use_itn
  181 + ? meta_data.with_itn_id
  182 + : meta_data.without_itn_id);
  183 +
  184 + Ort::Value language_tensor = Ort::Value::CreateTensor(
  185 + memory_info, language_array.data(), n, features_length_shape.data(),
  186 + features_length_shape.size());
  187 +
  188 + Ort::Value text_norm_tensor = Ort::Value::CreateTensor(
  189 + memory_info, text_norm_array.data(), n, features_length_shape.data(),
  190 + features_length_shape.size());
  191 +
  192 + Ort::Value logits{nullptr};
  193 + try {
  194 + logits = model_->Forward(std::move(x), std::move(x_length),
  195 + std::move(language_tensor),
  196 + std::move(text_norm_tensor));
  197 + } catch (const Ort::Exception &ex) {
  198 + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
  199 + ex.what());
  200 + return;
  201 + }
  202 +
  203 + // decoder_->Decode() requires that logits_length is of dtype int64
  204 + std::vector<int64_t> features_length_vec_64;
  205 + features_length_vec_64.reserve(n);
  206 + for (auto i : features_length_vec) {
  207 + i += 4;
  208 + features_length_vec_64.push_back(i);
  209 + }
  210 +
  211 + Ort::Value logits_length = Ort::Value::CreateTensor(
  212 + memory_info, features_length_vec_64.data(), n,
  213 + features_length_shape.data(), features_length_shape.size());
  214 +
  215 + auto results =
  216 + decoder_->Decode(std::move(logits), std::move(logits_length));
  217 +
  218 + int32_t frame_shift_ms = 10;
  219 + int32_t subsampling_factor = meta_data.window_shift;
  220 + for (int32_t i = 0; i != n; ++i) {
  221 + auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
  222 + frame_shift_ms, subsampling_factor);
  223 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  224 + ss[i]->SetResult(r);
  225 + }
  226 + }
  227 +
  228 + OfflineRecognizerConfig GetConfig() const override { return config_; }
  229 +
  230 + private:
  231 + void DecodeOneStream(OfflineStream *s) const {
  232 + const auto &meta_data = model_->GetModelMetadata();
  233 +
  234 + auto memory_info =
  235 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  236 +
  237 + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
  238 + std::vector<float> f = s->GetFrames();
  239 + f = ApplyLFR(f);
  240 + ApplyCMVN(&f);
  241 + int32_t num_frames = f.size() / feat_dim;
  242 + std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
  243 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  244 + shape.data(), shape.size());
  245 +
  246 + int64_t scale_shape = 1;
  247 +
  248 + Ort::Value x_length =
  249 + Ort::Value::CreateTensor(memory_info, &num_frames, 1, &scale_shape, 1);
  250 +
  251 + int32_t language = 0;
  252 + if (config_.model_config.sense_voice.language.empty()) {
  253 + language = 0;
  254 + } else if (meta_data.lang2id.count(
  255 + config_.model_config.sense_voice.language)) {
  256 + language =
  257 + meta_data.lang2id.at(config_.model_config.sense_voice.language);
  258 + } else {
  259 + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
  260 + config_.model_config.sense_voice.language.c_str());
  261 + }
  262 +
  263 + int32_t text_norm = config_.model_config.sense_voice.use_itn
  264 + ? meta_data.with_itn_id
  265 + : meta_data.without_itn_id;
  266 +
  267 + Ort::Value language_tensor =
  268 + Ort::Value::CreateTensor(memory_info, &language, 1, &scale_shape, 1);
  269 +
  270 + Ort::Value text_norm_tensor =
  271 + Ort::Value::CreateTensor(memory_info, &text_norm, 1, &scale_shape, 1);
  272 +
  273 + Ort::Value logits{nullptr};
  274 + try {
  275 + logits = model_->Forward(std::move(x), std::move(x_length),
  276 + std::move(language_tensor),
  277 + std::move(text_norm_tensor));
  278 + } catch (const Ort::Exception &ex) {
  279 + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
  280 + ex.what());
  281 + return;
  282 + }
  283 +
  284 + int64_t new_num_frames = num_frames + 4;
  285 + Ort::Value logits_length = Ort::Value::CreateTensor(
  286 + memory_info, &new_num_frames, 1, &scale_shape, 1);
  287 +
  288 + auto results =
  289 + decoder_->Decode(std::move(logits), std::move(logits_length));
  290 +
  291 + int32_t frame_shift_ms = 10;
  292 + int32_t subsampling_factor = meta_data.window_shift;
  293 + auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms,
  294 + subsampling_factor);
  295 +
  296 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  297 + s->SetResult(r);
  298 + }
  299 +
  300 + void InitFeatConfig() {
  301 + const auto &meta_data = model_->GetModelMetadata();
  302 +
  303 + config_.feat_config.normalize_samples = meta_data.normalize_samples;
  304 + config_.feat_config.window_type = "hamming";
  305 + config_.feat_config.high_freq = 0;
  306 + config_.feat_config.snip_edges = true;
  307 + }
  308 + std::vector<float> ApplyLFR(const std::vector<float> &in) const {
  309 + const auto &meta_data = model_->GetModelMetadata();
  310 +
  311 + int32_t lfr_window_size = meta_data.window_size;
  312 + int32_t lfr_window_shift = meta_data.window_shift;
  313 + int32_t in_feat_dim = config_.feat_config.feature_dim;
  314 +
  315 + int32_t in_num_frames = in.size() / in_feat_dim;
  316 + int32_t out_num_frames =
  317 + (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
  318 + int32_t out_feat_dim = in_feat_dim * lfr_window_size;
  319 +
  320 + std::vector<float> out(out_num_frames * out_feat_dim);
  321 +
  322 + const float *p_in = in.data();
  323 + float *p_out = out.data();
  324 +
  325 + for (int32_t i = 0; i != out_num_frames; ++i) {
  326 + std::copy(p_in, p_in + out_feat_dim, p_out);
  327 +
  328 + p_out += out_feat_dim;
  329 + p_in += lfr_window_shift * in_feat_dim;
  330 + }
  331 +
  332 + return out;
  333 + }
  334 +
  335 + void ApplyCMVN(std::vector<float> *v) const {
  336 + const auto &meta_data = model_->GetModelMetadata();
  337 +
  338 + const std::vector<float> &neg_mean = meta_data.neg_mean;
  339 + const std::vector<float> &inv_stddev = meta_data.inv_stddev;
  340 +
  341 + int32_t dim = neg_mean.size();
  342 + int32_t num_frames = v->size() / dim;
  343 +
  344 + float *p = v->data();
  345 +
  346 + for (int32_t i = 0; i != num_frames; ++i) {
  347 + for (int32_t k = 0; k != dim; ++k) {
  348 + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
  349 + }
  350 +
  351 + p += dim;
  352 + }
  353 + }
  354 +
  355 + OfflineRecognizerConfig config_;
  356 + SymbolTable symbol_table_;
  357 + std::unique_ptr<OfflineSenseVoiceModel> model_;
  358 + std::unique_ptr<OfflineCtcDecoder> decoder_;
  359 +};
  360 +
  361 +} // namespace sherpa_onnx
  362 +
  363 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-sense-voice-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) {
  13 + po->Register("sense-voice-model", &model,
  14 + "Path to model.onnx of SenseVoice.");
  15 + po->Register(
  16 + "sense-voice-language", &language,
  17 + "Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used");
  18 + po->Register(
  19 + "sense-voice-use-itn", &use_itn,
  20 + "True to enable inverse text normalization. False to disable it.");
  21 +}
  22 +
  23 +bool OfflineSenseVoiceModelConfig::Validate() const {
  24 + if (!FileExists(model)) {
  25 + SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str());
  26 + return false;
  27 + }
  28 +
  29 + if (!language.empty()) {
  30 + if (language != "auto" && language != "zh" && language != "en" &&
  31 + language != "ja" && language != "ko" && language != "yue") {
  32 + SHERPA_ONNX_LOGE(
  33 + "Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, "
  34 + "ja, ko, yue. Or you can leave it empty to use 'auto'",
  35 + language.c_str());
  36 +
  37 + return false;
  38 + }
  39 + }
  40 +
  41 + return true;
  42 +}
  43 +
  44 +std::string OfflineSenseVoiceModelConfig::ToString() const {
  45 + std::ostringstream os;
  46 +
  47 + os << "OfflineSenseVoiceModelConfig(";
  48 + os << "model=\"" << model << "\", ";
  49 + os << "language=\"" << language << "\", ";
  50 + os << "use_itn=" << (use_itn ? "True" : "False") << ")";
  51 +
  52 + return os.str();
  53 +}
  54 +
  55 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-sense-voice-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineSenseVoiceModelConfig {
  14 + std::string model;
  15 +
  16 + // "" or "auto" to let the model recognize the language
  17 + // valid values:
  18 + // zh, en, ja, ko, yue, auto
  19 + std::string language = "auto";
  20 +
  21 + // true to use inverse text normalization
  22 + // false to not use inverse text normalization
  23 + bool use_itn = false;
  24 +
  25 + OfflineSenseVoiceModelConfig() = default;
  26 + explicit OfflineSenseVoiceModelConfig(const std::string &model,
  27 + const std::string &language,
  28 + bool use_itn)
  29 + : model(model), language(language), use_itn(use_itn) {}
  30 +
  31 + void Register(ParseOptions *po);
  32 + bool Validate() const;
  33 +
  34 + std::string ToString() const;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
  6 +
  7 +#include <string>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineSenseVoiceModelMetaData {
  14 + // ID for using inverse text normalization
  15 + int32_t with_itn_id;
  16 +
  17 + // ID for not using inverse text normalization
  18 + int32_t without_itn_id;
  19 +
  20 + int32_t window_size; // lfr_m
  21 + int32_t window_shift; // lfr_n
  22 + int32_t vocab_size;
  23 +
  24 + int32_t subsampling_factor = 1;
  25 +
  26 + // Usually 0 for SenseVoice models.
  27 + // 0 means samples are scaled to [-32768, 32767] before are sent to the
  28 + // feature extractor
  29 + int32_t normalize_samples = 0;
  30 +
  31 + int32_t blank_id = 0;
  32 +
  33 + // possible values:
  34 + // zh, en, ja, ko, yue, auto
  35 + // where
  36 + // zh is Chinese (Mandarin)
  37 + // en is English
  38 + // ja is Japanese
  39 + // ko is Korean
  40 + // yue is Cantonese
  41 + // auto is to let the model recognize the language
  42 + std::unordered_map<std::string, int32_t> lang2id;
  43 +
  44 + std::vector<float> neg_mean;
  45 + std::vector<float> inv_stddev;
  46 +};
  47 +
  48 +} // namespace sherpa_onnx
  49 +
  50 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-sense-voice-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +#include <utility>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/session.h"
  13 +#include "sherpa-onnx/csrc/text-utils.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineSenseVoiceModel::Impl {
  18 + public:
  19 + explicit Impl(const OfflineModelConfig &config)
  20 + : config_(config),
  21 + env_(ORT_LOGGING_LEVEL_ERROR),
  22 + sess_opts_(GetSessionOptions(config)),
  23 + allocator_{} {
  24 + auto buf = ReadFile(config_.sense_voice.model);
  25 + Init(buf.data(), buf.size());
  26 + }
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 + Impl(AAssetManager *mgr, const OfflineModelConfig &config)
  30 + : config_(config),
  31 + env_(ORT_LOGGING_LEVEL_ERROR),
  32 + sess_opts_(GetSessionOptions(config)),
  33 + allocator_{} {
  34 + auto buf = ReadFile(mgr, config_.sense_voice.model);
  35 + Init(buf.data(), buf.size());
  36 + }
  37 +#endif
  38 +
  39 + Ort::Value Forward(Ort::Value features, Ort::Value features_length,
  40 + Ort::Value language, Ort::Value text_norm) {
  41 + std::array<Ort::Value, 4> inputs = {
  42 + std::move(features),
  43 + std::move(features_length),
  44 + std::move(language),
  45 + std::move(text_norm),
  46 + };
  47 +
  48 + auto ans =
  49 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  50 + output_names_ptr_.data(), output_names_ptr_.size());
  51 + return std::move(ans[0]);
  52 + }
  53 +
  54 + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
  55 + return meta_data_;
  56 + }
  57 +
  58 + OrtAllocator *Allocator() const { return allocator_; }
  59 +
  60 + private:
  61 + void Init(void *model_data, size_t model_data_length) {
  62 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  63 + sess_opts_);
  64 +
  65 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  66 +
  67 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  68 +
  69 + // get meta data
  70 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  71 + if (config_.debug) {
  72 + std::ostringstream os;
  73 + PrintModelMetadata(os, meta_data);
  74 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  75 + }
  76 +
  77 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  78 + SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size");
  79 + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size");
  80 + SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift");
  81 + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples,
  82 + "normalize_samples");
  83 +
  84 + SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn");
  85 +
  86 + SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn");
  87 +
  88 + int32_t lang_auto = 0;
  89 + int32_t lang_zh = 0;
  90 + int32_t lang_en = 0;
  91 + int32_t lang_ja = 0;
  92 + int32_t lang_ko = 0;
  93 + int32_t lang_yue = 0;
  94 +
  95 + SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto");
  96 + SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh");
  97 + SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en");
  98 + SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja");
  99 + SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko");
  100 + SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue");
  101 +
  102 + meta_data_.lang2id = {
  103 + {"auto", lang_auto}, {"zh", lang_zh}, {"ja", lang_ja},
  104 + {"ko", lang_ko}, {"yue", lang_yue},
  105 + };
  106 +
  107 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean");
  108 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev");
  109 + }
  110 +
  111 + private:
  112 + OfflineModelConfig config_;
  113 + Ort::Env env_;
  114 + Ort::SessionOptions sess_opts_;
  115 + Ort::AllocatorWithDefaultOptions allocator_;
  116 +
  117 + std::unique_ptr<Ort::Session> sess_;
  118 +
  119 + std::vector<std::string> input_names_;
  120 + std::vector<const char *> input_names_ptr_;
  121 +
  122 + std::vector<std::string> output_names_;
  123 + std::vector<const char *> output_names_ptr_;
  124 +
  125 + OfflineSenseVoiceModelMetaData meta_data_;
  126 +};
  127 +
  128 +OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config)
  129 + : impl_(std::make_unique<Impl>(config)) {}
  130 +
  131 +#if __ANDROID_API__ >= 9
  132 +OfflineSenseVoiceModel::OfflineSenseVoiceModel(AAssetManager *mgr,
  133 + const OfflineModelConfig &config)
  134 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  135 +#endif
  136 +
  137 +OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default;
  138 +
  139 +Ort::Value OfflineSenseVoiceModel::Forward(Ort::Value features,
  140 + Ort::Value features_length,
  141 + Ort::Value language,
  142 + Ort::Value text_norm) const {
  143 + return impl_->Forward(std::move(features), std::move(features_length),
  144 + std::move(language), std::move(text_norm));
  145 +}
  146 +
  147 +const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::GetModelMetadata()
  148 + const {
  149 + return impl_->GetModelMetadata();
  150 +}
  151 +
  152 +OrtAllocator *OfflineSenseVoiceModel::Allocator() const {
  153 + return impl_->Allocator();
  154 +}
  155 +
  156 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-sense-voice-model.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <vector>
  9 +
  10 +#if __ANDROID_API__ >= 9
  11 +#include "android/asset_manager.h"
  12 +#include "android/asset_manager_jni.h"
  13 +#endif
  14 +
  15 +#include "onnxruntime_cxx_api.h" // NOLINT
  16 +#include "sherpa-onnx/csrc/offline-model-config.h"
  17 +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +class OfflineSenseVoiceModel {
  22 + public:
  23 + explicit OfflineSenseVoiceModel(const OfflineModelConfig &config);
  24 +
  25 +#if __ANDROID_API__ >= 9
  26 + OfflineSenseVoiceModel(AAssetManager *mgr, const OfflineModelConfig &config);
  27 +#endif
  28 +
  29 + ~OfflineSenseVoiceModel();
  30 +
  31 + /** Run the forward method of the model.
  32 + *
  33 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  34 + * @param features_length A 1-D tensor of shape (N,) containing number of
  35 + * valid frames in `features` before padding.
  36 + * Its dtype is int32_t.
  37 + * @param language A 1-D tensor of shape (N,) with dtype int32_t
  38 + * @param text_norm A 1-D tensor of shape (N,) with dtype int32_t
  39 + *
  40 + * @return Return logits of shape (N, T, C) with dtype float
  41 + *
  42 + * Note: The subsampling factor is 1 for SenseVoice, so there is
  43 + * no need to output logits_length.
  44 + */
  45 + Ort::Value Forward(Ort::Value features, Ort::Value features_length,
  46 + Ort::Value language, Ort::Value text_norm) const;
  47 +
  48 + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
  49 +
  50 + /** Return an allocator for allocating memory
  51 + */
  52 + OrtAllocator *Allocator() const;
  53 +
  54 + private:
  55 + class Impl;
  56 + std::unique_ptr<Impl> impl_;
  57 +};
  58 +
  59 +} // namespace sherpa_onnx
  60 +
  61 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <fstream> 8 #include <fstream>
  9 +#include <functional>
  10 +#include <numeric>
9 #include <sstream> 11 #include <sstream>
10 #include <string> 12 #include <string>
11 13
@@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) { @@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) {
153 } 155 }
154 } 156 }
155 157
  158 +float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) {
  159 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  160 + auto size = static_cast<int32_t>(std::accumulate(
  161 + shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
  162 + if (n != -1 && n < size && n > 0) {
  163 + size = n;
  164 + }
  165 +
  166 + const float *p = v->GetTensorData<float>();
  167 +
  168 + return std::accumulate(p, p + size, 1.0f);
  169 +}
  170 +
  171 +float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) {
  172 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  173 + auto size = static_cast<int32_t>(std::accumulate(
  174 + shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
  175 +
  176 + if (n != -1 && n < size && n > 0) {
  177 + size = n;
  178 + }
  179 +
  180 + auto sum = ComputeSum(v, n);
  181 + return sum / size;
  182 +}
  183 +
  184 +void PrintShape(const Ort::Value *v) {
  185 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  186 + std::ostringstream os;
  187 + for (auto i : shape) {
  188 + os << i << ", ";
  189 + }
  190 + os << "\n";
  191 + fprintf(stderr, "%s", os.str().c_str());
  192 +}
  193 +
156 template <typename T /*= float*/> 194 template <typename T /*= float*/>
157 -void Print1D(Ort::Value *v) { 195 +void Print1D(const Ort::Value *v) {
158 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 196 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
159 const T *d = v->GetTensorData<T>(); 197 const T *d = v->GetTensorData<T>();
160 std::ostringstream os; 198 std::ostringstream os;
161 for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { 199 for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
162 - os << *d << " "; 200 + os << d[i] << " ";
163 } 201 }
164 os << "\n"; 202 os << "\n";
165 fprintf(stderr, "%s\n", os.str().c_str()); 203 fprintf(stderr, "%s\n", os.str().c_str());
166 } 204 }
167 205
168 -template void Print1D<int64_t>(Ort::Value *v);  
169 -template void Print1D<float>(Ort::Value *v); 206 +template void Print1D<int64_t>(const Ort::Value *v);
  207 +template void Print1D<int32_t>(const Ort::Value *v);
  208 +template void Print1D<float>(const Ort::Value *v);
170 209
171 template <typename T /*= float*/> 210 template <typename T /*= float*/>
172 -void Print2D(Ort::Value *v) { 211 +void Print2D(const Ort::Value *v) {
173 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 212 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
174 const T *d = v->GetTensorData<T>(); 213 const T *d = v->GetTensorData<T>();
175 214
@@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) { @@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) {
183 fprintf(stderr, "%s\n", os.str().c_str()); 222 fprintf(stderr, "%s\n", os.str().c_str());
184 } 223 }
185 224
186 -template void Print2D<int64_t>(Ort::Value *v);  
187 -template void Print2D<float>(Ort::Value *v); 225 +template void Print2D<int64_t>(const Ort::Value *v);
  226 +template void Print2D<float>(const Ort::Value *v);
188 227
189 -void Print3D(Ort::Value *v) { 228 +void Print3D(const Ort::Value *v) {
190 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 229 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
191 const float *d = v->GetTensorData<float>(); 230 const float *d = v->GetTensorData<float>();
192 231
@@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) { @@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) {
202 fprintf(stderr, "\n"); 241 fprintf(stderr, "\n");
203 } 242 }
204 243
205 -void Print4D(Ort::Value *v) { 244 +void Print4D(const Ort::Value *v) {
206 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 245 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
207 const float *d = v->GetTensorData<float>(); 246 const float *d = v->GetTensorData<float>();
208 247
@@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); @@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
68 // Return a shallow copy 68 // Return a shallow copy
69 Ort::Value View(Ort::Value *v); 69 Ort::Value View(Ort::Value *v);
70 70
  71 +float ComputeSum(const Ort::Value *v, int32_t n = -1);
  72 +float ComputeMean(const Ort::Value *v, int32_t n = -1);
  73 +
71 // Print a 1-D tensor to stderr 74 // Print a 1-D tensor to stderr
72 template <typename T = float> 75 template <typename T = float>
73 -void Print1D(Ort::Value *v); 76 +void Print1D(const Ort::Value *v);
74 77
75 // Print a 2-D tensor to stderr 78 // Print a 2-D tensor to stderr
76 template <typename T = float> 79 template <typename T = float>
77 -void Print2D(Ort::Value *v); 80 +void Print2D(const Ort::Value *v);
78 81
79 // Print a 3-D tensor to stderr 82 // Print a 3-D tensor to stderr
80 -void Print3D(Ort::Value *v); 83 +void Print3D(const Ort::Value *v);
81 84
82 // Print a 4-D tensor to stderr 85 // Print a 4-D tensor to stderr
83 -void Print4D(Ort::Value *v); 86 +void Print4D(const Ort::Value *v);
  87 +
  88 +void PrintShape(const Ort::Value *v);
84 89
85 template <typename T = float> 90 template <typename T = float>
86 void Fill(Ort::Value *tensor, T value) { 91 void Fill(Ort::Value *tensor, T value) {
@@ -15,6 +15,7 @@ set(srcs @@ -15,6 +15,7 @@ set(srcs
15 offline-paraformer-model-config.cc 15 offline-paraformer-model-config.cc
16 offline-punctuation.cc 16 offline-punctuation.cc
17 offline-recognizer.cc 17 offline-recognizer.cc
  18 + offline-sense-voice-model-config.cc
18 offline-stream.cc 19 offline-stream.cc
19 offline-tdnn-model-config.cc 20 offline-tdnn-model-config.cc
20 offline-transducer-model-config.cc 21 offline-transducer-model-config.cc
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  13 +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 15 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
15 #include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" 16 #include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
@@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) {
26 PybindOfflineTdnnModelConfig(m); 27 PybindOfflineTdnnModelConfig(m);
27 PybindOfflineZipformerCtcModelConfig(m); 28 PybindOfflineZipformerCtcModelConfig(m);
28 PybindOfflineWenetCtcModelConfig(m); 29 PybindOfflineWenetCtcModelConfig(m);
  30 + PybindOfflineSenseVoiceModelConfig(m);
29 31
30 using PyClass = OfflineModelConfig; 32 using PyClass = OfflineModelConfig;
31 py::class_<PyClass>(*m, "OfflineModelConfig") 33 py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) {
36 const OfflineNemoEncDecCtcModelConfig &, 38 const OfflineNemoEncDecCtcModelConfig &,
37 const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, 39 const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
38 const OfflineZipformerCtcModelConfig &, 40 const OfflineZipformerCtcModelConfig &,
39 - const OfflineWenetCtcModelConfig &, const std::string &, 41 + const OfflineWenetCtcModelConfig &,
  42 + const OfflineSenseVoiceModelConfig &, const std::string &,
40 const std::string &, int32_t, bool, const std::string &, 43 const std::string &, int32_t, bool, const std::string &,
41 const std::string &, const std::string &, const std::string &>(), 44 const std::string &, const std::string &, const std::string &>(),
42 py::arg("transducer") = OfflineTransducerModelConfig(), 45 py::arg("transducer") = OfflineTransducerModelConfig(),
@@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) {
46 py::arg("tdnn") = OfflineTdnnModelConfig(), 49 py::arg("tdnn") = OfflineTdnnModelConfig(),
47 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), 50 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
48 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), 51 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
  52 + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
49 py::arg("telespeech_ctc") = "", py::arg("tokens"), 53 py::arg("telespeech_ctc") = "", py::arg("tokens"),
50 py::arg("num_threads"), py::arg("debug") = false, 54 py::arg("num_threads"), py::arg("debug") = false,
51 py::arg("provider") = "cpu", py::arg("model_type") = "", 55 py::arg("provider") = "cpu", py::arg("model_type") = "",
@@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) {
57 .def_readwrite("tdnn", &PyClass::tdnn) 61 .def_readwrite("tdnn", &PyClass::tdnn)
58 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) 62 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
59 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 63 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
  64 + .def_readwrite("sense_voice", &PyClass::sense_voice)
60 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) 65 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
61 .def_readwrite("tokens", &PyClass::tokens) 66 .def_readwrite("tokens", &PyClass::tokens)
62 .def_readwrite("num_threads", &PyClass::num_threads) 67 .def_readwrite("num_threads", &PyClass::num_threads)
@@ -14,6 +14,7 @@ namespace sherpa_onnx { @@ -14,6 +14,7 @@ namespace sherpa_onnx {
14 void PybindOfflineParaformerModelConfig(py::module *m) { 14 void PybindOfflineParaformerModelConfig(py::module *m) {
15 using PyClass = OfflineParaformerModelConfig; 15 using PyClass = OfflineParaformerModelConfig;
16 py::class_<PyClass>(*m, "OfflineParaformerModelConfig") 16 py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
  17 + .def(py::init<>())
17 .def(py::init<const std::string &>(), py::arg("model")) 18 .def(py::init<const std::string &>(), py::arg("model"))
18 .def_readwrite("model", &PyClass::model) 19 .def_readwrite("model", &PyClass::model)
19 .def("__str__", &PyClass::ToString); 20 .def("__str__", &PyClass::ToString);
  1 +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineSenseVoiceModelConfig(py::module *m) {
  15 + using PyClass = OfflineSenseVoiceModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineSenseVoiceModelConfig")
  17 + .def(py::init<>())
  18 + .def(py::init<const std::string &, const std::string &, bool>(),
  19 + py::arg("model"), py::arg("language"), py::arg("use_itn"))
  20 + .def_readwrite("model", &PyClass::model)
  21 + .def_readwrite("language", &PyClass::language)
  22 + .def_readwrite("use_itn", &PyClass::use_itn)
  23 + .def("__str__", &PyClass::ToString);
  24 +}
  25 +
  26 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSenseVoiceModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
@@ -10,6 +10,7 @@ from _sherpa_onnx import ( @@ -10,6 +10,7 @@ from _sherpa_onnx import (
10 OfflineModelConfig, 10 OfflineModelConfig,
11 OfflineNemoEncDecCtcModelConfig, 11 OfflineNemoEncDecCtcModelConfig,
12 OfflineParaformerModelConfig, 12 OfflineParaformerModelConfig,
  13 + OfflineSenseVoiceModelConfig,
13 ) 14 )
14 from _sherpa_onnx import OfflineRecognizer as _Recognizer 15 from _sherpa_onnx import OfflineRecognizer as _Recognizer
15 from _sherpa_onnx import ( 16 from _sherpa_onnx import (
@@ -174,6 +175,88 @@ class OfflineRecognizer(object): @@ -174,6 +175,88 @@ class OfflineRecognizer(object):
174 return self 175 return self
175 176
176 @classmethod 177 @classmethod
  178 + def from_sense_voice(
  179 + cls,
  180 + model: str,
  181 + tokens: str,
  182 + num_threads: int = 1,
  183 + sample_rate: int = 16000,
  184 + feature_dim: int = 80,
  185 + decoding_method: str = "greedy_search",
  186 + debug: bool = False,
  187 + provider: str = "cpu",
  188 + language: str = "",
  189 + use_itn: bool = False,
  190 + rule_fsts: str = "",
  191 + rule_fars: str = "",
  192 + ):
  193 + """
  194 + Please refer to
  195 + `<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
  196 + to download pre-trained models.
  197 +
  198 + Args:
  199 + tokens:
  200 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  201 + columns::
  202 +
  203 + symbol integer_id
  204 +
  205 + model:
  206 + Path to ``model.onnx``.
  207 + num_threads:
  208 + Number of threads for neural network computation.
  209 + sample_rate:
  210 + Sample rate of the training data used to train the model.
  211 + feature_dim:
  212 + Dimension of the feature used to train the model.
  213 + decoding_method:
  214 + Valid values are greedy_search.
  215 + debug:
  216 + True to show debug messages.
  217 + provider:
  218 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  219 + language:
  220 + If not empty, then valid values are: auto, zh, en, ja, ko, yue
  221 + use_itn:
  222 + True to enable inverse text normalization; False to disable it.
  223 + rule_fsts:
  224 + If not empty, it specifies fsts for inverse text normalization.
  225 + If there are multiple fsts, they are separated by a comma.
  226 + rule_fars:
  227 + If not empty, it specifies fst archives for inverse text normalization.
  228 + If there are multiple archives, they are separated by a comma.
  229 + """
  230 + self = cls.__new__(cls)
  231 + model_config = OfflineModelConfig(
  232 + sense_voice=OfflineSenseVoiceModelConfig(
  233 + model=model,
  234 + language=language,
  235 + use_itn=use_itn,
  236 + ),
  237 + tokens=tokens,
  238 + num_threads=num_threads,
  239 + debug=debug,
  240 + provider=provider,
  241 + )
  242 +
  243 + feat_config = FeatureExtractorConfig(
  244 + sampling_rate=sample_rate,
  245 + feature_dim=feature_dim,
  246 + )
  247 +
  248 + recognizer_config = OfflineRecognizerConfig(
  249 + feat_config=feat_config,
  250 + model_config=model_config,
  251 + decoding_method=decoding_method,
  252 + rule_fsts=rule_fsts,
  253 + rule_fars=rule_fars,
  254 + )
  255 + self.recognizer = _Recognizer(recognizer_config)
  256 + self.config = recognizer_config
  257 + return self
  258 +
  259 + @classmethod
177 def from_paraformer( 260 def from_paraformer(
178 cls, 261 cls,
179 paraformer: str, 262 paraformer: str,
@@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig( @@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig(
355 ) 355 )
356 } 356 }
357 357
  358 +func sherpaOnnxOfflineSenseVoiceModelConfig(
  359 + model: String = "",
  360 + language: String = "",
  361 + useInverseTextNormalization: Bool = false
  362 +) -> SherpaOnnxOfflineSenseVoiceModelConfig {
  363 + return SherpaOnnxOfflineSenseVoiceModelConfig(
  364 + model: toCPointer(model),
  365 + language: toCPointer(language),
  366 + use_itn: useInverseTextNormalization ? 1 : 0
  367 + )
  368 +}
  369 +
358 func sherpaOnnxOfflineLMConfig( 370 func sherpaOnnxOfflineLMConfig(
359 model: String = "", 371 model: String = "",
360 scale: Float = 1.0 372 scale: Float = 1.0
@@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig( @@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig(
378 modelType: String = "", 390 modelType: String = "",
379 modelingUnit: String = "cjkchar", 391 modelingUnit: String = "cjkchar",
380 bpeVocab: String = "", 392 bpeVocab: String = "",
381 - teleSpeechCtc: String = "" 393 + teleSpeechCtc: String = "",
  394 + senseVoice: SherpaOnnxOfflineSenseVoiceModelConfig = sherpaOnnxOfflineSenseVoiceModelConfig()
382 ) -> SherpaOnnxOfflineModelConfig { 395 ) -> SherpaOnnxOfflineModelConfig {
383 return SherpaOnnxOfflineModelConfig( 396 return SherpaOnnxOfflineModelConfig(
384 transducer: transducer, 397 transducer: transducer,
@@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig( @@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig(
393 model_type: toCPointer(modelType), 406 model_type: toCPointer(modelType),
394 modeling_unit: toCPointer(modelingUnit), 407 modeling_unit: toCPointer(modelingUnit),
395 bpe_vocab: toCPointer(bpeVocab), 408 bpe_vocab: toCPointer(bpeVocab),
396 - telespeech_ctc: toCPointer(teleSpeechCtc) 409 + telespeech_ctc: toCPointer(teleSpeechCtc),
  410 + sense_voice: senseVoice
397 ) 411 )
398 } 412 }
399 413
@@ -17,6 +17,7 @@ func run() { @@ -17,6 +17,7 @@ func run() {
17 var modelConfig: SherpaOnnxOfflineModelConfig 17 var modelConfig: SherpaOnnxOfflineModelConfig
18 var modelType = "whisper" 18 var modelType = "whisper"
19 // modelType = "paraformer" 19 // modelType = "paraformer"
  20 + // modelType = "sense_voice"
20 21
21 if modelType == "whisper" { 22 if modelType == "whisper" {
22 let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" 23 let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
@@ -47,6 +48,19 @@ func run() { @@ -47,6 +48,19 @@ func run() {
47 debug: 0, 48 debug: 0,
48 modelType: "paraformer" 49 modelType: "paraformer"
49 ) 50 )
  51 + } else if modelType == "sense_voice" {
  52 + let model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
  53 + let tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
  54 + let senseVoiceConfig = sherpaOnnxOfflineSenseVoiceModelConfig(
  55 + model: model,
  56 + useInverseTextNormalization: true
  57 + )
  58 +
  59 + modelConfig = sherpaOnnxOfflineModelConfig(
  60 + tokens: tokens,
  61 + debug: 0,
  62 + senseVoice: senseVoiceConfig
  63 + )
50 } else { 64 } else {
51 print("Please specify a supported modelType \(modelType)") 65 print("Please specify a supported modelType \(modelType)")
52 return 66 return
@@ -63,7 +77,10 @@ func run() { @@ -63,7 +77,10 @@ func run() {
63 77
64 recognizer = SherpaOnnxOfflineRecognizer(config: &config) 78 recognizer = SherpaOnnxOfflineRecognizer(config: &config)
65 79
66 - let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" 80 + var filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
  81 + if modelType == "sense_voice" {
  82 + filePath = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
  83 + }
67 let fileURL: NSURL = NSURL(fileURLWithPath: filePath) 84 let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
68 let audioFile = try! AVAudioFile(forReading: fileURL as URL) 85 let audioFile = try! AVAudioFile(forReading: fileURL as URL)
69 86