Fangjun Kuang
Committed by GitHub

add streaming websocket server and client (#62)

... ... @@ -18,6 +18,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
... ... @@ -59,6 +60,8 @@ message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_EXTENSIONS OFF)
... ... @@ -91,6 +94,11 @@ if(SHERPA_ONNX_ENABLE_TESTS)
include(googletest)
endif()
if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(websocketpp)
include(asio)
endif()
add_subdirectory(sherpa-onnx)
if(SHERPA_ONNX_ENABLE_C_API)
... ...
... ... @@ -40,6 +40,10 @@ cmake \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=OFF \
-DSHERPA_ONNX_ENABLE_C_API=OFF \
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
-DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \
..
... ...
... ... @@ -76,6 +76,8 @@ cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake"
-DSHERPA_ONNX_ENABLE_JNI=ON \
-DCMAKE_INSTALL_PREFIX=./install \
-DANDROID_ABI="x86_64" \
-DSHERPA_ONNX_ENABLE_C_API=OFF \
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
-DANDROID_PLATFORM=android-21 ..
# make VERBOSE=1 -j4
... ...
function(download_asio)
include(FetchContent)
set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz")
set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b")
# If you don't have access to the Internet,
# please pre-download asio
set(possible_file_locations
$ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz
${PROJECT_SOURCE_DIR}/asio-asio-1-24-0.tar.gz
${PROJECT_BINARY_DIR}/asio-asio-1-24-0.tar.gz
/tmp/asio-asio-1-24-0.tar.gz
/star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(asio_URL "file://${f}")
break()
endif()
endforeach()
FetchContent_Declare(asio
URL ${asio_URL}
URL_HASH ${asio_HASH}
)
FetchContent_GetProperties(asio)
if(NOT asio_POPULATED)
message(STATUS "Downloading asio ${asio_URL}")
FetchContent_Populate(asio)
endif()
message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}")
# add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL)
include_directories(${asio_SOURCE_DIR}/asio/include)
endfunction()
download_asio()
... ...
function(download_websocketpp)
include(FetchContent)
# The latest commit on the develop branch os as 2022-10-22
set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip")
set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583")
# If you don't have access to the Internet,
# please pre-download websocketpp
set(possible_file_locations
$ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
${PROJECT_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
${PROJECT_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
/tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
/star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(websocketpp_URL "file://${f}")
break()
endif()
endforeach()
FetchContent_Declare(websocketpp
URL ${websocketpp_URL}
URL_HASH ${websocketpp_HASH}
)
FetchContent_GetProperties(websocketpp)
if(NOT websocketpp_POPULATED)
message(STATUS "Downloading websocketpp from ${websocketpp_URL}")
FetchContent_Populate(websocketpp)
endif()
message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}")
# add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL)
include_directories(${websocketpp_SOURCE_DIR})
endfunction()
download_websocketpp()
... ...
... ... @@ -4,6 +4,7 @@ set(sources
cat.cc
endpoint.cc
features.cc
file-utils.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
... ... @@ -86,6 +87,32 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO)
install(TARGETS sherpa-onnx-microphone DESTINATION bin)
endif()
if(SHERPA_ONNX_ENABLE_WEBSOCKET)
add_definitions(-DASIO_STANDALONE)
add_definitions(-D_WEBSOCKETPP_CPP11_STL_)
add_executable(sherpa-onnx-online-websocket-server
online-websocket-server-impl.cc
online-websocket-server.cc
)
target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
add_executable(sherpa-onnx-online-websocket-client
online-websocket-client.cc
)
target_link_libraries(sherpa-onnx-online-websocket-client sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx-online-websocket-server -pthread)
target_compile_options(sherpa-onnx-online-websocket-server PRIVATE -Wno-deprecated-declarations)
target_link_libraries(sherpa-onnx-online-websocket-client -pthread)
target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
endif()
endif()
if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs
... ...
exclude_files=tee-stream.h
... ...
... ... @@ -14,6 +14,15 @@
namespace sherpa_onnx {
void FeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("sample-rate", &sampling_rate,
"Sampling rate of the input waveform. Must match the one "
"expected by the model.");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
}
std::string FeatureExtractorConfig::ToString() const {
std::ostringstream os;
... ...
... ... @@ -9,6 +9,8 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct FeatureExtractorConfig {
... ... @@ -16,6 +18,8 @@ struct FeatureExtractorConfig {
int32_t feature_dim = 80;
std::string ToString() const;
void Register(ParseOptions *po);
};
class FeatureExtractor {
... ...
// sherpa-onnx/csrc/file-utils.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/file-utils.h"
#include <fstream>
#include <string>
#include "sherpa-onnx/csrc/log.h"
namespace sherpa_onnx {
bool FileExists(const std::string &filename) {
return std::ifstream(filename).good();
}
void AssertFileExists(const std::string &filename) {
if (!FileExists(filename)) {
SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!";
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/file-utils.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_
#define SHERPA_ONNX_CSRC_FILE_UTILS_H_
#include <fstream>
#include <string>
namespace sherpa_onnx {
/** Check whether a given path is a file or not
*
* @param filename Path to check.
* @return Return true if the given path is a file; return false otherwise.
*/
bool FileExists(const std::string &filename);
/** Abort if the file does not exist.
*
* @param filename The file to check.
*/
void AssertFileExists(const std::string &filename);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_
... ...
... ... @@ -12,6 +12,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
... ... @@ -31,6 +32,19 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
return ans;
}
void OnlineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
endpoint_config.Register(po);
po->Register("enable-endpoint", &enable_endpoint,
"True to enable endpoint detection. False to disable it.");
}
bool OnlineRecognizerConfig::Validate() const {
return model_config.Validate();
}
std::string OnlineRecognizerConfig::ToString() const {
std::ostringstream os;
... ...
... ... @@ -17,11 +17,15 @@
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineRecognizerResult {
std::string text;
// TODO(fangjun): Add a method to return a json string
std::string ToString() const { return text; }
};
struct OnlineRecognizerConfig {
... ... @@ -41,6 +45,9 @@ struct OnlineRecognizerConfig {
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
... ...
... ... @@ -5,8 +5,52 @@
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineTransducerModelConfig::Register(ParseOptions *po) {
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
po->Register("tokens", &tokens, "Path to tokens.txt");
po->Register("num_threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
}
bool OnlineTransducerModelConfig::Validate() const {
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
return false;
}
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
return false;
}
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
return true;
}
std::string OnlineTransducerModelConfig::ToString() const {
std::ostringstream os;
... ...
... ... @@ -6,6 +6,8 @@
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineTransducerModelConfig {
... ... @@ -13,7 +15,7 @@ struct OnlineTransducerModelConfig {
std::string decoder_filename;
std::string joiner_filename;
std::string tokens;
int32_t num_threads;
int32_t num_threads = 2;
bool debug = false;
OnlineTransducerModelConfig() = default;
... ... @@ -29,6 +31,9 @@ struct OnlineTransducerModelConfig {
num_threads(num_threads),
debug(debug) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
... ...
// sherpa/cpp_api/websocket/online-websocket-client.cc
//
// Copyright (c) 2022 Xiaomi Corporation
#include <chrono> // NOLINT
#include <fstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "websocketpp/client.hpp"
#include "websocketpp/config/asio_no_tls_client.hpp"
#include "websocketpp/uri.hpp"
using client = websocketpp::client<websocketpp::config::asio_client>;
using message_ptr = client::message_ptr;
using websocketpp::connection_hdl;
static constexpr const char *kUsageMessage = R"(
Automatic speech recognition with sherpa-onnx using websocket.
Usage:
./bin/sherpa-onnx-online-websocket-client --help
./bin/sherpa-onnx-online-websocket-client \
--server-ip=127.0.0.1 \
--server-port=6006 \
--samples-per-message=8000 \
--seconds-per-message=0.2 \
/path/to/foo.wav
It support only wave of with a single channel, 16kHz, 16-bit samples.
)";
class Client {
public:
Client(asio::io_context &io, // NOLINT
const std::string &ip, int16_t port, const std::vector<float> &samples,
int32_t samples_per_message, float seconds_per_message)
: io_(io),
uri_(/*secure*/ false, ip, port, /*resource*/ "/"),
samples_(samples),
samples_per_message_(samples_per_message),
seconds_per_message_(seconds_per_message) {
c_.clear_access_channels(websocketpp::log::alevel::all);
// c_.set_access_channels(websocketpp::log::alevel::connect);
// c_.set_access_channels(websocketpp::log::alevel::disconnect);
c_.init_asio(&io_);
c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
c_.set_close_handler(
[this](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); });
c_.set_message_handler(
[this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); });
Run();
}
private:
void Run() {
websocketpp::lib::error_code ec;
client::connection_ptr con = c_.get_connection(uri_.str(), ec);
if (ec) {
SHERPA_ONNX_LOGE("Could not create connection to %s because %s",
uri_.str().c_str(), ec.message().c_str());
exit(EXIT_FAILURE);
}
c_.connect(con);
}
void OnOpen(connection_hdl hdl) {
auto start_time = std::chrono::steady_clock::now();
asio::post(
io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); });
}
void OnMessage(connection_hdl hdl, message_ptr msg) {
const std::string &payload = msg->get_payload();
if (payload == "Done!") {
websocketpp::lib::error_code ec;
c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str());
exit(EXIT_FAILURE);
}
} else {
SHERPA_ONNX_LOGE("%s", payload.c_str());
}
}
void SendMessage(
connection_hdl hdl,
std::chrono::time_point<std::chrono::steady_clock> start_time) {
int32_t num_samples = samples_.size();
int32_t num_messages = num_samples / samples_per_message_;
websocketpp::lib::error_code ec;
auto time = std::chrono::steady_clock::now();
int elapsed_time_ms =
std::chrono::duration_cast<std::chrono::milliseconds>(time - start_time)
.count();
if (elapsed_time_ms <
static_cast<int>(seconds_per_message_ * num_sent_messages_ * 1000)) {
std::this_thread::sleep_for(std::chrono::milliseconds(int(
seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms)));
}
if (num_sent_messages_ < 1) {
SHERPA_ONNX_LOGE("Starting to send audio");
}
if (num_sent_messages_ < num_messages) {
c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_,
samples_per_message_ * sizeof(float),
websocketpp::frame::opcode::binary, ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
ec.clear();
++num_sent_messages_;
}
if (num_sent_messages_ == num_messages) {
int32_t remaining_samples = num_samples % samples_per_message_;
if (remaining_samples) {
c_.send(hdl,
samples_.data() + num_sent_messages_ * samples_per_message_,
remaining_samples * sizeof(float),
websocketpp::frame::opcode::binary, ec);
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
ec.clear();
}
// To signal that we have send all the messages
c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec);
SHERPA_ONNX_LOGE("Sent Done Signal");
if (ec) {
SHERPA_ONNX_LOGE("Failed to send audio samples because %s",
ec.message().c_str());
exit(EXIT_FAILURE);
}
} else {
asio::post(io_, [this, hdl, start_time]() {
this->SendMessage(hdl, start_time);
});
}
}
private:
client c_;
asio::io_context &io_;
websocketpp::uri uri_;
std::vector<float> samples_;
int32_t samples_per_message_ = 8000; // 0.5 seconds
float seconds_per_message_ = 0.2;
int32_t num_sent_messages_ = 0;
};
int32_t main(int32_t argc, char *argv[]) {
std::string server_ip = "127.0.0.1";
int32_t server_port = 6006;
// Sample rate of the input wave. No resampling is made.
int32_t sample_rate = 16000;
int32_t samples_per_message = 8000;
float seconds_per_message = 0.2;
sherpa_onnx::ParseOptions po(kUsageMessage);
po.Register("server-ip", &server_ip, "IP address of the websocket server");
po.Register("server-port", &server_port, "Port of the websocket server");
po.Register("sample-rate", &sample_rate,
"Sample rate of the input wave. Should be the one expected by "
"the server");
po.Register("samples-per-message", &samples_per_message,
"Send this number of samples per message.");
po.Register("seconds-per-message", &seconds_per_message,
"We will simulate that each message takes this number of seconds "
"to send. If you select a very large value, it will take a long "
"time to send all the samples");
po.Read(argc, argv);
if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(),
server_ip.end())) {
SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str());
return -1;
}
if (server_port <= 0 || server_port > 65535) {
SHERPA_ONNX_LOGE("Invalid server port: %d", server_port);
return -1;
}
// 0.01 is an arbitrary value. You can change it.
if (samples_per_message <= 0.01 * sample_rate) {
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
samples_per_message);
return -1;
}
// 100 is an arbitrary value. You can change it.
if (samples_per_message >= sample_rate * 100) {
SHERPA_ONNX_LOGE("--samples-per-message is too small: %d",
samples_per_message);
return -1;
}
if (seconds_per_message < 0) {
SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f",
seconds_per_message);
return -1;
}
// 1 is an arbitrary value.
if (seconds_per_message > 1) {
SHERPA_ONNX_LOGE(
"--seconds-per-message is too large: %.3f. You will wait a long time "
"to "
"send all the samples",
seconds_per_message);
return -1;
}
if (po.NumArgs() != 1) {
po.PrintUsage();
return -1;
}
std::string wave_filename = po.GetArg(1);
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok);
if (!is_ok) {
SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
return -1;
}
asio::io_context io_conn; // for network connections
Client c(io_conn, server_ip, server_port, samples, samples_per_message,
seconds_per_message);
io_conn.run(); // will exit when the above connection is closed
SHERPA_ONNX_LOGE("Done!");
return 0;
}
... ...
// sherpa-onnx/csrc/online-websocket-server-impl.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/log.h"
namespace sherpa_onnx {
void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) {
recognizer_config.Register(po);
po->Register("loop-interval-ms", &loop_interval_ms,
"It determines how often the decoder loop runs. ");
po->Register("max-batch-size", &max_batch_size,
"Max batch size for recognition.");
}
void OnlineWebsocketDecoderConfig::Validate() const {
recognizer_config.Validate();
SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0);
SHERPA_ONNX_CHECK_GT(max_batch_size, 0);
}
void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) {
decoder_config.Register(po);
po->Register("log-file", &log_file,
"Path to the log file. Logs are "
"appended to this file");
}
void OnlineWebsocketServerConfig::Validate() const {
decoder_config.Validate();
}
OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server)
: server_(server),
config_(server->GetConfig().decoder_config),
timer_(server->GetWorkContext()) {
recognizer_ = std::make_unique<OnlineRecognizer>(config_.recognizer_config);
}
std::shared_ptr<Connection> OnlineWebsocketDecoder::GetOrCreateConnection(
connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = connections_.find(hdl);
if (it != connections_.end()) {
return it->second;
} else {
// create a new connection
std::shared_ptr<OnlineStream> s = recognizer_->CreateStream();
auto c = std::make_shared<Connection>(hdl, s);
connections_.insert({hdl, c});
return c;
}
}
void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr<Connection> c) {
std::lock_guard<std::mutex> lock(c->mutex);
float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
while (!c->samples.empty()) {
const auto &s = c->samples.front();
c->s->AcceptWaveform(sample_rate, s.data(), s.size());
c->samples.pop_front();
}
}
void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
std::lock_guard<std::mutex> lock(c->mutex);
float sample_rate = config_.recognizer_config.feat_config.sampling_rate;
while (!c->samples.empty()) {
const auto &s = c->samples.front();
c->s->AcceptWaveform(sample_rate, s.data(), s.size());
c->samples.pop_front();
}
// TODO(fangjun): Change the amount of paddings to be configurable
std::vector<float> tail_padding(static_cast<int64_t>(0.8 * sample_rate));
c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size());
c->s->InputFinished();
c->eof = true;
}
void OnlineWebsocketDecoder::Run() {
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
timer_.async_wait(
[this](const asio::error_code &ec) { ProcessConnections(ec); });
}
void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) {
if (ec) {
SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!";
}
std::lock_guard<std::mutex> lock(mutex_);
std::vector<connection_hdl> to_remove;
for (auto &p : connections_) {
auto hdl = p.first;
auto c = p.second;
// The order of `if` below matters!
if (!server_->Contains(hdl)) {
// If the connection is disconnected, we stop processing it
to_remove.push_back(hdl);
continue;
}
if (active_.count(hdl)) {
// Another thread is decoding this stream, so skip it
continue;
}
if (!recognizer_->IsReady(c->s.get()) && !c->eof) {
// this stream has not enough frames to decode, so skip it
continue;
}
if (!recognizer_->IsReady(c->s.get()) && c->eof) {
// We won't receive samples from the client, so send a Done! to client
asio::post(server_->GetWorkContext(),
[this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); });
to_remove.push_back(hdl);
continue;
}
// TODO(fangun): If the connection is timed out, we need to also
// add it to `to_remove`
// this stream has enough frames and is currently not processed by any
// threads, so put it into the ready queue
ready_connections_.push_back(c);
// In `Decode()`, it will remove hdl from `active_`
active_.insert(c->hdl);
}
for (auto hdl : to_remove) {
connections_.erase(hdl);
}
if (!ready_connections_.empty()) {
asio::post(server_->GetWorkContext(), [this]() { Decode(); });
}
// Schedule another call
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
timer_.async_wait(
[this](const asio::error_code &ec) { ProcessConnections(ec); });
}
void OnlineWebsocketDecoder::Decode() {
std::unique_lock<std::mutex> lock(mutex_);
if (ready_connections_.empty()) {
// There are no connections that are ready for decoding,
// so we return directly
return;
}
std::vector<std::shared_ptr<Connection>> c_vec;
std::vector<OnlineStream *> s_vec;
while (!ready_connections_.empty() &&
static_cast<int32_t>(s_vec.size()) < config_.max_batch_size) {
auto c = ready_connections_.front();
ready_connections_.pop_front();
c_vec.push_back(c);
s_vec.push_back(c->s.get());
}
if (!ready_connections_.empty()) {
// there are too many ready connections but this thread can only handle
// max_batch_size connections at a time, so we schedule another call
// to Decode() and let other threads to process the ready connections
asio::post(server_->GetWorkContext(), [this]() { Decode(); });
}
lock.unlock();
recognizer_->DecodeStreams(s_vec.data(), s_vec.size());
lock.lock();
for (auto c : c_vec) {
auto result = recognizer_->GetResult(c->s.get());
asio::post(server_->GetConnectionContext(),
[this, hdl = c->hdl, str = result.ToString()]() {
server_->Send(hdl, str);
});
active_.erase(c->hdl);
}
}
OnlineWebsocketServer::OnlineWebsocketServer(
asio::io_context &io_conn, asio::io_context &io_work,
const OnlineWebsocketServerConfig &config)
: config_(config),
io_conn_(io_conn),
io_work_(io_work),
log_(config.log_file, std::ios::app),
tee_(std::cout, log_),
decoder_(this) {
SetupLog();
server_.init_asio(&io_conn_);
server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
server_.set_message_handler(
[this](connection_hdl hdl, server::message_ptr msg) {
OnMessage(hdl, msg);
});
}
void OnlineWebsocketServer::Run(uint16_t port) {
server_.set_reuse_addr(true);
server_.listen(asio::ip::tcp::v4(), port);
server_.start_accept();
decoder_.Run();
}
void OnlineWebsocketServer::SetupLog() {
server_.clear_access_channels(websocketpp::log::alevel::all);
// server_.set_access_channels(websocketpp::log::alevel::connect);
// server_.set_access_channels(websocketpp::log::alevel::disconnect);
// So that it also prints to std::cout and std::cerr
server_.get_alog().set_ostream(&tee_);
server_.get_elog().set_ostream(&tee_);
}
void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) {
websocketpp::lib::error_code ec;
if (!Contains(hdl)) {
return;
}
server_.send(hdl, text, websocketpp::frame::opcode::text, ec);
if (ec) {
server_.get_alog().write(websocketpp::log::alevel::app, ec.message());
}
}
void OnlineWebsocketServer::OnOpen(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.insert(hdl);
std::ostringstream os;
os << "New connection: "
<< server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". "
<< "Number of active connections: " << connections_.size() << ".\n";
SHERPA_ONNX_LOG(INFO) << os.str();
}
void OnlineWebsocketServer::OnClose(connection_hdl hdl) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(hdl);
SHERPA_ONNX_LOG(INFO) << "Number of active connections: "
<< connections_.size() << "\n";
}
bool OnlineWebsocketServer::Contains(connection_hdl hdl) const {
std::lock_guard<std::mutex> lock(mutex_);
return connections_.count(hdl);
}
void OnlineWebsocketServer::OnMessage(connection_hdl hdl,
server::message_ptr msg) {
auto c = decoder_.GetOrCreateConnection(hdl);
const std::string &payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
if (payload == "Done") {
asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); });
}
break;
case websocketpp::frame::opcode::binary: {
auto p = reinterpret_cast<const float *>(payload.data());
int32_t num_samples = payload.size() / sizeof(float);
std::vector<float> samples(p, p + num_samples);
c->samples.push_back(std::move(samples));
asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); });
break;
}
default:
break;
}
}
void OnlineWebsocketServer::Close(connection_hdl hdl,
websocketpp::close::status::value code,
const std::string &reason) {
auto con = server_.get_con_from_hdl(hdl);
std::ostringstream os;
os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
<< "\n";
websocketpp::lib::error_code ec;
server_.close(hdl, code, reason, ec);
if (ec) {
os << "Failed to close" << con->get_remote_endpoint() << ". "
<< ec.message() << "\n";
}
server_.get_alog().write(websocketpp::log::alevel::app, os.str());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "asio.hpp"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
struct Connection {
// handle to the connection. We can use it to send messages to the client
connection_hdl hdl;
std::shared_ptr<OnlineStream> s;
// set it to true when InputFinished() is called
bool eof = false;
// The last time we received a message from the client
// TODO(fangjun): Use it to disconnect from a client if it is inactive
// for a specified time.
std::chrono::steady_clock::time_point last_active;
std::mutex mutex; // protect samples
// Audio samples received from the client.
//
// The I/O threads receive audio samples into this queue
// and invoke work threads to compute features
std::deque<std::vector<float>> samples;
Connection() = default;
Connection(connection_hdl hdl, std::shared_ptr<OnlineStream> s)
: hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {}
};
struct OnlineWebsocketDecoderConfig {
OnlineRecognizerConfig recognizer_config;
// It determines how often the decoder loop runs.
int32_t loop_interval_ms = 10;
int32_t max_batch_size = 5;
void Register(ParseOptions *po);
void Validate() const;
};
class OnlineWebsocketServer;
class OnlineWebsocketDecoder {
public:
/**
* @param server Not owned.
*/
explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server);
std::shared_ptr<Connection> GetOrCreateConnection(connection_hdl hdl);
// Compute features for a stream given audio samples
void AcceptWaveform(std::shared_ptr<Connection> c);
// signal that there will be no more audio samples for a stream
void InputFinished(std::shared_ptr<Connection> c);
void Run();
private:
void ProcessConnections(const asio::error_code &ec);
/** It is called by one of the worker thread.
*/
void Decode();
private:
OnlineWebsocketServer *server_; // not owned
std::unique_ptr<OnlineRecognizer> recognizer_;
OnlineWebsocketDecoderConfig config_;
asio::steady_timer timer_;
// It protects `connections_`, `ready_connections_`, and `active_`
std::mutex mutex_;
std::map<connection_hdl, std::shared_ptr<Connection>,
std::owner_less<connection_hdl>>
connections_;
// Whenever a connection has enough feature frames for decoding, we put
// it in this queue
std::deque<std::shared_ptr<Connection>> ready_connections_;
// If we are decoding a stream, we put it in the active_ set so that
// only one thread can decode a stream at a time.
std::set<connection_hdl, std::owner_less<connection_hdl>> active_;
};
struct OnlineWebsocketServerConfig {
OnlineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(sherpa_onnx::ParseOptions *po);
void Validate() const;
};
class OnlineWebsocketServer {
public:
explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OnlineWebsocketServerConfig &config);
void Run(uint16_t port);
const OnlineWebsocketServerConfig &GetConfig() const { return config_; }
asio::io_context &GetConnectionContext() { return io_conn_; }
asio::io_context &GetWorkContext() { return io_work_; }
server &GetServer() { return server_; }
void Send(connection_hdl hdl, const std::string &text);
bool Contains(connection_hdl hdl) const;
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
OnlineWebsocketServerConfig config_;
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::ofstream log_;
sherpa_onnx::TeeStream tee_;
OnlineWebsocketDecoder decoder_;
mutable std::mutex mutex_;
std::set<connection_hdl, std::owner_less<connection_hdl>> connections_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
... ...
// sherpa-onnx/csrc/online-websocket-server.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "asio.hpp"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-websocket-server-impl.h"
#include "sherpa-onnx/csrc/parse-options.h"
static constexpr const char *kUsageMessage = R"(
Automatic speech recognition with sherpa-onnx using websocket.
Usage:
./bin/sherpa-onnx-online-websocket-server --help
./bin/sherpa-onnx-online-websocket-server \
--port=6006 \
--num-work-threads=5 \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--log-file=./log.txt \
--max-batch-size=5 \
--loop-interval-ms=10
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)";
int32_t main(int32_t argc, char *argv[]) {
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OnlineWebsocketServerConfig config;
// the server will listen on this port
int32_t port = 6006;
// size of the thread pool for handling network connections
int32_t num_io_threads = 1;
// size of the thread pool for neural network computation and decoding
int32_t num_work_threads = 3;
po.Register("num-io-threads", &num_io_threads,
"Thread pool size for network connections.");
po.Register("num-work-threads", &num_work_threads,
"Thread pool size for for neural network "
"computation and decoding.");
po.Register("port", &port, "The port on which the server will listen.");
config.Register(&po);
if (argc == 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
po.Read(argc, argv);
if (po.NumArgs() != 0) {
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
po.PrintUsage();
exit(EXIT_FAILURE);
}
config.Validate();
asio::io_context io_conn; // for network connections
asio::io_context io_work; // for neural network and decoding
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
server.Run(port);
SHERPA_ONNX_LOGE("Listening on: %d", port);
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
// give some work to do for the io_work pool
auto work_guard = asio::make_work_guard(io_work);
std::vector<std::thread> io_threads;
// decrement since the main thread is also used for network communications
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
}
std::vector<std::thread> work_threads;
for (int32_t i = 0; i < num_work_threads; ++i) {
work_threads.emplace_back([&io_work]() { io_work.run(); });
}
io_conn.run();
for (auto &t : io_threads) {
t.join();
}
for (auto &t : work_threads) {
t.join();
}
return 0;
}
... ...
// Code in this file is copied and modified from
// https://wordaligned.org/articles/cpp-streambufs
#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_
#define SHERPA_ONNX_CSRC_TEE_STREAM_H_
#include <ostream>
#include <streambuf>
#include <string>
namespace sherpa_onnx {
template <typename char_type, typename traits = std::char_traits<char_type>>
class basic_teebuf : public std::basic_streambuf<char_type, traits> {
public:
using int_type = typename traits::int_type;
basic_teebuf(std::basic_streambuf<char_type, traits> *sb1,
std::basic_streambuf<char_type, traits> *sb2)
: sb1(sb1), sb2(sb2) {}
private:
int sync() override {
int const r1 = sb1->pubsync();
int const r2 = sb2->pubsync();
return r1 == 0 && r2 == 0 ? 0 : -1;
}
int_type overflow(int_type c) override {
int_type const eof = traits::eof();
if (traits::eq_int_type(c, eof)) {
return traits::not_eof(c);
} else {
char_type const ch = traits::to_char_type(c);
int_type const r1 = sb1->sputc(ch);
int_type const r2 = sb2->sputc(ch);
return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof
: c;
}
}
private:
std::basic_streambuf<char_type, traits> *sb1;
std::basic_streambuf<char_type, traits> *sb2;
};
using teebuf = basic_teebuf<char>;
class TeeStream : public std::ostream {
public:
TeeStream(std::ostream &o1, std::ostream &o2)
: std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {}
private:
teebuf tbuf;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_
... ...