Committed by
GitHub
Add transducer modified_beam_search for RKNN. (#1949)
正在显示
10 个修改的文件
包含
290 行增加
和
25 行删除
| @@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) | @@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) | ||
| 155 | list(APPEND sources | 155 | list(APPEND sources |
| 156 | ./rknn/online-stream-rknn.cc | 156 | ./rknn/online-stream-rknn.cc |
| 157 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc | 157 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc |
| 158 | + ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc | ||
| 158 | ./rknn/online-zipformer-ctc-model-rknn.cc | 159 | ./rknn/online-zipformer-ctc-model-rknn.cc |
| 159 | ./rknn/online-zipformer-transducer-model-rknn.cc | 160 | ./rknn/online-zipformer-transducer-model-rknn.cc |
| 160 | ./rknn/utils.cc | 161 | ./rknn/utils.cc |
| @@ -142,7 +142,6 @@ class Hypotheses { | @@ -142,7 +142,6 @@ class Hypotheses { | ||
| 142 | 142 | ||
| 143 | void Clear() { hyps_dict_.clear(); } | 143 | void Clear() { hyps_dict_.clear(); } |
| 144 | 144 | ||
| 145 | - private: | ||
| 146 | // Return a list of hyps contained in this object. | 145 | // Return a list of hyps contained in this object. |
| 147 | std::vector<Hypothesis> Vec() const { | 146 | std::vector<Hypothesis> Vec() const { |
| 148 | std::vector<Hypothesis> ans; | 147 | std::vector<Hypothesis> ans; |
| @@ -119,5 +119,17 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | @@ -119,5 +119,17 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | ||
| 119 | return {vec_index.begin(), vec_index.begin() + k_num}; | 119 | return {vec_index.begin(), vec_index.begin() + k_num}; |
| 120 | } | 120 | } |
| 121 | 121 | ||
| 122 | +template <class T> | ||
| 123 | +std::vector<int32_t> TopkIndex(const std::vector<std::vector<T>> &vec, | ||
| 124 | + int32_t topk) { | ||
| 125 | + std::vector<T> flatten; | ||
| 126 | + flatten.reserve(vec.size() * vec[0].size()); | ||
| 127 | + for (const auto &v : vec) { | ||
| 128 | + flatten.insert(flatten.end(), v.begin(), v.end()); | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + return TopkIndex(flatten.data(), flatten.size(), topk); | ||
| 132 | +} | ||
| 133 | + | ||
| 122 | } // namespace sherpa_onnx | 134 | } // namespace sherpa_onnx |
| 123 | #endif // SHERPA_ONNX_CSRC_MATH_H_ | 135 | #endif // SHERPA_ONNX_CSRC_MATH_H_ |
| @@ -16,7 +16,9 @@ | @@ -16,7 +16,9 @@ | ||
| 16 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 16 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 17 | #include "sherpa-onnx/csrc/online-recognizer.h" | 17 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 18 | #include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" | 18 | #include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" |
| 19 | +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" | ||
| 19 | #include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" | 20 | #include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" |
| 21 | +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h" | ||
| 20 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" | 22 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" |
| 21 | #include "sherpa-onnx/csrc/symbol-table.h" | 23 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 22 | 24 | ||
| @@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | @@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | ||
| 87 | unk_id_ = sym_["<unk>"]; | 89 | unk_id_ = sym_["<unk>"]; |
| 88 | } | 90 | } |
| 89 | 91 | ||
| 90 | - decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoderRknn>( | ||
| 91 | - model_.get(), unk_id_); | 92 | + if (config.decoding_method == "greedy_search") { |
| 93 | + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoderRknn>( | ||
| 94 | + model_.get(), unk_id_); | ||
| 95 | + } else if (config.decoding_method == "modified_beam_search") { | ||
| 96 | + decoder_ = | ||
| 97 | + std::make_unique<OnlineTransducerModifiedBeamSearchDecoderRknn>( | ||
| 98 | + model_.get(), config.max_active_paths, unk_id_); | ||
| 99 | + } else { | ||
| 100 | + SHERPA_ONNX_LOGE( | ||
| 101 | + "Invalid decoding method: '%s'. Support only greedy_search and " | ||
| 102 | + "modified_beam_search.", | ||
| 103 | + config.decoding_method.c_str()); | ||
| 104 | + SHERPA_ONNX_EXIT(-1); | ||
| 105 | + } | ||
| 92 | } | 106 | } |
| 93 | 107 | ||
| 94 | template <typename Manager> | 108 | template <typename Manager> |
| @@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | @@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | ||
| 223 | Endpoint endpoint_; | 237 | Endpoint endpoint_; |
| 224 | int32_t unk_id_ = -1; | 238 | int32_t unk_id_ = -1; |
| 225 | std::unique_ptr<OnlineZipformerTransducerModelRknn> model_; | 239 | std::unique_ptr<OnlineZipformerTransducerModelRknn> model_; |
| 226 | - std::unique_ptr<OnlineTransducerGreedySearchDecoderRknn> decoder_; | 240 | + std::unique_ptr<OnlineTransducerDecoderRknn> decoder_; |
| 227 | }; | 241 | }; |
| 228 | 242 | ||
| 229 | } // namespace sherpa_onnx | 243 | } // namespace sherpa_onnx |
| @@ -8,7 +8,7 @@ | @@ -8,7 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "rknn_api.h" // NOLINT | 9 | #include "rknn_api.h" // NOLINT |
| 10 | #include "sherpa-onnx/csrc/online-stream.h" | 10 | #include "sherpa-onnx/csrc/online-stream.h" |
| 11 | -#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" | 11 | +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" |
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 |
| 1 | +// sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct OnlineTransducerDecoderResultRknn { | ||
| 16 | + /// Number of frames after subsampling we have decoded so far | ||
| 17 | + int32_t frame_offset = 0; | ||
| 18 | + | ||
| 19 | + /// The decoded token IDs so far | ||
| 20 | + std::vector<int64_t> tokens; | ||
| 21 | + | ||
| 22 | + /// number of trailing blank frames decoded so far | ||
| 23 | + int32_t num_trailing_blanks = 0; | ||
| 24 | + | ||
| 25 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 26 | + std::vector<int32_t> timestamps; | ||
| 27 | + | ||
| 28 | + // used only by greedy_search | ||
| 29 | + std::vector<float> previous_decoder_out; | ||
| 30 | + | ||
| 31 | + // used only in modified beam_search | ||
| 32 | + Hypotheses hyps; | ||
| 33 | + | ||
| 34 | + // used only by modified_beam_search | ||
| 35 | + std::vector<std::vector<float>> previous_decoder_out2; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +class OnlineTransducerDecoderRknn { | ||
| 39 | + public: | ||
| 40 | + virtual ~OnlineTransducerDecoderRknn() = default; | ||
| 41 | + | ||
| 42 | + /* Return an empty result. | ||
| 43 | + * | ||
| 44 | + * To simplify the decoding code, we add `context_size` blanks | ||
| 45 | + * to the beginning of the decoding result, which will be | ||
| 46 | + * stripped by calling `StripPrecedingBlanks()`. | ||
| 47 | + */ | ||
| 48 | + virtual OnlineTransducerDecoderResultRknn GetEmptyResult() const = 0; | ||
| 49 | + | ||
| 50 | + /** Strip blanks added by `GetEmptyResult()`. | ||
| 51 | + * | ||
| 52 | + * @param r It is changed in-place. | ||
| 53 | + */ | ||
| 54 | + virtual void StripLeadingBlanks( | ||
| 55 | + OnlineTransducerDecoderResultRknn * /*r*/) const {} | ||
| 56 | + | ||
| 57 | + virtual void Decode(std::vector<float> encoder_out, | ||
| 58 | + OnlineTransducerDecoderResultRknn *result) const = 0; | ||
| 59 | +}; | ||
| 60 | + | ||
| 61 | +} // namespace sherpa_onnx | ||
| 62 | + | ||
| 63 | +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ |
| @@ -7,39 +7,26 @@ | @@ -7,39 +7,26 @@ | ||
| 7 | 7 | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" | ||
| 11 | +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" | ||
| 10 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" | 12 | #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" |
| 11 | 13 | ||
| 12 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 13 | 15 | ||
| 14 | -struct OnlineTransducerDecoderResultRknn { | ||
| 15 | - /// Number of frames after subsampling we have decoded so far | ||
| 16 | - int32_t frame_offset = 0; | ||
| 17 | - | ||
| 18 | - /// The decoded token IDs so far | ||
| 19 | - std::vector<int64_t> tokens; | ||
| 20 | - | ||
| 21 | - /// number of trailing blank frames decoded so far | ||
| 22 | - int32_t num_trailing_blanks = 0; | ||
| 23 | - | ||
| 24 | - /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 25 | - std::vector<int32_t> timestamps; | ||
| 26 | - | ||
| 27 | - std::vector<float> previous_decoder_out; | ||
| 28 | -}; | ||
| 29 | - | ||
| 30 | -class OnlineTransducerGreedySearchDecoderRknn { | 16 | +class OnlineTransducerGreedySearchDecoderRknn |
| 17 | + : public OnlineTransducerDecoderRknn { | ||
| 31 | public: | 18 | public: |
| 32 | explicit OnlineTransducerGreedySearchDecoderRknn( | 19 | explicit OnlineTransducerGreedySearchDecoderRknn( |
| 33 | OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2, | 20 | OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2, |
| 34 | float blank_penalty = 0.0) | 21 | float blank_penalty = 0.0) |
| 35 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} | 22 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} |
| 36 | 23 | ||
| 37 | - OnlineTransducerDecoderResultRknn GetEmptyResult() const; | 24 | + OnlineTransducerDecoderResultRknn GetEmptyResult() const override; |
| 38 | 25 | ||
| 39 | - void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const; | 26 | + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override; |
| 40 | 27 | ||
| 41 | void Decode(std::vector<float> encoder_out, | 28 | void Decode(std::vector<float> encoder_out, |
| 42 | - OnlineTransducerDecoderResultRknn *result) const; | 29 | + OnlineTransducerDecoderResultRknn *result) const override; |
| 43 | 30 | ||
| 44 | private: | 31 | private: |
| 45 | OnlineZipformerTransducerModelRknn *model_; // Not owned | 32 | OnlineZipformerTransducerModelRknn *model_; // Not owned |
| 1 | +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/math.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +OnlineTransducerDecoderResultRknn | ||
| 18 | +OnlineTransducerModifiedBeamSearchDecoderRknn::GetEmptyResult() const { | ||
| 19 | + int32_t context_size = model_->ContextSize(); | ||
| 20 | + int32_t blank_id = 0; // always 0 | ||
| 21 | + OnlineTransducerDecoderResultRknn r; | ||
| 22 | + | ||
| 23 | + std::vector<int64_t> blanks(context_size, -1); | ||
| 24 | + blanks.back() = blank_id; | ||
| 25 | + | ||
| 26 | + Hypotheses blank_hyp({{blanks, 0}}); | ||
| 27 | + r.hyps = std::move(blank_hyp); | ||
| 28 | + r.tokens = std::move(blanks); | ||
| 29 | + | ||
| 30 | + return r; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +void OnlineTransducerModifiedBeamSearchDecoderRknn::StripLeadingBlanks( | ||
| 34 | + OnlineTransducerDecoderResultRknn *r) const { | ||
| 35 | + int32_t context_size = model_->ContextSize(); | ||
| 36 | + auto hyp = r->hyps.GetMostProbable(true); | ||
| 37 | + | ||
| 38 | + std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); | ||
| 39 | + r->tokens = std::move(tokens); | ||
| 40 | + r->timestamps = std::move(hyp.timestamps); | ||
| 41 | + | ||
| 42 | + r->num_trailing_blanks = hyp.num_trailing_blanks; | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +static std::vector<std::vector<float>> GetDecoderOut( | ||
| 46 | + OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec) { | ||
| 47 | + std::vector<std::vector<float>> ans; | ||
| 48 | + ans.reserve(hyp_vec.Size()); | ||
| 49 | + | ||
| 50 | + int32_t context_size = model->ContextSize(); | ||
| 51 | + for (const auto &p : hyp_vec) { | ||
| 52 | + const auto &hyp = p.second; | ||
| 53 | + auto start = hyp.ys.begin() + (hyp.ys.size() - context_size); | ||
| 54 | + auto end = hyp.ys.end(); | ||
| 55 | + auto tokens = std::vector<int64_t>(start, end); | ||
| 56 | + auto decoder_out = model->RunDecoder(std::move(tokens)); | ||
| 57 | + | ||
| 58 | + ans.push_back(std::move(decoder_out)); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + return ans; | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +static std::vector<std::vector<float>> GetJoinerOutLogSoftmax( | ||
| 65 | + OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out, | ||
| 66 | + const std::vector<std::vector<float>> &decoder_out) { | ||
| 67 | + std::vector<std::vector<float>> ans; | ||
| 68 | + ans.reserve(decoder_out.size()); | ||
| 69 | + | ||
| 70 | + for (const auto &d : decoder_out) { | ||
| 71 | + auto joiner_out = model->RunJoiner(p_encoder_out, d.data()); | ||
| 72 | + | ||
| 73 | + LogSoftmax(joiner_out.data(), joiner_out.size()); | ||
| 74 | + | ||
| 75 | + ans.push_back(std::move(joiner_out)); | ||
| 76 | + } | ||
| 77 | + return ans; | ||
| 78 | +} | ||
| 79 | + | ||
| 80 | +void OnlineTransducerModifiedBeamSearchDecoderRknn::Decode( | ||
| 81 | + std::vector<float> encoder_out, | ||
| 82 | + OnlineTransducerDecoderResultRknn *result) const { | ||
| 83 | + auto &r = result[0]; | ||
| 84 | + auto attr = model_->GetEncoderOutAttr(); | ||
| 85 | + int32_t num_frames = attr.dims[1]; | ||
| 86 | + int32_t encoder_out_dim = attr.dims[2]; | ||
| 87 | + | ||
| 88 | + int32_t vocab_size = model_->VocabSize(); | ||
| 89 | + int32_t context_size = model_->ContextSize(); | ||
| 90 | + | ||
| 91 | + Hypotheses cur = std::move(result->hyps); | ||
| 92 | + std::vector<Hypothesis> prev; | ||
| 93 | + | ||
| 94 | + auto decoder_out = std::move(result->previous_decoder_out2); | ||
| 95 | + if (decoder_out.empty()) { | ||
| 96 | + decoder_out = GetDecoderOut(model_, cur); | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + const float *p_encoder_out = encoder_out.data(); | ||
| 100 | + | ||
| 101 | + int32_t frame_offset = result->frame_offset; | ||
| 102 | + | ||
| 103 | + for (int32_t t = 0; t != num_frames; ++t) { | ||
| 104 | + prev = cur.Vec(); | ||
| 105 | + cur.Clear(); | ||
| 106 | + | ||
| 107 | + auto log_probs = GetJoinerOutLogSoftmax(model_, p_encoder_out, decoder_out); | ||
| 108 | + p_encoder_out += encoder_out_dim; | ||
| 109 | + | ||
| 110 | + for (int32_t i = 0; i != prev.size(); ++i) { | ||
| 111 | + auto log_prob = prev[i].log_prob; | ||
| 112 | + for (auto &p : log_probs[i]) { | ||
| 113 | + p += log_prob; | ||
| 114 | + } | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + auto topk = TopkIndex(log_probs, max_active_paths_); | ||
| 118 | + for (auto k : topk) { | ||
| 119 | + int32_t hyp_index = k / vocab_size; | ||
| 120 | + int32_t new_token = k % vocab_size; | ||
| 121 | + | ||
| 122 | + Hypothesis new_hyp = prev[hyp_index]; | ||
| 123 | + new_hyp.log_prob = log_probs[hyp_index][new_token]; | ||
| 124 | + | ||
| 125 | + // blank is hardcoded to 0 | ||
| 126 | + // also, it treats unk as blank | ||
| 127 | + if (new_token != 0 && new_token != unk_id_) { | ||
| 128 | + new_hyp.ys.push_back(new_token); | ||
| 129 | + new_hyp.timestamps.push_back(t + frame_offset); | ||
| 130 | + new_hyp.num_trailing_blanks = 0; | ||
| 131 | + | ||
| 132 | + } else { | ||
| 133 | + ++new_hyp.num_trailing_blanks; | ||
| 134 | + } | ||
| 135 | + cur.Add(std::move(new_hyp)); | ||
| 136 | + } | ||
| 137 | + | ||
| 138 | + decoder_out = GetDecoderOut(model_, cur); | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + result->hyps = std::move(cur); | ||
| 142 | + result->frame_offset += num_frames; | ||
| 143 | + result->previous_decoder_out2 = std::move(decoder_out); | ||
| 144 | +} | ||
| 145 | + | ||
| 146 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" | ||
| 11 | +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OnlineTransducerModifiedBeamSearchDecoderRknn | ||
| 16 | + : public OnlineTransducerDecoderRknn { | ||
| 17 | + public: | ||
| 18 | + explicit OnlineTransducerModifiedBeamSearchDecoderRknn( | ||
| 19 | + OnlineZipformerTransducerModelRknn *model, int32_t max_active_paths, | ||
| 20 | + int32_t unk_id = 2, float blank_penalty = 0.0) | ||
| 21 | + : model_(model), | ||
| 22 | + max_active_paths_(max_active_paths), | ||
| 23 | + unk_id_(unk_id), | ||
| 24 | + blank_penalty_(blank_penalty) {} | ||
| 25 | + | ||
| 26 | + OnlineTransducerDecoderResultRknn GetEmptyResult() const override; | ||
| 27 | + | ||
| 28 | + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override; | ||
| 29 | + | ||
| 30 | + void Decode(std::vector<float> encoder_out, | ||
| 31 | + OnlineTransducerDecoderResultRknn *result) const override; | ||
| 32 | + | ||
| 33 | + private: | ||
| 34 | + OnlineZipformerTransducerModelRknn *model_; // Not owned | ||
| 35 | + int32_t max_active_paths_; | ||
| 36 | + int32_t unk_id_; | ||
| 37 | + float blank_penalty_; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx | ||
| 41 | + | ||
| 42 | +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ |
-
请 注册 或 登录 后发表评论