Fangjun Kuang
Committed by GitHub

Fix GigaAM transducer encoder output length data type (#2426)

This PR fixes a data type compatibility issue in the GigaAM transducer encoder where the output length tensor could be either int32 or int64, but the code only handled int32. The fix adds runtime type checking and supports both data types.

- Adds runtime detection of encoder output length tensor data type (int32 vs int64)
- Implements conditional data access based on the detected type
- Adds error handling for unsupported data types
... ... @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne(
encoder_shape.data(), encoder_shape.size());
for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
}
} // for (int32_t i = 0; i != num_rows; ++i)
} // for (int32_t i = 0; i != num_rows; ++i)
return ans;
}
... ... @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);
const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
auto length_type =
encoder_out_length.GetTensorTypeAndShapeInfo().GetElementType();
if ((length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) &&
(length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) {
SHERPA_ONNX_LOGE("Unsupported encoder_out_length data type: %d",
static_cast<int32_t>(length_type));
SHERPA_ONNX_EXIT(-1);
}
const float *p = encoder_out.GetTensorData<float>();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
const float *this_p = p + dim1 * dim2 * i;
int32_t this_len = p_length[i];
int32_t this_len = length_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
? encoder_out_length.GetTensorData<int32_t>()[i]
: encoder_out_length.GetTensorData<int64_t>()[i];
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
}
... ...