offline-transducer-greedy-search-decoder.cc
2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include <algorithm>
#include <iterator>
#include <utility>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
Ort::Value encoder_out_length,
OfflineStream **ss /*= nullptr*/,
int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
for (auto &r : ans) {
r.tokens.resize(context_size, -1);
// 0 is the ID of the blank token
r.tokens.back() = 0;
}
auto decoder_input = model_->BuildDecoderInput(ans, ans.size());
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
bool emitted = false;
for (int32_t i = 0; i != n; ++i) {
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
p_logit += vocab_size;
// blank id is hardcoded to 0
// also, it treats unk as blank
if (y != 0 && y != unk_id_) {
ans[i].tokens.push_back(y);
ans[i].timestamps.push_back(t);
emitted = true;
}
}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(ans, n);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
++t;
}
for (auto &r : ans) {
r.tokens = {r.tokens.begin() + context_size, r.tokens.end()};
}
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]);
}
return unsorted_ans;
}
} // namespace sherpa_onnx