彭震东
Committed by GitHub

Use log probs for paraformer (#120)

* Use log probs for paraformer

* Fix
@@ -23,8 +23,7 @@ class OfflineParaformerDecoder { @@ -23,8 +23,7 @@ class OfflineParaformerDecoder {
23 /** Run beam search given the output from the paraformer model. 23 /** Run beam search given the output from the paraformer model.
24 * 24 *
25 * @param log_probs A 3-D tensor of shape (N, T, vocab_size) 25 * @param log_probs A 3-D tensor of shape (N, T, vocab_size)
26 - * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t.  
27 - * log_probs[i].argmax(axis=-1) equals to token_num[i] 26 + * @param token_num A 1-D tensor of shape (N). token_num equals to T.
28 * 27 *
29 * @return Return a vector of size `N` containing the decoded results. 28 * @return Return a vector of size `N` containing the decoded results.
30 */ 29 */
@@ -4,28 +4,33 @@ @@ -4,28 +4,33 @@
4 4
5 #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" 5 #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
6 6
  7 +#include <algorithm>
7 #include <vector> 8 #include <vector>
8 9
9 namespace sherpa_onnx { 10 namespace sherpa_onnx {
10 11
11 std::vector<OfflineParaformerDecoderResult> 12 std::vector<OfflineParaformerDecoderResult>
12 -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/,  
13 - Ort::Value token_num) {  
14 - std::vector<int64_t> shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); 13 +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
  14 + Ort::Value /*token_num*/) {
  15 + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
15 int32_t batch_size = shape[0]; 16 int32_t batch_size = shape[0];
16 int32_t num_tokens = shape[1]; 17 int32_t num_tokens = shape[1];
  18 + int32_t vocab_size = shape[2];
17 19
18 std::vector<OfflineParaformerDecoderResult> results(batch_size); 20 std::vector<OfflineParaformerDecoderResult> results(batch_size);
19 21
20 - const int64_t *p = token_num.GetTensorData<int64_t>();  
21 for (int32_t i = 0; i != batch_size; ++i) { 22 for (int32_t i = 0; i != batch_size; ++i) {
  23 + const float *p =
  24 + log_probs.GetTensorData<float>() + i * num_tokens * vocab_size;
22 for (int32_t k = 0; k != num_tokens; ++k) { 25 for (int32_t k = 0; k != num_tokens; ++k) {
23 - if (p[k] == eos_id_) break; 26 + auto max_idx = static_cast<int64_t>(
  27 + std::distance(p, std::max_element(p, p + vocab_size)));
  28 + if (max_idx == eos_id_) break;
24 29
25 - results[i].tokens.push_back(p[k]);  
26 - } 30 + results[i].tokens.push_back(max_idx);
27 31
28 - p += num_tokens; 32 + p += vocab_size;
  33 + }
29 } 34 }
30 35
31 return results; 36 return results;
@@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { @@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
17 : eos_id_(eos_id) {} 17 : eos_id_(eos_id) {}
18 18
19 std::vector<OfflineParaformerDecoderResult> Decode( 19 std::vector<OfflineParaformerDecoderResult> Decode(
20 - Ort::Value /*log_probs*/, Ort::Value token_num) override; 20 + Ort::Value log_probs, Ort::Value /*token_num*/) override;
21 21
22 private: 22 private:
23 int32_t eos_id_; 23 int32_t eos_id_;