Committed by
GitHub
Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…
…har tokens instead of BPE. (#2423)
正在显示
1 个修改的文件
包含
7 行增加
和
2 行删除
| @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 45 | 45 | ||
| 46 | int32_t vocab_size = model->VocabSize(); | 46 | int32_t vocab_size = model->VocabSize(); |
| 47 | int32_t blank_id = vocab_size - 1; | 47 | int32_t blank_id = vocab_size - 1; |
| 48 | + int32_t max_symbols_per_frame = 10; | ||
| 48 | 49 | ||
| 49 | auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | 50 | auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); |
| 50 | 51 | ||
| @@ -60,7 +61,8 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -60,7 +61,8 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 60 | memory_info, const_cast<float *>(p) + t * num_cols, num_cols, | 61 | memory_info, const_cast<float *>(p) + t * num_cols, num_cols, |
| 61 | encoder_shape.data(), encoder_shape.size()); | 62 | encoder_shape.data(), encoder_shape.size()); |
| 62 | 63 | ||
| 63 | - Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), | 64 | + for (int32_t q = 0; q != max_symbols_per_frame; ++q) { |
| 65 | + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), | ||
| 64 | View(&decoder_output_pair.first)); | 66 | View(&decoder_output_pair.first)); |
| 65 | 67 | ||
| 66 | float *p_logit = logit.GetTensorMutableData<float>(); | 68 | float *p_logit = logit.GetTensorMutableData<float>(); |
| @@ -83,7 +85,10 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -83,7 +85,10 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 83 | model->RunDecoder(std::move(decoder_input_pair.first), | 85 | model->RunDecoder(std::move(decoder_input_pair.first), |
| 84 | std::move(decoder_input_pair.second), | 86 | std::move(decoder_input_pair.second), |
| 85 | std::move(decoder_output_pair.second)); | 87 | std::move(decoder_output_pair.second)); |
| 88 | + } else { | ||
| 89 | + break; | ||
| 86 | } // if (y != blank_id) | 90 | } // if (y != blank_id) |
| 91 | + } | ||
| 87 | } // for (int32_t i = 0; i != num_rows; ++i) | 92 | } // for (int32_t i = 0; i != num_rows; ++i) |
| 88 | 93 | ||
| 89 | return ans; | 94 | return ans; |
| @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 99 | int32_t dim1 = static_cast<int32_t>(shape[1]); | 104 | int32_t dim1 = static_cast<int32_t>(shape[1]); |
| 100 | int32_t dim2 = static_cast<int32_t>(shape[2]); | 105 | int32_t dim2 = static_cast<int32_t>(shape[2]); |
| 101 | 106 | ||
| 102 | - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); | 107 | + const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>(); |
| 103 | const float *p = encoder_out.GetTensorData<float>(); | 108 | const float *p = encoder_out.GetTensorData<float>(); |
| 104 | 109 | ||
| 105 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); | 110 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); |
-
请 注册 或 登录 后发表评论