offline-whisper-greedy-search-decoder.cc
4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
void OfflineWhisperGreedySearchDecoder::SetConfig(
const OfflineWhisperModelConfig &config) {
config_ = config;
}
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
if (model_->IsMultiLingual()) {
if (!config_.language.empty()) {
const auto &lang2id = model_->GetLang2ID();
if (!lang2id.count(config_.language)) {
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
exit(-1);
}
int32_t lang_id = lang2id.at(config_.language);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
} else {
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
}
if (config_.task == "translate") {
initial_tokens[2] = model_->Translate();
} else if (config_.task != "transcribe") {
// initial_tokens[2] is transcribe by default
SHERPA_ONNX_LOGE(
"Unsupported task: %s. Valid values are: transcribe, translate.",
config_.task.c_str());
}
}
initial_tokens.push_back(model_->NoTimeStampsToken());
int32_t batch_size = 1;
std::array<int64_t, 2> token_shape{
batch_size, static_cast<int64_t>(initial_tokens.size())};
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, initial_tokens.data(), initial_tokens.size(),
token_shape.data(), token_shape.size());
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto self_kv_cache = model_->GetInitialSelfKVCache();
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
initial_tokens.size();
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];
const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size;
int32_t max_token_id = static_cast<int32_t>(
std::distance(p_start, std::max_element(p_start, p_start + vocab_size)));
int32_t n_text_ctx = model_->TextCtx();
std::vector<int32_t> predicted_tokens;
for (int32_t i = 0; i < n_text_ctx; ++i) {
if (max_token_id == model_->EOT()) {
break;
}
predicted_tokens.push_back(max_token_id);
std::array<int64_t, 2> token_shape{1, 1};
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), token_shape.data(), token_shape.size());
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
p_tokens[0] = max_token_id;
decoder_out = model_->ForwardDecoder(std::move(tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
*p_offset += 1;
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
max_token_id = static_cast<int64_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
}
std::vector<OfflineWhisperDecoderResult> ans(1);
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
}
ans[0].tokens = std::move(predicted_tokens);
return ans;
}
} // namespace sherpa_onnx