Committed by
GitHub
Use log probs for paraformer (#120)
* Use log probs for paraformer * Fix
正在显示
3 个修改的文件
包含
15 行增加
和
11 行删除
| @@ -23,8 +23,7 @@ class OfflineParaformerDecoder { | @@ -23,8 +23,7 @@ class OfflineParaformerDecoder { | ||
| 23 | /** Run beam search given the output from the paraformer model. | 23 | /** Run beam search given the output from the paraformer model. |
| 24 | * | 24 | * |
| 25 | * @param log_probs A 3-D tensor of shape (N, T, vocab_size) | 25 | * @param log_probs A 3-D tensor of shape (N, T, vocab_size) |
| 26 | - * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t. | ||
| 27 | - * log_probs[i].argmax(axis=-1) equals to token_num[i] | 26 | + * @param token_num A 1-D tensor of shape (N). token_num equals to T. |
| 28 | * | 27 | * |
| 29 | * @return Return a vector of size `N` containing the decoded results. | 28 | * @return Return a vector of size `N` containing the decoded results. |
| 30 | */ | 29 | */ |
| @@ -4,28 +4,33 @@ | @@ -4,28 +4,33 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" | 5 | #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" |
| 6 | 6 | ||
| 7 | +#include <algorithm> | ||
| 7 | #include <vector> | 8 | #include <vector> |
| 8 | 9 | ||
| 9 | namespace sherpa_onnx { | 10 | namespace sherpa_onnx { |
| 10 | 11 | ||
| 11 | std::vector<OfflineParaformerDecoderResult> | 12 | std::vector<OfflineParaformerDecoderResult> |
| 12 | -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/, | ||
| 13 | - Ort::Value token_num) { | ||
| 14 | - std::vector<int64_t> shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); | 13 | +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, |
| 14 | + Ort::Value /*token_num*/) { | ||
| 15 | + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 15 | int32_t batch_size = shape[0]; | 16 | int32_t batch_size = shape[0]; |
| 16 | int32_t num_tokens = shape[1]; | 17 | int32_t num_tokens = shape[1]; |
| 18 | + int32_t vocab_size = shape[2]; | ||
| 17 | 19 | ||
| 18 | std::vector<OfflineParaformerDecoderResult> results(batch_size); | 20 | std::vector<OfflineParaformerDecoderResult> results(batch_size); |
| 19 | 21 | ||
| 20 | - const int64_t *p = token_num.GetTensorData<int64_t>(); | ||
| 21 | for (int32_t i = 0; i != batch_size; ++i) { | 22 | for (int32_t i = 0; i != batch_size; ++i) { |
| 23 | + const float *p = | ||
| 24 | + log_probs.GetTensorData<float>() + i * num_tokens * vocab_size; | ||
| 22 | for (int32_t k = 0; k != num_tokens; ++k) { | 25 | for (int32_t k = 0; k != num_tokens; ++k) { |
| 23 | - if (p[k] == eos_id_) break; | 26 | + auto max_idx = static_cast<int64_t>( |
| 27 | + std::distance(p, std::max_element(p, p + vocab_size))); | ||
| 28 | + if (max_idx == eos_id_) break; | ||
| 24 | 29 | ||
| 25 | - results[i].tokens.push_back(p[k]); | ||
| 26 | - } | 30 | + results[i].tokens.push_back(max_idx); |
| 27 | 31 | ||
| 28 | - p += num_tokens; | 32 | + p += vocab_size; |
| 33 | + } | ||
| 29 | } | 34 | } |
| 30 | 35 | ||
| 31 | return results; | 36 | return results; |
| @@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { | @@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { | ||
| 17 | : eos_id_(eos_id) {} | 17 | : eos_id_(eos_id) {} |
| 18 | 18 | ||
| 19 | std::vector<OfflineParaformerDecoderResult> Decode( | 19 | std::vector<OfflineParaformerDecoderResult> Decode( |
| 20 | - Ort::Value /*log_probs*/, Ort::Value token_num) override; | 20 | + Ort::Value log_probs, Ort::Value /*token_num*/) override; |
| 21 | 21 | ||
| 22 | private: | 22 | private: |
| 23 | int32_t eos_id_; | 23 | int32_t eos_id_; |
-
请 注册 或 登录 后发表评论