audio-tagging.cc
1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// sherpa-onnx/csrc/audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::string AudioEvent::ToString() const {
std::ostringstream os;
os << "AudioEvent(";
os << "name=\"" << name << "\", ";
os << "index=" << index << ", ";
os << "prob=" << prob << ")";
return os.str();
}
void AudioTaggingConfig::Register(ParseOptions *po) {
model.Register(po);
po->Register("labels", &labels, "Event label file");
po->Register("top-k", &top_k, "Top k events to return in the result");
}
bool AudioTaggingConfig::Validate() const {
if (!model.Validate()) {
return false;
}
if (top_k < 1) {
SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k);
return false;
}
if (labels.empty()) {
SHERPA_ONNX_LOGE("Please provide --labels");
return false;
}
if (!FileExists(labels)) {
SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str());
return false;
}
return true;
}
std::string AudioTaggingConfig::ToString() const {
std::ostringstream os;
os << "AudioTaggingConfig(";
os << "model=" << model.ToString() << ", ";
os << "labels=\"" << labels << "\", ";
os << "top_k=" << top_k << ")";
return os.str();
}
AudioTagging::AudioTagging(const AudioTaggingConfig &config)
: impl_(AudioTaggingImpl::Create(config)) {}
AudioTagging::~AudioTagging() = default;
std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
return impl_->CreateStream();
}
std::vector<AudioEvent> AudioTagging::Compute(OfflineStream *s,
int32_t top_k /*= -1*/) const {
return impl_->Compute(s, top_k);
}
} // namespace sherpa_onnx