PF Luo
Committed by GitHub

Add lm rescore to online-modified-beam-search (#133)

@@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() { @@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() {
182 val config = OnlineRecognizerConfig( 182 val config = OnlineRecognizerConfig(
183 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 183 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
184 modelConfig = getModelConfig(type = type)!!, 184 modelConfig = getModelConfig(type = type)!!,
  185 + lmConfig = getOnlineLMConfig(type = type),
185 endpointConfig = getEndpointConfig(), 186 endpointConfig = getEndpointConfig(),
186 enableEndpoint = true, 187 enableEndpoint = true,
187 - decodingMethod = "greedy_search", 188 + decodingMethod = "modified_beam_search",
188 maxActivePaths = 4, 189 maxActivePaths = 4,
189 ) 190 )
190 191
@@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig( @@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig(
23 var debug: Boolean = false, 23 var debug: Boolean = false,
24 ) 24 )
25 25
  26 +data class OnlineLMConfig(
  27 + var model: String = "",
  28 + var scale: Float = 0.5f,
  29 +)
  30 +
26 data class FeatureConfig( 31 data class FeatureConfig(
27 var sampleRate: Int = 16000, 32 var sampleRate: Int = 16000,
28 var featureDim: Int = 80, 33 var featureDim: Int = 80,
@@ -31,6 +36,7 @@ data class FeatureConfig( @@ -31,6 +36,7 @@ data class FeatureConfig(
31 data class OnlineRecognizerConfig( 36 data class OnlineRecognizerConfig(
32 var featConfig: FeatureConfig = FeatureConfig(), 37 var featConfig: FeatureConfig = FeatureConfig(),
33 var modelConfig: OnlineTransducerModelConfig, 38 var modelConfig: OnlineTransducerModelConfig,
  39 + var lmConfig : OnlineLMConfig,
34 var endpointConfig: EndpointConfig = EndpointConfig(), 40 var endpointConfig: EndpointConfig = EndpointConfig(),
35 var enableEndpoint: Boolean = true, 41 var enableEndpoint: Boolean = true,
36 var decodingMethod: String = "greedy_search", 42 var decodingMethod: String = "greedy_search",
@@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
151 return null; 157 return null;
152 } 158 }
153 159
  160 +/*
  161 +Please see
  162 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  163 +for a list of pre-trained models.
  164 +
  165 +We only add a few here. Please change the following code
  166 +to add your own LM model. (It should be straightforward to train a new NN LM model
  167 +by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
  168 +
  169 +@param type
  170 +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
  171 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
  172 + */
  173 +fun getOnlineLMConfig(type : Int): OnlineLMConfig {
  174 + when (type) {
  175 + 0 -> {
  176 + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
  177 + return OnlineLMConfig(
  178 + model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
  179 + scale = 0.5f,
  180 + )
  181 + }
  182 + }
  183 + return OnlineLMConfig();
  184 +}
  185 +
154 fun getEndpointConfig(): EndpointConfig { 186 fun getEndpointConfig(): EndpointConfig {
155 return EndpointConfig( 187 return EndpointConfig(
156 rule1 = EndpointRule(false, 2.4f, 0.0f), 188 rule1 = EndpointRule(false, 2.4f, 0.0f),
@@ -22,8 +22,11 @@ fun main() { @@ -22,8 +22,11 @@ fun main() {
22 22
23 var endpointConfig = EndpointConfig() 23 var endpointConfig = EndpointConfig()
24 24
  25 + var lmConfig = OnlineLMConfig()
  26 +
25 var config = OnlineRecognizerConfig( 27 var config = OnlineRecognizerConfig(
26 modelConfig = modelConfig, 28 modelConfig = modelConfig,
  29 + lmConfig = lmConfig,
27 featConfig = featConfig, 30 featConfig = featConfig,
28 endpointConfig = endpointConfig, 31 endpointConfig = endpointConfig,
29 enableEndpoint = true, 32 enableEndpoint = true,
@@ -34,9 +34,11 @@ set(sources @@ -34,9 +34,11 @@ set(sources
34 offline-transducer-model-config.cc 34 offline-transducer-model-config.cc
35 offline-transducer-model.cc 35 offline-transducer-model.cc
36 offline-transducer-modified-beam-search-decoder.cc 36 offline-transducer-modified-beam-search-decoder.cc
  37 + online-lm.cc
37 online-lm-config.cc 38 online-lm-config.cc
38 online-lstm-transducer-model.cc 39 online-lstm-transducer-model.cc
39 online-recognizer.cc 40 online-recognizer.cc
  41 + online-rnn-lm.cc
40 online-stream.cc 42 online-stream.cc
41 online-transducer-decoder.cc 43 online-transducer-decoder.cc
42 online-transducer-greedy-search-decoder.cc 44 online-transducer-greedy-search-decoder.cc
1 /** 1 /**
2 * Copyright (c) 2023 Xiaomi Corporation 2 * Copyright (c) 2023 Xiaomi Corporation
3 - * 3 + * Copyright (c) 2023 Pingfeng Luo
4 */ 4 */
5 5
6 #include "sherpa-onnx/csrc/hypothesis.h" 6 #include "sherpa-onnx/csrc/hypothesis.h"
1 /** 1 /**
2 * Copyright (c) 2023 Xiaomi Corporation 2 * Copyright (c) 2023 Xiaomi Corporation
  3 + * Copyright (c) 2023 Pingfeng Luo
3 * 4 *
4 */ 5 */
5 6
@@ -12,7 +13,9 @@ @@ -12,7 +13,9 @@
12 #include <utility> 13 #include <utility>
13 #include <vector> 14 #include <vector>
14 15
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
15 #include "sherpa-onnx/csrc/math.h" 17 #include "sherpa-onnx/csrc/math.h"
  18 +#include "sherpa-onnx/csrc/onnx-utils.h"
16 19
17 namespace sherpa_onnx { 20 namespace sherpa_onnx {
18 21
@@ -31,6 +34,13 @@ struct Hypothesis { @@ -31,6 +34,13 @@ struct Hypothesis {
31 // LM log prob if any. 34 // LM log prob if any.
32 double lm_log_prob = 0; 35 double lm_log_prob = 0;
33 36
  37 + int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM
  38 + std::vector<CopyableOrtValue> nn_lm_states;
  39 +
  40 + // TODO(fangjun): Make it configurable
  41 + // the minimum of tokens in a chunk for streaming RNN LM
  42 + int32_t lm_rescore_min_chunk = 2; // a const
  43 +
34 int32_t num_trailing_blanks = 0; 44 int32_t num_trailing_blanks = 0;
35 45
36 Hypothesis() = default; 46 Hypothesis() = default;
@@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { @@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
96 } 96 }
97 } 97 }
98 98
99 -// TODO(fangjun): use std::partial_sort to replace std::sort.  
100 -// Remember also to fix sherpa-ncnn  
101 template <class T> 99 template <class T>
102 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { 100 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
103 std::vector<int32_t> vec_index(size); 101 std::vector<int32_t> vec_index(size);
104 std::iota(vec_index.begin(), vec_index.end(), 0); 102 std::iota(vec_index.begin(), vec_index.end(), 0);
105 103
106 - std::sort(vec_index.begin(), vec_index.end(),  
107 - [vec](int32_t index_1, int32_t index_2) {  
108 - return vec[index_1] > vec[index_2];  
109 - }); 104 + std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
  105 + vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
  106 + return vec[index_1] > vec[index_2];
  107 + });
110 108
111 int32_t k_num = std::min<int32_t>(size, topk); 109 int32_t k_num = std::min<int32_t>(size, topk);
112 std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num); 110 std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
@@ -15,7 +15,7 @@ struct OfflineLMConfig { @@ -15,7 +15,7 @@ struct OfflineLMConfig {
15 std::string model; 15 std::string model;
16 16
17 // LM scale 17 // LM scale
18 - float scale = 1.0; 18 + float scale = 0.5;
19 19
20 OfflineLMConfig() = default; 20 OfflineLMConfig() = default;
21 21
@@ -15,7 +15,7 @@ struct OnlineLMConfig { @@ -15,7 +15,7 @@ struct OnlineLMConfig {
15 std::string model; 15 std::string model;
16 16
17 // LM scale 17 // LM scale
18 - float scale = 1.0; 18 + float scale = 0.5;
19 19
20 OnlineLMConfig() = default; 20 OnlineLMConfig() = default;
21 21
  1 +// sherpa-onnx/csrc/online-lm.cc
  2 +//
  3 +// Copyright (c) 2023 Pingfeng Luo
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +
  6 +#include "sherpa-onnx/csrc/online-lm.h"
  7 +
  8 +#include <algorithm>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/online-rnn-lm.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
  17 + std::vector<CopyableOrtValue> ans;
  18 + ans.reserve(values.size());
  19 +
  20 + for (auto &v : values) {
  21 + ans.emplace_back(std::move(v));
  22 + }
  23 +
  24 + return ans;
  25 +}
  26 +
  27 +static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
  28 + std::vector<Ort::Value> ans;
  29 + ans.reserve(values.size());
  30 +
  31 + for (auto &v : values) {
  32 + ans.emplace_back(std::move(v.value));
  33 + }
  34 +
  35 + return ans;
  36 +}
  37 +
  38 +std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
  39 + return std::make_unique<OnlineRnnLM>(config);
  40 +}
  41 +
  42 +void OnlineLM::ComputeLMScore(float scale, int32_t context_size,
  43 + std::vector<Hypotheses> *hyps) {
  44 + Ort::AllocatorWithDefaultOptions allocator;
  45 +
  46 + for (auto &hyp : *hyps) {
  47 + for (auto &h_m : hyp) {
  48 + auto &h = h_m.second;
  49 + auto &ys = h.ys;
  50 + const int32_t token_num_in_chunk =
  51 + ys.size() - context_size - h.cur_scored_pos - 1;
  52 +
  53 + if (token_num_in_chunk < 1) {
  54 + continue;
  55 + }
  56 +
  57 + if (h.nn_lm_states.empty()) {
  58 + h.nn_lm_states = Convert(GetInitStates());
  59 + }
  60 +
  61 + if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
  62 + std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
  63 + // shape of x and y are same
  64 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(
  65 + allocator, x_shape.data(), x_shape.size());
  66 + Ort::Value y = Ort::Value::CreateTensor<int64_t>(
  67 + allocator, x_shape.data(), x_shape.size());
  68 + int64_t *p_x = x.GetTensorMutableData<int64_t>();
  69 + int64_t *p_y = y.GetTensorMutableData<int64_t>();
  70 + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
  71 + p_x);
  72 + std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(),
  73 + p_y);
  74 +
  75 + // streaming forward by NN LM
  76 + auto out = Rescore(std::move(x), std::move(y),
  77 + Convert(std::move(h.nn_lm_states)));
  78 +
  79 + // update NN LM score in hyp
  80 + const float *p_nll = out.first.GetTensorData<float>();
  81 + h.lm_log_prob = -scale * (*p_nll);
  82 +
  83 + // update NN LM states in hyp
  84 + h.nn_lm_states = Convert(std::move(out.second));
  85 +
  86 + h.cur_scored_pos += token_num_in_chunk;
  87 + }
  88 + }
  89 + }
  90 +}
  91 +
  92 +} // namespace sherpa_onnx
@@ -34,7 +34,7 @@ class OnlineLM { @@ -34,7 +34,7 @@ class OnlineLM {
34 * 34 *
35 * Caution: It returns negative log likelihood (nll), not log likelihood 35 * Caution: It returns negative log likelihood (nll), not log likelihood
36 */ 36 */
37 - std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore( 37 + virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
38 Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0; 38 Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
39 39
40 // This function updates hyp.lm_lob_prob of hyps. 40 // This function updates hyp.lm_lob_prob of hyps.
@@ -44,19 +44,6 @@ class OnlineLM { @@ -44,19 +44,6 @@ class OnlineLM {
44 // @param hyps It is changed in-place. 44 // @param hyps It is changed in-place.
45 void ComputeLMScore(float scale, int32_t context_size, 45 void ComputeLMScore(float scale, int32_t context_size,
46 std::vector<Hypotheses> *hyps); 46 std::vector<Hypotheses> *hyps);
47 - /** TODO(fangjun):  
48 - *  
49 - * 1. Add two fields to Hypothesis  
50 - * (a) int32_t lm_cur_pos = 0; number of scored tokens so far  
51 - * (b) std::vector<Ort::Value> lm_states;  
52 - * 2. When we want to score a hypothesis, we construct x and y as follows:  
53 - *  
54 - * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,  
55 - * hyp.ys.end() - 1};  
56 - * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1  
57 - * hyp.ys.end()};  
58 - * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;  
59 - */  
60 }; 47 };
61 48
62 } // namespace sherpa_onnx 49 } // namespace sherpa_onnx
@@ -16,6 +16,8 @@ @@ -16,6 +16,8 @@
16 16
17 #include "nlohmann/json.hpp" 17 #include "nlohmann/json.hpp"
18 #include "sherpa-onnx/csrc/file-utils.h" 18 #include "sherpa-onnx/csrc/file-utils.h"
  19 +#include "sherpa-onnx/csrc/macros.h"
  20 +#include "sherpa-onnx/csrc/online-lm.h"
19 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 21 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
20 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 22 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
21 #include "sherpa-onnx/csrc/online-transducer-model.h" 23 #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
80 feat_config.Register(po); 82 feat_config.Register(po);
81 model_config.Register(po); 83 model_config.Register(po);
82 endpoint_config.Register(po); 84 endpoint_config.Register(po);
  85 + lm_config.Register(po);
83 86
84 po->Register("enable-endpoint", &enable_endpoint, 87 po->Register("enable-endpoint", &enable_endpoint,
85 "True to enable endpoint detection. False to disable it."); 88 "True to enable endpoint detection. False to disable it.");
@@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
91 } 94 }
92 95
93 bool OnlineRecognizerConfig::Validate() const { 96 bool OnlineRecognizerConfig::Validate() const {
  97 + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
  98 + if (max_active_paths <= 0) {
  99 + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
  100 + max_active_paths);
  101 + return false;
  102 + }
  103 + if (!lm_config.Validate()) return false;
  104 + }
94 return model_config.Validate(); 105 return model_config.Validate();
95 } 106 }
96 107
@@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const {
100 os << "OnlineRecognizerConfig("; 111 os << "OnlineRecognizerConfig(";
101 os << "feat_config=" << feat_config.ToString() << ", "; 112 os << "feat_config=" << feat_config.ToString() << ", ";
102 os << "model_config=" << model_config.ToString() << ", "; 113 os << "model_config=" << model_config.ToString() << ", ";
  114 + os << "lm_config=" << lm_config.ToString() << ", ";
103 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 115 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
104 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; 116 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
105 os << "max_active_paths=" << max_active_paths << ", "; 117 os << "max_active_paths=" << max_active_paths << ", ";
@@ -116,8 +128,13 @@ class OnlineRecognizer::Impl { @@ -116,8 +128,13 @@ class OnlineRecognizer::Impl {
116 sym_(config.model_config.tokens), 128 sym_(config.model_config.tokens),
117 endpoint_(config_.endpoint_config) { 129 endpoint_(config_.endpoint_config) {
118 if (config.decoding_method == "modified_beam_search") { 130 if (config.decoding_method == "modified_beam_search") {
  131 + if (!config_.lm_config.model.empty()) {
  132 + lm_ = OnlineLM::Create(config.lm_config);
  133 + }
  134 +
119 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 135 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
120 - model_.get(), config_.max_active_paths); 136 + model_.get(), lm_.get(), config_.max_active_paths,
  137 + config_.lm_config.scale);
121 } else if (config.decoding_method == "greedy_search") { 138 } else if (config.decoding_method == "greedy_search") {
122 decoder_ = 139 decoder_ =
123 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 140 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
@@ -136,7 +153,8 @@ class OnlineRecognizer::Impl { @@ -136,7 +153,8 @@ class OnlineRecognizer::Impl {
136 endpoint_(config_.endpoint_config) { 153 endpoint_(config_.endpoint_config) {
137 if (config.decoding_method == "modified_beam_search") { 154 if (config.decoding_method == "modified_beam_search") {
138 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 155 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
139 - model_.get(), config_.max_active_paths); 156 + model_.get(), lm_.get(), config_.max_active_paths,
  157 + config_.lm_config.scale);
140 } else if (config.decoding_method == "greedy_search") { 158 } else if (config.decoding_method == "greedy_search") {
141 decoder_ = 159 decoder_ =
142 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 160 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
@@ -246,6 +264,7 @@ class OnlineRecognizer::Impl { @@ -246,6 +264,7 @@ class OnlineRecognizer::Impl {
246 private: 264 private:
247 OnlineRecognizerConfig config_; 265 OnlineRecognizerConfig config_;
248 std::unique_ptr<OnlineTransducerModel> model_; 266 std::unique_ptr<OnlineTransducerModel> model_;
  267 + std::unique_ptr<OnlineLM> lm_;
249 std::unique_ptr<OnlineTransducerDecoder> decoder_; 268 std::unique_ptr<OnlineTransducerDecoder> decoder_;
250 SymbolTable sym_; 269 SymbolTable sym_;
251 Endpoint endpoint_; 270 Endpoint endpoint_;
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 16
17 #include "sherpa-onnx/csrc/endpoint.h" 17 #include "sherpa-onnx/csrc/endpoint.h"
18 #include "sherpa-onnx/csrc/features.h" 18 #include "sherpa-onnx/csrc/features.h"
  19 +#include "sherpa-onnx/csrc/online-lm-config.h"
19 #include "sherpa-onnx/csrc/online-stream.h" 20 #include "sherpa-onnx/csrc/online-stream.h"
20 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 21 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
21 #include "sherpa-onnx/csrc/parse-options.h" 22 #include "sherpa-onnx/csrc/parse-options.h"
@@ -67,10 +68,11 @@ struct OnlineRecognizerResult { @@ -67,10 +68,11 @@ struct OnlineRecognizerResult {
67 struct OnlineRecognizerConfig { 68 struct OnlineRecognizerConfig {
68 FeatureExtractorConfig feat_config; 69 FeatureExtractorConfig feat_config;
69 OnlineTransducerModelConfig model_config; 70 OnlineTransducerModelConfig model_config;
  71 + OnlineLMConfig lm_config;
70 EndpointConfig endpoint_config; 72 EndpointConfig endpoint_config;
71 bool enable_endpoint = true; 73 bool enable_endpoint = true;
72 74
73 - std::string decoding_method = "greedy_search"; 75 + std::string decoding_method = "modified_beam_search";
74 // now support modified_beam_search and greedy_search 76 // now support modified_beam_search and greedy_search
75 77
76 int32_t max_active_paths = 4; // used only for modified_beam_search 78 int32_t max_active_paths = 4; // used only for modified_beam_search
@@ -79,6 +81,7 @@ struct OnlineRecognizerConfig { @@ -79,6 +81,7 @@ struct OnlineRecognizerConfig {
79 81
80 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 82 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
81 const OnlineTransducerModelConfig &model_config, 83 const OnlineTransducerModelConfig &model_config,
  84 + const OnlineLMConfig &lm_config,
82 const EndpointConfig &endpoint_config, 85 const EndpointConfig &endpoint_config,
83 bool enable_endpoint, 86 bool enable_endpoint,
84 const std::string &decoding_method, 87 const std::string &decoding_method,
  1 +// sherpa-onnx/csrc/on-rnn-lm.cc
  2 +//
  3 +// Copyright (c) 2023 Pingfeng Luo
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +
  6 +#include "sherpa-onnx/csrc/online-rnn-lm.h"
  7 +
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/onnx-utils.h"
  15 +#include "sherpa-onnx/csrc/text-utils.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OnlineRnnLM::Impl {
  20 + public:
  21 + explicit Impl(const OnlineLMConfig &config)
  22 + : config_(config),
  23 + env_(ORT_LOGGING_LEVEL_ERROR),
  24 + sess_opts_{},
  25 + allocator_{} {
  26 + Init(config);
  27 + }
  28 +
  29 + std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
  30 + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
  31 + std::array<Ort::Value, 4> inputs = {
  32 + std::move(x), std::move(y), std::move(states[0]), std::move(states[1])};
  33 +
  34 + auto out =
  35 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  36 + output_names_ptr_.data(), output_names_ptr_.size());
  37 +
  38 + std::vector<Ort::Value> next_states;
  39 + next_states.reserve(2);
  40 + next_states.push_back(std::move(out[1]));
  41 + next_states.push_back(std::move(out[2]));
  42 +
  43 + return {std::move(out[0]), std::move(next_states)};
  44 + }
  45 +
  46 + std::vector<Ort::Value> GetInitStates() const {
  47 + std::vector<Ort::Value> ans;
  48 + ans.reserve(init_states_.size());
  49 +
  50 + for (const auto &s : init_states_) {
  51 + ans.emplace_back(Clone(allocator_, &s));
  52 + }
  53 +
  54 + return ans;
  55 + }
  56 +
  57 + private:
  58 + void Init(const OnlineLMConfig &config) {
  59 + auto buf = ReadFile(config_.model);
  60 +
  61 + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
  62 + sess_opts_);
  63 +
  64 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  65 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  66 +
  67 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  68 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  69 + SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers");
  70 + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size");
  71 + SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
  72 +
  73 + ComputeInitStates();
  74 + }
  75 +
  76 + void ComputeInitStates() {
  77 + constexpr int32_t kBatchSize = 1;
  78 + std::array<int64_t, 3> h_shape{rnn_num_layers_, kBatchSize,
  79 + rnn_hidden_size_};
  80 + std::array<int64_t, 3> c_shape{rnn_num_layers_, kBatchSize,
  81 + rnn_hidden_size_};
  82 + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
  83 + h_shape.size());
  84 + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
  85 + c_shape.size());
  86 + Fill<float>(&h, 0);
  87 + Fill<float>(&c, 0);
  88 + std::array<int64_t, 2> x_shape{1, 1};
  89 + // shape of x and y are same
  90 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
  91 + x_shape.size());
  92 + Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
  93 + x_shape.size());
  94 + *x.GetTensorMutableData<int64_t>() = sos_id_;
  95 + *y.GetTensorMutableData<int64_t>() = sos_id_;
  96 +
  97 + std::vector<Ort::Value> states;
  98 + states.push_back(std::move(h));
  99 + states.push_back(std::move(c));
  100 + auto pair = Rescore(std::move(x), std::move(y), std::move(states));
  101 +
  102 + init_states_ = std::move(pair.second);
  103 + }
  104 +
  105 + private:
  106 + OnlineLMConfig config_;
  107 + Ort::Env env_;
  108 + Ort::SessionOptions sess_opts_;
  109 + Ort::AllocatorWithDefaultOptions allocator_;
  110 +
  111 + std::unique_ptr<Ort::Session> sess_;
  112 +
  113 + std::vector<std::string> input_names_;
  114 + std::vector<const char *> input_names_ptr_;
  115 +
  116 + std::vector<std::string> output_names_;
  117 + std::vector<const char *> output_names_ptr_;
  118 +
  119 + std::vector<Ort::Value> init_states_;
  120 +
  121 + int32_t rnn_num_layers_ = 2;
  122 + int32_t rnn_hidden_size_ = 512;
  123 + int32_t sos_id_ = 1;
  124 +};
  125 +
  126 +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
  127 + : impl_(std::make_unique<Impl>(config)) {}
  128 +
  129 +OnlineRnnLM::~OnlineRnnLM() = default;
  130 +
  131 +std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
  132 + return impl_->GetInitStates();
  133 +}
  134 +
  135 +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore(
  136 + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
  137 + return impl_->Rescore(std::move(x), std::move(y), std::move(states));
  138 +}
  139 +
  140 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-rnn-lm.h
  2 +//
  3 +// Copyright (c) 2023 Pingfeng Luo
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
  7 +#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
  8 +
  9 +#include <memory>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "onnxruntime_cxx_api.h" // NOLINT
  14 +#include "sherpa-onnx/csrc/online-lm-config.h"
  15 +#include "sherpa-onnx/csrc/online-lm.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OnlineRnnLM : public OnlineLM {
  20 + public:
  21 + ~OnlineRnnLM() override;
  22 +
  23 + explicit OnlineRnnLM(const OnlineLMConfig &config);
  24 +
  25 + std::vector<Ort::Value> GetInitStates() override;
  26 +
  27 + /** Rescore a batch of sentences.
  28 + *
  29 + * @param x A 2-D tensor of shape (N, L) with data type int64.
  30 + * @param y A 2-D tensor of shape (N, L) with data type int64.
  31 + * @param states It contains the states for the LM model
  32 + * @return Return a pair containingo
  33 + * - negative loglike
  34 + * - updated states
  35 + *
  36 + * Caution: It returns negative log likelihood (nll), not log likelihood
  37 + */
  38 + std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
  39 + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override;
  40 +
  41 + private:
  42 + class Impl;
  43 + std::unique_ptr<Impl> impl_;
  44 +};
  45 +
  46 +} // namespace sherpa_onnx
  47 +
  48 +#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
@@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
156 } // for (int32_t b = 0; b != batch_size; ++b) 156 } // for (int32_t b = 0; b != batch_size; ++b)
157 } 157 }
158 158
  159 + if (lm_) {
  160 + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
  161 + }
  162 +
159 for (int32_t b = 0; b != batch_size; ++b) { 163 for (int32_t b = 0; b != batch_size; ++b) {
160 auto &hyps = cur[b]; 164 auto &hyps = cur[b];
161 auto best_hyp = hyps.GetMostProbable(true); 165 auto best_hyp = hyps.GetMostProbable(true);
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "sherpa-onnx/csrc/online-lm.h"
11 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 12 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
12 #include "sherpa-onnx/csrc/online-transducer-model.h" 13 #include "sherpa-onnx/csrc/online-transducer-model.h"
13 14
@@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder
17 : public OnlineTransducerDecoder { 18 : public OnlineTransducerDecoder {
18 public: 19 public:
19 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, 20 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
20 - int32_t max_active_paths)  
21 - : model_(model), max_active_paths_(max_active_paths) {} 21 + OnlineLM *lm,
  22 + int32_t max_active_paths,
  23 + float lm_scale)
  24 + : model_(model),
  25 + lm_(lm),
  26 + max_active_paths_(max_active_paths),
  27 + lm_scale_(lm_scale) {}
22 28
23 OnlineTransducerDecoderResult GetEmptyResult() const override; 29 OnlineTransducerDecoderResult GetEmptyResult() const override;
24 30
@@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder
31 37
32 private: 38 private:
33 OnlineTransducerModel *model_; // Not owned 39 OnlineTransducerModel *model_; // Not owned
  40 + OnlineLM *lm_; // Not owned
  41 +
34 int32_t max_active_paths_; 42 int32_t max_active_paths_;
  43 + float lm_scale_; // used only when lm_ is not nullptr
35 }; 44 };
36 45
37 } // namespace sherpa_onnx 46 } // namespace sherpa_onnx
1 // sherpa-onnx/csrc/onnx-utils.cc 1 // sherpa-onnx/csrc/onnx-utils.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
  4 +// Copyright (c) 2023 Pingfeng Luo
4 #include "sherpa-onnx/csrc/onnx-utils.h" 5 #include "sherpa-onnx/csrc/onnx-utils.h"
5 6
6 #include <algorithm> 7 #include <algorithm>
7 #include <fstream> 8 #include <fstream>
8 #include <sstream> 9 #include <sstream>
9 #include <string> 10 #include <string>
10 -#include <vector>  
11 11
12 #if __ANDROID_API__ >= 9 12 #if __ANDROID_API__ >= 9
13 #include "android/asset_manager.h" 13 #include "android/asset_manager.h"
@@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, @@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
218 return ans; 218 return ans;
219 } 219 }
220 220
  221 +CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) {
  222 + *this = other;
  223 +}
  224 +
  225 +CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) {
  226 + if (this == &other) {
  227 + return *this;
  228 + }
  229 + if (other.value) {
  230 + Ort::AllocatorWithDefaultOptions allocator;
  231 + value = Clone(allocator, &other.value);
  232 + }
  233 + return *this;
  234 +}
  235 +
  236 +CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) {
  237 + *this = std::move(other);
  238 +}
  239 +
  240 +CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
  241 + if (this == &other) {
  242 + return *this;
  243 + }
  244 + value = std::move(other.value);
  245 + return *this;
  246 +}
  247 +
221 } // namespace sherpa_onnx 248 } // namespace sherpa_onnx
1 // sherpa-onnx/csrc/onnx-utils.h 1 // sherpa-onnx/csrc/onnx-utils.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
  4 +// Copyright (c) 2023 Pingfeng Luo
4 #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 5 #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
5 #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 6 #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
6 7
@@ -13,6 +14,7 @@ @@ -13,6 +14,7 @@
13 #include <cassert> 14 #include <cassert>
14 #include <ostream> 15 #include <ostream>
15 #include <string> 16 #include <string>
  17 +#include <utility>
16 #include <vector> 18 #include <vector>
17 19
18 #if __ANDROID_API__ >= 9 20 #if __ANDROID_API__ >= 9
@@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); @@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
89 // TODO(fangjun): Document it 91 // TODO(fangjun): Document it
90 Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, 92 Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
91 const std::vector<int32_t> &hyps_num_split); 93 const std::vector<int32_t> &hyps_num_split);
  94 +
  95 +struct CopyableOrtValue {
  96 + Ort::Value value{nullptr};
  97 +
  98 + CopyableOrtValue() = default;
  99 +
  100 + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
  101 + : value(std::move(v)) {}
  102 +
  103 + CopyableOrtValue(const CopyableOrtValue &other);
  104 +
  105 + CopyableOrtValue &operator=(const CopyableOrtValue &other);
  106 +
  107 + CopyableOrtValue(CopyableOrtValue &&other);
  108 +
  109 + CopyableOrtValue &operator=(CopyableOrtValue &&other);
  110 +};
  111 +
92 } // namespace sherpa_onnx 112 } // namespace sherpa_onnx
93 113
94 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 114 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
@@ -13,8 +13,9 @@ @@ -13,8 +13,9 @@
13 #include "sherpa-onnx/csrc/symbol-table.h" 13 #include "sherpa-onnx/csrc/symbol-table.h"
14 #include "sherpa-onnx/csrc/wave-reader.h" 14 #include "sherpa-onnx/csrc/wave-reader.h"
15 15
  16 +// TODO(fangjun): Use ParseOptions as we are getting more args
16 int main(int32_t argc, char *argv[]) { 17 int main(int32_t argc, char *argv[]) {
17 - if (argc < 6 || argc > 8) { 18 + if (argc < 6 || argc > 9) {
18 const char *usage = R"usage( 19 const char *usage = R"usage(
19 Usage: 20 Usage:
20 ./bin/sherpa-onnx \ 21 ./bin/sherpa-onnx \
@@ -22,7 +23,7 @@ Usage: @@ -22,7 +23,7 @@ Usage:
22 /path/to/encoder.onnx \ 23 /path/to/encoder.onnx \
23 /path/to/decoder.onnx \ 24 /path/to/decoder.onnx \
24 /path/to/joiner.onnx \ 25 /path/to/joiner.onnx \
25 - /path/to/foo.wav [num_threads [decoding_method]] 26 + /path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]]
26 27
27 Default value for num_threads is 2. 28 Default value for num_threads is 2.
28 Valid values for decoding_method: greedy_search (default), modified_beam_search. 29 Valid values for decoding_method: greedy_search (default), modified_beam_search.
@@ -53,10 +54,12 @@ for a list of pre-trained models to download. @@ -53,10 +54,12 @@ for a list of pre-trained models to download.
53 if (argc == 7 && atoi(argv[6]) > 0) { 54 if (argc == 7 && atoi(argv[6]) > 0) {
54 config.model_config.num_threads = atoi(argv[6]); 55 config.model_config.num_threads = atoi(argv[6]);
55 } 56 }
56 -  
57 if (argc == 8) { 57 if (argc == 8) {
58 config.decoding_method = argv[7]; 58 config.decoding_method = argv[7];
59 } 59 }
  60 + if (argc == 9) {
  61 + config.lm_config.model = argv[8];
  62 + }
60 config.max_active_paths = 4; 63 config.max_active_paths = 4;
61 64
62 fprintf(stderr, "%s\n", config.ToString().c_str()); 65 fprintf(stderr, "%s\n", config.ToString().c_str());
@@ -16,9 +16,8 @@ @@ -16,9 +16,8 @@
16 #if __ANDROID_API__ >= 9 16 #if __ANDROID_API__ >= 9
17 #include "android/asset_manager.h" 17 #include "android/asset_manager.h"
18 #include "android/asset_manager_jni.h" 18 #include "android/asset_manager_jni.h"
19 -#else  
20 -#include <fstream>  
21 #endif 19 #endif
  20 +#include <fstream>
22 21
23 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
24 #include "sherpa-onnx/csrc/online-recognizer.h" 23 #include "sherpa-onnx/csrc/online-recognizer.h"
@@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
188 fid = env->GetFieldID(model_config_cls, "debug", "Z"); 187 fid = env->GetFieldID(model_config_cls, "debug", "Z");
189 ans.model_config.debug = env->GetBooleanField(model_config, fid); 188 ans.model_config.debug = env->GetBooleanField(model_config, fid);
190 189
  190 + //---------- rnn lm model config ----------
  191 + fid = env->GetFieldID(cls, "lmConfig",
  192 + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
  193 + jobject lm_model_config = env->GetObjectField(config, fid);
  194 + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
  195 +
  196 + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
  197 + s = (jstring)env->GetObjectField(lm_model_config, fid);
  198 + p = env->GetStringUTFChars(s, nullptr);
  199 + ans.lm_config.model = p;
  200 + env->ReleaseStringUTFChars(s, p);
  201 +
  202 + fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
  203 + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
  204 +
191 return ans; 205 return ans;
192 } 206 }
193 207
@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
11 offline-recognizer.cc 11 offline-recognizer.cc
12 offline-stream.cc 12 offline-stream.cc
13 offline-transducer-model-config.cc 13 offline-transducer-model-config.cc
  14 + online-lm-config.cc
14 online-recognizer.cc 15 online-recognizer.cc
15 online-stream.cc 16 online-stream.cc
16 online-transducer-model-config.cc 17 online-transducer-model-config.cc
  1 +// sherpa-onnx/python/csrc/online-lm-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-lm-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx//csrc/online-lm-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOnlineLMConfig(py::module *m) {
  14 + using PyClass = OnlineLMConfig;
  15 + py::class_<PyClass>(*m, "OnlineLMConfig")
  16 + .def(py::init<const std::string &, float>(), py::arg("model") = "",
  17 + py::arg("scale") = 0.5f)
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def_readwrite("scale", &PyClass::scale)
  20 + .def("__str__", &PyClass::ToString);
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-lm-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineLMConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
@@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
21 using PyClass = OnlineRecognizerConfig; 21 using PyClass = OnlineRecognizerConfig;
22 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 22 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
23 .def(py::init<const FeatureExtractorConfig &, 23 .def(py::init<const FeatureExtractorConfig &,
24 - const OnlineTransducerModelConfig &, const EndpointConfig &,  
25 - bool, const std::string &, int32_t>(), 24 + const OnlineTransducerModelConfig &, const OnlineLMConfig &,
  25 + const EndpointConfig &, bool, const std::string &,
  26 + int32_t>(),
26 py::arg("feat_config"), py::arg("model_config"), 27 py::arg("feat_config"), py::arg("model_config"),
27 - py::arg("endpoint_config"), py::arg("enable_endpoint"),  
28 - py::arg("decoding_method"), py::arg("max_active_paths")) 28 + py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
  29 + py::arg("enable_endpoint"), py::arg("decoding_method"),
  30 + py::arg("max_active_paths"))
29 .def_readwrite("feat_config", &PyClass::feat_config) 31 .def_readwrite("feat_config", &PyClass::feat_config)
30 .def_readwrite("model_config", &PyClass::model_config) 32 .def_readwrite("model_config", &PyClass::model_config)
31 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 33 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/python/csrc/offline-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 12 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
13 #include "sherpa-onnx/python/csrc/offline-stream.h" 13 #include "sherpa-onnx/python/csrc/offline-stream.h"
  14 +#include "sherpa-onnx/python/csrc/online-lm-config.h"
14 #include "sherpa-onnx/python/csrc/online-recognizer.h" 15 #include "sherpa-onnx/python/csrc/online-recognizer.h"
15 #include "sherpa-onnx/python/csrc/online-stream.h" 16 #include "sherpa-onnx/python/csrc/online-stream.h"
16 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 17 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
@@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
22 23
23 PybindFeatures(&m); 24 PybindFeatures(&m);
24 PybindOnlineTransducerModelConfig(&m); 25 PybindOnlineTransducerModelConfig(&m);
  26 + PybindOnlineLMConfig(&m);
25 PybindOnlineStream(&m); 27 PybindOnlineStream(&m);
26 PybindEndpoint(&m); 28 PybindEndpoint(&m);
27 PybindOnlineRecognizer(&m); 29 PybindOnlineRecognizer(&m);