Committed by
GitHub
Add C++ and Python API for Dolphin CTC models (#2085)
正在显示
27 个修改的文件
包含
671 行增加
和
26 行删除
| @@ -15,6 +15,39 @@ echo "PATH: $PATH" | @@ -15,6 +15,39 @@ echo "PATH: $PATH" | ||
| 15 | 15 | ||
| 16 | which $EXE | 16 | which $EXE |
| 17 | 17 | ||
| 18 | +for type in base small; do | ||
| 19 | + log "------------------------------------------------------------" | ||
| 20 | + log "Run Dolphin CTC models ($type int8)" | ||
| 21 | + log "------------------------------------------------------------" | ||
| 22 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 23 | + tar xvf sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 24 | + rm sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 25 | + | ||
| 26 | + $EXE \ | ||
| 27 | + --dolphin-model=./sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02/model.int8.onnx \ | ||
| 28 | + --tokens=./sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02/tokens.txt \ | ||
| 29 | + --debug=1 \ | ||
| 30 | + ./sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02/test_wavs/0.wav | ||
| 31 | + | ||
| 32 | + rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02 | ||
| 33 | + | ||
| 34 | + log "------------------------------------------------------------" | ||
| 35 | + log "Run Dolphin CTC models ($type)" | ||
| 36 | + log "------------------------------------------------------------" | ||
| 37 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02.tar.bz2 | ||
| 38 | + tar xvf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02.tar.bz2 | ||
| 39 | + rm sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02.tar.bz2 | ||
| 40 | + | ||
| 41 | + $EXE \ | ||
| 42 | + --dolphin-model=./sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02/model.onnx \ | ||
| 43 | + --tokens=./sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02/tokens.txt \ | ||
| 44 | + --debug=1 \ | ||
| 45 | + ./sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02/test_wavs/0.wav | ||
| 46 | + | ||
| 47 | + rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 | ||
| 48 | +done | ||
| 49 | + | ||
| 50 | + | ||
| 18 | log "------------------------------------------------------------" | 51 | log "------------------------------------------------------------" |
| 19 | log "Run NeMo GigaAM Russian models" | 52 | log "Run NeMo GigaAM Russian models" |
| 20 | log "------------------------------------------------------------" | 53 | log "------------------------------------------------------------" |
| @@ -8,6 +8,15 @@ log() { | @@ -8,6 +8,15 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +log "test offline dolphin ctc" | ||
| 12 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 13 | +tar xvf sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 14 | +rm sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 | ||
| 15 | + | ||
| 16 | +python3 ./python-api-examples/offline-dolphin-ctc-decode-files.py | ||
| 17 | + | ||
| 18 | +rm -rf sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02 | ||
| 19 | + | ||
| 11 | log "test offline speech enhancement (GTCRN)" | 20 | log "test offline speech enhancement (GTCRN)" |
| 12 | 21 | ||
| 13 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | 22 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx |
| 1 | +name: export-dolphin-ctc-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-dolphin-ctc-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-dolphin-ctc-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: ${{ matrix.model_type }} | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + model_type: [small, base] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Download ${{ matrix.model_type }} | ||
| 25 | + shell: bash | ||
| 26 | + run: | | ||
| 27 | + git lfs install | ||
| 28 | + type=${{ matrix.model_type }} | ||
| 29 | + | ||
| 30 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02 | ||
| 31 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 | ||
| 32 | + | ||
| 33 | + rm -rf sherpa-onnx-dolphin-*/.git* | ||
| 34 | + | ||
| 35 | + ls -lha sherpa-onnx-dolphin-*/ | ||
| 36 | + | ||
| 37 | + tar cjfv sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02.tar.bz2 sherpa-onnx-dolphin-$type-ctc-multi-lang-int8-2025-04-02 | ||
| 38 | + tar cjfv sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02.tar.bz2 sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 | ||
| 39 | + | ||
| 40 | + - name: Release | ||
| 41 | + uses: svenstaro/upload-release-action@v2 | ||
| 42 | + with: | ||
| 43 | + file_glob: true | ||
| 44 | + file: ./*.tar.bz2 | ||
| 45 | + overwrite: true | ||
| 46 | + repo_name: k2-fsa/sherpa-onnx | ||
| 47 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 48 | + tag: asr-models |
| @@ -205,6 +205,16 @@ jobs: | @@ -205,6 +205,16 @@ jobs: | ||
| 205 | overwrite: true | 205 | overwrite: true |
| 206 | file: sherpa-onnx-*.tar.bz2 | 206 | file: sherpa-onnx-*.tar.bz2 |
| 207 | 207 | ||
| 208 | + - name: Test offline CTC | ||
| 209 | + shell: bash | ||
| 210 | + run: | | ||
| 211 | + du -h -d1 . | ||
| 212 | + export PATH=$PWD/build/bin:$PATH | ||
| 213 | + export EXE=sherpa-onnx-offline | ||
| 214 | + | ||
| 215 | + .github/scripts/test-offline-ctc.sh | ||
| 216 | + du -h -d1 . | ||
| 217 | + | ||
| 208 | - name: Test offline speech denoiser | 218 | - name: Test offline speech denoiser |
| 209 | shell: bash | 219 | shell: bash |
| 210 | run: | | 220 | run: | |
| @@ -249,16 +259,6 @@ jobs: | @@ -249,16 +259,6 @@ jobs: | ||
| 249 | .github/scripts/test-offline-moonshine.sh | 259 | .github/scripts/test-offline-moonshine.sh |
| 250 | du -h -d1 . | 260 | du -h -d1 . |
| 251 | 261 | ||
| 252 | - - name: Test offline CTC | ||
| 253 | - shell: bash | ||
| 254 | - run: | | ||
| 255 | - du -h -d1 . | ||
| 256 | - export PATH=$PWD/build/bin:$PATH | ||
| 257 | - export EXE=sherpa-onnx-offline | ||
| 258 | - | ||
| 259 | - .github/scripts/test-offline-ctc.sh | ||
| 260 | - du -h -d1 . | ||
| 261 | - | ||
| 262 | - name: Test C++ API | 262 | - name: Test C++ API |
| 263 | shell: bash | 263 | shell: bash |
| 264 | run: | | 264 | run: | |
| @@ -162,6 +162,14 @@ jobs: | @@ -162,6 +162,14 @@ jobs: | ||
| 162 | overwrite: true | 162 | overwrite: true |
| 163 | file: sherpa-onnx-*osx-universal2*.tar.bz2 | 163 | file: sherpa-onnx-*osx-universal2*.tar.bz2 |
| 164 | 164 | ||
| 165 | + - name: Test offline CTC | ||
| 166 | + shell: bash | ||
| 167 | + run: | | ||
| 168 | + export PATH=$PWD/build/bin:$PATH | ||
| 169 | + export EXE=sherpa-onnx-offline | ||
| 170 | + | ||
| 171 | + .github/scripts/test-offline-ctc.sh | ||
| 172 | + | ||
| 165 | - name: Test offline speech denoiser | 173 | - name: Test offline speech denoiser |
| 166 | shell: bash | 174 | shell: bash |
| 167 | run: | | 175 | run: | |
| @@ -226,14 +234,6 @@ jobs: | @@ -226,14 +234,6 @@ jobs: | ||
| 226 | 234 | ||
| 227 | .github/scripts/test-online-punctuation.sh | 235 | .github/scripts/test-online-punctuation.sh |
| 228 | 236 | ||
| 229 | - - name: Test offline CTC | ||
| 230 | - shell: bash | ||
| 231 | - run: | | ||
| 232 | - export PATH=$PWD/build/bin:$PATH | ||
| 233 | - export EXE=sherpa-onnx-offline | ||
| 234 | - | ||
| 235 | - .github/scripts/test-offline-ctc.sh | ||
| 236 | - | ||
| 237 | - name: Test online CTC | 237 | - name: Test online CTC |
| 238 | shell: bash | 238 | shell: bash |
| 239 | run: | | 239 | run: | |
| 1 | +if (CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") | ||
| 2 | + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) | ||
| 3 | +endif() | ||
| 4 | + | ||
| 1 | cmake_minimum_required(VERSION 3.13 FATAL_ERROR) | 5 | cmake_minimum_required(VERSION 3.13 FATAL_ERROR) |
| 2 | 6 | ||
| 3 | set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14" CACHE STRING "Minimum OS X deployment version. Used only for macOS") | 7 | set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14" CACHE STRING "Minimum OS X deployment version. Used only for macOS") |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming CTC model from Dolphin | ||
| 5 | +to decode files. | ||
| 6 | + | ||
| 7 | +Please download model files from | ||
| 8 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 9 | +""" | ||
| 10 | + | ||
| 11 | +from pathlib import Path | ||
| 12 | +import time | ||
| 13 | + | ||
| 14 | +import sherpa_onnx | ||
| 15 | +import soundfile as sf | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def create_recognizer(): | ||
| 19 | + model = "./sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02/model.int8.onnx" | ||
| 20 | + tokens = "./sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02/tokens.txt" | ||
| 21 | + test_wav = ( | ||
| 22 | + "./sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02/test_wavs/0.wav" | ||
| 23 | + ) | ||
| 24 | + | ||
| 25 | + if not Path(model).is_file() or not Path(test_wav).is_file(): | ||
| 26 | + raise ValueError( | ||
| 27 | + """Please download model files from | ||
| 28 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 29 | + """ | ||
| 30 | + ) | ||
| 31 | + return ( | ||
| 32 | + sherpa_onnx.OfflineRecognizer.from_dolphin_ctc( | ||
| 33 | + model=model, | ||
| 34 | + tokens=tokens, | ||
| 35 | + debug=True, | ||
| 36 | + ), | ||
| 37 | + test_wav, | ||
| 38 | + ) | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def main(): | ||
| 42 | + recognizer, wave_filename = create_recognizer() | ||
| 43 | + | ||
| 44 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 45 | + audio = audio[:, 0] # only use the first channel | ||
| 46 | + | ||
| 47 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 48 | + # sample_rate does not need to be 16000 Hz | ||
| 49 | + | ||
| 50 | + start = time.time() | ||
| 51 | + stream = recognizer.create_stream() | ||
| 52 | + stream.accept_waveform(sample_rate, audio) | ||
| 53 | + recognizer.decode_stream(stream) | ||
| 54 | + end = time.time() | ||
| 55 | + | ||
| 56 | + print(wave_filename) | ||
| 57 | + print(stream.result) | ||
| 58 | + | ||
| 59 | + elapsed_seconds = end - start | ||
| 60 | + audio_duration = len(audio) / sample_rate | ||
| 61 | + real_time_factor = elapsed_seconds / audio_duration | ||
| 62 | + | ||
| 63 | + print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 64 | + print(f"Audio duration in seconds: {audio_duration:.3f}") | ||
| 65 | + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") | ||
| 66 | + | ||
| 67 | + | ||
| 68 | +if __name__ == "__main__": | ||
| 69 | + main() |
| @@ -27,6 +27,8 @@ set(sources | @@ -27,6 +27,8 @@ set(sources | ||
| 27 | offline-ctc-fst-decoder.cc | 27 | offline-ctc-fst-decoder.cc |
| 28 | offline-ctc-greedy-search-decoder.cc | 28 | offline-ctc-greedy-search-decoder.cc |
| 29 | offline-ctc-model.cc | 29 | offline-ctc-model.cc |
| 30 | + offline-dolphin-model-config.cc | ||
| 31 | + offline-dolphin-model.cc | ||
| 30 | offline-fire-red-asr-greedy-search-decoder.cc | 32 | offline-fire-red-asr-greedy-search-decoder.cc |
| 31 | offline-fire-red-asr-model-config.cc | 33 | offline-fire-red-asr-model-config.cc |
| 32 | offline-fire-red-asr-model.cc | 34 | offline-fire-red-asr-model.cc |
| @@ -20,6 +20,7 @@ | @@ -20,6 +20,7 @@ | ||
| 20 | 20 | ||
| 21 | #include "sherpa-onnx/csrc/file-utils.h" | 21 | #include "sherpa-onnx/csrc/file-utils.h" |
| 22 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 23 | +#include "sherpa-onnx/csrc/offline-dolphin-model.h" | ||
| 23 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | 24 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" |
| 24 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | 25 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" |
| 25 | #include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h" | 26 | #include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h" |
| @@ -110,6 +111,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -110,6 +111,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 110 | 111 | ||
| 111 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 112 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 112 | const OfflineModelConfig &config) { | 113 | const OfflineModelConfig &config) { |
| 114 | + if (!config.dolphin.model.empty()) { | ||
| 115 | + return std::make_unique<OfflineDolphinModel>(config); | ||
| 116 | + } | ||
| 117 | + | ||
| 113 | // TODO(fangjun): Refactor it. We don't need to use model_type here | 118 | // TODO(fangjun): Refactor it. We don't need to use model_type here |
| 114 | ModelType model_type = ModelType::kUnknown; | 119 | ModelType model_type = ModelType::kUnknown; |
| 115 | 120 | ||
| @@ -160,6 +165,10 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -160,6 +165,10 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 160 | template <typename Manager> | 165 | template <typename Manager> |
| 161 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 166 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 162 | Manager *mgr, const OfflineModelConfig &config) { | 167 | Manager *mgr, const OfflineModelConfig &config) { |
| 168 | + if (!config.dolphin.model.empty()) { | ||
| 169 | + return std::make_unique<OfflineDolphinModel>(mgr, config); | ||
| 170 | + } | ||
| 171 | + | ||
| 163 | // TODO(fangjun): Refactor it. We don't need to use model_type here | 172 | // TODO(fangjun): Refactor it. We don't need to use model_type here |
| 164 | ModelType model_type = ModelType::kUnknown; | 173 | ModelType model_type = ModelType::kUnknown; |
| 165 | 174 |
| @@ -64,6 +64,10 @@ class OfflineCtcModel { | @@ -64,6 +64,10 @@ class OfflineCtcModel { | ||
| 64 | // return true for models from https://github.com/salute-developers/GigaAM | 64 | // return true for models from https://github.com/salute-developers/GigaAM |
| 65 | // return false otherwise | 65 | // return false otherwise |
| 66 | virtual bool IsGigaAM() const { return false; } | 66 | virtual bool IsGigaAM() const { return false; } |
| 67 | + | ||
| 68 | + // For Dolphin models, they use global CMVN | ||
| 69 | + virtual void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 70 | + int32_t feat_dim) const {} | ||
| 67 | }; | 71 | }; |
| 68 | 72 | ||
| 69 | } // namespace sherpa_onnx | 73 | } // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-dolphin-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-dolphin-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineDolphinModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("dolphin-model", &model, | ||
| 14 | + "Path to model.onnx of Dolphin CTC branch."); | ||
| 15 | +} | ||
| 16 | + | ||
| 17 | +bool OfflineDolphinModelConfig::Validate() const { | ||
| 18 | + if (!FileExists(model)) { | ||
| 19 | + SHERPA_ONNX_LOGE("Dolphin model '%s' does not exist", model.c_str()); | ||
| 20 | + return false; | ||
| 21 | + } | ||
| 22 | + | ||
| 23 | + return true; | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +std::string OfflineDolphinModelConfig::ToString() const { | ||
| 27 | + std::ostringstream os; | ||
| 28 | + | ||
| 29 | + os << "OfflineDolphinModelConfig("; | ||
| 30 | + os << "model=\"" << model << "\")"; | ||
| 31 | + | ||
| 32 | + return os.str(); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-dolphin-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineDolphinModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineDolphinModelConfig() = default; | ||
| 17 | + explicit OfflineDolphinModelConfig(const std::string &model) : model(model) {} | ||
| 18 | + | ||
| 19 | + void Register(ParseOptions *po); | ||
| 20 | + bool Validate() const; | ||
| 21 | + | ||
| 22 | + std::string ToString() const; | ||
| 23 | +}; | ||
| 24 | + | ||
| 25 | +} // namespace sherpa_onnx | ||
| 26 | + | ||
| 27 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-dolphin-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_META_DATA_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_META_DATA_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +struct OfflineDolphinModelMetaData { | ||
| 13 | + int32_t vocab_size; | ||
| 14 | + int32_t subsampling_factor = 4; | ||
| 15 | + std::vector<float> mean; | ||
| 16 | + std::vector<float> inv_stddev; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +} // namespace sherpa_onnx | ||
| 20 | + | ||
| 21 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_META_DATA_H_ |
sherpa-onnx/csrc/offline-dolphin-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-dolphin-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-dolphin-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#if __OHOS__ | ||
| 17 | +#include "rawfile/raw_file_manager.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 21 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 22 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 23 | +#include "sherpa-onnx/csrc/session.h" | ||
| 24 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +class OfflineDolphinModel::Impl { | ||
| 29 | + public: | ||
| 30 | + explicit Impl(const OfflineModelConfig &config) | ||
| 31 | + : config_(config), | ||
| 32 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 33 | + sess_opts_(GetSessionOptions(config)), | ||
| 34 | + allocator_{} { | ||
| 35 | + auto buf = ReadFile(config_.dolphin.model); | ||
| 36 | + Init(buf.data(), buf.size()); | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + template <typename Manager> | ||
| 40 | + Impl(Manager *mgr, const OfflineModelConfig &config) | ||
| 41 | + : config_(config), | ||
| 42 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 43 | + sess_opts_(GetSessionOptions(config)), | ||
| 44 | + allocator_{} { | ||
| 45 | + auto buf = ReadFile(mgr, config_.dolphin.model); | ||
| 46 | + Init(buf.data(), buf.size()); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 50 | + Ort::Value features_length) { | ||
| 51 | + std::array<Ort::Value, 2> inputs = { | ||
| 52 | + std::move(features), | ||
| 53 | + std::move(features_length), | ||
| 54 | + }; | ||
| 55 | + | ||
| 56 | + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 57 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + int32_t VocabSize() const { return meta_data_.vocab_size; } | ||
| 61 | + | ||
| 62 | + int32_t SubsamplingFactor() const { return meta_data_.subsampling_factor; } | ||
| 63 | + | ||
| 64 | + void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 65 | + int32_t feat_dim) const { | ||
| 66 | + auto p = features; | ||
| 67 | + const auto &mean = meta_data_.mean; | ||
| 68 | + const auto &invstd = meta_data_.inv_stddev; | ||
| 69 | + | ||
| 70 | + for (int32_t f = 0; f < num_frames; ++f) { | ||
| 71 | + for (int32_t d = 0; d < feat_dim; ++d) { | ||
| 72 | + p[d] = (p[d] - mean[d]) * invstd[d]; | ||
| 73 | + } | ||
| 74 | + p += feat_dim; | ||
| 75 | + } | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + OrtAllocator *Allocator() { return allocator_; } | ||
| 79 | + | ||
| 80 | + private: | ||
| 81 | + void Init(void *model_data, size_t model_data_length) { | ||
| 82 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 83 | + sess_opts_); | ||
| 84 | + | ||
| 85 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 86 | + | ||
| 87 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 88 | + | ||
| 89 | + // get meta data | ||
| 90 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 91 | + if (config_.debug) { | ||
| 92 | + std::ostringstream os; | ||
| 93 | + PrintModelMetadata(os, meta_data); | ||
| 94 | +#if __OHOS__ | ||
| 95 | + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); | ||
| 96 | +#else | ||
| 97 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 98 | +#endif | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 102 | + SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size"); | ||
| 103 | + | ||
| 104 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "mean"); | ||
| 105 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "invstd"); | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + private: | ||
| 109 | + OfflineModelConfig config_; | ||
| 110 | + Ort::Env env_; | ||
| 111 | + Ort::SessionOptions sess_opts_; | ||
| 112 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 113 | + | ||
| 114 | + std::unique_ptr<Ort::Session> sess_; | ||
| 115 | + | ||
| 116 | + std::vector<std::string> input_names_; | ||
| 117 | + std::vector<const char *> input_names_ptr_; | ||
| 118 | + | ||
| 119 | + std::vector<std::string> output_names_; | ||
| 120 | + std::vector<const char *> output_names_ptr_; | ||
| 121 | + | ||
| 122 | + OfflineDolphinModelMetaData meta_data_; | ||
| 123 | +}; | ||
| 124 | + | ||
| 125 | +OfflineDolphinModel::OfflineDolphinModel(const OfflineModelConfig &config) | ||
| 126 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 127 | + | ||
| 128 | +template <typename Manager> | ||
| 129 | +OfflineDolphinModel::OfflineDolphinModel(Manager *mgr, | ||
| 130 | + const OfflineModelConfig &config) | ||
| 131 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 132 | + | ||
| 133 | +OfflineDolphinModel::~OfflineDolphinModel() = default; | ||
| 134 | + | ||
| 135 | +std::vector<Ort::Value> OfflineDolphinModel::Forward( | ||
| 136 | + Ort::Value features, Ort::Value features_length) { | ||
| 137 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 138 | +} | ||
| 139 | + | ||
| 140 | +int32_t OfflineDolphinModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 141 | + | ||
| 142 | +int32_t OfflineDolphinModel::SubsamplingFactor() const { | ||
| 143 | + return impl_->SubsamplingFactor(); | ||
| 144 | +} | ||
| 145 | + | ||
| 146 | +void OfflineDolphinModel::NormalizeFeatures(float *features, int32_t num_frames, | ||
| 147 | + int32_t feat_dim) const { | ||
| 148 | + return impl_->NormalizeFeatures(features, num_frames, feat_dim); | ||
| 149 | +} | ||
| 150 | + | ||
| 151 | +OrtAllocator *OfflineDolphinModel::Allocator() const { | ||
| 152 | + return impl_->Allocator(); | ||
| 153 | +} | ||
| 154 | + | ||
| 155 | +#if __ANDROID_API__ >= 9 | ||
| 156 | +template OfflineDolphinModel::OfflineDolphinModel( | ||
| 157 | + AAssetManager *mgr, const OfflineModelConfig &config); | ||
| 158 | +#endif | ||
| 159 | + | ||
| 160 | +#if __OHOS__ | ||
| 161 | +template OfflineDolphinModel::OfflineDolphinModel( | ||
| 162 | + NativeResourceManager *mgr, const OfflineModelConfig &config); | ||
| 163 | +#endif | ||
| 164 | + | ||
| 165 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-dolphin-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-dolphin-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-dolphin-model-meta-data.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineDolphinModel : public OfflineCtcModel { | ||
| 18 | + public: | ||
| 19 | + explicit OfflineDolphinModel(const OfflineModelConfig &config); | ||
| 20 | + | ||
| 21 | + template <typename Manager> | ||
| 22 | + OfflineDolphinModel(Manager *mgr, const OfflineModelConfig &config); | ||
| 23 | + | ||
| 24 | + ~OfflineDolphinModel() override; | ||
| 25 | + | ||
| 26 | + /** Run the forward method of the model. | ||
| 27 | + * | ||
| 28 | + * @param features A tensor of shape (N, T, C). | ||
| 29 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 30 | + * valid frames in `features` before padding. | ||
| 31 | + * Its dtype is int64_t. | ||
| 32 | + * | ||
| 33 | + * @return Return a vector containing: | ||
| 34 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 35 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 36 | + */ | ||
| 37 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 38 | + Ort::Value features_length) override; | ||
| 39 | + | ||
| 40 | + /** Return the vocabulary size of the model | ||
| 41 | + */ | ||
| 42 | + int32_t VocabSize() const override; | ||
| 43 | + | ||
| 44 | + /** SubsamplingFactor of the model | ||
| 45 | + * | ||
| 46 | + * For Citrinet, the subsampling factor is usually 4. | ||
| 47 | + * For Conformer CTC, the subsampling factor is usually 8. | ||
| 48 | + */ | ||
| 49 | + int32_t SubsamplingFactor() const override; | ||
| 50 | + | ||
| 51 | + /** Return an allocator for allocating memory | ||
| 52 | + */ | ||
| 53 | + OrtAllocator *Allocator() const override; | ||
| 54 | + | ||
| 55 | + bool SupportBatchProcessing() const override { return true; } | ||
| 56 | + | ||
| 57 | + void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 58 | + int32_t feat_dim) const override; | ||
| 59 | + | ||
| 60 | + private: | ||
| 61 | + class Impl; | ||
| 62 | + std::unique_ptr<Impl> impl_; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx | ||
| 66 | + | ||
| 67 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_DOLPHIN_MODEL_H_ |
| @@ -21,6 +21,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -21,6 +21,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 21 | wenet_ctc.Register(po); | 21 | wenet_ctc.Register(po); |
| 22 | sense_voice.Register(po); | 22 | sense_voice.Register(po); |
| 23 | moonshine.Register(po); | 23 | moonshine.Register(po); |
| 24 | + dolphin.Register(po); | ||
| 24 | 25 | ||
| 25 | po->Register("telespeech-ctc", &telespeech_ctc, | 26 | po->Register("telespeech-ctc", &telespeech_ctc, |
| 26 | "Path to model.onnx for telespeech ctc"); | 27 | "Path to model.onnx for telespeech ctc"); |
| @@ -109,6 +110,10 @@ bool OfflineModelConfig::Validate() const { | @@ -109,6 +110,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 109 | return moonshine.Validate(); | 110 | return moonshine.Validate(); |
| 110 | } | 111 | } |
| 111 | 112 | ||
| 113 | + if (!dolphin.model.empty()) { | ||
| 114 | + return dolphin.Validate(); | ||
| 115 | + } | ||
| 116 | + | ||
| 112 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { | 117 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { |
| 113 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", | 118 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", |
| 114 | telespeech_ctc.c_str()); | 119 | telespeech_ctc.c_str()); |
| @@ -136,6 +141,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -136,6 +141,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 136 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 141 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 137 | os << "sense_voice=" << sense_voice.ToString() << ", "; | 142 | os << "sense_voice=" << sense_voice.ToString() << ", "; |
| 138 | os << "moonshine=" << moonshine.ToString() << ", "; | 143 | os << "moonshine=" << moonshine.ToString() << ", "; |
| 144 | + os << "dolphin=" << dolphin.ToString() << ", "; | ||
| 139 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; | 145 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; |
| 140 | os << "tokens=\"" << tokens << "\", "; | 146 | os << "tokens=\"" << tokens << "\", "; |
| 141 | os << "num_threads=" << num_threads << ", "; | 147 | os << "num_threads=" << num_threads << ", "; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/offline-dolphin-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| @@ -30,6 +31,7 @@ struct OfflineModelConfig { | @@ -30,6 +31,7 @@ struct OfflineModelConfig { | ||
| 30 | OfflineWenetCtcModelConfig wenet_ctc; | 31 | OfflineWenetCtcModelConfig wenet_ctc; |
| 31 | OfflineSenseVoiceModelConfig sense_voice; | 32 | OfflineSenseVoiceModelConfig sense_voice; |
| 32 | OfflineMoonshineModelConfig moonshine; | 33 | OfflineMoonshineModelConfig moonshine; |
| 34 | + OfflineDolphinModelConfig dolphin; | ||
| 33 | std::string telespeech_ctc; | 35 | std::string telespeech_ctc; |
| 34 | 36 | ||
| 35 | std::string tokens; | 37 | std::string tokens; |
| @@ -62,6 +64,7 @@ struct OfflineModelConfig { | @@ -62,6 +64,7 @@ struct OfflineModelConfig { | ||
| 62 | const OfflineWenetCtcModelConfig &wenet_ctc, | 64 | const OfflineWenetCtcModelConfig &wenet_ctc, |
| 63 | const OfflineSenseVoiceModelConfig &sense_voice, | 65 | const OfflineSenseVoiceModelConfig &sense_voice, |
| 64 | const OfflineMoonshineModelConfig &moonshine, | 66 | const OfflineMoonshineModelConfig &moonshine, |
| 67 | + const OfflineDolphinModelConfig &dolphin, | ||
| 65 | const std::string &telespeech_ctc, | 68 | const std::string &telespeech_ctc, |
| 66 | const std::string &tokens, int32_t num_threads, bool debug, | 69 | const std::string &tokens, int32_t num_threads, bool debug, |
| 67 | const std::string &provider, const std::string &model_type, | 70 | const std::string &provider, const std::string &model_type, |
| @@ -77,6 +80,7 @@ struct OfflineModelConfig { | @@ -77,6 +80,7 @@ struct OfflineModelConfig { | ||
| 77 | wenet_ctc(wenet_ctc), | 80 | wenet_ctc(wenet_ctc), |
| 78 | sense_voice(sense_voice), | 81 | sense_voice(sense_voice), |
| 79 | moonshine(moonshine), | 82 | moonshine(moonshine), |
| 83 | + dolphin(dolphin), | ||
| 80 | telespeech_ctc(telespeech_ctc), | 84 | telespeech_ctc(telespeech_ctc), |
| 81 | tokens(tokens), | 85 | tokens(tokens), |
| 82 | num_threads(num_threads), | 86 | num_threads(num_threads), |
| @@ -118,6 +118,19 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -118,6 +118,19 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 118 | } | 118 | } |
| 119 | } | 119 | } |
| 120 | 120 | ||
| 121 | + if (!config_.model_config.dolphin.model.empty()) { | ||
| 122 | + config_.feat_config.low_freq = 0; | ||
| 123 | + config_.feat_config.high_freq = 8000; | ||
| 124 | + config_.feat_config.remove_dc_offset = false; | ||
| 125 | + config_.feat_config.dither = 0; | ||
| 126 | + config_.feat_config.preemph_coeff = 0; | ||
| 127 | + config_.feat_config.window_type = "hann"; | ||
| 128 | + config_.feat_config.feature_dim = 80; | ||
| 129 | + config_.feat_config.is_librosa = true; | ||
| 130 | + config_.feat_config.frame_length_ms = 31.25; // 16000/512 = 31.25 | ||
| 131 | + config_.feat_config.snip_edges = false; | ||
| 132 | + } | ||
| 133 | + | ||
| 121 | if (!config_.model_config.wenet_ctc.model.empty()) { | 134 | if (!config_.model_config.wenet_ctc.model.empty()) { |
| 122 | // WeNet CTC models assume input samples are in the range | 135 | // WeNet CTC models assume input samples are in the range |
| 123 | // [-32768, 32767], so we set normalize_samples to false | 136 | // [-32768, 32767], so we set normalize_samples to false |
| @@ -157,7 +170,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -157,7 +170,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 157 | } else { | 170 | } else { |
| 158 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | 171 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", |
| 159 | config_.decoding_method.c_str()); | 172 | config_.decoding_method.c_str()); |
| 160 | - exit(-1); | 173 | + SHERPA_ONNX_EXIT(-1); |
| 161 | } | 174 | } |
| 162 | } | 175 | } |
| 163 | 176 | ||
| @@ -166,7 +179,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -166,7 +179,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 166 | } | 179 | } |
| 167 | 180 | ||
| 168 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { | 181 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { |
| 169 | - if (!model_->SupportBatchProcessing()) { | 182 | + if (!model_->SupportBatchProcessing() || (n == 1)) { |
| 170 | // If the model does not support batch process, | 183 | // If the model does not support batch process, |
| 171 | // we process each stream independently. | 184 | // we process each stream independently. |
| 172 | for (int32_t i = 0; i != n; ++i) { | 185 | for (int32_t i = 0; i != n; ++i) { |
| @@ -190,6 +203,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -190,6 +203,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 190 | std::vector<float> f = ss[i]->GetFrames(); | 203 | std::vector<float> f = ss[i]->GetFrames(); |
| 191 | 204 | ||
| 192 | int32_t num_frames = f.size() / feat_dim; | 205 | int32_t num_frames = f.size() / feat_dim; |
| 206 | + | ||
| 207 | + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); | ||
| 208 | + | ||
| 193 | features_vec[i] = std::move(f); | 209 | features_vec[i] = std::move(f); |
| 194 | 210 | ||
| 195 | features_length_vec[i] = num_frames; | 211 | features_length_vec[i] = num_frames; |
| @@ -241,6 +257,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -241,6 +257,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 241 | 257 | ||
| 242 | int32_t num_frames = f.size() / feat_dim; | 258 | int32_t num_frames = f.size() / feat_dim; |
| 243 | 259 | ||
| 260 | + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); | ||
| 261 | + | ||
| 244 | std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | 262 | std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; |
| 245 | 263 | ||
| 246 | Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | 264 | Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), |
| @@ -49,7 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -49,7 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 49 | if (!config.model_config.nemo_ctc.model.empty() || | 49 | if (!config.model_config.nemo_ctc.model.empty() || |
| 50 | !config.model_config.zipformer_ctc.model.empty() || | 50 | !config.model_config.zipformer_ctc.model.empty() || |
| 51 | !config.model_config.tdnn.model.empty() || | 51 | !config.model_config.tdnn.model.empty() || |
| 52 | - !config.model_config.wenet_ctc.model.empty()) { | 52 | + !config.model_config.wenet_ctc.model.empty() || |
| 53 | + !config.model_config.dolphin.model.empty()) { | ||
| 53 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 54 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 54 | } | 55 | } |
| 55 | 56 | ||
| @@ -234,7 +235,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -234,7 +235,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 234 | if (!config.model_config.nemo_ctc.model.empty() || | 235 | if (!config.model_config.nemo_ctc.model.empty() || |
| 235 | !config.model_config.zipformer_ctc.model.empty() || | 236 | !config.model_config.zipformer_ctc.model.empty() || |
| 236 | !config.model_config.tdnn.model.empty() || | 237 | !config.model_config.tdnn.model.empty() || |
| 237 | - !config.model_config.wenet_ctc.model.empty()) { | 238 | + !config.model_config.wenet_ctc.model.empty() || |
| 239 | + !config.model_config.dolphin.model.empty()) { | ||
| 238 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 240 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 239 | } | 241 | } |
| 240 | 242 |
| @@ -23,9 +23,8 @@ struct OfflineSenseVoiceModelConfig { | @@ -23,9 +23,8 @@ struct OfflineSenseVoiceModelConfig { | ||
| 23 | bool use_itn = false; | 23 | bool use_itn = false; |
| 24 | 24 | ||
| 25 | OfflineSenseVoiceModelConfig() = default; | 25 | OfflineSenseVoiceModelConfig() = default; |
| 26 | - explicit OfflineSenseVoiceModelConfig(const std::string &model, | ||
| 27 | - const std::string &language, | ||
| 28 | - bool use_itn) | 26 | + OfflineSenseVoiceModelConfig(const std::string &model, |
| 27 | + const std::string &language, bool use_itn) | ||
| 29 | : model(model), language(language), use_itn(use_itn) {} | 28 | : model(model), language(language), use_itn(use_itn) {} |
| 30 | 29 | ||
| 31 | void Register(ParseOptions *po); | 30 | void Register(ParseOptions *po); |
| @@ -41,6 +41,9 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -41,6 +41,9 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 41 | std::string text; | 41 | std::string text; |
| 42 | for (auto i : src.tokens) { | 42 | for (auto i : src.tokens) { |
| 43 | auto sym = sym_table[i]; | 43 | auto sym = sym_table[i]; |
| 44 | + if (sym == "<unk>") { | ||
| 45 | + continue; | ||
| 46 | + } | ||
| 44 | 47 | ||
| 45 | text.append(sym); | 48 | text.append(sym); |
| 46 | 49 |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | #ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ |
| 5 | #define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ | 5 | #define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ |
| 6 | 6 | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 7 | #include "rknn_api.h" // NOLINT | 9 | #include "rknn_api.h" // NOLINT |
| 8 | #include "sherpa-onnx/csrc/online-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 9 | #include "sherpa-onnx/csrc/vad-model.h" | 11 | #include "sherpa-onnx/csrc/vad-model.h" |
| @@ -9,6 +9,7 @@ set(srcs | @@ -9,6 +9,7 @@ set(srcs | ||
| 9 | features.cc | 9 | features.cc |
| 10 | keyword-spotter.cc | 10 | keyword-spotter.cc |
| 11 | offline-ctc-fst-decoder-config.cc | 11 | offline-ctc-fst-decoder-config.cc |
| 12 | + offline-dolphin-model-config.cc | ||
| 12 | offline-fire-red-asr-model-config.cc | 13 | offline-fire-red-asr-model-config.cc |
| 13 | offline-lm-config.cc | 14 | offline-lm-config.cc |
| 14 | offline-model-config.cc | 15 | offline-model-config.cc |
| 1 | +// sherpa-onnx/python/csrc/offline-dolphin-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-dolphin-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineDolphinModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineDolphinModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineDolphinModelConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 19 | + .def_readwrite("model", &PyClass::model) | ||
| 20 | + .def("__str__", &PyClass::ToString); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-dolphin-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineDolphinModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_DOLPHIN_MODEL_CONFIG_H_ |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | +#include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h" | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| @@ -32,6 +33,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -32,6 +33,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 32 | PybindOfflineWenetCtcModelConfig(m); | 33 | PybindOfflineWenetCtcModelConfig(m); |
| 33 | PybindOfflineSenseVoiceModelConfig(m); | 34 | PybindOfflineSenseVoiceModelConfig(m); |
| 34 | PybindOfflineMoonshineModelConfig(m); | 35 | PybindOfflineMoonshineModelConfig(m); |
| 36 | + PybindOfflineDolphinModelConfig(m); | ||
| 35 | 37 | ||
| 36 | using PyClass = OfflineModelConfig; | 38 | using PyClass = OfflineModelConfig; |
| 37 | py::class_<PyClass>(*m, "OfflineModelConfig") | 39 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| @@ -44,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -44,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 44 | const OfflineZipformerCtcModelConfig &, | 46 | const OfflineZipformerCtcModelConfig &, |
| 45 | const OfflineWenetCtcModelConfig &, | 47 | const OfflineWenetCtcModelConfig &, |
| 46 | const OfflineSenseVoiceModelConfig &, | 48 | const OfflineSenseVoiceModelConfig &, |
| 47 | - const OfflineMoonshineModelConfig &, const std::string &, | 49 | + const OfflineMoonshineModelConfig &, |
| 50 | + const OfflineDolphinModelConfig &, const std::string &, | ||
| 48 | const std::string &, int32_t, bool, const std::string &, | 51 | const std::string &, int32_t, bool, const std::string &, |
| 49 | const std::string &, const std::string &, | 52 | const std::string &, const std::string &, |
| 50 | const std::string &>(), | 53 | const std::string &>(), |
| @@ -58,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -58,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 58 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), | 61 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), |
| 59 | py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), | 62 | py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), |
| 60 | py::arg("moonshine") = OfflineMoonshineModelConfig(), | 63 | py::arg("moonshine") = OfflineMoonshineModelConfig(), |
| 64 | + py::arg("dolphin") = OfflineDolphinModelConfig(), | ||
| 61 | py::arg("telespeech_ctc") = "", py::arg("tokens"), | 65 | py::arg("telespeech_ctc") = "", py::arg("tokens"), |
| 62 | py::arg("num_threads"), py::arg("debug") = false, | 66 | py::arg("num_threads"), py::arg("debug") = false, |
| 63 | py::arg("provider") = "cpu", py::arg("model_type") = "", | 67 | py::arg("provider") = "cpu", py::arg("model_type") = "", |
| @@ -72,6 +76,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -72,6 +76,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 72 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 76 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 73 | .def_readwrite("sense_voice", &PyClass::sense_voice) | 77 | .def_readwrite("sense_voice", &PyClass::sense_voice) |
| 74 | .def_readwrite("moonshine", &PyClass::moonshine) | 78 | .def_readwrite("moonshine", &PyClass::moonshine) |
| 79 | + .def_readwrite("dolphin", &PyClass::dolphin) | ||
| 75 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) | 80 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) |
| 76 | .def_readwrite("tokens", &PyClass::tokens) | 81 | .def_readwrite("tokens", &PyClass::tokens) |
| 77 | .def_readwrite("num_threads", &PyClass::num_threads) | 82 | .def_readwrite("num_threads", &PyClass::num_threads) |
| @@ -6,6 +6,7 @@ from typing import List, Optional | @@ -6,6 +6,7 @@ from typing import List, Optional | ||
| 6 | from _sherpa_onnx import ( | 6 | from _sherpa_onnx import ( |
| 7 | FeatureExtractorConfig, | 7 | FeatureExtractorConfig, |
| 8 | OfflineCtcFstDecoderConfig, | 8 | OfflineCtcFstDecoderConfig, |
| 9 | + OfflineDolphinModelConfig, | ||
| 9 | OfflineFireRedAsrModelConfig, | 10 | OfflineFireRedAsrModelConfig, |
| 10 | OfflineLMConfig, | 11 | OfflineLMConfig, |
| 11 | OfflineModelConfig, | 12 | OfflineModelConfig, |
| @@ -409,6 +410,78 @@ class OfflineRecognizer(object): | @@ -409,6 +410,78 @@ class OfflineRecognizer(object): | ||
| 409 | return self | 410 | return self |
| 410 | 411 | ||
| 411 | @classmethod | 412 | @classmethod |
| 413 | + def from_dolphin_ctc( | ||
| 414 | + cls, | ||
| 415 | + model: str, | ||
| 416 | + tokens: str, | ||
| 417 | + num_threads: int = 1, | ||
| 418 | + sample_rate: int = 16000, | ||
| 419 | + feature_dim: int = 80, | ||
| 420 | + decoding_method: str = "greedy_search", | ||
| 421 | + debug: bool = False, | ||
| 422 | + provider: str = "cpu", | ||
| 423 | + rule_fsts: str = "", | ||
| 424 | + rule_fars: str = "", | ||
| 425 | + ): | ||
| 426 | + """ | ||
| 427 | + Please refer to | ||
| 428 | + `<https://k2-fsa.github.io/sherpa/onnx/dolphin/index.html>`_ | ||
| 429 | + to download pre-trained models. | ||
| 430 | + | ||
| 431 | + Args: | ||
| 432 | + model: | ||
| 433 | + Path to ``model.onnx`` or ``model.int8.onnx``. | ||
| 434 | + tokens: | ||
| 435 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 436 | + columns:: | ||
| 437 | + | ||
| 438 | + symbol integer_id | ||
| 439 | + | ||
| 440 | + num_threads: | ||
| 441 | + Number of threads for neural network computation. | ||
| 442 | + sample_rate: | ||
| 443 | + Sample rate of the training data used to train the model. | ||
| 444 | + feature_dim: | ||
| 445 | + Dimension of the feature used to train the model. | ||
| 446 | + decoding_method: | ||
| 447 | + Valid values are greedy_search. | ||
| 448 | + debug: | ||
| 449 | + True to show debug messages. | ||
| 450 | + provider: | ||
| 451 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 452 | + rule_fsts: | ||
| 453 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 454 | + If there are multiple fsts, they are separated by a comma. | ||
| 455 | + rule_fars: | ||
| 456 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 457 | + If there are multiple archives, they are separated by a comma. | ||
| 458 | + """ | ||
| 459 | + self = cls.__new__(cls) | ||
| 460 | + model_config = OfflineModelConfig( | ||
| 461 | + dolphin=OfflineDolphinModelConfig(model=model), | ||
| 462 | + tokens=tokens, | ||
| 463 | + num_threads=num_threads, | ||
| 464 | + debug=debug, | ||
| 465 | + provider=provider, | ||
| 466 | + ) | ||
| 467 | + | ||
| 468 | + feat_config = FeatureExtractorConfig( | ||
| 469 | + sampling_rate=sample_rate, | ||
| 470 | + feature_dim=feature_dim, | ||
| 471 | + ) | ||
| 472 | + | ||
| 473 | + recognizer_config = OfflineRecognizerConfig( | ||
| 474 | + feat_config=feat_config, | ||
| 475 | + model_config=model_config, | ||
| 476 | + decoding_method=decoding_method, | ||
| 477 | + rule_fsts=rule_fsts, | ||
| 478 | + rule_fars=rule_fars, | ||
| 479 | + ) | ||
| 480 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 481 | + self.config = recognizer_config | ||
| 482 | + return self | ||
| 483 | + | ||
| 484 | + @classmethod | ||
| 412 | def from_nemo_ctc( | 485 | def from_nemo_ctc( |
| 413 | cls, | 486 | cls, |
| 414 | model: str, | 487 | model: str, |
-
请 注册 或 登录 后发表评论