Fangjun Kuang
Committed by GitHub

Add C++ and Python API for FireRedASR AED models (#1867)

... ... @@ -133,3 +133,4 @@ lexicon.txt
us_gold.json
us_silver.json
kokoro-multi-lang-v1_0
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming FireRedAsr AED model from
https://github.com/FireRedTeam/FireRedASR
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"
if (
not Path(encoder).is_file()
or not Path(decoder).is_file()
or not Path(test_wav).is_file()
):
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
... ... @@ -27,6 +27,9 @@ set(sources
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-fire-red-asr-greedy-search-decoder.cc
offline-fire-red-asr-model-config.cc
offline-fire-red-asr-model.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#include <cstdint>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineFireRedAsrDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};
class OfflineFireRedAsrDecoder {
public:
virtual ~OfflineFireRedAsrDecoder() = default;
/** Run beam search given the output from the FireRedAsr encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
#include <algorithm>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
// Note: this functions works only for batch size == 1 at present
std::vector<OfflineFireRedAsrDecoderResult>
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
const auto &meta_data = model_->GetModelMetadata();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::array<int64_t, 2> token_shape = {1, 1};
int64_t token = meta_data.sos_id;
int32_t batch_size = 1;
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token, 1, token_shape.data(), token_shape.size());
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
std::vector<OfflineFireRedAsrDecoderResult> ans(1);
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
decoder_out = {Ort::Value{nullptr},
std::move(self_kv_cache.first),
std::move(self_kv_cache.second),
std::move(cross_k),
std::move(cross_v),
std::move(offset)};
for (int32_t i = 0; i < meta_data.max_len; ++i) {
decoder_out = model_->ForwardDecoder(View(&tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
if (max_token_id == meta_data.eos_id) {
break;
}
ans[0].tokens.push_back(max_token_id);
token = max_token_id;
// increment offset
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1;
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
namespace sherpa_onnx {
class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
public:
explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model)
: model_(model) {}
std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value cross_k, Ort::Value cross_v) override;
private:
OfflineFireRedAsrModel *model_; // not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) {
po->Register("fire-red-asr-encoder", &encoder,
"Path to onnx encoder of FireRedAsr");
po->Register("fire-red-asr-decoder", &decoder,
"Path to onnx decoder of FireRedAsr");
}
bool OfflineFireRedAsrModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist",
encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist",
decoder.c_str());
return false;
}
return true;
}
std::string OfflineFireRedAsrModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineFireRedAsrModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
// see https://github.com/FireRedTeam/FireRedASR
struct OfflineFireRedAsrModelConfig {
std::string encoder;
std::string decoder;
OfflineFireRedAsrModelConfig() = default;
OfflineFireRedAsrModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
struct OfflineFireRedAsrModelMetaData {
int32_t sos_id;
int32_t eos_id;
int32_t max_len;
int32_t num_decoder_layers;
int32_t num_head;
int32_t head_dim;
std::vector<float> mean;
std::vector<float> inv_stddev;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineFireRedAsrModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.fire_red_asr.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.fire_red_asr.decoder);
InitDecoder(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.fire_red_asr.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.fire_red_asr.decoder);
InitDecoder(buf.data(), buf.size());
}
}
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs{std::move(features),
std::move(features_length)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) {
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache),
std::move(n_layer_cross_k),
std::move(n_layer_cross_v),
std::move(offset)};
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
decoder_input.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
return std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value, Ort::Value>{
std::move(decoder_out[0]), std::move(decoder_out[1]),
std::move(decoder_out[2]), std::move(decoder_input[3]),
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
int32_t batch_size = 1;
std::array<int64_t, 5> shape{meta_data_.num_decoder_layers, batch_size,
meta_data_.max_len, meta_data_.num_head,
meta_data_.head_dim};
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
auto n = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
memset(p_k, 0, sizeof(float) * n);
memset(p_v, 0, sizeof(float) * n);
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
}
OrtAllocator *Allocator() { return allocator_; }
const OfflineFireRedAsrModelMetaData &GetModelMetadata() const {
return meta_data_;
}
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.num_decoder_layers,
"num_decoder_layers");
SHERPA_ONNX_READ_META_DATA(meta_data_.num_head, "num_head");
SHERPA_ONNX_READ_META_DATA(meta_data_.head_dim, "head_dim");
SHERPA_ONNX_READ_META_DATA(meta_data_.sos_id, "sos");
SHERPA_ONNX_READ_META_DATA(meta_data_.eos_id, "eos");
SHERPA_ONNX_READ_META_DATA(meta_data_.max_len, "max_len");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "cmvn_mean");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev,
"cmvn_inv_stddev");
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
OfflineFireRedAsrModelMetaData meta_data_;
};
OfflineFireRedAsrModel::OfflineFireRedAsrModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineFireRedAsrModel::OfflineFireRedAsrModel(Manager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineFireRedAsrModel::~OfflineFireRedAsrModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineFireRedAsrModel::ForwardEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
OfflineFireRedAsrModel::ForwardDecoder(Ort::Value tokens,
Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache,
Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v,
Ort::Value offset) const {
return impl_->ForwardDecoder(
std::move(tokens), std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
std::move(n_layer_cross_v), std::move(offset));
}
std::pair<Ort::Value, Ort::Value>
OfflineFireRedAsrModel::GetInitialSelfKVCache() const {
return impl_->GetInitialSelfKVCache();
}
OrtAllocator *OfflineFireRedAsrModel::Allocator() const {
return impl_->Allocator();
}
const OfflineFireRedAsrModelMetaData &OfflineFireRedAsrModel::GetModelMetadata()
const {
return impl_->GetModelMetadata();
}
#if __ANDROID_API__ >= 9
template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
#if __OHOS__
template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
NativeResourceManager *mgr, const OfflineModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-fire-red-asr-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
class OfflineFireRedAsrModel {
public:
explicit OfflineFireRedAsrModel(const OfflineModelConfig &config);
template <typename Manager>
OfflineFireRedAsrModel(Manager *mgr, const OfflineModelConfig &config);
~OfflineFireRedAsrModel();
/** Run the encoder model.
*
* @param features A tensor of shape (N, T, C).
* @param features_len A tensor of shape (N,) with dtype int64.
*
* @return Return a pair containing:
* - n_layer_cross_k: A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model)
* - n_layer_cross_v: A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model)
*/
std::pair<Ort::Value, Ort::Value> ForwardEncoder(
Ort::Value features, Ort::Value features_length) const;
/** Run the decoder model.
*
* @param tokens A int64 tensor of shape (N, num_words)
* @param n_layer_self_k_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* @param n_layer_self_v_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* @param n_layer_cross_k A 5-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 5-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param offset A int64 tensor of shape (N,)
*
* @return Return a tuple containing 6 tensors:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
* - out_n_layer_cross_k Same as n_layer_cross_k
* - out_n_layer_cross_v Same as n_layer_cross_v
* - out_offset Same as offset
*/
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) const;
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* - n_layer_self_v_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
*/
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
const OfflineFireRedAsrModelMetaData &GetModelMetadata() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
... ...
... ... @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
paraformer.Register(po);
nemo_ctc.Register(po);
whisper.Register(po);
fire_red_asr.Register(po);
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
... ... @@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc, telespeech_ctc."
"tdnn, zipformer2_ctc, telespeech_ctc, fire_red_asr."
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
... ... @@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const {
return whisper.Validate();
}
if (!fire_red_asr.encoder.empty()) {
return fire_red_asr.Validate();
}
if (!tdnn.model.empty()) {
return tdnn.Validate();
}
... ... @@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const {
os << "paraformer=" << paraformer.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "whisper=" << whisper.ToString() << ", ";
os << "fire_red_asr=" << fire_red_asr.ToString() << ", ";
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
... ...
... ... @@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
... ... @@ -23,6 +24,7 @@ struct OfflineModelConfig {
OfflineParaformerModelConfig paraformer;
OfflineNemoEncDecCtcModelConfig nemo_ctc;
OfflineWhisperModelConfig whisper;
OfflineFireRedAsrModelConfig fire_red_asr;
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
... ... @@ -54,6 +56,7 @@ struct OfflineModelConfig {
const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const OfflineWhisperModelConfig &whisper,
const OfflineFireRedAsrModelConfig &fire_red_asr,
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
... ... @@ -68,6 +71,7 @@ struct OfflineModelConfig {
paraformer(paraformer),
nemo_ctc(nemo_ctc),
whisper(whisper),
fire_red_asr(fire_red_asr),
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
... ...
// sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(
const OfflineFireRedAsrDecoderResult &src, const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}
const auto &s = sym_table[i];
text += s;
r.tokens.push_back(s);
}
r.text = text;
return r;
}
class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerFireRedAsrImpl(
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineFireRedAsrModel>(config.model_config)) {
Init();
}
template <typename Manager>
OfflineRecognizerFireRedAsrImpl(Manager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineFireRedAsrModel>(mgr,
config.model_config)) {
Init();
}
void Init() {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineFireRedAsrGreedySearchDecoder>(model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for FireRedAsr. Given %s",
config_.decoding_method.c_str());
SHERPA_ONNX_EXIT(-1);
}
const auto &meta_data = model_->GetModelMetadata();
config_.feat_config.normalize_samples = false;
config_.feat_config.high_freq = 0;
config_.feat_config.snip_edges = true;
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
// batch decoding is not implemented yet
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
ApplyCMVN(&f);
int64_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape{1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t len_shape = 1;
Ort::Value x_len =
Ort::Value::CreateTensor(memory_info, &num_frames, 1, &len_shape, 1);
auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len));
auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
}
void ApplyCMVN(std::vector<float> *v) const {
const auto &meta_data = model_->GetModelMetadata();
const auto &mean = meta_data.mean;
const auto &inv_stddev = meta_data.inv_stddev;
int32_t feat_dim = static_cast<int32_t>(mean.size());
int32_t num_frames = static_cast<int32_t>(v->size()) / feat_dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != feat_dim; ++k) {
p[k] = (p[k] - mean[k]) * inv_stddev[k];
}
p += feat_dim;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineFireRedAsrModel> model_;
std::unique_ptr<OfflineFireRedAsrDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
... ...
... ... @@ -24,6 +24,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
... ... @@ -56,6 +57,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
if (!config.model_config.fire_red_asr.encoder.empty()) {
return std::make_unique<OfflineRecognizerFireRedAsrImpl>(config);
}
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
}
... ... @@ -237,6 +242,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}
if (!config.model_config.fire_red_asr.encoder.empty()) {
return std::make_unique<OfflineRecognizerFireRedAsrImpl>(mgr, config);
}
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
}
... ...
... ... @@ -9,6 +9,7 @@ set(srcs
features.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-fire-red-asr-model-config.cc
offline-lm-config.cc
offline-model-config.cc
offline-moonshine-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
namespace sherpa_onnx {
void PybindOfflineFireRedAsrModelConfig(py::module *m) {
using PyClass = OfflineFireRedAsrModelConfig;
py::class_<PyClass>(*m, "OfflineFireRedAsrModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineFireRedAsrModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
... ... @@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineParaformerModelConfig(m);
PybindOfflineNemoEncDecCtcModelConfig(m);
PybindOfflineWhisperModelConfig(m);
PybindOfflineFireRedAsrModelConfig(m);
PybindOfflineTdnnModelConfig(m);
PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m);
... ... @@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) {
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(
py::init<
const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &,
const OfflineMoonshineModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("moonshine") = OfflineMoonshineModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &,
const OfflineFireRedAsrModelConfig &,
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &,
const OfflineMoonshineModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &,
const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("fire_red_asr") = OfflineFireRedAsrModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("moonshine") = OfflineMoonshineModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("fire_red_asr", &PyClass::fire_red_asr)
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
... ...
... ... @@ -6,6 +6,7 @@ from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
OfflineCtcFstDecoderConfig,
OfflineFireRedAsrModelConfig,
OfflineLMConfig,
OfflineModelConfig,
OfflineMoonshineModelConfig,
... ... @@ -572,6 +573,78 @@ class OfflineRecognizer(object):
return self
@classmethod
def from_fire_red_asr(
cls,
encoder: str,
decoder: str,
tokens: str,
num_threads: int = 1,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/fire_red_asr/index.html>`_
to download pre-trained models for different kinds of FireRedAsr models,
e.g., xs, large, etc.
Args:
encoder:
Path to the encoder model.
decoder:
Path to the decoder model.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
fire_red_asr=OfflineFireRedAsrModelConfig(
encoder=encoder,
decoder=decoder,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_moonshine(
cls,
preprocessor: str,
... ...