Fangjun Kuang
Committed by GitHub

Support RKNN for Zipformer CTC models. (#1948)

@@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN) @@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN)
155 list(APPEND sources 155 list(APPEND sources
156 ./rknn/online-stream-rknn.cc 156 ./rknn/online-stream-rknn.cc
157 ./rknn/online-transducer-greedy-search-decoder-rknn.cc 157 ./rknn/online-transducer-greedy-search-decoder-rknn.cc
  158 + ./rknn/online-zipformer-ctc-model-rknn.cc
158 ./rknn/online-zipformer-transducer-model-rknn.cc 159 ./rknn/online-zipformer-transducer-model-rknn.cc
  160 + ./rknn/utils.cc
159 ) 161 )
160 162
161 endif() 163 endif()
@@ -43,12 +43,14 @@ class OnlineCtcDecoder { @@ -43,12 +43,14 @@ class OnlineCtcDecoder {
43 43
44 /** Run streaming CTC decoding given the output from the encoder model. 44 /** Run streaming CTC decoding given the output from the encoder model.
45 * 45 *
46 - * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing  
47 - * lob_probs. 46 + * @param log_probs A 3-D tensor of shape
  47 + * (batch_size, num_frames, vocab_size) containing
  48 + * lob_probs in row major.
48 * 49 *
49 * @param results Input & Output parameters.. 50 * @param results Input & Output parameters..
50 */ 51 */
51 - virtual void Decode(Ort::Value log_probs, 52 + virtual void Decode(const float *log_probs, int32_t batch_size,
  53 + int32_t num_frames, int32_t vocab_size,
52 std::vector<OnlineCtcDecoderResult> *results, 54 std::vector<OnlineCtcDecoderResult> *results,
53 OnlineStream **ss = nullptr, int32_t n = 0) = 0; 55 OnlineStream **ss = nullptr, int32_t n = 0) = 0;
54 56
@@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, @@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
91 processed_frames += num_rows; 91 processed_frames += num_rows;
92 } 92 }
93 93
94 -void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, 94 +void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size,
  95 + int32_t num_frames, int32_t vocab_size,
95 std::vector<OnlineCtcDecoderResult> *results, 96 std::vector<OnlineCtcDecoderResult> *results,
96 OnlineStream **ss, int32_t n) { 97 OnlineStream **ss, int32_t n) {
97 - std::vector<int64_t> log_probs_shape =  
98 - log_probs.GetTensorTypeAndShapeInfo().GetShape();  
99 -  
100 - if (log_probs_shape[0] != results->size()) { 98 + if (batch_size != results->size()) {
101 SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", 99 SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
102 - static_cast<int32_t>(log_probs_shape[0]),  
103 - static_cast<int32_t>(results->size())); 100 + batch_size, static_cast<int32_t>(results->size()));
104 exit(-1); 101 exit(-1);
105 } 102 }
106 103
107 - if (log_probs_shape[0] != n) {  
108 - SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",  
109 - static_cast<int32_t>(log_probs_shape[0]), n); 104 + if (batch_size != n) {
  105 + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size,
  106 + n);
110 exit(-1); 107 exit(-1);
111 } 108 }
112 109
113 - int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);  
114 - int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);  
115 - int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);  
116 -  
117 - const float *p = log_probs.GetTensorData<float>(); 110 + const float *p = log_probs;
118 111
119 for (int32_t i = 0; i != batch_size; ++i) { 112 for (int32_t i = 0; i != batch_size; ++i) {
120 DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, 113 DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
@@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { @@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
19 OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, 19 OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
20 int32_t blank_id); 20 int32_t blank_id);
21 21
22 - void Decode(Ort::Value log_probs,  
23 - std::vector<OnlineCtcDecoderResult> *results, 22 + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
  23 + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
24 OnlineStream **ss = nullptr, int32_t n = 0) override; 24 OnlineStream **ss = nullptr, int32_t n = 0) override;
25 25
26 std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder() 26 std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
@@ -13,23 +13,16 @@ @@ -13,23 +13,16 @@
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 void OnlineCtcGreedySearchDecoder::Decode( 15 void OnlineCtcGreedySearchDecoder::Decode(
16 - Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results, 16 + const float *log_probs, int32_t batch_size, int32_t num_frames,
  17 + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
17 OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { 18 OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
18 - std::vector<int64_t> log_probs_shape =  
19 - log_probs.GetTensorTypeAndShapeInfo().GetShape();  
20 -  
21 - if (log_probs_shape[0] != results->size()) { 19 + if (batch_size != results->size()) {
22 SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", 20 SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
23 - static_cast<int32_t>(log_probs_shape[0]),  
24 - static_cast<int32_t>(results->size())); 21 + batch_size, static_cast<int32_t>(results->size()));
25 exit(-1); 22 exit(-1);
26 } 23 }
27 24
28 - int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);  
29 - int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);  
30 - int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);  
31 -  
32 - const float *p = log_probs.GetTensorData<float>(); 25 + const float *p = log_probs;
33 26
34 for (int32_t b = 0; b != batch_size; ++b) { 27 for (int32_t b = 0; b != batch_size; ++b) {
35 auto &r = (*results)[b]; 28 auto &r = (*results)[b];
@@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { @@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
16 explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) 16 explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
17 : blank_id_(blank_id) {} 17 : blank_id_(blank_id) {}
18 18
19 - void Decode(Ort::Value log_probs,  
20 - std::vector<OnlineCtcDecoderResult> *results, 19 + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
  20 + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
21 OnlineStream **ss = nullptr, int32_t n = 0) override; 21 OnlineStream **ss = nullptr, int32_t n = 0) override;
22 22
23 private: 23 private:
@@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const { @@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const {
76 transducer.decoder.c_str(), transducer.joiner.c_str()); 76 transducer.decoder.c_str(), transducer.joiner.c_str());
77 return false; 77 return false;
78 } 78 }
  79 +
  80 + if (!zipformer2_ctc.model.empty() &&
  81 + EndsWith(zipformer2_ctc.model, ".rknn")) {
  82 + SHERPA_ONNX_LOGE(
  83 + "--provider is %s, which is not rknn, but you pass rknn model "
  84 + "filename for zipformer2_ctc: '%s'",
  85 + provider_config.provider.c_str(), zipformer2_ctc.model.c_str());
  86 + return false;
  87 + }
79 } 88 }
80 89
81 if (provider_config.provider == "rknn") { 90 if (provider_config.provider == "rknn") {
@@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const { @@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const {
89 transducer.joiner.c_str()); 98 transducer.joiner.c_str());
90 return false; 99 return false;
91 } 100 }
  101 +
  102 + if (!zipformer2_ctc.model.empty() &&
  103 + EndsWith(zipformer2_ctc.model, ".onnx")) {
  104 + SHERPA_ONNX_LOGE(
  105 + "--provider rknn, but you pass onnx model filename for "
  106 + "zipformer2_ctc: '%s'",
  107 + zipformer2_ctc.model.c_str());
  108 + return false;
  109 + }
92 } 110 }
93 111
94 if (!tokens_buf.empty() && FileExists(tokens)) { 112 if (!tokens_buf.empty() && FileExists(tokens)) {
@@ -24,12 +24,11 @@ @@ -24,12 +24,11 @@
24 24
25 namespace sherpa_onnx { 25 namespace sherpa_onnx {
26 26
27 -static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,  
28 - const SymbolTable &sym_table,  
29 - float frame_shift_ms,  
30 - int32_t subsampling_factor,  
31 - int32_t segment,  
32 - int32_t frames_since_start) { 27 +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
  28 + const SymbolTable &sym_table,
  29 + float frame_shift_ms,
  30 + int32_t subsampling_factor, int32_t segment,
  31 + int32_t frames_since_start) {
33 OnlineRecognizerResult r; 32 OnlineRecognizerResult r;
34 r.tokens.reserve(src.tokens.size()); 33 r.tokens.reserve(src.tokens.size());
35 r.timestamps.reserve(src.tokens.size()); 34 r.timestamps.reserve(src.tokens.size());
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
182 std::vector<std::vector<Ort::Value>> next_states = 181 std::vector<std::vector<Ort::Value>> next_states =
183 model_->UnStackStates(std::move(out_states)); 182 model_->UnStackStates(std::move(out_states));
184 183
185 - decoder_->Decode(std::move(out[0]), &results, ss, n); 184 + std::vector<int64_t> log_probs_shape =
  185 + out[0].GetTensorTypeAndShapeInfo().GetShape();
  186 + decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
  187 + log_probs_shape[1], log_probs_shape[2], &results, ss, n);
186 188
187 for (int32_t k = 0; k != n; ++k) { 189 for (int32_t k = 0; k != n; ++k) {
188 ss[k]->SetCtcResult(results[k]); 190 ss[k]->SetCtcResult(results[k]);
@@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
196 // TODO(fangjun): Remember to change these constants if needed 198 // TODO(fangjun): Remember to change these constants if needed
197 int32_t frame_shift_ms = 10; 199 int32_t frame_shift_ms = 10;
198 int32_t subsampling_factor = 4; 200 int32_t subsampling_factor = 4;
199 - auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,  
200 - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 201 + auto r =
  202 + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
  203 + s->GetCurrentSegment(), s->GetNumFramesSinceStart());
201 r.text = ApplyInverseTextNormalization(r.text); 204 r.text = ApplyInverseTextNormalization(r.text);
202 return r; 205 return r;
203 } 206 }
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
306 std::vector<OnlineCtcDecoderResult> results(1); 309 std::vector<OnlineCtcDecoderResult> results(1);
307 results[0] = std::move(s->GetCtcResult()); 310 results[0] = std::move(s->GetCtcResult());
308 311
309 - decoder_->Decode(std::move(out[0]), &results, &s, 1); 312 + std::vector<int64_t> log_probs_shape =
  313 + out[0].GetTensorTypeAndShapeInfo().GetShape();
  314 + decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
  315 + log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
310 s->SetCtcResult(results[0]); 316 s->SetCtcResult(results[0]);
311 } 317 }
312 318
@@ -27,6 +27,7 @@ @@ -27,6 +27,7 @@
27 #include "sherpa-onnx/csrc/text-utils.h" 27 #include "sherpa-onnx/csrc/text-utils.h"
28 28
29 #if SHERPA_ONNX_ENABLE_RKNN 29 #if SHERPA_ONNX_ENABLE_RKNN
  30 +#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h"
30 #include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h" 31 #include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
31 #endif 32 #endif
32 33
@@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
37 if (config.model_config.provider_config.provider == "rknn") { 38 if (config.model_config.provider_config.provider == "rknn") {
38 #if SHERPA_ONNX_ENABLE_RKNN 39 #if SHERPA_ONNX_ENABLE_RKNN
39 // Currently, only zipformer v1 is suported for rknn 40 // Currently, only zipformer v1 is suported for rknn
40 - if (config.model_config.transducer.encoder.empty()) { 41 + if (config.model_config.transducer.encoder.empty() &&
  42 + config.model_config.zipformer2_ctc.model.empty()) {
41 SHERPA_ONNX_LOGE( 43 SHERPA_ONNX_LOGE(
42 - "Only Zipformer transducers are currently supported by rknn. "  
43 - "Fallback to CPU");  
44 - } else { 44 + "Only Zipformer transducers and CTC models are currently supported "
  45 + "by rknn. Fallback to CPU");
  46 + } else if (!config.model_config.transducer.encoder.empty()) {
45 return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config); 47 return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
  48 + } else if (!config.model_config.zipformer2_ctc.model.empty()) {
  49 + return std::make_unique<OnlineRecognizerCtcRknnImpl>(config);
46 } 50 }
47 #else 51 #else
48 SHERPA_ONNX_LOGE( 52 SHERPA_ONNX_LOGE(
  1 +// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <ios>
  10 +#include <memory>
  11 +#include <sstream>
  12 +#include <string>
  13 +#include <utility>
  14 +#include <vector>
  15 +
  16 +#include "sherpa-onnx/csrc/file-utils.h"
  17 +#include "sherpa-onnx/csrc/macros.h"
  18 +#include "sherpa-onnx/csrc/online-ctc-decoder.h"
  19 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
  20 +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
  21 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  22 +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
  23 +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
  24 +#include "sherpa-onnx/csrc/symbol-table.h"
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +// defined in ../online-recognizer-ctc-impl.h
  29 +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
  30 + const SymbolTable &sym_table,
  31 + float frame_shift_ms,
  32 + int32_t subsampling_factor, int32_t segment,
  33 + int32_t frames_since_start);
  34 +
  35 +class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
  36 + public:
  37 + explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config)
  38 + : OnlineRecognizerImpl(config),
  39 + config_(config),
  40 + model_(
  41 + std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
  42 + endpoint_(config_.endpoint_config) {
  43 + if (!config.model_config.tokens_buf.empty()) {
  44 + sym_ = SymbolTable(config.model_config.tokens_buf, false);
  45 + } else {
  46 + /// assuming tokens_buf and tokens are guaranteed not being both empty
  47 + sym_ = SymbolTable(config.model_config.tokens, true);
  48 + }
  49 +
  50 + InitDecoder();
  51 + }
  52 +
  53 + template <typename Manager>
  54 + explicit OnlineRecognizerCtcRknnImpl(Manager *mgr,
  55 + const OnlineRecognizerConfig &config)
  56 + : OnlineRecognizerImpl(mgr, config),
  57 + config_(config),
  58 + model_(
  59 + std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
  60 + sym_(mgr, config.model_config.tokens),
  61 + endpoint_(config_.endpoint_config) {
  62 + InitDecoder();
  63 + }
  64 +
  65 + std::unique_ptr<OnlineStream> CreateStream() const override {
  66 + auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config);
  67 + stream->SetZipformerEncoderStates(model_->GetInitStates());
  68 + stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
  69 + return stream;
  70 + }
  71 +
  72 + bool IsReady(OnlineStream *s) const override {
  73 + return s->GetNumProcessedFrames() + model_->ChunkSize() <
  74 + s->NumFramesReady();
  75 + }
  76 +
  77 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  78 + for (int32_t i = 0; i != n; ++i) {
  79 + DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i]));
  80 + }
  81 + }
  82 +
  83 + OnlineRecognizerResult GetResult(OnlineStream *s) const override {
  84 + OnlineCtcDecoderResult decoder_result = s->GetCtcResult();
  85 +
  86 + // TODO(fangjun): Remember to change these constants if needed
  87 + int32_t frame_shift_ms = 10;
  88 + int32_t subsampling_factor = 4;
  89 + auto r =
  90 + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
  91 + s->GetCurrentSegment(), s->GetNumFramesSinceStart());
  92 + r.text = ApplyInverseTextNormalization(r.text);
  93 + return r;
  94 + }
  95 +
  96 + bool IsEndpoint(OnlineStream *s) const override {
  97 + if (!config_.enable_endpoint) {
  98 + return false;
  99 + }
  100 +
  101 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  102 +
  103 + // frame shift is 10 milliseconds
  104 + float frame_shift_in_seconds = 0.01;
  105 +
  106 + // subsampling factor is 4
  107 + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4;
  108 +
  109 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  110 + frame_shift_in_seconds);
  111 + }
  112 +
  113 + void Reset(OnlineStream *s) const override {
  114 + // segment is incremented only when the last
  115 + // result is not empty
  116 + const auto &r = s->GetCtcResult();
  117 + if (!r.tokens.empty()) {
  118 + s->GetCurrentSegment() += 1;
  119 + }
  120 +
  121 + // clear result
  122 + s->SetCtcResult({});
  123 +
  124 + // clear states
  125 + reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerEncoderStates(
  126 + model_->GetInitStates());
  127 +
  128 + s->GetFasterDecoderProcessedFrames() = 0;
  129 +
  130 + // Note: We only update counters. The underlying audio samples
  131 + // are not discarded.
  132 + s->Reset();
  133 + }
  134 +
  135 + private:
  136 + void InitDecoder() {
  137 + if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
  138 + !sym_.Contains("<blank>")) {
  139 + SHERPA_ONNX_LOGE(
  140 + "We expect that tokens.txt contains "
  141 + "the symbol <blk> or <eps> or <blank> and its ID.");
  142 + exit(-1);
  143 + }
  144 +
  145 + int32_t blank_id = 0;
  146 + if (sym_.Contains("<blk>")) {
  147 + blank_id = sym_["<blk>"];
  148 + } else if (sym_.Contains("<eps>")) {
  149 + // for tdnn models of the yesno recipe from icefall
  150 + blank_id = sym_["<eps>"];
  151 + } else if (sym_.Contains("<blank>")) {
  152 + // for WeNet CTC models
  153 + blank_id = sym_["<blank>"];
  154 + }
  155 +
  156 + if (!config_.ctc_fst_decoder_config.graph.empty()) {
  157 + decoder_ = std::make_unique<OnlineCtcFstDecoder>(
  158 + config_.ctc_fst_decoder_config, blank_id);
  159 + } else if (config_.decoding_method == "greedy_search") {
  160 + decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
  161 + } else {
  162 + SHERPA_ONNX_LOGE(
  163 + "Unsupported decoding method: %s for streaming CTC models",
  164 + config_.decoding_method.c_str());
  165 + exit(-1);
  166 + }
  167 + }
  168 +
  169 + void DecodeStream(OnlineStreamRknn *s) const {
  170 + int32_t chunk_size = model_->ChunkSize();
  171 + int32_t chunk_shift = model_->ChunkShift();
  172 +
  173 + int32_t feat_dim = s->FeatureDim();
  174 +
  175 + const auto num_processed_frames = s->GetNumProcessedFrames();
  176 + std::vector<float> features =
  177 + s->GetFrames(num_processed_frames, chunk_size);
  178 + s->GetNumProcessedFrames() += chunk_shift;
  179 +
  180 + auto &states = s->GetZipformerEncoderStates();
  181 + auto p = model_->Run(features, std::move(states));
  182 + states = std::move(p.second);
  183 +
  184 + std::vector<OnlineCtcDecoderResult> results(1);
  185 + results[0] = std::move(s->GetCtcResult());
  186 +
  187 + auto attr = model_->GetOutAttr();
  188 +
  189 + decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2],
  190 + &results, reinterpret_cast<OnlineStream **>(&s), 1);
  191 + s->SetCtcResult(results[0]);
  192 + }
  193 +
  194 + private:
  195 + OnlineRecognizerConfig config_;
  196 + std::unique_ptr<OnlineZipformerCtcModelRknn> model_;
  197 + std::unique_ptr<OnlineCtcDecoder> decoder_;
  198 + SymbolTable sym_;
  199 + Endpoint endpoint_;
  200 +};
  201 +
  202 +} // namespace sherpa_onnx
  203 +
  204 +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
  1 +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
  6 +
  7 +#include <memory>
  8 +#include <sstream>
  9 +#include <string>
  10 +#include <unordered_map>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
  18 +
  19 +#if __OHOS__
  20 +#include "rawfile/raw_file_manager.h"
  21 +#endif
  22 +
  23 +#include "sherpa-onnx/csrc/file-utils.h"
  24 +#include "sherpa-onnx/csrc/rknn/macros.h"
  25 +#include "sherpa-onnx/csrc/rknn/utils.h"
  26 +#include "sherpa-onnx/csrc/text-utils.h"
  27 +
  28 +namespace sherpa_onnx {
  29 +
  30 +class OnlineZipformerCtcModelRknn::Impl {
  31 + public:
  32 + ~Impl() {
  33 + auto ret = rknn_destroy(ctx_);
  34 + if (ret != RKNN_SUCC) {
  35 + SHERPA_ONNX_LOGE("Failed to destroy the context");
  36 + }
  37 + }
  38 +
  39 + explicit Impl(const OnlineModelConfig &config) : config_(config) {
  40 + {
  41 + auto buf = ReadFile(config.zipformer2_ctc.model);
  42 + Init(buf.data(), buf.size());
  43 + }
  44 +
  45 + int32_t ret = RKNN_SUCC;
  46 + switch (config_.num_threads) {
  47 + case 1:
  48 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
  49 + break;
  50 + case 0:
  51 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
  52 + break;
  53 + case -1:
  54 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
  55 + break;
  56 + case -2:
  57 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
  58 + break;
  59 + case -3:
  60 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
  61 + break;
  62 + case -4:
  63 + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
  64 + break;
  65 + default:
  66 + SHERPA_ONNX_LOGE(
  67 + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
  68 + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
  69 + config_.num_threads);
  70 + break;
  71 + }
  72 + if (ret != RKNN_SUCC) {
  73 + SHERPA_ONNX_LOGE(
  74 + "Failed to select npu core to run the model (You can ignore it if "
  75 + "you "
  76 + "are not using RK3588.");
  77 + }
  78 + }
  79 +
  80 + // TODO(fangjun): Support Android
  81 +
  82 + std::vector<std::vector<uint8_t>> GetInitStates() const {
  83 + // input_attrs_[0] is for the feature
  84 + // input_attrs_[1:] is for states
  85 + // so we use -1 here
  86 + std::vector<std::vector<uint8_t>> states(input_attrs_.size() - 1);
  87 +
  88 + int32_t i = -1;
  89 + for (auto &attr : input_attrs_) {
  90 + i += 1;
  91 + if (i == 0) {
  92 + // skip processing the attr for features.
  93 + continue;
  94 + }
  95 +
  96 + if (attr.type == RKNN_TENSOR_FLOAT16) {
  97 + states[i - 1].resize(attr.n_elems * sizeof(float));
  98 + } else if (attr.type == RKNN_TENSOR_INT64) {
  99 + states[i - 1].resize(attr.n_elems * sizeof(int64_t));
  100 + } else {
  101 + SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type,
  102 + get_type_string(attr.type));
  103 + SHERPA_ONNX_EXIT(-1);
  104 + }
  105 + }
  106 +
  107 + return states;
  108 + }
  109 +
  110 + std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
  111 + std::vector<float> features,
  112 + std::vector<std::vector<uint8_t>> states) const {
  113 + std::vector<rknn_input> inputs(input_attrs_.size());
  114 +
  115 + for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
  116 + auto &input = inputs[i];
  117 + auto &attr = input_attrs_[i];
  118 + input.index = attr.index;
  119 +
  120 + if (attr.type == RKNN_TENSOR_FLOAT16) {
  121 + input.type = RKNN_TENSOR_FLOAT32;
  122 + } else if (attr.type == RKNN_TENSOR_INT64) {
  123 + input.type = RKNN_TENSOR_INT64;
  124 + } else {
  125 + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
  126 + get_type_string(attr.type));
  127 + SHERPA_ONNX_EXIT(-1);
  128 + }
  129 +
  130 + input.fmt = attr.fmt;
  131 + if (i == 0) {
  132 + input.buf = reinterpret_cast<void *>(features.data());
  133 + input.size = features.size() * sizeof(float);
  134 + } else {
  135 + input.buf = reinterpret_cast<void *>(states[i - 1].data());
  136 + input.size = states[i - 1].size();
  137 + }
  138 + }
  139 +
  140 + std::vector<float> out(output_attrs_[0].n_elems);
  141 +
  142 + // Note(fangjun): We can reuse the memory from input argument `states`
  143 + // auto next_states = GetInitStates();
  144 + auto &next_states = states;
  145 +
  146 + std::vector<rknn_output> outputs(output_attrs_.size());
  147 + for (int32_t i = 0; i < outputs.size(); ++i) {
  148 + auto &output = outputs[i];
  149 + auto &attr = output_attrs_[i];
  150 + output.index = attr.index;
  151 + output.is_prealloc = 1;
  152 +
  153 + if (attr.type == RKNN_TENSOR_FLOAT16) {
  154 + output.want_float = 1;
  155 + } else if (attr.type == RKNN_TENSOR_INT64) {
  156 + output.want_float = 0;
  157 + } else {
  158 + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
  159 + get_type_string(attr.type));
  160 + SHERPA_ONNX_EXIT(-1);
  161 + }
  162 +
  163 + if (i == 0) {
  164 + output.size = out.size() * sizeof(float);
  165 + output.buf = reinterpret_cast<void *>(out.data());
  166 + } else {
  167 + output.size = next_states[i - 1].size();
  168 + output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
  169 + }
  170 + }
  171 +
  172 + auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
  173 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
  174 +
  175 + ret = rknn_run(ctx_, nullptr);
  176 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
  177 +
  178 + ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
  179 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
  180 +
  181 + for (int32_t i = 0; i < next_states.size(); ++i) {
  182 + const auto &attr = input_attrs_[i + 1];
  183 + if (attr.n_dims == 4) {
  184 + // TODO(fangjun): The transpose is copied from
  185 + // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
  186 + // I don't understand why we need to do that.
  187 + std::vector<uint8_t> dst(next_states[i].size());
  188 + int32_t n = attr.dims[0];
  189 + int32_t h = attr.dims[1];
  190 + int32_t w = attr.dims[2];
  191 + int32_t c = attr.dims[3];
  192 + ConvertNCHWtoNHWC(
  193 + reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
  194 + reinterpret_cast<float *>(dst.data()));
  195 + next_states[i] = std::move(dst);
  196 + }
  197 + }
  198 +
  199 + return {std::move(out), std::move(next_states)};
  200 + }
  201 +
  202 + int32_t ChunkSize() const { return T_; }
  203 +
  204 + int32_t ChunkShift() const { return decode_chunk_len_; }
  205 +
  206 + int32_t VocabSize() const { return vocab_size_; }
  207 +
  208 + rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; }
  209 +
  210 + private:
  211 + void Init(void *model_data, size_t model_data_length) {
  212 + auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
  213 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
  214 + config_.zipformer2_ctc.model.c_str());
  215 +
  216 + if (config_.debug) {
  217 + rknn_sdk_version v;
  218 + ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
  219 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
  220 +
  221 + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
  222 + v.drv_version);
  223 + }
  224 +
  225 + rknn_input_output_num io_num;
  226 + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
  227 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
  228 +
  229 + if (config_.debug) {
  230 + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
  231 + static_cast<int32_t>(io_num.n_input),
  232 + static_cast<int32_t>(io_num.n_output));
  233 + }
  234 +
  235 + input_attrs_.resize(io_num.n_input);
  236 + output_attrs_.resize(io_num.n_output);
  237 +
  238 + int32_t i = 0;
  239 + for (auto &attr : input_attrs_) {
  240 + memset(&attr, 0, sizeof(attr));
  241 + attr.index = i;
  242 + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
  243 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
  244 + i += 1;
  245 + }
  246 +
  247 + if (config_.debug) {
  248 + std::ostringstream os;
  249 + std::string sep;
  250 + for (auto &attr : input_attrs_) {
  251 + os << sep << ToString(attr);
  252 + sep = "\n";
  253 + }
  254 + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
  255 + os.str().c_str());
  256 + }
  257 +
  258 + i = 0;
  259 + for (auto &attr : output_attrs_) {
  260 + memset(&attr, 0, sizeof(attr));
  261 + attr.index = i;
  262 + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
  263 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
  264 + i += 1;
  265 + }
  266 +
  267 + if (config_.debug) {
  268 + std::ostringstream os;
  269 + std::string sep;
  270 + for (auto &attr : output_attrs_) {
  271 + os << sep << ToString(attr);
  272 + sep = "\n";
  273 + }
  274 + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
  275 + os.str().c_str());
  276 + }
  277 +
  278 + rknn_custom_string custom_string;
  279 + ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
  280 + sizeof(custom_string));
  281 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
  282 + if (config_.debug) {
  283 + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
  284 + }
  285 + auto meta = Parse(custom_string);
  286 +
  287 + if (config_.debug) {
  288 + for (const auto &p : meta) {
  289 + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
  290 + }
  291 + }
  292 +
  293 + if (meta.count("T")) {
  294 + T_ = atoi(meta.at("T").c_str());
  295 + }
  296 +
  297 + if (meta.count("decode_chunk_len")) {
  298 + decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str());
  299 + }
  300 +
  301 + vocab_size_ = output_attrs_[0].dims[2];
  302 +
  303 + if (config_.debug) {
  304 +#if __OHOS__
  305 + SHERPA_ONNX_LOGE("T: %{public}d", T_);
  306 + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
  307 + SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size);
  308 +#else
  309 + SHERPA_ONNX_LOGE("T: %d", T_);
  310 + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
  311 + SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
  312 +#endif
  313 + }
  314 +
  315 + if (T_ == 0) {
  316 + SHERPA_ONNX_LOGE(
  317 + "Invalid T. Please use the script from icefall to export your model");
  318 + SHERPA_ONNX_EXIT(-1);
  319 + }
  320 +
  321 + if (decode_chunk_len_ == 0) {
  322 + SHERPA_ONNX_LOGE(
  323 + "Invalid decode_chunk_len. Please use the script from icefall to "
  324 + "export your model");
  325 + SHERPA_ONNX_EXIT(-1);
  326 + }
  327 + }
  328 +
  329 + private:
  330 + OnlineModelConfig config_;
  331 + rknn_context ctx_ = 0;
  332 +
  333 + std::vector<rknn_tensor_attr> input_attrs_;
  334 + std::vector<rknn_tensor_attr> output_attrs_;
  335 +
  336 + int32_t T_ = 0;
  337 + int32_t decode_chunk_len_ = 0;
  338 + int32_t vocab_size_ = 0;
  339 +};
  340 +
  341 +OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default;
  342 +
  343 +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
  344 + const OnlineModelConfig &config)
  345 + : impl_(std::make_unique<Impl>(config)) {}
  346 +
  347 +template <typename Manager>
  348 +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
  349 + Manager *mgr, const OnlineModelConfig &config)
  350 + : impl_(std::make_unique<OnlineZipformerCtcModelRknn>(mgr, config)) {}
  351 +
  352 +std::vector<std::vector<uint8_t>> OnlineZipformerCtcModelRknn::GetInitStates()
  353 + const {
  354 + return impl_->GetInitStates();
  355 +}
  356 +
  357 +std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>>
  358 +OnlineZipformerCtcModelRknn::Run(
  359 + std::vector<float> features,
  360 + std::vector<std::vector<uint8_t>> states) const {
  361 + return impl_->Run(std::move(features), std::move(states));
  362 +}
  363 +
  364 +int32_t OnlineZipformerCtcModelRknn::ChunkSize() const {
  365 + return impl_->ChunkSize();
  366 +}
  367 +
  368 +int32_t OnlineZipformerCtcModelRknn::ChunkShift() const {
  369 + return impl_->ChunkShift();
  370 +}
  371 +
  372 +int32_t OnlineZipformerCtcModelRknn::VocabSize() const {
  373 + return impl_->VocabSize();
  374 +}
  375 +
  376 +rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const {
  377 + return impl_->GetOutAttr();
  378 +}
  379 +
  380 +#if __ANDROID_API__ >= 9
  381 +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
  382 + AAssetManager *mgr, const OnlineModelConfig &config);
  383 +#endif
  384 +
  385 +#if __OHOS__
  386 +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
  387 + NativeResourceManager *mgr, const OnlineModelConfig &config);
  388 +#endif
  389 +
  390 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
  5 +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "rknn_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/online-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OnlineZipformerCtcModelRknn {
  17 + public:
  18 + ~OnlineZipformerCtcModelRknn();
  19 +
  20 + explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config);
  21 +
  22 + template <typename Manager>
  23 + OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config);
  24 +
  25 + std::vector<std::vector<uint8_t>> GetInitStates() const;
  26 +
  27 + std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
  28 + std::vector<float> features,
  29 + std::vector<std::vector<uint8_t>> states) const;
  30 +
  31 + int32_t ChunkSize() const;
  32 +
  33 + int32_t ChunkShift() const;
  34 +
  35 + int32_t VocabSize() const;
  36 +
  37 + rknn_tensor_attr GetOutAttr() const;
  38 +
  39 + private:
  40 + class Impl;
  41 + std::unique_ptr<Impl> impl_;
  42 +};
  43 +
  44 +} // namespace sherpa_onnx
  45 +
  46 +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
1 // sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc 1 // sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2025 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" 5 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
6 6
@@ -22,68 +22,11 @@ @@ -22,68 +22,11 @@
22 22
23 #include "sherpa-onnx/csrc/file-utils.h" 23 #include "sherpa-onnx/csrc/file-utils.h"
24 #include "sherpa-onnx/csrc/rknn/macros.h" 24 #include "sherpa-onnx/csrc/rknn/macros.h"
  25 +#include "sherpa-onnx/csrc/rknn/utils.h"
25 #include "sherpa-onnx/csrc/text-utils.h" 26 #include "sherpa-onnx/csrc/text-utils.h"
26 27
27 namespace sherpa_onnx { 28 namespace sherpa_onnx {
28 29
29 -// chw -> hwc  
30 -static void Transpose(const float *src, int32_t n, int32_t channel,  
31 - int32_t height, int32_t width, float *dst) {  
32 - for (int32_t i = 0; i < n; ++i) {  
33 - for (int32_t h = 0; h < height; ++h) {  
34 - for (int32_t w = 0; w < width; ++w) {  
35 - for (int32_t c = 0; c < channel; ++c) {  
36 - // dst[h, w, c] = src[c, h, w]  
37 - dst[i * height * width * channel + h * width * channel + w * channel +  
38 - c] = src[i * height * width * channel + c * height * width +  
39 - h * width + w];  
40 - }  
41 - }  
42 - }  
43 - }  
44 -}  
45 -  
46 -static std::string ToString(const rknn_tensor_attr &attr) {  
47 - std::ostringstream os;  
48 - os << "{";  
49 - os << attr.index;  
50 - os << ", name: " << attr.name;  
51 - os << ", shape: (";  
52 - std::string sep;  
53 - for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {  
54 - os << sep << attr.dims[i];  
55 - sep = ",";  
56 - }  
57 - os << ")";  
58 - os << ", n_elems: " << attr.n_elems;  
59 - os << ", size: " << attr.size;  
60 - os << ", fmt: " << get_format_string(attr.fmt);  
61 - os << ", type: " << get_type_string(attr.type);  
62 - os << ", pass_through: " << (attr.pass_through ? "true" : "false");  
63 - os << "}";  
64 - return os.str();  
65 -}  
66 -  
67 -static std::unordered_map<std::string, std::string> Parse(  
68 - const rknn_custom_string &custom_string) {  
69 - std::unordered_map<std::string, std::string> ans;  
70 - std::vector<std::string> fields;  
71 - SplitStringToVector(custom_string.string, ";", false, &fields);  
72 -  
73 - std::vector<std::string> tmp;  
74 - for (const auto &f : fields) {  
75 - SplitStringToVector(f, "=", false, &tmp);  
76 - if (tmp.size() != 2) {  
77 - SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,  
78 - f.c_str());  
79 - SHERPA_ONNX_EXIT(-1);  
80 - }  
81 - ans[std::move(tmp[0])] = std::move(tmp[1]);  
82 - }  
83 -  
84 - return ans;  
85 -}  
86 -  
87 class OnlineZipformerTransducerModelRknn::Impl { 30 class OnlineZipformerTransducerModelRknn::Impl {
88 public: 31 public:
89 ~Impl() { 32 ~Impl() {
@@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
285 for (int32_t i = 0; i < next_states.size(); ++i) { 228 for (int32_t i = 0; i < next_states.size(); ++i) {
286 const auto &attr = encoder_input_attrs_[i + 1]; 229 const auto &attr = encoder_input_attrs_[i + 1];
287 if (attr.n_dims == 4) { 230 if (attr.n_dims == 4) {
288 - // TODO(fangjun): The transpose is copied from 231 + // TODO(fangjun): The ConvertNCHWtoNHWC is copied from
289 // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 232 // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
290 // I don't understand why we need to do that. 233 // I don't understand why we need to do that.
291 std::vector<uint8_t> dst(next_states[i].size()); 234 std::vector<uint8_t> dst(next_states[i].size());
@@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
293 int32_t h = attr.dims[1]; 236 int32_t h = attr.dims[1];
294 int32_t w = attr.dims[2]; 237 int32_t w = attr.dims[2];
295 int32_t c = attr.dims[3]; 238 int32_t c = attr.dims[3];
296 - Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c,  
297 - h, w, reinterpret_cast<float *>(dst.data())); 239 + ConvertNCHWtoNHWC(
  240 + reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
  241 + reinterpret_cast<float *>(dst.data()));
298 next_states[i] = std::move(dst); 242 next_states[i] = std::move(dst);
299 } 243 }
300 } 244 }
@@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
527 #if __OHOS__ 471 #if __OHOS__
528 SHERPA_ONNX_LOGE("T: %{public}d", T_); 472 SHERPA_ONNX_LOGE("T: %{public}d", T_);
529 SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); 473 SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
530 - SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_);  
531 #else 474 #else
532 SHERPA_ONNX_LOGE("T: %d", T_); 475 SHERPA_ONNX_LOGE("T: %d", T_);
533 SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); 476 SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
534 - SHERPA_ONNX_LOGE("context_size: %d", context_size_);  
535 #endif 477 #endif
536 } 478 }
537 } 479 }
@@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl {
597 SHERPA_ONNX_EXIT(-1); 539 SHERPA_ONNX_EXIT(-1);
598 } 540 }
599 541
  542 + context_size_ = decoder_input_attrs_[0].dims[1];
  543 + if (config_.debug) {
  544 + SHERPA_ONNX_LOGE("context_size: %d", context_size_);
  545 + }
  546 +
600 i = 0; 547 i = 0;
601 for (auto &attr : decoder_output_attrs_) { 548 for (auto &attr : decoder_output_attrs_) {
602 memset(&attr, 0, sizeof(attr)); 549 memset(&attr, 0, sizeof(attr));
@@ -14,8 +14,11 @@ @@ -14,8 +14,11 @@
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
17 -// this is for zipformer v1, i.e., the folder  
18 -// pruned_transducer_statelss7_streaming from icefall 17 +// this is for zipformer v1 and v2, i.e., the folder
  18 +// pruned_transducer_statelss7_streaming
  19 +// and
  20 +// zipformer
  21 +// from icefall
19 class OnlineZipformerTransducerModelRknn { 22 class OnlineZipformerTransducerModelRknn {
20 public: 23 public:
21 ~OnlineZipformerTransducerModelRknn(); 24 ~OnlineZipformerTransducerModelRknn();
  1 +// sherpa-onnx/csrc/utils.cc
  2 +//
  3 +// Copyright 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/utils.h"
  6 +
  7 +#include <sstream>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
  17 + int32_t height, int32_t width, float *dst) {
  18 + for (int32_t i = 0; i < n; ++i) {
  19 + for (int32_t h = 0; h < height; ++h) {
  20 + for (int32_t w = 0; w < width; ++w) {
  21 + for (int32_t c = 0; c < channel; ++c) {
  22 + // dst[h, w, c] = src[c, h, w]
  23 + dst[i * height * width * channel + h * width * channel + w * channel +
  24 + c] = src[i * height * width * channel + c * height * width +
  25 + h * width + w];
  26 + }
  27 + }
  28 + }
  29 + }
  30 +}
  31 +
  32 +std::string ToString(const rknn_tensor_attr &attr) {
  33 + std::ostringstream os;
  34 + os << "{";
  35 + os << attr.index;
  36 + os << ", name: " << attr.name;
  37 + os << ", shape: (";
  38 + std::string sep;
  39 + for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
  40 + os << sep << attr.dims[i];
  41 + sep = ",";
  42 + }
  43 + os << ")";
  44 + os << ", n_elems: " << attr.n_elems;
  45 + os << ", size: " << attr.size;
  46 + os << ", fmt: " << get_format_string(attr.fmt);
  47 + os << ", type: " << get_type_string(attr.type);
  48 + os << ", pass_through: " << (attr.pass_through ? "true" : "false");
  49 + os << "}";
  50 + return os.str();
  51 +}
  52 +
  53 +std::unordered_map<std::string, std::string> Parse(
  54 + const rknn_custom_string &custom_string) {
  55 + std::unordered_map<std::string, std::string> ans;
  56 + std::vector<std::string> fields;
  57 + SplitStringToVector(custom_string.string, ";", false, &fields);
  58 +
  59 + std::vector<std::string> tmp;
  60 + for (const auto &f : fields) {
  61 + SplitStringToVector(f, "=", false, &tmp);
  62 + if (tmp.size() != 2) {
  63 + SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
  64 + f.c_str());
  65 + SHERPA_ONNX_EXIT(-1);
  66 + }
  67 + ans[std::move(tmp[0])] = std::move(tmp[1]);
  68 + }
  69 +
  70 + return ans;
  71 +}
  72 +
  73 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/utils.h
  2 +//
  3 +// Copyright 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_
  7 +
  8 +#include <string>
  9 +#include <unordered_map>
  10 +
  11 +#include "rknn_api.h" // NOLINT
  12 +
  13 +namespace sherpa_onnx {
  14 +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
  15 + int32_t height, int32_t width, float *dst);
  16 +
  17 +std::string ToString(const rknn_tensor_attr &attr);
  18 +
  19 +std::unordered_map<std::string, std::string> Parse(
  20 + const rknn_custom_string &custom_string);
  21 +} // namespace sherpa_onnx
  22 +
  23 +#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
@@ -83,6 +83,7 @@ for a list of pre-trained models to download. @@ -83,6 +83,7 @@ for a list of pre-trained models to download.
83 po.Read(argc, argv); 83 po.Read(argc, argv);
84 if (po.NumArgs() < 1) { 84 if (po.NumArgs() < 1) {
85 po.PrintUsage(); 85 po.PrintUsage();
  86 + fprintf(stderr, "Error! Please provide at lease 1 wav file\n");
86 exit(EXIT_FAILURE); 87 exit(EXIT_FAILURE);
87 } 88 }
88 89