Nickolay V. Shmyrev
Committed by GitHub

Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…

…har tokens instead of BPE. (#2423)
... ... @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(
int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
int32_t max_symbols_per_frame = 10;
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
... ... @@ -60,7 +61,8 @@ static OfflineTransducerDecoderResult DecodeOne(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
... ... @@ -83,7 +85,10 @@ static OfflineTransducerDecoderResult DecodeOne(
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
}
} // for (int32_t i = 0; i != num_rows; ++i)
return ans;
... ... @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
const float *p = encoder_out.GetTensorData<float>();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
... ...