Fangjun Kuang
Committed by GitHub

Add C++ and Python API for FireRedASR AED models (#1867)

@@ -133,3 +133,4 @@ lexicon.txt @@ -133,3 +133,4 @@ lexicon.txt
133 us_gold.json 133 us_gold.json
134 us_silver.json 134 us_silver.json
135 kokoro-multi-lang-v1_0 135 kokoro-multi-lang-v1_0
  136 +sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming FireRedAsr AED model from
  5 +https://github.com/FireRedTeam/FireRedASR
  6 +to decode files.
  7 +
  8 +Please download model files from
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  10 +
  11 +For instance,
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
  14 +tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
  15 +rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
  16 +"""
  17 +
  18 +from pathlib import Path
  19 +
  20 +import sherpa_onnx
  21 +import soundfile as sf
  22 +
  23 +
  24 +def create_recognizer():
  25 + encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
  26 + decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
  27 + tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
  28 + test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
  29 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
  30 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
  31 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
  32 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
  33 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
  34 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
  35 + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"
  36 +
  37 + if (
  38 + not Path(encoder).is_file()
  39 + or not Path(decoder).is_file()
  40 + or not Path(test_wav).is_file()
  41 + ):
  42 + raise ValueError(
  43 + """Please download model files from
  44 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  45 + """
  46 + )
  47 + return (
  48 + sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
  49 + encoder=encoder,
  50 + decoder=decoder,
  51 + tokens=tokens,
  52 + debug=True,
  53 + ),
  54 + test_wav,
  55 + )
  56 +
  57 +
  58 +def main():
  59 + recognizer, wave_filename = create_recognizer()
  60 +
  61 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  62 + audio = audio[:, 0] # only use the first channel
  63 +
  64 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  65 + # sample_rate does not need to be 16000 Hz
  66 +
  67 + stream = recognizer.create_stream()
  68 + stream.accept_waveform(sample_rate, audio)
  69 + recognizer.decode_stream(stream)
  70 + print(wave_filename)
  71 + print(stream.result)
  72 +
  73 +
  74 +if __name__ == "__main__":
  75 + main()
@@ -27,6 +27,9 @@ set(sources @@ -27,6 +27,9 @@ set(sources
27 offline-ctc-fst-decoder.cc 27 offline-ctc-fst-decoder.cc
28 offline-ctc-greedy-search-decoder.cc 28 offline-ctc-greedy-search-decoder.cc
29 offline-ctc-model.cc 29 offline-ctc-model.cc
  30 + offline-fire-red-asr-greedy-search-decoder.cc
  31 + offline-fire-red-asr-model-config.cc
  32 + offline-fire-red-asr-model.cc
30 offline-lm-config.cc 33 offline-lm-config.cc
31 offline-lm.cc 34 offline-lm.cc
32 offline-model-config.cc 35 offline-model-config.cc
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
  7 +
  8 +#include <cstdint>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OfflineFireRedAsrDecoderResult {
  16 + /// The decoded token IDs
  17 + std::vector<int32_t> tokens;
  18 +};
  19 +
  20 +class OfflineFireRedAsrDecoder {
  21 + public:
  22 + virtual ~OfflineFireRedAsrDecoder() = default;
  23 +
  24 + /** Run beam search given the output from the FireRedAsr encoder model.
  25 + *
  26 + * @param n_layer_cross_k A 4-D tensor of shape
  27 + * (num_decoder_layers, N, T, d_model).
  28 + * @param n_layer_cross_v A 4-D tensor of shape
  29 + * (num_decoder_layers, N, T, d_model).
  30 + *
  31 + * @return Return a vector of size `N` containing the decoded results.
  32 + */
  33 + virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
  34 + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <tuple>
  9 +#include <utility>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +// Note: this functions works only for batch size == 1 at present
  17 +std::vector<OfflineFireRedAsrDecoderResult>
  18 +OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
  19 + Ort::Value cross_v) {
  20 + const auto &meta_data = model_->GetModelMetadata();
  21 +
  22 + auto memory_info =
  23 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  24 +
  25 + // For multilingual models, initial_tokens contains [sot, language, task]
  26 + // - language is English by default
  27 + // - task is transcribe by default
  28 + //
  29 + // For non-multilingual models, initial_tokens contains [sot]
  30 + std::array<int64_t, 2> token_shape = {1, 1};
  31 + int64_t token = meta_data.sos_id;
  32 +
  33 + int32_t batch_size = 1;
  34 +
  35 + Ort::Value tokens = Ort::Value::CreateTensor(
  36 + memory_info, &token, 1, token_shape.data(), token_shape.size());
  37 +
  38 + std::array<int64_t, 1> offset_shape{1};
  39 + Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
  40 + model_->Allocator(), offset_shape.data(), offset_shape.size());
  41 + *(offset.GetTensorMutableData<int64_t>()) = 0;
  42 +
  43 + std::vector<OfflineFireRedAsrDecoderResult> ans(1);
  44 +
  45 + auto self_kv_cache = model_->GetInitialSelfKVCache();
  46 +
  47 + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  48 + Ort::Value>
  49 + decoder_out = {Ort::Value{nullptr},
  50 + std::move(self_kv_cache.first),
  51 + std::move(self_kv_cache.second),
  52 + std::move(cross_k),
  53 + std::move(cross_v),
  54 + std::move(offset)};
  55 +
  56 + for (int32_t i = 0; i < meta_data.max_len; ++i) {
  57 + decoder_out = model_->ForwardDecoder(View(&tokens),
  58 + std::move(std::get<1>(decoder_out)),
  59 + std::move(std::get<2>(decoder_out)),
  60 + std::move(std::get<3>(decoder_out)),
  61 + std::move(std::get<4>(decoder_out)),
  62 + std::move(std::get<5>(decoder_out)));
  63 +
  64 + const auto &logits = std::get<0>(decoder_out);
  65 + const float *p_logits = logits.GetTensorData<float>();
  66 +
  67 + auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
  68 + int32_t vocab_size = logits_shape[2];
  69 +
  70 + int32_t max_token_id = static_cast<int32_t>(std::distance(
  71 + p_logits, std::max_element(p_logits, p_logits + vocab_size)));
  72 + if (max_token_id == meta_data.eos_id) {
  73 + break;
  74 + }
  75 +
  76 + ans[0].tokens.push_back(max_token_id);
  77 +
  78 + token = max_token_id;
  79 +
  80 + // increment offset
  81 + *(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1;
  82 + }
  83 +
  84 + return ans;
  85 +}
  86 +
  87 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
  16 + public:
  17 + explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model)
  18 + : model_(model) {}
  19 +
  20 + std::vector<OfflineFireRedAsrDecoderResult> Decode(
  21 + Ort::Value cross_k, Ort::Value cross_v) override;
  22 +
  23 + private:
  24 + OfflineFireRedAsrModel *model_; // not owned
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) {
  13 + po->Register("fire-red-asr-encoder", &encoder,
  14 + "Path to onnx encoder of FireRedAsr");
  15 +
  16 + po->Register("fire-red-asr-decoder", &decoder,
  17 + "Path to onnx decoder of FireRedAsr");
  18 +}
  19 +
  20 +bool OfflineFireRedAsrModelConfig::Validate() const {
  21 + if (encoder.empty()) {
  22 + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder");
  23 + return false;
  24 + }
  25 +
  26 + if (!FileExists(encoder)) {
  27 + SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist",
  28 + encoder.c_str());
  29 + return false;
  30 + }
  31 +
  32 + if (decoder.empty()) {
  33 + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder");
  34 + return false;
  35 + }
  36 +
  37 + if (!FileExists(decoder)) {
  38 + SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist",
  39 + decoder.c_str());
  40 + return false;
  41 + }
  42 +
  43 + return true;
  44 +}
  45 +
  46 +std::string OfflineFireRedAsrModelConfig::ToString() const {
  47 + std::ostringstream os;
  48 +
  49 + os << "OfflineFireRedAsrModelConfig(";
  50 + os << "encoder=\"" << encoder << "\", ";
  51 + os << "decoder=\"" << decoder << "\")";
  52 +
  53 + return os.str();
  54 +}
  55 +
  56 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// see https://github.com/FireRedTeam/FireRedASR
  14 +struct OfflineFireRedAsrModelConfig {
  15 + std::string encoder;
  16 + std::string decoder;
  17 +
  18 + OfflineFireRedAsrModelConfig() = default;
  19 + OfflineFireRedAsrModelConfig(const std::string &encoder,
  20 + const std::string &decoder)
  21 + : encoder(encoder), decoder(decoder) {}
  22 +
  23 + void Register(ParseOptions *po);
  24 + bool Validate() const;
  25 +
  26 + std::string ToString() const;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
  6 +
  7 +#include <string>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineFireRedAsrModelMetaData {
  14 + int32_t sos_id;
  15 + int32_t eos_id;
  16 + int32_t max_len;
  17 +
  18 + int32_t num_decoder_layers;
  19 + int32_t num_head;
  20 + int32_t head_dim;
  21 +
  22 + std::vector<float> mean;
  23 + std::vector<float> inv_stddev;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <string>
  10 +#include <tuple>
  11 +#include <unordered_map>
  12 +#include <utility>
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
  18 +
  19 +#if __OHOS__
  20 +#include "rawfile/raw_file_manager.h"
  21 +#endif
  22 +
  23 +#include "sherpa-onnx/csrc/macros.h"
  24 +#include "sherpa-onnx/csrc/onnx-utils.h"
  25 +#include "sherpa-onnx/csrc/session.h"
  26 +#include "sherpa-onnx/csrc/text-utils.h"
  27 +
  28 +namespace sherpa_onnx {
  29 +
  30 +class OfflineFireRedAsrModel::Impl {
  31 + public:
  32 + explicit Impl(const OfflineModelConfig &config)
  33 + : config_(config),
  34 + env_(ORT_LOGGING_LEVEL_ERROR),
  35 + sess_opts_(GetSessionOptions(config)),
  36 + allocator_{} {
  37 + {
  38 + auto buf = ReadFile(config.fire_red_asr.encoder);
  39 + InitEncoder(buf.data(), buf.size());
  40 + }
  41 +
  42 + {
  43 + auto buf = ReadFile(config.fire_red_asr.decoder);
  44 + InitDecoder(buf.data(), buf.size());
  45 + }
  46 + }
  47 +
  48 + template <typename Manager>
  49 + Impl(Manager *mgr, const OfflineModelConfig &config)
  50 + : config_(config),
  51 + env_(ORT_LOGGING_LEVEL_ERROR),
  52 + sess_opts_(GetSessionOptions(config)),
  53 + allocator_{} {
  54 + {
  55 + auto buf = ReadFile(mgr, config.fire_red_asr.encoder);
  56 + InitEncoder(buf.data(), buf.size());
  57 + }
  58 +
  59 + {
  60 + auto buf = ReadFile(mgr, config.fire_red_asr.decoder);
  61 + InitDecoder(buf.data(), buf.size());
  62 + }
  63 + }
  64 +
  65 + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features,
  66 + Ort::Value features_length) {
  67 + std::array<Ort::Value, 2> inputs{std::move(features),
  68 + std::move(features_length)};
  69 +
  70 + auto encoder_out = encoder_sess_->Run(
  71 + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
  72 + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
  73 +
  74 + return {std::move(encoder_out[0]), std::move(encoder_out[1])};
  75 + }
  76 +
  77 + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  78 + Ort::Value>
  79 + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
  80 + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
  81 + Ort::Value n_layer_cross_v, Ort::Value offset) {
  82 + std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
  83 + std::move(n_layer_self_k_cache),
  84 + std::move(n_layer_self_v_cache),
  85 + std::move(n_layer_cross_k),
  86 + std::move(n_layer_cross_v),
  87 + std::move(offset)};
  88 +
  89 + auto decoder_out = decoder_sess_->Run(
  90 + {}, decoder_input_names_ptr_.data(), decoder_input.data(),
  91 + decoder_input.size(), decoder_output_names_ptr_.data(),
  92 + decoder_output_names_ptr_.size());
  93 +
  94 + return std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  95 + Ort::Value, Ort::Value>{
  96 + std::move(decoder_out[0]), std::move(decoder_out[1]),
  97 + std::move(decoder_out[2]), std::move(decoder_input[3]),
  98 + std::move(decoder_input[4]), std::move(decoder_input[5])};
  99 + }
  100 +
  101 + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
  102 + int32_t batch_size = 1;
  103 + std::array<int64_t, 5> shape{meta_data_.num_decoder_layers, batch_size,
  104 + meta_data_.max_len, meta_data_.num_head,
  105 + meta_data_.head_dim};
  106 +
  107 + Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
  108 + Allocator(), shape.data(), shape.size());
  109 +
  110 + Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
  111 + Allocator(), shape.data(), shape.size());
  112 +
  113 + auto n = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
  114 +
  115 + float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
  116 + float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
  117 +
  118 + memset(p_k, 0, sizeof(float) * n);
  119 + memset(p_v, 0, sizeof(float) * n);
  120 +
  121 + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
  122 + }
  123 +
  124 + OrtAllocator *Allocator() { return allocator_; }
  125 +
  126 + const OfflineFireRedAsrModelMetaData &GetModelMetadata() const {
  127 + return meta_data_;
  128 + }
  129 +
  130 + private:
  131 + void InitEncoder(void *model_data, size_t model_data_length) {
  132 + encoder_sess_ = std::make_unique<Ort::Session>(
  133 + env_, model_data, model_data_length, sess_opts_);
  134 +
  135 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  136 + &encoder_input_names_ptr_);
  137 +
  138 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  139 + &encoder_output_names_ptr_);
  140 +
  141 + // get meta data
  142 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  143 + if (config_.debug) {
  144 + std::ostringstream os;
  145 + os << "---encoder---\n";
  146 + PrintModelMetadata(os, meta_data);
  147 +#if __OHOS__
  148 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  149 +#else
  150 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  151 +#endif
  152 + }
  153 +
  154 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  155 + SHERPA_ONNX_READ_META_DATA(meta_data_.num_decoder_layers,
  156 + "num_decoder_layers");
  157 + SHERPA_ONNX_READ_META_DATA(meta_data_.num_head, "num_head");
  158 + SHERPA_ONNX_READ_META_DATA(meta_data_.head_dim, "head_dim");
  159 + SHERPA_ONNX_READ_META_DATA(meta_data_.sos_id, "sos");
  160 + SHERPA_ONNX_READ_META_DATA(meta_data_.eos_id, "eos");
  161 + SHERPA_ONNX_READ_META_DATA(meta_data_.max_len, "max_len");
  162 +
  163 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "cmvn_mean");
  164 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev,
  165 + "cmvn_inv_stddev");
  166 + }
  167 +
  168 + void InitDecoder(void *model_data, size_t model_data_length) {
  169 + decoder_sess_ = std::make_unique<Ort::Session>(
  170 + env_, model_data, model_data_length, sess_opts_);
  171 +
  172 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  173 + &decoder_input_names_ptr_);
  174 +
  175 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  176 + &decoder_output_names_ptr_);
  177 + }
  178 +
  179 + private:
  180 + OfflineModelConfig config_;
  181 + Ort::Env env_;
  182 + Ort::SessionOptions sess_opts_;
  183 + Ort::AllocatorWithDefaultOptions allocator_;
  184 +
  185 + std::unique_ptr<Ort::Session> encoder_sess_;
  186 + std::unique_ptr<Ort::Session> decoder_sess_;
  187 +
  188 + std::vector<std::string> encoder_input_names_;
  189 + std::vector<const char *> encoder_input_names_ptr_;
  190 +
  191 + std::vector<std::string> encoder_output_names_;
  192 + std::vector<const char *> encoder_output_names_ptr_;
  193 +
  194 + std::vector<std::string> decoder_input_names_;
  195 + std::vector<const char *> decoder_input_names_ptr_;
  196 +
  197 + std::vector<std::string> decoder_output_names_;
  198 + std::vector<const char *> decoder_output_names_ptr_;
  199 +
  200 + OfflineFireRedAsrModelMetaData meta_data_;
  201 +};
  202 +
  203 +OfflineFireRedAsrModel::OfflineFireRedAsrModel(const OfflineModelConfig &config)
  204 + : impl_(std::make_unique<Impl>(config)) {}
  205 +
  206 +template <typename Manager>
  207 +OfflineFireRedAsrModel::OfflineFireRedAsrModel(Manager *mgr,
  208 + const OfflineModelConfig &config)
  209 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  210 +
  211 +OfflineFireRedAsrModel::~OfflineFireRedAsrModel() = default;
  212 +
  213 +std::pair<Ort::Value, Ort::Value> OfflineFireRedAsrModel::ForwardEncoder(
  214 + Ort::Value features, Ort::Value features_length) const {
  215 + return impl_->ForwardEncoder(std::move(features), std::move(features_length));
  216 +}
  217 +
  218 +std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  219 + Ort::Value>
  220 +OfflineFireRedAsrModel::ForwardDecoder(Ort::Value tokens,
  221 + Ort::Value n_layer_self_k_cache,
  222 + Ort::Value n_layer_self_v_cache,
  223 + Ort::Value n_layer_cross_k,
  224 + Ort::Value n_layer_cross_v,
  225 + Ort::Value offset) const {
  226 + return impl_->ForwardDecoder(
  227 + std::move(tokens), std::move(n_layer_self_k_cache),
  228 + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
  229 + std::move(n_layer_cross_v), std::move(offset));
  230 +}
  231 +
  232 +std::pair<Ort::Value, Ort::Value>
  233 +OfflineFireRedAsrModel::GetInitialSelfKVCache() const {
  234 + return impl_->GetInitialSelfKVCache();
  235 +}
  236 +
  237 +OrtAllocator *OfflineFireRedAsrModel::Allocator() const {
  238 + return impl_->Allocator();
  239 +}
  240 +
  241 +const OfflineFireRedAsrModelMetaData &OfflineFireRedAsrModel::GetModelMetadata()
  242 + const {
  243 + return impl_->GetModelMetadata();
  244 +}
  245 +
  246 +#if __ANDROID_API__ >= 9
  247 +template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
  248 + AAssetManager *mgr, const OfflineModelConfig &config);
  249 +#endif
  250 +
  251 +#if __OHOS__
  252 +template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
  253 + NativeResourceManager *mgr, const OfflineModelConfig &config);
  254 +#endif
  255 +
  256 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-fire-red-asr-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <tuple>
  10 +#include <unordered_map>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "onnxruntime_cxx_api.h" // NOLINT
  15 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h"
  16 +#include "sherpa-onnx/csrc/offline-model-config.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +class OfflineFireRedAsrModel {
  21 + public:
  22 + explicit OfflineFireRedAsrModel(const OfflineModelConfig &config);
  23 +
  24 + template <typename Manager>
  25 + OfflineFireRedAsrModel(Manager *mgr, const OfflineModelConfig &config);
  26 +
  27 + ~OfflineFireRedAsrModel();
  28 +
  29 + /** Run the encoder model.
  30 + *
  31 + * @param features A tensor of shape (N, T, C).
  32 + * @param features_len A tensor of shape (N,) with dtype int64.
  33 + *
  34 + * @return Return a pair containing:
  35 + * - n_layer_cross_k: A 4-D tensor of shape
  36 + * (num_decoder_layers, N, T, d_model)
  37 + * - n_layer_cross_v: A 4-D tensor of shape
  38 + * (num_decoder_layers, N, T, d_model)
  39 + */
  40 + std::pair<Ort::Value, Ort::Value> ForwardEncoder(
  41 + Ort::Value features, Ort::Value features_length) const;
  42 +
  43 + /** Run the decoder model.
  44 + *
  45 + * @param tokens A int64 tensor of shape (N, num_words)
  46 + * @param n_layer_self_k_cache A 5-D tensor of shape
  47 + * (num_decoder_layers, N, max_len, num_head, head_dim).
  48 + * @param n_layer_self_v_cache A 5-D tensor of shape
  49 + * (num_decoder_layers, N, max_len, num_head, head_dim).
  50 + * @param n_layer_cross_k A 5-D tensor of shape
  51 + * (num_decoder_layers, N, T, d_model).
  52 + * @param n_layer_cross_v A 5-D tensor of shape
  53 + * (num_decoder_layers, N, T, d_model).
  54 + * @param offset A int64 tensor of shape (N,)
  55 + *
  56 + * @return Return a tuple containing 6 tensors:
  57 + *
  58 + * - logits A 3-D tensor of shape (N, num_words, vocab_size)
  59 + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
  60 + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
  61 + * - out_n_layer_cross_k Same as n_layer_cross_k
  62 + * - out_n_layer_cross_v Same as n_layer_cross_v
  63 + * - out_offset Same as offset
  64 + */
  65 + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  66 + Ort::Value>
  67 + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
  68 + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
  69 + Ort::Value n_layer_cross_v, Ort::Value offset) const;
  70 +
  71 + /** Return the initial self kv cache in a pair
  72 + * - n_layer_self_k_cache A 5-D tensor of shape
  73 + * (num_decoder_layers, N, max_len, num_head, head_dim).
  74 + * - n_layer_self_v_cache A 5-D tensor of shape
  75 + * (num_decoder_layers, N, max_len, num_head, head_dim).
  76 + */
  77 + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
  78 +
  79 + const OfflineFireRedAsrModelMetaData &GetModelMetadata() const;
  80 +
  81 + /** Return an allocator for allocating memory
  82 + */
  83 + OrtAllocator *Allocator() const;
  84 +
  85 + private:
  86 + class Impl;
  87 + std::unique_ptr<Impl> impl_;
  88 +};
  89 +
  90 +} // namespace sherpa_onnx
  91 +
  92 +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
@@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
15 paraformer.Register(po); 15 paraformer.Register(po);
16 nemo_ctc.Register(po); 16 nemo_ctc.Register(po);
17 whisper.Register(po); 17 whisper.Register(po);
  18 + fire_red_asr.Register(po);
18 tdnn.Register(po); 19 tdnn.Register(po);
19 zipformer_ctc.Register(po); 20 zipformer_ctc.Register(po);
20 wenet_ctc.Register(po); 21 wenet_ctc.Register(po);
@@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
38 po->Register("model-type", &model_type, 39 po->Register("model-type", &model_type,
39 "Specify it to reduce model initialization time. " 40 "Specify it to reduce model initialization time. "
40 "Valid values are: transducer, paraformer, nemo_ctc, whisper, " 41 "Valid values are: transducer, paraformer, nemo_ctc, whisper, "
41 - "tdnn, zipformer2_ctc, telespeech_ctc." 42 + "tdnn, zipformer2_ctc, telespeech_ctc, fire_red_asr."
42 "All other values lead to loading the model twice."); 43 "All other values lead to loading the model twice.");
43 po->Register("modeling-unit", &modeling_unit, 44 po->Register("modeling-unit", &modeling_unit,
44 "The modeling unit of the model, commonly used units are bpe, " 45 "The modeling unit of the model, commonly used units are bpe, "
@@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const { @@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const {
84 return whisper.Validate(); 85 return whisper.Validate();
85 } 86 }
86 87
  88 + if (!fire_red_asr.encoder.empty()) {
  89 + return fire_red_asr.Validate();
  90 + }
  91 +
87 if (!tdnn.model.empty()) { 92 if (!tdnn.model.empty()) {
88 return tdnn.Validate(); 93 return tdnn.Validate();
89 } 94 }
@@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const { @@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const {
125 os << "paraformer=" << paraformer.ToString() << ", "; 130 os << "paraformer=" << paraformer.ToString() << ", ";
126 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 131 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
127 os << "whisper=" << whisper.ToString() << ", "; 132 os << "whisper=" << whisper.ToString() << ", ";
  133 + os << "fire_red_asr=" << fire_red_asr.ToString() << ", ";
128 os << "tdnn=" << tdnn.ToString() << ", "; 134 os << "tdnn=" << tdnn.ToString() << ", ";
129 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; 135 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
130 os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; 136 os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
9 #include "sherpa-onnx/csrc/offline-moonshine-model-config.h" 10 #include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
10 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" 11 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
11 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 12 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
@@ -23,6 +24,7 @@ struct OfflineModelConfig { @@ -23,6 +24,7 @@ struct OfflineModelConfig {
23 OfflineParaformerModelConfig paraformer; 24 OfflineParaformerModelConfig paraformer;
24 OfflineNemoEncDecCtcModelConfig nemo_ctc; 25 OfflineNemoEncDecCtcModelConfig nemo_ctc;
25 OfflineWhisperModelConfig whisper; 26 OfflineWhisperModelConfig whisper;
  27 + OfflineFireRedAsrModelConfig fire_red_asr;
26 OfflineTdnnModelConfig tdnn; 28 OfflineTdnnModelConfig tdnn;
27 OfflineZipformerCtcModelConfig zipformer_ctc; 29 OfflineZipformerCtcModelConfig zipformer_ctc;
28 OfflineWenetCtcModelConfig wenet_ctc; 30 OfflineWenetCtcModelConfig wenet_ctc;
@@ -54,6 +56,7 @@ struct OfflineModelConfig { @@ -54,6 +56,7 @@ struct OfflineModelConfig {
54 const OfflineParaformerModelConfig &paraformer, 56 const OfflineParaformerModelConfig &paraformer,
55 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 57 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
56 const OfflineWhisperModelConfig &whisper, 58 const OfflineWhisperModelConfig &whisper,
  59 + const OfflineFireRedAsrModelConfig &fire_red_asr,
57 const OfflineTdnnModelConfig &tdnn, 60 const OfflineTdnnModelConfig &tdnn,
58 const OfflineZipformerCtcModelConfig &zipformer_ctc, 61 const OfflineZipformerCtcModelConfig &zipformer_ctc,
59 const OfflineWenetCtcModelConfig &wenet_ctc, 62 const OfflineWenetCtcModelConfig &wenet_ctc,
@@ -68,6 +71,7 @@ struct OfflineModelConfig { @@ -68,6 +71,7 @@ struct OfflineModelConfig {
68 paraformer(paraformer), 71 paraformer(paraformer),
69 nemo_ctc(nemo_ctc), 72 nemo_ctc(nemo_ctc),
70 whisper(whisper), 73 whisper(whisper),
  74 + fire_red_asr(fire_red_asr),
71 tdnn(tdnn), 75 tdnn(tdnn),
72 zipformer_ctc(zipformer_ctc), 76 zipformer_ctc(zipformer_ctc),
73 wenet_ctc(wenet_ctc), 77 wenet_ctc(wenet_ctc),
  1 +// sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <cmath>
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
  16 +#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
  17 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
  18 +#include "sherpa-onnx/csrc/offline-model-config.h"
  19 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  20 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  21 +#include "sherpa-onnx/csrc/symbol-table.h"
  22 +#include "sherpa-onnx/csrc/transpose.h"
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +static OfflineRecognitionResult Convert(
  27 + const OfflineFireRedAsrDecoderResult &src, const SymbolTable &sym_table) {
  28 + OfflineRecognitionResult r;
  29 + r.tokens.reserve(src.tokens.size());
  30 +
  31 + std::string text;
  32 + for (auto i : src.tokens) {
  33 + if (!sym_table.Contains(i)) {
  34 + continue;
  35 + }
  36 +
  37 + const auto &s = sym_table[i];
  38 + text += s;
  39 + r.tokens.push_back(s);
  40 + }
  41 +
  42 + r.text = text;
  43 +
  44 + return r;
  45 +}
  46 +
  47 +class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
  48 + public:
  49 + explicit OfflineRecognizerFireRedAsrImpl(
  50 + const OfflineRecognizerConfig &config)
  51 + : OfflineRecognizerImpl(config),
  52 + config_(config),
  53 + symbol_table_(config_.model_config.tokens),
  54 + model_(std::make_unique<OfflineFireRedAsrModel>(config.model_config)) {
  55 + Init();
  56 + }
  57 +
  58 + template <typename Manager>
  59 + OfflineRecognizerFireRedAsrImpl(Manager *mgr,
  60 + const OfflineRecognizerConfig &config)
  61 + : OfflineRecognizerImpl(mgr, config),
  62 + config_(config),
  63 + symbol_table_(mgr, config_.model_config.tokens),
  64 + model_(std::make_unique<OfflineFireRedAsrModel>(mgr,
  65 + config.model_config)) {
  66 + Init();
  67 + }
  68 +
  69 + void Init() {
  70 + if (config_.decoding_method == "greedy_search") {
  71 + decoder_ =
  72 + std::make_unique<OfflineFireRedAsrGreedySearchDecoder>(model_.get());
  73 + } else {
  74 + SHERPA_ONNX_LOGE(
  75 + "Only greedy_search is supported at present for FireRedAsr. Given %s",
  76 + config_.decoding_method.c_str());
  77 + SHERPA_ONNX_EXIT(-1);
  78 + }
  79 +
  80 + const auto &meta_data = model_->GetModelMetadata();
  81 +
  82 + config_.feat_config.normalize_samples = false;
  83 + config_.feat_config.high_freq = 0;
  84 + config_.feat_config.snip_edges = true;
  85 + }
  86 +
  87 + std::unique_ptr<OfflineStream> CreateStream() const override {
  88 + return std::make_unique<OfflineStream>(config_.feat_config);
  89 + }
  90 +
  91 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  92 + // batch decoding is not implemented yet
  93 + for (int32_t i = 0; i != n; ++i) {
  94 + DecodeStream(ss[i]);
  95 + }
  96 + }
  97 +
  98 + OfflineRecognizerConfig GetConfig() const override { return config_; }
  99 +
  100 + private:
  101 + void DecodeStream(OfflineStream *s) const {
  102 + auto memory_info =
  103 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  104 +
  105 + int32_t feat_dim = s->FeatureDim();
  106 + std::vector<float> f = s->GetFrames();
  107 + ApplyCMVN(&f);
  108 +
  109 + int64_t num_frames = f.size() / feat_dim;
  110 +
  111 + std::array<int64_t, 3> shape{1, num_frames, feat_dim};
  112 +
  113 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  114 + shape.data(), shape.size());
  115 +
  116 + int64_t len_shape = 1;
  117 + Ort::Value x_len =
  118 + Ort::Value::CreateTensor(memory_info, &num_frames, 1, &len_shape, 1);
  119 +
  120 + auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len));
  121 +
  122 + auto results =
  123 + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
  124 +
  125 + auto r = Convert(results[0], symbol_table_);
  126 +
  127 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  128 + s->SetResult(r);
  129 + }
  130 +
  131 + void ApplyCMVN(std::vector<float> *v) const {
  132 + const auto &meta_data = model_->GetModelMetadata();
  133 + const auto &mean = meta_data.mean;
  134 + const auto &inv_stddev = meta_data.inv_stddev;
  135 + int32_t feat_dim = static_cast<int32_t>(mean.size());
  136 + int32_t num_frames = static_cast<int32_t>(v->size()) / feat_dim;
  137 +
  138 + float *p = v->data();
  139 +
  140 + for (int32_t i = 0; i != num_frames; ++i) {
  141 + for (int32_t k = 0; k != feat_dim; ++k) {
  142 + p[k] = (p[k] - mean[k]) * inv_stddev[k];
  143 + }
  144 +
  145 + p += feat_dim;
  146 + }
  147 + }
  148 +
  149 + private:
  150 + OfflineRecognizerConfig config_;
  151 + SymbolTable symbol_table_;
  152 + std::unique_ptr<OfflineFireRedAsrModel> model_;
  153 + std::unique_ptr<OfflineFireRedAsrDecoder> decoder_;
  154 +};
  155 +
  156 +} // namespace sherpa_onnx
  157 +
  158 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
@@ -24,6 +24,7 @@ @@ -24,6 +24,7 @@
24 #include "onnxruntime_cxx_api.h" // NOLINT 24 #include "onnxruntime_cxx_api.h" // NOLINT
25 #include "sherpa-onnx/csrc/macros.h" 25 #include "sherpa-onnx/csrc/macros.h"
26 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 26 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
  27 +#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
27 #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" 28 #include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
28 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 29 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
29 #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" 30 #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
@@ -56,6 +57,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -56,6 +57,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
56 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 57 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
57 } 58 }
58 59
  60 + if (!config.model_config.fire_red_asr.encoder.empty()) {
  61 + return std::make_unique<OfflineRecognizerFireRedAsrImpl>(config);
  62 + }
  63 +
59 if (!config.model_config.moonshine.preprocessor.empty()) { 64 if (!config.model_config.moonshine.preprocessor.empty()) {
60 return std::make_unique<OfflineRecognizerMoonshineImpl>(config); 65 return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
61 } 66 }
@@ -237,6 +242,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -237,6 +242,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
237 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); 242 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
238 } 243 }
239 244
  245 + if (!config.model_config.fire_red_asr.encoder.empty()) {
  246 + return std::make_unique<OfflineRecognizerFireRedAsrImpl>(mgr, config);
  247 + }
  248 +
240 if (!config.model_config.moonshine.preprocessor.empty()) { 249 if (!config.model_config.moonshine.preprocessor.empty()) {
241 return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config); 250 return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
242 } 251 }
@@ -9,6 +9,7 @@ set(srcs @@ -9,6 +9,7 @@ set(srcs
9 features.cc 9 features.cc
10 keyword-spotter.cc 10 keyword-spotter.cc
11 offline-ctc-fst-decoder-config.cc 11 offline-ctc-fst-decoder-config.cc
  12 + offline-fire-red-asr-model-config.cc
12 offline-lm-config.cc 13 offline-lm-config.cc
13 offline-model-config.cc 14 offline-model-config.cc
14 offline-moonshine-model-config.cc 15 offline-moonshine-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineFireRedAsrModelConfig(py::module *m) {
  15 + using PyClass = OfflineFireRedAsrModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineFireRedAsrModelConfig")
  17 + .def(py::init<const std::string &, const std::string &>(),
  18 + py::arg("encoder"), py::arg("decoder"))
  19 + .def_readwrite("encoder", &PyClass::encoder)
  20 + .def_readwrite("decoder", &PyClass::decoder)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineFireRedAsrModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
@@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) {
25 PybindOfflineParaformerModelConfig(m); 26 PybindOfflineParaformerModelConfig(m);
26 PybindOfflineNemoEncDecCtcModelConfig(m); 27 PybindOfflineNemoEncDecCtcModelConfig(m);
27 PybindOfflineWhisperModelConfig(m); 28 PybindOfflineWhisperModelConfig(m);
  29 + PybindOfflineFireRedAsrModelConfig(m);
28 PybindOfflineTdnnModelConfig(m); 30 PybindOfflineTdnnModelConfig(m);
29 PybindOfflineZipformerCtcModelConfig(m); 31 PybindOfflineZipformerCtcModelConfig(m);
30 PybindOfflineWenetCtcModelConfig(m); 32 PybindOfflineWenetCtcModelConfig(m);
@@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) { @@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) {
33 35
34 using PyClass = OfflineModelConfig; 36 using PyClass = OfflineModelConfig;
35 py::class_<PyClass>(*m, "OfflineModelConfig") 37 py::class_<PyClass>(*m, "OfflineModelConfig")
36 - .def(  
37 - py::init<  
38 - const OfflineTransducerModelConfig &,  
39 - const OfflineParaformerModelConfig &,  
40 - const OfflineNemoEncDecCtcModelConfig &,  
41 - const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,  
42 - const OfflineZipformerCtcModelConfig &,  
43 - const OfflineWenetCtcModelConfig &,  
44 - const OfflineSenseVoiceModelConfig &,  
45 - const OfflineMoonshineModelConfig &, const std::string &,  
46 - const std::string &, int32_t, bool, const std::string &,  
47 - const std::string &, const std::string &, const std::string &>(),  
48 - py::arg("transducer") = OfflineTransducerModelConfig(),  
49 - py::arg("paraformer") = OfflineParaformerModelConfig(),  
50 - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),  
51 - py::arg("whisper") = OfflineWhisperModelConfig(),  
52 - py::arg("tdnn") = OfflineTdnnModelConfig(),  
53 - py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),  
54 - py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),  
55 - py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),  
56 - py::arg("moonshine") = OfflineMoonshineModelConfig(),  
57 - py::arg("telespeech_ctc") = "", py::arg("tokens"),  
58 - py::arg("num_threads"), py::arg("debug") = false,  
59 - py::arg("provider") = "cpu", py::arg("model_type") = "",  
60 - py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") 38 + .def(py::init<const OfflineTransducerModelConfig &,
  39 + const OfflineParaformerModelConfig &,
  40 + const OfflineNemoEncDecCtcModelConfig &,
  41 + const OfflineWhisperModelConfig &,
  42 + const OfflineFireRedAsrModelConfig &,
  43 + const OfflineTdnnModelConfig &,
  44 + const OfflineZipformerCtcModelConfig &,
  45 + const OfflineWenetCtcModelConfig &,
  46 + const OfflineSenseVoiceModelConfig &,
  47 + const OfflineMoonshineModelConfig &, const std::string &,
  48 + const std::string &, int32_t, bool, const std::string &,
  49 + const std::string &, const std::string &,
  50 + const std::string &>(),
  51 + py::arg("transducer") = OfflineTransducerModelConfig(),
  52 + py::arg("paraformer") = OfflineParaformerModelConfig(),
  53 + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
  54 + py::arg("whisper") = OfflineWhisperModelConfig(),
  55 + py::arg("fire_red_asr") = OfflineFireRedAsrModelConfig(),
  56 + py::arg("tdnn") = OfflineTdnnModelConfig(),
  57 + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
  58 + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
  59 + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
  60 + py::arg("moonshine") = OfflineMoonshineModelConfig(),
  61 + py::arg("telespeech_ctc") = "", py::arg("tokens"),
  62 + py::arg("num_threads"), py::arg("debug") = false,
  63 + py::arg("provider") = "cpu", py::arg("model_type") = "",
  64 + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
61 .def_readwrite("transducer", &PyClass::transducer) 65 .def_readwrite("transducer", &PyClass::transducer)
62 .def_readwrite("paraformer", &PyClass::paraformer) 66 .def_readwrite("paraformer", &PyClass::paraformer)
63 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 67 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
64 .def_readwrite("whisper", &PyClass::whisper) 68 .def_readwrite("whisper", &PyClass::whisper)
  69 + .def_readwrite("fire_red_asr", &PyClass::fire_red_asr)
65 .def_readwrite("tdnn", &PyClass::tdnn) 70 .def_readwrite("tdnn", &PyClass::tdnn)
66 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) 71 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
67 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 72 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
@@ -6,6 +6,7 @@ from typing import List, Optional @@ -6,6 +6,7 @@ from typing import List, Optional
6 from _sherpa_onnx import ( 6 from _sherpa_onnx import (
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
8 OfflineCtcFstDecoderConfig, 8 OfflineCtcFstDecoderConfig,
  9 + OfflineFireRedAsrModelConfig,
9 OfflineLMConfig, 10 OfflineLMConfig,
10 OfflineModelConfig, 11 OfflineModelConfig,
11 OfflineMoonshineModelConfig, 12 OfflineMoonshineModelConfig,
@@ -572,6 +573,78 @@ class OfflineRecognizer(object): @@ -572,6 +573,78 @@ class OfflineRecognizer(object):
572 return self 573 return self
573 574
574 @classmethod 575 @classmethod
  576 + def from_fire_red_asr(
  577 + cls,
  578 + encoder: str,
  579 + decoder: str,
  580 + tokens: str,
  581 + num_threads: int = 1,
  582 + decoding_method: str = "greedy_search",
  583 + debug: bool = False,
  584 + provider: str = "cpu",
  585 + rule_fsts: str = "",
  586 + rule_fars: str = "",
  587 + ):
  588 + """
  589 + Please refer to
  590 + `<https://k2-fsa.github.io/sherpa/onnx/fire_red_asr/index.html>`_
  591 + to download pre-trained models for different kinds of FireRedAsr models,
  592 + e.g., xs, large, etc.
  593 +
  594 + Args:
  595 + encoder:
  596 + Path to the encoder model.
  597 + decoder:
  598 + Path to the decoder model.
  599 + tokens:
  600 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  601 + columns::
  602 +
  603 + symbol integer_id
  604 + num_threads:
  605 + Number of threads for neural network computation.
  606 + decoding_method:
  607 + Valid values: greedy_search.
  608 + debug:
  609 + True to show debug messages.
  610 + provider:
  611 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  612 + rule_fsts:
  613 + If not empty, it specifies fsts for inverse text normalization.
  614 + If there are multiple fsts, they are separated by a comma.
  615 + rule_fars:
  616 + If not empty, it specifies fst archives for inverse text normalization.
  617 + If there are multiple archives, they are separated by a comma.
  618 + """
  619 + self = cls.__new__(cls)
  620 + model_config = OfflineModelConfig(
  621 + fire_red_asr=OfflineFireRedAsrModelConfig(
  622 + encoder=encoder,
  623 + decoder=decoder,
  624 + ),
  625 + tokens=tokens,
  626 + num_threads=num_threads,
  627 + debug=debug,
  628 + provider=provider,
  629 + )
  630 +
  631 + feat_config = FeatureExtractorConfig(
  632 + sampling_rate=16000,
  633 + feature_dim=80,
  634 + )
  635 +
  636 + recognizer_config = OfflineRecognizerConfig(
  637 + feat_config=feat_config,
  638 + model_config=model_config,
  639 + decoding_method=decoding_method,
  640 + rule_fsts=rule_fsts,
  641 + rule_fars=rule_fars,
  642 + )
  643 + self.recognizer = _Recognizer(recognizer_config)
  644 + self.config = recognizer_config
  645 + return self
  646 +
  647 + @classmethod
575 def from_moonshine( 648 def from_moonshine(
576 cls, 649 cls,
577 preprocessor: str, 650 preprocessor: str,