keanu
Committed by GitHub

Offline decode support multi threads (#306)

Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
@@ -85,9 +85,12 @@ set(sources @@ -85,9 +85,12 @@ set(sources
85 if(SHERPA_ONNX_ENABLE_CHECK) 85 if(SHERPA_ONNX_ENABLE_CHECK)
86 list(APPEND sources log.cc) 86 list(APPEND sources log.cc)
87 endif() 87 endif()
88 -  
89 add_library(sherpa-onnx-core ${sources}) 88 add_library(sherpa-onnx-core ${sources})
90 89
  90 +if(NOT WIN32)
  91 + target_link_libraries(sherpa-onnx-core -pthread)
  92 +endif()
  93 +
91 if(ANDROID_NDK) 94 if(ANDROID_NDK)
92 target_link_libraries(sherpa-onnx-core android log) 95 target_link_libraries(sherpa-onnx-core android log)
93 endif() 96 endif()
@@ -121,19 +124,23 @@ endif() @@ -121,19 +124,23 @@ endif()
121 124
122 add_executable(sherpa-onnx sherpa-onnx.cc) 125 add_executable(sherpa-onnx sherpa-onnx.cc)
123 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 126 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
  127 +add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
124 128
125 target_link_libraries(sherpa-onnx sherpa-onnx-core) 129 target_link_libraries(sherpa-onnx sherpa-onnx-core)
126 target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) 130 target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
  131 +target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core)
127 if(NOT WIN32) 132 if(NOT WIN32)
128 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") 133 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
129 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") 134 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
130 135
131 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") 136 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
132 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") 137 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
  138 + target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
133 139
134 if(SHERPA_ONNX_ENABLE_PYTHON) 140 if(SHERPA_ONNX_ENABLE_PYTHON)
135 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") 141 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
136 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") 142 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
  143 + target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
137 endif() 144 endif()
138 endif() 145 endif()
139 146
@@ -151,6 +158,7 @@ install( @@ -151,6 +158,7 @@ install(
151 TARGETS 158 TARGETS
152 sherpa-onnx 159 sherpa-onnx
153 sherpa-onnx-offline 160 sherpa-onnx-offline
  161 + sherpa-onnx-offline-parallel
154 DESTINATION 162 DESTINATION
155 bin 163 bin
156 ) 164 )
@@ -78,7 +78,9 @@ class OfflineWhisperModel::Impl { @@ -78,7 +78,9 @@ class OfflineWhisperModel::Impl {
78 decoder_input.size(), decoder_output_names_ptr_.data(), 78 decoder_input.size(), decoder_output_names_ptr_.data(),
79 decoder_output_names_ptr_.size()); 79 decoder_output_names_ptr_.size());
80 80
81 - return {std::move(decoder_out[0]), std::move(decoder_out[1]), 81 + return std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  82 + Ort::Value, Ort::Value>{
  83 + std::move(decoder_out[0]), std::move(decoder_out[1]),
82 std::move(decoder_out[2]), std::move(decoder_input[3]), 84 std::move(decoder_out[2]), std::move(decoder_input[3]),
83 std::move(decoder_input[4]), std::move(decoder_input[5])}; 85 std::move(decoder_input[4]), std::move(decoder_input[5])};
84 } 86 }
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-parallel.cc
  2 +//
  3 +// Copyright (c) 2022-2023 cuidc
  4 +
  5 +#include <stdio.h>
  6 +
  7 +#include <atomic>
  8 +#include <chrono> // NOLINT
  9 +#include <fstream>
  10 +#include <mutex> // NOLINT
  11 +#include <string>
  12 +#include <thread> // NOLINT
  13 +#include <vector>
  14 +
  15 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  16 +#include "sherpa-onnx/csrc/parse-options.h"
  17 +#include "sherpa-onnx/csrc/wave-reader.h"
  18 +
  19 +std::atomic<int> wav_index(0);
  20 +std::mutex mtx;
  21 +
  22 +std::vector<std::vector<std::string>> SplitToBatches(
  23 + const std::vector<std::string> &input, int32_t batch_size) {
  24 + std::vector<std::vector<std::string>> outputs;
  25 + auto itr = input.cbegin();
  26 + int32_t process_num = 0;
  27 +
  28 + while (process_num + batch_size <= static_cast<int32_t>(input.size())) {
  29 + auto chunk_end = itr + batch_size;
  30 + outputs.emplace_back(itr, chunk_end);
  31 + itr = chunk_end;
  32 + process_num += batch_size;
  33 + }
  34 + if (itr != input.cend()) {
  35 + outputs.emplace_back(itr, input.cend());
  36 + }
  37 + return outputs;
  38 +}
  39 +
  40 +std::vector<std::string> LoadScpFile(const std::string &wav_scp_path) {
  41 + std::vector<std::string> wav_paths;
  42 + std::ifstream in(wav_scp_path);
  43 + if (!in.is_open()) {
  44 + fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str());
  45 + return wav_paths;
  46 + }
  47 + std::string line, column1, column2;
  48 + while (std::getline(in, line)) {
  49 + std::istringstream iss(line);
  50 + iss >> column1 >> column2;
  51 + wav_paths.emplace_back(std::move(column2));
  52 + }
  53 +
  54 + return wav_paths;
  55 +}
  56 +
  57 +void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
  58 + sherpa_onnx::OfflineRecognizer* recognizer,
  59 + float* total_length, float* total_time) {
  60 + std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
  61 + std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
  62 + float duration = 0.0f;
  63 + float elapsed_seconds_batch = 0.0f;
  64 +
  65 + // warm up
  66 + for (const auto &wav_filename : chunk_wav_paths[0]) {
  67 + int32_t sampling_rate = -1;
  68 + bool is_ok = false;
  69 + const std::vector<float> samples =
  70 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  71 + if (!is_ok) {
  72 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  73 + continue;
  74 + }
  75 + duration += samples.size() / static_cast<float>(sampling_rate);
  76 + auto s = recognizer->CreateStream();
  77 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  78 +
  79 + ss.push_back(std::move(s));
  80 + ss_pointers.push_back(ss.back().get());
  81 + }
  82 + recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
  83 + ss_pointers.clear();
  84 + ss.clear();
  85 +
  86 + while (true) {
  87 + int chunk = wav_index.fetch_add(1);
  88 + if (chunk >= chunk_wav_paths.size()) {
  89 + break;
  90 + }
  91 + const auto &wav_paths = chunk_wav_paths[chunk];
  92 + const auto begin = std::chrono::steady_clock::now();
  93 + for (const auto &wav_filename : wav_paths) {
  94 + int32_t sampling_rate = -1;
  95 + bool is_ok = false;
  96 + const std::vector<float> samples =
  97 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  98 + if (!is_ok) {
  99 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  100 + continue;
  101 + }
  102 + duration += samples.size() / static_cast<float>(sampling_rate);
  103 + auto s = recognizer->CreateStream();
  104 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  105 +
  106 + ss.push_back(std::move(s));
  107 + ss_pointers.push_back(ss.back().get());
  108 + }
  109 + recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
  110 + const auto end = std::chrono::steady_clock::now();
  111 + float elapsed_seconds =
  112 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  113 + .count() /
  114 + 1000.;
  115 + elapsed_seconds_batch += elapsed_seconds;
  116 + int i = 0;
  117 + for (const auto &wav_filename : wav_paths) {
  118 + fprintf(stderr, "%s\n%s\n----\n", wav_filename.c_str(),
  119 + ss[i]->GetResult().AsJsonString().c_str());
  120 + i = i + 1;
  121 + }
  122 + ss_pointers.clear();
  123 + ss.clear();
  124 + }
  125 + fprintf(stderr, "thread %lu.\n", std::this_thread::get_id());
  126 + {
  127 + std::lock_guard<std::mutex> guard(mtx);
  128 + *total_length += duration;
  129 + if (*total_time < elapsed_seconds_batch) {
  130 + *total_time = elapsed_seconds_batch;
  131 + }
  132 + }
  133 +}
  134 +
  135 +
  136 +int main(int32_t argc, char *argv[]) {
  137 + const char *kUsageMessage = R"usage(
  138 +Speech recognition using non-streaming models with sherpa-onnx.
  139 +
  140 +Usage:
  141 +
  142 +(1) Transducer from icefall
  143 +
  144 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
  145 +
  146 + ./bin/sherpa-onnx-offline-parallel \
  147 + --tokens=/path/to/tokens.txt \
  148 + --encoder=/path/to/encoder.onnx \
  149 + --decoder=/path/to/decoder.onnx \
  150 + --joiner=/path/to/joiner.onnx \
  151 + --num-threads=1 \
  152 + --decoding-method=greedy_search \
  153 + --batch-size=8 \
  154 + --nj=1 \
  155 + --wav-scp=wav.scp
  156 +
  157 + ./bin/sherpa-onnx-offline-parallel \
  158 + --tokens=/path/to/tokens.txt \
  159 + --encoder=/path/to/encoder.onnx \
  160 + --decoder=/path/to/decoder.onnx \
  161 + --joiner=/path/to/joiner.onnx \
  162 + --num-threads=1 \
  163 + --decoding-method=greedy_search \
  164 + --batch-size=1 \
  165 + --nj=8 \
  166 + /path/to/foo.wav [bar.wav foobar.wav ...]
  167 +
  168 +(2) Paraformer from FunASR
  169 +
  170 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
  171 +
  172 + ./bin/sherpa-onnx-offline-parallel \
  173 + --tokens=/path/to/tokens.txt \
  174 + --paraformer=/path/to/model.onnx \
  175 + --num-threads=1 \
  176 + --decoding-method=greedy_search \
  177 + /path/to/foo.wav [bar.wav foobar.wav ...]
  178 +
  179 +(3) Whisper models
  180 +
  181 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
  182 +
  183 + ./bin/sherpa-onnx-offline-parallel \
  184 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  185 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  186 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  187 + --num-threads=1 \
  188 + /path/to/foo.wav [bar.wav foobar.wav ...]
  189 +
  190 +(4) NeMo CTC models
  191 +
  192 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
  193 +
  194 + ./bin/sherpa-onnx-offline-parallel \
  195 + --tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \
  196 + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
  197 + --num-threads=2 \
  198 + --decoding-method=greedy_search \
  199 + --debug=false \
  200 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
  201 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
  202 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
  203 +
  204 +(5) TDNN CTC model for the yesno recipe from icefall
  205 +
  206 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
  207 + //
  208 + ./bin/sherpa-onnx-offline-parallel \
  209 + --sample-rate=8000 \
  210 + --feat-dim=23 \
  211 + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  212 + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  213 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  214 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav
  215 +
  216 +Note: It supports decoding multiple files in batches
  217 +
  218 +foo.wav should be of single channel, 16-bit PCM encoded wave file; its
  219 +sampling rate can be arbitrary and does not need to be 16kHz.
  220 +
  221 +Please refer to
  222 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  223 +for a list of pre-trained models to download.
  224 +)usage";
  225 + std::string wav_scp = ""; // file path, kaldi style wav list.
  226 + int32_t nj = 1; // thread number
  227 + int32_t batch_size = 1; // number of wav files processed at once.
  228 + sherpa_onnx::ParseOptions po(kUsageMessage);
  229 + sherpa_onnx::OfflineRecognizerConfig config;
  230 + config.Register(&po);
  231 + po.Register("wav-scp", &wav_scp,
  232 + "a file including wav-id and wav-path, kaldi style wav list."
  233 + "default="". when it is not empty, wav files which positional "
  234 + "parameters provide are invalid.");
  235 + po.Register("nj", &nj,
  236 + "multi-thread num for decoding, default=1");
  237 + po.Register("batch-size", &batch_size,
  238 + "number of wav files processed at once during the decoding"
  239 + "process. default=1");
  240 +
  241 + po.Read(argc, argv);
  242 + if (po.NumArgs() < 1 && wav_scp.empty()) {
  243 + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
  244 + po.PrintUsage();
  245 + exit(EXIT_FAILURE);
  246 + }
  247 +
  248 + fprintf(stderr, "%s\n", config.ToString().c_str());
  249 +
  250 + if (!config.Validate()) {
  251 + fprintf(stderr, "Errors in config!\n");
  252 + return -1;
  253 + }
  254 + std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s
  255 + fprintf(stderr, "Creating recognizer ...\n");
  256 + const auto begin = std::chrono::steady_clock::now();
  257 + sherpa_onnx::OfflineRecognizer recognizer(config);
  258 + const auto end = std::chrono::steady_clock::now();
  259 + float elapsed_seconds =
  260 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  261 + .count() /
  262 + 1000.;
  263 + fprintf(stderr,
  264 + "Started nj: %d, batch_size: %d, wav_path: %s. recognizer init time: "
  265 + "%.6f\n", nj, batch_size, wav_scp.c_str(), elapsed_seconds);
  266 + std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s
  267 + std::vector<std::string> wav_paths;
  268 + if (!wav_scp.empty()) {
  269 + wav_paths = LoadScpFile(wav_scp);
  270 + } else {
  271 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  272 + wav_paths.emplace_back(po.GetArg(i));
  273 + }
  274 + }
  275 + if (wav_paths.empty()) {
  276 + fprintf(stderr, "wav files is empty.\n");
  277 + return -1;
  278 + }
  279 + std::vector<std::thread> threads;
  280 + std::vector<std::vector<std::string>> batch_wav_paths =
  281 + SplitToBatches(wav_paths, batch_size);
  282 + float total_length = 0.0f;
  283 + float total_time = 0.0f;
  284 + for (int i = 0; i < nj; i++) {
  285 + threads.emplace_back(std::thread(AsrInference, batch_wav_paths,
  286 + &recognizer, &total_length, &total_time));
  287 + }
  288 +
  289 + for (auto& thread : threads) {
  290 + thread.join();
  291 + }
  292 +
  293 + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
  294 + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
  295 + if (config.decoding_method == "modified_beam_search") {
  296 + fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
  297 + }
  298 + fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time);
  299 + float rtf = total_time / total_length;
  300 + fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n",
  301 + total_time, total_length, rtf);
  302 + fprintf(stderr, "SPEEDUP: %.4f\n", 1.0 / rtf);
  303 +
  304 + return 0;
  305 +}