speaker-embedding-extractor-nemo-impl.h
4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorNeMoImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorNeMoImpl(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: model_(mgr, config) {}
#endif
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
const auto &meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.feature_dim = meta_data.feat_dim;
feat_config.normalize_samples = true;
feat_config.snip_edges = true;
feat_config.frame_shift_ms = meta_data.window_stride_ms;
feat_config.frame_length_ms = meta_data.window_size_ms;
feat_config.low_freq = 0;
feat_config.is_librosa = true;
feat_config.remove_dc_offset = false;
feat_config.window_type = meta_data.window_type;
return std::make_unique<OnlineStream>(feat_config);
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}
std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}
std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
s->GetNumProcessedFrames() += num_frames;
int32_t feat_dim = features.size() / num_frames;
const auto &meta_data = model_.GetMetaData();
if (!meta_data.feature_normalize_type.empty()) {
if (meta_data.feature_normalize_type == "per_feature") {
NormalizePerFeature(features.data(), num_frames, feat_dim);
} else {
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
exit(-1);
}
}
if (num_frames % 16 != 0) {
int32_t pad = 16 - num_frames % 16;
features.resize((num_frames + pad) * feat_dim);
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
x = Transpose12(model_.Allocator(), &x);
int64_t x_lens = num_frames;
std::array<int64_t, 1> x_lens_shape{1};
Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());
Ort::Value embedding =
model_.Compute(std::move(x), std::move(x_lens_tensor));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();
std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());
return ans;
}
private:
void NormalizePerFeature(float *p, int32_t num_frames,
int32_t feat_dim) const {
auto m = Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
p, num_frames, feat_dim);
auto EX = m.colwise().mean();
auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();
m = (m.rowwise() - EX).array().rowwise() / stddev.array();
}
private:
SpeakerEmbeddingExtractorNeMoModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_