Fangjun Kuang
Committed by GitHub

Add C++ runtime and Python API for NeMo Canary models (#2352)

... ... @@ -8,6 +8,13 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test nemo canary"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
rm sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
python3 ./python-api-examples/offline-nemo-canary-decode-files.py
rm -rf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
log "test spleeter"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming Canary model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model supports 4 languages and it is converted from
https://huggingface.co/nvidia/canary-180m-flash
It supports automatic speech-to-text recognition (ASR) in 4 languages
(English, German, French, Spanish) and translation from English to
German/French/Spanish and from German/French/Spanish to English with or
without punctuation and capitalization (PnC).
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/encoder.int8.onnx"
decoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/decoder.int8.onnx"
tokens = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/tokens.txt"
en_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/en.wav"
de_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/de.wav"
if not Path(encoder).is_file() or not Path(en_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_nemo_canary(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
en_wav,
de_wav,
)
def decode(recognizer, samples, sample_rate, src_lang, tgt_lang):
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, samples)
recognizer.recognizer.set_config(
config=sherpa_onnx.OfflineRecognizerConfig(
model_config=sherpa_onnx.OfflineModelConfig(
canary=sherpa_onnx.OfflineCanaryModelConfig(
src_lang=src_lang,
tgt_lang=tgt_lang,
)
)
)
)
recognizer.decode_stream(stream)
return stream.result.text
def main():
recognizer, en_wav, de_wav = create_recognizer()
en_audio, en_sample_rate = sf.read(en_wav, dtype="float32", always_2d=True)
en_audio = en_audio[:, 0] # only use the first channel
de_audio, de_sample_rate = sf.read(de_wav, dtype="float32", always_2d=True)
de_audio = de_audio[:, 0] # only use the first channel
en_wav_en_result = decode(
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="en"
)
en_wav_es_result = decode(
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="es"
)
en_wav_de_result = decode(
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="de"
)
en_wav_fr_result = decode(
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="fr"
)
de_wav_en_result = decode(
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="en"
)
de_wav_de_result = decode(
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="de"
)
print("en_wav_en_result", en_wav_en_result)
print("en_wav_es_result", en_wav_es_result)
print("en_wav_de_result", en_wav_de_result)
print("en_wav_fr_result", en_wav_fr_result)
print("-" * 10)
print("de_wav_en_result", de_wav_en_result)
print("de_wav_de_result", de_wav_de_result)
if __name__ == "__main__":
main()
... ...
... ... @@ -281,9 +281,14 @@ def export_decoder(canary_model):
def export_tokens(canary_model):
underline = "▁"
with open("./tokens.txt", "w", encoding="utf-8") as f:
for i in range(canary_model.tokenizer.vocab_size):
s = canary_model.tokenizer.ids_to_text([i])
if s[0] == " ":
s = underline + s[1:]
f.write(f"{s} {i}\n")
print("Saved to tokens.txt")
... ...
... ... @@ -289,7 +289,13 @@ def main():
tokens.append(t)
print("len(tokens)", len(tokens))
print("tokens", tokens)
text = "".join([id2token[i] for i in tokens])
underline = "▁"
# underline = b"\xe2\x96\x81".decode()
text = text.replace(underline, " ").strip()
print("text:", text)
... ...
... ... @@ -5,6 +5,7 @@
#include <algorithm>
#include <cstring>
#include <utility>
namespace sherpa_onnx::cxx {
... ...
... ... @@ -25,6 +25,8 @@ set(sources
jieba.cc
keyword-spotter-impl.cc
keyword-spotter.cc
offline-canary-model-config.cc
offline-canary-model.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
... ... @@ -50,7 +52,6 @@ set(sources
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc
offline-source-separation-impl.cc
offline-source-separation-model-config.cc
offline-source-separation-spleeter-model-config.cc
... ... @@ -58,7 +59,6 @@ set(sources
offline-source-separation-uvr-model-config.cc
offline-source-separation-uvr-model.cc
offline-source-separation.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
... ...
// sherpa-onnx/csrc/offline-canary-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineCanaryModelConfig::Register(ParseOptions *po) {
po->Register("canary-encoder", &encoder,
"Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
po->Register("canary-decoder", &decoder,
"Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
po->Register("canary-src-lang", &src_lang,
"Valid values: en, de, es, fr. If empty, default to use en");
po->Register("canary-tgt-lang", &tgt_lang,
"Valid values: en, de, es, fr. If empty, default to use en");
po->Register("canary-use-pnc", &use_pnc,
"true to enable punctuations and casing. false to disable them");
}
bool OfflineCanaryModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --canary-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --canary-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
decoder.c_str());
return false;
}
if (!src_lang.empty()) {
if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
src_lang != "fr") {
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
return false;
}
}
if (!tgt_lang.empty()) {
if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
tgt_lang != "fr") {
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
return false;
}
}
return true;
}
std::string OfflineCanaryModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineCanaryModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\", ";
os << "src_lang=\"" << src_lang << "\", ";
os << "tgt_lang=\"" << tgt_lang << "\", ";
os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-canary-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineCanaryModelConfig {
std::string encoder;
std::string decoder;
// en, de, es, fr, or leave it empty to use en
std::string src_lang;
// en, de, es, fr, or leave it empty to use en
std::string tgt_lang;
// true to enable punctuations and casing
// false to disable punctuations and casing
bool use_pnc = true;
OfflineCanaryModelConfig() = default;
OfflineCanaryModelConfig(const std::string &encoder,
const std::string &decoder,
const std::string &src_lang,
const std::string &tgt_lang, bool use_pnc)
: encoder(encoder),
decoder(decoder),
src_lang(src_lang),
tgt_lang(tgt_lang),
use_pnc(use_pnc) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-canary-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
struct OfflineCanaryModelMetaData {
int32_t vocab_size;
int32_t subsampling_factor = 8;
int32_t feat_dim = 120;
std::string normalize_type;
std::unordered_map<std::string, int32_t> lang2id;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-canary-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-canary-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineCanaryModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.canary.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.canary.decoder);
InitDecoder(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.canary.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.canary.decoder);
InitDecoder(buf.data(), buf.size());
}
}
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
std::move(features_length)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return encoder_out;
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states, Ort::Value enc_mask) {
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(3 + decoder_states.size());
decoder_inputs.push_back(std::move(tokens));
for (auto &s : decoder_states) {
decoder_inputs.push_back(std::move(s));
}
decoder_inputs.push_back(std::move(encoder_states));
decoder_inputs.push_back(std::move(enc_mask));
auto decoder_outputs = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
decoder_inputs.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
Ort::Value logits = std::move(decoder_outputs[0]);
std::vector<Ort::Value> output_decoder_states;
output_decoder_states.reserve(decoder_states.size());
int32_t i = 0;
for (auto &s : decoder_outputs) {
i += 1;
if (i == 1) {
continue;
}
output_decoder_states.push_back(std::move(s));
}
return {std::move(logits), std::move(output_decoder_states)};
}
std::vector<Ort::Value> GetInitialDecoderStates() {
std::array<int64_t, 3> shape{1, 0, 1024};
std::vector<Ort::Value> ans;
ans.reserve(6);
for (int32_t i = 0; i < 6; ++i) {
Ort::Value state = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
ans.push_back(std::move(state));
}
return ans;
}
OrtAllocator *Allocator() { return allocator_; }
const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; }
OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
if (model_type != "EncDecMultiTaskModel") {
SHERPA_ONNX_LOGE(
"Expected model type 'EncDecMultiTaskModel'. Given: '%s'",
model_type.c_str());
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size");
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type,
"normalize_type");
SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim");
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OfflineCanaryModelMetaData meta_;
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
};
OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineCanaryModel::OfflineCanaryModel(Manager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineCanaryModel::~OfflineCanaryModel() = default;
std::vector<Ort::Value> OfflineCanaryModel::ForwardEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineCanaryModel::ForwardDecoder(Ort::Value tokens,
std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states,
Ort::Value enc_mask) const {
return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states),
std::move(encoder_states), std::move(enc_mask));
}
std::vector<Ort::Value> OfflineCanaryModel::GetInitialDecoderStates() const {
return impl_->GetInitialDecoderStates();
}
OrtAllocator *OfflineCanaryModel::Allocator() const {
return impl_->Allocator();
}
const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const {
return impl_->GetModelMetadata();
}
OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() {
return impl_->GetModelMetadata();
}
#if __ANDROID_API__ >= 9
template OfflineCanaryModel::OfflineCanaryModel(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
#if __OHOS__
template OfflineCanaryModel::OfflineCanaryModel(
NativeResourceManager *mgr, const OfflineModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-canary-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
// see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py
class OfflineCanaryModel {
public:
explicit OfflineCanaryModel(const OfflineModelConfig &config);
template <typename Manager>
OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config);
~OfflineCanaryModel();
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C) of dtype float32.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - encoder_states: A 3-D tensor of shape (N, T', encoder_dim)
* - encoder_len: A 1-D tensor of shape (N,) containing number
* of frames in `encoder_out` before padding.
* Its dtype is int64_t
* - enc_mask: A 2-D tensor of shape (N, T') with dtype bool
*/
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) const;
/** Run the decoder model.
*
* @param tokens A int32 tensor of shape (N, num_tokens)
* @param decoder_states std::vector<Ort::Value>
* @param encoder_states Output from ForwardEncoder()
* @param enc_mask Output from ForwardEncoder()
*
* @return Return a pair:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - new_decoder_states: Can be used as input for ForwardDecoder()
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states, Ort::Value enc_mask) const;
// The return value can be used as input for ForwardDecoder()
std::vector<Ort::Value> GetInitialDecoderStates() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
const OfflineCanaryModelMetaData &GetModelMetadata() const;
OfflineCanaryModelMetaData &GetModelMetadata();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
... ...
... ... @@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
sense_voice.Register(po);
moonshine.Register(po);
dolphin.Register(po);
canary.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
... ... @@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const {
return dolphin.Validate();
}
if (!canary.encoder.empty()) {
return canary.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
... ... @@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const {
os << "sense_voice=" << sense_voice.ToString() << ", ";
os << "moonshine=" << moonshine.ToString() << ", ";
os << "dolphin=" << dolphin.ToString() << ", ";
os << "canary=" << canary.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
... ...
... ... @@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
#include "sherpa-onnx/csrc/offline-dolphin-model-config.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
... ... @@ -32,6 +33,7 @@ struct OfflineModelConfig {
OfflineSenseVoiceModelConfig sense_voice;
OfflineMoonshineModelConfig moonshine;
OfflineDolphinModelConfig dolphin;
OfflineCanaryModelConfig canary;
std::string telespeech_ctc;
std::string tokens;
... ... @@ -65,6 +67,7 @@ struct OfflineModelConfig {
const OfflineSenseVoiceModelConfig &sense_voice,
const OfflineMoonshineModelConfig &moonshine,
const OfflineDolphinModelConfig &dolphin,
const OfflineCanaryModelConfig &canary,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
... ... @@ -81,6 +84,7 @@ struct OfflineModelConfig {
sense_voice(sense_voice),
moonshine(moonshine),
dolphin(dolphin),
canary(canary),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
... ...
// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#include <algorithm>
#include <ios>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-canary-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
PostInit();
}
template <typename Manager>
explicit OfflineRecognizerCanaryImpl(Manager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
PostInit();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
DecodeStream(ss[i]);
}
}
void DecodeStream(OfflineStream *s) const {
auto meta = model_->GetModelMetadata();
auto enc_out = RunEncoder(s);
Ort::Value enc_states = std::move(enc_out[0]);
Ort::Value enc_mask = std::move(enc_out[2]);
// enc_out[1] is discarded
std::vector<int32_t> decoder_input = GetInitialDecoderInput();
auto decoder_states = model_->GetInitialDecoderStates();
Ort::Value logits{nullptr};
for (int32_t i = 0; i < decoder_input.size(); ++i) {
std::tie(logits, decoder_states) =
RunDecoder(decoder_input[i], i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
}
int32_t max_token_id = GetMaxTokenId(&logits);
int32_t eos = symbol_table_["<|endoftext|>"];
int32_t num_feature_frames =
enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
meta.subsampling_factor;
std::vector<int32_t> tokens = {max_token_id};
// Assume 30 tokens per second. It is to avoid the following for loop
// running indefinitely.
int32_t num_tokens =
static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
for (int32_t i = 1; i <= num_tokens; ++i) {
if (tokens.back() == eos) {
break;
}
std::tie(logits, decoder_states) =
RunDecoder(tokens.back(), i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
tokens.push_back(GetMaxTokenId(&logits));
}
// remove the last eos token
tokens.pop_back();
auto r = Convert(tokens);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
// we don't change the config_ in the base class
}
private:
OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
OfflineRecognitionResult r;
r.tokens.reserve(tokens.size());
std::string text;
for (auto i : tokens) {
if (!symbol_table_.Contains(i)) {
continue;
}
const auto &s = symbol_table_[i];
text += s;
r.tokens.push_back(s);
}
r.text = std::move(text);
return r;
}
int32_t GetMaxTokenId(Ort::Value *logits) const {
// logits is of shape (1, 1, vocab_size)
auto meta = model_->GetModelMetadata();
const float *p_logits = logits->GetTensorData<float>();
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
return max_token_id;
}
std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
return model_->ForwardEncoder(std::move(x), std::move(x_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
Ort::Value enc_states, Ort::Value enc_mask) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape = {1, 2};
std::array<int32_t, 2> _decoder_input = {token, pos};
Ort::Value decoder_input = Ort::Value::CreateTensor(
memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
shape.size());
return model_->ForwardDecoder(std::move(decoder_input),
std::move(decoder_states),
std::move(enc_states), std::move(enc_mask));
}
// see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
std::vector<int32_t> GetInitialDecoderInput() const {
auto canary_config = config_.model_config.canary;
const auto &meta = model_->GetModelMetadata();
std::vector<int32_t> decoder_input(9);
decoder_input[0] = symbol_table_["<|startofcontext|>"];
decoder_input[1] = symbol_table_["<|startoftranscript|>"];
decoder_input[2] = symbol_table_["<|emo:undefined|>"];
if (canary_config.src_lang.empty() ||
!meta.lang2id.count(canary_config.src_lang)) {
decoder_input[3] = meta.lang2id.at("en");
} else {
decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
}
if (canary_config.tgt_lang.empty() ||
!meta.lang2id.count(canary_config.tgt_lang)) {
decoder_input[4] = meta.lang2id.at("en");
} else {
decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
}
if (canary_config.use_pnc) {
decoder_input[5] = symbol_table_["<|pnc|>"];
} else {
decoder_input[5] = symbol_table_["<|nopnc|>"];
}
decoder_input[6] = symbol_table_["<|noitn|>"];
decoder_input[7] = symbol_table_["<|notimestamp|>"];
decoder_input[8] = symbol_table_["<|nodiarize|>"];
return decoder_input;
}
private:
void PostInit() {
auto &meta = model_->GetModelMetadata();
config_.feat_config.feature_dim = meta.feat_dim;
config_.feat_config.nemo_normalize_type = meta.normalize_type;
config_.feat_config.dither = 0;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.low_freq = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.is_librosa = true;
meta.lang2id["en"] = symbol_table_["<|en|>"];
meta.lang2id["es"] = symbol_table_["<|es|>"];
meta.lang2id["de"] = symbol_table_["<|de|>"];
meta.lang2id["fr"] = symbol_table_["<|fr|>"];
if (symbol_table_.NumSymbols() != meta.vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), meta.vocab_size);
SHERPA_ONNX_EXIT(-1);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineCanaryModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
... ...
... ... @@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert(
r.tokens.push_back(s);
}
r.text = text;
r.text = std::move(text);
return r;
}
... ...
... ... @@ -24,6 +24,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
... ... @@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
}
if (!config.model_config.canary.encoder.empty()) {
return std::make_unique<OfflineRecognizerCanaryImpl>(config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
... ... @@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
}
if (!config.model_config.canary.encoder.empty()) {
return std::make_unique<OfflineRecognizerCanaryImpl>(mgr, config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
... ...
... ... @@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) {
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<bool>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
... ...
... ... @@ -9,6 +9,7 @@ set(srcs
features.cc
homophone-replacer.cc
keyword-spotter.cc
offline-canary-model-config.cc
offline-ctc-fst-decoder-config.cc
offline-dolphin-model-config.cc
offline-fire-red-asr-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-canary-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
namespace sherpa_onnx {
void PybindOfflineCanaryModelConfig(py::module *m) {
using PyClass = OfflineCanaryModelConfig;
py::class_<PyClass>(*m, "OfflineCanaryModelConfig")
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &, bool>(),
py::arg("encoder") = "", py::arg("decoder") = "",
py::arg("src_lang") = "", py::arg("tgt_lang") = "",
py::arg("use_pnc") = true)
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def_readwrite("src_lang", &PyClass::src_lang)
.def_readwrite("tgt_lang", &PyClass::tgt_lang)
.def_readwrite("use_pnc", &PyClass::use_pnc)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-canary-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineCanaryModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
#include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h"
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
... ... @@ -34,6 +35,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineSenseVoiceModelConfig(m);
PybindOfflineMoonshineModelConfig(m);
PybindOfflineDolphinModelConfig(m);
PybindOfflineCanaryModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
... ... @@ -47,7 +49,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &,
const OfflineMoonshineModelConfig &,
const OfflineDolphinModelConfig &, const std::string &,
const OfflineDolphinModelConfig &,
const OfflineCanaryModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &,
const std::string &>(),
... ... @@ -62,8 +65,9 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("moonshine") = OfflineMoonshineModelConfig(),
py::arg("dolphin") = OfflineDolphinModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("canary") = OfflineCanaryModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens") = "",
py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
... ... @@ -77,6 +81,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("sense_voice", &PyClass::sense_voice)
.def_readwrite("moonshine", &PyClass::moonshine)
.def_readwrite("dolphin", &PyClass::dolphin)
.def_readwrite("canary", &PyClass::canary)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
... ...
... ... @@ -19,7 +19,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
const std::string &, int32_t, const std::string &, float,
float, const std::string &, const std::string &,
const HomophoneReplacerConfig &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("feat_config") = FeatureExtractorConfig(),
py::arg("model_config") = OfflineModelConfig(),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
... ... @@ -61,6 +62,8 @@ void PybindOfflineRecognizer(py::module *m) {
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
py::call_guard<py::gil_scoped_release>())
.def("set_config", &PyClass::SetConfig, py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def(
"decode_streams",
[](const PyClass &self, std::vector<OfflineStream *> ss) {
... ...
... ... @@ -8,9 +8,22 @@ from _sherpa_onnx import (
DenoisedAudio,
FastClustering,
FastClusteringConfig,
FeatureExtractorConfig,
HomophoneReplacerConfig,
OfflineCanaryModelConfig,
OfflineCtcFstDecoderConfig,
OfflineDolphinModelConfig,
OfflineFireRedAsrModelConfig,
OfflineLMConfig,
OfflineModelConfig,
OfflineMoonshineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflinePunctuation,
OfflinePunctuationConfig,
OfflinePunctuationModelConfig,
OfflineRecognizerConfig,
OfflineSenseVoiceModelConfig,
OfflineSourceSeparation,
OfflineSourceSeparationConfig,
OfflineSourceSeparationModelConfig,
... ... @@ -27,13 +40,18 @@ from _sherpa_onnx import (
OfflineSpeechDenoiserGtcrnModelConfig,
OfflineSpeechDenoiserModelConfig,
OfflineStream,
OfflineTdnnModelConfig,
OfflineTransducerModelConfig,
OfflineTts,
OfflineTtsConfig,
OfflineTtsKokoroModelConfig,
OfflineTtsMatchaModelConfig,
OfflineTtsModelConfig,
OfflineTtsVitsModelConfig,
OfflineWenetCtcModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerAudioTaggingModelConfig,
OfflineZipformerCtcModelConfig,
OnlinePunctuation,
OnlinePunctuationConfig,
OnlinePunctuationModelConfig,
... ...
... ... @@ -6,6 +6,7 @@ from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
HomophoneReplacerConfig,
OfflineCanaryModelConfig,
OfflineCtcFstDecoderConfig,
OfflineDolphinModelConfig,
OfflineFireRedAsrModelConfig,
... ... @@ -425,7 +426,6 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="nemo_ctc",
)
feat_config = FeatureExtractorConfig(
... ... @@ -691,6 +691,102 @@ class OfflineRecognizer(object):
return self
@classmethod
def from_nemo_canary(
cls,
encoder: str,
decoder: str,
tokens: str,
src_lang: str = "en",
tgt_lang: str = "en",
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 128, # not used
decoding_method: str = "greedy_search", # not used
debug: bool = False,
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/nemo/index.html>`_
to download pre-trained models for different languages.
Args:
encoder:
Path to ``encoder.onnx`` or ``encoder.int8.onnx``.
decoder:
Path to ``decoder.onnx`` or ``decoder.int8.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
src_lang:
The language of the input audio. Valid values are: en, es, de, fr.
If you leave it empty, it uses en internally.
tgt_lang:
The language of the output text. Valid values are: en, es, de, fr.
If you leave it empty, it uses en internally.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model. Not used
feature_dim:
Dimension of the feature used to train the model. Not used
decoding_method:
Valid values are greedy_search. Not used
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
canary=OfflineCanaryModelConfig(
encoder=encoder,
decoder=decoder,
src_lang=src_lang,
tgt_lang=tgt_lang,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_whisper(
cls,
encoder: str,
... ...