PF Luo
Committed by GitHub

add modified beam search (#69)

@@ -34,3 +34,4 @@ decode-file @@ -34,3 +34,4 @@ decode-file
34 tokens.txt 34 tokens.txt
35 *.onnx 35 *.onnx
36 log.txt 36 log.txt
  37 +tags
@@ -5,11 +5,13 @@ set(sources @@ -5,11 +5,13 @@ set(sources
5 endpoint.cc 5 endpoint.cc
6 features.cc 6 features.cc
7 file-utils.cc 7 file-utils.cc
  8 + hypothesis.cc
8 online-lstm-transducer-model.cc 9 online-lstm-transducer-model.cc
9 online-recognizer.cc 10 online-recognizer.cc
10 online-stream.cc 11 online-stream.cc
11 online-transducer-greedy-search-decoder.cc 12 online-transducer-greedy-search-decoder.cc
12 online-transducer-model-config.cc 13 online-transducer-model-config.cc
  14 + online-transducer-modified-beam-search-decoder.cc
13 online-transducer-model.cc 15 online-transducer-model.cc
14 online-zipformer-transducer-model.cc 16 online-zipformer-transducer-model.cc
15 onnx-utils.cc 17 onnx-utils.cc
  1 +/**
  2 + * Copyright (c) 2023 Xiaomi Corporation
  3 + *
  4 + */
  5 +
  6 +#include "sherpa-onnx/csrc/hypothesis.h"
  7 +
  8 +#include <algorithm>
  9 +#include <utility>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void Hypotheses::Add(Hypothesis hyp) {
  14 + auto key = hyp.Key();
  15 + auto it = hyps_dict_.find(key);
  16 + if (it == hyps_dict_.end()) {
  17 + hyps_dict_[key] = std::move(hyp);
  18 + } else {
  19 + it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
  20 + }
  21 +}
  22 +
  23 +Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
  24 + if (length_norm == false) {
  25 + return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
  26 + [](const auto &left, auto &right) -> bool {
  27 + return left.second.log_prob <
  28 + right.second.log_prob;
  29 + })
  30 + ->second;
  31 + } else {
  32 + // for length_norm is true
  33 + return std::max_element(
  34 + hyps_dict_.begin(), hyps_dict_.end(),
  35 + [](const auto &left, const auto &right) -> bool {
  36 + return left.second.log_prob / left.second.ys.size() <
  37 + right.second.log_prob / right.second.ys.size();
  38 + })
  39 + ->second;
  40 + }
  41 +}
  42 +
  43 +std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
  44 + k = std::max(k, 1);
  45 + k = std::min(k, Size());
  46 +
  47 + std::vector<Hypothesis> all_hyps = Vec();
  48 +
  49 + if (length_norm == false) {
  50 + std::partial_sort(
  51 + all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
  52 + [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; });
  53 + } else {
  54 + // for length_norm is true
  55 + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
  56 + [](const auto &a, const auto &b) {
  57 + return a.log_prob / a.ys.size() >
  58 + b.log_prob / b.ys.size();
  59 + });
  60 + }
  61 +
  62 + return {all_hyps.begin(), all_hyps.begin() + k};
  63 +}
  64 +
  65 +} // namespace sherpa_onnx
  1 +/**
  2 + * Copyright (c) 2023 Xiaomi Corporation
  3 + *
  4 + */
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_
  7 +#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_
  8 +
  9 +#include <sstream>
  10 +#include <string>
  11 +#include <unordered_map>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#include "sherpa-onnx/csrc/math.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +struct Hypothesis {
  20 + // The predicted tokens so far. Newly predicated tokens are appended.
  21 + std::vector<int32_t> ys;
  22 +
  23 + // timestamps[i] contains the frame number after subsampling
  24 + // on which ys[i] is decoded.
  25 + std::vector<int32_t> timestamps;
  26 +
  27 + // The total score of ys in log space.
  28 + double log_prob = 0;
  29 +
  30 + int32_t num_trailing_blanks = 0;
  31 +
  32 + Hypothesis() = default;
  33 + Hypothesis(const std::vector<int32_t> &ys, double log_prob)
  34 + : ys(ys), log_prob(log_prob) {}
  35 +
  36 + // If two Hypotheses have the same `Key`, then they contain
  37 + // the same token sequence.
  38 + std::string Key() const {
  39 + // TODO(fangjun): Use a hash function?
  40 + std::ostringstream os;
  41 + std::string sep = "-";
  42 + for (auto i : ys) {
  43 + os << i << sep;
  44 + sep = "-";
  45 + }
  46 + return os.str();
  47 + }
  48 +
  49 + // For debugging
  50 + std::string ToString() const {
  51 + std::ostringstream os;
  52 + os << "(" << Key() << ", " << log_prob << ")";
  53 + return os.str();
  54 + }
  55 +};
  56 +
  57 +class Hypotheses {
  58 + public:
  59 + Hypotheses() = default;
  60 +
  61 + explicit Hypotheses(std::vector<Hypothesis> hyps) {
  62 + for (auto &h : hyps) {
  63 + hyps_dict_[h.Key()] = std::move(h);
  64 + }
  65 + }
  66 +
  67 + explicit Hypotheses(std::unordered_map<std::string, Hypothesis> hyps_dict)
  68 + : hyps_dict_(std::move(hyps_dict)) {}
  69 +
  70 + // Add hyp to this object. If it already exists, its log_prob
  71 + // is updated with the given hyp using log-sum-exp.
  72 + void Add(Hypothesis hyp);
  73 +
  74 + // Get the hyp that has the largest log_prob.
  75 + // If length_norm is true, hyp's log_prob is divided by
  76 + // len(hyp.ys) before comparison.
  77 + Hypothesis GetMostProbable(bool length_norm) const;
  78 +
  79 + // Get the k hyps that have the largest log_prob.
  80 + // If length_norm is true, hyp's log_prob is divided by
  81 + // len(hyp.ys) before comparison.
  82 + std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm) const;
  83 +
  84 + int32_t Size() const { return hyps_dict_.size(); }
  85 +
  86 + std::string ToString() const {
  87 + std::ostringstream os;
  88 + for (const auto &p : hyps_dict_) {
  89 + os << p.second.ToString() << "\n";
  90 + }
  91 + return os.str();
  92 + }
  93 +
  94 + const auto begin() const { return hyps_dict_.begin(); }
  95 + const auto end() const { return hyps_dict_.end(); }
  96 +
  97 + void Clear() { hyps_dict_.clear(); }
  98 +
  99 + private:
  100 + // Return a list of hyps contained in this object.
  101 + std::vector<Hypothesis> Vec() const {
  102 + std::vector<Hypothesis> ans;
  103 + ans.reserve(hyps_dict_.size());
  104 + for (const auto &p : hyps_dict_) {
  105 + ans.push_back(p.second);
  106 + }
  107 + return ans;
  108 + }
  109 +
  110 + private:
  111 + using Map = std ::unordered_map<std::string, Hypothesis>;
  112 + Map hyps_dict_;
  113 +};
  114 +
  115 +} // namespace sherpa_onnx
  116 +
  117 +#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey)
  3 + * Copyright (c) 2023 (Pingfeng Luo)
  4 + *
  5 + */
  6 +// This file is copied from k2/csrc/utils.h
  7 +#ifndef SHERPA_ONNX_CSRC_MATH_H_
  8 +#define SHERPA_ONNX_CSRC_MATH_H_
  9 +
  10 +#include <algorithm>
  11 +#include <cassert>
  12 +#include <cmath>
  13 +#include <numeric>
  14 +#include <vector>
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +// logf(FLT_EPSILON)
  19 +#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f
  20 +
  21 +// log(DBL_EPSILON)
  22 +#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \
  23 + -36.0436533891171535515240975655615329742431640625
  24 +
  25 +template <typename T>
  26 +struct LogAdd;
  27 +
  28 +template <>
  29 +struct LogAdd<double> {
  30 + double operator()(double x, double y) const {
  31 + double diff;
  32 +
  33 + if (x < y) {
  34 + diff = x - y;
  35 + x = y;
  36 + } else {
  37 + diff = y - x;
  38 + }
  39 + // diff is negative. x is now the larger one.
  40 +
  41 + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) {
  42 + double res;
  43 + res = x + log1p(exp(diff));
  44 + return res;
  45 + }
  46 +
  47 + return x; // return the larger one.
  48 + }
  49 +};
  50 +
  51 +template <>
  52 +struct LogAdd<float> {
  53 + float operator()(float x, float y) const {
  54 + float diff;
  55 +
  56 + if (x < y) {
  57 + diff = x - y;
  58 + x = y;
  59 + } else {
  60 + diff = y - x;
  61 + }
  62 + // diff is negative. x is now the larger one.
  63 +
  64 + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) {
  65 + float res;
  66 + res = x + log1pf(expf(diff));
  67 + return res;
  68 + }
  69 +
  70 + return x; // return the larger one.
  71 + }
  72 +};
  73 +
  74 +template <class T>
  75 +void LogSoftmax(T *input, int32_t input_len) {
  76 + assert(input);
  77 +
  78 + T m = *std::max_element(input, input + input_len);
  79 +
  80 + T sum = 0.0;
  81 + for (int32_t i = 0; i < input_len; i++) {
  82 + sum += exp(input[i] - m);
  83 + }
  84 +
  85 + T offset = m + log(sum);
  86 + for (int32_t i = 0; i < input_len; i++) {
  87 + input[i] -= offset;
  88 + }
  89 +}
  90 +
  91 +template <class T>
  92 +std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
  93 + std::vector<int32_t> vec_index(size);
  94 + std::iota(vec_index.begin(), vec_index.end(), 0);
  95 +
  96 + std::sort(vec_index.begin(), vec_index.end(),
  97 + [vec](int32_t index_1, int32_t index_2) {
  98 + return vec[index_1] > vec[index_2];
  99 + });
  100 +
  101 + int32_t k_num = std::min<int32_t>(size, topk);
  102 + std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
  103 + return index;
  104 +}
  105 +
  106 +} // namespace sherpa_onnx
  107 +#endif // SHERPA_ONNX_CSRC_MATH_H_
@@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, @@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
247 return {std::move(encoder_out[0]), std::move(next_states)}; 247 return {std::move(encoder_out[0]), std::move(next_states)};
248 } 248 }
249 249
250 -Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(  
251 - const std::vector<OnlineTransducerDecoderResult> &results) {  
252 - int32_t batch_size = static_cast<int32_t>(results.size());  
253 - std::array<int64_t, 2> shape{batch_size, context_size_};  
254 - Ort::Value decoder_input =  
255 - Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());  
256 - int64_t *p = decoder_input.GetTensorMutableData<int64_t>();  
257 -  
258 - for (const auto &r : results) {  
259 - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;  
260 - const int64_t *end = r.tokens.data() + r.tokens.size();  
261 - std::copy(begin, end, p);  
262 - p += context_size_;  
263 - }  
264 -  
265 - return decoder_input;  
266 -}  
267 -  
268 Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { 250 Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
269 auto decoder_out = decoder_sess_->Run( 251 auto decoder_out = decoder_sess_->Run(
270 {}, decoder_input_names_ptr_.data(), &decoder_input, 1, 252 {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
@@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
40 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 40 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
41 Ort::Value features, std::vector<Ort::Value> states) override; 41 Ort::Value features, std::vector<Ort::Value> states) override;
42 42
43 - Ort::Value BuildDecoderInput(  
44 - const std::vector<OnlineTransducerDecoderResult> &results) override;  
45 -  
46 Ort::Value RunDecoder(Ort::Value decoder_input) override; 43 Ort::Value RunDecoder(Ort::Value decoder_input) override;
47 44
48 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; 45 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
1 // sherpa-onnx/csrc/online-recognizer.cc 1 // sherpa-onnx/csrc/online-recognizer.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
  4 +// Copyright (c) 2023 Pingfeng Luo
4 5
5 #include "sherpa-onnx/csrc/online-recognizer.h" 6 #include "sherpa-onnx/csrc/online-recognizer.h"
6 7
@@ -16,6 +17,7 @@ @@ -16,6 +17,7 @@
16 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 17 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
17 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 18 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
18 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
  20 +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
19 #include "sherpa-onnx/csrc/symbol-table.h" 21 #include "sherpa-onnx/csrc/symbol-table.h"
20 22
21 namespace sherpa_onnx { 23 namespace sherpa_onnx {
@@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
39 41
40 po->Register("enable-endpoint", &enable_endpoint, 42 po->Register("enable-endpoint", &enable_endpoint,
41 "True to enable endpoint detection. False to disable it."); 43 "True to enable endpoint detection. False to disable it.");
  44 + po->Register("max-active-paths", &max_active_paths,
  45 + "beam size used in modified beam search.");
  46 + po->Register("decoding-mothod", &decoding_method,
  47 + "decoding method,"
  48 + "now support greedy_search and modified_beam_search.");
42 } 49 }
43 50
44 bool OnlineRecognizerConfig::Validate() const { 51 bool OnlineRecognizerConfig::Validate() const {
@@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const {
52 os << "feat_config=" << feat_config.ToString() << ", "; 59 os << "feat_config=" << feat_config.ToString() << ", ";
53 os << "model_config=" << model_config.ToString() << ", "; 60 os << "model_config=" << model_config.ToString() << ", ";
54 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 61 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
55 - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; 62 + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
  63 + os << "max_active_paths=" << max_active_paths << ",";
  64 + os << "decoding_method=\"" << decoding_method << "\")";
56 65
57 return os.str(); 66 return os.str();
58 } 67 }
@@ -64,8 +73,17 @@ class OnlineRecognizer::Impl { @@ -64,8 +73,17 @@ class OnlineRecognizer::Impl {
64 model_(OnlineTransducerModel::Create(config.model_config)), 73 model_(OnlineTransducerModel::Create(config.model_config)),
65 sym_(config.model_config.tokens), 74 sym_(config.model_config.tokens),
66 endpoint_(config_.endpoint_config) { 75 endpoint_(config_.endpoint_config) {
  76 + if (config.decoding_method == "modified_beam_search") {
  77 + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
  78 + model_.get(), config_.max_active_paths);
  79 + } else if (config.decoding_method == "greedy_search") {
67 decoder_ = 80 decoder_ =
68 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 81 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  82 + } else {
  83 + fprintf(stderr, "Unsupported decoding method: %s\n",
  84 + config.decoding_method.c_str());
  85 + exit(-1);
  86 + }
69 } 87 }
70 88
71 #if __ANDROID_API__ >= 9 89 #if __ANDROID_API__ >= 9
@@ -74,8 +92,17 @@ class OnlineRecognizer::Impl { @@ -74,8 +92,17 @@ class OnlineRecognizer::Impl {
74 model_(OnlineTransducerModel::Create(mgr, config.model_config)), 92 model_(OnlineTransducerModel::Create(mgr, config.model_config)),
75 sym_(mgr, config.model_config.tokens), 93 sym_(mgr, config.model_config.tokens),
76 endpoint_(config_.endpoint_config) { 94 endpoint_(config_.endpoint_config) {
  95 + if (config.decoding_method == "modified_beam_search") {
  96 + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
  97 + model_.get(), config_.max_active_paths);
  98 + } else if (config.decoding_method == "greedy_search") {
77 decoder_ = 99 decoder_ =
78 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 100 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  101 + } else {
  102 + fprintf(stderr, "Unsupported decoding method: %s\n",
  103 + config.decoding_method.c_str());
  104 + exit(-1);
  105 + }
79 } 106 }
80 #endif 107 #endif
81 108
@@ -32,7 +32,11 @@ struct OnlineRecognizerConfig { @@ -32,7 +32,11 @@ struct OnlineRecognizerConfig {
32 FeatureExtractorConfig feat_config; 32 FeatureExtractorConfig feat_config;
33 OnlineTransducerModelConfig model_config; 33 OnlineTransducerModelConfig model_config;
34 EndpointConfig endpoint_config; 34 EndpointConfig endpoint_config;
35 - bool enable_endpoint; 35 + bool enable_endpoint = true;
  36 + int32_t max_active_paths = 4;
  37 +
  38 + std::string decoding_method = "modified_beam_search";
  39 + // now support modified_beam_search and greedy_search
36 40
37 OnlineRecognizerConfig() = default; 41 OnlineRecognizerConfig() = default;
38 42
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "onnxruntime_cxx_api.h" // NOLINT 10 #include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/hypothesis.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
@@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult { @@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult {
17 18
18 /// number of trailing blank frames decoded so far 19 /// number of trailing blank frames decoded so far
19 int32_t num_trailing_blanks = 0; 20 int32_t num_trailing_blanks = 0;
  21 +
  22 + // used only in modified beam_search
  23 + Hypotheses hyps;
20 }; 24 };
21 25
22 class OnlineTransducerDecoder { 26 class OnlineTransducerDecoder {
@@ -4,8 +4,6 @@ @@ -4,8 +4,6 @@
4 4
5 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 5 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
6 6
7 -#include <assert.h>  
8 -  
9 #include <algorithm> 7 #include <algorithm>
10 #include <utility> 8 #include <utility>
11 #include <vector> 9 #include <vector>
@@ -15,39 +13,6 @@ @@ -15,39 +13,6 @@
15 13
16 namespace sherpa_onnx { 14 namespace sherpa_onnx {
17 15
18 -static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out,  
19 - int32_t t) {  
20 - std::vector<int64_t> encoder_out_shape =  
21 - encoder_out->GetTensorTypeAndShapeInfo().GetShape();  
22 -  
23 - auto batch_size = encoder_out_shape[0];  
24 - auto num_frames = encoder_out_shape[1];  
25 - assert(t < num_frames);  
26 -  
27 - auto encoder_out_dim = encoder_out_shape[2];  
28 -  
29 - auto offset = num_frames * encoder_out_dim;  
30 -  
31 - auto memory_info =  
32 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
33 -  
34 - std::array<int64_t, 2> shape{batch_size, encoder_out_dim};  
35 -  
36 - Ort::Value ans =  
37 - Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());  
38 -  
39 - float *dst = ans.GetTensorMutableData<float>();  
40 - const float *src = encoder_out->GetTensorData<float>();  
41 -  
42 - for (int32_t i = 0; i != batch_size; ++i) {  
43 - std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);  
44 - src += offset;  
45 - dst += encoder_out_dim;  
46 - }  
47 -  
48 - return ans;  
49 -}  
50 -  
51 OnlineTransducerDecoderResult 16 OnlineTransducerDecoderResult
52 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { 17 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
53 int32_t context_size = model_->ContextSize(); 18 int32_t context_size = model_->ContextSize();
@@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
90 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 55 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
91 56
92 for (int32_t t = 0; t != num_frames; ++t) { 57 for (int32_t t = 0; t != num_frames; ++t) {
93 - Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t); 58 + Ort::Value cur_encoder_out =
  59 + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
94 Ort::Value logit = model_->RunJoiner( 60 Ort::Value logit = model_->RunJoiner(
95 std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); 61 std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
96 62
1 // sherpa-onnx/csrc/online-transducer-model.cc 1 // sherpa-onnx/csrc/online-transducer-model.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
  4 +// Copyright (c) 2023 Pingfeng Luo
4 #include "sherpa-onnx/csrc/online-transducer-model.h" 5 #include "sherpa-onnx/csrc/online-transducer-model.h"
5 6
6 #if __ANDROID_API__ >= 9 7 #if __ANDROID_API__ >= 9
@@ -8,6 +9,7 @@ @@ -8,6 +9,7 @@
8 #include "android/asset_manager_jni.h" 9 #include "android/asset_manager_jni.h"
9 #endif 10 #endif
10 11
  12 +#include <algorithm>
11 #include <memory> 13 #include <memory>
12 #include <sstream> 14 #include <sstream>
13 #include <string> 15 #include <string>
@@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
75 return nullptr; 77 return nullptr;
76 } 78 }
77 79
  80 +Ort::Value OnlineTransducerModel::BuildDecoderInput(
  81 + const std::vector<OnlineTransducerDecoderResult> &results) {
  82 + int32_t batch_size = static_cast<int32_t>(results.size());
  83 + int32_t context_size = ContextSize();
  84 + std::array<int64_t, 2> shape{batch_size, context_size};
  85 + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
  86 + Allocator(), shape.data(), shape.size());
  87 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  88 +
  89 + for (const auto &r : results) {
  90 + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size;
  91 + const int64_t *end = r.tokens.data() + r.tokens.size();
  92 + std::copy(begin, end, p);
  93 + p += context_size;
  94 + }
  95 + return decoder_input;
  96 +}
  97 +
  98 +Ort::Value OnlineTransducerModel::BuildDecoderInput(
  99 + const std::vector<Hypothesis> &hyps) {
  100 + int32_t batch_size = static_cast<int32_t>(hyps.size());
  101 + int32_t context_size = ContextSize();
  102 + std::array<int64_t, 2> shape{batch_size, context_size};
  103 + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
  104 + Allocator(), shape.data(), shape.size());
  105 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  106 +
  107 + for (const auto &h : hyps) {
  108 + std::copy(h.ys.end() - context_size, h.ys.end(), p);
  109 + p += context_size;
  110 + }
  111 + return decoder_input;
  112 +}
  113 +
78 #if __ANDROID_API__ >= 9 114 #if __ANDROID_API__ >= 9
79 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 115 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
80 AAssetManager *mgr, const OnlineTransducerModelConfig &config) { 116 AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
@@ -14,6 +14,8 @@ @@ -14,6 +14,8 @@
14 #endif 14 #endif
15 15
16 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/hypothesis.h"
  18 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
17 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 19 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
18 20
19 namespace sherpa_onnx { 21 namespace sherpa_onnx {
@@ -71,9 +73,6 @@ class OnlineTransducerModel { @@ -71,9 +73,6 @@ class OnlineTransducerModel {
71 Ort::Value features, 73 Ort::Value features,
72 std::vector<Ort::Value> states) = 0; // NOLINT 74 std::vector<Ort::Value> states) = 0; // NOLINT
73 75
74 - virtual Ort::Value BuildDecoderInput(  
75 - const std::vector<OnlineTransducerDecoderResult> &results) = 0;  
76 -  
77 /** Run the decoder network. 76 /** Run the decoder network.
78 * 77 *
79 * Caution: We assume there are no recurrent connections in the decoder and 78 * Caution: We assume there are no recurrent connections in the decoder and
@@ -125,7 +124,13 @@ class OnlineTransducerModel { @@ -125,7 +124,13 @@ class OnlineTransducerModel {
125 virtual int32_t VocabSize() const = 0; 124 virtual int32_t VocabSize() const = 0;
126 125
127 virtual int32_t SubsamplingFactor() const { return 4; } 126 virtual int32_t SubsamplingFactor() const { return 4; }
  127 +
128 virtual OrtAllocator *Allocator() = 0; 128 virtual OrtAllocator *Allocator() = 0;
  129 +
  130 + Ort::Value BuildDecoderInput(
  131 + const std::vector<OnlineTransducerDecoderResult> &results);
  132 +
  133 + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &hyps);
129 }; 134 };
130 135
131 } // namespace sherpa_onnx 136 } // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Pingfeng Luo
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +
  6 +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
  7 +
  8 +#include <algorithm>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
  17 + const std::vector<int32_t> &hyps_num_split) {
  18 + std::vector<int64_t> cur_encoder_out_shape =
  19 + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
  20 +
  21 + std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
  22 + cur_encoder_out_shape[1]};
  23 +
  24 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  25 + ans_shape.size());
  26 +
  27 + const float *src = cur_encoder_out->GetTensorData<float>();
  28 + float *dst = ans.GetTensorMutableData<float>();
  29 + int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
  30 + for (int32_t b = 0; b != batch_size; ++b) {
  31 + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
  32 + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
  33 + std::copy(src, src + cur_encoder_out_shape[1], dst);
  34 + dst += cur_encoder_out_shape[1];
  35 + }
  36 + src += cur_encoder_out_shape[1];
  37 + }
  38 + return ans;
  39 +}
  40 +
  41 +static void LogSoftmax(float *in, int32_t w, int32_t h) {
  42 + for (int32_t i = 0; i != h; ++i) {
  43 + LogSoftmax(in, w);
  44 + in += w;
  45 + }
  46 +}
  47 +
  48 +OnlineTransducerDecoderResult
  49 +OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
  50 + int32_t context_size = model_->ContextSize();
  51 + int32_t blank_id = 0; // always 0
  52 + OnlineTransducerDecoderResult r;
  53 + std::vector<int32_t> blanks(context_size, blank_id);
  54 + Hypotheses blank_hyp({{blanks, 0}});
  55 + r.hyps = std::move(blank_hyp);
  56 + return r;
  57 +}
  58 +
  59 +void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
  60 + OnlineTransducerDecoderResult *r) const {
  61 + int32_t context_size = model_->ContextSize();
  62 + auto hyp = r->hyps.GetMostProbable(true);
  63 +
  64 + std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
  65 + r->tokens = std::move(tokens);
  66 + r->num_trailing_blanks = hyp.num_trailing_blanks;
  67 +}
  68 +
  69 +void OnlineTransducerModifiedBeamSearchDecoder::Decode(
  70 + Ort::Value encoder_out,
  71 + std::vector<OnlineTransducerDecoderResult> *result) {
  72 + std::vector<int64_t> encoder_out_shape =
  73 + encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  74 +
  75 + if (encoder_out_shape[0] != result->size()) {
  76 + fprintf(stderr,
  77 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
  78 + static_cast<int32_t>(encoder_out_shape[0]),
  79 + static_cast<int32_t>(result->size()));
  80 + exit(-1);
  81 + }
  82 +
  83 + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
  84 + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
  85 + int32_t vocab_size = model_->VocabSize();
  86 +
  87 + std::vector<Hypotheses> cur;
  88 + for (auto &r : *result) {
  89 + cur.push_back(std::move(r.hyps));
  90 + }
  91 + std::vector<Hypothesis> prev;
  92 +
  93 + for (int32_t t = 0; t != num_frames; ++t) {
  94 + // Due to merging paths with identical token sequences,
  95 + // not all utterances have "num_active_paths" paths.
  96 + int32_t hyps_num_acc = 0;
  97 + std::vector<int32_t> hyps_num_split;
  98 + hyps_num_split.push_back(0);
  99 +
  100 + prev.clear();
  101 + for (auto &hyps : cur) {
  102 + for (auto &h : hyps) {
  103 + prev.push_back(std::move(h.second));
  104 + hyps_num_acc++;
  105 + }
  106 + hyps_num_split.push_back(hyps_num_acc);
  107 + }
  108 + cur.clear();
  109 + cur.reserve(batch_size);
  110 +
  111 + Ort::Value decoder_input = model_->BuildDecoderInput(prev);
  112 + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  113 +
  114 + Ort::Value cur_encoder_out =
  115 + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
  116 + cur_encoder_out =
  117 + Repeat(model_->Allocator(), &cur_encoder_out, hyps_num_split);
  118 + Ort::Value logit = model_->RunJoiner(
  119 + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
  120 + float *p_logit = logit.GetTensorMutableData<float>();
  121 +
  122 + for (int32_t b = 0; b < batch_size; ++b) {
  123 + int32_t start = hyps_num_split[b];
  124 + int32_t end = hyps_num_split[b + 1];
  125 + LogSoftmax(p_logit, vocab_size, (end - start));
  126 + auto topk =
  127 + TopkIndex(p_logit, vocab_size * (end - start), max_active_paths_);
  128 +
  129 + Hypotheses hyps;
  130 + for (auto i : topk) {
  131 + int32_t hyp_index = i / vocab_size + start;
  132 + int32_t new_token = i % vocab_size;
  133 +
  134 + Hypothesis new_hyp = prev[hyp_index];
  135 + if (new_token != 0) {
  136 + new_hyp.ys.push_back(new_token);
  137 + new_hyp.num_trailing_blanks = 0;
  138 + } else {
  139 + ++new_hyp.num_trailing_blanks;
  140 + }
  141 + new_hyp.log_prob += p_logit[i];
  142 + hyps.Add(std::move(new_hyp));
  143 + }
  144 + cur.push_back(std::move(hyps));
  145 + p_logit += vocab_size * (end - start);
  146 + }
  147 + }
  148 +
  149 + for (int32_t b = 0; b != batch_size; ++b) {
  150 + (*result)[b].hyps = std::move(cur[b]);
  151 + }
  152 +}
  153 +
  154 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-transducer-modified_beam-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Pingfeng Luo
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
  7 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
  8 +
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  12 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OnlineTransducerModifiedBeamSearchDecoder
  17 + : public OnlineTransducerDecoder {
  18 + public:
  19 + OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
  20 + int32_t max_active_paths)
  21 + : model_(model), max_active_paths_(max_active_paths) {}
  22 +
  23 + OnlineTransducerDecoderResult GetEmptyResult() const override;
  24 +
  25 + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override;
  26 +
  27 + void Decode(Ort::Value encoder_out,
  28 + std::vector<OnlineTransducerDecoderResult> *result) override;
  29 +
  30 + private:
  31 + OnlineTransducerModel *model_; // Not owned
  32 + int32_t max_active_paths_;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
@@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, @@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
461 return {std::move(encoder_out[0]), std::move(next_states)}; 461 return {std::move(encoder_out[0]), std::move(next_states)};
462 } 462 }
463 463
464 -Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput(  
465 - const std::vector<OnlineTransducerDecoderResult> &results) {  
466 - int32_t batch_size = static_cast<int32_t>(results.size());  
467 - std::array<int64_t, 2> shape{batch_size, context_size_};  
468 - Ort::Value decoder_input =  
469 - Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());  
470 - int64_t *p = decoder_input.GetTensorMutableData<int64_t>();  
471 -  
472 - for (const auto &r : results) {  
473 - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;  
474 - const int64_t *end = r.tokens.data() + r.tokens.size();  
475 - std::copy(begin, end, p);  
476 - p += context_size_;  
477 - }  
478 -  
479 - return decoder_input;  
480 -}  
481 -  
482 Ort::Value OnlineZipformerTransducerModel::RunDecoder( 464 Ort::Value OnlineZipformerTransducerModel::RunDecoder(
483 Ort::Value decoder_input) { 465 Ort::Value decoder_input) {
484 auto decoder_out = decoder_sess_->Run( 466 auto decoder_out = decoder_sess_->Run(
@@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { @@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
41 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 41 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
42 Ort::Value features, std::vector<Ort::Value> states) override; 42 Ort::Value features, std::vector<Ort::Value> states) override;
43 43
44 - Ort::Value BuildDecoderInput(  
45 - const std::vector<OnlineTransducerDecoderResult> &results) override;  
46 -  
47 Ort::Value RunDecoder(Ort::Value decoder_input) override; 44 Ort::Value RunDecoder(Ort::Value decoder_input) override;
48 45
49 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; 46 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
@@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, @@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
44 } 44 }
45 } 45 }
46 46
  47 +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
  48 + int32_t t) {
  49 + std::vector<int64_t> encoder_out_shape =
  50 + encoder_out->GetTensorTypeAndShapeInfo().GetShape();
  51 +
  52 + auto batch_size = encoder_out_shape[0];
  53 + auto num_frames = encoder_out_shape[1];
  54 + assert(t < num_frames);
  55 +
  56 + auto encoder_out_dim = encoder_out_shape[2];
  57 +
  58 + auto offset = num_frames * encoder_out_dim;
  59 +
  60 + auto memory_info =
  61 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  62 +
  63 + std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
  64 +
  65 + Ort::Value ans =
  66 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  67 +
  68 + float *dst = ans.GetTensorMutableData<float>();
  69 + const float *src = encoder_out->GetTensorData<float>();
  70 +
  71 + for (int32_t i = 0; i != batch_size; ++i) {
  72 + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);
  73 + src += offset;
  74 + dst += encoder_out_dim;
  75 + }
  76 + return ans;
  77 +}
  78 +
47 void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { 79 void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
48 Ort::AllocatorWithDefaultOptions allocator; 80 Ort::AllocatorWithDefaultOptions allocator;
49 std::vector<Ort::AllocatedStringPtr> v = 81 std::vector<Ort::AllocatedStringPtr> v =
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <locale> 10 #include <locale>
11 #endif 11 #endif
12 12
  13 +#include <cassert>
13 #include <ostream> 14 #include <ostream>
14 #include <string> 15 #include <string>
15 #include <vector> 16 #include <vector>
@@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, @@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
57 void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, 58 void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
58 std::vector<const char *> *output_names_ptr); 59 std::vector<const char *> *output_names_ptr);
59 60
  61 +/**
  62 + * Get the output frame of Encoder
  63 + *
  64 + * @param allocator allocator of onnxruntime
  65 + * @param encoder_out encoder out tensor
  66 + * @param t frame_index
  67 + *
  68 + */
  69 +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
  70 + int32_t t);
  71 +
60 void PrintModelMetadata(std::ostream &os, 72 void PrintModelMetadata(std::ostream &os,
61 const Ort::ModelMetadata &meta_data); // NOLINT 73 const Ort::ModelMetadata &meta_data); // NOLINT
62 74