Jingzhao Ou
Committed by GitHub

add batch processing to sherpa-onnx (#166)

@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
5 #include <stdio.h> 5 #include <stdio.h>
6 6
7 #include <chrono> // NOLINT 7 #include <chrono> // NOLINT
  8 +#include <iomanip>
  9 +#include <iostream>
8 #include <string> 10 #include <string>
9 #include <vector> 11 #include <vector>
10 12
@@ -14,6 +16,12 @@ @@ -14,6 +16,12 @@
14 #include "sherpa-onnx/csrc/parse-options.h" 16 #include "sherpa-onnx/csrc/parse-options.h"
15 #include "sherpa-onnx/csrc/wave-reader.h" 17 #include "sherpa-onnx/csrc/wave-reader.h"
16 18
  19 +typedef struct {
  20 + std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
  21 + float duration;
  22 + float elapsed_seconds;
  23 +} Stream;
  24 +
17 int main(int32_t argc, char *argv[]) { 25 int main(int32_t argc, char *argv[]) {
18 const char *kUsageMessage = R"usage( 26 const char *kUsageMessage = R"usage(
19 Usage: 27 Usage:
@@ -61,29 +69,26 @@ for a list of pre-trained models to download. @@ -61,29 +69,26 @@ for a list of pre-trained models to download.
61 69
62 sherpa_onnx::OnlineRecognizer recognizer(config); 70 sherpa_onnx::OnlineRecognizer recognizer(config);
63 71
64 - float duration = 0; 72 + std::vector<Stream> ss;
  73 +
  74 + const auto begin = std::chrono::steady_clock::now();
  75 + std::vector<float> durations;
  76 +
65 for (int32_t i = 1; i <= po.NumArgs(); ++i) { 77 for (int32_t i = 1; i <= po.NumArgs(); ++i) {
66 const std::string wav_filename = po.GetArg(i); 78 const std::string wav_filename = po.GetArg(i);
67 int32_t sampling_rate = -1; 79 int32_t sampling_rate = -1;
68 80
69 bool is_ok = false; 81 bool is_ok = false;
70 const std::vector<float> samples = 82 const std::vector<float> samples =
71 - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); 83 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
72 84
73 if (!is_ok) { 85 if (!is_ok) {
74 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); 86 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
75 return -1; 87 return -1;
76 } 88 }
77 - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);  
78 89
79 const float duration = samples.size() / static_cast<float>(sampling_rate); 90 const float duration = samples.size() / static_cast<float>(sampling_rate);
80 91
81 - fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());  
82 - fprintf(stderr, "wav duration (s): %.3f\n", duration);  
83 -  
84 - fprintf(stderr, "Started\n");  
85 - const auto begin = std::chrono::steady_clock::now();  
86 -  
87 auto s = recognizer.CreateStream(); 92 auto s = recognizer.CreateStream();
88 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 93 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
89 94
@@ -94,33 +99,46 @@ for a list of pre-trained models to download. @@ -94,33 +99,46 @@ for a list of pre-trained models to download.
94 99
95 // Call InputFinished() to indicate that no audio samples are available 100 // Call InputFinished() to indicate that no audio samples are available
96 s->InputFinished(); 101 s->InputFinished();
  102 + ss.push_back({ std::move(s), duration, 0 });
  103 + }
97 104
98 - while (recognizer.IsReady(s.get())) {  
99 - recognizer.DecodeStream(s.get());  
100 - }  
101 -  
102 - const std::string text = recognizer.GetResult(s.get()).AsJsonString();  
103 -  
104 - const auto end = std::chrono::steady_clock::now();  
105 - const float elapsed_seconds =  
106 - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) 105 + std::vector<sherpa_onnx::OnlineStream *> ready_streams;
  106 + for (;;) {
  107 + ready_streams.clear();
  108 + for (auto &s : ss) {
  109 + const auto p_ss = s.online_stream.get();
  110 + if (recognizer.IsReady(p_ss)) {
  111 + ready_streams.push_back(p_ss);
  112 + } else if (s.elapsed_seconds == 0) {
  113 + const auto end = std::chrono::steady_clock::now();
  114 + const float elapsed_seconds =
  115 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
107 .count() / 1000.; 116 .count() / 1000.;
  117 + s.elapsed_seconds = elapsed_seconds;
  118 + }
  119 + }
108 120
109 - fprintf(stderr, "Done!\n");  
110 - fprintf(stderr,  
111 - "Recognition result for %s:\n%s\n",  
112 - wav_filename.c_str(), text.c_str());  
113 - fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);  
114 - fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());  
115 - if (config.decoding_method == "modified_beam_search") {  
116 - fprintf(stderr, "max active paths: %d\n", config.max_active_paths); 121 + if (ready_streams.empty()) {
  122 + break;
117 } 123 }
118 124
119 - fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);  
120 - const float rtf = elapsed_seconds / duration;  
121 - fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",  
122 - elapsed_seconds, duration, rtf); 125 + recognizer.DecodeStreams(ready_streams.data(), ready_streams.size());
123 } 126 }
124 127
  128 + std::ostringstream os;
  129 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  130 + const auto &s = ss[i - 1];
  131 + const float rtf = s.elapsed_seconds / s.duration;
  132 +
  133 + os << po.GetArg(i) << "\n";
  134 + os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
  135 + << ", Real time factor (RTF): " << rtf << "\n";
  136 + const auto r = recognizer.GetResult(s.online_stream.get());
  137 + os << r.text << "\n";
  138 + os << r.AsJsonString() << "\n\n";
  139 + }
  140 +
  141 + std::cerr << os.str();
  142 +
125 return 0; 143 return 0;
126 } 144 }