Committed by
GitHub
add blank_penalty for offline transducer (#542)
正在显示
13 个修改的文件
包含
97 行增加
和
14 行删除
| @@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser): | @@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser): | ||
| 383 | """, | 383 | """, |
| 384 | ) | 384 | ) |
| 385 | 385 | ||
| 386 | +def add_blank_penalty_args(parser: argparse.ArgumentParser): | ||
| 387 | + parser.add_argument( | ||
| 388 | + "--blank-penalty", | ||
| 389 | + type=float, | ||
| 390 | + default=0.0, | ||
| 391 | + help=""" | ||
| 392 | + The penalty applied on blank symbol during decoding. | ||
| 393 | + Note: It is a positive value that would be applied to logits like | ||
| 394 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 395 | + [batch_size, vocab] and blank id is 0). | ||
| 396 | + """, | ||
| 397 | + ) | ||
| 398 | + | ||
| 386 | 399 | ||
| 387 | def check_args(args): | 400 | def check_args(args): |
| 388 | if not Path(args.tokens).is_file(): | 401 | if not Path(args.tokens).is_file(): |
| @@ -414,6 +427,7 @@ def get_args(): | @@ -414,6 +427,7 @@ def get_args(): | ||
| 414 | add_feature_config_args(parser) | 427 | add_feature_config_args(parser) |
| 415 | add_decoding_args(parser) | 428 | add_decoding_args(parser) |
| 416 | add_hotwords_args(parser) | 429 | add_hotwords_args(parser) |
| 430 | + add_blank_penalty_args(parser) | ||
| 417 | 431 | ||
| 418 | parser.add_argument( | 432 | parser.add_argument( |
| 419 | "--port", | 433 | "--port", |
| @@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 862 | max_active_paths=args.max_active_paths, | 876 | max_active_paths=args.max_active_paths, |
| 863 | hotwords_file=args.hotwords_file, | 877 | hotwords_file=args.hotwords_file, |
| 864 | hotwords_score=args.hotwords_score, | 878 | hotwords_score=args.hotwords_score, |
| 879 | + blank_penalty=args.blank_penalty, | ||
| 865 | provider=args.provider, | 880 | provider=args.provider, |
| 866 | ) | 881 | ) |
| 867 | elif args.paraformer: | 882 | elif args.paraformer: |
| @@ -232,6 +232,18 @@ def get_args(): | @@ -232,6 +232,18 @@ def get_args(): | ||
| 232 | ) | 232 | ) |
| 233 | 233 | ||
| 234 | parser.add_argument( | 234 | parser.add_argument( |
| 235 | + "--blank-penalty", | ||
| 236 | + type=float, | ||
| 237 | + default=0.0, | ||
| 238 | + help=""" | ||
| 239 | + The penalty applied on blank symbol during decoding. | ||
| 240 | + Note: It is a positive value that would be applied to logits like | ||
| 241 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 242 | + [batch_size, vocab] and blank id is 0). | ||
| 243 | + """, | ||
| 244 | + ) | ||
| 245 | + | ||
| 246 | + parser.add_argument( | ||
| 235 | "--decoding-method", | 247 | "--decoding-method", |
| 236 | type=str, | 248 | type=str, |
| 237 | default="greedy_search", | 249 | default="greedy_search", |
| @@ -335,6 +347,7 @@ def main(): | @@ -335,6 +347,7 @@ def main(): | ||
| 335 | decoding_method=args.decoding_method, | 347 | decoding_method=args.decoding_method, |
| 336 | hotwords_file=args.hotwords_file, | 348 | hotwords_file=args.hotwords_file, |
| 337 | hotwords_score=args.hotwords_score, | 349 | hotwords_score=args.hotwords_score, |
| 350 | + blank_penalty=args.blank_penalty, | ||
| 338 | debug=args.debug, | 351 | debug=args.debug, |
| 339 | ) | 352 | ) |
| 340 | elif args.paraformer: | 353 | elif args.paraformer: |
| @@ -178,6 +178,18 @@ def get_args(): | @@ -178,6 +178,18 @@ def get_args(): | ||
| 178 | ) | 178 | ) |
| 179 | 179 | ||
| 180 | parser.add_argument( | 180 | parser.add_argument( |
| 181 | + "--blank-penalty", | ||
| 182 | + type=float, | ||
| 183 | + default=0.0, | ||
| 184 | + help=""" | ||
| 185 | + The penalty applied on blank symbol during decoding. | ||
| 186 | + Note: It is a positive value that would be applied to logits like | ||
| 187 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 188 | + [batch_size, vocab] and blank id is 0). | ||
| 189 | + """, | ||
| 190 | + ) | ||
| 191 | + | ||
| 192 | + parser.add_argument( | ||
| 181 | "--decoding-method", | 193 | "--decoding-method", |
| 182 | type=str, | 194 | type=str, |
| 183 | default="greedy_search", | 195 | default="greedy_search", |
| @@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 237 | sample_rate=args.sample_rate, | 249 | sample_rate=args.sample_rate, |
| 238 | feature_dim=args.feature_dim, | 250 | feature_dim=args.feature_dim, |
| 239 | decoding_method=args.decoding_method, | 251 | decoding_method=args.decoding_method, |
| 252 | + blank_penalty=args.blank_penalty, | ||
| 240 | debug=args.debug, | 253 | debug=args.debug, |
| 241 | ) | 254 | ) |
| 242 | elif args.paraformer: | 255 | elif args.paraformer: |
| @@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | @@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | ||
| 96 | } | 96 | } |
| 97 | } | 97 | } |
| 98 | 98 | ||
| 99 | +template <typename T> | ||
| 100 | +void SubtractBlank(T *in, int32_t w, int32_t h, | ||
| 101 | + int32_t blank_idx, float blank_penalty) { | ||
| 102 | + for (int32_t i = 0; i != h; ++i) { | ||
| 103 | + in[blank_idx] -= blank_penalty; | ||
| 104 | + in += w; | ||
| 105 | + } | ||
| 106 | +} | ||
| 107 | + | ||
| 99 | template <class T> | 108 | template <class T> |
| 100 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | 109 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { |
| 101 | std::vector<int32_t> vec_index(size); | 110 | std::vector<int32_t> vec_index(size); |
| @@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 79 | } | 79 | } |
| 80 | if (config_.decoding_method == "greedy_search") { | 80 | if (config_.decoding_method == "greedy_search") { |
| 81 | decoder_ = | 81 | decoder_ = |
| 82 | - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 82 | + std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 83 | + model_.get(), config_.blank_penalty); | ||
| 83 | } else if (config_.decoding_method == "modified_beam_search") { | 84 | } else if (config_.decoding_method == "modified_beam_search") { |
| 84 | if (!config_.lm_config.model.empty()) { | 85 | if (!config_.lm_config.model.empty()) { |
| 85 | lm_ = OfflineLM::Create(config.lm_config); | 86 | lm_ = OfflineLM::Create(config.lm_config); |
| @@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 87 | 88 | ||
| 88 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 89 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 89 | model_.get(), lm_.get(), config_.max_active_paths, | 90 | model_.get(), lm_.get(), config_.max_active_paths, |
| 90 | - config_.lm_config.scale); | 91 | + config_.lm_config.scale, config_.blank_penalty); |
| 91 | } else { | 92 | } else { |
| 92 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 93 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 93 | config_.decoding_method.c_str()); | 94 | config_.decoding_method.c_str()); |
| @@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 104 | config_.model_config)) { | 105 | config_.model_config)) { |
| 105 | if (config_.decoding_method == "greedy_search") { | 106 | if (config_.decoding_method == "greedy_search") { |
| 106 | decoder_ = | 107 | decoder_ = |
| 107 | - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 108 | + std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 109 | + model_.get(), config_.blank_penalty); | ||
| 108 | } else if (config_.decoding_method == "modified_beam_search") { | 110 | } else if (config_.decoding_method == "modified_beam_search") { |
| 109 | if (!config_.lm_config.model.empty()) { | 111 | if (!config_.lm_config.model.empty()) { |
| 110 | lm_ = OfflineLM::Create(mgr, config.lm_config); | 112 | lm_ = OfflineLM::Create(mgr, config.lm_config); |
| @@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 112 | 114 | ||
| 113 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 115 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 114 | model_.get(), lm_.get(), config_.max_active_paths, | 116 | model_.get(), lm_.get(), config_.max_active_paths, |
| 115 | - config_.lm_config.scale); | 117 | + config_.lm_config.scale, config_.blank_penalty); |
| 116 | } else { | 118 | } else { |
| 117 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 119 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 118 | config_.decoding_method.c_str()); | 120 | config_.decoding_method.c_str()); |
| @@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 28 | po->Register("max-active-paths", &max_active_paths, | 28 | po->Register("max-active-paths", &max_active_paths, |
| 29 | "Used only when decoding_method is modified_beam_search"); | 29 | "Used only when decoding_method is modified_beam_search"); |
| 30 | 30 | ||
| 31 | + po->Register("blank-penalty", &blank_penalty, | ||
| 32 | + "The penalty applied on blank symbol during decoding. " | ||
| 33 | + "Note: It is a positive value. " | ||
| 34 | + "Increasing value will lead to lower deletion at the cost" | ||
| 35 | + "of higher insertions. " | ||
| 36 | + "Currently only applicable for transducer models."); | ||
| 37 | + | ||
| 31 | po->Register( | 38 | po->Register( |
| 32 | "hotwords-file", &hotwords_file, | 39 | "hotwords-file", &hotwords_file, |
| 33 | "The file containing hotwords, one words/phrases per line, and for each" | 40 | "The file containing hotwords, one words/phrases per line, and for each" |
| @@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 74 | os << "decoding_method=\"" << decoding_method << "\", "; | 81 | os << "decoding_method=\"" << decoding_method << "\", "; |
| 75 | os << "max_active_paths=" << max_active_paths << ", "; | 82 | os << "max_active_paths=" << max_active_paths << ", "; |
| 76 | os << "hotwords_file=\"" << hotwords_file << "\", "; | 83 | os << "hotwords_file=\"" << hotwords_file << "\", "; |
| 77 | - os << "hotwords_score=" << hotwords_score << ")"; | 84 | + os << "hotwords_score=" << hotwords_score << ", "; |
| 85 | + os << "blank_penalty=" << blank_penalty << ")"; | ||
| 78 | 86 | ||
| 79 | return os.str(); | 87 | return os.str(); |
| 80 | } | 88 | } |
| @@ -37,6 +37,8 @@ struct OfflineRecognizerConfig { | @@ -37,6 +37,8 @@ struct OfflineRecognizerConfig { | ||
| 37 | std::string hotwords_file; | 37 | std::string hotwords_file; |
| 38 | float hotwords_score = 1.5; | 38 | float hotwords_score = 1.5; |
| 39 | 39 | ||
| 40 | + float blank_penalty = 0.0; | ||
| 41 | + | ||
| 40 | // only greedy_search is implemented | 42 | // only greedy_search is implemented |
| 41 | // TODO(fangjun): Implement modified_beam_search | 43 | // TODO(fangjun): Implement modified_beam_search |
| 42 | 44 | ||
| @@ -46,7 +48,8 @@ struct OfflineRecognizerConfig { | @@ -46,7 +48,8 @@ struct OfflineRecognizerConfig { | ||
| 46 | const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, | 48 | const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, |
| 47 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, | 49 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, |
| 48 | const std::string &decoding_method, int32_t max_active_paths, | 50 | const std::string &decoding_method, int32_t max_active_paths, |
| 49 | - const std::string &hotwords_file, float hotwords_score) | 51 | + const std::string &hotwords_file, float hotwords_score, |
| 52 | + float blank_penalty) | ||
| 50 | : feat_config(feat_config), | 53 | : feat_config(feat_config), |
| 51 | model_config(model_config), | 54 | model_config(model_config), |
| 52 | lm_config(lm_config), | 55 | lm_config(lm_config), |
| @@ -54,7 +57,8 @@ struct OfflineRecognizerConfig { | @@ -54,7 +57,8 @@ struct OfflineRecognizerConfig { | ||
| 54 | decoding_method(decoding_method), | 57 | decoding_method(decoding_method), |
| 55 | max_active_paths(max_active_paths), | 58 | max_active_paths(max_active_paths), |
| 56 | hotwords_file(hotwords_file), | 59 | hotwords_file(hotwords_file), |
| 57 | - hotwords_score(hotwords_score) {} | 60 | + hotwords_score(hotwords_score), |
| 61 | + blank_penalty(blank_penalty) {} | ||
| 58 | 62 | ||
| 59 | void Register(ParseOptions *po); | 63 | void Register(ParseOptions *po); |
| 60 | bool Validate() const; | 64 | bool Validate() const; |
| @@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | @@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | ||
| 46 | start += n; | 46 | start += n; |
| 47 | Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), | 47 | Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), |
| 48 | std::move(cur_decoder_out)); | 48 | std::move(cur_decoder_out)); |
| 49 | - const float *p_logit = logit.GetTensorData<float>(); | 49 | + float *p_logit = logit.GetTensorMutableData<float>(); |
| 50 | bool emitted = false; | 50 | bool emitted = false; |
| 51 | for (int32_t i = 0; i != n; ++i) { | 51 | for (int32_t i = 0; i != n; ++i) { |
| 52 | + if (blank_penalty_ > 0.0) { | ||
| 53 | + p_logit[0] -= blank_penalty_; // assuming blank id is 0 | ||
| 54 | + } | ||
| 52 | auto y = static_cast<int32_t>(std::distance( | 55 | auto y = static_cast<int32_t>(std::distance( |
| 53 | static_cast<const float *>(p_logit), | 56 | static_cast<const float *>(p_logit), |
| 54 | std::max_element(static_cast<const float *>(p_logit), | 57 | std::max_element(static_cast<const float *>(p_logit), |
| @@ -14,8 +14,10 @@ namespace sherpa_onnx { | @@ -14,8 +14,10 @@ namespace sherpa_onnx { | ||
| 14 | 14 | ||
| 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model) | ||
| 18 | - : model_(model) {} | 17 | + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, |
| 18 | + float blank_penalty) | ||
| 19 | + : model_(model), | ||
| 20 | + blank_penalty_(blank_penalty) {} | ||
| 19 | 21 | ||
| 20 | std::vector<OfflineTransducerDecoderResult> Decode( | 22 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 21 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | @@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 23 | 25 | ||
| 24 | private: | 26 | private: |
| 25 | OfflineTransducerModel *model_; // Not owned | 27 | OfflineTransducerModel *model_; // Not owned |
| 28 | + float blank_penalty_; | ||
| 26 | }; | 29 | }; |
| 27 | 30 | ||
| 28 | } // namespace sherpa_onnx | 31 | } // namespace sherpa_onnx |
| @@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 97 | model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | 97 | model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); |
| 98 | 98 | ||
| 99 | float *p_logit = logit.GetTensorMutableData<float>(); | 99 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 100 | + if (blank_penalty_ > 0.0) { | ||
| 101 | + // assuming blank id is 0 | ||
| 102 | + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); | ||
| 103 | + } | ||
| 100 | LogSoftmax(p_logit, vocab_size, num_hyps); | 104 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| 101 | 105 | ||
| 102 | // now p_logit contains log_softmax output, we rename it to p_logprob | 106 | // now p_logit contains log_softmax output, we rename it to p_logprob |
| @@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder | @@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder | ||
| 19 | OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, | 19 | OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, |
| 20 | OfflineLM *lm, | 20 | OfflineLM *lm, |
| 21 | int32_t max_active_paths, | 21 | int32_t max_active_paths, |
| 22 | - float lm_scale) | 22 | + float lm_scale, |
| 23 | + float blank_penalty) | ||
| 23 | : model_(model), | 24 | : model_(model), |
| 24 | lm_(lm), | 25 | lm_(lm), |
| 25 | max_active_paths_(max_active_paths), | 26 | max_active_paths_(max_active_paths), |
| 26 | - lm_scale_(lm_scale) {} | 27 | + lm_scale_(lm_scale), |
| 28 | + blank_penalty_(blank_penalty) {} | ||
| 27 | 29 | ||
| 28 | std::vector<OfflineTransducerDecoderResult> Decode( | 30 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 29 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 31 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder | @@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder | ||
| 35 | 37 | ||
| 36 | int32_t max_active_paths_; | 38 | int32_t max_active_paths_; |
| 37 | float lm_scale_; // used only when lm_ is not nullptr | 39 | float lm_scale_; // used only when lm_ is not nullptr |
| 40 | + float blank_penalty_; | ||
| 38 | }; | 41 | }; |
| 39 | 42 | ||
| 40 | } // namespace sherpa_onnx | 43 | } // namespace sherpa_onnx |
| @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 17 | .def(py::init<const OfflineFeatureExtractorConfig &, | 17 | .def(py::init<const OfflineFeatureExtractorConfig &, |
| 18 | const OfflineModelConfig &, const OfflineLMConfig &, | 18 | const OfflineModelConfig &, const OfflineLMConfig &, |
| 19 | const OfflineCtcFstDecoderConfig &, const std::string &, | 19 | const OfflineCtcFstDecoderConfig &, const std::string &, |
| 20 | - int32_t, const std::string &, float>(), | 20 | + int32_t, const std::string &, float, float>(), |
| 21 | py::arg("feat_config"), py::arg("model_config"), | 21 | py::arg("feat_config"), py::arg("model_config"), |
| 22 | py::arg("lm_config") = OfflineLMConfig(), | 22 | py::arg("lm_config") = OfflineLMConfig(), |
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| 24 | py::arg("decoding_method") = "greedy_search", | 24 | py::arg("decoding_method") = "greedy_search", |
| 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 26 | - py::arg("hotwords_score") = 1.5) | 26 | + py::arg("hotwords_score") = 1.5, |
| 27 | + py::arg("blank_penalty") = 0.0) | ||
| 27 | .def_readwrite("feat_config", &PyClass::feat_config) | 28 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 28 | .def_readwrite("model_config", &PyClass::model_config) | 29 | .def_readwrite("model_config", &PyClass::model_config) |
| 29 | .def_readwrite("lm_config", &PyClass::lm_config) | 30 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 32 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 33 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 33 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) | 34 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 34 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) | 35 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) |
| 36 | + .def_readwrite("blank_penalty", &PyClass::blank_penalty) | ||
| 35 | .def("__str__", &PyClass::ToString); | 37 | .def("__str__", &PyClass::ToString); |
| 36 | } | 38 | } |
| 37 | 39 |
| @@ -48,6 +48,7 @@ class OfflineRecognizer(object): | @@ -48,6 +48,7 @@ class OfflineRecognizer(object): | ||
| 48 | max_active_paths: int = 4, | 48 | max_active_paths: int = 4, |
| 49 | hotwords_file: str = "", | 49 | hotwords_file: str = "", |
| 50 | hotwords_score: float = 1.5, | 50 | hotwords_score: float = 1.5, |
| 51 | + blank_penalty: float = 0.0, | ||
| 51 | debug: bool = False, | 52 | debug: bool = False, |
| 52 | provider: str = "cpu", | 53 | provider: str = "cpu", |
| 53 | ): | 54 | ): |
| @@ -81,6 +82,8 @@ class OfflineRecognizer(object): | @@ -81,6 +82,8 @@ class OfflineRecognizer(object): | ||
| 81 | max_active_paths: | 82 | max_active_paths: |
| 82 | Maximum number of active paths to keep. Used only when | 83 | Maximum number of active paths to keep. Used only when |
| 83 | decoding_method is modified_beam_search. | 84 | decoding_method is modified_beam_search. |
| 85 | + blank_penalty: | ||
| 86 | + The penalty applied on blank symbol during decoding. | ||
| 84 | debug: | 87 | debug: |
| 85 | True to show debug messages. | 88 | True to show debug messages. |
| 86 | provider: | 89 | provider: |
| @@ -117,6 +120,7 @@ class OfflineRecognizer(object): | @@ -117,6 +120,7 @@ class OfflineRecognizer(object): | ||
| 117 | decoding_method=decoding_method, | 120 | decoding_method=decoding_method, |
| 118 | hotwords_file=hotwords_file, | 121 | hotwords_file=hotwords_file, |
| 119 | hotwords_score=hotwords_score, | 122 | hotwords_score=hotwords_score, |
| 123 | + blank_penalty=blank_penalty, | ||
| 120 | ) | 124 | ) |
| 121 | self.recognizer = _Recognizer(recognizer_config) | 125 | self.recognizer = _Recognizer(recognizer_config) |
| 122 | self.config = recognizer_config | 126 | self.config = recognizer_config |
-
请 注册 或 登录 后发表评论