Committed by
GitHub
Support audio tagging using zipformer (#747)
正在显示
30 个修改的文件
包含
927 行增加
和
11 行删除
.github/scripts/test-audio-tagging.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run zipformer for audio tagging " | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 21 | +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 22 | +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 23 | +repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09 | ||
| 24 | +ls -lh $repo | ||
| 25 | + | ||
| 26 | +for w in 1.wav 2.wav 3.wav 4.wav; do | ||
| 27 | + $EXE \ | ||
| 28 | + --zipformer-model=$repo/model.onnx \ | ||
| 29 | + --labels=$repo/class_labels_indices.csv \ | ||
| 30 | + $repo/test_wavs/$w | ||
| 31 | +done | ||
| 32 | +rm -rf $repo |
| @@ -15,6 +15,7 @@ on: | @@ -15,6 +15,7 @@ on: | ||
| 15 | - '.github/scripts/test-offline-ctc.sh' | 15 | - '.github/scripts/test-offline-ctc.sh' |
| 16 | - '.github/scripts/test-online-ctc.sh' | 16 | - '.github/scripts/test-online-ctc.sh' |
| 17 | - '.github/scripts/test-offline-tts.sh' | 17 | - '.github/scripts/test-offline-tts.sh' |
| 18 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 18 | - 'CMakeLists.txt' | 19 | - 'CMakeLists.txt' |
| 19 | - 'cmake/**' | 20 | - 'cmake/**' |
| 20 | - 'sherpa-onnx/csrc/*' | 21 | - 'sherpa-onnx/csrc/*' |
| @@ -32,6 +33,7 @@ on: | @@ -32,6 +33,7 @@ on: | ||
| 32 | - '.github/scripts/test-offline-ctc.sh' | 33 | - '.github/scripts/test-offline-ctc.sh' |
| 33 | - '.github/scripts/test-online-ctc.sh' | 34 | - '.github/scripts/test-online-ctc.sh' |
| 34 | - '.github/scripts/test-offline-tts.sh' | 35 | - '.github/scripts/test-offline-tts.sh' |
| 36 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 35 | - 'CMakeLists.txt' | 37 | - 'CMakeLists.txt' |
| 36 | - 'cmake/**' | 38 | - 'cmake/**' |
| 37 | - 'sherpa-onnx/csrc/*' | 39 | - 'sherpa-onnx/csrc/*' |
| @@ -124,6 +126,14 @@ jobs: | @@ -124,6 +126,14 @@ jobs: | ||
| 124 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 126 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 125 | path: build/bin/* | 127 | path: build/bin/* |
| 126 | 128 | ||
| 129 | + - name: Test Audio tagging | ||
| 130 | + shell: bash | ||
| 131 | + run: | | ||
| 132 | + export PATH=$PWD/build/bin:$PATH | ||
| 133 | + export EXE=sherpa-onnx-offline-audio-tagging | ||
| 134 | + | ||
| 135 | + .github/scripts/test-audio-tagging.sh | ||
| 136 | + | ||
| 127 | - name: Test online CTC | 137 | - name: Test online CTC |
| 128 | shell: bash | 138 | shell: bash |
| 129 | run: | | 139 | run: | |
| @@ -15,6 +15,7 @@ on: | @@ -15,6 +15,7 @@ on: | ||
| 15 | - '.github/scripts/test-offline-ctc.sh' | 15 | - '.github/scripts/test-offline-ctc.sh' |
| 16 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 17 | - '.github/scripts/test-online-ctc.sh' | 17 | - '.github/scripts/test-online-ctc.sh' |
| 18 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 18 | - 'CMakeLists.txt' | 19 | - 'CMakeLists.txt' |
| 19 | - 'cmake/**' | 20 | - 'cmake/**' |
| 20 | - 'sherpa-onnx/csrc/*' | 21 | - 'sherpa-onnx/csrc/*' |
| @@ -31,6 +32,7 @@ on: | @@ -31,6 +32,7 @@ on: | ||
| 31 | - '.github/scripts/test-offline-ctc.sh' | 32 | - '.github/scripts/test-offline-ctc.sh' |
| 32 | - '.github/scripts/test-offline-tts.sh' | 33 | - '.github/scripts/test-offline-tts.sh' |
| 33 | - '.github/scripts/test-online-ctc.sh' | 34 | - '.github/scripts/test-online-ctc.sh' |
| 35 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 34 | - 'CMakeLists.txt' | 36 | - 'CMakeLists.txt' |
| 35 | - 'cmake/**' | 37 | - 'cmake/**' |
| 36 | - 'sherpa-onnx/csrc/*' | 38 | - 'sherpa-onnx/csrc/*' |
| @@ -103,6 +105,14 @@ jobs: | @@ -103,6 +105,14 @@ jobs: | ||
| 103 | otool -L build/bin/sherpa-onnx | 105 | otool -L build/bin/sherpa-onnx |
| 104 | otool -l build/bin/sherpa-onnx | 106 | otool -l build/bin/sherpa-onnx |
| 105 | 107 | ||
| 108 | + - name: Test Audio tagging | ||
| 109 | + shell: bash | ||
| 110 | + run: | | ||
| 111 | + export PATH=$PWD/build/bin:$PATH | ||
| 112 | + export EXE=sherpa-onnx-offline-audio-tagging | ||
| 113 | + | ||
| 114 | + .github/scripts/test-audio-tagging.sh | ||
| 115 | + | ||
| 106 | - name: Test C API | 116 | - name: Test C API |
| 107 | shell: bash | 117 | shell: bash |
| 108 | run: | | 118 | run: | |
| @@ -14,6 +14,7 @@ on: | @@ -14,6 +14,7 @@ on: | ||
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | - '.github/scripts/test-online-ctc.sh' | 15 | - '.github/scripts/test-online-ctc.sh' |
| 16 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 17 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 17 | - 'CMakeLists.txt' | 18 | - 'CMakeLists.txt' |
| 18 | - 'cmake/**' | 19 | - 'cmake/**' |
| 19 | - 'sherpa-onnx/csrc/*' | 20 | - 'sherpa-onnx/csrc/*' |
| @@ -28,6 +29,7 @@ on: | @@ -28,6 +29,7 @@ on: | ||
| 28 | - '.github/scripts/test-offline-ctc.sh' | 29 | - '.github/scripts/test-offline-ctc.sh' |
| 29 | - '.github/scripts/test-online-ctc.sh' | 30 | - '.github/scripts/test-online-ctc.sh' |
| 30 | - '.github/scripts/test-offline-tts.sh' | 31 | - '.github/scripts/test-offline-tts.sh' |
| 32 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 31 | - 'CMakeLists.txt' | 33 | - 'CMakeLists.txt' |
| 32 | - 'cmake/**' | 34 | - 'cmake/**' |
| 33 | - 'sherpa-onnx/csrc/*' | 35 | - 'sherpa-onnx/csrc/*' |
| @@ -70,6 +72,14 @@ jobs: | @@ -70,6 +72,14 @@ jobs: | ||
| 70 | 72 | ||
| 71 | ls -lh ./bin/Release/sherpa-onnx.exe | 73 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 72 | 74 | ||
| 75 | + - name: Test Audio tagging | ||
| 76 | + shell: bash | ||
| 77 | + run: | | ||
| 78 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 79 | + export EXE=sherpa-onnx-offline-audio-tagging.exe | ||
| 80 | + | ||
| 81 | + .github/scripts/test-audio-tagging.sh | ||
| 82 | + | ||
| 73 | - name: Test C API | 83 | - name: Test C API |
| 74 | shell: bash | 84 | shell: bash |
| 75 | run: | | 85 | run: | |
| @@ -14,6 +14,7 @@ on: | @@ -14,6 +14,7 @@ on: | ||
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | - '.github/scripts/test-offline-tts.sh' | 15 | - '.github/scripts/test-offline-tts.sh' |
| 16 | - '.github/scripts/test-online-ctc.sh' | 16 | - '.github/scripts/test-online-ctc.sh' |
| 17 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 17 | - 'CMakeLists.txt' | 18 | - 'CMakeLists.txt' |
| 18 | - 'cmake/**' | 19 | - 'cmake/**' |
| 19 | - 'sherpa-onnx/csrc/*' | 20 | - 'sherpa-onnx/csrc/*' |
| @@ -28,6 +29,7 @@ on: | @@ -28,6 +29,7 @@ on: | ||
| 28 | - '.github/scripts/test-offline-ctc.sh' | 29 | - '.github/scripts/test-offline-ctc.sh' |
| 29 | - '.github/scripts/test-offline-tts.sh' | 30 | - '.github/scripts/test-offline-tts.sh' |
| 30 | - '.github/scripts/test-online-ctc.sh' | 31 | - '.github/scripts/test-online-ctc.sh' |
| 32 | + - '.github/scripts/test-audio-tagging.sh' | ||
| 31 | - 'CMakeLists.txt' | 33 | - 'CMakeLists.txt' |
| 32 | - 'cmake/**' | 34 | - 'cmake/**' |
| 33 | - 'sherpa-onnx/csrc/*' | 35 | - 'sherpa-onnx/csrc/*' |
| @@ -85,6 +87,13 @@ jobs: | @@ -85,6 +87,13 @@ jobs: | ||
| 85 | # export EXE=sherpa-onnx-offline-language-identification.exe | 87 | # export EXE=sherpa-onnx-offline-language-identification.exe |
| 86 | # | 88 | # |
| 87 | # .github/scripts/test-spoken-language-identification.sh | 89 | # .github/scripts/test-spoken-language-identification.sh |
| 90 | + - name: Test Audio tagging | ||
| 91 | + shell: bash | ||
| 92 | + run: | | ||
| 93 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 94 | + export EXE=sherpa-onnx-offline-audio-tagging.exe | ||
| 95 | + | ||
| 96 | + .github/scripts/test-audio-tagging.sh | ||
| 88 | 97 | ||
| 89 | - name: Test online CTC | 98 | - name: Test online CTC |
| 90 | shell: bash | 99 | shell: bash |
| @@ -46,6 +46,7 @@ def enable_alsa(): | @@ -46,6 +46,7 @@ def enable_alsa(): | ||
| 46 | def get_binaries(): | 46 | def get_binaries(): |
| 47 | binaries = [ | 47 | binaries = [ |
| 48 | "sherpa-onnx", | 48 | "sherpa-onnx", |
| 49 | + "sherpa-onnx-offline-audio-tagging", | ||
| 49 | "sherpa-onnx-keyword-spotter", | 50 | "sherpa-onnx-keyword-spotter", |
| 50 | "sherpa-onnx-microphone", | 51 | "sherpa-onnx-microphone", |
| 51 | "sherpa-onnx-microphone-offline", | 52 | "sherpa-onnx-microphone-offline", |
| @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx'); | @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx'); | ||
| 4 | 4 | ||
| 5 | function createOfflineTts() { | 5 | function createOfflineTts() { |
| 6 | let offlineTtsVitsModelConfig = { | 6 | let offlineTtsVitsModelConfig = { |
| 7 | - model: './vits-icefall-zh-aishell3/vits-aishell3.onnx', | 7 | + model: './vits-icefall-zh-aishell3/model.onnx', |
| 8 | lexicon: './vits-icefall-zh-aishell3/lexicon.txt', | 8 | lexicon: './vits-icefall-zh-aishell3/lexicon.txt', |
| 9 | tokens: './vits-icefall-zh-aishell3/tokens.txt', | 9 | tokens: './vits-icefall-zh-aishell3/tokens.txt', |
| 10 | dataDir: '', | 10 | dataDir: '', |
| @@ -111,6 +111,16 @@ list(APPEND sources | @@ -111,6 +111,16 @@ list(APPEND sources | ||
| 111 | speaker-embedding-manager.cc | 111 | speaker-embedding-manager.cc |
| 112 | ) | 112 | ) |
| 113 | 113 | ||
| 114 | +# audio tagging | ||
| 115 | +list(APPEND sources | ||
| 116 | + audio-tagging-impl.cc | ||
| 117 | + audio-tagging-label-file.cc | ||
| 118 | + audio-tagging-model-config.cc | ||
| 119 | + audio-tagging.cc | ||
| 120 | + offline-zipformer-audio-tagging-model-config.cc | ||
| 121 | + offline-zipformer-audio-tagging-model.cc | ||
| 122 | +) | ||
| 123 | + | ||
| 114 | if(SHERPA_ONNX_ENABLE_TTS) | 124 | if(SHERPA_ONNX_ENABLE_TTS) |
| 115 | list(APPEND sources | 125 | list(APPEND sources |
| 116 | lexicon.cc | 126 | lexicon.cc |
| @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 193 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) | 203 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) |
| 194 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 204 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 195 | add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) | 205 | add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) |
| 206 | + add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) | ||
| 196 | 207 | ||
| 197 | if(SHERPA_ONNX_ENABLE_TTS) | 208 | if(SHERPA_ONNX_ENABLE_TTS) |
| 198 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 209 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 204 | sherpa-onnx-offline | 215 | sherpa-onnx-offline |
| 205 | sherpa-onnx-offline-parallel | 216 | sherpa-onnx-offline-parallel |
| 206 | sherpa-onnx-offline-language-identification | 217 | sherpa-onnx-offline-language-identification |
| 218 | + sherpa-onnx-offline-audio-tagging | ||
| 207 | ) | 219 | ) |
| 208 | if(SHERPA_ONNX_ENABLE_TTS) | 220 | if(SHERPA_ONNX_ENABLE_TTS) |
| 209 | list(APPEND main_exes | 221 | list(APPEND main_exes |
sherpa-onnx/csrc/audio-tagging-impl.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/audio-tagging-impl.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( | ||
| 13 | + const AudioTaggingConfig &config) { | ||
| 14 | + if (!config.model.zipformer.model.empty()) { | ||
| 15 | + return std::make_unique<AudioTaggingZipformerImpl>(config); | ||
| 16 | + } | ||
| 17 | + | ||
| 18 | + SHERPA_ONNX_LOG( | ||
| 19 | + "Please specify an audio tagging model! Return a null pointer"); | ||
| 20 | + return nullptr; | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/audio-tagging-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/audio-tagging.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class AudioTaggingImpl { | ||
| 15 | + public: | ||
| 16 | + virtual ~AudioTaggingImpl() = default; | ||
| 17 | + | ||
| 18 | + static std::unique_ptr<AudioTaggingImpl> Create( | ||
| 19 | + const AudioTaggingConfig &config); | ||
| 20 | + | ||
| 21 | + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | ||
| 22 | + | ||
| 23 | + virtual std::vector<AudioEvent> Compute(OfflineStream *s, | ||
| 24 | + int32_t top_k = -1) const = 0; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ |
sherpa-onnx/csrc/audio-tagging-label-file.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging-label-file.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" | ||
| 6 | + | ||
| 7 | +#include <fstream> | ||
| 8 | +#include <sstream> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { | ||
| 17 | + std::ifstream is(filename); | ||
| 18 | + Init(is); | ||
| 19 | +} | ||
| 20 | + | ||
| 21 | +// Format of a label file | ||
| 22 | +/* | ||
| 23 | +index,mid,display_name | ||
| 24 | +0,/m/09x0r,"Speech" | ||
| 25 | +1,/m/05zppz,"Male speech, man speaking" | ||
| 26 | +*/ | ||
| 27 | +void AudioTaggingLabels::Init(std::istream &is) { | ||
| 28 | + std::string line; | ||
| 29 | + std::getline(is, line); // skip the header | ||
| 30 | + | ||
| 31 | + std::string index; | ||
| 32 | + std::string tmp; | ||
| 33 | + std::string name; | ||
| 34 | + | ||
| 35 | + while (std::getline(is, line)) { | ||
| 36 | + index.clear(); | ||
| 37 | + name.clear(); | ||
| 38 | + std::istringstream input2(line); | ||
| 39 | + | ||
| 40 | + std::getline(input2, index, ','); | ||
| 41 | + std::getline(input2, tmp, ','); | ||
| 42 | + std::getline(input2, name); | ||
| 43 | + | ||
| 44 | + std::size_t pos{}; | ||
| 45 | + int32_t i = std::stoi(index, &pos); | ||
| 46 | + if (index.size() == 0 || pos != index.size()) { | ||
| 47 | + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); | ||
| 48 | + exit(-1); | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + if (i != names_.size()) { | ||
| 52 | + SHERPA_ONNX_LOGE( | ||
| 53 | + "Index should be sorted and contiguous. Expected index: %d, given: " | ||
| 54 | + "%d.", | ||
| 55 | + static_cast<int32_t>(names_.size()), i); | ||
| 56 | + } | ||
| 57 | + if (name.empty() || name.front() != '"' || name.back() != '"') { | ||
| 58 | + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); | ||
| 59 | + exit(-1); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + names_.emplace_back(name.begin() + 1, name.end() - 1); | ||
| 63 | + } | ||
| 64 | +} | ||
| 65 | + | ||
| 66 | +const std::string &AudioTaggingLabels::GetEventName(int32_t index) const { | ||
| 67 | + return names_.at(index); | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/audio-tagging-label-file.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging-label-file.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ | ||
| 6 | + | ||
| 7 | +#include <istream> | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +class AudioTaggingLabels { | ||
| 14 | + public: | ||
| 15 | + explicit AudioTaggingLabels(const std::string &filename); | ||
| 16 | + | ||
| 17 | + // Return the event name for the given index. | ||
| 18 | + // The returned reference is valid as long as this object is alive | ||
| 19 | + const std::string &GetEventName(int32_t index) const; | ||
| 20 | + int32_t NumEventClasses() const { return names_.size(); } | ||
| 21 | + | ||
| 22 | + private: | ||
| 23 | + void Init(std::istream &is); | ||
| 24 | + | ||
| 25 | + private: | ||
| 26 | + std::vector<std::string> names_; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ |
| 1 | +// sherpa-onnx/csrc/audio-tagging-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 6 | + | ||
| 7 | +namespace sherpa_onnx { | ||
| 8 | + | ||
| 9 | +void AudioTaggingModelConfig::Register(ParseOptions *po) { | ||
| 10 | + zipformer.Register(po); | ||
| 11 | + | ||
| 12 | + po->Register("num-threads", &num_threads, | ||
| 13 | + "Number of threads to run the neural network"); | ||
| 14 | + | ||
| 15 | + po->Register("debug", &debug, | ||
| 16 | + "true to print model information while loading it."); | ||
| 17 | + | ||
| 18 | + po->Register("provider", &provider, | ||
| 19 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +bool AudioTaggingModelConfig::Validate() const { | ||
| 23 | + if (!zipformer.model.empty() && !zipformer.Validate()) { | ||
| 24 | + return false; | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + return true; | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +std::string AudioTaggingModelConfig::ToString() const { | ||
| 31 | + std::ostringstream os; | ||
| 32 | + | ||
| 33 | + os << "AudioTaggingModelConfig("; | ||
| 34 | + os << "zipformer=" << zipformer.ToString() << ", "; | ||
| 35 | + os << "num_threads=" << num_threads << ", "; | ||
| 36 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 37 | + os << "provider=\"" << provider << "\")"; | ||
| 38 | + | ||
| 39 | + return os.str(); | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/audio-tagging-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct AudioTaggingModelConfig { | ||
| 15 | + struct OfflineZipformerAudioTaggingModelConfig zipformer; | ||
| 16 | + | ||
| 17 | + int32_t num_threads = 1; | ||
| 18 | + bool debug = false; | ||
| 19 | + std::string provider = "cpu"; | ||
| 20 | + | ||
| 21 | + AudioTaggingModelConfig() = default; | ||
| 22 | + | ||
| 23 | + AudioTaggingModelConfig( | ||
| 24 | + const OfflineZipformerAudioTaggingModelConfig &zipformer, | ||
| 25 | + int32_t num_threads, bool debug, const std::string &provider) | ||
| 26 | + : zipformer(zipformer), | ||
| 27 | + num_threads(num_threads), | ||
| 28 | + debug(debug), | ||
| 29 | + provider(provider) {} | ||
| 30 | + | ||
| 31 | + void Register(ParseOptions *po); | ||
| 32 | + bool Validate() const; | ||
| 33 | + | ||
| 34 | + std::string ToString() const; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx | ||
| 38 | + | ||
| 39 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/audio-tagging-impl.h" | ||
| 12 | +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" | ||
| 13 | +#include "sherpa-onnx/csrc/audio-tagging.h" | ||
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 15 | +#include "sherpa-onnx/csrc/math.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" | ||
| 17 | + | ||
| 18 | +namespace sherpa_onnx { | ||
| 19 | + | ||
| 20 | +class AudioTaggingZipformerImpl : public AudioTaggingImpl { | ||
| 21 | + public: | ||
| 22 | + explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) | ||
| 23 | + : config_(config), model_(config.model), labels_(config.labels) { | ||
| 24 | + if (model_.NumEventClasses() != labels_.NumEventClasses()) { | ||
| 25 | + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", | ||
| 26 | + model_.NumEventClasses(), labels_.NumEventClasses()); | ||
| 27 | + exit(-1); | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 32 | + return std::make_unique<OfflineStream>(); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + std::vector<AudioEvent> Compute(OfflineStream *s, | ||
| 36 | + int32_t top_k = -1) const override { | ||
| 37 | + if (top_k < 0) { | ||
| 38 | + top_k = config_.top_k; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + int32_t num_event_classes = model_.NumEventClasses(); | ||
| 42 | + | ||
| 43 | + if (top_k > num_event_classes) { | ||
| 44 | + top_k = num_event_classes; | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + auto memory_info = | ||
| 48 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 49 | + | ||
| 50 | + // WARNING(fangjun): It is fixed to 80 for all models from icefall | ||
| 51 | + int32_t feat_dim = 80; | ||
| 52 | + std::vector<float> f = s->GetFrames(); | ||
| 53 | + | ||
| 54 | + int32_t num_frames = f.size() / feat_dim; | ||
| 55 | + | ||
| 56 | + std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | ||
| 57 | + | ||
| 58 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | ||
| 59 | + shape.data(), shape.size()); | ||
| 60 | + | ||
| 61 | + int64_t x_length_scalar = num_frames; | ||
| 62 | + std::array<int64_t, 1> x_length_shape = {1}; | ||
| 63 | + Ort::Value x_length = | ||
| 64 | + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, | ||
| 65 | + x_length_shape.data(), x_length_shape.size()); | ||
| 66 | + | ||
| 67 | + Ort::Value probs = model_.Forward(std::move(x), std::move(x_length)); | ||
| 68 | + | ||
| 69 | + const float *p = probs.GetTensorData<float>(); | ||
| 70 | + | ||
| 71 | + std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k); | ||
| 72 | + | ||
| 73 | + std::vector<AudioEvent> ans(top_k); | ||
| 74 | + | ||
| 75 | + int32_t i = 0; | ||
| 76 | + | ||
| 77 | + for (int32_t index : top_k_indexes) { | ||
| 78 | + ans[i].name = labels_.GetEventName(index); | ||
| 79 | + ans[i].index = index; | ||
| 80 | + ans[i].prob = p[index]; | ||
| 81 | + i += 1; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + return ans; | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + private: | ||
| 88 | + AudioTaggingConfig config_; | ||
| 89 | + OfflineZipformerAudioTaggingModel model_; | ||
| 90 | + AudioTaggingLabels labels_; | ||
| 91 | +}; | ||
| 92 | + | ||
| 93 | +} // namespace sherpa_onnx | ||
| 94 | + | ||
| 95 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ |
sherpa-onnx/csrc/audio-tagging.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/audio-tagging.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/audio-tagging-impl.h" | ||
| 8 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +std::string AudioEvent::ToString() const { | ||
| 14 | + std::ostringstream os; | ||
| 15 | + os << "AudioEvent("; | ||
| 16 | + os << "name=\"" << name << "\", "; | ||
| 17 | + os << "index=" << index << ", "; | ||
| 18 | + os << "prob=" << prob << ")"; | ||
| 19 | + return os.str(); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +void AudioTaggingConfig::Register(ParseOptions *po) { | ||
| 23 | + model.Register(po); | ||
| 24 | + po->Register("labels", &labels, "Event label file"); | ||
| 25 | + po->Register("top-k", &top_k, "Top k events to return in the result"); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +bool AudioTaggingConfig::Validate() const { | ||
| 29 | + if (!model.Validate()) { | ||
| 30 | + return false; | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + if (top_k < 1) { | ||
| 34 | + SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k); | ||
| 35 | + return false; | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + if (labels.empty()) { | ||
| 39 | + SHERPA_ONNX_LOGE("Please provide --labels"); | ||
| 40 | + return false; | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + if (!FileExists(labels)) { | ||
| 44 | + SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str()); | ||
| 45 | + return false; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + return true; | ||
| 49 | +} | ||
| 50 | +std::string AudioTaggingConfig::ToString() const { | ||
| 51 | + std::ostringstream os; | ||
| 52 | + | ||
| 53 | + os << "AudioTaggingConfig("; | ||
| 54 | + os << "model=" << model.ToString() << ", "; | ||
| 55 | + os << "labels=\"" << labels << "\", "; | ||
| 56 | + os << "top_k=" << top_k << ")"; | ||
| 57 | + | ||
| 58 | + return os.str(); | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +AudioTagging::AudioTagging(const AudioTaggingConfig &config) | ||
| 62 | + : impl_(AudioTaggingImpl::Create(config)) {} | ||
| 63 | + | ||
| 64 | +AudioTagging::~AudioTagging() = default; | ||
| 65 | + | ||
| 66 | +std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const { | ||
| 67 | + return impl_->CreateStream(); | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +std::vector<AudioEvent> AudioTagging::Compute(OfflineStream *s, | ||
| 71 | + int32_t top_k /*= -1*/) const { | ||
| 72 | + return impl_->Compute(s, top_k); | ||
| 73 | +} | ||
| 74 | + | ||
| 75 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/audio-tagging.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct AudioTaggingConfig { | ||
| 18 | + AudioTaggingModelConfig model; | ||
| 19 | + std::string labels; | ||
| 20 | + | ||
| 21 | + int32_t top_k = 5; | ||
| 22 | + | ||
| 23 | + AudioTaggingConfig() = default; | ||
| 24 | + | ||
| 25 | + AudioTaggingConfig(const AudioTaggingModelConfig &model, | ||
| 26 | + const std::string &labels, int32_t top_k) | ||
| 27 | + : model(model), labels(labels), top_k(top_k) {} | ||
| 28 | + | ||
| 29 | + void Register(ParseOptions *po); | ||
| 30 | + bool Validate() const; | ||
| 31 | + | ||
| 32 | + std::string ToString() const; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +struct AudioEvent { | ||
| 36 | + std::string name; // name of the event | ||
| 37 | + int32_t index; // index of the event in the label file | ||
| 38 | + float prob; // probability of the event | ||
| 39 | + | ||
| 40 | + std::string ToString() const; | ||
| 41 | +}; | ||
| 42 | + | ||
| 43 | +class AudioTaggingImpl; | ||
| 44 | + | ||
| 45 | +class AudioTagging { | ||
| 46 | + public: | ||
| 47 | + explicit AudioTagging(const AudioTaggingConfig &config); | ||
| 48 | + | ||
| 49 | + ~AudioTagging(); | ||
| 50 | + | ||
| 51 | + std::unique_ptr<OfflineStream> CreateStream() const; | ||
| 52 | + | ||
| 53 | + // If top_k is -1, then config.top_k is used. | ||
| 54 | + // Otherwise, config.top_k is ignored | ||
| 55 | + // | ||
| 56 | + // Return top_k AudioEvent. ans[0].prob is the largest of all returned events. | ||
| 57 | + std::vector<AudioEvent> Compute(OfflineStream *s, int32_t top_k = -1) const; | ||
| 58 | + | ||
| 59 | + private: | ||
| 60 | + std::unique_ptr<AudioTaggingImpl> impl_; | ||
| 61 | +}; | ||
| 62 | + | ||
| 63 | +} // namespace sherpa_onnx | ||
| 64 | + | ||
| 65 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ |
| @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | ||
| 97 | } | 97 | } |
| 98 | 98 | ||
| 99 | template <typename T> | 99 | template <typename T> |
| 100 | -void SubtractBlank(T *in, int32_t w, int32_t h, | ||
| 101 | - int32_t blank_idx, float blank_penalty) { | 100 | +void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx, |
| 101 | + float blank_penalty) { | ||
| 102 | for (int32_t i = 0; i != h; ++i) { | 102 | for (int32_t i = 0; i != h; ++i) { |
| 103 | in[blank_idx] -= blank_penalty; | 103 | in[blank_idx] -= blank_penalty; |
| 104 | in += w; | 104 | in += w; |
| @@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | @@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | ||
| 116 | }); | 116 | }); |
| 117 | 117 | ||
| 118 | int32_t k_num = std::min<int32_t>(size, topk); | 118 | int32_t k_num = std::min<int32_t>(size, topk); |
| 119 | - std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num); | ||
| 120 | - return index; | 119 | + return {vec_index.begin(), vec_index.begin() + k_num}; |
| 121 | } | 120 | } |
| 122 | 121 | ||
| 123 | } // namespace sherpa_onnx | 122 | } // namespace sherpa_onnx |
| @@ -234,7 +234,7 @@ OfflineStream::OfflineStream( | @@ -234,7 +234,7 @@ OfflineStream::OfflineStream( | ||
| 234 | : impl_(std::make_unique<Impl>(config, context_graph)) {} | 234 | : impl_(std::make_unique<Impl>(config, context_graph)) {} |
| 235 | 235 | ||
| 236 | OfflineStream::OfflineStream(WhisperTag tag, | 236 | OfflineStream::OfflineStream(WhisperTag tag, |
| 237 | - ContextGraphPtr context_graph /*= nullptr*/) | 237 | + ContextGraphPtr context_graph /*= {}*/) |
| 238 | : impl_(std::make_unique<Impl>(tag, context_graph)) {} | 238 | : impl_(std::make_unique<Impl>(tag, context_graph)) {} |
| 239 | 239 | ||
| 240 | OfflineStream::~OfflineStream() = default; | 240 | OfflineStream::~OfflineStream() = default; |
| @@ -71,10 +71,9 @@ struct WhisperTag {}; | @@ -71,10 +71,9 @@ struct WhisperTag {}; | ||
| 71 | class OfflineStream { | 71 | class OfflineStream { |
| 72 | public: | 72 | public: |
| 73 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, | 73 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, |
| 74 | - ContextGraphPtr context_graph = nullptr); | 74 | + ContextGraphPtr context_graph = {}); |
| 75 | 75 | ||
| 76 | - explicit OfflineStream(WhisperTag tag, | ||
| 77 | - ContextGraphPtr context_graph = nullptr); | 76 | + explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); |
| 78 | ~OfflineStream(); | 77 | ~OfflineStream(); |
| 79 | 78 | ||
| 80 | /** | 79 | /** |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("zipformer-model", &model, | ||
| 14 | + "Path to zipformer model for audio tagging"); | ||
| 15 | +} | ||
| 16 | + | ||
| 17 | +bool OfflineZipformerAudioTaggingModelConfig::Validate() const { | ||
| 18 | + if (model.empty()) { | ||
| 19 | + SHERPA_ONNX_LOGE("Please provide --zipformer-model"); | ||
| 20 | + return false; | ||
| 21 | + } | ||
| 22 | + | ||
| 23 | + if (!FileExists(model)) { | ||
| 24 | + SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str()); | ||
| 25 | + return false; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + return true; | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +std::string OfflineZipformerAudioTaggingModelConfig::ToString() const { | ||
| 32 | + std::ostringstream os; | ||
| 33 | + | ||
| 34 | + os << "OfflineZipformerAudioTaggingModelConfig("; | ||
| 35 | + os << "model=\"" << model << "\")"; | ||
| 36 | + | ||
| 37 | + return os.str(); | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineZipformerAudioTaggingModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineZipformerAudioTaggingModelConfig() = default; | ||
| 17 | + | ||
| 18 | + explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model) | ||
| 19 | + : model(model) {} | ||
| 20 | + | ||
| 21 | + void Register(ParseOptions *po); | ||
| 22 | + bool Validate() const; | ||
| 23 | + | ||
| 24 | + std::string ToString() const; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/session.h" | ||
| 12 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineZipformerAudioTaggingModel::Impl { | ||
| 17 | + public: | ||
| 18 | + explicit Impl(const AudioTaggingModelConfig &config) | ||
| 19 | + : config_(config), | ||
| 20 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 21 | + sess_opts_(GetSessionOptions(config)), | ||
| 22 | + allocator_{} { | ||
| 23 | + auto buf = ReadFile(config_.zipformer.model); | ||
| 24 | + Init(buf.data(), buf.size()); | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | +#if __ANDROID_API__ >= 9 | ||
| 28 | + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) | ||
| 29 | + : config_(config), | ||
| 30 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 31 | + sess_opts_(GetSessionOptions(config)), | ||
| 32 | + allocator_{} { | ||
| 33 | + auto buf = ReadFile(mgr, config_.zipformer.model); | ||
| 34 | + Init(buf.data(), buf.size()); | ||
| 35 | + } | ||
| 36 | +#endif | ||
| 37 | + | ||
| 38 | + Ort::Value Forward(Ort::Value features, Ort::Value features_length) { | ||
| 39 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 40 | + std::move(features_length)}; | ||
| 41 | + | ||
| 42 | + auto ans = | ||
| 43 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 44 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 45 | + return std::move(ans[0]); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + int32_t NumEventClasses() const { return num_event_classes_; } | ||
| 49 | + | ||
| 50 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 51 | + | ||
| 52 | + private: | ||
| 53 | + void Init(void *model_data, size_t model_data_length) { | ||
| 54 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 55 | + sess_opts_); | ||
| 56 | + | ||
| 57 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 58 | + | ||
| 59 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 60 | + | ||
| 61 | + // get meta data | ||
| 62 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 63 | + if (config_.debug) { | ||
| 64 | + std::ostringstream os; | ||
| 65 | + PrintModelMetadata(os, meta_data); | ||
| 66 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + // get num_event_classes from the output[0].shape, | ||
| 70 | + // which is (N, num_event_classes) | ||
| 71 | + num_event_classes_ = | ||
| 72 | + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + private: | ||
| 76 | + AudioTaggingModelConfig config_; | ||
| 77 | + Ort::Env env_; | ||
| 78 | + Ort::SessionOptions sess_opts_; | ||
| 79 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 80 | + | ||
| 81 | + std::unique_ptr<Ort::Session> sess_; | ||
| 82 | + | ||
| 83 | + std::vector<std::string> input_names_; | ||
| 84 | + std::vector<const char *> input_names_ptr_; | ||
| 85 | + | ||
| 86 | + std::vector<std::string> output_names_; | ||
| 87 | + std::vector<const char *> output_names_ptr_; | ||
| 88 | + | ||
| 89 | + int32_t num_event_classes_ = 0; | ||
| 90 | +}; | ||
| 91 | + | ||
| 92 | +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( | ||
| 93 | + const AudioTaggingModelConfig &config) | ||
| 94 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 95 | + | ||
| 96 | +#if __ANDROID_API__ >= 9 | ||
| 97 | +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( | ||
| 98 | + AAssetManager *mgr, const AudioTaggingModelConfig &config) | ||
| 99 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 100 | +#endif | ||
| 101 | + | ||
| 102 | +OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() = | ||
| 103 | + default; | ||
| 104 | + | ||
| 105 | +Ort::Value OfflineZipformerAudioTaggingModel::Forward( | ||
| 106 | + Ort::Value features, Ort::Value features_length) const { | ||
| 107 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const { | ||
| 111 | + return impl_->NumEventClasses(); | ||
| 112 | +} | ||
| 113 | + | ||
| 114 | +OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const { | ||
| 115 | + return impl_->Allocator(); | ||
| 116 | +} | ||
| 117 | + | ||
| 118 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <utility> | ||
| 8 | + | ||
| 9 | +#if __ANDROID_API__ >= 9 | ||
| 10 | +#include "android/asset_manager.h" | ||
| 11 | +#include "android/asset_manager_jni.h" | ||
| 12 | +#endif | ||
| 13 | + | ||
| 14 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 15 | +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +/** This class implements the zipformer CTC model of the librispeech recipe | ||
| 20 | + * from icefall. | ||
| 21 | + * | ||
| 22 | + * See | ||
| 23 | + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py | ||
| 24 | + */ | ||
| 25 | +class OfflineZipformerAudioTaggingModel { | ||
| 26 | + public: | ||
| 27 | + explicit OfflineZipformerAudioTaggingModel( | ||
| 28 | + const AudioTaggingModelConfig &config); | ||
| 29 | + | ||
| 30 | +#if __ANDROID_API__ >= 9 | ||
| 31 | + OfflineZipformerAudioTaggingModel(AAssetManager *mgr, | ||
| 32 | + const AudioTaggingModelConfig &config); | ||
| 33 | +#endif | ||
| 34 | + | ||
| 35 | + ~OfflineZipformerAudioTaggingModel(); | ||
| 36 | + | ||
| 37 | + /** Run the forward method of the model. | ||
| 38 | + * | ||
| 39 | + * @param features A tensor of shape (N, T, C). | ||
| 40 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 41 | + * valid frames in `features` before padding. | ||
| 42 | + * Its dtype is int64_t. | ||
| 43 | + * | ||
| 44 | + * @return Return a tensor | ||
| 45 | + * - probs: A 2-D tensor of shape (N, num_event_classes). | ||
| 46 | + */ | ||
| 47 | + Ort::Value Forward(Ort::Value features, Ort::Value features_length) const; | ||
| 48 | + | ||
| 49 | + /** Return the number of event classes of the model | ||
| 50 | + */ | ||
| 51 | + int32_t NumEventClasses() const; | ||
| 52 | + | ||
| 53 | + /** Return an allocator for allocating memory | ||
| 54 | + */ | ||
| 55 | + OrtAllocator *Allocator() const; | ||
| 56 | + | ||
| 57 | + private: | ||
| 58 | + class Impl; | ||
| 59 | + std::unique_ptr<Impl> impl_; | ||
| 60 | +}; | ||
| 61 | + | ||
| 62 | +} // namespace sherpa_onnx | ||
| 63 | + | ||
| 64 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" | 5 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" |
| 6 | 6 | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 9 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | #include "sherpa-onnx/csrc/onnx-utils.h" | 10 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 9 | #include "sherpa-onnx/csrc/session.h" | 11 | #include "sherpa-onnx/csrc/session.h" |
| @@ -4,7 +4,6 @@ | @@ -4,7 +4,6 @@ | ||
| 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ |
| 5 | #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ | 5 | #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ |
| 6 | #include <memory> | 6 | #include <memory> |
| 7 | -#include <string> | ||
| 8 | #include <utility> | 7 | #include <utility> |
| 9 | #include <vector> | 8 | #include <vector> |
| 10 | 9 |
| @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { | @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { | ||
| 140 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 140 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 141 | } | 141 | } |
| 142 | 142 | ||
| 143 | +#if SHERPA_ONNX_ENABLE_TTS | ||
| 143 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { | 144 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { |
| 144 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 145 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 145 | } | 146 | } |
| 147 | +#endif | ||
| 146 | 148 | ||
| 147 | Ort::SessionOptions GetSessionOptions( | 149 | Ort::SessionOptions GetSessionOptions( |
| 148 | const SpeakerEmbeddingExtractorConfig &config) { | 150 | const SpeakerEmbeddingExtractorConfig &config) { |
| @@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions( | @@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions( | ||
| 154 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 156 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 155 | } | 157 | } |
| 156 | 158 | ||
| 159 | +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { | ||
| 160 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 161 | +} | ||
| 162 | + | ||
| 157 | } // namespace sherpa_onnx | 163 | } // namespace sherpa_onnx |
| @@ -6,15 +6,19 @@ | @@ -6,15 +6,19 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ | 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ |
| 7 | 7 | ||
| 8 | #include "onnxruntime_cxx_api.h" // NOLINT | 8 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 9 | +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 10 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | -#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 12 | #include "sherpa-onnx/csrc/online-lm-config.h" | 12 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 13 | #include "sherpa-onnx/csrc/online-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | 14 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" |
| 15 | #include "sherpa-onnx/csrc/spoken-language-identification.h" | 15 | #include "sherpa-onnx/csrc/spoken-language-identification.h" |
| 16 | #include "sherpa-onnx/csrc/vad-model-config.h" | 16 | #include "sherpa-onnx/csrc/vad-model-config.h" |
| 17 | 17 | ||
| 18 | +#if SHERPA_ONNX_ENABLE_TTS | ||
| 19 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 18 | namespace sherpa_onnx { | 22 | namespace sherpa_onnx { |
| 19 | 23 | ||
| 20 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); | 24 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); |
| @@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | @@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | ||
| 27 | 31 | ||
| 28 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); | 32 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); |
| 29 | 33 | ||
| 34 | +#if SHERPA_ONNX_ENABLE_TTS | ||
| 30 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | 35 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); |
| 36 | +#endif | ||
| 31 | 37 | ||
| 32 | Ort::SessionOptions GetSessionOptions( | 38 | Ort::SessionOptions GetSessionOptions( |
| 33 | const SpeakerEmbeddingExtractorConfig &config); | 39 | const SpeakerEmbeddingExtractorConfig &config); |
| @@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions( | @@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions( | ||
| 35 | Ort::SessionOptions GetSessionOptions( | 41 | Ort::SessionOptions GetSessionOptions( |
| 36 | const SpokenLanguageIdentificationConfig &config); | 42 | const SpokenLanguageIdentificationConfig &config); |
| 37 | 43 | ||
| 44 | +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); | ||
| 45 | + | ||
| 38 | } // namespace sherpa_onnx | 46 | } // namespace sherpa_onnx |
| 39 | 47 | ||
| 40 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ | 48 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#include <stdio.h> | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/audio-tagging.h" | ||
| 7 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 8 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 9 | + | ||
| 10 | +int32_t main(int32_t argc, char *argv[]) { | ||
| 11 | + const char *kUsageMessage = R"usage( | ||
| 12 | +Audio tagging from a file. | ||
| 13 | + | ||
| 14 | +Usage: | ||
| 15 | + | ||
| 16 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 17 | +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 18 | +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 | ||
| 19 | + | ||
| 20 | +./bin/sherpa-onnx-offline-audio-tagging \ | ||
| 21 | + --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \ | ||
| 22 | + --labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ | ||
| 23 | + sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav | ||
| 24 | + | ||
| 25 | +Input wave files should be of single channel, 16-bit PCM encoded wave file; its | ||
| 26 | +sampling rate can be arbitrary and does not need to be 16kHz. | ||
| 27 | + | ||
| 28 | +Please see | ||
| 29 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models | ||
| 30 | +for more models. | ||
| 31 | +)usage"; | ||
| 32 | + | ||
| 33 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 34 | + sherpa_onnx::AudioTaggingConfig config; | ||
| 35 | + config.Register(&po); | ||
| 36 | + po.Read(argc, argv); | ||
| 37 | + | ||
| 38 | + if (po.NumArgs() != 1) { | ||
| 39 | + fprintf(stderr, "\nError: Please provide 1 wave file\n\n"); | ||
| 40 | + po.PrintUsage(); | ||
| 41 | + exit(EXIT_FAILURE); | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 45 | + | ||
| 46 | + if (!config.Validate()) { | ||
| 47 | + fprintf(stderr, "Errors in config!\n"); | ||
| 48 | + return -1; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + sherpa_onnx::AudioTagging tagger(config); | ||
| 52 | + std::string wav_filename = po.GetArg(1); | ||
| 53 | + | ||
| 54 | + int32_t sampling_rate = -1; | ||
| 55 | + | ||
| 56 | + bool is_ok = false; | ||
| 57 | + const std::vector<float> samples = | ||
| 58 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 59 | + | ||
| 60 | + if (!is_ok) { | ||
| 61 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 62 | + return -1; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + const float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 66 | + | ||
| 67 | + fprintf(stderr, "Start to compute\n"); | ||
| 68 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 69 | + | ||
| 70 | + auto stream = tagger.CreateStream(); | ||
| 71 | + | ||
| 72 | + stream->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 73 | + | ||
| 74 | + auto results = tagger.Compute(stream.get()); | ||
| 75 | + const auto end = std::chrono::steady_clock::now(); | ||
| 76 | + fprintf(stderr, "Done\n"); | ||
| 77 | + | ||
| 78 | + int32_t i = 0; | ||
| 79 | + | ||
| 80 | + for (const auto &event : results) { | ||
| 81 | + fprintf(stderr, "%d: %s\n", i, event.ToString().c_str()); | ||
| 82 | + i += 1; | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + float elapsed_seconds = | ||
| 86 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 87 | + .count() / | ||
| 88 | + 1000.; | ||
| 89 | + float rtf = elapsed_seconds / duration; | ||
| 90 | + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); | ||
| 91 | + fprintf(stderr, "Wave duration: %.3f\n", duration); | ||
| 92 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 93 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 94 | + elapsed_seconds, duration, rtf); | ||
| 95 | + | ||
| 96 | + return 0; | ||
| 97 | +} |
-
请 注册 或 登录 后发表评论