Committed by
GitHub
add batch processing to sherpa-onnx (#166)
正在显示
1 个修改的文件
包含
48 行增加
和
30 行删除
| @@ -5,6 +5,8 @@ | @@ -5,6 +5,8 @@ | ||
| 5 | #include <stdio.h> | 5 | #include <stdio.h> |
| 6 | 6 | ||
| 7 | #include <chrono> // NOLINT | 7 | #include <chrono> // NOLINT |
| 8 | +#include <iomanip> | ||
| 9 | +#include <iostream> | ||
| 8 | #include <string> | 10 | #include <string> |
| 9 | #include <vector> | 11 | #include <vector> |
| 10 | 12 | ||
| @@ -14,6 +16,12 @@ | @@ -14,6 +16,12 @@ | ||
| 14 | #include "sherpa-onnx/csrc/parse-options.h" | 16 | #include "sherpa-onnx/csrc/parse-options.h" |
| 15 | #include "sherpa-onnx/csrc/wave-reader.h" | 17 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 16 | 18 | ||
| 19 | +typedef struct { | ||
| 20 | + std::unique_ptr<sherpa_onnx::OnlineStream> online_stream; | ||
| 21 | + float duration; | ||
| 22 | + float elapsed_seconds; | ||
| 23 | +} Stream; | ||
| 24 | + | ||
| 17 | int main(int32_t argc, char *argv[]) { | 25 | int main(int32_t argc, char *argv[]) { |
| 18 | const char *kUsageMessage = R"usage( | 26 | const char *kUsageMessage = R"usage( |
| 19 | Usage: | 27 | Usage: |
| @@ -61,29 +69,26 @@ for a list of pre-trained models to download. | @@ -61,29 +69,26 @@ for a list of pre-trained models to download. | ||
| 61 | 69 | ||
| 62 | sherpa_onnx::OnlineRecognizer recognizer(config); | 70 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 63 | 71 | ||
| 64 | - float duration = 0; | 72 | + std::vector<Stream> ss; |
| 73 | + | ||
| 74 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 75 | + std::vector<float> durations; | ||
| 76 | + | ||
| 65 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { | 77 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { |
| 66 | const std::string wav_filename = po.GetArg(i); | 78 | const std::string wav_filename = po.GetArg(i); |
| 67 | int32_t sampling_rate = -1; | 79 | int32_t sampling_rate = -1; |
| 68 | 80 | ||
| 69 | bool is_ok = false; | 81 | bool is_ok = false; |
| 70 | const std::vector<float> samples = | 82 | const std::vector<float> samples = |
| 71 | - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | 83 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); |
| 72 | 84 | ||
| 73 | if (!is_ok) { | 85 | if (!is_ok) { |
| 74 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | 86 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| 75 | return -1; | 87 | return -1; |
| 76 | } | 88 | } |
| 77 | - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 78 | 89 | ||
| 79 | const float duration = samples.size() / static_cast<float>(sampling_rate); | 90 | const float duration = samples.size() / static_cast<float>(sampling_rate); |
| 80 | 91 | ||
| 81 | - fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | ||
| 82 | - fprintf(stderr, "wav duration (s): %.3f\n", duration); | ||
| 83 | - | ||
| 84 | - fprintf(stderr, "Started\n"); | ||
| 85 | - const auto begin = std::chrono::steady_clock::now(); | ||
| 86 | - | ||
| 87 | auto s = recognizer.CreateStream(); | 92 | auto s = recognizer.CreateStream(); |
| 88 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | 93 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); |
| 89 | 94 | ||
| @@ -94,33 +99,46 @@ for a list of pre-trained models to download. | @@ -94,33 +99,46 @@ for a list of pre-trained models to download. | ||
| 94 | 99 | ||
| 95 | // Call InputFinished() to indicate that no audio samples are available | 100 | // Call InputFinished() to indicate that no audio samples are available |
| 96 | s->InputFinished(); | 101 | s->InputFinished(); |
| 102 | + ss.push_back({ std::move(s), duration, 0 }); | ||
| 103 | + } | ||
| 97 | 104 | ||
| 98 | - while (recognizer.IsReady(s.get())) { | ||
| 99 | - recognizer.DecodeStream(s.get()); | ||
| 100 | - } | ||
| 101 | - | ||
| 102 | - const std::string text = recognizer.GetResult(s.get()).AsJsonString(); | ||
| 103 | - | ||
| 104 | - const auto end = std::chrono::steady_clock::now(); | ||
| 105 | - const float elapsed_seconds = | ||
| 106 | - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | 105 | + std::vector<sherpa_onnx::OnlineStream *> ready_streams; |
| 106 | + for (;;) { | ||
| 107 | + ready_streams.clear(); | ||
| 108 | + for (auto &s : ss) { | ||
| 109 | + const auto p_ss = s.online_stream.get(); | ||
| 110 | + if (recognizer.IsReady(p_ss)) { | ||
| 111 | + ready_streams.push_back(p_ss); | ||
| 112 | + } else if (s.elapsed_seconds == 0) { | ||
| 113 | + const auto end = std::chrono::steady_clock::now(); | ||
| 114 | + const float elapsed_seconds = | ||
| 115 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 107 | .count() / 1000.; | 116 | .count() / 1000.; |
| 117 | + s.elapsed_seconds = elapsed_seconds; | ||
| 118 | + } | ||
| 119 | + } | ||
| 108 | 120 | ||
| 109 | - fprintf(stderr, "Done!\n"); | ||
| 110 | - fprintf(stderr, | ||
| 111 | - "Recognition result for %s:\n%s\n", | ||
| 112 | - wav_filename.c_str(), text.c_str()); | ||
| 113 | - fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | ||
| 114 | - fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||
| 115 | - if (config.decoding_method == "modified_beam_search") { | ||
| 116 | - fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | 121 | + if (ready_streams.empty()) { |
| 122 | + break; | ||
| 117 | } | 123 | } |
| 118 | 124 | ||
| 119 | - fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 120 | - const float rtf = elapsed_seconds / duration; | ||
| 121 | - fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 122 | - elapsed_seconds, duration, rtf); | 125 | + recognizer.DecodeStreams(ready_streams.data(), ready_streams.size()); |
| 123 | } | 126 | } |
| 124 | 127 | ||
| 128 | + std::ostringstream os; | ||
| 129 | + for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||
| 130 | + const auto &s = ss[i - 1]; | ||
| 131 | + const float rtf = s.elapsed_seconds / s.duration; | ||
| 132 | + | ||
| 133 | + os << po.GetArg(i) << "\n"; | ||
| 134 | + os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds | ||
| 135 | + << ", Real time factor (RTF): " << rtf << "\n"; | ||
| 136 | + const auto r = recognizer.GetResult(s.online_stream.get()); | ||
| 137 | + os << r.text << "\n"; | ||
| 138 | + os << r.AsJsonString() << "\n\n"; | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + std::cerr << os.str(); | ||
| 142 | + | ||
| 125 | return 0; | 143 | return 0; |
| 126 | } | 144 | } |
-
请 注册 或 登录 后发表评论