Fangjun Kuang
Committed by GitHub

Fix a bug for multilingual ASR (#281)

1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.8") 4 +set(SHERPA_ONNX_VERSION "1.7.9")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
136 auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); 136 auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
137 int32_t vocab_size = logits_shape[2]; 137 int32_t vocab_size = logits_shape[2];
138 138
139 - int32_t max_token_id = static_cast<int32_t>(std::distance(  
140 - p_logits, std::max_element(p_logits, p_logits + vocab_size))); 139 + const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size;
  140 +
  141 + int32_t max_token_id = static_cast<int32_t>(
  142 + std::distance(p_start, std::max_element(p_start, p_start + vocab_size)));
141 143
142 int32_t n_text_ctx = model_->TextCtx(); 144 int32_t n_text_ctx = model_->TextCtx();
143 145