正在显示
5 个修改的文件
包含
10 行增加
和
5 行删除
| @@ -30,8 +30,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | @@ -30,8 +30,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | ||
| 30 | 30 | ||
| 31 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); | 31 | std::vector<OfflineTransducerDecoderResult> ans(batch_size); |
| 32 | for (auto &r : ans) { | 32 | for (auto &r : ans) { |
| 33 | + r.tokens.resize(context_size, -1); | ||
| 33 | // 0 is the ID of the blank token | 34 | // 0 is the ID of the blank token |
| 34 | - r.tokens.resize(context_size, 0); | 35 | + r.tokens.back() = 0; |
| 35 | } | 36 | } |
| 36 | 37 | ||
| 37 | auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); | 38 | auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); |
| @@ -32,7 +32,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -32,7 +32,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 32 | int32_t vocab_size = model_->VocabSize(); | 32 | int32_t vocab_size = model_->VocabSize(); |
| 33 | int32_t context_size = model_->ContextSize(); | 33 | int32_t context_size = model_->ContextSize(); |
| 34 | 34 | ||
| 35 | - std::vector<int64_t> blanks(context_size, 0); | 35 | + std::vector<int64_t> blanks(context_size, -1); |
| 36 | + blanks.back() = 0; | ||
| 36 | 37 | ||
| 37 | std::deque<Hypotheses> finalized; | 38 | std::deque<Hypotheses> finalized; |
| 38 | std::vector<Hypotheses> cur; | 39 | std::vector<Hypotheses> cur; |
| @@ -55,7 +55,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { | @@ -55,7 +55,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { | ||
| 55 | int32_t context_size = model_->ContextSize(); | 55 | int32_t context_size = model_->ContextSize(); |
| 56 | int32_t blank_id = 0; // always 0 | 56 | int32_t blank_id = 0; // always 0 |
| 57 | OnlineTransducerDecoderResult r; | 57 | OnlineTransducerDecoderResult r; |
| 58 | - r.tokens.resize(context_size, blank_id); | 58 | + r.tokens.resize(context_size, -1); |
| 59 | + r.tokens.back() = blank_id; | ||
| 59 | 60 | ||
| 60 | return r; | 61 | return r; |
| 61 | } | 62 | } |
| @@ -42,7 +42,9 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | @@ -42,7 +42,9 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | ||
| 42 | int32_t context_size = model_->ContextSize(); | 42 | int32_t context_size = model_->ContextSize(); |
| 43 | int32_t blank_id = 0; // always 0 | 43 | int32_t blank_id = 0; // always 0 |
| 44 | OnlineTransducerDecoderResult r; | 44 | OnlineTransducerDecoderResult r; |
| 45 | - std::vector<int64_t> blanks(context_size, blank_id); | 45 | + std::vector<int64_t> blanks(context_size, -1); |
| 46 | + blanks.back() = blank_id; | ||
| 47 | + | ||
| 46 | Hypotheses blank_hyp({{blanks, 0}}); | 48 | Hypotheses blank_hyp({{blanks, 0}}); |
| 47 | r.hyps = std::move(blank_hyp); | 49 | r.hyps = std::move(blank_hyp); |
| 48 | r.tokens = std::move(blanks); | 50 | r.tokens = std::move(blanks); |
-
请 注册 或 登录 后发表评论