audio-tagging-from-a-file.py 3.4 KB
#!/usr/bin/env python3

"""
This script shows how to use audio tagging Python APIs to tag a file.

Please read the code to download the required model files and test wave file.
"""

import logging
import time
from pathlib import Path

import numpy as np
import sherpa_onnx
import soundfile as sf


def read_test_wave():
    # Please download the model files and test wave files from
    # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
    test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"

    if not Path(test_wave).is_file():
        raise ValueError(
            f"Please download {test_wave} from "
            "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
        )

    # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
    data, sample_rate = sf.read(
        test_wave,
        always_2d=True,
        dtype="float32",
    )
    data = data[:, 0]  # use only the first channel
    samples = np.ascontiguousarray(data)

    # samples is a 1-d array of dtype float32
    # sample_rate is a scalar
    return samples, sample_rate


def create_audio_tagger():
    # Please download the model files and test wave files from
    # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
    model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
    label_file = (
        "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
    )

    if not Path(model_file).is_file():
        raise ValueError(
            f"Please download {model_file} from "
            "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
        )

    if not Path(label_file).is_file():
        raise ValueError(
            f"Please download {label_file} from "
            "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
        )

    config = sherpa_onnx.AudioTaggingConfig(
        model=sherpa_onnx.AudioTaggingModelConfig(
            zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
                model=model_file,
            ),
            num_threads=1,
            debug=True,
            provider="cpu",
        ),
        labels=label_file,
        top_k=5,
    )
    if not config.validate():
        raise ValueError(f"Please check the config: {config}")

    print(config)

    return sherpa_onnx.AudioTagging(config)


def main():
    logging.info("Create audio tagger")
    audio_tagger = create_audio_tagger()

    logging.info("Read test wave")
    samples, sample_rate = read_test_wave()

    logging.info("Computing")

    start_time = time.time()

    stream = audio_tagger.create_stream()
    stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
    result = audio_tagger.compute(stream)
    end_time = time.time()

    elapsed_seconds = end_time - start_time
    audio_duration = len(samples) / sample_rate

    real_time_factor = elapsed_seconds / audio_duration
    logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
    logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
    logging.info(
        f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
    )

    s = "\n"
    for i, e in enumerate(result):
        s += f"{i}: {e}\n"

    logging.info(s)


if __name__ == "__main__":
    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

    logging.basicConfig(format=formatter, level=logging.INFO)

    main()