Fangjun Kuang
Committed by GitHub

add streaming websocket server and client (#62)

@@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) @@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) 18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) 19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) 20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
  21 +option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
21 22
22 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 23 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
23 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 24 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") @@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
59 message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") 60 message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
60 message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") 61 message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
61 message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") 62 message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
  63 +message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
  64 +message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
62 65
63 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 66 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
64 set(CMAKE_CXX_EXTENSIONS OFF) 67 set(CMAKE_CXX_EXTENSIONS OFF)
@@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS)
91 include(googletest) 94 include(googletest)
92 endif() 95 endif()
93 96
  97 +if(SHERPA_ONNX_ENABLE_WEBSOCKET)
  98 + include(websocketpp)
  99 + include(asio)
  100 +endif()
  101 +
94 add_subdirectory(sherpa-onnx) 102 add_subdirectory(sherpa-onnx)
95 103
96 if(SHERPA_ONNX_ENABLE_C_API) 104 if(SHERPA_ONNX_ENABLE_C_API)
@@ -40,6 +40,10 @@ cmake \ @@ -40,6 +40,10 @@ cmake \
40 -DSHERPA_ONNX_ENABLE_TESTS=OFF \ 40 -DSHERPA_ONNX_ENABLE_TESTS=OFF \
41 -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ 41 -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
42 -DSHERPA_ONNX_ENABLE_CHECK=OFF \ 42 -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  43 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  44 + -DSHERPA_ONNX_ENABLE_JNI=OFF \
  45 + -DSHERPA_ONNX_ENABLE_C_API=OFF \
  46 + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
43 -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ 47 -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \
44 .. 48 ..
45 49
@@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" @@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake"
76 -DSHERPA_ONNX_ENABLE_JNI=ON \ 76 -DSHERPA_ONNX_ENABLE_JNI=ON \
77 -DCMAKE_INSTALL_PREFIX=./install \ 77 -DCMAKE_INSTALL_PREFIX=./install \
78 -DANDROID_ABI="x86_64" \ 78 -DANDROID_ABI="x86_64" \
  79 + -DSHERPA_ONNX_ENABLE_C_API=OFF \
  80 + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
79 -DANDROID_PLATFORM=android-21 .. 81 -DANDROID_PLATFORM=android-21 ..
80 82
81 # make VERBOSE=1 -j4 83 # make VERBOSE=1 -j4
  1 +function(download_asio)
  2 + include(FetchContent)
  3 +
  4 + set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz")
  5 + set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b")
  6 +
  7 + # If you don't have access to the Internet,
  8 + # please pre-download asio
  9 + set(possible_file_locations
  10 + $ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz
  11 + ${PROJECT_SOURCE_DIR}/asio-asio-1-24-0.tar.gz
  12 + ${PROJECT_BINARY_DIR}/asio-asio-1-24-0.tar.gz
  13 + /tmp/asio-asio-1-24-0.tar.gz
  14 + /star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz
  15 + )
  16 +
  17 + foreach(f IN LISTS possible_file_locations)
  18 + if(EXISTS ${f})
  19 + set(asio_URL "file://${f}")
  20 + break()
  21 + endif()
  22 + endforeach()
  23 +
  24 + FetchContent_Declare(asio
  25 + URL ${asio_URL}
  26 + URL_HASH ${asio_HASH}
  27 + )
  28 +
  29 + FetchContent_GetProperties(asio)
  30 + if(NOT asio_POPULATED)
  31 + message(STATUS "Downloading asio ${asio_URL}")
  32 + FetchContent_Populate(asio)
  33 + endif()
  34 + message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}")
  35 + # add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL)
  36 + include_directories(${asio_SOURCE_DIR}/asio/include)
  37 +endfunction()
  38 +
  39 +download_asio()
  1 +function(download_websocketpp)
  2 + include(FetchContent)
  3 +
  4 + # The latest commit on the develop branch os as 2022-10-22
  5 + set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip")
  6 + set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download websocketpp
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
  12 + ${PROJECT_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
  13 + ${PROJECT_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
  14 + /tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
  15 + /star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(websocketpp_URL "file://${f}")
  21 + break()
  22 + endif()
  23 + endforeach()
  24 +
  25 + FetchContent_Declare(websocketpp
  26 + URL ${websocketpp_URL}
  27 + URL_HASH ${websocketpp_HASH}
  28 + )
  29 +
  30 + FetchContent_GetProperties(websocketpp)
  31 + if(NOT websocketpp_POPULATED)
  32 + message(STATUS "Downloading websocketpp from ${websocketpp_URL}")
  33 + FetchContent_Populate(websocketpp)
  34 + endif()
  35 + message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}")
  36 + # add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL)
  37 + include_directories(${websocketpp_SOURCE_DIR})
  38 +endfunction()
  39 +
  40 +download_websocketpp()
@@ -4,6 +4,7 @@ set(sources @@ -4,6 +4,7 @@ set(sources
4 cat.cc 4 cat.cc
5 endpoint.cc 5 endpoint.cc
6 features.cc 6 features.cc
  7 + file-utils.cc
7 online-lstm-transducer-model.cc 8 online-lstm-transducer-model.cc
8 online-recognizer.cc 9 online-recognizer.cc
9 online-stream.cc 10 online-stream.cc
@@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) @@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO)
86 install(TARGETS sherpa-onnx-microphone DESTINATION bin) 87 install(TARGETS sherpa-onnx-microphone DESTINATION bin)
87 endif() 88 endif()
88 89
  90 +if(SHERPA_ONNX_ENABLE_WEBSOCKET)
  91 + add_definitions(-DASIO_STANDALONE)
  92 + add_definitions(-D_WEBSOCKETPP_CPP11_STL_)
  93 +
  94 + add_executable(sherpa-onnx-online-websocket-server
  95 + online-websocket-server-impl.cc
  96 + online-websocket-server.cc
  97 + )
  98 + target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
  99 +
  100 +
  101 + add_executable(sherpa-onnx-online-websocket-client
  102 + online-websocket-client.cc
  103 + )
  104 + target_link_libraries(sherpa-onnx-online-websocket-client sherpa-onnx-core)
  105 +
  106 + if(NOT WIN32)
  107 + target_link_libraries(sherpa-onnx-online-websocket-server -pthread)
  108 + target_compile_options(sherpa-onnx-online-websocket-server PRIVATE -Wno-deprecated-declarations)
  109 +
  110 + target_link_libraries(sherpa-onnx-online-websocket-client -pthread)
  111 + target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
  112 + endif()
  113 +
  114 +endif()
  115 +
89 116
90 if(SHERPA_ONNX_ENABLE_TESTS) 117 if(SHERPA_ONNX_ENABLE_TESTS)
91 set(sherpa_onnx_test_srcs 118 set(sherpa_onnx_test_srcs
  1 +exclude_files=tee-stream.h
@@ -14,6 +14,15 @@ @@ -14,6 +14,15 @@
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
  17 +void FeatureExtractorConfig::Register(ParseOptions *po) {
  18 + po->Register("sample-rate", &sampling_rate,
  19 + "Sampling rate of the input waveform. Must match the one "
  20 + "expected by the model.");
  21 +
  22 + po->Register("feat-dim", &feature_dim,
  23 + "Feature dimension. Must match the one expected by the model.");
  24 +}
  25 +
17 std::string FeatureExtractorConfig::ToString() const { 26 std::string FeatureExtractorConfig::ToString() const {
18 std::ostringstream os; 27 std::ostringstream os;
19 28
@@ -9,6 +9,8 @@ @@ -9,6 +9,8 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/parse-options.h"
  13 +
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
13 15
14 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
@@ -16,6 +18,8 @@ struct FeatureExtractorConfig { @@ -16,6 +18,8 @@ struct FeatureExtractorConfig {
16 int32_t feature_dim = 80; 18 int32_t feature_dim = 80;
17 19
18 std::string ToString() const; 20 std::string ToString() const;
  21 +
  22 + void Register(ParseOptions *po);
19 }; 23 };
20 24
21 class FeatureExtractor { 25 class FeatureExtractor {
  1 +// sherpa-onnx/csrc/file-utils.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/file-utils.h"
  6 +
  7 +#include <fstream>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/log.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +bool FileExists(const std::string &filename) {
  15 + return std::ifstream(filename).good();
  16 +}
  17 +
  18 +void AssertFileExists(const std::string &filename) {
  19 + if (!FileExists(filename)) {
  20 + SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!";
  21 + }
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/file-utils.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_
  6 +#define SHERPA_ONNX_CSRC_FILE_UTILS_H_
  7 +
  8 +#include <fstream>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +/** Check whether a given path is a file or not
  14 + *
  15 + * @param filename Path to check.
  16 + * @return Return true if the given path is a file; return false otherwise.
  17 + */
  18 +bool FileExists(const std::string &filename);
  19 +
  20 +/** Abort if the file does not exist.
  21 + *
  22 + * @param filename The file to check.
  23 + */
  24 +void AssertFileExists(const std::string &filename);
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include <utility> 12 #include <utility>
13 #include <vector> 13 #include <vector>
14 14
  15 +#include "sherpa-onnx/csrc/file-utils.h"
15 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 16 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
16 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 17 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
17 #include "sherpa-onnx/csrc/online-transducer-model.h" 18 #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
31 return ans; 32 return ans;
32 } 33 }
33 34
  35 +void OnlineRecognizerConfig::Register(ParseOptions *po) {
  36 + feat_config.Register(po);
  37 + model_config.Register(po);
  38 + endpoint_config.Register(po);
  39 +
  40 + po->Register("enable-endpoint", &enable_endpoint,
  41 + "True to enable endpoint detection. False to disable it.");
  42 +}
  43 +
  44 +bool OnlineRecognizerConfig::Validate() const {
  45 + return model_config.Validate();
  46 +}
  47 +
34 std::string OnlineRecognizerConfig::ToString() const { 48 std::string OnlineRecognizerConfig::ToString() const {
35 std::ostringstream os; 49 std::ostringstream os;
36 50
@@ -17,11 +17,15 @@ @@ -17,11 +17,15 @@
17 #include "sherpa-onnx/csrc/features.h" 17 #include "sherpa-onnx/csrc/features.h"
18 #include "sherpa-onnx/csrc/online-stream.h" 18 #include "sherpa-onnx/csrc/online-stream.h"
19 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 19 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
  20 +#include "sherpa-onnx/csrc/parse-options.h"
20 21
21 namespace sherpa_onnx { 22 namespace sherpa_onnx {
22 23
23 struct OnlineRecognizerResult { 24 struct OnlineRecognizerResult {
24 std::string text; 25 std::string text;
  26 +
  27 + // TODO(fangjun): Add a method to return a json string
  28 + std::string ToString() const { return text; }
25 }; 29 };
26 30
27 struct OnlineRecognizerConfig { 31 struct OnlineRecognizerConfig {
@@ -41,6 +45,9 @@ struct OnlineRecognizerConfig { @@ -41,6 +45,9 @@ struct OnlineRecognizerConfig {
41 endpoint_config(endpoint_config), 45 endpoint_config(endpoint_config),
42 enable_endpoint(enable_endpoint) {} 46 enable_endpoint(enable_endpoint) {}
43 47
  48 + void Register(ParseOptions *po);
  49 + bool Validate() const;
  50 +
44 std::string ToString() const; 51 std::string ToString() const;
45 }; 52 };
46 53
@@ -5,8 +5,52 @@ @@ -5,8 +5,52 @@
5 5
6 #include <sstream> 6 #include <sstream>
7 7
  8 +#include "sherpa-onnx/csrc/file-utils.h"
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
8 namespace sherpa_onnx { 11 namespace sherpa_onnx {
9 12
  13 +void OnlineTransducerModelConfig::Register(ParseOptions *po) {
  14 + po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
  15 + po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
  16 + po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
  17 + po->Register("tokens", &tokens, "Path to tokens.txt");
  18 + po->Register("num_threads", &num_threads,
  19 + "Number of threads to run the neural network");
  20 +
  21 + po->Register("debug", &debug,
  22 + "true to print model information while loading it.");
  23 +}
  24 +
  25 +bool OnlineTransducerModelConfig::Validate() const {
  26 + if (!FileExists(tokens)) {
  27 + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
  28 + return false;
  29 + }
  30 +
  31 + if (!FileExists(encoder_filename)) {
  32 + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
  33 + return false;
  34 + }
  35 +
  36 + if (!FileExists(decoder_filename)) {
  37 + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
  38 + return false;
  39 + }
  40 +
  41 + if (!FileExists(joiner_filename)) {
  42 + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
  43 + return false;
  44 + }
  45 +
  46 + if (num_threads < 1) {
  47 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  48 + return false;
  49 + }
  50 +
  51 + return true;
  52 +}
  53 +
10 std::string OnlineTransducerModelConfig::ToString() const { 54 std::string OnlineTransducerModelConfig::ToString() const {
11 std::ostringstream os; 55 std::ostringstream os;
12 56
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
9 namespace sherpa_onnx { 11 namespace sherpa_onnx {
10 12
11 struct OnlineTransducerModelConfig { 13 struct OnlineTransducerModelConfig {
@@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig { @@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig {
13 std::string decoder_filename; 15 std::string decoder_filename;
14 std::string joiner_filename; 16 std::string joiner_filename;
15 std::string tokens; 17 std::string tokens;
16 - int32_t num_threads; 18 + int32_t num_threads = 2;
17 bool debug = false; 19 bool debug = false;
18 20
19 OnlineTransducerModelConfig() = default; 21 OnlineTransducerModelConfig() = default;
@@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig { @@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig {
29 num_threads(num_threads), 31 num_threads(num_threads),
30 debug(debug) {} 32 debug(debug) {}
31 33
  34 + void Register(ParseOptions *po);
  35 + bool Validate() const;
  36 +
32 std::string ToString() const; 37 std::string ToString() const;
33 }; 38 };
34 39
  1 +// sherpa/cpp_api/websocket/online-websocket-client.cc
  2 +//
  3 +// Copyright (c) 2022 Xiaomi Corporation
  4 +#include <chrono> // NOLINT
  5 +#include <fstream>
  6 +#include <string>
  7 +
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +#include "sherpa-onnx/csrc/wave-reader.h"
  11 +#include "websocketpp/client.hpp"
  12 +#include "websocketpp/config/asio_no_tls_client.hpp"
  13 +#include "websocketpp/uri.hpp"
  14 +
  15 +using client = websocketpp::client<websocketpp::config::asio_client>;
  16 +
  17 +using message_ptr = client::message_ptr;
  18 +using websocketpp::connection_hdl;
  19 +
  20 +static constexpr const char *kUsageMessage = R"(
  21 +Automatic speech recognition with sherpa-onnx using websocket.
  22 +
  23 +Usage:
  24 +
  25 +./bin/sherpa-onnx-online-websocket-client --help
  26 +
  27 +./bin/sherpa-onnx-online-websocket-client \
  28 + --server-ip=127.0.0.1 \
  29 + --server-port=6006 \
  30 + --samples-per-message=8000 \
  31 + --seconds-per-message=0.2 \
  32 + /path/to/foo.wav
  33 +
  34 +It support only wave of with a single channel, 16kHz, 16-bit samples.
  35 +)";
  36 +
  37 +class Client {
  38 + public:
  39 + Client(asio::io_context &io, // NOLINT
  40 + const std::string &ip, int16_t port, const std::vector<float> &samples,
  41 + int32_t samples_per_message, float seconds_per_message)
  42 + : io_(io),
  43 + uri_(/*secure*/ false, ip, port, /*resource*/ "/"),
  44 + samples_(samples),
  45 + samples_per_message_(samples_per_message),
  46 + seconds_per_message_(seconds_per_message) {
  47 + c_.clear_access_channels(websocketpp::log::alevel::all);
  48 + // c_.set_access_channels(websocketpp::log::alevel::connect);
  49 + // c_.set_access_channels(websocketpp::log::alevel::disconnect);
  50 +
  51 + c_.init_asio(&io_);
  52 + c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
  53 + c_.set_close_handler(
  54 + [this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); });
  55 + c_.set_message_handler(
  56 + [this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); });
  57 +
  58 + Run();
  59 + }
  60 +
  61 + private:
  62 + void Run() {
  63 + websocketpp::lib::error_code ec;
  64 + client::connection_ptr con = c_.get_connection(uri_.str(), ec);
  65 + if (ec) {
  66 + SHERPA_ONNX_LOGE("Could not create connection to %s because %s",
  67 + uri_.str().c_str(), ec.message().c_str());
  68 + exit(EXIT_FAILURE);
  69 + }
  70 +
  71 + c_.connect(con);
  72 + }
  73 +
  74 + void OnOpen(connection_hdl hdl) {
  75 + auto start_time = std::chrono::steady_clock::now();
  76 + asio::post(
  77 + io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); });
  78 + }
  79 +
  80 + void OnMessage(connection_hdl hdl, message_ptr msg) {
  81 + const std::string &payload = msg->get_payload();
  82 +
  83 + if (payload == "Done!") {
  84 + websocketpp::lib::error_code ec;
  85 + c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec);
  86 + if (ec) {
  87 + SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str());
  88 + exit(EXIT_FAILURE);
  89 + }
  90 + } else {
  91 + SHERPA_ONNX_LOGE("%s", payload.c_str());
  92 + }
  93 + }
  94 +
  95 + void SendMessage(
  96 + connection_hdl hdl,
  97 + std::chrono::time_point<std::chrono::steady_clock> start_time) {
  98 + int32_t num_samples = samples_.size();
  99 + int32_t num_messages = num_samples / samples_per_message_;
  100 +
  101 + websocketpp::lib::error_code ec;
  102 + auto time = std::chrono::steady_clock::now();
  103 + int elapsed_time_ms =
  104 + std::chrono::duration_cast<std::chrono::milliseconds>(time - start_time)
  105 + .count();
  106 +
  107 + if (elapsed_time_ms <
  108 + static_cast<int>(seconds_per_message_ * num_sent_messages_ * 1000)) {
  109 + std::this_thread::sleep_for(std::chrono::milliseconds(int(
  110 + seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms)));
  111 + }
  112 +
  113 + if (num_sent_messages_ < 1) {
  114 + SHERPA_ONNX_LOGE("Starting to send audio");
  115 + }
  116 +
  117 + if (num_sent_messages_ < num_messages) {
  118 + c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_,
  119 + samples_per_message_ * sizeof(float),
  120 + websocketpp::frame::opcode::binary, ec);
  121 +
  122 + if (ec) {
  123 + SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
  124 + ec.message().c_str());
  125 + exit(EXIT_FAILURE);
  126 + }
  127 +
  128 + ec.clear();
  129 +
  130 + ++num_sent_messages_;
  131 + }
  132 +
  133 + if (num_sent_messages_ == num_messages) {
  134 + int32_t remaining_samples = num_samples % samples_per_message_;
  135 + if (remaining_samples) {
  136 + c_.send(hdl,
  137 + samples_.data() + num_sent_messages_ * samples_per_message_,
  138 + remaining_samples * sizeof(float),
  139 + websocketpp::frame::opcode::binary, ec);
  140 +
  141 + if (ec) {
  142 + SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
  143 + ec.message().c_str());
  144 + exit(EXIT_FAILURE);
  145 + }
  146 + ec.clear();
  147 + }
  148 +
  149 + // To signal that we have send all the messages
  150 + c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec);
  151 + SHERPA_ONNX_LOGE("Sent Done Signal");
  152 +
  153 + if (ec) {
  154 + SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
  155 + ec.message().c_str());
  156 + exit(EXIT_FAILURE);
  157 + }
  158 + } else {
  159 + asio::post(io_, [this, hdl, start_time]() {
  160 + this->SendMessage(hdl, start_time);
  161 + });
  162 + }
  163 + }
  164 +
  165 + private:
  166 + client c_;
  167 + asio::io_context &io_;
  168 + websocketpp::uri uri_;
  169 + std::vector<float> samples_;
  170 + int32_t samples_per_message_ = 8000; // 0.5 seconds
  171 + float seconds_per_message_ = 0.2;
  172 + int32_t num_sent_messages_ = 0;
  173 +};
  174 +
  175 +int32_t main(int32_t argc, char *argv[]) {
  176 + std::string server_ip = "127.0.0.1";
  177 + int32_t server_port = 6006;
  178 +
  179 + // Sample rate of the input wave. No resampling is made.
  180 + int32_t sample_rate = 16000;
  181 + int32_t samples_per_message = 8000;
  182 + float seconds_per_message = 0.2;
  183 +
  184 + sherpa_onnx::ParseOptions po(kUsageMessage);
  185 +
  186 + po.Register("server-ip", &server_ip, "IP address of the websocket server");
  187 + po.Register("server-port", &server_port, "Port of the websocket server");
  188 + po.Register("sample-rate", &sample_rate,
  189 + "Sample rate of the input wave. Should be the one expected by "
  190 + "the server");
  191 +
  192 + po.Register("samples-per-message", &samples_per_message,
  193 + "Send this number of samples per message.");
  194 +
  195 + po.Register("seconds-per-message", &seconds_per_message,
  196 + "We will simulate that each message takes this number of seconds "
  197 + "to send. If you select a very large value, it will take a long "
  198 + "time to send all the samples");
  199 +
  200 + po.Read(argc, argv);
  201 +
  202 + if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(),
  203 + server_ip.end())) {
  204 + SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str());
  205 + return -1;
  206 + }
  207 +
  208 + if (server_port <= 0 || server_port > 65535) {
  209 + SHERPA_ONNX_LOGE("Invalid server port: %d", server_port);
  210 + return -1;
  211 + }
  212 +
  213 + // 0.01 is an arbitrary value. You can change it.
  214 + if (samples_per_message <= 0.01 * sample_rate) {
  215 + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
  216 + samples_per_message);
  217 + return -1;
  218 + }
  219 +
  220 + // 100 is an arbitrary value. You can change it.
  221 + if (samples_per_message >= sample_rate * 100) {
  222 + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
  223 + samples_per_message);
  224 + return -1;
  225 + }
  226 +
  227 + if (seconds_per_message < 0) {
  228 + SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f",
  229 + seconds_per_message);
  230 + return -1;
  231 + }
  232 +
  233 + // 1 is an arbitrary value.
  234 + if (seconds_per_message > 1) {
  235 + SHERPA_ONNX_LOGE(
  236 + "--seconds-per-message is too large: %.3f. You will wait a long time "
  237 + "to "
  238 + "send all the samples",
  239 + seconds_per_message);
  240 + return -1;
  241 + }
  242 +
  243 + if (po.NumArgs() != 1) {
  244 + po.PrintUsage();
  245 + return -1;
  246 + }
  247 +
  248 + std::string wave_filename = po.GetArg(1);
  249 +
  250 + bool is_ok = false;
  251 + std::vector<float> samples =
  252 + sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok);
  253 +
  254 + if (!is_ok) {
  255 + SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
  256 + return -1;
  257 + }
  258 +
  259 + asio::io_context io_conn; // for network connections
  260 + Client c(io_conn, server_ip, server_port, samples, samples_per_message,
  261 + seconds_per_message);
  262 +
  263 + io_conn.run(); // will exit when the above connection is closed
  264 +
  265 + SHERPA_ONNX_LOGE("Done!");
  266 + return 0;
  267 +}
  1 +// sherpa-onnx/csrc/online-websocket-server-impl.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/log.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) {
  15 + recognizer_config.Register(po);
  16 +
  17 + po->Register("loop-interval-ms", &loop_interval_ms,
  18 + "It determines how often the decoder loop runs. ");
  19 +
  20 + po->Register("max-batch-size", &max_batch_size,
  21 + "Max batch size for recognition.");
  22 +}
  23 +
  24 +void OnlineWebsocketDecoderConfig::Validate() const {
  25 + recognizer_config.Validate();
  26 + SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0);
  27 + SHERPA_ONNX_CHECK_GT(max_batch_size, 0);
  28 +}
  29 +
  30 +void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) {
  31 + decoder_config.Register(po);
  32 +
  33 + po->Register("log-file", &log_file,
  34 + "Path to the log file. Logs are "
  35 + "appended to this file");
  36 +}
  37 +
  38 +void OnlineWebsocketServerConfig::Validate() const {
  39 + decoder_config.Validate();
  40 +}
  41 +
  42 +OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server)
  43 + : server_(server),
  44 + config_(server->GetConfig().decoder_config),
  45 + timer_(server->GetWorkContext()) {
  46 + recognizer_ = std::make_unique<OnlineRecognizer>(config_.recognizer_config);
  47 +}
  48 +
  49 +std::shared_ptr<Connection> OnlineWebsocketDecoder::GetOrCreateConnection(
  50 + connection_hdl hdl) {
  51 + std::lock_guard<std::mutex> lock(mutex_);
  52 + auto it = connections_.find(hdl);
  53 + if (it != connections_.end()) {
  54 + return it->second;
  55 + } else {
  56 + // create a new connection
  57 + std::shared_ptr<OnlineStream> s = recognizer_->CreateStream();
  58 + auto c = std::make_shared<Connection>(hdl, s);
  59 + connections_.insert({hdl, c});
  60 + return c;
  61 + }
  62 +}
  63 +
  64 +void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr<Connection> c) {
  65 + std::lock_guard<std::mutex> lock(c->mutex);
  66 + float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
  67 + while (!c->samples.empty()) {
  68 + const auto &s = c->samples.front();
  69 + c->s->AcceptWaveform(sample_rate, s.data(), s.size());
  70 + c->samples.pop_front();
  71 + }
  72 +}
  73 +
  74 +void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
  75 + std::lock_guard<std::mutex> lock(c->mutex);
  76 +
  77 + float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
  78 +
  79 + while (!c->samples.empty()) {
  80 + const auto &s = c->samples.front();
  81 + c->s->AcceptWaveform(sample_rate, s.data(), s.size());
  82 + c->samples.pop_front();
  83 + }
  84 +
  85 + // TODO(fangjun): Change the amount of paddings to be configurable
  86 + std::vector<float> tail_padding(static_cast<int64_t>(0.8 * sample_rate));
  87 +
  88 + c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size());
  89 +
  90 + c->s->InputFinished();
  91 + c->eof = true;
  92 +}
  93 +
  94 +void OnlineWebsocketDecoder::Run() {
  95 + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
  96 +
  97 + timer_.async_wait(
  98 + [this](const asio::error_code &ec) { ProcessConnections(ec); });
  99 +}
  100 +
  101 +void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) {
  102 + if (ec) {
  103 + SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!";
  104 + }
  105 +
  106 + std::lock_guard<std::mutex> lock(mutex_);
  107 + std::vector<connection_hdl> to_remove;
  108 + for (auto &p : connections_) {
  109 + auto hdl = p.first;
  110 + auto c = p.second;
  111 +
  112 + // The order of `if` below matters!
  113 + if (!server_->Contains(hdl)) {
  114 + // If the connection is disconnected, we stop processing it
  115 + to_remove.push_back(hdl);
  116 + continue;
  117 + }
  118 +
  119 + if (active_.count(hdl)) {
  120 + // Another thread is decoding this stream, so skip it
  121 + continue;
  122 + }
  123 +
  124 + if (!recognizer_->IsReady(c->s.get()) && !c->eof) {
  125 + // this stream has not enough frames to decode, so skip it
  126 + continue;
  127 + }
  128 +
  129 + if (!recognizer_->IsReady(c->s.get()) && c->eof) {
  130 + // We won't receive samples from the client, so send a Done! to client
  131 +
  132 + asio::post(server_->GetWorkContext(),
  133 + [this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); });
  134 +
  135 + to_remove.push_back(hdl);
  136 + continue;
  137 + }
  138 +
  139 + // TODO(fangun): If the connection is timed out, we need to also
  140 + // add it to `to_remove`
  141 +
  142 + // this stream has enough frames and is currently not processed by any
  143 + // threads, so put it into the ready queue
  144 + ready_connections_.push_back(c);
  145 +
  146 + // In `Decode()`, it will remove hdl from `active_`
  147 + active_.insert(c->hdl);
  148 + }
  149 +
  150 + for (auto hdl : to_remove) {
  151 + connections_.erase(hdl);
  152 + }
  153 +
  154 + if (!ready_connections_.empty()) {
  155 + asio::post(server_->GetWorkContext(), [this]() { Decode(); });
  156 + }
  157 +
  158 + // Schedule another call
  159 + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
  160 +
  161 + timer_.async_wait(
  162 + [this](const asio::error_code &ec) { ProcessConnections(ec); });
  163 +}
  164 +
  165 +void OnlineWebsocketDecoder::Decode() {
  166 + std::unique_lock<std::mutex> lock(mutex_);
  167 + if (ready_connections_.empty()) {
  168 + // There are no connections that are ready for decoding,
  169 + // so we return directly
  170 + return;
  171 + }
  172 +
  173 + std::vector<std::shared_ptr<Connection>> c_vec;
  174 + std::vector<OnlineStream *> s_vec;
  175 + while (!ready_connections_.empty() &&
  176 + static_cast<int32_t>(s_vec.size()) < config_.max_batch_size) {
  177 + auto c = ready_connections_.front();
  178 + ready_connections_.pop_front();
  179 +
  180 + c_vec.push_back(c);
  181 + s_vec.push_back(c->s.get());
  182 + }
  183 +
  184 + if (!ready_connections_.empty()) {
  185 + // there are too many ready connections but this thread can only handle
  186 + // max_batch_size connections at a time, so we schedule another call
  187 + // to Decode() and let other threads to process the ready connections
  188 + asio::post(server_->GetWorkContext(), [this]() { Decode(); });
  189 + }
  190 +
  191 + lock.unlock();
  192 + recognizer_->DecodeStreams(s_vec.data(), s_vec.size());
  193 + lock.lock();
  194 +
  195 + for (auto c : c_vec) {
  196 + auto result = recognizer_->GetResult(c->s.get());
  197 +
  198 + asio::post(server_->GetConnectionContext(),
  199 + [this, hdl = c->hdl, str = result.ToString()]() {
  200 + server_->Send(hdl, str);
  201 + });
  202 + active_.erase(c->hdl);
  203 + }
  204 +}
  205 +
  206 +OnlineWebsocketServer::OnlineWebsocketServer(
  207 + asio::io_context &io_conn, asio::io_context &io_work,
  208 + const OnlineWebsocketServerConfig &config)
  209 + : config_(config),
  210 + io_conn_(io_conn),
  211 + io_work_(io_work),
  212 + log_(config.log_file, std::ios::app),
  213 + tee_(std::cout, log_),
  214 + decoder_(this) {
  215 + SetupLog();
  216 +
  217 + server_.init_asio(&io_conn_);
  218 +
  219 + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
  220 +
  221 + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
  222 +
  223 + server_.set_message_handler(
  224 + [this](connection_hdl hdl, server::message_ptr msg) {
  225 + OnMessage(hdl, msg);
  226 + });
  227 +}
  228 +
  229 +void OnlineWebsocketServer::Run(uint16_t port) {
  230 + server_.set_reuse_addr(true);
  231 + server_.listen(asio::ip::tcp::v4(), port);
  232 + server_.start_accept();
  233 + decoder_.Run();
  234 +}
  235 +
  236 +void OnlineWebsocketServer::SetupLog() {
  237 + server_.clear_access_channels(websocketpp::log::alevel::all);
  238 + // server_.set_access_channels(websocketpp::log::alevel::connect);
  239 + // server_.set_access_channels(websocketpp::log::alevel::disconnect);
  240 +
  241 + // So that it also prints to std::cout and std::cerr
  242 + server_.get_alog().set_ostream(&tee_);
  243 + server_.get_elog().set_ostream(&tee_);
  244 +}
  245 +
  246 +void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) {
  247 + websocketpp::lib::error_code ec;
  248 + if (!Contains(hdl)) {
  249 + return;
  250 + }
  251 +
  252 + server_.send(hdl, text, websocketpp::frame::opcode::text, ec);
  253 + if (ec) {
  254 + server_.get_alog().write(websocketpp::log::alevel::app, ec.message());
  255 + }
  256 +}
  257 +
  258 +void OnlineWebsocketServer::OnOpen(connection_hdl hdl) {
  259 + std::lock_guard<std::mutex> lock(mutex_);
  260 + connections_.insert(hdl);
  261 +
  262 + std::ostringstream os;
  263 + os << "New connection: "
  264 + << server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". "
  265 + << "Number of active connections: " << connections_.size() << ".\n";
  266 + SHERPA_ONNX_LOG(INFO) << os.str();
  267 +}
  268 +
  269 +void OnlineWebsocketServer::OnClose(connection_hdl hdl) {
  270 + std::lock_guard<std::mutex> lock(mutex_);
  271 + connections_.erase(hdl);
  272 +
  273 + SHERPA_ONNX_LOG(INFO) << "Number of active connections: "
  274 + << connections_.size() << "\n";
  275 +}
  276 +
  277 +bool OnlineWebsocketServer::Contains(connection_hdl hdl) const {
  278 + std::lock_guard<std::mutex> lock(mutex_);
  279 + return connections_.count(hdl);
  280 +}
  281 +
  282 +void OnlineWebsocketServer::OnMessage(connection_hdl hdl,
  283 + server::message_ptr msg) {
  284 + auto c = decoder_.GetOrCreateConnection(hdl);
  285 +
  286 + const std::string &payload = msg->get_payload();
  287 +
  288 + switch (msg->get_opcode()) {
  289 + case websocketpp::frame::opcode::text:
  290 + if (payload == "Done") {
  291 + asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); });
  292 + }
  293 + break;
  294 + case websocketpp::frame::opcode::binary: {
  295 + auto p = reinterpret_cast<const float *>(payload.data());
  296 + int32_t num_samples = payload.size() / sizeof(float);
  297 + std::vector<float> samples(p, p + num_samples);
  298 +
  299 + c->samples.push_back(std::move(samples));
  300 +
  301 + asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); });
  302 + break;
  303 + }
  304 + default:
  305 + break;
  306 + }
  307 +}
  308 +
  309 +void OnlineWebsocketServer::Close(connection_hdl hdl,
  310 + websocketpp::close::status::value code,
  311 + const std::string &reason) {
  312 + auto con = server_.get_con_from_hdl(hdl);
  313 +
  314 + std::ostringstream os;
  315 + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
  316 + << "\n";
  317 +
  318 + websocketpp::lib::error_code ec;
  319 + server_.close(hdl, code, reason, ec);
  320 + if (ec) {
  321 + os << "Failed to close" << con->get_remote_endpoint() << ". "
  322 + << ec.message() << "\n";
  323 + }
  324 + server_.get_alog().write(websocketpp::log::alevel::app, os.str());
  325 +}
  326 +
  327 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-websocket-server-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
  7 +
  8 +#include <deque>
  9 +#include <fstream>
  10 +#include <map>
  11 +#include <memory>
  12 +#include <mutex> // NOLINT
  13 +#include <set>
  14 +#include <string>
  15 +#include <unordered_set>
  16 +#include <utility>
  17 +#include <vector>
  18 +
  19 +#include "asio.hpp"
  20 +#include "sherpa-onnx/csrc/online-recognizer.h"
  21 +#include "sherpa-onnx/csrc/online-stream.h"
  22 +#include "sherpa-onnx/csrc/parse-options.h"
  23 +#include "sherpa-onnx/csrc/tee-stream.h"
  24 +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
  25 +#include "websocketpp/server.hpp"
  26 +using server = websocketpp::server<websocketpp::config::asio>;
  27 +using connection_hdl = websocketpp::connection_hdl;
  28 +
  29 +namespace sherpa_onnx {
  30 +
  31 +struct Connection {
  32 + // handle to the connection. We can use it to send messages to the client
  33 + connection_hdl hdl;
  34 + std::shared_ptr<OnlineStream> s;
  35 +
  36 + // set it to true when InputFinished() is called
  37 + bool eof = false;
  38 +
  39 + // The last time we received a message from the client
  40 + // TODO(fangjun): Use it to disconnect from a client if it is inactive
  41 + // for a specified time.
  42 + std::chrono::steady_clock::time_point last_active;
  43 +
  44 + std::mutex mutex; // protect samples
  45 +
  46 + // Audio samples received from the client.
  47 + //
  48 + // The I/O threads receive audio samples into this queue
  49 + // and invoke work threads to compute features
  50 + std::deque<std::vector<float>> samples;
  51 +
  52 + Connection() = default;
  53 + Connection(connection_hdl hdl, std::shared_ptr<OnlineStream> s)
  54 + : hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {}
  55 +};
  56 +
  57 +struct OnlineWebsocketDecoderConfig {
  58 + OnlineRecognizerConfig recognizer_config;
  59 +
  60 + // It determines how often the decoder loop runs.
  61 + int32_t loop_interval_ms = 10;
  62 +
  63 + int32_t max_batch_size = 5;
  64 +
  65 + void Register(ParseOptions *po);
  66 + void Validate() const;
  67 +};
  68 +
  69 +class OnlineWebsocketServer;
  70 +
  71 +class OnlineWebsocketDecoder {
  72 + public:
  73 + /**
  74 + * @param server Not owned.
  75 + */
  76 + explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server);
  77 +
  78 + std::shared_ptr<Connection> GetOrCreateConnection(connection_hdl hdl);
  79 +
  80 + // Compute features for a stream given audio samples
  81 + void AcceptWaveform(std::shared_ptr<Connection> c);
  82 +
  83 + // signal that there will be no more audio samples for a stream
  84 + void InputFinished(std::shared_ptr<Connection> c);
  85 +
  86 + void Run();
  87 +
  88 + private:
  89 + void ProcessConnections(const asio::error_code &ec);
  90 +
  91 + /** It is called by one of the worker thread.
  92 + */
  93 + void Decode();
  94 +
  95 + private:
  96 + OnlineWebsocketServer *server_; // not owned
  97 + std::unique_ptr<OnlineRecognizer> recognizer_;
  98 + OnlineWebsocketDecoderConfig config_;
  99 + asio::steady_timer timer_;
  100 +
  101 + // It protects `connections_`, `ready_connections_`, and `active_`
  102 + std::mutex mutex_;
  103 +
  104 + std::map<connection_hdl, std::shared_ptr<Connection>,
  105 + std::owner_less<connection_hdl>>
  106 + connections_;
  107 +
  108 + // Whenever a connection has enough feature frames for decoding, we put
  109 + // it in this queue
  110 + std::deque<std::shared_ptr<Connection>> ready_connections_;
  111 +
  112 + // If we are decoding a stream, we put it in the active_ set so that
  113 + // only one thread can decode a stream at a time.
  114 + std::set<connection_hdl, std::owner_less<connection_hdl>> active_;
  115 +};
  116 +
  117 +struct OnlineWebsocketServerConfig {
  118 + OnlineWebsocketDecoderConfig decoder_config;
  119 +
  120 + std::string log_file = "./log.txt";
  121 +
  122 + void Register(sherpa_onnx::ParseOptions *po);
  123 + void Validate() const;
  124 +};
  125 +
  126 +class OnlineWebsocketServer {
  127 + public:
  128 + explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT
  129 + asio::io_context &io_work, // NOLINT
  130 + const OnlineWebsocketServerConfig &config);
  131 +
  132 + void Run(uint16_t port);
  133 +
  134 + const OnlineWebsocketServerConfig &GetConfig() const { return config_; }
  135 + asio::io_context &GetConnectionContext() { return io_conn_; }
  136 + asio::io_context &GetWorkContext() { return io_work_; }
  137 + server &GetServer() { return server_; }
  138 +
  139 + void Send(connection_hdl hdl, const std::string &text);
  140 +
  141 + bool Contains(connection_hdl hdl) const;
  142 +
  143 + private:
  144 + void SetupLog();
  145 +
  146 + // When a websocket client is connected, it will invoke this method
  147 + // (Not for HTTP)
  148 + void OnOpen(connection_hdl hdl);
  149 +
  150 + // When a websocket client is disconnected, it will invoke this method
  151 + void OnClose(connection_hdl hdl);
  152 +
  153 + void OnMessage(connection_hdl hdl, server::message_ptr msg);
  154 +
  155 + // Close a websocket connection with given code and reason
  156 + void Close(connection_hdl hdl, websocketpp::close::status::value code,
  157 + const std::string &reason);
  158 +
  159 + private:
  160 + OnlineWebsocketServerConfig config_;
  161 + asio::io_context &io_conn_;
  162 + asio::io_context &io_work_;
  163 + server server_;
  164 +
  165 + std::ofstream log_;
  166 + sherpa_onnx::TeeStream tee_;
  167 +
  168 + OnlineWebsocketDecoder decoder_;
  169 +
  170 + mutable std::mutex mutex_;
  171 +
  172 + std::set<connection_hdl, std::owner_less<connection_hdl>> connections_;
  173 +};
  174 +
  175 +} // namespace sherpa_onnx
  176 +
  177 +#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
  1 +// sherpa-onnx/csrc/online-websocket-server.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "asio.hpp"
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
  8 +#include "sherpa-onnx/csrc/parse-options.h"
  9 +
  10 +static constexpr const char *kUsageMessage = R"(
  11 +Automatic speech recognition with sherpa-onnx using websocket.
  12 +
  13 +Usage:
  14 +
  15 +./bin/sherpa-onnx-online-websocket-server --help
  16 +
  17 +./bin/sherpa-onnx-online-websocket-server \
  18 + --port=6006 \
  19 + --num-work-threads=5 \
  20 + --tokens=/path/to/tokens.txt \
  21 + --encoder=/path/to/encoder.onnx \
  22 + --decoder=/path/to/decoder.onnx \
  23 + --joiner=/path/to/joiner.onnx \
  24 + --log-file=./log.txt \
  25 + --max-batch-size=5 \
  26 + --loop-interval-ms=10
  27 +
  28 +Please refer to
  29 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  30 +for a list of pre-trained models to download.
  31 +)";
  32 +
  33 +int32_t main(int32_t argc, char *argv[]) {
  34 + sherpa_onnx::ParseOptions po(kUsageMessage);
  35 +
  36 + sherpa_onnx::OnlineWebsocketServerConfig config;
  37 +
  38 + // the server will listen on this port
  39 + int32_t port = 6006;
  40 +
  41 + // size of the thread pool for handling network connections
  42 + int32_t num_io_threads = 1;
  43 +
  44 + // size of the thread pool for neural network computation and decoding
  45 + int32_t num_work_threads = 3;
  46 +
  47 + po.Register("num-io-threads", &num_io_threads,
  48 + "Thread pool size for network connections.");
  49 +
  50 + po.Register("num-work-threads", &num_work_threads,
  51 + "Thread pool size for for neural network "
  52 + "computation and decoding.");
  53 +
  54 + po.Register("port", &port, "The port on which the server will listen.");
  55 +
  56 + config.Register(&po);
  57 +
  58 + if (argc == 1) {
  59 + po.PrintUsage();
  60 + exit(EXIT_FAILURE);
  61 + }
  62 +
  63 + po.Read(argc, argv);
  64 +
  65 + if (po.NumArgs() != 0) {
  66 + SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
  67 + po.PrintUsage();
  68 + exit(EXIT_FAILURE);
  69 + }
  70 +
  71 + config.Validate();
  72 +
  73 + asio::io_context io_conn; // for network connections
  74 + asio::io_context io_work; // for neural network and decoding
  75 +
  76 + sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
  77 + server.Run(port);
  78 +
  79 + SHERPA_ONNX_LOGE("Listening on: %d", port);
  80 + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
  81 +
  82 + // give some work to do for the io_work pool
  83 + auto work_guard = asio::make_work_guard(io_work);
  84 +
  85 + std::vector<std::thread> io_threads;
  86 +
  87 + // decrement since the main thread is also used for network communications
  88 + for (int32_t i = 0; i < num_io_threads - 1; ++i) {
  89 + io_threads.emplace_back([&io_conn]() { io_conn.run(); });
  90 + }
  91 +
  92 + std::vector<std::thread> work_threads;
  93 + for (int32_t i = 0; i < num_work_threads; ++i) {
  94 + work_threads.emplace_back([&io_work]() { io_work.run(); });
  95 + }
  96 +
  97 + io_conn.run();
  98 +
  99 + for (auto &t : io_threads) {
  100 + t.join();
  101 + }
  102 +
  103 + for (auto &t : work_threads) {
  104 + t.join();
  105 + }
  106 +
  107 + return 0;
  108 +}
  1 +// Code in this file is copied and modified from
  2 +// https://wordaligned.org/articles/cpp-streambufs
  3 +
  4 +#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_
  5 +#define SHERPA_ONNX_CSRC_TEE_STREAM_H_
  6 +#include <ostream>
  7 +#include <streambuf>
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +template <typename char_type, typename traits = std::char_traits<char_type>>
  13 +class basic_teebuf : public std::basic_streambuf<char_type, traits> {
  14 + public:
  15 + using int_type = typename traits::int_type;
  16 +
  17 + basic_teebuf(std::basic_streambuf<char_type, traits> *sb1,
  18 + std::basic_streambuf<char_type, traits> *sb2)
  19 + : sb1(sb1), sb2(sb2) {}
  20 +
  21 + private:
  22 + int sync() override {
  23 + int const r1 = sb1->pubsync();
  24 + int const r2 = sb2->pubsync();
  25 + return r1 == 0 && r2 == 0 ? 0 : -1;
  26 + }
  27 +
  28 + int_type overflow(int_type c) override {
  29 + int_type const eof = traits::eof();
  30 +
  31 + if (traits::eq_int_type(c, eof)) {
  32 + return traits::not_eof(c);
  33 + } else {
  34 + char_type const ch = traits::to_char_type(c);
  35 + int_type const r1 = sb1->sputc(ch);
  36 + int_type const r2 = sb2->sputc(ch);
  37 +
  38 + return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof
  39 + : c;
  40 + }
  41 + }
  42 +
  43 + private:
  44 + std::basic_streambuf<char_type, traits> *sb1;
  45 + std::basic_streambuf<char_type, traits> *sb2;
  46 +};
  47 +
  48 +using teebuf = basic_teebuf<char>;
  49 +
  50 +class TeeStream : public std::ostream {
  51 + public:
  52 + TeeStream(std::ostream &o1, std::ostream &o2)
  53 + : std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {}
  54 +
  55 + private:
  56 + teebuf tbuf;
  57 +};
  58 +
  59 +} // namespace sherpa_onnx
  60 +
  61 +#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_