Fangjun Kuang
Committed by GitHub

Support GigaAM CTC models for Russian ASR (#1464)

See also https://github.com/salute-developers/GigaAM
@@ -16,6 +16,21 @@ echo "PATH: $PATH" @@ -16,6 +16,21 @@ echo "PATH: $PATH"
16 which $EXE 16 which $EXE
17 17
18 log "------------------------------------------------------------" 18 log "------------------------------------------------------------"
  19 +log "Run NeMo GigaAM Russian models"
  20 +log "------------------------------------------------------------"
  21 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
  22 +tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
  23 +rm sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
  24 +
  25 +$EXE \
  26 + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/model.int8.onnx \
  27 + --tokens=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/tokens.txt \
  28 + --debug=1 \
  29 + ./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/test_wavs/example.wav
  30 +
  31 +rm -rf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
  32 +
  33 +log "------------------------------------------------------------"
19 log "Run SenseVoice models" 34 log "Run SenseVoice models"
20 log "------------------------------------------------------------" 35 log "------------------------------------------------------------"
21 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 36 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  1 +name: export-nemo-giga-am-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-nemo-giga-am-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-nemo-am-giga-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export nemo GigaAM models to ONNX
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Run CTC
  30 + shell: bash
  31 + run: |
  32 + pushd scripts/nemo/GigaAM
  33 + ./run-ctc.sh
  34 + popd
  35 +
  36 + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
  37 + mkdir $d
  38 + mkdir $d/test_wavs
  39 + rm scripts/nemo/GigaAM/model.onnx
  40 + mv -v scripts/nemo/GigaAM/*.int8.onnx $d/
  41 + mv -v scripts/nemo/GigaAM/*.md $d/
  42 + mv -v scripts/nemo/GigaAM/*.pdf $d/
  43 + mv -v scripts/nemo/GigaAM/tokens.txt $d/
  44 + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
  45 + mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
  46 + mv -v scripts/nemo/GigaAM/*-ctc.py $d/
  47 +
  48 + ls -lh scripts/nemo/GigaAM/
  49 +
  50 + ls -lh $d
  51 +
  52 + tar cjvf ${d}.tar.bz2 $d
  53 +
  54 + - name: Release
  55 + uses: svenstaro/upload-release-action@v2
  56 + with:
  57 + file_glob: true
  58 + file: ./*.tar.bz2
  59 + overwrite: true
  60 + repo_name: k2-fsa/sherpa-onnx
  61 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  62 + tag: asr-models
  63 +
  64 + - name: Publish to huggingface (CTC)
  65 + env:
  66 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  67 + uses: nick-fields/retry@v3
  68 + with:
  69 + max_attempts: 20
  70 + timeout_seconds: 200
  71 + shell: bash
  72 + command: |
  73 + git config --global user.email "csukuangfj@gmail.com"
  74 + git config --global user.name "Fangjun Kuang"
  75 +
  76 + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
  77 + export GIT_LFS_SKIP_SMUDGE=1
  78 + export GIT_CLONE_PROTECTION_ACTIVE=false
  79 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
  80 + mv -v $d/* ./huggingface
  81 + cd huggingface
  82 + git lfs track "*.onnx"
  83 + git lfs track "*.wav"
  84 + git status
  85 + git add .
  86 + git status
  87 + git commit -m "add models"
  88 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
@@ -149,6 +149,16 @@ jobs: @@ -149,6 +149,16 @@ jobs:
149 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 149 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
150 path: install/* 150 path: install/*
151 151
  152 + - name: Test offline CTC
  153 + shell: bash
  154 + run: |
  155 + du -h -d1 .
  156 + export PATH=$PWD/build/bin:$PATH
  157 + export EXE=sherpa-onnx-offline
  158 +
  159 + .github/scripts/test-offline-ctc.sh
  160 + du -h -d1 .
  161 +
152 - name: Test C++ API 162 - name: Test C++ API
153 shell: bash 163 shell: bash
154 run: | 164 run: |
@@ -180,16 +190,6 @@ jobs: @@ -180,16 +190,6 @@ jobs:
180 .github/scripts/test-offline-transducer.sh 190 .github/scripts/test-offline-transducer.sh
181 du -h -d1 . 191 du -h -d1 .
182 192
183 - - name: Test offline CTC  
184 - shell: bash  
185 - run: |  
186 - du -h -d1 .  
187 - export PATH=$PWD/build/bin:$PATH  
188 - export EXE=sherpa-onnx-offline  
189 -  
190 - .github/scripts/test-offline-ctc.sh  
191 - du -h -d1 .  
192 -  
193 - name: Test online punctuation 193 - name: Test online punctuation
194 shell: bash 194 shell: bash
195 run: | 195 run: |
@@ -336,6 +336,24 @@ def get_models(): @@ -336,6 +336,24 @@ def get_models():
336 popd 336 popd
337 """, 337 """,
338 ), 338 ),
  339 + Model(
  340 + model_name="sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24",
  341 + idx=19,
  342 + lang="ru",
  343 + short_name="nemo_ctc_giga_am",
  344 + cmd="""
  345 + pushd $model_name
  346 +
  347 + rm -rfv test_wavs
  348 +
  349 + rm -fv *.sh
  350 + rm -fv *.py
  351 +
  352 + ls -lh
  353 +
  354 + popd
  355 + """,
  356 + ),
339 ] 357 ]
340 return models 358 return models
341 359
  1 +# Introduction
  2 +
  3 +This folder contains scripts for converting models from
  4 +https://github.com/salute-developers/GigaAM
  5 +to sherpa-onnx.
  6 +
  7 +The ASR models are for Russian speech recognition in this folder.
  8 +
  9 +Please see the license of the models at
  10 +https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +from typing import Dict
  4 +
  5 +import onnx
  6 +import torch
  7 +import torchaudio
  8 +from nemo.collections.asr.models import EncDecCTCModel
  9 +from nemo.collections.asr.modules.audio_preprocessing import (
  10 + AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor,
  11 +)
  12 +from nemo.collections.asr.parts.preprocessing.features import (
  13 + FilterbankFeaturesTA as NeMoFilterbankFeaturesTA,
  14 +)
  15 +from onnxruntime.quantization import QuantType, quantize_dynamic
  16 +
  17 +
  18 +class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA):
  19 + def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs):
  20 + if "window_size" in kwargs:
  21 + del kwargs["window_size"]
  22 + if "window_stride" in kwargs:
  23 + del kwargs["window_stride"]
  24 +
  25 + super().__init__(**kwargs)
  26 +
  27 + self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = (
  28 + torchaudio.transforms.MelSpectrogram(
  29 + sample_rate=self._sample_rate,
  30 + win_length=self.win_length,
  31 + hop_length=self.hop_length,
  32 + n_mels=kwargs["nfilt"],
  33 + window_fn=self.torch_windows[kwargs["window"]],
  34 + mel_scale=mel_scale,
  35 + norm=kwargs["mel_norm"],
  36 + n_fft=kwargs["n_fft"],
  37 + f_max=kwargs.get("highfreq", None),
  38 + f_min=kwargs.get("lowfreq", 0),
  39 + wkwargs=wkwargs,
  40 + )
  41 + )
  42 +
  43 +
  44 +class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor):
  45 + def __init__(self, mel_scale: str = "htk", **kwargs):
  46 + super().__init__(**kwargs)
  47 + kwargs["nfilt"] = kwargs["features"]
  48 + del kwargs["features"]
  49 + self.featurizer = (
  50 + FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility
  51 + mel_scale=mel_scale,
  52 + **kwargs,
  53 + )
  54 + )
  55 +
  56 +
  57 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  58 + """Add meta data to an ONNX model. It is changed in-place.
  59 +
  60 + Args:
  61 + filename:
  62 + Filename of the ONNX model to be changed.
  63 + meta_data:
  64 + Key-value pairs.
  65 + """
  66 + model = onnx.load(filename)
  67 + while len(model.metadata_props):
  68 + model.metadata_props.pop()
  69 +
  70 + for key, value in meta_data.items():
  71 + meta = model.metadata_props.add()
  72 + meta.key = key
  73 + meta.value = str(value)
  74 +
  75 + onnx.save(model, filename)
  76 +
  77 +
  78 +def main():
  79 + model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml")
  80 + ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu")
  81 + model.load_state_dict(ckpt, strict=False)
  82 + model.eval()
  83 +
  84 + with open("tokens.txt", "w", encoding="utf-8") as f:
  85 + for i, t in enumerate(model.cfg.labels):
  86 + f.write(f"{t} {i}\n")
  87 + f.write(f"<blk> {i+1}\n")
  88 +
  89 + filename = "model.onnx"
  90 + model.export(filename)
  91 +
  92 + meta_data = {
  93 + "vocab_size": len(model.cfg.labels) + 1,
  94 + "normalize_type": "",
  95 + "subsampling_factor": 4,
  96 + "model_type": "EncDecCTCModel",
  97 + "version": "1",
  98 + "model_author": "https://github.com/salute-developers/GigaAM",
  99 + "license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf",
  100 + "language": "Russian",
  101 + "is_giga_am": 1,
  102 + }
  103 + add_meta_data(filename, meta_data)
  104 +
  105 + filename_int8 = "model.int8.onnx"
  106 + quantize_dynamic(
  107 + model_input=filename,
  108 + model_output=filename_int8,
  109 + weight_type=QuantType.QUInt8,
  110 + )
  111 +
  112 +
  113 +if __name__ == "__main__":
  114 + main()
  1 +#!/usr/bin/env bash
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +set -ex
  5 +
  6 +function install_nemo() {
  7 + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
  8 + python3 get-pip.py
  9 +
  10 + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
  11 +
  12 + pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa
  13 + pip install -qq ipython
  14 +
  15 + # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
  16 +
  17 + BRANCH='main'
  18 + python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]
  19 +
  20 + pip install numpy==1.26.4
  21 +}
  22 +
  23 +function download_files() {
  24 + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt
  25 + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml
  26 + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav
  27 + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav
  28 + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf
  29 +}
  30 +
  31 +install_nemo
  32 +download_files
  33 +
  34 +python3 ./export-onnx-ctc.py
  35 +ls -lh
  36 +python3 ./test-onnx-ctc.py
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# https://github.com/salute-developers/GigaAM
  5 +
  6 +import kaldi_native_fbank as knf
  7 +import librosa
  8 +import numpy as np
  9 +import onnxruntime as ort
  10 +import soundfile as sf
  11 +import torch
  12 +
  13 +
  14 +def create_fbank():
  15 + opts = knf.FbankOptions()
  16 + opts.frame_opts.dither = 0
  17 + opts.frame_opts.remove_dc_offset = False
  18 + opts.frame_opts.preemph_coeff = 0
  19 + opts.frame_opts.window_type = "hann"
  20 +
  21 + # Even though GigaAM uses 400 for fft, here we use 512
  22 + # since kaldi-native-fbank only support fft for power of 2.
  23 + opts.frame_opts.round_to_power_of_two = True
  24 +
  25 + opts.mel_opts.low_freq = 0
  26 + opts.mel_opts.high_freq = 8000
  27 + opts.mel_opts.num_bins = 64
  28 +
  29 + fbank = knf.OnlineFbank(opts)
  30 + return fbank
  31 +
  32 +
  33 +def compute_features(audio, fbank) -> np.ndarray:
  34 + """
  35 + Args:
  36 + audio: (num_samples,), np.float32
  37 + fbank: the fbank extractor
  38 + Returns:
  39 + features: (num_frames, feat_dim), np.float32
  40 + """
  41 + assert len(audio.shape) == 1, audio.shape
  42 + fbank.accept_waveform(16000, audio)
  43 + ans = []
  44 + processed = 0
  45 + while processed < fbank.num_frames_ready:
  46 + ans.append(np.array(fbank.get_frame(processed)))
  47 + processed += 1
  48 + ans = np.stack(ans)
  49 + return ans
  50 +
  51 +
  52 +def display(sess):
  53 + print("==========Input==========")
  54 + for i in sess.get_inputs():
  55 + print(i)
  56 + print("==========Output==========")
  57 + for i in sess.get_outputs():
  58 + print(i)
  59 +
  60 +
  61 +"""
  62 +==========Input==========
  63 +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2'])
  64 +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
  65 +==========Output==========
  66 +NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 34])
  67 +"""
  68 +
  69 +
  70 +class OnnxModel:
  71 + def __init__(
  72 + self,
  73 + filename: str,
  74 + ):
  75 + session_opts = ort.SessionOptions()
  76 + session_opts.inter_op_num_threads = 1
  77 + session_opts.intra_op_num_threads = 1
  78 +
  79 + self.model = ort.InferenceSession(
  80 + filename,
  81 + sess_options=session_opts,
  82 + providers=["CPUExecutionProvider"],
  83 + )
  84 + display(self.model)
  85 +
  86 + def __call__(self, x: np.ndarray):
  87 + # x: (T, C)
  88 + x = torch.from_numpy(x)
  89 + x = x.t().unsqueeze(0)
  90 + # x: [1, C, T]
  91 + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
  92 +
  93 + log_probs = self.model.run(
  94 + [
  95 + self.model.get_outputs()[0].name,
  96 + ],
  97 + {
  98 + self.model.get_inputs()[0].name: x.numpy(),
  99 + self.model.get_inputs()[1].name: x_lens.numpy(),
  100 + },
  101 + )[0]
  102 + # [batch_size, T, dim]
  103 + return log_probs
  104 +
  105 +
  106 +def main():
  107 + filename = "./model.int8.onnx"
  108 + tokens = "./tokens.txt"
  109 + wav = "./example.wav"
  110 +
  111 + model = OnnxModel(filename)
  112 +
  113 + id2token = dict()
  114 + with open(tokens, encoding="utf-8") as f:
  115 + for line in f:
  116 + fields = line.split()
  117 + if len(fields) == 1:
  118 + id2token[int(fields[0])] = " "
  119 + else:
  120 + t, idx = fields
  121 + id2token[int(idx)] = t
  122 +
  123 + fbank = create_fbank()
  124 + audio, sample_rate = sf.read(wav, dtype="float32", always_2d=True)
  125 + audio = audio[:, 0] # only use the first channel
  126 + if sample_rate != 16000:
  127 + audio = librosa.resample(
  128 + audio,
  129 + orig_sr=sample_rate,
  130 + target_sr=16000,
  131 + )
  132 + sample_rate = 16000
  133 +
  134 + features = compute_features(audio, fbank)
  135 + print("features.shape", features.shape)
  136 +
  137 + blank = len(id2token) - 1
  138 + prev = -1
  139 + ans = []
  140 + log_probs = model(features)
  141 + print("log_probs", log_probs.shape)
  142 + log_probs = torch.from_numpy(log_probs)[0]
  143 + ids = torch.argmax(log_probs, dim=1).tolist()
  144 + for i in ids:
  145 + if i != blank and i != prev:
  146 + ans.append(i)
  147 + prev = i
  148 +
  149 + tokens = [id2token[i] for i in ans]
  150 +
  151 + text = "".join(tokens)
  152 + print(wav)
  153 + print(text)
  154 +
  155 +
  156 +if __name__ == "__main__":
  157 + main()
@@ -193,6 +193,7 @@ class FeatureExtractor::Impl { @@ -193,6 +193,7 @@ class FeatureExtractor::Impl {
193 opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; 193 opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
194 opts_.frame_opts.frame_length_ms = config_.frame_length_ms; 194 opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
195 opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; 195 opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
  196 + opts_.frame_opts.preemph_coeff = config_.preemph_coeff;
196 opts_.frame_opts.window_type = config_.window_type; 197 opts_.frame_opts.window_type = config_.window_type;
197 198
198 opts_.mel_opts.num_bins = config_.feature_dim; 199 opts_.mel_opts.num_bins = config_.feature_dim;
@@ -211,6 +212,7 @@ class FeatureExtractor::Impl { @@ -211,6 +212,7 @@ class FeatureExtractor::Impl {
211 mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; 212 mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
212 mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; 213 mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
213 mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; 214 mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
  215 + mfcc_opts_.frame_opts.preemph_coeff = config_.preemph_coeff;
214 mfcc_opts_.frame_opts.window_type = config_.window_type; 216 mfcc_opts_.frame_opts.window_type = config_.window_type;
215 217
216 mfcc_opts_.mel_opts.num_bins = config_.feature_dim; 218 mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
@@ -57,6 +57,7 @@ struct FeatureExtractorConfig { @@ -57,6 +57,7 @@ struct FeatureExtractorConfig {
57 float frame_length_ms = 25.0f; // in milliseconds. 57 float frame_length_ms = 25.0f; // in milliseconds.
58 bool is_librosa = false; 58 bool is_librosa = false;
59 bool remove_dc_offset = true; // Subtract mean of wave before FFT. 59 bool remove_dc_offset = true; // Subtract mean of wave before FFT.
  60 + float preemph_coeff = 0.97f; // Preemphasis coefficient.
60 std::string window_type = "povey"; // e.g. Hamming window 61 std::string window_type = "povey"; // e.g. Hamming window
61 62
62 // For models from NeMo 63 // For models from NeMo
@@ -10,8 +10,8 @@ @@ -10,8 +10,8 @@
10 10
11 #include "cppjieba/Jieba.hpp" 11 #include "cppjieba/Jieba.hpp"
12 #include "sherpa-onnx/csrc/file-utils.h" 12 #include "sherpa-onnx/csrc/file-utils.h"
13 -#include "sherpa-onnx/csrc/lexicon.h"  
14 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/symbol-table.h"
15 #include "sherpa-onnx/csrc/text-utils.h" 15 #include "sherpa-onnx/csrc/text-utils.h"
16 16
17 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -21,6 +21,7 @@ @@ -21,6 +21,7 @@
21 21
22 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
23 #include "sherpa-onnx/csrc/onnx-utils.h" 23 #include "sherpa-onnx/csrc/onnx-utils.h"
  24 +#include "sherpa-onnx/csrc/symbol-table.h"
24 #include "sherpa-onnx/csrc/text-utils.h" 25 #include "sherpa-onnx/csrc/text-utils.h"
25 26
26 namespace sherpa_onnx { 27 namespace sherpa_onnx {
@@ -74,45 +75,6 @@ static std::vector<std::string> ProcessHeteronyms( @@ -74,45 +75,6 @@ static std::vector<std::string> ProcessHeteronyms(
74 return ans; 75 return ans;
75 } 76 }
76 77
77 -// Note: We don't use SymbolTable here since tokens may contain a blank  
78 -// in the first column  
79 -std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {  
80 - std::unordered_map<std::string, int32_t> token2id;  
81 -  
82 - std::string line;  
83 -  
84 - std::string sym;  
85 - int32_t id = -1;  
86 - while (std::getline(is, line)) {  
87 - std::istringstream iss(line);  
88 - iss >> sym;  
89 - if (iss.eof()) {  
90 - id = atoi(sym.c_str());  
91 - sym = " ";  
92 - } else {  
93 - iss >> id;  
94 - }  
95 -  
96 - // eat the trailing \r\n on windows  
97 - iss >> std::ws;  
98 - if (!iss.eof()) {  
99 - SHERPA_ONNX_LOGE("Error: %s", line.c_str());  
100 - exit(-1);  
101 - }  
102 -  
103 -#if 0  
104 - if (token2id.count(sym)) {  
105 - SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",  
106 - sym.c_str(), line.c_str(), token2id.at(sym));  
107 - exit(-1);  
108 - }  
109 -#endif  
110 - token2id.insert({std::move(sym), id});  
111 - }  
112 -  
113 - return token2id;  
114 -}  
115 -  
116 std::vector<int32_t> ConvertTokensToIds( 78 std::vector<int32_t> ConvertTokensToIds(
117 const std::unordered_map<std::string, int32_t> &token2id, 79 const std::unordered_map<std::string, int32_t> &token2id,
118 const std::vector<std::string> &tokens) { 80 const std::vector<std::string> &tokens) {
@@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend { @@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend {
67 bool debug_ = false; 67 bool debug_ = false;
68 }; 68 };
69 69
70 -std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);  
71 -  
72 -std::vector<int32_t> ConvertTokensToIds(  
73 - const std::unordered_map<std::string, int32_t> &token2id,  
74 - const std::vector<std::string> &tokens);  
75 -  
76 } // namespace sherpa_onnx 70 } // namespace sherpa_onnx
77 71
78 #endif // SHERPA_ONNX_CSRC_LEXICON_H_ 72 #endif // SHERPA_ONNX_CSRC_LEXICON_H_
@@ -41,13 +41,13 @@ @@ -41,13 +41,13 @@
41 auto value = \ 41 auto value = \
42 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ 42 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
43 if (!value) { \ 43 if (!value) { \
44 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ 44 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
45 exit(-1); \ 45 exit(-1); \
46 } \ 46 } \
47 \ 47 \
48 dst = atoi(value.get()); \ 48 dst = atoi(value.get()); \
49 if (dst < 0) { \ 49 if (dst < 0) { \
50 - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ 50 + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
51 exit(-1); \ 51 exit(-1); \
52 } \ 52 } \
53 } while (0) 53 } while (0)
@@ -61,80 +61,80 @@ @@ -61,80 +61,80 @@
61 } else { \ 61 } else { \
62 dst = atoi(value.get()); \ 62 dst = atoi(value.get()); \
63 if (dst < 0) { \ 63 if (dst < 0) { \
64 - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ 64 + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
65 exit(-1); \ 65 exit(-1); \
66 } \ 66 } \
67 } \ 67 } \
68 } while (0) 68 } while (0)
69 69
70 // read a vector of integers 70 // read a vector of integers
71 -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \  
72 - do { \  
73 - auto value = \  
74 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
75 - if (!value) { \  
76 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \  
77 - exit(-1); \  
78 - } \  
79 - \  
80 - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \  
81 - if (!ret) { \  
82 - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \  
83 - exit(-1); \  
84 - } \ 71 +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
  72 + do { \
  73 + auto value = \
  74 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  75 + if (!value) { \
  76 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  77 + exit(-1); \
  78 + } \
  79 + \
  80 + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
  81 + if (!ret) { \
  82 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
  83 + exit(-1); \
  84 + } \
85 } while (0) 85 } while (0)
86 86
87 // read a vector of floats 87 // read a vector of floats
88 -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \  
89 - do { \  
90 - auto value = \  
91 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
92 - if (!value) { \  
93 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \  
94 - exit(-1); \  
95 - } \  
96 - \  
97 - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \  
98 - if (!ret) { \  
99 - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \  
100 - exit(-1); \  
101 - } \ 88 +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
  89 + do { \
  90 + auto value = \
  91 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  92 + if (!value) { \
  93 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  94 + exit(-1); \
  95 + } \
  96 + \
  97 + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
  98 + if (!ret) { \
  99 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
  100 + exit(-1); \
  101 + } \
102 } while (0) 102 } while (0)
103 103
104 // read a vector of strings 104 // read a vector of strings
105 -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \  
106 - do { \  
107 - auto value = \  
108 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
109 - if (!value) { \  
110 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \  
111 - exit(-1); \  
112 - } \  
113 - SplitStringToVector(value.get(), ",", false, &dst); \  
114 - \  
115 - if (dst.empty()) { \  
116 - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \  
117 - src_key); \  
118 - exit(-1); \  
119 - } \ 105 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
  106 + do { \
  107 + auto value = \
  108 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  109 + if (!value) { \
  110 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  111 + exit(-1); \
  112 + } \
  113 + SplitStringToVector(value.get(), ",", false, &dst); \
  114 + \
  115 + if (dst.empty()) { \
  116 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
  117 + value.get(), src_key); \
  118 + exit(-1); \
  119 + } \
120 } while (0) 120 } while (0)
121 121
122 // read a vector of strings separated by sep 122 // read a vector of strings separated by sep
123 -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \  
124 - do { \  
125 - auto value = \  
126 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
127 - if (!value) { \  
128 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \  
129 - exit(-1); \  
130 - } \  
131 - SplitStringToVector(value.get(), sep, false, &dst); \  
132 - \  
133 - if (dst.empty()) { \  
134 - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \  
135 - src_key); \  
136 - exit(-1); \  
137 - } \ 123 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
  124 + do { \
  125 + auto value = \
  126 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  127 + if (!value) { \
  128 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  129 + exit(-1); \
  130 + } \
  131 + SplitStringToVector(value.get(), sep, false, &dst); \
  132 + \
  133 + if (dst.empty()) { \
  134 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
  135 + value.get(), src_key); \
  136 + exit(-1); \
  137 + } \
138 } while (0) 138 } while (0)
139 139
140 // Read a string 140 // Read a string
@@ -143,17 +143,29 @@ @@ -143,17 +143,29 @@
143 auto value = \ 143 auto value = \
144 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ 144 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
145 if (!value) { \ 145 if (!value) { \
146 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ 146 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
147 exit(-1); \ 147 exit(-1); \
148 } \ 148 } \
149 \ 149 \
150 dst = value.get(); \ 150 dst = value.get(); \
151 if (dst.empty()) { \ 151 if (dst.empty()) { \
152 - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ 152 + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
153 exit(-1); \ 153 exit(-1); \
154 } \ 154 } \
155 } while (0) 155 } while (0)
156 156
  157 +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
  158 + do { \
  159 + auto value = \
  160 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  161 + if (!value) { \
  162 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  163 + exit(-1); \
  164 + } \
  165 + \
  166 + dst = value.get(); \
  167 + } while (0)
  168 +
157 #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ 169 #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
158 default_value) \ 170 default_value) \
159 do { \ 171 do { \
@@ -164,7 +176,7 @@ @@ -164,7 +176,7 @@
164 } else { \ 176 } else { \
165 dst = value.get(); \ 177 dst = value.get(); \
166 if (dst.empty()) { \ 178 if (dst.empty()) { \
167 - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ 179 + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
168 exit(-1); \ 180 exit(-1); \
169 } \ 181 } \
170 } \ 182 } \
@@ -10,8 +10,8 @@ @@ -10,8 +10,8 @@
10 10
11 #include "cppjieba/Jieba.hpp" 11 #include "cppjieba/Jieba.hpp"
12 #include "sherpa-onnx/csrc/file-utils.h" 12 #include "sherpa-onnx/csrc/file-utils.h"
13 -#include "sherpa-onnx/csrc/lexicon.h"  
14 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/symbol-table.h"
15 #include "sherpa-onnx/csrc/text-utils.h" 15 #include "sherpa-onnx/csrc/text-utils.h"
16 16
17 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -21,6 +21,7 @@ namespace { @@ -21,6 +21,7 @@ namespace {
21 21
22 enum class ModelType : std::uint8_t { 22 enum class ModelType : std::uint8_t {
23 kEncDecCTCModelBPE, 23 kEncDecCTCModelBPE,
  24 + kEncDecCTCModel,
24 kEncDecHybridRNNTCTCBPEModel, 25 kEncDecHybridRNNTCTCBPEModel,
25 kTdnn, 26 kTdnn,
26 kZipformerCtc, 27 kZipformerCtc,
@@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
75 76
76 if (model_type.get() == std::string("EncDecCTCModelBPE")) { 77 if (model_type.get() == std::string("EncDecCTCModelBPE")) {
77 return ModelType::kEncDecCTCModelBPE; 78 return ModelType::kEncDecCTCModelBPE;
  79 + } else if (model_type.get() == std::string("EncDecCTCModel")) {
  80 + return ModelType::kEncDecCTCModel;
78 } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { 81 } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
79 return ModelType::kEncDecHybridRNNTCTCBPEModel; 82 return ModelType::kEncDecHybridRNNTCTCBPEModel;
80 } else if (model_type.get() == std::string("tdnn")) { 83 } else if (model_type.get() == std::string("tdnn")) {
@@ -121,22 +124,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -121,22 +124,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
121 switch (model_type) { 124 switch (model_type) {
122 case ModelType::kEncDecCTCModelBPE: 125 case ModelType::kEncDecCTCModelBPE:
123 return std::make_unique<OfflineNemoEncDecCtcModel>(config); 126 return std::make_unique<OfflineNemoEncDecCtcModel>(config);
124 - break; 127 + case ModelType::kEncDecCTCModel:
  128 + return std::make_unique<OfflineNemoEncDecCtcModel>(config);
125 case ModelType::kEncDecHybridRNNTCTCBPEModel: 129 case ModelType::kEncDecHybridRNNTCTCBPEModel:
126 return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(config); 130 return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(config);
127 - break;  
128 case ModelType::kTdnn: 131 case ModelType::kTdnn:
129 return std::make_unique<OfflineTdnnCtcModel>(config); 132 return std::make_unique<OfflineTdnnCtcModel>(config);
130 - break;  
131 case ModelType::kZipformerCtc: 133 case ModelType::kZipformerCtc:
132 return std::make_unique<OfflineZipformerCtcModel>(config); 134 return std::make_unique<OfflineZipformerCtcModel>(config);
133 - break;  
134 case ModelType::kWenetCtc: 135 case ModelType::kWenetCtc:
135 return std::make_unique<OfflineWenetCtcModel>(config); 136 return std::make_unique<OfflineWenetCtcModel>(config);
136 - break;  
137 case ModelType::kTeleSpeechCtc: 137 case ModelType::kTeleSpeechCtc:
138 return std::make_unique<OfflineTeleSpeechCtcModel>(config); 138 return std::make_unique<OfflineTeleSpeechCtcModel>(config);
139 - break;  
140 case ModelType::kUnknown: 139 case ModelType::kUnknown:
141 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 140 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
142 return nullptr; 141 return nullptr;
@@ -177,23 +176,19 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -177,23 +176,19 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
177 switch (model_type) { 176 switch (model_type) {
178 case ModelType::kEncDecCTCModelBPE: 177 case ModelType::kEncDecCTCModelBPE:
179 return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config); 178 return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
180 - break; 179 + case ModelType::kEncDecCTCModel:
  180 + return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
181 case ModelType::kEncDecHybridRNNTCTCBPEModel: 181 case ModelType::kEncDecHybridRNNTCTCBPEModel:
182 return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(mgr, 182 return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(mgr,
183 config); 183 config);
184 - break;  
185 case ModelType::kTdnn: 184 case ModelType::kTdnn:
186 return std::make_unique<OfflineTdnnCtcModel>(mgr, config); 185 return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
187 - break;  
188 case ModelType::kZipformerCtc: 186 case ModelType::kZipformerCtc:
189 return std::make_unique<OfflineZipformerCtcModel>(mgr, config); 187 return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
190 - break;  
191 case ModelType::kWenetCtc: 188 case ModelType::kWenetCtc:
192 return std::make_unique<OfflineWenetCtcModel>(mgr, config); 189 return std::make_unique<OfflineWenetCtcModel>(mgr, config);
193 - break;  
194 case ModelType::kTeleSpeechCtc: 190 case ModelType::kTeleSpeechCtc:
195 return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config); 191 return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
196 - break;  
197 case ModelType::kUnknown: 192 case ModelType::kUnknown:
198 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 193 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
199 return nullptr; 194 return nullptr;
@@ -66,6 +66,10 @@ class OfflineCtcModel { @@ -66,6 +66,10 @@ class OfflineCtcModel {
66 66
67 // Return true if the model supports batch size > 1 67 // Return true if the model supports batch size > 1
68 virtual bool SupportBatchProcessing() const { return true; } 68 virtual bool SupportBatchProcessing() const { return true; }
  69 +
  70 + // return true for models from https://github.com/salute-developers/GigaAM
  71 + // return false otherwise
  72 + virtual bool IsGigaAM() const { return false; }
69 }; 73 };
70 74
71 } // namespace sherpa_onnx 75 } // namespace sherpa_onnx
@@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl {
72 72
73 std::string FeatureNormalizationMethod() const { return normalize_type_; } 73 std::string FeatureNormalizationMethod() const { return normalize_type_; }
74 74
  75 + bool IsGigaAM() const { return is_giga_am_; }
  76 +
75 private: 77 private:
76 void Init(void *model_data, size_t model_data_length) { 78 void Init(void *model_data, size_t model_data_length) {
77 sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, 79 sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
@@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl {
92 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 94 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
93 SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); 95 SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
94 SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); 96 SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
95 - SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); 97 + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_,
  98 + "normalize_type");
  99 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
96 } 100 }
97 101
98 private: 102 private:
@@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl {
112 int32_t vocab_size_ = 0; 116 int32_t vocab_size_ = 0;
113 int32_t subsampling_factor_ = 0; 117 int32_t subsampling_factor_ = 0;
114 std::string normalize_type_; 118 std::string normalize_type_;
  119 +
  120 + // it is 1 for models from
  121 + // https://github.com/salute-developers/GigaAM
  122 + int32_t is_giga_am_ = 0;
115 }; 123 };
116 124
117 OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( 125 OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
@@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { @@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
146 return impl_->FeatureNormalizationMethod(); 154 return impl_->FeatureNormalizationMethod();
147 } 155 }
148 156
  157 +bool OfflineNemoEncDecCtcModel::IsGigaAM() const { return impl_->IsGigaAM(); }
  158 +
149 } // namespace sherpa_onnx 159 } // namespace sherpa_onnx
@@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { @@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
76 // for details 76 // for details
77 std::string FeatureNormalizationMethod() const override; 77 std::string FeatureNormalizationMethod() const override;
78 78
  79 + bool IsGigaAM() const override;
  80 +
79 private: 81 private:
80 class Impl; 82 class Impl;
81 std::unique_ptr<Impl> impl_; 83 std::unique_ptr<Impl> impl_;
@@ -104,11 +104,20 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -104,11 +104,20 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
104 } 104 }
105 105
106 if (!config_.model_config.nemo_ctc.model.empty()) { 106 if (!config_.model_config.nemo_ctc.model.empty()) {
107 - config_.feat_config.low_freq = 0;  
108 - config_.feat_config.high_freq = 0;  
109 - config_.feat_config.is_librosa = true;  
110 - config_.feat_config.remove_dc_offset = false;  
111 - config_.feat_config.window_type = "hann"; 107 + if (model_->IsGigaAM()) {
  108 + config_.feat_config.low_freq = 0;
  109 + config_.feat_config.high_freq = 8000;
  110 + config_.feat_config.remove_dc_offset = false;
  111 + config_.feat_config.preemph_coeff = 0;
  112 + config_.feat_config.window_type = "hann";
  113 + config_.feat_config.feature_dim = 64;
  114 + } else {
  115 + config_.feat_config.low_freq = 0;
  116 + config_.feat_config.high_freq = 0;
  117 + config_.feat_config.is_librosa = true;
  118 + config_.feat_config.remove_dc_offset = false;
  119 + config_.feat_config.window_type = "hann";
  120 + }
112 } 121 }
113 122
114 if (!config_.model_config.wenet_ctc.model.empty()) { 123 if (!config_.model_config.wenet_ctc.model.empty()) {
@@ -172,7 +172,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -172,7 +172,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
172 return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config); 172 return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
173 } 173 }
174 174
175 - if (model_type == "EncDecCTCModelBPE" || 175 + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" ||
176 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || 176 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
177 model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || 177 model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
178 model_type == "telespeech_ctc") { 178 model_type == "telespeech_ctc") {
@@ -189,6 +189,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -189,6 +189,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
189 " - Non-streaming transducer models from icefall\n" 189 " - Non-streaming transducer models from icefall\n"
190 " - Non-streaming Paraformer models from FunASR\n" 190 " - Non-streaming Paraformer models from FunASR\n"
191 " - EncDecCTCModelBPE models from NeMo\n" 191 " - EncDecCTCModelBPE models from NeMo\n"
  192 + " - EncDecCTCModel models from NeMo\n"
192 " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" 193 " - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
193 " - Whisper models\n" 194 " - Whisper models\n"
194 " - Tdnn models\n" 195 " - Tdnn models\n"
@@ -343,7 +344,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -343,7 +344,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
343 return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config); 344 return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
344 } 345 }
345 346
346 - if (model_type == "EncDecCTCModelBPE" || 347 + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" ||
347 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || 348 model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
348 model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || 349 model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
349 model_type == "telespeech_ctc") { 350 model_type == "telespeech_ctc") {
@@ -360,6 +361,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -360,6 +361,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
360 " - Non-streaming transducer models from icefall\n" 361 " - Non-streaming transducer models from icefall\n"
361 " - Non-streaming Paraformer models from FunASR\n" 362 " - Non-streaming Paraformer models from FunASR\n"
362 " - EncDecCTCModelBPE models from NeMo\n" 363 " - EncDecCTCModelBPE models from NeMo\n"
  364 + " - EncDecCTCModel models from NeMo\n"
363 " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" 365 " - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
364 " - Whisper models\n" 366 " - Whisper models\n"
365 " - Tdnn models\n" 367 " - Tdnn models\n"
@@ -7,6 +7,8 @@ @@ -7,6 +7,8 @@
7 #include <cassert> 7 #include <cassert>
8 #include <fstream> 8 #include <fstream>
9 #include <sstream> 9 #include <sstream>
  10 +#include <string>
  11 +#include <utility>
10 12
11 #if __ANDROID_API__ >= 9 13 #if __ANDROID_API__ >= 9
12 #include <strstream> 14 #include <strstream>
@@ -16,10 +18,54 @@ @@ -16,10 +18,54 @@
16 #endif 18 #endif
17 19
18 #include "sherpa-onnx/csrc/base64-decode.h" 20 #include "sherpa-onnx/csrc/base64-decode.h"
  21 +#include "sherpa-onnx/csrc/lexicon.h"
19 #include "sherpa-onnx/csrc/onnx-utils.h" 22 #include "sherpa-onnx/csrc/onnx-utils.h"
20 23
21 namespace sherpa_onnx { 24 namespace sherpa_onnx {
22 25
  26 +std::unordered_map<std::string, int32_t> ReadTokens(
  27 + std::istream &is,
  28 + std::unordered_map<int32_t, std::string> *id2token /*= nullptr*/) {
  29 + std::unordered_map<std::string, int32_t> token2id;
  30 +
  31 + std::string line;
  32 +
  33 + std::string sym;
  34 + int32_t id = -1;
  35 + while (std::getline(is, line)) {
  36 + std::istringstream iss(line);
  37 + iss >> sym;
  38 + if (iss.eof()) {
  39 + id = atoi(sym.c_str());
  40 + sym = " ";
  41 + } else {
  42 + iss >> id;
  43 + }
  44 +
  45 + // eat the trailing \r\n on windows
  46 + iss >> std::ws;
  47 + if (!iss.eof()) {
  48 + SHERPA_ONNX_LOGE("Error: %s", line.c_str());
  49 + exit(-1);
  50 + }
  51 +
  52 +#if 0
  53 + if (token2id.count(sym)) {
  54 + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
  55 + sym.c_str(), line.c_str(), token2id.at(sym));
  56 + exit(-1);
  57 + }
  58 +#endif
  59 + if (id2token) {
  60 + id2token->insert({id, sym});
  61 + }
  62 +
  63 + token2id.insert({std::move(sym), id});
  64 + }
  65 +
  66 + return token2id;
  67 +}
  68 +
23 SymbolTable::SymbolTable(const std::string &filename, bool is_file) { 69 SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
24 if (is_file) { 70 if (is_file) {
25 std::ifstream is(filename); 71 std::ifstream is(filename);
@@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { @@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
39 } 85 }
40 #endif 86 #endif
41 87
42 -void SymbolTable::Init(std::istream &is) {  
43 - std::string sym;  
44 - int32_t id = 0;  
45 - while (is >> sym >> id) {  
46 -#if 0  
47 - // we disable the test here since for some multi-lingual BPE models  
48 - // from NeMo, the same symbol can appear multiple times with different IDs.  
49 - if (sym != " ") {  
50 - assert(sym2id_.count(sym) == 0);  
51 - }  
52 -#endif  
53 -  
54 - assert(id2sym_.count(id) == 0);  
55 -  
56 - sym2id_.insert({sym, id});  
57 - id2sym_.insert({id, sym});  
58 - }  
59 - assert(is.eof());  
60 -} 88 +void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); }
61 89
62 std::string SymbolTable::ToString() const { 90 std::string SymbolTable::ToString() const {
63 std::ostringstream os; 91 std::ostringstream os;
@@ -5,8 +5,10 @@ @@ -5,8 +5,10 @@
5 #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ 5 #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
6 #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ 6 #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
7 7
  8 +#include <istream>
8 #include <string> 9 #include <string>
9 #include <unordered_map> 10 #include <unordered_map>
  11 +#include <vector>
10 12
11 #if __ANDROID_API__ >= 9 13 #if __ANDROID_API__ >= 9
12 #include "android/asset_manager.h" 14 #include "android/asset_manager.h"
@@ -15,6 +17,16 @@ @@ -15,6 +17,16 @@
15 17
16 namespace sherpa_onnx { 18 namespace sherpa_onnx {
17 19
  20 +// The same token can be mapped to different integer IDs, so
  21 +// we need an id2token argument here.
  22 +std::unordered_map<std::string, int32_t> ReadTokens(
  23 + std::istream &is,
  24 + std::unordered_map<int32_t, std::string> *id2token = nullptr);
  25 +
  26 +std::vector<int32_t> ConvertTokensToIds(
  27 + const std::unordered_map<std::string, int32_t> &token2id,
  28 + const std::vector<std::string> &tokens);
  29 +
18 /// It manages mapping between symbols and integer IDs. 30 /// It manages mapping between symbols and integer IDs.
19 class SymbolTable { 31 class SymbolTable {
20 public: 32 public:
@@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { @@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
394 modelType = "transducer", 394 modelType = "transducer",
395 ) 395 )
396 } 396 }
  397 +
  398 + 19 -> {
  399 + val modelDir = "sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24"
  400 + return OfflineModelConfig(
  401 + nemo = OfflineNemoEncDecCtcModelConfig(
  402 + model = "$modelDir/model.int8.onnx",
  403 + ),
  404 + tokens = "$modelDir/tokens.txt",
  405 + )
  406 + }
397 } 407 }
398 return null 408 return null
399 } 409 }