Fangjun Kuang
Committed by GitHub

Support GigaAM CTC models for Russian ASR (#1464)

See also https://github.com/salute-developers/GigaAM
... ... @@ -16,6 +16,21 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run NeMo GigaAM Russian models"
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
rm sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
$EXE \
--nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/model.int8.onnx \
--tokens=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/tokens.txt \
--debug=1 \
./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/test_wavs/example.wav
rm -rf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
log "------------------------------------------------------------"
log "Run SenseVoice models"
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
... ...
name: export-nemo-giga-am-to-onnx
on:
workflow_dispatch:
concurrency:
group: export-nemo-giga-am-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-nemo-am-giga-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export nemo GigaAM models to ONNX
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Run CTC
shell: bash
run: |
pushd scripts/nemo/GigaAM
./run-ctc.sh
popd
d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
mkdir $d
mkdir $d/test_wavs
rm scripts/nemo/GigaAM/model.onnx
mv -v scripts/nemo/GigaAM/*.int8.onnx $d/
mv -v scripts/nemo/GigaAM/*.md $d/
mv -v scripts/nemo/GigaAM/*.pdf $d/
mv -v scripts/nemo/GigaAM/tokens.txt $d/
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
mv -v scripts/nemo/GigaAM/*-ctc.py $d/
ls -lh scripts/nemo/GigaAM/
ls -lh $d
tar cjvf ${d}.tar.bz2 $d
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
- name: Publish to huggingface (CTC)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
mv -v $d/* ./huggingface
cd huggingface
git lfs track "*.onnx"
git lfs track "*.wav"
git status
git add .
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
... ...
... ... @@ -149,6 +149,16 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test C++ API
shell: bash
run: |
... ... @@ -180,16 +190,6 @@ jobs:
.github/scripts/test-offline-transducer.sh
du -h -d1 .
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test online punctuation
shell: bash
run: |
... ...
... ... @@ -336,6 +336,24 @@ def get_models():
popd
""",
),
Model(
model_name="sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24",
idx=19,
lang="ru",
short_name="nemo_ctc_giga_am",
cmd="""
pushd $model_name
rm -rfv test_wavs
rm -fv *.sh
rm -fv *.py
ls -lh
popd
""",
),
]
return models
... ...
# Introduction
This folder contains scripts for converting models from
https://github.com/salute-developers/GigaAM
to sherpa-onnx.
The ASR models are for Russian speech recognition in this folder.
Please see the license of the models at
https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
from typing import Dict
import onnx
import torch
import torchaudio
from nemo.collections.asr.models import EncDecCTCModel
from nemo.collections.asr.modules.audio_preprocessing import (
AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor,
)
from nemo.collections.asr.parts.preprocessing.features import (
FilterbankFeaturesTA as NeMoFilterbankFeaturesTA,
)
from onnxruntime.quantization import QuantType, quantize_dynamic
class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA):
def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs):
if "window_size" in kwargs:
del kwargs["window_size"]
if "window_stride" in kwargs:
del kwargs["window_stride"]
super().__init__(**kwargs)
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = (
torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=kwargs["nfilt"],
window_fn=self.torch_windows[kwargs["window"]],
mel_scale=mel_scale,
norm=kwargs["mel_norm"],
n_fft=kwargs["n_fft"],
f_max=kwargs.get("highfreq", None),
f_min=kwargs.get("lowfreq", 0),
wkwargs=wkwargs,
)
)
class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor):
def __init__(self, mel_scale: str = "htk", **kwargs):
super().__init__(**kwargs)
kwargs["nfilt"] = kwargs["features"]
del kwargs["features"]
self.featurizer = (
FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility
mel_scale=mel_scale,
**kwargs,
)
)
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def main():
model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml")
ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu")
model.load_state_dict(ckpt, strict=False)
model.eval()
with open("tokens.txt", "w", encoding="utf-8") as f:
for i, t in enumerate(model.cfg.labels):
f.write(f"{t} {i}\n")
f.write(f"<blk> {i+1}\n")
filename = "model.onnx"
model.export(filename)
meta_data = {
"vocab_size": len(model.cfg.labels) + 1,
"normalize_type": "",
"subsampling_factor": 4,
"model_type": "EncDecCTCModel",
"version": "1",
"model_author": "https://github.com/salute-developers/GigaAM",
"license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf",
"language": "Russian",
"is_giga_am": 1,
}
add_meta_data(filename, meta_data)
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
function install_nemo() {
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
python3 get-pip.py
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa
pip install -qq ipython
# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
BRANCH='main'
python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]
pip install numpy==1.26.4
}
function download_files() {
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav
curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf
}
install_nemo
download_files
python3 ./export-onnx-ctc.py
ls -lh
python3 ./test-onnx-ctc.py
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
# https://github.com/salute-developers/GigaAM
import kaldi_native_fbank as knf
import librosa
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
def create_fbank():
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.remove_dc_offset = False
opts.frame_opts.preemph_coeff = 0
opts.frame_opts.window_type = "hann"
# Even though GigaAM uses 400 for fft, here we use 512
# since kaldi-native-fbank only support fft for power of 2.
opts.frame_opts.round_to_power_of_two = True
opts.mel_opts.low_freq = 0
opts.mel_opts.high_freq = 8000
opts.mel_opts.num_bins = 64
fbank = knf.OnlineFbank(opts)
return fbank
def compute_features(audio, fbank) -> np.ndarray:
"""
Args:
audio: (num_samples,), np.float32
fbank: the fbank extractor
Returns:
features: (num_frames, feat_dim), np.float32
"""
assert len(audio.shape) == 1, audio.shape
fbank.accept_waveform(16000, audio)
ans = []
processed = 0
while processed < fbank.num_frames_ready:
ans.append(np.array(fbank.get_frame(processed)))
processed += 1
ans = np.stack(ans)
return ans
def display(sess):
print("==========Input==========")
for i in sess.get_inputs():
print(i)
print("==========Output==========")
for i in sess.get_outputs():
print(i)
"""
==========Input==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
==========Output==========
NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 34])
"""
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.model = ort.InferenceSession(
filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
display(self.model)
def __call__(self, x: np.ndarray):
# x: (T, C)
x = torch.from_numpy(x)
x = x.t().unsqueeze(0)
# x: [1, C, T]
x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
log_probs = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)[0]
# [batch_size, T, dim]
return log_probs
def main():
filename = "./model.int8.onnx"
tokens = "./tokens.txt"
wav = "./example.wav"
model = OnnxModel(filename)
id2token = dict()
with open(tokens, encoding="utf-8") as f:
for line in f:
fields = line.split()
if len(fields) == 1:
id2token[int(fields[0])] = " "
else:
t, idx = fields
id2token[int(idx)] = t
fbank = create_fbank()
audio, sample_rate = sf.read(wav, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
features = compute_features(audio, fbank)
print("features.shape", features.shape)
blank = len(id2token) - 1
prev = -1
ans = []
log_probs = model(features)
print("log_probs", log_probs.shape)
log_probs = torch.from_numpy(log_probs)[0]
ids = torch.argmax(log_probs, dim=1).tolist()
for i in ids:
if i != blank and i != prev:
ans.append(i)
prev = i
tokens = [id2token[i] for i in ans]
text = "".join(tokens)
print(wav)
print(text)
if __name__ == "__main__":
main()
... ...
... ... @@ -193,6 +193,7 @@ class FeatureExtractor::Impl {
opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
opts_.frame_opts.preemph_coeff = config_.preemph_coeff;
opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config_.feature_dim;
... ... @@ -211,6 +212,7 @@ class FeatureExtractor::Impl {
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.preemph_coeff = config_.preemph_coeff;
mfcc_opts_.frame_opts.window_type = config_.window_type;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
... ...
... ... @@ -57,6 +57,7 @@ struct FeatureExtractorConfig {
float frame_length_ms = 25.0f; // in milliseconds.
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
float preemph_coeff = 0.97f; // Preemphasis coefficient.
std::string window_type = "povey"; // e.g. Hamming window
// For models from NeMo
... ...
... ... @@ -10,8 +10,8 @@
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ...
... ... @@ -21,6 +21,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ... @@ -74,45 +75,6 @@ static std::vector<std::string> ProcessHeteronyms(
return ans;
}
// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::unordered_map<std::string, int32_t> token2id;
std::string line;
std::string sym;
int32_t id = -1;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}
// eat the trailing \r\n on windows
iss >> std::ws;
if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error: %s", line.c_str());
exit(-1);
}
#if 0
if (token2id.count(sym)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(sym));
exit(-1);
}
#endif
token2id.insert({std::move(sym), id});
}
return token2id;
}
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens) {
... ...
... ... @@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend {
bool debug_ = false;
};
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_LEXICON_H_
... ...
... ... @@ -41,13 +41,13 @@
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = atoi(value.get()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
} while (0)
... ... @@ -61,7 +61,7 @@
} else { \
dst = atoi(value.get()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
} \
... ... @@ -73,13 +73,13 @@
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
} while (0)
... ... @@ -96,7 +96,7 @@
\
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
} while (0)
... ... @@ -107,14 +107,14 @@
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
src_key); \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
} while (0)
... ... @@ -125,14 +125,14 @@
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), sep, false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
src_key); \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
} while (0)
... ... @@ -143,17 +143,29 @@
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \
do { \
... ... @@ -164,7 +176,7 @@
} else { \
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} \
... ...
... ... @@ -10,8 +10,8 @@
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ...
... ... @@ -21,6 +21,7 @@ namespace {
enum class ModelType : std::uint8_t {
kEncDecCTCModelBPE,
kEncDecCTCModel,
kEncDecHybridRNNTCTCBPEModel,
kTdnn,
kZipformerCtc,
... ... @@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("EncDecCTCModel")) {
return ModelType::kEncDecCTCModel;
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
return ModelType::kEncDecHybridRNNTCTCBPEModel;
} else if (model_type.get() == std::string("tdnn")) {
... ... @@ -121,22 +124,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
break;
case ModelType::kEncDecCTCModel:
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
case ModelType::kEncDecHybridRNNTCTCBPEModel:
return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(config);
break;
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ... @@ -177,23 +176,19 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
break;
case ModelType::kEncDecCTCModel:
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
case ModelType::kEncDecHybridRNNTCTCBPEModel:
return std::make_unique<OfflineNemoEncDecHybridRNNTCTCBPEModel>(mgr,
config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
break;
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ...
... ... @@ -66,6 +66,10 @@ class OfflineCtcModel {
// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
// return true for models from https://github.com/salute-developers/GigaAM
// return false otherwise
virtual bool IsGigaAM() const { return false; }
};
} // namespace sherpa_onnx
... ...
... ... @@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }
bool IsGigaAM() const { return is_giga_am_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
... ... @@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_,
"normalize_type");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
}
private:
... ... @@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl {
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 0;
std::string normalize_type_;
// it is 1 for models from
// https://github.com/salute-developers/GigaAM
int32_t is_giga_am_ = 0;
};
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
... ... @@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}
bool OfflineNemoEncDecCtcModel::IsGigaAM() const { return impl_->IsGigaAM(); }
} // namespace sherpa_onnx
... ...
... ... @@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
// for details
std::string FeatureNormalizationMethod() const override;
bool IsGigaAM() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -104,12 +104,21 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
if (!config_.model_config.nemo_ctc.model.empty()) {
if (model_->IsGigaAM()) {
config_.feat_config.low_freq = 0;
config_.feat_config.high_freq = 8000;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.preemph_coeff = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.feature_dim = 64;
} else {
config_.feat_config.low_freq = 0;
config_.feat_config.high_freq = 0;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.window_type = "hann";
}
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
... ...
... ... @@ -172,7 +172,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
}
if (model_type == "EncDecCTCModelBPE" ||
if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
... ... @@ -189,6 +189,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - EncDecCTCModel models from NeMo\n"
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n"
... ... @@ -343,7 +344,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
}
if (model_type == "EncDecCTCModelBPE" ||
if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
... ... @@ -360,6 +361,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - EncDecCTCModel models from NeMo\n"
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n"
... ...
... ... @@ -7,6 +7,8 @@
#include <cassert>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#if __ANDROID_API__ >= 9
#include <strstream>
... ... @@ -16,10 +18,54 @@
#endif
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
std::unordered_map<std::string, int32_t> ReadTokens(
std::istream &is,
std::unordered_map<int32_t, std::string> *id2token /*= nullptr*/) {
std::unordered_map<std::string, int32_t> token2id;
std::string line;
std::string sym;
int32_t id = -1;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}
// eat the trailing \r\n on windows
iss >> std::ws;
if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error: %s", line.c_str());
exit(-1);
}
#if 0
if (token2id.count(sym)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(sym));
exit(-1);
}
#endif
if (id2token) {
id2token->insert({id, sym});
}
token2id.insert({std::move(sym), id});
}
return token2id;
}
SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
if (is_file) {
std::ifstream is(filename);
... ... @@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
}
#endif
void SymbolTable::Init(std::istream &is) {
std::string sym;
int32_t id = 0;
while (is >> sym >> id) {
#if 0
// we disable the test here since for some multi-lingual BPE models
// from NeMo, the same symbol can appear multiple times with different IDs.
if (sym != " ") {
assert(sym2id_.count(sym) == 0);
}
#endif
assert(id2sym_.count(id) == 0);
sym2id_.insert({sym, id});
id2sym_.insert({id, sym});
}
assert(is.eof());
}
void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); }
std::string SymbolTable::ToString() const {
std::ostringstream os;
... ...
... ... @@ -5,8 +5,10 @@
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
#include <istream>
#include <string>
#include <unordered_map>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -15,6 +17,16 @@
namespace sherpa_onnx {
// The same token can be mapped to different integer IDs, so
// we need an id2token argument here.
std::unordered_map<std::string, int32_t> ReadTokens(
std::istream &is,
std::unordered_map<int32_t, std::string> *id2token = nullptr);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
/// It manages mapping between symbols and integer IDs.
class SymbolTable {
public:
... ...
... ... @@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
modelType = "transducer",
)
}
19 -> {
val modelDir = "sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24"
return OfflineModelConfig(
nemo = OfflineNemoEncDecCtcModelConfig(
model = "$modelDir/model.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
)
}
}
return null
}
... ...