Committed by
GitHub
share GetHypsRowSplits interface and fix getting Topk not taking logprob (#131)
正在显示
4 个修改的文件
包含
51 行增加
和
37 行删除
| @@ -66,4 +66,19 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { | @@ -66,4 +66,19 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { | ||
| 66 | return {all_hyps.begin(), all_hyps.begin() + k}; | 66 | return {all_hyps.begin(), all_hyps.begin() + k}; |
| 67 | } | 67 | } |
| 68 | 68 | ||
| 69 | +const std::vector<int32_t> GetHypsRowSplits( | ||
| 70 | + const std::vector<Hypotheses> &hyps) { | ||
| 71 | + std::vector<int32_t> row_splits; | ||
| 72 | + row_splits.reserve(hyps.size() + 1); | ||
| 73 | + | ||
| 74 | + row_splits.push_back(0); | ||
| 75 | + int32_t s = 0; | ||
| 76 | + for (const auto &h : hyps) { | ||
| 77 | + s += h.Size(); | ||
| 78 | + row_splits.push_back(s); | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + return row_splits; | ||
| 82 | +} | ||
| 83 | + | ||
| 69 | } // namespace sherpa_onnx | 84 | } // namespace sherpa_onnx |
| @@ -121,6 +121,9 @@ class Hypotheses { | @@ -121,6 +121,9 @@ class Hypotheses { | ||
| 121 | Map hyps_dict_; | 121 | Map hyps_dict_; |
| 122 | }; | 122 | }; |
| 123 | 123 | ||
| 124 | +const std::vector<int32_t> GetHypsRowSplits( | ||
| 125 | + const std::vector<Hypotheses> &hyps); | ||
| 126 | + | ||
| 124 | } // namespace sherpa_onnx | 127 | } // namespace sherpa_onnx |
| 125 | 128 | ||
| 126 | #endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_ | 129 | #endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_ |
| @@ -15,21 +15,6 @@ | @@ -15,21 +15,6 @@ | ||
| 15 | 15 | ||
| 16 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 17 | 17 | ||
| 18 | -static std::vector<int32_t> GetHypsRowSplits( | ||
| 19 | - const std::vector<Hypotheses> &hyps) { | ||
| 20 | - std::vector<int32_t> row_splits; | ||
| 21 | - row_splits.reserve(hyps.size() + 1); | ||
| 22 | - | ||
| 23 | - row_splits.push_back(0); | ||
| 24 | - int32_t s = 0; | ||
| 25 | - for (const auto &h : hyps) { | ||
| 26 | - s += h.Size(); | ||
| 27 | - row_splits.push_back(s); | ||
| 28 | - } | ||
| 29 | - | ||
| 30 | - return row_splits; | ||
| 31 | -} | ||
| 32 | - | ||
| 33 | std::vector<OfflineTransducerDecoderResult> | 18 | std::vector<OfflineTransducerDecoderResult> |
| 34 | OfflineTransducerModifiedBeamSearchDecoder::Decode( | 19 | OfflineTransducerModifiedBeamSearchDecoder::Decode( |
| 35 | Ort::Value encoder_out, Ort::Value encoder_out_length) { | 20 | Ort::Value encoder_out, Ort::Value encoder_out_length) { |
| @@ -14,7 +14,7 @@ | @@ -14,7 +14,7 @@ | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | static void UseCachedDecoderOut( | 16 | static void UseCachedDecoderOut( |
| 17 | - const std::vector<int32_t> &hyps_num_split, | 17 | + const std::vector<int32_t> &hyps_row_splits, |
| 18 | const std::vector<OnlineTransducerDecoderResult> &results, | 18 | const std::vector<OnlineTransducerDecoderResult> &results, |
| 19 | int32_t context_size, Ort::Value *decoder_out) { | 19 | int32_t context_size, Ort::Value *decoder_out) { |
| 20 | std::vector<int64_t> shape = | 20 | std::vector<int64_t> shape = |
| @@ -24,7 +24,7 @@ static void UseCachedDecoderOut( | @@ -24,7 +24,7 @@ static void UseCachedDecoderOut( | ||
| 24 | 24 | ||
| 25 | int32_t batch_size = static_cast<int32_t>(results.size()); | 25 | int32_t batch_size = static_cast<int32_t>(results.size()); |
| 26 | for (int32_t i = 0; i != batch_size; ++i) { | 26 | for (int32_t i = 0; i != batch_size; ++i) { |
| 27 | - int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i]; | 27 | + int32_t num_hyps = hyps_row_splits[i + 1] - hyps_row_splits[i]; |
| 28 | if (num_hyps > 1 || !results[i].decoder_out) { | 28 | if (num_hyps > 1 || !results[i].decoder_out) { |
| 29 | dst += num_hyps * shape[1]; | 29 | dst += num_hyps * shape[1]; |
| 30 | continue; | 30 | continue; |
| @@ -86,17 +86,14 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -86,17 +86,14 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 86 | for (int32_t t = 0; t != num_frames; ++t) { | 86 | for (int32_t t = 0; t != num_frames; ++t) { |
| 87 | // Due to merging paths with identical token sequences, | 87 | // Due to merging paths with identical token sequences, |
| 88 | // not all utterances have "num_active_paths" paths. | 88 | // not all utterances have "num_active_paths" paths. |
| 89 | - int32_t hyps_num_acc = 0; | ||
| 90 | - std::vector<int32_t> hyps_num_split; | ||
| 91 | - hyps_num_split.push_back(0); | ||
| 92 | - | 89 | + auto hyps_row_splits = GetHypsRowSplits(cur); |
| 90 | + int32_t num_hyps = | ||
| 91 | + hyps_row_splits.back(); // total num hyps for all utterance | ||
| 93 | prev.clear(); | 92 | prev.clear(); |
| 94 | for (auto &hyps : cur) { | 93 | for (auto &hyps : cur) { |
| 95 | for (auto &h : hyps) { | 94 | for (auto &h : hyps) { |
| 96 | prev.push_back(std::move(h.second)); | 95 | prev.push_back(std::move(h.second)); |
| 97 | - hyps_num_acc++; | ||
| 98 | } | 96 | } |
| 99 | - hyps_num_split.push_back(hyps_num_acc); | ||
| 100 | } | 97 | } |
| 101 | cur.clear(); | 98 | cur.clear(); |
| 102 | cur.reserve(batch_size); | 99 | cur.reserve(batch_size); |
| @@ -104,30 +101,44 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -104,30 +101,44 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 104 | Ort::Value decoder_input = model_->BuildDecoderInput(prev); | 101 | Ort::Value decoder_input = model_->BuildDecoderInput(prev); |
| 105 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 102 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 106 | if (t == 0) { | 103 | if (t == 0) { |
| 107 | - UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(), | 104 | + UseCachedDecoderOut(hyps_row_splits, *result, model_->ContextSize(), |
| 108 | &decoder_out); | 105 | &decoder_out); |
| 109 | } | 106 | } |
| 110 | 107 | ||
| 111 | Ort::Value cur_encoder_out = | 108 | Ort::Value cur_encoder_out = |
| 112 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 109 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| 113 | cur_encoder_out = | 110 | cur_encoder_out = |
| 114 | - Repeat(model_->Allocator(), &cur_encoder_out, hyps_num_split); | 111 | + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); |
| 115 | Ort::Value logit = model_->RunJoiner( | 112 | Ort::Value logit = model_->RunJoiner( |
| 116 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 113 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); |
| 114 | + | ||
| 117 | float *p_logit = logit.GetTensorMutableData<float>(); | 115 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 116 | + LogSoftmax(p_logit, vocab_size, num_hyps); | ||
| 117 | + | ||
| 118 | + // now p_logit contains log_softmax output, we rename it to p_logprob | ||
| 119 | + // to match what it actually contains | ||
| 120 | + float *p_logprob = p_logit; | ||
| 121 | + | ||
| 122 | + // add log_prob of each hypothesis to p_logprob before taking top_k | ||
| 123 | + for (int32_t i = 0; i != num_hyps; ++i) { | ||
| 124 | + float log_prob = prev[i].log_prob; | ||
| 125 | + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { | ||
| 126 | + *p_logprob += log_prob; | ||
| 127 | + } | ||
| 128 | + } | ||
| 129 | + p_logprob = p_logit; // we changed p_logprob in the above for loop | ||
| 118 | 130 | ||
| 119 | - for (int32_t b = 0; b < batch_size; ++b) { | 131 | + for (int32_t b = 0; b != batch_size; ++b) { |
| 120 | int32_t frame_offset = (*result)[b].frame_offset; | 132 | int32_t frame_offset = (*result)[b].frame_offset; |
| 121 | - int32_t start = hyps_num_split[b]; | ||
| 122 | - int32_t end = hyps_num_split[b + 1]; | ||
| 123 | - LogSoftmax(p_logit, vocab_size, (end - start)); | 133 | + int32_t start = hyps_row_splits[b]; |
| 134 | + int32_t end = hyps_row_splits[b + 1]; | ||
| 124 | auto topk = | 135 | auto topk = |
| 125 | - TopkIndex(p_logit, vocab_size * (end - start), max_active_paths_); | 136 | + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); |
| 126 | 137 | ||
| 127 | Hypotheses hyps; | 138 | Hypotheses hyps; |
| 128 | - for (auto i : topk) { | ||
| 129 | - int32_t hyp_index = i / vocab_size + start; | ||
| 130 | - int32_t new_token = i % vocab_size; | 139 | + for (auto k : topk) { |
| 140 | + int32_t hyp_index = k / vocab_size + start; | ||
| 141 | + int32_t new_token = k % vocab_size; | ||
| 131 | 142 | ||
| 132 | Hypothesis new_hyp = prev[hyp_index]; | 143 | Hypothesis new_hyp = prev[hyp_index]; |
| 133 | if (new_token != 0) { | 144 | if (new_token != 0) { |
| @@ -137,12 +148,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -137,12 +148,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 137 | } else { | 148 | } else { |
| 138 | ++new_hyp.num_trailing_blanks; | 149 | ++new_hyp.num_trailing_blanks; |
| 139 | } | 150 | } |
| 140 | - new_hyp.log_prob += p_logit[i]; | 151 | + new_hyp.log_prob = p_logprob[k]; |
| 141 | hyps.Add(std::move(new_hyp)); | 152 | hyps.Add(std::move(new_hyp)); |
| 142 | - } | 153 | + } // for (auto k : topk) |
| 143 | cur.push_back(std::move(hyps)); | 154 | cur.push_back(std::move(hyps)); |
| 144 | - p_logit += vocab_size * (end - start); | ||
| 145 | - } | 155 | + p_logprob += (end - start) * vocab_size; |
| 156 | + } // for (int32_t b = 0; b != batch_size; ++b) | ||
| 146 | } | 157 | } |
| 147 | 158 | ||
| 148 | for (int32_t b = 0; b != batch_size; ++b) { | 159 | for (int32_t b = 0; b != batch_size; ++b) { |
-
请 注册 或 登录 后发表评论