Fangjun Kuang
Committed by GitHub

Add Python API and Python examples for audio tagging (#753)

@@ -8,6 +8,15 @@ log() { @@ -8,6 +8,15 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test audio tagging"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  14 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  15 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  16 + python3 ./python-api-examples/audio-tagging-from-a-file.py
  17 +rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09
  18 +
  19 +
11 log "test streaming zipformer2 ctc HLG decoding" 20 log "test streaming zipformer2 ctc HLG decoding"
12 21
13 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 22 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
@@ -17,7 +17,6 @@ fi @@ -17,7 +17,6 @@ fi
17 if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then 17 if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
18 mkdir -p $onnxruntime_dir 18 mkdir -p $onnxruntime_dir
19 pushd $onnxruntime_dir 19 pushd $onnxruntime_dir
20 -# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2  
21 wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 20 wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
22 tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 21 tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
23 rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 22 rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
@@ -3,7 +3,6 @@ function(download_kaldi_native_fbank) @@ -3,7 +3,6 @@ function(download_kaldi_native_fbank)
3 3
4 set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz") 4 set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
5 set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz") 5 set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
6 -# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")  
7 set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904") 6 set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
8 7
9 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to use audio tagging Python APIs to tag a file.
  5 +
  6 +Please read the code to download the required model files and test wave file.
  7 +"""
  8 +
  9 +import logging
  10 +import time
  11 +from pathlib import Path
  12 +
  13 +import numpy as np
  14 +import sherpa_onnx
  15 +import soundfile as sf
  16 +
  17 +
  18 +def read_test_wave():
  19 + # Please download the model files and test wave files from
  20 + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
  21 + test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"
  22 +
  23 + if not Path(test_wave).is_file():
  24 + raise ValueError(
  25 + f"Please download {test_wave} from "
  26 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  27 + )
  28 +
  29 + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
  30 + data, sample_rate = sf.read(
  31 + test_wave,
  32 + always_2d=True,
  33 + dtype="float32",
  34 + )
  35 + data = data[:, 0] # use only the first channel
  36 + samples = np.ascontiguousarray(data)
  37 +
  38 + # samples is a 1-d array of dtype float32
  39 + # sample_rate is a scalar
  40 + return samples, sample_rate
  41 +
  42 +
  43 +def create_audio_tagger():
  44 + # Please download the model files and test wave files from
  45 + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
  46 + model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
  47 + label_file = (
  48 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
  49 + )
  50 +
  51 + if not Path(model_file).is_file():
  52 + raise ValueError(
  53 + f"Please download {model_file} from "
  54 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  55 + )
  56 +
  57 + if not Path(label_file).is_file():
  58 + raise ValueError(
  59 + f"Please download {label_file} from "
  60 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  61 + )
  62 +
  63 + config = sherpa_onnx.AudioTaggingConfig(
  64 + model=sherpa_onnx.AudioTaggingModelConfig(
  65 + zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
  66 + model=model_file,
  67 + ),
  68 + num_threads=1,
  69 + debug=True,
  70 + provider="cpu",
  71 + ),
  72 + labels=label_file,
  73 + top_k=5,
  74 + )
  75 + if not config.validate():
  76 + raise ValueError(f"Please check the config: {config}")
  77 +
  78 + print(config)
  79 +
  80 + return sherpa_onnx.AudioTagging(config)
  81 +
  82 +
  83 +def main():
  84 + logging.info("Create audio tagger")
  85 + audio_tagger = create_audio_tagger()
  86 +
  87 + logging.info("Read test wave")
  88 + samples, sample_rate = read_test_wave()
  89 +
  90 + logging.info("Computing")
  91 +
  92 + start_time = time.time()
  93 +
  94 + stream = audio_tagger.create_stream()
  95 + stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
  96 + result = audio_tagger.compute(stream)
  97 + end_time = time.time()
  98 +
  99 + elapsed_seconds = end_time - start_time
  100 + audio_duration = len(samples) / sample_rate
  101 +
  102 + real_time_factor = elapsed_seconds / audio_duration
  103 + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
  104 + logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
  105 + logging.info(
  106 + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
  107 + )
  108 +
  109 + s = "\n"
  110 + for i, e in enumerate(result):
  111 + s += f"{i}: {e}\n"
  112 +
  113 + logging.info(s)
  114 +
  115 +
  116 +if __name__ == "__main__":
  117 + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
  118 +
  119 + logging.basicConfig(format=formatter, level=logging.INFO)
  120 +
  121 + main()
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 set(srcs 3 set(srcs
  4 + audio-tagging.cc
4 circular-buffer.cc 5 circular-buffer.cc
5 display.cc 6 display.cc
6 endpoint.cc 7 endpoint.cc
  1 +// sherpa-onnx/python/csrc/audio-tagging.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/audio-tagging.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/audio-tagging.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) {
  14 + using PyClass = OfflineZipformerAudioTaggingModelConfig;
  15 + py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig")
  16 + .def(py::init<>())
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("validate", &PyClass::Validate)
  20 + .def("__str__", &PyClass::ToString);
  21 +}
  22 +
  23 +static void PybindAudioTaggingModelConfig(py::module *m) {
  24 + PybindOfflineZipformerAudioTaggingModelConfig(m);
  25 +
  26 + using PyClass = AudioTaggingModelConfig;
  27 +
  28 + py::class_<PyClass>(*m, "AudioTaggingModelConfig")
  29 + .def(py::init<>())
  30 + .def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
  31 + bool, const std::string &>(),
  32 + py::arg("zipformer"), py::arg("num_threads") = 1,
  33 + py::arg("debug") = false, py::arg("provider") = "cpu")
  34 + .def_readwrite("zipformer", &PyClass::zipformer)
  35 + .def_readwrite("num_threads", &PyClass::num_threads)
  36 + .def_readwrite("debug", &PyClass::debug)
  37 + .def_readwrite("provider", &PyClass::provider)
  38 + .def("validate", &PyClass::Validate)
  39 + .def("__str__", &PyClass::ToString);
  40 +}
  41 +
  42 +static void PybindAudioTaggingConfig(py::module *m) {
  43 + PybindAudioTaggingModelConfig(m);
  44 +
  45 + using PyClass = AudioTaggingConfig;
  46 +
  47 + py::class_<PyClass>(*m, "AudioTaggingConfig")
  48 + .def(py::init<>())
  49 + .def(py::init<const AudioTaggingModelConfig &, const std::string &,
  50 + int32_t>(),
  51 + py::arg("model"), py::arg("labels"), py::arg("top_k") = 5)
  52 + .def_readwrite("model", &PyClass::model)
  53 + .def_readwrite("labels", &PyClass::labels)
  54 + .def_readwrite("top_k", &PyClass::top_k)
  55 + .def("validate", &PyClass::Validate)
  56 + .def("__str__", &PyClass::ToString);
  57 +}
  58 +
  59 +static void PybindAudioEvent(py::module *m) {
  60 + using PyClass = AudioEvent;
  61 +
  62 + py::class_<PyClass>(*m, "AudioEvent")
  63 + .def_property_readonly(
  64 + "name", [](const PyClass &self) -> std::string { return self.name; })
  65 + .def_property_readonly(
  66 + "index", [](const PyClass &self) -> int32_t { return self.index; })
  67 + .def_property_readonly(
  68 + "prob", [](const PyClass &self) -> float { return self.prob; })
  69 + .def("__str__", &PyClass::ToString);
  70 +}
  71 +
  72 +void PybindAudioTagging(py::module *m) {
  73 + PybindAudioTaggingConfig(m);
  74 + PybindAudioEvent(m);
  75 +
  76 + using PyClass = AudioTagging;
  77 +
  78 + py::class_<PyClass>(*m, "AudioTagging")
  79 + .def(py::init<const AudioTaggingConfig &>(), py::arg("config"),
  80 + py::call_guard<py::gil_scoped_release>())
  81 + .def("create_stream", &PyClass::CreateStream,
  82 + py::call_guard<py::gil_scoped_release>())
  83 + .def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1,
  84 + py::call_guard<py::gil_scoped_release>());
  85 +}
  86 +
  87 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/audio-tagging.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindAudioTagging(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
@@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { @@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
16 py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig") 16 py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
17 .def(py::init<>()) 17 .def(py::init<>())
18 .def(py::init<const std::string &, const std::string &, 18 .def(py::init<const std::string &, const std::string &,
19 - const std::string &, const std::string, float, float, 19 + const std::string &, const std::string &, float, float,
20 float>(), 20 float>(),
21 py::arg("model"), py::arg("lexicon"), py::arg("tokens"), 21 py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
22 py::arg("data_dir") = "", py::arg("noise_scale") = 0.667, 22 py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h" 5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h"
6 6
7 #include "sherpa-onnx/python/csrc/alsa.h" 7 #include "sherpa-onnx/python/csrc/alsa.h"
  8 +#include "sherpa-onnx/python/csrc/audio-tagging.h"
8 #include "sherpa-onnx/python/csrc/circular-buffer.h" 9 #include "sherpa-onnx/python/csrc/circular-buffer.h"
9 #include "sherpa-onnx/python/csrc/display.h" 10 #include "sherpa-onnx/python/csrc/display.h"
10 #include "sherpa-onnx/python/csrc/endpoint.h" 11 #include "sherpa-onnx/python/csrc/endpoint.h"
@@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
38 m.doc() = "pybind11 binding of sherpa-onnx"; 39 m.doc() = "pybind11 binding of sherpa-onnx";
39 40
40 PybindWaveWriter(&m); 41 PybindWaveWriter(&m);
  42 + PybindAudioTagging(&m);
41 43
42 PybindFeatures(&m); 44 PybindFeatures(&m);
43 PybindOnlineCtcFstDecoderConfig(&m); 45 PybindOnlineCtcFstDecoderConfig(&m);
@@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { @@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
14 using PyClass = SpeakerEmbeddingExtractorConfig; 14 using PyClass = SpeakerEmbeddingExtractorConfig;
15 py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig") 15 py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
16 .def(py::init<>()) 16 .def(py::init<>())
17 - .def(py::init<const std::string &, int32_t, bool, const std::string>(), 17 + .def(py::init<const std::string &, int32_t, bool, const std::string &>(),
18 py::arg("model"), py::arg("num_threads") = 1, 18 py::arg("model"), py::arg("num_threads") = 1,
19 py::arg("debug") = false, py::arg("provider") = "cpu") 19 py::arg("debug") = false, py::arg("provider") = "cpu")
20 .def_readwrite("model", &PyClass::model) 20 .def_readwrite("model", &PyClass::model)
@@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) { @@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
33 py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig") 33 py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
34 .def(py::init<>()) 34 .def(py::init<>())
35 .def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t, 35 .def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
36 - bool, const std::string>(), 36 + bool, const std::string &>(),
37 py::arg("whisper"), py::arg("num_threads") = 1, 37 py::arg("whisper"), py::arg("num_threads") = 1,
38 py::arg("debug") = false, py::arg("provider") = "cpu") 38 py::arg("debug") = false, py::arg("provider") = "cpu")
39 .def_readwrite("whisper", &PyClass::whisper) 39 .def_readwrite("whisper", &PyClass::whisper)
@@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) { @@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) {
53 py::arg("config"), py::call_guard<py::gil_scoped_release>()) 53 py::arg("config"), py::call_guard<py::gil_scoped_release>())
54 .def("create_stream", &PyClass::CreateStream, 54 .def("create_stream", &PyClass::CreateStream,
55 py::call_guard<py::gil_scoped_release>()) 55 py::call_guard<py::gil_scoped_release>())
56 - .def("compute", &PyClass::Compute, 56 + .def("compute", &PyClass::Compute, py::arg("s"),
57 py::call_guard<py::gil_scoped_release>()); 57 py::call_guard<py::gil_scoped_release>());
58 } 58 }
59 59
1 from _sherpa_onnx import ( 1 from _sherpa_onnx import (
2 Alsa, 2 Alsa,
  3 + AudioEvent,
  4 + AudioTagging,
  5 + AudioTaggingConfig,
  6 + AudioTaggingModelConfig,
3 CircularBuffer, 7 CircularBuffer,
4 Display, 8 Display,
5 OfflineStream, 9 OfflineStream,
@@ -7,6 +11,7 @@ from _sherpa_onnx import ( @@ -7,6 +11,7 @@ from _sherpa_onnx import (
7 OfflineTtsConfig, 11 OfflineTtsConfig,
8 OfflineTtsModelConfig, 12 OfflineTtsModelConfig,
9 OfflineTtsVitsModelConfig, 13 OfflineTtsVitsModelConfig,
  14 + OfflineZipformerAudioTaggingModelConfig,
10 OnlineStream, 15 OnlineStream,
11 SileroVadModelConfig, 16 SileroVadModelConfig,
12 SpeakerEmbeddingExtractor, 17 SpeakerEmbeddingExtractor,