test_audio_tagging.kt
1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package com.k2fsa.sherpa.onnx
fun main() {
testAudioTagging()
}
fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
zipformer=OfflineZipformerAudioTaggingModelConfig(
model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
),
numThreads=1,
debug=true,
provider="cpu",
),
labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
topK=5,
)
val tagger = AudioTagging(config=config)
val testFiles = arrayOf(
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
)
println("----------")
for (waveFilename in testFiles) {
val stream = tagger.createStream()
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
stream.acceptWaveform(samples, sampleRate = sampleRate)
val events = tagger.compute(stream)
stream.release()
println(waveFilename)
println(events)
println("----------")
}
tagger.release()
}