Fangjun Kuang
Committed by GitHub

Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)

... ... @@ -13,6 +13,105 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------------------"
log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)"
log "------------------------------------------------------------------------"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2
name=$(basename $url)
curl -SL -O $url
tar xvf $name
rm $name
repo=$(basename -s .tar.bz2 $name)
ls -lh $repo
log "test $repo"
test_wavs=(
de-german.wav
es-spanish.wav
hr-croatian.wav
po-polish.wav
uk-ukrainian.wav
en-english.wav
fr-french.wav
it-italian.wav
ru-russian.wav
)
for w in ${test_wavs[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--debug=1 \
$repo/test_wavs/$w
done
rm -rf $repo
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2
name=$(basename $url)
curl -SL -O $url
tar xvf $name
rm $name
repo=$(basename -s .tar.bz2 $name)
ls -lh $repo
log "Test $repo"
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--debug=1 \
$repo/test_wavs/en-english.wav
rm -rf $repo
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2
name=$(basename $url)
curl -SL -O $url
tar xvf $name
rm $name
repo=$(basename -s .tar.bz2 $name)
ls -lh $repo
log "test $repo"
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--debug=1 \
$repo/test_wavs/es-spanish.wav
rm -rf $repo
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2
name=$(basename $url)
curl -SL -O $url
tar xvf $name
rm $name
repo=$(basename -s .tar.bz2 $name)
ls -lh $repo
log "Test $repo"
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--debug=1 \
$repo/test_wavs/en-english.wav \
$repo/test_wavs/de-german.wav \
$repo/test_wavs/fr-french.wav \
$repo/test_wavs/es-spanish.wav
rm -rf $repo
log "------------------------------------------------------------"
log "Run Conformer transducer (English)"
log "------------------------------------------------------------"
... ...
... ... @@ -128,6 +128,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
- name: Test spoken language identification (C++ API)
shell: bash
run: |
... ... @@ -215,14 +223,6 @@ jobs:
.github/scripts/test-online-paraformer.sh
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
- name: Test online transducer
shell: bash
run: |
... ...
... ... @@ -107,6 +107,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
- name: Test online CTC
shell: bash
run: |
... ... @@ -192,14 +200,6 @@ jobs:
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
- name: Test online transducer
shell: bash
run: |
... ...
... ... @@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
sherpa-onnx-ced-*
node_modules
package-lock.json
sherpa-onnx-nemo-*
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming CTC model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model supports 10 languages and it is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx"
tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming transducer model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model supports 10 languages and it is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx"
decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx"
joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx"
tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
if not Path(encoder).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
model_type="nemo_transducer",
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
... ... @@ -40,9 +40,11 @@ set(sources
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-greedy-search-nemo-decoder.cc
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-transducer-nemo-model.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
... ...
... ... @@ -56,6 +56,19 @@ struct FeatureExtractorConfig {
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window
// For models from NeMo
// This option is not exposed and is set internally when loading models.
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string nemo_normalize_type;
std::string ToString() const;
void Register(ParseOptions *po);
... ...
... ... @@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
... ... @@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
: config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
... ...
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
... ...
... ... @@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
std::string text;
for (int32_t i = 0; i != src.tokens.size(); ++i) {
if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
// tdnn models from yesno have a SIL token, we should remove it.
continue;
}
... ... @@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>") &&
!symbol_table_.contains("<blank>")) {
if (!symbol_table_.Contains("<blk>") &&
!symbol_table_.Contains("<eps>") &&
!symbol_table_.Contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
... ... @@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
int32_t blank_id = 0;
if (symbol_table_.contains("<blk>")) {
if (symbol_table_.Contains("<blk>")) {
blank_id = symbol_table_["<blk>"];
} else if (symbol_table_.contains("<eps>")) {
} else if (symbol_table_.Contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = symbol_table_["<eps>"];
} else if (symbol_table_.contains("<blank>")) {
} else if (symbol_table_.Contains("<blank>")) {
// for Wenet CTC models
blank_id = symbol_table_["<blank>"];
}
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
... ... @@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
} else if (model_type == "nemo_transducer") {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
... ... @@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
}
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
!config.model_config.transducer.decoder_filename.empty() &&
!config.model_config.transducer.joiner_filename.empty()) {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
}
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
... ... @@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
} else if (model_type == "nemo_transducer") {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
... ... @@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
!config.model_config.transducer.decoder_filename.empty() &&
!config.model_config.transducer.joiner_filename.empty()) {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
}
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
... ...
// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#include <fstream>
#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
// defined in ./offline-recognizer-transducer-impl.h
OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src,
const SymbolTable &sym_table,
int32_t frame_shift_ms,
int32_t subsampling_factor);
class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerTransducerNeMoImpl(
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerNeMoModel>(
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
PostInit();
}
#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerNeMoModel>(
mgr, config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
PostInit();
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = ss[0]->FeatureDim();
std::vector<Ort::Value> features;
features.reserve(n);
std::vector<std::vector<float>> features_vec(n);
std::vector<int64_t> features_length_vec(n);
for (int32_t i = 0; i != n; ++i) {
auto f = ss[i]->GetFrames();
int32_t num_frames = f.size() / feat_dim;
features_length_vec[i] = num_frames;
features_vec[i] = std::move(f);
std::array<int64_t, 2> shape = {num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(
memory_info, features_vec[i].data(), features_vec[i].size(),
shape.data(), shape.size());
features.push_back(std::move(x));
}
std::vector<const Ort::Value *> features_pointer(n);
for (int32_t i = 0; i != n; ++i) {
features_pointer[i] = &features[i];
}
std::array<int64_t, 1> features_length_shape = {n};
Ort::Value x_length = Ort::Value::CreateTensor(
memory_info, features_length_vec.data(), n,
features_length_shape.data(), features_length_shape.size());
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] encoder_out_length, int64 tensor, (batch_size,)
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1]));
int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
ss[i]->SetResult(r);
}
}
private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
config_.feat_config.dither = 0;
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
int32_t vocab_size = model_->VocabSize();
// check the blank ID
if (!symbol_table_.Contains("<blk>")) {
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
exit(-1);
}
if (symbol_table_["<blk>"] != vocab_size - 1) {
SHERPA_ONNX_LOGE("<blk> is not the last token!");
exit(-1);
}
if (symbol_table_.NumSymbols() != vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineTransducerNeMoModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
... ...
... ... @@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
std::string text;
for (auto i : src.tokens) {
if (!sym_table.contains(i)) {
if (!sym_table.Contains(i)) {
continue;
}
... ...
... ... @@ -14,6 +14,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
... ... @@ -26,7 +27,7 @@ namespace sherpa_onnx {
struct OfflineRecognitionResult;
struct OfflineRecognizerConfig {
OfflineFeatureExtractorConfig feat_config;
FeatureExtractorConfig feat_config;
OfflineModelConfig model_config;
OfflineLMConfig lm_config;
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
... ... @@ -44,7 +45,7 @@ struct OfflineRecognizerConfig {
OfflineRecognizerConfig() = default;
OfflineRecognizerConfig(
const OfflineFeatureExtractorConfig &feat_config,
const FeatureExtractorConfig &feat_config,
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
... ...
... ... @@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
}
}
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("sample-rate", &sampling_rate,
"Sampling rate of the input waveform. "
"Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
}
std::string OfflineFeatureExtractorConfig::ToString() const {
std::ostringstream os;
os << "OfflineFeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
return os.str();
}
class OfflineStream::Impl {
public:
explicit Impl(const OfflineFeatureExtractorConfig &config,
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
... ... @@ -237,7 +220,7 @@ class OfflineStream::Impl {
}
private:
OfflineFeatureExtractorConfig config_;
FeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
... ... @@ -245,8 +228,7 @@ class OfflineStream::Impl {
ContextGraphPtr context_graph_;
};
OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/,
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
... ...
... ... @@ -11,6 +11,7 @@
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
... ... @@ -32,46 +33,12 @@ struct OfflineRecognitionResult {
std::string AsJsonString() const;
};
struct OfflineFeatureExtractorConfig {
// Sampling rate used by the feature extractor. If it is different from
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
int32_t feature_dim = 80;
// Set internally by some models, e.g., paraformer and wenet CTC models set
// it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;
// For models from NeMo
// This option is not exposed and is set internally when loading models.
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string nemo_normalize_type;
std::string ToString() const;
void Register(ParseOptions *po);
};
struct WhisperTag {};
struct CEDTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
explicit OfflineStream(const FeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag);
... ...
... ... @@ -14,7 +14,7 @@ namespace sherpa_onnx {
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
public:
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {}
... ...
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
#include <algorithm>
#include <iterator>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
int32_t token, OrtAllocator *allocator) {
std::array<int64_t, 2> shape{1, 1};
Ort::Value decoder_input =
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
std::array<int64_t, 1> length_shape{1};
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
allocator, length_shape.data(), length_shape.size());
int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
p[0] = token;
p_length[0] = 1;
return {std::move(decoder_input), std::move(decoder_input_length)};
}
static OfflineTransducerDecoderResult DecodeOne(
const float *p, int32_t num_rows, int32_t num_cols,
OfflineTransducerNeMoModel *model, float blank_penalty) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
OfflineTransducerDecoderResult ans;
int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
model->GetDecoderInitStates(1));
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
for (int32_t t = 0; t != num_rows; ++t) {
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} // if (y != blank_id)
} // for (int32_t i = 0; i != num_rows; ++i)
return ans;
}
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchNeMoDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
const float *p = encoder_out.GetTensorData<float>();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
const float *this_p = p + dim1 * dim2 * i;
int32_t this_len = p_length[i];
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
namespace sherpa_onnx {
class OfflineTransducerGreedySearchNeMoDecoder
: public OfflineTransducerDecoder {
public:
OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private:
OfflineTransducerNeMoModel *model_; // Not owned
float blank_penalty_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-transducer-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineTransducerNeMoModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.transducer.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.transducer.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.transducer.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length) {
// (B, T, C) -> (B, C, T)
features = Transpose12(allocator_, &features);
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
std::move(features_length)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return encoder_out;
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
Ort::Value targets, Ort::Value targets_length,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(2 + states.size());
decoder_inputs.push_back(std::move(targets));
decoder_inputs.push_back(std::move(targets_length));
for (auto &s : states) {
decoder_inputs.push_back(std::move(s));
}
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
decoder_inputs.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
std::vector<Ort::Value> states_next;
states_next.reserve(states.size());
// decoder_out[0]: decoder_output
// decoder_out[1]: decoder_output_length
// decoder_out[2:] states_next
for (int32_t i = 0; i != states.size(); ++i) {
states_next.push_back(std::move(decoder_out[i + 2]));
}
// we discard decoder_out[1]
return {std::move(decoder_out[0]), std::move(states_next)};
}
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
joiner_input.data(), joiner_input.size(),
joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
s0_shape.size());
Fill<float>(&s0, 0);
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
s1_shape.size());
Fill<float>(&s1, 0);
std::vector<Ort::Value> states;
states.reserve(2);
states.push_back(std::move(s0));
states.push_back(std::move(s1));
return states;
}
int32_t SubsamplingFactor() const { return subsampling_factor_; }
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
// need to increase by 1 since the blank token is not included in computing
// vocab_size in NeMo.
vocab_size_ += 1;
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
if (normalize_type_ == "NA") {
normalize_type_ = "";
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
void InitJoiner(void *model_data, size_t model_data_length) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 8;
std::string normalize_type_;
int32_t pred_rnn_layers_ = -1;
int32_t pred_hidden_ = -1;
};
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default;
std::vector<Ort::Value> OfflineTransducerNeMoModel::RunEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->RunEncoder(std::move(features), std::move(features_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets,
Ort::Value targets_length,
std::vector<Ort::Value> states) const {
return impl_->RunDecoder(std::move(targets), std::move(targets_length),
std::move(states));
}
std::vector<Ort::Value> OfflineTransducerNeMoModel::GetDecoderInitStates(
int32_t batch_size) const {
return impl_->GetDecoderInitStates(batch_size);
}
Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) const {
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
}
int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
int32_t OfflineTransducerNeMoModel::VocabSize() const {
return impl_->VocabSize();
}
OrtAllocator *OfflineTransducerNeMoModel::Allocator() const {
return impl_->Allocator();
}
std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-nemo-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
// see
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
// Its decoder is stateful, not stateless.
class OfflineTransducerNeMoModel {
public:
explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTransducerNeMoModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineTransducerNeMoModel();
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
* of frames in `encoder_out` before padding.
*/
std::vector<Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length) const;
/** Run the decoder network.
*
* @param targets A int32 tensor of shape (batch_size, 1)
* @param targets_length A int32 tensor of shape (batch_size,)
* @param states The states for the decoder model.
* @return Return a vector:
* - ans[0] is the decoder_out (a float tensor)
* - ans[1] is the decoder_out_length (a int32 tensor)
* - ans[2:] is the states_next
*/
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
Ort::Value targets, Ort::Value targets_length,
std::vector<Ort::Value> states) const;
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
/** Run the joint network.
*
* @param encoder_out Output of the encoder network.
* @param decoder_out Output of the decoder network.
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
*/
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
/** Return the subsampling factor of the model.
*/
int32_t SubsamplingFactor() const;
int32_t VocabSize() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string FeatureNormalizationMethod() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
... ...
... ... @@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
private:
void InitDecoder() {
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
!sym_.Contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
... ... @@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
}
int32_t blank_id = 0;
if (sym_.contains("<blk>")) {
if (sym_.Contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.contains("<eps>")) {
} else if (sym_.Contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.contains("<blank>")) {
} else if (sym_.Contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}
... ...
... ... @@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
... ... @@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(),
unk_id_,
config_.blank_penalty,
model_.get(), unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else {
... ... @@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
... ... @@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(),
unk_id_,
config_.blank_penalty,
model_.get(), unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else {
... ...
... ... @@ -13,7 +13,7 @@ namespace sherpa_onnx {
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
*
* @param allocator
* @param v A 2-D tensor. Its data type is T.
* @param v A 3-D tensor. Its data type is T.
* @param dim0_start Start index of the first dimension..
* @param dim0_end End index of the first dimension..
* @param dim1_start Start index of the second dimension.
... ...
... ... @@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const {
return sym2id_.at(sym);
}
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; }
bool SymbolTable::contains(const std::string &sym) const {
bool SymbolTable::Contains(const std::string &sym) const {
return sym2id_.count(sym) != 0;
}
... ...
... ... @@ -40,14 +40,16 @@ class SymbolTable {
int32_t operator[](const std::string &sym) const;
/// Return true if there is a symbol with the given ID.
bool contains(int32_t id) const;
bool Contains(int32_t id) const;
/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;
bool Contains(const std::string &sym) const;
// for tokens.txt from Whisper
void ApplyBase64Decode();
int32_t NumSymbols() const { return id2sym_.size(); }
private:
void Init(std::istream &is);
... ...
... ... @@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
word = word.replace(0, 3, " ");
}
}
if (symbol_table.contains(word)) {
if (symbol_table.Contains(word)) {
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
} else {
... ...
... ... @@ -14,10 +14,10 @@ namespace sherpa_onnx {
static void PybindOfflineRecognizerConfig(py::module *m) {
using PyClass = OfflineRecognizerConfig;
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const OfflineCtcFstDecoderConfig &, const std::string &,
int32_t, const std::string &, float, float>(),
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
const std::string &, int32_t, const std::string &, float,
float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
... ...
... ... @@ -25,6 +25,7 @@ Args:
static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
using PyClass = OfflineRecognitionResult;
py::class_<PyClass>(*m, "OfflineRecognitionResult")
.def("__str__", &PyClass::AsJsonString)
.def_property_readonly(
"text",
[](const PyClass &self) -> py::str {
... ... @@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
"timestamps", [](const PyClass &self) { return self.timestamps; });
}
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
using PyClass = OfflineFeatureExtractorConfig;
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def("__str__", &PyClass::ToString);
}
void PybindOfflineStream(py::module *m) {
PybindOfflineFeatureExtractorConfig(m);
PybindOfflineRecognitionResult(m);
using PyClass = OfflineStream;
... ...
... ... @@ -4,8 +4,8 @@ from pathlib import Path
from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
OfflineCtcFstDecoderConfig,
OfflineFeatureExtractorConfig,
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
... ... @@ -51,6 +51,7 @@ class OfflineRecognizer(object):
blank_penalty: float = 0.0,
debug: bool = False,
provider: str = "cpu",
model_type: str = "transducer",
):
"""
Please refer to
... ... @@ -106,10 +107,10 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="transducer",
model_type=model_type,
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
... ... @@ -182,7 +183,7 @@ class OfflineRecognizer(object):
model_type="paraformer",
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
... ... @@ -246,7 +247,7 @@ class OfflineRecognizer(object):
model_type="nemo_ctc",
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
... ... @@ -326,7 +327,7 @@ class OfflineRecognizer(object):
model_type="whisper",
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
... ... @@ -389,7 +390,7 @@ class OfflineRecognizer(object):
model_type="tdnn",
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
... ... @@ -453,7 +454,7 @@ class OfflineRecognizer(object):
model_type="wenet_ctc",
)
feat_config = OfflineFeatureExtractorConfig(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
... ...