Fangjun Kuang
Committed by GitHub

Support audio tagging using zipformer (#747)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run zipformer for audio tagging "
  18 +log "------------------------------------------------------------"
  19 +
  20 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  21 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  22 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  23 +repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09
  24 +ls -lh $repo
  25 +
  26 +for w in 1.wav 2.wav 3.wav 4.wav; do
  27 + $EXE \
  28 + --zipformer-model=$repo/model.onnx \
  29 + --labels=$repo/class_labels_indices.csv \
  30 + $repo/test_wavs/$w
  31 +done
  32 +rm -rf $repo
@@ -15,6 +15,7 @@ on: @@ -15,6 +15,7 @@ on:
15 - '.github/scripts/test-offline-ctc.sh' 15 - '.github/scripts/test-offline-ctc.sh'
16 - '.github/scripts/test-online-ctc.sh' 16 - '.github/scripts/test-online-ctc.sh'
17 - '.github/scripts/test-offline-tts.sh' 17 - '.github/scripts/test-offline-tts.sh'
  18 + - '.github/scripts/test-audio-tagging.sh'
18 - 'CMakeLists.txt' 19 - 'CMakeLists.txt'
19 - 'cmake/**' 20 - 'cmake/**'
20 - 'sherpa-onnx/csrc/*' 21 - 'sherpa-onnx/csrc/*'
@@ -32,6 +33,7 @@ on: @@ -32,6 +33,7 @@ on:
32 - '.github/scripts/test-offline-ctc.sh' 33 - '.github/scripts/test-offline-ctc.sh'
33 - '.github/scripts/test-online-ctc.sh' 34 - '.github/scripts/test-online-ctc.sh'
34 - '.github/scripts/test-offline-tts.sh' 35 - '.github/scripts/test-offline-tts.sh'
  36 + - '.github/scripts/test-audio-tagging.sh'
35 - 'CMakeLists.txt' 37 - 'CMakeLists.txt'
36 - 'cmake/**' 38 - 'cmake/**'
37 - 'sherpa-onnx/csrc/*' 39 - 'sherpa-onnx/csrc/*'
@@ -124,6 +126,14 @@ jobs: @@ -124,6 +126,14 @@ jobs:
124 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 126 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
125 path: build/bin/* 127 path: build/bin/*
126 128
  129 + - name: Test Audio tagging
  130 + shell: bash
  131 + run: |
  132 + export PATH=$PWD/build/bin:$PATH
  133 + export EXE=sherpa-onnx-offline-audio-tagging
  134 +
  135 + .github/scripts/test-audio-tagging.sh
  136 +
127 - name: Test online CTC 137 - name: Test online CTC
128 shell: bash 138 shell: bash
129 run: | 139 run: |
@@ -15,6 +15,7 @@ on: @@ -15,6 +15,7 @@ on:
15 - '.github/scripts/test-offline-ctc.sh' 15 - '.github/scripts/test-offline-ctc.sh'
16 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
17 - '.github/scripts/test-online-ctc.sh' 17 - '.github/scripts/test-online-ctc.sh'
  18 + - '.github/scripts/test-audio-tagging.sh'
18 - 'CMakeLists.txt' 19 - 'CMakeLists.txt'
19 - 'cmake/**' 20 - 'cmake/**'
20 - 'sherpa-onnx/csrc/*' 21 - 'sherpa-onnx/csrc/*'
@@ -31,6 +32,7 @@ on: @@ -31,6 +32,7 @@ on:
31 - '.github/scripts/test-offline-ctc.sh' 32 - '.github/scripts/test-offline-ctc.sh'
32 - '.github/scripts/test-offline-tts.sh' 33 - '.github/scripts/test-offline-tts.sh'
33 - '.github/scripts/test-online-ctc.sh' 34 - '.github/scripts/test-online-ctc.sh'
  35 + - '.github/scripts/test-audio-tagging.sh'
34 - 'CMakeLists.txt' 36 - 'CMakeLists.txt'
35 - 'cmake/**' 37 - 'cmake/**'
36 - 'sherpa-onnx/csrc/*' 38 - 'sherpa-onnx/csrc/*'
@@ -103,6 +105,14 @@ jobs: @@ -103,6 +105,14 @@ jobs:
103 otool -L build/bin/sherpa-onnx 105 otool -L build/bin/sherpa-onnx
104 otool -l build/bin/sherpa-onnx 106 otool -l build/bin/sherpa-onnx
105 107
  108 + - name: Test Audio tagging
  109 + shell: bash
  110 + run: |
  111 + export PATH=$PWD/build/bin:$PATH
  112 + export EXE=sherpa-onnx-offline-audio-tagging
  113 +
  114 + .github/scripts/test-audio-tagging.sh
  115 +
106 - name: Test C API 116 - name: Test C API
107 shell: bash 117 shell: bash
108 run: | 118 run: |
@@ -14,6 +14,7 @@ on: @@ -14,6 +14,7 @@ on:
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
15 - '.github/scripts/test-online-ctc.sh' 15 - '.github/scripts/test-online-ctc.sh'
16 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
  17 + - '.github/scripts/test-audio-tagging.sh'
17 - 'CMakeLists.txt' 18 - 'CMakeLists.txt'
18 - 'cmake/**' 19 - 'cmake/**'
19 - 'sherpa-onnx/csrc/*' 20 - 'sherpa-onnx/csrc/*'
@@ -28,6 +29,7 @@ on: @@ -28,6 +29,7 @@ on:
28 - '.github/scripts/test-offline-ctc.sh' 29 - '.github/scripts/test-offline-ctc.sh'
29 - '.github/scripts/test-online-ctc.sh' 30 - '.github/scripts/test-online-ctc.sh'
30 - '.github/scripts/test-offline-tts.sh' 31 - '.github/scripts/test-offline-tts.sh'
  32 + - '.github/scripts/test-audio-tagging.sh'
31 - 'CMakeLists.txt' 33 - 'CMakeLists.txt'
32 - 'cmake/**' 34 - 'cmake/**'
33 - 'sherpa-onnx/csrc/*' 35 - 'sherpa-onnx/csrc/*'
@@ -70,6 +72,14 @@ jobs: @@ -70,6 +72,14 @@ jobs:
70 72
71 ls -lh ./bin/Release/sherpa-onnx.exe 73 ls -lh ./bin/Release/sherpa-onnx.exe
72 74
  75 + - name: Test Audio tagging
  76 + shell: bash
  77 + run: |
  78 + export PATH=$PWD/build/bin/Release:$PATH
  79 + export EXE=sherpa-onnx-offline-audio-tagging.exe
  80 +
  81 + .github/scripts/test-audio-tagging.sh
  82 +
73 - name: Test C API 83 - name: Test C API
74 shell: bash 84 shell: bash
75 run: | 85 run: |
@@ -14,6 +14,7 @@ on: @@ -14,6 +14,7 @@ on:
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 15 - '.github/scripts/test-offline-tts.sh'
16 - '.github/scripts/test-online-ctc.sh' 16 - '.github/scripts/test-online-ctc.sh'
  17 + - '.github/scripts/test-audio-tagging.sh'
17 - 'CMakeLists.txt' 18 - 'CMakeLists.txt'
18 - 'cmake/**' 19 - 'cmake/**'
19 - 'sherpa-onnx/csrc/*' 20 - 'sherpa-onnx/csrc/*'
@@ -28,6 +29,7 @@ on: @@ -28,6 +29,7 @@ on:
28 - '.github/scripts/test-offline-ctc.sh' 29 - '.github/scripts/test-offline-ctc.sh'
29 - '.github/scripts/test-offline-tts.sh' 30 - '.github/scripts/test-offline-tts.sh'
30 - '.github/scripts/test-online-ctc.sh' 31 - '.github/scripts/test-online-ctc.sh'
  32 + - '.github/scripts/test-audio-tagging.sh'
31 - 'CMakeLists.txt' 33 - 'CMakeLists.txt'
32 - 'cmake/**' 34 - 'cmake/**'
33 - 'sherpa-onnx/csrc/*' 35 - 'sherpa-onnx/csrc/*'
@@ -85,6 +87,13 @@ jobs: @@ -85,6 +87,13 @@ jobs:
85 # export EXE=sherpa-onnx-offline-language-identification.exe 87 # export EXE=sherpa-onnx-offline-language-identification.exe
86 # 88 #
87 # .github/scripts/test-spoken-language-identification.sh 89 # .github/scripts/test-spoken-language-identification.sh
  90 + - name: Test Audio tagging
  91 + shell: bash
  92 + run: |
  93 + export PATH=$PWD/build/bin/Release:$PATH
  94 + export EXE=sherpa-onnx-offline-audio-tagging.exe
  95 +
  96 + .github/scripts/test-audio-tagging.sh
88 97
89 - name: Test online CTC 98 - name: Test online CTC
90 shell: bash 99 shell: bash
@@ -46,6 +46,7 @@ def enable_alsa(): @@ -46,6 +46,7 @@ def enable_alsa():
46 def get_binaries(): 46 def get_binaries():
47 binaries = [ 47 binaries = [
48 "sherpa-onnx", 48 "sherpa-onnx",
  49 + "sherpa-onnx-offline-audio-tagging",
49 "sherpa-onnx-keyword-spotter", 50 "sherpa-onnx-keyword-spotter",
50 "sherpa-onnx-microphone", 51 "sherpa-onnx-microphone",
51 "sherpa-onnx-microphone-offline", 52 "sherpa-onnx-microphone-offline",
  1 +go.sum
  2 +vad-asr-paraformer
@@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx'); @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx');
4 4
5 function createOfflineTts() { 5 function createOfflineTts() {
6 let offlineTtsVitsModelConfig = { 6 let offlineTtsVitsModelConfig = {
7 - model: './vits-icefall-zh-aishell3/vits-aishell3.onnx', 7 + model: './vits-icefall-zh-aishell3/model.onnx',
8 lexicon: './vits-icefall-zh-aishell3/lexicon.txt', 8 lexicon: './vits-icefall-zh-aishell3/lexicon.txt',
9 tokens: './vits-icefall-zh-aishell3/tokens.txt', 9 tokens: './vits-icefall-zh-aishell3/tokens.txt',
10 dataDir: '', 10 dataDir: '',
@@ -111,6 +111,16 @@ list(APPEND sources @@ -111,6 +111,16 @@ list(APPEND sources
111 speaker-embedding-manager.cc 111 speaker-embedding-manager.cc
112 ) 112 )
113 113
  114 +# audio tagging
  115 +list(APPEND sources
  116 + audio-tagging-impl.cc
  117 + audio-tagging-label-file.cc
  118 + audio-tagging-model-config.cc
  119 + audio-tagging.cc
  120 + offline-zipformer-audio-tagging-model-config.cc
  121 + offline-zipformer-audio-tagging-model.cc
  122 +)
  123 +
114 if(SHERPA_ONNX_ENABLE_TTS) 124 if(SHERPA_ONNX_ENABLE_TTS)
115 list(APPEND sources 125 list(APPEND sources
116 lexicon.cc 126 lexicon.cc
@@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
193 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 203 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
194 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 204 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
195 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) 205 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
  206 + add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)
196 207
197 if(SHERPA_ONNX_ENABLE_TTS) 208 if(SHERPA_ONNX_ENABLE_TTS)
198 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 209 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
@@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
204 sherpa-onnx-offline 215 sherpa-onnx-offline
205 sherpa-onnx-offline-parallel 216 sherpa-onnx-offline-parallel
206 sherpa-onnx-offline-language-identification 217 sherpa-onnx-offline-language-identification
  218 + sherpa-onnx-offline-audio-tagging
207 ) 219 )
208 if(SHERPA_ONNX_ENABLE_TTS) 220 if(SHERPA_ONNX_ENABLE_TTS)
209 list(APPEND main_exes 221 list(APPEND main_exes
  1 +// sherpa-onnx/csrc/audio-tagging-impl.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/audio-tagging-impl.h"
  6 +
  7 +#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
  13 + const AudioTaggingConfig &config) {
  14 + if (!config.model.zipformer.model.empty()) {
  15 + return std::make_unique<AudioTaggingZipformerImpl>(config);
  16 + }
  17 +
  18 + SHERPA_ONNX_LOG(
  19 + "Please specify an audio tagging model! Return a null pointer");
  20 + return nullptr;
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/audio-tagging-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/audio-tagging.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class AudioTaggingImpl {
  15 + public:
  16 + virtual ~AudioTaggingImpl() = default;
  17 +
  18 + static std::unique_ptr<AudioTaggingImpl> Create(
  19 + const AudioTaggingConfig &config);
  20 +
  21 + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
  22 +
  23 + virtual std::vector<AudioEvent> Compute(OfflineStream *s,
  24 + int32_t top_k = -1) const = 0;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
  1 +// sherpa-onnx/csrc/audio-tagging-label-file.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
  6 +
  7 +#include <fstream>
  8 +#include <sstream>
  9 +#include <string>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
  17 + std::ifstream is(filename);
  18 + Init(is);
  19 +}
  20 +
  21 +// Format of a label file
  22 +/*
  23 +index,mid,display_name
  24 +0,/m/09x0r,"Speech"
  25 +1,/m/05zppz,"Male speech, man speaking"
  26 +*/
  27 +void AudioTaggingLabels::Init(std::istream &is) {
  28 + std::string line;
  29 + std::getline(is, line); // skip the header
  30 +
  31 + std::string index;
  32 + std::string tmp;
  33 + std::string name;
  34 +
  35 + while (std::getline(is, line)) {
  36 + index.clear();
  37 + name.clear();
  38 + std::istringstream input2(line);
  39 +
  40 + std::getline(input2, index, ',');
  41 + std::getline(input2, tmp, ',');
  42 + std::getline(input2, name);
  43 +
  44 + std::size_t pos{};
  45 + int32_t i = std::stoi(index, &pos);
  46 + if (index.size() == 0 || pos != index.size()) {
  47 + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
  48 + exit(-1);
  49 + }
  50 +
  51 + if (i != names_.size()) {
  52 + SHERPA_ONNX_LOGE(
  53 + "Index should be sorted and contiguous. Expected index: %d, given: "
  54 + "%d.",
  55 + static_cast<int32_t>(names_.size()), i);
  56 + }
  57 + if (name.empty() || name.front() != '"' || name.back() != '"') {
  58 + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
  59 + exit(-1);
  60 + }
  61 +
  62 + names_.emplace_back(name.begin() + 1, name.end() - 1);
  63 + }
  64 +}
  65 +
  66 +const std::string &AudioTaggingLabels::GetEventName(int32_t index) const {
  67 + return names_.at(index);
  68 +}
  69 +
  70 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/audio-tagging-label-file.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
  6 +
  7 +#include <istream>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +class AudioTaggingLabels {
  14 + public:
  15 + explicit AudioTaggingLabels(const std::string &filename);
  16 +
  17 + // Return the event name for the given index.
  18 + // The returned reference is valid as long as this object is alive
  19 + const std::string &GetEventName(int32_t index) const;
  20 + int32_t NumEventClasses() const { return names_.size(); }
  21 +
  22 + private:
  23 + void Init(std::istream &is);
  24 +
  25 + private:
  26 + std::vector<std::string> names_;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
  1 +// sherpa-onnx/csrc/audio-tagging-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
  6 +
  7 +namespace sherpa_onnx {
  8 +
  9 +void AudioTaggingModelConfig::Register(ParseOptions *po) {
  10 + zipformer.Register(po);
  11 +
  12 + po->Register("num-threads", &num_threads,
  13 + "Number of threads to run the neural network");
  14 +
  15 + po->Register("debug", &debug,
  16 + "true to print model information while loading it.");
  17 +
  18 + po->Register("provider", &provider,
  19 + "Specify a provider to use: cpu, cuda, coreml");
  20 +}
  21 +
  22 +bool AudioTaggingModelConfig::Validate() const {
  23 + if (!zipformer.model.empty() && !zipformer.Validate()) {
  24 + return false;
  25 + }
  26 +
  27 + return true;
  28 +}
  29 +
  30 +std::string AudioTaggingModelConfig::ToString() const {
  31 + std::ostringstream os;
  32 +
  33 + os << "AudioTaggingModelConfig(";
  34 + os << "zipformer=" << zipformer.ToString() << ", ";
  35 + os << "num_threads=" << num_threads << ", ";
  36 + os << "debug=" << (debug ? "True" : "False") << ", ";
  37 + os << "provider=\"" << provider << "\")";
  38 +
  39 + return os.str();
  40 +}
  41 +
  42 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/audio-tagging-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct AudioTaggingModelConfig {
  15 + struct OfflineZipformerAudioTaggingModelConfig zipformer;
  16 +
  17 + int32_t num_threads = 1;
  18 + bool debug = false;
  19 + std::string provider = "cpu";
  20 +
  21 + AudioTaggingModelConfig() = default;
  22 +
  23 + AudioTaggingModelConfig(
  24 + const OfflineZipformerAudioTaggingModelConfig &zipformer,
  25 + int32_t num_threads, bool debug, const std::string &provider)
  26 + : zipformer(zipformer),
  27 + num_threads(num_threads),
  28 + debug(debug),
  29 + provider(provider) {}
  30 +
  31 + void Register(ParseOptions *po);
  32 + bool Validate() const;
  33 +
  34 + std::string ToString() const;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/audio-tagging-impl.h"
  12 +#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
  13 +#include "sherpa-onnx/csrc/audio-tagging.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +#include "sherpa-onnx/csrc/math.h"
  16 +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +class AudioTaggingZipformerImpl : public AudioTaggingImpl {
  21 + public:
  22 + explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config)
  23 + : config_(config), model_(config.model), labels_(config.labels) {
  24 + if (model_.NumEventClasses() != labels_.NumEventClasses()) {
  25 + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
  26 + model_.NumEventClasses(), labels_.NumEventClasses());
  27 + exit(-1);
  28 + }
  29 + }
  30 +
  31 + std::unique_ptr<OfflineStream> CreateStream() const override {
  32 + return std::make_unique<OfflineStream>();
  33 + }
  34 +
  35 + std::vector<AudioEvent> Compute(OfflineStream *s,
  36 + int32_t top_k = -1) const override {
  37 + if (top_k < 0) {
  38 + top_k = config_.top_k;
  39 + }
  40 +
  41 + int32_t num_event_classes = model_.NumEventClasses();
  42 +
  43 + if (top_k > num_event_classes) {
  44 + top_k = num_event_classes;
  45 + }
  46 +
  47 + auto memory_info =
  48 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  49 +
  50 + // WARNING(fangjun): It is fixed to 80 for all models from icefall
  51 + int32_t feat_dim = 80;
  52 + std::vector<float> f = s->GetFrames();
  53 +
  54 + int32_t num_frames = f.size() / feat_dim;
  55 +
  56 + std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
  57 +
  58 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  59 + shape.data(), shape.size());
  60 +
  61 + int64_t x_length_scalar = num_frames;
  62 + std::array<int64_t, 1> x_length_shape = {1};
  63 + Ort::Value x_length =
  64 + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
  65 + x_length_shape.data(), x_length_shape.size());
  66 +
  67 + Ort::Value probs = model_.Forward(std::move(x), std::move(x_length));
  68 +
  69 + const float *p = probs.GetTensorData<float>();
  70 +
  71 + std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
  72 +
  73 + std::vector<AudioEvent> ans(top_k);
  74 +
  75 + int32_t i = 0;
  76 +
  77 + for (int32_t index : top_k_indexes) {
  78 + ans[i].name = labels_.GetEventName(index);
  79 + ans[i].index = index;
  80 + ans[i].prob = p[index];
  81 + i += 1;
  82 + }
  83 +
  84 + return ans;
  85 + }
  86 +
  87 + private:
  88 + AudioTaggingConfig config_;
  89 + OfflineZipformerAudioTaggingModel model_;
  90 + AudioTaggingLabels labels_;
  91 +};
  92 +
  93 +} // namespace sherpa_onnx
  94 +
  95 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
  1 +// sherpa-onnx/csrc/audio-tagging.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/audio-tagging.h"
  6 +
  7 +#include "sherpa-onnx/csrc/audio-tagging-impl.h"
  8 +#include "sherpa-onnx/csrc/file-utils.h"
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +std::string AudioEvent::ToString() const {
  14 + std::ostringstream os;
  15 + os << "AudioEvent(";
  16 + os << "name=\"" << name << "\", ";
  17 + os << "index=" << index << ", ";
  18 + os << "prob=" << prob << ")";
  19 + return os.str();
  20 +}
  21 +
  22 +void AudioTaggingConfig::Register(ParseOptions *po) {
  23 + model.Register(po);
  24 + po->Register("labels", &labels, "Event label file");
  25 + po->Register("top-k", &top_k, "Top k events to return in the result");
  26 +}
  27 +
  28 +bool AudioTaggingConfig::Validate() const {
  29 + if (!model.Validate()) {
  30 + return false;
  31 + }
  32 +
  33 + if (top_k < 1) {
  34 + SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k);
  35 + return false;
  36 + }
  37 +
  38 + if (labels.empty()) {
  39 + SHERPA_ONNX_LOGE("Please provide --labels");
  40 + return false;
  41 + }
  42 +
  43 + if (!FileExists(labels)) {
  44 + SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str());
  45 + return false;
  46 + }
  47 +
  48 + return true;
  49 +}
  50 +std::string AudioTaggingConfig::ToString() const {
  51 + std::ostringstream os;
  52 +
  53 + os << "AudioTaggingConfig(";
  54 + os << "model=" << model.ToString() << ", ";
  55 + os << "labels=\"" << labels << "\", ";
  56 + os << "top_k=" << top_k << ")";
  57 +
  58 + return os.str();
  59 +}
  60 +
  61 +AudioTagging::AudioTagging(const AudioTaggingConfig &config)
  62 + : impl_(AudioTaggingImpl::Create(config)) {}
  63 +
  64 +AudioTagging::~AudioTagging() = default;
  65 +
  66 +std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
  67 + return impl_->CreateStream();
  68 +}
  69 +
  70 +std::vector<AudioEvent> AudioTagging::Compute(OfflineStream *s,
  71 + int32_t top_k /*= -1*/) const {
  72 + return impl_->Compute(s, top_k);
  73 +}
  74 +
  75 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/audio-tagging.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
  12 +#include "sherpa-onnx/csrc/offline-stream.h"
  13 +#include "sherpa-onnx/csrc/parse-options.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct AudioTaggingConfig {
  18 + AudioTaggingModelConfig model;
  19 + std::string labels;
  20 +
  21 + int32_t top_k = 5;
  22 +
  23 + AudioTaggingConfig() = default;
  24 +
  25 + AudioTaggingConfig(const AudioTaggingModelConfig &model,
  26 + const std::string &labels, int32_t top_k)
  27 + : model(model), labels(labels), top_k(top_k) {}
  28 +
  29 + void Register(ParseOptions *po);
  30 + bool Validate() const;
  31 +
  32 + std::string ToString() const;
  33 +};
  34 +
  35 +struct AudioEvent {
  36 + std::string name; // name of the event
  37 + int32_t index; // index of the event in the label file
  38 + float prob; // probability of the event
  39 +
  40 + std::string ToString() const;
  41 +};
  42 +
  43 +class AudioTaggingImpl;
  44 +
  45 +class AudioTagging {
  46 + public:
  47 + explicit AudioTagging(const AudioTaggingConfig &config);
  48 +
  49 + ~AudioTagging();
  50 +
  51 + std::unique_ptr<OfflineStream> CreateStream() const;
  52 +
  53 + // If top_k is -1, then config.top_k is used.
  54 + // Otherwise, config.top_k is ignored
  55 + //
  56 + // Return top_k AudioEvent. ans[0].prob is the largest of all returned events.
  57 + std::vector<AudioEvent> Compute(OfflineStream *s, int32_t top_k = -1) const;
  58 +
  59 + private:
  60 + std::unique_ptr<AudioTaggingImpl> impl_;
  61 +};
  62 +
  63 +} // namespace sherpa_onnx
  64 +
  65 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
@@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
97 } 97 }
98 98
99 template <typename T> 99 template <typename T>
100 -void SubtractBlank(T *in, int32_t w, int32_t h,  
101 - int32_t blank_idx, float blank_penalty) { 100 +void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx,
  101 + float blank_penalty) {
102 for (int32_t i = 0; i != h; ++i) { 102 for (int32_t i = 0; i != h; ++i) {
103 in[blank_idx] -= blank_penalty; 103 in[blank_idx] -= blank_penalty;
104 in += w; 104 in += w;
@@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { @@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
116 }); 116 });
117 117
118 int32_t k_num = std::min<int32_t>(size, topk); 118 int32_t k_num = std::min<int32_t>(size, topk);
119 - std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);  
120 - return index; 119 + return {vec_index.begin(), vec_index.begin() + k_num};
121 } 120 }
122 121
123 } // namespace sherpa_onnx 122 } // namespace sherpa_onnx
@@ -234,7 +234,7 @@ OfflineStream::OfflineStream( @@ -234,7 +234,7 @@ OfflineStream::OfflineStream(
234 : impl_(std::make_unique<Impl>(config, context_graph)) {} 234 : impl_(std::make_unique<Impl>(config, context_graph)) {}
235 235
236 OfflineStream::OfflineStream(WhisperTag tag, 236 OfflineStream::OfflineStream(WhisperTag tag,
237 - ContextGraphPtr context_graph /*= nullptr*/) 237 + ContextGraphPtr context_graph /*= {}*/)
238 : impl_(std::make_unique<Impl>(tag, context_graph)) {} 238 : impl_(std::make_unique<Impl>(tag, context_graph)) {}
239 239
240 OfflineStream::~OfflineStream() = default; 240 OfflineStream::~OfflineStream() = default;
@@ -71,10 +71,9 @@ struct WhisperTag {}; @@ -71,10 +71,9 @@ struct WhisperTag {};
71 class OfflineStream { 71 class OfflineStream {
72 public: 72 public:
73 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, 73 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
74 - ContextGraphPtr context_graph = nullptr); 74 + ContextGraphPtr context_graph = {});
75 75
76 - explicit OfflineStream(WhisperTag tag,  
77 - ContextGraphPtr context_graph = nullptr); 76 + explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
78 ~OfflineStream(); 77 ~OfflineStream();
79 78
80 /** 79 /**
  1 +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) {
  13 + po->Register("zipformer-model", &model,
  14 + "Path to zipformer model for audio tagging");
  15 +}
  16 +
  17 +bool OfflineZipformerAudioTaggingModelConfig::Validate() const {
  18 + if (model.empty()) {
  19 + SHERPA_ONNX_LOGE("Please provide --zipformer-model");
  20 + return false;
  21 + }
  22 +
  23 + if (!FileExists(model)) {
  24 + SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str());
  25 + return false;
  26 + }
  27 +
  28 + return true;
  29 +}
  30 +
  31 +std::string OfflineZipformerAudioTaggingModelConfig::ToString() const {
  32 + std::ostringstream os;
  33 +
  34 + os << "OfflineZipformerAudioTaggingModelConfig(";
  35 + os << "model=\"" << model << "\")";
  36 +
  37 + return os.str();
  38 +}
  39 +
  40 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineZipformerAudioTaggingModelConfig {
  14 + std::string model;
  15 +
  16 + OfflineZipformerAudioTaggingModelConfig() = default;
  17 +
  18 + explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model)
  19 + : model(model) {}
  20 +
  21 + void Register(ParseOptions *po);
  22 + bool Validate() const;
  23 +
  24 + std::string ToString() const;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +#include "sherpa-onnx/csrc/session.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineZipformerAudioTaggingModel::Impl {
  17 + public:
  18 + explicit Impl(const AudioTaggingModelConfig &config)
  19 + : config_(config),
  20 + env_(ORT_LOGGING_LEVEL_ERROR),
  21 + sess_opts_(GetSessionOptions(config)),
  22 + allocator_{} {
  23 + auto buf = ReadFile(config_.zipformer.model);
  24 + Init(buf.data(), buf.size());
  25 + }
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
  29 + : config_(config),
  30 + env_(ORT_LOGGING_LEVEL_ERROR),
  31 + sess_opts_(GetSessionOptions(config)),
  32 + allocator_{} {
  33 + auto buf = ReadFile(mgr, config_.zipformer.model);
  34 + Init(buf.data(), buf.size());
  35 + }
  36 +#endif
  37 +
  38 + Ort::Value Forward(Ort::Value features, Ort::Value features_length) {
  39 + std::array<Ort::Value, 2> inputs = {std::move(features),
  40 + std::move(features_length)};
  41 +
  42 + auto ans =
  43 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  44 + output_names_ptr_.data(), output_names_ptr_.size());
  45 + return std::move(ans[0]);
  46 + }
  47 +
  48 + int32_t NumEventClasses() const { return num_event_classes_; }
  49 +
  50 + OrtAllocator *Allocator() const { return allocator_; }
  51 +
  52 + private:
  53 + void Init(void *model_data, size_t model_data_length) {
  54 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  55 + sess_opts_);
  56 +
  57 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  58 +
  59 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  60 +
  61 + // get meta data
  62 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  63 + if (config_.debug) {
  64 + std::ostringstream os;
  65 + PrintModelMetadata(os, meta_data);
  66 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  67 + }
  68 +
  69 + // get num_event_classes from the output[0].shape,
  70 + // which is (N, num_event_classes)
  71 + num_event_classes_ =
  72 + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
  73 + }
  74 +
  75 + private:
  76 + AudioTaggingModelConfig config_;
  77 + Ort::Env env_;
  78 + Ort::SessionOptions sess_opts_;
  79 + Ort::AllocatorWithDefaultOptions allocator_;
  80 +
  81 + std::unique_ptr<Ort::Session> sess_;
  82 +
  83 + std::vector<std::string> input_names_;
  84 + std::vector<const char *> input_names_ptr_;
  85 +
  86 + std::vector<std::string> output_names_;
  87 + std::vector<const char *> output_names_ptr_;
  88 +
  89 + int32_t num_event_classes_ = 0;
  90 +};
  91 +
  92 +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
  93 + const AudioTaggingModelConfig &config)
  94 + : impl_(std::make_unique<Impl>(config)) {}
  95 +
  96 +#if __ANDROID_API__ >= 9
  97 +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
  98 + AAssetManager *mgr, const AudioTaggingModelConfig &config)
  99 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  100 +#endif
  101 +
  102 +OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() =
  103 + default;
  104 +
  105 +Ort::Value OfflineZipformerAudioTaggingModel::Forward(
  106 + Ort::Value features, Ort::Value features_length) const {
  107 + return impl_->Forward(std::move(features), std::move(features_length));
  108 +}
  109 +
  110 +int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const {
  111 + return impl_->NumEventClasses();
  112 +}
  113 +
  114 +OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const {
  115 + return impl_->Allocator();
  116 +}
  117 +
  118 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
  6 +#include <memory>
  7 +#include <utility>
  8 +
  9 +#if __ANDROID_API__ >= 9
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
  14 +#include "onnxruntime_cxx_api.h" // NOLINT
  15 +#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +/** This class implements the zipformer CTC model of the librispeech recipe
  20 + * from icefall.
  21 + *
  22 + * See
  23 + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py
  24 + */
  25 +class OfflineZipformerAudioTaggingModel {
  26 + public:
  27 + explicit OfflineZipformerAudioTaggingModel(
  28 + const AudioTaggingModelConfig &config);
  29 +
  30 +#if __ANDROID_API__ >= 9
  31 + OfflineZipformerAudioTaggingModel(AAssetManager *mgr,
  32 + const AudioTaggingModelConfig &config);
  33 +#endif
  34 +
  35 + ~OfflineZipformerAudioTaggingModel();
  36 +
  37 + /** Run the forward method of the model.
  38 + *
  39 + * @param features A tensor of shape (N, T, C).
  40 + * @param features_length A 1-D tensor of shape (N,) containing number of
  41 + * valid frames in `features` before padding.
  42 + * Its dtype is int64_t.
  43 + *
  44 + * @return Return a tensor
  45 + * - probs: A 2-D tensor of shape (N, num_event_classes).
  46 + */
  47 + Ort::Value Forward(Ort::Value features, Ort::Value features_length) const;
  48 +
  49 + /** Return the number of event classes of the model
  50 + */
  51 + int32_t NumEventClasses() const;
  52 +
  53 + /** Return an allocator for allocating memory
  54 + */
  55 + OrtAllocator *Allocator() const;
  56 +
  57 + private:
  58 + class Impl;
  59 + std::unique_ptr<Impl> impl_;
  60 +};
  61 +
  62 +} // namespace sherpa_onnx
  63 +
  64 +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
@@ -4,6 +4,8 @@ @@ -4,6 +4,8 @@
4 4
5 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" 5 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
6 6
  7 +#include <string>
  8 +
7 #include "sherpa-onnx/csrc/macros.h" 9 #include "sherpa-onnx/csrc/macros.h"
8 #include "sherpa-onnx/csrc/onnx-utils.h" 10 #include "sherpa-onnx/csrc/onnx-utils.h"
9 #include "sherpa-onnx/csrc/session.h" 11 #include "sherpa-onnx/csrc/session.h"
@@ -4,7 +4,6 @@ @@ -4,7 +4,6 @@
4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ 4 #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
5 #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
6 #include <memory> 6 #include <memory>
7 -#include <string>  
8 #include <utility> 7 #include <utility>
9 #include <vector> 8 #include <vector>
10 9
@@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
140 return GetSessionOptionsImpl(config.num_threads, config.provider); 140 return GetSessionOptionsImpl(config.num_threads, config.provider);
141 } 141 }
142 142
  143 +#if SHERPA_ONNX_ENABLE_TTS
143 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { 144 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
144 return GetSessionOptionsImpl(config.num_threads, config.provider); 145 return GetSessionOptionsImpl(config.num_threads, config.provider);
145 } 146 }
  147 +#endif
146 148
147 Ort::SessionOptions GetSessionOptions( 149 Ort::SessionOptions GetSessionOptions(
148 const SpeakerEmbeddingExtractorConfig &config) { 150 const SpeakerEmbeddingExtractorConfig &config) {
@@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions( @@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions(
154 return GetSessionOptionsImpl(config.num_threads, config.provider); 156 return GetSessionOptionsImpl(config.num_threads, config.provider);
155 } 157 }
156 158
  159 +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
  160 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  161 +}
  162 +
157 } // namespace sherpa_onnx 163 } // namespace sherpa_onnx
@@ -6,15 +6,19 @@ @@ -6,15 +6,19 @@
6 #define SHERPA_ONNX_CSRC_SESSION_H_ 6 #define SHERPA_ONNX_CSRC_SESSION_H_
7 7
8 #include "onnxruntime_cxx_api.h" // NOLINT 8 #include "onnxruntime_cxx_api.h" // NOLINT
  9 +#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
9 #include "sherpa-onnx/csrc/offline-lm-config.h" 10 #include "sherpa-onnx/csrc/offline-lm-config.h"
10 #include "sherpa-onnx/csrc/offline-model-config.h" 11 #include "sherpa-onnx/csrc/offline-model-config.h"
11 -#include "sherpa-onnx/csrc/offline-tts-model-config.h"  
12 #include "sherpa-onnx/csrc/online-lm-config.h" 12 #include "sherpa-onnx/csrc/online-lm-config.h"
13 #include "sherpa-onnx/csrc/online-model-config.h" 13 #include "sherpa-onnx/csrc/online-model-config.h"
14 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 14 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
15 #include "sherpa-onnx/csrc/spoken-language-identification.h" 15 #include "sherpa-onnx/csrc/spoken-language-identification.h"
16 #include "sherpa-onnx/csrc/vad-model-config.h" 16 #include "sherpa-onnx/csrc/vad-model-config.h"
17 17
  18 +#if SHERPA_ONNX_ENABLE_TTS
  19 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  20 +#endif
  21 +
18 namespace sherpa_onnx { 22 namespace sherpa_onnx {
19 23
20 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); 24 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
@@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); @@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
27 31
28 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); 32 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
29 33
  34 +#if SHERPA_ONNX_ENABLE_TTS
30 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); 35 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
  36 +#endif
31 37
32 Ort::SessionOptions GetSessionOptions( 38 Ort::SessionOptions GetSessionOptions(
33 const SpeakerEmbeddingExtractorConfig &config); 39 const SpeakerEmbeddingExtractorConfig &config);
@@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions( @@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions(
35 Ort::SessionOptions GetSessionOptions( 41 Ort::SessionOptions GetSessionOptions(
36 const SpokenLanguageIdentificationConfig &config); 42 const SpokenLanguageIdentificationConfig &config);
37 43
  44 +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
  45 +
38 } // namespace sherpa_onnx 46 } // namespace sherpa_onnx
39 47
40 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 48 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include <stdio.h>
  5 +
  6 +#include "sherpa-onnx/csrc/audio-tagging.h"
  7 +#include "sherpa-onnx/csrc/parse-options.h"
  8 +#include "sherpa-onnx/csrc/wave-reader.h"
  9 +
  10 +int32_t main(int32_t argc, char *argv[]) {
  11 + const char *kUsageMessage = R"usage(
  12 +Audio tagging from a file.
  13 +
  14 +Usage:
  15 +
  16 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  17 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  18 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  19 +
  20 +./bin/sherpa-onnx-offline-audio-tagging \
  21 + --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \
  22 + --labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \
  23 + sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav
  24 +
  25 +Input wave files should be of single channel, 16-bit PCM encoded wave file; its
  26 +sampling rate can be arbitrary and does not need to be 16kHz.
  27 +
  28 +Please see
  29 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
  30 +for more models.
  31 +)usage";
  32 +
  33 + sherpa_onnx::ParseOptions po(kUsageMessage);
  34 + sherpa_onnx::AudioTaggingConfig config;
  35 + config.Register(&po);
  36 + po.Read(argc, argv);
  37 +
  38 + if (po.NumArgs() != 1) {
  39 + fprintf(stderr, "\nError: Please provide 1 wave file\n\n");
  40 + po.PrintUsage();
  41 + exit(EXIT_FAILURE);
  42 + }
  43 +
  44 + fprintf(stderr, "%s\n", config.ToString().c_str());
  45 +
  46 + if (!config.Validate()) {
  47 + fprintf(stderr, "Errors in config!\n");
  48 + return -1;
  49 + }
  50 +
  51 + sherpa_onnx::AudioTagging tagger(config);
  52 + std::string wav_filename = po.GetArg(1);
  53 +
  54 + int32_t sampling_rate = -1;
  55 +
  56 + bool is_ok = false;
  57 + const std::vector<float> samples =
  58 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  59 +
  60 + if (!is_ok) {
  61 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  62 + return -1;
  63 + }
  64 +
  65 + const float duration = samples.size() / static_cast<float>(sampling_rate);
  66 +
  67 + fprintf(stderr, "Start to compute\n");
  68 + const auto begin = std::chrono::steady_clock::now();
  69 +
  70 + auto stream = tagger.CreateStream();
  71 +
  72 + stream->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  73 +
  74 + auto results = tagger.Compute(stream.get());
  75 + const auto end = std::chrono::steady_clock::now();
  76 + fprintf(stderr, "Done\n");
  77 +
  78 + int32_t i = 0;
  79 +
  80 + for (const auto &event : results) {
  81 + fprintf(stderr, "%d: %s\n", i, event.ToString().c_str());
  82 + i += 1;
  83 + }
  84 +
  85 + float elapsed_seconds =
  86 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  87 + .count() /
  88 + 1000.;
  89 + float rtf = elapsed_seconds / duration;
  90 + fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
  91 + fprintf(stderr, "Wave duration: %.3f\n", duration);
  92 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  93 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  94 + elapsed_seconds, duration, rtf);
  95 +
  96 + return 0;
  97 +}