正在显示
2 个修改的文件
包含
5 行增加
和
3 行删除
| @@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 136 | auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); | 136 | auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); |
| 137 | int32_t vocab_size = logits_shape[2]; | 137 | int32_t vocab_size = logits_shape[2]; |
| 138 | 138 | ||
| 139 | - int32_t max_token_id = static_cast<int32_t>(std::distance( | ||
| 140 | - p_logits, std::max_element(p_logits, p_logits + vocab_size))); | 139 | + const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size; |
| 140 | + | ||
| 141 | + int32_t max_token_id = static_cast<int32_t>( | ||
| 142 | + std::distance(p_start, std::max_element(p_start, p_start + vocab_size))); | ||
| 141 | 143 | ||
| 142 | int32_t n_text_ctx = model_->TextCtx(); | 144 | int32_t n_text_ctx = model_->TextCtx(); |
| 143 | 145 |
-
请 注册 或 登录 后发表评论