Fangjun Kuang
Committed by GitHub

Add Python API example for CED audio tagging. (#793)

  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to use audio tagging Python APIs to tag a file.
  5 +
  6 +Please read the code to download the required model files and test wave file.
  7 +"""
  8 +
  9 +import logging
  10 +import time
  11 +from pathlib import Path
  12 +
  13 +import numpy as np
  14 +import sherpa_onnx
  15 +import soundfile as sf
  16 +
  17 +
  18 +def read_test_wave():
  19 + # Please download the model files and test wave files from
  20 + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
  21 + test_wave = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav"
  22 +
  23 + if not Path(test_wave).is_file():
  24 + raise ValueError(
  25 + f"Please download {test_wave} from "
  26 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  27 + )
  28 +
  29 + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
  30 + data, sample_rate = sf.read(
  31 + test_wave,
  32 + always_2d=True,
  33 + dtype="float32",
  34 + )
  35 + data = data[:, 0] # use only the first channel
  36 + samples = np.ascontiguousarray(data)
  37 +
  38 + # samples is a 1-d array of dtype float32
  39 + # sample_rate is a scalar
  40 + return samples, sample_rate
  41 +
  42 +
  43 +def create_audio_tagger():
  44 + # Please download the model files and test wave files from
  45 + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
  46 + model_file = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx"
  47 + label_file = (
  48 + "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv"
  49 + )
  50 +
  51 + if not Path(model_file).is_file():
  52 + raise ValueError(
  53 + f"Please download {model_file} from "
  54 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  55 + )
  56 +
  57 + if not Path(label_file).is_file():
  58 + raise ValueError(
  59 + f"Please download {label_file} from "
  60 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
  61 + )
  62 +
  63 + config = sherpa_onnx.AudioTaggingConfig(
  64 + model=sherpa_onnx.AudioTaggingModelConfig(
  65 + ced=model_file,
  66 + num_threads=1,
  67 + debug=True,
  68 + provider="cpu",
  69 + ),
  70 + labels=label_file,
  71 + top_k=5,
  72 + )
  73 + if not config.validate():
  74 + raise ValueError(f"Please check the config: {config}")
  75 +
  76 + print(config)
  77 +
  78 + return sherpa_onnx.AudioTagging(config)
  79 +
  80 +
  81 +def main():
  82 + logging.info("Create audio tagger")
  83 + audio_tagger = create_audio_tagger()
  84 +
  85 + logging.info("Read test wave")
  86 + samples, sample_rate = read_test_wave()
  87 +
  88 + logging.info("Computing")
  89 +
  90 + start_time = time.time()
  91 +
  92 + stream = audio_tagger.create_stream()
  93 + stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
  94 + result = audio_tagger.compute(stream)
  95 + end_time = time.time()
  96 +
  97 + elapsed_seconds = end_time - start_time
  98 + audio_duration = len(samples) / sample_rate
  99 +
  100 + real_time_factor = elapsed_seconds / audio_duration
  101 + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
  102 + logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
  103 + logging.info(
  104 + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
  105 + )
  106 +
  107 + s = "\n"
  108 + for i, e in enumerate(result):
  109 + s += f"{i}: {e}\n"
  110 +
  111 + logging.info(s)
  112 +
  113 +
  114 +if __name__ == "__main__":
  115 + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
  116 +
  117 + logging.basicConfig(format=formatter, level=logging.INFO)
  118 +
  119 + main()
@@ -29,9 +29,9 @@ static void PybindAudioTaggingModelConfig(py::module *m) { @@ -29,9 +29,9 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
29 .def(py::init<>()) 29 .def(py::init<>())
30 .def(py::init<const OfflineZipformerAudioTaggingModelConfig &, 30 .def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
31 const std::string &, int32_t, bool, const std::string &>(), 31 const std::string &, int32_t, bool, const std::string &>(),
32 - py::arg("zipformer"), py::arg("ced") = "",  
33 - py::arg("num_threads") = 1, py::arg("debug") = false,  
34 - py::arg("provider") = "cpu") 32 + py::arg("zipformer") = OfflineZipformerAudioTaggingModelConfig{},
  33 + py::arg("ced") = "", py::arg("num_threads") = 1,
  34 + py::arg("debug") = false, py::arg("provider") = "cpu")
35 .def_readwrite("zipformer", &PyClass::zipformer) 35 .def_readwrite("zipformer", &PyClass::zipformer)
36 .def_readwrite("num_threads", &PyClass::num_threads) 36 .def_readwrite("num_threads", &PyClass::num_threads)
37 .def_readwrite("debug", &PyClass::debug) 37 .def_readwrite("debug", &PyClass::debug)