Fangjun Kuang
Committed by GitHub

Add C++ support for streaming NeMo CTC models. (#857)

... ... @@ -14,6 +14,28 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run streaming NeMo CTC "
log "------------------------------------------------------------"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)
curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
$EXE \
--nemo-ctc-model=$repo/model.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC HLG decoding "
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
... ...
... ... @@ -8,6 +8,19 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test online NeMo CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)
curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
python3 ./python-api-examples/online-nemo-ctc-decode-files.py
rm -rf $repo
log "test offline punctuation"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
... ...
... ... @@ -128,6 +128,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline transducer
shell: bash
run: |
... ... @@ -163,14 +171,6 @@ jobs:
.github/scripts/test-offline-ctc.sh
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline punctuation
shell: bash
run: |
... ...
#!/usr/bin/env python3
"""
This file shows how to use a streaming CTC model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms
"""
from pathlib import Path
import numpy as np
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx"
tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt"
test_wav = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
model=model,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
stream.accept_waveform(sample_rate, tail_paddings)
stream.input_finished()
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
print(wave_filename)
print(recognizer.get_result_all(stream))
if __name__ == "__main__":
main()
... ...
... ... @@ -100,7 +100,7 @@ class OnnxModel:
dtype=torch.float32,
).numpy()
self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy()
self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy()
def __call__(self, x: np.ndarray):
# x: (T, C)
... ...
... ... @@ -142,7 +142,7 @@ class OnnxModel:
dtype=torch.float32,
).numpy()
self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy()
self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy()
def run_encoder(self, x: np.ndarray):
# x: (T, C)
... ...
... ... @@ -61,6 +61,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-nemo-ctc-model-config.cc
online-nemo-ctc-model.cc
online-paraformer-model-config.cc
online-paraformer-model.cc
online-recognizer-impl.cc
... ...
... ... @@ -4,11 +4,12 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#include <math.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <math.h>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
int32_t segment_size = 20;
int32_t max_len = 200;
int32_t num_segments = ceil(((float)token_ids.size() + segment_size - 1) / segment_size);
int32_t num_segments =
ceil((static_cast<float>(token_ids.size()) + segment_size - 1) /
segment_size);
std::vector<int32_t> punctuations;
int32_t last = -1;
... ...
... ... @@ -10,6 +10,7 @@
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -22,6 +23,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -36,6 +39,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ...
... ... @@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
paraformer.Register(po);
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
nemo_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register(
"model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc"
"All other values lead to loading the model twice.");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
"wenet_ctc, nemo_ctc. "
"All other values lead to loading the model twice.");
}
bool OnlineModelConfig::Validate() const {
... ... @@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const {
return zipformer2_ctc.Validate();
}
if (!nemo_ctc.model.empty()) {
return nemo_ctc.Validate();
}
return transducer.Validate();
}
... ... @@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const {
os << "paraformer=" << paraformer.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", ";
... ...
... ... @@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
... ... @@ -18,6 +19,7 @@ struct OnlineModelConfig {
OnlineParaformerModelConfig paraformer;
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc;
std::string tokens;
int32_t num_threads = 1;
int32_t warm_up = 0;
... ... @@ -30,6 +32,7 @@ struct OnlineModelConfig {
// - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer or CTC from icefall
// - wenet_ctc, wenet CTC model
// - nemo_ctc, NeMo CTC model
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
... ... @@ -39,6 +42,7 @@ struct OnlineModelConfig {
const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
const std::string &model_type)
... ... @@ -46,6 +50,7 @@ struct OnlineModelConfig {
paraformer(paraformer),
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc),
tokens(tokens),
num_threads(num_threads),
warm_up(warm_up),
... ...
// sherpa-onnx/csrc/online-nemo-ctc-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) {
po->Register("nemo-ctc-model", &model,
"Path to CTC model.onnx from NeMo. Please see "
"https://github.com/k2-fsa/sherpa-onnx/pull/843");
}
bool OnlineNeMoCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineNeMoCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineNeMoCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-nemo-ctc-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineNeMoCtcModelConfig {
std::string model;
OnlineNeMoCtcModelConfig() = default;
explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-nemo-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineNeMoCtcModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.nemo_ctc.model);
Init(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.nemo_ctc.model);
Init(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> Forward(Ort::Value x,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];
int32_t batch_size = x.GetTensorTypeAndShapeInfo().GetShape()[0];
std::array<int64_t, 1> length_shape{batch_size};
Ort::Value length = Ort::Value::CreateTensor<int64_t>(
allocator_, length_shape.data(), length_shape.size());
int64_t *p_length = length.GetTensorMutableData<int64_t>();
std::fill(p_length, p_length + batch_size, ChunkLength());
// (B, T, C) -> (B, C, T)
x = Transpose12(allocator_, &x);
std::array<Ort::Value, 5> inputs = {
std::move(x), View(&length), std::move(cache_last_channel),
std::move(cache_last_time), std::move(cache_last_channel_len)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
// out[0]: logit
// out[1] logit_length
// out[2:] states_next
//
// we need to remove out[1]
std::vector<Ort::Value> ans;
ans.reserve(out.size() - 1);
for (int32_t i = 0; i != out.size(); ++i) {
if (i == 1) {
continue;
}
ans.push_back(std::move(out[i]));
}
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ChunkLength() const { return window_size_; }
int32_t ChunkShift() const { return chunk_shift_; }
OrtAllocator *Allocator() const { return allocator_; }
// Return a vector containing 3 tensors
// - cache_last_channel
// - cache_last_time_
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(3);
ans.push_back(View(&cache_last_channel_));
ans.push_back(View(&cache_last_time_));
ans.push_back(View(&cache_last_channel_len_));
return ans;
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) {
return std::move(states[0]);
}
std::vector<Ort::Value> ans;
// stack cache_last_channel
std::vector<const Ort::Value *> buf(batch_size);
// there are 3 states to be stacked
for (int32_t i = 0; i != 3; ++i) {
buf.clear();
buf.reserve(batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
assert(states[b].size() == 3);
buf.push_back(&states[b][i]);
}
Ort::Value c{nullptr};
if (i == 2) {
c = Cat<int64_t>(allocator_, buf, 0);
} else {
c = Cat(allocator_, buf, 0);
}
ans.push_back(std::move(c));
}
return ans;
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
assert(states.size() == 3);
std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = shape[0];
ans.resize(batch_size);
if (batch_size == 1) {
ans[0] = std::move(states);
return ans;
}
for (int32_t i = 0; i != 3; ++i) {
std::vector<Ort::Value> v;
if (i == 2) {
v = Unbind<int64_t>(allocator_, &states[i], 0);
} else {
v = Unbind(allocator_, &states[i], 0);
}
assert(v.size() == batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
ans[b].push_back(std::move(v[b]));
}
}
return ans;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(window_size_, "window_size");
SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_,
"cache_last_channel_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_,
"cache_last_channel_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_,
"cache_last_channel_dim3");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3");
// need to increase by 1 since the blank token is not included in computing
// vocab_size in NeMo.
vocab_size_ += 1;
InitStates();
}
void InitStates() {
std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
cache_last_channel_dim2_,
cache_last_channel_dim3_};
cache_last_channel_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_channel_shape.data(),
cache_last_channel_shape.size());
Fill<float>(&cache_last_channel_, 0);
std::array<int64_t, 4> cache_last_time_shape{
1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_};
cache_last_time_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_time_shape.data(), cache_last_time_shape.size());
Fill<float>(&cache_last_time_, 0);
int64_t shape = 1;
cache_last_channel_len_ =
Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1);
cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0;
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t window_size_;
int32_t chunk_shift_;
int32_t subsampling_factor_;
int32_t vocab_size_;
int32_t cache_last_channel_dim1_;
int32_t cache_last_channel_dim2_;
int32_t cache_last_channel_dim3_;
int32_t cache_last_time_dim1_;
int32_t cache_last_time_dim2_;
int32_t cache_last_time_dim3_;
Ort::Value cache_last_channel_{nullptr};
Ort::Value cache_last_time_{nullptr};
Ort::Value cache_last_channel_len_{nullptr};
};
OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineNeMoCtcModel::OnlineNeMoCtcModel(AAssetManager *mgr,
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default;
std::vector<Ort::Value> OnlineNeMoCtcModel::Forward(
Ort::Value x, std::vector<Ort::Value> states) const {
return impl_->Forward(std::move(x), std::move(states));
}
int32_t OnlineNeMoCtcModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OnlineNeMoCtcModel::ChunkLength() const { return impl_->ChunkLength(); }
int32_t OnlineNeMoCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
OrtAllocator *OnlineNeMoCtcModel::Allocator() const {
return impl_->Allocator();
}
std::vector<Ort::Value> OnlineNeMoCtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineNeMoCtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
return impl_->StackStates(std::move(states));
}
std::vector<std::vector<Ort::Value>> OnlineNeMoCtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
return impl_->UnStackStates(std::move(states));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-nemo-ctc-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineNeMoCtcModel : public OnlineCtcModel {
public:
explicit OnlineNeMoCtcModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineNeMoCtcModel() override;
// A list of 3 tensors:
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a list of tensors
* - ans[0] contains log_probs, of shape (N, T, C)
* - ans[1:] contains next_states
*/
std::vector<Ort::Value> Forward(
Ort::Value x, std::vector<Ort::Value> states) const override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// The model accepts this number of frames before subsampling as input
int32_t ChunkLength() const override;
// Similar to frame_shift in feature extractor, after processing
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
int32_t ChunkShift() const override;
bool SupportBatchProcessing() const override { return true; }
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_
... ...
... ... @@ -21,7 +21,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
}
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(config);
}
... ... @@ -41,7 +42,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
}
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
}
... ...
... ... @@ -23,6 +23,7 @@ set(srcs
online-ctc-fst-decoder-config.cc
online-lm-config.cc
online-model-config.cc
online-nemo-ctc-model-config.cc
online-paraformer-model-config.cc
online-recognizer.cc
online-stream.cc
... ...
... ... @@ -9,6 +9,7 @@
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
... ... @@ -21,26 +22,30 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineParaformerModelConfig(m);
PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m);
PybindOnlineNeMoCtcModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const std::string &,
const OnlineZipformer2CtcModelConfig &,
const OnlineNeMoCtcModelConfig &, const std::string &,
int32_t, int32_t, bool, const std::string &,
const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("provider") = "cpu",
py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOnlineNeMoCtcModelConfig(py::module *m) {
using PyClass = OnlineNeMoCtcModelConfig;
py::class_<PyClass>(*m, "OnlineNeMoCtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineNeMoCtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
... ...
... ... @@ -42,6 +42,8 @@ static void PybindOnlineRecognizerResult(py::module *m) {
"segment", [](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"is_final", [](PyClass &self) -> bool { return self.is_final; })
.def("__str__", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>())
.def("as_json_string", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>());
}
... ... @@ -50,29 +52,17 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(
py::init<const FeatureExtractorConfig &,
const OnlineModelConfig &,
const OnlineLMConfig &,
const EndpointConfig &,
const OnlineCtcFstDecoderConfig &,
bool,
const std::string &,
int32_t,
const std::string &,
float,
float,
float>(),
py::arg("feat_config"),
py::arg("model_config"),
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"),
py::arg("decoding_method"),
py::arg("max_active_paths") = 4,
py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0,
py::arg("blank_penalty") = 0.0,
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
... ...
... ... @@ -12,9 +12,11 @@ from _sherpa_onnx import (
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineNeMoCtcModelConfig,
OnlineZipformer2CtcModelConfig,
OnlineCtcFstDecoderConfig,
)
... ... @@ -59,6 +61,7 @@ class OnlineRecognizer(object):
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
debug: bool = False,
):
"""
Please refer to
... ... @@ -154,6 +157,7 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type=model_type,
debug=debug,
)
feat_config = FeatureExtractorConfig(
... ... @@ -220,6 +224,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
... ... @@ -283,6 +288,7 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type="paraformer",
debug=debug,
)
feat_config = FeatureExtractorConfig(
... ... @@ -324,6 +330,7 @@ class OnlineRecognizer(object):
ctc_graph: str = "",
ctc_max_active: int = 3000,
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
... ... @@ -386,6 +393,7 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
... ... @@ -418,6 +426,106 @@ class OnlineRecognizer(object):
return self
@classmethod
def from_nemo_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
debug:
True to show meta data in the model.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
nemo_ctc_config = OnlineNeMoCtcModelConfig(
model=model,
)
model_config = OnlineModelConfig(
nemo_ctc=nemo_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
tokens: str,
... ... @@ -433,6 +541,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
):
"""
Please refer to
... ... @@ -497,6 +606,7 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
debug=debug,
)
feat_config = FeatureExtractorConfig(
... ... @@ -537,6 +647,9 @@ class OnlineRecognizer(object):
def is_ready(self, s: OnlineStream) -> bool:
return self.recognizer.is_ready(s)
def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult:
return self.recognizer.get_result(s)
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text.strip()
... ...