Fangjun Kuang
Committed by GitHub

Add C++ and Python support for T-one streaming Russian ASR models (#2575)

This PR adds support for T-one streaming Russian ASR models in both C++ and Python APIs. The T-one model is a CTC-based Russian speech recognition model with specific characteristics including float16 state handling, 300ms frame lengths, and 8kHz sampling rate.

- Added new OnlineToneCtcModel implementation with specialized processing for T-one models
- Integrated T-one support into the existing CTC model pipeline and Python bindings
- Added Python example and test scripts for the new functionality
... ... @@ -8,6 +8,16 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test T-one"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
python3 ./python-api-examples/online-t-one-ctc-decode-files.py
rm -rf sherpa-onnx-streaming-t-one-russian-2025-09-08
log "test nemo canary"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
... ...
... ... @@ -149,3 +149,4 @@ kitten-nano-en-v0_1-fp16
*.egg-info
*.jar
vocab.json
*.so
... ...
... ... @@ -2,7 +2,8 @@
// Copyright (c) 2025 Xiaomi Corporation
// To use punctuation model:
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
// wget
// https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
// tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
// rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
... ... @@ -15,14 +16,17 @@ int32_t main() {
using namespace sherpa_onnx::cxx; // NOLINT
OfflinePunctuationConfig punctuation_config;
punctuation_config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx";
punctuation_config.model.ct_transformer =
"./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/"
"model.onnx";
punctuation_config.model.num_threads = 1;
punctuation_config.model.debug = false;
punctuation_config.model.provider = "cpu";
OfflinePunctuation punct = OfflinePunctuation::Create(punctuation_config);
if (!punct.Get()) {
std::cerr << "Failed to create punctuation model. Please check your config\n";
std::cerr
<< "Failed to create punctuation model. Please check your config\n";
return -1;
}
... ...
#!/usr/bin/env python3
"""
This file shows how to use a streaming CTC model from T-one
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model is converted from
https://github.com/voicekit-team/T-one
using
https://github.com/k2-fsa/sherpa-onnx/tree/master/scripts/t-one
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
"""
from pathlib import Path
import numpy as np
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/model.onnx"
tokens = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/tokens.txt"
test_wav = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/0.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OnlineRecognizer.from_t_one_ctc(
model=model,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 8000 Hz
stream = recognizer.create_stream()
left_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32)
stream.accept_waveform(sample_rate, left_paddings)
stream.accept_waveform(sample_rate, audio)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
stream.accept_waveform(sample_rate, tail_paddings)
stream.input_finished()
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
print(wave_filename)
print(recognizer.get_result_all(stream))
if __name__ == "__main__":
main()
... ...
... ... @@ -147,14 +147,13 @@ def main():
sample_rate = model.sample_rate
# Pad 0.5 seconds
samples = np.pad(samples, (0, 4000))
samples = np.pad(samples, (2400, 2400))
features = compute_feat(
samples=samples,
sample_rate=sample_rate,
frame_length_ms=model.frame_length_ms,
)
print(features.shape)
id2token = load_tokens(args.tokens)
... ...
... ... @@ -95,6 +95,8 @@ set(sources
online-recognizer.cc
online-rnn-lm.cc
online-stream.cc
online-t-one-ctc-model-config.cc
online-t-one-ctc-model.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-greedy-search-nemo-decoder.cc
... ...
... ... @@ -7,8 +7,10 @@
#include <algorithm>
#include <functional>
#include <numeric>
#include <sstream>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
... ... @@ -27,10 +29,12 @@ static bool Compare(const std::vector<int64_t> &a,
}
static void PrintShape(const std::vector<int64_t> &a) {
std::ostringstream os;
for (auto i : a) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
os << i << " ";
}
fprintf(stderr, "\n");
os << "\n";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
template <typename T /*=float*/>
... ... @@ -51,15 +55,15 @@ Ort::Value Cat(OrtAllocator *allocator,
bool ret = Compare(v0_shape, s, dim);
if (!ret) {
fprintf(stderr, "Incorrect shape in Cat !\n");
SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
fprintf(stderr, "Shape for tensor 0: ");
SHERPA_ONNX_LOGE("Shape for tensor 0: ");
PrintShape(v0_shape);
fprintf(stderr, "Shape for tensor %d: ", i);
SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
PrintShape(s);
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
}
... ... @@ -99,8 +103,77 @@ template Ort::Value Cat<float>(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
template Ort::Value Cat<uint16_t>(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
Ort::Value CatFloat16(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim) {
if (values.size() == 1u) {
return Clone(allocator, values[0]);
}
std::vector<int64_t> v0_shape =
values[0]->GetTensorTypeAndShapeInfo().GetShape();
int64_t total_dim = v0_shape[dim];
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
total_dim += s[dim];
bool ret = Compare(v0_shape, s, dim);
if (!ret) {
SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
SHERPA_ONNX_LOGE("Shape for tensor 0: ");
PrintShape(v0_shape);
SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
PrintShape(s);
SHERPA_ONNX_EXIT(-1);
}
}
std::vector<int64_t> ans_shape;
ans_shape.reserve(v0_shape.size());
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
ans_shape.push_back(total_dim);
ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
v0_shape.data() + v0_shape.size());
auto leading_size = static_cast<int32_t>(std::accumulate(
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
auto trailing_size = static_cast<int32_t>(
std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
std::multiplies<int64_t>()));
Ort::Value ans =
Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
using T = uint16_t;
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (auto value : values) {
auto this_dim = value->GetTensorTypeAndShapeInfo().GetShape()[dim];
const T *src = value->GetTensorData<T>();
src += i * this_dim * trailing_size;
std::copy(src, src + this_dim * trailing_size, dst);
dst += this_dim * trailing_size;
}
}
return ans;
}
} // namespace sherpa_onnx
... ...
... ... @@ -23,6 +23,10 @@ template <typename T = float>
Ort::Value Cat(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim);
Ort::Value CatFloat16(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_CAT_H_
... ...
... ... @@ -62,6 +62,8 @@ class FeatureExtractor::Impl {
InitMfcc();
} else if (config_.is_whisper) {
InitWhisper();
} else if (config_.is_t_one) {
InitRawAudioSamples();
} else {
InitFbank();
}
... ... @@ -135,6 +137,9 @@ class FeatureExtractor::Impl {
} else if (whisper_fbank_) {
whisper_fbank_->InputFinished();
return;
} else if (raw_audio_) {
raw_audio_->InputFinished();
return;
} else if (mfcc_) {
mfcc_->InputFinished();
return;
... ... @@ -149,6 +154,8 @@ class FeatureExtractor::Impl {
return fbank_->NumFramesReady();
} else if (whisper_fbank_) {
return whisper_fbank_->NumFramesReady();
} else if (raw_audio_) {
return raw_audio_->NumFramesReady();
} else if (mfcc_) {
return mfcc_->NumFramesReady();
}
... ... @@ -163,6 +170,8 @@ class FeatureExtractor::Impl {
return fbank_->IsLastFrame(frame);
} else if (whisper_fbank_) {
return whisper_fbank_->IsLastFrame(frame);
} else if (raw_audio_) {
return raw_audio_->IsLastFrame(frame);
} else if (mfcc_) {
return mfcc_->IsLastFrame(frame);
}
... ... @@ -209,6 +218,8 @@ class FeatureExtractor::Impl {
return opts_.mel_opts.num_bins;
} else if (mfcc_) {
return mfcc_opts_.num_ceps;
} else if (raw_audio_) {
return raw_audio_->Dim();
}
SHERPA_ONNX_LOGE("unreachable code");
... ... @@ -225,6 +236,9 @@ class FeatureExtractor::Impl {
} else if (whisper_fbank_) {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (raw_audio_) {
raw_audio_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
return;
... ... @@ -239,6 +253,8 @@ class FeatureExtractor::Impl {
return fbank_->GetFrame(frame_index);
} else if (whisper_fbank_) {
return whisper_fbank_->GetFrame(frame_index);
} else if (raw_audio_) {
return raw_audio_->GetFrame(frame_index);
} else if (mfcc_) {
return mfcc_->GetFrame(frame_index);
}
... ... @@ -255,6 +271,9 @@ class FeatureExtractor::Impl {
} else if (whisper_fbank_) {
whisper_fbank_->Pop(discard_num);
return;
} else if (raw_audio_) {
raw_audio_->Pop(discard_num);
return;
} else if (mfcc_) {
mfcc_->Pop(discard_num);
return;
... ... @@ -322,11 +341,21 @@ class FeatureExtractor::Impl {
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
void InitRawAudioSamples() {
opts_raw_audio_.frame_opts.samp_freq = config_.sampling_rate;
opts_raw_audio_.frame_opts.frame_length_ms = config_.frame_length_ms;
opts_raw_audio_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
raw_audio_ = std::make_unique<knf::OnlineRawAudioSamples>(opts_raw_audio_);
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
std::unique_ptr<knf::OnlineRawAudioSamples> raw_audio_;
knf::FbankOptions opts_;
knf::RawAudioSamplesOptions opts_raw_audio_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
... ...
... ... @@ -81,6 +81,8 @@ struct FeatureExtractorConfig {
bool is_whisper = false;
bool is_t_one = false;
bool round_to_power_of_two = true;
std::string ToString() const;
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include <algorithm>
#include <fstream>
#include <regex> // NOLINT
#include <strstream>
... ...
... ... @@ -38,7 +38,8 @@ struct OfflineRecognitionResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only)
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models
/// only)
std::vector<float> durations;
std::vector<int32_t> words;
... ...
... ... @@ -4,6 +4,7 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
... ...
... ... @@ -104,7 +104,8 @@ class OfflineTtsZipvoiceModel::Impl {
int64_t feat_dim = meta_data_.feat_dim;
std::vector<float> x_data(batch_size * num_frames * feat_dim);
std::default_random_engine rng(std::random_device{}());
std::random_device rd;
std::default_random_engine rng(rd());
std::normal_distribution<float> norm(0, 1);
for (auto &v : x_data) v = norm(rng);
std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim};
... ...
... ... @@ -7,6 +7,7 @@
#include <cmath>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ...
... ... @@ -28,6 +28,13 @@ void OnlineCtcGreedySearchDecoder::Decode(
auto &r = (*results)[b];
int32_t prev_id = -1;
if (!r.tokens.empty()) {
if (r.num_trailing_blanks > 0) {
prev_id = blank_id_;
} else {
prev_id = r.tokens.back();
}
}
for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
int32_t y = static_cast<int32_t>(std::distance(
... ...
... ... @@ -20,6 +20,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -34,9 +35,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(config);
} else if (!config.t_one_ctc.model.empty()) {
return std::make_unique<OnlineToneCtcModel>(config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
}
... ... @@ -49,9 +52,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
} else if (!config.t_one_ctc.model.empty()) {
return std::make_unique<OnlineToneCtcModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
}
... ...
... ... @@ -17,6 +17,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
nemo_ctc.Register(po);
t_one_ctc.Register(po);
provider_config.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -149,6 +150,10 @@ bool OnlineModelConfig::Validate() const {
return nemo_ctc.Validate();
}
if (!t_one_ctc.model.empty()) {
return t_one_ctc.Validate();
}
if (!provider_config.Validate()) {
return false;
}
... ... @@ -165,6 +170,7 @@ std::string OnlineModelConfig::ToString() const {
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "t_one_ctc=" << t_one_ctc.ToString() << ", ";
os << "provider_config=" << provider_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
... ...
... ... @@ -8,6 +8,7 @@
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
... ... @@ -21,6 +22,7 @@ struct OnlineModelConfig {
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc;
OnlineToneCtcModelConfig t_one_ctc;
ProviderConfig provider_config;
std::string tokens;
int32_t num_threads = 1;
... ... @@ -56,6 +58,7 @@ struct OnlineModelConfig {
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc,
const OnlineToneCtcModelConfig &t_one_ctc,
const ProviderConfig &provider_config,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &model_type,
... ... @@ -66,6 +69,7 @@ struct OnlineModelConfig {
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc),
t_one_ctc(t_one_ctc),
provider_config(provider_config),
tokens(tokens),
num_threads(num_threads),
... ...
... ... @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
#include <algorithm>
#include <cassert>
#include <ios>
#include <memory>
#include <sstream>
... ... @@ -79,24 +80,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_(config),
model_(OnlineCtcModel::Create(config.model_config)),
endpoint_(config_.endpoint_config) {
if (!config.model_config.tokens_buf.empty()) {
sym_ = SymbolTable(config.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = SymbolTable(config.model_config.tokens, true);
}
if (!config.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
PostInit();
}
template <typename Manager>
... ... @@ -107,17 +91,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
model_(OnlineCtcModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
PostInit();
}
std::unique_ptr<OnlineStream> CreateStream() const override {
... ... @@ -211,6 +185,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
if (!config_.model_config.t_one_ctc.model.empty()) {
// each input frame is of 300ms long, which produces 10 output frames.
// so frame_shift_ms is 300/10 = 30ms
//
frame_shift_ms = 30;
subsampling_factor = 1;
}
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
... ... @@ -258,6 +240,33 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
}
private:
void PostInit() {
if (!config_.model_config.tokens_buf.empty()) {
sym_ = SymbolTable(config_.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = SymbolTable(config_.model_config.tokens, true);
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
if (!config_.model_config.t_one_ctc.model.empty()) {
config_.feat_config.is_t_one = true;
config_.feat_config.frame_length_ms = 300;
config_.feat_config.frame_shift_ms = 300;
config_.feat_config.sampling_rate = 8000;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
}
void InitDecoder() {
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
!sym_.Contains("<blank>")) {
... ...
... ... @@ -83,12 +83,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
!config.model_config.nemo_ctc.model.empty() ||
!config.model_config.t_one_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
template <typename Manager>
... ... @@ -142,12 +143,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
!config.model_config.nemo_ctc.model.empty() ||
!config.model_config.t_one_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
... ...
// sherpa-onnx/csrc/online-t-one-ctc-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineToneCtcModelConfig::Register(ParseOptions *po) {
po->Register("t-one-ctc-model", &model,
"Path to CTC model.onnx from T-one. Please see "
"https://github.com/k2-fsa/sherpa-onnx/pull/2571");
}
bool OnlineToneCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("T-one CTC model '%s' does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineToneCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineToneCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-t-one-ctc-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineToneCtcModelConfig {
std::string model;
OnlineToneCtcModelConfig() = default;
explicit OnlineToneCtcModelConfig(const std::string &model) : model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-t-one-ctc-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineToneCtcModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.t_one_ctc.model);
Init(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.t_one_ctc.model);
Init(buf.data(), buf.size());
}
}
std::vector<Ort::Value> Forward(Ort::Value x,
std::vector<Ort::Value> states) {
// shape0 is (batch_size, 1, num_samples)
auto shape0 = x.GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 3> shape = {shape0[0], shape0[2], shape0[1]};
std::vector<int32_t> samples(shape[0] * shape[1] * shape[2]);
const float *px = x.GetTensorData<float>();
for (int32_t i = 0; i < samples.size(); ++i) {
float f = px[i];
f = f > 1 ? 1 : f;
f = f < -1 ? -1 : f;
samples[i] = static_cast<int32_t>(f * 32767);
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value xx =
Ort::Value::CreateTensor(memory_info, samples.data(), samples.size(),
shape.data(), shape.size());
std::array<Ort::Value, 2> inputs = {std::move(xx), std::move(states[0])};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
// out[0]: log_probs
// out[1] next_states
return out;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ChunkLength() const { return 1; }
int32_t ChunkShift() const { return 1; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 1 tensor
// - state_
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.push_back(View(&state_));
return ans;
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) {
return std::move(states[0]);
}
std::vector<Ort::Value> ans;
ans.reserve(1);
std::vector<const Ort::Value *> buf;
buf.reserve(batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
buf.push_back(&states[b][0]);
}
Ort::Value c{nullptr};
c = CatFloat16(allocator_, buf, 0);
ans.push_back(std::move(c));
return ans;
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
auto allocator = const_cast<Impl *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = shape[0];
ans.resize(batch_size);
if (batch_size == 1) {
ans[0] = std::move(states);
return ans;
}
std::vector<Ort::Value> v;
v = UnbindFloat16(allocator, &states[0], 0);
for (int32_t b = 0; b != batch_size; ++b) {
ans[b].push_back(std::move(v[b]));
}
return ans;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(frame_length_ms_, "frame_length_ms");
SHERPA_ONNX_READ_META_DATA(state_dim_, "state_dim");
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
InitStates();
vocab_size_ = sess_->GetOutputTypeInfo(0)
.GetTensorTypeAndShapeInfo()
.GetShape()
.back();
}
void InitStates() {
std::array<int64_t, 2> state_shape{1, state_dim_};
state_ = Ort::Value::CreateTensor(allocator_, state_shape.data(),
state_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
auto p = state_.GetTensorMutableData<uint16_t>();
std::fill(p, p + state_dim_, 0);
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
// One input frame is of length is 300ms
// For each input frame, there are 10 output frames,
// so each output frame is 30ms
int32_t frame_length_ms_ = 0;
int32_t state_dim_ = 0;
int32_t sample_rate_ = 0;
int32_t vocab_size_ = 0;
Ort::Value state_{nullptr};
};
OnlineToneCtcModel::OnlineToneCtcModel(const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OnlineToneCtcModel::OnlineToneCtcModel(Manager *mgr,
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OnlineToneCtcModel::~OnlineToneCtcModel() = default;
std::vector<Ort::Value> OnlineToneCtcModel::Forward(
Ort::Value x, std::vector<Ort::Value> states) const {
return impl_->Forward(std::move(x), std::move(states));
}
int32_t OnlineToneCtcModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OnlineToneCtcModel::ChunkLength() const { return impl_->ChunkLength(); }
int32_t OnlineToneCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
OrtAllocator *OnlineToneCtcModel::Allocator() const {
return impl_->Allocator();
}
std::vector<Ort::Value> OnlineToneCtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineToneCtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
return impl_->StackStates(std::move(states));
}
std::vector<std::vector<Ort::Value>> OnlineToneCtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
return impl_->UnStackStates(std::move(states));
}
#if __ANDROID_API__ >= 9
template OnlineToneCtcModel::OnlineToneCtcModel(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif
#if __OHOS__
template OnlineToneCtcModel::OnlineToneCtcModel(
NativeResourceManager *mgr, const OnlineModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-t-one-ctc-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineToneCtcModel : public OnlineCtcModel {
public:
explicit OnlineToneCtcModel(const OnlineModelConfig &config);
template <typename Manager>
OnlineToneCtcModel(Manager *mgr, const OnlineModelConfig &config);
~OnlineToneCtcModel() override;
// A list of 1 tensor:
// - (batch_size, state_dim)
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (batch_size, num_samples).
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a list of tensors
* - ans[0] contains log_probs, of shape (N, T, C)
* - ans[1:] contains next_states
*/
std::vector<Ort::Value> Forward(
Ort::Value x, std::vector<Ort::Value> states) const override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// The model accepts this number of frames before subsampling as input
int32_t ChunkLength() const override;
// Similar to frame_shift in feature extractor, after processing
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
int32_t ChunkShift() const override;
bool SupportBatchProcessing() const override { return true; }
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
... ...
... ... @@ -155,10 +155,30 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
std::copy(start, end, dst);
return ans;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
Ort::Value ans =
Ort::Value::CreateTensor(allocator, shape.data(), shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
const auto *start = v->GetTensorData<uint16_t>();
const auto *end = start + type_and_shape.GetElementCount();
auto *dst = ans.GetTensorMutableData<uint16_t>();
std::copy(start, end, dst);
return ans;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
Ort::Value ans = Ort::Value::CreateTensor<uint16_t>(
allocator, shape.data(), shape.size());
const auto *start = v->GetTensorData<uint16_t>();
const auto *end = start + type_and_shape.GetElementCount();
auto *dst = ans.GetTensorMutableData<uint16_t>();
std::copy(start, end, dst);
return ans;
}
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
exit(-1);
SHERPA_ONNX_LOGE("Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
SHERPA_ONNX_EXIT(-1);
// unreachable code
return Ort::Value{nullptr};
}
... ... @@ -183,14 +203,23 @@ Ort::Value View(Ort::Value *v) {
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<uint16_t>(),
type_and_shape.GetElementCount() * sizeof(uint16_t), shape.data(),
shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<uint16_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<bool>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
exit(-1);
SHERPA_ONNX_LOGE("Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
SHERPA_ONNX_EXIT(-1);
// unreachable code
return Ort::Value{nullptr};
}
... ...
... ... @@ -11,6 +11,7 @@
#include <locale>
#endif
#include <algorithm>
#include <cassert>
#include <ostream>
#include <string>
... ...
... ... @@ -117,6 +117,11 @@ for a list of pre-trained models to download.
const float duration = samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer.CreateStream();
std::vector<float> left_paddings(static_cast<int>(0.3 * sampling_rate));
s->AcceptWaveform(sampling_rate, left_paddings.data(),
left_paddings.size());
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
... ...
... ... @@ -4,7 +4,7 @@
#include "sherpa-onnx/csrc/text-utils.h"
#include <regex>
#include <regex> // NOLINT
#include <sstream>
#include "gtest/gtest.h"
... ...
... ... @@ -68,4 +68,49 @@ template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
const Ort::Value *value,
int32_t dim);
std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
const Ort::Value *value, int32_t dim) {
std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
assert(dim >= 0);
assert(dim < static_cast<int32_t>(shape.size()));
int32_t n = static_cast<int32_t>(shape[dim]);
if (n == 1) {
std::vector<Ort::Value> ans;
ans.push_back(Clone(allocator, value));
return ans;
}
std::vector<int64_t> ans_shape = shape;
ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
// allocator tensors
std::vector<Ort::Value> ans;
ans.reserve(n);
for (int32_t i = 0; i != n; ++i) {
Ort::Value t =
Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
ans.push_back(std::move(t));
}
auto leading_size = static_cast<int32_t>(std::accumulate(
shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
auto trailing_size = static_cast<int32_t>(std::accumulate(
shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
using T = uint16_t;
const T *src = value->GetTensorData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t k = 0; k != n; ++k) {
T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
std::copy(src, src + trailing_size, dst);
src += trailing_size;
}
}
return ans;
}
} // namespace sherpa_onnx
... ...
... ... @@ -23,6 +23,9 @@ template <typename T = float>
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t dim);
std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
const Ort::Value *value, int32_t dim);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_UNBIND_H_
... ...
... ... @@ -42,6 +42,7 @@ set(srcs
online-punctuation.cc
online-recognizer.cc
online-stream.cc
online-t-one-ctc-model-config.cc
online-transducer-model-config.cc
online-wenet-ctc-model-config.cc
online-zipformer2-ctc-model-config.cc
... ...
... ... @@ -5,6 +5,7 @@
#include <algorithm>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/provider-config.h"
#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
... ... @@ -25,6 +26,7 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m);
PybindOnlineNeMoCtcModelConfig(m);
PybindOnlineToneCtcModelConfig(m);
PybindProviderConfig(m);
using PyClass = OnlineModelConfig;
... ... @@ -34,17 +36,18 @@ void PybindOnlineModelConfig(py::module *m) {
const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &,
const OnlineNeMoCtcModelConfig &,
const ProviderConfig &,
const std::string &, int32_t, int32_t,
bool, const std::string &, const std::string &,
const OnlineToneCtcModelConfig &, const ProviderConfig &,
const std::string &, int32_t, int32_t, bool,
const std::string &, const std::string &,
const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(),
py::arg("provider_config") = ProviderConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("t_one_ctc") = OnlineToneCtcModelConfig(),
py::arg("provider_config") = ProviderConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("model_type") = "",
py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
... ... @@ -52,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) {
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("t_one_ctc", &PyClass::t_one_ctc)
.def_readwrite("provider_config", &PyClass::provider_config)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
... ...
// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOnlineToneCtcModelConfig(py::module *m) {
using PyClass = OnlineToneCtcModelConfig;
py::class_<PyClass>(*m, "OnlineToneCtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineToneCtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
... ...
... ... @@ -18,6 +18,7 @@ from sherpa_onnx.lib._sherpa_onnx import (
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
OnlineToneCtcModelConfig,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineZipformer2CtcModelConfig,
... ... @@ -603,6 +604,132 @@ class OnlineRecognizer(object):
return self
@classmethod
def from_t_one_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 8000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
debug:
True to show meta data in the model.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
t_one_ctc_config = OnlineToneCtcModelConfig(
model=model,
)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
t_one_ctc=t_one_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider_config=provider_config,
debug=debug,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_nemo_ctc(
cls,
tokens: str,
... ...