Committed by
GitHub
Add C++ runtime for SenseVoice models (#1148)
正在显示
34 个修改的文件
包含
1159 行增加
和
38 行删除
| @@ -15,7 +15,30 @@ echo "PATH: $PATH" | @@ -15,7 +15,30 @@ echo "PATH: $PATH" | ||
| 15 | 15 | ||
| 16 | which $EXE | 16 | which $EXE |
| 17 | 17 | ||
| 18 | -if false; then | 18 | +log "------------------------------------------------------------" |
| 19 | +log "Run SenseVoice models" | ||
| 20 | +log "------------------------------------------------------------" | ||
| 21 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 22 | +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 23 | +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 24 | +repo=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 | ||
| 25 | + | ||
| 26 | +for m in model.onnx model.int8.onnx; do | ||
| 27 | + for w in zh en yue ja ko; do | ||
| 28 | + for use_itn in 0 1; do | ||
| 29 | + echo "$m $w $use_itn" | ||
| 30 | + time $EXE \ | ||
| 31 | + --tokens=$repo/tokens.txt \ | ||
| 32 | + --sense-voice-model=$repo/$m \ | ||
| 33 | + --sense-voice-use-itn=$use_itn \ | ||
| 34 | + $repo/test_wavs/$w.wav | ||
| 35 | + done | ||
| 36 | + done | ||
| 37 | +done | ||
| 38 | + | ||
| 39 | +rm -rf $repo | ||
| 40 | + | ||
| 41 | +if true; then | ||
| 19 | # It has problems with onnxruntime 1.18 | 42 | # It has problems with onnxruntime 1.18 |
| 20 | log "------------------------------------------------------------" | 43 | log "------------------------------------------------------------" |
| 21 | log "Run Wenet models" | 44 | log "Run Wenet models" |
| @@ -10,6 +10,18 @@ log() { | @@ -10,6 +10,18 @@ log() { | ||
| 10 | 10 | ||
| 11 | export GIT_CLONE_PROTECTION_ACTIVE=false | 11 | export GIT_CLONE_PROTECTION_ACTIVE=false |
| 12 | 12 | ||
| 13 | +log "test offline SenseVoice CTC" | ||
| 14 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 15 | +name=$(basename $url) | ||
| 16 | +repo=$(basename -s .tar.bz2 $name) | ||
| 17 | + | ||
| 18 | +curl -SL -O $url | ||
| 19 | +tar xvf $name | ||
| 20 | +rm $name | ||
| 21 | +ls -lh $repo | ||
| 22 | +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py | ||
| 23 | +rm -rf $repo | ||
| 24 | + | ||
| 13 | log "test offline TeleSpeech CTC" | 25 | log "test offline TeleSpeech CTC" |
| 14 | url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 | 26 | url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 |
| 15 | name=$(basename $url) | 27 | name=$(basename $url) |
| @@ -73,7 +73,7 @@ jobs: | @@ -73,7 +73,7 @@ jobs: | ||
| 73 | echo "pwd: $PWD" | 73 | echo "pwd: $PWD" |
| 74 | ls -lh ../scripts/sense-voice | 74 | ls -lh ../scripts/sense-voice |
| 75 | 75 | ||
| 76 | - rm -rf ./ | 76 | + rm -rf ./* |
| 77 | 77 | ||
| 78 | cp -v ../scripts/sense-voice/*.onnx . | 78 | cp -v ../scripts/sense-voice/*.onnx . |
| 79 | cp -v ../scripts/sense-voice/tokens.txt . | 79 | cp -v ../scripts/sense-voice/tokens.txt . |
| @@ -11,7 +11,7 @@ project(sherpa-onnx) | @@ -11,7 +11,7 @@ project(sherpa-onnx) | ||
| 11 | # ./nodejs-addon-examples | 11 | # ./nodejs-addon-examples |
| 12 | # ./dart-api-examples/ | 12 | # ./dart-api-examples/ |
| 13 | # ./CHANGELOG.md | 13 | # ./CHANGELOG.md |
| 14 | -set(SHERPA_ONNX_VERSION "1.10.16") | 14 | +set(SHERPA_ONNX_VERSION "1.10.17") |
| 15 | 15 | ||
| 16 | # Disable warning about | 16 | # Disable warning about |
| 17 | # | 17 | # |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming SenseVoice CTC model from | ||
| 5 | +https://github.com/FunAudioLLM/SenseVoice | ||
| 6 | +to decode files. | ||
| 7 | + | ||
| 8 | +Please download model files from | ||
| 9 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 10 | + | ||
| 11 | +For instance, | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 15 | +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 16 | +""" | ||
| 17 | + | ||
| 18 | +from pathlib import Path | ||
| 19 | + | ||
| 20 | +import sherpa_onnx | ||
| 21 | +import soundfile as sf | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +def create_recognizer(): | ||
| 25 | + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx" | ||
| 26 | + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" | ||
| 27 | + test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" | ||
| 28 | + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav" | ||
| 29 | + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav" | ||
| 30 | + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav" | ||
| 31 | + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav" | ||
| 32 | + | ||
| 33 | + if not Path(model).is_file() or not Path(test_wav).is_file(): | ||
| 34 | + raise ValueError( | ||
| 35 | + """Please download model files from | ||
| 36 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 37 | + """ | ||
| 38 | + ) | ||
| 39 | + return ( | ||
| 40 | + sherpa_onnx.OfflineRecognizer.from_sense_voice( | ||
| 41 | + model=model, | ||
| 42 | + tokens=tokens, | ||
| 43 | + use_itn=True, | ||
| 44 | + debug=True, | ||
| 45 | + ), | ||
| 46 | + test_wav, | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + | ||
| 50 | +def main(): | ||
| 51 | + recognizer, wave_filename = create_recognizer() | ||
| 52 | + | ||
| 53 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 54 | + audio = audio[:, 0] # only use the first channel | ||
| 55 | + | ||
| 56 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 57 | + # sample_rate does not need to be 16000 Hz | ||
| 58 | + | ||
| 59 | + stream = recognizer.create_stream() | ||
| 60 | + stream.accept_waveform(sample_rate, audio) | ||
| 61 | + recognizer.decode_stream(stream) | ||
| 62 | + print(wave_filename) | ||
| 63 | + print(stream.result) | ||
| 64 | + | ||
| 65 | + | ||
| 66 | +if __name__ == "__main__": | ||
| 67 | + main() |
| @@ -162,7 +162,9 @@ def main(): | @@ -162,7 +162,9 @@ def main(): | ||
| 162 | "neg_mean": neg_mean, | 162 | "neg_mean": neg_mean, |
| 163 | "inv_stddev": inv_stddev, | 163 | "inv_stddev": inv_stddev, |
| 164 | "model_type": "sense_voice_ctc", | 164 | "model_type": "sense_voice_ctc", |
| 165 | - "version": "1", | 165 | + # version 1: Use QInt8 |
| 166 | + # version 2: Use QUInt8 | ||
| 167 | + "version": "2", | ||
| 166 | "model_author": "iic", | 168 | "model_author": "iic", |
| 167 | "maintainer": "k2-fsa", | 169 | "maintainer": "k2-fsa", |
| 168 | "vocab_size": vocab_size, | 170 | "vocab_size": vocab_size, |
| @@ -185,7 +187,10 @@ def main(): | @@ -185,7 +187,10 @@ def main(): | ||
| 185 | model_input=filename, | 187 | model_input=filename, |
| 186 | model_output=filename_int8, | 188 | model_output=filename_int8, |
| 187 | op_types_to_quantize=["MatMul"], | 189 | op_types_to_quantize=["MatMul"], |
| 188 | - weight_type=QuantType.QInt8, | 190 | + # Note that we have to use QUInt8 here. |
| 191 | + # | ||
| 192 | + # When QInt8 is used, C++ onnxruntime produces incorrect results | ||
| 193 | + weight_type=QuantType.QUInt8, | ||
| 189 | ) | 194 | ) |
| 190 | 195 | ||
| 191 | 196 |
| @@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream { | @@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream { | ||
| 310 | 310 | ||
| 311 | static sherpa_onnx::OfflineRecognizerConfig convertConfig( | 311 | static sherpa_onnx::OfflineRecognizerConfig convertConfig( |
| 312 | const SherpaOnnxOfflineRecognizerConfig *config); | 312 | const SherpaOnnxOfflineRecognizerConfig *config); |
| 313 | + | ||
| 313 | SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | 314 | SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( |
| 314 | const SherpaOnnxOfflineRecognizerConfig *config) { | 315 | const SherpaOnnxOfflineRecognizerConfig *config) { |
| 315 | sherpa_onnx::OfflineRecognizerConfig recognizer_config = | 316 | sherpa_onnx::OfflineRecognizerConfig recognizer_config = |
| @@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( | @@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( | ||
| 391 | recognizer_config.model_config.telespeech_ctc = | 392 | recognizer_config.model_config.telespeech_ctc = |
| 392 | SHERPA_ONNX_OR(config->model_config.telespeech_ctc, ""); | 393 | SHERPA_ONNX_OR(config->model_config.telespeech_ctc, ""); |
| 393 | 394 | ||
| 395 | + recognizer_config.model_config.sense_voice.model = | ||
| 396 | + SHERPA_ONNX_OR(config->model_config.sense_voice.model, ""); | ||
| 397 | + | ||
| 398 | + recognizer_config.model_config.sense_voice.language = | ||
| 399 | + SHERPA_ONNX_OR(config->model_config.sense_voice.language, ""); | ||
| 400 | + | ||
| 401 | + recognizer_config.model_config.sense_voice.use_itn = | ||
| 402 | + config->model_config.sense_voice.use_itn; | ||
| 403 | + | ||
| 394 | recognizer_config.lm_config.model = | 404 | recognizer_config.lm_config.model = |
| 395 | SHERPA_ONNX_OR(config->lm_config.model, ""); | 405 | SHERPA_ONNX_OR(config->lm_config.model, ""); |
| 396 | recognizer_config.lm_config.scale = | 406 | recognizer_config.lm_config.scale = |
| @@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { | @@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { | ||
| 379 | float scale; | 379 | float scale; |
| 380 | } SherpaOnnxOfflineLMConfig; | 380 | } SherpaOnnxOfflineLMConfig; |
| 381 | 381 | ||
| 382 | +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSenseVoiceModelConfig { | ||
| 383 | + const char *model; | ||
| 384 | + const char *language; | ||
| 385 | + int32_t use_itn; | ||
| 386 | +} SherpaOnnxOfflineSenseVoiceModelConfig; | ||
| 387 | + | ||
| 382 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { | 388 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { |
| 383 | SherpaOnnxOfflineTransducerModelConfig transducer; | 389 | SherpaOnnxOfflineTransducerModelConfig transducer; |
| 384 | SherpaOnnxOfflineParaformerModelConfig paraformer; | 390 | SherpaOnnxOfflineParaformerModelConfig paraformer; |
| @@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { | @@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { | ||
| 398 | const char *modeling_unit; | 404 | const char *modeling_unit; |
| 399 | const char *bpe_vocab; | 405 | const char *bpe_vocab; |
| 400 | const char *telespeech_ctc; | 406 | const char *telespeech_ctc; |
| 407 | + SherpaOnnxOfflineSenseVoiceModelConfig sense_voice; | ||
| 401 | } SherpaOnnxOfflineModelConfig; | 408 | } SherpaOnnxOfflineModelConfig; |
| 402 | 409 | ||
| 403 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { | 410 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { |
| @@ -36,6 +36,8 @@ set(sources | @@ -36,6 +36,8 @@ set(sources | ||
| 36 | offline-recognizer-impl.cc | 36 | offline-recognizer-impl.cc |
| 37 | offline-recognizer.cc | 37 | offline-recognizer.cc |
| 38 | offline-rnn-lm.cc | 38 | offline-rnn-lm.cc |
| 39 | + offline-sense-voice-model-config.cc | ||
| 40 | + offline-sense-voice-model.cc | ||
| 39 | offline-stream.cc | 41 | offline-stream.cc |
| 40 | offline-tdnn-ctc-model.cc | 42 | offline-tdnn-ctc-model.cc |
| 41 | offline-tdnn-model-config.cc | 43 | offline-tdnn-model-config.cc |
| 1 | -// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h | 1 | +// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2024 Xiaomi Corporation | 3 | // Copyright (c) 2024 Xiaomi Corporation |
| 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ |
| @@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 93 | 93 | ||
| 94 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 94 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 95 | const OfflineModelConfig &config) { | 95 | const OfflineModelConfig &config) { |
| 96 | + // TODO(fangjun): Refactor it. We don't need to use model_type here | ||
| 96 | ModelType model_type = ModelType::kUnknown; | 97 | ModelType model_type = ModelType::kUnknown; |
| 97 | 98 | ||
| 98 | std::string filename; | 99 | std::string filename; |
| @@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 148 | 149 | ||
| 149 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 150 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 150 | AAssetManager *mgr, const OfflineModelConfig &config) { | 151 | AAssetManager *mgr, const OfflineModelConfig &config) { |
| 152 | + // TODO(fangjun): Refactor it. We don't need to use model_type here | ||
| 151 | ModelType model_type = ModelType::kUnknown; | 153 | ModelType model_type = ModelType::kUnknown; |
| 152 | 154 | ||
| 153 | std::string filename; | 155 | std::string filename; |
| @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 18 | tdnn.Register(po); | 18 | tdnn.Register(po); |
| 19 | zipformer_ctc.Register(po); | 19 | zipformer_ctc.Register(po); |
| 20 | wenet_ctc.Register(po); | 20 | wenet_ctc.Register(po); |
| 21 | + sense_voice.Register(po); | ||
| 21 | 22 | ||
| 22 | po->Register("telespeech-ctc", &telespeech_ctc, | 23 | po->Register("telespeech-ctc", &telespeech_ctc, |
| 23 | "Path to model.onnx for telespeech ctc"); | 24 | "Path to model.onnx for telespeech ctc"); |
| @@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const { | @@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const { | ||
| 94 | return wenet_ctc.Validate(); | 95 | return wenet_ctc.Validate(); |
| 95 | } | 96 | } |
| 96 | 97 | ||
| 98 | + if (!sense_voice.model.empty()) { | ||
| 99 | + return sense_voice.Validate(); | ||
| 100 | + } | ||
| 101 | + | ||
| 97 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { | 102 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { |
| 98 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", | 103 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", |
| 99 | telespeech_ctc.c_str()); | 104 | telespeech_ctc.c_str()); |
| 100 | return false; | 105 | return false; |
| 101 | - } else { | ||
| 102 | - return true; | ||
| 103 | } | 106 | } |
| 104 | 107 | ||
| 108 | + if (!transducer.encoder_filename.empty()) { | ||
| 105 | return transducer.Validate(); | 109 | return transducer.Validate(); |
| 110 | + } | ||
| 111 | + | ||
| 112 | + return true; | ||
| 106 | } | 113 | } |
| 107 | 114 | ||
| 108 | std::string OfflineModelConfig::ToString() const { | 115 | std::string OfflineModelConfig::ToString() const { |
| @@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 116 | os << "tdnn=" << tdnn.ToString() << ", "; | 123 | os << "tdnn=" << tdnn.ToString() << ", "; |
| 117 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; | 124 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; |
| 118 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 125 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 126 | + os << "sense_voice=" << sense_voice.ToString() << ", "; | ||
| 119 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; | 127 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; |
| 120 | os << "tokens=\"" << tokens << "\", "; | 128 | os << "tokens=\"" << tokens << "\", "; |
| 121 | os << "num_threads=" << num_threads << ", "; | 129 | os << "num_threads=" << num_threads << ", "; |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" | 14 | #include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" |
| @@ -24,6 +25,7 @@ struct OfflineModelConfig { | @@ -24,6 +25,7 @@ struct OfflineModelConfig { | ||
| 24 | OfflineTdnnModelConfig tdnn; | 25 | OfflineTdnnModelConfig tdnn; |
| 25 | OfflineZipformerCtcModelConfig zipformer_ctc; | 26 | OfflineZipformerCtcModelConfig zipformer_ctc; |
| 26 | OfflineWenetCtcModelConfig wenet_ctc; | 27 | OfflineWenetCtcModelConfig wenet_ctc; |
| 28 | + OfflineSenseVoiceModelConfig sense_voice; | ||
| 27 | std::string telespeech_ctc; | 29 | std::string telespeech_ctc; |
| 28 | 30 | ||
| 29 | std::string tokens; | 31 | std::string tokens; |
| @@ -53,6 +55,7 @@ struct OfflineModelConfig { | @@ -53,6 +55,7 @@ struct OfflineModelConfig { | ||
| 53 | const OfflineTdnnModelConfig &tdnn, | 55 | const OfflineTdnnModelConfig &tdnn, |
| 54 | const OfflineZipformerCtcModelConfig &zipformer_ctc, | 56 | const OfflineZipformerCtcModelConfig &zipformer_ctc, |
| 55 | const OfflineWenetCtcModelConfig &wenet_ctc, | 57 | const OfflineWenetCtcModelConfig &wenet_ctc, |
| 58 | + const OfflineSenseVoiceModelConfig &sense_voice, | ||
| 56 | const std::string &telespeech_ctc, | 59 | const std::string &telespeech_ctc, |
| 57 | const std::string &tokens, int32_t num_threads, bool debug, | 60 | const std::string &tokens, int32_t num_threads, bool debug, |
| 58 | const std::string &provider, const std::string &model_type, | 61 | const std::string &provider, const std::string &model_type, |
| @@ -65,6 +68,7 @@ struct OfflineModelConfig { | @@ -65,6 +68,7 @@ struct OfflineModelConfig { | ||
| 65 | tdnn(tdnn), | 68 | tdnn(tdnn), |
| 66 | zipformer_ctc(zipformer_ctc), | 69 | zipformer_ctc(zipformer_ctc), |
| 67 | wenet_ctc(wenet_ctc), | 70 | wenet_ctc(wenet_ctc), |
| 71 | + sense_voice(sense_voice), | ||
| 68 | telespeech_ctc(telespeech_ctc), | 72 | telespeech_ctc(telespeech_ctc), |
| 69 | tokens(tokens), | 73 | tokens(tokens), |
| 70 | num_threads(num_threads), | 74 | num_threads(num_threads), |
| @@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 212 | } | 212 | } |
| 213 | } | 213 | } |
| 214 | 214 | ||
| 215 | - OfflineRecognizerConfig GetConfig() const override { | ||
| 216 | - return config_; | ||
| 217 | - } | ||
| 218 | - | 215 | + OfflineRecognizerConfig GetConfig() const override { return config_; } |
| 219 | 216 | ||
| 220 | private: | 217 | private: |
| 221 | // Decode a single stream. | 218 | // Decode a single stream. |
| @@ -21,6 +21,7 @@ | @@ -21,6 +21,7 @@ | ||
| 21 | #include "sherpa-onnx/csrc/macros.h" | 21 | #include "sherpa-onnx/csrc/macros.h" |
| 22 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | 22 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" |
| 23 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | 23 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" |
| 24 | +#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" | ||
| 24 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 25 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| 25 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" | 26 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" |
| 26 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" | 27 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" |
| @@ -31,6 +32,28 @@ namespace sherpa_onnx { | @@ -31,6 +32,28 @@ namespace sherpa_onnx { | ||
| 31 | 32 | ||
| 32 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 33 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 33 | const OfflineRecognizerConfig &config) { | 34 | const OfflineRecognizerConfig &config) { |
| 35 | + if (!config.model_config.sense_voice.model.empty()) { | ||
| 36 | + return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config); | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + if (!config.model_config.paraformer.model.empty()) { | ||
| 40 | + return std::make_unique<OfflineRecognizerParaformerImpl>(config); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + if (!config.model_config.nemo_ctc.model.empty() || | ||
| 44 | + !config.model_config.zipformer_ctc.model.empty() || | ||
| 45 | + !config.model_config.tdnn.model.empty() || | ||
| 46 | + !config.model_config.wenet_ctc.model.empty()) { | ||
| 47 | + return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + if (!config.model_config.whisper.encoder.empty()) { | ||
| 51 | + return std::make_unique<OfflineRecognizerWhisperImpl>(config); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + // TODO(fangjun): Refactor it. We only need to use model type for the | ||
| 55 | + // following models: | ||
| 56 | + // 1. transducer and nemo_transducer | ||
| 34 | if (!config.model_config.model_type.empty()) { | 57 | if (!config.model_config.model_type.empty()) { |
| 35 | const auto &model_type = config.model_config.model_type; | 58 | const auto &model_type = config.model_config.model_type; |
| 36 | if (model_type == "transducer") { | 59 | if (model_type == "transducer") { |
| @@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 180 | #if __ANDROID_API__ >= 9 | 203 | #if __ANDROID_API__ >= 9 |
| 181 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 204 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 182 | AAssetManager *mgr, const OfflineRecognizerConfig &config) { | 205 | AAssetManager *mgr, const OfflineRecognizerConfig &config) { |
| 206 | + if (!config.model_config.sense_voice.model.empty()) { | ||
| 207 | + return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config); | ||
| 208 | + } | ||
| 209 | + | ||
| 210 | + if (!config.model_config.paraformer.model.empty()) { | ||
| 211 | + return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | ||
| 212 | + } | ||
| 213 | + | ||
| 214 | + if (!config.model_config.nemo_ctc.model.empty() || | ||
| 215 | + !config.model_config.zipformer_ctc.model.empty() || | ||
| 216 | + !config.model_config.tdnn.model.empty() || | ||
| 217 | + !config.model_type.wenet_ctc.model.empty()) { | ||
| 218 | + return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | ||
| 219 | + } | ||
| 220 | + | ||
| 221 | + if (!config.model_config.whisper.encoder.empty()) { | ||
| 222 | + return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); | ||
| 223 | + } | ||
| 224 | + | ||
| 225 | + // TODO(fangjun): Refactor it. We only need to use model type for the | ||
| 226 | + // following models: | ||
| 227 | + // 1. transducer and nemo_transducer | ||
| 183 | if (!config.model_config.model_type.empty()) { | 228 | if (!config.model_config.model_type.empty()) { |
| 184 | const auto &model_type = config.model_config.model_type; | 229 | const auto &model_type = config.model_config.model_type; |
| 185 | if (model_type == "transducer") { | 230 | if (model_type == "transducer") { |
| @@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 102 | exit(-1); | 102 | exit(-1); |
| 103 | } | 103 | } |
| 104 | 104 | ||
| 105 | - // Paraformer models assume input samples are in the range | ||
| 106 | - // [-32768, 32767], so we set normalize_samples to false | ||
| 107 | - config_.feat_config.normalize_samples = false; | 105 | + InitFeatConfig(); |
| 108 | } | 106 | } |
| 109 | 107 | ||
| 110 | #if __ANDROID_API__ >= 9 | 108 | #if __ANDROID_API__ >= 9 |
| @@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 124 | exit(-1); | 122 | exit(-1); |
| 125 | } | 123 | } |
| 126 | 124 | ||
| 127 | - // Paraformer models assume input samples are in the range | ||
| 128 | - // [-32768, 32767], so we set normalize_samples to false | ||
| 129 | - config_.feat_config.normalize_samples = false; | 125 | + InitFeatConfig(); |
| 130 | } | 126 | } |
| 131 | #endif | 127 | #endif |
| 132 | 128 | ||
| @@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 211 | } | 207 | } |
| 212 | } | 208 | } |
| 213 | 209 | ||
| 214 | - OfflineRecognizerConfig GetConfig() const override { | ||
| 215 | - return config_; | ||
| 216 | - } | 210 | + OfflineRecognizerConfig GetConfig() const override { return config_; } |
| 217 | 211 | ||
| 218 | private: | 212 | private: |
| 213 | + void InitFeatConfig() { | ||
| 214 | + // Paraformer models assume input samples are in the range | ||
| 215 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 216 | + config_.feat_config.normalize_samples = false; | ||
| 217 | + config_.feat_config.window_type = "hamming"; | ||
| 218 | + config_.feat_config.high_freq = 0; | ||
| 219 | + config_.feat_config.snip_edges = true; | ||
| 220 | + } | ||
| 221 | + | ||
| 219 | std::vector<float> ApplyLFR(const std::vector<float> &in) const { | 222 | std::vector<float> ApplyLFR(const std::vector<float> &in) const { |
| 220 | int32_t lfr_window_size = model_->LfrWindowSize(); | 223 | int32_t lfr_window_size = model_->LfrWindowSize(); |
| 221 | int32_t lfr_window_shift = model_->LfrWindowShift(); | 224 | int32_t lfr_window_shift = model_->LfrWindowShift(); |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#if __ANDROID_API__ >= 9 | ||
| 15 | +#include "android/asset_manager.h" | ||
| 16 | +#include "android/asset_manager_jni.h" | ||
| 17 | +#endif | ||
| 18 | + | ||
| 19 | +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" | ||
| 20 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 21 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 22 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 23 | +#include "sherpa-onnx/csrc/offline-sense-voice-model.h" | ||
| 24 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 25 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 26 | + | ||
| 27 | +namespace sherpa_onnx { | ||
| 28 | + | ||
| 29 | +static OfflineRecognitionResult ConvertSenseVoiceResult( | ||
| 30 | + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, | ||
| 31 | + int32_t frame_shift_ms, int32_t subsampling_factor) { | ||
| 32 | + OfflineRecognitionResult r; | ||
| 33 | + r.tokens.reserve(src.tokens.size()); | ||
| 34 | + r.timestamps.reserve(src.timestamps.size()); | ||
| 35 | + | ||
| 36 | + std::string text; | ||
| 37 | + | ||
| 38 | + for (int32_t i = 4; i < src.tokens.size(); ++i) { | ||
| 39 | + auto sym = sym_table[src.tokens[i]]; | ||
| 40 | + text.append(sym); | ||
| 41 | + | ||
| 42 | + r.tokens.push_back(std::move(sym)); | ||
| 43 | + } | ||
| 44 | + r.text = std::move(text); | ||
| 45 | + | ||
| 46 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 47 | + | ||
| 48 | + for (int32_t i = 4; i < src.timestamps.size(); ++i) { | ||
| 49 | + float time = frame_shift_s * (src.timestamps[i] - 4); | ||
| 50 | + r.timestamps.push_back(time); | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + r.words = std::move(src.words); | ||
| 54 | + | ||
| 55 | + return r; | ||
| 56 | +} | ||
| 57 | + | ||
| 58 | +class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | ||
| 59 | + public: | ||
| 60 | + explicit OfflineRecognizerSenseVoiceImpl( | ||
| 61 | + const OfflineRecognizerConfig &config) | ||
| 62 | + : OfflineRecognizerImpl(config), | ||
| 63 | + config_(config), | ||
| 64 | + symbol_table_(config_.model_config.tokens), | ||
| 65 | + model_(std::make_unique<OfflineSenseVoiceModel>(config.model_config)) { | ||
| 66 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 67 | + if (config.decoding_method == "greedy_search") { | ||
| 68 | + decoder_ = | ||
| 69 | + std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id); | ||
| 70 | + } else { | ||
| 71 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 72 | + config.decoding_method.c_str()); | ||
| 73 | + exit(-1); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + InitFeatConfig(); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | +#if __ANDROID_API__ >= 9 | ||
| 80 | + OfflineRecognizerSenseVoiceImpl(AAssetManager *mgr, | ||
| 81 | + const OfflineRecognizerConfig &config) | ||
| 82 | + : OfflineRecognizerImpl(mgr, config), | ||
| 83 | + config_(config), | ||
| 84 | + symbol_table_(mgr, config_.model_config.tokens), | ||
| 85 | + model_(std::make_unique<OfflineSenseVoiceModel>(mgr, | ||
| 86 | + config.model_config)) { | ||
| 87 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 88 | + if (config.decoding_method == "greedy_search") { | ||
| 89 | + decoder_ = | ||
| 90 | + std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id); | ||
| 91 | + } else { | ||
| 92 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 93 | + config.decoding_method.c_str()); | ||
| 94 | + exit(-1); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + InitFeatConfig(); | ||
| 98 | + } | ||
| 99 | +#endif | ||
| 100 | + | ||
| 101 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 102 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 106 | + if (n == 1) { | ||
| 107 | + DecodeOneStream(ss[0]); | ||
| 108 | + return; | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 112 | + // 1. Apply LFR | ||
| 113 | + // 2. Apply CMVN | ||
| 114 | + // | ||
| 115 | + // Please refer to | ||
| 116 | + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf | ||
| 117 | + // for what LFR means | ||
| 118 | + // | ||
| 119 | + // "Lower Frame Rate Neural Network Acoustic Models" | ||
| 120 | + auto memory_info = | ||
| 121 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 122 | + | ||
| 123 | + std::vector<Ort::Value> features; | ||
| 124 | + features.reserve(n); | ||
| 125 | + | ||
| 126 | + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; | ||
| 127 | + | ||
| 128 | + std::vector<std::vector<float>> features_vec(n); | ||
| 129 | + std::vector<int32_t> features_length_vec(n); | ||
| 130 | + for (int32_t i = 0; i != n; ++i) { | ||
| 131 | + std::vector<float> f = ss[i]->GetFrames(); | ||
| 132 | + | ||
| 133 | + f = ApplyLFR(f); | ||
| 134 | + ApplyCMVN(&f); | ||
| 135 | + | ||
| 136 | + int32_t num_frames = f.size() / feat_dim; | ||
| 137 | + features_vec[i] = std::move(f); | ||
| 138 | + | ||
| 139 | + features_length_vec[i] = num_frames; | ||
| 140 | + | ||
| 141 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 142 | + | ||
| 143 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 144 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 145 | + shape.data(), shape.size()); | ||
| 146 | + features.push_back(std::move(x)); | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 150 | + for (int32_t i = 0; i != n; ++i) { | ||
| 151 | + features_pointer[i] = &features[i]; | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 155 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 156 | + memory_info, features_length_vec.data(), n, | ||
| 157 | + features_length_shape.data(), features_length_shape.size()); | ||
| 158 | + | ||
| 159 | + // Caution(fangjun): We cannot pad it with log(eps), | ||
| 160 | + // i.e., -23.025850929940457f | ||
| 161 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); | ||
| 162 | + | ||
| 163 | + int32_t language = 0; | ||
| 164 | + if (config_.model_config.sense_voice.language.empty()) { | ||
| 165 | + language = 0; | ||
| 166 | + } else if (meta_data.lang2id.count( | ||
| 167 | + config_.model_config.sense_voice.language)) { | ||
| 168 | + language = | ||
| 169 | + meta_data.lang2id.at(config_.model_config.sense_voice.language); | ||
| 170 | + } else { | ||
| 171 | + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", | ||
| 172 | + config_.model_config.sense_voice.language.c_str()); | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + std::vector<int32_t> language_array(n); | ||
| 176 | + std::fill(language_array.begin(), language_array.end(), language); | ||
| 177 | + | ||
| 178 | + std::vector<int32_t> text_norm_array(n); | ||
| 179 | + std::fill(text_norm_array.begin(), text_norm_array.end(), | ||
| 180 | + config_.model_config.sense_voice.use_itn | ||
| 181 | + ? meta_data.with_itn_id | ||
| 182 | + : meta_data.without_itn_id); | ||
| 183 | + | ||
| 184 | + Ort::Value language_tensor = Ort::Value::CreateTensor( | ||
| 185 | + memory_info, language_array.data(), n, features_length_shape.data(), | ||
| 186 | + features_length_shape.size()); | ||
| 187 | + | ||
| 188 | + Ort::Value text_norm_tensor = Ort::Value::CreateTensor( | ||
| 189 | + memory_info, text_norm_array.data(), n, features_length_shape.data(), | ||
| 190 | + features_length_shape.size()); | ||
| 191 | + | ||
| 192 | + Ort::Value logits{nullptr}; | ||
| 193 | + try { | ||
| 194 | + logits = model_->Forward(std::move(x), std::move(x_length), | ||
| 195 | + std::move(language_tensor), | ||
| 196 | + std::move(text_norm_tensor)); | ||
| 197 | + } catch (const Ort::Exception &ex) { | ||
| 198 | + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", | ||
| 199 | + ex.what()); | ||
| 200 | + return; | ||
| 201 | + } | ||
| 202 | + | ||
| 203 | + // decoder_->Decode() requires that logits_length is of dtype int64 | ||
| 204 | + std::vector<int64_t> features_length_vec_64; | ||
| 205 | + features_length_vec_64.reserve(n); | ||
| 206 | + for (auto i : features_length_vec) { | ||
| 207 | + i += 4; | ||
| 208 | + features_length_vec_64.push_back(i); | ||
| 209 | + } | ||
| 210 | + | ||
| 211 | + Ort::Value logits_length = Ort::Value::CreateTensor( | ||
| 212 | + memory_info, features_length_vec_64.data(), n, | ||
| 213 | + features_length_shape.data(), features_length_shape.size()); | ||
| 214 | + | ||
| 215 | + auto results = | ||
| 216 | + decoder_->Decode(std::move(logits), std::move(logits_length)); | ||
| 217 | + | ||
| 218 | + int32_t frame_shift_ms = 10; | ||
| 219 | + int32_t subsampling_factor = meta_data.window_shift; | ||
| 220 | + for (int32_t i = 0; i != n; ++i) { | ||
| 221 | + auto r = ConvertSenseVoiceResult(results[i], symbol_table_, | ||
| 222 | + frame_shift_ms, subsampling_factor); | ||
| 223 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 224 | + ss[i]->SetResult(r); | ||
| 225 | + } | ||
| 226 | + } | ||
| 227 | + | ||
| 228 | + OfflineRecognizerConfig GetConfig() const override { return config_; } | ||
| 229 | + | ||
| 230 | + private: | ||
| 231 | + void DecodeOneStream(OfflineStream *s) const { | ||
| 232 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 233 | + | ||
| 234 | + auto memory_info = | ||
| 235 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 236 | + | ||
| 237 | + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; | ||
| 238 | + std::vector<float> f = s->GetFrames(); | ||
| 239 | + f = ApplyLFR(f); | ||
| 240 | + ApplyCMVN(&f); | ||
| 241 | + int32_t num_frames = f.size() / feat_dim; | ||
| 242 | + std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | ||
| 243 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | ||
| 244 | + shape.data(), shape.size()); | ||
| 245 | + | ||
| 246 | + int64_t scale_shape = 1; | ||
| 247 | + | ||
| 248 | + Ort::Value x_length = | ||
| 249 | + Ort::Value::CreateTensor(memory_info, &num_frames, 1, &scale_shape, 1); | ||
| 250 | + | ||
| 251 | + int32_t language = 0; | ||
| 252 | + if (config_.model_config.sense_voice.language.empty()) { | ||
| 253 | + language = 0; | ||
| 254 | + } else if (meta_data.lang2id.count( | ||
| 255 | + config_.model_config.sense_voice.language)) { | ||
| 256 | + language = | ||
| 257 | + meta_data.lang2id.at(config_.model_config.sense_voice.language); | ||
| 258 | + } else { | ||
| 259 | + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", | ||
| 260 | + config_.model_config.sense_voice.language.c_str()); | ||
| 261 | + } | ||
| 262 | + | ||
| 263 | + int32_t text_norm = config_.model_config.sense_voice.use_itn | ||
| 264 | + ? meta_data.with_itn_id | ||
| 265 | + : meta_data.without_itn_id; | ||
| 266 | + | ||
| 267 | + Ort::Value language_tensor = | ||
| 268 | + Ort::Value::CreateTensor(memory_info, &language, 1, &scale_shape, 1); | ||
| 269 | + | ||
| 270 | + Ort::Value text_norm_tensor = | ||
| 271 | + Ort::Value::CreateTensor(memory_info, &text_norm, 1, &scale_shape, 1); | ||
| 272 | + | ||
| 273 | + Ort::Value logits{nullptr}; | ||
| 274 | + try { | ||
| 275 | + logits = model_->Forward(std::move(x), std::move(x_length), | ||
| 276 | + std::move(language_tensor), | ||
| 277 | + std::move(text_norm_tensor)); | ||
| 278 | + } catch (const Ort::Exception &ex) { | ||
| 279 | + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", | ||
| 280 | + ex.what()); | ||
| 281 | + return; | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + int64_t new_num_frames = num_frames + 4; | ||
| 285 | + Ort::Value logits_length = Ort::Value::CreateTensor( | ||
| 286 | + memory_info, &new_num_frames, 1, &scale_shape, 1); | ||
| 287 | + | ||
| 288 | + auto results = | ||
| 289 | + decoder_->Decode(std::move(logits), std::move(logits_length)); | ||
| 290 | + | ||
| 291 | + int32_t frame_shift_ms = 10; | ||
| 292 | + int32_t subsampling_factor = meta_data.window_shift; | ||
| 293 | + auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms, | ||
| 294 | + subsampling_factor); | ||
| 295 | + | ||
| 296 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 297 | + s->SetResult(r); | ||
| 298 | + } | ||
| 299 | + | ||
| 300 | + void InitFeatConfig() { | ||
| 301 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 302 | + | ||
| 303 | + config_.feat_config.normalize_samples = meta_data.normalize_samples; | ||
| 304 | + config_.feat_config.window_type = "hamming"; | ||
| 305 | + config_.feat_config.high_freq = 0; | ||
| 306 | + config_.feat_config.snip_edges = true; | ||
| 307 | + } | ||
| 308 | + std::vector<float> ApplyLFR(const std::vector<float> &in) const { | ||
| 309 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 310 | + | ||
| 311 | + int32_t lfr_window_size = meta_data.window_size; | ||
| 312 | + int32_t lfr_window_shift = meta_data.window_shift; | ||
| 313 | + int32_t in_feat_dim = config_.feat_config.feature_dim; | ||
| 314 | + | ||
| 315 | + int32_t in_num_frames = in.size() / in_feat_dim; | ||
| 316 | + int32_t out_num_frames = | ||
| 317 | + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; | ||
| 318 | + int32_t out_feat_dim = in_feat_dim * lfr_window_size; | ||
| 319 | + | ||
| 320 | + std::vector<float> out(out_num_frames * out_feat_dim); | ||
| 321 | + | ||
| 322 | + const float *p_in = in.data(); | ||
| 323 | + float *p_out = out.data(); | ||
| 324 | + | ||
| 325 | + for (int32_t i = 0; i != out_num_frames; ++i) { | ||
| 326 | + std::copy(p_in, p_in + out_feat_dim, p_out); | ||
| 327 | + | ||
| 328 | + p_out += out_feat_dim; | ||
| 329 | + p_in += lfr_window_shift * in_feat_dim; | ||
| 330 | + } | ||
| 331 | + | ||
| 332 | + return out; | ||
| 333 | + } | ||
| 334 | + | ||
| 335 | + void ApplyCMVN(std::vector<float> *v) const { | ||
| 336 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 337 | + | ||
| 338 | + const std::vector<float> &neg_mean = meta_data.neg_mean; | ||
| 339 | + const std::vector<float> &inv_stddev = meta_data.inv_stddev; | ||
| 340 | + | ||
| 341 | + int32_t dim = neg_mean.size(); | ||
| 342 | + int32_t num_frames = v->size() / dim; | ||
| 343 | + | ||
| 344 | + float *p = v->data(); | ||
| 345 | + | ||
| 346 | + for (int32_t i = 0; i != num_frames; ++i) { | ||
| 347 | + for (int32_t k = 0; k != dim; ++k) { | ||
| 348 | + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; | ||
| 349 | + } | ||
| 350 | + | ||
| 351 | + p += dim; | ||
| 352 | + } | ||
| 353 | + } | ||
| 354 | + | ||
| 355 | + OfflineRecognizerConfig config_; | ||
| 356 | + SymbolTable symbol_table_; | ||
| 357 | + std::unique_ptr<OfflineSenseVoiceModel> model_; | ||
| 358 | + std::unique_ptr<OfflineCtcDecoder> decoder_; | ||
| 359 | +}; | ||
| 360 | + | ||
| 361 | +} // namespace sherpa_onnx | ||
| 362 | + | ||
| 363 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-sense-voice-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("sense-voice-model", &model, | ||
| 14 | + "Path to model.onnx of SenseVoice."); | ||
| 15 | + po->Register( | ||
| 16 | + "sense-voice-language", &language, | ||
| 17 | + "Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used"); | ||
| 18 | + po->Register( | ||
| 19 | + "sense-voice-use-itn", &use_itn, | ||
| 20 | + "True to enable inverse text normalization. False to disable it."); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +bool OfflineSenseVoiceModelConfig::Validate() const { | ||
| 24 | + if (!FileExists(model)) { | ||
| 25 | + SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str()); | ||
| 26 | + return false; | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + if (!language.empty()) { | ||
| 30 | + if (language != "auto" && language != "zh" && language != "en" && | ||
| 31 | + language != "ja" && language != "ko" && language != "yue") { | ||
| 32 | + SHERPA_ONNX_LOGE( | ||
| 33 | + "Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, " | ||
| 34 | + "ja, ko, yue. Or you can leave it empty to use 'auto'", | ||
| 35 | + language.c_str()); | ||
| 36 | + | ||
| 37 | + return false; | ||
| 38 | + } | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + return true; | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +std::string OfflineSenseVoiceModelConfig::ToString() const { | ||
| 45 | + std::ostringstream os; | ||
| 46 | + | ||
| 47 | + os << "OfflineSenseVoiceModelConfig("; | ||
| 48 | + os << "model=\"" << model << "\", "; | ||
| 49 | + os << "language=\"" << language << "\", "; | ||
| 50 | + os << "use_itn=" << (use_itn ? "True" : "False") << ")"; | ||
| 51 | + | ||
| 52 | + return os.str(); | ||
| 53 | +} | ||
| 54 | + | ||
| 55 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-sense-voice-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineSenseVoiceModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + // "" or "auto" to let the model recognize the language | ||
| 17 | + // valid values: | ||
| 18 | + // zh, en, ja, ko, yue, auto | ||
| 19 | + std::string language = "auto"; | ||
| 20 | + | ||
| 21 | + // true to use inverse text normalization | ||
| 22 | + // false to not use inverse text normalization | ||
| 23 | + bool use_itn = false; | ||
| 24 | + | ||
| 25 | + OfflineSenseVoiceModelConfig() = default; | ||
| 26 | + explicit OfflineSenseVoiceModelConfig(const std::string &model, | ||
| 27 | + const std::string &language, | ||
| 28 | + bool use_itn) | ||
| 29 | + : model(model), language(language), use_itn(use_itn) {} | ||
| 30 | + | ||
| 31 | + void Register(ParseOptions *po); | ||
| 32 | + bool Validate() const; | ||
| 33 | + | ||
| 34 | + std::string ToString() const; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx | ||
| 38 | + | ||
| 39 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <unordered_map> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineSenseVoiceModelMetaData { | ||
| 14 | + // ID for using inverse text normalization | ||
| 15 | + int32_t with_itn_id; | ||
| 16 | + | ||
| 17 | + // ID for not using inverse text normalization | ||
| 18 | + int32_t without_itn_id; | ||
| 19 | + | ||
| 20 | + int32_t window_size; // lfr_m | ||
| 21 | + int32_t window_shift; // lfr_n | ||
| 22 | + int32_t vocab_size; | ||
| 23 | + | ||
| 24 | + int32_t subsampling_factor = 1; | ||
| 25 | + | ||
| 26 | + // Usually 0 for SenseVoice models. | ||
| 27 | + // 0 means samples are scaled to [-32768, 32767] before are sent to the | ||
| 28 | + // feature extractor | ||
| 29 | + int32_t normalize_samples = 0; | ||
| 30 | + | ||
| 31 | + int32_t blank_id = 0; | ||
| 32 | + | ||
| 33 | + // possible values: | ||
| 34 | + // zh, en, ja, ko, yue, auto | ||
| 35 | + // where | ||
| 36 | + // zh is Chinese (Mandarin) | ||
| 37 | + // en is English | ||
| 38 | + // ja is Japanese | ||
| 39 | + // ko is Korean | ||
| 40 | + // yue is Cantonese | ||
| 41 | + // auto is to let the model recognize the language | ||
| 42 | + std::unordered_map<std::string, int32_t> lang2id; | ||
| 43 | + | ||
| 44 | + std::vector<float> neg_mean; | ||
| 45 | + std::vector<float> inv_stddev; | ||
| 46 | +}; | ||
| 47 | + | ||
| 48 | +} // namespace sherpa_onnx | ||
| 49 | + | ||
| 50 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/offline-sense-voice-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-sense-voice-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/session.h" | ||
| 13 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineSenseVoiceModel::Impl { | ||
| 18 | + public: | ||
| 19 | + explicit Impl(const OfflineModelConfig &config) | ||
| 20 | + : config_(config), | ||
| 21 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 22 | + sess_opts_(GetSessionOptions(config)), | ||
| 23 | + allocator_{} { | ||
| 24 | + auto buf = ReadFile(config_.sense_voice.model); | ||
| 25 | + Init(buf.data(), buf.size()); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + Impl(AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 30 | + : config_(config), | ||
| 31 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 32 | + sess_opts_(GetSessionOptions(config)), | ||
| 33 | + allocator_{} { | ||
| 34 | + auto buf = ReadFile(mgr, config_.sense_voice.model); | ||
| 35 | + Init(buf.data(), buf.size()); | ||
| 36 | + } | ||
| 37 | +#endif | ||
| 38 | + | ||
| 39 | + Ort::Value Forward(Ort::Value features, Ort::Value features_length, | ||
| 40 | + Ort::Value language, Ort::Value text_norm) { | ||
| 41 | + std::array<Ort::Value, 4> inputs = { | ||
| 42 | + std::move(features), | ||
| 43 | + std::move(features_length), | ||
| 44 | + std::move(language), | ||
| 45 | + std::move(text_norm), | ||
| 46 | + }; | ||
| 47 | + | ||
| 48 | + auto ans = | ||
| 49 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 50 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 51 | + return std::move(ans[0]); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const { | ||
| 55 | + return meta_data_; | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 59 | + | ||
| 60 | + private: | ||
| 61 | + void Init(void *model_data, size_t model_data_length) { | ||
| 62 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 63 | + sess_opts_); | ||
| 64 | + | ||
| 65 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 66 | + | ||
| 67 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 68 | + | ||
| 69 | + // get meta data | ||
| 70 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 71 | + if (config_.debug) { | ||
| 72 | + std::ostringstream os; | ||
| 73 | + PrintModelMetadata(os, meta_data); | ||
| 74 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 78 | + SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size"); | ||
| 79 | + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size"); | ||
| 80 | + SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift"); | ||
| 81 | + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples, | ||
| 82 | + "normalize_samples"); | ||
| 83 | + | ||
| 84 | + SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn"); | ||
| 85 | + | ||
| 86 | + SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn"); | ||
| 87 | + | ||
| 88 | + int32_t lang_auto = 0; | ||
| 89 | + int32_t lang_zh = 0; | ||
| 90 | + int32_t lang_en = 0; | ||
| 91 | + int32_t lang_ja = 0; | ||
| 92 | + int32_t lang_ko = 0; | ||
| 93 | + int32_t lang_yue = 0; | ||
| 94 | + | ||
| 95 | + SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto"); | ||
| 96 | + SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh"); | ||
| 97 | + SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en"); | ||
| 98 | + SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja"); | ||
| 99 | + SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko"); | ||
| 100 | + SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue"); | ||
| 101 | + | ||
| 102 | + meta_data_.lang2id = { | ||
| 103 | + {"auto", lang_auto}, {"zh", lang_zh}, {"ja", lang_ja}, | ||
| 104 | + {"ko", lang_ko}, {"yue", lang_yue}, | ||
| 105 | + }; | ||
| 106 | + | ||
| 107 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean"); | ||
| 108 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev"); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + private: | ||
| 112 | + OfflineModelConfig config_; | ||
| 113 | + Ort::Env env_; | ||
| 114 | + Ort::SessionOptions sess_opts_; | ||
| 115 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 116 | + | ||
| 117 | + std::unique_ptr<Ort::Session> sess_; | ||
| 118 | + | ||
| 119 | + std::vector<std::string> input_names_; | ||
| 120 | + std::vector<const char *> input_names_ptr_; | ||
| 121 | + | ||
| 122 | + std::vector<std::string> output_names_; | ||
| 123 | + std::vector<const char *> output_names_ptr_; | ||
| 124 | + | ||
| 125 | + OfflineSenseVoiceModelMetaData meta_data_; | ||
| 126 | +}; | ||
| 127 | + | ||
| 128 | +OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config) | ||
| 129 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 130 | + | ||
| 131 | +#if __ANDROID_API__ >= 9 | ||
| 132 | +OfflineSenseVoiceModel::OfflineSenseVoiceModel(AAssetManager *mgr, | ||
| 133 | + const OfflineModelConfig &config) | ||
| 134 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 135 | +#endif | ||
| 136 | + | ||
| 137 | +OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default; | ||
| 138 | + | ||
| 139 | +Ort::Value OfflineSenseVoiceModel::Forward(Ort::Value features, | ||
| 140 | + Ort::Value features_length, | ||
| 141 | + Ort::Value language, | ||
| 142 | + Ort::Value text_norm) const { | ||
| 143 | + return impl_->Forward(std::move(features), std::move(features_length), | ||
| 144 | + std::move(language), std::move(text_norm)); | ||
| 145 | +} | ||
| 146 | + | ||
| 147 | +const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::GetModelMetadata() | ||
| 148 | + const { | ||
| 149 | + return impl_->GetModelMetadata(); | ||
| 150 | +} | ||
| 151 | + | ||
| 152 | +OrtAllocator *OfflineSenseVoiceModel::Allocator() const { | ||
| 153 | + return impl_->Allocator(); | ||
| 154 | +} | ||
| 155 | + | ||
| 156 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-sense-voice-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-sense-voice-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#if __ANDROID_API__ >= 9 | ||
| 11 | +#include "android/asset_manager.h" | ||
| 12 | +#include "android/asset_manager_jni.h" | ||
| 13 | +#endif | ||
| 14 | + | ||
| 15 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 16 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +class OfflineSenseVoiceModel { | ||
| 22 | + public: | ||
| 23 | + explicit OfflineSenseVoiceModel(const OfflineModelConfig &config); | ||
| 24 | + | ||
| 25 | +#if __ANDROID_API__ >= 9 | ||
| 26 | + OfflineSenseVoiceModel(AAssetManager *mgr, const OfflineModelConfig &config); | ||
| 27 | +#endif | ||
| 28 | + | ||
| 29 | + ~OfflineSenseVoiceModel(); | ||
| 30 | + | ||
| 31 | + /** Run the forward method of the model. | ||
| 32 | + * | ||
| 33 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 34 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 35 | + * valid frames in `features` before padding. | ||
| 36 | + * Its dtype is int32_t. | ||
| 37 | + * @param language A 1-D tensor of shape (N,) with dtype int32_t | ||
| 38 | + * @param text_norm A 1-D tensor of shape (N,) with dtype int32_t | ||
| 39 | + * | ||
| 40 | + * @return Return logits of shape (N, T, C) with dtype float | ||
| 41 | + * | ||
| 42 | + * Note: The subsampling factor is 1 for SenseVoice, so there is | ||
| 43 | + * no need to output logits_length. | ||
| 44 | + */ | ||
| 45 | + Ort::Value Forward(Ort::Value features, Ort::Value features_length, | ||
| 46 | + Ort::Value language, Ort::Value text_norm) const; | ||
| 47 | + | ||
| 48 | + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const; | ||
| 49 | + | ||
| 50 | + /** Return an allocator for allocating memory | ||
| 51 | + */ | ||
| 52 | + OrtAllocator *Allocator() const; | ||
| 53 | + | ||
| 54 | + private: | ||
| 55 | + class Impl; | ||
| 56 | + std::unique_ptr<Impl> impl_; | ||
| 57 | +}; | ||
| 58 | + | ||
| 59 | +} // namespace sherpa_onnx | ||
| 60 | + | ||
| 61 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ |
| @@ -6,6 +6,8 @@ | @@ -6,6 +6,8 @@ | ||
| 6 | 6 | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | +#include <functional> | ||
| 10 | +#include <numeric> | ||
| 9 | #include <sstream> | 11 | #include <sstream> |
| 10 | #include <string> | 12 | #include <string> |
| 11 | 13 | ||
| @@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) { | @@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) { | ||
| 153 | } | 155 | } |
| 154 | } | 156 | } |
| 155 | 157 | ||
| 158 | +float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) { | ||
| 159 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 160 | + auto size = static_cast<int32_t>(std::accumulate( | ||
| 161 | + shape.begin(), shape.end(), 1, std::multiplies<int64_t>())); | ||
| 162 | + if (n != -1 && n < size && n > 0) { | ||
| 163 | + size = n; | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + const float *p = v->GetTensorData<float>(); | ||
| 167 | + | ||
| 168 | + return std::accumulate(p, p + size, 1.0f); | ||
| 169 | +} | ||
| 170 | + | ||
| 171 | +float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) { | ||
| 172 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 173 | + auto size = static_cast<int32_t>(std::accumulate( | ||
| 174 | + shape.begin(), shape.end(), 1, std::multiplies<int64_t>())); | ||
| 175 | + | ||
| 176 | + if (n != -1 && n < size && n > 0) { | ||
| 177 | + size = n; | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + auto sum = ComputeSum(v, n); | ||
| 181 | + return sum / size; | ||
| 182 | +} | ||
| 183 | + | ||
| 184 | +void PrintShape(const Ort::Value *v) { | ||
| 185 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 186 | + std::ostringstream os; | ||
| 187 | + for (auto i : shape) { | ||
| 188 | + os << i << ", "; | ||
| 189 | + } | ||
| 190 | + os << "\n"; | ||
| 191 | + fprintf(stderr, "%s", os.str().c_str()); | ||
| 192 | +} | ||
| 193 | + | ||
| 156 | template <typename T /*= float*/> | 194 | template <typename T /*= float*/> |
| 157 | -void Print1D(Ort::Value *v) { | 195 | +void Print1D(const Ort::Value *v) { |
| 158 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 196 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 159 | const T *d = v->GetTensorData<T>(); | 197 | const T *d = v->GetTensorData<T>(); |
| 160 | std::ostringstream os; | 198 | std::ostringstream os; |
| 161 | for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | 199 | for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { |
| 162 | - os << *d << " "; | 200 | + os << d[i] << " "; |
| 163 | } | 201 | } |
| 164 | os << "\n"; | 202 | os << "\n"; |
| 165 | fprintf(stderr, "%s\n", os.str().c_str()); | 203 | fprintf(stderr, "%s\n", os.str().c_str()); |
| 166 | } | 204 | } |
| 167 | 205 | ||
| 168 | -template void Print1D<int64_t>(Ort::Value *v); | ||
| 169 | -template void Print1D<float>(Ort::Value *v); | 206 | +template void Print1D<int64_t>(const Ort::Value *v); |
| 207 | +template void Print1D<int32_t>(const Ort::Value *v); | ||
| 208 | +template void Print1D<float>(const Ort::Value *v); | ||
| 170 | 209 | ||
| 171 | template <typename T /*= float*/> | 210 | template <typename T /*= float*/> |
| 172 | -void Print2D(Ort::Value *v) { | 211 | +void Print2D(const Ort::Value *v) { |
| 173 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 212 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 174 | const T *d = v->GetTensorData<T>(); | 213 | const T *d = v->GetTensorData<T>(); |
| 175 | 214 | ||
| @@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) { | @@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) { | ||
| 183 | fprintf(stderr, "%s\n", os.str().c_str()); | 222 | fprintf(stderr, "%s\n", os.str().c_str()); |
| 184 | } | 223 | } |
| 185 | 224 | ||
| 186 | -template void Print2D<int64_t>(Ort::Value *v); | ||
| 187 | -template void Print2D<float>(Ort::Value *v); | 225 | +template void Print2D<int64_t>(const Ort::Value *v); |
| 226 | +template void Print2D<float>(const Ort::Value *v); | ||
| 188 | 227 | ||
| 189 | -void Print3D(Ort::Value *v) { | 228 | +void Print3D(const Ort::Value *v) { |
| 190 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 229 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 191 | const float *d = v->GetTensorData<float>(); | 230 | const float *d = v->GetTensorData<float>(); |
| 192 | 231 | ||
| @@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) { | @@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) { | ||
| 202 | fprintf(stderr, "\n"); | 241 | fprintf(stderr, "\n"); |
| 203 | } | 242 | } |
| 204 | 243 | ||
| 205 | -void Print4D(Ort::Value *v) { | 244 | +void Print4D(const Ort::Value *v) { |
| 206 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 245 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 207 | const float *d = v->GetTensorData<float>(); | 246 | const float *d = v->GetTensorData<float>(); |
| 208 | 247 |
| @@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | @@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | ||
| 68 | // Return a shallow copy | 68 | // Return a shallow copy |
| 69 | Ort::Value View(Ort::Value *v); | 69 | Ort::Value View(Ort::Value *v); |
| 70 | 70 | ||
| 71 | +float ComputeSum(const Ort::Value *v, int32_t n = -1); | ||
| 72 | +float ComputeMean(const Ort::Value *v, int32_t n = -1); | ||
| 73 | + | ||
| 71 | // Print a 1-D tensor to stderr | 74 | // Print a 1-D tensor to stderr |
| 72 | template <typename T = float> | 75 | template <typename T = float> |
| 73 | -void Print1D(Ort::Value *v); | 76 | +void Print1D(const Ort::Value *v); |
| 74 | 77 | ||
| 75 | // Print a 2-D tensor to stderr | 78 | // Print a 2-D tensor to stderr |
| 76 | template <typename T = float> | 79 | template <typename T = float> |
| 77 | -void Print2D(Ort::Value *v); | 80 | +void Print2D(const Ort::Value *v); |
| 78 | 81 | ||
| 79 | // Print a 3-D tensor to stderr | 82 | // Print a 3-D tensor to stderr |
| 80 | -void Print3D(Ort::Value *v); | 83 | +void Print3D(const Ort::Value *v); |
| 81 | 84 | ||
| 82 | // Print a 4-D tensor to stderr | 85 | // Print a 4-D tensor to stderr |
| 83 | -void Print4D(Ort::Value *v); | 86 | +void Print4D(const Ort::Value *v); |
| 87 | + | ||
| 88 | +void PrintShape(const Ort::Value *v); | ||
| 84 | 89 | ||
| 85 | template <typename T = float> | 90 | template <typename T = float> |
| 86 | void Fill(Ort::Value *tensor, T value) { | 91 | void Fill(Ort::Value *tensor, T value) { |
| @@ -15,6 +15,7 @@ set(srcs | @@ -15,6 +15,7 @@ set(srcs | ||
| 15 | offline-paraformer-model-config.cc | 15 | offline-paraformer-model-config.cc |
| 16 | offline-punctuation.cc | 16 | offline-punctuation.cc |
| 17 | offline-recognizer.cc | 17 | offline-recognizer.cc |
| 18 | + offline-sense-voice-model-config.cc | ||
| 18 | offline-stream.cc | 19 | offline-stream.cc |
| 19 | offline-tdnn-model-config.cc | 20 | offline-tdnn-model-config.cc |
| 20 | offline-transducer-model-config.cc | 21 | offline-transducer-model-config.cc |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 13 | +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" | ||
| 13 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 15 | #include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" | 16 | #include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" |
| @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 26 | PybindOfflineTdnnModelConfig(m); | 27 | PybindOfflineTdnnModelConfig(m); |
| 27 | PybindOfflineZipformerCtcModelConfig(m); | 28 | PybindOfflineZipformerCtcModelConfig(m); |
| 28 | PybindOfflineWenetCtcModelConfig(m); | 29 | PybindOfflineWenetCtcModelConfig(m); |
| 30 | + PybindOfflineSenseVoiceModelConfig(m); | ||
| 29 | 31 | ||
| 30 | using PyClass = OfflineModelConfig; | 32 | using PyClass = OfflineModelConfig; |
| 31 | py::class_<PyClass>(*m, "OfflineModelConfig") | 33 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 36 | const OfflineNemoEncDecCtcModelConfig &, | 38 | const OfflineNemoEncDecCtcModelConfig &, |
| 37 | const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, | 39 | const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, |
| 38 | const OfflineZipformerCtcModelConfig &, | 40 | const OfflineZipformerCtcModelConfig &, |
| 39 | - const OfflineWenetCtcModelConfig &, const std::string &, | 41 | + const OfflineWenetCtcModelConfig &, |
| 42 | + const OfflineSenseVoiceModelConfig &, const std::string &, | ||
| 40 | const std::string &, int32_t, bool, const std::string &, | 43 | const std::string &, int32_t, bool, const std::string &, |
| 41 | const std::string &, const std::string &, const std::string &>(), | 44 | const std::string &, const std::string &, const std::string &>(), |
| 42 | py::arg("transducer") = OfflineTransducerModelConfig(), | 45 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 46 | py::arg("tdnn") = OfflineTdnnModelConfig(), | 49 | py::arg("tdnn") = OfflineTdnnModelConfig(), |
| 47 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), | 50 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), |
| 48 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), | 51 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), |
| 52 | + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), | ||
| 49 | py::arg("telespeech_ctc") = "", py::arg("tokens"), | 53 | py::arg("telespeech_ctc") = "", py::arg("tokens"), |
| 50 | py::arg("num_threads"), py::arg("debug") = false, | 54 | py::arg("num_threads"), py::arg("debug") = false, |
| 51 | py::arg("provider") = "cpu", py::arg("model_type") = "", | 55 | py::arg("provider") = "cpu", py::arg("model_type") = "", |
| @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 57 | .def_readwrite("tdnn", &PyClass::tdnn) | 61 | .def_readwrite("tdnn", &PyClass::tdnn) |
| 58 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) | 62 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) |
| 59 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 63 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 64 | + .def_readwrite("sense_voice", &PyClass::sense_voice) | ||
| 60 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) | 65 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) |
| 61 | .def_readwrite("tokens", &PyClass::tokens) | 66 | .def_readwrite("tokens", &PyClass::tokens) |
| 62 | .def_readwrite("num_threads", &PyClass::num_threads) | 67 | .def_readwrite("num_threads", &PyClass::num_threads) |
| @@ -14,6 +14,7 @@ namespace sherpa_onnx { | @@ -14,6 +14,7 @@ namespace sherpa_onnx { | ||
| 14 | void PybindOfflineParaformerModelConfig(py::module *m) { | 14 | void PybindOfflineParaformerModelConfig(py::module *m) { |
| 15 | using PyClass = OfflineParaformerModelConfig; | 15 | using PyClass = OfflineParaformerModelConfig; |
| 16 | py::class_<PyClass>(*m, "OfflineParaformerModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineParaformerModelConfig") |
| 17 | + .def(py::init<>()) | ||
| 17 | .def(py::init<const std::string &>(), py::arg("model")) | 18 | .def(py::init<const std::string &>(), py::arg("model")) |
| 18 | .def_readwrite("model", &PyClass::model) | 19 | .def_readwrite("model", &PyClass::model) |
| 19 | .def("__str__", &PyClass::ToString); | 20 | .def("__str__", &PyClass::ToString); |
| 1 | +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineSenseVoiceModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineSenseVoiceModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineSenseVoiceModelConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init<const std::string &, const std::string &, bool>(), | ||
| 19 | + py::arg("model"), py::arg("language"), py::arg("use_itn")) | ||
| 20 | + .def_readwrite("model", &PyClass::model) | ||
| 21 | + .def_readwrite("language", &PyClass::language) | ||
| 22 | + .def_readwrite("use_itn", &PyClass::use_itn) | ||
| 23 | + .def("__str__", &PyClass::ToString); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineSenseVoiceModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ |
| @@ -10,6 +10,7 @@ from _sherpa_onnx import ( | @@ -10,6 +10,7 @@ from _sherpa_onnx import ( | ||
| 10 | OfflineModelConfig, | 10 | OfflineModelConfig, |
| 11 | OfflineNemoEncDecCtcModelConfig, | 11 | OfflineNemoEncDecCtcModelConfig, |
| 12 | OfflineParaformerModelConfig, | 12 | OfflineParaformerModelConfig, |
| 13 | + OfflineSenseVoiceModelConfig, | ||
| 13 | ) | 14 | ) |
| 14 | from _sherpa_onnx import OfflineRecognizer as _Recognizer | 15 | from _sherpa_onnx import OfflineRecognizer as _Recognizer |
| 15 | from _sherpa_onnx import ( | 16 | from _sherpa_onnx import ( |
| @@ -174,6 +175,88 @@ class OfflineRecognizer(object): | @@ -174,6 +175,88 @@ class OfflineRecognizer(object): | ||
| 174 | return self | 175 | return self |
| 175 | 176 | ||
| 176 | @classmethod | 177 | @classmethod |
| 178 | + def from_sense_voice( | ||
| 179 | + cls, | ||
| 180 | + model: str, | ||
| 181 | + tokens: str, | ||
| 182 | + num_threads: int = 1, | ||
| 183 | + sample_rate: int = 16000, | ||
| 184 | + feature_dim: int = 80, | ||
| 185 | + decoding_method: str = "greedy_search", | ||
| 186 | + debug: bool = False, | ||
| 187 | + provider: str = "cpu", | ||
| 188 | + language: str = "", | ||
| 189 | + use_itn: bool = False, | ||
| 190 | + rule_fsts: str = "", | ||
| 191 | + rule_fars: str = "", | ||
| 192 | + ): | ||
| 193 | + """ | ||
| 194 | + Please refer to | ||
| 195 | + `<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_ | ||
| 196 | + to download pre-trained models. | ||
| 197 | + | ||
| 198 | + Args: | ||
| 199 | + tokens: | ||
| 200 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 201 | + columns:: | ||
| 202 | + | ||
| 203 | + symbol integer_id | ||
| 204 | + | ||
| 205 | + model: | ||
| 206 | + Path to ``model.onnx``. | ||
| 207 | + num_threads: | ||
| 208 | + Number of threads for neural network computation. | ||
| 209 | + sample_rate: | ||
| 210 | + Sample rate of the training data used to train the model. | ||
| 211 | + feature_dim: | ||
| 212 | + Dimension of the feature used to train the model. | ||
| 213 | + decoding_method: | ||
| 214 | + Valid values are greedy_search. | ||
| 215 | + debug: | ||
| 216 | + True to show debug messages. | ||
| 217 | + provider: | ||
| 218 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 219 | + language: | ||
| 220 | + If not empty, then valid values are: auto, zh, en, ja, ko, yue | ||
| 221 | + use_itn: | ||
| 222 | + True to enable inverse text normalization; False to disable it. | ||
| 223 | + rule_fsts: | ||
| 224 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 225 | + If there are multiple fsts, they are separated by a comma. | ||
| 226 | + rule_fars: | ||
| 227 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 228 | + If there are multiple archives, they are separated by a comma. | ||
| 229 | + """ | ||
| 230 | + self = cls.__new__(cls) | ||
| 231 | + model_config = OfflineModelConfig( | ||
| 232 | + sense_voice=OfflineSenseVoiceModelConfig( | ||
| 233 | + model=model, | ||
| 234 | + language=language, | ||
| 235 | + use_itn=use_itn, | ||
| 236 | + ), | ||
| 237 | + tokens=tokens, | ||
| 238 | + num_threads=num_threads, | ||
| 239 | + debug=debug, | ||
| 240 | + provider=provider, | ||
| 241 | + ) | ||
| 242 | + | ||
| 243 | + feat_config = FeatureExtractorConfig( | ||
| 244 | + sampling_rate=sample_rate, | ||
| 245 | + feature_dim=feature_dim, | ||
| 246 | + ) | ||
| 247 | + | ||
| 248 | + recognizer_config = OfflineRecognizerConfig( | ||
| 249 | + feat_config=feat_config, | ||
| 250 | + model_config=model_config, | ||
| 251 | + decoding_method=decoding_method, | ||
| 252 | + rule_fsts=rule_fsts, | ||
| 253 | + rule_fars=rule_fars, | ||
| 254 | + ) | ||
| 255 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 256 | + self.config = recognizer_config | ||
| 257 | + return self | ||
| 258 | + | ||
| 259 | + @classmethod | ||
| 177 | def from_paraformer( | 260 | def from_paraformer( |
| 178 | cls, | 261 | cls, |
| 179 | paraformer: str, | 262 | paraformer: str, |
| @@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig( | @@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig( | ||
| 355 | ) | 355 | ) |
| 356 | } | 356 | } |
| 357 | 357 | ||
| 358 | +func sherpaOnnxOfflineSenseVoiceModelConfig( | ||
| 359 | + model: String = "", | ||
| 360 | + language: String = "", | ||
| 361 | + useInverseTextNormalization: Bool = false | ||
| 362 | +) -> SherpaOnnxOfflineSenseVoiceModelConfig { | ||
| 363 | + return SherpaOnnxOfflineSenseVoiceModelConfig( | ||
| 364 | + model: toCPointer(model), | ||
| 365 | + language: toCPointer(language), | ||
| 366 | + use_itn: useInverseTextNormalization ? 1 : 0 | ||
| 367 | + ) | ||
| 368 | +} | ||
| 369 | + | ||
| 358 | func sherpaOnnxOfflineLMConfig( | 370 | func sherpaOnnxOfflineLMConfig( |
| 359 | model: String = "", | 371 | model: String = "", |
| 360 | scale: Float = 1.0 | 372 | scale: Float = 1.0 |
| @@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig( | @@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig( | ||
| 378 | modelType: String = "", | 390 | modelType: String = "", |
| 379 | modelingUnit: String = "cjkchar", | 391 | modelingUnit: String = "cjkchar", |
| 380 | bpeVocab: String = "", | 392 | bpeVocab: String = "", |
| 381 | - teleSpeechCtc: String = "" | 393 | + teleSpeechCtc: String = "", |
| 394 | + senseVoice: SherpaOnnxOfflineSenseVoiceModelConfig = sherpaOnnxOfflineSenseVoiceModelConfig() | ||
| 382 | ) -> SherpaOnnxOfflineModelConfig { | 395 | ) -> SherpaOnnxOfflineModelConfig { |
| 383 | return SherpaOnnxOfflineModelConfig( | 396 | return SherpaOnnxOfflineModelConfig( |
| 384 | transducer: transducer, | 397 | transducer: transducer, |
| @@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig( | @@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig( | ||
| 393 | model_type: toCPointer(modelType), | 406 | model_type: toCPointer(modelType), |
| 394 | modeling_unit: toCPointer(modelingUnit), | 407 | modeling_unit: toCPointer(modelingUnit), |
| 395 | bpe_vocab: toCPointer(bpeVocab), | 408 | bpe_vocab: toCPointer(bpeVocab), |
| 396 | - telespeech_ctc: toCPointer(teleSpeechCtc) | 409 | + telespeech_ctc: toCPointer(teleSpeechCtc), |
| 410 | + sense_voice: senseVoice | ||
| 397 | ) | 411 | ) |
| 398 | } | 412 | } |
| 399 | 413 |
| @@ -17,6 +17,7 @@ func run() { | @@ -17,6 +17,7 @@ func run() { | ||
| 17 | var modelConfig: SherpaOnnxOfflineModelConfig | 17 | var modelConfig: SherpaOnnxOfflineModelConfig |
| 18 | var modelType = "whisper" | 18 | var modelType = "whisper" |
| 19 | // modelType = "paraformer" | 19 | // modelType = "paraformer" |
| 20 | + // modelType = "sense_voice" | ||
| 20 | 21 | ||
| 21 | if modelType == "whisper" { | 22 | if modelType == "whisper" { |
| 22 | let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" | 23 | let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" |
| @@ -47,6 +48,19 @@ func run() { | @@ -47,6 +48,19 @@ func run() { | ||
| 47 | debug: 0, | 48 | debug: 0, |
| 48 | modelType: "paraformer" | 49 | modelType: "paraformer" |
| 49 | ) | 50 | ) |
| 51 | + } else if modelType == "sense_voice" { | ||
| 52 | + let model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx" | ||
| 53 | + let tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" | ||
| 54 | + let senseVoiceConfig = sherpaOnnxOfflineSenseVoiceModelConfig( | ||
| 55 | + model: model, | ||
| 56 | + useInverseTextNormalization: true | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + modelConfig = sherpaOnnxOfflineModelConfig( | ||
| 60 | + tokens: tokens, | ||
| 61 | + debug: 0, | ||
| 62 | + senseVoice: senseVoiceConfig | ||
| 63 | + ) | ||
| 50 | } else { | 64 | } else { |
| 51 | print("Please specify a supported modelType \(modelType)") | 65 | print("Please specify a supported modelType \(modelType)") |
| 52 | return | 66 | return |
| @@ -63,7 +77,10 @@ func run() { | @@ -63,7 +77,10 @@ func run() { | ||
| 63 | 77 | ||
| 64 | recognizer = SherpaOnnxOfflineRecognizer(config: &config) | 78 | recognizer = SherpaOnnxOfflineRecognizer(config: &config) |
| 65 | 79 | ||
| 66 | - let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" | 80 | + var filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" |
| 81 | + if modelType == "sense_voice" { | ||
| 82 | + filePath = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" | ||
| 83 | + } | ||
| 67 | let fileURL: NSURL = NSURL(fileURLWithPath: filePath) | 84 | let fileURL: NSURL = NSURL(fileURLWithPath: filePath) |
| 68 | let audioFile = try! AVAudioFile(forReading: fileURL as URL) | 85 | let audioFile = try! AVAudioFile(forReading: fileURL as URL) |
| 69 | 86 |
-
请 注册 或 登录 后发表评论