audio-tagging-ced-impl.h
3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ced-model.h"
namespace sherpa_onnx {
class AudioTaggingCEDImpl : public AudioTaggingImpl {
public:
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
: config_(config), model_(config.model), labels_(config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#if __ANDROID_API__ >= 9
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
const AudioTaggingConfig &config)
: config_(config),
model_(mgr, config.model),
labels_(mgr, config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(CEDTag{});
}
std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const override {
if (top_k < 0) {
top_k = config_.top_k;
}
int32_t num_event_classes = model_.NumEventClasses();
if (top_k > num_event_classes) {
top_k = num_event_classes;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// WARNING(fangjun): It is fixed to 64 for CED models
int32_t feat_dim = 64;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == static_cast<int32_t>(f.size()));
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
Ort::Value probs = model_.Forward(std::move(x));
const float *p = probs.GetTensorData<float>();
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
std::vector<AudioEvent> ans(top_k);
int32_t i = 0;
for (int32_t index : top_k_indexes) {
ans[i].name = labels_.GetEventName(index);
ans[i].index = index;
ans[i].prob = p[index];
i += 1;
}
return ans;
}
private:
AudioTaggingConfig config_;
OfflineCEDModel model_;
AudioTaggingLabels labels_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_