Committed by
GitHub
Support GigaAM CTC models for Russian ASR (#1464)
See also https://github.com/salute-developers/GigaAM
正在显示
24 个修改的文件
包含
641 行增加
和
160 行删除
| @@ -16,6 +16,21 @@ echo "PATH: $PATH" | @@ -16,6 +16,21 @@ echo "PATH: $PATH" | ||
| 16 | which $EXE | 16 | which $EXE |
| 17 | 17 | ||
| 18 | log "------------------------------------------------------------" | 18 | log "------------------------------------------------------------" |
| 19 | +log "Run NeMo GigaAM Russian models" | ||
| 20 | +log "------------------------------------------------------------" | ||
| 21 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 | ||
| 22 | +tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 | ||
| 23 | +rm sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 | ||
| 24 | + | ||
| 25 | +$EXE \ | ||
| 26 | + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/model.int8.onnx \ | ||
| 27 | + --tokens=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/tokens.txt \ | ||
| 28 | + --debug=1 \ | ||
| 29 | + ./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/test_wavs/example.wav | ||
| 30 | + | ||
| 31 | +rm -rf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 | ||
| 32 | + | ||
| 33 | +log "------------------------------------------------------------" | ||
| 19 | log "Run SenseVoice models" | 34 | log "Run SenseVoice models" |
| 20 | log "------------------------------------------------------------" | 35 | log "------------------------------------------------------------" |
| 21 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | 36 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 |
| 1 | +name: export-nemo-giga-am-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-giga-am-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-am-giga-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export nemo GigaAM models to ONNX | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Run CTC | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + pushd scripts/nemo/GigaAM | ||
| 33 | + ./run-ctc.sh | ||
| 34 | + popd | ||
| 35 | + | ||
| 36 | + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 | ||
| 37 | + mkdir $d | ||
| 38 | + mkdir $d/test_wavs | ||
| 39 | + rm scripts/nemo/GigaAM/model.onnx | ||
| 40 | + mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ | ||
| 41 | + mv -v scripts/nemo/GigaAM/*.md $d/ | ||
| 42 | + mv -v scripts/nemo/GigaAM/*.pdf $d/ | ||
| 43 | + mv -v scripts/nemo/GigaAM/tokens.txt $d/ | ||
| 44 | + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ | ||
| 45 | + mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ | ||
| 46 | + mv -v scripts/nemo/GigaAM/*-ctc.py $d/ | ||
| 47 | + | ||
| 48 | + ls -lh scripts/nemo/GigaAM/ | ||
| 49 | + | ||
| 50 | + ls -lh $d | ||
| 51 | + | ||
| 52 | + tar cjvf ${d}.tar.bz2 $d | ||
| 53 | + | ||
| 54 | + - name: Release | ||
| 55 | + uses: svenstaro/upload-release-action@v2 | ||
| 56 | + with: | ||
| 57 | + file_glob: true | ||
| 58 | + file: ./*.tar.bz2 | ||
| 59 | + overwrite: true | ||
| 60 | + repo_name: k2-fsa/sherpa-onnx | ||
| 61 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 62 | + tag: asr-models | ||
| 63 | + | ||
| 64 | + - name: Publish to huggingface (CTC) | ||
| 65 | + env: | ||
| 66 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 67 | + uses: nick-fields/retry@v3 | ||
| 68 | + with: | ||
| 69 | + max_attempts: 20 | ||
| 70 | + timeout_seconds: 200 | ||
| 71 | + shell: bash | ||
| 72 | + command: | | ||
| 73 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 74 | + git config --global user.name "Fangjun Kuang" | ||
| 75 | + | ||
| 76 | + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 | ||
| 77 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 78 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 79 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | ||
| 80 | + mv -v $d/* ./huggingface | ||
| 81 | + cd huggingface | ||
| 82 | + git lfs track "*.onnx" | ||
| 83 | + git lfs track "*.wav" | ||
| 84 | + git status | ||
| 85 | + git add . | ||
| 86 | + git status | ||
| 87 | + git commit -m "add models" | ||
| 88 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main |
| @@ -149,6 +149,16 @@ jobs: | @@ -149,6 +149,16 @@ jobs: | ||
| 149 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 149 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 150 | path: install/* | 150 | path: install/* |
| 151 | 151 | ||
| 152 | + - name: Test offline CTC | ||
| 153 | + shell: bash | ||
| 154 | + run: | | ||
| 155 | + du -h -d1 . | ||
| 156 | + export PATH=$PWD/build/bin:$PATH | ||
| 157 | + export EXE=sherpa-onnx-offline | ||
| 158 | + | ||
| 159 | + .github/scripts/test-offline-ctc.sh | ||
| 160 | + du -h -d1 . | ||
| 161 | + | ||
| 152 | - name: Test C++ API | 162 | - name: Test C++ API |
| 153 | shell: bash | 163 | shell: bash |
| 154 | run: | | 164 | run: | |
| @@ -180,16 +190,6 @@ jobs: | @@ -180,16 +190,6 @@ jobs: | ||
| 180 | .github/scripts/test-offline-transducer.sh | 190 | .github/scripts/test-offline-transducer.sh |
| 181 | du -h -d1 . | 191 | du -h -d1 . |
| 182 | 192 | ||
| 183 | - - name: Test offline CTC | ||
| 184 | - shell: bash | ||
| 185 | - run: | | ||
| 186 | - du -h -d1 . | ||
| 187 | - export PATH=$PWD/build/bin:$PATH | ||
| 188 | - export EXE=sherpa-onnx-offline | ||
| 189 | - | ||
| 190 | - .github/scripts/test-offline-ctc.sh | ||
| 191 | - du -h -d1 . | ||
| 192 | - | ||
| 193 | - name: Test online punctuation | 193 | - name: Test online punctuation |
| 194 | shell: bash | 194 | shell: bash |
| 195 | run: | | 195 | run: | |
| @@ -336,6 +336,24 @@ def get_models(): | @@ -336,6 +336,24 @@ def get_models(): | ||
| 336 | popd | 336 | popd |
| 337 | """, | 337 | """, |
| 338 | ), | 338 | ), |
| 339 | + Model( | ||
| 340 | + model_name="sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24", | ||
| 341 | + idx=19, | ||
| 342 | + lang="ru", | ||
| 343 | + short_name="nemo_ctc_giga_am", | ||
| 344 | + cmd=""" | ||
| 345 | + pushd $model_name | ||
| 346 | + | ||
| 347 | + rm -rfv test_wavs | ||
| 348 | + | ||
| 349 | + rm -fv *.sh | ||
| 350 | + rm -fv *.py | ||
| 351 | + | ||
| 352 | + ls -lh | ||
| 353 | + | ||
| 354 | + popd | ||
| 355 | + """, | ||
| 356 | + ), | ||
| 339 | ] | 357 | ] |
| 340 | return models | 358 | return models |
| 341 | 359 |
scripts/nemo/GigaAM/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains scripts for converting models from | ||
| 4 | +https://github.com/salute-developers/GigaAM | ||
| 5 | +to sherpa-onnx. | ||
| 6 | + | ||
| 7 | +The ASR models are for Russian speech recognition in this folder. | ||
| 8 | + | ||
| 9 | +Please see the license of the models at | ||
| 10 | +https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf |
scripts/nemo/GigaAM/export-onnx-ctc.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +from typing import Dict | ||
| 4 | + | ||
| 5 | +import onnx | ||
| 6 | +import torch | ||
| 7 | +import torchaudio | ||
| 8 | +from nemo.collections.asr.models import EncDecCTCModel | ||
| 9 | +from nemo.collections.asr.modules.audio_preprocessing import ( | ||
| 10 | + AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor, | ||
| 11 | +) | ||
| 12 | +from nemo.collections.asr.parts.preprocessing.features import ( | ||
| 13 | + FilterbankFeaturesTA as NeMoFilterbankFeaturesTA, | ||
| 14 | +) | ||
| 15 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA): | ||
| 19 | + def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs): | ||
| 20 | + if "window_size" in kwargs: | ||
| 21 | + del kwargs["window_size"] | ||
| 22 | + if "window_stride" in kwargs: | ||
| 23 | + del kwargs["window_stride"] | ||
| 24 | + | ||
| 25 | + super().__init__(**kwargs) | ||
| 26 | + | ||
| 27 | + self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = ( | ||
| 28 | + torchaudio.transforms.MelSpectrogram( | ||
| 29 | + sample_rate=self._sample_rate, | ||
| 30 | + win_length=self.win_length, | ||
| 31 | + hop_length=self.hop_length, | ||
| 32 | + n_mels=kwargs["nfilt"], | ||
| 33 | + window_fn=self.torch_windows[kwargs["window"]], | ||
| 34 | + mel_scale=mel_scale, | ||
| 35 | + norm=kwargs["mel_norm"], | ||
| 36 | + n_fft=kwargs["n_fft"], | ||
| 37 | + f_max=kwargs.get("highfreq", None), | ||
| 38 | + f_min=kwargs.get("lowfreq", 0), | ||
| 39 | + wkwargs=wkwargs, | ||
| 40 | + ) | ||
| 41 | + ) | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor): | ||
| 45 | + def __init__(self, mel_scale: str = "htk", **kwargs): | ||
| 46 | + super().__init__(**kwargs) | ||
| 47 | + kwargs["nfilt"] = kwargs["features"] | ||
| 48 | + del kwargs["features"] | ||
| 49 | + self.featurizer = ( | ||
| 50 | + FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility | ||
| 51 | + mel_scale=mel_scale, | ||
| 52 | + **kwargs, | ||
| 53 | + ) | ||
| 54 | + ) | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 58 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 59 | + | ||
| 60 | + Args: | ||
| 61 | + filename: | ||
| 62 | + Filename of the ONNX model to be changed. | ||
| 63 | + meta_data: | ||
| 64 | + Key-value pairs. | ||
| 65 | + """ | ||
| 66 | + model = onnx.load(filename) | ||
| 67 | + while len(model.metadata_props): | ||
| 68 | + model.metadata_props.pop() | ||
| 69 | + | ||
| 70 | + for key, value in meta_data.items(): | ||
| 71 | + meta = model.metadata_props.add() | ||
| 72 | + meta.key = key | ||
| 73 | + meta.value = str(value) | ||
| 74 | + | ||
| 75 | + onnx.save(model, filename) | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def main(): | ||
| 79 | + model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml") | ||
| 80 | + ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu") | ||
| 81 | + model.load_state_dict(ckpt, strict=False) | ||
| 82 | + model.eval() | ||
| 83 | + | ||
| 84 | + with open("tokens.txt", "w", encoding="utf-8") as f: | ||
| 85 | + for i, t in enumerate(model.cfg.labels): | ||
| 86 | + f.write(f"{t} {i}\n") | ||
| 87 | + f.write(f"<blk> {i+1}\n") | ||
| 88 | + | ||
| 89 | + filename = "model.onnx" | ||
| 90 | + model.export(filename) | ||
| 91 | + | ||
| 92 | + meta_data = { | ||
| 93 | + "vocab_size": len(model.cfg.labels) + 1, | ||
| 94 | + "normalize_type": "", | ||
| 95 | + "subsampling_factor": 4, | ||
| 96 | + "model_type": "EncDecCTCModel", | ||
| 97 | + "version": "1", | ||
| 98 | + "model_author": "https://github.com/salute-developers/GigaAM", | ||
| 99 | + "license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf", | ||
| 100 | + "language": "Russian", | ||
| 101 | + "is_giga_am": 1, | ||
| 102 | + } | ||
| 103 | + add_meta_data(filename, meta_data) | ||
| 104 | + | ||
| 105 | + filename_int8 = "model.int8.onnx" | ||
| 106 | + quantize_dynamic( | ||
| 107 | + model_input=filename, | ||
| 108 | + model_output=filename_int8, | ||
| 109 | + weight_type=QuantType.QUInt8, | ||
| 110 | + ) | ||
| 111 | + | ||
| 112 | + | ||
| 113 | +if __name__ == "__main__": | ||
| 114 | + main() |
scripts/nemo/GigaAM/run-ctc.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +function install_nemo() { | ||
| 7 | + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py | ||
| 8 | + python3 get-pip.py | ||
| 9 | + | ||
| 10 | + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html | ||
| 11 | + | ||
| 12 | + pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa | ||
| 13 | + pip install -qq ipython | ||
| 14 | + | ||
| 15 | + # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython | ||
| 16 | + | ||
| 17 | + BRANCH='main' | ||
| 18 | + python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 19 | + | ||
| 20 | + pip install numpy==1.26.4 | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +function download_files() { | ||
| 24 | + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt | ||
| 25 | + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml | ||
| 26 | + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav | ||
| 27 | + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav | ||
| 28 | + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +install_nemo | ||
| 32 | +download_files | ||
| 33 | + | ||
| 34 | +python3 ./export-onnx-ctc.py | ||
| 35 | +ls -lh | ||
| 36 | +python3 ./test-onnx-ctc.py |
scripts/nemo/GigaAM/test-onnx-ctc.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +# https://github.com/salute-developers/GigaAM | ||
| 5 | + | ||
| 6 | +import kaldi_native_fbank as knf | ||
| 7 | +import librosa | ||
| 8 | +import numpy as np | ||
| 9 | +import onnxruntime as ort | ||
| 10 | +import soundfile as sf | ||
| 11 | +import torch | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +def create_fbank(): | ||
| 15 | + opts = knf.FbankOptions() | ||
| 16 | + opts.frame_opts.dither = 0 | ||
| 17 | + opts.frame_opts.remove_dc_offset = False | ||
| 18 | + opts.frame_opts.preemph_coeff = 0 | ||
| 19 | + opts.frame_opts.window_type = "hann" | ||
| 20 | + | ||
| 21 | + # Even though GigaAM uses 400 for fft, here we use 512 | ||
| 22 | + # since kaldi-native-fbank only support fft for power of 2. | ||
| 23 | + opts.frame_opts.round_to_power_of_two = True | ||
| 24 | + | ||
| 25 | + opts.mel_opts.low_freq = 0 | ||
| 26 | + opts.mel_opts.high_freq = 8000 | ||
| 27 | + opts.mel_opts.num_bins = 64 | ||
| 28 | + | ||
| 29 | + fbank = knf.OnlineFbank(opts) | ||
| 30 | + return fbank | ||
| 31 | + | ||
| 32 | + | ||
| 33 | +def compute_features(audio, fbank) -> np.ndarray: | ||
| 34 | + """ | ||
| 35 | + Args: | ||
| 36 | + audio: (num_samples,), np.float32 | ||
| 37 | + fbank: the fbank extractor | ||
| 38 | + Returns: | ||
| 39 | + features: (num_frames, feat_dim), np.float32 | ||
| 40 | + """ | ||
| 41 | + assert len(audio.shape) == 1, audio.shape | ||
| 42 | + fbank.accept_waveform(16000, audio) | ||
| 43 | + ans = [] | ||
| 44 | + processed = 0 | ||
| 45 | + while processed < fbank.num_frames_ready: | ||
| 46 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 47 | + processed += 1 | ||
| 48 | + ans = np.stack(ans) | ||
| 49 | + return ans | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +def display(sess): | ||
| 53 | + print("==========Input==========") | ||
| 54 | + for i in sess.get_inputs(): | ||
| 55 | + print(i) | ||
| 56 | + print("==========Output==========") | ||
| 57 | + for i in sess.get_outputs(): | ||
| 58 | + print(i) | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +""" | ||
| 62 | +==========Input========== | ||
| 63 | +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2']) | ||
| 64 | +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) | ||
| 65 | +==========Output========== | ||
| 66 | +NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 34]) | ||
| 67 | +""" | ||
| 68 | + | ||
| 69 | + | ||
| 70 | +class OnnxModel: | ||
| 71 | + def __init__( | ||
| 72 | + self, | ||
| 73 | + filename: str, | ||
| 74 | + ): | ||
| 75 | + session_opts = ort.SessionOptions() | ||
| 76 | + session_opts.inter_op_num_threads = 1 | ||
| 77 | + session_opts.intra_op_num_threads = 1 | ||
| 78 | + | ||
| 79 | + self.model = ort.InferenceSession( | ||
| 80 | + filename, | ||
| 81 | + sess_options=session_opts, | ||
| 82 | + providers=["CPUExecutionProvider"], | ||
| 83 | + ) | ||
| 84 | + display(self.model) | ||
| 85 | + | ||
| 86 | + def __call__(self, x: np.ndarray): | ||
| 87 | + # x: (T, C) | ||
| 88 | + x = torch.from_numpy(x) | ||
| 89 | + x = x.t().unsqueeze(0) | ||
| 90 | + # x: [1, C, T] | ||
| 91 | + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) | ||
| 92 | + | ||
| 93 | + log_probs = self.model.run( | ||
| 94 | + [ | ||
| 95 | + self.model.get_outputs()[0].name, | ||
| 96 | + ], | ||
| 97 | + { | ||
| 98 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 99 | + self.model.get_inputs()[1].name: x_lens.numpy(), | ||
| 100 | + }, | ||
| 101 | + )[0] | ||
| 102 | + # [batch_size, T, dim] | ||
| 103 | + return log_probs | ||
| 104 | + | ||
| 105 | + | ||
| 106 | +def main(): | ||
| 107 | + filename = "./model.int8.onnx" | ||
| 108 | + tokens = "./tokens.txt" | ||
| 109 | + wav = "./example.wav" | ||
| 110 | + | ||
| 111 | + model = OnnxModel(filename) | ||
| 112 | + | ||
| 113 | + id2token = dict() | ||
| 114 | + with open(tokens, encoding="utf-8") as f: | ||
| 115 | + for line in f: | ||
| 116 | + fields = line.split() | ||
| 117 | + if len(fields) == 1: | ||
| 118 | + id2token[int(fields[0])] = " " | ||
| 119 | + else: | ||
| 120 | + t, idx = fields | ||
| 121 | + id2token[int(idx)] = t | ||
| 122 | + | ||
| 123 | + fbank = create_fbank() | ||
| 124 | + audio, sample_rate = sf.read(wav, dtype="float32", always_2d=True) | ||
| 125 | + audio = audio[:, 0] # only use the first channel | ||
| 126 | + if sample_rate != 16000: | ||
| 127 | + audio = librosa.resample( | ||
| 128 | + audio, | ||
| 129 | + orig_sr=sample_rate, | ||
| 130 | + target_sr=16000, | ||
| 131 | + ) | ||
| 132 | + sample_rate = 16000 | ||
| 133 | + | ||
| 134 | + features = compute_features(audio, fbank) | ||
| 135 | + print("features.shape", features.shape) | ||
| 136 | + | ||
| 137 | + blank = len(id2token) - 1 | ||
| 138 | + prev = -1 | ||
| 139 | + ans = [] | ||
| 140 | + log_probs = model(features) | ||
| 141 | + print("log_probs", log_probs.shape) | ||
| 142 | + log_probs = torch.from_numpy(log_probs)[0] | ||
| 143 | + ids = torch.argmax(log_probs, dim=1).tolist() | ||
| 144 | + for i in ids: | ||
| 145 | + if i != blank and i != prev: | ||
| 146 | + ans.append(i) | ||
| 147 | + prev = i | ||
| 148 | + | ||
| 149 | + tokens = [id2token[i] for i in ans] | ||
| 150 | + | ||
| 151 | + text = "".join(tokens) | ||
| 152 | + print(wav) | ||
| 153 | + print(text) | ||
| 154 | + | ||
| 155 | + | ||
| 156 | +if __name__ == "__main__": | ||
| 157 | + main() |
| @@ -193,6 +193,7 @@ class FeatureExtractor::Impl { | @@ -193,6 +193,7 @@ class FeatureExtractor::Impl { | ||
| 193 | opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; | 193 | opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; |
| 194 | opts_.frame_opts.frame_length_ms = config_.frame_length_ms; | 194 | opts_.frame_opts.frame_length_ms = config_.frame_length_ms; |
| 195 | opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; | 195 | opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; |
| 196 | + opts_.frame_opts.preemph_coeff = config_.preemph_coeff; | ||
| 196 | opts_.frame_opts.window_type = config_.window_type; | 197 | opts_.frame_opts.window_type = config_.window_type; |
| 197 | 198 | ||
| 198 | opts_.mel_opts.num_bins = config_.feature_dim; | 199 | opts_.mel_opts.num_bins = config_.feature_dim; |
| @@ -211,6 +212,7 @@ class FeatureExtractor::Impl { | @@ -211,6 +212,7 @@ class FeatureExtractor::Impl { | ||
| 211 | mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; | 212 | mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; |
| 212 | mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; | 213 | mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; |
| 213 | mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; | 214 | mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; |
| 215 | + mfcc_opts_.frame_opts.preemph_coeff = config_.preemph_coeff; | ||
| 214 | mfcc_opts_.frame_opts.window_type = config_.window_type; | 216 | mfcc_opts_.frame_opts.window_type = config_.window_type; |
| 215 | 217 | ||
| 216 | mfcc_opts_.mel_opts.num_bins = config_.feature_dim; | 218 | mfcc_opts_.mel_opts.num_bins = config_.feature_dim; |
| @@ -57,6 +57,7 @@ struct FeatureExtractorConfig { | @@ -57,6 +57,7 @@ struct FeatureExtractorConfig { | ||
| 57 | float frame_length_ms = 25.0f; // in milliseconds. | 57 | float frame_length_ms = 25.0f; // in milliseconds. |
| 58 | bool is_librosa = false; | 58 | bool is_librosa = false; |
| 59 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. | 59 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. |
| 60 | + float preemph_coeff = 0.97f; // Preemphasis coefficient. | ||
| 60 | std::string window_type = "povey"; // e.g. Hamming window | 61 | std::string window_type = "povey"; // e.g. Hamming window |
| 61 | 62 | ||
| 62 | // For models from NeMo | 63 | // For models from NeMo |
| @@ -10,8 +10,8 @@ | @@ -10,8 +10,8 @@ | ||
| 10 | 10 | ||
| 11 | #include "cppjieba/Jieba.hpp" | 11 | #include "cppjieba/Jieba.hpp" |
| 12 | #include "sherpa-onnx/csrc/file-utils.h" | 12 | #include "sherpa-onnx/csrc/file-utils.h" |
| 13 | -#include "sherpa-onnx/csrc/lexicon.h" | ||
| 14 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 15 | #include "sherpa-onnx/csrc/text-utils.h" | 15 | #include "sherpa-onnx/csrc/text-utils.h" |
| 16 | 16 | ||
| 17 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -21,6 +21,7 @@ | @@ -21,6 +21,7 @@ | ||
| 21 | 21 | ||
| 22 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 23 | #include "sherpa-onnx/csrc/onnx-utils.h" | 23 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 24 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 24 | #include "sherpa-onnx/csrc/text-utils.h" | 25 | #include "sherpa-onnx/csrc/text-utils.h" |
| 25 | 26 | ||
| 26 | namespace sherpa_onnx { | 27 | namespace sherpa_onnx { |
| @@ -74,45 +75,6 @@ static std::vector<std::string> ProcessHeteronyms( | @@ -74,45 +75,6 @@ static std::vector<std::string> ProcessHeteronyms( | ||
| 74 | return ans; | 75 | return ans; |
| 75 | } | 76 | } |
| 76 | 77 | ||
| 77 | -// Note: We don't use SymbolTable here since tokens may contain a blank | ||
| 78 | -// in the first column | ||
| 79 | -std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) { | ||
| 80 | - std::unordered_map<std::string, int32_t> token2id; | ||
| 81 | - | ||
| 82 | - std::string line; | ||
| 83 | - | ||
| 84 | - std::string sym; | ||
| 85 | - int32_t id = -1; | ||
| 86 | - while (std::getline(is, line)) { | ||
| 87 | - std::istringstream iss(line); | ||
| 88 | - iss >> sym; | ||
| 89 | - if (iss.eof()) { | ||
| 90 | - id = atoi(sym.c_str()); | ||
| 91 | - sym = " "; | ||
| 92 | - } else { | ||
| 93 | - iss >> id; | ||
| 94 | - } | ||
| 95 | - | ||
| 96 | - // eat the trailing \r\n on windows | ||
| 97 | - iss >> std::ws; | ||
| 98 | - if (!iss.eof()) { | ||
| 99 | - SHERPA_ONNX_LOGE("Error: %s", line.c_str()); | ||
| 100 | - exit(-1); | ||
| 101 | - } | ||
| 102 | - | ||
| 103 | -#if 0 | ||
| 104 | - if (token2id.count(sym)) { | ||
| 105 | - SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
| 106 | - sym.c_str(), line.c_str(), token2id.at(sym)); | ||
| 107 | - exit(-1); | ||
| 108 | - } | ||
| 109 | -#endif | ||
| 110 | - token2id.insert({std::move(sym), id}); | ||
| 111 | - } | ||
| 112 | - | ||
| 113 | - return token2id; | ||
| 114 | -} | ||
| 115 | - | ||
| 116 | std::vector<int32_t> ConvertTokensToIds( | 78 | std::vector<int32_t> ConvertTokensToIds( |
| 117 | const std::unordered_map<std::string, int32_t> &token2id, | 79 | const std::unordered_map<std::string, int32_t> &token2id, |
| 118 | const std::vector<std::string> &tokens) { | 80 | const std::vector<std::string> &tokens) { |
| @@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend { | @@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend { | ||
| 67 | bool debug_ = false; | 67 | bool debug_ = false; |
| 68 | }; | 68 | }; |
| 69 | 69 | ||
| 70 | -std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is); | ||
| 71 | - | ||
| 72 | -std::vector<int32_t> ConvertTokensToIds( | ||
| 73 | - const std::unordered_map<std::string, int32_t> &token2id, | ||
| 74 | - const std::vector<std::string> &tokens); | ||
| 75 | - | ||
| 76 | } // namespace sherpa_onnx | 70 | } // namespace sherpa_onnx |
| 77 | 71 | ||
| 78 | #endif // SHERPA_ONNX_CSRC_LEXICON_H_ | 72 | #endif // SHERPA_ONNX_CSRC_LEXICON_H_ |
| @@ -41,13 +41,13 @@ | @@ -41,13 +41,13 @@ | ||
| 41 | auto value = \ | 41 | auto value = \ |
| 42 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | 42 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ |
| 43 | if (!value) { \ | 43 | if (!value) { \ |
| 44 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | 44 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ |
| 45 | exit(-1); \ | 45 | exit(-1); \ |
| 46 | } \ | 46 | } \ |
| 47 | \ | 47 | \ |
| 48 | dst = atoi(value.get()); \ | 48 | dst = atoi(value.get()); \ |
| 49 | if (dst < 0) { \ | 49 | if (dst < 0) { \ |
| 50 | - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ | 50 | + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ |
| 51 | exit(-1); \ | 51 | exit(-1); \ |
| 52 | } \ | 52 | } \ |
| 53 | } while (0) | 53 | } while (0) |
| @@ -61,80 +61,80 @@ | @@ -61,80 +61,80 @@ | ||
| 61 | } else { \ | 61 | } else { \ |
| 62 | dst = atoi(value.get()); \ | 62 | dst = atoi(value.get()); \ |
| 63 | if (dst < 0) { \ | 63 | if (dst < 0) { \ |
| 64 | - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ | 64 | + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ |
| 65 | exit(-1); \ | 65 | exit(-1); \ |
| 66 | } \ | 66 | } \ |
| 67 | } \ | 67 | } \ |
| 68 | } while (0) | 68 | } while (0) |
| 69 | 69 | ||
| 70 | // read a vector of integers | 70 | // read a vector of integers |
| 71 | -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ | ||
| 72 | - do { \ | ||
| 73 | - auto value = \ | ||
| 74 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 75 | - if (!value) { \ | ||
| 76 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 77 | - exit(-1); \ | ||
| 78 | - } \ | ||
| 79 | - \ | ||
| 80 | - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 81 | - if (!ret) { \ | ||
| 82 | - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ | ||
| 83 | - exit(-1); \ | ||
| 84 | - } \ | 71 | +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ |
| 72 | + do { \ | ||
| 73 | + auto value = \ | ||
| 74 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 75 | + if (!value) { \ | ||
| 76 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 77 | + exit(-1); \ | ||
| 78 | + } \ | ||
| 79 | + \ | ||
| 80 | + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 81 | + if (!ret) { \ | ||
| 82 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ | ||
| 83 | + exit(-1); \ | ||
| 84 | + } \ | ||
| 85 | } while (0) | 85 | } while (0) |
| 86 | 86 | ||
| 87 | // read a vector of floats | 87 | // read a vector of floats |
| 88 | -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ | ||
| 89 | - do { \ | ||
| 90 | - auto value = \ | ||
| 91 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 92 | - if (!value) { \ | ||
| 93 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 94 | - exit(-1); \ | ||
| 95 | - } \ | ||
| 96 | - \ | ||
| 97 | - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ | ||
| 98 | - if (!ret) { \ | ||
| 99 | - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ | ||
| 100 | - exit(-1); \ | ||
| 101 | - } \ | 88 | +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ |
| 89 | + do { \ | ||
| 90 | + auto value = \ | ||
| 91 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 92 | + if (!value) { \ | ||
| 93 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 94 | + exit(-1); \ | ||
| 95 | + } \ | ||
| 96 | + \ | ||
| 97 | + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ | ||
| 98 | + if (!ret) { \ | ||
| 99 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ | ||
| 100 | + exit(-1); \ | ||
| 101 | + } \ | ||
| 102 | } while (0) | 102 | } while (0) |
| 103 | 103 | ||
| 104 | // read a vector of strings | 104 | // read a vector of strings |
| 105 | -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ | ||
| 106 | - do { \ | ||
| 107 | - auto value = \ | ||
| 108 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 109 | - if (!value) { \ | ||
| 110 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 111 | - exit(-1); \ | ||
| 112 | - } \ | ||
| 113 | - SplitStringToVector(value.get(), ",", false, &dst); \ | ||
| 114 | - \ | ||
| 115 | - if (dst.empty()) { \ | ||
| 116 | - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ | ||
| 117 | - src_key); \ | ||
| 118 | - exit(-1); \ | ||
| 119 | - } \ | 105 | +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ |
| 106 | + do { \ | ||
| 107 | + auto value = \ | ||
| 108 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 109 | + if (!value) { \ | ||
| 110 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 111 | + exit(-1); \ | ||
| 112 | + } \ | ||
| 113 | + SplitStringToVector(value.get(), ",", false, &dst); \ | ||
| 114 | + \ | ||
| 115 | + if (dst.empty()) { \ | ||
| 116 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 117 | + value.get(), src_key); \ | ||
| 118 | + exit(-1); \ | ||
| 119 | + } \ | ||
| 120 | } while (0) | 120 | } while (0) |
| 121 | 121 | ||
| 122 | // read a vector of strings separated by sep | 122 | // read a vector of strings separated by sep |
| 123 | -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ | ||
| 124 | - do { \ | ||
| 125 | - auto value = \ | ||
| 126 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 127 | - if (!value) { \ | ||
| 128 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 129 | - exit(-1); \ | ||
| 130 | - } \ | ||
| 131 | - SplitStringToVector(value.get(), sep, false, &dst); \ | ||
| 132 | - \ | ||
| 133 | - if (dst.empty()) { \ | ||
| 134 | - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ | ||
| 135 | - src_key); \ | ||
| 136 | - exit(-1); \ | ||
| 137 | - } \ | 123 | +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ |
| 124 | + do { \ | ||
| 125 | + auto value = \ | ||
| 126 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 127 | + if (!value) { \ | ||
| 128 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 129 | + exit(-1); \ | ||
| 130 | + } \ | ||
| 131 | + SplitStringToVector(value.get(), sep, false, &dst); \ | ||
| 132 | + \ | ||
| 133 | + if (dst.empty()) { \ | ||
| 134 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 135 | + value.get(), src_key); \ | ||
| 136 | + exit(-1); \ | ||
| 137 | + } \ | ||
| 138 | } while (0) | 138 | } while (0) |
| 139 | 139 | ||
| 140 | // Read a string | 140 | // Read a string |
| @@ -143,17 +143,29 @@ | @@ -143,17 +143,29 @@ | ||
| 143 | auto value = \ | 143 | auto value = \ |
| 144 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | 144 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ |
| 145 | if (!value) { \ | 145 | if (!value) { \ |
| 146 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | 146 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ |
| 147 | exit(-1); \ | 147 | exit(-1); \ |
| 148 | } \ | 148 | } \ |
| 149 | \ | 149 | \ |
| 150 | dst = value.get(); \ | 150 | dst = value.get(); \ |
| 151 | if (dst.empty()) { \ | 151 | if (dst.empty()) { \ |
| 152 | - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ | 152 | + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ |
| 153 | exit(-1); \ | 153 | exit(-1); \ |
| 154 | } \ | 154 | } \ |
| 155 | } while (0) | 155 | } while (0) |
| 156 | 156 | ||
| 157 | +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ | ||
| 158 | + do { \ | ||
| 159 | + auto value = \ | ||
| 160 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 161 | + if (!value) { \ | ||
| 162 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 163 | + exit(-1); \ | ||
| 164 | + } \ | ||
| 165 | + \ | ||
| 166 | + dst = value.get(); \ | ||
| 167 | + } while (0) | ||
| 168 | + | ||
| 157 | #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ | 169 | #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ |
| 158 | default_value) \ | 170 | default_value) \ |
| 159 | do { \ | 171 | do { \ |
| @@ -164,7 +176,7 @@ | @@ -164,7 +176,7 @@ | ||
| 164 | } else { \ | 176 | } else { \ |
| 165 | dst = value.get(); \ | 177 | dst = value.get(); \ |
| 166 | if (dst.empty()) { \ | 178 | if (dst.empty()) { \ |
| 167 | - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ | 179 | + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ |
| 168 | exit(-1); \ | 180 | exit(-1); \ |
| 169 | } \ | 181 | } \ |
| 170 | } \ | 182 | } \ |
| @@ -10,8 +10,8 @@ | @@ -10,8 +10,8 @@ | ||
| 10 | 10 | ||
| 11 | #include "cppjieba/Jieba.hpp" | 11 | #include "cppjieba/Jieba.hpp" |
| 12 | #include "sherpa-onnx/csrc/file-utils.h" | 12 | #include "sherpa-onnx/csrc/file-utils.h" |
| 13 | -#include "sherpa-onnx/csrc/lexicon.h" | ||
| 14 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 15 | #include "sherpa-onnx/csrc/text-utils.h" | 15 | #include "sherpa-onnx/csrc/text-utils.h" |
| 16 | 16 | ||
| 17 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -21,6 +21,7 @@ namespace { | @@ -21,6 +21,7 @@ namespace { | ||
| 21 | 21 | ||
| 22 | enum class ModelType : std::uint8_t { | 22 | enum class ModelType : std::uint8_t { |
| 23 | kEncDecCTCModelBPE, | 23 | kEncDecCTCModelBPE, |
| 24 | + kEncDecCTCModel, | ||
| 24 | kEncDecHybridRNNTCTCBPEModel, | 25 | kEncDecHybridRNNTCTCBPEModel, |
| 25 | kTdnn, | 26 | kTdnn, |
| 26 | kZipformerCtc, | 27 | kZipformerCtc, |
| @@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 75 | 76 | ||
| 76 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { | 77 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { |
| 77 | return ModelType::kEncDecCTCModelBPE; | 78 | return ModelType::kEncDecCTCModelBPE; |
| 79 | + } else if (model_type.get() == std::string("EncDecCTCModel")) { | ||
| 80 | + return ModelType::kEncDecCTCModel; | ||
| 78 | } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { | 81 | } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { |
| 79 | return ModelType::kEncDecHybridRNNTCTCBPEModel; | 82 | return ModelType::kEncDecHybridRNNTCTCBPEModel; |
| 80 | } else if (model_type.get() == std::string("tdnn")) { | 83 | } else if (model_type.get() == std::string("tdnn")) { |
| @@ -121,22 +124,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -121,22 +124,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 121 | switch (model_type) { | 124 | switch (model_type) { |
| 122 | case ModelType::kEncDecCTCModelBPE: | 125 | case ModelType::kEncDecCTCModelBPE: |
| 123 | return std::make_unique<OfflineNemoEncDecCtcModel>(config); | 126 | return std::make_unique<OfflineNemoEncDecCtcModel>(config); |
| 124 | - break; | 127 | + case ModelType::kEncDecCTCModel: |
| 128 | + return std::make_unique<OfflineNemoEncDecCtcModel>(config); | ||
| 125 | case ModelType::kEncDecHybridRNNTCTCBPEModel: | 129 | case ModelType::kEncDecHybridRNNTCTCBPEModel: |
| 126 | return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(config); | 130 | return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(config); |
| 127 | - break; | ||
| 128 | case ModelType::kTdnn: | 131 | case ModelType::kTdnn: |
| 129 | return std::make_unique<OfflineTdnnCtcModel>(config); | 132 | return std::make_unique<OfflineTdnnCtcModel>(config); |
| 130 | - break; | ||
| 131 | case ModelType::kZipformerCtc: | 133 | case ModelType::kZipformerCtc: |
| 132 | return std::make_unique<OfflineZipformerCtcModel>(config); | 134 | return std::make_unique<OfflineZipformerCtcModel>(config); |
| 133 | - break; | ||
| 134 | case ModelType::kWenetCtc: | 135 | case ModelType::kWenetCtc: |
| 135 | return std::make_unique<OfflineWenetCtcModel>(config); | 136 | return std::make_unique<OfflineWenetCtcModel>(config); |
| 136 | - break; | ||
| 137 | case ModelType::kTeleSpeechCtc: | 137 | case ModelType::kTeleSpeechCtc: |
| 138 | return std::make_unique<OfflineTeleSpeechCtcModel>(config); | 138 | return std::make_unique<OfflineTeleSpeechCtcModel>(config); |
| 139 | - break; | ||
| 140 | case ModelType::kUnknown: | 139 | case ModelType::kUnknown: |
| 141 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 140 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 142 | return nullptr; | 141 | return nullptr; |
| @@ -177,23 +176,19 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -177,23 +176,19 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 177 | switch (model_type) { | 176 | switch (model_type) { |
| 178 | case ModelType::kEncDecCTCModelBPE: | 177 | case ModelType::kEncDecCTCModelBPE: |
| 179 | return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config); | 178 | return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config); |
| 180 | - break; | 179 | + case ModelType::kEncDecCTCModel: |
| 180 | + return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config); | ||
| 181 | case ModelType::kEncDecHybridRNNTCTCBPEModel: | 181 | case ModelType::kEncDecHybridRNNTCTCBPEModel: |
| 182 | return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(mgr, | 182 | return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(mgr, |
| 183 | config); | 183 | config); |
| 184 | - break; | ||
| 185 | case ModelType::kTdnn: | 184 | case ModelType::kTdnn: |
| 186 | return std::make_unique<OfflineTdnnCtcModel>(mgr, config); | 185 | return std::make_unique<OfflineTdnnCtcModel>(mgr, config); |
| 187 | - break; | ||
| 188 | case ModelType::kZipformerCtc: | 186 | case ModelType::kZipformerCtc: |
| 189 | return std::make_unique<OfflineZipformerCtcModel>(mgr, config); | 187 | return std::make_unique<OfflineZipformerCtcModel>(mgr, config); |
| 190 | - break; | ||
| 191 | case ModelType::kWenetCtc: | 188 | case ModelType::kWenetCtc: |
| 192 | return std::make_unique<OfflineWenetCtcModel>(mgr, config); | 189 | return std::make_unique<OfflineWenetCtcModel>(mgr, config); |
| 193 | - break; | ||
| 194 | case ModelType::kTeleSpeechCtc: | 190 | case ModelType::kTeleSpeechCtc: |
| 195 | return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config); | 191 | return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config); |
| 196 | - break; | ||
| 197 | case ModelType::kUnknown: | 192 | case ModelType::kUnknown: |
| 198 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 193 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 199 | return nullptr; | 194 | return nullptr; |
| @@ -66,6 +66,10 @@ class OfflineCtcModel { | @@ -66,6 +66,10 @@ class OfflineCtcModel { | ||
| 66 | 66 | ||
| 67 | // Return true if the model supports batch size > 1 | 67 | // Return true if the model supports batch size > 1 |
| 68 | virtual bool SupportBatchProcessing() const { return true; } | 68 | virtual bool SupportBatchProcessing() const { return true; } |
| 69 | + | ||
| 70 | + // return true for models from https://github.com/salute-developers/GigaAM | ||
| 71 | + // return false otherwise | ||
| 72 | + virtual bool IsGigaAM() const { return false; } | ||
| 69 | }; | 73 | }; |
| 70 | 74 | ||
| 71 | } // namespace sherpa_onnx | 75 | } // namespace sherpa_onnx |
| @@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 72 | 72 | ||
| 73 | std::string FeatureNormalizationMethod() const { return normalize_type_; } | 73 | std::string FeatureNormalizationMethod() const { return normalize_type_; } |
| 74 | 74 | ||
| 75 | + bool IsGigaAM() const { return is_giga_am_; } | ||
| 76 | + | ||
| 75 | private: | 77 | private: |
| 76 | void Init(void *model_data, size_t model_data_length) { | 78 | void Init(void *model_data, size_t model_data_length) { |
| 77 | sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | 79 | sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, |
| @@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 92 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 94 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 93 | SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | 95 | SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); |
| 94 | SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | 96 | SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); |
| 95 | - SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); | 97 | + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_, |
| 98 | + "normalize_type"); | ||
| 99 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); | ||
| 96 | } | 100 | } |
| 97 | 101 | ||
| 98 | private: | 102 | private: |
| @@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 112 | int32_t vocab_size_ = 0; | 116 | int32_t vocab_size_ = 0; |
| 113 | int32_t subsampling_factor_ = 0; | 117 | int32_t subsampling_factor_ = 0; |
| 114 | std::string normalize_type_; | 118 | std::string normalize_type_; |
| 119 | + | ||
| 120 | + // it is 1 for models from | ||
| 121 | + // https://github.com/salute-developers/GigaAM | ||
| 122 | + int32_t is_giga_am_ = 0; | ||
| 115 | }; | 123 | }; |
| 116 | 124 | ||
| 117 | OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( | 125 | OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( |
| @@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { | @@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { | ||
| 146 | return impl_->FeatureNormalizationMethod(); | 154 | return impl_->FeatureNormalizationMethod(); |
| 147 | } | 155 | } |
| 148 | 156 | ||
| 157 | +bool OfflineNemoEncDecCtcModel::IsGigaAM() const { return impl_->IsGigaAM(); } | ||
| 158 | + | ||
| 149 | } // namespace sherpa_onnx | 159 | } // namespace sherpa_onnx |
| @@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { | @@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { | ||
| 76 | // for details | 76 | // for details |
| 77 | std::string FeatureNormalizationMethod() const override; | 77 | std::string FeatureNormalizationMethod() const override; |
| 78 | 78 | ||
| 79 | + bool IsGigaAM() const override; | ||
| 80 | + | ||
| 79 | private: | 81 | private: |
| 80 | class Impl; | 82 | class Impl; |
| 81 | std::unique_ptr<Impl> impl_; | 83 | std::unique_ptr<Impl> impl_; |
| @@ -104,11 +104,20 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -104,11 +104,20 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 104 | } | 104 | } |
| 105 | 105 | ||
| 106 | if (!config_.model_config.nemo_ctc.model.empty()) { | 106 | if (!config_.model_config.nemo_ctc.model.empty()) { |
| 107 | - config_.feat_config.low_freq = 0; | ||
| 108 | - config_.feat_config.high_freq = 0; | ||
| 109 | - config_.feat_config.is_librosa = true; | ||
| 110 | - config_.feat_config.remove_dc_offset = false; | ||
| 111 | - config_.feat_config.window_type = "hann"; | 107 | + if (model_->IsGigaAM()) { |
| 108 | + config_.feat_config.low_freq = 0; | ||
| 109 | + config_.feat_config.high_freq = 8000; | ||
| 110 | + config_.feat_config.remove_dc_offset = false; | ||
| 111 | + config_.feat_config.preemph_coeff = 0; | ||
| 112 | + config_.feat_config.window_type = "hann"; | ||
| 113 | + config_.feat_config.feature_dim = 64; | ||
| 114 | + } else { | ||
| 115 | + config_.feat_config.low_freq = 0; | ||
| 116 | + config_.feat_config.high_freq = 0; | ||
| 117 | + config_.feat_config.is_librosa = true; | ||
| 118 | + config_.feat_config.remove_dc_offset = false; | ||
| 119 | + config_.feat_config.window_type = "hann"; | ||
| 120 | + } | ||
| 112 | } | 121 | } |
| 113 | 122 | ||
| 114 | if (!config_.model_config.wenet_ctc.model.empty()) { | 123 | if (!config_.model_config.wenet_ctc.model.empty()) { |
| @@ -172,7 +172,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -172,7 +172,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 172 | return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config); | 172 | return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config); |
| 173 | } | 173 | } |
| 174 | 174 | ||
| 175 | - if (model_type == "EncDecCTCModelBPE" || | 175 | + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || |
| 176 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || | 176 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || |
| 177 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || | 177 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || |
| 178 | model_type == "telespeech_ctc") { | 178 | model_type == "telespeech_ctc") { |
| @@ -189,6 +189,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -189,6 +189,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 189 | " - Non-streaming transducer models from icefall\n" | 189 | " - Non-streaming transducer models from icefall\n" |
| 190 | " - Non-streaming Paraformer models from FunASR\n" | 190 | " - Non-streaming Paraformer models from FunASR\n" |
| 191 | " - EncDecCTCModelBPE models from NeMo\n" | 191 | " - EncDecCTCModelBPE models from NeMo\n" |
| 192 | + " - EncDecCTCModel models from NeMo\n" | ||
| 192 | " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" | 193 | " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" |
| 193 | " - Whisper models\n" | 194 | " - Whisper models\n" |
| 194 | " - Tdnn models\n" | 195 | " - Tdnn models\n" |
| @@ -343,7 +344,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -343,7 +344,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 343 | return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config); | 344 | return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config); |
| 344 | } | 345 | } |
| 345 | 346 | ||
| 346 | - if (model_type == "EncDecCTCModelBPE" || | 347 | + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || |
| 347 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || | 348 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || |
| 348 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || | 349 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || |
| 349 | model_type == "telespeech_ctc") { | 350 | model_type == "telespeech_ctc") { |
| @@ -360,6 +361,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -360,6 +361,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 360 | " - Non-streaming transducer models from icefall\n" | 361 | " - Non-streaming transducer models from icefall\n" |
| 361 | " - Non-streaming Paraformer models from FunASR\n" | 362 | " - Non-streaming Paraformer models from FunASR\n" |
| 362 | " - EncDecCTCModelBPE models from NeMo\n" | 363 | " - EncDecCTCModelBPE models from NeMo\n" |
| 364 | + " - EncDecCTCModel models from NeMo\n" | ||
| 363 | " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" | 365 | " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" |
| 364 | " - Whisper models\n" | 366 | " - Whisper models\n" |
| 365 | " - Tdnn models\n" | 367 | " - Tdnn models\n" |
| @@ -7,6 +7,8 @@ | @@ -7,6 +7,8 @@ | ||
| 7 | #include <cassert> | 7 | #include <cassert> |
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | #include <sstream> | 9 | #include <sstream> |
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 10 | 12 | ||
| 11 | #if __ANDROID_API__ >= 9 | 13 | #if __ANDROID_API__ >= 9 |
| 12 | #include <strstream> | 14 | #include <strstream> |
| @@ -16,10 +18,54 @@ | @@ -16,10 +18,54 @@ | ||
| 16 | #endif | 18 | #endif |
| 17 | 19 | ||
| 18 | #include "sherpa-onnx/csrc/base64-decode.h" | 20 | #include "sherpa-onnx/csrc/base64-decode.h" |
| 21 | +#include "sherpa-onnx/csrc/lexicon.h" | ||
| 19 | #include "sherpa-onnx/csrc/onnx-utils.h" | 22 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 20 | 23 | ||
| 21 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 22 | 25 | ||
| 26 | +std::unordered_map<std::string, int32_t> ReadTokens( | ||
| 27 | + std::istream &is, | ||
| 28 | + std::unordered_map<int32_t, std::string> *id2token /*= nullptr*/) { | ||
| 29 | + std::unordered_map<std::string, int32_t> token2id; | ||
| 30 | + | ||
| 31 | + std::string line; | ||
| 32 | + | ||
| 33 | + std::string sym; | ||
| 34 | + int32_t id = -1; | ||
| 35 | + while (std::getline(is, line)) { | ||
| 36 | + std::istringstream iss(line); | ||
| 37 | + iss >> sym; | ||
| 38 | + if (iss.eof()) { | ||
| 39 | + id = atoi(sym.c_str()); | ||
| 40 | + sym = " "; | ||
| 41 | + } else { | ||
| 42 | + iss >> id; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + // eat the trailing \r\n on windows | ||
| 46 | + iss >> std::ws; | ||
| 47 | + if (!iss.eof()) { | ||
| 48 | + SHERPA_ONNX_LOGE("Error: %s", line.c_str()); | ||
| 49 | + exit(-1); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | +#if 0 | ||
| 53 | + if (token2id.count(sym)) { | ||
| 54 | + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
| 55 | + sym.c_str(), line.c_str(), token2id.at(sym)); | ||
| 56 | + exit(-1); | ||
| 57 | + } | ||
| 58 | +#endif | ||
| 59 | + if (id2token) { | ||
| 60 | + id2token->insert({id, sym}); | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + token2id.insert({std::move(sym), id}); | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + return token2id; | ||
| 67 | +} | ||
| 68 | + | ||
| 23 | SymbolTable::SymbolTable(const std::string &filename, bool is_file) { | 69 | SymbolTable::SymbolTable(const std::string &filename, bool is_file) { |
| 24 | if (is_file) { | 70 | if (is_file) { |
| 25 | std::ifstream is(filename); | 71 | std::ifstream is(filename); |
| @@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { | @@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { | ||
| 39 | } | 85 | } |
| 40 | #endif | 86 | #endif |
| 41 | 87 | ||
| 42 | -void SymbolTable::Init(std::istream &is) { | ||
| 43 | - std::string sym; | ||
| 44 | - int32_t id = 0; | ||
| 45 | - while (is >> sym >> id) { | ||
| 46 | -#if 0 | ||
| 47 | - // we disable the test here since for some multi-lingual BPE models | ||
| 48 | - // from NeMo, the same symbol can appear multiple times with different IDs. | ||
| 49 | - if (sym != " ") { | ||
| 50 | - assert(sym2id_.count(sym) == 0); | ||
| 51 | - } | ||
| 52 | -#endif | ||
| 53 | - | ||
| 54 | - assert(id2sym_.count(id) == 0); | ||
| 55 | - | ||
| 56 | - sym2id_.insert({sym, id}); | ||
| 57 | - id2sym_.insert({id, sym}); | ||
| 58 | - } | ||
| 59 | - assert(is.eof()); | ||
| 60 | -} | 88 | +void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); } |
| 61 | 89 | ||
| 62 | std::string SymbolTable::ToString() const { | 90 | std::string SymbolTable::ToString() const { |
| 63 | std::ostringstream os; | 91 | std::ostringstream os; |
| @@ -5,8 +5,10 @@ | @@ -5,8 +5,10 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | 6 | #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ |
| 7 | 7 | ||
| 8 | +#include <istream> | ||
| 8 | #include <string> | 9 | #include <string> |
| 9 | #include <unordered_map> | 10 | #include <unordered_map> |
| 11 | +#include <vector> | ||
| 10 | 12 | ||
| 11 | #if __ANDROID_API__ >= 9 | 13 | #if __ANDROID_API__ >= 9 |
| 12 | #include "android/asset_manager.h" | 14 | #include "android/asset_manager.h" |
| @@ -15,6 +17,16 @@ | @@ -15,6 +17,16 @@ | ||
| 15 | 17 | ||
| 16 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 17 | 19 | ||
| 20 | +// The same token can be mapped to different integer IDs, so | ||
| 21 | +// we need an id2token argument here. | ||
| 22 | +std::unordered_map<std::string, int32_t> ReadTokens( | ||
| 23 | + std::istream &is, | ||
| 24 | + std::unordered_map<int32_t, std::string> *id2token = nullptr); | ||
| 25 | + | ||
| 26 | +std::vector<int32_t> ConvertTokensToIds( | ||
| 27 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 28 | + const std::vector<std::string> &tokens); | ||
| 29 | + | ||
| 18 | /// It manages mapping between symbols and integer IDs. | 30 | /// It manages mapping between symbols and integer IDs. |
| 19 | class SymbolTable { | 31 | class SymbolTable { |
| 20 | public: | 32 | public: |
| @@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { | @@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { | ||
| 394 | modelType = "transducer", | 394 | modelType = "transducer", |
| 395 | ) | 395 | ) |
| 396 | } | 396 | } |
| 397 | + | ||
| 398 | + 19 -> { | ||
| 399 | + val modelDir = "sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24" | ||
| 400 | + return OfflineModelConfig( | ||
| 401 | + nemo = OfflineNemoEncDecCtcModelConfig( | ||
| 402 | + model = "$modelDir/model.int8.onnx", | ||
| 403 | + ), | ||
| 404 | + tokens = "$modelDir/tokens.txt", | ||
| 405 | + ) | ||
| 406 | + } | ||
| 397 | } | 407 | } |
| 398 | return null | 408 | return null |
| 399 | } | 409 | } |
-
请 注册 或 登录 后发表评论