Committed by
GitHub
add blank_penalty for online transducer (#548)
正在显示
13 个修改的文件
包含
94 行增加
和
14 行删除
| @@ -217,6 +217,18 @@ def get_args(): | @@ -217,6 +217,18 @@ def get_args(): | ||
| 217 | ) | 217 | ) |
| 218 | 218 | ||
| 219 | parser.add_argument( | 219 | parser.add_argument( |
| 220 | + "--blank-penalty", | ||
| 221 | + type=float, | ||
| 222 | + default=0.0, | ||
| 223 | + help=""" | ||
| 224 | + The penalty applied on blank symbol during decoding. | ||
| 225 | + Note: It is a positive value that would be applied to logits like | ||
| 226 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 227 | + [batch_size, vocab] and blank id is 0). | ||
| 228 | + """, | ||
| 229 | + ) | ||
| 230 | + | ||
| 231 | + parser.add_argument( | ||
| 220 | "sound_files", | 232 | "sound_files", |
| 221 | type=str, | 233 | type=str, |
| 222 | nargs="+", | 234 | nargs="+", |
| @@ -290,6 +302,7 @@ def main(): | @@ -290,6 +302,7 @@ def main(): | ||
| 290 | lm_scale=args.lm_scale, | 302 | lm_scale=args.lm_scale, |
| 291 | hotwords_file=args.hotwords_file, | 303 | hotwords_file=args.hotwords_file, |
| 292 | hotwords_score=args.hotwords_score, | 304 | hotwords_score=args.hotwords_score, |
| 305 | + blank_penalty=args.blank_penalty, | ||
| 293 | ) | 306 | ) |
| 294 | elif args.zipformer2_ctc: | 307 | elif args.zipformer2_ctc: |
| 295 | recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( | 308 | recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( |
| @@ -102,6 +102,17 @@ def get_args(): | @@ -102,6 +102,17 @@ def get_args(): | ||
| 102 | """, | 102 | """, |
| 103 | ) | 103 | ) |
| 104 | 104 | ||
| 105 | + parser.add_argument( | ||
| 106 | + "--blank-penalty", | ||
| 107 | + type=float, | ||
| 108 | + default=0.0, | ||
| 109 | + help=""" | ||
| 110 | + The penalty applied on blank symbol during decoding. | ||
| 111 | + Note: It is a positive value that would be applied to logits like | ||
| 112 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 113 | + [batch_size, vocab] and blank id is 0). | ||
| 114 | + """, | ||
| 115 | + ) | ||
| 105 | 116 | ||
| 106 | return parser.parse_args() | 117 | return parser.parse_args() |
| 107 | 118 | ||
| @@ -130,6 +141,7 @@ def create_recognizer(args): | @@ -130,6 +141,7 @@ def create_recognizer(args): | ||
| 130 | provider=args.provider, | 141 | provider=args.provider, |
| 131 | hotwords_file=args.hotwords_file, | 142 | hotwords_file=args.hotwords_file, |
| 132 | hotwords_score=args.hotwords_score, | 143 | hotwords_score=args.hotwords_score, |
| 144 | + blank_penalty=args.blank_penalty, | ||
| 133 | ) | 145 | ) |
| 134 | return recognizer | 146 | return recognizer |
| 135 | 147 |
| @@ -111,6 +111,17 @@ def get_args(): | @@ -111,6 +111,17 @@ def get_args(): | ||
| 111 | """, | 111 | """, |
| 112 | ) | 112 | ) |
| 113 | 113 | ||
| 114 | + parser.add_argument( | ||
| 115 | + "--blank-penalty", | ||
| 116 | + type=float, | ||
| 117 | + default=0.0, | ||
| 118 | + help=""" | ||
| 119 | + The penalty applied on blank symbol during decoding. | ||
| 120 | + Note: It is a positive value that would be applied to logits like | ||
| 121 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 122 | + [batch_size, vocab] and blank id is 0). | ||
| 123 | + """, | ||
| 124 | + ) | ||
| 114 | 125 | ||
| 115 | return parser.parse_args() | 126 | return parser.parse_args() |
| 116 | 127 | ||
| @@ -136,6 +147,7 @@ def create_recognizer(args): | @@ -136,6 +147,7 @@ def create_recognizer(args): | ||
| 136 | provider=args.provider, | 147 | provider=args.provider, |
| 137 | hotwords_file=args.hotwords_file, | 148 | hotwords_file=args.hotwords_file, |
| 138 | hotwords_score=args.hotwords_score, | 149 | hotwords_score=args.hotwords_score, |
| 150 | + blank_penalty=args.blank_penalty, | ||
| 139 | ) | 151 | ) |
| 140 | return recognizer | 152 | return recognizer |
| 141 | 153 |
| @@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): | @@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): | ||
| 241 | """, | 241 | """, |
| 242 | ) | 242 | ) |
| 243 | 243 | ||
| 244 | +def add_blank_penalty_args(parser: argparse.ArgumentParser): | ||
| 245 | + parser.add_argument( | ||
| 246 | + "--blank-penalty", | ||
| 247 | + type=float, | ||
| 248 | + default=0.0, | ||
| 249 | + help=""" | ||
| 250 | + The penalty applied on blank symbol during decoding. | ||
| 251 | + Note: It is a positive value that would be applied to logits like | ||
| 252 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 253 | + [batch_size, vocab] and blank id is 0). | ||
| 254 | + """, | ||
| 255 | + ) | ||
| 244 | 256 | ||
| 245 | def add_endpointing_args(parser: argparse.ArgumentParser): | 257 | def add_endpointing_args(parser: argparse.ArgumentParser): |
| 246 | parser.add_argument( | 258 | parser.add_argument( |
| @@ -284,6 +296,7 @@ def get_args(): | @@ -284,6 +296,7 @@ def get_args(): | ||
| 284 | add_decoding_args(parser) | 296 | add_decoding_args(parser) |
| 285 | add_endpointing_args(parser) | 297 | add_endpointing_args(parser) |
| 286 | add_hotwords_args(parser) | 298 | add_hotwords_args(parser) |
| 299 | + add_blank_penalty_args(parser) | ||
| 287 | 300 | ||
| 288 | parser.add_argument( | 301 | parser.add_argument( |
| 289 | "--port", | 302 | "--port", |
| @@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | @@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 390 | max_active_paths=args.num_active_paths, | 403 | max_active_paths=args.num_active_paths, |
| 391 | hotwords_score=args.hotwords_score, | 404 | hotwords_score=args.hotwords_score, |
| 392 | hotwords_file=args.hotwords_file, | 405 | hotwords_file=args.hotwords_file, |
| 406 | + blank_penalty=args.blank_penalty, | ||
| 393 | enable_endpoint_detection=args.use_endpoint != 0, | 407 | enable_endpoint_detection=args.use_endpoint != 0, |
| 394 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, | 408 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, |
| 395 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, | 409 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, |
| @@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 95 | 95 | ||
| 96 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 96 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 97 | model_.get(), lm_.get(), config_.max_active_paths, | 97 | model_.get(), lm_.get(), config_.max_active_paths, |
| 98 | - config_.lm_config.scale, unk_id_); | 98 | + config_.lm_config.scale, unk_id_, config_.blank_penalty); |
| 99 | } else if (config.decoding_method == "greedy_search") { | 99 | } else if (config.decoding_method == "greedy_search") { |
| 100 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 100 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 101 | - model_.get(), unk_id_); | 101 | + model_.get(), unk_id_, config_.blank_penalty); |
| 102 | } else { | 102 | } else { |
| 103 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 103 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 104 | config.decoding_method.c_str()); | 104 | config.decoding_method.c_str()); |
| @@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 131 | 131 | ||
| 132 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 132 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 133 | model_.get(), lm_.get(), config_.max_active_paths, | 133 | model_.get(), lm_.get(), config_.max_active_paths, |
| 134 | - config_.lm_config.scale, unk_id_); | 134 | + config_.lm_config.scale, unk_id_, config_.blank_penalty); |
| 135 | } else if (config.decoding_method == "greedy_search") { | 135 | } else if (config.decoding_method == "greedy_search") { |
| 136 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 136 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 137 | - model_.get(), unk_id_); | 137 | + model_.get(), unk_id_, config_.blank_penalty); |
| 138 | } else { | 138 | } else { |
| 139 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 139 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 140 | config.decoding_method.c_str()); | 140 | config.decoding_method.c_str()); |
| @@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 81 | "True to enable endpoint detection. False to disable it."); | 81 | "True to enable endpoint detection. False to disable it."); |
| 82 | po->Register("max-active-paths", &max_active_paths, | 82 | po->Register("max-active-paths", &max_active_paths, |
| 83 | "beam size used in modified beam search."); | 83 | "beam size used in modified beam search."); |
| 84 | + po->Register("blank-penalty", &blank_penalty, | ||
| 85 | + "The penalty applied on blank symbol during decoding. " | ||
| 86 | + "Note: It is a positive value. " | ||
| 87 | + "Increasing value will lead to lower deletion at the cost" | ||
| 88 | + "of higher insertions. " | ||
| 89 | + "Currently only applicable for transducer models."); | ||
| 84 | po->Register("hotwords-score", &hotwords_score, | 90 | po->Register("hotwords-score", &hotwords_score, |
| 85 | "The bonus score for each token in context word/phrase. " | 91 | "The bonus score for each token in context word/phrase. " |
| 86 | "Used only when decoding_method is modified_beam_search"); | 92 | "Used only when decoding_method is modified_beam_search"); |
| @@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 131 | os << "max_active_paths=" << max_active_paths << ", "; | 137 | os << "max_active_paths=" << max_active_paths << ", "; |
| 132 | os << "hotwords_score=" << hotwords_score << ", "; | 138 | os << "hotwords_score=" << hotwords_score << ", "; |
| 133 | os << "hotwords_file=\"" << hotwords_file << "\", "; | 139 | os << "hotwords_file=\"" << hotwords_file << "\", "; |
| 134 | - os << "decoding_method=\"" << decoding_method << "\")"; | 140 | + os << "decoding_method=\"" << decoding_method << "\", "; |
| 141 | + os << "blank_penalty=" << blank_penalty << ")"; | ||
| 135 | 142 | ||
| 136 | return os.str(); | 143 | return os.str(); |
| 137 | } | 144 | } |
| @@ -83,6 +83,8 @@ struct OnlineRecognizerConfig { | @@ -83,6 +83,8 @@ struct OnlineRecognizerConfig { | ||
| 83 | float hotwords_score = 1.5; | 83 | float hotwords_score = 1.5; |
| 84 | std::string hotwords_file; | 84 | std::string hotwords_file; |
| 85 | 85 | ||
| 86 | + float blank_penalty = 0.0; | ||
| 87 | + | ||
| 86 | OnlineRecognizerConfig() = default; | 88 | OnlineRecognizerConfig() = default; |
| 87 | 89 | ||
| 88 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, | 90 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, |
| @@ -92,7 +94,8 @@ struct OnlineRecognizerConfig { | @@ -92,7 +94,8 @@ struct OnlineRecognizerConfig { | ||
| 92 | bool enable_endpoint, | 94 | bool enable_endpoint, |
| 93 | const std::string &decoding_method, | 95 | const std::string &decoding_method, |
| 94 | int32_t max_active_paths, | 96 | int32_t max_active_paths, |
| 95 | - const std::string &hotwords_file, float hotwords_score) | 97 | + const std::string &hotwords_file, float hotwords_score, |
| 98 | + float blank_penalty) | ||
| 96 | : feat_config(feat_config), | 99 | : feat_config(feat_config), |
| 97 | model_config(model_config), | 100 | model_config(model_config), |
| 98 | lm_config(lm_config), | 101 | lm_config(lm_config), |
| @@ -101,7 +104,8 @@ struct OnlineRecognizerConfig { | @@ -101,7 +104,8 @@ struct OnlineRecognizerConfig { | ||
| 101 | decoding_method(decoding_method), | 104 | decoding_method(decoding_method), |
| 102 | max_active_paths(max_active_paths), | 105 | max_active_paths(max_active_paths), |
| 103 | hotwords_score(hotwords_score), | 106 | hotwords_score(hotwords_score), |
| 104 | - hotwords_file(hotwords_file) {} | 107 | + hotwords_file(hotwords_file), |
| 108 | + blank_penalty(blank_penalty) {} | ||
| 105 | 109 | ||
| 106 | void Register(ParseOptions *po); | 110 | void Register(ParseOptions *po); |
| 107 | bool Validate() const; | 111 | bool Validate() const; |
| @@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 116 | Ort::Value logit = model_->RunJoiner( | 116 | Ort::Value logit = model_->RunJoiner( |
| 117 | std::move(cur_encoder_out), View(&decoder_out)); | 117 | std::move(cur_encoder_out), View(&decoder_out)); |
| 118 | 118 | ||
| 119 | - const float *p_logit = logit.GetTensorData<float>(); | 119 | + float *p_logit = logit.GetTensorMutableData<float>(); |
| 120 | 120 | ||
| 121 | bool emitted = false; | 121 | bool emitted = false; |
| 122 | for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { | 122 | for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { |
| 123 | auto &r = (*result)[i]; | 123 | auto &r = (*result)[i]; |
| 124 | + if (blank_penalty_ > 0.0) { | ||
| 125 | + p_logit[0] -= blank_penalty_; // assuming blank id is 0 | ||
| 126 | + } | ||
| 124 | auto y = static_cast<int32_t>(std::distance( | 127 | auto y = static_cast<int32_t>(std::distance( |
| 125 | static_cast<const float *>(p_logit), | 128 | static_cast<const float *>(p_logit), |
| 126 | std::max_element(static_cast<const float *>(p_logit), | 129 | std::max_element(static_cast<const float *>(p_logit), |
| @@ -15,8 +15,9 @@ namespace sherpa_onnx { | @@ -15,8 +15,9 @@ namespace sherpa_onnx { | ||
| 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, | 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, |
| 18 | - int32_t unk_id) | ||
| 19 | - : model_(model), unk_id_(unk_id) {} | 18 | + int32_t unk_id, |
| 19 | + float blank_penalty) | ||
| 20 | + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} | ||
| 20 | 21 | ||
| 21 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 22 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 22 | 23 | ||
| @@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | @@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | ||
| 28 | private: | 29 | private: |
| 29 | OnlineTransducerModel *model_; // Not owned | 30 | OnlineTransducerModel *model_; // Not owned |
| 30 | int32_t unk_id_; | 31 | int32_t unk_id_; |
| 32 | + float blank_penalty_; | ||
| 31 | }; | 33 | }; |
| 32 | 34 | ||
| 33 | } // namespace sherpa_onnx | 35 | } // namespace sherpa_onnx |
| @@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 123 | model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | 123 | model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); |
| 124 | 124 | ||
| 125 | float *p_logit = logit.GetTensorMutableData<float>(); | 125 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 126 | + if (blank_penalty_ > 0.0) { | ||
| 127 | + // assuming blank id is 0 | ||
| 128 | + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); | ||
| 129 | + } | ||
| 126 | LogSoftmax(p_logit, vocab_size, num_hyps); | 130 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| 127 | 131 | ||
| 128 | // now p_logit contains log_softmax output, we rename it to p_logprob | 132 | // now p_logit contains log_softmax output, we rename it to p_logprob |
| @@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, | 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, |
| 22 | OnlineLM *lm, | 22 | OnlineLM *lm, |
| 23 | int32_t max_active_paths, | 23 | int32_t max_active_paths, |
| 24 | - float lm_scale, int32_t unk_id) | 24 | + float lm_scale, int32_t unk_id, |
| 25 | + float blank_penalty) | ||
| 25 | : model_(model), | 26 | : model_(model), |
| 26 | lm_(lm), | 27 | lm_(lm), |
| 27 | max_active_paths_(max_active_paths), | 28 | max_active_paths_(max_active_paths), |
| 28 | lm_scale_(lm_scale), | 29 | lm_scale_(lm_scale), |
| 29 | - unk_id_(unk_id) {} | 30 | + unk_id_(unk_id), |
| 31 | + blank_penalty_(blank_penalty) {} | ||
| 30 | 32 | ||
| 31 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 33 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 32 | 34 | ||
| @@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 47 | int32_t max_active_paths_; | 49 | int32_t max_active_paths_; |
| 48 | float lm_scale_; // used only when lm_ is not nullptr | 50 | float lm_scale_; // used only when lm_ is not nullptr |
| 49 | int32_t unk_id_; | 51 | int32_t unk_id_; |
| 52 | + float blank_penalty_; | ||
| 50 | }; | 53 | }; |
| 51 | 54 | ||
| 52 | } // namespace sherpa_onnx | 55 | } // namespace sherpa_onnx |
| @@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 33 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 33 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 34 | .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | 34 | .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, |
| 35 | const OnlineLMConfig &, const EndpointConfig &, bool, | 35 | const OnlineLMConfig &, const EndpointConfig &, bool, |
| 36 | - const std::string &, int32_t, const std::string &, float>(), | 36 | + const std::string &, int32_t, const std::string &, float, |
| 37 | + float>(), | ||
| 37 | py::arg("feat_config"), py::arg("model_config"), | 38 | py::arg("feat_config"), py::arg("model_config"), |
| 38 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), | 39 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), |
| 39 | py::arg("enable_endpoint"), py::arg("decoding_method"), | 40 | py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 40 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 41 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 41 | - py::arg("hotwords_score") = 0) | 42 | + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) |
| 42 | .def_readwrite("feat_config", &PyClass::feat_config) | 43 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 43 | .def_readwrite("model_config", &PyClass::model_config) | 44 | .def_readwrite("model_config", &PyClass::model_config) |
| 44 | .def_readwrite("lm_config", &PyClass::lm_config) | 45 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 48 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 49 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 49 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) | 50 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 50 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) | 51 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) |
| 52 | + .def_readwrite("blank_penalty", &PyClass::blank_penalty) | ||
| 51 | .def("__str__", &PyClass::ToString); | 53 | .def("__str__", &PyClass::ToString); |
| 52 | } | 54 | } |
| 53 | 55 |
| @@ -48,6 +48,7 @@ class OnlineRecognizer(object): | @@ -48,6 +48,7 @@ class OnlineRecognizer(object): | ||
| 48 | decoding_method: str = "greedy_search", | 48 | decoding_method: str = "greedy_search", |
| 49 | max_active_paths: int = 4, | 49 | max_active_paths: int = 4, |
| 50 | hotwords_score: float = 1.5, | 50 | hotwords_score: float = 1.5, |
| 51 | + blank_penalty: float = 0.0, | ||
| 51 | hotwords_file: str = "", | 52 | hotwords_file: str = "", |
| 52 | provider: str = "cpu", | 53 | provider: str = "cpu", |
| 53 | model_type: str = "", | 54 | model_type: str = "", |
| @@ -100,6 +101,8 @@ class OnlineRecognizer(object): | @@ -100,6 +101,8 @@ class OnlineRecognizer(object): | ||
| 100 | max_active_paths: | 101 | max_active_paths: |
| 101 | Use only when decoding_method is modified_beam_search. It specifies | 102 | Use only when decoding_method is modified_beam_search. It specifies |
| 102 | the maximum number of active paths during beam search. | 103 | the maximum number of active paths during beam search. |
| 104 | + blank_penalty: | ||
| 105 | + The penalty applied on blank symbol during decoding. | ||
| 103 | hotwords_file: | 106 | hotwords_file: |
| 104 | The file containing hotwords, one words/phrases per line, and for each | 107 | The file containing hotwords, one words/phrases per line, and for each |
| 105 | phrase the bpe/cjkchar are separated by a space. | 108 | phrase the bpe/cjkchar are separated by a space. |
| @@ -172,6 +175,7 @@ class OnlineRecognizer(object): | @@ -172,6 +175,7 @@ class OnlineRecognizer(object): | ||
| 172 | max_active_paths=max_active_paths, | 175 | max_active_paths=max_active_paths, |
| 173 | hotwords_score=hotwords_score, | 176 | hotwords_score=hotwords_score, |
| 174 | hotwords_file=hotwords_file, | 177 | hotwords_file=hotwords_file, |
| 178 | + blank_penalty=blank_penalty, | ||
| 175 | ) | 179 | ) |
| 176 | 180 | ||
| 177 | self.recognizer = _Recognizer(recognizer_config) | 181 | self.recognizer = _Recognizer(recognizer_config) |
-
请 注册 或 登录 后发表评论