Jingzhao Ou
Committed by GitHub

add batch processing to sherpa-onnx (#166)

... ... @@ -5,6 +5,8 @@
#include <stdio.h>
#include <chrono> // NOLINT
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
... ... @@ -14,6 +16,12 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
float duration;
float elapsed_seconds;
} Stream;
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Usage:
... ... @@ -61,29 +69,26 @@ for a list of pre-trained models to download.
sherpa_onnx::OnlineRecognizer recognizer(config);
float duration = 0;
std::vector<Stream> ss;
const auto begin = std::chrono::steady_clock::now();
std::vector<float> durations;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
const float duration = samples.size() / static_cast<float>(sampling_rate);
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
fprintf(stderr, "wav duration (s): %.3f\n", duration);
fprintf(stderr, "Started\n");
const auto begin = std::chrono::steady_clock::now();
auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
... ... @@ -94,33 +99,46 @@ for a list of pre-trained models to download.
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({ std::move(s), duration, 0 });
}
while (recognizer.IsReady(s.get())) {
recognizer.DecodeStream(s.get());
}
const std::string text = recognizer.GetResult(s.get()).AsJsonString();
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
for (;;) {
ready_streams.clear();
for (auto &s : ss) {
const auto p_ss = s.online_stream.get();
if (recognizer.IsReady(p_ss)) {
ready_streams.push_back(p_ss);
} else if (s.elapsed_seconds == 0) {
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() / 1000.;
s.elapsed_seconds = elapsed_seconds;
}
}
fprintf(stderr, "Done!\n");
fprintf(stderr,
"Recognition result for %s:\n%s\n",
wav_filename.c_str(), text.c_str());
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
if (config.decoding_method == "modified_beam_search") {
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
if (ready_streams.empty()) {
break;
}
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
const float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
recognizer.DecodeStreams(ready_streams.data(), ready_streams.size());
}
std::ostringstream os;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const auto &s = ss[i - 1];
const float rtf = s.elapsed_seconds / s.duration;
os << po.GetArg(i) << "\n";
os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
<< ", Real time factor (RTF): " << rtf << "\n";
const auto r = recognizer.GetResult(s.online_stream.get());
os << r.text << "\n";
os << r.AsJsonString() << "\n\n";
}
std::cerr << os.str();
return 0;
}
... ...