Fangjun Kuang
Committed by GitHub

Support RKNN for Zipformer CTC models. (#1948)

... ... @@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN)
list(APPEND sources
./rknn/online-stream-rknn.cc
./rknn/online-transducer-greedy-search-decoder-rknn.cc
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/utils.cc
)
endif()
... ...
... ... @@ -43,12 +43,14 @@ class OnlineCtcDecoder {
/** Run streaming CTC decoding given the output from the encoder model.
*
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
* lob_probs.
* @param log_probs A 3-D tensor of shape
* (batch_size, num_frames, vocab_size) containing
* lob_probs in row major.
*
* @param results Input & Output parameters..
*/
virtual void Decode(Ort::Value log_probs,
virtual void Decode(const float *log_probs, int32_t batch_size,
int32_t num_frames, int32_t vocab_size,
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
... ...
... ... @@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
processed_frames += num_rows;
}
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size,
int32_t num_frames, int32_t vocab_size,
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss, int32_t n) {
std::vector<int64_t> log_probs_shape =
log_probs.GetTensorTypeAndShapeInfo().GetShape();
if (log_probs_shape[0] != results->size()) {
if (batch_size != results->size()) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
static_cast<int32_t>(log_probs_shape[0]),
static_cast<int32_t>(results->size()));
batch_size, static_cast<int32_t>(results->size()));
exit(-1);
}
if (log_probs_shape[0] != n) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
static_cast<int32_t>(log_probs_shape[0]), n);
if (batch_size != n) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size,
n);
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
const float *p = log_probs.GetTensorData<float>();
const float *p = log_probs;
for (int32_t i = 0; i != batch_size; ++i) {
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
... ...
... ... @@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
int32_t blank_id);
void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results,
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) override;
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
... ...
... ... @@ -13,23 +13,16 @@
namespace sherpa_onnx {
void OnlineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
const float *log_probs, int32_t batch_size, int32_t num_frames,
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
std::vector<int64_t> log_probs_shape =
log_probs.GetTensorTypeAndShapeInfo().GetShape();
if (log_probs_shape[0] != results->size()) {
if (batch_size != results->size()) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
static_cast<int32_t>(log_probs_shape[0]),
static_cast<int32_t>(results->size()));
batch_size, static_cast<int32_t>(results->size()));
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
const float *p = log_probs.GetTensorData<float>();
const float *p = log_probs;
for (int32_t b = 0; b != batch_size; ++b) {
auto &r = (*results)[b];
... ...
... ... @@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
: blank_id_(blank_id) {}
void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results,
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) override;
private:
... ...
... ... @@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const {
transducer.decoder.c_str(), transducer.joiner.c_str());
return false;
}
if (!zipformer2_ctc.model.empty() &&
EndsWith(zipformer2_ctc.model, ".rknn")) {
SHERPA_ONNX_LOGE(
"--provider is %s, which is not rknn, but you pass rknn model "
"filename for zipformer2_ctc: '%s'",
provider_config.provider.c_str(), zipformer2_ctc.model.c_str());
return false;
}
}
if (provider_config.provider == "rknn") {
... ... @@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const {
transducer.joiner.c_str());
return false;
}
if (!zipformer2_ctc.model.empty() &&
EndsWith(zipformer2_ctc.model, ".onnx")) {
SHERPA_ONNX_LOGE(
"--provider rknn, but you pass onnx model filename for "
"zipformer2_ctc: '%s'",
zipformer2_ctc.model.c_str());
return false;
}
}
if (!tokens_buf.empty() && FileExists(tokens)) {
... ...
... ... @@ -24,11 +24,10 @@
namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t subsampling_factor, int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
... ... @@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results, ss, n);
std::vector<int64_t> log_probs_shape =
out[0].GetTensorTypeAndShapeInfo().GetShape();
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
log_probs_shape[1], log_probs_shape[2], &results, ss, n);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
... ... @@ -196,7 +198,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
return r;
... ... @@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<OnlineCtcDecoderResult> results(1);
results[0] = std::move(s->GetCtcResult());
decoder_->Decode(std::move(out[0]), &results, &s, 1);
std::vector<int64_t> log_probs_shape =
out[0].GetTensorTypeAndShapeInfo().GetShape();
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
s->SetCtcResult(results[0]);
}
... ...
... ... @@ -27,6 +27,7 @@
#include "sherpa-onnx/csrc/text-utils.h"
#if SHERPA_ONNX_ENABLE_RKNN
#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h"
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
#endif
... ... @@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
// Currently, only zipformer v1 is suported for rknn
if (config.model_config.transducer.encoder.empty()) {
if (config.model_config.transducer.encoder.empty() &&
config.model_config.zipformer2_ctc.model.empty()) {
SHERPA_ONNX_LOGE(
"Only Zipformer transducers are currently supported by rknn. "
"Fallback to CPU");
} else {
"Only Zipformer transducers and CTC models are currently supported "
"by rknn. Fallback to CPU");
} else if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcRknnImpl>(config);
}
#else
SHERPA_ONNX_LOGE(
... ...
// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
#include <algorithm>
#include <ios>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
// defined in ../online-recognizer-ctc-impl.h
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor, int32_t segment,
int32_t frames_since_start);
class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config)
: OnlineRecognizerImpl(config),
config_(config),
model_(
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
endpoint_(config_.endpoint_config) {
if (!config.model_config.tokens_buf.empty()) {
sym_ = SymbolTable(config.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = SymbolTable(config.model_config.tokens, true);
}
InitDecoder();
}
template <typename Manager>
explicit OnlineRecognizerCtcRknnImpl(Manager *mgr,
const OnlineRecognizerConfig &config)
: OnlineRecognizerImpl(mgr, config),
config_(config),
model_(
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
InitDecoder();
}
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config);
stream->SetZipformerEncoderStates(model_->GetInitStates());
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
for (int32_t i = 0; i != n; ++i) {
DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i]));
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
OnlineCtcDecoderResult decoder_result = s->GetCtcResult();
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
return r;
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) {
return false;
}
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
// subsampling factor is 4
int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const override {
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetCtcResult();
if (!r.tokens.empty()) {
s->GetCurrentSegment() += 1;
}
// clear result
s->SetCtcResult({});
// clear states
reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerEncoderStates(
model_->GetInitStates());
s->GetFasterDecoderProcessedFrames() = 0;
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}
private:
void InitDecoder() {
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
!sym_.Contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
int32_t blank_id = 0;
if (sym_.Contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.Contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.Contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}
if (!config_.ctc_fst_decoder_config.graph.empty()) {
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
config_.ctc_fst_decoder_config, blank_id);
} else if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE(
"Unsupported decoding method: %s for streaming CTC models",
config_.decoding_method.c_str());
exit(-1);
}
}
void DecodeStream(OnlineStreamRknn *s) const {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feat_dim = s->FeatureDim();
const auto num_processed_frames = s->GetNumProcessedFrames();
std::vector<float> features =
s->GetFrames(num_processed_frames, chunk_size);
s->GetNumProcessedFrames() += chunk_shift;
auto &states = s->GetZipformerEncoderStates();
auto p = model_->Run(features, std::move(states));
states = std::move(p.second);
std::vector<OnlineCtcDecoderResult> results(1);
results[0] = std::move(s->GetCtcResult());
auto attr = model_->GetOutAttr();
decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2],
&results, reinterpret_cast<OnlineStream **>(&s), 1);
s->SetCtcResult(results[0]);
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineZipformerCtcModelRknn> model_;
std::unique_ptr<OnlineCtcDecoder> decoder_;
SymbolTable sym_;
Endpoint endpoint_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
... ...
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OnlineZipformerCtcModelRknn::Impl {
public:
~Impl() {
auto ret = rknn_destroy(ctx_);
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE("Failed to destroy the context");
}
}
explicit Impl(const OnlineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
int32_t ret = RKNN_SUCC;
switch (config_.num_threads) {
case 1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
break;
case -1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
break;
case -2:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
break;
case -3:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
break;
case -4:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
config_.num_threads);
break;
}
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run the model (You can ignore it if "
"you "
"are not using RK3588.");
}
}
// TODO(fangjun): Support Android
std::vector<std::vector<uint8_t>> GetInitStates() const {
// input_attrs_[0] is for the feature
// input_attrs_[1:] is for states
// so we use -1 here
std::vector<std::vector<uint8_t>> states(input_attrs_.size() - 1);
int32_t i = -1;
for (auto &attr : input_attrs_) {
i += 1;
if (i == 0) {
// skip processing the attr for features.
continue;
}
if (attr.type == RKNN_TENSOR_FLOAT16) {
states[i - 1].resize(attr.n_elems * sizeof(float));
} else if (attr.type == RKNN_TENSOR_INT64) {
states[i - 1].resize(attr.n_elems * sizeof(int64_t));
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
}
return states;
}
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
std::vector<float> features,
std::vector<std::vector<uint8_t>> states) const {
std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
auto &input = inputs[i];
auto &attr = input_attrs_[i];
input.index = attr.index;
if (attr.type == RKNN_TENSOR_FLOAT16) {
input.type = RKNN_TENSOR_FLOAT32;
} else if (attr.type == RKNN_TENSOR_INT64) {
input.type = RKNN_TENSOR_INT64;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
input.fmt = attr.fmt;
if (i == 0) {
input.buf = reinterpret_cast<void *>(features.data());
input.size = features.size() * sizeof(float);
} else {
input.buf = reinterpret_cast<void *>(states[i - 1].data());
input.size = states[i - 1].size();
}
}
std::vector<float> out(output_attrs_[0].n_elems);
// Note(fangjun): We can reuse the memory from input argument `states`
// auto next_states = GetInitStates();
auto &next_states = states;
std::vector<rknn_output> outputs(output_attrs_.size());
for (int32_t i = 0; i < outputs.size(); ++i) {
auto &output = outputs[i];
auto &attr = output_attrs_[i];
output.index = attr.index;
output.is_prealloc = 1;
if (attr.type == RKNN_TENSOR_FLOAT16) {
output.want_float = 1;
} else if (attr.type == RKNN_TENSOR_INT64) {
output.want_float = 0;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
if (i == 0) {
output.size = out.size() * sizeof(float);
output.buf = reinterpret_cast<void *>(out.data());
} else {
output.size = next_states[i - 1].size();
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
}
}
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
for (int32_t i = 0; i < next_states.size(); ++i) {
const auto &attr = input_attrs_[i + 1];
if (attr.n_dims == 4) {
// TODO(fangjun): The transpose is copied from
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
// I don't understand why we need to do that.
std::vector<uint8_t> dst(next_states[i].size());
int32_t n = attr.dims[0];
int32_t h = attr.dims[1];
int32_t w = attr.dims[2];
int32_t c = attr.dims[3];
ConvertNCHWtoNHWC(
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
reinterpret_cast<float *>(dst.data()));
next_states[i] = std::move(dst);
}
}
return {std::move(out), std::move(next_states)};
}
int32_t ChunkSize() const { return T_; }
int32_t ChunkShift() const { return decode_chunk_len_; }
int32_t VocabSize() const { return vocab_size_; }
rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; }
private:
void Init(void *model_data, size_t model_data_length) {
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
config_.zipformer2_ctc.model.c_str());
if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs_.resize(io_num.n_input);
output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
rknn_custom_string custom_string;
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
if (meta.count("T")) {
T_ = atoi(meta.at("T").c_str());
}
if (meta.count("decode_chunk_len")) {
decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str());
}
vocab_size_ = output_attrs_[0].dims[2];
if (config_.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("T: %{public}d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size);
#else
SHERPA_ONNX_LOGE("T: %d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
#endif
}
if (T_ == 0) {
SHERPA_ONNX_LOGE(
"Invalid T. Please use the script from icefall to export your model");
SHERPA_ONNX_EXIT(-1);
}
if (decode_chunk_len_ == 0) {
SHERPA_ONNX_LOGE(
"Invalid decode_chunk_len. Please use the script from icefall to "
"export your model");
SHERPA_ONNX_EXIT(-1);
}
}
private:
OnlineModelConfig config_;
rknn_context ctx_ = 0;
std::vector<rknn_tensor_attr> input_attrs_;
std::vector<rknn_tensor_attr> output_attrs_;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t vocab_size_ = 0;
};
OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default;
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
Manager *mgr, const OnlineModelConfig &config)
: impl_(std::make_unique<OnlineZipformerCtcModelRknn>(mgr, config)) {}
std::vector<std::vector<uint8_t>> OnlineZipformerCtcModelRknn::GetInitStates()
const {
return impl_->GetInitStates();
}
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>>
OnlineZipformerCtcModelRknn::Run(
std::vector<float> features,
std::vector<std::vector<uint8_t>> states) const {
return impl_->Run(std::move(features), std::move(states));
}
int32_t OnlineZipformerCtcModelRknn::ChunkSize() const {
return impl_->ChunkSize();
}
int32_t OnlineZipformerCtcModelRknn::ChunkShift() const {
return impl_->ChunkShift();
}
int32_t OnlineZipformerCtcModelRknn::VocabSize() const {
return impl_->VocabSize();
}
rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const {
return impl_->GetOutAttr();
}
#if __ANDROID_API__ >= 9
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif
#if __OHOS__
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
NativeResourceManager *mgr, const OnlineModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
#include <memory>
#include <utility>
#include <vector>
#include "rknn_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineZipformerCtcModelRknn {
public:
~OnlineZipformerCtcModelRknn();
explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config);
template <typename Manager>
OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config);
std::vector<std::vector<uint8_t>> GetInitStates() const;
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
std::vector<float> features,
std::vector<std::vector<uint8_t>> states) const;
int32_t ChunkSize() const;
int32_t ChunkShift() const;
int32_t VocabSize() const;
rknn_tensor_attr GetOutAttr() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
... ...
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
... ... @@ -22,68 +22,11 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
// chw -> hwc
static void Transpose(const float *src, int32_t n, int32_t channel,
int32_t height, int32_t width, float *dst) {
for (int32_t i = 0; i < n; ++i) {
for (int32_t h = 0; h < height; ++h) {
for (int32_t w = 0; w < width; ++w) {
for (int32_t c = 0; c < channel; ++c) {
// dst[h, w, c] = src[c, h, w]
dst[i * height * width * channel + h * width * channel + w * channel +
c] = src[i * height * width * channel + c * height * width +
h * width + w];
}
}
}
}
}
static std::string ToString(const rknn_tensor_attr &attr) {
std::ostringstream os;
os << "{";
os << attr.index;
os << ", name: " << attr.name;
os << ", shape: (";
std::string sep;
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
os << sep << attr.dims[i];
sep = ",";
}
os << ")";
os << ", n_elems: " << attr.n_elems;
os << ", size: " << attr.size;
os << ", fmt: " << get_format_string(attr.fmt);
os << ", type: " << get_type_string(attr.type);
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
os << "}";
return os.str();
}
static std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string) {
std::unordered_map<std::string, std::string> ans;
std::vector<std::string> fields;
SplitStringToVector(custom_string.string, ";", false, &fields);
std::vector<std::string> tmp;
for (const auto &f : fields) {
SplitStringToVector(f, "=", false, &tmp);
if (tmp.size() != 2) {
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
f.c_str());
SHERPA_ONNX_EXIT(-1);
}
ans[std::move(tmp[0])] = std::move(tmp[1]);
}
return ans;
}
class OnlineZipformerTransducerModelRknn::Impl {
public:
~Impl() {
... ... @@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
for (int32_t i = 0; i < next_states.size(); ++i) {
const auto &attr = encoder_input_attrs_[i + 1];
if (attr.n_dims == 4) {
// TODO(fangjun): The transpose is copied from
// TODO(fangjun): The ConvertNCHWtoNHWC is copied from
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
// I don't understand why we need to do that.
std::vector<uint8_t> dst(next_states[i].size());
... ... @@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
int32_t h = attr.dims[1];
int32_t w = attr.dims[2];
int32_t c = attr.dims[3];
Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c,
h, w, reinterpret_cast<float *>(dst.data()));
ConvertNCHWtoNHWC(
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
reinterpret_cast<float *>(dst.data()));
next_states[i] = std::move(dst);
}
}
... ... @@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
#if __OHOS__
SHERPA_ONNX_LOGE("T: %{public}d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_);
#else
SHERPA_ONNX_LOGE("T: %d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
#endif
}
}
... ... @@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl {
SHERPA_ONNX_EXIT(-1);
}
context_size_ = decoder_input_attrs_[0].dims[1];
if (config_.debug) {
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
}
i = 0;
for (auto &attr : decoder_output_attrs_) {
memset(&attr, 0, sizeof(attr));
... ...
... ... @@ -14,8 +14,11 @@
namespace sherpa_onnx {
// this is for zipformer v1, i.e., the folder
// pruned_transducer_statelss7_streaming from icefall
// this is for zipformer v1 and v2, i.e., the folder
// pruned_transducer_statelss7_streaming
// and
// zipformer
// from icefall
class OnlineZipformerTransducerModelRknn {
public:
~OnlineZipformerTransducerModelRknn();
... ...
// sherpa-onnx/csrc/utils.cc
//
// Copyright 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/utils.h"
#include <sstream>
#include <unordered_map>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
int32_t height, int32_t width, float *dst) {
for (int32_t i = 0; i < n; ++i) {
for (int32_t h = 0; h < height; ++h) {
for (int32_t w = 0; w < width; ++w) {
for (int32_t c = 0; c < channel; ++c) {
// dst[h, w, c] = src[c, h, w]
dst[i * height * width * channel + h * width * channel + w * channel +
c] = src[i * height * width * channel + c * height * width +
h * width + w];
}
}
}
}
}
std::string ToString(const rknn_tensor_attr &attr) {
std::ostringstream os;
os << "{";
os << attr.index;
os << ", name: " << attr.name;
os << ", shape: (";
std::string sep;
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
os << sep << attr.dims[i];
sep = ",";
}
os << ")";
os << ", n_elems: " << attr.n_elems;
os << ", size: " << attr.size;
os << ", fmt: " << get_format_string(attr.fmt);
os << ", type: " << get_type_string(attr.type);
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
os << "}";
return os.str();
}
std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string) {
std::unordered_map<std::string, std::string> ans;
std::vector<std::string> fields;
SplitStringToVector(custom_string.string, ";", false, &fields);
std::vector<std::string> tmp;
for (const auto &f : fields) {
SplitStringToVector(f, "=", false, &tmp);
if (tmp.size() != 2) {
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
f.c_str());
SHERPA_ONNX_EXIT(-1);
}
ans[std::move(tmp[0])] = std::move(tmp[1]);
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/utils.h
//
// Copyright 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_
#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_
#include <string>
#include <unordered_map>
#include "rknn_api.h" // NOLINT
namespace sherpa_onnx {
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
int32_t height, int32_t width, float *dst);
std::string ToString(const rknn_tensor_attr &attr);
std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
... ...
... ... @@ -83,6 +83,7 @@ for a list of pre-trained models to download.
po.Read(argc, argv);
if (po.NumArgs() < 1) {
po.PrintUsage();
fprintf(stderr, "Error! Please provide at lease 1 wav file\n");
exit(EXIT_FAILURE);
}
... ...