Nickolay V. Shmyrev
Committed by GitHub

Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…

…har tokens instead of BPE. (#2423)
@@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(
45 45
46 int32_t vocab_size = model->VocabSize(); 46 int32_t vocab_size = model->VocabSize();
47 int32_t blank_id = vocab_size - 1; 47 int32_t blank_id = vocab_size - 1;
  48 + int32_t max_symbols_per_frame = 10;
48 49
49 auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); 50 auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
50 51
@@ -60,7 +61,8 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -60,7 +61,8 @@ static OfflineTransducerDecoderResult DecodeOne(
60 memory_info, const_cast<float *>(p) + t * num_cols, num_cols, 61 memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
61 encoder_shape.data(), encoder_shape.size()); 62 encoder_shape.data(), encoder_shape.size());
62 63
63 - Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), 64 + for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
  65 + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
64 View(&decoder_output_pair.first)); 66 View(&decoder_output_pair.first));
65 67
66 float *p_logit = logit.GetTensorMutableData<float>(); 68 float *p_logit = logit.GetTensorMutableData<float>();
@@ -83,7 +85,10 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -83,7 +85,10 @@ static OfflineTransducerDecoderResult DecodeOne(
83 model->RunDecoder(std::move(decoder_input_pair.first), 85 model->RunDecoder(std::move(decoder_input_pair.first),
84 std::move(decoder_input_pair.second), 86 std::move(decoder_input_pair.second),
85 std::move(decoder_output_pair.second)); 87 std::move(decoder_output_pair.second));
  88 + } else {
  89 + break;
86 } // if (y != blank_id) 90 } // if (y != blank_id)
  91 + }
87 } // for (int32_t i = 0; i != num_rows; ++i) 92 } // for (int32_t i = 0; i != num_rows; ++i)
88 93
89 return ans; 94 return ans;
@@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
99 int32_t dim1 = static_cast<int32_t>(shape[1]); 104 int32_t dim1 = static_cast<int32_t>(shape[1]);
100 int32_t dim2 = static_cast<int32_t>(shape[2]); 105 int32_t dim2 = static_cast<int32_t>(shape[2]);
101 106
102 - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); 107 + const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
103 const float *p = encoder_out.GetTensorData<float>(); 108 const float *p = encoder_out.GetTensorData<float>();
104 109
105 std::vector<OfflineTransducerDecoderResult> ans(batch_size); 110 std::vector<OfflineTransducerDecoderResult> ans(batch_size);