Nickolay V. Shmyrev
Committed by GitHub

Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…

…har tokens instead of BPE. (#2423)
@@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(
45 45
46 int32_t vocab_size = model->VocabSize(); 46 int32_t vocab_size = model->VocabSize();
47 int32_t blank_id = vocab_size - 1; 47 int32_t blank_id = vocab_size - 1;
  48 + int32_t max_symbols_per_frame = 10;
48 49
49 auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); 50 auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
50 51
@@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne(
60 memory_info, const_cast<float *>(p) + t * num_cols, num_cols, 61 memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
61 encoder_shape.data(), encoder_shape.size()); 62 encoder_shape.data(), encoder_shape.size());
62 63
63 - Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),  
64 - View(&decoder_output_pair.first));  
65 -  
66 - float *p_logit = logit.GetTensorMutableData<float>();  
67 - if (blank_penalty > 0) {  
68 - p_logit[blank_id] -= blank_penalty; 64 + for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
  65 + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
  66 + View(&decoder_output_pair.first));
  67 +
  68 + float *p_logit = logit.GetTensorMutableData<float>();
  69 + if (blank_penalty > 0) {
  70 + p_logit[blank_id] -= blank_penalty;
  71 + }
  72 +
  73 + auto y = static_cast<int32_t>(std::distance(
  74 + static_cast<const float *>(p_logit),
  75 + std::max_element(static_cast<const float *>(p_logit),
  76 + static_cast<const float *>(p_logit) + vocab_size)));
  77 +
  78 + if (y != blank_id) {
  79 + ans.tokens.push_back(y);
  80 + ans.timestamps.push_back(t);
  81 +
  82 + decoder_input_pair = BuildDecoderInput(y, model->Allocator());
  83 +
  84 + decoder_output_pair =
  85 + model->RunDecoder(std::move(decoder_input_pair.first),
  86 + std::move(decoder_input_pair.second),
  87 + std::move(decoder_output_pair.second));
  88 + } else {
  89 + break;
  90 + } // if (y != blank_id)
69 } 91 }
70 -  
71 - auto y = static_cast<int32_t>(std::distance(  
72 - static_cast<const float *>(p_logit),  
73 - std::max_element(static_cast<const float *>(p_logit),  
74 - static_cast<const float *>(p_logit) + vocab_size)));  
75 -  
76 - if (y != blank_id) {  
77 - ans.tokens.push_back(y);  
78 - ans.timestamps.push_back(t);  
79 -  
80 - decoder_input_pair = BuildDecoderInput(y, model->Allocator());  
81 -  
82 - decoder_output_pair =  
83 - model->RunDecoder(std::move(decoder_input_pair.first),  
84 - std::move(decoder_input_pair.second),  
85 - std::move(decoder_output_pair.second));  
86 - } // if (y != blank_id)  
87 } // for (int32_t i = 0; i != num_rows; ++i) 92 } // for (int32_t i = 0; i != num_rows; ++i)
88 93
89 return ans; 94 return ans;
@@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
99 int32_t dim1 = static_cast<int32_t>(shape[1]); 104 int32_t dim1 = static_cast<int32_t>(shape[1]);
100 int32_t dim2 = static_cast<int32_t>(shape[2]); 105 int32_t dim2 = static_cast<int32_t>(shape[2]);
101 106
102 - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); 107 + const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
103 const float *p = encoder_out.GetTensorData<float>(); 108 const float *p = encoder_out.GetTensorData<float>();
104 109
105 std::vector<OfflineTransducerDecoderResult> ans(batch_size); 110 std::vector<OfflineTransducerDecoderResult> ans(batch_size);