Committed by
GitHub
Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…
…har tokens instead of BPE. (#2423)
正在显示
1 个修改的文件
包含
29 行增加
和
24 行删除
| @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 45 | 45 | ||
| 46 | int32_t vocab_size = model->VocabSize(); | 46 | int32_t vocab_size = model->VocabSize(); |
| 47 | int32_t blank_id = vocab_size - 1; | 47 | int32_t blank_id = vocab_size - 1; |
| 48 | + int32_t max_symbols_per_frame = 10; | ||
| 48 | 49 | ||
| 49 | auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | 50 | auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); |
| 50 | 51 | ||
| @@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 60 | memory_info, const_cast<float *>(p) + t * num_cols, num_cols, | 61 | memory_info, const_cast<float *>(p) + t * num_cols, num_cols, |
| 61 | encoder_shape.data(), encoder_shape.size()); | 62 | encoder_shape.data(), encoder_shape.size()); |
| 62 | 63 | ||
| 63 | - Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), | ||
| 64 | - View(&decoder_output_pair.first)); | ||
| 65 | - | ||
| 66 | - float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 67 | - if (blank_penalty > 0) { | ||
| 68 | - p_logit[blank_id] -= blank_penalty; | 64 | + for (int32_t q = 0; q != max_symbols_per_frame; ++q) { |
| 65 | + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), | ||
| 66 | + View(&decoder_output_pair.first)); | ||
| 67 | + | ||
| 68 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 69 | + if (blank_penalty > 0) { | ||
| 70 | + p_logit[blank_id] -= blank_penalty; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + auto y = static_cast<int32_t>(std::distance( | ||
| 74 | + static_cast<const float *>(p_logit), | ||
| 75 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 76 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 77 | + | ||
| 78 | + if (y != blank_id) { | ||
| 79 | + ans.tokens.push_back(y); | ||
| 80 | + ans.timestamps.push_back(t); | ||
| 81 | + | ||
| 82 | + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 83 | + | ||
| 84 | + decoder_output_pair = | ||
| 85 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 86 | + std::move(decoder_input_pair.second), | ||
| 87 | + std::move(decoder_output_pair.second)); | ||
| 88 | + } else { | ||
| 89 | + break; | ||
| 90 | + } // if (y != blank_id) | ||
| 69 | } | 91 | } |
| 70 | - | ||
| 71 | - auto y = static_cast<int32_t>(std::distance( | ||
| 72 | - static_cast<const float *>(p_logit), | ||
| 73 | - std::max_element(static_cast<const float *>(p_logit), | ||
| 74 | - static_cast<const float *>(p_logit) + vocab_size))); | ||
| 75 | - | ||
| 76 | - if (y != blank_id) { | ||
| 77 | - ans.tokens.push_back(y); | ||
| 78 | - ans.timestamps.push_back(t); | ||
| 79 | - | ||
| 80 | - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 81 | - | ||
| 82 | - decoder_output_pair = | ||
| 83 | - model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 84 | - std::move(decoder_input_pair.second), | ||
| 85 | - std::move(decoder_output_pair.second)); | ||
| 86 | - } // if (y != blank_id) | ||
| 87 | } // for (int32_t i = 0; i != num_rows; ++i) | 92 | } // for (int32_t i = 0; i != num_rows; ++i) |
| 88 | 93 | ||
| 89 | return ans; | 94 | return ans; |
| @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 99 | int32_t dim1 = static_cast<int32_t>(shape[1]); | 104 | int32_t dim1 = static_cast<int32_t>(shape[1]); |
| 100 | int32_t dim2 = static_cast<int32_t>(shape[2]); | 105 | int32_t dim2 = static_cast<int32_t>(shape[2]); |
| 101 | 106 | ||
| 102 | - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); | 107 | + const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>(); |
| 103 | const float *p = encoder_out.GetTensorData<float>(); | 108 | const float *p = encoder_out.GetTensorData<float>(); |
| 104 | 109 | ||
| 105 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); | 110 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); |
-
请 注册 或 登录 后发表评论