Nickolay V. Shmyrev
Committed by GitHub

Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…

…har tokens instead of BPE. (#2423)
... ... @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(
int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
int32_t max_symbols_per_frame = 10;
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
... ... @@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} // if (y != blank_id)
} // for (int32_t i = 0; i != num_rows; ++i)
return ans;
... ... @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
const float *p = encoder_out.GetTensorData<float>();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
... ...