Fangjun Kuang
Committed by GitHub

Add C++ and Python support for T-one streaming Russian ASR models (#2575)

This PR adds support for T-one streaming Russian ASR models in both C++ and Python APIs. The T-one model is a CTC-based Russian speech recognition model with specific characteristics including float16 state handling, 300ms frame lengths, and 8kHz sampling rate.

- Added new OnlineToneCtcModel implementation with specialized processing for T-one models
- Integrated T-one support into the existing CTC model pipeline and Python bindings
- Added Python example and test scripts for the new functionality
@@ -8,6 +8,16 @@ log() { @@ -8,6 +8,16 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test T-one"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  14 +tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  15 +rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  16 +
  17 +python3 ./python-api-examples/online-t-one-ctc-decode-files.py
  18 +
  19 +rm -rf sherpa-onnx-streaming-t-one-russian-2025-09-08
  20 +
11 log "test nemo canary" 21 log "test nemo canary"
12 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2 22 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
13 tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2 23 tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
@@ -149,3 +149,4 @@ kitten-nano-en-v0_1-fp16 @@ -149,3 +149,4 @@ kitten-nano-en-v0_1-fp16
149 *.egg-info 149 *.egg-info
150 *.jar 150 *.jar
151 vocab.json 151 vocab.json
  152 +*.so
@@ -2,7 +2,8 @@ @@ -2,7 +2,8 @@
2 // Copyright (c) 2025 Xiaomi Corporation 2 // Copyright (c) 2025 Xiaomi Corporation
3 3
4 // To use punctuation model: 4 // To use punctuation model:
5 -// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 5 +// wget
  6 +// https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
6 // tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 7 // tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
7 // rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 8 // rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
8 9
@@ -15,14 +16,17 @@ int32_t main() { @@ -15,14 +16,17 @@ int32_t main() {
15 using namespace sherpa_onnx::cxx; // NOLINT 16 using namespace sherpa_onnx::cxx; // NOLINT
16 17
17 OfflinePunctuationConfig punctuation_config; 18 OfflinePunctuationConfig punctuation_config;
18 - punctuation_config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"; 19 + punctuation_config.model.ct_transformer =
  20 + "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/"
  21 + "model.onnx";
19 punctuation_config.model.num_threads = 1; 22 punctuation_config.model.num_threads = 1;
20 punctuation_config.model.debug = false; 23 punctuation_config.model.debug = false;
21 punctuation_config.model.provider = "cpu"; 24 punctuation_config.model.provider = "cpu";
22 25
23 OfflinePunctuation punct = OfflinePunctuation::Create(punctuation_config); 26 OfflinePunctuation punct = OfflinePunctuation::Create(punctuation_config);
24 if (!punct.Get()) { 27 if (!punct.Get()) {
25 - std::cerr << "Failed to create punctuation model. Please check your config\n"; 28 + std::cerr
  29 + << "Failed to create punctuation model. Please check your config\n";
26 return -1; 30 return -1;
27 } 31 }
28 32
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a streaming CTC model from T-one
  5 +to decode files.
  6 +
  7 +Please download model files from
  8 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  9 +
  10 +
  11 +The example model is converted from
  12 +https://github.com/voicekit-team/T-one
  13 +using
  14 +https://github.com/k2-fsa/sherpa-onnx/tree/master/scripts/t-one
  15 +
  16 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  17 +tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  18 +rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
  19 +"""
  20 +
  21 +from pathlib import Path
  22 +
  23 +import numpy as np
  24 +import sherpa_onnx
  25 +import soundfile as sf
  26 +
  27 +
  28 +def create_recognizer():
  29 + model = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/model.onnx"
  30 + tokens = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/tokens.txt"
  31 + test_wav = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/0.wav"
  32 +
  33 + if not Path(model).is_file() or not Path(test_wav).is_file():
  34 + raise ValueError(
  35 + """Please download model files from
  36 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  37 + """
  38 + )
  39 + return (
  40 + sherpa_onnx.OnlineRecognizer.from_t_one_ctc(
  41 + model=model,
  42 + tokens=tokens,
  43 + debug=True,
  44 + ),
  45 + test_wav,
  46 + )
  47 +
  48 +
  49 +def main():
  50 + recognizer, wave_filename = create_recognizer()
  51 +
  52 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  53 + audio = audio[:, 0] # only use the first channel
  54 +
  55 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  56 + # sample_rate does not need to be 8000 Hz
  57 +
  58 + stream = recognizer.create_stream()
  59 + left_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32)
  60 + stream.accept_waveform(sample_rate, left_paddings)
  61 +
  62 + stream.accept_waveform(sample_rate, audio)
  63 +
  64 + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
  65 + stream.accept_waveform(sample_rate, tail_paddings)
  66 + stream.input_finished()
  67 +
  68 + while recognizer.is_ready(stream):
  69 + recognizer.decode_stream(stream)
  70 + print(wave_filename)
  71 + print(recognizer.get_result_all(stream))
  72 +
  73 +
  74 +if __name__ == "__main__":
  75 + main()
@@ -147,14 +147,13 @@ def main(): @@ -147,14 +147,13 @@ def main():
147 sample_rate = model.sample_rate 147 sample_rate = model.sample_rate
148 148
149 # Pad 0.5 seconds 149 # Pad 0.5 seconds
150 - samples = np.pad(samples, (0, 4000)) 150 + samples = np.pad(samples, (2400, 2400))
151 151
152 features = compute_feat( 152 features = compute_feat(
153 samples=samples, 153 samples=samples,
154 sample_rate=sample_rate, 154 sample_rate=sample_rate,
155 frame_length_ms=model.frame_length_ms, 155 frame_length_ms=model.frame_length_ms,
156 ) 156 )
157 - print(features.shape)  
158 157
159 id2token = load_tokens(args.tokens) 158 id2token = load_tokens(args.tokens)
160 159
@@ -95,6 +95,8 @@ set(sources @@ -95,6 +95,8 @@ set(sources
95 online-recognizer.cc 95 online-recognizer.cc
96 online-rnn-lm.cc 96 online-rnn-lm.cc
97 online-stream.cc 97 online-stream.cc
  98 + online-t-one-ctc-model-config.cc
  99 + online-t-one-ctc-model.cc
98 online-transducer-decoder.cc 100 online-transducer-decoder.cc
99 online-transducer-greedy-search-decoder.cc 101 online-transducer-greedy-search-decoder.cc
100 online-transducer-greedy-search-nemo-decoder.cc 102 online-transducer-greedy-search-nemo-decoder.cc
@@ -7,8 +7,10 @@ @@ -7,8 +7,10 @@
7 #include <algorithm> 7 #include <algorithm>
8 #include <functional> 8 #include <functional>
9 #include <numeric> 9 #include <numeric>
  10 +#include <sstream>
10 #include <utility> 11 #include <utility>
11 12
  13 +#include "sherpa-onnx/csrc/macros.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
13 15
14 namespace sherpa_onnx { 16 namespace sherpa_onnx {
@@ -27,10 +29,12 @@ static bool Compare(const std::vector<int64_t> &a, @@ -27,10 +29,12 @@ static bool Compare(const std::vector<int64_t> &a,
27 } 29 }
28 30
29 static void PrintShape(const std::vector<int64_t> &a) { 31 static void PrintShape(const std::vector<int64_t> &a) {
  32 + std::ostringstream os;
30 for (auto i : a) { 33 for (auto i : a) {
31 - fprintf(stderr, "%d ", static_cast<int32_t>(i)); 34 + os << i << " ";
32 } 35 }
33 - fprintf(stderr, "\n"); 36 + os << "\n";
  37 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
34 } 38 }
35 39
36 template <typename T /*=float*/> 40 template <typename T /*=float*/>
@@ -51,15 +55,15 @@ Ort::Value Cat(OrtAllocator *allocator, @@ -51,15 +55,15 @@ Ort::Value Cat(OrtAllocator *allocator,
51 55
52 bool ret = Compare(v0_shape, s, dim); 56 bool ret = Compare(v0_shape, s, dim);
53 if (!ret) { 57 if (!ret) {
54 - fprintf(stderr, "Incorrect shape in Cat !\n"); 58 + SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
55 59
56 - fprintf(stderr, "Shape for tensor 0: "); 60 + SHERPA_ONNX_LOGE("Shape for tensor 0: ");
57 PrintShape(v0_shape); 61 PrintShape(v0_shape);
58 62
59 - fprintf(stderr, "Shape for tensor %d: ", i); 63 + SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
60 PrintShape(s); 64 PrintShape(s);
61 65
62 - exit(-1); 66 + SHERPA_ONNX_EXIT(-1);
63 } 67 }
64 } 68 }
65 69
@@ -99,8 +103,77 @@ template Ort::Value Cat<float>(OrtAllocator *allocator, @@ -99,8 +103,77 @@ template Ort::Value Cat<float>(OrtAllocator *allocator,
99 const std::vector<const Ort::Value *> &values, 103 const std::vector<const Ort::Value *> &values,
100 int32_t dim); 104 int32_t dim);
101 105
  106 +template Ort::Value Cat<uint16_t>(OrtAllocator *allocator,
  107 + const std::vector<const Ort::Value *> &values,
  108 + int32_t dim);
  109 +
102 template Ort::Value Cat<int64_t>(OrtAllocator *allocator, 110 template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
103 const std::vector<const Ort::Value *> &values, 111 const std::vector<const Ort::Value *> &values,
104 int32_t dim); 112 int32_t dim);
105 113
  114 +Ort::Value CatFloat16(OrtAllocator *allocator,
  115 + const std::vector<const Ort::Value *> &values,
  116 + int32_t dim) {
  117 + if (values.size() == 1u) {
  118 + return Clone(allocator, values[0]);
  119 + }
  120 +
  121 + std::vector<int64_t> v0_shape =
  122 + values[0]->GetTensorTypeAndShapeInfo().GetShape();
  123 +
  124 + int64_t total_dim = v0_shape[dim];
  125 +
  126 + for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
  127 + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
  128 + total_dim += s[dim];
  129 +
  130 + bool ret = Compare(v0_shape, s, dim);
  131 + if (!ret) {
  132 + SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
  133 +
  134 + SHERPA_ONNX_LOGE("Shape for tensor 0: ");
  135 + PrintShape(v0_shape);
  136 +
  137 + SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
  138 + PrintShape(s);
  139 +
  140 + SHERPA_ONNX_EXIT(-1);
  141 + }
  142 + }
  143 +
  144 + std::vector<int64_t> ans_shape;
  145 + ans_shape.reserve(v0_shape.size());
  146 + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
  147 + ans_shape.push_back(total_dim);
  148 + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
  149 + v0_shape.data() + v0_shape.size());
  150 +
  151 + auto leading_size = static_cast<int32_t>(std::accumulate(
  152 + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
  153 +
  154 + auto trailing_size = static_cast<int32_t>(
  155 + std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
  156 + std::multiplies<int64_t>()));
  157 +
  158 + Ort::Value ans =
  159 + Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
  160 + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
  161 + using T = uint16_t;
  162 +
  163 + T *dst = ans.GetTensorMutableData<T>();
  164 +
  165 + for (int32_t i = 0; i != leading_size; ++i) {
  166 + for (auto value : values) {
  167 + auto this_dim = value->GetTensorTypeAndShapeInfo().GetShape()[dim];
  168 + const T *src = value->GetTensorData<T>();
  169 + src += i * this_dim * trailing_size;
  170 +
  171 + std::copy(src, src + this_dim * trailing_size, dst);
  172 + dst += this_dim * trailing_size;
  173 + }
  174 + }
  175 +
  176 + return ans;
  177 +}
  178 +
106 } // namespace sherpa_onnx 179 } // namespace sherpa_onnx
@@ -23,6 +23,10 @@ template <typename T = float> @@ -23,6 +23,10 @@ template <typename T = float>
23 Ort::Value Cat(OrtAllocator *allocator, 23 Ort::Value Cat(OrtAllocator *allocator,
24 const std::vector<const Ort::Value *> &values, int32_t dim); 24 const std::vector<const Ort::Value *> &values, int32_t dim);
25 25
  26 +Ort::Value CatFloat16(OrtAllocator *allocator,
  27 + const std::vector<const Ort::Value *> &values,
  28 + int32_t dim);
  29 +
26 } // namespace sherpa_onnx 30 } // namespace sherpa_onnx
27 31
28 #endif // SHERPA_ONNX_CSRC_CAT_H_ 32 #endif // SHERPA_ONNX_CSRC_CAT_H_
@@ -62,6 +62,8 @@ class FeatureExtractor::Impl { @@ -62,6 +62,8 @@ class FeatureExtractor::Impl {
62 InitMfcc(); 62 InitMfcc();
63 } else if (config_.is_whisper) { 63 } else if (config_.is_whisper) {
64 InitWhisper(); 64 InitWhisper();
  65 + } else if (config_.is_t_one) {
  66 + InitRawAudioSamples();
65 } else { 67 } else {
66 InitFbank(); 68 InitFbank();
67 } 69 }
@@ -135,6 +137,9 @@ class FeatureExtractor::Impl { @@ -135,6 +137,9 @@ class FeatureExtractor::Impl {
135 } else if (whisper_fbank_) { 137 } else if (whisper_fbank_) {
136 whisper_fbank_->InputFinished(); 138 whisper_fbank_->InputFinished();
137 return; 139 return;
  140 + } else if (raw_audio_) {
  141 + raw_audio_->InputFinished();
  142 + return;
138 } else if (mfcc_) { 143 } else if (mfcc_) {
139 mfcc_->InputFinished(); 144 mfcc_->InputFinished();
140 return; 145 return;
@@ -149,6 +154,8 @@ class FeatureExtractor::Impl { @@ -149,6 +154,8 @@ class FeatureExtractor::Impl {
149 return fbank_->NumFramesReady(); 154 return fbank_->NumFramesReady();
150 } else if (whisper_fbank_) { 155 } else if (whisper_fbank_) {
151 return whisper_fbank_->NumFramesReady(); 156 return whisper_fbank_->NumFramesReady();
  157 + } else if (raw_audio_) {
  158 + return raw_audio_->NumFramesReady();
152 } else if (mfcc_) { 159 } else if (mfcc_) {
153 return mfcc_->NumFramesReady(); 160 return mfcc_->NumFramesReady();
154 } 161 }
@@ -163,6 +170,8 @@ class FeatureExtractor::Impl { @@ -163,6 +170,8 @@ class FeatureExtractor::Impl {
163 return fbank_->IsLastFrame(frame); 170 return fbank_->IsLastFrame(frame);
164 } else if (whisper_fbank_) { 171 } else if (whisper_fbank_) {
165 return whisper_fbank_->IsLastFrame(frame); 172 return whisper_fbank_->IsLastFrame(frame);
  173 + } else if (raw_audio_) {
  174 + return raw_audio_->IsLastFrame(frame);
166 } else if (mfcc_) { 175 } else if (mfcc_) {
167 return mfcc_->IsLastFrame(frame); 176 return mfcc_->IsLastFrame(frame);
168 } 177 }
@@ -209,6 +218,8 @@ class FeatureExtractor::Impl { @@ -209,6 +218,8 @@ class FeatureExtractor::Impl {
209 return opts_.mel_opts.num_bins; 218 return opts_.mel_opts.num_bins;
210 } else if (mfcc_) { 219 } else if (mfcc_) {
211 return mfcc_opts_.num_ceps; 220 return mfcc_opts_.num_ceps;
  221 + } else if (raw_audio_) {
  222 + return raw_audio_->Dim();
212 } 223 }
213 224
214 SHERPA_ONNX_LOGE("unreachable code"); 225 SHERPA_ONNX_LOGE("unreachable code");
@@ -225,6 +236,9 @@ class FeatureExtractor::Impl { @@ -225,6 +236,9 @@ class FeatureExtractor::Impl {
225 } else if (whisper_fbank_) { 236 } else if (whisper_fbank_) {
226 whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); 237 whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
227 return; 238 return;
  239 + } else if (raw_audio_) {
  240 + raw_audio_->AcceptWaveform(sampling_rate, waveform, n);
  241 + return;
228 } else if (mfcc_) { 242 } else if (mfcc_) {
229 mfcc_->AcceptWaveform(sampling_rate, waveform, n); 243 mfcc_->AcceptWaveform(sampling_rate, waveform, n);
230 return; 244 return;
@@ -239,6 +253,8 @@ class FeatureExtractor::Impl { @@ -239,6 +253,8 @@ class FeatureExtractor::Impl {
239 return fbank_->GetFrame(frame_index); 253 return fbank_->GetFrame(frame_index);
240 } else if (whisper_fbank_) { 254 } else if (whisper_fbank_) {
241 return whisper_fbank_->GetFrame(frame_index); 255 return whisper_fbank_->GetFrame(frame_index);
  256 + } else if (raw_audio_) {
  257 + return raw_audio_->GetFrame(frame_index);
242 } else if (mfcc_) { 258 } else if (mfcc_) {
243 return mfcc_->GetFrame(frame_index); 259 return mfcc_->GetFrame(frame_index);
244 } 260 }
@@ -255,6 +271,9 @@ class FeatureExtractor::Impl { @@ -255,6 +271,9 @@ class FeatureExtractor::Impl {
255 } else if (whisper_fbank_) { 271 } else if (whisper_fbank_) {
256 whisper_fbank_->Pop(discard_num); 272 whisper_fbank_->Pop(discard_num);
257 return; 273 return;
  274 + } else if (raw_audio_) {
  275 + raw_audio_->Pop(discard_num);
  276 + return;
258 } else if (mfcc_) { 277 } else if (mfcc_) {
259 mfcc_->Pop(discard_num); 278 mfcc_->Pop(discard_num);
260 return; 279 return;
@@ -322,11 +341,21 @@ class FeatureExtractor::Impl { @@ -322,11 +341,21 @@ class FeatureExtractor::Impl {
322 config_.sampling_rate = opts_.frame_opts.samp_freq; 341 config_.sampling_rate = opts_.frame_opts.samp_freq;
323 } 342 }
324 343
  344 + void InitRawAudioSamples() {
  345 + opts_raw_audio_.frame_opts.samp_freq = config_.sampling_rate;
  346 + opts_raw_audio_.frame_opts.frame_length_ms = config_.frame_length_ms;
  347 + opts_raw_audio_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
  348 +
  349 + raw_audio_ = std::make_unique<knf::OnlineRawAudioSamples>(opts_raw_audio_);
  350 + }
  351 +
325 private: 352 private:
326 std::unique_ptr<knf::OnlineFbank> fbank_; 353 std::unique_ptr<knf::OnlineFbank> fbank_;
327 std::unique_ptr<knf::OnlineMfcc> mfcc_; 354 std::unique_ptr<knf::OnlineMfcc> mfcc_;
328 std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; 355 std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
  356 + std::unique_ptr<knf::OnlineRawAudioSamples> raw_audio_;
329 knf::FbankOptions opts_; 357 knf::FbankOptions opts_;
  358 + knf::RawAudioSamplesOptions opts_raw_audio_;
330 knf::MfccOptions mfcc_opts_; 359 knf::MfccOptions mfcc_opts_;
331 FeatureExtractorConfig config_; 360 FeatureExtractorConfig config_;
332 mutable std::mutex mutex_; 361 mutable std::mutex mutex_;
@@ -81,6 +81,8 @@ struct FeatureExtractorConfig { @@ -81,6 +81,8 @@ struct FeatureExtractorConfig {
81 81
82 bool is_whisper = false; 82 bool is_whisper = false;
83 83
  84 + bool is_t_one = false;
  85 +
84 bool round_to_power_of_two = true; 86 bool round_to_power_of_two = true;
85 87
86 std::string ToString() const; 88 std::string ToString() const;
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/jieba-lexicon.h" 5 #include "sherpa-onnx/csrc/jieba-lexicon.h"
6 6
  7 +#include <algorithm>
7 #include <fstream> 8 #include <fstream>
8 #include <regex> // NOLINT 9 #include <regex> // NOLINT
9 #include <strstream> 10 #include <strstream>
@@ -38,7 +38,8 @@ struct OfflineRecognitionResult { @@ -38,7 +38,8 @@ struct OfflineRecognitionResult {
38 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 38 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
39 std::vector<float> timestamps; 39 std::vector<float> timestamps;
40 40
41 - /// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only) 41 + /// durations[i] contains the duration (in seconds) for tokens[i] (TDT models
  42 + /// only)
42 std::vector<float> durations; 43 std::vector<float> durations;
43 44
44 std::vector<int32_t> words; 45 std::vector<int32_t> words;
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_ 4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
5 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
6 6
  7 +#include <algorithm>
7 #include <cmath> 8 #include <cmath>
8 #include <memory> 9 #include <memory>
9 #include <string> 10 #include <string>
@@ -104,7 +104,8 @@ class OfflineTtsZipvoiceModel::Impl { @@ -104,7 +104,8 @@ class OfflineTtsZipvoiceModel::Impl {
104 int64_t feat_dim = meta_data_.feat_dim; 104 int64_t feat_dim = meta_data_.feat_dim;
105 105
106 std::vector<float> x_data(batch_size * num_frames * feat_dim); 106 std::vector<float> x_data(batch_size * num_frames * feat_dim);
107 - std::default_random_engine rng(std::random_device{}()); 107 + std::random_device rd;
  108 + std::default_random_engine rng(rd());
108 std::normal_distribution<float> norm(0, 1); 109 std::normal_distribution<float> norm(0, 1);
109 for (auto &v : x_data) v = norm(rng); 110 for (auto &v : x_data) v = norm(rng);
110 std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim}; 111 std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim};
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 #include <cmath> 7 #include <cmath>
8 #include <string> 8 #include <string>
9 #include <utility> 9 #include <utility>
  10 +#include <vector>
10 11
11 #if __ANDROID_API__ >= 9 12 #if __ANDROID_API__ >= 9
12 #include "android/asset_manager.h" 13 #include "android/asset_manager.h"
@@ -28,6 +28,13 @@ void OnlineCtcGreedySearchDecoder::Decode( @@ -28,6 +28,13 @@ void OnlineCtcGreedySearchDecoder::Decode(
28 auto &r = (*results)[b]; 28 auto &r = (*results)[b];
29 29
30 int32_t prev_id = -1; 30 int32_t prev_id = -1;
  31 + if (!r.tokens.empty()) {
  32 + if (r.num_trailing_blanks > 0) {
  33 + prev_id = blank_id_;
  34 + } else {
  35 + prev_id = r.tokens.back();
  36 + }
  37 + }
31 38
32 for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) { 39 for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
33 int32_t y = static_cast<int32_t>(std::distance( 40 int32_t y = static_cast<int32_t>(std::distance(
@@ -20,6 +20,7 @@ @@ -20,6 +20,7 @@
20 20
21 #include "sherpa-onnx/csrc/macros.h" 21 #include "sherpa-onnx/csrc/macros.h"
22 #include "sherpa-onnx/csrc/online-nemo-ctc-model.h" 22 #include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
  23 +#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
23 #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" 24 #include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
24 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" 25 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
25 #include "sherpa-onnx/csrc/onnx-utils.h" 26 #include "sherpa-onnx/csrc/onnx-utils.h"
@@ -34,9 +35,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( @@ -34,9 +35,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
34 return std::make_unique<OnlineZipformer2CtcModel>(config); 35 return std::make_unique<OnlineZipformer2CtcModel>(config);
35 } else if (!config.nemo_ctc.model.empty()) { 36 } else if (!config.nemo_ctc.model.empty()) {
36 return std::make_unique<OnlineNeMoCtcModel>(config); 37 return std::make_unique<OnlineNeMoCtcModel>(config);
  38 + } else if (!config.t_one_ctc.model.empty()) {
  39 + return std::make_unique<OnlineToneCtcModel>(config);
37 } else { 40 } else {
38 SHERPA_ONNX_LOGE("Please specify a CTC model"); 41 SHERPA_ONNX_LOGE("Please specify a CTC model");
39 - exit(-1); 42 + SHERPA_ONNX_EXIT(-1);
40 } 43 }
41 } 44 }
42 45
@@ -49,9 +52,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( @@ -49,9 +52,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
49 return std::make_unique<OnlineZipformer2CtcModel>(mgr, config); 52 return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
50 } else if (!config.nemo_ctc.model.empty()) { 53 } else if (!config.nemo_ctc.model.empty()) {
51 return std::make_unique<OnlineNeMoCtcModel>(mgr, config); 54 return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
  55 + } else if (!config.t_one_ctc.model.empty()) {
  56 + return std::make_unique<OnlineToneCtcModel>(mgr, config);
52 } else { 57 } else {
53 SHERPA_ONNX_LOGE("Please specify a CTC model"); 58 SHERPA_ONNX_LOGE("Please specify a CTC model");
54 - exit(-1); 59 + SHERPA_ONNX_EXIT(-1);
55 } 60 }
56 } 61 }
57 62
@@ -17,6 +17,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { @@ -17,6 +17,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
17 wenet_ctc.Register(po); 17 wenet_ctc.Register(po);
18 zipformer2_ctc.Register(po); 18 zipformer2_ctc.Register(po);
19 nemo_ctc.Register(po); 19 nemo_ctc.Register(po);
  20 + t_one_ctc.Register(po);
20 provider_config.Register(po); 21 provider_config.Register(po);
21 22
22 po->Register("tokens", &tokens, "Path to tokens.txt"); 23 po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -149,6 +150,10 @@ bool OnlineModelConfig::Validate() const { @@ -149,6 +150,10 @@ bool OnlineModelConfig::Validate() const {
149 return nemo_ctc.Validate(); 150 return nemo_ctc.Validate();
150 } 151 }
151 152
  153 + if (!t_one_ctc.model.empty()) {
  154 + return t_one_ctc.Validate();
  155 + }
  156 +
152 if (!provider_config.Validate()) { 157 if (!provider_config.Validate()) {
153 return false; 158 return false;
154 } 159 }
@@ -165,6 +170,7 @@ std::string OnlineModelConfig::ToString() const { @@ -165,6 +170,7 @@ std::string OnlineModelConfig::ToString() const {
165 os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; 170 os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
166 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; 171 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
167 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 172 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
  173 + os << "t_one_ctc=" << t_one_ctc.ToString() << ", ";
168 os << "provider_config=" << provider_config.ToString() << ", "; 174 os << "provider_config=" << provider_config.ToString() << ", ";
169 os << "tokens=\"" << tokens << "\", "; 175 os << "tokens=\"" << tokens << "\", ";
170 os << "num_threads=" << num_threads << ", "; 176 os << "num_threads=" << num_threads << ", ";
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" 9 #include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
10 #include "sherpa-onnx/csrc/online-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/online-paraformer-model-config.h"
  11 +#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
11 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 12 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
12 #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" 13 #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
13 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" 14 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
@@ -21,6 +22,7 @@ struct OnlineModelConfig { @@ -21,6 +22,7 @@ struct OnlineModelConfig {
21 OnlineWenetCtcModelConfig wenet_ctc; 22 OnlineWenetCtcModelConfig wenet_ctc;
22 OnlineZipformer2CtcModelConfig zipformer2_ctc; 23 OnlineZipformer2CtcModelConfig zipformer2_ctc;
23 OnlineNeMoCtcModelConfig nemo_ctc; 24 OnlineNeMoCtcModelConfig nemo_ctc;
  25 + OnlineToneCtcModelConfig t_one_ctc;
24 ProviderConfig provider_config; 26 ProviderConfig provider_config;
25 std::string tokens; 27 std::string tokens;
26 int32_t num_threads = 1; 28 int32_t num_threads = 1;
@@ -56,6 +58,7 @@ struct OnlineModelConfig { @@ -56,6 +58,7 @@ struct OnlineModelConfig {
56 const OnlineWenetCtcModelConfig &wenet_ctc, 58 const OnlineWenetCtcModelConfig &wenet_ctc,
57 const OnlineZipformer2CtcModelConfig &zipformer2_ctc, 59 const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
58 const OnlineNeMoCtcModelConfig &nemo_ctc, 60 const OnlineNeMoCtcModelConfig &nemo_ctc,
  61 + const OnlineToneCtcModelConfig &t_one_ctc,
59 const ProviderConfig &provider_config, 62 const ProviderConfig &provider_config,
60 const std::string &tokens, int32_t num_threads, 63 const std::string &tokens, int32_t num_threads,
61 int32_t warm_up, bool debug, const std::string &model_type, 64 int32_t warm_up, bool debug, const std::string &model_type,
@@ -66,6 +69,7 @@ struct OnlineModelConfig { @@ -66,6 +69,7 @@ struct OnlineModelConfig {
66 wenet_ctc(wenet_ctc), 69 wenet_ctc(wenet_ctc),
67 zipformer2_ctc(zipformer2_ctc), 70 zipformer2_ctc(zipformer2_ctc),
68 nemo_ctc(nemo_ctc), 71 nemo_ctc(nemo_ctc),
  72 + t_one_ctc(t_one_ctc),
69 provider_config(provider_config), 73 provider_config(provider_config),
70 tokens(tokens), 74 tokens(tokens),
71 num_threads(num_threads), 75 num_threads(num_threads),
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
7 7
8 #include <algorithm> 8 #include <algorithm>
  9 +#include <cassert>
9 #include <ios> 10 #include <ios>
10 #include <memory> 11 #include <memory>
11 #include <sstream> 12 #include <sstream>
@@ -79,24 +80,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -79,24 +80,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
79 config_(config), 80 config_(config),
80 model_(OnlineCtcModel::Create(config.model_config)), 81 model_(OnlineCtcModel::Create(config.model_config)),
81 endpoint_(config_.endpoint_config) { 82 endpoint_(config_.endpoint_config) {
82 - if (!config.model_config.tokens_buf.empty()) {  
83 - sym_ = SymbolTable(config.model_config.tokens_buf, false);  
84 - } else {  
85 - /// assuming tokens_buf and tokens are guaranteed not being both empty  
86 - sym_ = SymbolTable(config.model_config.tokens, true);  
87 - }  
88 -  
89 - if (!config.model_config.wenet_ctc.model.empty()) {  
90 - // WeNet CTC models assume input samples are in the range  
91 - // [-32768, 32767], so we set normalize_samples to false  
92 - config_.feat_config.normalize_samples = false;  
93 - }  
94 -  
95 - if (model_->UseWhisperFeature()) {  
96 - config_.feat_config.is_whisper = true;  
97 - }  
98 -  
99 - InitDecoder(); 83 + PostInit();
100 } 84 }
101 85
102 template <typename Manager> 86 template <typename Manager>
@@ -107,17 +91,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -107,17 +91,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
107 model_(OnlineCtcModel::Create(mgr, config.model_config)), 91 model_(OnlineCtcModel::Create(mgr, config.model_config)),
108 sym_(mgr, config.model_config.tokens), 92 sym_(mgr, config.model_config.tokens),
109 endpoint_(config_.endpoint_config) { 93 endpoint_(config_.endpoint_config) {
110 - if (!config.model_config.wenet_ctc.model.empty()) {  
111 - // WeNet CTC models assume input samples are in the range  
112 - // [-32768, 32767], so we set normalize_samples to false  
113 - config_.feat_config.normalize_samples = false;  
114 - }  
115 -  
116 - if (model_->UseWhisperFeature()) {  
117 - config_.feat_config.is_whisper = true;  
118 - }  
119 -  
120 - InitDecoder(); 94 + PostInit();
121 } 95 }
122 96
123 std::unique_ptr<OnlineStream> CreateStream() const override { 97 std::unique_ptr<OnlineStream> CreateStream() const override {
@@ -211,6 +185,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -211,6 +185,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
211 // TODO(fangjun): Remember to change these constants if needed 185 // TODO(fangjun): Remember to change these constants if needed
212 int32_t frame_shift_ms = 10; 186 int32_t frame_shift_ms = 10;
213 int32_t subsampling_factor = 4; 187 int32_t subsampling_factor = 4;
  188 + if (!config_.model_config.t_one_ctc.model.empty()) {
  189 + // each input frame is of 300ms long, which produces 10 output frames.
  190 + // so frame_shift_ms is 300/10 = 30ms
  191 + //
  192 + frame_shift_ms = 30;
  193 + subsampling_factor = 1;
  194 + }
  195 +
214 auto r = 196 auto r =
215 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, 197 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
216 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 198 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
@@ -258,6 +240,33 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -258,6 +240,33 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
258 } 240 }
259 241
260 private: 242 private:
  243 + void PostInit() {
  244 + if (!config_.model_config.tokens_buf.empty()) {
  245 + sym_ = SymbolTable(config_.model_config.tokens_buf, false);
  246 + } else {
  247 + /// assuming tokens_buf and tokens are guaranteed not being both empty
  248 + sym_ = SymbolTable(config_.model_config.tokens, true);
  249 + }
  250 +
  251 + if (!config_.model_config.wenet_ctc.model.empty()) {
  252 + // WeNet CTC models assume input samples are in the range
  253 + // [-32768, 32767], so we set normalize_samples to false
  254 + config_.feat_config.normalize_samples = false;
  255 + }
  256 +
  257 + if (!config_.model_config.t_one_ctc.model.empty()) {
  258 + config_.feat_config.is_t_one = true;
  259 + config_.feat_config.frame_length_ms = 300;
  260 + config_.feat_config.frame_shift_ms = 300;
  261 + config_.feat_config.sampling_rate = 8000;
  262 + }
  263 +
  264 + if (model_->UseWhisperFeature()) {
  265 + config_.feat_config.is_whisper = true;
  266 + }
  267 +
  268 + InitDecoder();
  269 + }
261 void InitDecoder() { 270 void InitDecoder() {
262 if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") && 271 if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
263 !sym_.Contains("<blank>")) { 272 !sym_.Contains("<blank>")) {
@@ -83,12 +83,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -83,12 +83,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
83 83
84 if (!config.model_config.wenet_ctc.model.empty() || 84 if (!config.model_config.wenet_ctc.model.empty() ||
85 !config.model_config.zipformer2_ctc.model.empty() || 85 !config.model_config.zipformer2_ctc.model.empty() ||
86 - !config.model_config.nemo_ctc.model.empty()) { 86 + !config.model_config.nemo_ctc.model.empty() ||
  87 + !config.model_config.t_one_ctc.model.empty()) {
87 return std::make_unique<OnlineRecognizerCtcImpl>(config); 88 return std::make_unique<OnlineRecognizerCtcImpl>(config);
88 } 89 }
89 90
90 SHERPA_ONNX_LOGE("Please specify a model"); 91 SHERPA_ONNX_LOGE("Please specify a model");
91 - exit(-1); 92 + SHERPA_ONNX_EXIT(-1);
92 } 93 }
93 94
94 template <typename Manager> 95 template <typename Manager>
@@ -142,12 +143,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -142,12 +143,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
142 143
143 if (!config.model_config.wenet_ctc.model.empty() || 144 if (!config.model_config.wenet_ctc.model.empty() ||
144 !config.model_config.zipformer2_ctc.model.empty() || 145 !config.model_config.zipformer2_ctc.model.empty() ||
145 - !config.model_config.nemo_ctc.model.empty()) { 146 + !config.model_config.nemo_ctc.model.empty() ||
  147 + !config.model_config.t_one_ctc.model.empty()) {
146 return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); 148 return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
147 } 149 }
148 150
149 SHERPA_ONNX_LOGE("Please specify a model"); 151 SHERPA_ONNX_LOGE("Please specify a model");
150 - exit(-1); 152 + SHERPA_ONNX_EXIT(-1);
151 } 153 }
152 154
153 OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) 155 OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
  1 +// sherpa-onnx/csrc/online-t-one-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OnlineToneCtcModelConfig::Register(ParseOptions *po) {
  13 + po->Register("t-one-ctc-model", &model,
  14 + "Path to CTC model.onnx from T-one. Please see "
  15 + "https://github.com/k2-fsa/sherpa-onnx/pull/2571");
  16 +}
  17 +
  18 +bool OnlineToneCtcModelConfig::Validate() const {
  19 + if (!FileExists(model)) {
  20 + SHERPA_ONNX_LOGE("T-one CTC model '%s' does not exist", model.c_str());
  21 + return false;
  22 + }
  23 +
  24 + return true;
  25 +}
  26 +
  27 +std::string OnlineToneCtcModelConfig::ToString() const {
  28 + std::ostringstream os;
  29 +
  30 + os << "OnlineToneCtcModelConfig(";
  31 + os << "model=\"" << model << "\")";
  32 +
  33 + return os.str();
  34 +}
  35 +
  36 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-t-one-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OnlineToneCtcModelConfig {
  14 + std::string model;
  15 +
  16 + OnlineToneCtcModelConfig() = default;
  17 +
  18 + explicit OnlineToneCtcModelConfig(const std::string &model) : model(model) {}
  19 +
  20 + void Register(ParseOptions *po);
  21 + bool Validate() const;
  22 +
  23 + std::string ToString() const;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-t-one-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <string>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#if __OHOS__
  17 +#include "rawfile/raw_file_manager.h"
  18 +#endif
  19 +
  20 +#include "sherpa-onnx/csrc/cat.h"
  21 +#include "sherpa-onnx/csrc/file-utils.h"
  22 +#include "sherpa-onnx/csrc/macros.h"
  23 +#include "sherpa-onnx/csrc/onnx-utils.h"
  24 +#include "sherpa-onnx/csrc/session.h"
  25 +#include "sherpa-onnx/csrc/text-utils.h"
  26 +#include "sherpa-onnx/csrc/unbind.h"
  27 +
  28 +namespace sherpa_onnx {
  29 +
  30 +class OnlineToneCtcModel::Impl {
  31 + public:
  32 + explicit Impl(const OnlineModelConfig &config)
  33 + : config_(config),
  34 + env_(ORT_LOGGING_LEVEL_ERROR),
  35 + sess_opts_(GetSessionOptions(config)),
  36 + allocator_{} {
  37 + {
  38 + auto buf = ReadFile(config.t_one_ctc.model);
  39 + Init(buf.data(), buf.size());
  40 + }
  41 + }
  42 +
  43 + template <typename Manager>
  44 + Impl(Manager *mgr, const OnlineModelConfig &config)
  45 + : config_(config),
  46 + env_(ORT_LOGGING_LEVEL_ERROR),
  47 + sess_opts_(GetSessionOptions(config)),
  48 + allocator_{} {
  49 + {
  50 + auto buf = ReadFile(mgr, config.t_one_ctc.model);
  51 + Init(buf.data(), buf.size());
  52 + }
  53 + }
  54 +
  55 + std::vector<Ort::Value> Forward(Ort::Value x,
  56 + std::vector<Ort::Value> states) {
  57 + // shape0 is (batch_size, 1, num_samples)
  58 + auto shape0 = x.GetTensorTypeAndShapeInfo().GetShape();
  59 + std::array<int64_t, 3> shape = {shape0[0], shape0[2], shape0[1]};
  60 + std::vector<int32_t> samples(shape[0] * shape[1] * shape[2]);
  61 + const float *px = x.GetTensorData<float>();
  62 +
  63 + for (int32_t i = 0; i < samples.size(); ++i) {
  64 + float f = px[i];
  65 + f = f > 1 ? 1 : f;
  66 + f = f < -1 ? -1 : f;
  67 + samples[i] = static_cast<int32_t>(f * 32767);
  68 + }
  69 +
  70 + auto memory_info =
  71 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  72 +
  73 + Ort::Value xx =
  74 + Ort::Value::CreateTensor(memory_info, samples.data(), samples.size(),
  75 + shape.data(), shape.size());
  76 +
  77 + std::array<Ort::Value, 2> inputs = {std::move(xx), std::move(states[0])};
  78 +
  79 + auto out =
  80 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  81 + output_names_ptr_.data(), output_names_ptr_.size());
  82 + // out[0]: log_probs
  83 + // out[1] next_states
  84 +
  85 + return out;
  86 + }
  87 +
  88 + int32_t VocabSize() const { return vocab_size_; }
  89 +
  90 + int32_t ChunkLength() const { return 1; }
  91 +
  92 + int32_t ChunkShift() const { return 1; }
  93 +
  94 + OrtAllocator *Allocator() { return allocator_; }
  95 +
  96 + // Return a vector containing 1 tensor
  97 + // - state_
  98 + std::vector<Ort::Value> GetInitStates() {
  99 + std::vector<Ort::Value> ans;
  100 + ans.push_back(View(&state_));
  101 +
  102 + return ans;
  103 + }
  104 +
  105 + std::vector<Ort::Value> StackStates(
  106 + std::vector<std::vector<Ort::Value>> states) {
  107 + int32_t batch_size = static_cast<int32_t>(states.size());
  108 + if (batch_size == 1) {
  109 + return std::move(states[0]);
  110 + }
  111 +
  112 + std::vector<Ort::Value> ans;
  113 + ans.reserve(1);
  114 +
  115 + std::vector<const Ort::Value *> buf;
  116 + buf.reserve(batch_size);
  117 +
  118 + for (int32_t b = 0; b != batch_size; ++b) {
  119 + buf.push_back(&states[b][0]);
  120 + }
  121 +
  122 + Ort::Value c{nullptr};
  123 + c = CatFloat16(allocator_, buf, 0);
  124 +
  125 + ans.push_back(std::move(c));
  126 +
  127 + return ans;
  128 + }
  129 +
  130 + std::vector<std::vector<Ort::Value>> UnStackStates(
  131 + std::vector<Ort::Value> states) const {
  132 + auto allocator = const_cast<Impl *>(this)->allocator_;
  133 +
  134 + std::vector<std::vector<Ort::Value>> ans;
  135 +
  136 + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
  137 + int32_t batch_size = shape[0];
  138 + ans.resize(batch_size);
  139 +
  140 + if (batch_size == 1) {
  141 + ans[0] = std::move(states);
  142 + return ans;
  143 + }
  144 +
  145 + std::vector<Ort::Value> v;
  146 + v = UnbindFloat16(allocator, &states[0], 0);
  147 +
  148 + for (int32_t b = 0; b != batch_size; ++b) {
  149 + ans[b].push_back(std::move(v[b]));
  150 + }
  151 +
  152 + return ans;
  153 + }
  154 +
  155 + private:
  156 + void Init(void *model_data, size_t model_data_length) {
  157 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  158 + sess_opts_);
  159 +
  160 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  161 +
  162 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  163 +
  164 + // get meta data
  165 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  166 + if (config_.debug) {
  167 + std::ostringstream os;
  168 + PrintModelMetadata(os, meta_data);
  169 +#if __OHOS__
  170 + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
  171 +#else
  172 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  173 +#endif
  174 + }
  175 +
  176 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  177 + SHERPA_ONNX_READ_META_DATA(frame_length_ms_, "frame_length_ms");
  178 + SHERPA_ONNX_READ_META_DATA(state_dim_, "state_dim");
  179 + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
  180 +
  181 + InitStates();
  182 +
  183 + vocab_size_ = sess_->GetOutputTypeInfo(0)
  184 + .GetTensorTypeAndShapeInfo()
  185 + .GetShape()
  186 + .back();
  187 + }
  188 +
  189 + void InitStates() {
  190 + std::array<int64_t, 2> state_shape{1, state_dim_};
  191 +
  192 + state_ = Ort::Value::CreateTensor(allocator_, state_shape.data(),
  193 + state_shape.size(),
  194 + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
  195 +
  196 + auto p = state_.GetTensorMutableData<uint16_t>();
  197 + std::fill(p, p + state_dim_, 0);
  198 + }
  199 +
  200 + private:
  201 + OnlineModelConfig config_;
  202 + Ort::Env env_;
  203 + Ort::SessionOptions sess_opts_;
  204 + Ort::AllocatorWithDefaultOptions allocator_;
  205 +
  206 + std::unique_ptr<Ort::Session> sess_;
  207 +
  208 + std::vector<std::string> input_names_;
  209 + std::vector<const char *> input_names_ptr_;
  210 +
  211 + std::vector<std::string> output_names_;
  212 + std::vector<const char *> output_names_ptr_;
  213 +
  214 + // One input frame is of length is 300ms
  215 + // For each input frame, there are 10 output frames,
  216 + // so each output frame is 30ms
  217 + int32_t frame_length_ms_ = 0;
  218 + int32_t state_dim_ = 0;
  219 + int32_t sample_rate_ = 0;
  220 + int32_t vocab_size_ = 0;
  221 +
  222 + Ort::Value state_{nullptr};
  223 +};
  224 +
  225 +OnlineToneCtcModel::OnlineToneCtcModel(const OnlineModelConfig &config)
  226 + : impl_(std::make_unique<Impl>(config)) {}
  227 +
  228 +template <typename Manager>
  229 +OnlineToneCtcModel::OnlineToneCtcModel(Manager *mgr,
  230 + const OnlineModelConfig &config)
  231 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  232 +
  233 +OnlineToneCtcModel::~OnlineToneCtcModel() = default;
  234 +
  235 +std::vector<Ort::Value> OnlineToneCtcModel::Forward(
  236 + Ort::Value x, std::vector<Ort::Value> states) const {
  237 + return impl_->Forward(std::move(x), std::move(states));
  238 +}
  239 +
  240 +int32_t OnlineToneCtcModel::VocabSize() const { return impl_->VocabSize(); }
  241 +
  242 +int32_t OnlineToneCtcModel::ChunkLength() const { return impl_->ChunkLength(); }
  243 +
  244 +int32_t OnlineToneCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
  245 +
  246 +OrtAllocator *OnlineToneCtcModel::Allocator() const {
  247 + return impl_->Allocator();
  248 +}
  249 +
  250 +std::vector<Ort::Value> OnlineToneCtcModel::GetInitStates() const {
  251 + return impl_->GetInitStates();
  252 +}
  253 +
  254 +std::vector<Ort::Value> OnlineToneCtcModel::StackStates(
  255 + std::vector<std::vector<Ort::Value>> states) const {
  256 + return impl_->StackStates(std::move(states));
  257 +}
  258 +
  259 +std::vector<std::vector<Ort::Value>> OnlineToneCtcModel::UnStackStates(
  260 + std::vector<Ort::Value> states) const {
  261 + return impl_->UnStackStates(std::move(states));
  262 +}
  263 +
  264 +#if __ANDROID_API__ >= 9
  265 +template OnlineToneCtcModel::OnlineToneCtcModel(
  266 + AAssetManager *mgr, const OnlineModelConfig &config);
  267 +#endif
  268 +
  269 +#if __OHOS__
  270 +template OnlineToneCtcModel::OnlineToneCtcModel(
  271 + NativeResourceManager *mgr, const OnlineModelConfig &config);
  272 +#endif
  273 +
  274 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-t-one-ctc-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/online-ctc-model.h"
  13 +#include "sherpa-onnx/csrc/online-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OnlineToneCtcModel : public OnlineCtcModel {
  18 + public:
  19 + explicit OnlineToneCtcModel(const OnlineModelConfig &config);
  20 +
  21 + template <typename Manager>
  22 + OnlineToneCtcModel(Manager *mgr, const OnlineModelConfig &config);
  23 +
  24 + ~OnlineToneCtcModel() override;
  25 +
  26 + // A list of 1 tensor:
  27 + // - (batch_size, state_dim)
  28 + std::vector<Ort::Value> GetInitStates() const override;
  29 +
  30 + std::vector<Ort::Value> StackStates(
  31 + std::vector<std::vector<Ort::Value>> states) const override;
  32 +
  33 + std::vector<std::vector<Ort::Value>> UnStackStates(
  34 + std::vector<Ort::Value> states) const override;
  35 +
  36 + /**
  37 + *
  38 + * @param x A 3-D tensor of shape (batch_size, num_samples).
  39 + * @param states It is from GetInitStates() or returned from this method.
  40 + *
  41 + * @return Return a list of tensors
  42 + * - ans[0] contains log_probs, of shape (N, T, C)
  43 + * - ans[1:] contains next_states
  44 + */
  45 + std::vector<Ort::Value> Forward(
  46 + Ort::Value x, std::vector<Ort::Value> states) const override;
  47 +
  48 + /** Return the vocabulary size of the model
  49 + */
  50 + int32_t VocabSize() const override;
  51 +
  52 + /** Return an allocator for allocating memory
  53 + */
  54 + OrtAllocator *Allocator() const override;
  55 +
  56 + // The model accepts this number of frames before subsampling as input
  57 + int32_t ChunkLength() const override;
  58 +
  59 + // Similar to frame_shift in feature extractor, after processing
  60 + // ChunkLength() frames, we advance by ChunkShift() frames
  61 + // before we process the next chunk.
  62 + int32_t ChunkShift() const override;
  63 +
  64 + bool SupportBatchProcessing() const override { return true; }
  65 +
  66 + private:
  67 + class Impl;
  68 + std::unique_ptr<Impl> impl_;
  69 +};
  70 +
  71 +} // namespace sherpa_onnx
  72 +
  73 +#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
@@ -155,10 +155,30 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { @@ -155,10 +155,30 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
155 std::copy(start, end, dst); 155 std::copy(start, end, dst);
156 return ans; 156 return ans;
157 } 157 }
  158 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
  159 + Ort::Value ans =
  160 + Ort::Value::CreateTensor(allocator, shape.data(), shape.size(),
  161 + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
  162 + const auto *start = v->GetTensorData<uint16_t>();
  163 + const auto *end = start + type_and_shape.GetElementCount();
  164 + auto *dst = ans.GetTensorMutableData<uint16_t>();
  165 + std::copy(start, end, dst);
  166 + return ans;
  167 + }
  168 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
  169 + Ort::Value ans = Ort::Value::CreateTensor<uint16_t>(
  170 + allocator, shape.data(), shape.size());
  171 + const auto *start = v->GetTensorData<uint16_t>();
  172 + const auto *end = start + type_and_shape.GetElementCount();
  173 + auto *dst = ans.GetTensorMutableData<uint16_t>();
  174 + std::copy(start, end, dst);
  175 + return ans;
  176 + }
  177 +
158 default: 178 default:
159 - fprintf(stderr, "Unsupported type: %d\n",  
160 - static_cast<int32_t>(type_and_shape.GetElementType()));  
161 - exit(-1); 179 + SHERPA_ONNX_LOGE("Unsupported type: %d\n",
  180 + static_cast<int32_t>(type_and_shape.GetElementType()));
  181 + SHERPA_ONNX_EXIT(-1);
162 // unreachable code 182 // unreachable code
163 return Ort::Value{nullptr}; 183 return Ort::Value{nullptr};
164 } 184 }
@@ -183,14 +203,23 @@ Ort::Value View(Ort::Value *v) { @@ -183,14 +203,23 @@ Ort::Value View(Ort::Value *v) {
183 return Ort::Value::CreateTensor( 203 return Ort::Value::CreateTensor(
184 memory_info, v->GetTensorMutableData<float>(), 204 memory_info, v->GetTensorMutableData<float>(),
185 type_and_shape.GetElementCount(), shape.data(), shape.size()); 205 type_and_shape.GetElementCount(), shape.data(), shape.size());
  206 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
  207 + return Ort::Value::CreateTensor(
  208 + memory_info, v->GetTensorMutableData<uint16_t>(),
  209 + type_and_shape.GetElementCount() * sizeof(uint16_t), shape.data(),
  210 + shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
  211 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
  212 + return Ort::Value::CreateTensor(
  213 + memory_info, v->GetTensorMutableData<uint16_t>(),
  214 + type_and_shape.GetElementCount(), shape.data(), shape.size());
186 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: 215 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
187 return Ort::Value::CreateTensor( 216 return Ort::Value::CreateTensor(
188 memory_info, v->GetTensorMutableData<bool>(), 217 memory_info, v->GetTensorMutableData<bool>(),
189 type_and_shape.GetElementCount(), shape.data(), shape.size()); 218 type_and_shape.GetElementCount(), shape.data(), shape.size());
190 default: 219 default:
191 - fprintf(stderr, "Unsupported type: %d\n",  
192 - static_cast<int32_t>(type_and_shape.GetElementType()));  
193 - exit(-1); 220 + SHERPA_ONNX_LOGE("Unsupported type: %d\n",
  221 + static_cast<int32_t>(type_and_shape.GetElementType()));
  222 + SHERPA_ONNX_EXIT(-1);
194 // unreachable code 223 // unreachable code
195 return Ort::Value{nullptr}; 224 return Ort::Value{nullptr};
196 } 225 }
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include <locale> 11 #include <locale>
12 #endif 12 #endif
13 13
  14 +#include <algorithm>
14 #include <cassert> 15 #include <cassert>
15 #include <ostream> 16 #include <ostream>
16 #include <string> 17 #include <string>
@@ -117,6 +117,11 @@ for a list of pre-trained models to download. @@ -117,6 +117,11 @@ for a list of pre-trained models to download.
117 const float duration = samples.size() / static_cast<float>(sampling_rate); 117 const float duration = samples.size() / static_cast<float>(sampling_rate);
118 118
119 auto s = recognizer.CreateStream(); 119 auto s = recognizer.CreateStream();
  120 +
  121 + std::vector<float> left_paddings(static_cast<int>(0.3 * sampling_rate));
  122 + s->AcceptWaveform(sampling_rate, left_paddings.data(),
  123 + left_paddings.size());
  124 +
120 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 125 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
121 126
122 std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate)); 127 std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/text-utils.h" 5 #include "sherpa-onnx/csrc/text-utils.h"
6 6
7 -#include <regex> 7 +#include <regex> // NOLINT
8 #include <sstream> 8 #include <sstream>
9 9
10 #include "gtest/gtest.h" 10 #include "gtest/gtest.h"
@@ -68,4 +68,49 @@ template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator, @@ -68,4 +68,49 @@ template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
68 const Ort::Value *value, 68 const Ort::Value *value,
69 int32_t dim); 69 int32_t dim);
70 70
  71 +std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
  72 + const Ort::Value *value, int32_t dim) {
  73 + std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
  74 + assert(dim >= 0);
  75 + assert(dim < static_cast<int32_t>(shape.size()));
  76 + int32_t n = static_cast<int32_t>(shape[dim]);
  77 + if (n == 1) {
  78 + std::vector<Ort::Value> ans;
  79 + ans.push_back(Clone(allocator, value));
  80 + return ans;
  81 + }
  82 +
  83 + std::vector<int64_t> ans_shape = shape;
  84 + ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
  85 +
  86 + // allocator tensors
  87 + std::vector<Ort::Value> ans;
  88 + ans.reserve(n);
  89 + for (int32_t i = 0; i != n; ++i) {
  90 + Ort::Value t =
  91 + Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
  92 + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
  93 + ans.push_back(std::move(t));
  94 + }
  95 +
  96 + auto leading_size = static_cast<int32_t>(std::accumulate(
  97 + shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
  98 +
  99 + auto trailing_size = static_cast<int32_t>(std::accumulate(
  100 + shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
  101 +
  102 + using T = uint16_t;
  103 + const T *src = value->GetTensorData<T>();
  104 +
  105 + for (int32_t i = 0; i != leading_size; ++i) {
  106 + for (int32_t k = 0; k != n; ++k) {
  107 + T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
  108 + std::copy(src, src + trailing_size, dst);
  109 + src += trailing_size;
  110 + }
  111 + }
  112 +
  113 + return ans;
  114 +}
  115 +
71 } // namespace sherpa_onnx 116 } // namespace sherpa_onnx
@@ -23,6 +23,9 @@ template <typename T = float> @@ -23,6 +23,9 @@ template <typename T = float>
23 std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, 23 std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
24 int32_t dim); 24 int32_t dim);
25 25
  26 +std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
  27 + const Ort::Value *value, int32_t dim);
  28 +
26 } // namespace sherpa_onnx 29 } // namespace sherpa_onnx
27 30
28 #endif // SHERPA_ONNX_CSRC_UNBIND_H_ 31 #endif // SHERPA_ONNX_CSRC_UNBIND_H_
@@ -42,6 +42,7 @@ set(srcs @@ -42,6 +42,7 @@ set(srcs
42 online-punctuation.cc 42 online-punctuation.cc
43 online-recognizer.cc 43 online-recognizer.cc
44 online-stream.cc 44 online-stream.cc
  45 + online-t-one-ctc-model-config.cc
45 online-transducer-model-config.cc 46 online-transducer-model-config.cc
46 online-wenet-ctc-model-config.cc 47 online-wenet-ctc-model-config.cc
47 online-zipformer2-ctc-model-config.cc 48 online-zipformer2-ctc-model-config.cc
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 5
6 #include <algorithm> 6 #include <algorithm>
7 #include <string> 7 #include <string>
  8 +#include <vector>
8 9
9 #include "sherpa-onnx/csrc/offline-tts.h" 10 #include "sherpa-onnx/csrc/offline-tts.h"
10 #include "sherpa-onnx/python/csrc/offline-tts-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/provider-config.h" 12 #include "sherpa-onnx/csrc/provider-config.h"
13 #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" 13 #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
14 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" 14 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
  15 +#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
15 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 16 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
16 #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" 17 #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
17 #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" 18 #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
@@ -25,6 +26,7 @@ void PybindOnlineModelConfig(py::module *m) { @@ -25,6 +26,7 @@ void PybindOnlineModelConfig(py::module *m) {
25 PybindOnlineWenetCtcModelConfig(m); 26 PybindOnlineWenetCtcModelConfig(m);
26 PybindOnlineZipformer2CtcModelConfig(m); 27 PybindOnlineZipformer2CtcModelConfig(m);
27 PybindOnlineNeMoCtcModelConfig(m); 28 PybindOnlineNeMoCtcModelConfig(m);
  29 + PybindOnlineToneCtcModelConfig(m);
28 PybindProviderConfig(m); 30 PybindProviderConfig(m);
29 31
30 using PyClass = OnlineModelConfig; 32 using PyClass = OnlineModelConfig;
@@ -34,17 +36,18 @@ void PybindOnlineModelConfig(py::module *m) { @@ -34,17 +36,18 @@ void PybindOnlineModelConfig(py::module *m) {
34 const OnlineWenetCtcModelConfig &, 36 const OnlineWenetCtcModelConfig &,
35 const OnlineZipformer2CtcModelConfig &, 37 const OnlineZipformer2CtcModelConfig &,
36 const OnlineNeMoCtcModelConfig &, 38 const OnlineNeMoCtcModelConfig &,
37 - const ProviderConfig &,  
38 - const std::string &, int32_t, int32_t,  
39 - bool, const std::string &, const std::string &, 39 + const OnlineToneCtcModelConfig &, const ProviderConfig &,
  40 + const std::string &, int32_t, int32_t, bool,
  41 + const std::string &, const std::string &,
40 const std::string &>(), 42 const std::string &>(),
41 py::arg("transducer") = OnlineTransducerModelConfig(), 43 py::arg("transducer") = OnlineTransducerModelConfig(),
42 py::arg("paraformer") = OnlineParaformerModelConfig(), 44 py::arg("paraformer") = OnlineParaformerModelConfig(),
43 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), 45 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
44 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), 46 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
45 py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), 47 py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(),
46 - py::arg("provider_config") = ProviderConfig(),  
47 - py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, 48 + py::arg("t_one_ctc") = OnlineToneCtcModelConfig(),
  49 + py::arg("provider_config") = ProviderConfig(), py::arg("tokens"),
  50 + py::arg("num_threads"), py::arg("warm_up") = 0,
48 py::arg("debug") = false, py::arg("model_type") = "", 51 py::arg("debug") = false, py::arg("model_type") = "",
49 py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") 52 py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "")
50 .def_readwrite("transducer", &PyClass::transducer) 53 .def_readwrite("transducer", &PyClass::transducer)
@@ -52,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) { @@ -52,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) {
52 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 55 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
53 .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) 56 .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
54 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 57 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
  58 + .def_readwrite("t_one_ctc", &PyClass::t_one_ctc)
55 .def_readwrite("provider_config", &PyClass::provider_config) 59 .def_readwrite("provider_config", &PyClass::provider_config)
56 .def_readwrite("tokens", &PyClass::tokens) 60 .def_readwrite("tokens", &PyClass::tokens)
57 .def_readwrite("num_threads", &PyClass::num_threads) 61 .def_readwrite("num_threads", &PyClass::num_threads)
  1 +// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOnlineToneCtcModelConfig(py::module *m) {
  15 + using PyClass = OnlineToneCtcModelConfig;
  16 + py::class_<PyClass>(*m, "OnlineToneCtcModelConfig")
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineToneCtcModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
@@ -18,6 +18,7 @@ from sherpa_onnx.lib._sherpa_onnx import ( @@ -18,6 +18,7 @@ from sherpa_onnx.lib._sherpa_onnx import (
18 OnlineRecognizerConfig, 18 OnlineRecognizerConfig,
19 OnlineRecognizerResult, 19 OnlineRecognizerResult,
20 OnlineStream, 20 OnlineStream,
  21 + OnlineToneCtcModelConfig,
21 OnlineTransducerModelConfig, 22 OnlineTransducerModelConfig,
22 OnlineWenetCtcModelConfig, 23 OnlineWenetCtcModelConfig,
23 OnlineZipformer2CtcModelConfig, 24 OnlineZipformer2CtcModelConfig,
@@ -603,6 +604,132 @@ class OnlineRecognizer(object): @@ -603,6 +604,132 @@ class OnlineRecognizer(object):
603 return self 604 return self
604 605
605 @classmethod 606 @classmethod
  607 + def from_t_one_ctc(
  608 + cls,
  609 + tokens: str,
  610 + model: str,
  611 + num_threads: int = 2,
  612 + sample_rate: float = 8000,
  613 + feature_dim: int = 80,
  614 + enable_endpoint_detection: bool = False,
  615 + rule1_min_trailing_silence: float = 2.4,
  616 + rule2_min_trailing_silence: float = 1.2,
  617 + rule3_min_utterance_length: float = 20.0,
  618 + decoding_method: str = "greedy_search",
  619 + provider: str = "cpu",
  620 + debug: bool = False,
  621 + rule_fsts: str = "",
  622 + rule_fars: str = "",
  623 + device: int = 0,
  624 + hr_dict_dir: str = "",
  625 + hr_rule_fsts: str = "",
  626 + hr_lexicon: str = "",
  627 + ):
  628 + """
  629 + Please refer to
  630 + `<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
  631 + to download pre-trained models.
  632 +
  633 + Args:
  634 + tokens:
  635 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  636 + columns::
  637 +
  638 + symbol integer_id
  639 +
  640 + model:
  641 + Path to ``model.onnx``.
  642 + num_threads:
  643 + Number of threads for neural network computation.
  644 + sample_rate:
  645 + Sample rate of the training data used to train the model.
  646 + feature_dim:
  647 + Dimension of the feature used to train the model.
  648 + enable_endpoint_detection:
  649 + True to enable endpoint detection. False to disable endpoint
  650 + detection.
  651 + rule1_min_trailing_silence:
  652 + Used only when enable_endpoint_detection is True. If the duration
  653 + of trailing silence in seconds is larger than this value, we assume
  654 + an endpoint is detected.
  655 + rule2_min_trailing_silence:
  656 + Used only when enable_endpoint_detection is True. If we have decoded
  657 + something that is nonsilence and if the duration of trailing silence
  658 + in seconds is larger than this value, we assume an endpoint is
  659 + detected.
  660 + rule3_min_utterance_length:
  661 + Used only when enable_endpoint_detection is True. If the utterance
  662 + length in seconds is larger than this value, we assume an endpoint
  663 + is detected.
  664 + decoding_method:
  665 + The only valid value is greedy_search.
  666 + provider:
  667 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  668 + debug:
  669 + True to show meta data in the model.
  670 + rule_fsts:
  671 + If not empty, it specifies fsts for inverse text normalization.
  672 + If there are multiple fsts, they are separated by a comma.
  673 + rule_fars:
  674 + If not empty, it specifies fst archives for inverse text normalization.
  675 + If there are multiple archives, they are separated by a comma.
  676 + device:
  677 + onnxruntime cuda device index.
  678 + """
  679 + self = cls.__new__(cls)
  680 + _assert_file_exists(tokens)
  681 + _assert_file_exists(model)
  682 +
  683 + assert num_threads > 0, num_threads
  684 +
  685 + t_one_ctc_config = OnlineToneCtcModelConfig(
  686 + model=model,
  687 + )
  688 +
  689 + provider_config = ProviderConfig(
  690 + provider=provider,
  691 + device=device,
  692 + )
  693 +
  694 + model_config = OnlineModelConfig(
  695 + t_one_ctc=t_one_ctc_config,
  696 + tokens=tokens,
  697 + num_threads=num_threads,
  698 + provider_config=provider_config,
  699 + debug=debug,
  700 + )
  701 +
  702 + feat_config = FeatureExtractorConfig(
  703 + sampling_rate=sample_rate,
  704 + feature_dim=feature_dim,
  705 + )
  706 +
  707 + endpoint_config = EndpointConfig(
  708 + rule1_min_trailing_silence=rule1_min_trailing_silence,
  709 + rule2_min_trailing_silence=rule2_min_trailing_silence,
  710 + rule3_min_utterance_length=rule3_min_utterance_length,
  711 + )
  712 +
  713 + recognizer_config = OnlineRecognizerConfig(
  714 + feat_config=feat_config,
  715 + model_config=model_config,
  716 + endpoint_config=endpoint_config,
  717 + enable_endpoint=enable_endpoint_detection,
  718 + decoding_method=decoding_method,
  719 + rule_fsts=rule_fsts,
  720 + rule_fars=rule_fars,
  721 + hr=HomophoneReplacerConfig(
  722 + dict_dir=hr_dict_dir,
  723 + lexicon=hr_lexicon,
  724 + rule_fsts=hr_rule_fsts,
  725 + ),
  726 + )
  727 +
  728 + self.recognizer = _Recognizer(recognizer_config)
  729 + self.config = recognizer_config
  730 + return self
  731 +
  732 + @classmethod
606 def from_nemo_ctc( 733 def from_nemo_ctc(
607 cls, 734 cls,
608 tokens: str, 735 tokens: str,