audio-tagging-from-a-file-ced.py
3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
"""
This script shows how to use audio tagging Python APIs to tag a file.
Please read the code to download the required model files and test wave file.
"""
import logging
import time
from pathlib import Path
import numpy as np
import sherpa_onnx
import soundfile as sf
def read_test_wave():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
test_wave = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav"
if not Path(test_wave).is_file():
raise ValueError(
f"Please download {test_wave} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
data, sample_rate = sf.read(
test_wave,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
# samples is a 1-d array of dtype float32
# sample_rate is a scalar
return samples, sample_rate
def create_audio_tagger():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
model_file = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx"
label_file = (
"./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv"
)
if not Path(model_file).is_file():
raise ValueError(
f"Please download {model_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
if not Path(label_file).is_file():
raise ValueError(
f"Please download {label_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
config = sherpa_onnx.AudioTaggingConfig(
model=sherpa_onnx.AudioTaggingModelConfig(
ced=model_file,
num_threads=1,
debug=True,
provider="cpu",
),
labels=label_file,
top_k=5,
)
if not config.validate():
raise ValueError(f"Please check the config: {config}")
print(config)
return sherpa_onnx.AudioTagging(config)
def main():
logging.info("Create audio tagger")
audio_tagger = create_audio_tagger()
logging.info("Read test wave")
samples, sample_rate = read_test_wave()
logging.info("Computing")
start_time = time.time()
stream = audio_tagger.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
result = audio_tagger.compute(stream)
end_time = time.time()
elapsed_seconds = end_time - start_time
audio_duration = len(samples) / sample_rate
real_time_factor = elapsed_seconds / audio_duration
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
logging.info(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
s = "\n"
for i, e in enumerate(result):
s += f"{i}: {e}\n"
logging.info(s)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()