Fangjun Kuang
Committed by GitHub

Add endpointing (#54)

@@ -4,9 +4,11 @@ build @@ -4,9 +4,11 @@ build
4 onnxruntime-* 4 onnxruntime-*
5 icefall-* 5 icefall-*
6 run.sh 6 run.sh
7 -sherpa-onnx-*  
8 __pycache__ 7 __pycache__
9 dist/ 8 dist/
10 sherpa_onnx.egg-info/ 9 sherpa_onnx.egg-info/
11 .DS_Store 10 .DS_Store
12 build-aarch64-linux-gnu 11 build-aarch64-linux-gnu
  12 +sherpa-onnx-streaming-zipformer-*
  13 +sherpa-onnx-lstm-en-*
  14 +sherpa-onnx-lstm-zh-*
@@ -13,6 +13,7 @@ endif() @@ -13,6 +13,7 @@ endif()
13 13
14 option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) 14 option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
15 option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) 15 option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
  16 +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
16 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) 17 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
17 18
18 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 19 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") @@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
46 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") 47 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
47 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") 48 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
48 message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") 49 message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
  50 +message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
  51 +message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
49 52
50 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 53 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
51 set(CMAKE_CXX_EXTENSIONS OFF) 54 set(CMAKE_CXX_EXTENSIONS OFF)
@@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA) @@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA)
56 add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1) 59 add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1)
57 endif() 60 endif()
58 61
  62 +check_include_file_cxx(cxxabi.h SHERPA_ONNX_HAVE_CXXABI_H)
  63 +check_include_file_cxx(execinfo.h SHERPA_ONNX_HAVE_EXECINFO_H)
  64 +
59 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 65 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
60 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 66 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
61 67
  1 +#!/usr/bin/env python3
  2 +
  3 +# Real-time speech recognition from a microphone with sherpa-onnx Python API
  4 +# with endpoint detection.
  5 +#
  6 +# Please refer to
  7 +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  8 +# to download pre-trained models
  9 +
  10 +import sys
  11 +
  12 +try:
  13 + import sounddevice as sd
  14 +except ImportError as e:
  15 + print("Please install sounddevice first. You can use")
  16 + print()
  17 + print(" pip install sounddevice")
  18 + print()
  19 + print("to install it")
  20 + sys.exit(-1)
  21 +
  22 +import sherpa_onnx
  23 +
  24 +
  25 +def create_recognizer():
  26 + # Please replace the model files if needed.
  27 + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  28 + # for download links.
  29 + recognizer = sherpa_onnx.OnlineRecognizer(
  30 + tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt",
  31 + encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx",
  32 + decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx",
  33 + joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx",
  34 + num_threads=4,
  35 + sample_rate=16000,
  36 + feature_dim=80,
  37 + enable_endpoint_detection=True,
  38 + rule1_min_trailing_silence=2.4,
  39 + rule2_min_trailing_silence=1.2,
  40 + rule3_min_utterance_length=300, # it essentially disables this rule
  41 + )
  42 + return recognizer
  43 +
  44 +
  45 +def main():
  46 + print("Started! Please speak")
  47 + recognizer = create_recognizer()
  48 + sample_rate = 16000
  49 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
  50 + last_result = ""
  51 + stream = recognizer.create_stream()
  52 +
  53 + last_result = ""
  54 + segment_id = 0
  55 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
  56 + while True:
  57 + samples, _ = s.read(samples_per_read) # a blocking read
  58 + samples = samples.reshape(-1)
  59 + stream.accept_waveform(sample_rate, samples)
  60 + while recognizer.is_ready(stream):
  61 + recognizer.decode_stream(stream)
  62 +
  63 + is_endpoint = recognizer.is_endpoint(stream)
  64 +
  65 + result = recognizer.get_result(stream)
  66 +
  67 + if result and (last_result != result):
  68 + last_result = result
  69 + print(f"{segment_id}: {result}")
  70 +
  71 + if result and is_endpoint:
  72 + segment_id += 1
  73 + recognizer.reset(stream)
  74 +
  75 +
  76 +if __name__ == "__main__":
  77 + devices = sd.query_devices()
  78 + print(devices)
  79 + default_input_device_idx = sd.default.device[0]
  80 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  81 +
  82 + try:
  83 + main()
  84 + except KeyboardInterrupt:
  85 + print("\nCaught Ctrl + C. Exiting")
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 -add_library(sherpa-onnx-core 3 +set(sources
4 cat.cc 4 cat.cc
  5 + endpoint.cc
5 features.cc 6 features.cc
6 online-lstm-transducer-model.cc 7 online-lstm-transducer-model.cc
7 online-recognizer.cc 8 online-recognizer.cc
@@ -11,6 +12,7 @@ add_library(sherpa-onnx-core @@ -11,6 +12,7 @@ add_library(sherpa-onnx-core
11 online-transducer-model.cc 12 online-transducer-model.cc
12 online-zipformer-transducer-model.cc 13 online-zipformer-transducer-model.cc
13 onnx-utils.cc 14 onnx-utils.cc
  15 + parse-options.cc
14 resample.cc 16 resample.cc
15 symbol-table.cc 17 symbol-table.cc
16 text-utils.cc 18 text-utils.cc
@@ -18,11 +20,29 @@ add_library(sherpa-onnx-core @@ -18,11 +20,29 @@ add_library(sherpa-onnx-core
18 wave-reader.cc 20 wave-reader.cc
19 ) 21 )
20 22
  23 +if(SHERPA_ONNX_ENABLE_CHECK)
  24 + list(APPEND sources log.cc)
  25 +endif()
  26 +
  27 +add_library(sherpa-onnx-core ${sources})
  28 +
21 target_link_libraries(sherpa-onnx-core 29 target_link_libraries(sherpa-onnx-core
22 onnxruntime 30 onnxruntime
23 kaldi-native-fbank-core 31 kaldi-native-fbank-core
24 ) 32 )
25 33
  34 +if(SHERPA_ONNX_ENABLE_CHECK)
  35 + target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
  36 +
  37 + if(SHERPA_ONNX_HAVE_EXECINFO_H)
  38 + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_EXECINFO_H=1)
  39 + endif()
  40 +
  41 + if(SHERPA_ONNX_HAVE_CXXABI_H)
  42 + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_CXXABI_H=1)
  43 + endif()
  44 +endif()
  45 +
26 add_executable(sherpa-onnx sherpa-onnx.cc) 46 add_executable(sherpa-onnx sherpa-onnx.cc)
27 47
28 target_link_libraries(sherpa-onnx sherpa-onnx-core) 48 target_link_libraries(sherpa-onnx sherpa-onnx-core)
  1 +// sherpa-onnx/csrc/endpoint.cc
  2 +//
  3 +// Copyright (c) 2022 (authors: Pingfeng Luo)
  4 +// 2022-2023 Xiaomi Corporation
  5 +
  6 +#include "sherpa-onnx/csrc/endpoint.h"
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/log.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +static bool RuleActivated(const EndpointRule &rule,
  16 + const std::string &rule_name, float trailing_silence,
  17 + float utterance_length) {
  18 + bool contain_nonsilence = utterance_length > trailing_silence;
  19 + bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) &&
  20 + trailing_silence >= rule.min_trailing_silence &&
  21 + utterance_length >= rule.min_utterance_length;
  22 + if (ans) {
  23 + SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: "
  24 + << (contain_nonsilence ? "true" : "false") << ','
  25 + << trailing_silence << ',' << utterance_length;
  26 + }
  27 + return ans;
  28 +}
  29 +
  30 +static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule,
  31 + const std::string &rule_name) {
  32 + po->Register(
  33 + rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence,
  34 + "If True, for this endpointing " + rule_name +
  35 + " to apply there must be nonsilence in the best-path traceback. "
  36 + "For decoding, a non-blank token is considered as non-silence");
  37 + po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence,
  38 + "This endpointing " + rule_name +
  39 + " requires duration of trailing silence in seconds) to "
  40 + "be >= this value.");
  41 + po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length,
  42 + "This endpointing " + rule_name +
  43 + " requires utterance-length (in seconds) to be >= this "
  44 + "value.");
  45 +}
  46 +
  47 +std::string EndpointRule::ToString() const {
  48 + std::ostringstream os;
  49 +
  50 + os << "EndpointRule(";
  51 + os << "must_contain_nonsilence="
  52 + << (must_contain_nonsilence ? "True" : "False") << ", ";
  53 + os << "min_trailing_silence=" << min_trailing_silence << ", ";
  54 + os << "min_utterance_length=" << min_utterance_length << ")";
  55 +
  56 + return os.str();
  57 +}
  58 +
  59 +void EndpointConfig::Register(ParseOptions *po) {
  60 + RegisterEndpointRule(po, &rule1, "rule1");
  61 + RegisterEndpointRule(po, &rule2, "rule2");
  62 + RegisterEndpointRule(po, &rule3, "rule3");
  63 +}
  64 +
  65 +std::string EndpointConfig::ToString() const {
  66 + std::ostringstream os;
  67 +
  68 + os << "EndpointConfig(";
  69 + os << "rule1=" << rule1.ToString() << ", ";
  70 + os << "rule2=" << rule2.ToString() << ", ";
  71 + os << "rule3=" << rule3.ToString() << ")";
  72 +
  73 + return os.str();
  74 +}
  75 +
  76 +bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
  77 + float frame_shift_in_seconds) const {
  78 + float utterance_length = num_frames_decoded * frame_shift_in_seconds;
  79 + float trailing_silence = trailing_silence_frames * frame_shift_in_seconds;
  80 + if (RuleActivated(config_.rule1, "rule1", trailing_silence,
  81 + utterance_length) ||
  82 + RuleActivated(config_.rule2, "rule2", trailing_silence,
  83 + utterance_length) ||
  84 + RuleActivated(config_.rule3, "rule3", trailing_silence,
  85 + utterance_length)) {
  86 + return true;
  87 + }
  88 + return false;
  89 +}
  90 +
  91 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/endpoint.h
  2 +//
  3 +// Copyright (c) 2022 (authors: Pingfeng Luo)
  4 +// 2022-2023 Xiaomi Corporation
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_
  7 +#define SHERPA_ONNX_CSRC_ENDPOINT_H_
  8 +
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct EndpointRule {
  15 + // If True, for this endpointing rule to apply there must
  16 + // be nonsilence in the best-path traceback.
  17 + // For decoding, a non-blank token is considered as non-silence
  18 + bool must_contain_nonsilence = true;
  19 + // This endpointing rule requires duration of trailing silence
  20 + // (in seconds) to be >= this value.
  21 + float min_trailing_silence = 2.0;
  22 + // This endpointing rule requires utterance-length (in seconds)
  23 + // to be >= this value.
  24 + float min_utterance_length = 0.0f;
  25 +
  26 + EndpointRule() = default;
  27 +
  28 + EndpointRule(bool must_contain_nonsilence, float min_trailing_silence,
  29 + float min_utterance_length)
  30 + : must_contain_nonsilence(must_contain_nonsilence),
  31 + min_trailing_silence(min_trailing_silence),
  32 + min_utterance_length(min_utterance_length) {}
  33 +
  34 + std::string ToString() const;
  35 +};
  36 +
  37 +class ParseOptions;
  38 +
  39 +struct EndpointConfig {
  40 + // For default setting,
  41 + // rule1 times out after 2.4 seconds of silence, even if we decoded nothing.
  42 + // rule2 times out after 1.2 seconds of silence after decoding something.
  43 + // rule3 times out after the utterance is 20 seconds long, regardless of
  44 + // anything else.
  45 + EndpointRule rule1;
  46 + EndpointRule rule2;
  47 + EndpointRule rule3;
  48 +
  49 + void Register(ParseOptions *po);
  50 +
  51 + EndpointConfig()
  52 + : rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {}
  53 +
  54 + EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2,
  55 + const EndpointRule &rule3)
  56 + : rule1(rule1), rule2(rule2), rule3(rule3) {}
  57 +
  58 + std::string ToString() const;
  59 +};
  60 +
  61 +class Endpoint {
  62 + public:
  63 + explicit Endpoint(const EndpointConfig &config) : config_(config) {}
  64 +
  65 + /// This function returns true if this set of endpointing rules thinks we
  66 + /// should terminate decoding.
  67 + bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
  68 + float frame_shift_in_seconds) const;
  69 +
  70 + private:
  71 + EndpointConfig config_;
  72 +};
  73 +
  74 +} // namespace sherpa_onnx
  75 +
  76 +#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_
  1 +// sherpa-onnx/csrc/log.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/log.h"
  6 +
  7 +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
  8 +#include <execinfo.h> // To get stack trace in error messages.
  9 +#ifdef SHERPA_ONNX_HAVE_CXXABI_H
  10 +#include <cxxabi.h> // For name demangling.
  11 +// Useful to decode the stack trace, but only used if we have execinfo.h
  12 +#endif // SHERPA_ONNX_HAVE_CXXABI_H
  13 +#endif // SHERPA_ONNX_HAVE_EXECINFO_H
  14 +
  15 +#include <stdlib.h>
  16 +
  17 +#include <ctime>
  18 +#include <iomanip>
  19 +#include <string>
  20 +
  21 +namespace sherpa_onnx {
  22 +
  23 +std::string GetDateTimeStr() {
  24 + std::ostringstream os;
  25 + std::time_t t = std::time(nullptr);
  26 + std::tm tm = *std::localtime(&t);
  27 + os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss
  28 + return os.str();
  29 +}
  30 +
  31 +static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin,
  32 + std::size_t *end) {
  33 + // Find the first '_' with leading ' ' or '('.
  34 + *begin = std::string::npos;
  35 + for (std::size_t i = 1; i < trace_name.size(); ++i) {
  36 + if (trace_name[i] != '_') {
  37 + continue;
  38 + }
  39 + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') {
  40 + *begin = i;
  41 + break;
  42 + }
  43 + }
  44 + if (*begin == std::string::npos) {
  45 + return false;
  46 + }
  47 + *end = trace_name.find_first_of(" +", *begin);
  48 + return *end != std::string::npos;
  49 +}
  50 +
  51 +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
  52 +static std::string Demangle(const std::string &trace_name) {
  53 +#ifndef SHERPA_ONNX_HAVE_CXXABI_H
  54 + return trace_name;
  55 +#else // SHERPA_ONNX_HAVE_CXXABI_H
  56 + // Try demangle the symbol. We are trying to support the following formats
  57 + // produced by different platforms:
  58 + //
  59 + // Linux:
  60 + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d]
  61 + //
  62 + // Mac:
  63 + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813
  64 + //
  65 + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and
  66 + // demangle it info a readable name like kaldi::UnitTextError.
  67 + std::size_t begin, end;
  68 + if (!LocateSymbolRange(trace_name, &begin, &end)) {
  69 + return trace_name;
  70 + }
  71 + std::string symbol = trace_name.substr(begin, end - begin);
  72 + int status;
  73 + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status);
  74 + if (status == 0 && demangled_name != nullptr) {
  75 + symbol = demangled_name;
  76 + free(demangled_name);
  77 + }
  78 + return trace_name.substr(0, begin) + symbol +
  79 + trace_name.substr(end, std::string::npos);
  80 +#endif // SHERPA_ONNX_HAVE_CXXABI_H
  81 +}
  82 +#endif // SHERPA_ONNX_HAVE_EXECINFO_H
  83 +
  84 +std::string GetStackTrace() {
  85 + std::string ans;
  86 +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
  87 + constexpr const std::size_t kMaxTraceSize = 50;
  88 + constexpr const std::size_t kMaxTracePrint = 50; // Must be even.
  89 + // Buffer for the trace.
  90 + void *trace[kMaxTraceSize];
  91 + // Get the trace.
  92 + std::size_t size = backtrace(trace, kMaxTraceSize);
  93 + // Get the trace symbols.
  94 + char **trace_symbol = backtrace_symbols(trace, size);
  95 + if (trace_symbol == nullptr) return ans;
  96 +
  97 + // Compose a human-readable backtrace string.
  98 + ans += "[ Stack-Trace: ]\n";
  99 + if (size <= kMaxTracePrint) {
  100 + for (std::size_t i = 0; i < size; ++i) {
  101 + ans += Demangle(trace_symbol[i]) + "\n";
  102 + }
  103 + } else { // Print out first+last (e.g.) 5.
  104 + for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) {
  105 + ans += Demangle(trace_symbol[i]) + "\n";
  106 + }
  107 + ans += ".\n.\n.\n";
  108 + for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) {
  109 + ans += Demangle(trace_symbol[i]) + "\n";
  110 + }
  111 + if (size == kMaxTraceSize)
  112 + ans += ".\n.\n.\n"; // Stack was too long, probably a bug.
  113 + }
  114 +
  115 + // We must free the array of pointers allocated by backtrace_symbols(),
  116 + // but not the strings themselves.
  117 + free(trace_symbol);
  118 +#endif // SHERPA_ONNX_HAVE_EXECINFO_H
  119 + return ans;
  120 +}
  121 +
  122 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/log.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_LOG_H_
  6 +#define SHERPA_ONNX_CSRC_LOG_H_
  7 +
  8 +#include <stdio.h>
  9 +
  10 +#include <mutex> // NOLINT
  11 +#include <sstream>
  12 +#include <string>
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +#if SHERPA_ONNX_ENABLE_CHECK
  17 +
  18 +#if defined(NDEBUG)
  19 +constexpr bool kDisableDebug = true;
  20 +#else
  21 +constexpr bool kDisableDebug = false;
  22 +#endif
  23 +
  24 +enum class LogLevel {
  25 + kTrace = 0,
  26 + kDebug = 1,
  27 + kInfo = 2,
  28 + kWarning = 3,
  29 + kError = 4,
  30 + kFatal = 5, // print message and abort the program
  31 +};
  32 +
  33 +// They are used in SHERPA_ONNX_LOG(xxx), so their names
  34 +// do not follow the google c++ code style
  35 +//
  36 +// You can use them in the following way:
  37 +//
  38 +// SHERPA_ONNX_LOG(TRACE) << "some message";
  39 +// SHERPA_ONNX_LOG(DEBUG) << "some message";
  40 +#ifndef _MSC_VER
  41 +constexpr LogLevel TRACE = LogLevel::kTrace;
  42 +constexpr LogLevel DEBUG = LogLevel::kDebug;
  43 +constexpr LogLevel INFO = LogLevel::kInfo;
  44 +constexpr LogLevel WARNING = LogLevel::kWarning;
  45 +constexpr LogLevel ERROR = LogLevel::kError;
  46 +constexpr LogLevel FATAL = LogLevel::kFatal;
  47 +#else
  48 +#define TRACE LogLevel::kTrace
  49 +#define DEBUG LogLevel::kDebug
  50 +#define INFO LogLevel::kInfo
  51 +#define WARNING LogLevel::kWarning
  52 +#define ERROR LogLevel::kError
  53 +#define FATAL LogLevel::kFatal
  54 +#endif
  55 +
  56 +std::string GetStackTrace();
  57 +
  58 +/* Return the current log level.
  59 +
  60 +
  61 + If the current log level is TRACE, then all logged messages are printed out.
  62 +
  63 + If the current log level is DEBUG, log messages with "TRACE" level are not
  64 + shown and all other levels are printed out.
  65 +
  66 + Similarly, if the current log level is INFO, log message with "TRACE" and
  67 + "DEBUG" are not shown and all other levels are printed out.
  68 +
  69 + If it is FATAL, then only FATAL messages are shown.
  70 + */
  71 +inline LogLevel GetCurrentLogLevel() {
  72 + static LogLevel log_level = INFO;
  73 + static std::once_flag init_flag;
  74 + std::call_once(init_flag, []() {
  75 + const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL");
  76 + if (env_log_level == nullptr) return;
  77 +
  78 + std::string s = env_log_level;
  79 + if (s == "TRACE")
  80 + log_level = TRACE;
  81 + else if (s == "DEBUG")
  82 + log_level = DEBUG;
  83 + else if (s == "INFO")
  84 + log_level = INFO;
  85 + else if (s == "WARNING")
  86 + log_level = WARNING;
  87 + else if (s == "ERROR")
  88 + log_level = ERROR;
  89 + else if (s == "FATAL")
  90 + log_level = FATAL;
  91 + else
  92 + fprintf(stderr,
  93 + "Unknown SHERPA_ONNX_LOG_LEVEL: %s"
  94 + "\nSupported values are: "
  95 + "TRACE, DEBUG, INFO, WARNING, ERROR, FATAL",
  96 + s.c_str());
  97 + });
  98 + return log_level;
  99 +}
  100 +
  101 +inline bool EnableAbort() {
  102 + static std::once_flag init_flag;
  103 + static bool enable_abort = false;
  104 + std::call_once(init_flag, []() {
  105 + enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr);
  106 + });
  107 + return enable_abort;
  108 +}
  109 +
  110 +class Logger {
  111 + public:
  112 + Logger(const char *filename, const char *func_name, uint32_t line_num,
  113 + LogLevel level)
  114 + : filename_(filename),
  115 + func_name_(func_name),
  116 + line_num_(line_num),
  117 + level_(level) {
  118 + cur_level_ = GetCurrentLogLevel();
  119 + switch (level) {
  120 + case TRACE:
  121 + if (cur_level_ <= TRACE) fprintf(stderr, "[T] ");
  122 + break;
  123 + case DEBUG:
  124 + if (cur_level_ <= DEBUG) fprintf(stderr, "[D] ");
  125 + break;
  126 + case INFO:
  127 + if (cur_level_ <= INFO) fprintf(stderr, "[I] ");
  128 + break;
  129 + case WARNING:
  130 + if (cur_level_ <= WARNING) fprintf(stderr, "[W] ");
  131 + break;
  132 + case ERROR:
  133 + if (cur_level_ <= ERROR) fprintf(stderr, "[E] ");
  134 + break;
  135 + case FATAL:
  136 + if (cur_level_ <= FATAL) fprintf(stderr, "[F] ");
  137 + break;
  138 + }
  139 +
  140 + if (cur_level_ <= level_) {
  141 + fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name);
  142 + }
  143 + }
  144 +
  145 + ~Logger() noexcept(false) {
  146 + static constexpr const char *kErrMsg = R"(
  147 + Some bad things happened. Please read the above error messages and stack
  148 + trace. If you are using Python, the following command may be helpful:
  149 +
  150 + gdb --args python /path/to/your/code.py
  151 +
  152 + (You can use `gdb` to debug the code. Please consider compiling
  153 + a debug version of sherpa_onnx.).
  154 +
  155 + If you are unable to fix it, please open an issue at:
  156 +
  157 + https://github.com/csukuangfj/kaldi-native-fbank/issues/new
  158 + )";
  159 + if (level_ == FATAL) {
  160 + fprintf(stderr, "\n");
  161 + std::string stack_trace = GetStackTrace();
  162 + if (!stack_trace.empty()) {
  163 + fprintf(stderr, "\n\n%s\n", stack_trace.c_str());
  164 + }
  165 +
  166 + fflush(nullptr);
  167 +
  168 +#ifndef __ANDROID_API__
  169 + if (EnableAbort()) {
  170 + // NOTE: abort() will terminate the program immediately without
  171 + // printing the Python stack backtrace.
  172 + abort();
  173 + }
  174 +
  175 + throw std::runtime_error(kErrMsg);
  176 +#else
  177 + abort();
  178 +#endif
  179 + }
  180 + }
  181 +
  182 + const Logger &operator<<(bool b) const {
  183 + if (cur_level_ <= level_) {
  184 + fprintf(stderr, b ? "true" : "false");
  185 + }
  186 + return *this;
  187 + }
  188 +
  189 + const Logger &operator<<(int8_t i) const {
  190 + if (cur_level_ <= level_) fprintf(stderr, "%d", i);
  191 + return *this;
  192 + }
  193 +
  194 + const Logger &operator<<(const char *s) const {
  195 + if (cur_level_ <= level_) fprintf(stderr, "%s", s);
  196 + return *this;
  197 + }
  198 +
  199 + const Logger &operator<<(int32_t i) const {
  200 + if (cur_level_ <= level_) fprintf(stderr, "%d", i);
  201 + return *this;
  202 + }
  203 +
  204 + const Logger &operator<<(uint32_t i) const {
  205 + if (cur_level_ <= level_) fprintf(stderr, "%u", i);
  206 + return *this;
  207 + }
  208 +
  209 + const Logger &operator<<(uint64_t i) const {
  210 + if (cur_level_ <= level_)
  211 + fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT
  212 + return *this;
  213 + }
  214 +
  215 + const Logger &operator<<(int64_t i) const {
  216 + if (cur_level_ <= level_)
  217 + fprintf(stderr, "%lli", (long long int)i); // NOLINT
  218 + return *this;
  219 + }
  220 +
  221 + const Logger &operator<<(float f) const {
  222 + if (cur_level_ <= level_) fprintf(stderr, "%f", f);
  223 + return *this;
  224 + }
  225 +
  226 + const Logger &operator<<(double d) const {
  227 + if (cur_level_ <= level_) fprintf(stderr, "%f", d);
  228 + return *this;
  229 + }
  230 +
  231 + template <typename T>
  232 + const Logger &operator<<(const T &t) const {
  233 + // require T overloads operator<<
  234 + std::ostringstream os;
  235 + os << t;
  236 + return *this << os.str().c_str();
  237 + }
  238 +
  239 + // specialization to fix compile error: `stringstream << nullptr` is ambiguous
  240 + const Logger &operator<<(const std::nullptr_t &null) const {
  241 + if (cur_level_ <= level_) *this << "(null)";
  242 + return *this;
  243 + }
  244 +
  245 + private:
  246 + const char *filename_;
  247 + const char *func_name_;
  248 + uint32_t line_num_;
  249 + LogLevel level_;
  250 + LogLevel cur_level_;
  251 +};
  252 +#endif // SHERPA_ONNX_ENABLE_CHECK
  253 +
  254 +class Voidifier {
  255 + public:
  256 +#if SHERPA_ONNX_ENABLE_CHECK
  257 + void operator&(const Logger &) const {}
  258 +#endif
  259 +};
  260 +#if !defined(SHERPA_ONNX_ENABLE_CHECK)
  261 +template <typename T>
  262 +const Voidifier &operator<<(const Voidifier &v, T &&) {
  263 + return v;
  264 +}
  265 +#endif
  266 +
  267 +} // namespace sherpa_onnx
  268 +
  269 +#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "")
  270 +
  271 +#ifdef SHERPA_ONNX_ENABLE_CHECK
  272 +
  273 +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \
  274 + defined(__PRETTY_FUNCTION__)
  275 +// for clang and GCC
  276 +#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__
  277 +#else
  278 +// for other compilers
  279 +#define SHERPA_ONNX_FUNC __func__
  280 +#endif
  281 +
  282 +#define SHERPA_ONNX_CHECK(x) \
  283 + (x) ? (void)0 \
  284 + : ::sherpa_onnx::Voidifier() & \
  285 + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
  286 + ::sherpa_onnx::FATAL) \
  287 + << "Check failed: " << #x << " "
  288 +
  289 +// WARNING: x and y may be evaluated multiple times, but this happens only
  290 +// when the check fails. Since the program aborts if it fails, we don't think
  291 +// the extra evaluation of x and y matters.
  292 +//
  293 +// CAUTION: we recommend the following use case:
  294 +//
  295 +// auto x = Foo();
  296 +// auto y = Bar();
  297 +// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message";
  298 +//
  299 +// And please avoid
  300 +//
  301 +// SHERPA_ONNX_CHECK_EQ(Foo(), Bar());
  302 +//
  303 +// if `Foo()` or `Bar()` causes some side effects, e.g., changing some
  304 +// local static variables or global variables.
  305 +#define _SHERPA_ONNX_CHECK_OP(x, y, op) \
  306 + ((x)op(y)) ? (void)0 \
  307 + : ::sherpa_onnx::Voidifier() & \
  308 + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
  309 + ::sherpa_onnx::FATAL) \
  310 + << "Check failed: " << #x << " " << #op << " " << #y \
  311 + << " (" << (x) << " vs. " << (y) << ") "
  312 +
  313 +#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==)
  314 +#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=)
  315 +#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <)
  316 +#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=)
  317 +#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >)
  318 +#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=)
  319 +
  320 +#define SHERPA_ONNX_LOG(x) \
  321 + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_onnx::x)
  322 +
  323 +// ------------------------------------------------------------
  324 +// For debug check
  325 +// ------------------------------------------------------------
  326 +// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank,
  327 +// the following macros are in fact empty and does nothing.
  328 +
  329 +#define SHERPA_ONNX_DCHECK(x) \
  330 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x)
  331 +
  332 +#define SHERPA_ONNX_DCHECK_EQ(x, y) \
  333 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y)
  334 +
  335 +#define SHERPA_ONNX_DCHECK_NE(x, y) \
  336 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y)
  337 +
  338 +#define SHERPA_ONNX_DCHECK_LT(x, y) \
  339 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y)
  340 +
  341 +#define SHERPA_ONNX_DCHECK_LE(x, y) \
  342 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y)
  343 +
  344 +#define SHERPA_ONNX_DCHECK_GT(x, y) \
  345 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y)
  346 +
  347 +#define SHERPA_ONNX_DCHECK_GE(x, y) \
  348 + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y)
  349 +
  350 +#define SHERPA_ONNX_DLOG(x) \
  351 + ::sherpa_onnx::kDisableDebug \
  352 + ? (void)0 \
  353 + : ::sherpa_onnx::Voidifier() & SHERPA_ONNX_LOG(x)
  354 +
  355 +#else
  356 +
  357 +#define SHERPA_ONNX_CHECK(x) ::sherpa_onnx::Voidifier()
  358 +#define SHERPA_ONNX_LOG(x) ::sherpa_onnx::Voidifier()
  359 +
  360 +#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
  361 +#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_onnx::Voidifier()
  362 +#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_onnx::Voidifier()
  363 +#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_onnx::Voidifier()
  364 +#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_onnx::Voidifier()
  365 +#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_onnx::Voidifier()
  366 +
  367 +#define SHERPA_ONNX_DCHECK(x) ::sherpa_onnx::Voidifier()
  368 +#define SHERPA_ONNX_DLOG(x) ::sherpa_onnx::Voidifier()
  369 +#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
  370 +#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_onnx::Voidifier()
  371 +#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_onnx::Voidifier()
  372 +#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_onnx::Voidifier()
  373 +#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_onnx::Voidifier()
  374 +#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_onnx::Voidifier()
  375 +
  376 +#endif // SHERPA_ONNX_CHECK_NE
  377 +
  378 +#endif // SHERPA_ONNX_CSRC_LOG_H_
@@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const {
37 os << "OnlineRecognizerConfig("; 37 os << "OnlineRecognizerConfig(";
38 os << "feat_config=" << feat_config.ToString() << ", "; 38 os << "feat_config=" << feat_config.ToString() << ", ";
39 os << "model_config=" << model_config.ToString() << ", "; 39 os << "model_config=" << model_config.ToString() << ", ";
40 - os << "tokens=\"" << tokens << "\")"; 40 + os << "tokens=\"" << tokens << "\", ";
  41 + os << "endpoint_config=" << endpoint_config.ToString() << ", ";
  42 + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
41 43
42 return os.str(); 44 return os.str();
43 } 45 }
@@ -47,7 +49,8 @@ class OnlineRecognizer::Impl { @@ -47,7 +49,8 @@ class OnlineRecognizer::Impl {
47 explicit Impl(const OnlineRecognizerConfig &config) 49 explicit Impl(const OnlineRecognizerConfig &config)
48 : config_(config), 50 : config_(config),
49 model_(OnlineTransducerModel::Create(config.model_config)), 51 model_(OnlineTransducerModel::Create(config.model_config)),
50 - sym_(config.tokens) { 52 + sym_(config.tokens),
  53 + endpoint_(config_.endpoint_config) {
51 decoder_ = 54 decoder_ =
52 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 55 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
53 } 56 }
@@ -64,7 +67,7 @@ class OnlineRecognizer::Impl { @@ -64,7 +67,7 @@ class OnlineRecognizer::Impl {
64 s->NumFramesReady(); 67 s->NumFramesReady();
65 } 68 }
66 69
67 - void DecodeStreams(OnlineStream **ss, int32_t n) { 70 + void DecodeStreams(OnlineStream **ss, int32_t n) const {
68 int32_t chunk_size = model_->ChunkSize(); 71 int32_t chunk_size = model_->ChunkSize();
69 int32_t chunk_shift = model_->ChunkShift(); 72 int32_t chunk_shift = model_->ChunkShift();
70 73
@@ -111,18 +114,44 @@ class OnlineRecognizer::Impl { @@ -111,18 +114,44 @@ class OnlineRecognizer::Impl {
111 } 114 }
112 } 115 }
113 116
114 - OnlineRecognizerResult GetResult(OnlineStream *s) { 117 + OnlineRecognizerResult GetResult(OnlineStream *s) const {
115 OnlineTransducerDecoderResult decoder_result = s->GetResult(); 118 OnlineTransducerDecoderResult decoder_result = s->GetResult();
116 decoder_->StripLeadingBlanks(&decoder_result); 119 decoder_->StripLeadingBlanks(&decoder_result);
117 120
118 return Convert(decoder_result, sym_); 121 return Convert(decoder_result, sym_);
119 } 122 }
120 123
  124 + bool IsEndpoint(OnlineStream *s) const {
  125 + if (!config_.enable_endpoint) return false;
  126 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  127 +
  128 + // frame shift is 10 milliseconds
  129 + float frame_shift_in_seconds = 0.01;
  130 +
  131 + // subsampling factor is 4
  132 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
  133 +
  134 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  135 + frame_shift_in_seconds);
  136 + }
  137 +
  138 + void Reset(OnlineStream *s) const {
  139 + // reset result and neural network model state,
  140 + // but keep the feature extractor state
  141 +
  142 + // reset result
  143 + s->SetResult(decoder_->GetEmptyResult());
  144 +
  145 + // reset neural network model state
  146 + s->SetStates(model_->GetEncoderInitStates());
  147 + }
  148 +
121 private: 149 private:
122 OnlineRecognizerConfig config_; 150 OnlineRecognizerConfig config_;
123 std::unique_ptr<OnlineTransducerModel> model_; 151 std::unique_ptr<OnlineTransducerModel> model_;
124 std::unique_ptr<OnlineTransducerDecoder> decoder_; 152 std::unique_ptr<OnlineTransducerDecoder> decoder_;
125 SymbolTable sym_; 153 SymbolTable sym_;
  154 + Endpoint endpoint_;
126 }; 155 };
127 156
128 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) 157 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
@@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { @@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
137 return impl_->IsReady(s); 166 return impl_->IsReady(s);
138 } 167 }
139 168
140 -void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) { 169 +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
141 impl_->DecodeStreams(ss, n); 170 impl_->DecodeStreams(ss, n);
142 } 171 }
143 172
144 -OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) { 173 +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
145 return impl_->GetResult(s); 174 return impl_->GetResult(s);
146 } 175 }
147 176
  177 +bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
  178 + return impl_->IsEndpoint(s);
  179 +}
  180 +
  181 +void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
  182 +
148 } // namespace sherpa_onnx 183 } // namespace sherpa_onnx
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <memory> 8 #include <memory>
9 #include <string> 9 #include <string>
10 10
  11 +#include "sherpa-onnx/csrc/endpoint.h"
11 #include "sherpa-onnx/csrc/features.h" 12 #include "sherpa-onnx/csrc/features.h"
12 #include "sherpa-onnx/csrc/online-stream.h" 13 #include "sherpa-onnx/csrc/online-stream.h"
13 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 14 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
@@ -22,13 +23,21 @@ struct OnlineRecognizerConfig { @@ -22,13 +23,21 @@ struct OnlineRecognizerConfig {
22 FeatureExtractorConfig feat_config; 23 FeatureExtractorConfig feat_config;
23 OnlineTransducerModelConfig model_config; 24 OnlineTransducerModelConfig model_config;
24 std::string tokens; 25 std::string tokens;
  26 + EndpointConfig endpoint_config;
  27 + bool enable_endpoint;
25 28
26 OnlineRecognizerConfig() = default; 29 OnlineRecognizerConfig() = default;
27 30
28 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 31 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
29 const OnlineTransducerModelConfig &model_config, 32 const OnlineTransducerModelConfig &model_config,
30 - const std::string &tokens)  
31 - : feat_config(feat_config), model_config(model_config), tokens(tokens) {} 33 + const std::string &tokens,
  34 + const EndpointConfig &endpoint_config,
  35 + bool enable_endpoint)
  36 + : feat_config(feat_config),
  37 + model_config(model_config),
  38 + tokens(tokens),
  39 + endpoint_config(endpoint_config),
  40 + enable_endpoint(enable_endpoint) {}
32 41
33 std::string ToString() const; 42 std::string ToString() const;
34 }; 43 };
@@ -48,7 +57,7 @@ class OnlineRecognizer { @@ -48,7 +57,7 @@ class OnlineRecognizer {
48 bool IsReady(OnlineStream *s) const; 57 bool IsReady(OnlineStream *s) const;
49 58
50 /** Decode a single stream. */ 59 /** Decode a single stream. */
51 - void DecodeStream(OnlineStream *s) { 60 + void DecodeStream(OnlineStream *s) const {
52 OnlineStream *ss[1] = {s}; 61 OnlineStream *ss[1] = {s};
53 DecodeStreams(ss, 1); 62 DecodeStreams(ss, 1);
54 } 63 }
@@ -58,9 +67,18 @@ class OnlineRecognizer { @@ -58,9 +67,18 @@ class OnlineRecognizer {
58 * @param ss Pointer array containing streams to be decoded. 67 * @param ss Pointer array containing streams to be decoded.
59 * @param n Number of streams in `ss`. 68 * @param n Number of streams in `ss`.
60 */ 69 */
61 - void DecodeStreams(OnlineStream **ss, int32_t n); 70 + void DecodeStreams(OnlineStream **ss, int32_t n) const;
62 71
63 - OnlineRecognizerResult GetResult(OnlineStream *s); 72 + OnlineRecognizerResult GetResult(OnlineStream *s) const;
  73 +
  74 + // Return true if we detect an endpoint for this stream.
  75 + // Note: If this function returns true, you usually want to
  76 + // invoke Reset(s).
  77 + bool IsEndpoint(OnlineStream *s) const;
  78 +
  79 + // Clear the state of this stream. If IsEndpoint(s) returns true,
  80 + // after calling this function, IsEndpoint(s) will return false
  81 + void Reset(OnlineStream *s) const;
64 82
65 private: 83 private:
66 class Impl; 84 class Impl;
@@ -55,7 +55,8 @@ class OnlineStream { @@ -55,7 +55,8 @@ class OnlineStream {
55 55
56 int32_t FeatureDim() const; 56 int32_t FeatureDim() const;
57 57
58 - // Return a reference to the number of processed frames so far. 58 + // Return a reference to the number of processed frames so far
  59 + // before subsampling..
59 // Initially, it is 0. It is always less than NumFramesReady(). 60 // Initially, it is 0. It is always less than NumFramesReady().
60 // 61 //
61 // The returned reference is valid as long as this object is alive. 62 // The returned reference is valid as long as this object is alive.
@@ -14,6 +14,9 @@ namespace sherpa_onnx { @@ -14,6 +14,9 @@ namespace sherpa_onnx {
14 struct OnlineTransducerDecoderResult { 14 struct OnlineTransducerDecoderResult {
15 /// The decoded token IDs so far 15 /// The decoded token IDs so far
16 std::vector<int64_t> tokens; 16 std::vector<int64_t> tokens;
  17 +
  18 + /// number of trailing blank frames decoded so far
  19 + int32_t num_trailing_blanks = 0;
17 }; 20 };
18 21
19 class OnlineTransducerDecoder { 22 class OnlineTransducerDecoder {
@@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
113 if (y != 0) { 113 if (y != 0) {
114 emitted = true; 114 emitted = true;
115 (*result)[i].tokens.push_back(y); 115 (*result)[i].tokens.push_back(y);
  116 + (*result)[i].num_trailing_blanks = 0;
  117 + } else {
  118 + ++(*result)[i].num_trailing_blanks;
116 } 119 }
117 } 120 }
118 if (emitted) { 121 if (emitted) {
  1 +// sherpa-onnx/csrc/parse-options.cc
  2 +/**
  3 + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation;
  4 + * Saarland University (Author: Arnab Ghoshal);
  5 + * Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey);
  6 + * Frantisek Skala; Arnab Ghoshal
  7 + * Copyright 2013 Tanel Alumae
  8 + */
  9 +
  10 +// This file is copied and modified from kaldi/src/util/parse-options.cu
  11 +
  12 +#include "sherpa-onnx/csrc/parse-options.h"
  13 +
  14 +#include <ctype.h>
  15 +
  16 +#include <algorithm>
  17 +#include <cctype>
  18 +#include <cstring>
  19 +#include <fstream>
  20 +#include <iomanip>
  21 +#include <limits>
  22 +#include <type_traits>
  23 +#include <unordered_map>
  24 +
  25 +#include "sherpa-onnx/csrc/log.h"
  26 +
  27 +#ifdef _MSC_VER
  28 +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
  29 + _strtoi64(cur_cstr, end_cstr, 10);
  30 +#else
  31 +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
  32 +#endif
  33 +
  34 +namespace sherpa_onnx {
  35 +
  36 +/// Converts a string into an integer via strtoll and returns false if there was
  37 +/// any kind of problem (i.e. the string was not an integer or contained extra
  38 +/// non-whitespace junk, or the integer was too large to fit into the type it is
  39 +/// being converted into). Only sets *out if everything was OK and it returns
  40 +/// true.
  41 +template <class Int>
  42 +bool ConvertStringToInteger(const std::string &str, Int *out) {
  43 + // copied from kaldi/src/util/text-util.h
  44 + static_assert(std::is_integral<Int>::value, "");
  45 + const char *this_str = str.c_str();
  46 + char *end = nullptr;
  47 + errno = 0;
  48 + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
  49 + if (end != this_str) {
  50 + while (isspace(*end)) ++end;
  51 + }
  52 + if (end == this_str || *end != '\0' || errno != 0) return false;
  53 + Int iInt = static_cast<Int>(i);
  54 + if (static_cast<int64_t>(iInt) != i ||
  55 + (i < 0 && !std::numeric_limits<Int>::is_signed)) {
  56 + return false;
  57 + }
  58 + *out = iInt;
  59 + return true;
  60 +}
  61 +
  62 +// copied from kaldi/src/util/text-util.cc
  63 +template <class T>
  64 +class NumberIstream {
  65 + public:
  66 + explicit NumberIstream(std::istream &i) : in_(i) {}
  67 +
  68 + NumberIstream &operator>>(T &x) {
  69 + if (!in_.good()) return *this;
  70 + in_ >> x;
  71 + if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
  72 + return ParseOnFail(&x);
  73 + }
  74 +
  75 + private:
  76 + std::istream &in_;
  77 +
  78 + bool RemainderIsOnlySpaces() {
  79 + if (in_.tellg() != std::istream::pos_type(-1)) {
  80 + std::string rem;
  81 + in_ >> rem;
  82 +
  83 + if (rem.find_first_not_of(' ') != std::string::npos) {
  84 + // there is not only spaces
  85 + return false;
  86 + }
  87 + }
  88 +
  89 + in_.clear();
  90 + return true;
  91 + }
  92 +
  93 + NumberIstream &ParseOnFail(T *x) {
  94 + std::string str;
  95 + in_.clear();
  96 + in_.seekg(0);
  97 + // If the stream is broken even before trying
  98 + // to read from it or if there are many tokens,
  99 + // it's pointless to try.
  100 + if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
  101 + in_.setstate(std::ios_base::failbit);
  102 + return *this;
  103 + }
  104 +
  105 + std::unordered_map<std::string, T> inf_nan_map;
  106 + // we'll keep just uppercase values.
  107 + inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
  108 + inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
  109 + inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
  110 + inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
  111 + inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
  112 + inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
  113 + inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
  114 + inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
  115 + inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
  116 + // MSVC
  117 + inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
  118 + inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
  119 + inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
  120 + inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
  121 +
  122 + std::transform(str.begin(), str.end(), str.begin(), ::toupper);
  123 +
  124 + if (inf_nan_map.find(str) != inf_nan_map.end()) {
  125 + *x = inf_nan_map[str];
  126 + } else {
  127 + in_.setstate(std::ios_base::failbit);
  128 + }
  129 +
  130 + return *this;
  131 + }
  132 +};
  133 +
  134 +/// ConvertStringToReal converts a string into either float or double
  135 +/// and returns false if there was any kind of problem (i.e. the string
  136 +/// was not a floating point number or contained extra non-whitespace junk).
  137 +/// Be careful- this function will successfully read inf's or nan's.
  138 +template <typename T>
  139 +bool ConvertStringToReal(const std::string &str, T *out) {
  140 + std::istringstream iss(str);
  141 +
  142 + NumberIstream<T> i(iss);
  143 +
  144 + i >> *out;
  145 +
  146 + if (iss.fail()) {
  147 + // Number conversion failed.
  148 + return false;
  149 + }
  150 +
  151 + return true;
  152 +}
  153 +
  154 +ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
  155 + : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) {
  156 + if (po != nullptr && po->other_parser_ != nullptr) {
  157 + // we get here if this constructor is used twice, recursively.
  158 + other_parser_ = po->other_parser_;
  159 + } else {
  160 + other_parser_ = po;
  161 + }
  162 + if (po != nullptr && po->prefix_ != "") {
  163 + prefix_ = po->prefix_ + std::string(".") + prefix;
  164 + } else {
  165 + prefix_ = prefix;
  166 + }
  167 +}
  168 +
  169 +void ParseOptions::Register(const std::string &name, bool *ptr,
  170 + const std::string &doc) {
  171 + RegisterTmpl(name, ptr, doc);
  172 +}
  173 +
  174 +void ParseOptions::Register(const std::string &name, int32_t *ptr,
  175 + const std::string &doc) {
  176 + RegisterTmpl(name, ptr, doc);
  177 +}
  178 +
  179 +void ParseOptions::Register(const std::string &name, uint32_t *ptr,
  180 + const std::string &doc) {
  181 + RegisterTmpl(name, ptr, doc);
  182 +}
  183 +
  184 +void ParseOptions::Register(const std::string &name, float *ptr,
  185 + const std::string &doc) {
  186 + RegisterTmpl(name, ptr, doc);
  187 +}
  188 +
  189 +void ParseOptions::Register(const std::string &name, double *ptr,
  190 + const std::string &doc) {
  191 + RegisterTmpl(name, ptr, doc);
  192 +}
  193 +
  194 +void ParseOptions::Register(const std::string &name, std::string *ptr,
  195 + const std::string &doc) {
  196 + RegisterTmpl(name, ptr, doc);
  197 +}
  198 +
  199 +// old-style, used for registering application-specific parameters
  200 +template <typename T>
  201 +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr,
  202 + const std::string &doc) {
  203 + if (other_parser_ == nullptr) {
  204 + this->RegisterCommon(name, ptr, doc, false);
  205 + } else {
  206 + SHERPA_ONNX_CHECK(prefix_ != "")
  207 + << "prefix: " << prefix_ << "\n"
  208 + << "Cannot use empty prefix when registering with prefix.";
  209 + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name
  210 + other_parser_->Register(new_name, ptr, doc);
  211 + }
  212 +}
  213 +
  214 +// does the common part of the job of registering a parameter
  215 +template <typename T>
  216 +void ParseOptions::RegisterCommon(const std::string &name, T *ptr,
  217 + const std::string &doc, bool is_standard) {
  218 + SHERPA_ONNX_CHECK(ptr != nullptr);
  219 + std::string idx = name;
  220 + NormalizeArgName(&idx);
  221 + if (doc_map_.find(idx) != doc_map_.end()) {
  222 + SHERPA_ONNX_LOG(WARNING)
  223 + << "Registering option twice, ignoring second time: " << name;
  224 + } else {
  225 + this->RegisterSpecific(name, idx, ptr, doc, is_standard);
  226 + }
  227 +}
  228 +
  229 +// used to register standard parameters (those that are present in all of the
  230 +// applications)
  231 +template <typename T>
  232 +void ParseOptions::RegisterStandard(const std::string &name, T *ptr,
  233 + const std::string &doc) {
  234 + this->RegisterCommon(name, ptr, doc, true);
  235 +}
  236 +
  237 +void ParseOptions::RegisterSpecific(const std::string &name,
  238 + const std::string &idx, bool *b,
  239 + const std::string &doc, bool is_standard) {
  240 + bool_map_[idx] = b;
  241 + doc_map_[idx] =
  242 + DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"),
  243 + is_standard);
  244 +}
  245 +
  246 +void ParseOptions::RegisterSpecific(const std::string &name,
  247 + const std::string &idx, int32_t *i,
  248 + const std::string &doc, bool is_standard) {
  249 + int_map_[idx] = i;
  250 + std::ostringstream ss;
  251 + ss << doc << " (int, default = " << *i << ")";
  252 + doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
  253 +}
  254 +
  255 +void ParseOptions::RegisterSpecific(const std::string &name,
  256 + const std::string &idx, uint32_t *u,
  257 + const std::string &doc, bool is_standard) {
  258 + uint_map_[idx] = u;
  259 + std::ostringstream ss;
  260 + ss << doc << " (uint, default = " << *u << ")";
  261 + doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
  262 +}
  263 +
  264 +void ParseOptions::RegisterSpecific(const std::string &name,
  265 + const std::string &idx, float *f,
  266 + const std::string &doc, bool is_standard) {
  267 + float_map_[idx] = f;
  268 + std::ostringstream ss;
  269 + ss << doc << " (float, default = " << *f << ")";
  270 + doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
  271 +}
  272 +
  273 +void ParseOptions::RegisterSpecific(const std::string &name,
  274 + const std::string &idx, double *f,
  275 + const std::string &doc, bool is_standard) {
  276 + double_map_[idx] = f;
  277 + std::ostringstream ss;
  278 + ss << doc << " (double, default = " << *f << ")";
  279 + doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
  280 +}
  281 +
  282 +void ParseOptions::RegisterSpecific(const std::string &name,
  283 + const std::string &idx, std::string *s,
  284 + const std::string &doc, bool is_standard) {
  285 + string_map_[idx] = s;
  286 + doc_map_[idx] =
  287 + DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard);
  288 +}
  289 +
  290 +void ParseOptions::DisableOption(const std::string &name) {
  291 + if (argv_ != nullptr) {
  292 + SHERPA_ONNX_LOG(FATAL)
  293 + << "DisableOption must not be called after calling Read().";
  294 + }
  295 + if (doc_map_.erase(name) == 0) {
  296 + SHERPA_ONNX_LOG(FATAL) << "Option " << name
  297 + << " was not registered so cannot be disabled: ";
  298 + }
  299 + bool_map_.erase(name);
  300 + int_map_.erase(name);
  301 + uint_map_.erase(name);
  302 + float_map_.erase(name);
  303 + double_map_.erase(name);
  304 + string_map_.erase(name);
  305 +}
  306 +
  307 +int ParseOptions::NumArgs() const { return positional_args_.size(); }
  308 +
  309 +std::string ParseOptions::GetArg(int i) const {
  310 + if (i < 1 || i > static_cast<int>(positional_args_.size())) {
  311 + SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i;
  312 + }
  313 +
  314 + return positional_args_[i - 1];
  315 +}
  316 +
  317 +// We currently do not support any other options.
  318 +enum ShellType { kBash = 0 };
  319 +
  320 +// This can be changed in the code if it ever does need to be changed (as it's
  321 +// unlikely that one compilation of this tool-set would use both shells).
  322 +static ShellType kShellType = kBash;
  323 +
  324 +// Returns true if we need to escape a string before putting it into
  325 +// a shell (mainly thinking of bash shell, but should work for others)
  326 +// This is for the convenience of the user so command-lines that are
  327 +// printed out by ParseOptions::Read (with --print-args=true) are
  328 +// paste-able into the shell and will run. If you use a different type of
  329 +// shell, it might be necessary to change this function.
  330 +// But it's mostly a cosmetic issue as it basically affects how
  331 +// the program echoes its command-line arguments to the screen.
  332 +static bool MustBeQuoted(const std::string &str, ShellType st) {
  333 + // Only Bash is supported (for the moment).
  334 + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
  335 +
  336 + const char *c = str.c_str();
  337 + if (*c == '\0') {
  338 + return true; // Must quote empty string
  339 + } else {
  340 + const char *ok_chars[2];
  341 +
  342 + // These seem not to be interpreted as long as there are no other "bad"
  343 + // characters involved (e.g. "," would be interpreted as part of something
  344 + // like a{b,c}, but not on its own.
  345 + ok_chars[kBash] = "[]~#^_-+=:.,/";
  346 +
  347 + // Just want to make sure that a space character doesn't get automatically
  348 + // inserted here via an automated style-checking script, like it did before.
  349 + SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' '));
  350 +
  351 + for (; *c != '\0'; ++c) {
  352 + // For non-alphanumeric characters we have a list of characters which
  353 + // are OK. All others are forbidden (this is easier since the shell
  354 + // interprets most non-alphanumeric characters).
  355 + if (!isalnum(*c)) {
  356 + const char *d;
  357 + for (d = ok_chars[st]; *d != '\0'; ++d) {
  358 + if (*c == *d) break;
  359 + }
  360 + // If not alphanumeric or one of the "ok_chars", it must be escaped.
  361 + if (*d == '\0') return true;
  362 + }
  363 + }
  364 + return false; // The string was OK. No quoting or escaping.
  365 + }
  366 +}
  367 +
  368 +// Returns a quoted and escaped version of "str"
  369 +// which has previously been determined to need escaping.
  370 +// Our aim is to print out the command line in such a way that if it's
  371 +// pasted into a shell of ShellType "st" (only bash for now), it
  372 +// will get passed to the program in the same way.
  373 +static std::string QuoteAndEscape(const std::string &str, ShellType st) {
  374 + // Only Bash is supported (for the moment).
  375 + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
  376 +
  377 + // For now we use the following rules:
  378 + // In the normal case, we quote with single-quote "'", and to escape
  379 + // a single-quote we use the string: '\'' (interpreted as closing the
  380 + // single-quote, putting an escaped single-quote from the shell, and
  381 + // then reopening the single quote).
  382 + char quote_char = '\'';
  383 + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b
  384 +
  385 + // If the string contains single-quotes that would need escaping this
  386 + // way, and we determine that the string could be safely double-quoted
  387 + // without requiring any escaping, then we double-quote the string.
  388 + // This is the case if the characters "`$\ do not appear in the string.
  389 + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html
  390 + const char *c_str = str.c_str();
  391 + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) {
  392 + quote_char = '"';
  393 + escape_str = "\\\""; // should never be accessed.
  394 + }
  395 +
  396 + char buf[2];
  397 + buf[1] = '\0';
  398 +
  399 + buf[0] = quote_char;
  400 + std::string ans = buf;
  401 + const char *c = str.c_str();
  402 + for (; *c != '\0'; ++c) {
  403 + if (*c == quote_char) {
  404 + ans += escape_str;
  405 + } else {
  406 + buf[0] = *c;
  407 + ans += buf;
  408 + }
  409 + }
  410 + buf[0] = quote_char;
  411 + ans += buf;
  412 + return ans;
  413 +}
  414 +
  415 +// static function
  416 +std::string ParseOptions::Escape(const std::string &str) {
  417 + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str;
  418 +}
  419 +
  420 +int ParseOptions::Read(int argc, const char *const argv[]) {
  421 + argc_ = argc;
  422 + argv_ = argv;
  423 + std::string key, value;
  424 + int i;
  425 +
  426 + // first pass: look for config parameter, look for priority
  427 + for (i = 1; i < argc; ++i) {
  428 + if (std::strncmp(argv[i], "--", 2) == 0) {
  429 + if (std::strcmp(argv[i], "--") == 0) {
  430 + // a lone "--" marks the end of named options
  431 + break;
  432 + }
  433 + bool has_equal_sign;
  434 + SplitLongArg(argv[i], &key, &value, &has_equal_sign);
  435 + NormalizeArgName(&key);
  436 + Trim(&value);
  437 + if (key.compare("config") == 0) {
  438 + ReadConfigFile(value);
  439 + } else if (key.compare("help") == 0) {
  440 + PrintUsage();
  441 + exit(0);
  442 + }
  443 + }
  444 + }
  445 +
  446 + bool double_dash_seen = false;
  447 + // second pass: add the command line options
  448 + for (i = 1; i < argc; ++i) {
  449 + if (std::strncmp(argv[i], "--", 2) == 0) {
  450 + if (std::strcmp(argv[i], "--") == 0) {
  451 + // A lone "--" marks the end of named options.
  452 + // Skip that option and break the processing of named options
  453 + i += 1;
  454 + double_dash_seen = true;
  455 + break;
  456 + }
  457 + bool has_equal_sign;
  458 + SplitLongArg(argv[i], &key, &value, &has_equal_sign);
  459 + NormalizeArgName(&key);
  460 + Trim(&value);
  461 + if (!SetOption(key, value, has_equal_sign)) {
  462 + PrintUsage(true);
  463 + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i];
  464 + }
  465 + } else {
  466 + break;
  467 + }
  468 + }
  469 +
  470 + // process remaining arguments as positional
  471 + for (; i < argc; ++i) {
  472 + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) {
  473 + double_dash_seen = true;
  474 + } else {
  475 + positional_args_.push_back(std::string(argv[i]));
  476 + }
  477 + }
  478 +
  479 + // if the user did not suppress this with --print-args = false....
  480 + if (print_args_) {
  481 + std::ostringstream strm;
  482 + for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
  483 + strm << '\n';
  484 + SHERPA_ONNX_LOG(INFO) << strm.str();
  485 + }
  486 + return i;
  487 +}
  488 +
  489 +void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
  490 + std::ostringstream os;
  491 + os << '\n' << usage_ << '\n';
  492 + // first we print application-specific options
  493 + bool app_specific_header_printed = false;
  494 + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
  495 + if (it->second.is_standard_ == false) { // application-specific option
  496 + if (app_specific_header_printed == false) { // header was not yet printed
  497 + os << "Options:" << '\n';
  498 + app_specific_header_printed = true;
  499 + }
  500 + os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
  501 + << it->second.use_msg_ << '\n';
  502 + }
  503 + }
  504 + if (app_specific_header_printed == true) {
  505 + os << '\n';
  506 + }
  507 +
  508 + // then the standard options
  509 + os << "Standard options:" << '\n';
  510 + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
  511 + if (it->second.is_standard_ == true) { // we have standard option
  512 + os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
  513 + << it->second.use_msg_ << '\n';
  514 + }
  515 + }
  516 + os << '\n';
  517 + if (print_command_line) {
  518 + std::ostringstream strm;
  519 + strm << "Command line was: ";
  520 + for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " ";
  521 + strm << '\n';
  522 + os << strm.str();
  523 + }
  524 +
  525 + SHERPA_ONNX_LOG(INFO) << os.str();
  526 +}
  527 +
  528 +void ParseOptions::PrintConfig(std::ostream &os) const {
  529 + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n';
  530 + std::string key;
  531 + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
  532 + key = it->first;
  533 + os << it->second.name_ << " = ";
  534 + if (bool_map_.end() != bool_map_.find(key)) {
  535 + os << (*bool_map_.at(key) ? "true" : "false");
  536 + } else if (int_map_.end() != int_map_.find(key)) {
  537 + os << (*int_map_.at(key));
  538 + } else if (uint_map_.end() != uint_map_.find(key)) {
  539 + os << (*uint_map_.at(key));
  540 + } else if (float_map_.end() != float_map_.find(key)) {
  541 + os << (*float_map_.at(key));
  542 + } else if (double_map_.end() != double_map_.find(key)) {
  543 + os << (*double_map_.at(key));
  544 + } else if (string_map_.end() != string_map_.find(key)) {
  545 + os << "'" << *string_map_.at(key) << "'";
  546 + } else {
  547 + SHERPA_ONNX_LOG(FATAL)
  548 + << "PrintConfig: unrecognized option " << key << "[code error]";
  549 + }
  550 + os << '\n';
  551 + }
  552 + os << '\n';
  553 +}
  554 +
  555 +void ParseOptions::ReadConfigFile(const std::string &filename) {
  556 + std::ifstream is(filename.c_str(), std::ifstream::in);
  557 + if (!is.good()) {
  558 + SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename;
  559 + }
  560 +
  561 + std::string line, key, value;
  562 + int32_t line_number = 0;
  563 + while (std::getline(is, line)) {
  564 + ++line_number;
  565 + // trim out the comments
  566 + size_t pos;
  567 + if ((pos = line.find_first_of('#')) != std::string::npos) {
  568 + line.erase(pos);
  569 + }
  570 + // skip empty lines
  571 + Trim(&line);
  572 + if (line.length() == 0) continue;
  573 +
  574 + if (line.substr(0, 2) != "--") {
  575 + SHERPA_ONNX_LOG(FATAL)
  576 + << "Reading config file " << filename << ": line " << line_number
  577 + << " does not look like a line "
  578 + << "from a Kaldi command-line program's config file: should "
  579 + << "be of the form --x=y. Note: config files intended to "
  580 + << "be sourced by shell scripts lack the '--'.";
  581 + }
  582 +
  583 + // parse option
  584 + bool has_equal_sign;
  585 + SplitLongArg(line, &key, &value, &has_equal_sign);
  586 + NormalizeArgName(&key);
  587 + Trim(&value);
  588 + if (!SetOption(key, value, has_equal_sign)) {
  589 + PrintUsage(true);
  590 + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file "
  591 + << filename << ": line " << line_number;
  592 + }
  593 + }
  594 +}
  595 +
  596 +void ParseOptions::SplitLongArg(const std::string &in, std::string *key,
  597 + std::string *value,
  598 + bool *has_equal_sign) const {
  599 + SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition.
  600 + size_t pos = in.find_first_of('=', 0);
  601 + if (pos == std::string::npos) { // we allow --option for bools
  602 + // defaults to empty. We handle this differently in different cases.
  603 + *key = in.substr(2, in.size() - 2); // 2 because starts with --.
  604 + *value = "";
  605 + *has_equal_sign = false;
  606 + } else if (pos == 2) { // we also don't allow empty keys: --=value
  607 + PrintUsage(true);
  608 + SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in;
  609 + } else { // normal case: --option=value
  610 + *key = in.substr(2, pos - 2); // 2 because starts with --.
  611 + *value = in.substr(pos + 1);
  612 + *has_equal_sign = true;
  613 + }
  614 +}
  615 +
  616 +void ParseOptions::NormalizeArgName(std::string *str) const {
  617 + std::string out;
  618 + std::string::iterator it;
  619 +
  620 + for (it = str->begin(); it != str->end(); ++it) {
  621 + if (*it == '_') {
  622 + out += '-'; // convert _ to -
  623 + } else {
  624 + out += std::tolower(*it);
  625 + }
  626 + }
  627 + *str = out;
  628 +
  629 + SHERPA_ONNX_CHECK_GT(str->length(), 0);
  630 +}
  631 +
  632 +void ParseOptions::Trim(std::string *str) const {
  633 + const char *white_chars = " \t\n\r\f\v";
  634 +
  635 + std::string::size_type pos = str->find_last_not_of(white_chars);
  636 + if (pos != std::string::npos) {
  637 + str->erase(pos + 1);
  638 + pos = str->find_first_not_of(white_chars);
  639 + if (pos != std::string::npos) str->erase(0, pos);
  640 + } else {
  641 + str->erase(str->begin(), str->end());
  642 + }
  643 +}
  644 +
  645 +bool ParseOptions::SetOption(const std::string &key, const std::string &value,
  646 + bool has_equal_sign) {
  647 + if (bool_map_.end() != bool_map_.find(key)) {
  648 + if (has_equal_sign && value == "") {
  649 + SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "=";
  650 + }
  651 + *(bool_map_[key]) = ToBool(value);
  652 + } else if (int_map_.end() != int_map_.find(key)) {
  653 + *(int_map_[key]) = ToInt(value);
  654 + } else if (uint_map_.end() != uint_map_.find(key)) {
  655 + *(uint_map_[key]) = ToUint(value);
  656 + } else if (float_map_.end() != float_map_.find(key)) {
  657 + *(float_map_[key]) = ToFloat(value);
  658 + } else if (double_map_.end() != double_map_.find(key)) {
  659 + *(double_map_[key]) = ToDouble(value);
  660 + } else if (string_map_.end() != string_map_.find(key)) {
  661 + if (!has_equal_sign) {
  662 + SHERPA_ONNX_LOG(FATAL)
  663 + << "Invalid option --" << key << " (option format is --x=y).";
  664 + }
  665 + *(string_map_[key]) = value;
  666 + } else {
  667 + return false;
  668 + }
  669 + return true;
  670 +}
  671 +
  672 +bool ParseOptions::ToBool(std::string str) const {
  673 + std::transform(str.begin(), str.end(), str.begin(), ::tolower);
  674 +
  675 + // allow "" as a valid option for "true", so that --x is the same as --x=true
  676 + if ((str.compare("true") == 0) || (str.compare("t") == 0) ||
  677 + (str.compare("1") == 0) || (str.compare("") == 0)) {
  678 + return true;
  679 + }
  680 + if ((str.compare("false") == 0) || (str.compare("f") == 0) ||
  681 + (str.compare("0") == 0)) {
  682 + return false;
  683 + }
  684 + // if it is neither true nor false:
  685 + PrintUsage(true);
  686 + SHERPA_ONNX_LOG(FATAL)
  687 + << "Invalid format for boolean argument [expected true or false]: "
  688 + << str;
  689 + return false; // never reached
  690 +}
  691 +
  692 +int32_t ParseOptions::ToInt(const std::string &str) const {
  693 + int32_t ret = 0;
  694 + if (!ConvertStringToInteger(str, &ret))
  695 + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
  696 + return ret;
  697 +}
  698 +
  699 +uint32_t ParseOptions::ToUint(const std::string &str) const {
  700 + uint32_t ret = 0;
  701 + if (!ConvertStringToInteger(str, &ret))
  702 + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
  703 + return ret;
  704 +}
  705 +
  706 +float ParseOptions::ToFloat(const std::string &str) const {
  707 + float ret;
  708 + if (!ConvertStringToReal(str, &ret))
  709 + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
  710 + return ret;
  711 +}
  712 +
  713 +double ParseOptions::ToDouble(const std::string &str) const {
  714 + double ret;
  715 + if (!ConvertStringToReal(str, &ret))
  716 + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
  717 + return ret;
  718 +}
  719 +
  720 +// instantiate templates
  721 +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr,
  722 + const std::string &doc);
  723 +template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr,
  724 + const std::string &doc);
  725 +template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr,
  726 + const std::string &doc);
  727 +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr,
  728 + const std::string &doc);
  729 +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr,
  730 + const std::string &doc);
  731 +template void ParseOptions::RegisterTmpl(const std::string &name,
  732 + std::string *ptr,
  733 + const std::string &doc);
  734 +
  735 +template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr,
  736 + const std::string &doc);
  737 +template void ParseOptions::RegisterStandard(const std::string &name,
  738 + int32_t *ptr,
  739 + const std::string &doc);
  740 +template void ParseOptions::RegisterStandard(const std::string &name,
  741 + uint32_t *ptr,
  742 + const std::string &doc);
  743 +template void ParseOptions::RegisterStandard(const std::string &name,
  744 + float *ptr,
  745 + const std::string &doc);
  746 +template void ParseOptions::RegisterStandard(const std::string &name,
  747 + double *ptr,
  748 + const std::string &doc);
  749 +template void ParseOptions::RegisterStandard(const std::string &name,
  750 + std::string *ptr,
  751 + const std::string &doc);
  752 +
  753 +template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr,
  754 + const std::string &doc,
  755 + bool is_standard);
  756 +template void ParseOptions::RegisterCommon(const std::string &name,
  757 + int32_t *ptr, const std::string &doc,
  758 + bool is_standard);
  759 +template void ParseOptions::RegisterCommon(const std::string &name,
  760 + uint32_t *ptr,
  761 + const std::string &doc,
  762 + bool is_standard);
  763 +template void ParseOptions::RegisterCommon(const std::string &name, float *ptr,
  764 + const std::string &doc,
  765 + bool is_standard);
  766 +template void ParseOptions::RegisterCommon(const std::string &name, double *ptr,
  767 + const std::string &doc,
  768 + bool is_standard);
  769 +template void ParseOptions::RegisterCommon(const std::string &name,
  770 + std::string *ptr,
  771 + const std::string &doc,
  772 + bool is_standard);
  773 +
  774 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/parse-options.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +//
  5 +// This file is copied and modified from kaldi/src/util/parse-options.h
  6 +
  7 +#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
  8 +#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
  9 +
  10 +#include <sstream>
  11 +#include <string>
  12 +#include <unordered_map>
  13 +#include <vector>
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class ParseOptions {
  18 + public:
  19 + explicit ParseOptions(const char *usage)
  20 + : print_args_(true),
  21 + help_(false),
  22 + usage_(usage),
  23 + argc_(0),
  24 + argv_(nullptr),
  25 + prefix_(""),
  26 + other_parser_(nullptr) {
  27 +#if !defined(_MSC_VER) && !defined(__CYGWIN__)
  28 + // This is just a convenient place to set the stderr to line
  29 + // buffering mode, since it's called at program start.
  30 + // This helps ensure different programs' output is not mixed up.
  31 + setlinebuf(stderr);
  32 +#endif
  33 + RegisterStandard("config", &config_,
  34 + "Configuration file to read (this "
  35 + "option may be repeated)");
  36 + RegisterStandard("print-args", &print_args_,
  37 + "Print the command line arguments (to stderr)");
  38 + RegisterStandard("help", &help_, "Print out usage message");
  39 + }
  40 +
  41 + /**
  42 + This is a constructor for the special case where some options are
  43 + registered with a prefix to avoid conflicts. The object thus created will
  44 + only be used temporarily to register an options class with the original
  45 + options parser (which is passed as the *other pointer) using the given
  46 + prefix. It should not be used for any other purpose, and the prefix must
  47 + not be the empty string. It seems to be the least bad way of implementing
  48 + options with prefixes at this point.
  49 + Example of usage is:
  50 + ParseOptions po; // original ParseOptions object
  51 + ParseOptions po_mfcc("mfcc", &po); // object with prefix.
  52 + MfccOptions mfcc_opts;
  53 + mfcc_opts.Register(&po_mfcc);
  54 + The options will now get registered as, e.g., --mfcc.frame-shift=10.0
  55 + instead of just --frame-shift=10.0
  56 + */
  57 + ParseOptions(const std::string &prefix, ParseOptions *other);
  58 +
  59 + ParseOptions(const ParseOptions &) = delete;
  60 + ParseOptions &operator=(const ParseOptions &) = delete;
  61 + ~ParseOptions() = default;
  62 +
  63 + void Register(const std::string &name, bool *ptr, const std::string &doc);
  64 + void Register(const std::string &name, int32_t *ptr, const std::string &doc);
  65 + void Register(const std::string &name, uint32_t *ptr, const std::string &doc);
  66 + void Register(const std::string &name, float *ptr, const std::string &doc);
  67 + void Register(const std::string &name, double *ptr, const std::string &doc);
  68 + void Register(const std::string &name, std::string *ptr,
  69 + const std::string &doc);
  70 +
  71 + /// If called after registering an option and before calling
  72 + /// Read(), disables that option from being used. Will crash
  73 + /// at runtime if that option had not been registered.
  74 + void DisableOption(const std::string &name);
  75 +
  76 + /// This one is used for registering standard parameters of all the programs
  77 + template <typename T>
  78 + void RegisterStandard(const std::string &name, T *ptr,
  79 + const std::string &doc);
  80 +
  81 + /**
  82 + Parses the command line options and fills the ParseOptions-registered
  83 + variables. This must be called after all the variables were registered!!!
  84 +
  85 + Initially the variables have implicit values,
  86 + then the config file values are set-up,
  87 + finally the command line values given.
  88 + Returns the first position in argv that was not used.
  89 + [typically not useful: use NumParams() and GetParam(). ]
  90 + */
  91 + int Read(int argc, const char *const *argv);
  92 +
  93 + /// Prints the usage documentation [provided in the constructor].
  94 + void PrintUsage(bool print_command_line = false) const;
  95 +
  96 + /// Prints the actual configuration of all the registered variables
  97 + void PrintConfig(std::ostream &os) const;
  98 +
  99 + /// Reads the options values from a config file. Must be called after
  100 + /// registering all options. This is usually used internally after the
  101 + /// standard --config option is used, but it may also be called from a
  102 + /// program.
  103 + void ReadConfigFile(const std::string &filename);
  104 +
  105 + /// Number of positional parameters (c.f. argc-1).
  106 + int NumArgs() const;
  107 +
  108 + /// Returns one of the positional parameters; 1-based indexing for argc/argv
  109 + /// compatibility. Will crash if param is not >=1 and <=NumArgs().
  110 + ///
  111 + /// Note: Index is 1 based.
  112 + std::string GetArg(int param) const;
  113 +
  114 + std::string GetOptArg(int param) const {
  115 + return (param <= NumArgs() ? GetArg(param) : "");
  116 + }
  117 +
  118 + /// The following function will return a possibly quoted and escaped
  119 + /// version of "str", according to the current shell. Currently
  120 + /// this is just hardwired to bash. It's useful for debug output.
  121 + static std::string Escape(const std::string &str);
  122 +
  123 + private:
  124 + /// Template to register various variable types,
  125 + /// used for program-specific parameters
  126 + template <typename T>
  127 + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc);
  128 +
  129 + // Following functions do just the datatype-specific part of the job
  130 + /// Register boolean variable
  131 + void RegisterSpecific(const std::string &name, const std::string &idx,
  132 + bool *b, const std::string &doc, bool is_standard);
  133 + /// Register int32_t variable
  134 + void RegisterSpecific(const std::string &name, const std::string &idx,
  135 + int32_t *i, const std::string &doc, bool is_standard);
  136 + /// Register unsigned int32_t variable
  137 + void RegisterSpecific(const std::string &name, const std::string &idx,
  138 + uint32_t *u, const std::string &doc, bool is_standard);
  139 + /// Register float variable
  140 + void RegisterSpecific(const std::string &name, const std::string &idx,
  141 + float *f, const std::string &doc, bool is_standard);
  142 + /// Register double variable [useful as we change BaseFloat type].
  143 + void RegisterSpecific(const std::string &name, const std::string &idx,
  144 + double *f, const std::string &doc, bool is_standard);
  145 + /// Register string variable
  146 + void RegisterSpecific(const std::string &name, const std::string &idx,
  147 + std::string *s, const std::string &doc,
  148 + bool is_standard);
  149 +
  150 + /// Does the actual job for both kinds of parameters
  151 + /// Does the common part of the job for all datatypes,
  152 + /// then calls RegisterSpecific
  153 + template <typename T>
  154 + void RegisterCommon(const std::string &name, T *ptr, const std::string &doc,
  155 + bool is_standard);
  156 +
  157 + /// Set option with name "key" to "value"; will crash if can't do it.
  158 + /// "has_equal_sign" is used to allow --x for a boolean option x,
  159 + /// and --y=, for a string option y.
  160 + bool SetOption(const std::string &key, const std::string &value,
  161 + bool has_equal_sign);
  162 +
  163 + bool ToBool(std::string str) const;
  164 + int32_t ToInt(const std::string &str) const;
  165 + uint32_t ToUint(const std::string &str) const;
  166 + float ToFloat(const std::string &str) const;
  167 + double ToDouble(const std::string &str) const;
  168 +
  169 + // maps for option variables
  170 + std::unordered_map<std::string, bool *> bool_map_;
  171 + std::unordered_map<std::string, int32_t *> int_map_;
  172 + std::unordered_map<std::string, uint32_t *> uint_map_;
  173 + std::unordered_map<std::string, float *> float_map_;
  174 + std::unordered_map<std::string, double *> double_map_;
  175 + std::unordered_map<std::string, std::string *> string_map_;
  176 +
  177 + /**
  178 + Structure for options' documentation
  179 + */
  180 + struct DocInfo {
  181 + DocInfo() = default;
  182 + DocInfo(const std::string &name, const std::string &usemsg)
  183 + : name_(name), use_msg_(usemsg), is_standard_(false) {}
  184 + DocInfo(const std::string &name, const std::string &usemsg,
  185 + bool is_standard)
  186 + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {}
  187 +
  188 + std::string name_;
  189 + std::string use_msg_;
  190 + bool is_standard_;
  191 + };
  192 + using DocMapType = std::unordered_map<std::string, DocInfo>;
  193 + DocMapType doc_map_; ///< map for the documentation
  194 +
  195 + bool print_args_; ///< variable for the implicit --print-args parameter
  196 + bool help_; ///< variable for the implicit --help parameter
  197 + std::string config_; ///< variable for the implicit --config parameter
  198 + std::vector<std::string> positional_args_;
  199 + const char *usage_;
  200 + int argc_;
  201 + const char *const *argv_;
  202 +
  203 + /// These members are not normally used. They are only used when the object
  204 + /// is constructed with a prefix
  205 + std::string prefix_;
  206 + ParseOptions *other_parser_;
  207 +
  208 + protected:
  209 + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a,
  210 + /// and sets "has_equal_sign" to true if an equals-sign was parsed..
  211 + /// this is needed in order to correctly allow --x for a boolean option
  212 + /// x, and --y= for a string option y, and to disallow --x= and --y.
  213 + void SplitLongArg(const std::string &in, std::string *key, std::string *value,
  214 + bool *has_equal_sign) const;
  215 +
  216 + void NormalizeArgName(std::string *str) const;
  217 +
  218 + /// Removes the beginning and trailing whitespaces from a string
  219 + void Trim(std::string *str) const;
  220 +};
  221 +
  222 +/// This template is provided for convenience in reading config classes from
  223 +/// files; this is not the standard way to read configuration options, but may
  224 +/// occasionally be needed. This function assumes the config has a function
  225 +/// "void Register(ParseOptions *opts)" which it can call to register the
  226 +/// ParseOptions object.
  227 +template <class C>
  228 +void ReadConfigFromFile(const std::string &config_filename, C *c) {
  229 + std::ostringstream usage_str;
  230 + usage_str << "Parsing config from "
  231 + << "from '" << config_filename << "'";
  232 + ParseOptions po(usage_str.str().c_str());
  233 + c->Register(&po);
  234 + po.ReadConfigFile(config_filename);
  235 +}
  236 +
  237 +/// This variant of the template ReadConfigFromFile is for if you need to read
  238 +/// two config classes from the same file.
  239 +template <class C1, class C2>
  240 +void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) {
  241 + std::ostringstream usage_str;
  242 + usage_str << "Parsing config from "
  243 + << "from '" << conf << "'";
  244 + ParseOptions po(usage_str.str().c_str());
  245 + c1->Register(&po);
  246 + c2->Register(&po);
  247 + po.ReadConfigFile(conf);
  248 +}
  249 +
  250 +} // namespace sherpa_onnx
  251 +
  252 +#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-alsa.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#include <signal.h>
  5 +#include <stdio.h>
  6 +#include <stdlib.h>
  7 +
  8 +#include <algorithm>
  9 +#include <cctype> // std::tolower
  10 +#include <cstdint>
  11 +
  12 +#include "sherpa-onnx/csrc/alsa.h"
  13 +#include "sherpa-onnx/csrc/display.h"
  14 +#include "sherpa-onnx/csrc/online-recognizer.h"
  15 +
  16 +bool stop = false;
  17 +
  18 +static void Handler(int sig) {
  19 + stop = true;
  20 + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
  21 +}
  22 +
  23 +int main(int32_t argc, char *argv[]) {
  24 + if (argc < 6 || argc > 7) {
  25 + const char *usage = R"usage(
  26 +Usage:
  27 + ./bin/sherpa-onnx-alsa \
  28 + /path/to/tokens.txt \
  29 + /path/to/encoder.onnx \
  30 + /path/to/decoder.onnx \
  31 + /path/to/joiner.onnx \
  32 + device_name \
  33 + [num_threads]
  34 +
  35 +Please refer to
  36 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  37 +for a list of pre-trained models to download.
  38 +
  39 +The device name specifies which microphone to use in case there are several
  40 +on you system. You can use
  41 +
  42 + arecord -l
  43 +
  44 +to find all available microphones on your computer. For instance, if it outputs
  45 +
  46 +**** List of CAPTURE Hardware Devices ****
  47 +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
  48 + Subdevices: 1/1
  49 + Subdevice #0: subdevice #0
  50 +
  51 +and if you want to select card 3 and the device 0 on that card, please use:
  52 +
  53 + hw:3,0
  54 +
  55 +as the device_name.
  56 +)usage";
  57 +
  58 + fprintf(stderr, "%s\n", usage);
  59 + fprintf(stderr, "argc, %d\n", argc);
  60 +
  61 + return 0;
  62 + }
  63 +
  64 + signal(SIGINT, Handler);
  65 +
  66 + sherpa_onnx::OnlineRecognizerConfig config;
  67 +
  68 + config.tokens = argv[1];
  69 +
  70 + config.model_config.debug = false;
  71 + config.model_config.encoder_filename = argv[2];
  72 + config.model_config.decoder_filename = argv[3];
  73 + config.model_config.joiner_filename = argv[4];
  74 +
  75 + const char *device_name = argv[5];
  76 +
  77 + config.model_config.num_threads = 2;
  78 + if (argc == 7 && atoi(argv[6]) > 0) {
  79 + config.model_config.num_threads = atoi(argv[6]);
  80 + }
  81 +
  82 + config.enable_endpoint = true;
  83 +
  84 + config.endpoint_config.rule1.min_trailing_silence = 2.4;
  85 + config.endpoint_config.rule2.min_trailing_silence = 1.2;
  86 + config.endpoint_config.rule3.min_utterance_length = 300;
  87 +
  88 + fprintf(stderr, "%s\n", config.ToString().c_str());
  89 +
  90 + sherpa_onnx::OnlineRecognizer recognizer(config);
  91 +
  92 + int32_t expected_sample_rate = config.feat_config.sampling_rate;
  93 +
  94 + sherpa_onnx::Alsa alsa(device_name);
  95 + fprintf(stderr, "Use recording device: %s\n", device_name);
  96 +
  97 + if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
  98 + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
  99 + expected_sample_rate);
  100 + exit(-1);
  101 + }
  102 +
  103 + int32_t chunk = 0.1 * alsa.GetActualSampleRate();
  104 +
  105 + std::string last_text;
  106 +
  107 + auto stream = recognizer.CreateStream();
  108 +
  109 + sherpa_onnx::Display display;
  110 +
  111 + int32_t segment_index = 0;
  112 + while (!stop) {
  113 + const std::vector<float> samples = alsa.Read(chunk);
  114 +
  115 + stream->AcceptWaveform(expected_sample_rate, samples.data(),
  116 + samples.size());
  117 +
  118 + while (recognizer.IsReady(stream.get())) {
  119 + recognizer.DecodeStream(stream.get());
  120 + }
  121 +
  122 + auto text = recognizer.GetResult(stream.get()).text;
  123 +
  124 + bool is_endpoint = recognizer.IsEndpoint(stream.get());
  125 +
  126 + if (!text.empty() && last_text != text) {
  127 + last_text = text;
  128 +
  129 + std::transform(text.begin(), text.end(), text.begin(),
  130 + [](auto c) { return std::tolower(c); });
  131 +
  132 + display.Print(segment_index, text);
  133 + }
  134 +
  135 + if (!text.empty() && is_endpoint) {
  136 + ++segment_index;
  137 + recognizer.Reset(stream.get());
  138 + }
  139 + }
  140 +
  141 + return 0;
  142 +}
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
4 features.cc 4 features.cc
5 online-transducer-model-config.cc 5 online-transducer-model-config.cc
6 sherpa-onnx.cc 6 sherpa-onnx.cc
  7 + endpoint.cc
7 online-stream.cc 8 online-stream.cc
8 online-recognizer.cc 9 online-recognizer.cc
9 ) 10 )
  1 +// sherpa-onnx/csrc/endpoint.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/endpoint.h"
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/endpoint.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +static constexpr const char *kEndpointRuleInitDoc = R"doc(
  15 +Constructor for EndpointRule.
  16 +
  17 +Args:
  18 + must_contain_nonsilence:
  19 + If True, for this endpointing rule to apply there must be nonsilence in the
  20 + best-path traceback. For decoding, a non-blank token is considered as
  21 + non-silence.
  22 + min_trailing_silence:
  23 + This endpointing rule requires duration of trailing silence (in seconds)
  24 + to be ``>=`` this value.
  25 + min_utterance_length:
  26 + This endpointing rule requires utterance-length (in seconds) to
  27 + be ``>=`` this value.
  28 +)doc";
  29 +
  30 +static constexpr const char *kEndpointConfigInitDoc = R"doc(
  31 +If any rule in EndpointConfig is activated, it is said that an endpointing
  32 +is detected.
  33 +
  34 +Args:
  35 + rule1:
  36 + By default, it times out after 2.4 seconds of silence, even if
  37 + we decoded nothing.
  38 + rule2:
  39 + By default, it times out after 1.2 seconds of silence after decoding
  40 + something.
  41 + rule3:
  42 + By default, it times out after the utterance is 20 seconds long, regardless of
  43 + anything else.
  44 +)doc";
  45 +
  46 +static void PybindEndpointRule(py::module *m) {
  47 + using PyClass = EndpointRule;
  48 + py::class_<PyClass>(*m, "EndpointRule")
  49 + .def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
  50 + py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
  51 + kEndpointRuleInitDoc)
  52 + .def("__str__", &PyClass::ToString)
  53 + .def_readwrite("must_contain_nonsilence",
  54 + &PyClass::must_contain_nonsilence)
  55 + .def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
  56 + .def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
  57 +}
  58 +
  59 +static void PybindEndpointConfig(py::module *m) {
  60 + using PyClass = EndpointConfig;
  61 + py::class_<PyClass>(*m, "EndpointConfig")
  62 + .def(
  63 + py::init(
  64 + [](float rule1_min_trailing_silence,
  65 + float rule2_min_trailing_silence,
  66 + float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
  67 + EndpointRule rule1(false, rule1_min_trailing_silence, 0);
  68 + EndpointRule rule2(true, rule2_min_trailing_silence, 0);
  69 + EndpointRule rule3(false, 0, rule3_min_utterance_length);
  70 +
  71 + return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
  72 + }),
  73 + py::arg("rule1_min_trailing_silence"),
  74 + py::arg("rule2_min_trailing_silence"),
  75 + py::arg("rule3_min_utterance_length"))
  76 + .def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
  77 + const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
  78 + auto ans = std::make_unique<PyClass>();
  79 + ans->rule1 = rule1;
  80 + ans->rule2 = rule2;
  81 + ans->rule3 = rule3;
  82 + return ans;
  83 + }),
  84 + py::arg("rule1") = EndpointRule(false, 2.4, 0),
  85 + py::arg("rule2") = EndpointRule(true, 1.2, 0),
  86 + py::arg("rule3") = EndpointRule(false, 0, 20),
  87 + kEndpointConfigInitDoc)
  88 + .def("__str__",
  89 + [](const PyClass &self) -> std::string { return self.ToString(); })
  90 + .def_readwrite("rule1", &PyClass::rule1)
  91 + .def_readwrite("rule2", &PyClass::rule2)
  92 + .def_readwrite("rule3", &PyClass::rule3);
  93 +}
  94 +
  95 +void PybindEndpoint(py::module *m) {
  96 + PybindEndpointRule(m);
  97 + PybindEndpointConfig(m);
  98 +}
  99 +
  100 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/endpoint.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindEndpoint(py::module *m);
  13 +
  14 +} // namespace sherpa_onnx
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
@@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
21 using PyClass = OnlineRecognizerConfig; 21 using PyClass = OnlineRecognizerConfig;
22 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 22 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
23 .def(py::init<const FeatureExtractorConfig &, 23 .def(py::init<const FeatureExtractorConfig &,
24 - const OnlineTransducerModelConfig &, const std::string &>(),  
25 - py::arg("feat_config"), py::arg("model_config"), py::arg("tokens")) 24 + const OnlineTransducerModelConfig &, const std::string &,
  25 + const EndpointConfig &, bool>(),
  26 + py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
  27 + py::arg("endpoint_config"), py::arg("enable_endpoint"))
26 .def_readwrite("feat_config", &PyClass::feat_config) 28 .def_readwrite("feat_config", &PyClass::feat_config)
27 .def_readwrite("model_config", &PyClass::model_config) 29 .def_readwrite("model_config", &PyClass::model_config)
28 .def_readwrite("tokens", &PyClass::tokens) 30 .def_readwrite("tokens", &PyClass::tokens)
  31 + .def_readwrite("endpoint_config", &PyClass::endpoint_config)
  32 + .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
29 .def("__str__", &PyClass::ToString); 33 .def("__str__", &PyClass::ToString);
30 } 34 }
31 35
@@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) { @@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
43 [](PyClass &self, std::vector<OnlineStream *> ss) { 47 [](PyClass &self, std::vector<OnlineStream *> ss) {
44 self.DecodeStreams(ss.data(), ss.size()); 48 self.DecodeStreams(ss.data(), ss.size());
45 }) 49 })
46 - .def("get_result", &PyClass::GetResult); 50 + .def("get_result", &PyClass::GetResult)
  51 + .def("is_endpoint", &PyClass::IsEndpoint)
  52 + .def("reset", &PyClass::Reset);
47 } 53 }
48 54
49 } // namespace sherpa_onnx 55 } // namespace sherpa_onnx
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h" 5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h"
6 6
  7 +#include "sherpa-onnx/python/csrc/endpoint.h"
7 #include "sherpa-onnx/python/csrc/features.h" 8 #include "sherpa-onnx/python/csrc/features.h"
8 #include "sherpa-onnx/python/csrc/online-recognizer.h" 9 #include "sherpa-onnx/python/csrc/online-recognizer.h"
9 #include "sherpa-onnx/python/csrc/online-stream.h" 10 #include "sherpa-onnx/python/csrc/online-stream.h"
@@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
16 PybindFeatures(&m); 17 PybindFeatures(&m);
17 PybindOnlineTransducerModelConfig(&m); 18 PybindOnlineTransducerModelConfig(&m);
18 PybindOnlineStream(&m); 19 PybindOnlineStream(&m);
  20 + PybindEndpoint(&m);
19 PybindOnlineRecognizer(&m); 21 PybindOnlineRecognizer(&m);
20 } 22 }
21 23
1 from _sherpa_onnx import ( 1 from _sherpa_onnx import (
  2 + EndpointConfig,
2 FeatureExtractorConfig, 3 FeatureExtractorConfig,
3 OnlineRecognizerConfig, 4 OnlineRecognizerConfig,
4 OnlineStream, 5 OnlineStream,
@@ -2,12 +2,13 @@ from pathlib import Path @@ -2,12 +2,13 @@ from pathlib import Path
2 from typing import List 2 from typing import List
3 3
4 from _sherpa_onnx import ( 4 from _sherpa_onnx import (
5 - OnlineStream,  
6 - OnlineTransducerModelConfig, 5 + EndpointConfig,
7 FeatureExtractorConfig, 6 FeatureExtractorConfig,
  7 + OnlineRecognizer as _Recognizer,
8 OnlineRecognizerConfig, 8 OnlineRecognizerConfig,
  9 + OnlineStream,
  10 + OnlineTransducerModelConfig,
9 ) 11 )
10 -from _sherpa_onnx import OnlineRecognizer as _Recognizer  
11 12
12 13
13 def _assert_file_exists(f: str): 14 def _assert_file_exists(f: str):
@@ -26,6 +27,10 @@ class OnlineRecognizer(object): @@ -26,6 +27,10 @@ class OnlineRecognizer(object):
26 num_threads: int = 4, 27 num_threads: int = 4,
27 sample_rate: float = 16000, 28 sample_rate: float = 16000,
28 feature_dim: int = 80, 29 feature_dim: int = 80,
  30 + enable_endpoint_detection: bool = False,
  31 + rule1_min_trailing_silence: int = 2.4,
  32 + rule2_min_trailing_silence: int = 1.2,
  33 + rule3_min_utterance_length: int = 20,
29 ): 34 ):
30 """ 35 """
31 Please refer to 36 Please refer to
@@ -52,6 +57,22 @@ class OnlineRecognizer(object): @@ -52,6 +57,22 @@ class OnlineRecognizer(object):
52 Sample rate of the training data used to train the model. 57 Sample rate of the training data used to train the model.
53 feature_dim: 58 feature_dim:
54 Dimension of the feature used to train the model. 59 Dimension of the feature used to train the model.
  60 + enable_endpoint_detection:
  61 + True to enable endpoint detection. False to disable endpoint
  62 + detection.
  63 + rule1_min_trailing_silence:
  64 + Used only when enable_endpoint_detection is True. If the duration
  65 + of trailing silence in seconds is larger than this value, we assume
  66 + an endpoint is detected.
  67 + rule2_min_trailing_silence:
  68 + Used only when enable_endpoint_detection is True. If we have decoded
  69 + something that is nonsilence and if the duration of trailing silence
  70 + in seconds is larger than this value, we assume an endpoint is
  71 + detected.
  72 + rule3_min_utterance_length:
  73 + Used only when enable_endpoint_detection is True. If the utterance
  74 + length in seconds is larger than this value, we assume an endpoint
  75 + is detected.
55 """ 76 """
56 _assert_file_exists(tokens) 77 _assert_file_exists(tokens)
57 _assert_file_exists(encoder) 78 _assert_file_exists(encoder)
@@ -72,10 +93,18 @@ class OnlineRecognizer(object): @@ -72,10 +93,18 @@ class OnlineRecognizer(object):
72 feature_dim=feature_dim, 93 feature_dim=feature_dim,
73 ) 94 )
74 95
  96 + endpoint_config = EndpointConfig(
  97 + rule1_min_trailing_silence=rule1_min_trailing_silence,
  98 + rule2_min_trailing_silence=rule2_min_trailing_silence,
  99 + rule3_min_utterance_length=rule3_min_utterance_length,
  100 + )
  101 +
75 recognizer_config = OnlineRecognizerConfig( 102 recognizer_config = OnlineRecognizerConfig(
76 feat_config=feat_config, 103 feat_config=feat_config,
77 model_config=model_config, 104 model_config=model_config,
78 tokens=tokens, 105 tokens=tokens,
  106 + endpoint_config=endpoint_config,
  107 + enable_endpoint=enable_endpoint_detection,
79 ) 108 )
80 109
81 self.recognizer = _Recognizer(recognizer_config) 110 self.recognizer = _Recognizer(recognizer_config)
@@ -93,4 +122,10 @@ class OnlineRecognizer(object): @@ -93,4 +122,10 @@ class OnlineRecognizer(object):
93 return self.recognizer.is_ready(s) 122 return self.recognizer.is_ready(s)
94 123
95 def get_result(self, s: OnlineStream) -> str: 124 def get_result(self, s: OnlineStream) -> str:
96 - return self.recognizer.get_result(s).text 125 + return self.recognizer.get_result(s).text.strip()
  126 +
  127 + def is_endpoint(self, s: OnlineStream) -> bool:
  128 + return self.recognizer.is_endpoint(s)
  129 +
  130 + def reset(self, s: OnlineStream) -> bool:
  131 + return self.recognizer.reset(s)