Fangjun Kuang
Committed by GitHub

Support audio tagging using zipformer (#747)

#!/usr/bin/env bash
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run zipformer for audio tagging "
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09
ls -lh $repo
for w in 1.wav 2.wav 3.wav 4.wav; do
$EXE \
--zipformer-model=$repo/model.onnx \
--labels=$repo/class_labels_indices.csv \
$repo/test_wavs/$w
done
rm -rf $repo
... ...
... ... @@ -15,6 +15,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -32,6 +33,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -124,6 +126,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-audio-tagging
.github/scripts/test-audio-tagging.sh
- name: Test online CTC
shell: bash
run: |
... ...
... ... @@ -15,6 +15,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -31,6 +32,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -103,6 +105,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-audio-tagging
.github/scripts/test-audio-tagging.sh
- name: Test C API
shell: bash
run: |
... ...
... ... @@ -14,6 +14,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -28,6 +29,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -70,6 +72,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-audio-tagging.exe
.github/scripts/test-audio-tagging.sh
- name: Test C API
shell: bash
run: |
... ...
... ... @@ -14,6 +14,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -28,6 +29,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -85,6 +87,13 @@ jobs:
# export EXE=sherpa-onnx-offline-language-identification.exe
#
# .github/scripts/test-spoken-language-identification.sh
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-audio-tagging.exe
.github/scripts/test-audio-tagging.sh
- name: Test online CTC
shell: bash
... ...
... ... @@ -46,6 +46,7 @@ def enable_alsa():
def get_binaries():
binaries = [
"sherpa-onnx",
"sherpa-onnx-offline-audio-tagging",
"sherpa-onnx-keyword-spotter",
"sherpa-onnx-microphone",
"sherpa-onnx-microphone-offline",
... ...
go.sum
vad-asr-paraformer
... ...
... ... @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx');
function createOfflineTts() {
let offlineTtsVitsModelConfig = {
model: './vits-icefall-zh-aishell3/vits-aishell3.onnx',
model: './vits-icefall-zh-aishell3/model.onnx',
lexicon: './vits-icefall-zh-aishell3/lexicon.txt',
tokens: './vits-icefall-zh-aishell3/tokens.txt',
dataDir: '',
... ...
... ... @@ -111,6 +111,16 @@ list(APPEND sources
speaker-embedding-manager.cc
)
# audio tagging
list(APPEND sources
audio-tagging-impl.cc
audio-tagging-label-file.cc
audio-tagging-model-config.cc
audio-tagging.cc
offline-zipformer-audio-tagging-model-config.cc
offline-zipformer-audio-tagging-model.cc
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
lexicon.cc
... ... @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)
if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
... ... @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-audio-tagging
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND main_exes
... ...
// sherpa-onnx/csrc/audio-tagging-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(config);
}
SHERPA_ONNX_LOG(
"Please specify an audio tagging model! Return a null pointer");
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/audio-tagging-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
#include <memory>
#include <vector>
#include "sherpa-onnx/csrc/audio-tagging.h"
namespace sherpa_onnx {
class AudioTaggingImpl {
public:
virtual ~AudioTaggingImpl() = default;
static std::unique_ptr<AudioTaggingImpl> Create(
const AudioTaggingConfig &config);
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
... ...
// sherpa-onnx/csrc/audio-tagging-label-file.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include <fstream>
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
std::ifstream is(filename);
Init(is);
}
// Format of a label file
/*
index,mid,display_name
0,/m/09x0r,"Speech"
1,/m/05zppz,"Male speech, man speaking"
*/
void AudioTaggingLabels::Init(std::istream &is) {
std::string line;
std::getline(is, line); // skip the header
std::string index;
std::string tmp;
std::string name;
while (std::getline(is, line)) {
index.clear();
name.clear();
std::istringstream input2(line);
std::getline(input2, index, ',');
std::getline(input2, tmp, ',');
std::getline(input2, name);
std::size_t pos{};
int32_t i = std::stoi(index, &pos);
if (index.size() == 0 || pos != index.size()) {
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
exit(-1);
}
if (i != names_.size()) {
SHERPA_ONNX_LOGE(
"Index should be sorted and contiguous. Expected index: %d, given: "
"%d.",
static_cast<int32_t>(names_.size()), i);
}
if (name.empty() || name.front() != '"' || name.back() != '"') {
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
exit(-1);
}
names_.emplace_back(name.begin() + 1, name.end() - 1);
}
}
const std::string &AudioTaggingLabels::GetEventName(int32_t index) const {
return names_.at(index);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/audio-tagging-label-file.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
#include <istream>
#include <string>
#include <vector>
namespace sherpa_onnx {
class AudioTaggingLabels {
public:
explicit AudioTaggingLabels(const std::string &filename);
// Return the event name for the given index.
// The returned reference is valid as long as this object is alive
const std::string &GetEventName(int32_t index) const;
int32_t NumEventClasses() const { return names_.size(); }
private:
void Init(std::istream &is);
private:
std::vector<std::string> names_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
... ...
// sherpa-onnx/csrc/audio-tagging-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
namespace sherpa_onnx {
void AudioTaggingModelConfig::Register(ParseOptions *po) {
zipformer.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool AudioTaggingModelConfig::Validate() const {
if (!zipformer.model.empty() && !zipformer.Validate()) {
return false;
}
return true;
}
std::string AudioTaggingModelConfig::ToString() const {
std::ostringstream os;
os << "AudioTaggingModelConfig(";
os << "zipformer=" << zipformer.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/audio-tagging-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct AudioTaggingModelConfig {
struct OfflineZipformerAudioTaggingModelConfig zipformer;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
AudioTaggingModelConfig() = default;
AudioTaggingModelConfig(
const OfflineZipformerAudioTaggingModelConfig &zipformer,
int32_t num_threads, bool debug, const std::string &provider)
: zipformer(zipformer),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
namespace sherpa_onnx {
class AudioTaggingZipformerImpl : public AudioTaggingImpl {
public:
explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config)
: config_(config), model_(config.model), labels_(config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>();
}
std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const override {
if (top_k < 0) {
top_k = config_.top_k;
}
int32_t num_event_classes = model_.NumEventClasses();
if (top_k > num_event_classes) {
top_k = num_event_classes;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// WARNING(fangjun): It is fixed to 80 for all models from icefall
int32_t feat_dim = 80;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
Ort::Value probs = model_.Forward(std::move(x), std::move(x_length));
const float *p = probs.GetTensorData<float>();
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
std::vector<AudioEvent> ans(top_k);
int32_t i = 0;
for (int32_t index : top_k_indexes) {
ans[i].name = labels_.GetEventName(index);
ans[i].index = index;
ans[i].prob = p[index];
i += 1;
}
return ans;
}
private:
AudioTaggingConfig config_;
OfflineZipformerAudioTaggingModel model_;
AudioTaggingLabels labels_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
... ...
// sherpa-onnx/csrc/audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::string AudioEvent::ToString() const {
std::ostringstream os;
os << "AudioEvent(";
os << "name=\"" << name << "\", ";
os << "index=" << index << ", ";
os << "prob=" << prob << ")";
return os.str();
}
void AudioTaggingConfig::Register(ParseOptions *po) {
model.Register(po);
po->Register("labels", &labels, "Event label file");
po->Register("top-k", &top_k, "Top k events to return in the result");
}
bool AudioTaggingConfig::Validate() const {
if (!model.Validate()) {
return false;
}
if (top_k < 1) {
SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k);
return false;
}
if (labels.empty()) {
SHERPA_ONNX_LOGE("Please provide --labels");
return false;
}
if (!FileExists(labels)) {
SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str());
return false;
}
return true;
}
std::string AudioTaggingConfig::ToString() const {
std::ostringstream os;
os << "AudioTaggingConfig(";
os << "model=" << model.ToString() << ", ";
os << "labels=\"" << labels << "\", ";
os << "top_k=" << top_k << ")";
return os.str();
}
AudioTagging::AudioTagging(const AudioTaggingConfig &config)
: impl_(AudioTaggingImpl::Create(config)) {}
AudioTagging::~AudioTagging() = default;
std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
return impl_->CreateStream();
}
std::vector<AudioEvent> AudioTagging::Compute(OfflineStream *s,
int32_t top_k /*= -1*/) const {
return impl_->Compute(s, top_k);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/audio-tagging.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct AudioTaggingConfig {
AudioTaggingModelConfig model;
std::string labels;
int32_t top_k = 5;
AudioTaggingConfig() = default;
AudioTaggingConfig(const AudioTaggingModelConfig &model,
const std::string &labels, int32_t top_k)
: model(model), labels(labels), top_k(top_k) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct AudioEvent {
std::string name; // name of the event
int32_t index; // index of the event in the label file
float prob; // probability of the event
std::string ToString() const;
};
class AudioTaggingImpl;
class AudioTagging {
public:
explicit AudioTagging(const AudioTaggingConfig &config);
~AudioTagging();
std::unique_ptr<OfflineStream> CreateStream() const;
// If top_k is -1, then config.top_k is used.
// Otherwise, config.top_k is ignored
//
// Return top_k AudioEvent. ans[0].prob is the largest of all returned events.
std::vector<AudioEvent> Compute(OfflineStream *s, int32_t top_k = -1) const;
private:
std::unique_ptr<AudioTaggingImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
... ...
... ... @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
}
template <typename T>
void SubtractBlank(T *in, int32_t w, int32_t h,
int32_t blank_idx, float blank_penalty) {
void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx,
float blank_penalty) {
for (int32_t i = 0; i != h; ++i) {
in[blank_idx] -= blank_penalty;
in += w;
... ... @@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
});
int32_t k_num = std::min<int32_t>(size, topk);
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
return index;
return {vec_index.begin(), vec_index.begin() + k_num};
}
} // namespace sherpa_onnx
... ...
... ... @@ -234,7 +234,7 @@ OfflineStream::OfflineStream(
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= nullptr*/)
ContextGraphPtr context_graph /*= {}*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::~OfflineStream() = default;
... ...
... ... @@ -71,10 +71,9 @@ struct WhisperTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph = nullptr);
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
~OfflineStream();
/**
... ...
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) {
po->Register("zipformer-model", &model,
"Path to zipformer model for audio tagging");
}
bool OfflineZipformerAudioTaggingModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipformer-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str());
return false;
}
return true;
}
std::string OfflineZipformerAudioTaggingModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineZipformerAudioTaggingModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineZipformerAudioTaggingModelConfig {
std::string model;
OfflineZipformerAudioTaggingModelConfig() = default;
explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineZipformerAudioTaggingModel::Impl {
public:
explicit Impl(const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.zipformer.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.zipformer.model);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Forward(Ort::Value features, Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(ans[0]);
}
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
// get num_event_classes from the output[0].shape,
// which is (N, num_event_classes)
num_event_classes_ =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
}
private:
AudioTaggingModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t num_event_classes_ = 0;
};
OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
AAssetManager *mgr, const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() =
default;
Ort::Value OfflineZipformerAudioTaggingModel::Forward(
Ort::Value features, Ort::Value features_length) const {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const {
return impl_->NumEventClasses();
}
OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
#include <memory>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
namespace sherpa_onnx {
/** This class implements the zipformer CTC model of the librispeech recipe
* from icefall.
*
* See
* https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py
*/
class OfflineZipformerAudioTaggingModel {
public:
explicit OfflineZipformerAudioTaggingModel(
const AudioTaggingModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineZipformerAudioTaggingModel(AAssetManager *mgr,
const AudioTaggingModelConfig &config);
#endif
~OfflineZipformerAudioTaggingModel();
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a tensor
* - probs: A 2-D tensor of shape (N, num_event_classes).
*/
Ort::Value Forward(Ort::Value features, Ort::Value features_length) const;
/** Return the number of event classes of the model
*/
int32_t NumEventClasses() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
... ...
... ... @@ -4,6 +4,8 @@
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
... ...
... ... @@ -4,7 +4,6 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
... ...
... ... @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
#if SHERPA_ONNX_ENABLE_TTS
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
#endif
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config) {
... ... @@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions(
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
... ... @@ -6,15 +6,19 @@
#define SHERPA_ONNX_CSRC_SESSION_H_
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
#if SHERPA_ONNX_ENABLE_TTS
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#endif
namespace sherpa_onnx {
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
... ... @@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
#if SHERPA_ONNX_ENABLE_TTS
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
#endif
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
... ... @@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions(
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config);
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <stdio.h>
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
int32_t main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Audio tagging from a file.
Usage:
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
./bin/sherpa-onnx-offline-audio-tagging \
--zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \
--labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \
sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav
Input wave files should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please see
https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
for more models.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::AudioTaggingConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 1) {
fprintf(stderr, "\nError: Please provide 1 wave file\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::AudioTagging tagger(config);
std::string wav_filename = po.GetArg(1);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
const float duration = samples.size() / static_cast<float>(sampling_rate);
fprintf(stderr, "Start to compute\n");
const auto begin = std::chrono::steady_clock::now();
auto stream = tagger.CreateStream();
stream->AcceptWaveform(sampling_rate, samples.data(), samples.size());
auto results = tagger.Compute(stream.get());
const auto end = std::chrono::steady_clock::now();
fprintf(stderr, "Done\n");
int32_t i = 0;
for (const auto &event : results) {
fprintf(stderr, "%d: %s\n", i, event.ToString().c_str());
i += 1;
}
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
fprintf(stderr, "Wave duration: %.3f\n", duration);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...