Fangjun Kuang
Committed by GitHub

Fix GigaAM transducer encoder output length data type (#2426)

This PR fixes a data type compatibility issue in the GigaAM transducer encoder where the output length tensor could be either int32 or int64, but the code only handled int32. The fix adds runtime type checking and supports both data types.

- Adds runtime detection of encoder output length tensor data type (int32 vs int64)
- Implements conditional data access based on the detected type
- Adds error handling for unsupported data types
@@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne(
62 encoder_shape.data(), encoder_shape.size()); 62 encoder_shape.data(), encoder_shape.size());
63 63
64 for (int32_t q = 0; q != max_symbols_per_frame; ++q) { 64 for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
65 - Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),  
66 - View(&decoder_output_pair.first));  
67 -  
68 - float *p_logit = logit.GetTensorMutableData<float>();  
69 - if (blank_penalty > 0) {  
70 - p_logit[blank_id] -= blank_penalty;  
71 - }  
72 -  
73 - auto y = static_cast<int32_t>(std::distance(  
74 - static_cast<const float *>(p_logit),  
75 - std::max_element(static_cast<const float *>(p_logit),  
76 - static_cast<const float *>(p_logit) + vocab_size)));  
77 -  
78 - if (y != blank_id) {  
79 - ans.tokens.push_back(y);  
80 - ans.timestamps.push_back(t);  
81 -  
82 - decoder_input_pair = BuildDecoderInput(y, model->Allocator());  
83 -  
84 - decoder_output_pair =  
85 - model->RunDecoder(std::move(decoder_input_pair.first),  
86 - std::move(decoder_input_pair.second),  
87 - std::move(decoder_output_pair.second));  
88 - } else {  
89 - break;  
90 - } // if (y != blank_id) 65 + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
  66 + View(&decoder_output_pair.first));
  67 +
  68 + float *p_logit = logit.GetTensorMutableData<float>();
  69 + if (blank_penalty > 0) {
  70 + p_logit[blank_id] -= blank_penalty;
  71 + }
  72 +
  73 + auto y = static_cast<int32_t>(std::distance(
  74 + static_cast<const float *>(p_logit),
  75 + std::max_element(static_cast<const float *>(p_logit),
  76 + static_cast<const float *>(p_logit) + vocab_size)));
  77 +
  78 + if (y != blank_id) {
  79 + ans.tokens.push_back(y);
  80 + ans.timestamps.push_back(t);
  81 +
  82 + decoder_input_pair = BuildDecoderInput(y, model->Allocator());
  83 +
  84 + decoder_output_pair =
  85 + model->RunDecoder(std::move(decoder_input_pair.first),
  86 + std::move(decoder_input_pair.second),
  87 + std::move(decoder_output_pair.second));
  88 + } else {
  89 + break;
  90 + } // if (y != blank_id)
91 } 91 }
92 - } // for (int32_t i = 0; i != num_rows; ++i) 92 + } // for (int32_t i = 0; i != num_rows; ++i)
93 93
94 return ans; 94 return ans;
95 } 95 }
@@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
104 int32_t dim1 = static_cast<int32_t>(shape[1]); 104 int32_t dim1 = static_cast<int32_t>(shape[1]);
105 int32_t dim2 = static_cast<int32_t>(shape[2]); 105 int32_t dim2 = static_cast<int32_t>(shape[2]);
106 106
107 - const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>(); 107 + auto length_type =
  108 + encoder_out_length.GetTensorTypeAndShapeInfo().GetElementType();
  109 + if ((length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) &&
  110 + (length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) {
  111 + SHERPA_ONNX_LOGE("Unsupported encoder_out_length data type: %d",
  112 + static_cast<int32_t>(length_type));
  113 + SHERPA_ONNX_EXIT(-1);
  114 + }
  115 +
108 const float *p = encoder_out.GetTensorData<float>(); 116 const float *p = encoder_out.GetTensorData<float>();
109 117
110 std::vector<OfflineTransducerDecoderResult> ans(batch_size); 118 std::vector<OfflineTransducerDecoderResult> ans(batch_size);
111 119
112 for (int32_t i = 0; i != batch_size; ++i) { 120 for (int32_t i = 0; i != batch_size; ++i) {
113 const float *this_p = p + dim1 * dim2 * i; 121 const float *this_p = p + dim1 * dim2 * i;
114 - int32_t this_len = p_length[i]; 122 + int32_t this_len = length_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
  123 + ? encoder_out_length.GetTensorData<int32_t>()[i]
  124 + : encoder_out_length.GetTensorData<int64_t>()[i];
115 125
116 ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); 126 ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
117 } 127 }