test_online_asr.kt
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package com.k2fsa.sherpa.onnx
fun main() {
testOnlineAsr("transducer")
testOnlineAsr("zipformer2-ctc")
testOnlineAsr("ctc-hlg")
testOnlineAsr("nemo-ctc")
testOnlineAsr("tone-ctc")
}
fun testOnlineAsr(type: String) {
val featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)
var ctcFstDecoderConfig = OnlineCtcFstDecoderConfig()
val waveFilename: String
val modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
"nemo-ctc" -> {
waveFilename = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav"
OnlineModelConfig(
neMoCtc = OnlineNeMoCtcModelConfig(
model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx",
),
tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt",
numThreads = 1,
debug = false,
)
}
"tone-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/0.wav"
OnlineModelConfig(
toneCtc = OnlineToneCtcModelConfig(
model = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/model.onnx",
),
tokens = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/tokens.txt",
numThreads = 1,
debug = false,
)
}
"ctc-hlg" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/1.wav"
ctcFstDecoderConfig.graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
val endpointConfig = EndpointConfig()
val lmConfig = OnlineLMConfig()
val config = OnlineRecognizerConfig(
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
ctcFstDecoderConfig=ctcFstDecoderConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
val recognizer = OnlineRecognizer(
config = config,
)
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = recognizer.createStream()
val leftPaddings = FloatArray((sampleRate * 0.3).toInt()) // 0.3 seconds
stream.acceptWaveform(leftPaddings, sampleRate = sampleRate)
stream.acceptWaveform(samples, sampleRate = sampleRate)
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
val tailPaddings = FloatArray((sampleRate * 0.6).toInt()) // 0.6 seconds
stream.acceptWaveform(tailPaddings, sampleRate = sampleRate)
stream.inputFinished()
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
println("results: ${recognizer.getResult(stream).text}")
stream.release()
recognizer.release()
}