正在显示
5 个修改的文件
包含
29 行增加
和
12 行删除
| @@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 57 | model_(OnlineTransducerModel::Create(config.model_config)), | 57 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 58 | sym_(config.model_config.tokens), | 58 | sym_(config.model_config.tokens), |
| 59 | endpoint_(config_.endpoint_config) { | 59 | endpoint_(config_.endpoint_config) { |
| 60 | + if (sym_.contains("<unk>")) { | ||
| 61 | + unk_id_ = sym_["<unk>"]; | ||
| 62 | + } | ||
| 63 | + | ||
| 60 | if (config.decoding_method == "modified_beam_search") { | 64 | if (config.decoding_method == "modified_beam_search") { |
| 61 | if (!config_.lm_config.model.empty()) { | 65 | if (!config_.lm_config.model.empty()) { |
| 62 | lm_ = OnlineLM::Create(config.lm_config); | 66 | lm_ = OnlineLM::Create(config.lm_config); |
| @@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 64 | 68 | ||
| 65 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 69 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 66 | model_.get(), lm_.get(), config_.max_active_paths, | 70 | model_.get(), lm_.get(), config_.max_active_paths, |
| 67 | - config_.lm_config.scale); | 71 | + config_.lm_config.scale, unk_id_); |
| 68 | } else if (config.decoding_method == "greedy_search") { | 72 | } else if (config.decoding_method == "greedy_search") { |
| 69 | - decoder_ = | ||
| 70 | - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 73 | + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 74 | + model_.get(), unk_id_); | ||
| 71 | } else { | 75 | } else { |
| 72 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 76 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 73 | config.decoding_method.c_str()); | 77 | config.decoding_method.c_str()); |
| @@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 82 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), | 86 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), |
| 83 | sym_(mgr, config.model_config.tokens), | 87 | sym_(mgr, config.model_config.tokens), |
| 84 | endpoint_(config_.endpoint_config) { | 88 | endpoint_(config_.endpoint_config) { |
| 89 | + if (sym_.contains("<unk>")) { | ||
| 90 | + unk_id_ = sym_["<unk>"]; | ||
| 91 | + } | ||
| 92 | + | ||
| 85 | if (config.decoding_method == "modified_beam_search") { | 93 | if (config.decoding_method == "modified_beam_search") { |
| 86 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 94 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 87 | model_.get(), lm_.get(), config_.max_active_paths, | 95 | model_.get(), lm_.get(), config_.max_active_paths, |
| 88 | - config_.lm_config.scale); | 96 | + config_.lm_config.scale, unk_id_); |
| 89 | } else if (config.decoding_method == "greedy_search") { | 97 | } else if (config.decoding_method == "greedy_search") { |
| 90 | - decoder_ = | ||
| 91 | - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 98 | + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 99 | + model_.get(), unk_id_); | ||
| 92 | } else { | 100 | } else { |
| 93 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 101 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 94 | config.decoding_method.c_str()); | 102 | config.decoding_method.c_str()); |
| @@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 268 | std::unique_ptr<OnlineTransducerDecoder> decoder_; | 276 | std::unique_ptr<OnlineTransducerDecoder> decoder_; |
| 269 | SymbolTable sym_; | 277 | SymbolTable sym_; |
| 270 | Endpoint endpoint_; | 278 | Endpoint endpoint_; |
| 279 | + int32_t unk_id_ = -1; | ||
| 271 | }; | 280 | }; |
| 272 | 281 | ||
| 273 | } // namespace sherpa_onnx | 282 | } // namespace sherpa_onnx |
| @@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 108 | static_cast<const float *>(p_logit), | 108 | static_cast<const float *>(p_logit), |
| 109 | std::max_element(static_cast<const float *>(p_logit), | 109 | std::max_element(static_cast<const float *>(p_logit), |
| 110 | static_cast<const float *>(p_logit) + vocab_size))); | 110 | static_cast<const float *>(p_logit) + vocab_size))); |
| 111 | - if (y != 0) { | 111 | + // blank id is hardcoded to 0 |
| 112 | + // also, it treats unk as blank | ||
| 113 | + if (y != 0 && y != unk_id_) { | ||
| 112 | emitted = true; | 114 | emitted = true; |
| 113 | r.tokens.push_back(y); | 115 | r.tokens.push_back(y); |
| 114 | r.timestamps.push_back(t + r.frame_offset); | 116 | r.timestamps.push_back(t + r.frame_offset); |
| @@ -14,8 +14,9 @@ namespace sherpa_onnx { | @@ -14,8 +14,9 @@ namespace sherpa_onnx { | ||
| 14 | 14 | ||
| 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | - explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) | ||
| 18 | - : model_(model) {} | 17 | + OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, |
| 18 | + int32_t unk_id) | ||
| 19 | + : model_(model), unk_id_(unk_id) {} | ||
| 19 | 20 | ||
| 20 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 21 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 21 | 22 | ||
| @@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | @@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | ||
| 26 | 27 | ||
| 27 | private: | 28 | private: |
| 28 | OnlineTransducerModel *model_; // Not owned | 29 | OnlineTransducerModel *model_; // Not owned |
| 30 | + int32_t unk_id_; | ||
| 29 | }; | 31 | }; |
| 30 | 32 | ||
| 31 | } // namespace sherpa_onnx | 33 | } // namespace sherpa_onnx |
| @@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 155 | float context_score = 0; | 155 | float context_score = 0; |
| 156 | auto context_state = new_hyp.context_state; | 156 | auto context_state = new_hyp.context_state; |
| 157 | 157 | ||
| 158 | - if (new_token != 0) { | 158 | + // blank is hardcoded to 0 |
| 159 | + // also, it treats unk as blank | ||
| 160 | + if (new_token != 0 && new_token != unk_id_) { | ||
| 159 | new_hyp.ys.push_back(new_token); | 161 | new_hyp.ys.push_back(new_token); |
| 160 | new_hyp.timestamps.push_back(t + frame_offset); | 162 | new_hyp.timestamps.push_back(t + frame_offset); |
| 161 | new_hyp.num_trailing_blanks = 0; | 163 | new_hyp.num_trailing_blanks = 0; |
| @@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, | 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, |
| 22 | OnlineLM *lm, | 22 | OnlineLM *lm, |
| 23 | int32_t max_active_paths, | 23 | int32_t max_active_paths, |
| 24 | - float lm_scale) | 24 | + float lm_scale, int32_t unk_id) |
| 25 | : model_(model), | 25 | : model_(model), |
| 26 | lm_(lm), | 26 | lm_(lm), |
| 27 | max_active_paths_(max_active_paths), | 27 | max_active_paths_(max_active_paths), |
| 28 | - lm_scale_(lm_scale) {} | 28 | + lm_scale_(lm_scale), |
| 29 | + unk_id_(unk_id) {} | ||
| 29 | 30 | ||
| 30 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 31 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 31 | 32 | ||
| @@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 45 | 46 | ||
| 46 | int32_t max_active_paths_; | 47 | int32_t max_active_paths_; |
| 47 | float lm_scale_; // used only when lm_ is not nullptr | 48 | float lm_scale_; // used only when lm_ is not nullptr |
| 49 | + int32_t unk_id_; | ||
| 48 | }; | 50 | }; |
| 49 | 51 | ||
| 50 | } // namespace sherpa_onnx | 52 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论