Committed by
GitHub
Support RKNN for Zipformer CTC models. (#1948)
正在显示
17 个修改的文件
包含
819 行增加
和
114 行删除
| @@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN) | @@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN) | ||
| 155 | list(APPEND sources | 155 | list(APPEND sources |
| 156 | ./rknn/online-stream-rknn.cc | 156 | ./rknn/online-stream-rknn.cc |
| 157 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc | 157 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc |
| 158 | + ./rknn/online-zipformer-ctc-model-rknn.cc | ||
| 158 | ./rknn/online-zipformer-transducer-model-rknn.cc | 159 | ./rknn/online-zipformer-transducer-model-rknn.cc |
| 160 | + ./rknn/utils.cc | ||
| 159 | ) | 161 | ) |
| 160 | 162 | ||
| 161 | endif() | 163 | endif() |
| @@ -43,12 +43,14 @@ class OnlineCtcDecoder { | @@ -43,12 +43,14 @@ class OnlineCtcDecoder { | ||
| 43 | 43 | ||
| 44 | /** Run streaming CTC decoding given the output from the encoder model. | 44 | /** Run streaming CTC decoding given the output from the encoder model. |
| 45 | * | 45 | * |
| 46 | - * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing | ||
| 47 | - * lob_probs. | 46 | + * @param log_probs A 3-D tensor of shape |
| 47 | + * (batch_size, num_frames, vocab_size) containing | ||
| 48 | + * lob_probs in row major. | ||
| 48 | * | 49 | * |
| 49 | * @param results Input & Output parameters.. | 50 | * @param results Input & Output parameters.. |
| 50 | */ | 51 | */ |
| 51 | - virtual void Decode(Ort::Value log_probs, | 52 | + virtual void Decode(const float *log_probs, int32_t batch_size, |
| 53 | + int32_t num_frames, int32_t vocab_size, | ||
| 52 | std::vector<OnlineCtcDecoderResult> *results, | 54 | std::vector<OnlineCtcDecoderResult> *results, |
| 53 | OnlineStream **ss = nullptr, int32_t n = 0) = 0; | 55 | OnlineStream **ss = nullptr, int32_t n = 0) = 0; |
| 54 | 56 |
| @@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | @@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | ||
| 91 | processed_frames += num_rows; | 91 | processed_frames += num_rows; |
| 92 | } | 92 | } |
| 93 | 93 | ||
| 94 | -void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, | 94 | +void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size, |
| 95 | + int32_t num_frames, int32_t vocab_size, | ||
| 95 | std::vector<OnlineCtcDecoderResult> *results, | 96 | std::vector<OnlineCtcDecoderResult> *results, |
| 96 | OnlineStream **ss, int32_t n) { | 97 | OnlineStream **ss, int32_t n) { |
| 97 | - std::vector<int64_t> log_probs_shape = | ||
| 98 | - log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 99 | - | ||
| 100 | - if (log_probs_shape[0] != results->size()) { | 98 | + if (batch_size != results->size()) { |
| 101 | SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", | 99 | SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", |
| 102 | - static_cast<int32_t>(log_probs_shape[0]), | ||
| 103 | - static_cast<int32_t>(results->size())); | 100 | + batch_size, static_cast<int32_t>(results->size())); |
| 104 | exit(-1); | 101 | exit(-1); |
| 105 | } | 102 | } |
| 106 | 103 | ||
| 107 | - if (log_probs_shape[0] != n) { | ||
| 108 | - SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", | ||
| 109 | - static_cast<int32_t>(log_probs_shape[0]), n); | 104 | + if (batch_size != n) { |
| 105 | + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size, | ||
| 106 | + n); | ||
| 110 | exit(-1); | 107 | exit(-1); |
| 111 | } | 108 | } |
| 112 | 109 | ||
| 113 | - int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]); | ||
| 114 | - int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]); | ||
| 115 | - int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]); | ||
| 116 | - | ||
| 117 | - const float *p = log_probs.GetTensorData<float>(); | 110 | + const float *p = log_probs; |
| 118 | 111 | ||
| 119 | for (int32_t i = 0; i != batch_size; ++i) { | 112 | for (int32_t i = 0; i != batch_size; ++i) { |
| 120 | DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, | 113 | DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, |
| @@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { | @@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { | ||
| 19 | OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, | 19 | OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, |
| 20 | int32_t blank_id); | 20 | int32_t blank_id); |
| 21 | 21 | ||
| 22 | - void Decode(Ort::Value log_probs, | ||
| 23 | - std::vector<OnlineCtcDecoderResult> *results, | 22 | + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, |
| 23 | + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results, | ||
| 24 | OnlineStream **ss = nullptr, int32_t n = 0) override; | 24 | OnlineStream **ss = nullptr, int32_t n = 0) override; |
| 25 | 25 | ||
| 26 | std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder() | 26 | std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder() |
| @@ -13,23 +13,16 @@ | @@ -13,23 +13,16 @@ | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | void OnlineCtcGreedySearchDecoder::Decode( | 15 | void OnlineCtcGreedySearchDecoder::Decode( |
| 16 | - Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results, | 16 | + const float *log_probs, int32_t batch_size, int32_t num_frames, |
| 17 | + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results, | ||
| 17 | OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { | 18 | OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { |
| 18 | - std::vector<int64_t> log_probs_shape = | ||
| 19 | - log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 20 | - | ||
| 21 | - if (log_probs_shape[0] != results->size()) { | 19 | + if (batch_size != results->size()) { |
| 22 | SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", | 20 | SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", |
| 23 | - static_cast<int32_t>(log_probs_shape[0]), | ||
| 24 | - static_cast<int32_t>(results->size())); | 21 | + batch_size, static_cast<int32_t>(results->size())); |
| 25 | exit(-1); | 22 | exit(-1); |
| 26 | } | 23 | } |
| 27 | 24 | ||
| 28 | - int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]); | ||
| 29 | - int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]); | ||
| 30 | - int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]); | ||
| 31 | - | ||
| 32 | - const float *p = log_probs.GetTensorData<float>(); | 25 | + const float *p = log_probs; |
| 33 | 26 | ||
| 34 | for (int32_t b = 0; b != batch_size; ++b) { | 27 | for (int32_t b = 0; b != batch_size; ++b) { |
| 35 | auto &r = (*results)[b]; | 28 | auto &r = (*results)[b]; |
| @@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { | @@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { | ||
| 16 | explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) | 16 | explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) |
| 17 | : blank_id_(blank_id) {} | 17 | : blank_id_(blank_id) {} |
| 18 | 18 | ||
| 19 | - void Decode(Ort::Value log_probs, | ||
| 20 | - std::vector<OnlineCtcDecoderResult> *results, | 19 | + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, |
| 20 | + int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results, | ||
| 21 | OnlineStream **ss = nullptr, int32_t n = 0) override; | 21 | OnlineStream **ss = nullptr, int32_t n = 0) override; |
| 22 | 22 | ||
| 23 | private: | 23 | private: |
| @@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const { | @@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const { | ||
| 76 | transducer.decoder.c_str(), transducer.joiner.c_str()); | 76 | transducer.decoder.c_str(), transducer.joiner.c_str()); |
| 77 | return false; | 77 | return false; |
| 78 | } | 78 | } |
| 79 | + | ||
| 80 | + if (!zipformer2_ctc.model.empty() && | ||
| 81 | + EndsWith(zipformer2_ctc.model, ".rknn")) { | ||
| 82 | + SHERPA_ONNX_LOGE( | ||
| 83 | + "--provider is %s, which is not rknn, but you pass rknn model " | ||
| 84 | + "filename for zipformer2_ctc: '%s'", | ||
| 85 | + provider_config.provider.c_str(), zipformer2_ctc.model.c_str()); | ||
| 86 | + return false; | ||
| 87 | + } | ||
| 79 | } | 88 | } |
| 80 | 89 | ||
| 81 | if (provider_config.provider == "rknn") { | 90 | if (provider_config.provider == "rknn") { |
| @@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const { | @@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const { | ||
| 89 | transducer.joiner.c_str()); | 98 | transducer.joiner.c_str()); |
| 90 | return false; | 99 | return false; |
| 91 | } | 100 | } |
| 101 | + | ||
| 102 | + if (!zipformer2_ctc.model.empty() && | ||
| 103 | + EndsWith(zipformer2_ctc.model, ".onnx")) { | ||
| 104 | + SHERPA_ONNX_LOGE( | ||
| 105 | + "--provider rknn, but you pass onnx model filename for " | ||
| 106 | + "zipformer2_ctc: '%s'", | ||
| 107 | + zipformer2_ctc.model.c_str()); | ||
| 108 | + return false; | ||
| 109 | + } | ||
| 92 | } | 110 | } |
| 93 | 111 | ||
| 94 | if (!tokens_buf.empty() && FileExists(tokens)) { | 112 | if (!tokens_buf.empty() && FileExists(tokens)) { |
| @@ -24,12 +24,11 @@ | @@ -24,12 +24,11 @@ | ||
| 24 | 24 | ||
| 25 | namespace sherpa_onnx { | 25 | namespace sherpa_onnx { |
| 26 | 26 | ||
| 27 | -static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 28 | - const SymbolTable &sym_table, | ||
| 29 | - float frame_shift_ms, | ||
| 30 | - int32_t subsampling_factor, | ||
| 31 | - int32_t segment, | ||
| 32 | - int32_t frames_since_start) { | 27 | +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src, |
| 28 | + const SymbolTable &sym_table, | ||
| 29 | + float frame_shift_ms, | ||
| 30 | + int32_t subsampling_factor, int32_t segment, | ||
| 31 | + int32_t frames_since_start) { | ||
| 33 | OnlineRecognizerResult r; | 32 | OnlineRecognizerResult r; |
| 34 | r.tokens.reserve(src.tokens.size()); | 33 | r.tokens.reserve(src.tokens.size()); |
| 35 | r.timestamps.reserve(src.tokens.size()); | 34 | r.timestamps.reserve(src.tokens.size()); |
| @@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 182 | std::vector<std::vector<Ort::Value>> next_states = | 181 | std::vector<std::vector<Ort::Value>> next_states = |
| 183 | model_->UnStackStates(std::move(out_states)); | 182 | model_->UnStackStates(std::move(out_states)); |
| 184 | 183 | ||
| 185 | - decoder_->Decode(std::move(out[0]), &results, ss, n); | 184 | + std::vector<int64_t> log_probs_shape = |
| 185 | + out[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
| 186 | + decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0], | ||
| 187 | + log_probs_shape[1], log_probs_shape[2], &results, ss, n); | ||
| 186 | 188 | ||
| 187 | for (int32_t k = 0; k != n; ++k) { | 189 | for (int32_t k = 0; k != n; ++k) { |
| 188 | ss[k]->SetCtcResult(results[k]); | 190 | ss[k]->SetCtcResult(results[k]); |
| @@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 196 | // TODO(fangjun): Remember to change these constants if needed | 198 | // TODO(fangjun): Remember to change these constants if needed |
| 197 | int32_t frame_shift_ms = 10; | 199 | int32_t frame_shift_ms = 10; |
| 198 | int32_t subsampling_factor = 4; | 200 | int32_t subsampling_factor = 4; |
| 199 | - auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | ||
| 200 | - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 201 | + auto r = |
| 202 | + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, | ||
| 203 | + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | ||
| 201 | r.text = ApplyInverseTextNormalization(r.text); | 204 | r.text = ApplyInverseTextNormalization(r.text); |
| 202 | return r; | 205 | return r; |
| 203 | } | 206 | } |
| @@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 306 | std::vector<OnlineCtcDecoderResult> results(1); | 309 | std::vector<OnlineCtcDecoderResult> results(1); |
| 307 | results[0] = std::move(s->GetCtcResult()); | 310 | results[0] = std::move(s->GetCtcResult()); |
| 308 | 311 | ||
| 309 | - decoder_->Decode(std::move(out[0]), &results, &s, 1); | 312 | + std::vector<int64_t> log_probs_shape = |
| 313 | + out[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
| 314 | + decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0], | ||
| 315 | + log_probs_shape[1], log_probs_shape[2], &results, &s, 1); | ||
| 310 | s->SetCtcResult(results[0]); | 316 | s->SetCtcResult(results[0]); |
| 311 | } | 317 | } |
| 312 | 318 |
| @@ -27,6 +27,7 @@ | @@ -27,6 +27,7 @@ | ||
| 27 | #include "sherpa-onnx/csrc/text-utils.h" | 27 | #include "sherpa-onnx/csrc/text-utils.h" |
| 28 | 28 | ||
| 29 | #if SHERPA_ONNX_ENABLE_RKNN | 29 | #if SHERPA_ONNX_ENABLE_RKNN |
| 30 | +#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h" | ||
| 30 | #include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h" | 31 | #include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h" |
| 31 | #endif | 32 | #endif |
| 32 | 33 | ||
| @@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 37 | if (config.model_config.provider_config.provider == "rknn") { | 38 | if (config.model_config.provider_config.provider == "rknn") { |
| 38 | #if SHERPA_ONNX_ENABLE_RKNN | 39 | #if SHERPA_ONNX_ENABLE_RKNN |
| 39 | // Currently, only zipformer v1 is suported for rknn | 40 | // Currently, only zipformer v1 is suported for rknn |
| 40 | - if (config.model_config.transducer.encoder.empty()) { | 41 | + if (config.model_config.transducer.encoder.empty() && |
| 42 | + config.model_config.zipformer2_ctc.model.empty()) { | ||
| 41 | SHERPA_ONNX_LOGE( | 43 | SHERPA_ONNX_LOGE( |
| 42 | - "Only Zipformer transducers are currently supported by rknn. " | ||
| 43 | - "Fallback to CPU"); | ||
| 44 | - } else { | 44 | + "Only Zipformer transducers and CTC models are currently supported " |
| 45 | + "by rknn. Fallback to CPU"); | ||
| 46 | + } else if (!config.model_config.transducer.encoder.empty()) { | ||
| 45 | return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config); | 47 | return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config); |
| 48 | + } else if (!config.model_config.zipformer2_ctc.model.empty()) { | ||
| 49 | + return std::make_unique<OnlineRecognizerCtcRknnImpl>(config); | ||
| 46 | } | 50 | } |
| 47 | #else | 51 | #else |
| 48 | SHERPA_ONNX_LOGE( | 52 | SHERPA_ONNX_LOGE( |
| 1 | +// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <ios> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <sstream> | ||
| 12 | +#include <string> | ||
| 13 | +#include <utility> | ||
| 14 | +#include <vector> | ||
| 15 | + | ||
| 16 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 17 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-ctc-decoder.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" | ||
| 20 | +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | ||
| 21 | +#include "sherpa-onnx/csrc/online-recognizer-impl.h" | ||
| 22 | +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" | ||
| 23 | +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h" | ||
| 24 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +// defined in ../online-recognizer-ctc-impl.h | ||
| 29 | +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src, | ||
| 30 | + const SymbolTable &sym_table, | ||
| 31 | + float frame_shift_ms, | ||
| 32 | + int32_t subsampling_factor, int32_t segment, | ||
| 33 | + int32_t frames_since_start); | ||
| 34 | + | ||
| 35 | +class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { | ||
| 36 | + public: | ||
| 37 | + explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config) | ||
| 38 | + : OnlineRecognizerImpl(config), | ||
| 39 | + config_(config), | ||
| 40 | + model_( | ||
| 41 | + std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)), | ||
| 42 | + endpoint_(config_.endpoint_config) { | ||
| 43 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 44 | + sym_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 45 | + } else { | ||
| 46 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 47 | + sym_ = SymbolTable(config.model_config.tokens, true); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + InitDecoder(); | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + template <typename Manager> | ||
| 54 | + explicit OnlineRecognizerCtcRknnImpl(Manager *mgr, | ||
| 55 | + const OnlineRecognizerConfig &config) | ||
| 56 | + : OnlineRecognizerImpl(mgr, config), | ||
| 57 | + config_(config), | ||
| 58 | + model_( | ||
| 59 | + std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)), | ||
| 60 | + sym_(mgr, config.model_config.tokens), | ||
| 61 | + endpoint_(config_.endpoint_config) { | ||
| 62 | + InitDecoder(); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 66 | + auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config); | ||
| 67 | + stream->SetZipformerEncoderStates(model_->GetInitStates()); | ||
| 68 | + stream->SetFasterDecoder(decoder_->CreateFasterDecoder()); | ||
| 69 | + return stream; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + bool IsReady(OnlineStream *s) const override { | ||
| 73 | + return s->GetNumProcessedFrames() + model_->ChunkSize() < | ||
| 74 | + s->NumFramesReady(); | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
| 78 | + for (int32_t i = 0; i != n; ++i) { | ||
| 79 | + DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i])); | ||
| 80 | + } | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + OnlineRecognizerResult GetResult(OnlineStream *s) const override { | ||
| 84 | + OnlineCtcDecoderResult decoder_result = s->GetCtcResult(); | ||
| 85 | + | ||
| 86 | + // TODO(fangjun): Remember to change these constants if needed | ||
| 87 | + int32_t frame_shift_ms = 10; | ||
| 88 | + int32_t subsampling_factor = 4; | ||
| 89 | + auto r = | ||
| 90 | + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, | ||
| 91 | + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | ||
| 92 | + r.text = ApplyInverseTextNormalization(r.text); | ||
| 93 | + return r; | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + bool IsEndpoint(OnlineStream *s) const override { | ||
| 97 | + if (!config_.enable_endpoint) { | ||
| 98 | + return false; | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + int32_t num_processed_frames = s->GetNumProcessedFrames(); | ||
| 102 | + | ||
| 103 | + // frame shift is 10 milliseconds | ||
| 104 | + float frame_shift_in_seconds = 0.01; | ||
| 105 | + | ||
| 106 | + // subsampling factor is 4 | ||
| 107 | + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4; | ||
| 108 | + | ||
| 109 | + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | ||
| 110 | + frame_shift_in_seconds); | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + void Reset(OnlineStream *s) const override { | ||
| 114 | + // segment is incremented only when the last | ||
| 115 | + // result is not empty | ||
| 116 | + const auto &r = s->GetCtcResult(); | ||
| 117 | + if (!r.tokens.empty()) { | ||
| 118 | + s->GetCurrentSegment() += 1; | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + // clear result | ||
| 122 | + s->SetCtcResult({}); | ||
| 123 | + | ||
| 124 | + // clear states | ||
| 125 | + reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerEncoderStates( | ||
| 126 | + model_->GetInitStates()); | ||
| 127 | + | ||
| 128 | + s->GetFasterDecoderProcessedFrames() = 0; | ||
| 129 | + | ||
| 130 | + // Note: We only update counters. The underlying audio samples | ||
| 131 | + // are not discarded. | ||
| 132 | + s->Reset(); | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + private: | ||
| 136 | + void InitDecoder() { | ||
| 137 | + if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") && | ||
| 138 | + !sym_.Contains("<blank>")) { | ||
| 139 | + SHERPA_ONNX_LOGE( | ||
| 140 | + "We expect that tokens.txt contains " | ||
| 141 | + "the symbol <blk> or <eps> or <blank> and its ID."); | ||
| 142 | + exit(-1); | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + int32_t blank_id = 0; | ||
| 146 | + if (sym_.Contains("<blk>")) { | ||
| 147 | + blank_id = sym_["<blk>"]; | ||
| 148 | + } else if (sym_.Contains("<eps>")) { | ||
| 149 | + // for tdnn models of the yesno recipe from icefall | ||
| 150 | + blank_id = sym_["<eps>"]; | ||
| 151 | + } else if (sym_.Contains("<blank>")) { | ||
| 152 | + // for WeNet CTC models | ||
| 153 | + blank_id = sym_["<blank>"]; | ||
| 154 | + } | ||
| 155 | + | ||
| 156 | + if (!config_.ctc_fst_decoder_config.graph.empty()) { | ||
| 157 | + decoder_ = std::make_unique<OnlineCtcFstDecoder>( | ||
| 158 | + config_.ctc_fst_decoder_config, blank_id); | ||
| 159 | + } else if (config_.decoding_method == "greedy_search") { | ||
| 160 | + decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id); | ||
| 161 | + } else { | ||
| 162 | + SHERPA_ONNX_LOGE( | ||
| 163 | + "Unsupported decoding method: %s for streaming CTC models", | ||
| 164 | + config_.decoding_method.c_str()); | ||
| 165 | + exit(-1); | ||
| 166 | + } | ||
| 167 | + } | ||
| 168 | + | ||
| 169 | + void DecodeStream(OnlineStreamRknn *s) const { | ||
| 170 | + int32_t chunk_size = model_->ChunkSize(); | ||
| 171 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 172 | + | ||
| 173 | + int32_t feat_dim = s->FeatureDim(); | ||
| 174 | + | ||
| 175 | + const auto num_processed_frames = s->GetNumProcessedFrames(); | ||
| 176 | + std::vector<float> features = | ||
| 177 | + s->GetFrames(num_processed_frames, chunk_size); | ||
| 178 | + s->GetNumProcessedFrames() += chunk_shift; | ||
| 179 | + | ||
| 180 | + auto &states = s->GetZipformerEncoderStates(); | ||
| 181 | + auto p = model_->Run(features, std::move(states)); | ||
| 182 | + states = std::move(p.second); | ||
| 183 | + | ||
| 184 | + std::vector<OnlineCtcDecoderResult> results(1); | ||
| 185 | + results[0] = std::move(s->GetCtcResult()); | ||
| 186 | + | ||
| 187 | + auto attr = model_->GetOutAttr(); | ||
| 188 | + | ||
| 189 | + decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2], | ||
| 190 | + &results, reinterpret_cast<OnlineStream **>(&s), 1); | ||
| 191 | + s->SetCtcResult(results[0]); | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + private: | ||
| 195 | + OnlineRecognizerConfig config_; | ||
| 196 | + std::unique_ptr<OnlineZipformerCtcModelRknn> model_; | ||
| 197 | + std::unique_ptr<OnlineCtcDecoder> decoder_; | ||
| 198 | + SymbolTable sym_; | ||
| 199 | + Endpoint endpoint_; | ||
| 200 | +}; | ||
| 201 | + | ||
| 202 | +} // namespace sherpa_onnx | ||
| 203 | + | ||
| 204 | +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <sstream> | ||
| 9 | +#include <string> | ||
| 10 | +#include <unordered_map> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#if __ANDROID_API__ >= 9 | ||
| 15 | +#include "android/asset_manager.h" | ||
| 16 | +#include "android/asset_manager_jni.h" | ||
| 17 | +#endif | ||
| 18 | + | ||
| 19 | +#if __OHOS__ | ||
| 20 | +#include "rawfile/raw_file_manager.h" | ||
| 21 | +#endif | ||
| 22 | + | ||
| 23 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 24 | +#include "sherpa-onnx/csrc/rknn/macros.h" | ||
| 25 | +#include "sherpa-onnx/csrc/rknn/utils.h" | ||
| 26 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 27 | + | ||
| 28 | +namespace sherpa_onnx { | ||
| 29 | + | ||
| 30 | +class OnlineZipformerCtcModelRknn::Impl { | ||
| 31 | + public: | ||
| 32 | + ~Impl() { | ||
| 33 | + auto ret = rknn_destroy(ctx_); | ||
| 34 | + if (ret != RKNN_SUCC) { | ||
| 35 | + SHERPA_ONNX_LOGE("Failed to destroy the context"); | ||
| 36 | + } | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + explicit Impl(const OnlineModelConfig &config) : config_(config) { | ||
| 40 | + { | ||
| 41 | + auto buf = ReadFile(config.zipformer2_ctc.model); | ||
| 42 | + Init(buf.data(), buf.size()); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + int32_t ret = RKNN_SUCC; | ||
| 46 | + switch (config_.num_threads) { | ||
| 47 | + case 1: | ||
| 48 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO); | ||
| 49 | + break; | ||
| 50 | + case 0: | ||
| 51 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0); | ||
| 52 | + break; | ||
| 53 | + case -1: | ||
| 54 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1); | ||
| 55 | + break; | ||
| 56 | + case -2: | ||
| 57 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2); | ||
| 58 | + break; | ||
| 59 | + case -3: | ||
| 60 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1); | ||
| 61 | + break; | ||
| 62 | + case -4: | ||
| 63 | + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2); | ||
| 64 | + break; | ||
| 65 | + default: | ||
| 66 | + SHERPA_ONNX_LOGE( | ||
| 67 | + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " | ||
| 68 | + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", | ||
| 69 | + config_.num_threads); | ||
| 70 | + break; | ||
| 71 | + } | ||
| 72 | + if (ret != RKNN_SUCC) { | ||
| 73 | + SHERPA_ONNX_LOGE( | ||
| 74 | + "Failed to select npu core to run the model (You can ignore it if " | ||
| 75 | + "you " | ||
| 76 | + "are not using RK3588."); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + // TODO(fangjun): Support Android | ||
| 81 | + | ||
| 82 | + std::vector<std::vector<uint8_t>> GetInitStates() const { | ||
| 83 | + // input_attrs_[0] is for the feature | ||
| 84 | + // input_attrs_[1:] is for states | ||
| 85 | + // so we use -1 here | ||
| 86 | + std::vector<std::vector<uint8_t>> states(input_attrs_.size() - 1); | ||
| 87 | + | ||
| 88 | + int32_t i = -1; | ||
| 89 | + for (auto &attr : input_attrs_) { | ||
| 90 | + i += 1; | ||
| 91 | + if (i == 0) { | ||
| 92 | + // skip processing the attr for features. | ||
| 93 | + continue; | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + if (attr.type == RKNN_TENSOR_FLOAT16) { | ||
| 97 | + states[i - 1].resize(attr.n_elems * sizeof(float)); | ||
| 98 | + } else if (attr.type == RKNN_TENSOR_INT64) { | ||
| 99 | + states[i - 1].resize(attr.n_elems * sizeof(int64_t)); | ||
| 100 | + } else { | ||
| 101 | + SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type, | ||
| 102 | + get_type_string(attr.type)); | ||
| 103 | + SHERPA_ONNX_EXIT(-1); | ||
| 104 | + } | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + return states; | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( | ||
| 111 | + std::vector<float> features, | ||
| 112 | + std::vector<std::vector<uint8_t>> states) const { | ||
| 113 | + std::vector<rknn_input> inputs(input_attrs_.size()); | ||
| 114 | + | ||
| 115 | + for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { | ||
| 116 | + auto &input = inputs[i]; | ||
| 117 | + auto &attr = input_attrs_[i]; | ||
| 118 | + input.index = attr.index; | ||
| 119 | + | ||
| 120 | + if (attr.type == RKNN_TENSOR_FLOAT16) { | ||
| 121 | + input.type = RKNN_TENSOR_FLOAT32; | ||
| 122 | + } else if (attr.type == RKNN_TENSOR_INT64) { | ||
| 123 | + input.type = RKNN_TENSOR_INT64; | ||
| 124 | + } else { | ||
| 125 | + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, | ||
| 126 | + get_type_string(attr.type)); | ||
| 127 | + SHERPA_ONNX_EXIT(-1); | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + input.fmt = attr.fmt; | ||
| 131 | + if (i == 0) { | ||
| 132 | + input.buf = reinterpret_cast<void *>(features.data()); | ||
| 133 | + input.size = features.size() * sizeof(float); | ||
| 134 | + } else { | ||
| 135 | + input.buf = reinterpret_cast<void *>(states[i - 1].data()); | ||
| 136 | + input.size = states[i - 1].size(); | ||
| 137 | + } | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + std::vector<float> out(output_attrs_[0].n_elems); | ||
| 141 | + | ||
| 142 | + // Note(fangjun): We can reuse the memory from input argument `states` | ||
| 143 | + // auto next_states = GetInitStates(); | ||
| 144 | + auto &next_states = states; | ||
| 145 | + | ||
| 146 | + std::vector<rknn_output> outputs(output_attrs_.size()); | ||
| 147 | + for (int32_t i = 0; i < outputs.size(); ++i) { | ||
| 148 | + auto &output = outputs[i]; | ||
| 149 | + auto &attr = output_attrs_[i]; | ||
| 150 | + output.index = attr.index; | ||
| 151 | + output.is_prealloc = 1; | ||
| 152 | + | ||
| 153 | + if (attr.type == RKNN_TENSOR_FLOAT16) { | ||
| 154 | + output.want_float = 1; | ||
| 155 | + } else if (attr.type == RKNN_TENSOR_INT64) { | ||
| 156 | + output.want_float = 0; | ||
| 157 | + } else { | ||
| 158 | + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, | ||
| 159 | + get_type_string(attr.type)); | ||
| 160 | + SHERPA_ONNX_EXIT(-1); | ||
| 161 | + } | ||
| 162 | + | ||
| 163 | + if (i == 0) { | ||
| 164 | + output.size = out.size() * sizeof(float); | ||
| 165 | + output.buf = reinterpret_cast<void *>(out.data()); | ||
| 166 | + } else { | ||
| 167 | + output.size = next_states[i - 1].size(); | ||
| 168 | + output.buf = reinterpret_cast<void *>(next_states[i - 1].data()); | ||
| 169 | + } | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); | ||
| 173 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); | ||
| 174 | + | ||
| 175 | + ret = rknn_run(ctx_, nullptr); | ||
| 176 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); | ||
| 177 | + | ||
| 178 | + ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); | ||
| 179 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); | ||
| 180 | + | ||
| 181 | + for (int32_t i = 0; i < next_states.size(); ++i) { | ||
| 182 | + const auto &attr = input_attrs_[i + 1]; | ||
| 183 | + if (attr.n_dims == 4) { | ||
| 184 | + // TODO(fangjun): The transpose is copied from | ||
| 185 | + // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 | ||
| 186 | + // I don't understand why we need to do that. | ||
| 187 | + std::vector<uint8_t> dst(next_states[i].size()); | ||
| 188 | + int32_t n = attr.dims[0]; | ||
| 189 | + int32_t h = attr.dims[1]; | ||
| 190 | + int32_t w = attr.dims[2]; | ||
| 191 | + int32_t c = attr.dims[3]; | ||
| 192 | + ConvertNCHWtoNHWC( | ||
| 193 | + reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w, | ||
| 194 | + reinterpret_cast<float *>(dst.data())); | ||
| 195 | + next_states[i] = std::move(dst); | ||
| 196 | + } | ||
| 197 | + } | ||
| 198 | + | ||
| 199 | + return {std::move(out), std::move(next_states)}; | ||
| 200 | + } | ||
| 201 | + | ||
| 202 | + int32_t ChunkSize() const { return T_; } | ||
| 203 | + | ||
| 204 | + int32_t ChunkShift() const { return decode_chunk_len_; } | ||
| 205 | + | ||
| 206 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 207 | + | ||
| 208 | + rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; } | ||
| 209 | + | ||
| 210 | + private: | ||
| 211 | + void Init(void *model_data, size_t model_data_length) { | ||
| 212 | + auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); | ||
| 213 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'", | ||
| 214 | + config_.zipformer2_ctc.model.c_str()); | ||
| 215 | + | ||
| 216 | + if (config_.debug) { | ||
| 217 | + rknn_sdk_version v; | ||
| 218 | + ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); | ||
| 219 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); | ||
| 220 | + | ||
| 221 | + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, | ||
| 222 | + v.drv_version); | ||
| 223 | + } | ||
| 224 | + | ||
| 225 | + rknn_input_output_num io_num; | ||
| 226 | + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); | ||
| 227 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); | ||
| 228 | + | ||
| 229 | + if (config_.debug) { | ||
| 230 | + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", | ||
| 231 | + static_cast<int32_t>(io_num.n_input), | ||
| 232 | + static_cast<int32_t>(io_num.n_output)); | ||
| 233 | + } | ||
| 234 | + | ||
| 235 | + input_attrs_.resize(io_num.n_input); | ||
| 236 | + output_attrs_.resize(io_num.n_output); | ||
| 237 | + | ||
| 238 | + int32_t i = 0; | ||
| 239 | + for (auto &attr : input_attrs_) { | ||
| 240 | + memset(&attr, 0, sizeof(attr)); | ||
| 241 | + attr.index = i; | ||
| 242 | + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 243 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); | ||
| 244 | + i += 1; | ||
| 245 | + } | ||
| 246 | + | ||
| 247 | + if (config_.debug) { | ||
| 248 | + std::ostringstream os; | ||
| 249 | + std::string sep; | ||
| 250 | + for (auto &attr : input_attrs_) { | ||
| 251 | + os << sep << ToString(attr); | ||
| 252 | + sep = "\n"; | ||
| 253 | + } | ||
| 254 | + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", | ||
| 255 | + os.str().c_str()); | ||
| 256 | + } | ||
| 257 | + | ||
| 258 | + i = 0; | ||
| 259 | + for (auto &attr : output_attrs_) { | ||
| 260 | + memset(&attr, 0, sizeof(attr)); | ||
| 261 | + attr.index = i; | ||
| 262 | + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 263 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); | ||
| 264 | + i += 1; | ||
| 265 | + } | ||
| 266 | + | ||
| 267 | + if (config_.debug) { | ||
| 268 | + std::ostringstream os; | ||
| 269 | + std::string sep; | ||
| 270 | + for (auto &attr : output_attrs_) { | ||
| 271 | + os << sep << ToString(attr); | ||
| 272 | + sep = "\n"; | ||
| 273 | + } | ||
| 274 | + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", | ||
| 275 | + os.str().c_str()); | ||
| 276 | + } | ||
| 277 | + | ||
| 278 | + rknn_custom_string custom_string; | ||
| 279 | + ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, | ||
| 280 | + sizeof(custom_string)); | ||
| 281 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); | ||
| 282 | + if (config_.debug) { | ||
| 283 | + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); | ||
| 284 | + } | ||
| 285 | + auto meta = Parse(custom_string); | ||
| 286 | + | ||
| 287 | + if (config_.debug) { | ||
| 288 | + for (const auto &p : meta) { | ||
| 289 | + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); | ||
| 290 | + } | ||
| 291 | + } | ||
| 292 | + | ||
| 293 | + if (meta.count("T")) { | ||
| 294 | + T_ = atoi(meta.at("T").c_str()); | ||
| 295 | + } | ||
| 296 | + | ||
| 297 | + if (meta.count("decode_chunk_len")) { | ||
| 298 | + decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str()); | ||
| 299 | + } | ||
| 300 | + | ||
| 301 | + vocab_size_ = output_attrs_[0].dims[2]; | ||
| 302 | + | ||
| 303 | + if (config_.debug) { | ||
| 304 | +#if __OHOS__ | ||
| 305 | + SHERPA_ONNX_LOGE("T: %{public}d", T_); | ||
| 306 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); | ||
| 307 | + SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size); | ||
| 308 | +#else | ||
| 309 | + SHERPA_ONNX_LOGE("T: %d", T_); | ||
| 310 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | ||
| 311 | + SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_); | ||
| 312 | +#endif | ||
| 313 | + } | ||
| 314 | + | ||
| 315 | + if (T_ == 0) { | ||
| 316 | + SHERPA_ONNX_LOGE( | ||
| 317 | + "Invalid T. Please use the script from icefall to export your model"); | ||
| 318 | + SHERPA_ONNX_EXIT(-1); | ||
| 319 | + } | ||
| 320 | + | ||
| 321 | + if (decode_chunk_len_ == 0) { | ||
| 322 | + SHERPA_ONNX_LOGE( | ||
| 323 | + "Invalid decode_chunk_len. Please use the script from icefall to " | ||
| 324 | + "export your model"); | ||
| 325 | + SHERPA_ONNX_EXIT(-1); | ||
| 326 | + } | ||
| 327 | + } | ||
| 328 | + | ||
| 329 | + private: | ||
| 330 | + OnlineModelConfig config_; | ||
| 331 | + rknn_context ctx_ = 0; | ||
| 332 | + | ||
| 333 | + std::vector<rknn_tensor_attr> input_attrs_; | ||
| 334 | + std::vector<rknn_tensor_attr> output_attrs_; | ||
| 335 | + | ||
| 336 | + int32_t T_ = 0; | ||
| 337 | + int32_t decode_chunk_len_ = 0; | ||
| 338 | + int32_t vocab_size_ = 0; | ||
| 339 | +}; | ||
| 340 | + | ||
| 341 | +OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default; | ||
| 342 | + | ||
| 343 | +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( | ||
| 344 | + const OnlineModelConfig &config) | ||
| 345 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 346 | + | ||
| 347 | +template <typename Manager> | ||
| 348 | +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( | ||
| 349 | + Manager *mgr, const OnlineModelConfig &config) | ||
| 350 | + : impl_(std::make_unique<OnlineZipformerCtcModelRknn>(mgr, config)) {} | ||
| 351 | + | ||
| 352 | +std::vector<std::vector<uint8_t>> OnlineZipformerCtcModelRknn::GetInitStates() | ||
| 353 | + const { | ||
| 354 | + return impl_->GetInitStates(); | ||
| 355 | +} | ||
| 356 | + | ||
| 357 | +std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> | ||
| 358 | +OnlineZipformerCtcModelRknn::Run( | ||
| 359 | + std::vector<float> features, | ||
| 360 | + std::vector<std::vector<uint8_t>> states) const { | ||
| 361 | + return impl_->Run(std::move(features), std::move(states)); | ||
| 362 | +} | ||
| 363 | + | ||
| 364 | +int32_t OnlineZipformerCtcModelRknn::ChunkSize() const { | ||
| 365 | + return impl_->ChunkSize(); | ||
| 366 | +} | ||
| 367 | + | ||
| 368 | +int32_t OnlineZipformerCtcModelRknn::ChunkShift() const { | ||
| 369 | + return impl_->ChunkShift(); | ||
| 370 | +} | ||
| 371 | + | ||
| 372 | +int32_t OnlineZipformerCtcModelRknn::VocabSize() const { | ||
| 373 | + return impl_->VocabSize(); | ||
| 374 | +} | ||
| 375 | + | ||
| 376 | +rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const { | ||
| 377 | + return impl_->GetOutAttr(); | ||
| 378 | +} | ||
| 379 | + | ||
| 380 | +#if __ANDROID_API__ >= 9 | ||
| 381 | +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( | ||
| 382 | + AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 383 | +#endif | ||
| 384 | + | ||
| 385 | +#if __OHOS__ | ||
| 386 | +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( | ||
| 387 | + NativeResourceManager *mgr, const OnlineModelConfig &config); | ||
| 388 | +#endif | ||
| 389 | + | ||
| 390 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "rknn_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OnlineZipformerCtcModelRknn { | ||
| 17 | + public: | ||
| 18 | + ~OnlineZipformerCtcModelRknn(); | ||
| 19 | + | ||
| 20 | + explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config); | ||
| 21 | + | ||
| 22 | + template <typename Manager> | ||
| 23 | + OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config); | ||
| 24 | + | ||
| 25 | + std::vector<std::vector<uint8_t>> GetInitStates() const; | ||
| 26 | + | ||
| 27 | + std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( | ||
| 28 | + std::vector<float> features, | ||
| 29 | + std::vector<std::vector<uint8_t>> states) const; | ||
| 30 | + | ||
| 31 | + int32_t ChunkSize() const; | ||
| 32 | + | ||
| 33 | + int32_t ChunkShift() const; | ||
| 34 | + | ||
| 35 | + int32_t VocabSize() const; | ||
| 36 | + | ||
| 37 | + rknn_tensor_attr GetOutAttr() const; | ||
| 38 | + | ||
| 39 | + private: | ||
| 40 | + class Impl; | ||
| 41 | + std::unique_ptr<Impl> impl_; | ||
| 42 | +}; | ||
| 43 | + | ||
| 44 | +} // namespace sherpa_onnx | ||
| 45 | + | ||
| 46 | +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ |
| 1 | // sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc | 1 | // sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2025 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" | 5 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" |
| 6 | 6 | ||
| @@ -22,68 +22,11 @@ | @@ -22,68 +22,11 @@ | ||
| 22 | 22 | ||
| 23 | #include "sherpa-onnx/csrc/file-utils.h" | 23 | #include "sherpa-onnx/csrc/file-utils.h" |
| 24 | #include "sherpa-onnx/csrc/rknn/macros.h" | 24 | #include "sherpa-onnx/csrc/rknn/macros.h" |
| 25 | +#include "sherpa-onnx/csrc/rknn/utils.h" | ||
| 25 | #include "sherpa-onnx/csrc/text-utils.h" | 26 | #include "sherpa-onnx/csrc/text-utils.h" |
| 26 | 27 | ||
| 27 | namespace sherpa_onnx { | 28 | namespace sherpa_onnx { |
| 28 | 29 | ||
| 29 | -// chw -> hwc | ||
| 30 | -static void Transpose(const float *src, int32_t n, int32_t channel, | ||
| 31 | - int32_t height, int32_t width, float *dst) { | ||
| 32 | - for (int32_t i = 0; i < n; ++i) { | ||
| 33 | - for (int32_t h = 0; h < height; ++h) { | ||
| 34 | - for (int32_t w = 0; w < width; ++w) { | ||
| 35 | - for (int32_t c = 0; c < channel; ++c) { | ||
| 36 | - // dst[h, w, c] = src[c, h, w] | ||
| 37 | - dst[i * height * width * channel + h * width * channel + w * channel + | ||
| 38 | - c] = src[i * height * width * channel + c * height * width + | ||
| 39 | - h * width + w]; | ||
| 40 | - } | ||
| 41 | - } | ||
| 42 | - } | ||
| 43 | - } | ||
| 44 | -} | ||
| 45 | - | ||
| 46 | -static std::string ToString(const rknn_tensor_attr &attr) { | ||
| 47 | - std::ostringstream os; | ||
| 48 | - os << "{"; | ||
| 49 | - os << attr.index; | ||
| 50 | - os << ", name: " << attr.name; | ||
| 51 | - os << ", shape: ("; | ||
| 52 | - std::string sep; | ||
| 53 | - for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) { | ||
| 54 | - os << sep << attr.dims[i]; | ||
| 55 | - sep = ","; | ||
| 56 | - } | ||
| 57 | - os << ")"; | ||
| 58 | - os << ", n_elems: " << attr.n_elems; | ||
| 59 | - os << ", size: " << attr.size; | ||
| 60 | - os << ", fmt: " << get_format_string(attr.fmt); | ||
| 61 | - os << ", type: " << get_type_string(attr.type); | ||
| 62 | - os << ", pass_through: " << (attr.pass_through ? "true" : "false"); | ||
| 63 | - os << "}"; | ||
| 64 | - return os.str(); | ||
| 65 | -} | ||
| 66 | - | ||
| 67 | -static std::unordered_map<std::string, std::string> Parse( | ||
| 68 | - const rknn_custom_string &custom_string) { | ||
| 69 | - std::unordered_map<std::string, std::string> ans; | ||
| 70 | - std::vector<std::string> fields; | ||
| 71 | - SplitStringToVector(custom_string.string, ";", false, &fields); | ||
| 72 | - | ||
| 73 | - std::vector<std::string> tmp; | ||
| 74 | - for (const auto &f : fields) { | ||
| 75 | - SplitStringToVector(f, "=", false, &tmp); | ||
| 76 | - if (tmp.size() != 2) { | ||
| 77 | - SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string, | ||
| 78 | - f.c_str()); | ||
| 79 | - SHERPA_ONNX_EXIT(-1); | ||
| 80 | - } | ||
| 81 | - ans[std::move(tmp[0])] = std::move(tmp[1]); | ||
| 82 | - } | ||
| 83 | - | ||
| 84 | - return ans; | ||
| 85 | -} | ||
| 86 | - | ||
| 87 | class OnlineZipformerTransducerModelRknn::Impl { | 30 | class OnlineZipformerTransducerModelRknn::Impl { |
| 88 | public: | 31 | public: |
| 89 | ~Impl() { | 32 | ~Impl() { |
| @@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 285 | for (int32_t i = 0; i < next_states.size(); ++i) { | 228 | for (int32_t i = 0; i < next_states.size(); ++i) { |
| 286 | const auto &attr = encoder_input_attrs_[i + 1]; | 229 | const auto &attr = encoder_input_attrs_[i + 1]; |
| 287 | if (attr.n_dims == 4) { | 230 | if (attr.n_dims == 4) { |
| 288 | - // TODO(fangjun): The transpose is copied from | 231 | + // TODO(fangjun): The ConvertNCHWtoNHWC is copied from |
| 289 | // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 | 232 | // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 |
| 290 | // I don't understand why we need to do that. | 233 | // I don't understand why we need to do that. |
| 291 | std::vector<uint8_t> dst(next_states[i].size()); | 234 | std::vector<uint8_t> dst(next_states[i].size()); |
| @@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 293 | int32_t h = attr.dims[1]; | 236 | int32_t h = attr.dims[1]; |
| 294 | int32_t w = attr.dims[2]; | 237 | int32_t w = attr.dims[2]; |
| 295 | int32_t c = attr.dims[3]; | 238 | int32_t c = attr.dims[3]; |
| 296 | - Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c, | ||
| 297 | - h, w, reinterpret_cast<float *>(dst.data())); | 239 | + ConvertNCHWtoNHWC( |
| 240 | + reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w, | ||
| 241 | + reinterpret_cast<float *>(dst.data())); | ||
| 298 | next_states[i] = std::move(dst); | 242 | next_states[i] = std::move(dst); |
| 299 | } | 243 | } |
| 300 | } | 244 | } |
| @@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 527 | #if __OHOS__ | 471 | #if __OHOS__ |
| 528 | SHERPA_ONNX_LOGE("T: %{public}d", T_); | 472 | SHERPA_ONNX_LOGE("T: %{public}d", T_); |
| 529 | SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); | 473 | SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); |
| 530 | - SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_); | ||
| 531 | #else | 474 | #else |
| 532 | SHERPA_ONNX_LOGE("T: %d", T_); | 475 | SHERPA_ONNX_LOGE("T: %d", T_); |
| 533 | SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | 476 | SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); |
| 534 | - SHERPA_ONNX_LOGE("context_size: %d", context_size_); | ||
| 535 | #endif | 477 | #endif |
| 536 | } | 478 | } |
| 537 | } | 479 | } |
| @@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 597 | SHERPA_ONNX_EXIT(-1); | 539 | SHERPA_ONNX_EXIT(-1); |
| 598 | } | 540 | } |
| 599 | 541 | ||
| 542 | + context_size_ = decoder_input_attrs_[0].dims[1]; | ||
| 543 | + if (config_.debug) { | ||
| 544 | + SHERPA_ONNX_LOGE("context_size: %d", context_size_); | ||
| 545 | + } | ||
| 546 | + | ||
| 600 | i = 0; | 547 | i = 0; |
| 601 | for (auto &attr : decoder_output_attrs_) { | 548 | for (auto &attr : decoder_output_attrs_) { |
| 602 | memset(&attr, 0, sizeof(attr)); | 549 | memset(&attr, 0, sizeof(attr)); |
| @@ -14,8 +14,11 @@ | @@ -14,8 +14,11 @@ | ||
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| 17 | -// this is for zipformer v1, i.e., the folder | ||
| 18 | -// pruned_transducer_statelss7_streaming from icefall | 17 | +// this is for zipformer v1 and v2, i.e., the folder |
| 18 | +// pruned_transducer_statelss7_streaming | ||
| 19 | +// and | ||
| 20 | +// zipformer | ||
| 21 | +// from icefall | ||
| 19 | class OnlineZipformerTransducerModelRknn { | 22 | class OnlineZipformerTransducerModelRknn { |
| 20 | public: | 23 | public: |
| 21 | ~OnlineZipformerTransducerModelRknn(); | 24 | ~OnlineZipformerTransducerModelRknn(); |
sherpa-onnx/csrc/rknn/utils.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/utils.cc | ||
| 2 | +// | ||
| 3 | +// Copyright 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/rknn/utils.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <unordered_map> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, | ||
| 17 | + int32_t height, int32_t width, float *dst) { | ||
| 18 | + for (int32_t i = 0; i < n; ++i) { | ||
| 19 | + for (int32_t h = 0; h < height; ++h) { | ||
| 20 | + for (int32_t w = 0; w < width; ++w) { | ||
| 21 | + for (int32_t c = 0; c < channel; ++c) { | ||
| 22 | + // dst[h, w, c] = src[c, h, w] | ||
| 23 | + dst[i * height * width * channel + h * width * channel + w * channel + | ||
| 24 | + c] = src[i * height * width * channel + c * height * width + | ||
| 25 | + h * width + w]; | ||
| 26 | + } | ||
| 27 | + } | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +std::string ToString(const rknn_tensor_attr &attr) { | ||
| 33 | + std::ostringstream os; | ||
| 34 | + os << "{"; | ||
| 35 | + os << attr.index; | ||
| 36 | + os << ", name: " << attr.name; | ||
| 37 | + os << ", shape: ("; | ||
| 38 | + std::string sep; | ||
| 39 | + for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) { | ||
| 40 | + os << sep << attr.dims[i]; | ||
| 41 | + sep = ","; | ||
| 42 | + } | ||
| 43 | + os << ")"; | ||
| 44 | + os << ", n_elems: " << attr.n_elems; | ||
| 45 | + os << ", size: " << attr.size; | ||
| 46 | + os << ", fmt: " << get_format_string(attr.fmt); | ||
| 47 | + os << ", type: " << get_type_string(attr.type); | ||
| 48 | + os << ", pass_through: " << (attr.pass_through ? "true" : "false"); | ||
| 49 | + os << "}"; | ||
| 50 | + return os.str(); | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +std::unordered_map<std::string, std::string> Parse( | ||
| 54 | + const rknn_custom_string &custom_string) { | ||
| 55 | + std::unordered_map<std::string, std::string> ans; | ||
| 56 | + std::vector<std::string> fields; | ||
| 57 | + SplitStringToVector(custom_string.string, ";", false, &fields); | ||
| 58 | + | ||
| 59 | + std::vector<std::string> tmp; | ||
| 60 | + for (const auto &f : fields) { | ||
| 61 | + SplitStringToVector(f, "=", false, &tmp); | ||
| 62 | + if (tmp.size() != 2) { | ||
| 63 | + SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string, | ||
| 64 | + f.c_str()); | ||
| 65 | + SHERPA_ONNX_EXIT(-1); | ||
| 66 | + } | ||
| 67 | + ans[std::move(tmp[0])] = std::move(tmp[1]); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + return ans; | ||
| 71 | +} | ||
| 72 | + | ||
| 73 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/rknn/utils.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/utils.h | ||
| 2 | +// | ||
| 3 | +// Copyright 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | +#include <unordered_map> | ||
| 10 | + | ||
| 11 | +#include "rknn_api.h" // NOLINT | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, | ||
| 15 | + int32_t height, int32_t width, float *dst); | ||
| 16 | + | ||
| 17 | +std::string ToString(const rknn_tensor_attr &attr); | ||
| 18 | + | ||
| 19 | +std::unordered_map<std::string, std::string> Parse( | ||
| 20 | + const rknn_custom_string &custom_string); | ||
| 21 | +} // namespace sherpa_onnx | ||
| 22 | + | ||
| 23 | +#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_ |
| @@ -83,6 +83,7 @@ for a list of pre-trained models to download. | @@ -83,6 +83,7 @@ for a list of pre-trained models to download. | ||
| 83 | po.Read(argc, argv); | 83 | po.Read(argc, argv); |
| 84 | if (po.NumArgs() < 1) { | 84 | if (po.NumArgs() < 1) { |
| 85 | po.PrintUsage(); | 85 | po.PrintUsage(); |
| 86 | + fprintf(stderr, "Error! Please provide at lease 1 wav file\n"); | ||
| 86 | exit(EXIT_FAILURE); | 87 | exit(EXIT_FAILURE); |
| 87 | } | 88 | } |
| 88 | 89 |
-
请 注册 或 登录 后发表评论