spoken-language-identification-whisper-impl.h
3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class SpokenLanguageIdentificationWhisperImpl
: public SpokenLanguageIdentificationImpl {
public:
explicit SpokenLanguageIdentificationWhisperImpl(
const SpokenLanguageIdentificationConfig &config)
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
Check();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
}
std::string Compute(OfflineStream *s) const override {
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
// we use 50 here so that there will be some zero tail paddings
if (num_frames >= max_num_frames - 50) {
SHERPA_ONNX_LOGE(
"Only waves less than 30 seconds are supported. We process only the "
"first 30 seconds and discard the remaining data");
num_frames = max_num_frames - 50;
}
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
//
// Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token.
int32_t tail_padding_frames = 1000;
if (config_.whisper.tail_paddings > 0) {
tail_padding_frames = config_.whisper.tail_paddings;
}
int32_t actual_frames =
std::min(num_frames + tail_padding_frames, max_num_frames);
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
std::fill_n(p_mel + num_frames * feat_dim,
(actual_frames - num_frames) * feat_dim, 0);
mel = Transpose12(model_->Allocator(), &mel);
try {
auto cross_kv = model_->ForwardEncoder(std::move(mel));
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(lang_id)) {
return id2lang.at(lang_id);
} else {
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
lang_id);
return "";
}
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"input frames: %d, Current tail "
"paddings: %d. If you see a lot of such exceptions, please consider "
"using a larger --whisper-tail-paddings",
ex.what(), num_frames, tail_padding_frames);
return "";
}
}
private:
void Check() const {
if (!model_->IsMultiLingual()) {
SHERPA_ONNX_LOGE(
"Only whisper multilingual models can be used for spoken language "
"identification. Given: %s,%s",
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
exit(-1);
}
}
private:
SpokenLanguageIdentificationConfig config_;
std::unique_ptr<OfflineWhisperModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_