sherpa-onnx.cc
5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
// sherpa-onnx/csrc/sherpa-onnx.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
float duration;
float elapsed_seconds;
} Stream;
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Usage:
(1) Streaming transducer
./bin/sherpa-onnx \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=2 \
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
(2) Streaming zipformer2 CTC
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./bin/sherpa-onnx \
--debug=1 \
--zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav
(3) Streaming paraformer
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
./bin/sherpa-onnx \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Valid values for provider: cpu (default), cuda, coreml.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OnlineRecognizerConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() < 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::OnlineRecognizer recognizer(config);
std::vector<Stream> ss;
const auto begin = std::chrono::steady_clock::now();
std::vector<float> durations;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str());
return -1;
}
const float duration = samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({std::move(s), duration, 0});
}
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
for (;;) {
ready_streams.clear();
for (auto &s : ss) {
const auto p_ss = s.online_stream.get();
if (recognizer.IsReady(p_ss)) {
ready_streams.push_back(p_ss);
} else if (s.elapsed_seconds == 0) {
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
s.elapsed_seconds = elapsed_seconds;
}
}
if (ready_streams.empty()) {
break;
}
recognizer.DecodeStreams(ready_streams.data(), ready_streams.size());
}
std::ostringstream os;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const auto &s = ss[i - 1];
const float rtf = s.elapsed_seconds / s.duration;
os << po.GetArg(i) << "\n";
os << "Number of threads: " << config.model_config.num_threads << ", "
<< std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
<< ", Audio duration (s): " << s.duration
<< ", Real time factor (RTF) = " << s.elapsed_seconds << "/"
<< s.duration << " = " << rtf << "\n";
const auto r = recognizer.GetResult(s.online_stream.get());
os << r.text << "\n";
os << r.AsJsonString() << "\n\n";
}
std::cerr << os.str();
return 0;
}