Fangjun Kuang
Committed by GitHub

Add C++ runtime and Python API for NeMo Canary models (#2352)

@@ -8,6 +8,13 @@ log() { @@ -8,6 +8,13 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test nemo canary"
  12 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
  13 +tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
  14 +rm sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
  15 +python3 ./python-api-examples/offline-nemo-canary-decode-files.py
  16 +rm -rf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
  17 +
11 log "test spleeter" 18 log "test spleeter"
12 19
13 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 20 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming Canary model from NeMo
  5 +to decode files.
  6 +
  7 +Please download model files from
  8 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  9 +
  10 +
  11 +The example model supports 4 languages and it is converted from
  12 +https://huggingface.co/nvidia/canary-180m-flash
  13 +
  14 +It supports automatic speech-to-text recognition (ASR) in 4 languages
  15 +(English, German, French, Spanish) and translation from English to
  16 +German/French/Spanish and from German/French/Spanish to English with or
  17 +without punctuation and capitalization (PnC).
  18 +"""
  19 +
  20 +from pathlib import Path
  21 +
  22 +import sherpa_onnx
  23 +import soundfile as sf
  24 +
  25 +
  26 +def create_recognizer():
  27 + encoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/encoder.int8.onnx"
  28 + decoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/decoder.int8.onnx"
  29 + tokens = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/tokens.txt"
  30 +
  31 + en_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/en.wav"
  32 + de_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/de.wav"
  33 +
  34 + if not Path(encoder).is_file() or not Path(en_wav).is_file():
  35 + raise ValueError(
  36 + """Please download model files from
  37 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  38 + """
  39 + )
  40 + return (
  41 + sherpa_onnx.OfflineRecognizer.from_nemo_canary(
  42 + encoder=encoder,
  43 + decoder=decoder,
  44 + tokens=tokens,
  45 + debug=True,
  46 + ),
  47 + en_wav,
  48 + de_wav,
  49 + )
  50 +
  51 +
  52 +def decode(recognizer, samples, sample_rate, src_lang, tgt_lang):
  53 + stream = recognizer.create_stream()
  54 + stream.accept_waveform(sample_rate, samples)
  55 +
  56 + recognizer.recognizer.set_config(
  57 + config=sherpa_onnx.OfflineRecognizerConfig(
  58 + model_config=sherpa_onnx.OfflineModelConfig(
  59 + canary=sherpa_onnx.OfflineCanaryModelConfig(
  60 + src_lang=src_lang,
  61 + tgt_lang=tgt_lang,
  62 + )
  63 + )
  64 + )
  65 + )
  66 +
  67 + recognizer.decode_stream(stream)
  68 + return stream.result.text
  69 +
  70 +
  71 +def main():
  72 + recognizer, en_wav, de_wav = create_recognizer()
  73 +
  74 + en_audio, en_sample_rate = sf.read(en_wav, dtype="float32", always_2d=True)
  75 + en_audio = en_audio[:, 0] # only use the first channel
  76 +
  77 + de_audio, de_sample_rate = sf.read(de_wav, dtype="float32", always_2d=True)
  78 + de_audio = de_audio[:, 0] # only use the first channel
  79 +
  80 + en_wav_en_result = decode(
  81 + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="en"
  82 + )
  83 + en_wav_es_result = decode(
  84 + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="es"
  85 + )
  86 + en_wav_de_result = decode(
  87 + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="de"
  88 + )
  89 + en_wav_fr_result = decode(
  90 + recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="fr"
  91 + )
  92 +
  93 + de_wav_en_result = decode(
  94 + recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="en"
  95 + )
  96 + de_wav_de_result = decode(
  97 + recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="de"
  98 + )
  99 +
  100 + print("en_wav_en_result", en_wav_en_result)
  101 + print("en_wav_es_result", en_wav_es_result)
  102 + print("en_wav_de_result", en_wav_de_result)
  103 + print("en_wav_fr_result", en_wav_fr_result)
  104 + print("-" * 10)
  105 + print("de_wav_en_result", de_wav_en_result)
  106 + print("de_wav_de_result", de_wav_de_result)
  107 +
  108 +
  109 +if __name__ == "__main__":
  110 + main()
@@ -281,9 +281,14 @@ def export_decoder(canary_model): @@ -281,9 +281,14 @@ def export_decoder(canary_model):
281 281
282 282
283 def export_tokens(canary_model): 283 def export_tokens(canary_model):
  284 + underline = "▁"
284 with open("./tokens.txt", "w", encoding="utf-8") as f: 285 with open("./tokens.txt", "w", encoding="utf-8") as f:
285 for i in range(canary_model.tokenizer.vocab_size): 286 for i in range(canary_model.tokenizer.vocab_size):
286 s = canary_model.tokenizer.ids_to_text([i]) 287 s = canary_model.tokenizer.ids_to_text([i])
  288 +
  289 + if s[0] == " ":
  290 + s = underline + s[1:]
  291 +
287 f.write(f"{s} {i}\n") 292 f.write(f"{s} {i}\n")
288 print("Saved to tokens.txt") 293 print("Saved to tokens.txt")
289 294
@@ -289,7 +289,13 @@ def main(): @@ -289,7 +289,13 @@ def main():
289 tokens.append(t) 289 tokens.append(t)
290 print("len(tokens)", len(tokens)) 290 print("len(tokens)", len(tokens))
291 print("tokens", tokens) 291 print("tokens", tokens)
  292 +
292 text = "".join([id2token[i] for i in tokens]) 293 text = "".join([id2token[i] for i in tokens])
  294 +
  295 + underline = "▁"
  296 + # underline = b"\xe2\x96\x81".decode()
  297 +
  298 + text = text.replace(underline, " ").strip()
293 print("text:", text) 299 print("text:", text)
294 300
295 301
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 5
6 #include <algorithm> 6 #include <algorithm>
7 #include <cstring> 7 #include <cstring>
  8 +#include <utility>
8 9
9 namespace sherpa_onnx::cxx { 10 namespace sherpa_onnx::cxx {
10 11
@@ -25,6 +25,8 @@ set(sources @@ -25,6 +25,8 @@ set(sources
25 jieba.cc 25 jieba.cc
26 keyword-spotter-impl.cc 26 keyword-spotter-impl.cc
27 keyword-spotter.cc 27 keyword-spotter.cc
  28 + offline-canary-model-config.cc
  29 + offline-canary-model.cc
28 offline-ctc-fst-decoder-config.cc 30 offline-ctc-fst-decoder-config.cc
29 offline-ctc-fst-decoder.cc 31 offline-ctc-fst-decoder.cc
30 offline-ctc-greedy-search-decoder.cc 32 offline-ctc-greedy-search-decoder.cc
@@ -50,7 +52,6 @@ set(sources @@ -50,7 +52,6 @@ set(sources
50 offline-rnn-lm.cc 52 offline-rnn-lm.cc
51 offline-sense-voice-model-config.cc 53 offline-sense-voice-model-config.cc
52 offline-sense-voice-model.cc 54 offline-sense-voice-model.cc
53 -  
54 offline-source-separation-impl.cc 55 offline-source-separation-impl.cc
55 offline-source-separation-model-config.cc 56 offline-source-separation-model-config.cc
56 offline-source-separation-spleeter-model-config.cc 57 offline-source-separation-spleeter-model-config.cc
@@ -58,7 +59,6 @@ set(sources @@ -58,7 +59,6 @@ set(sources
58 offline-source-separation-uvr-model-config.cc 59 offline-source-separation-uvr-model-config.cc
59 offline-source-separation-uvr-model.cc 60 offline-source-separation-uvr-model.cc
60 offline-source-separation.cc 61 offline-source-separation.cc
61 -  
62 offline-stream.cc 62 offline-stream.cc
63 offline-tdnn-ctc-model.cc 63 offline-tdnn-ctc-model.cc
64 offline-tdnn-model-config.cc 64 offline-tdnn-model-config.cc
  1 +// sherpa-onnx/csrc/offline-canary-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-canary-model-config.h"
  6 +
  7 +#include <sstream>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OfflineCanaryModelConfig::Register(ParseOptions *po) {
  15 + po->Register("canary-encoder", &encoder,
  16 + "Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
  17 +
  18 + po->Register("canary-decoder", &decoder,
  19 + "Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
  20 +
  21 + po->Register("canary-src-lang", &src_lang,
  22 + "Valid values: en, de, es, fr. If empty, default to use en");
  23 +
  24 + po->Register("canary-tgt-lang", &tgt_lang,
  25 + "Valid values: en, de, es, fr. If empty, default to use en");
  26 +
  27 + po->Register("canary-use-pnc", &use_pnc,
  28 + "true to enable punctuations and casing. false to disable them");
  29 +}
  30 +
  31 +bool OfflineCanaryModelConfig::Validate() const {
  32 + if (encoder.empty()) {
  33 + SHERPA_ONNX_LOGE("Please provide --canary-encoder");
  34 + return false;
  35 + }
  36 +
  37 + if (!FileExists(encoder)) {
  38 + SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
  39 + encoder.c_str());
  40 + return false;
  41 + }
  42 +
  43 + if (decoder.empty()) {
  44 + SHERPA_ONNX_LOGE("Please provide --canary-decoder");
  45 + return false;
  46 + }
  47 +
  48 + if (!FileExists(decoder)) {
  49 + SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
  50 + decoder.c_str());
  51 + return false;
  52 + }
  53 +
  54 + if (!src_lang.empty()) {
  55 + if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
  56 + src_lang != "fr") {
  57 + SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
  58 + return false;
  59 + }
  60 + }
  61 +
  62 + if (!tgt_lang.empty()) {
  63 + if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
  64 + tgt_lang != "fr") {
  65 + SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
  66 + return false;
  67 + }
  68 + }
  69 +
  70 + return true;
  71 +}
  72 +
  73 +std::string OfflineCanaryModelConfig::ToString() const {
  74 + std::ostringstream os;
  75 +
  76 + os << "OfflineCanaryModelConfig(";
  77 + os << "encoder=\"" << encoder << "\", ";
  78 + os << "decoder=\"" << decoder << "\", ";
  79 + os << "src_lang=\"" << src_lang << "\", ";
  80 + os << "tgt_lang=\"" << tgt_lang << "\", ";
  81 + os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
  82 +
  83 + return os.str();
  84 +}
  85 +
  86 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-canary-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineCanaryModelConfig {
  15 + std::string encoder;
  16 + std::string decoder;
  17 +
  18 + // en, de, es, fr, or leave it empty to use en
  19 + std::string src_lang;
  20 +
  21 + // en, de, es, fr, or leave it empty to use en
  22 + std::string tgt_lang;
  23 +
  24 + // true to enable punctuations and casing
  25 + // false to disable punctuations and casing
  26 + bool use_pnc = true;
  27 +
  28 + OfflineCanaryModelConfig() = default;
  29 + OfflineCanaryModelConfig(const std::string &encoder,
  30 + const std::string &decoder,
  31 + const std::string &src_lang,
  32 + const std::string &tgt_lang, bool use_pnc)
  33 + : encoder(encoder),
  34 + decoder(decoder),
  35 + src_lang(src_lang),
  36 + tgt_lang(tgt_lang),
  37 + use_pnc(use_pnc) {}
  38 +
  39 + void Register(ParseOptions *po);
  40 + bool Validate() const;
  41 +
  42 + std::string ToString() const;
  43 +};
  44 +
  45 +} // namespace sherpa_onnx
  46 +
  47 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-canary-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
  6 +
  7 +#include <string>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineCanaryModelMetaData {
  14 + int32_t vocab_size;
  15 + int32_t subsampling_factor = 8;
  16 + int32_t feat_dim = 120;
  17 + std::string normalize_type;
  18 + std::unordered_map<std::string, int32_t> lang2id;
  19 +};
  20 +
  21 +} // namespace sherpa_onnx
  22 +
  23 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-canary-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-canary-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <string>
  10 +#include <tuple>
  11 +#include <unordered_map>
  12 +#include <utility>
  13 +
  14 +#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
  15 +
  16 +#if __ANDROID_API__ >= 9
  17 +#include "android/asset_manager.h"
  18 +#include "android/asset_manager_jni.h"
  19 +#endif
  20 +
  21 +#if __OHOS__
  22 +#include "rawfile/raw_file_manager.h"
  23 +#endif
  24 +
  25 +#include "sherpa-onnx/csrc/file-utils.h"
  26 +#include "sherpa-onnx/csrc/macros.h"
  27 +#include "sherpa-onnx/csrc/onnx-utils.h"
  28 +#include "sherpa-onnx/csrc/session.h"
  29 +#include "sherpa-onnx/csrc/text-utils.h"
  30 +
  31 +namespace sherpa_onnx {
  32 +
  33 +class OfflineCanaryModel::Impl {
  34 + public:
  35 + explicit Impl(const OfflineModelConfig &config)
  36 + : config_(config),
  37 + env_(ORT_LOGGING_LEVEL_ERROR),
  38 + sess_opts_(GetSessionOptions(config)),
  39 + allocator_{} {
  40 + {
  41 + auto buf = ReadFile(config.canary.encoder);
  42 + InitEncoder(buf.data(), buf.size());
  43 + }
  44 +
  45 + {
  46 + auto buf = ReadFile(config.canary.decoder);
  47 + InitDecoder(buf.data(), buf.size());
  48 + }
  49 + }
  50 +
  51 + template <typename Manager>
  52 + Impl(Manager *mgr, const OfflineModelConfig &config)
  53 + : config_(config),
  54 + env_(ORT_LOGGING_LEVEL_ERROR),
  55 + sess_opts_(GetSessionOptions(config)),
  56 + allocator_{} {
  57 + {
  58 + auto buf = ReadFile(mgr, config.canary.encoder);
  59 + InitEncoder(buf.data(), buf.size());
  60 + }
  61 +
  62 + {
  63 + auto buf = ReadFile(mgr, config.canary.decoder);
  64 + InitDecoder(buf.data(), buf.size());
  65 + }
  66 + }
  67 +
  68 + std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
  69 + Ort::Value features_length) {
  70 + std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
  71 + std::move(features_length)};
  72 +
  73 + auto encoder_out = encoder_sess_->Run(
  74 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  75 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  76 + encoder_output_names_ptr_.size());
  77 +
  78 + return encoder_out;
  79 + }
  80 +
  81 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
  82 + Ort::Value tokens, std::vector<Ort::Value> decoder_states,
  83 + Ort::Value encoder_states, Ort::Value enc_mask) {
  84 + std::vector<Ort::Value> decoder_inputs;
  85 + decoder_inputs.reserve(3 + decoder_states.size());
  86 +
  87 + decoder_inputs.push_back(std::move(tokens));
  88 + for (auto &s : decoder_states) {
  89 + decoder_inputs.push_back(std::move(s));
  90 + }
  91 +
  92 + decoder_inputs.push_back(std::move(encoder_states));
  93 + decoder_inputs.push_back(std::move(enc_mask));
  94 +
  95 + auto decoder_outputs = decoder_sess_->Run(
  96 + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
  97 + decoder_inputs.size(), decoder_output_names_ptr_.data(),
  98 + decoder_output_names_ptr_.size());
  99 +
  100 + Ort::Value logits = std::move(decoder_outputs[0]);
  101 +
  102 + std::vector<Ort::Value> output_decoder_states;
  103 + output_decoder_states.reserve(decoder_states.size());
  104 +
  105 + int32_t i = 0;
  106 + for (auto &s : decoder_outputs) {
  107 + i += 1;
  108 + if (i == 1) {
  109 + continue;
  110 + }
  111 + output_decoder_states.push_back(std::move(s));
  112 + }
  113 +
  114 + return {std::move(logits), std::move(output_decoder_states)};
  115 + }
  116 +
  117 + std::vector<Ort::Value> GetInitialDecoderStates() {
  118 + std::array<int64_t, 3> shape{1, 0, 1024};
  119 +
  120 + std::vector<Ort::Value> ans;
  121 + ans.reserve(6);
  122 + for (int32_t i = 0; i < 6; ++i) {
  123 + Ort::Value state = Ort::Value::CreateTensor<float>(
  124 + Allocator(), shape.data(), shape.size());
  125 +
  126 + ans.push_back(std::move(state));
  127 + }
  128 +
  129 + return ans;
  130 + }
  131 +
  132 + OrtAllocator *Allocator() { return allocator_; }
  133 +
  134 + const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; }
  135 +
  136 + OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; }
  137 +
  138 + private:
  139 + void InitEncoder(void *model_data, size_t model_data_length) {
  140 + encoder_sess_ = std::make_unique<Ort::Session>(
  141 + env_, model_data, model_data_length, sess_opts_);
  142 +
  143 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  144 + &encoder_input_names_ptr_);
  145 +
  146 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  147 + &encoder_output_names_ptr_);
  148 +
  149 + // get meta data
  150 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  151 + if (config_.debug) {
  152 + std::ostringstream os;
  153 + os << "---encoder---\n";
  154 + PrintModelMetadata(os, meta_data);
  155 +#if __OHOS__
  156 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  157 +#else
  158 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  159 +#endif
  160 + }
  161 +
  162 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  163 +
  164 + std::string model_type;
  165 + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
  166 +
  167 + if (model_type != "EncDecMultiTaskModel") {
  168 + SHERPA_ONNX_LOGE(
  169 + "Expected model type 'EncDecMultiTaskModel'. Given: '%s'",
  170 + model_type.c_str());
  171 + SHERPA_ONNX_EXIT(-1);
  172 + }
  173 +
  174 + SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size");
  175 + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type,
  176 + "normalize_type");
  177 + SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor");
  178 + SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim");
  179 + }
  180 +
  181 + void InitDecoder(void *model_data, size_t model_data_length) {
  182 + decoder_sess_ = std::make_unique<Ort::Session>(
  183 + env_, model_data, model_data_length, sess_opts_);
  184 +
  185 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  186 + &decoder_input_names_ptr_);
  187 +
  188 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  189 + &decoder_output_names_ptr_);
  190 + }
  191 +
  192 + private:
  193 + OfflineCanaryModelMetaData meta_;
  194 + OfflineModelConfig config_;
  195 + Ort::Env env_;
  196 + Ort::SessionOptions sess_opts_;
  197 + Ort::AllocatorWithDefaultOptions allocator_;
  198 +
  199 + std::unique_ptr<Ort::Session> encoder_sess_;
  200 + std::unique_ptr<Ort::Session> decoder_sess_;
  201 +
  202 + std::vector<std::string> encoder_input_names_;
  203 + std::vector<const char *> encoder_input_names_ptr_;
  204 +
  205 + std::vector<std::string> encoder_output_names_;
  206 + std::vector<const char *> encoder_output_names_ptr_;
  207 +
  208 + std::vector<std::string> decoder_input_names_;
  209 + std::vector<const char *> decoder_input_names_ptr_;
  210 +
  211 + std::vector<std::string> decoder_output_names_;
  212 + std::vector<const char *> decoder_output_names_ptr_;
  213 +};
  214 +
  215 +OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config)
  216 + : impl_(std::make_unique<Impl>(config)) {}
  217 +
  218 +template <typename Manager>
  219 +OfflineCanaryModel::OfflineCanaryModel(Manager *mgr,
  220 + const OfflineModelConfig &config)
  221 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  222 +
  223 +OfflineCanaryModel::~OfflineCanaryModel() = default;
  224 +
  225 +std::vector<Ort::Value> OfflineCanaryModel::ForwardEncoder(
  226 + Ort::Value features, Ort::Value features_length) const {
  227 + return impl_->ForwardEncoder(std::move(features), std::move(features_length));
  228 +}
  229 +
  230 +std::pair<Ort::Value, std::vector<Ort::Value>>
  231 +OfflineCanaryModel::ForwardDecoder(Ort::Value tokens,
  232 + std::vector<Ort::Value> decoder_states,
  233 + Ort::Value encoder_states,
  234 + Ort::Value enc_mask) const {
  235 + return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states),
  236 + std::move(encoder_states), std::move(enc_mask));
  237 +}
  238 +
  239 +std::vector<Ort::Value> OfflineCanaryModel::GetInitialDecoderStates() const {
  240 + return impl_->GetInitialDecoderStates();
  241 +}
  242 +
  243 +OrtAllocator *OfflineCanaryModel::Allocator() const {
  244 + return impl_->Allocator();
  245 +}
  246 +
  247 +const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const {
  248 + return impl_->GetModelMetadata();
  249 +}
  250 +OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() {
  251 + return impl_->GetModelMetadata();
  252 +}
  253 +
  254 +#if __ANDROID_API__ >= 9
  255 +template OfflineCanaryModel::OfflineCanaryModel(
  256 + AAssetManager *mgr, const OfflineModelConfig &config);
  257 +#endif
  258 +
  259 +#if __OHOS__
  260 +template OfflineCanaryModel::OfflineCanaryModel(
  261 + NativeResourceManager *mgr, const OfflineModelConfig &config);
  262 +#endif
  263 +
  264 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-canary-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <unordered_map>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "onnxruntime_cxx_api.h" // NOLINT
  14 +#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
  15 +#include "sherpa-onnx/csrc/offline-model-config.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +// see
  20 +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py
  21 +class OfflineCanaryModel {
  22 + public:
  23 + explicit OfflineCanaryModel(const OfflineModelConfig &config);
  24 +
  25 + template <typename Manager>
  26 + OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config);
  27 +
  28 + ~OfflineCanaryModel();
  29 +
  30 + /** Run the encoder.
  31 + *
  32 + * @param features A tensor of shape (N, T, C) of dtype float32.
  33 + * @param features_length A 1-D tensor of shape (N,) containing number of
  34 + * valid frames in `features` before padding.
  35 + * Its dtype is int64_t.
  36 + *
  37 + * @return Return a vector containing:
  38 + * - encoder_states: A 3-D tensor of shape (N, T', encoder_dim)
  39 + * - encoder_len: A 1-D tensor of shape (N,) containing number
  40 + * of frames in `encoder_out` before padding.
  41 + * Its dtype is int64_t
  42 + * - enc_mask: A 2-D tensor of shape (N, T') with dtype bool
  43 + */
  44 + std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
  45 + Ort::Value features_length) const;
  46 +
  47 + /** Run the decoder model.
  48 + *
  49 + * @param tokens A int32 tensor of shape (N, num_tokens)
  50 + * @param decoder_states std::vector<Ort::Value>
  51 + * @param encoder_states Output from ForwardEncoder()
  52 + * @param enc_mask Output from ForwardEncoder()
  53 + *
  54 + * @return Return a pair:
  55 + *
  56 + * - logits A 3-D tensor of shape (N, num_words, vocab_size)
  57 + * - new_decoder_states: Can be used as input for ForwardDecoder()
  58 + */
  59 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
  60 + Ort::Value tokens, std::vector<Ort::Value> decoder_states,
  61 + Ort::Value encoder_states, Ort::Value enc_mask) const;
  62 +
  63 + // The return value can be used as input for ForwardDecoder()
  64 + std::vector<Ort::Value> GetInitialDecoderStates() const;
  65 +
  66 + /** Return an allocator for allocating memory
  67 + */
  68 + OrtAllocator *Allocator() const;
  69 +
  70 + const OfflineCanaryModelMetaData &GetModelMetadata() const;
  71 +
  72 + OfflineCanaryModelMetaData &GetModelMetadata();
  73 +
  74 + private:
  75 + class Impl;
  76 + std::unique_ptr<Impl> impl_;
  77 +};
  78 +
  79 +} // namespace sherpa_onnx
  80 +
  81 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
@@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
22 sense_voice.Register(po); 22 sense_voice.Register(po);
23 moonshine.Register(po); 23 moonshine.Register(po);
24 dolphin.Register(po); 24 dolphin.Register(po);
  25 + canary.Register(po);
25 26
26 po->Register("telespeech-ctc", &telespeech_ctc, 27 po->Register("telespeech-ctc", &telespeech_ctc,
27 "Path to model.onnx for telespeech ctc"); 28 "Path to model.onnx for telespeech ctc");
@@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const { @@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const {
114 return dolphin.Validate(); 115 return dolphin.Validate();
115 } 116 }
116 117
  118 + if (!canary.encoder.empty()) {
  119 + return canary.Validate();
  120 + }
  121 +
117 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { 122 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
118 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", 123 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
119 telespeech_ctc.c_str()); 124 telespeech_ctc.c_str());
@@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const { @@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const {
142 os << "sense_voice=" << sense_voice.ToString() << ", "; 147 os << "sense_voice=" << sense_voice.ToString() << ", ";
143 os << "moonshine=" << moonshine.ToString() << ", "; 148 os << "moonshine=" << moonshine.ToString() << ", ";
144 os << "dolphin=" << dolphin.ToString() << ", "; 149 os << "dolphin=" << dolphin.ToString() << ", ";
  150 + os << "canary=" << canary.ToString() << ", ";
145 os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; 151 os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
146 os << "tokens=\"" << tokens << "\", "; 152 os << "tokens=\"" << tokens << "\", ";
147 os << "num_threads=" << num_threads << ", "; 153 os << "num_threads=" << num_threads << ", ";
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/offline-canary-model-config.h"
9 #include "sherpa-onnx/csrc/offline-dolphin-model-config.h" 10 #include "sherpa-onnx/csrc/offline-dolphin-model-config.h"
10 #include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" 11 #include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
11 #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" 12 #include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
@@ -32,6 +33,7 @@ struct OfflineModelConfig { @@ -32,6 +33,7 @@ struct OfflineModelConfig {
32 OfflineSenseVoiceModelConfig sense_voice; 33 OfflineSenseVoiceModelConfig sense_voice;
33 OfflineMoonshineModelConfig moonshine; 34 OfflineMoonshineModelConfig moonshine;
34 OfflineDolphinModelConfig dolphin; 35 OfflineDolphinModelConfig dolphin;
  36 + OfflineCanaryModelConfig canary;
35 std::string telespeech_ctc; 37 std::string telespeech_ctc;
36 38
37 std::string tokens; 39 std::string tokens;
@@ -65,6 +67,7 @@ struct OfflineModelConfig { @@ -65,6 +67,7 @@ struct OfflineModelConfig {
65 const OfflineSenseVoiceModelConfig &sense_voice, 67 const OfflineSenseVoiceModelConfig &sense_voice,
66 const OfflineMoonshineModelConfig &moonshine, 68 const OfflineMoonshineModelConfig &moonshine,
67 const OfflineDolphinModelConfig &dolphin, 69 const OfflineDolphinModelConfig &dolphin,
  70 + const OfflineCanaryModelConfig &canary,
68 const std::string &telespeech_ctc, 71 const std::string &telespeech_ctc,
69 const std::string &tokens, int32_t num_threads, bool debug, 72 const std::string &tokens, int32_t num_threads, bool debug,
70 const std::string &provider, const std::string &model_type, 73 const std::string &provider, const std::string &model_type,
@@ -81,6 +84,7 @@ struct OfflineModelConfig { @@ -81,6 +84,7 @@ struct OfflineModelConfig {
81 sense_voice(sense_voice), 84 sense_voice(sense_voice),
82 moonshine(moonshine), 85 moonshine(moonshine),
83 dolphin(dolphin), 86 dolphin(dolphin),
  87 + canary(canary),
84 telespeech_ctc(telespeech_ctc), 88 telespeech_ctc(telespeech_ctc),
85 tokens(tokens), 89 tokens(tokens),
86 num_threads(num_threads), 90 num_threads(num_threads),
  1 +// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <ios>
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#include "sherpa-onnx/csrc/macros.h"
  16 +#include "sherpa-onnx/csrc/offline-canary-model.h"
  17 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  18 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  19 +#include "sherpa-onnx/csrc/onnx-utils.h"
  20 +#include "sherpa-onnx/csrc/symbol-table.h"
  21 +#include "sherpa-onnx/csrc/utils.h"
  22 +
  23 +namespace sherpa_onnx {
  24 +
  25 +class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
  26 + public:
  27 + explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
  28 + : OfflineRecognizerImpl(config),
  29 + config_(config),
  30 + symbol_table_(config_.model_config.tokens),
  31 + model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
  32 + PostInit();
  33 + }
  34 +
  35 + template <typename Manager>
  36 + explicit OfflineRecognizerCanaryImpl(Manager *mgr,
  37 + const OfflineRecognizerConfig &config)
  38 + : OfflineRecognizerImpl(mgr, config),
  39 + config_(config),
  40 + symbol_table_(mgr, config_.model_config.tokens),
  41 + model_(
  42 + std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
  43 + PostInit();
  44 + }
  45 +
  46 + std::unique_ptr<OfflineStream> CreateStream() const override {
  47 + return std::make_unique<OfflineStream>(config_.feat_config);
  48 + }
  49 +
  50 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  51 + for (int32_t i = 0; i < n; ++i) {
  52 + DecodeStream(ss[i]);
  53 + }
  54 + }
  55 +
  56 + void DecodeStream(OfflineStream *s) const {
  57 + auto meta = model_->GetModelMetadata();
  58 + auto enc_out = RunEncoder(s);
  59 + Ort::Value enc_states = std::move(enc_out[0]);
  60 + Ort::Value enc_mask = std::move(enc_out[2]);
  61 + // enc_out[1] is discarded
  62 + std::vector<int32_t> decoder_input = GetInitialDecoderInput();
  63 + auto decoder_states = model_->GetInitialDecoderStates();
  64 + Ort::Value logits{nullptr};
  65 +
  66 + for (int32_t i = 0; i < decoder_input.size(); ++i) {
  67 + std::tie(logits, decoder_states) =
  68 + RunDecoder(decoder_input[i], i, std::move(decoder_states),
  69 + View(&enc_states), View(&enc_mask));
  70 + }
  71 +
  72 + int32_t max_token_id = GetMaxTokenId(&logits);
  73 + int32_t eos = symbol_table_["<|endoftext|>"];
  74 +
  75 + int32_t num_feature_frames =
  76 + enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
  77 + meta.subsampling_factor;
  78 +
  79 + std::vector<int32_t> tokens = {max_token_id};
  80 +
  81 + // Assume 30 tokens per second. It is to avoid the following for loop
  82 + // running indefinitely.
  83 + int32_t num_tokens =
  84 + static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
  85 +
  86 + for (int32_t i = 1; i <= num_tokens; ++i) {
  87 + if (tokens.back() == eos) {
  88 + break;
  89 + }
  90 +
  91 + std::tie(logits, decoder_states) =
  92 + RunDecoder(tokens.back(), i, std::move(decoder_states),
  93 + View(&enc_states), View(&enc_mask));
  94 + tokens.push_back(GetMaxTokenId(&logits));
  95 + }
  96 +
  97 + // remove the last eos token
  98 + tokens.pop_back();
  99 +
  100 + auto r = Convert(tokens);
  101 +
  102 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  103 + r.text = ApplyHomophoneReplacer(std::move(r.text));
  104 +
  105 + s->SetResult(r);
  106 + }
  107 +
  108 + OfflineRecognizerConfig GetConfig() const override { return config_; }
  109 +
  110 + void SetConfig(const OfflineRecognizerConfig &config) override {
  111 + config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
  112 + config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
  113 + config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
  114 +
  115 + // we don't change the config_ in the base class
  116 + }
  117 +
  118 + private:
  119 + OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
  120 + OfflineRecognitionResult r;
  121 + r.tokens.reserve(tokens.size());
  122 +
  123 + std::string text;
  124 + for (auto i : tokens) {
  125 + if (!symbol_table_.Contains(i)) {
  126 + continue;
  127 + }
  128 +
  129 + const auto &s = symbol_table_[i];
  130 + text += s;
  131 + r.tokens.push_back(s);
  132 + }
  133 +
  134 + r.text = std::move(text);
  135 +
  136 + return r;
  137 + }
  138 +
  139 + int32_t GetMaxTokenId(Ort::Value *logits) const {
  140 + // logits is of shape (1, 1, vocab_size)
  141 + auto meta = model_->GetModelMetadata();
  142 + const float *p_logits = logits->GetTensorData<float>();
  143 +
  144 + int32_t max_token_id = static_cast<int32_t>(std::distance(
  145 + p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
  146 +
  147 + return max_token_id;
  148 + }
  149 +
  150 + std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
  151 + auto memory_info =
  152 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  153 +
  154 + int32_t feat_dim = config_.feat_config.feature_dim;
  155 + std::vector<float> f = s->GetFrames();
  156 +
  157 + int32_t num_frames = f.size() / feat_dim;
  158 +
  159 + std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
  160 +
  161 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  162 + shape.data(), shape.size());
  163 +
  164 + int64_t x_length_scalar = num_frames;
  165 + std::array<int64_t, 1> x_length_shape = {1};
  166 + Ort::Value x_length =
  167 + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
  168 + x_length_shape.data(), x_length_shape.size());
  169 + return model_->ForwardEncoder(std::move(x), std::move(x_length));
  170 + }
  171 +
  172 + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
  173 + int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
  174 + Ort::Value enc_states, Ort::Value enc_mask) const {
  175 + auto memory_info =
  176 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  177 +
  178 + std::array<int64_t, 2> shape = {1, 2};
  179 + std::array<int32_t, 2> _decoder_input = {token, pos};
  180 +
  181 + Ort::Value decoder_input = Ort::Value::CreateTensor(
  182 + memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
  183 + shape.size());
  184 +
  185 + return model_->ForwardDecoder(std::move(decoder_input),
  186 + std::move(decoder_states),
  187 + std::move(enc_states), std::move(enc_mask));
  188 + }
  189 +
  190 + // see
  191 + // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
  192 + std::vector<int32_t> GetInitialDecoderInput() const {
  193 + auto canary_config = config_.model_config.canary;
  194 + const auto &meta = model_->GetModelMetadata();
  195 +
  196 + std::vector<int32_t> decoder_input(9);
  197 + decoder_input[0] = symbol_table_["<|startofcontext|>"];
  198 + decoder_input[1] = symbol_table_["<|startoftranscript|>"];
  199 + decoder_input[2] = symbol_table_["<|emo:undefined|>"];
  200 +
  201 + if (canary_config.src_lang.empty() ||
  202 + !meta.lang2id.count(canary_config.src_lang)) {
  203 + decoder_input[3] = meta.lang2id.at("en");
  204 + } else {
  205 + decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
  206 + }
  207 +
  208 + if (canary_config.tgt_lang.empty() ||
  209 + !meta.lang2id.count(canary_config.tgt_lang)) {
  210 + decoder_input[4] = meta.lang2id.at("en");
  211 + } else {
  212 + decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
  213 + }
  214 +
  215 + if (canary_config.use_pnc) {
  216 + decoder_input[5] = symbol_table_["<|pnc|>"];
  217 + } else {
  218 + decoder_input[5] = symbol_table_["<|nopnc|>"];
  219 + }
  220 +
  221 + decoder_input[6] = symbol_table_["<|noitn|>"];
  222 + decoder_input[7] = symbol_table_["<|notimestamp|>"];
  223 + decoder_input[8] = symbol_table_["<|nodiarize|>"];
  224 +
  225 + return decoder_input;
  226 + }
  227 +
  228 + private:
  229 + void PostInit() {
  230 + auto &meta = model_->GetModelMetadata();
  231 + config_.feat_config.feature_dim = meta.feat_dim;
  232 +
  233 + config_.feat_config.nemo_normalize_type = meta.normalize_type;
  234 +
  235 + config_.feat_config.dither = 0;
  236 + config_.feat_config.remove_dc_offset = false;
  237 + config_.feat_config.low_freq = 0;
  238 + config_.feat_config.window_type = "hann";
  239 + config_.feat_config.is_librosa = true;
  240 +
  241 + meta.lang2id["en"] = symbol_table_["<|en|>"];
  242 + meta.lang2id["es"] = symbol_table_["<|es|>"];
  243 + meta.lang2id["de"] = symbol_table_["<|de|>"];
  244 + meta.lang2id["fr"] = symbol_table_["<|fr|>"];
  245 +
  246 + if (symbol_table_.NumSymbols() != meta.vocab_size) {
  247 + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
  248 + symbol_table_.NumSymbols(), meta.vocab_size);
  249 + SHERPA_ONNX_EXIT(-1);
  250 + }
  251 + }
  252 +
  253 + private:
  254 + OfflineRecognizerConfig config_;
  255 + SymbolTable symbol_table_;
  256 + std::unique_ptr<OfflineCanaryModel> model_;
  257 +};
  258 +
  259 +} // namespace sherpa_onnx
  260 +
  261 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
@@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert( @@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert(
39 r.tokens.push_back(s); 39 r.tokens.push_back(s);
40 } 40 }
41 41
42 - r.text = text; 42 + r.text = std::move(text);
43 43
44 return r; 44 return r;
45 } 45 }
@@ -24,6 +24,7 @@ @@ -24,6 +24,7 @@
24 #include "onnxruntime_cxx_api.h" // NOLINT 24 #include "onnxruntime_cxx_api.h" // NOLINT
25 #include "sherpa-onnx/csrc/file-utils.h" 25 #include "sherpa-onnx/csrc/file-utils.h"
26 #include "sherpa-onnx/csrc/macros.h" 26 #include "sherpa-onnx/csrc/macros.h"
  27 +#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h"
27 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 28 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
28 #include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h" 29 #include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
29 #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" 30 #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
@@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
66 return std::make_unique<OfflineRecognizerMoonshineImpl>(config); 67 return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
67 } 68 }
68 69
  70 + if (!config.model_config.canary.encoder.empty()) {
  71 + return std::make_unique<OfflineRecognizerCanaryImpl>(config);
  72 + }
  73 +
69 // TODO(fangjun): Refactor it. We only need to use model type for the 74 // TODO(fangjun): Refactor it. We only need to use model type for the
70 // following models: 75 // following models:
71 // 1. transducer and nemo_transducer 76 // 1. transducer and nemo_transducer
@@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
252 return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config); 257 return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
253 } 258 }
254 259
  260 + if (!config.model_config.canary.encoder.empty()) {
  261 + return std::make_unique<OfflineRecognizerCanaryImpl>(mgr, config);
  262 + }
  263 +
255 // TODO(fangjun): Refactor it. We only need to use model type for the 264 // TODO(fangjun): Refactor it. We only need to use model type for the
256 // following models: 265 // following models:
257 // 1. transducer and nemo_transducer 266 // 1. transducer and nemo_transducer
@@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) { @@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) {
183 return Ort::Value::CreateTensor( 183 return Ort::Value::CreateTensor(
184 memory_info, v->GetTensorMutableData<float>(), 184 memory_info, v->GetTensorMutableData<float>(),
185 type_and_shape.GetElementCount(), shape.data(), shape.size()); 185 type_and_shape.GetElementCount(), shape.data(), shape.size());
  186 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
  187 + return Ort::Value::CreateTensor(
  188 + memory_info, v->GetTensorMutableData<bool>(),
  189 + type_and_shape.GetElementCount(), shape.data(), shape.size());
186 default: 190 default:
187 fprintf(stderr, "Unsupported type: %d\n", 191 fprintf(stderr, "Unsupported type: %d\n",
188 static_cast<int32_t>(type_and_shape.GetElementType())); 192 static_cast<int32_t>(type_and_shape.GetElementType()));
@@ -9,6 +9,7 @@ set(srcs @@ -9,6 +9,7 @@ set(srcs
9 features.cc 9 features.cc
10 homophone-replacer.cc 10 homophone-replacer.cc
11 keyword-spotter.cc 11 keyword-spotter.cc
  12 + offline-canary-model-config.cc
12 offline-ctc-fst-decoder-config.cc 13 offline-ctc-fst-decoder-config.cc
13 offline-dolphin-model-config.cc 14 offline-dolphin-model-config.cc
14 offline-fire-red-asr-model-config.cc 15 offline-fire-red-asr-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-canary-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-canary-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineCanaryModelConfig(py::module *m) {
  15 + using PyClass = OfflineCanaryModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineCanaryModelConfig")
  17 + .def(py::init<const std::string &, const std::string &,
  18 + const std::string &, const std::string &, bool>(),
  19 + py::arg("encoder") = "", py::arg("decoder") = "",
  20 + py::arg("src_lang") = "", py::arg("tgt_lang") = "",
  21 + py::arg("use_pnc") = true)
  22 + .def_readwrite("encoder", &PyClass::encoder)
  23 + .def_readwrite("decoder", &PyClass::decoder)
  24 + .def_readwrite("src_lang", &PyClass::src_lang)
  25 + .def_readwrite("tgt_lang", &PyClass::tgt_lang)
  26 + .def_readwrite("use_pnc", &PyClass::use_pnc)
  27 + .def("__str__", &PyClass::ToString);
  28 +}
  29 +
  30 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-canary-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineCanaryModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
@@ -34,6 +35,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -34,6 +35,7 @@ void PybindOfflineModelConfig(py::module *m) {
34 PybindOfflineSenseVoiceModelConfig(m); 35 PybindOfflineSenseVoiceModelConfig(m);
35 PybindOfflineMoonshineModelConfig(m); 36 PybindOfflineMoonshineModelConfig(m);
36 PybindOfflineDolphinModelConfig(m); 37 PybindOfflineDolphinModelConfig(m);
  38 + PybindOfflineCanaryModelConfig(m);
37 39
38 using PyClass = OfflineModelConfig; 40 using PyClass = OfflineModelConfig;
39 py::class_<PyClass>(*m, "OfflineModelConfig") 41 py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -47,7 +49,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -47,7 +49,8 @@ void PybindOfflineModelConfig(py::module *m) {
47 const OfflineWenetCtcModelConfig &, 49 const OfflineWenetCtcModelConfig &,
48 const OfflineSenseVoiceModelConfig &, 50 const OfflineSenseVoiceModelConfig &,
49 const OfflineMoonshineModelConfig &, 51 const OfflineMoonshineModelConfig &,
50 - const OfflineDolphinModelConfig &, const std::string &, 52 + const OfflineDolphinModelConfig &,
  53 + const OfflineCanaryModelConfig &, const std::string &,
51 const std::string &, int32_t, bool, const std::string &, 54 const std::string &, int32_t, bool, const std::string &,
52 const std::string &, const std::string &, 55 const std::string &, const std::string &,
53 const std::string &>(), 56 const std::string &>(),
@@ -62,8 +65,9 @@ void PybindOfflineModelConfig(py::module *m) { @@ -62,8 +65,9 @@ void PybindOfflineModelConfig(py::module *m) {
62 py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), 65 py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
63 py::arg("moonshine") = OfflineMoonshineModelConfig(), 66 py::arg("moonshine") = OfflineMoonshineModelConfig(),
64 py::arg("dolphin") = OfflineDolphinModelConfig(), 67 py::arg("dolphin") = OfflineDolphinModelConfig(),
65 - py::arg("telespeech_ctc") = "", py::arg("tokens"),  
66 - py::arg("num_threads"), py::arg("debug") = false, 68 + py::arg("canary") = OfflineCanaryModelConfig(),
  69 + py::arg("telespeech_ctc") = "", py::arg("tokens") = "",
  70 + py::arg("num_threads") = 1, py::arg("debug") = false,
67 py::arg("provider") = "cpu", py::arg("model_type") = "", 71 py::arg("provider") = "cpu", py::arg("model_type") = "",
68 py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") 72 py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
69 .def_readwrite("transducer", &PyClass::transducer) 73 .def_readwrite("transducer", &PyClass::transducer)
@@ -77,6 +81,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -77,6 +81,7 @@ void PybindOfflineModelConfig(py::module *m) {
77 .def_readwrite("sense_voice", &PyClass::sense_voice) 81 .def_readwrite("sense_voice", &PyClass::sense_voice)
78 .def_readwrite("moonshine", &PyClass::moonshine) 82 .def_readwrite("moonshine", &PyClass::moonshine)
79 .def_readwrite("dolphin", &PyClass::dolphin) 83 .def_readwrite("dolphin", &PyClass::dolphin)
  84 + .def_readwrite("canary", &PyClass::canary)
80 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) 85 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
81 .def_readwrite("tokens", &PyClass::tokens) 86 .def_readwrite("tokens", &PyClass::tokens)
82 .def_readwrite("num_threads", &PyClass::num_threads) 87 .def_readwrite("num_threads", &PyClass::num_threads)
@@ -19,7 +19,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -19,7 +19,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
19 const std::string &, int32_t, const std::string &, float, 19 const std::string &, int32_t, const std::string &, float,
20 float, const std::string &, const std::string &, 20 float, const std::string &, const std::string &,
21 const HomophoneReplacerConfig &>(), 21 const HomophoneReplacerConfig &>(),
22 - py::arg("feat_config"), py::arg("model_config"), 22 + py::arg("feat_config") = FeatureExtractorConfig(),
  23 + py::arg("model_config") = OfflineModelConfig(),
23 py::arg("lm_config") = OfflineLMConfig(), 24 py::arg("lm_config") = OfflineLMConfig(),
24 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 25 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
25 py::arg("decoding_method") = "greedy_search", 26 py::arg("decoding_method") = "greedy_search",
@@ -61,6 +62,8 @@ void PybindOfflineRecognizer(py::module *m) { @@ -61,6 +62,8 @@ void PybindOfflineRecognizer(py::module *m) {
61 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>()) 62 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
62 .def("decode_stream", &PyClass::DecodeStream, py::arg("s"), 63 .def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
63 py::call_guard<py::gil_scoped_release>()) 64 py::call_guard<py::gil_scoped_release>())
  65 + .def("set_config", &PyClass::SetConfig, py::arg("config"),
  66 + py::call_guard<py::gil_scoped_release>())
64 .def( 67 .def(
65 "decode_streams", 68 "decode_streams",
66 [](const PyClass &self, std::vector<OfflineStream *> ss) { 69 [](const PyClass &self, std::vector<OfflineStream *> ss) {
@@ -8,9 +8,22 @@ from _sherpa_onnx import ( @@ -8,9 +8,22 @@ from _sherpa_onnx import (
8 DenoisedAudio, 8 DenoisedAudio,
9 FastClustering, 9 FastClustering,
10 FastClusteringConfig, 10 FastClusteringConfig,
  11 + FeatureExtractorConfig,
  12 + HomophoneReplacerConfig,
  13 + OfflineCanaryModelConfig,
  14 + OfflineCtcFstDecoderConfig,
  15 + OfflineDolphinModelConfig,
  16 + OfflineFireRedAsrModelConfig,
  17 + OfflineLMConfig,
  18 + OfflineModelConfig,
  19 + OfflineMoonshineModelConfig,
  20 + OfflineNemoEncDecCtcModelConfig,
  21 + OfflineParaformerModelConfig,
11 OfflinePunctuation, 22 OfflinePunctuation,
12 OfflinePunctuationConfig, 23 OfflinePunctuationConfig,
13 OfflinePunctuationModelConfig, 24 OfflinePunctuationModelConfig,
  25 + OfflineRecognizerConfig,
  26 + OfflineSenseVoiceModelConfig,
14 OfflineSourceSeparation, 27 OfflineSourceSeparation,
15 OfflineSourceSeparationConfig, 28 OfflineSourceSeparationConfig,
16 OfflineSourceSeparationModelConfig, 29 OfflineSourceSeparationModelConfig,
@@ -27,13 +40,18 @@ from _sherpa_onnx import ( @@ -27,13 +40,18 @@ from _sherpa_onnx import (
27 OfflineSpeechDenoiserGtcrnModelConfig, 40 OfflineSpeechDenoiserGtcrnModelConfig,
28 OfflineSpeechDenoiserModelConfig, 41 OfflineSpeechDenoiserModelConfig,
29 OfflineStream, 42 OfflineStream,
  43 + OfflineTdnnModelConfig,
  44 + OfflineTransducerModelConfig,
30 OfflineTts, 45 OfflineTts,
31 OfflineTtsConfig, 46 OfflineTtsConfig,
32 OfflineTtsKokoroModelConfig, 47 OfflineTtsKokoroModelConfig,
33 OfflineTtsMatchaModelConfig, 48 OfflineTtsMatchaModelConfig,
34 OfflineTtsModelConfig, 49 OfflineTtsModelConfig,
35 OfflineTtsVitsModelConfig, 50 OfflineTtsVitsModelConfig,
  51 + OfflineWenetCtcModelConfig,
  52 + OfflineWhisperModelConfig,
36 OfflineZipformerAudioTaggingModelConfig, 53 OfflineZipformerAudioTaggingModelConfig,
  54 + OfflineZipformerCtcModelConfig,
37 OnlinePunctuation, 55 OnlinePunctuation,
38 OnlinePunctuationConfig, 56 OnlinePunctuationConfig,
39 OnlinePunctuationModelConfig, 57 OnlinePunctuationModelConfig,
@@ -6,6 +6,7 @@ from typing import List, Optional @@ -6,6 +6,7 @@ from typing import List, Optional
6 from _sherpa_onnx import ( 6 from _sherpa_onnx import (
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
8 HomophoneReplacerConfig, 8 HomophoneReplacerConfig,
  9 + OfflineCanaryModelConfig,
9 OfflineCtcFstDecoderConfig, 10 OfflineCtcFstDecoderConfig,
10 OfflineDolphinModelConfig, 11 OfflineDolphinModelConfig,
11 OfflineFireRedAsrModelConfig, 12 OfflineFireRedAsrModelConfig,
@@ -425,7 +426,6 @@ class OfflineRecognizer(object): @@ -425,7 +426,6 @@ class OfflineRecognizer(object):
425 num_threads=num_threads, 426 num_threads=num_threads,
426 debug=debug, 427 debug=debug,
427 provider=provider, 428 provider=provider,
428 - model_type="nemo_ctc",  
429 ) 429 )
430 430
431 feat_config = FeatureExtractorConfig( 431 feat_config = FeatureExtractorConfig(
@@ -691,6 +691,102 @@ class OfflineRecognizer(object): @@ -691,6 +691,102 @@ class OfflineRecognizer(object):
691 return self 691 return self
692 692
693 @classmethod 693 @classmethod
  694 + def from_nemo_canary(
  695 + cls,
  696 + encoder: str,
  697 + decoder: str,
  698 + tokens: str,
  699 + src_lang: str = "en",
  700 + tgt_lang: str = "en",
  701 + num_threads: int = 1,
  702 + sample_rate: int = 16000,
  703 + feature_dim: int = 128, # not used
  704 + decoding_method: str = "greedy_search", # not used
  705 + debug: bool = False,
  706 + provider: str = "cpu",
  707 + rule_fsts: str = "",
  708 + rule_fars: str = "",
  709 + hr_dict_dir: str = "",
  710 + hr_rule_fsts: str = "",
  711 + hr_lexicon: str = "",
  712 + ):
  713 + """
  714 + Please refer to
  715 + `<https://k2-fsa.github.io/sherpa/onnx/nemo/index.html>`_
  716 + to download pre-trained models for different languages.
  717 +
  718 + Args:
  719 + encoder:
  720 + Path to ``encoder.onnx`` or ``encoder.int8.onnx``.
  721 + decoder:
  722 + Path to ``decoder.onnx`` or ``decoder.int8.onnx``.
  723 + tokens:
  724 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  725 + columns::
  726 +
  727 + symbol integer_id
  728 +
  729 + src_lang:
  730 + The language of the input audio. Valid values are: en, es, de, fr.
  731 + If you leave it empty, it uses en internally.
  732 + tgt_lang:
  733 + The language of the output text. Valid values are: en, es, de, fr.
  734 + If you leave it empty, it uses en internally.
  735 + num_threads:
  736 + Number of threads for neural network computation.
  737 + sample_rate:
  738 + Sample rate of the training data used to train the model. Not used
  739 + feature_dim:
  740 + Dimension of the feature used to train the model. Not used
  741 + decoding_method:
  742 + Valid values are greedy_search. Not used
  743 + debug:
  744 + True to show debug messages.
  745 + provider:
  746 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  747 + rule_fsts:
  748 + If not empty, it specifies fsts for inverse text normalization.
  749 + If there are multiple fsts, they are separated by a comma.
  750 + rule_fars:
  751 + If not empty, it specifies fst archives for inverse text normalization.
  752 + If there are multiple archives, they are separated by a comma.
  753 + """
  754 + self = cls.__new__(cls)
  755 + model_config = OfflineModelConfig(
  756 + canary=OfflineCanaryModelConfig(
  757 + encoder=encoder,
  758 + decoder=decoder,
  759 + src_lang=src_lang,
  760 + tgt_lang=tgt_lang,
  761 + ),
  762 + tokens=tokens,
  763 + num_threads=num_threads,
  764 + debug=debug,
  765 + provider=provider,
  766 + )
  767 +
  768 + feat_config = FeatureExtractorConfig(
  769 + sampling_rate=sample_rate,
  770 + feature_dim=feature_dim,
  771 + )
  772 +
  773 + recognizer_config = OfflineRecognizerConfig(
  774 + feat_config=feat_config,
  775 + model_config=model_config,
  776 + decoding_method=decoding_method,
  777 + rule_fsts=rule_fsts,
  778 + rule_fars=rule_fars,
  779 + hr=HomophoneReplacerConfig(
  780 + dict_dir=hr_dict_dir,
  781 + lexicon=hr_lexicon,
  782 + rule_fsts=hr_rule_fsts,
  783 + ),
  784 + )
  785 + self.recognizer = _Recognizer(recognizer_config)
  786 + self.config = recognizer_config
  787 + return self
  788 +
  789 + @classmethod
694 def from_whisper( 790 def from_whisper(
695 cls, 791 cls,
696 encoder: str, 792 encoder: str,