online-transducer-modified-beam-search-decoder-rknn.cc
4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
namespace sherpa_onnx {
OnlineTransducerDecoderResultRknn
OnlineTransducerModifiedBeamSearchDecoderRknn::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResultRknn r;
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = blank_id;
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
r.tokens = std::move(blanks);
return r;
}
void OnlineTransducerModifiedBeamSearchDecoderRknn::StripLeadingBlanks(
OnlineTransducerDecoderResultRknn *r) const {
int32_t context_size = model_->ContextSize();
auto hyp = r->hyps.GetMostProbable(true);
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
r->num_trailing_blanks = hyp.num_trailing_blanks;
}
static std::vector<std::vector<float>> GetDecoderOut(
OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec) {
std::vector<std::vector<float>> ans;
ans.reserve(hyp_vec.Size());
int32_t context_size = model->ContextSize();
for (const auto &p : hyp_vec) {
const auto &hyp = p.second;
auto start = hyp.ys.begin() + (hyp.ys.size() - context_size);
auto end = hyp.ys.end();
auto tokens = std::vector<int64_t>(start, end);
auto decoder_out = model->RunDecoder(std::move(tokens));
ans.push_back(std::move(decoder_out));
}
return ans;
}
static std::vector<std::vector<float>> GetJoinerOutLogSoftmax(
OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out,
const std::vector<std::vector<float>> &decoder_out) {
std::vector<std::vector<float>> ans;
ans.reserve(decoder_out.size());
for (const auto &d : decoder_out) {
auto joiner_out = model->RunJoiner(p_encoder_out, d.data());
LogSoftmax(joiner_out.data(), joiner_out.size());
ans.push_back(std::move(joiner_out));
}
return ans;
}
void OnlineTransducerModifiedBeamSearchDecoderRknn::Decode(
std::vector<float> encoder_out,
OnlineTransducerDecoderResultRknn *result) const {
auto &r = result[0];
auto attr = model_->GetEncoderOutAttr();
int32_t num_frames = attr.dims[1];
int32_t encoder_out_dim = attr.dims[2];
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
Hypotheses cur = std::move(result->hyps);
std::vector<Hypothesis> prev;
auto decoder_out = std::move(result->previous_decoder_out2);
if (decoder_out.empty()) {
decoder_out = GetDecoderOut(model_, cur);
}
const float *p_encoder_out = encoder_out.data();
int32_t frame_offset = result->frame_offset;
for (int32_t t = 0; t != num_frames; ++t) {
prev = cur.Vec();
cur.Clear();
auto log_probs = GetJoinerOutLogSoftmax(model_, p_encoder_out, decoder_out);
p_encoder_out += encoder_out_dim;
for (int32_t i = 0; i != prev.size(); ++i) {
auto log_prob = prev[i].log_prob;
for (auto &p : log_probs[i]) {
p += log_prob;
}
}
auto topk = TopkIndex(log_probs, max_active_paths_);
for (auto k : topk) {
int32_t hyp_index = k / vocab_size;
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
new_hyp.log_prob = log_probs[hyp_index][new_token];
// blank is hardcoded to 0
// also, it treats unk as blank
if (new_token != 0 && new_token != unk_id_) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
} else {
++new_hyp.num_trailing_blanks;
}
cur.Add(std::move(new_hyp));
}
decoder_out = GetDecoderOut(model_, cur);
}
result->hyps = std::move(cur);
result->frame_offset += num_frames;
result->previous_decoder_out2 = std::move(decoder_out);
}
} // namespace sherpa_onnx