Committed by
GitHub
add streaming websocket server and client (#62)
正在显示
20 个修改的文件
包含
1197 行增加
和
1 行删除
| @@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) | @@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) | ||
| 18 | option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) | 18 | option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) |
| 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) |
| 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) | 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) |
| 21 | +option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) | ||
| 21 | 22 | ||
| 22 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 23 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| 23 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 24 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| @@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") | @@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") | ||
| 59 | message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") | 60 | message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") |
| 60 | message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") | 61 | message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") |
| 61 | message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") | 62 | message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") |
| 63 | +message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") | ||
| 64 | +message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") | ||
| 62 | 65 | ||
| 63 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") | 66 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") |
| 64 | set(CMAKE_CXX_EXTENSIONS OFF) | 67 | set(CMAKE_CXX_EXTENSIONS OFF) |
| @@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 91 | include(googletest) | 94 | include(googletest) |
| 92 | endif() | 95 | endif() |
| 93 | 96 | ||
| 97 | +if(SHERPA_ONNX_ENABLE_WEBSOCKET) | ||
| 98 | + include(websocketpp) | ||
| 99 | + include(asio) | ||
| 100 | +endif() | ||
| 101 | + | ||
| 94 | add_subdirectory(sherpa-onnx) | 102 | add_subdirectory(sherpa-onnx) |
| 95 | 103 | ||
| 96 | if(SHERPA_ONNX_ENABLE_C_API) | 104 | if(SHERPA_ONNX_ENABLE_C_API) |
| @@ -40,6 +40,10 @@ cmake \ | @@ -40,6 +40,10 @@ cmake \ | ||
| 40 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | 40 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ |
| 41 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | 41 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ |
| 42 | -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | 42 | -DSHERPA_ONNX_ENABLE_CHECK=OFF \ |
| 43 | + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ | ||
| 44 | + -DSHERPA_ONNX_ENABLE_JNI=OFF \ | ||
| 45 | + -DSHERPA_ONNX_ENABLE_C_API=OFF \ | ||
| 46 | + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ | ||
| 43 | -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ | 47 | -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ |
| 44 | .. | 48 | .. |
| 45 | 49 |
| @@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" | @@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" | ||
| 76 | -DSHERPA_ONNX_ENABLE_JNI=ON \ | 76 | -DSHERPA_ONNX_ENABLE_JNI=ON \ |
| 77 | -DCMAKE_INSTALL_PREFIX=./install \ | 77 | -DCMAKE_INSTALL_PREFIX=./install \ |
| 78 | -DANDROID_ABI="x86_64" \ | 78 | -DANDROID_ABI="x86_64" \ |
| 79 | + -DSHERPA_ONNX_ENABLE_C_API=OFF \ | ||
| 80 | + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ | ||
| 79 | -DANDROID_PLATFORM=android-21 .. | 81 | -DANDROID_PLATFORM=android-21 .. |
| 80 | 82 | ||
| 81 | # make VERBOSE=1 -j4 | 83 | # make VERBOSE=1 -j4 |
cmake/asio.cmake
0 → 100644
| 1 | +function(download_asio) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz") | ||
| 5 | + set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b") | ||
| 6 | + | ||
| 7 | + # If you don't have access to the Internet, | ||
| 8 | + # please pre-download asio | ||
| 9 | + set(possible_file_locations | ||
| 10 | + $ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz | ||
| 11 | + ${PROJECT_SOURCE_DIR}/asio-asio-1-24-0.tar.gz | ||
| 12 | + ${PROJECT_BINARY_DIR}/asio-asio-1-24-0.tar.gz | ||
| 13 | + /tmp/asio-asio-1-24-0.tar.gz | ||
| 14 | + /star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz | ||
| 15 | + ) | ||
| 16 | + | ||
| 17 | + foreach(f IN LISTS possible_file_locations) | ||
| 18 | + if(EXISTS ${f}) | ||
| 19 | + set(asio_URL "file://${f}") | ||
| 20 | + break() | ||
| 21 | + endif() | ||
| 22 | + endforeach() | ||
| 23 | + | ||
| 24 | + FetchContent_Declare(asio | ||
| 25 | + URL ${asio_URL} | ||
| 26 | + URL_HASH ${asio_HASH} | ||
| 27 | + ) | ||
| 28 | + | ||
| 29 | + FetchContent_GetProperties(asio) | ||
| 30 | + if(NOT asio_POPULATED) | ||
| 31 | + message(STATUS "Downloading asio ${asio_URL}") | ||
| 32 | + FetchContent_Populate(asio) | ||
| 33 | + endif() | ||
| 34 | + message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}") | ||
| 35 | + # add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 36 | + include_directories(${asio_SOURCE_DIR}/asio/include) | ||
| 37 | +endfunction() | ||
| 38 | + | ||
| 39 | +download_asio() |
cmake/websocketpp.cmake
0 → 100644
| 1 | +function(download_websocketpp) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + # The latest commit on the develop branch os as 2022-10-22 | ||
| 5 | + set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip") | ||
| 6 | + set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download websocketpp | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip | ||
| 12 | + ${PROJECT_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip | ||
| 13 | + ${PROJECT_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip | ||
| 14 | + /tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip | ||
| 15 | + /star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(websocketpp_URL "file://${f}") | ||
| 21 | + break() | ||
| 22 | + endif() | ||
| 23 | + endforeach() | ||
| 24 | + | ||
| 25 | + FetchContent_Declare(websocketpp | ||
| 26 | + URL ${websocketpp_URL} | ||
| 27 | + URL_HASH ${websocketpp_HASH} | ||
| 28 | + ) | ||
| 29 | + | ||
| 30 | + FetchContent_GetProperties(websocketpp) | ||
| 31 | + if(NOT websocketpp_POPULATED) | ||
| 32 | + message(STATUS "Downloading websocketpp from ${websocketpp_URL}") | ||
| 33 | + FetchContent_Populate(websocketpp) | ||
| 34 | + endif() | ||
| 35 | + message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}") | ||
| 36 | + # add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 37 | + include_directories(${websocketpp_SOURCE_DIR}) | ||
| 38 | +endfunction() | ||
| 39 | + | ||
| 40 | +download_websocketpp() |
| @@ -4,6 +4,7 @@ set(sources | @@ -4,6 +4,7 @@ set(sources | ||
| 4 | cat.cc | 4 | cat.cc |
| 5 | endpoint.cc | 5 | endpoint.cc |
| 6 | features.cc | 6 | features.cc |
| 7 | + file-utils.cc | ||
| 7 | online-lstm-transducer-model.cc | 8 | online-lstm-transducer-model.cc |
| 8 | online-recognizer.cc | 9 | online-recognizer.cc |
| 9 | online-stream.cc | 10 | online-stream.cc |
| @@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) | @@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) | ||
| 86 | install(TARGETS sherpa-onnx-microphone DESTINATION bin) | 87 | install(TARGETS sherpa-onnx-microphone DESTINATION bin) |
| 87 | endif() | 88 | endif() |
| 88 | 89 | ||
| 90 | +if(SHERPA_ONNX_ENABLE_WEBSOCKET) | ||
| 91 | + add_definitions(-DASIO_STANDALONE) | ||
| 92 | + add_definitions(-D_WEBSOCKETPP_CPP11_STL_) | ||
| 93 | + | ||
| 94 | + add_executable(sherpa-onnx-online-websocket-server | ||
| 95 | + online-websocket-server-impl.cc | ||
| 96 | + online-websocket-server.cc | ||
| 97 | + ) | ||
| 98 | + target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) | ||
| 99 | + | ||
| 100 | + | ||
| 101 | + add_executable(sherpa-onnx-online-websocket-client | ||
| 102 | + online-websocket-client.cc | ||
| 103 | + ) | ||
| 104 | + target_link_libraries(sherpa-onnx-online-websocket-client sherpa-onnx-core) | ||
| 105 | + | ||
| 106 | + if(NOT WIN32) | ||
| 107 | + target_link_libraries(sherpa-onnx-online-websocket-server -pthread) | ||
| 108 | + target_compile_options(sherpa-onnx-online-websocket-server PRIVATE -Wno-deprecated-declarations) | ||
| 109 | + | ||
| 110 | + target_link_libraries(sherpa-onnx-online-websocket-client -pthread) | ||
| 111 | + target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) | ||
| 112 | + endif() | ||
| 113 | + | ||
| 114 | +endif() | ||
| 115 | + | ||
| 89 | 116 | ||
| 90 | if(SHERPA_ONNX_ENABLE_TESTS) | 117 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 91 | set(sherpa_onnx_test_srcs | 118 | set(sherpa_onnx_test_srcs |
sherpa-onnx/csrc/CPPLINT.cfg
0 → 100644
| 1 | +exclude_files=tee-stream.h |
| @@ -14,6 +14,15 @@ | @@ -14,6 +14,15 @@ | ||
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| 17 | +void FeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 18 | + po->Register("sample-rate", &sampling_rate, | ||
| 19 | + "Sampling rate of the input waveform. Must match the one " | ||
| 20 | + "expected by the model."); | ||
| 21 | + | ||
| 22 | + po->Register("feat-dim", &feature_dim, | ||
| 23 | + "Feature dimension. Must match the one expected by the model."); | ||
| 24 | +} | ||
| 25 | + | ||
| 17 | std::string FeatureExtractorConfig::ToString() const { | 26 | std::string FeatureExtractorConfig::ToString() const { |
| 18 | std::ostringstream os; | 27 | std::ostringstream os; |
| 19 | 28 |
| @@ -9,6 +9,8 @@ | @@ -9,6 +9,8 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 13 | + | ||
| 12 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 13 | 15 | ||
| 14 | struct FeatureExtractorConfig { | 16 | struct FeatureExtractorConfig { |
| @@ -16,6 +18,8 @@ struct FeatureExtractorConfig { | @@ -16,6 +18,8 @@ struct FeatureExtractorConfig { | ||
| 16 | int32_t feature_dim = 80; | 18 | int32_t feature_dim = 80; |
| 17 | 19 | ||
| 18 | std::string ToString() const; | 20 | std::string ToString() const; |
| 21 | + | ||
| 22 | + void Register(ParseOptions *po); | ||
| 19 | }; | 23 | }; |
| 20 | 24 | ||
| 21 | class FeatureExtractor { | 25 | class FeatureExtractor { |
sherpa-onnx/csrc/file-utils.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/file-utils.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 6 | + | ||
| 7 | +#include <fstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/log.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +bool FileExists(const std::string &filename) { | ||
| 15 | + return std::ifstream(filename).good(); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +void AssertFileExists(const std::string &filename) { | ||
| 19 | + if (!FileExists(filename)) { | ||
| 20 | + SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!"; | ||
| 21 | + } | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/file-utils.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/file-utils.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_FILE_UTILS_H_ | ||
| 7 | + | ||
| 8 | +#include <fstream> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +/** Check whether a given path is a file or not | ||
| 14 | + * | ||
| 15 | + * @param filename Path to check. | ||
| 16 | + * @return Return true if the given path is a file; return false otherwise. | ||
| 17 | + */ | ||
| 18 | +bool FileExists(const std::string &filename); | ||
| 19 | + | ||
| 20 | +/** Abort if the file does not exist. | ||
| 21 | + * | ||
| 22 | + * @param filename The file to check. | ||
| 23 | + */ | ||
| 24 | +void AssertFileExists(const std::string &filename); | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_ |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include <utility> | 12 | #include <utility> |
| 13 | #include <vector> | 13 | #include <vector> |
| 14 | 14 | ||
| 15 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 15 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 16 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 16 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 17 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 17 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 18 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| @@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 31 | return ans; | 32 | return ans; |
| 32 | } | 33 | } |
| 33 | 34 | ||
| 35 | +void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 36 | + feat_config.Register(po); | ||
| 37 | + model_config.Register(po); | ||
| 38 | + endpoint_config.Register(po); | ||
| 39 | + | ||
| 40 | + po->Register("enable-endpoint", &enable_endpoint, | ||
| 41 | + "True to enable endpoint detection. False to disable it."); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +bool OnlineRecognizerConfig::Validate() const { | ||
| 45 | + return model_config.Validate(); | ||
| 46 | +} | ||
| 47 | + | ||
| 34 | std::string OnlineRecognizerConfig::ToString() const { | 48 | std::string OnlineRecognizerConfig::ToString() const { |
| 35 | std::ostringstream os; | 49 | std::ostringstream os; |
| 36 | 50 |
| @@ -17,11 +17,15 @@ | @@ -17,11 +17,15 @@ | ||
| 17 | #include "sherpa-onnx/csrc/features.h" | 17 | #include "sherpa-onnx/csrc/features.h" |
| 18 | #include "sherpa-onnx/csrc/online-stream.h" | 18 | #include "sherpa-onnx/csrc/online-stream.h" |
| 19 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 20 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 20 | 21 | ||
| 21 | namespace sherpa_onnx { | 22 | namespace sherpa_onnx { |
| 22 | 23 | ||
| 23 | struct OnlineRecognizerResult { | 24 | struct OnlineRecognizerResult { |
| 24 | std::string text; | 25 | std::string text; |
| 26 | + | ||
| 27 | + // TODO(fangjun): Add a method to return a json string | ||
| 28 | + std::string ToString() const { return text; } | ||
| 25 | }; | 29 | }; |
| 26 | 30 | ||
| 27 | struct OnlineRecognizerConfig { | 31 | struct OnlineRecognizerConfig { |
| @@ -41,6 +45,9 @@ struct OnlineRecognizerConfig { | @@ -41,6 +45,9 @@ struct OnlineRecognizerConfig { | ||
| 41 | endpoint_config(endpoint_config), | 45 | endpoint_config(endpoint_config), |
| 42 | enable_endpoint(enable_endpoint) {} | 46 | enable_endpoint(enable_endpoint) {} |
| 43 | 47 | ||
| 48 | + void Register(ParseOptions *po); | ||
| 49 | + bool Validate() const; | ||
| 50 | + | ||
| 44 | std::string ToString() const; | 51 | std::string ToString() const; |
| 45 | }; | 52 | }; |
| 46 | 53 |
| @@ -5,8 +5,52 @@ | @@ -5,8 +5,52 @@ | ||
| 5 | 5 | ||
| 6 | #include <sstream> | 6 | #include <sstream> |
| 7 | 7 | ||
| 8 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 8 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 9 | 12 | ||
| 13 | +void OnlineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 14 | + po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); | ||
| 15 | + po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); | ||
| 16 | + po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); | ||
| 17 | + po->Register("tokens", &tokens, "Path to tokens.txt"); | ||
| 18 | + po->Register("num_threads", &num_threads, | ||
| 19 | + "Number of threads to run the neural network"); | ||
| 20 | + | ||
| 21 | + po->Register("debug", &debug, | ||
| 22 | + "true to print model information while loading it."); | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +bool OnlineTransducerModelConfig::Validate() const { | ||
| 26 | + if (!FileExists(tokens)) { | ||
| 27 | + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | ||
| 28 | + return false; | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + if (!FileExists(encoder_filename)) { | ||
| 32 | + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); | ||
| 33 | + return false; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + if (!FileExists(decoder_filename)) { | ||
| 37 | + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); | ||
| 38 | + return false; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + if (!FileExists(joiner_filename)) { | ||
| 42 | + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); | ||
| 43 | + return false; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + if (num_threads < 1) { | ||
| 47 | + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 48 | + return false; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + return true; | ||
| 52 | +} | ||
| 53 | + | ||
| 10 | std::string OnlineTransducerModelConfig::ToString() const { | 54 | std::string OnlineTransducerModelConfig::ToString() const { |
| 11 | std::ostringstream os; | 55 | std::ostringstream os; |
| 12 | 56 |
| @@ -6,6 +6,8 @@ | @@ -6,6 +6,8 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 9 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 10 | 12 | ||
| 11 | struct OnlineTransducerModelConfig { | 13 | struct OnlineTransducerModelConfig { |
| @@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig { | @@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig { | ||
| 13 | std::string decoder_filename; | 15 | std::string decoder_filename; |
| 14 | std::string joiner_filename; | 16 | std::string joiner_filename; |
| 15 | std::string tokens; | 17 | std::string tokens; |
| 16 | - int32_t num_threads; | 18 | + int32_t num_threads = 2; |
| 17 | bool debug = false; | 19 | bool debug = false; |
| 18 | 20 | ||
| 19 | OnlineTransducerModelConfig() = default; | 21 | OnlineTransducerModelConfig() = default; |
| @@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig { | @@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig { | ||
| 29 | num_threads(num_threads), | 31 | num_threads(num_threads), |
| 30 | debug(debug) {} | 32 | debug(debug) {} |
| 31 | 33 | ||
| 34 | + void Register(ParseOptions *po); | ||
| 35 | + bool Validate() const; | ||
| 36 | + | ||
| 32 | std::string ToString() const; | 37 | std::string ToString() const; |
| 33 | }; | 38 | }; |
| 34 | 39 |
sherpa-onnx/csrc/online-websocket-client.cc
0 → 100644
| 1 | +// sherpa/cpp_api/websocket/online-websocket-client.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022 Xiaomi Corporation | ||
| 4 | +#include <chrono> // NOLINT | ||
| 5 | +#include <fstream> | ||
| 6 | +#include <string> | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 11 | +#include "websocketpp/client.hpp" | ||
| 12 | +#include "websocketpp/config/asio_no_tls_client.hpp" | ||
| 13 | +#include "websocketpp/uri.hpp" | ||
| 14 | + | ||
| 15 | +using client = websocketpp::client<websocketpp::config::asio_client>; | ||
| 16 | + | ||
| 17 | +using message_ptr = client::message_ptr; | ||
| 18 | +using websocketpp::connection_hdl; | ||
| 19 | + | ||
| 20 | +static constexpr const char *kUsageMessage = R"( | ||
| 21 | +Automatic speech recognition with sherpa-onnx using websocket. | ||
| 22 | + | ||
| 23 | +Usage: | ||
| 24 | + | ||
| 25 | +./bin/sherpa-onnx-online-websocket-client --help | ||
| 26 | + | ||
| 27 | +./bin/sherpa-onnx-online-websocket-client \ | ||
| 28 | + --server-ip=127.0.0.1 \ | ||
| 29 | + --server-port=6006 \ | ||
| 30 | + --samples-per-message=8000 \ | ||
| 31 | + --seconds-per-message=0.2 \ | ||
| 32 | + /path/to/foo.wav | ||
| 33 | + | ||
| 34 | +It support only wave of with a single channel, 16kHz, 16-bit samples. | ||
| 35 | +)"; | ||
| 36 | + | ||
| 37 | +class Client { | ||
| 38 | + public: | ||
| 39 | + Client(asio::io_context &io, // NOLINT | ||
| 40 | + const std::string &ip, int16_t port, const std::vector<float> &samples, | ||
| 41 | + int32_t samples_per_message, float seconds_per_message) | ||
| 42 | + : io_(io), | ||
| 43 | + uri_(/*secure*/ false, ip, port, /*resource*/ "/"), | ||
| 44 | + samples_(samples), | ||
| 45 | + samples_per_message_(samples_per_message), | ||
| 46 | + seconds_per_message_(seconds_per_message) { | ||
| 47 | + c_.clear_access_channels(websocketpp::log::alevel::all); | ||
| 48 | + // c_.set_access_channels(websocketpp::log::alevel::connect); | ||
| 49 | + // c_.set_access_channels(websocketpp::log::alevel::disconnect); | ||
| 50 | + | ||
| 51 | + c_.init_asio(&io_); | ||
| 52 | + c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); | ||
| 53 | + c_.set_close_handler( | ||
| 54 | + [this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); }); | ||
| 55 | + c_.set_message_handler( | ||
| 56 | + [this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); }); | ||
| 57 | + | ||
| 58 | + Run(); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + private: | ||
| 62 | + void Run() { | ||
| 63 | + websocketpp::lib::error_code ec; | ||
| 64 | + client::connection_ptr con = c_.get_connection(uri_.str(), ec); | ||
| 65 | + if (ec) { | ||
| 66 | + SHERPA_ONNX_LOGE("Could not create connection to %s because %s", | ||
| 67 | + uri_.str().c_str(), ec.message().c_str()); | ||
| 68 | + exit(EXIT_FAILURE); | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + c_.connect(con); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + void OnOpen(connection_hdl hdl) { | ||
| 75 | + auto start_time = std::chrono::steady_clock::now(); | ||
| 76 | + asio::post( | ||
| 77 | + io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); }); | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + void OnMessage(connection_hdl hdl, message_ptr msg) { | ||
| 81 | + const std::string &payload = msg->get_payload(); | ||
| 82 | + | ||
| 83 | + if (payload == "Done!") { | ||
| 84 | + websocketpp::lib::error_code ec; | ||
| 85 | + c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec); | ||
| 86 | + if (ec) { | ||
| 87 | + SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str()); | ||
| 88 | + exit(EXIT_FAILURE); | ||
| 89 | + } | ||
| 90 | + } else { | ||
| 91 | + SHERPA_ONNX_LOGE("%s", payload.c_str()); | ||
| 92 | + } | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + void SendMessage( | ||
| 96 | + connection_hdl hdl, | ||
| 97 | + std::chrono::time_point<std::chrono::steady_clock> start_time) { | ||
| 98 | + int32_t num_samples = samples_.size(); | ||
| 99 | + int32_t num_messages = num_samples / samples_per_message_; | ||
| 100 | + | ||
| 101 | + websocketpp::lib::error_code ec; | ||
| 102 | + auto time = std::chrono::steady_clock::now(); | ||
| 103 | + int elapsed_time_ms = | ||
| 104 | + std::chrono::duration_cast<std::chrono::milliseconds>(time - start_time) | ||
| 105 | + .count(); | ||
| 106 | + | ||
| 107 | + if (elapsed_time_ms < | ||
| 108 | + static_cast<int>(seconds_per_message_ * num_sent_messages_ * 1000)) { | ||
| 109 | + std::this_thread::sleep_for(std::chrono::milliseconds(int( | ||
| 110 | + seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms))); | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + if (num_sent_messages_ < 1) { | ||
| 114 | + SHERPA_ONNX_LOGE("Starting to send audio"); | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + if (num_sent_messages_ < num_messages) { | ||
| 118 | + c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_, | ||
| 119 | + samples_per_message_ * sizeof(float), | ||
| 120 | + websocketpp::frame::opcode::binary, ec); | ||
| 121 | + | ||
| 122 | + if (ec) { | ||
| 123 | + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", | ||
| 124 | + ec.message().c_str()); | ||
| 125 | + exit(EXIT_FAILURE); | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + ec.clear(); | ||
| 129 | + | ||
| 130 | + ++num_sent_messages_; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + if (num_sent_messages_ == num_messages) { | ||
| 134 | + int32_t remaining_samples = num_samples % samples_per_message_; | ||
| 135 | + if (remaining_samples) { | ||
| 136 | + c_.send(hdl, | ||
| 137 | + samples_.data() + num_sent_messages_ * samples_per_message_, | ||
| 138 | + remaining_samples * sizeof(float), | ||
| 139 | + websocketpp::frame::opcode::binary, ec); | ||
| 140 | + | ||
| 141 | + if (ec) { | ||
| 142 | + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", | ||
| 143 | + ec.message().c_str()); | ||
| 144 | + exit(EXIT_FAILURE); | ||
| 145 | + } | ||
| 146 | + ec.clear(); | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + // To signal that we have send all the messages | ||
| 150 | + c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec); | ||
| 151 | + SHERPA_ONNX_LOGE("Sent Done Signal"); | ||
| 152 | + | ||
| 153 | + if (ec) { | ||
| 154 | + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", | ||
| 155 | + ec.message().c_str()); | ||
| 156 | + exit(EXIT_FAILURE); | ||
| 157 | + } | ||
| 158 | + } else { | ||
| 159 | + asio::post(io_, [this, hdl, start_time]() { | ||
| 160 | + this->SendMessage(hdl, start_time); | ||
| 161 | + }); | ||
| 162 | + } | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + private: | ||
| 166 | + client c_; | ||
| 167 | + asio::io_context &io_; | ||
| 168 | + websocketpp::uri uri_; | ||
| 169 | + std::vector<float> samples_; | ||
| 170 | + int32_t samples_per_message_ = 8000; // 0.5 seconds | ||
| 171 | + float seconds_per_message_ = 0.2; | ||
| 172 | + int32_t num_sent_messages_ = 0; | ||
| 173 | +}; | ||
| 174 | + | ||
| 175 | +int32_t main(int32_t argc, char *argv[]) { | ||
| 176 | + std::string server_ip = "127.0.0.1"; | ||
| 177 | + int32_t server_port = 6006; | ||
| 178 | + | ||
| 179 | + // Sample rate of the input wave. No resampling is made. | ||
| 180 | + int32_t sample_rate = 16000; | ||
| 181 | + int32_t samples_per_message = 8000; | ||
| 182 | + float seconds_per_message = 0.2; | ||
| 183 | + | ||
| 184 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 185 | + | ||
| 186 | + po.Register("server-ip", &server_ip, "IP address of the websocket server"); | ||
| 187 | + po.Register("server-port", &server_port, "Port of the websocket server"); | ||
| 188 | + po.Register("sample-rate", &sample_rate, | ||
| 189 | + "Sample rate of the input wave. Should be the one expected by " | ||
| 190 | + "the server"); | ||
| 191 | + | ||
| 192 | + po.Register("samples-per-message", &samples_per_message, | ||
| 193 | + "Send this number of samples per message."); | ||
| 194 | + | ||
| 195 | + po.Register("seconds-per-message", &seconds_per_message, | ||
| 196 | + "We will simulate that each message takes this number of seconds " | ||
| 197 | + "to send. If you select a very large value, it will take a long " | ||
| 198 | + "time to send all the samples"); | ||
| 199 | + | ||
| 200 | + po.Read(argc, argv); | ||
| 201 | + | ||
| 202 | + if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(), | ||
| 203 | + server_ip.end())) { | ||
| 204 | + SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str()); | ||
| 205 | + return -1; | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + if (server_port <= 0 || server_port > 65535) { | ||
| 209 | + SHERPA_ONNX_LOGE("Invalid server port: %d", server_port); | ||
| 210 | + return -1; | ||
| 211 | + } | ||
| 212 | + | ||
| 213 | + // 0.01 is an arbitrary value. You can change it. | ||
| 214 | + if (samples_per_message <= 0.01 * sample_rate) { | ||
| 215 | + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", | ||
| 216 | + samples_per_message); | ||
| 217 | + return -1; | ||
| 218 | + } | ||
| 219 | + | ||
| 220 | + // 100 is an arbitrary value. You can change it. | ||
| 221 | + if (samples_per_message >= sample_rate * 100) { | ||
| 222 | + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", | ||
| 223 | + samples_per_message); | ||
| 224 | + return -1; | ||
| 225 | + } | ||
| 226 | + | ||
| 227 | + if (seconds_per_message < 0) { | ||
| 228 | + SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f", | ||
| 229 | + seconds_per_message); | ||
| 230 | + return -1; | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + // 1 is an arbitrary value. | ||
| 234 | + if (seconds_per_message > 1) { | ||
| 235 | + SHERPA_ONNX_LOGE( | ||
| 236 | + "--seconds-per-message is too large: %.3f. You will wait a long time " | ||
| 237 | + "to " | ||
| 238 | + "send all the samples", | ||
| 239 | + seconds_per_message); | ||
| 240 | + return -1; | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + if (po.NumArgs() != 1) { | ||
| 244 | + po.PrintUsage(); | ||
| 245 | + return -1; | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + std::string wave_filename = po.GetArg(1); | ||
| 249 | + | ||
| 250 | + bool is_ok = false; | ||
| 251 | + std::vector<float> samples = | ||
| 252 | + sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok); | ||
| 253 | + | ||
| 254 | + if (!is_ok) { | ||
| 255 | + SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); | ||
| 256 | + return -1; | ||
| 257 | + } | ||
| 258 | + | ||
| 259 | + asio::io_context io_conn; // for network connections | ||
| 260 | + Client c(io_conn, server_ip, server_port, samples, samples_per_message, | ||
| 261 | + seconds_per_message); | ||
| 262 | + | ||
| 263 | + io_conn.run(); // will exit when the above connection is closed | ||
| 264 | + | ||
| 265 | + SHERPA_ONNX_LOGE("Done!"); | ||
| 266 | + return 0; | ||
| 267 | +} |
| 1 | +// sherpa-onnx/csrc/online-websocket-server-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-websocket-server-impl.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/log.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) { | ||
| 15 | + recognizer_config.Register(po); | ||
| 16 | + | ||
| 17 | + po->Register("loop-interval-ms", &loop_interval_ms, | ||
| 18 | + "It determines how often the decoder loop runs. "); | ||
| 19 | + | ||
| 20 | + po->Register("max-batch-size", &max_batch_size, | ||
| 21 | + "Max batch size for recognition."); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +void OnlineWebsocketDecoderConfig::Validate() const { | ||
| 25 | + recognizer_config.Validate(); | ||
| 26 | + SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); | ||
| 27 | + SHERPA_ONNX_CHECK_GT(max_batch_size, 0); | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) { | ||
| 31 | + decoder_config.Register(po); | ||
| 32 | + | ||
| 33 | + po->Register("log-file", &log_file, | ||
| 34 | + "Path to the log file. Logs are " | ||
| 35 | + "appended to this file"); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +void OnlineWebsocketServerConfig::Validate() const { | ||
| 39 | + decoder_config.Validate(); | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server) | ||
| 43 | + : server_(server), | ||
| 44 | + config_(server->GetConfig().decoder_config), | ||
| 45 | + timer_(server->GetWorkContext()) { | ||
| 46 | + recognizer_ = std::make_unique<OnlineRecognizer>(config_.recognizer_config); | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +std::shared_ptr<Connection> OnlineWebsocketDecoder::GetOrCreateConnection( | ||
| 50 | + connection_hdl hdl) { | ||
| 51 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 52 | + auto it = connections_.find(hdl); | ||
| 53 | + if (it != connections_.end()) { | ||
| 54 | + return it->second; | ||
| 55 | + } else { | ||
| 56 | + // create a new connection | ||
| 57 | + std::shared_ptr<OnlineStream> s = recognizer_->CreateStream(); | ||
| 58 | + auto c = std::make_shared<Connection>(hdl, s); | ||
| 59 | + connections_.insert({hdl, c}); | ||
| 60 | + return c; | ||
| 61 | + } | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr<Connection> c) { | ||
| 65 | + std::lock_guard<std::mutex> lock(c->mutex); | ||
| 66 | + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; | ||
| 67 | + while (!c->samples.empty()) { | ||
| 68 | + const auto &s = c->samples.front(); | ||
| 69 | + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); | ||
| 70 | + c->samples.pop_front(); | ||
| 71 | + } | ||
| 72 | +} | ||
| 73 | + | ||
| 74 | +void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) { | ||
| 75 | + std::lock_guard<std::mutex> lock(c->mutex); | ||
| 76 | + | ||
| 77 | + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; | ||
| 78 | + | ||
| 79 | + while (!c->samples.empty()) { | ||
| 80 | + const auto &s = c->samples.front(); | ||
| 81 | + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); | ||
| 82 | + c->samples.pop_front(); | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + // TODO(fangjun): Change the amount of paddings to be configurable | ||
| 86 | + std::vector<float> tail_padding(static_cast<int64_t>(0.8 * sample_rate)); | ||
| 87 | + | ||
| 88 | + c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); | ||
| 89 | + | ||
| 90 | + c->s->InputFinished(); | ||
| 91 | + c->eof = true; | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +void OnlineWebsocketDecoder::Run() { | ||
| 95 | + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); | ||
| 96 | + | ||
| 97 | + timer_.async_wait( | ||
| 98 | + [this](const asio::error_code &ec) { ProcessConnections(ec); }); | ||
| 99 | +} | ||
| 100 | + | ||
| 101 | +void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) { | ||
| 102 | + if (ec) { | ||
| 103 | + SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!"; | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 107 | + std::vector<connection_hdl> to_remove; | ||
| 108 | + for (auto &p : connections_) { | ||
| 109 | + auto hdl = p.first; | ||
| 110 | + auto c = p.second; | ||
| 111 | + | ||
| 112 | + // The order of `if` below matters! | ||
| 113 | + if (!server_->Contains(hdl)) { | ||
| 114 | + // If the connection is disconnected, we stop processing it | ||
| 115 | + to_remove.push_back(hdl); | ||
| 116 | + continue; | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + if (active_.count(hdl)) { | ||
| 120 | + // Another thread is decoding this stream, so skip it | ||
| 121 | + continue; | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + if (!recognizer_->IsReady(c->s.get()) && !c->eof) { | ||
| 125 | + // this stream has not enough frames to decode, so skip it | ||
| 126 | + continue; | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + if (!recognizer_->IsReady(c->s.get()) && c->eof) { | ||
| 130 | + // We won't receive samples from the client, so send a Done! to client | ||
| 131 | + | ||
| 132 | + asio::post(server_->GetWorkContext(), | ||
| 133 | + [this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); }); | ||
| 134 | + | ||
| 135 | + to_remove.push_back(hdl); | ||
| 136 | + continue; | ||
| 137 | + } | ||
| 138 | + | ||
| 139 | + // TODO(fangun): If the connection is timed out, we need to also | ||
| 140 | + // add it to `to_remove` | ||
| 141 | + | ||
| 142 | + // this stream has enough frames and is currently not processed by any | ||
| 143 | + // threads, so put it into the ready queue | ||
| 144 | + ready_connections_.push_back(c); | ||
| 145 | + | ||
| 146 | + // In `Decode()`, it will remove hdl from `active_` | ||
| 147 | + active_.insert(c->hdl); | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + for (auto hdl : to_remove) { | ||
| 151 | + connections_.erase(hdl); | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + if (!ready_connections_.empty()) { | ||
| 155 | + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + // Schedule another call | ||
| 159 | + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); | ||
| 160 | + | ||
| 161 | + timer_.async_wait( | ||
| 162 | + [this](const asio::error_code &ec) { ProcessConnections(ec); }); | ||
| 163 | +} | ||
| 164 | + | ||
| 165 | +void OnlineWebsocketDecoder::Decode() { | ||
| 166 | + std::unique_lock<std::mutex> lock(mutex_); | ||
| 167 | + if (ready_connections_.empty()) { | ||
| 168 | + // There are no connections that are ready for decoding, | ||
| 169 | + // so we return directly | ||
| 170 | + return; | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + std::vector<std::shared_ptr<Connection>> c_vec; | ||
| 174 | + std::vector<OnlineStream *> s_vec; | ||
| 175 | + while (!ready_connections_.empty() && | ||
| 176 | + static_cast<int32_t>(s_vec.size()) < config_.max_batch_size) { | ||
| 177 | + auto c = ready_connections_.front(); | ||
| 178 | + ready_connections_.pop_front(); | ||
| 179 | + | ||
| 180 | + c_vec.push_back(c); | ||
| 181 | + s_vec.push_back(c->s.get()); | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + if (!ready_connections_.empty()) { | ||
| 185 | + // there are too many ready connections but this thread can only handle | ||
| 186 | + // max_batch_size connections at a time, so we schedule another call | ||
| 187 | + // to Decode() and let other threads to process the ready connections | ||
| 188 | + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + lock.unlock(); | ||
| 192 | + recognizer_->DecodeStreams(s_vec.data(), s_vec.size()); | ||
| 193 | + lock.lock(); | ||
| 194 | + | ||
| 195 | + for (auto c : c_vec) { | ||
| 196 | + auto result = recognizer_->GetResult(c->s.get()); | ||
| 197 | + | ||
| 198 | + asio::post(server_->GetConnectionContext(), | ||
| 199 | + [this, hdl = c->hdl, str = result.ToString()]() { | ||
| 200 | + server_->Send(hdl, str); | ||
| 201 | + }); | ||
| 202 | + active_.erase(c->hdl); | ||
| 203 | + } | ||
| 204 | +} | ||
| 205 | + | ||
| 206 | +OnlineWebsocketServer::OnlineWebsocketServer( | ||
| 207 | + asio::io_context &io_conn, asio::io_context &io_work, | ||
| 208 | + const OnlineWebsocketServerConfig &config) | ||
| 209 | + : config_(config), | ||
| 210 | + io_conn_(io_conn), | ||
| 211 | + io_work_(io_work), | ||
| 212 | + log_(config.log_file, std::ios::app), | ||
| 213 | + tee_(std::cout, log_), | ||
| 214 | + decoder_(this) { | ||
| 215 | + SetupLog(); | ||
| 216 | + | ||
| 217 | + server_.init_asio(&io_conn_); | ||
| 218 | + | ||
| 219 | + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); | ||
| 220 | + | ||
| 221 | + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); | ||
| 222 | + | ||
| 223 | + server_.set_message_handler( | ||
| 224 | + [this](connection_hdl hdl, server::message_ptr msg) { | ||
| 225 | + OnMessage(hdl, msg); | ||
| 226 | + }); | ||
| 227 | +} | ||
| 228 | + | ||
| 229 | +void OnlineWebsocketServer::Run(uint16_t port) { | ||
| 230 | + server_.set_reuse_addr(true); | ||
| 231 | + server_.listen(asio::ip::tcp::v4(), port); | ||
| 232 | + server_.start_accept(); | ||
| 233 | + decoder_.Run(); | ||
| 234 | +} | ||
| 235 | + | ||
| 236 | +void OnlineWebsocketServer::SetupLog() { | ||
| 237 | + server_.clear_access_channels(websocketpp::log::alevel::all); | ||
| 238 | + // server_.set_access_channels(websocketpp::log::alevel::connect); | ||
| 239 | + // server_.set_access_channels(websocketpp::log::alevel::disconnect); | ||
| 240 | + | ||
| 241 | + // So that it also prints to std::cout and std::cerr | ||
| 242 | + server_.get_alog().set_ostream(&tee_); | ||
| 243 | + server_.get_elog().set_ostream(&tee_); | ||
| 244 | +} | ||
| 245 | + | ||
| 246 | +void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) { | ||
| 247 | + websocketpp::lib::error_code ec; | ||
| 248 | + if (!Contains(hdl)) { | ||
| 249 | + return; | ||
| 250 | + } | ||
| 251 | + | ||
| 252 | + server_.send(hdl, text, websocketpp::frame::opcode::text, ec); | ||
| 253 | + if (ec) { | ||
| 254 | + server_.get_alog().write(websocketpp::log::alevel::app, ec.message()); | ||
| 255 | + } | ||
| 256 | +} | ||
| 257 | + | ||
| 258 | +void OnlineWebsocketServer::OnOpen(connection_hdl hdl) { | ||
| 259 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 260 | + connections_.insert(hdl); | ||
| 261 | + | ||
| 262 | + std::ostringstream os; | ||
| 263 | + os << "New connection: " | ||
| 264 | + << server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". " | ||
| 265 | + << "Number of active connections: " << connections_.size() << ".\n"; | ||
| 266 | + SHERPA_ONNX_LOG(INFO) << os.str(); | ||
| 267 | +} | ||
| 268 | + | ||
| 269 | +void OnlineWebsocketServer::OnClose(connection_hdl hdl) { | ||
| 270 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 271 | + connections_.erase(hdl); | ||
| 272 | + | ||
| 273 | + SHERPA_ONNX_LOG(INFO) << "Number of active connections: " | ||
| 274 | + << connections_.size() << "\n"; | ||
| 275 | +} | ||
| 276 | + | ||
| 277 | +bool OnlineWebsocketServer::Contains(connection_hdl hdl) const { | ||
| 278 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 279 | + return connections_.count(hdl); | ||
| 280 | +} | ||
| 281 | + | ||
| 282 | +void OnlineWebsocketServer::OnMessage(connection_hdl hdl, | ||
| 283 | + server::message_ptr msg) { | ||
| 284 | + auto c = decoder_.GetOrCreateConnection(hdl); | ||
| 285 | + | ||
| 286 | + const std::string &payload = msg->get_payload(); | ||
| 287 | + | ||
| 288 | + switch (msg->get_opcode()) { | ||
| 289 | + case websocketpp::frame::opcode::text: | ||
| 290 | + if (payload == "Done") { | ||
| 291 | + asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); }); | ||
| 292 | + } | ||
| 293 | + break; | ||
| 294 | + case websocketpp::frame::opcode::binary: { | ||
| 295 | + auto p = reinterpret_cast<const float *>(payload.data()); | ||
| 296 | + int32_t num_samples = payload.size() / sizeof(float); | ||
| 297 | + std::vector<float> samples(p, p + num_samples); | ||
| 298 | + | ||
| 299 | + c->samples.push_back(std::move(samples)); | ||
| 300 | + | ||
| 301 | + asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); }); | ||
| 302 | + break; | ||
| 303 | + } | ||
| 304 | + default: | ||
| 305 | + break; | ||
| 306 | + } | ||
| 307 | +} | ||
| 308 | + | ||
| 309 | +void OnlineWebsocketServer::Close(connection_hdl hdl, | ||
| 310 | + websocketpp::close::status::value code, | ||
| 311 | + const std::string &reason) { | ||
| 312 | + auto con = server_.get_con_from_hdl(hdl); | ||
| 313 | + | ||
| 314 | + std::ostringstream os; | ||
| 315 | + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason | ||
| 316 | + << "\n"; | ||
| 317 | + | ||
| 318 | + websocketpp::lib::error_code ec; | ||
| 319 | + server_.close(hdl, code, reason, ec); | ||
| 320 | + if (ec) { | ||
| 321 | + os << "Failed to close" << con->get_remote_endpoint() << ". " | ||
| 322 | + << ec.message() << "\n"; | ||
| 323 | + } | ||
| 324 | + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); | ||
| 325 | +} | ||
| 326 | + | ||
| 327 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-websocket-server-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <deque> | ||
| 9 | +#include <fstream> | ||
| 10 | +#include <map> | ||
| 11 | +#include <memory> | ||
| 12 | +#include <mutex> // NOLINT | ||
| 13 | +#include <set> | ||
| 14 | +#include <string> | ||
| 15 | +#include <unordered_set> | ||
| 16 | +#include <utility> | ||
| 17 | +#include <vector> | ||
| 18 | + | ||
| 19 | +#include "asio.hpp" | ||
| 20 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 21 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 22 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 23 | +#include "sherpa-onnx/csrc/tee-stream.h" | ||
| 24 | +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS | ||
| 25 | +#include "websocketpp/server.hpp" | ||
| 26 | +using server = websocketpp::server<websocketpp::config::asio>; | ||
| 27 | +using connection_hdl = websocketpp::connection_hdl; | ||
| 28 | + | ||
| 29 | +namespace sherpa_onnx { | ||
| 30 | + | ||
| 31 | +struct Connection { | ||
| 32 | + // handle to the connection. We can use it to send messages to the client | ||
| 33 | + connection_hdl hdl; | ||
| 34 | + std::shared_ptr<OnlineStream> s; | ||
| 35 | + | ||
| 36 | + // set it to true when InputFinished() is called | ||
| 37 | + bool eof = false; | ||
| 38 | + | ||
| 39 | + // The last time we received a message from the client | ||
| 40 | + // TODO(fangjun): Use it to disconnect from a client if it is inactive | ||
| 41 | + // for a specified time. | ||
| 42 | + std::chrono::steady_clock::time_point last_active; | ||
| 43 | + | ||
| 44 | + std::mutex mutex; // protect samples | ||
| 45 | + | ||
| 46 | + // Audio samples received from the client. | ||
| 47 | + // | ||
| 48 | + // The I/O threads receive audio samples into this queue | ||
| 49 | + // and invoke work threads to compute features | ||
| 50 | + std::deque<std::vector<float>> samples; | ||
| 51 | + | ||
| 52 | + Connection() = default; | ||
| 53 | + Connection(connection_hdl hdl, std::shared_ptr<OnlineStream> s) | ||
| 54 | + : hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {} | ||
| 55 | +}; | ||
| 56 | + | ||
| 57 | +struct OnlineWebsocketDecoderConfig { | ||
| 58 | + OnlineRecognizerConfig recognizer_config; | ||
| 59 | + | ||
| 60 | + // It determines how often the decoder loop runs. | ||
| 61 | + int32_t loop_interval_ms = 10; | ||
| 62 | + | ||
| 63 | + int32_t max_batch_size = 5; | ||
| 64 | + | ||
| 65 | + void Register(ParseOptions *po); | ||
| 66 | + void Validate() const; | ||
| 67 | +}; | ||
| 68 | + | ||
| 69 | +class OnlineWebsocketServer; | ||
| 70 | + | ||
| 71 | +class OnlineWebsocketDecoder { | ||
| 72 | + public: | ||
| 73 | + /** | ||
| 74 | + * @param server Not owned. | ||
| 75 | + */ | ||
| 76 | + explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server); | ||
| 77 | + | ||
| 78 | + std::shared_ptr<Connection> GetOrCreateConnection(connection_hdl hdl); | ||
| 79 | + | ||
| 80 | + // Compute features for a stream given audio samples | ||
| 81 | + void AcceptWaveform(std::shared_ptr<Connection> c); | ||
| 82 | + | ||
| 83 | + // signal that there will be no more audio samples for a stream | ||
| 84 | + void InputFinished(std::shared_ptr<Connection> c); | ||
| 85 | + | ||
| 86 | + void Run(); | ||
| 87 | + | ||
| 88 | + private: | ||
| 89 | + void ProcessConnections(const asio::error_code &ec); | ||
| 90 | + | ||
| 91 | + /** It is called by one of the worker thread. | ||
| 92 | + */ | ||
| 93 | + void Decode(); | ||
| 94 | + | ||
| 95 | + private: | ||
| 96 | + OnlineWebsocketServer *server_; // not owned | ||
| 97 | + std::unique_ptr<OnlineRecognizer> recognizer_; | ||
| 98 | + OnlineWebsocketDecoderConfig config_; | ||
| 99 | + asio::steady_timer timer_; | ||
| 100 | + | ||
| 101 | + // It protects `connections_`, `ready_connections_`, and `active_` | ||
| 102 | + std::mutex mutex_; | ||
| 103 | + | ||
| 104 | + std::map<connection_hdl, std::shared_ptr<Connection>, | ||
| 105 | + std::owner_less<connection_hdl>> | ||
| 106 | + connections_; | ||
| 107 | + | ||
| 108 | + // Whenever a connection has enough feature frames for decoding, we put | ||
| 109 | + // it in this queue | ||
| 110 | + std::deque<std::shared_ptr<Connection>> ready_connections_; | ||
| 111 | + | ||
| 112 | + // If we are decoding a stream, we put it in the active_ set so that | ||
| 113 | + // only one thread can decode a stream at a time. | ||
| 114 | + std::set<connection_hdl, std::owner_less<connection_hdl>> active_; | ||
| 115 | +}; | ||
| 116 | + | ||
| 117 | +struct OnlineWebsocketServerConfig { | ||
| 118 | + OnlineWebsocketDecoderConfig decoder_config; | ||
| 119 | + | ||
| 120 | + std::string log_file = "./log.txt"; | ||
| 121 | + | ||
| 122 | + void Register(sherpa_onnx::ParseOptions *po); | ||
| 123 | + void Validate() const; | ||
| 124 | +}; | ||
| 125 | + | ||
| 126 | +class OnlineWebsocketServer { | ||
| 127 | + public: | ||
| 128 | + explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT | ||
| 129 | + asio::io_context &io_work, // NOLINT | ||
| 130 | + const OnlineWebsocketServerConfig &config); | ||
| 131 | + | ||
| 132 | + void Run(uint16_t port); | ||
| 133 | + | ||
| 134 | + const OnlineWebsocketServerConfig &GetConfig() const { return config_; } | ||
| 135 | + asio::io_context &GetConnectionContext() { return io_conn_; } | ||
| 136 | + asio::io_context &GetWorkContext() { return io_work_; } | ||
| 137 | + server &GetServer() { return server_; } | ||
| 138 | + | ||
| 139 | + void Send(connection_hdl hdl, const std::string &text); | ||
| 140 | + | ||
| 141 | + bool Contains(connection_hdl hdl) const; | ||
| 142 | + | ||
| 143 | + private: | ||
| 144 | + void SetupLog(); | ||
| 145 | + | ||
| 146 | + // When a websocket client is connected, it will invoke this method | ||
| 147 | + // (Not for HTTP) | ||
| 148 | + void OnOpen(connection_hdl hdl); | ||
| 149 | + | ||
| 150 | + // When a websocket client is disconnected, it will invoke this method | ||
| 151 | + void OnClose(connection_hdl hdl); | ||
| 152 | + | ||
| 153 | + void OnMessage(connection_hdl hdl, server::message_ptr msg); | ||
| 154 | + | ||
| 155 | + // Close a websocket connection with given code and reason | ||
| 156 | + void Close(connection_hdl hdl, websocketpp::close::status::value code, | ||
| 157 | + const std::string &reason); | ||
| 158 | + | ||
| 159 | + private: | ||
| 160 | + OnlineWebsocketServerConfig config_; | ||
| 161 | + asio::io_context &io_conn_; | ||
| 162 | + asio::io_context &io_work_; | ||
| 163 | + server server_; | ||
| 164 | + | ||
| 165 | + std::ofstream log_; | ||
| 166 | + sherpa_onnx::TeeStream tee_; | ||
| 167 | + | ||
| 168 | + OnlineWebsocketDecoder decoder_; | ||
| 169 | + | ||
| 170 | + mutable std::mutex mutex_; | ||
| 171 | + | ||
| 172 | + std::set<connection_hdl, std::owner_less<connection_hdl>> connections_; | ||
| 173 | +}; | ||
| 174 | + | ||
| 175 | +} // namespace sherpa_onnx | ||
| 176 | + | ||
| 177 | +#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ |
sherpa-onnx/csrc/online-websocket-server.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-websocket-server.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "asio.hpp" | ||
| 6 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 7 | +#include "sherpa-onnx/csrc/online-websocket-server-impl.h" | ||
| 8 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 9 | + | ||
| 10 | +static constexpr const char *kUsageMessage = R"( | ||
| 11 | +Automatic speech recognition with sherpa-onnx using websocket. | ||
| 12 | + | ||
| 13 | +Usage: | ||
| 14 | + | ||
| 15 | +./bin/sherpa-onnx-online-websocket-server --help | ||
| 16 | + | ||
| 17 | +./bin/sherpa-onnx-online-websocket-server \ | ||
| 18 | + --port=6006 \ | ||
| 19 | + --num-work-threads=5 \ | ||
| 20 | + --tokens=/path/to/tokens.txt \ | ||
| 21 | + --encoder=/path/to/encoder.onnx \ | ||
| 22 | + --decoder=/path/to/decoder.onnx \ | ||
| 23 | + --joiner=/path/to/joiner.onnx \ | ||
| 24 | + --log-file=./log.txt \ | ||
| 25 | + --max-batch-size=5 \ | ||
| 26 | + --loop-interval-ms=10 | ||
| 27 | + | ||
| 28 | +Please refer to | ||
| 29 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 30 | +for a list of pre-trained models to download. | ||
| 31 | +)"; | ||
| 32 | + | ||
| 33 | +int32_t main(int32_t argc, char *argv[]) { | ||
| 34 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 35 | + | ||
| 36 | + sherpa_onnx::OnlineWebsocketServerConfig config; | ||
| 37 | + | ||
| 38 | + // the server will listen on this port | ||
| 39 | + int32_t port = 6006; | ||
| 40 | + | ||
| 41 | + // size of the thread pool for handling network connections | ||
| 42 | + int32_t num_io_threads = 1; | ||
| 43 | + | ||
| 44 | + // size of the thread pool for neural network computation and decoding | ||
| 45 | + int32_t num_work_threads = 3; | ||
| 46 | + | ||
| 47 | + po.Register("num-io-threads", &num_io_threads, | ||
| 48 | + "Thread pool size for network connections."); | ||
| 49 | + | ||
| 50 | + po.Register("num-work-threads", &num_work_threads, | ||
| 51 | + "Thread pool size for for neural network " | ||
| 52 | + "computation and decoding."); | ||
| 53 | + | ||
| 54 | + po.Register("port", &port, "The port on which the server will listen."); | ||
| 55 | + | ||
| 56 | + config.Register(&po); | ||
| 57 | + | ||
| 58 | + if (argc == 1) { | ||
| 59 | + po.PrintUsage(); | ||
| 60 | + exit(EXIT_FAILURE); | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + po.Read(argc, argv); | ||
| 64 | + | ||
| 65 | + if (po.NumArgs() != 0) { | ||
| 66 | + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); | ||
| 67 | + po.PrintUsage(); | ||
| 68 | + exit(EXIT_FAILURE); | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + config.Validate(); | ||
| 72 | + | ||
| 73 | + asio::io_context io_conn; // for network connections | ||
| 74 | + asio::io_context io_work; // for neural network and decoding | ||
| 75 | + | ||
| 76 | + sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); | ||
| 77 | + server.Run(port); | ||
| 78 | + | ||
| 79 | + SHERPA_ONNX_LOGE("Listening on: %d", port); | ||
| 80 | + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); | ||
| 81 | + | ||
| 82 | + // give some work to do for the io_work pool | ||
| 83 | + auto work_guard = asio::make_work_guard(io_work); | ||
| 84 | + | ||
| 85 | + std::vector<std::thread> io_threads; | ||
| 86 | + | ||
| 87 | + // decrement since the main thread is also used for network communications | ||
| 88 | + for (int32_t i = 0; i < num_io_threads - 1; ++i) { | ||
| 89 | + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + std::vector<std::thread> work_threads; | ||
| 93 | + for (int32_t i = 0; i < num_work_threads; ++i) { | ||
| 94 | + work_threads.emplace_back([&io_work]() { io_work.run(); }); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + io_conn.run(); | ||
| 98 | + | ||
| 99 | + for (auto &t : io_threads) { | ||
| 100 | + t.join(); | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + for (auto &t : work_threads) { | ||
| 104 | + t.join(); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + return 0; | ||
| 108 | +} |
sherpa-onnx/csrc/tee-stream.h
0 → 100644
| 1 | +// Code in this file is copied and modified from | ||
| 2 | +// https://wordaligned.org/articles/cpp-streambufs | ||
| 3 | + | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_TEE_STREAM_H_ | ||
| 6 | +#include <ostream> | ||
| 7 | +#include <streambuf> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +template <typename char_type, typename traits = std::char_traits<char_type>> | ||
| 13 | +class basic_teebuf : public std::basic_streambuf<char_type, traits> { | ||
| 14 | + public: | ||
| 15 | + using int_type = typename traits::int_type; | ||
| 16 | + | ||
| 17 | + basic_teebuf(std::basic_streambuf<char_type, traits> *sb1, | ||
| 18 | + std::basic_streambuf<char_type, traits> *sb2) | ||
| 19 | + : sb1(sb1), sb2(sb2) {} | ||
| 20 | + | ||
| 21 | + private: | ||
| 22 | + int sync() override { | ||
| 23 | + int const r1 = sb1->pubsync(); | ||
| 24 | + int const r2 = sb2->pubsync(); | ||
| 25 | + return r1 == 0 && r2 == 0 ? 0 : -1; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + int_type overflow(int_type c) override { | ||
| 29 | + int_type const eof = traits::eof(); | ||
| 30 | + | ||
| 31 | + if (traits::eq_int_type(c, eof)) { | ||
| 32 | + return traits::not_eof(c); | ||
| 33 | + } else { | ||
| 34 | + char_type const ch = traits::to_char_type(c); | ||
| 35 | + int_type const r1 = sb1->sputc(ch); | ||
| 36 | + int_type const r2 = sb2->sputc(ch); | ||
| 37 | + | ||
| 38 | + return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof | ||
| 39 | + : c; | ||
| 40 | + } | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + private: | ||
| 44 | + std::basic_streambuf<char_type, traits> *sb1; | ||
| 45 | + std::basic_streambuf<char_type, traits> *sb2; | ||
| 46 | +}; | ||
| 47 | + | ||
| 48 | +using teebuf = basic_teebuf<char>; | ||
| 49 | + | ||
| 50 | +class TeeStream : public std::ostream { | ||
| 51 | + public: | ||
| 52 | + TeeStream(std::ostream &o1, std::ostream &o2) | ||
| 53 | + : std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {} | ||
| 54 | + | ||
| 55 | + private: | ||
| 56 | + teebuf tbuf; | ||
| 57 | +}; | ||
| 58 | + | ||
| 59 | +} // namespace sherpa_onnx | ||
| 60 | + | ||
| 61 | +#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_ |
-
请 注册 或 登录 后发表评论