Fangjun Kuang
Committed by GitHub

Add transducer modified_beam_search for RKNN. (#1949)

@@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) @@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
155 list(APPEND sources 155 list(APPEND sources
156 ./rknn/online-stream-rknn.cc 156 ./rknn/online-stream-rknn.cc
157 ./rknn/online-transducer-greedy-search-decoder-rknn.cc 157 ./rknn/online-transducer-greedy-search-decoder-rknn.cc
  158 + ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
158 ./rknn/online-zipformer-ctc-model-rknn.cc 159 ./rknn/online-zipformer-ctc-model-rknn.cc
159 ./rknn/online-zipformer-transducer-model-rknn.cc 160 ./rknn/online-zipformer-transducer-model-rknn.cc
160 ./rknn/utils.cc 161 ./rknn/utils.cc
@@ -142,7 +142,6 @@ class Hypotheses { @@ -142,7 +142,6 @@ class Hypotheses {
142 142
143 void Clear() { hyps_dict_.clear(); } 143 void Clear() { hyps_dict_.clear(); }
144 144
145 - private:  
146 // Return a list of hyps contained in this object. 145 // Return a list of hyps contained in this object.
147 std::vector<Hypothesis> Vec() const { 146 std::vector<Hypothesis> Vec() const {
148 std::vector<Hypothesis> ans; 147 std::vector<Hypothesis> ans;
@@ -119,5 +119,17 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { @@ -119,5 +119,17 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
119 return {vec_index.begin(), vec_index.begin() + k_num}; 119 return {vec_index.begin(), vec_index.begin() + k_num};
120 } 120 }
121 121
  122 +template <class T>
  123 +std::vector<int32_t> TopkIndex(const std::vector<std::vector<T>> &vec,
  124 + int32_t topk) {
  125 + std::vector<T> flatten;
  126 + flatten.reserve(vec.size() * vec[0].size());
  127 + for (const auto &v : vec) {
  128 + flatten.insert(flatten.end(), v.begin(), v.end());
  129 + }
  130 +
  131 + return TopkIndex(flatten.data(), flatten.size(), topk);
  132 +}
  133 +
122 } // namespace sherpa_onnx 134 } // namespace sherpa_onnx
123 #endif // SHERPA_ONNX_CSRC_MATH_H_ 135 #endif // SHERPA_ONNX_CSRC_MATH_H_
@@ -16,7 +16,9 @@ @@ -16,7 +16,9 @@
16 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 16 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
17 #include "sherpa-onnx/csrc/online-recognizer.h" 17 #include "sherpa-onnx/csrc/online-recognizer.h"
18 #include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" 18 #include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
  19 +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
19 #include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" 20 #include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
  21 +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h"
20 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" 22 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
21 #include "sherpa-onnx/csrc/symbol-table.h" 23 #include "sherpa-onnx/csrc/symbol-table.h"
22 24
@@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { @@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
87 unk_id_ = sym_["<unk>"]; 89 unk_id_ = sym_["<unk>"];
88 } 90 }
89 91
90 - decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoderRknn>(  
91 - model_.get(), unk_id_); 92 + if (config.decoding_method == "greedy_search") {
  93 + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoderRknn>(
  94 + model_.get(), unk_id_);
  95 + } else if (config.decoding_method == "modified_beam_search") {
  96 + decoder_ =
  97 + std::make_unique<OnlineTransducerModifiedBeamSearchDecoderRknn>(
  98 + model_.get(), config.max_active_paths, unk_id_);
  99 + } else {
  100 + SHERPA_ONNX_LOGE(
  101 + "Invalid decoding method: '%s'. Support only greedy_search and "
  102 + "modified_beam_search.",
  103 + config.decoding_method.c_str());
  104 + SHERPA_ONNX_EXIT(-1);
  105 + }
92 } 106 }
93 107
94 template <typename Manager> 108 template <typename Manager>
@@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { @@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
223 Endpoint endpoint_; 237 Endpoint endpoint_;
224 int32_t unk_id_ = -1; 238 int32_t unk_id_ = -1;
225 std::unique_ptr<OnlineZipformerTransducerModelRknn> model_; 239 std::unique_ptr<OnlineZipformerTransducerModelRknn> model_;
226 - std::unique_ptr<OnlineTransducerGreedySearchDecoderRknn> decoder_; 240 + std::unique_ptr<OnlineTransducerDecoderRknn> decoder_;
227 }; 241 };
228 242
229 } // namespace sherpa_onnx 243 } // namespace sherpa_onnx
@@ -8,7 +8,7 @@ @@ -8,7 +8,7 @@
8 8
9 #include "rknn_api.h" // NOLINT 9 #include "rknn_api.h" // NOLINT
10 #include "sherpa-onnx/csrc/online-stream.h" 10 #include "sherpa-onnx/csrc/online-stream.h"
11 -#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" 11 +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
  1 +// sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/hypothesis.h"
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OnlineTransducerDecoderResultRknn {
  16 + /// Number of frames after subsampling we have decoded so far
  17 + int32_t frame_offset = 0;
  18 +
  19 + /// The decoded token IDs so far
  20 + std::vector<int64_t> tokens;
  21 +
  22 + /// number of trailing blank frames decoded so far
  23 + int32_t num_trailing_blanks = 0;
  24 +
  25 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  26 + std::vector<int32_t> timestamps;
  27 +
  28 + // used only by greedy_search
  29 + std::vector<float> previous_decoder_out;
  30 +
  31 + // used only in modified beam_search
  32 + Hypotheses hyps;
  33 +
  34 + // used only by modified_beam_search
  35 + std::vector<std::vector<float>> previous_decoder_out2;
  36 +};
  37 +
  38 +class OnlineTransducerDecoderRknn {
  39 + public:
  40 + virtual ~OnlineTransducerDecoderRknn() = default;
  41 +
  42 + /* Return an empty result.
  43 + *
  44 + * To simplify the decoding code, we add `context_size` blanks
  45 + * to the beginning of the decoding result, which will be
  46 + * stripped by calling `StripPrecedingBlanks()`.
  47 + */
  48 + virtual OnlineTransducerDecoderResultRknn GetEmptyResult() const = 0;
  49 +
  50 + /** Strip blanks added by `GetEmptyResult()`.
  51 + *
  52 + * @param r It is changed in-place.
  53 + */
  54 + virtual void StripLeadingBlanks(
  55 + OnlineTransducerDecoderResultRknn * /*r*/) const {}
  56 +
  57 + virtual void Decode(std::vector<float> encoder_out,
  58 + OnlineTransducerDecoderResultRknn *result) const = 0;
  59 +};
  60 +
  61 +} // namespace sherpa_onnx
  62 +
  63 +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
@@ -7,39 +7,26 @@ @@ -7,39 +7,26 @@
7 7
8 #include <vector> 8 #include <vector>
9 9
  10 +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
  11 +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
10 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" 12 #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
11 13
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
13 15
14 -struct OnlineTransducerDecoderResultRknn {  
15 - /// Number of frames after subsampling we have decoded so far  
16 - int32_t frame_offset = 0;  
17 -  
18 - /// The decoded token IDs so far  
19 - std::vector<int64_t> tokens;  
20 -  
21 - /// number of trailing blank frames decoded so far  
22 - int32_t num_trailing_blanks = 0;  
23 -  
24 - /// timestamps[i] contains the output frame index where tokens[i] is decoded.  
25 - std::vector<int32_t> timestamps;  
26 -  
27 - std::vector<float> previous_decoder_out;  
28 -};  
29 -  
30 -class OnlineTransducerGreedySearchDecoderRknn { 16 +class OnlineTransducerGreedySearchDecoderRknn
  17 + : public OnlineTransducerDecoderRknn {
31 public: 18 public:
32 explicit OnlineTransducerGreedySearchDecoderRknn( 19 explicit OnlineTransducerGreedySearchDecoderRknn(
33 OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2, 20 OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2,
34 float blank_penalty = 0.0) 21 float blank_penalty = 0.0)
35 : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} 22 : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
36 23
37 - OnlineTransducerDecoderResultRknn GetEmptyResult() const; 24 + OnlineTransducerDecoderResultRknn GetEmptyResult() const override;
38 25
39 - void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const; 26 + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override;
40 27
41 void Decode(std::vector<float> encoder_out, 28 void Decode(std::vector<float> encoder_out,
42 - OnlineTransducerDecoderResultRknn *result) const; 29 + OnlineTransducerDecoderResultRknn *result) const override;
43 30
44 private: 31 private:
45 OnlineZipformerTransducerModelRknn *model_; // Not owned 32 OnlineZipformerTransducerModelRknn *model_; // Not owned
  1 +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/hypothesis.h"
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/math.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +OnlineTransducerDecoderResultRknn
  18 +OnlineTransducerModifiedBeamSearchDecoderRknn::GetEmptyResult() const {
  19 + int32_t context_size = model_->ContextSize();
  20 + int32_t blank_id = 0; // always 0
  21 + OnlineTransducerDecoderResultRknn r;
  22 +
  23 + std::vector<int64_t> blanks(context_size, -1);
  24 + blanks.back() = blank_id;
  25 +
  26 + Hypotheses blank_hyp({{blanks, 0}});
  27 + r.hyps = std::move(blank_hyp);
  28 + r.tokens = std::move(blanks);
  29 +
  30 + return r;
  31 +}
  32 +
  33 +void OnlineTransducerModifiedBeamSearchDecoderRknn::StripLeadingBlanks(
  34 + OnlineTransducerDecoderResultRknn *r) const {
  35 + int32_t context_size = model_->ContextSize();
  36 + auto hyp = r->hyps.GetMostProbable(true);
  37 +
  38 + std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
  39 + r->tokens = std::move(tokens);
  40 + r->timestamps = std::move(hyp.timestamps);
  41 +
  42 + r->num_trailing_blanks = hyp.num_trailing_blanks;
  43 +}
  44 +
  45 +static std::vector<std::vector<float>> GetDecoderOut(
  46 + OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec) {
  47 + std::vector<std::vector<float>> ans;
  48 + ans.reserve(hyp_vec.Size());
  49 +
  50 + int32_t context_size = model->ContextSize();
  51 + for (const auto &p : hyp_vec) {
  52 + const auto &hyp = p.second;
  53 + auto start = hyp.ys.begin() + (hyp.ys.size() - context_size);
  54 + auto end = hyp.ys.end();
  55 + auto tokens = std::vector<int64_t>(start, end);
  56 + auto decoder_out = model->RunDecoder(std::move(tokens));
  57 +
  58 + ans.push_back(std::move(decoder_out));
  59 + }
  60 +
  61 + return ans;
  62 +}
  63 +
  64 +static std::vector<std::vector<float>> GetJoinerOutLogSoftmax(
  65 + OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out,
  66 + const std::vector<std::vector<float>> &decoder_out) {
  67 + std::vector<std::vector<float>> ans;
  68 + ans.reserve(decoder_out.size());
  69 +
  70 + for (const auto &d : decoder_out) {
  71 + auto joiner_out = model->RunJoiner(p_encoder_out, d.data());
  72 +
  73 + LogSoftmax(joiner_out.data(), joiner_out.size());
  74 +
  75 + ans.push_back(std::move(joiner_out));
  76 + }
  77 + return ans;
  78 +}
  79 +
  80 +void OnlineTransducerModifiedBeamSearchDecoderRknn::Decode(
  81 + std::vector<float> encoder_out,
  82 + OnlineTransducerDecoderResultRknn *result) const {
  83 + auto &r = result[0];
  84 + auto attr = model_->GetEncoderOutAttr();
  85 + int32_t num_frames = attr.dims[1];
  86 + int32_t encoder_out_dim = attr.dims[2];
  87 +
  88 + int32_t vocab_size = model_->VocabSize();
  89 + int32_t context_size = model_->ContextSize();
  90 +
  91 + Hypotheses cur = std::move(result->hyps);
  92 + std::vector<Hypothesis> prev;
  93 +
  94 + auto decoder_out = std::move(result->previous_decoder_out2);
  95 + if (decoder_out.empty()) {
  96 + decoder_out = GetDecoderOut(model_, cur);
  97 + }
  98 +
  99 + const float *p_encoder_out = encoder_out.data();
  100 +
  101 + int32_t frame_offset = result->frame_offset;
  102 +
  103 + for (int32_t t = 0; t != num_frames; ++t) {
  104 + prev = cur.Vec();
  105 + cur.Clear();
  106 +
  107 + auto log_probs = GetJoinerOutLogSoftmax(model_, p_encoder_out, decoder_out);
  108 + p_encoder_out += encoder_out_dim;
  109 +
  110 + for (int32_t i = 0; i != prev.size(); ++i) {
  111 + auto log_prob = prev[i].log_prob;
  112 + for (auto &p : log_probs[i]) {
  113 + p += log_prob;
  114 + }
  115 + }
  116 +
  117 + auto topk = TopkIndex(log_probs, max_active_paths_);
  118 + for (auto k : topk) {
  119 + int32_t hyp_index = k / vocab_size;
  120 + int32_t new_token = k % vocab_size;
  121 +
  122 + Hypothesis new_hyp = prev[hyp_index];
  123 + new_hyp.log_prob = log_probs[hyp_index][new_token];
  124 +
  125 + // blank is hardcoded to 0
  126 + // also, it treats unk as blank
  127 + if (new_token != 0 && new_token != unk_id_) {
  128 + new_hyp.ys.push_back(new_token);
  129 + new_hyp.timestamps.push_back(t + frame_offset);
  130 + new_hyp.num_trailing_blanks = 0;
  131 +
  132 + } else {
  133 + ++new_hyp.num_trailing_blanks;
  134 + }
  135 + cur.Add(std::move(new_hyp));
  136 + }
  137 +
  138 + decoder_out = GetDecoderOut(model_, cur);
  139 + }
  140 +
  141 + result->hyps = std::move(cur);
  142 + result->frame_offset += num_frames;
  143 + result->previous_decoder_out2 = std::move(decoder_out);
  144 +}
  145 +
  146 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
  11 +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OnlineTransducerModifiedBeamSearchDecoderRknn
  16 + : public OnlineTransducerDecoderRknn {
  17 + public:
  18 + explicit OnlineTransducerModifiedBeamSearchDecoderRknn(
  19 + OnlineZipformerTransducerModelRknn *model, int32_t max_active_paths,
  20 + int32_t unk_id = 2, float blank_penalty = 0.0)
  21 + : model_(model),
  22 + max_active_paths_(max_active_paths),
  23 + unk_id_(unk_id),
  24 + blank_penalty_(blank_penalty) {}
  25 +
  26 + OnlineTransducerDecoderResultRknn GetEmptyResult() const override;
  27 +
  28 + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override;
  29 +
  30 + void Decode(std::vector<float> encoder_out,
  31 + OnlineTransducerDecoderResultRknn *result) const override;
  32 +
  33 + private:
  34 + OnlineZipformerTransducerModelRknn *model_; // Not owned
  35 + int32_t max_active_paths_;
  36 + int32_t unk_id_;
  37 + float blank_penalty_;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <sstream> 7 #include <sstream>
8 #include <unordered_map> 8 #include <unordered_map>
  9 +#include <utility>
9 #include <vector> 10 #include <vector>
10 11
11 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"