Peng He
Committed by GitHub

Add lm decode for the Python API. (#353)

* Add lm decode for the Python API.

* fix style.

* Fix LogAdd,

	Shouldn't double lm_log_prob when merge same prefix path

* sort the import alphabetically
@@ -116,6 +116,24 @@ def get_args(): @@ -116,6 +116,24 @@ def get_args():
116 ) 116 )
117 117
118 parser.add_argument( 118 parser.add_argument(
  119 + "--lm",
  120 + type=str,
  121 + default="",
  122 + help="""Used only when --decoding-method is modified_beam_search.
  123 + path of language model.
  124 + """,
  125 + )
  126 +
  127 + parser.add_argument(
  128 + "--lm-scale",
  129 + type=float,
  130 + default=0.1,
  131 + help="""Used only when --decoding-method is modified_beam_search.
  132 + scale of language model.
  133 + """,
  134 + )
  135 +
  136 + parser.add_argument(
119 "--provider", 137 "--provider",
120 type=str, 138 type=str,
121 default="cpu", 139 default="cpu",
@@ -215,6 +233,8 @@ def main(): @@ -215,6 +233,8 @@ def main():
215 feature_dim=80, 233 feature_dim=80,
216 decoding_method=args.decoding_method, 234 decoding_method=args.decoding_method,
217 max_active_paths=args.max_active_paths, 235 max_active_paths=args.max_active_paths,
  236 + lm=args.lm,
  237 + lm_scale=args.lm_scale,
218 hotwords_file=args.hotwords_file, 238 hotwords_file=args.hotwords_file,
219 hotwords_score=args.hotwords_score, 239 hotwords_score=args.hotwords_score,
220 ) 240 )
@@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) { @@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) {
17 hyps_dict_[key] = std::move(hyp); 17 hyps_dict_[key] = std::move(hyp);
18 } else { 18 } else {
19 it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); 19 it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
20 -  
21 - if (it->second.lm_log_prob != 0 && hyp.lm_log_prob != 0) {  
22 - it->second.lm_log_prob =  
23 - LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);  
24 - }  
25 } 20 }
26 } 21 }
27 22
@@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
37 py::arg("hotwords_score") = 0) 37 py::arg("hotwords_score") = 0)
38 .def_readwrite("feat_config", &PyClass::feat_config) 38 .def_readwrite("feat_config", &PyClass::feat_config)
39 .def_readwrite("model_config", &PyClass::model_config) 39 .def_readwrite("model_config", &PyClass::model_config)
  40 + .def_readwrite("lm_config", &PyClass::lm_config)
40 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 41 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
41 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 42 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
42 .def_readwrite("decoding_method", &PyClass::decoding_method) 43 .def_readwrite("decoding_method", &PyClass::decoding_method)
@@ -5,6 +5,7 @@ from typing import List, Optional @@ -5,6 +5,7 @@ from typing import List, Optional
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
6 EndpointConfig, 6 EndpointConfig,
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
  8 + OnlineLMConfig,
8 OnlineModelConfig, 9 OnlineModelConfig,
9 OnlineParaformerModelConfig, 10 OnlineParaformerModelConfig,
10 OnlineRecognizer as _Recognizer, 11 OnlineRecognizer as _Recognizer,
@@ -46,6 +47,8 @@ class OnlineRecognizer(object): @@ -46,6 +47,8 @@ class OnlineRecognizer(object):
46 hotwords_file: str = "", 47 hotwords_file: str = "",
47 provider: str = "cpu", 48 provider: str = "cpu",
48 model_type: str = "", 49 model_type: str = "",
  50 + lm: str = "",
  51 + lm_scale: float = 0.1,
49 ): 52 ):
50 """ 53 """
51 Please refer to 54 Please refer to
@@ -137,10 +140,22 @@ class OnlineRecognizer(object): @@ -137,10 +140,22 @@ class OnlineRecognizer(object):
137 "Please use --decoding-method=modified_beam_search when using " 140 "Please use --decoding-method=modified_beam_search when using "
138 f"--hotwords-file. Currently given: {decoding_method}" 141 f"--hotwords-file. Currently given: {decoding_method}"
139 ) 142 )
  143 +
  144 + if lm and decoding_method != "modified_beam_search":
  145 + raise ValueError(
  146 + "Please use --decoding-method=modified_beam_search when using "
  147 + f"--lm. Currently given: {decoding_method}"
  148 + )
  149 +
  150 + lm_config = OnlineLMConfig(
  151 + model=lm,
  152 + scale=lm_scale,
  153 + )
140 154
141 recognizer_config = OnlineRecognizerConfig( 155 recognizer_config = OnlineRecognizerConfig(
142 feat_config=feat_config, 156 feat_config=feat_config,
143 model_config=model_config, 157 model_config=model_config,
  158 + lm_config=lm_config,
144 endpoint_config=endpoint_config, 159 endpoint_config=endpoint_config,
145 enable_endpoint=enable_endpoint_detection, 160 enable_endpoint=enable_endpoint_detection,
146 decoding_method=decoding_method, 161 decoding_method=decoding_method,