正在显示
19 个修改的文件
包含
610 行增加
和
83 行删除
| @@ -5,11 +5,13 @@ set(sources | @@ -5,11 +5,13 @@ set(sources | ||
| 5 | endpoint.cc | 5 | endpoint.cc |
| 6 | features.cc | 6 | features.cc |
| 7 | file-utils.cc | 7 | file-utils.cc |
| 8 | + hypothesis.cc | ||
| 8 | online-lstm-transducer-model.cc | 9 | online-lstm-transducer-model.cc |
| 9 | online-recognizer.cc | 10 | online-recognizer.cc |
| 10 | online-stream.cc | 11 | online-stream.cc |
| 11 | online-transducer-greedy-search-decoder.cc | 12 | online-transducer-greedy-search-decoder.cc |
| 12 | online-transducer-model-config.cc | 13 | online-transducer-model-config.cc |
| 14 | + online-transducer-modified-beam-search-decoder.cc | ||
| 13 | online-transducer-model.cc | 15 | online-transducer-model.cc |
| 14 | online-zipformer-transducer-model.cc | 16 | online-zipformer-transducer-model.cc |
| 15 | onnx-utils.cc | 17 | onnx-utils.cc |
sherpa-onnx/csrc/hypothesis.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2023 Xiaomi Corporation | ||
| 3 | + * | ||
| 4 | + */ | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void Hypotheses::Add(Hypothesis hyp) { | ||
| 14 | + auto key = hyp.Key(); | ||
| 15 | + auto it = hyps_dict_.find(key); | ||
| 16 | + if (it == hyps_dict_.end()) { | ||
| 17 | + hyps_dict_[key] = std::move(hyp); | ||
| 18 | + } else { | ||
| 19 | + it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); | ||
| 20 | + } | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { | ||
| 24 | + if (length_norm == false) { | ||
| 25 | + return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), | ||
| 26 | + [](const auto &left, auto &right) -> bool { | ||
| 27 | + return left.second.log_prob < | ||
| 28 | + right.second.log_prob; | ||
| 29 | + }) | ||
| 30 | + ->second; | ||
| 31 | + } else { | ||
| 32 | + // for length_norm is true | ||
| 33 | + return std::max_element( | ||
| 34 | + hyps_dict_.begin(), hyps_dict_.end(), | ||
| 35 | + [](const auto &left, const auto &right) -> bool { | ||
| 36 | + return left.second.log_prob / left.second.ys.size() < | ||
| 37 | + right.second.log_prob / right.second.ys.size(); | ||
| 38 | + }) | ||
| 39 | + ->second; | ||
| 40 | + } | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { | ||
| 44 | + k = std::max(k, 1); | ||
| 45 | + k = std::min(k, Size()); | ||
| 46 | + | ||
| 47 | + std::vector<Hypothesis> all_hyps = Vec(); | ||
| 48 | + | ||
| 49 | + if (length_norm == false) { | ||
| 50 | + std::partial_sort( | ||
| 51 | + all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), | ||
| 52 | + [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; }); | ||
| 53 | + } else { | ||
| 54 | + // for length_norm is true | ||
| 55 | + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), | ||
| 56 | + [](const auto &a, const auto &b) { | ||
| 57 | + return a.log_prob / a.ys.size() > | ||
| 58 | + b.log_prob / b.ys.size(); | ||
| 59 | + }); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + return {all_hyps.begin(), all_hyps.begin() + k}; | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/hypothesis.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2023 Xiaomi Corporation | ||
| 3 | + * | ||
| 4 | + */ | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_ | ||
| 8 | + | ||
| 9 | +#include <sstream> | ||
| 10 | +#include <string> | ||
| 11 | +#include <unordered_map> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#include "sherpa-onnx/csrc/math.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +struct Hypothesis { | ||
| 20 | + // The predicted tokens so far. Newly predicated tokens are appended. | ||
| 21 | + std::vector<int32_t> ys; | ||
| 22 | + | ||
| 23 | + // timestamps[i] contains the frame number after subsampling | ||
| 24 | + // on which ys[i] is decoded. | ||
| 25 | + std::vector<int32_t> timestamps; | ||
| 26 | + | ||
| 27 | + // The total score of ys in log space. | ||
| 28 | + double log_prob = 0; | ||
| 29 | + | ||
| 30 | + int32_t num_trailing_blanks = 0; | ||
| 31 | + | ||
| 32 | + Hypothesis() = default; | ||
| 33 | + Hypothesis(const std::vector<int32_t> &ys, double log_prob) | ||
| 34 | + : ys(ys), log_prob(log_prob) {} | ||
| 35 | + | ||
| 36 | + // If two Hypotheses have the same `Key`, then they contain | ||
| 37 | + // the same token sequence. | ||
| 38 | + std::string Key() const { | ||
| 39 | + // TODO(fangjun): Use a hash function? | ||
| 40 | + std::ostringstream os; | ||
| 41 | + std::string sep = "-"; | ||
| 42 | + for (auto i : ys) { | ||
| 43 | + os << i << sep; | ||
| 44 | + sep = "-"; | ||
| 45 | + } | ||
| 46 | + return os.str(); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + // For debugging | ||
| 50 | + std::string ToString() const { | ||
| 51 | + std::ostringstream os; | ||
| 52 | + os << "(" << Key() << ", " << log_prob << ")"; | ||
| 53 | + return os.str(); | ||
| 54 | + } | ||
| 55 | +}; | ||
| 56 | + | ||
| 57 | +class Hypotheses { | ||
| 58 | + public: | ||
| 59 | + Hypotheses() = default; | ||
| 60 | + | ||
| 61 | + explicit Hypotheses(std::vector<Hypothesis> hyps) { | ||
| 62 | + for (auto &h : hyps) { | ||
| 63 | + hyps_dict_[h.Key()] = std::move(h); | ||
| 64 | + } | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + explicit Hypotheses(std::unordered_map<std::string, Hypothesis> hyps_dict) | ||
| 68 | + : hyps_dict_(std::move(hyps_dict)) {} | ||
| 69 | + | ||
| 70 | + // Add hyp to this object. If it already exists, its log_prob | ||
| 71 | + // is updated with the given hyp using log-sum-exp. | ||
| 72 | + void Add(Hypothesis hyp); | ||
| 73 | + | ||
| 74 | + // Get the hyp that has the largest log_prob. | ||
| 75 | + // If length_norm is true, hyp's log_prob is divided by | ||
| 76 | + // len(hyp.ys) before comparison. | ||
| 77 | + Hypothesis GetMostProbable(bool length_norm) const; | ||
| 78 | + | ||
| 79 | + // Get the k hyps that have the largest log_prob. | ||
| 80 | + // If length_norm is true, hyp's log_prob is divided by | ||
| 81 | + // len(hyp.ys) before comparison. | ||
| 82 | + std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm) const; | ||
| 83 | + | ||
| 84 | + int32_t Size() const { return hyps_dict_.size(); } | ||
| 85 | + | ||
| 86 | + std::string ToString() const { | ||
| 87 | + std::ostringstream os; | ||
| 88 | + for (const auto &p : hyps_dict_) { | ||
| 89 | + os << p.second.ToString() << "\n"; | ||
| 90 | + } | ||
| 91 | + return os.str(); | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + const auto begin() const { return hyps_dict_.begin(); } | ||
| 95 | + const auto end() const { return hyps_dict_.end(); } | ||
| 96 | + | ||
| 97 | + void Clear() { hyps_dict_.clear(); } | ||
| 98 | + | ||
| 99 | + private: | ||
| 100 | + // Return a list of hyps contained in this object. | ||
| 101 | + std::vector<Hypothesis> Vec() const { | ||
| 102 | + std::vector<Hypothesis> ans; | ||
| 103 | + ans.reserve(hyps_dict_.size()); | ||
| 104 | + for (const auto &p : hyps_dict_) { | ||
| 105 | + ans.push_back(p.second); | ||
| 106 | + } | ||
| 107 | + return ans; | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + private: | ||
| 111 | + using Map = std ::unordered_map<std::string, Hypothesis>; | ||
| 112 | + Map hyps_dict_; | ||
| 113 | +}; | ||
| 114 | + | ||
| 115 | +} // namespace sherpa_onnx | ||
| 116 | + | ||
| 117 | +#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_ |
sherpa-onnx/csrc/math.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey) | ||
| 3 | + * Copyright (c) 2023 (Pingfeng Luo) | ||
| 4 | + * | ||
| 5 | + */ | ||
| 6 | +// This file is copied from k2/csrc/utils.h | ||
| 7 | +#ifndef SHERPA_ONNX_CSRC_MATH_H_ | ||
| 8 | +#define SHERPA_ONNX_CSRC_MATH_H_ | ||
| 9 | + | ||
| 10 | +#include <algorithm> | ||
| 11 | +#include <cassert> | ||
| 12 | +#include <cmath> | ||
| 13 | +#include <numeric> | ||
| 14 | +#include <vector> | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +// logf(FLT_EPSILON) | ||
| 19 | +#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f | ||
| 20 | + | ||
| 21 | +// log(DBL_EPSILON) | ||
| 22 | +#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \ | ||
| 23 | + -36.0436533891171535515240975655615329742431640625 | ||
| 24 | + | ||
| 25 | +template <typename T> | ||
| 26 | +struct LogAdd; | ||
| 27 | + | ||
| 28 | +template <> | ||
| 29 | +struct LogAdd<double> { | ||
| 30 | + double operator()(double x, double y) const { | ||
| 31 | + double diff; | ||
| 32 | + | ||
| 33 | + if (x < y) { | ||
| 34 | + diff = x - y; | ||
| 35 | + x = y; | ||
| 36 | + } else { | ||
| 37 | + diff = y - x; | ||
| 38 | + } | ||
| 39 | + // diff is negative. x is now the larger one. | ||
| 40 | + | ||
| 41 | + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { | ||
| 42 | + double res; | ||
| 43 | + res = x + log1p(exp(diff)); | ||
| 44 | + return res; | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + return x; // return the larger one. | ||
| 48 | + } | ||
| 49 | +}; | ||
| 50 | + | ||
| 51 | +template <> | ||
| 52 | +struct LogAdd<float> { | ||
| 53 | + float operator()(float x, float y) const { | ||
| 54 | + float diff; | ||
| 55 | + | ||
| 56 | + if (x < y) { | ||
| 57 | + diff = x - y; | ||
| 58 | + x = y; | ||
| 59 | + } else { | ||
| 60 | + diff = y - x; | ||
| 61 | + } | ||
| 62 | + // diff is negative. x is now the larger one. | ||
| 63 | + | ||
| 64 | + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { | ||
| 65 | + float res; | ||
| 66 | + res = x + log1pf(expf(diff)); | ||
| 67 | + return res; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + return x; // return the larger one. | ||
| 71 | + } | ||
| 72 | +}; | ||
| 73 | + | ||
| 74 | +template <class T> | ||
| 75 | +void LogSoftmax(T *input, int32_t input_len) { | ||
| 76 | + assert(input); | ||
| 77 | + | ||
| 78 | + T m = *std::max_element(input, input + input_len); | ||
| 79 | + | ||
| 80 | + T sum = 0.0; | ||
| 81 | + for (int32_t i = 0; i < input_len; i++) { | ||
| 82 | + sum += exp(input[i] - m); | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + T offset = m + log(sum); | ||
| 86 | + for (int32_t i = 0; i < input_len; i++) { | ||
| 87 | + input[i] -= offset; | ||
| 88 | + } | ||
| 89 | +} | ||
| 90 | + | ||
| 91 | +template <class T> | ||
| 92 | +std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | ||
| 93 | + std::vector<int32_t> vec_index(size); | ||
| 94 | + std::iota(vec_index.begin(), vec_index.end(), 0); | ||
| 95 | + | ||
| 96 | + std::sort(vec_index.begin(), vec_index.end(), | ||
| 97 | + [vec](int32_t index_1, int32_t index_2) { | ||
| 98 | + return vec[index_1] > vec[index_2]; | ||
| 99 | + }); | ||
| 100 | + | ||
| 101 | + int32_t k_num = std::min<int32_t>(size, topk); | ||
| 102 | + std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num); | ||
| 103 | + return index; | ||
| 104 | +} | ||
| 105 | + | ||
| 106 | +} // namespace sherpa_onnx | ||
| 107 | +#endif // SHERPA_ONNX_CSRC_MATH_H_ |
| @@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | @@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | ||
| 247 | return {std::move(encoder_out[0]), std::move(next_states)}; | 247 | return {std::move(encoder_out[0]), std::move(next_states)}; |
| 248 | } | 248 | } |
| 249 | 249 | ||
| 250 | -Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( | ||
| 251 | - const std::vector<OnlineTransducerDecoderResult> &results) { | ||
| 252 | - int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 253 | - std::array<int64_t, 2> shape{batch_size, context_size_}; | ||
| 254 | - Ort::Value decoder_input = | ||
| 255 | - Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size()); | ||
| 256 | - int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 257 | - | ||
| 258 | - for (const auto &r : results) { | ||
| 259 | - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; | ||
| 260 | - const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 261 | - std::copy(begin, end, p); | ||
| 262 | - p += context_size_; | ||
| 263 | - } | ||
| 264 | - | ||
| 265 | - return decoder_input; | ||
| 266 | -} | ||
| 267 | - | ||
| 268 | Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { | 250 | Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { |
| 269 | auto decoder_out = decoder_sess_->Run( | 251 | auto decoder_out = decoder_sess_->Run( |
| 270 | {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | 252 | {}, decoder_input_names_ptr_.data(), &decoder_input, 1, |
| @@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 40 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 40 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 41 | Ort::Value features, std::vector<Ort::Value> states) override; | 41 | Ort::Value features, std::vector<Ort::Value> states) override; |
| 42 | 42 | ||
| 43 | - Ort::Value BuildDecoderInput( | ||
| 44 | - const std::vector<OnlineTransducerDecoderResult> &results) override; | ||
| 45 | - | ||
| 46 | Ort::Value RunDecoder(Ort::Value decoder_input) override; | 43 | Ort::Value RunDecoder(Ort::Value decoder_input) override; |
| 47 | 44 | ||
| 48 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | 45 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; |
| 1 | // sherpa-onnx/csrc/online-recognizer.cc | 1 | // sherpa-onnx/csrc/online-recognizer.cc |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | 5 | ||
| 5 | #include "sherpa-onnx/csrc/online-recognizer.h" | 6 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 6 | 7 | ||
| @@ -16,6 +17,7 @@ | @@ -16,6 +17,7 @@ | ||
| 16 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 17 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 17 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 18 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 18 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 20 | +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | ||
| 19 | #include "sherpa-onnx/csrc/symbol-table.h" | 21 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 20 | 22 | ||
| 21 | namespace sherpa_onnx { | 23 | namespace sherpa_onnx { |
| @@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 39 | 41 | ||
| 40 | po->Register("enable-endpoint", &enable_endpoint, | 42 | po->Register("enable-endpoint", &enable_endpoint, |
| 41 | "True to enable endpoint detection. False to disable it."); | 43 | "True to enable endpoint detection. False to disable it."); |
| 44 | + po->Register("max-active-paths", &max_active_paths, | ||
| 45 | + "beam size used in modified beam search."); | ||
| 46 | + po->Register("decoding-mothod", &decoding_method, | ||
| 47 | + "decoding method," | ||
| 48 | + "now support greedy_search and modified_beam_search."); | ||
| 42 | } | 49 | } |
| 43 | 50 | ||
| 44 | bool OnlineRecognizerConfig::Validate() const { | 51 | bool OnlineRecognizerConfig::Validate() const { |
| @@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 52 | os << "feat_config=" << feat_config.ToString() << ", "; | 59 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 53 | os << "model_config=" << model_config.ToString() << ", "; | 60 | os << "model_config=" << model_config.ToString() << ", "; |
| 54 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 61 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 55 | - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; | 62 | + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ","; |
| 63 | + os << "max_active_paths=" << max_active_paths << ","; | ||
| 64 | + os << "decoding_method=\"" << decoding_method << "\")"; | ||
| 56 | 65 | ||
| 57 | return os.str(); | 66 | return os.str(); |
| 58 | } | 67 | } |
| @@ -64,8 +73,17 @@ class OnlineRecognizer::Impl { | @@ -64,8 +73,17 @@ class OnlineRecognizer::Impl { | ||
| 64 | model_(OnlineTransducerModel::Create(config.model_config)), | 73 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 65 | sym_(config.model_config.tokens), | 74 | sym_(config.model_config.tokens), |
| 66 | endpoint_(config_.endpoint_config) { | 75 | endpoint_(config_.endpoint_config) { |
| 76 | + if (config.decoding_method == "modified_beam_search") { | ||
| 77 | + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | ||
| 78 | + model_.get(), config_.max_active_paths); | ||
| 79 | + } else if (config.decoding_method == "greedy_search") { | ||
| 67 | decoder_ = | 80 | decoder_ = |
| 68 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 81 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 82 | + } else { | ||
| 83 | + fprintf(stderr, "Unsupported decoding method: %s\n", | ||
| 84 | + config.decoding_method.c_str()); | ||
| 85 | + exit(-1); | ||
| 86 | + } | ||
| 69 | } | 87 | } |
| 70 | 88 | ||
| 71 | #if __ANDROID_API__ >= 9 | 89 | #if __ANDROID_API__ >= 9 |
| @@ -74,8 +92,17 @@ class OnlineRecognizer::Impl { | @@ -74,8 +92,17 @@ class OnlineRecognizer::Impl { | ||
| 74 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), | 92 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), |
| 75 | sym_(mgr, config.model_config.tokens), | 93 | sym_(mgr, config.model_config.tokens), |
| 76 | endpoint_(config_.endpoint_config) { | 94 | endpoint_(config_.endpoint_config) { |
| 95 | + if (config.decoding_method == "modified_beam_search") { | ||
| 96 | + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | ||
| 97 | + model_.get(), config_.max_active_paths); | ||
| 98 | + } else if (config.decoding_method == "greedy_search") { | ||
| 77 | decoder_ = | 99 | decoder_ = |
| 78 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 100 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 101 | + } else { | ||
| 102 | + fprintf(stderr, "Unsupported decoding method: %s\n", | ||
| 103 | + config.decoding_method.c_str()); | ||
| 104 | + exit(-1); | ||
| 105 | + } | ||
| 79 | } | 106 | } |
| 80 | #endif | 107 | #endif |
| 81 | 108 |
| @@ -32,7 +32,11 @@ struct OnlineRecognizerConfig { | @@ -32,7 +32,11 @@ struct OnlineRecognizerConfig { | ||
| 32 | FeatureExtractorConfig feat_config; | 32 | FeatureExtractorConfig feat_config; |
| 33 | OnlineTransducerModelConfig model_config; | 33 | OnlineTransducerModelConfig model_config; |
| 34 | EndpointConfig endpoint_config; | 34 | EndpointConfig endpoint_config; |
| 35 | - bool enable_endpoint; | 35 | + bool enable_endpoint = true; |
| 36 | + int32_t max_active_paths = 4; | ||
| 37 | + | ||
| 38 | + std::string decoding_method = "modified_beam_search"; | ||
| 39 | + // now support modified_beam_search and greedy_search | ||
| 36 | 40 | ||
| 37 | OnlineRecognizerConfig() = default; | 41 | OnlineRecognizerConfig() = default; |
| 38 | 42 |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| @@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult { | @@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult { | ||
| 17 | 18 | ||
| 18 | /// number of trailing blank frames decoded so far | 19 | /// number of trailing blank frames decoded so far |
| 19 | int32_t num_trailing_blanks = 0; | 20 | int32_t num_trailing_blanks = 0; |
| 21 | + | ||
| 22 | + // used only in modified beam_search | ||
| 23 | + Hypotheses hyps; | ||
| 20 | }; | 24 | }; |
| 21 | 25 | ||
| 22 | class OnlineTransducerDecoder { | 26 | class OnlineTransducerDecoder { |
| @@ -4,8 +4,6 @@ | @@ -4,8 +4,6 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 5 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 6 | 6 | ||
| 7 | -#include <assert.h> | ||
| 8 | - | ||
| 9 | #include <algorithm> | 7 | #include <algorithm> |
| 10 | #include <utility> | 8 | #include <utility> |
| 11 | #include <vector> | 9 | #include <vector> |
| @@ -15,39 +13,6 @@ | @@ -15,39 +13,6 @@ | ||
| 15 | 13 | ||
| 16 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 17 | 15 | ||
| 18 | -static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | ||
| 19 | - int32_t t) { | ||
| 20 | - std::vector<int64_t> encoder_out_shape = | ||
| 21 | - encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 22 | - | ||
| 23 | - auto batch_size = encoder_out_shape[0]; | ||
| 24 | - auto num_frames = encoder_out_shape[1]; | ||
| 25 | - assert(t < num_frames); | ||
| 26 | - | ||
| 27 | - auto encoder_out_dim = encoder_out_shape[2]; | ||
| 28 | - | ||
| 29 | - auto offset = num_frames * encoder_out_dim; | ||
| 30 | - | ||
| 31 | - auto memory_info = | ||
| 32 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 33 | - | ||
| 34 | - std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; | ||
| 35 | - | ||
| 36 | - Ort::Value ans = | ||
| 37 | - Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 38 | - | ||
| 39 | - float *dst = ans.GetTensorMutableData<float>(); | ||
| 40 | - const float *src = encoder_out->GetTensorData<float>(); | ||
| 41 | - | ||
| 42 | - for (int32_t i = 0; i != batch_size; ++i) { | ||
| 43 | - std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); | ||
| 44 | - src += offset; | ||
| 45 | - dst += encoder_out_dim; | ||
| 46 | - } | ||
| 47 | - | ||
| 48 | - return ans; | ||
| 49 | -} | ||
| 50 | - | ||
| 51 | OnlineTransducerDecoderResult | 16 | OnlineTransducerDecoderResult |
| 52 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { | 17 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { |
| 53 | int32_t context_size = model_->ContextSize(); | 18 | int32_t context_size = model_->ContextSize(); |
| @@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 90 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 55 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 91 | 56 | ||
| 92 | for (int32_t t = 0; t != num_frames; ++t) { | 57 | for (int32_t t = 0; t != num_frames; ++t) { |
| 93 | - Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t); | 58 | + Ort::Value cur_encoder_out = |
| 59 | + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | ||
| 94 | Ort::Value logit = model_->RunJoiner( | 60 | Ort::Value logit = model_->RunJoiner( |
| 95 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 61 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); |
| 96 | 62 |
| 1 | // sherpa-onnx/csrc/online-transducer-model.cc | 1 | // sherpa-onnx/csrc/online-transducer-model.cc |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 5 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 5 | 6 | ||
| 6 | #if __ANDROID_API__ >= 9 | 7 | #if __ANDROID_API__ >= 9 |
| @@ -8,6 +9,7 @@ | @@ -8,6 +9,7 @@ | ||
| 8 | #include "android/asset_manager_jni.h" | 9 | #include "android/asset_manager_jni.h" |
| 9 | #endif | 10 | #endif |
| 10 | 11 | ||
| 12 | +#include <algorithm> | ||
| 11 | #include <memory> | 13 | #include <memory> |
| 12 | #include <sstream> | 14 | #include <sstream> |
| 13 | #include <string> | 15 | #include <string> |
| @@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 75 | return nullptr; | 77 | return nullptr; |
| 76 | } | 78 | } |
| 77 | 79 | ||
| 80 | +Ort::Value OnlineTransducerModel::BuildDecoderInput( | ||
| 81 | + const std::vector<OnlineTransducerDecoderResult> &results) { | ||
| 82 | + int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 83 | + int32_t context_size = ContextSize(); | ||
| 84 | + std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 85 | + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>( | ||
| 86 | + Allocator(), shape.data(), shape.size()); | ||
| 87 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 88 | + | ||
| 89 | + for (const auto &r : results) { | ||
| 90 | + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size; | ||
| 91 | + const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 92 | + std::copy(begin, end, p); | ||
| 93 | + p += context_size; | ||
| 94 | + } | ||
| 95 | + return decoder_input; | ||
| 96 | +} | ||
| 97 | + | ||
| 98 | +Ort::Value OnlineTransducerModel::BuildDecoderInput( | ||
| 99 | + const std::vector<Hypothesis> &hyps) { | ||
| 100 | + int32_t batch_size = static_cast<int32_t>(hyps.size()); | ||
| 101 | + int32_t context_size = ContextSize(); | ||
| 102 | + std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 103 | + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>( | ||
| 104 | + Allocator(), shape.data(), shape.size()); | ||
| 105 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 106 | + | ||
| 107 | + for (const auto &h : hyps) { | ||
| 108 | + std::copy(h.ys.end() - context_size, h.ys.end(), p); | ||
| 109 | + p += context_size; | ||
| 110 | + } | ||
| 111 | + return decoder_input; | ||
| 112 | +} | ||
| 113 | + | ||
| 78 | #if __ANDROID_API__ >= 9 | 114 | #if __ANDROID_API__ >= 9 |
| 79 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | 115 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( |
| 80 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) { | 116 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) { |
| @@ -14,6 +14,8 @@ | @@ -14,6 +14,8 @@ | ||
| 14 | #endif | 14 | #endif |
| 15 | 15 | ||
| 16 | #include "onnxruntime_cxx_api.h" // NOLINT | 16 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 17 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 17 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 18 | 20 | ||
| 19 | namespace sherpa_onnx { | 21 | namespace sherpa_onnx { |
| @@ -71,9 +73,6 @@ class OnlineTransducerModel { | @@ -71,9 +73,6 @@ class OnlineTransducerModel { | ||
| 71 | Ort::Value features, | 73 | Ort::Value features, |
| 72 | std::vector<Ort::Value> states) = 0; // NOLINT | 74 | std::vector<Ort::Value> states) = 0; // NOLINT |
| 73 | 75 | ||
| 74 | - virtual Ort::Value BuildDecoderInput( | ||
| 75 | - const std::vector<OnlineTransducerDecoderResult> &results) = 0; | ||
| 76 | - | ||
| 77 | /** Run the decoder network. | 76 | /** Run the decoder network. |
| 78 | * | 77 | * |
| 79 | * Caution: We assume there are no recurrent connections in the decoder and | 78 | * Caution: We assume there are no recurrent connections in the decoder and |
| @@ -125,7 +124,13 @@ class OnlineTransducerModel { | @@ -125,7 +124,13 @@ class OnlineTransducerModel { | ||
| 125 | virtual int32_t VocabSize() const = 0; | 124 | virtual int32_t VocabSize() const = 0; |
| 126 | 125 | ||
| 127 | virtual int32_t SubsamplingFactor() const { return 4; } | 126 | virtual int32_t SubsamplingFactor() const { return 4; } |
| 127 | + | ||
| 128 | virtual OrtAllocator *Allocator() = 0; | 128 | virtual OrtAllocator *Allocator() = 0; |
| 129 | + | ||
| 130 | + Ort::Value BuildDecoderInput( | ||
| 131 | + const std::vector<OnlineTransducerDecoderResult> &results); | ||
| 132 | + | ||
| 133 | + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &hyps); | ||
| 129 | }; | 134 | }; |
| 130 | 135 | ||
| 131 | } // namespace sherpa_onnx | 136 | } // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 17 | + const std::vector<int32_t> &hyps_num_split) { | ||
| 18 | + std::vector<int64_t> cur_encoder_out_shape = | ||
| 19 | + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 20 | + | ||
| 21 | + std::array<int64_t, 2> ans_shape{hyps_num_split.back(), | ||
| 22 | + cur_encoder_out_shape[1]}; | ||
| 23 | + | ||
| 24 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 25 | + ans_shape.size()); | ||
| 26 | + | ||
| 27 | + const float *src = cur_encoder_out->GetTensorData<float>(); | ||
| 28 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 29 | + int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1; | ||
| 30 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 31 | + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; | ||
| 32 | + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { | ||
| 33 | + std::copy(src, src + cur_encoder_out_shape[1], dst); | ||
| 34 | + dst += cur_encoder_out_shape[1]; | ||
| 35 | + } | ||
| 36 | + src += cur_encoder_out_shape[1]; | ||
| 37 | + } | ||
| 38 | + return ans; | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +static void LogSoftmax(float *in, int32_t w, int32_t h) { | ||
| 42 | + for (int32_t i = 0; i != h; ++i) { | ||
| 43 | + LogSoftmax(in, w); | ||
| 44 | + in += w; | ||
| 45 | + } | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +OnlineTransducerDecoderResult | ||
| 49 | +OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | ||
| 50 | + int32_t context_size = model_->ContextSize(); | ||
| 51 | + int32_t blank_id = 0; // always 0 | ||
| 52 | + OnlineTransducerDecoderResult r; | ||
| 53 | + std::vector<int32_t> blanks(context_size, blank_id); | ||
| 54 | + Hypotheses blank_hyp({{blanks, 0}}); | ||
| 55 | + r.hyps = std::move(blank_hyp); | ||
| 56 | + return r; | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | ||
| 60 | + OnlineTransducerDecoderResult *r) const { | ||
| 61 | + int32_t context_size = model_->ContextSize(); | ||
| 62 | + auto hyp = r->hyps.GetMostProbable(true); | ||
| 63 | + | ||
| 64 | + std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); | ||
| 65 | + r->tokens = std::move(tokens); | ||
| 66 | + r->num_trailing_blanks = hyp.num_trailing_blanks; | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 70 | + Ort::Value encoder_out, | ||
| 71 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 72 | + std::vector<int64_t> encoder_out_shape = | ||
| 73 | + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 74 | + | ||
| 75 | + if (encoder_out_shape[0] != result->size()) { | ||
| 76 | + fprintf(stderr, | ||
| 77 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 78 | + static_cast<int32_t>(encoder_out_shape[0]), | ||
| 79 | + static_cast<int32_t>(result->size())); | ||
| 80 | + exit(-1); | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); | ||
| 84 | + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); | ||
| 85 | + int32_t vocab_size = model_->VocabSize(); | ||
| 86 | + | ||
| 87 | + std::vector<Hypotheses> cur; | ||
| 88 | + for (auto &r : *result) { | ||
| 89 | + cur.push_back(std::move(r.hyps)); | ||
| 90 | + } | ||
| 91 | + std::vector<Hypothesis> prev; | ||
| 92 | + | ||
| 93 | + for (int32_t t = 0; t != num_frames; ++t) { | ||
| 94 | + // Due to merging paths with identical token sequences, | ||
| 95 | + // not all utterances have "num_active_paths" paths. | ||
| 96 | + int32_t hyps_num_acc = 0; | ||
| 97 | + std::vector<int32_t> hyps_num_split; | ||
| 98 | + hyps_num_split.push_back(0); | ||
| 99 | + | ||
| 100 | + prev.clear(); | ||
| 101 | + for (auto &hyps : cur) { | ||
| 102 | + for (auto &h : hyps) { | ||
| 103 | + prev.push_back(std::move(h.second)); | ||
| 104 | + hyps_num_acc++; | ||
| 105 | + } | ||
| 106 | + hyps_num_split.push_back(hyps_num_acc); | ||
| 107 | + } | ||
| 108 | + cur.clear(); | ||
| 109 | + cur.reserve(batch_size); | ||
| 110 | + | ||
| 111 | + Ort::Value decoder_input = model_->BuildDecoderInput(prev); | ||
| 112 | + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 113 | + | ||
| 114 | + Ort::Value cur_encoder_out = | ||
| 115 | + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | ||
| 116 | + cur_encoder_out = | ||
| 117 | + Repeat(model_->Allocator(), &cur_encoder_out, hyps_num_split); | ||
| 118 | + Ort::Value logit = model_->RunJoiner( | ||
| 119 | + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | ||
| 120 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 121 | + | ||
| 122 | + for (int32_t b = 0; b < batch_size; ++b) { | ||
| 123 | + int32_t start = hyps_num_split[b]; | ||
| 124 | + int32_t end = hyps_num_split[b + 1]; | ||
| 125 | + LogSoftmax(p_logit, vocab_size, (end - start)); | ||
| 126 | + auto topk = | ||
| 127 | + TopkIndex(p_logit, vocab_size * (end - start), max_active_paths_); | ||
| 128 | + | ||
| 129 | + Hypotheses hyps; | ||
| 130 | + for (auto i : topk) { | ||
| 131 | + int32_t hyp_index = i / vocab_size + start; | ||
| 132 | + int32_t new_token = i % vocab_size; | ||
| 133 | + | ||
| 134 | + Hypothesis new_hyp = prev[hyp_index]; | ||
| 135 | + if (new_token != 0) { | ||
| 136 | + new_hyp.ys.push_back(new_token); | ||
| 137 | + new_hyp.num_trailing_blanks = 0; | ||
| 138 | + } else { | ||
| 139 | + ++new_hyp.num_trailing_blanks; | ||
| 140 | + } | ||
| 141 | + new_hyp.log_prob += p_logit[i]; | ||
| 142 | + hyps.Add(std::move(new_hyp)); | ||
| 143 | + } | ||
| 144 | + cur.push_back(std::move(hyps)); | ||
| 145 | + p_logit += vocab_size * (end - start); | ||
| 146 | + } | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 150 | + (*result)[b].hyps = std::move(cur[b]); | ||
| 151 | + } | ||
| 152 | +} | ||
| 153 | + | ||
| 154 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-transducer-modified_beam-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ | ||
| 8 | + | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 12 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OnlineTransducerModifiedBeamSearchDecoder | ||
| 17 | + : public OnlineTransducerDecoder { | ||
| 18 | + public: | ||
| 19 | + OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, | ||
| 20 | + int32_t max_active_paths) | ||
| 21 | + : model_(model), max_active_paths_(max_active_paths) {} | ||
| 22 | + | ||
| 23 | + OnlineTransducerDecoderResult GetEmptyResult() const override; | ||
| 24 | + | ||
| 25 | + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; | ||
| 26 | + | ||
| 27 | + void Decode(Ort::Value encoder_out, | ||
| 28 | + std::vector<OnlineTransducerDecoderResult> *result) override; | ||
| 29 | + | ||
| 30 | + private: | ||
| 31 | + OnlineTransducerModel *model_; // Not owned | ||
| 32 | + int32_t max_active_paths_; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx | ||
| 36 | + | ||
| 37 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ |
| @@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, | @@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, | ||
| 461 | return {std::move(encoder_out[0]), std::move(next_states)}; | 461 | return {std::move(encoder_out[0]), std::move(next_states)}; |
| 462 | } | 462 | } |
| 463 | 463 | ||
| 464 | -Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput( | ||
| 465 | - const std::vector<OnlineTransducerDecoderResult> &results) { | ||
| 466 | - int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 467 | - std::array<int64_t, 2> shape{batch_size, context_size_}; | ||
| 468 | - Ort::Value decoder_input = | ||
| 469 | - Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size()); | ||
| 470 | - int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 471 | - | ||
| 472 | - for (const auto &r : results) { | ||
| 473 | - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; | ||
| 474 | - const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 475 | - std::copy(begin, end, p); | ||
| 476 | - p += context_size_; | ||
| 477 | - } | ||
| 478 | - | ||
| 479 | - return decoder_input; | ||
| 480 | -} | ||
| 481 | - | ||
| 482 | Ort::Value OnlineZipformerTransducerModel::RunDecoder( | 464 | Ort::Value OnlineZipformerTransducerModel::RunDecoder( |
| 483 | Ort::Value decoder_input) { | 465 | Ort::Value decoder_input) { |
| 484 | auto decoder_out = decoder_sess_->Run( | 466 | auto decoder_out = decoder_sess_->Run( |
| @@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | @@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | ||
| 41 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 41 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 42 | Ort::Value features, std::vector<Ort::Value> states) override; | 42 | Ort::Value features, std::vector<Ort::Value> states) override; |
| 43 | 43 | ||
| 44 | - Ort::Value BuildDecoderInput( | ||
| 45 | - const std::vector<OnlineTransducerDecoderResult> &results) override; | ||
| 46 | - | ||
| 47 | Ort::Value RunDecoder(Ort::Value decoder_input) override; | 44 | Ort::Value RunDecoder(Ort::Value decoder_input) override; |
| 48 | 45 | ||
| 49 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | 46 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; |
| @@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | @@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 44 | } | 44 | } |
| 45 | } | 45 | } |
| 46 | 46 | ||
| 47 | +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | ||
| 48 | + int32_t t) { | ||
| 49 | + std::vector<int64_t> encoder_out_shape = | ||
| 50 | + encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 51 | + | ||
| 52 | + auto batch_size = encoder_out_shape[0]; | ||
| 53 | + auto num_frames = encoder_out_shape[1]; | ||
| 54 | + assert(t < num_frames); | ||
| 55 | + | ||
| 56 | + auto encoder_out_dim = encoder_out_shape[2]; | ||
| 57 | + | ||
| 58 | + auto offset = num_frames * encoder_out_dim; | ||
| 59 | + | ||
| 60 | + auto memory_info = | ||
| 61 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 62 | + | ||
| 63 | + std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; | ||
| 64 | + | ||
| 65 | + Ort::Value ans = | ||
| 66 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 67 | + | ||
| 68 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 69 | + const float *src = encoder_out->GetTensorData<float>(); | ||
| 70 | + | ||
| 71 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 72 | + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); | ||
| 73 | + src += offset; | ||
| 74 | + dst += encoder_out_dim; | ||
| 75 | + } | ||
| 76 | + return ans; | ||
| 77 | +} | ||
| 78 | + | ||
| 47 | void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | 79 | void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { |
| 48 | Ort::AllocatorWithDefaultOptions allocator; | 80 | Ort::AllocatorWithDefaultOptions allocator; |
| 49 | std::vector<Ort::AllocatedStringPtr> v = | 81 | std::vector<Ort::AllocatedStringPtr> v = |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <locale> | 10 | #include <locale> |
| 11 | #endif | 11 | #endif |
| 12 | 12 | ||
| 13 | +#include <cassert> | ||
| 13 | #include <ostream> | 14 | #include <ostream> |
| 14 | #include <string> | 15 | #include <string> |
| 15 | #include <vector> | 16 | #include <vector> |
| @@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | @@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | ||
| 57 | void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | 58 | void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, |
| 58 | std::vector<const char *> *output_names_ptr); | 59 | std::vector<const char *> *output_names_ptr); |
| 59 | 60 | ||
| 61 | +/** | ||
| 62 | + * Get the output frame of Encoder | ||
| 63 | + * | ||
| 64 | + * @param allocator allocator of onnxruntime | ||
| 65 | + * @param encoder_out encoder out tensor | ||
| 66 | + * @param t frame_index | ||
| 67 | + * | ||
| 68 | + */ | ||
| 69 | +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | ||
| 70 | + int32_t t); | ||
| 71 | + | ||
| 60 | void PrintModelMetadata(std::ostream &os, | 72 | void PrintModelMetadata(std::ostream &os, |
| 61 | const Ort::ModelMetadata &meta_data); // NOLINT | 73 | const Ort::ModelMetadata &meta_data); // NOLINT |
| 62 | 74 |
-
请 注册 或 登录 后发表评论