Committed by
GitHub
Add lm decode for the Python API. (#353)
* Add lm decode for the Python API. * fix style. * Fix LogAdd, Shouldn't double lm_log_prob when merge same prefix path * sort the import alphabetically
正在显示
4 个修改的文件
包含
36 行增加
和
5 行删除
| @@ -116,6 +116,24 @@ def get_args(): | @@ -116,6 +116,24 @@ def get_args(): | ||
| 116 | ) | 116 | ) |
| 117 | 117 | ||
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| 119 | + "--lm", | ||
| 120 | + type=str, | ||
| 121 | + default="", | ||
| 122 | + help="""Used only when --decoding-method is modified_beam_search. | ||
| 123 | + path of language model. | ||
| 124 | + """, | ||
| 125 | + ) | ||
| 126 | + | ||
| 127 | + parser.add_argument( | ||
| 128 | + "--lm-scale", | ||
| 129 | + type=float, | ||
| 130 | + default=0.1, | ||
| 131 | + help="""Used only when --decoding-method is modified_beam_search. | ||
| 132 | + scale of language model. | ||
| 133 | + """, | ||
| 134 | + ) | ||
| 135 | + | ||
| 136 | + parser.add_argument( | ||
| 119 | "--provider", | 137 | "--provider", |
| 120 | type=str, | 138 | type=str, |
| 121 | default="cpu", | 139 | default="cpu", |
| @@ -215,6 +233,8 @@ def main(): | @@ -215,6 +233,8 @@ def main(): | ||
| 215 | feature_dim=80, | 233 | feature_dim=80, |
| 216 | decoding_method=args.decoding_method, | 234 | decoding_method=args.decoding_method, |
| 217 | max_active_paths=args.max_active_paths, | 235 | max_active_paths=args.max_active_paths, |
| 236 | + lm=args.lm, | ||
| 237 | + lm_scale=args.lm_scale, | ||
| 218 | hotwords_file=args.hotwords_file, | 238 | hotwords_file=args.hotwords_file, |
| 219 | hotwords_score=args.hotwords_score, | 239 | hotwords_score=args.hotwords_score, |
| 220 | ) | 240 | ) |
| @@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) { | @@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) { | ||
| 17 | hyps_dict_[key] = std::move(hyp); | 17 | hyps_dict_[key] = std::move(hyp); |
| 18 | } else { | 18 | } else { |
| 19 | it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); | 19 | it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); |
| 20 | - | ||
| 21 | - if (it->second.lm_log_prob != 0 && hyp.lm_log_prob != 0) { | ||
| 22 | - it->second.lm_log_prob = | ||
| 23 | - LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob); | ||
| 24 | - } | ||
| 25 | } | 20 | } |
| 26 | } | 21 | } |
| 27 | 22 |
| @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 37 | py::arg("hotwords_score") = 0) | 37 | py::arg("hotwords_score") = 0) |
| 38 | .def_readwrite("feat_config", &PyClass::feat_config) | 38 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 39 | .def_readwrite("model_config", &PyClass::model_config) | 39 | .def_readwrite("model_config", &PyClass::model_config) |
| 40 | + .def_readwrite("lm_config", &PyClass::lm_config) | ||
| 40 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 41 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| 41 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | 42 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) |
| 42 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 43 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| @@ -5,6 +5,7 @@ from typing import List, Optional | @@ -5,6 +5,7 @@ from typing import List, Optional | ||
| 5 | from _sherpa_onnx import ( | 5 | from _sherpa_onnx import ( |
| 6 | EndpointConfig, | 6 | EndpointConfig, |
| 7 | FeatureExtractorConfig, | 7 | FeatureExtractorConfig, |
| 8 | + OnlineLMConfig, | ||
| 8 | OnlineModelConfig, | 9 | OnlineModelConfig, |
| 9 | OnlineParaformerModelConfig, | 10 | OnlineParaformerModelConfig, |
| 10 | OnlineRecognizer as _Recognizer, | 11 | OnlineRecognizer as _Recognizer, |
| @@ -46,6 +47,8 @@ class OnlineRecognizer(object): | @@ -46,6 +47,8 @@ class OnlineRecognizer(object): | ||
| 46 | hotwords_file: str = "", | 47 | hotwords_file: str = "", |
| 47 | provider: str = "cpu", | 48 | provider: str = "cpu", |
| 48 | model_type: str = "", | 49 | model_type: str = "", |
| 50 | + lm: str = "", | ||
| 51 | + lm_scale: float = 0.1, | ||
| 49 | ): | 52 | ): |
| 50 | """ | 53 | """ |
| 51 | Please refer to | 54 | Please refer to |
| @@ -137,10 +140,22 @@ class OnlineRecognizer(object): | @@ -137,10 +140,22 @@ class OnlineRecognizer(object): | ||
| 137 | "Please use --decoding-method=modified_beam_search when using " | 140 | "Please use --decoding-method=modified_beam_search when using " |
| 138 | f"--hotwords-file. Currently given: {decoding_method}" | 141 | f"--hotwords-file. Currently given: {decoding_method}" |
| 139 | ) | 142 | ) |
| 143 | + | ||
| 144 | + if lm and decoding_method != "modified_beam_search": | ||
| 145 | + raise ValueError( | ||
| 146 | + "Please use --decoding-method=modified_beam_search when using " | ||
| 147 | + f"--lm. Currently given: {decoding_method}" | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + lm_config = OnlineLMConfig( | ||
| 151 | + model=lm, | ||
| 152 | + scale=lm_scale, | ||
| 153 | + ) | ||
| 140 | 154 | ||
| 141 | recognizer_config = OnlineRecognizerConfig( | 155 | recognizer_config = OnlineRecognizerConfig( |
| 142 | feat_config=feat_config, | 156 | feat_config=feat_config, |
| 143 | model_config=model_config, | 157 | model_config=model_config, |
| 158 | + lm_config=lm_config, | ||
| 144 | endpoint_config=endpoint_config, | 159 | endpoint_config=endpoint_config, |
| 145 | enable_endpoint=enable_endpoint_detection, | 160 | enable_endpoint=enable_endpoint_detection, |
| 146 | decoding_method=decoding_method, | 161 | decoding_method=decoding_method, |
-
请 注册 或 登录 后发表评论