Fangjun Kuang
Committed by GitHub

Add RNN LM rescore for offline ASR with modified_beam_search (#125)

@@ -56,3 +56,4 @@ run-offline-decode-files.sh @@ -56,3 +56,4 @@ run-offline-decode-files.sh
56 sherpa-onnx-nemo-ctc-en-citrinet-512 56 sherpa-onnx-nemo-ctc-en-citrinet-512
57 run-offline-decode-files-nemo-ctc.sh 57 run-offline-decode-files-nemo-ctc.sh
58 *.jar 58 *.jar
  59 +sherpa-onnx-nemo-ctc-*
@@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI) @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI)
51 set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE) 51 set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
52 endif() 52 endif()
53 53
  54 +if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS)
  55 + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON")
  56 + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
  57 +endif()
  58 +
54 if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS) 59 if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
55 message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON") 60 message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
56 set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) 61 set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
@@ -18,6 +18,8 @@ set(sources @@ -18,6 +18,8 @@ set(sources
18 hypothesis.cc 18 hypothesis.cc
19 offline-ctc-greedy-search-decoder.cc 19 offline-ctc-greedy-search-decoder.cc
20 offline-ctc-model.cc 20 offline-ctc-model.cc
  21 + offline-lm-config.cc
  22 + offline-lm.cc
21 offline-model-config.cc 23 offline-model-config.cc
22 offline-nemo-enc-dec-ctc-model-config.cc 24 offline-nemo-enc-dec-ctc-model-config.cc
23 offline-nemo-enc-dec-ctc-model.cc 25 offline-nemo-enc-dec-ctc-model.cc
@@ -26,10 +28,13 @@ set(sources @@ -26,10 +28,13 @@ set(sources
26 offline-paraformer-model.cc 28 offline-paraformer-model.cc
27 offline-recognizer-impl.cc 29 offline-recognizer-impl.cc
28 offline-recognizer.cc 30 offline-recognizer.cc
  31 + offline-rnn-lm.cc
29 offline-stream.cc 32 offline-stream.cc
30 offline-transducer-greedy-search-decoder.cc 33 offline-transducer-greedy-search-decoder.cc
31 offline-transducer-model-config.cc 34 offline-transducer-model-config.cc
32 offline-transducer-model.cc 35 offline-transducer-model.cc
  36 + offline-transducer-modified-beam-search-decoder.cc
  37 + online-lm-config.cc
33 online-lstm-transducer-model.cc 38 online-lstm-transducer-model.cc
34 online-recognizer.cc 39 online-recognizer.cc
35 online-stream.cc 40 online-stream.cc
@@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) { @@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) {
17 hyps_dict_[key] = std::move(hyp); 17 hyps_dict_[key] = std::move(hyp);
18 } else { 18 } else {
19 it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); 19 it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
  20 +
  21 + it->second.lm_log_prob =
  22 + LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
20 } 23 }
21 } 24 }
22 25
@@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { @@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
24 if (length_norm == false) { 27 if (length_norm == false) {
25 return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), 28 return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
26 [](const auto &left, auto &right) -> bool { 29 [](const auto &left, auto &right) -> bool {
27 - return left.second.log_prob <  
28 - right.second.log_prob; 30 + return left.second.TotalLogProb() <
  31 + right.second.TotalLogProb();
29 }) 32 })
30 ->second; 33 ->second;
31 } else { 34 } else {
@@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { @@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
33 return std::max_element( 36 return std::max_element(
34 hyps_dict_.begin(), hyps_dict_.end(), 37 hyps_dict_.begin(), hyps_dict_.end(),
35 [](const auto &left, const auto &right) -> bool { 38 [](const auto &left, const auto &right) -> bool {
36 - return left.second.log_prob / left.second.ys.size() <  
37 - right.second.log_prob / right.second.ys.size(); 39 + return left.second.TotalLogProb() / left.second.ys.size() <
  40 + right.second.TotalLogProb() / right.second.ys.size();
38 }) 41 })
39 ->second; 42 ->second;
40 } 43 }
@@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { @@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
47 std::vector<Hypothesis> all_hyps = Vec(); 50 std::vector<Hypothesis> all_hyps = Vec();
48 51
49 if (length_norm == false) { 52 if (length_norm == false) {
50 - std::partial_sort(  
51 - all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),  
52 - [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; }); 53 + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
  54 + [](const auto &a, const auto &b) {
  55 + return a.TotalLogProb() > b.TotalLogProb();
  56 + });
53 } else { 57 } else {
54 // for length_norm is true 58 // for length_norm is true
55 std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), 59 std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
56 [](const auto &a, const auto &b) { 60 [](const auto &a, const auto &b) {
57 - return a.log_prob / a.ys.size() >  
58 - b.log_prob / b.ys.size(); 61 + return a.TotalLogProb() / a.ys.size() >
  62 + b.TotalLogProb() / b.ys.size();
59 }); 63 });
60 } 64 }
61 65
@@ -25,14 +25,20 @@ struct Hypothesis { @@ -25,14 +25,20 @@ struct Hypothesis {
25 std::vector<int32_t> timestamps; 25 std::vector<int32_t> timestamps;
26 26
27 // The total score of ys in log space. 27 // The total score of ys in log space.
  28 + // It contains only acoustic scores
28 double log_prob = 0; 29 double log_prob = 0;
29 30
  31 + // LM log prob if any.
  32 + double lm_log_prob = 0;
  33 +
30 int32_t num_trailing_blanks = 0; 34 int32_t num_trailing_blanks = 0;
31 35
32 Hypothesis() = default; 36 Hypothesis() = default;
33 Hypothesis(const std::vector<int64_t> &ys, double log_prob) 37 Hypothesis(const std::vector<int64_t> &ys, double log_prob)
34 : ys(ys), log_prob(log_prob) {} 38 : ys(ys), log_prob(log_prob) {}
35 39
  40 + double TotalLogProb() const { return log_prob + lm_log_prob; }
  41 +
36 // If two Hypotheses have the same `Key`, then they contain 42 // If two Hypotheses have the same `Key`, then they contain
37 // the same token sequence. 43 // the same token sequence.
38 std::string Key() const { 44 std::string Key() const {
@@ -94,6 +100,9 @@ class Hypotheses { @@ -94,6 +100,9 @@ class Hypotheses {
94 const auto begin() const { return hyps_dict_.begin(); } 100 const auto begin() const { return hyps_dict_.begin(); }
95 const auto end() const { return hyps_dict_.end(); } 101 const auto end() const { return hyps_dict_.end(); }
96 102
  103 + auto begin() { return hyps_dict_.begin(); }
  104 + auto end() { return hyps_dict_.end(); }
  105 +
97 void Clear() { hyps_dict_.clear(); } 106 void Clear() { hyps_dict_.clear(); }
98 107
99 private: 108 private:
@@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) { @@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) {
88 } 88 }
89 } 89 }
90 90
  91 +template <typename T>
  92 +void LogSoftmax(T *in, int32_t w, int32_t h) {
  93 + for (int32_t i = 0; i != h; ++i) {
  94 + LogSoftmax(in, w);
  95 + in += w;
  96 + }
  97 +}
  98 +
  99 +// TODO(fangjun): use std::partial_sort to replace std::sort.
  100 +// Remember also to fix sherpa-ncnn
91 template <class T> 101 template <class T>
92 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { 102 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
93 std::vector<int32_t> vec_index(size); 103 std::vector<int32_t> vec_index(size);
  1 +// sherpa-onnx/csrc/offline-lm-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-lm-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OfflineLMConfig::Register(ParseOptions *po) {
  15 + po->Register("lm", &model, "Path to LM model.");
  16 + po->Register("lm-scale", &scale, "LM scale.");
  17 +}
  18 +
  19 +bool OfflineLMConfig::Validate() const {
  20 + if (!FileExists(model)) {
  21 + SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
  22 + return false;
  23 + }
  24 +
  25 + return true;
  26 +}
  27 +
  28 +std::string OfflineLMConfig::ToString() const {
  29 + std::ostringstream os;
  30 +
  31 + os << "OfflineLMConfig(";
  32 + os << "model=\"" << model << "\", ";
  33 + os << "scale=" << scale << ")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-lm-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineLMConfig {
  14 + // path to the onnx model
  15 + std::string model;
  16 +
  17 + // LM scale
  18 + float scale = 1.0;
  19 +
  20 + OfflineLMConfig() = default;
  21 +
  22 + OfflineLMConfig(const std::string &model, float scale)
  23 + : model(model), scale(scale) {}
  24 +
  25 + void Register(ParseOptions *po);
  26 + bool Validate() const;
  27 +
  28 + std::string ToString() const;
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-lm.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-lm.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-rnn-lm.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
  16 + return std::make_unique<OfflineRnnLM>(config);
  17 +}
  18 +
  19 +void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
  20 + std::vector<Hypotheses> *hyps) {
  21 + // compute the max token seq so that we know how much space to allocate
  22 + int32_t max_token_seq = 0;
  23 + int32_t num_hyps = 0;
  24 +
  25 + // we subtract context_size below since each token sequence is prepended
  26 + // with context_size blanks
  27 + for (const auto &h : *hyps) {
  28 + num_hyps += h.Size();
  29 + for (const auto &t : h) {
  30 + max_token_seq =
  31 + std::max<int32_t>(max_token_seq, t.second.ys.size() - context_size);
  32 + }
  33 + }
  34 +
  35 + Ort::AllocatorWithDefaultOptions allocator;
  36 + std::array<int64_t, 2> x_shape{num_hyps, max_token_seq};
  37 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator, x_shape.data(),
  38 + x_shape.size());
  39 +
  40 + std::array<int64_t, 1> x_lens_shape{num_hyps};
  41 + Ort::Value x_lens = Ort::Value::CreateTensor<int64_t>(
  42 + allocator, x_lens_shape.data(), x_lens_shape.size());
  43 +
  44 + int64_t *p = x.GetTensorMutableData<int64_t>();
  45 + std::fill(p, p + num_hyps * max_token_seq, 0);
  46 +
  47 + int64_t *p_lens = x_lens.GetTensorMutableData<int64_t>();
  48 +
  49 + for (const auto &h : *hyps) {
  50 + for (const auto &t : h) {
  51 + const auto &ys = t.second.ys;
  52 + int32_t len = ys.size() - context_size;
  53 + std::copy(ys.begin() + context_size, ys.end(), p);
  54 + *p_lens = len;
  55 +
  56 + p += max_token_seq;
  57 + ++p_lens;
  58 + }
  59 + }
  60 + auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
  61 + const float *p_nll = negative_loglike.GetTensorData<float>();
  62 + for (auto &h : *hyps) {
  63 + for (auto &t : h) {
  64 + // Use -scale here since we want to change negative loglike to loglike.
  65 + t.second.lm_log_prob = -scale * (*p_nll);
  66 + ++p_nll;
  67 + }
  68 + }
  69 +}
  70 +
  71 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-lm.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/hypothesis.h"
  13 +#include "sherpa-onnx/csrc/offline-lm-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineLM {
  18 + public:
  19 + virtual ~OfflineLM() = default;
  20 +
  21 + static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
  22 +
  23 + /** Rescore a batch of sentences.
  24 + *
  25 + * @param x A 2-D tensor of shape (N, L) with data type int64.
  26 + * @param x_lens A 1-D tensor of shape (N,) with data type int64.
  27 + * It contains number of valid tokens in x before padding.
  28 + * @return Return a 1-D tensor of shape (N,) containing the negative log
  29 + * likelihood of each utterance. Its data type is float32.
  30 + *
  31 + * Caution: It returns negative log likelihood (nll), not log likelihood
  32 + */
  33 + virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0;
  34 +
  35 + // This function updates hyp.lm_lob_prob of hyps.
  36 + //
  37 + // @param scale LM score
  38 + // @param context_size Context size of the transducer decoder model
  39 + // @param hyps It is changed in-place.
  40 + void ComputeLMScore(float scale, int32_t context_size,
  41 + std::vector<Hypotheses> *hyps);
  42 +};
  43 +
  44 +} // namespace sherpa_onnx
  45 +
  46 +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 #include "sherpa-onnx/csrc/offline-transducer-decoder.h" 16 #include "sherpa-onnx/csrc/offline-transducer-decoder.h"
17 #include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" 17 #include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
18 #include "sherpa-onnx/csrc/offline-transducer-model.h" 18 #include "sherpa-onnx/csrc/offline-transducer-model.h"
  19 +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
19 #include "sherpa-onnx/csrc/pad-sequence.h" 20 #include "sherpa-onnx/csrc/pad-sequence.h"
20 #include "sherpa-onnx/csrc/symbol-table.h" 21 #include "sherpa-onnx/csrc/symbol-table.h"
21 22
@@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
57 decoder_ = 58 decoder_ =
58 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 59 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
59 } else if (config_.decoding_method == "modified_beam_search") { 60 } else if (config_.decoding_method == "modified_beam_search") {
60 - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");  
61 - exit(-1); 61 + if (!config_.lm_config.model.empty()) {
  62 + lm_ = OfflineLM::Create(config.lm_config);
  63 + }
  64 +
  65 + decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
  66 + model_.get(), lm_.get(), config_.max_active_paths,
  67 + config_.lm_config.scale);
62 } else { 68 } else {
63 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 69 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
64 config_.decoding_method.c_str()); 70 config_.decoding_method.c_str());
@@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
127 SymbolTable symbol_table_; 133 SymbolTable symbol_table_;
128 std::unique_ptr<OfflineTransducerModel> model_; 134 std::unique_ptr<OfflineTransducerModel> model_;
129 std::unique_ptr<OfflineTransducerDecoder> decoder_; 135 std::unique_ptr<OfflineTransducerDecoder> decoder_;
  136 + std::unique_ptr<OfflineLM> lm_;
130 }; 137 };
131 138
132 } // namespace sherpa_onnx 139 } // namespace sherpa_onnx
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "sherpa-onnx/csrc/file-utils.h" 9 #include "sherpa-onnx/csrc/file-utils.h"
10 #include "sherpa-onnx/csrc/macros.h" 10 #include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/offline-lm-config.h"
11 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
@@ -15,13 +16,28 @@ namespace sherpa_onnx { @@ -15,13 +16,28 @@ namespace sherpa_onnx {
15 void OfflineRecognizerConfig::Register(ParseOptions *po) { 16 void OfflineRecognizerConfig::Register(ParseOptions *po) {
16 feat_config.Register(po); 17 feat_config.Register(po);
17 model_config.Register(po); 18 model_config.Register(po);
  19 + lm_config.Register(po);
18 20
19 - po->Register("decoding-method", &decoding_method,  
20 - "decoding method,"  
21 - "Valid values: greedy_search."); 21 + po->Register(
  22 + "decoding-method", &decoding_method,
  23 + "decoding method,"
  24 + "Valid values: greedy_search, modified_beam_search. "
  25 + "modified_beam_search is applicable only for transducer models.");
  26 +
  27 + po->Register("max-active-paths", &max_active_paths,
  28 + "Used only when decoding_method is modified_beam_search");
22 } 29 }
23 30
24 bool OfflineRecognizerConfig::Validate() const { 31 bool OfflineRecognizerConfig::Validate() const {
  32 + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
  33 + if (max_active_paths <= 0) {
  34 + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
  35 + max_active_paths);
  36 + return false;
  37 + }
  38 + if (!lm_config.Validate()) return false;
  39 + }
  40 +
25 return model_config.Validate(); 41 return model_config.Validate();
26 } 42 }
27 43
@@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const {
31 os << "OfflineRecognizerConfig("; 47 os << "OfflineRecognizerConfig(";
32 os << "feat_config=" << feat_config.ToString() << ", "; 48 os << "feat_config=" << feat_config.ToString() << ", ";
33 os << "model_config=" << model_config.ToString() << ", "; 49 os << "model_config=" << model_config.ToString() << ", ";
34 - os << "decoding_method=\"" << decoding_method << "\")"; 50 + os << "lm_config=" << lm_config.ToString() << ", ";
  51 + os << "decoding_method=\"" << decoding_method << "\", ";
  52 + os << "max_active_paths=" << max_active_paths << ")";
35 53
36 return os.str(); 54 return os.str();
37 } 55 }
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/offline-lm-config.h"
12 #include "sherpa-onnx/csrc/offline-model-config.h" 13 #include "sherpa-onnx/csrc/offline-model-config.h"
13 #include "sherpa-onnx/csrc/offline-stream.h" 14 #include "sherpa-onnx/csrc/offline-stream.h"
14 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 15 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
@@ -21,18 +22,24 @@ struct OfflineRecognitionResult; @@ -21,18 +22,24 @@ struct OfflineRecognitionResult;
21 struct OfflineRecognizerConfig { 22 struct OfflineRecognizerConfig {
22 OfflineFeatureExtractorConfig feat_config; 23 OfflineFeatureExtractorConfig feat_config;
23 OfflineModelConfig model_config; 24 OfflineModelConfig model_config;
  25 + OfflineLMConfig lm_config;
24 26
25 std::string decoding_method = "greedy_search"; 27 std::string decoding_method = "greedy_search";
  28 + int32_t max_active_paths = 4;
26 // only greedy_search is implemented 29 // only greedy_search is implemented
27 // TODO(fangjun): Implement modified_beam_search 30 // TODO(fangjun): Implement modified_beam_search
28 31
29 OfflineRecognizerConfig() = default; 32 OfflineRecognizerConfig() = default;
30 OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, 33 OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
31 const OfflineModelConfig &model_config, 34 const OfflineModelConfig &model_config,
32 - const std::string &decoding_method) 35 + const OfflineLMConfig &lm_config,
  36 + const std::string &decoding_method,
  37 + int32_t max_active_paths)
33 : feat_config(feat_config), 38 : feat_config(feat_config),
34 model_config(model_config), 39 model_config(model_config),
35 - decoding_method(decoding_method) {} 40 + lm_config(lm_config),
  41 + decoding_method(decoding_method),
  42 + max_active_paths(max_active_paths) {}
36 43
37 void Register(ParseOptions *po); 44 void Register(ParseOptions *po);
38 bool Validate() const; 45 bool Validate() const;
  1 +// sherpa-onnx/csrc/offline-rnn-lm.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-rnn-lm.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +#include "sherpa-onnx/csrc/text-utils.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OfflineRnnLM::Impl {
  19 + public:
  20 + explicit Impl(const OfflineLMConfig &config)
  21 + : config_(config),
  22 + env_(ORT_LOGGING_LEVEL_ERROR),
  23 + sess_opts_{},
  24 + allocator_{} {
  25 + Init(config);
  26 + }
  27 +
  28 + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
  29 + std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
  30 +
  31 + auto out =
  32 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  33 + output_names_ptr_.data(), output_names_ptr_.size());
  34 +
  35 + return std::move(out[0]);
  36 + }
  37 +
  38 + private:
  39 + void Init(const OfflineLMConfig &config) {
  40 + auto buf = ReadFile(config_.model);
  41 +
  42 + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
  43 + sess_opts_);
  44 +
  45 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  46 +
  47 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  48 + }
  49 +
  50 + private:
  51 + OfflineLMConfig config_;
  52 + Ort::Env env_;
  53 + Ort::SessionOptions sess_opts_;
  54 + Ort::AllocatorWithDefaultOptions allocator_;
  55 +
  56 + std::unique_ptr<Ort::Session> sess_;
  57 +
  58 + std::vector<std::string> input_names_;
  59 + std::vector<const char *> input_names_ptr_;
  60 +
  61 + std::vector<std::string> output_names_;
  62 + std::vector<const char *> output_names_ptr_;
  63 +};
  64 +
  65 +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
  66 + : impl_(std::make_unique<Impl>(config)) {}
  67 +
  68 +OfflineRnnLM::~OfflineRnnLM() = default;
  69 +
  70 +Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {
  71 + return impl_->Rescore(std::move(x), std::move(x_lens));
  72 +}
  73 +
  74 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-rnn-lm.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
  7 +
  8 +#include <memory>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/offline-lm-config.h"
  12 +#include "sherpa-onnx/csrc/offline-lm.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineRnnLM : public OfflineLM {
  17 + public:
  18 + ~OfflineRnnLM() override;
  19 +
  20 + explicit OfflineRnnLM(const OfflineLMConfig &config);
  21 +
  22 + /** Rescore a batch of sentences.
  23 + *
  24 + * @param x A 2-D tensor of shape (N, L) with data type int64.
  25 + * @param x_lens A 1-D tensor of shape (N,) with data type int64.
  26 + * It contains number of valid tokens in x before padding.
  27 + * @return Return a 1-D tensor of shape (N,) containing the log likelihood
  28 + * of each utterance. Its data type is float32.
  29 + *
  30 + * Caution: It returns log likelihood, not negative log likelihood (nll).
  31 + */
  32 + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override;
  33 +
  34 + private:
  35 + class Impl;
  36 + std::unique_ptr<Impl> impl_;
  37 +};
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
@@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl { @@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
95 std::copy(begin, end, p); 95 std::copy(begin, end, p);
96 p += context_size; 96 p += context_size;
97 } 97 }
  98 +
  99 + return decoder_input;
  100 + }
  101 +
  102 + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
  103 + int32_t end_index) const {
  104 + assert(end_index <= results.size());
  105 +
  106 + int32_t batch_size = end_index;
  107 + int32_t context_size = ContextSize();
  108 + std::array<int64_t, 2> shape{batch_size, context_size};
  109 +
  110 + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
  111 + Allocator(), shape.data(), shape.size());
  112 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  113 +
  114 + for (int32_t i = 0; i != batch_size; ++i) {
  115 + const auto &r = results[i];
  116 + const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
  117 + const int64_t *end = r.ys.data() + r.ys.size();
  118 + std::copy(begin, end, p);
  119 + p += context_size;
  120 + }
  121 +
98 return decoder_input; 122 return decoder_input;
99 } 123 }
100 124
@@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput( @@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
234 return impl_->BuildDecoderInput(results, end_index); 258 return impl_->BuildDecoderInput(results, end_index);
235 } 259 }
236 260
  261 +Ort::Value OfflineTransducerModel::BuildDecoderInput(
  262 + const std::vector<Hypothesis> &results, int32_t end_index) const {
  263 + return impl_->BuildDecoderInput(results, end_index);
  264 +}
  265 +
237 } // namespace sherpa_onnx 266 } // namespace sherpa_onnx
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/hypothesis.h"
12 #include "sherpa-onnx/csrc/offline-model-config.h" 13 #include "sherpa-onnx/csrc/offline-model-config.h"
13 14
14 namespace sherpa_onnx { 15 namespace sherpa_onnx {
@@ -79,13 +80,16 @@ class OfflineTransducerModel { @@ -79,13 +80,16 @@ class OfflineTransducerModel {
79 * 80 *
80 * @param results Current decoded results. 81 * @param results Current decoded results.
81 * @param end_index We only use results[0:end_index] to build 82 * @param end_index We only use results[0:end_index] to build
82 - * the decoder_input. 83 + * the decoder_input. results[end_index] is not used.
83 * @return Return a tensor of shape (results.size(), ContextSize()) 84 * @return Return a tensor of shape (results.size(), ContextSize())
84 */ 85 */
85 Ort::Value BuildDecoderInput( 86 Ort::Value BuildDecoderInput(
86 const std::vector<OfflineTransducerDecoderResult> &results, 87 const std::vector<OfflineTransducerDecoderResult> &results,
87 int32_t end_index) const; 88 int32_t end_index) const;
88 89
  90 + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
  91 + int32_t end_index) const;
  92 +
89 private: 93 private:
90 class Impl; 94 class Impl;
91 std::unique_ptr<Impl> impl_; 95 std::unique_ptr<Impl> impl_;
  1 +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
  6 +
  7 +#include <deque>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/hypothesis.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +#include "sherpa-onnx/csrc/packed-sequence.h"
  14 +#include "sherpa-onnx/csrc/slice.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +static std::vector<int32_t> GetHypsRowSplits(
  19 + const std::vector<Hypotheses> &hyps) {
  20 + std::vector<int32_t> row_splits;
  21 + row_splits.reserve(hyps.size() + 1);
  22 +
  23 + row_splits.push_back(0);
  24 + int32_t s = 0;
  25 + for (const auto &h : hyps) {
  26 + s += h.Size();
  27 + row_splits.push_back(s);
  28 + }
  29 +
  30 + return row_splits;
  31 +}
  32 +
  33 +std::vector<OfflineTransducerDecoderResult>
  34 +OfflineTransducerModifiedBeamSearchDecoder::Decode(
  35 + Ort::Value encoder_out, Ort::Value encoder_out_length) {
  36 + PackedSequence packed_encoder_out = PackPaddedSequence(
  37 + model_->Allocator(), &encoder_out, &encoder_out_length);
  38 +
  39 + int32_t batch_size =
  40 + static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
  41 +
  42 + int32_t vocab_size = model_->VocabSize();
  43 + int32_t context_size = model_->ContextSize();
  44 +
  45 + std::vector<int64_t> blanks(context_size, 0);
  46 + Hypotheses blank_hyp({{blanks, 0}});
  47 +
  48 + std::deque<Hypotheses> finalized;
  49 + std::vector<Hypotheses> cur(batch_size, blank_hyp);
  50 + std::vector<Hypothesis> prev;
  51 +
  52 + int32_t start = 0;
  53 + int32_t t = 0;
  54 + for (auto n : packed_encoder_out.batch_sizes) {
  55 + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
  56 + start += n;
  57 +
  58 + if (n < static_cast<int32_t>(cur.size())) {
  59 + for (int32_t k = static_cast<int32_t>(cur.size()) - 1; k >= n; --k) {
  60 + finalized.push_front(std::move(cur[k]));
  61 + }
  62 +
  63 + cur.erase(cur.begin() + n, cur.end());
  64 + } // if (n < static_cast<int32_t>(cur.size()))
  65 +
  66 + // Due to merging paths with identical token sequences,
  67 + // not all utterances have "max_active_paths" paths.
  68 + auto hyps_row_splits = GetHypsRowSplits(cur);
  69 + int32_t num_hyps = hyps_row_splits.back();
  70 +
  71 + prev.clear();
  72 + prev.reserve(num_hyps);
  73 +
  74 + for (auto &hyps : cur) {
  75 + for (auto &h : hyps) {
  76 + prev.push_back(std::move(h.second));
  77 + }
  78 + }
  79 + cur.clear();
  80 + cur.reserve(n);
  81 +
  82 + auto decoder_input = model_->BuildDecoderInput(prev, num_hyps);
  83 + // decoder_input shape: (num_hyps, context_size)
  84 +
  85 + auto decoder_out = model_->RunDecoder(std::move(decoder_input));
  86 + // decoder_out is (num_hyps, joiner_dim)
  87 +
  88 + cur_encoder_out =
  89 + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
  90 + // now cur_encoder_out is of shape (num_hyps, joiner_dim)
  91 +
  92 + Ort::Value logit = model_->RunJoiner(
  93 + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
  94 +
  95 + float *p_logit = logit.GetTensorMutableData<float>();
  96 + LogSoftmax(p_logit, vocab_size, num_hyps);
  97 +
  98 + // now p_logit contains log_softmax output, we rename it to p_logprob
  99 + // to match what it actually contains
  100 + float *p_logprob = p_logit;
  101 +
  102 + // add log_prob of each hypothesis to p_logprob before taking top_k
  103 + for (int32_t i = 0; i != num_hyps; ++i) {
  104 + float log_prob = prev[i].log_prob;
  105 + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
  106 + *p_logprob += log_prob;
  107 + }
  108 + }
  109 + p_logprob = p_logit; // we changed p_logprob in the above for loop
  110 +
  111 + // Now compute top_k for each utterance
  112 + for (int32_t i = 0; i != n; ++i) {
  113 + int32_t start = hyps_row_splits[i];
  114 + int32_t end = hyps_row_splits[i + 1];
  115 + auto topk =
  116 + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
  117 +
  118 + Hypotheses hyps;
  119 + for (auto k : topk) {
  120 + int32_t hyp_index = k / vocab_size + start;
  121 + int32_t new_token = k % vocab_size;
  122 + Hypothesis new_hyp = prev[hyp_index];
  123 +
  124 + if (new_token != 0) {
  125 + // blank id is fixed to 0
  126 + new_hyp.ys.push_back(new_token);
  127 + new_hyp.timestamps.push_back(t);
  128 + }
  129 +
  130 + new_hyp.log_prob = p_logprob[k];
  131 + hyps.Add(std::move(new_hyp));
  132 + } // for (auto k : topk)
  133 + p_logprob += (end - start) * vocab_size;
  134 + cur.push_back(std::move(hyps));
  135 + } // for (int32_t i = 0; i != n; ++i)
  136 +
  137 + ++t;
  138 + } // for (auto n : packed_encoder_out.batch_sizes)
  139 +
  140 + for (auto &h : finalized) {
  141 + cur.push_back(std::move(h));
  142 + }
  143 +
  144 + if (lm_) {
  145 + // use LM for rescoring
  146 + lm_->ComputeLMScore(lm_scale_, context_size, &cur);
  147 + }
  148 +
  149 + std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
  150 + for (int32_t i = 0; i != batch_size; ++i) {
  151 + Hypothesis hyp = cur[i].GetMostProbable(true);
  152 +
  153 + auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]];
  154 +
  155 + // strip leading blanks
  156 + r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()};
  157 + r.timestamps = std::move(hyp.timestamps);
  158 + }
  159 +
  160 + return unsorted_ans;
  161 +}
  162 +
  163 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-lm.h"
  11 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  12 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineTransducerModifiedBeamSearchDecoder
  17 + : public OfflineTransducerDecoder {
  18 + public:
  19 + OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
  20 + OfflineLM *lm,
  21 + int32_t max_active_paths,
  22 + float lm_scale)
  23 + : model_(model),
  24 + lm_(lm),
  25 + max_active_paths_(max_active_paths),
  26 + lm_scale_(lm_scale) {}
  27 +
  28 + std::vector<OfflineTransducerDecoderResult> Decode(
  29 + Ort::Value encoder_out, Ort::Value encoder_out_length) override;
  30 +
  31 + private:
  32 + OfflineTransducerModel *model_; // Not owned
  33 + OfflineLM *lm_; // Not owned; may be nullptr
  34 +
  35 + int32_t max_active_paths_;
  36 + float lm_scale_; // used only when lm_ is not nullptr
  37 +};
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/online-lm-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-lm-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OnlineLMConfig::Register(ParseOptions *po) {
  15 + po->Register("lm", &model, "Path to LM model.");
  16 + po->Register("lm-scale", &scale, "LM scale.");
  17 +}
  18 +
  19 +bool OnlineLMConfig::Validate() const {
  20 + if (!FileExists(model)) {
  21 + SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
  22 + return false;
  23 + }
  24 +
  25 + return true;
  26 +}
  27 +
  28 +std::string OnlineLMConfig::ToString() const {
  29 + std::ostringstream os;
  30 +
  31 + os << "OnlineLMConfig(";
  32 + os << "model=\"" << model << "\", ";
  33 + os << "scale=" << scale << ")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-lm-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OnlineLMConfig {
  14 + // path to the onnx model
  15 + std::string model;
  16 +
  17 + // LM scale
  18 + float scale = 1.0;
  19 +
  20 + OnlineLMConfig() = default;
  21 +
  22 + OnlineLMConfig(const std::string &model, float scale)
  23 + : model(model), scale(scale) {}
  24 +
  25 + void Register(ParseOptions *po);
  26 + bool Validate() const;
  27 +
  28 + std::string ToString() const;
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-lm.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_LM_H_
  7 +
  8 +#include <memory>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/hypothesis.h"
  14 +#include "sherpa-onnx/csrc/online-lm-config.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OnlineLM {
  19 + public:
  20 + virtual ~OnlineLM() = default;
  21 +
  22 + static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
  23 +
  24 + virtual std::vector<Ort::Value> GetInitStates() = 0;
  25 +
  26 + /** Rescore a batch of sentences.
  27 + *
  28 + * @param x A 2-D tensor of shape (N, L) with data type int64.
  29 + * @param y A 2-D tensor of shape (N, L) with data type int64.
  30 + * @param states It contains the states for the LM model
  31 + * @return Return a pair containingo
  32 + * - negative loglike
  33 + * - updated states
  34 + *
  35 + * Caution: It returns negative log likelihood (nll), not log likelihood
  36 + */
  37 + std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
  38 + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
  39 +
  40 + // This function updates hyp.lm_lob_prob of hyps.
  41 + //
  42 + // @param scale LM score
  43 + // @param context_size Context size of the transducer decoder model
  44 + // @param hyps It is changed in-place.
  45 + void ComputeLMScore(float scale, int32_t context_size,
  46 + std::vector<Hypotheses> *hyps);
  47 + /** TODO(fangjun):
  48 + *
  49 + * 1. Add two fields to Hypothesis
  50 + * (a) int32_t lm_cur_pos = 0; number of scored tokens so far
  51 + * (b) std::vector<Ort::Value> lm_states;
  52 + * 2. When we want to score a hypothesis, we construct x and y as follows:
  53 + *
  54 + * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
  55 + * hyp.ys.end() - 1};
  56 + * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
  57 + * hyp.ys.end()};
  58 + * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
  59 + */
  60 +};
  61 +
  62 +} // namespace sherpa_onnx
  63 +
  64 +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_
@@ -36,38 +36,6 @@ static void UseCachedDecoderOut( @@ -36,38 +36,6 @@ static void UseCachedDecoderOut(
36 } 36 }
37 } 37 }
38 38
39 -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,  
40 - const std::vector<int32_t> &hyps_num_split) {  
41 - std::vector<int64_t> cur_encoder_out_shape =  
42 - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();  
43 -  
44 - std::array<int64_t, 2> ans_shape{hyps_num_split.back(),  
45 - cur_encoder_out_shape[1]};  
46 -  
47 - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),  
48 - ans_shape.size());  
49 -  
50 - const float *src = cur_encoder_out->GetTensorData<float>();  
51 - float *dst = ans.GetTensorMutableData<float>();  
52 - int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;  
53 - for (int32_t b = 0; b != batch_size; ++b) {  
54 - int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];  
55 - for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {  
56 - std::copy(src, src + cur_encoder_out_shape[1], dst);  
57 - dst += cur_encoder_out_shape[1];  
58 - }  
59 - src += cur_encoder_out_shape[1];  
60 - }  
61 - return ans;  
62 -}  
63 -  
64 -static void LogSoftmax(float *in, int32_t w, int32_t h) {  
65 - for (int32_t i = 0; i != h; ++i) {  
66 - LogSoftmax(in, w);  
67 - in += w;  
68 - }  
69 -}  
70 -  
71 OnlineTransducerDecoderResult 39 OnlineTransducerDecoderResult
72 OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { 40 OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
73 int32_t context_size = model_->ContextSize(); 41 int32_t context_size = model_->ContextSize();
@@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) { @@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
193 } 193 }
194 #endif 194 #endif
195 195
  196 +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
  197 + const std::vector<int32_t> &hyps_num_split) {
  198 + std::vector<int64_t> cur_encoder_out_shape =
  199 + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
  200 +
  201 + std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
  202 + cur_encoder_out_shape[1]};
  203 +
  204 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  205 + ans_shape.size());
  206 +
  207 + const float *src = cur_encoder_out->GetTensorData<float>();
  208 + float *dst = ans.GetTensorMutableData<float>();
  209 + int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
  210 + for (int32_t b = 0; b != batch_size; ++b) {
  211 + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
  212 + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
  213 + std::copy(src, src + cur_encoder_out_shape[1], dst);
  214 + dst += cur_encoder_out_shape[1];
  215 + }
  216 + src += cur_encoder_out_shape[1];
  217 + }
  218 + return ans;
  219 +}
  220 +
196 } // namespace sherpa_onnx 221 } // namespace sherpa_onnx
@@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename); @@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename);
86 std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); 86 std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
87 #endif 87 #endif
88 88
  89 +// TODO(fangjun): Document it
  90 +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
  91 + const std::vector<int32_t> &hyps_num_split);
89 } // namespace sherpa_onnx 92 } // namespace sherpa_onnx
90 93
91 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 94 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
@@ -111,6 +111,9 @@ for a list of pre-trained models to download. @@ -111,6 +111,9 @@ for a list of pre-trained models to download.
111 111
112 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); 112 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
113 fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); 113 fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
  114 + if (config.decoding_method == "modified_beam_search") {
  115 + fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
  116 + }
114 117
115 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 118 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
116 float rtf = elapsed_seconds / duration; 119 float rtf = elapsed_seconds / duration;
@@ -117,6 +117,9 @@ for a list of pre-trained models to download. @@ -117,6 +117,9 @@ for a list of pre-trained models to download.
117 117
118 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); 118 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
119 fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); 119 fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
  120 + if (config.decoding_method == "modified_beam_search") {
  121 + fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
  122 + }
120 123
121 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 124 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
122 float rtf = elapsed_seconds / duration; 125 float rtf = elapsed_seconds / duration;
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
4 display.cc 4 display.cc
5 endpoint.cc 5 endpoint.cc
6 features.cc 6 features.cc
  7 + offline-lm-config.cc
7 offline-model-config.cc 8 offline-model-config.cc
8 offline-nemo-enc-dec-ctc-model-config.cc 9 offline-nemo-enc-dec-ctc-model-config.cc
9 offline-paraformer-model-config.cc 10 offline-paraformer-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-lm-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-lm-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx//csrc/offline-lm-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineLMConfig(py::module *m) {
  14 + using PyClass = OfflineLMConfig;
  15 + py::class_<PyClass>(*m, "OfflineLMConfig")
  16 + .def(py::init<const std::string &, float>(), py::arg("model"),
  17 + py::arg("scale"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def_readwrite("scale", &PyClass::scale)
  20 + .def("__str__", &PyClass::ToString);
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-lm-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineLMConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
@@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
15 using PyClass = OfflineRecognizerConfig; 15 using PyClass = OfflineRecognizerConfig;
16 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
17 .def(py::init<const OfflineFeatureExtractorConfig &, 17 .def(py::init<const OfflineFeatureExtractorConfig &,
18 - const OfflineModelConfig &, const std::string &>(), 18 + const OfflineModelConfig &, const OfflineLMConfig &,
  19 + const std::string &, int32_t>(),
19 py::arg("feat_config"), py::arg("model_config"), 20 py::arg("feat_config"), py::arg("model_config"),
20 - py::arg("decoding_method")) 21 + py::arg("lm_config") = OfflineLMConfig(),
  22 + py::arg("decoding_method") = "greedy_search",
  23 + py::arg("max_active_paths") = 4)
21 .def_readwrite("feat_config", &PyClass::feat_config) 24 .def_readwrite("feat_config", &PyClass::feat_config)
22 .def_readwrite("model_config", &PyClass::model_config) 25 .def_readwrite("model_config", &PyClass::model_config)
  26 + .def_readwrite("lm_config", &PyClass::lm_config)
23 .def_readwrite("decoding_method", &PyClass::decoding_method) 27 .def_readwrite("decoding_method", &PyClass::decoding_method)
  28 + .def_readwrite("max_active_paths", &PyClass::max_active_paths)
24 .def("__str__", &PyClass::ToString); 29 .def("__str__", &PyClass::ToString);
25 } 30 }
26 31
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 #include "sherpa-onnx/python/csrc/display.h" 7 #include "sherpa-onnx/python/csrc/display.h"
8 #include "sherpa-onnx/python/csrc/endpoint.h" 8 #include "sherpa-onnx/python/csrc/endpoint.h"
9 #include "sherpa-onnx/python/csrc/features.h" 9 #include "sherpa-onnx/python/csrc/features.h"
  10 +#include "sherpa-onnx/python/csrc/offline-lm-config.h"
10 #include "sherpa-onnx/python/csrc/offline-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 12 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
12 #include "sherpa-onnx/python/csrc/offline-stream.h" 13 #include "sherpa-onnx/python/csrc/offline-stream.h"
@@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
28 PybindDisplay(&m); 29 PybindDisplay(&m);
29 30
30 PybindOfflineStream(&m); 31 PybindOfflineStream(&m);
  32 + PybindOfflineLMConfig(&m);
31 PybindOfflineModelConfig(&m); 33 PybindOfflineModelConfig(&m);
32 PybindOfflineRecognizer(&m); 34 PybindOfflineRecognizer(&m);
33 } 35 }