Fangjun Kuang
Committed by GitHub

Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)

@@ -13,6 +13,105 @@ echo "PATH: $PATH" @@ -13,6 +13,105 @@ echo "PATH: $PATH"
13 13
14 which $EXE 14 which $EXE
15 15
  16 +log "------------------------------------------------------------------------"
  17 +log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)"
  18 +log "------------------------------------------------------------------------"
  19 +
  20 +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2
  21 +name=$(basename $url)
  22 +curl -SL -O $url
  23 +tar xvf $name
  24 +rm $name
  25 +repo=$(basename -s .tar.bz2 $name)
  26 +ls -lh $repo
  27 +
  28 +log "test $repo"
  29 +test_wavs=(
  30 +de-german.wav
  31 +es-spanish.wav
  32 +hr-croatian.wav
  33 +po-polish.wav
  34 +uk-ukrainian.wav
  35 +en-english.wav
  36 +fr-french.wav
  37 +it-italian.wav
  38 +ru-russian.wav
  39 +)
  40 +for w in ${test_wavs[@]}; do
  41 + time $EXE \
  42 + --tokens=$repo/tokens.txt \
  43 + --encoder=$repo/encoder.onnx \
  44 + --decoder=$repo/decoder.onnx \
  45 + --joiner=$repo/joiner.onnx \
  46 + --debug=1 \
  47 + $repo/test_wavs/$w
  48 +done
  49 +
  50 +rm -rf $repo
  51 +
  52 +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2
  53 +name=$(basename $url)
  54 +curl -SL -O $url
  55 +tar xvf $name
  56 +rm $name
  57 +repo=$(basename -s .tar.bz2 $name)
  58 +ls -lh $repo
  59 +
  60 +log "Test $repo"
  61 +
  62 +time $EXE \
  63 + --tokens=$repo/tokens.txt \
  64 + --encoder=$repo/encoder.onnx \
  65 + --decoder=$repo/decoder.onnx \
  66 + --joiner=$repo/joiner.onnx \
  67 + --debug=1 \
  68 + $repo/test_wavs/en-english.wav
  69 +
  70 +rm -rf $repo
  71 +
  72 +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2
  73 +name=$(basename $url)
  74 +curl -SL -O $url
  75 +tar xvf $name
  76 +rm $name
  77 +repo=$(basename -s .tar.bz2 $name)
  78 +ls -lh $repo
  79 +
  80 +log "test $repo"
  81 +
  82 +time $EXE \
  83 + --tokens=$repo/tokens.txt \
  84 + --encoder=$repo/encoder.onnx \
  85 + --decoder=$repo/decoder.onnx \
  86 + --joiner=$repo/joiner.onnx \
  87 + --debug=1 \
  88 + $repo/test_wavs/es-spanish.wav
  89 +
  90 +rm -rf $repo
  91 +
  92 +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2
  93 +name=$(basename $url)
  94 +curl -SL -O $url
  95 +tar xvf $name
  96 +rm $name
  97 +repo=$(basename -s .tar.bz2 $name)
  98 +ls -lh $repo
  99 +
  100 +log "Test $repo"
  101 +
  102 +time $EXE \
  103 + --tokens=$repo/tokens.txt \
  104 + --encoder=$repo/encoder.onnx \
  105 + --decoder=$repo/decoder.onnx \
  106 + --joiner=$repo/joiner.onnx \
  107 + --debug=1 \
  108 + $repo/test_wavs/en-english.wav \
  109 + $repo/test_wavs/de-german.wav \
  110 + $repo/test_wavs/fr-french.wav \
  111 + $repo/test_wavs/es-spanish.wav
  112 +
  113 +rm -rf $repo
  114 +
16 log "------------------------------------------------------------" 115 log "------------------------------------------------------------"
17 log "Run Conformer transducer (English)" 116 log "Run Conformer transducer (English)"
18 log "------------------------------------------------------------" 117 log "------------------------------------------------------------"
@@ -128,6 +128,14 @@ jobs: @@ -128,6 +128,14 @@ jobs:
128 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 128 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
129 path: install/* 129 path: install/*
130 130
  131 + - name: Test offline transducer
  132 + shell: bash
  133 + run: |
  134 + export PATH=$PWD/build/bin:$PATH
  135 + export EXE=sherpa-onnx-offline
  136 +
  137 + .github/scripts/test-offline-transducer.sh
  138 +
131 - name: Test spoken language identification (C++ API) 139 - name: Test spoken language identification (C++ API)
132 shell: bash 140 shell: bash
133 run: | 141 run: |
@@ -215,14 +223,6 @@ jobs: @@ -215,14 +223,6 @@ jobs:
215 223
216 .github/scripts/test-online-paraformer.sh 224 .github/scripts/test-online-paraformer.sh
217 225
218 - - name: Test offline transducer  
219 - shell: bash  
220 - run: |  
221 - export PATH=$PWD/build/bin:$PATH  
222 - export EXE=sherpa-onnx-offline  
223 -  
224 - .github/scripts/test-offline-transducer.sh  
225 -  
226 - name: Test online transducer 226 - name: Test online transducer
227 shell: bash 227 shell: bash
228 run: | 228 run: |
@@ -107,6 +107,14 @@ jobs: @@ -107,6 +107,14 @@ jobs:
107 otool -L build/bin/sherpa-onnx 107 otool -L build/bin/sherpa-onnx
108 otool -l build/bin/sherpa-onnx 108 otool -l build/bin/sherpa-onnx
109 109
  110 + - name: Test offline transducer
  111 + shell: bash
  112 + run: |
  113 + export PATH=$PWD/build/bin:$PATH
  114 + export EXE=sherpa-onnx-offline
  115 +
  116 + .github/scripts/test-offline-transducer.sh
  117 +
110 - name: Test online CTC 118 - name: Test online CTC
111 shell: bash 119 shell: bash
112 run: | 120 run: |
@@ -192,14 +200,6 @@ jobs: @@ -192,14 +200,6 @@ jobs:
192 200
193 .github/scripts/test-offline-ctc.sh 201 .github/scripts/test-offline-ctc.sh
194 202
195 - - name: Test offline transducer  
196 - shell: bash  
197 - run: |  
198 - export PATH=$PWD/build/bin:$PATH  
199 - export EXE=sherpa-onnx-offline  
200 -  
201 - .github/scripts/test-offline-transducer.sh  
202 -  
203 - name: Test online transducer 203 - name: Test online transducer
204 shell: bash 204 shell: bash
205 run: | 205 run: |
@@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 @@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
104 sherpa-onnx-ced-* 104 sherpa-onnx-ced-*
105 node_modules 105 node_modules
106 package-lock.json 106 package-lock.json
  107 +sherpa-onnx-nemo-*
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming CTC model from NeMo
  5 +to decode files.
  6 +
  7 +Please download model files from
  8 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  9 +
  10 +
  11 +The example model supports 10 languages and it is converted from
  12 +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
  13 +"""
  14 +
  15 +from pathlib import Path
  16 +
  17 +import sherpa_onnx
  18 +import soundfile as sf
  19 +
  20 +
  21 +def create_recognizer():
  22 + model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx"
  23 + tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
  24 +
  25 + test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
  26 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
  27 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
  28 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
  29 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
  30 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
  31 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
  32 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
  33 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
  34 +
  35 + if not Path(model).is_file() or not Path(test_wav).is_file():
  36 + raise ValueError(
  37 + """Please download model files from
  38 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  39 + """
  40 + )
  41 + return (
  42 + sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  43 + model=model,
  44 + tokens=tokens,
  45 + debug=True,
  46 + ),
  47 + test_wav,
  48 + )
  49 +
  50 +
  51 +def main():
  52 + recognizer, wave_filename = create_recognizer()
  53 +
  54 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  55 + audio = audio[:, 0] # only use the first channel
  56 +
  57 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  58 + # sample_rate does not need to be 16000 Hz
  59 +
  60 + stream = recognizer.create_stream()
  61 + stream.accept_waveform(sample_rate, audio)
  62 + recognizer.decode_stream(stream)
  63 + print(wave_filename)
  64 + print(stream.result)
  65 +
  66 +
  67 +if __name__ == "__main__":
  68 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming transducer model from NeMo
  5 +to decode files.
  6 +
  7 +Please download model files from
  8 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  9 +
  10 +
  11 +The example model supports 10 languages and it is converted from
  12 +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
  13 +"""
  14 +
  15 +from pathlib import Path
  16 +
  17 +import sherpa_onnx
  18 +import soundfile as sf
  19 +
  20 +
  21 +def create_recognizer():
  22 + encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx"
  23 + decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx"
  24 + joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx"
  25 + tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
  26 +
  27 + test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
  28 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
  29 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
  30 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
  31 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
  32 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
  33 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
  34 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
  35 + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
  36 +
  37 + if not Path(encoder).is_file() or not Path(test_wav).is_file():
  38 + raise ValueError(
  39 + """Please download model files from
  40 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  41 + """
  42 + )
  43 + return (
  44 + sherpa_onnx.OfflineRecognizer.from_transducer(
  45 + encoder=encoder,
  46 + decoder=decoder,
  47 + joiner=joiner,
  48 + tokens=tokens,
  49 + model_type="nemo_transducer",
  50 + debug=True,
  51 + ),
  52 + test_wav,
  53 + )
  54 +
  55 +
  56 +def main():
  57 + recognizer, wave_filename = create_recognizer()
  58 +
  59 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  60 + audio = audio[:, 0] # only use the first channel
  61 +
  62 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  63 + # sample_rate does not need to be 16000 Hz
  64 +
  65 + stream = recognizer.create_stream()
  66 + stream.accept_waveform(sample_rate, audio)
  67 + recognizer.decode_stream(stream)
  68 + print(wave_filename)
  69 + print(stream.result)
  70 +
  71 +
  72 +if __name__ == "__main__":
  73 + main()
@@ -40,9 +40,11 @@ set(sources @@ -40,9 +40,11 @@ set(sources
40 offline-tdnn-ctc-model.cc 40 offline-tdnn-ctc-model.cc
41 offline-tdnn-model-config.cc 41 offline-tdnn-model-config.cc
42 offline-transducer-greedy-search-decoder.cc 42 offline-transducer-greedy-search-decoder.cc
  43 + offline-transducer-greedy-search-nemo-decoder.cc
43 offline-transducer-model-config.cc 44 offline-transducer-model-config.cc
44 offline-transducer-model.cc 45 offline-transducer-model.cc
45 offline-transducer-modified-beam-search-decoder.cc 46 offline-transducer-modified-beam-search-decoder.cc
  47 + offline-transducer-nemo-model.cc
46 offline-wenet-ctc-model-config.cc 48 offline-wenet-ctc-model-config.cc
47 offline-wenet-ctc-model.cc 49 offline-wenet-ctc-model.cc
48 offline-whisper-greedy-search-decoder.cc 50 offline-whisper-greedy-search-decoder.cc
@@ -56,6 +56,19 @@ struct FeatureExtractorConfig { @@ -56,6 +56,19 @@ struct FeatureExtractorConfig {
56 bool remove_dc_offset = true; // Subtract mean of wave before FFT. 56 bool remove_dc_offset = true; // Subtract mean of wave before FFT.
57 std::string window_type = "povey"; // e.g. Hamming window 57 std::string window_type = "povey"; // e.g. Hamming window
58 58
  59 + // For models from NeMo
  60 + // This option is not exposed and is set internally when loading models.
  61 + // Possible values:
  62 + // - per_feature
  63 + // - all_features (not implemented yet)
  64 + // - fixed_mean (not implemented)
  65 + // - fixed_std (not implemented)
  66 + // - or just leave it to empty
  67 + // See
  68 + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
  69 + // for details
  70 + std::string nemo_normalize_type;
  71 +
59 std::string ToString() const; 72 std::string ToString() const;
60 73
61 void Register(ParseOptions *po); 74 void Register(ParseOptions *po);
@@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { @@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
68 : config_(config), 68 : config_(config),
69 model_(OnlineTransducerModel::Create(config.model_config)), 69 model_(OnlineTransducerModel::Create(config.model_config)),
70 sym_(config.model_config.tokens) { 70 sym_(config.model_config.tokens) {
71 - if (sym_.contains("<unk>")) { 71 + if (sym_.Contains("<unk>")) {
72 unk_id_ = sym_["<unk>"]; 72 unk_id_ = sym_["<unk>"];
73 } 73 }
74 74
@@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { @@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
87 : config_(config), 87 : config_(config),
88 model_(OnlineTransducerModel::Create(mgr, config.model_config)), 88 model_(OnlineTransducerModel::Create(mgr, config.model_config)),
89 sym_(mgr, config.model_config.tokens) { 89 sym_(mgr, config.model_config.tokens) {
90 - if (sym_.contains("<unk>")) { 90 + if (sym_.Contains("<unk>")) {
91 unk_id_ = sym_["<unk>"]; 91 unk_id_ = sym_["<unk>"];
92 } 92 }
93 93
1 // sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc 1 // sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2023-2024 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" 5 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
6 6
@@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
38 std::string text; 38 std::string text;
39 39
40 for (int32_t i = 0; i != src.tokens.size(); ++i) { 40 for (int32_t i = 0; i != src.tokens.size(); ++i) {
41 - if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { 41 + if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
42 // tdnn models from yesno have a SIL token, we should remove it. 42 // tdnn models from yesno have a SIL token, we should remove it.
43 continue; 43 continue;
44 } 44 }
@@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
103 decoder_ = std::make_unique<OfflineCtcFstDecoder>( 103 decoder_ = std::make_unique<OfflineCtcFstDecoder>(
104 config_.ctc_fst_decoder_config); 104 config_.ctc_fst_decoder_config);
105 } else if (config_.decoding_method == "greedy_search") { 105 } else if (config_.decoding_method == "greedy_search") {
106 - if (!symbol_table_.contains("<blk>") &&  
107 - !symbol_table_.contains("<eps>") &&  
108 - !symbol_table_.contains("<blank>")) { 106 + if (!symbol_table_.Contains("<blk>") &&
  107 + !symbol_table_.Contains("<eps>") &&
  108 + !symbol_table_.Contains("<blank>")) {
109 SHERPA_ONNX_LOGE( 109 SHERPA_ONNX_LOGE(
110 "We expect that tokens.txt contains " 110 "We expect that tokens.txt contains "
111 "the symbol <blk> or <eps> or <blank> and its ID."); 111 "the symbol <blk> or <eps> or <blank> and its ID.");
@@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
113 } 113 }
114 114
115 int32_t blank_id = 0; 115 int32_t blank_id = 0;
116 - if (symbol_table_.contains("<blk>")) { 116 + if (symbol_table_.Contains("<blk>")) {
117 blank_id = symbol_table_["<blk>"]; 117 blank_id = symbol_table_["<blk>"];
118 - } else if (symbol_table_.contains("<eps>")) { 118 + } else if (symbol_table_.Contains("<eps>")) {
119 // for tdnn models of the yesno recipe from icefall 119 // for tdnn models of the yesno recipe from icefall
120 blank_id = symbol_table_["<eps>"]; 120 blank_id = symbol_table_["<eps>"];
121 - } else if (symbol_table_.contains("<blank>")) { 121 + } else if (symbol_table_.Contains("<blank>")) {
122 // for Wenet CTC models 122 // for Wenet CTC models
123 blank_id = symbol_table_["<blank>"]; 123 blank_id = symbol_table_["<blank>"];
124 } 124 }
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 11 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
12 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
13 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 13 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
  14 +#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
14 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" 15 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
15 #include "sherpa-onnx/csrc/onnx-utils.h" 16 #include "sherpa-onnx/csrc/onnx-utils.h"
16 #include "sherpa-onnx/csrc/text-utils.h" 17 #include "sherpa-onnx/csrc/text-utils.h"
@@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
23 const auto &model_type = config.model_config.model_type; 24 const auto &model_type = config.model_config.model_type;
24 if (model_type == "transducer") { 25 if (model_type == "transducer") {
25 return std::make_unique<OfflineRecognizerTransducerImpl>(config); 26 return std::make_unique<OfflineRecognizerTransducerImpl>(config);
  27 + } else if (model_type == "nemo_transducer") {
  28 + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
26 } else if (model_type == "paraformer") { 29 } else if (model_type == "paraformer") {
27 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 30 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
28 } else if (model_type == "nemo_ctc" || model_type == "tdnn" || 31 } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
@@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
122 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 125 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
123 } 126 }
124 127
  128 + if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
  129 + !config.model_config.transducer.decoder_filename.empty() &&
  130 + !config.model_config.transducer.joiner_filename.empty()) {
  131 + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
  132 + }
  133 +
125 if (model_type == "EncDecCTCModelBPE" || 134 if (model_type == "EncDecCTCModelBPE" ||
126 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || 135 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
127 model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { 136 model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
@@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
155 const auto &model_type = config.model_config.model_type; 164 const auto &model_type = config.model_config.model_type;
156 if (model_type == "transducer") { 165 if (model_type == "transducer") {
157 return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); 166 return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
  167 + } else if (model_type == "nemo_transducer") {
  168 + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
158 } else if (model_type == "paraformer") { 169 } else if (model_type == "paraformer") {
159 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); 170 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
160 } else if (model_type == "nemo_ctc" || model_type == "tdnn" || 171 } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
@@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
254 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); 265 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
255 } 266 }
256 267
  268 + if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
  269 + !config.model_config.transducer.decoder_filename.empty() &&
  270 + !config.model_config.transducer.joiner_filename.empty()) {
  271 + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
  272 + }
  273 +
257 if (model_type == "EncDecCTCModelBPE" || 274 if (model_type == "EncDecCTCModelBPE" ||
258 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || 275 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
259 model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { 276 model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
  1 +// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
  7 +
  8 +#include <fstream>
  9 +#include <ios>
  10 +#include <memory>
  11 +#include <regex> // NOLINT
  12 +#include <sstream>
  13 +#include <string>
  14 +#include <utility>
  15 +#include <vector>
  16 +
  17 +#if __ANDROID_API__ >= 9
  18 +#include "android/asset_manager.h"
  19 +#include "android/asset_manager_jni.h"
  20 +#endif
  21 +
  22 +#include "sherpa-onnx/csrc/macros.h"
  23 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  24 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  25 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
  26 +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
  27 +#include "sherpa-onnx/csrc/pad-sequence.h"
  28 +#include "sherpa-onnx/csrc/symbol-table.h"
  29 +#include "sherpa-onnx/csrc/transpose.h"
  30 +#include "sherpa-onnx/csrc/utils.h"
  31 +
  32 +namespace sherpa_onnx {
  33 +
  34 +// defined in ./offline-recognizer-transducer-impl.h
  35 +OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src,
  36 + const SymbolTable &sym_table,
  37 + int32_t frame_shift_ms,
  38 + int32_t subsampling_factor);
  39 +
  40 +class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
  41 + public:
  42 + explicit OfflineRecognizerTransducerNeMoImpl(
  43 + const OfflineRecognizerConfig &config)
  44 + : config_(config),
  45 + symbol_table_(config_.model_config.tokens),
  46 + model_(std::make_unique<OfflineTransducerNeMoModel>(
  47 + config_.model_config)) {
  48 + if (config_.decoding_method == "greedy_search") {
  49 + decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
  50 + model_.get(), config_.blank_penalty);
  51 + } else {
  52 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  53 + config_.decoding_method.c_str());
  54 + exit(-1);
  55 + }
  56 + PostInit();
  57 + }
  58 +
  59 +#if __ANDROID_API__ >= 9
  60 + explicit OfflineRecognizerTransducerNeMoImpl(
  61 + AAssetManager *mgr, const OfflineRecognizerConfig &config)
  62 + : config_(config),
  63 + symbol_table_(mgr, config_.model_config.tokens),
  64 + model_(std::make_unique<OfflineTransducerNeMoModel>(
  65 + mgr, config_.model_config)) {
  66 + if (config_.decoding_method == "greedy_search") {
  67 + decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
  68 + model_.get(), config_.blank_penalty);
  69 + } else {
  70 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  71 + config_.decoding_method.c_str());
  72 + exit(-1);
  73 + }
  74 +
  75 + PostInit();
  76 + }
  77 +#endif
  78 +
  79 + std::unique_ptr<OfflineStream> CreateStream() const override {
  80 + return std::make_unique<OfflineStream>(config_.feat_config);
  81 + }
  82 +
  83 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  84 + auto memory_info =
  85 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  86 +
  87 + int32_t feat_dim = ss[0]->FeatureDim();
  88 +
  89 + std::vector<Ort::Value> features;
  90 +
  91 + features.reserve(n);
  92 +
  93 + std::vector<std::vector<float>> features_vec(n);
  94 + std::vector<int64_t> features_length_vec(n);
  95 + for (int32_t i = 0; i != n; ++i) {
  96 + auto f = ss[i]->GetFrames();
  97 + int32_t num_frames = f.size() / feat_dim;
  98 +
  99 + features_length_vec[i] = num_frames;
  100 + features_vec[i] = std::move(f);
  101 +
  102 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  103 +
  104 + Ort::Value x = Ort::Value::CreateTensor(
  105 + memory_info, features_vec[i].data(), features_vec[i].size(),
  106 + shape.data(), shape.size());
  107 + features.push_back(std::move(x));
  108 + }
  109 +
  110 + std::vector<const Ort::Value *> features_pointer(n);
  111 + for (int32_t i = 0; i != n; ++i) {
  112 + features_pointer[i] = &features[i];
  113 + }
  114 +
  115 + std::array<int64_t, 1> features_length_shape = {n};
  116 + Ort::Value x_length = Ort::Value::CreateTensor(
  117 + memory_info, features_length_vec.data(), n,
  118 + features_length_shape.data(), features_length_shape.size());
  119 +
  120 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
  121 +
  122 + auto t = model_->RunEncoder(std::move(x), std::move(x_length));
  123 + // t[0] encoder_out, float tensor, (batch_size, dim, T)
  124 + // t[1] encoder_out_length, int64 tensor, (batch_size,)
  125 +
  126 + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
  127 +
  128 + auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1]));
  129 +
  130 + int32_t frame_shift_ms = 10;
  131 + for (int32_t i = 0; i != n; ++i) {
  132 + auto r = Convert(results[i], symbol_table_, frame_shift_ms,
  133 + model_->SubsamplingFactor());
  134 +
  135 + ss[i]->SetResult(r);
  136 + }
  137 + }
  138 +
  139 + private:
  140 + void PostInit() {
  141 + config_.feat_config.nemo_normalize_type =
  142 + model_->FeatureNormalizationMethod();
  143 +
  144 + config_.feat_config.low_freq = 0;
  145 + // config_.feat_config.high_freq = 8000;
  146 + config_.feat_config.is_librosa = true;
  147 + config_.feat_config.remove_dc_offset = false;
  148 + // config_.feat_config.window_type = "hann";
  149 + config_.feat_config.dither = 0;
  150 + config_.feat_config.nemo_normalize_type =
  151 + model_->FeatureNormalizationMethod();
  152 +
  153 + int32_t vocab_size = model_->VocabSize();
  154 +
  155 + // check the blank ID
  156 + if (!symbol_table_.Contains("<blk>")) {
  157 + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
  158 + exit(-1);
  159 + }
  160 +
  161 + if (symbol_table_["<blk>"] != vocab_size - 1) {
  162 + SHERPA_ONNX_LOGE("<blk> is not the last token!");
  163 + exit(-1);
  164 + }
  165 +
  166 + if (symbol_table_.NumSymbols() != vocab_size) {
  167 + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
  168 + symbol_table_.NumSymbols(), vocab_size);
  169 + exit(-1);
  170 + }
  171 + }
  172 +
  173 + private:
  174 + OfflineRecognizerConfig config_;
  175 + SymbolTable symbol_table_;
  176 + std::unique_ptr<OfflineTransducerNeMoModel> model_;
  177 + std::unique_ptr<OfflineTransducerDecoder> decoder_;
  178 +};
  179 +
  180 +} // namespace sherpa_onnx
  181 +
  182 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
@@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, @@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
35 35
36 std::string text; 36 std::string text;
37 for (auto i : src.tokens) { 37 for (auto i : src.tokens) {
38 - if (!sym_table.contains(i)) { 38 + if (!sym_table.Contains(i)) {
39 continue; 39 continue;
40 } 40 }
41 41
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include "android/asset_manager_jni.h" 14 #include "android/asset_manager_jni.h"
15 #endif 15 #endif
16 16
  17 +#include "sherpa-onnx/csrc/features.h"
17 #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" 18 #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
18 #include "sherpa-onnx/csrc/offline-lm-config.h" 19 #include "sherpa-onnx/csrc/offline-lm-config.h"
19 #include "sherpa-onnx/csrc/offline-model-config.h" 20 #include "sherpa-onnx/csrc/offline-model-config.h"
@@ -26,7 +27,7 @@ namespace sherpa_onnx { @@ -26,7 +27,7 @@ namespace sherpa_onnx {
26 struct OfflineRecognitionResult; 27 struct OfflineRecognitionResult;
27 28
28 struct OfflineRecognizerConfig { 29 struct OfflineRecognizerConfig {
29 - OfflineFeatureExtractorConfig feat_config; 30 + FeatureExtractorConfig feat_config;
30 OfflineModelConfig model_config; 31 OfflineModelConfig model_config;
31 OfflineLMConfig lm_config; 32 OfflineLMConfig lm_config;
32 OfflineCtcFstDecoderConfig ctc_fst_decoder_config; 33 OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
@@ -44,7 +45,7 @@ struct OfflineRecognizerConfig { @@ -44,7 +45,7 @@ struct OfflineRecognizerConfig {
44 45
45 OfflineRecognizerConfig() = default; 46 OfflineRecognizerConfig() = default;
46 OfflineRecognizerConfig( 47 OfflineRecognizerConfig(
47 - const OfflineFeatureExtractorConfig &feat_config, 48 + const FeatureExtractorConfig &feat_config,
48 const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, 49 const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
49 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, 50 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
50 const std::string &decoding_method, int32_t max_active_paths, 51 const std::string &decoding_method, int32_t max_active_paths,
@@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, @@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
52 } 52 }
53 } 53 }
54 54
55 -void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {  
56 - po->Register("sample-rate", &sampling_rate,  
57 - "Sampling rate of the input waveform. "  
58 - "Note: You can have a different "  
59 - "sample rate for the input waveform. We will do resampling "  
60 - "inside the feature extractor");  
61 -  
62 - po->Register("feat-dim", &feature_dim,  
63 - "Feature dimension. Must match the one expected by the model.");  
64 -}  
65 -  
66 -std::string OfflineFeatureExtractorConfig::ToString() const {  
67 - std::ostringstream os;  
68 -  
69 - os << "OfflineFeatureExtractorConfig(";  
70 - os << "sampling_rate=" << sampling_rate << ", ";  
71 - os << "feature_dim=" << feature_dim << ")";  
72 -  
73 - return os.str();  
74 -}  
75 -  
76 class OfflineStream::Impl { 55 class OfflineStream::Impl {
77 public: 56 public:
78 - explicit Impl(const OfflineFeatureExtractorConfig &config, 57 + explicit Impl(const FeatureExtractorConfig &config,
79 ContextGraphPtr context_graph) 58 ContextGraphPtr context_graph)
80 : config_(config), context_graph_(context_graph) { 59 : config_(config), context_graph_(context_graph) {
81 - opts_.frame_opts.dither = 0;  
82 - opts_.frame_opts.snip_edges = false; 60 + opts_.frame_opts.dither = config.dither;
  61 + opts_.frame_opts.snip_edges = config.snip_edges;
83 opts_.frame_opts.samp_freq = config.sampling_rate; 62 opts_.frame_opts.samp_freq = config.sampling_rate;
  63 + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
  64 + opts_.frame_opts.frame_length_ms = config.frame_length_ms;
  65 + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
  66 + opts_.frame_opts.window_type = config.window_type;
  67 +
84 opts_.mel_opts.num_bins = config.feature_dim; 68 opts_.mel_opts.num_bins = config.feature_dim;
85 69
86 - // Please see  
87 - // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27  
88 - // and  
89 - // https://github.com/k2-fsa/sherpa-onnx/issues/514  
90 - opts_.mel_opts.high_freq = -400; 70 + opts_.mel_opts.high_freq = config.high_freq;
  71 + opts_.mel_opts.low_freq = config.low_freq;
  72 +
  73 + opts_.mel_opts.is_librosa = config.is_librosa;
91 74
92 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 75 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
93 } 76 }
@@ -237,7 +220,7 @@ class OfflineStream::Impl { @@ -237,7 +220,7 @@ class OfflineStream::Impl {
237 } 220 }
238 221
239 private: 222 private:
240 - OfflineFeatureExtractorConfig config_; 223 + FeatureExtractorConfig config_;
241 std::unique_ptr<knf::OnlineFbank> fbank_; 224 std::unique_ptr<knf::OnlineFbank> fbank_;
242 std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; 225 std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
243 knf::FbankOptions opts_; 226 knf::FbankOptions opts_;
@@ -245,9 +228,8 @@ class OfflineStream::Impl { @@ -245,9 +228,8 @@ class OfflineStream::Impl {
245 ContextGraphPtr context_graph_; 228 ContextGraphPtr context_graph_;
246 }; 229 };
247 230
248 -OfflineStream::OfflineStream(  
249 - const OfflineFeatureExtractorConfig &config /*= {}*/,  
250 - ContextGraphPtr context_graph /*= nullptr*/) 231 +OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
  232 + ContextGraphPtr context_graph /*= nullptr*/)
251 : impl_(std::make_unique<Impl>(config, context_graph)) {} 233 : impl_(std::make_unique<Impl>(config, context_graph)) {}
252 234
253 OfflineStream::OfflineStream(WhisperTag tag) 235 OfflineStream::OfflineStream(WhisperTag tag)
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include <vector> 11 #include <vector>
12 12
13 #include "sherpa-onnx/csrc/context-graph.h" 13 #include "sherpa-onnx/csrc/context-graph.h"
  14 +#include "sherpa-onnx/csrc/features.h"
14 #include "sherpa-onnx/csrc/parse-options.h" 15 #include "sherpa-onnx/csrc/parse-options.h"
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -32,46 +33,12 @@ struct OfflineRecognitionResult { @@ -32,46 +33,12 @@ struct OfflineRecognitionResult {
32 std::string AsJsonString() const; 33 std::string AsJsonString() const;
33 }; 34 };
34 35
35 -struct OfflineFeatureExtractorConfig {  
36 - // Sampling rate used by the feature extractor. If it is different from  
37 - // the sampling rate of the input waveform, we will do resampling inside.  
38 - int32_t sampling_rate = 16000;  
39 -  
40 - // Feature dimension  
41 - int32_t feature_dim = 80;  
42 -  
43 - // Set internally by some models, e.g., paraformer and wenet CTC models set  
44 - // it to false.  
45 - // This parameter is not exposed to users from the commandline  
46 - // If true, the feature extractor expects inputs to be normalized to  
47 - // the range [-1, 1].  
48 - // If false, we will multiply the inputs by 32768  
49 - bool normalize_samples = true;  
50 -  
51 - // For models from NeMo  
52 - // This option is not exposed and is set internally when loading models.  
53 - // Possible values:  
54 - // - per_feature  
55 - // - all_features (not implemented yet)  
56 - // - fixed_mean (not implemented)  
57 - // - fixed_std (not implemented)  
58 - // - or just leave it to empty  
59 - // See  
60 - // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59  
61 - // for details  
62 - std::string nemo_normalize_type;  
63 -  
64 - std::string ToString() const;  
65 -  
66 - void Register(ParseOptions *po);  
67 -};  
68 -  
69 struct WhisperTag {}; 36 struct WhisperTag {};
70 struct CEDTag {}; 37 struct CEDTag {};
71 38
72 class OfflineStream { 39 class OfflineStream {
73 public: 40 public:
74 - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, 41 + explicit OfflineStream(const FeatureExtractorConfig &config = {},
75 ContextGraphPtr context_graph = {}); 42 ContextGraphPtr context_graph = {});
76 43
77 explicit OfflineStream(WhisperTag tag); 44 explicit OfflineStream(WhisperTag tag);
@@ -14,8 +14,8 @@ namespace sherpa_onnx { @@ -14,8 +14,8 @@ namespace sherpa_onnx {
14 14
15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { 15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
16 public: 16 public:
17 - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,  
18 - float blank_penalty) 17 + OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
  18 + float blank_penalty)
19 : model_(model), blank_penalty_(blank_penalty) {} 19 : model_(model), blank_penalty_(blank_penalty) {}
20 20
21 std::vector<OfflineTransducerDecoderResult> Decode( 21 std::vector<OfflineTransducerDecoderResult> Decode(
  1 +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <iterator>
  9 +#include <utility>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
  17 + int32_t token, OrtAllocator *allocator) {
  18 + std::array<int64_t, 2> shape{1, 1};
  19 +
  20 + Ort::Value decoder_input =
  21 + Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
  22 +
  23 + std::array<int64_t, 1> length_shape{1};
  24 + Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
  25 + allocator, length_shape.data(), length_shape.size());
  26 +
  27 + int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
  28 +
  29 + int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
  30 +
  31 + p[0] = token;
  32 +
  33 + p_length[0] = 1;
  34 +
  35 + return {std::move(decoder_input), std::move(decoder_input_length)};
  36 +}
  37 +
  38 +static OfflineTransducerDecoderResult DecodeOne(
  39 + const float *p, int32_t num_rows, int32_t num_cols,
  40 + OfflineTransducerNeMoModel *model, float blank_penalty) {
  41 + auto memory_info =
  42 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  43 +
  44 + OfflineTransducerDecoderResult ans;
  45 +
  46 + int32_t vocab_size = model->VocabSize();
  47 + int32_t blank_id = vocab_size - 1;
  48 +
  49 + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
  50 +
  51 + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
  52 + model->RunDecoder(std::move(decoder_input_pair.first),
  53 + std::move(decoder_input_pair.second),
  54 + model->GetDecoderInitStates(1));
  55 +
  56 + std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
  57 +
  58 + for (int32_t t = 0; t != num_rows; ++t) {
  59 + Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
  60 + memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
  61 + encoder_shape.data(), encoder_shape.size());
  62 +
  63 + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
  64 + View(&decoder_output_pair.first));
  65 +
  66 + float *p_logit = logit.GetTensorMutableData<float>();
  67 + if (blank_penalty > 0) {
  68 + p_logit[blank_id] -= blank_penalty;
  69 + }
  70 +
  71 + auto y = static_cast<int32_t>(std::distance(
  72 + static_cast<const float *>(p_logit),
  73 + std::max_element(static_cast<const float *>(p_logit),
  74 + static_cast<const float *>(p_logit) + vocab_size)));
  75 +
  76 + if (y != blank_id) {
  77 + ans.tokens.push_back(y);
  78 + ans.timestamps.push_back(t);
  79 +
  80 + decoder_input_pair = BuildDecoderInput(y, model->Allocator());
  81 +
  82 + decoder_output_pair =
  83 + model->RunDecoder(std::move(decoder_input_pair.first),
  84 + std::move(decoder_input_pair.second),
  85 + std::move(decoder_output_pair.second));
  86 + } // if (y != blank_id)
  87 + } // for (int32_t i = 0; i != num_rows; ++i)
  88 +
  89 + return ans;
  90 +}
  91 +
  92 +std::vector<OfflineTransducerDecoderResult>
  93 +OfflineTransducerGreedySearchNeMoDecoder::Decode(
  94 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  95 + OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
  96 + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  97 +
  98 + int32_t batch_size = static_cast<int32_t>(shape[0]);
  99 + int32_t dim1 = static_cast<int32_t>(shape[1]);
  100 + int32_t dim2 = static_cast<int32_t>(shape[2]);
  101 +
  102 + const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
  103 + const float *p = encoder_out.GetTensorData<float>();
  104 +
  105 + std::vector<OfflineTransducerDecoderResult> ans(batch_size);
  106 +
  107 + for (int32_t i = 0; i != batch_size; ++i) {
  108 + const float *this_p = p + dim1 * dim2 * i;
  109 + int32_t this_len = p_length[i];
  110 +
  111 + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
  112 + }
  113 +
  114 + return ans;
  115 +}
  116 +
  117 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineTransducerGreedySearchNeMoDecoder
  16 + : public OfflineTransducerDecoder {
  17 + public:
  18 + OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
  19 + float blank_penalty)
  20 + : model_(model), blank_penalty_(blank_penalty) {}
  21 +
  22 + std::vector<OfflineTransducerDecoderResult> Decode(
  23 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  24 + OfflineStream **ss = nullptr, int32_t n = 0) override;
  25 +
  26 + private:
  27 + OfflineTransducerNeMoModel *model_; // Not owned
  28 + float blank_penalty_;
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-transducer-nemo-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  14 +#include "sherpa-onnx/csrc/onnx-utils.h"
  15 +#include "sherpa-onnx/csrc/session.h"
  16 +#include "sherpa-onnx/csrc/transpose.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +class OfflineTransducerNeMoModel::Impl {
  21 + public:
  22 + explicit Impl(const OfflineModelConfig &config)
  23 + : config_(config),
  24 + env_(ORT_LOGGING_LEVEL_WARNING),
  25 + sess_opts_(GetSessionOptions(config)),
  26 + allocator_{} {
  27 + {
  28 + auto buf = ReadFile(config.transducer.encoder_filename);
  29 + InitEncoder(buf.data(), buf.size());
  30 + }
  31 +
  32 + {
  33 + auto buf = ReadFile(config.transducer.decoder_filename);
  34 + InitDecoder(buf.data(), buf.size());
  35 + }
  36 +
  37 + {
  38 + auto buf = ReadFile(config.transducer.joiner_filename);
  39 + InitJoiner(buf.data(), buf.size());
  40 + }
  41 + }
  42 +
  43 +#if __ANDROID_API__ >= 9
  44 + Impl(AAssetManager *mgr, const OfflineModelConfig &config)
  45 + : config_(config),
  46 + env_(ORT_LOGGING_LEVEL_WARNING),
  47 + sess_opts_(GetSessionOptions(config)),
  48 + allocator_{} {
  49 + {
  50 + auto buf = ReadFile(mgr, config.transducer.encoder_filename);
  51 + InitEncoder(buf.data(), buf.size());
  52 + }
  53 +
  54 + {
  55 + auto buf = ReadFile(mgr, config.transducer.decoder_filename);
  56 + InitDecoder(buf.data(), buf.size());
  57 + }
  58 +
  59 + {
  60 + auto buf = ReadFile(mgr, config.transducer.joiner_filename);
  61 + InitJoiner(buf.data(), buf.size());
  62 + }
  63 + }
  64 +#endif
  65 +
  66 + std::vector<Ort::Value> RunEncoder(Ort::Value features,
  67 + Ort::Value features_length) {
  68 + // (B, T, C) -> (B, C, T)
  69 + features = Transpose12(allocator_, &features);
  70 +
  71 + std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
  72 + std::move(features_length)};
  73 +
  74 + auto encoder_out = encoder_sess_->Run(
  75 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  76 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  77 + encoder_output_names_ptr_.size());
  78 +
  79 + return encoder_out;
  80 + }
  81 +
  82 + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
  83 + Ort::Value targets, Ort::Value targets_length,
  84 + std::vector<Ort::Value> states) {
  85 + std::vector<Ort::Value> decoder_inputs;
  86 + decoder_inputs.reserve(2 + states.size());
  87 +
  88 + decoder_inputs.push_back(std::move(targets));
  89 + decoder_inputs.push_back(std::move(targets_length));
  90 +
  91 + for (auto &s : states) {
  92 + decoder_inputs.push_back(std::move(s));
  93 + }
  94 +
  95 + auto decoder_out = decoder_sess_->Run(
  96 + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
  97 + decoder_inputs.size(), decoder_output_names_ptr_.data(),
  98 + decoder_output_names_ptr_.size());
  99 +
  100 + std::vector<Ort::Value> states_next;
  101 + states_next.reserve(states.size());
  102 +
  103 + // decoder_out[0]: decoder_output
  104 + // decoder_out[1]: decoder_output_length
  105 + // decoder_out[2:] states_next
  106 +
  107 + for (int32_t i = 0; i != states.size(); ++i) {
  108 + states_next.push_back(std::move(decoder_out[i + 2]));
  109 + }
  110 +
  111 + // we discard decoder_out[1]
  112 + return {std::move(decoder_out[0]), std::move(states_next)};
  113 + }
  114 +
  115 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
  116 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  117 + std::move(decoder_out)};
  118 + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
  119 + joiner_input.data(), joiner_input.size(),
  120 + joiner_output_names_ptr_.data(),
  121 + joiner_output_names_ptr_.size());
  122 +
  123 + return std::move(logit[0]);
  124 + }
  125 +
  126 + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
  127 + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  128 + Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
  129 + s0_shape.size());
  130 +
  131 + Fill<float>(&s0, 0);
  132 +
  133 + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  134 +
  135 + Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
  136 + s1_shape.size());
  137 +
  138 + Fill<float>(&s1, 0);
  139 +
  140 + std::vector<Ort::Value> states;
  141 +
  142 + states.reserve(2);
  143 + states.push_back(std::move(s0));
  144 + states.push_back(std::move(s1));
  145 +
  146 + return states;
  147 + }
  148 +
  149 + int32_t SubsamplingFactor() const { return subsampling_factor_; }
  150 + int32_t VocabSize() const { return vocab_size_; }
  151 +
  152 + OrtAllocator *Allocator() const { return allocator_; }
  153 +
  154 + std::string FeatureNormalizationMethod() const { return normalize_type_; }
  155 +
  156 + private:
  157 + void InitEncoder(void *model_data, size_t model_data_length) {
  158 + encoder_sess_ = std::make_unique<Ort::Session>(
  159 + env_, model_data, model_data_length, sess_opts_);
  160 +
  161 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  162 + &encoder_input_names_ptr_);
  163 +
  164 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  165 + &encoder_output_names_ptr_);
  166 +
  167 + // get meta data
  168 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  169 + if (config_.debug) {
  170 + std::ostringstream os;
  171 + os << "---encoder---\n";
  172 + PrintModelMetadata(os, meta_data);
  173 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  174 + }
  175 +
  176 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  177 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  178 +
  179 + // need to increase by 1 since the blank token is not included in computing
  180 + // vocab_size in NeMo.
  181 + vocab_size_ += 1;
  182 +
  183 + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
  184 + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
  185 + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
  186 + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
  187 +
  188 + if (normalize_type_ == "NA") {
  189 + normalize_type_ = "";
  190 + }
  191 + }
  192 +
  193 + void InitDecoder(void *model_data, size_t model_data_length) {
  194 + decoder_sess_ = std::make_unique<Ort::Session>(
  195 + env_, model_data, model_data_length, sess_opts_);
  196 +
  197 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  198 + &decoder_input_names_ptr_);
  199 +
  200 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  201 + &decoder_output_names_ptr_);
  202 + }
  203 +
  204 + void InitJoiner(void *model_data, size_t model_data_length) {
  205 + joiner_sess_ = std::make_unique<Ort::Session>(
  206 + env_, model_data, model_data_length, sess_opts_);
  207 +
  208 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  209 + &joiner_input_names_ptr_);
  210 +
  211 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  212 + &joiner_output_names_ptr_);
  213 + }
  214 +
  215 + private:
  216 + OfflineModelConfig config_;
  217 + Ort::Env env_;
  218 + Ort::SessionOptions sess_opts_;
  219 + Ort::AllocatorWithDefaultOptions allocator_;
  220 +
  221 + std::unique_ptr<Ort::Session> encoder_sess_;
  222 + std::unique_ptr<Ort::Session> decoder_sess_;
  223 + std::unique_ptr<Ort::Session> joiner_sess_;
  224 +
  225 + std::vector<std::string> encoder_input_names_;
  226 + std::vector<const char *> encoder_input_names_ptr_;
  227 +
  228 + std::vector<std::string> encoder_output_names_;
  229 + std::vector<const char *> encoder_output_names_ptr_;
  230 +
  231 + std::vector<std::string> decoder_input_names_;
  232 + std::vector<const char *> decoder_input_names_ptr_;
  233 +
  234 + std::vector<std::string> decoder_output_names_;
  235 + std::vector<const char *> decoder_output_names_ptr_;
  236 +
  237 + std::vector<std::string> joiner_input_names_;
  238 + std::vector<const char *> joiner_input_names_ptr_;
  239 +
  240 + std::vector<std::string> joiner_output_names_;
  241 + std::vector<const char *> joiner_output_names_ptr_;
  242 +
  243 + int32_t vocab_size_ = 0;
  244 + int32_t subsampling_factor_ = 8;
  245 + std::string normalize_type_;
  246 + int32_t pred_rnn_layers_ = -1;
  247 + int32_t pred_hidden_ = -1;
  248 +};
  249 +
  250 +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
  251 + const OfflineModelConfig &config)
  252 + : impl_(std::make_unique<Impl>(config)) {}
  253 +
  254 +#if __ANDROID_API__ >= 9
  255 +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
  256 + AAssetManager *mgr, const OfflineModelConfig &config)
  257 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  258 +#endif
  259 +
  260 +OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default;
  261 +
  262 +std::vector<Ort::Value> OfflineTransducerNeMoModel::RunEncoder(
  263 + Ort::Value features, Ort::Value features_length) const {
  264 + return impl_->RunEncoder(std::move(features), std::move(features_length));
  265 +}
  266 +
  267 +std::pair<Ort::Value, std::vector<Ort::Value>>
  268 +OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets,
  269 + Ort::Value targets_length,
  270 + std::vector<Ort::Value> states) const {
  271 + return impl_->RunDecoder(std::move(targets), std::move(targets_length),
  272 + std::move(states));
  273 +}
  274 +
  275 +std::vector<Ort::Value> OfflineTransducerNeMoModel::GetDecoderInitStates(
  276 + int32_t batch_size) const {
  277 + return impl_->GetDecoderInitStates(batch_size);
  278 +}
  279 +
  280 +Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
  281 + Ort::Value decoder_out) const {
  282 + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
  283 +}
  284 +
  285 +int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const {
  286 + return impl_->SubsamplingFactor();
  287 +}
  288 +
  289 +int32_t OfflineTransducerNeMoModel::VocabSize() const {
  290 + return impl_->VocabSize();
  291 +}
  292 +
  293 +OrtAllocator *OfflineTransducerNeMoModel::Allocator() const {
  294 + return impl_->Allocator();
  295 +}
  296 +
  297 +std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
  298 + return impl_->FeatureNormalizationMethod();
  299 +}
  300 +
  301 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-nemo-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "onnxruntime_cxx_api.h" // NOLINT
  18 +#include "sherpa-onnx/csrc/offline-model-config.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +// see
  23 +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
  24 +// Its decoder is stateful, not stateless.
  25 +class OfflineTransducerNeMoModel {
  26 + public:
  27 + explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config);
  28 +
  29 +#if __ANDROID_API__ >= 9
  30 + OfflineTransducerNeMoModel(AAssetManager *mgr,
  31 + const OfflineModelConfig &config);
  32 +#endif
  33 +
  34 + ~OfflineTransducerNeMoModel();
  35 +
  36 + /** Run the encoder.
  37 + *
  38 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  39 + * @param features_length A 1-D tensor of shape (N,) containing number of
  40 + * valid frames in `features` before padding.
  41 + * Its dtype is int64_t.
  42 + *
  43 + * @return Return a vector containing:
  44 + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
  45 + * - encoder_out_length: A 1-D tensor of shape (N,) containing number
  46 + * of frames in `encoder_out` before padding.
  47 + */
  48 + std::vector<Ort::Value> RunEncoder(Ort::Value features,
  49 + Ort::Value features_length) const;
  50 +
  51 + /** Run the decoder network.
  52 + *
  53 + * @param targets A int32 tensor of shape (batch_size, 1)
  54 + * @param targets_length A int32 tensor of shape (batch_size,)
  55 + * @param states The states for the decoder model.
  56 + * @return Return a vector:
  57 + * - ans[0] is the decoder_out (a float tensor)
  58 + * - ans[1] is the decoder_out_length (a int32 tensor)
  59 + * - ans[2:] is the states_next
  60 + */
  61 + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
  62 + Ort::Value targets, Ort::Value targets_length,
  63 + std::vector<Ort::Value> states) const;
  64 +
  65 + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
  66 +
  67 + /** Run the joint network.
  68 + *
  69 + * @param encoder_out Output of the encoder network.
  70 + * @param decoder_out Output of the decoder network.
  71 + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
  72 + */
  73 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
  74 +
  75 + /** Return the subsampling factor of the model.
  76 + */
  77 + int32_t SubsamplingFactor() const;
  78 +
  79 + int32_t VocabSize() const;
  80 +
  81 + /** Return an allocator for allocating memory
  82 + */
  83 + OrtAllocator *Allocator() const;
  84 +
  85 + // Possible values:
  86 + // - per_feature
  87 + // - all_features (not implemented yet)
  88 + // - fixed_mean (not implemented)
  89 + // - fixed_std (not implemented)
  90 + // - or just leave it to empty
  91 + // See
  92 + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
  93 + // for details
  94 + std::string FeatureNormalizationMethod() const;
  95 +
  96 + private:
  97 + class Impl;
  98 + std::unique_ptr<Impl> impl_;
  99 +};
  100 +
  101 +} // namespace sherpa_onnx
  102 +
  103 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
@@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
223 223
224 private: 224 private:
225 void InitDecoder() { 225 void InitDecoder() {
226 - if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&  
227 - !sym_.contains("<blank>")) { 226 + if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
  227 + !sym_.Contains("<blank>")) {
228 SHERPA_ONNX_LOGE( 228 SHERPA_ONNX_LOGE(
229 "We expect that tokens.txt contains " 229 "We expect that tokens.txt contains "
230 "the symbol <blk> or <eps> or <blank> and its ID."); 230 "the symbol <blk> or <eps> or <blank> and its ID.");
@@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
232 } 232 }
233 233
234 int32_t blank_id = 0; 234 int32_t blank_id = 0;
235 - if (sym_.contains("<blk>")) { 235 + if (sym_.Contains("<blk>")) {
236 blank_id = sym_["<blk>"]; 236 blank_id = sym_["<blk>"];
237 - } else if (sym_.contains("<eps>")) { 237 + } else if (sym_.Contains("<eps>")) {
238 // for tdnn models of the yesno recipe from icefall 238 // for tdnn models of the yesno recipe from icefall
239 blank_id = sym_["<eps>"]; 239 blank_id = sym_["<eps>"];
240 - } else if (sym_.contains("<blank>")) { 240 + } else if (sym_.Contains("<blank>")) {
241 // for WeNet CTC models 241 // for WeNet CTC models
242 blank_id = sym_["<blank>"]; 242 blank_id = sym_["<blank>"];
243 } 243 }
@@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
87 model_(OnlineTransducerModel::Create(config.model_config)), 87 model_(OnlineTransducerModel::Create(config.model_config)),
88 sym_(config.model_config.tokens), 88 sym_(config.model_config.tokens),
89 endpoint_(config_.endpoint_config) { 89 endpoint_(config_.endpoint_config) {
90 - if (sym_.contains("<unk>")) { 90 + if (sym_.Contains("<unk>")) {
91 unk_id_ = sym_["<unk>"]; 91 unk_id_ = sym_["<unk>"];
92 } 92 }
93 93
@@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
103 } 103 }
104 104
105 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 105 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
106 - model_.get(),  
107 - lm_.get(),  
108 - config_.max_active_paths,  
109 - config_.lm_config.scale,  
110 - unk_id_,  
111 - config_.blank_penalty, 106 + model_.get(), lm_.get(), config_.max_active_paths,
  107 + config_.lm_config.scale, unk_id_, config_.blank_penalty,
112 config_.temperature_scale); 108 config_.temperature_scale);
113 109
114 } else if (config.decoding_method == "greedy_search") { 110 } else if (config.decoding_method == "greedy_search") {
115 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 111 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
116 - model_.get(),  
117 - unk_id_,  
118 - config_.blank_penalty, 112 + model_.get(), unk_id_, config_.blank_penalty,
119 config_.temperature_scale); 113 config_.temperature_scale);
120 114
121 } else { 115 } else {
@@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
132 model_(OnlineTransducerModel::Create(mgr, config.model_config)), 126 model_(OnlineTransducerModel::Create(mgr, config.model_config)),
133 sym_(mgr, config.model_config.tokens), 127 sym_(mgr, config.model_config.tokens),
134 endpoint_(config_.endpoint_config) { 128 endpoint_(config_.endpoint_config) {
135 - if (sym_.contains("<unk>")) { 129 + if (sym_.Contains("<unk>")) {
136 unk_id_ = sym_["<unk>"]; 130 unk_id_ = sym_["<unk>"];
137 } 131 }
138 132
@@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
151 } 145 }
152 146
153 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 147 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
154 - model_.get(),  
155 - lm_.get(),  
156 - config_.max_active_paths,  
157 - config_.lm_config.scale,  
158 - unk_id_,  
159 - config_.blank_penalty, 148 + model_.get(), lm_.get(), config_.max_active_paths,
  149 + config_.lm_config.scale, unk_id_, config_.blank_penalty,
160 config_.temperature_scale); 150 config_.temperature_scale);
161 151
162 } else if (config.decoding_method == "greedy_search") { 152 } else if (config.decoding_method == "greedy_search") {
163 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 153 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
164 - model_.get(),  
165 - unk_id_,  
166 - config_.blank_penalty, 154 + model_.get(), unk_id_, config_.blank_penalty,
167 config_.temperature_scale); 155 config_.temperature_scale);
168 156
169 } else { 157 } else {
@@ -13,7 +13,7 @@ namespace sherpa_onnx { @@ -13,7 +13,7 @@ namespace sherpa_onnx {
13 * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] 13 * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
14 * 14 *
15 * @param allocator 15 * @param allocator
16 - * @param v A 2-D tensor. Its data type is T. 16 + * @param v A 3-D tensor. Its data type is T.
17 * @param dim0_start Start index of the first dimension.. 17 * @param dim0_start Start index of the first dimension..
18 * @param dim0_end End index of the first dimension.. 18 * @param dim0_end End index of the first dimension..
19 * @param dim1_start Start index of the second dimension. 19 * @param dim1_start Start index of the second dimension.
@@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const { @@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const {
100 return sym2id_.at(sym); 100 return sym2id_.at(sym);
101 } 101 }
102 102
103 -bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } 103 +bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; }
104 104
105 -bool SymbolTable::contains(const std::string &sym) const { 105 +bool SymbolTable::Contains(const std::string &sym) const {
106 return sym2id_.count(sym) != 0; 106 return sym2id_.count(sym) != 0;
107 } 107 }
108 108
@@ -40,14 +40,16 @@ class SymbolTable { @@ -40,14 +40,16 @@ class SymbolTable {
40 int32_t operator[](const std::string &sym) const; 40 int32_t operator[](const std::string &sym) const;
41 41
42 /// Return true if there is a symbol with the given ID. 42 /// Return true if there is a symbol with the given ID.
43 - bool contains(int32_t id) const; 43 + bool Contains(int32_t id) const;
44 44
45 /// Return true if there is a given symbol in the symbol table. 45 /// Return true if there is a given symbol in the symbol table.
46 - bool contains(const std::string &sym) const; 46 + bool Contains(const std::string &sym) const;
47 47
48 // for tokens.txt from Whisper 48 // for tokens.txt from Whisper
49 void ApplyBase64Decode(); 49 void ApplyBase64Decode();
50 50
  51 + int32_t NumSymbols() const { return id2sym_.size(); }
  52 +
51 private: 53 private:
52 void Init(std::istream &is); 54 void Init(std::istream &is);
53 55
@@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, @@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
49 word = word.replace(0, 3, " "); 49 word = word.replace(0, 3, " ");
50 } 50 }
51 } 51 }
52 - if (symbol_table.contains(word)) { 52 + if (symbol_table.Contains(word)) {
53 int32_t id = symbol_table[word]; 53 int32_t id = symbol_table[word];
54 tmp_ids.push_back(id); 54 tmp_ids.push_back(id);
55 } else { 55 } else {
@@ -14,10 +14,10 @@ namespace sherpa_onnx { @@ -14,10 +14,10 @@ namespace sherpa_onnx {
14 static void PybindOfflineRecognizerConfig(py::module *m) { 14 static void PybindOfflineRecognizerConfig(py::module *m) {
15 using PyClass = OfflineRecognizerConfig; 15 using PyClass = OfflineRecognizerConfig;
16 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
17 - .def(py::init<const OfflineFeatureExtractorConfig &,  
18 - const OfflineModelConfig &, const OfflineLMConfig &,  
19 - const OfflineCtcFstDecoderConfig &, const std::string &,  
20 - int32_t, const std::string &, float, float>(), 17 + .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
  18 + const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
  19 + const std::string &, int32_t, const std::string &, float,
  20 + float>(),
21 py::arg("feat_config"), py::arg("model_config"), 21 py::arg("feat_config"), py::arg("model_config"),
22 py::arg("lm_config") = OfflineLMConfig(), 22 py::arg("lm_config") = OfflineLMConfig(),
23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
@@ -25,6 +25,7 @@ Args: @@ -25,6 +25,7 @@ Args:
25 static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT 25 static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
26 using PyClass = OfflineRecognitionResult; 26 using PyClass = OfflineRecognitionResult;
27 py::class_<PyClass>(*m, "OfflineRecognitionResult") 27 py::class_<PyClass>(*m, "OfflineRecognitionResult")
  28 + .def("__str__", &PyClass::AsJsonString)
28 .def_property_readonly( 29 .def_property_readonly(
29 "text", 30 "text",
30 [](const PyClass &self) -> py::str { 31 [](const PyClass &self) -> py::str {
@@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT @@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
37 "timestamps", [](const PyClass &self) { return self.timestamps; }); 38 "timestamps", [](const PyClass &self) { return self.timestamps; });
38 } 39 }
39 40
40 -static void PybindOfflineFeatureExtractorConfig(py::module *m) {  
41 - using PyClass = OfflineFeatureExtractorConfig;  
42 - py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")  
43 - .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,  
44 - py::arg("feature_dim") = 80)  
45 - .def_readwrite("sampling_rate", &PyClass::sampling_rate)  
46 - .def_readwrite("feature_dim", &PyClass::feature_dim)  
47 - .def("__str__", &PyClass::ToString);  
48 -}  
49 -  
50 void PybindOfflineStream(py::module *m) { 41 void PybindOfflineStream(py::module *m) {
51 - PybindOfflineFeatureExtractorConfig(m);  
52 PybindOfflineRecognitionResult(m); 42 PybindOfflineRecognitionResult(m);
53 43
54 using PyClass = OfflineStream; 44 using PyClass = OfflineStream;
@@ -4,8 +4,8 @@ from pathlib import Path @@ -4,8 +4,8 @@ from pathlib import Path
4 from typing import List, Optional 4 from typing import List, Optional
5 5
6 from _sherpa_onnx import ( 6 from _sherpa_onnx import (
  7 + FeatureExtractorConfig,
7 OfflineCtcFstDecoderConfig, 8 OfflineCtcFstDecoderConfig,
8 - OfflineFeatureExtractorConfig,  
9 OfflineModelConfig, 9 OfflineModelConfig,
10 OfflineNemoEncDecCtcModelConfig, 10 OfflineNemoEncDecCtcModelConfig,
11 OfflineParaformerModelConfig, 11 OfflineParaformerModelConfig,
@@ -51,6 +51,7 @@ class OfflineRecognizer(object): @@ -51,6 +51,7 @@ class OfflineRecognizer(object):
51 blank_penalty: float = 0.0, 51 blank_penalty: float = 0.0,
52 debug: bool = False, 52 debug: bool = False,
53 provider: str = "cpu", 53 provider: str = "cpu",
  54 + model_type: str = "transducer",
54 ): 55 ):
55 """ 56 """
56 Please refer to 57 Please refer to
@@ -106,10 +107,10 @@ class OfflineRecognizer(object): @@ -106,10 +107,10 @@ class OfflineRecognizer(object):
106 num_threads=num_threads, 107 num_threads=num_threads,
107 debug=debug, 108 debug=debug,
108 provider=provider, 109 provider=provider,
109 - model_type="transducer", 110 + model_type=model_type,
110 ) 111 )
111 112
112 - feat_config = OfflineFeatureExtractorConfig( 113 + feat_config = FeatureExtractorConfig(
113 sampling_rate=sample_rate, 114 sampling_rate=sample_rate,
114 feature_dim=feature_dim, 115 feature_dim=feature_dim,
115 ) 116 )
@@ -182,7 +183,7 @@ class OfflineRecognizer(object): @@ -182,7 +183,7 @@ class OfflineRecognizer(object):
182 model_type="paraformer", 183 model_type="paraformer",
183 ) 184 )
184 185
185 - feat_config = OfflineFeatureExtractorConfig( 186 + feat_config = FeatureExtractorConfig(
186 sampling_rate=sample_rate, 187 sampling_rate=sample_rate,
187 feature_dim=feature_dim, 188 feature_dim=feature_dim,
188 ) 189 )
@@ -246,7 +247,7 @@ class OfflineRecognizer(object): @@ -246,7 +247,7 @@ class OfflineRecognizer(object):
246 model_type="nemo_ctc", 247 model_type="nemo_ctc",
247 ) 248 )
248 249
249 - feat_config = OfflineFeatureExtractorConfig( 250 + feat_config = FeatureExtractorConfig(
250 sampling_rate=sample_rate, 251 sampling_rate=sample_rate,
251 feature_dim=feature_dim, 252 feature_dim=feature_dim,
252 ) 253 )
@@ -326,7 +327,7 @@ class OfflineRecognizer(object): @@ -326,7 +327,7 @@ class OfflineRecognizer(object):
326 model_type="whisper", 327 model_type="whisper",
327 ) 328 )
328 329
329 - feat_config = OfflineFeatureExtractorConfig( 330 + feat_config = FeatureExtractorConfig(
330 sampling_rate=16000, 331 sampling_rate=16000,
331 feature_dim=80, 332 feature_dim=80,
332 ) 333 )
@@ -389,7 +390,7 @@ class OfflineRecognizer(object): @@ -389,7 +390,7 @@ class OfflineRecognizer(object):
389 model_type="tdnn", 390 model_type="tdnn",
390 ) 391 )
391 392
392 - feat_config = OfflineFeatureExtractorConfig( 393 + feat_config = FeatureExtractorConfig(
393 sampling_rate=sample_rate, 394 sampling_rate=sample_rate,
394 feature_dim=feature_dim, 395 feature_dim=feature_dim,
395 ) 396 )
@@ -453,7 +454,7 @@ class OfflineRecognizer(object): @@ -453,7 +454,7 @@ class OfflineRecognizer(object):
453 model_type="wenet_ctc", 454 model_type="wenet_ctc",
454 ) 455 )
455 456
456 - feat_config = OfflineFeatureExtractorConfig( 457 + feat_config = FeatureExtractorConfig(
457 sampling_rate=sample_rate, 458 sampling_rate=sample_rate,
458 feature_dim=feature_dim, 459 feature_dim=feature_dim,
459 ) 460 )