Committed by
GitHub
offline transducer: treat unk as blank (#1005)
Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
正在显示
5 个修改的文件
包含
25 行增加
和
9 行删除
| @@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 78 | config_(config), | 78 | config_(config), |
| 79 | symbol_table_(config_.model_config.tokens), | 79 | symbol_table_(config_.model_config.tokens), |
| 80 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | 80 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { |
| 81 | + if (symbol_table_.Contains("<unk>")) { | ||
| 82 | + unk_id_ = symbol_table_["<unk>"]; | ||
| 83 | + } | ||
| 84 | + | ||
| 81 | if (config_.decoding_method == "greedy_search") { | 85 | if (config_.decoding_method == "greedy_search") { |
| 82 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( | 86 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 83 | - model_.get(), config_.blank_penalty); | 87 | + model_.get(), unk_id_, config_.blank_penalty); |
| 84 | } else if (config_.decoding_method == "modified_beam_search") { | 88 | } else if (config_.decoding_method == "modified_beam_search") { |
| 85 | if (!config_.lm_config.model.empty()) { | 89 | if (!config_.lm_config.model.empty()) { |
| 86 | lm_ = OfflineLM::Create(config.lm_config); | 90 | lm_ = OfflineLM::Create(config.lm_config); |
| @@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 97 | 101 | ||
| 98 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 102 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 99 | model_.get(), lm_.get(), config_.max_active_paths, | 103 | model_.get(), lm_.get(), config_.max_active_paths, |
| 100 | - config_.lm_config.scale, config_.blank_penalty); | 104 | + config_.lm_config.scale, unk_id_, config_.blank_penalty); |
| 101 | } else { | 105 | } else { |
| 102 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 106 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 103 | config_.decoding_method.c_str()); | 107 | config_.decoding_method.c_str()); |
| @@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 113 | symbol_table_(mgr, config_.model_config.tokens), | 117 | symbol_table_(mgr, config_.model_config.tokens), |
| 114 | model_(std::make_unique<OfflineTransducerModel>(mgr, | 118 | model_(std::make_unique<OfflineTransducerModel>(mgr, |
| 115 | config_.model_config)) { | 119 | config_.model_config)) { |
| 120 | + if (symbol_table_.Contains("<unk>")) { | ||
| 121 | + unk_id_ = symbol_table_["<unk>"]; | ||
| 122 | + } | ||
| 123 | + | ||
| 116 | if (config_.decoding_method == "greedy_search") { | 124 | if (config_.decoding_method == "greedy_search") { |
| 117 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( | 125 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 118 | - model_.get(), config_.blank_penalty); | 126 | + model_.get(), unk_id_, config_.blank_penalty); |
| 119 | } else if (config_.decoding_method == "modified_beam_search") { | 127 | } else if (config_.decoding_method == "modified_beam_search") { |
| 120 | if (!config_.lm_config.model.empty()) { | 128 | if (!config_.lm_config.model.empty()) { |
| 121 | lm_ = OfflineLM::Create(mgr, config.lm_config); | 129 | lm_ = OfflineLM::Create(mgr, config.lm_config); |
| @@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 133 | 141 | ||
| 134 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 142 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 135 | model_.get(), lm_.get(), config_.max_active_paths, | 143 | model_.get(), lm_.get(), config_.max_active_paths, |
| 136 | - config_.lm_config.scale, config_.blank_penalty); | 144 | + config_.lm_config.scale, unk_id_, config_.blank_penalty); |
| 137 | } else { | 145 | } else { |
| 138 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 146 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 139 | config_.decoding_method.c_str()); | 147 | config_.decoding_method.c_str()); |
| @@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 293 | std::unique_ptr<OfflineTransducerModel> model_; | 301 | std::unique_ptr<OfflineTransducerModel> model_; |
| 294 | std::unique_ptr<OfflineTransducerDecoder> decoder_; | 302 | std::unique_ptr<OfflineTransducerDecoder> decoder_; |
| 295 | std::unique_ptr<OfflineLM> lm_; | 303 | std::unique_ptr<OfflineLM> lm_; |
| 304 | + int32_t unk_id_ = -1; | ||
| 296 | }; | 305 | }; |
| 297 | 306 | ||
| 298 | } // namespace sherpa_onnx | 307 | } // namespace sherpa_onnx |
| @@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | @@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | ||
| 57 | std::max_element(static_cast<const float *>(p_logit), | 57 | std::max_element(static_cast<const float *>(p_logit), |
| 58 | static_cast<const float *>(p_logit) + vocab_size))); | 58 | static_cast<const float *>(p_logit) + vocab_size))); |
| 59 | p_logit += vocab_size; | 59 | p_logit += vocab_size; |
| 60 | - if (y != 0) { | 60 | + // blank id is hardcoded to 0 |
| 61 | + // also, it treats unk as blank | ||
| 62 | + if (y != 0 && y != unk_id_) { | ||
| 61 | ans[i].tokens.push_back(y); | 63 | ans[i].tokens.push_back(y); |
| 62 | ans[i].timestamps.push_back(t); | 64 | ans[i].timestamps.push_back(t); |
| 63 | emitted = true; | 65 | emitted = true; |
| @@ -15,8 +15,9 @@ namespace sherpa_onnx { | @@ -15,8 +15,9 @@ namespace sherpa_onnx { | ||
| 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, | 17 | OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, |
| 18 | + int32_t unk_id, | ||
| 18 | float blank_penalty) | 19 | float blank_penalty) |
| 19 | - : model_(model), blank_penalty_(blank_penalty) {} | 20 | + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} |
| 20 | 21 | ||
| 21 | std::vector<OfflineTransducerDecoderResult> Decode( | 22 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 22 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | @@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 24 | 25 | ||
| 25 | private: | 26 | private: |
| 26 | OfflineTransducerModel *model_; // Not owned | 27 | OfflineTransducerModel *model_; // Not owned |
| 28 | + int32_t unk_id_; | ||
| 27 | float blank_penalty_; | 29 | float blank_penalty_; |
| 28 | }; | 30 | }; |
| 29 | 31 |
| @@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 131 | 131 | ||
| 132 | float context_score = 0; | 132 | float context_score = 0; |
| 133 | auto context_state = new_hyp.context_state; | 133 | auto context_state = new_hyp.context_state; |
| 134 | - if (new_token != 0) { | ||
| 135 | - // blank id is fixed to 0 | 134 | + // blank is hardcoded to 0 |
| 135 | + // also, it treats unk as blank | ||
| 136 | + if (new_token != 0 && new_token != unk_id_) { | ||
| 136 | new_hyp.ys.push_back(new_token); | 137 | new_hyp.ys.push_back(new_token); |
| 137 | new_hyp.timestamps.push_back(t); | 138 | new_hyp.timestamps.push_back(t); |
| 138 | if (context_graphs[i] != nullptr) { | 139 | if (context_graphs[i] != nullptr) { |
| @@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder | @@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder | ||
| 19 | OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, | 19 | OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, |
| 20 | OfflineLM *lm, | 20 | OfflineLM *lm, |
| 21 | int32_t max_active_paths, | 21 | int32_t max_active_paths, |
| 22 | - float lm_scale, | 22 | + float lm_scale, int32_t unk_id, |
| 23 | float blank_penalty) | 23 | float blank_penalty) |
| 24 | : model_(model), | 24 | : model_(model), |
| 25 | lm_(lm), | 25 | lm_(lm), |
| 26 | max_active_paths_(max_active_paths), | 26 | max_active_paths_(max_active_paths), |
| 27 | lm_scale_(lm_scale), | 27 | lm_scale_(lm_scale), |
| 28 | + unk_id_(unk_id), | ||
| 28 | blank_penalty_(blank_penalty) {} | 29 | blank_penalty_(blank_penalty) {} |
| 29 | 30 | ||
| 30 | std::vector<OfflineTransducerDecoderResult> Decode( | 31 | std::vector<OfflineTransducerDecoderResult> Decode( |
| @@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder | @@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder | ||
| 37 | 38 | ||
| 38 | int32_t max_active_paths_; | 39 | int32_t max_active_paths_; |
| 39 | float lm_scale_; // used only when lm_ is not nullptr | 40 | float lm_scale_; // used only when lm_ is not nullptr |
| 41 | + int32_t unk_id_; | ||
| 40 | float blank_penalty_; | 42 | float blank_penalty_; |
| 41 | }; | 43 | }; |
| 42 | 44 |
-
请 注册 或 登录 后发表评论