Committed by
GitHub
Fix GigaAM transducer encoder output length data type (#2426)
This PR fixes a data type compatibility issue in the GigaAM transducer encoder where the output length tensor could be either int32 or int64, but the code only handled int32. The fix adds runtime type checking and supports both data types. - Adds runtime detection of encoder output length tensor data type (int32 vs int64) - Implements conditional data access based on the detected type - Adds error handling for unsupported data types
正在显示
1 个修改的文件
包含
39 行增加
和
29 行删除
| @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 62 | encoder_shape.data(), encoder_shape.size()); | 62 | encoder_shape.data(), encoder_shape.size()); |
| 63 | 63 | ||
| 64 | for (int32_t q = 0; q != max_symbols_per_frame; ++q) { | 64 | for (int32_t q = 0; q != max_symbols_per_frame; ++q) { |
| 65 | - Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), | ||
| 66 | - View(&decoder_output_pair.first)); | ||
| 67 | - | ||
| 68 | - float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 69 | - if (blank_penalty > 0) { | ||
| 70 | - p_logit[blank_id] -= blank_penalty; | ||
| 71 | - } | ||
| 72 | - | ||
| 73 | - auto y = static_cast<int32_t>(std::distance( | ||
| 74 | - static_cast<const float *>(p_logit), | ||
| 75 | - std::max_element(static_cast<const float *>(p_logit), | ||
| 76 | - static_cast<const float *>(p_logit) + vocab_size))); | ||
| 77 | - | ||
| 78 | - if (y != blank_id) { | ||
| 79 | - ans.tokens.push_back(y); | ||
| 80 | - ans.timestamps.push_back(t); | ||
| 81 | - | ||
| 82 | - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 83 | - | ||
| 84 | - decoder_output_pair = | ||
| 85 | - model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 86 | - std::move(decoder_input_pair.second), | ||
| 87 | - std::move(decoder_output_pair.second)); | ||
| 88 | - } else { | ||
| 89 | - break; | ||
| 90 | - } // if (y != blank_id) | 65 | + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), |
| 66 | + View(&decoder_output_pair.first)); | ||
| 67 | + | ||
| 68 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 69 | + if (blank_penalty > 0) { | ||
| 70 | + p_logit[blank_id] -= blank_penalty; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + auto y = static_cast<int32_t>(std::distance( | ||
| 74 | + static_cast<const float *>(p_logit), | ||
| 75 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 76 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 77 | + | ||
| 78 | + if (y != blank_id) { | ||
| 79 | + ans.tokens.push_back(y); | ||
| 80 | + ans.timestamps.push_back(t); | ||
| 81 | + | ||
| 82 | + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 83 | + | ||
| 84 | + decoder_output_pair = | ||
| 85 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 86 | + std::move(decoder_input_pair.second), | ||
| 87 | + std::move(decoder_output_pair.second)); | ||
| 88 | + } else { | ||
| 89 | + break; | ||
| 90 | + } // if (y != blank_id) | ||
| 91 | } | 91 | } |
| 92 | - } // for (int32_t i = 0; i != num_rows; ++i) | 92 | + } // for (int32_t i = 0; i != num_rows; ++i) |
| 93 | 93 | ||
| 94 | return ans; | 94 | return ans; |
| 95 | } | 95 | } |
| @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 104 | int32_t dim1 = static_cast<int32_t>(shape[1]); | 104 | int32_t dim1 = static_cast<int32_t>(shape[1]); |
| 105 | int32_t dim2 = static_cast<int32_t>(shape[2]); | 105 | int32_t dim2 = static_cast<int32_t>(shape[2]); |
| 106 | 106 | ||
| 107 | - const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>(); | 107 | + auto length_type = |
| 108 | + encoder_out_length.GetTensorTypeAndShapeInfo().GetElementType(); | ||
| 109 | + if ((length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) && | ||
| 110 | + (length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) { | ||
| 111 | + SHERPA_ONNX_LOGE("Unsupported encoder_out_length data type: %d", | ||
| 112 | + static_cast<int32_t>(length_type)); | ||
| 113 | + SHERPA_ONNX_EXIT(-1); | ||
| 114 | + } | ||
| 115 | + | ||
| 108 | const float *p = encoder_out.GetTensorData<float>(); | 116 | const float *p = encoder_out.GetTensorData<float>(); |
| 109 | 117 | ||
| 110 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); | 118 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); |
| 111 | 119 | ||
| 112 | for (int32_t i = 0; i != batch_size; ++i) { | 120 | for (int32_t i = 0; i != batch_size; ++i) { |
| 113 | const float *this_p = p + dim1 * dim2 * i; | 121 | const float *this_p = p + dim1 * dim2 * i; |
| 114 | - int32_t this_len = p_length[i]; | 122 | + int32_t this_len = length_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 |
| 123 | + ? encoder_out_length.GetTensorData<int32_t>()[i] | ||
| 124 | + : encoder_out_length.GetTensorData<int64_t>()[i]; | ||
| 115 | 125 | ||
| 116 | ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); | 126 | ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); |
| 117 | } | 127 | } |
-
请 注册 或 登录 后发表评论