Fangjun Kuang
Committed by GitHub

Add endpointing (#54)

... ... @@ -4,9 +4,11 @@ build
onnxruntime-*
icefall-*
run.sh
sherpa-onnx-*
__pycache__
dist/
sherpa_onnx.egg-info/
.DS_Store
build-aarch64-linux-gnu
sherpa-onnx-streaming-zipformer-*
sherpa-onnx-lstm-en-*
sherpa-onnx-lstm-zh-*
... ...
... ... @@ -13,6 +13,7 @@ endif()
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
... ... @@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_EXTENSIONS OFF)
... ... @@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA)
add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1)
endif()
check_include_file_cxx(cxxabi.h SHERPA_ONNX_HAVE_CXXABI_H)
check_include_file_cxx(execinfo.h SHERPA_ONNX_HAVE_EXECINFO_H)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
... ...
#!/usr/bin/env python3
# Real-time speech recognition from a microphone with sherpa-onnx Python API
# with endpoint detection.
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# to download pre-trained models
import sys
try:
import sounddevice as sd
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_onnx
def create_recognizer():
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer(
tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt",
encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx",
decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx",
joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx",
num_threads=4,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
)
return recognizer
def main():
print("Started! Please speak")
recognizer = create_recognizer()
sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = recognizer.create_stream()
last_result = ""
segment_id = 0
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
is_endpoint = recognizer.is_endpoint(stream)
result = recognizer.get_result(stream)
if result and (last_result != result):
last_result = result
print(f"{segment_id}: {result}")
if result and is_endpoint:
segment_id += 1
recognizer.reset(stream)
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
... ...
include_directories(${CMAKE_SOURCE_DIR})
add_library(sherpa-onnx-core
set(sources
cat.cc
endpoint.cc
features.cc
online-lstm-transducer-model.cc
online-recognizer.cc
... ... @@ -11,6 +12,7 @@ add_library(sherpa-onnx-core
online-transducer-model.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
parse-options.cc
resample.cc
symbol-table.cc
text-utils.cc
... ... @@ -18,11 +20,29 @@ add_library(sherpa-onnx-core
wave-reader.cc
)
if(SHERPA_ONNX_ENABLE_CHECK)
list(APPEND sources log.cc)
endif()
add_library(sherpa-onnx-core ${sources})
target_link_libraries(sherpa-onnx-core
onnxruntime
kaldi-native-fbank-core
)
if(SHERPA_ONNX_ENABLE_CHECK)
target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
if(SHERPA_ONNX_HAVE_EXECINFO_H)
target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_EXECINFO_H=1)
endif()
if(SHERPA_ONNX_HAVE_CXXABI_H)
target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_CXXABI_H=1)
endif()
endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
target_link_libraries(sherpa-onnx sherpa-onnx-core)
... ...
// sherpa-onnx/csrc/endpoint.cc
//
// Copyright (c) 2022 (authors: Pingfeng Luo)
// 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/endpoint.h"
#include <string>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
static bool RuleActivated(const EndpointRule &rule,
const std::string &rule_name, float trailing_silence,
float utterance_length) {
bool contain_nonsilence = utterance_length > trailing_silence;
bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) &&
trailing_silence >= rule.min_trailing_silence &&
utterance_length >= rule.min_utterance_length;
if (ans) {
SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: "
<< (contain_nonsilence ? "true" : "false") << ','
<< trailing_silence << ',' << utterance_length;
}
return ans;
}
static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule,
const std::string &rule_name) {
po->Register(
rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence,
"If True, for this endpointing " + rule_name +
" to apply there must be nonsilence in the best-path traceback. "
"For decoding, a non-blank token is considered as non-silence");
po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence,
"This endpointing " + rule_name +
" requires duration of trailing silence in seconds) to "
"be >= this value.");
po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length,
"This endpointing " + rule_name +
" requires utterance-length (in seconds) to be >= this "
"value.");
}
std::string EndpointRule::ToString() const {
std::ostringstream os;
os << "EndpointRule(";
os << "must_contain_nonsilence="
<< (must_contain_nonsilence ? "True" : "False") << ", ";
os << "min_trailing_silence=" << min_trailing_silence << ", ";
os << "min_utterance_length=" << min_utterance_length << ")";
return os.str();
}
void EndpointConfig::Register(ParseOptions *po) {
RegisterEndpointRule(po, &rule1, "rule1");
RegisterEndpointRule(po, &rule2, "rule2");
RegisterEndpointRule(po, &rule3, "rule3");
}
std::string EndpointConfig::ToString() const {
std::ostringstream os;
os << "EndpointConfig(";
os << "rule1=" << rule1.ToString() << ", ";
os << "rule2=" << rule2.ToString() << ", ";
os << "rule3=" << rule3.ToString() << ")";
return os.str();
}
bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
float frame_shift_in_seconds) const {
float utterance_length = num_frames_decoded * frame_shift_in_seconds;
float trailing_silence = trailing_silence_frames * frame_shift_in_seconds;
if (RuleActivated(config_.rule1, "rule1", trailing_silence,
utterance_length) ||
RuleActivated(config_.rule2, "rule2", trailing_silence,
utterance_length) ||
RuleActivated(config_.rule3, "rule3", trailing_silence,
utterance_length)) {
return true;
}
return false;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/endpoint.h
//
// Copyright (c) 2022 (authors: Pingfeng Luo)
// 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_
#define SHERPA_ONNX_CSRC_ENDPOINT_H_
#include <string>
#include <vector>
namespace sherpa_onnx {
struct EndpointRule {
// If True, for this endpointing rule to apply there must
// be nonsilence in the best-path traceback.
// For decoding, a non-blank token is considered as non-silence
bool must_contain_nonsilence = true;
// This endpointing rule requires duration of trailing silence
// (in seconds) to be >= this value.
float min_trailing_silence = 2.0;
// This endpointing rule requires utterance-length (in seconds)
// to be >= this value.
float min_utterance_length = 0.0f;
EndpointRule() = default;
EndpointRule(bool must_contain_nonsilence, float min_trailing_silence,
float min_utterance_length)
: must_contain_nonsilence(must_contain_nonsilence),
min_trailing_silence(min_trailing_silence),
min_utterance_length(min_utterance_length) {}
std::string ToString() const;
};
class ParseOptions;
struct EndpointConfig {
// For default setting,
// rule1 times out after 2.4 seconds of silence, even if we decoded nothing.
// rule2 times out after 1.2 seconds of silence after decoding something.
// rule3 times out after the utterance is 20 seconds long, regardless of
// anything else.
EndpointRule rule1;
EndpointRule rule2;
EndpointRule rule3;
void Register(ParseOptions *po);
EndpointConfig()
: rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {}
EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2,
const EndpointRule &rule3)
: rule1(rule1), rule2(rule2), rule3(rule3) {}
std::string ToString() const;
};
class Endpoint {
public:
explicit Endpoint(const EndpointConfig &config) : config_(config) {}
/// This function returns true if this set of endpointing rules thinks we
/// should terminate decoding.
bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
float frame_shift_in_seconds) const;
private:
EndpointConfig config_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_
... ...
// sherpa-onnx/csrc/log.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/log.h"
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
#include <execinfo.h> // To get stack trace in error messages.
#ifdef SHERPA_ONNX_HAVE_CXXABI_H
#include <cxxabi.h> // For name demangling.
// Useful to decode the stack trace, but only used if we have execinfo.h
#endif // SHERPA_ONNX_HAVE_CXXABI_H
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
#include <stdlib.h>
#include <ctime>
#include <iomanip>
#include <string>
namespace sherpa_onnx {
std::string GetDateTimeStr() {
std::ostringstream os;
std::time_t t = std::time(nullptr);
std::tm tm = *std::localtime(&t);
os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss
return os.str();
}
static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin,
std::size_t *end) {
// Find the first '_' with leading ' ' or '('.
*begin = std::string::npos;
for (std::size_t i = 1; i < trace_name.size(); ++i) {
if (trace_name[i] != '_') {
continue;
}
if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') {
*begin = i;
break;
}
}
if (*begin == std::string::npos) {
return false;
}
*end = trace_name.find_first_of(" +", *begin);
return *end != std::string::npos;
}
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
static std::string Demangle(const std::string &trace_name) {
#ifndef SHERPA_ONNX_HAVE_CXXABI_H
return trace_name;
#else // SHERPA_ONNX_HAVE_CXXABI_H
// Try demangle the symbol. We are trying to support the following formats
// produced by different platforms:
//
// Linux:
// ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d]
//
// Mac:
// 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813
//
// We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and
// demangle it info a readable name like kaldi::UnitTextError.
std::size_t begin, end;
if (!LocateSymbolRange(trace_name, &begin, &end)) {
return trace_name;
}
std::string symbol = trace_name.substr(begin, end - begin);
int status;
char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status);
if (status == 0 && demangled_name != nullptr) {
symbol = demangled_name;
free(demangled_name);
}
return trace_name.substr(0, begin) + symbol +
trace_name.substr(end, std::string::npos);
#endif // SHERPA_ONNX_HAVE_CXXABI_H
}
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
std::string GetStackTrace() {
std::string ans;
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
constexpr const std::size_t kMaxTraceSize = 50;
constexpr const std::size_t kMaxTracePrint = 50; // Must be even.
// Buffer for the trace.
void *trace[kMaxTraceSize];
// Get the trace.
std::size_t size = backtrace(trace, kMaxTraceSize);
// Get the trace symbols.
char **trace_symbol = backtrace_symbols(trace, size);
if (trace_symbol == nullptr) return ans;
// Compose a human-readable backtrace string.
ans += "[ Stack-Trace: ]\n";
if (size <= kMaxTracePrint) {
for (std::size_t i = 0; i < size; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
} else { // Print out first+last (e.g.) 5.
for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
ans += ".\n.\n.\n";
for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
if (size == kMaxTraceSize)
ans += ".\n.\n.\n"; // Stack was too long, probably a bug.
}
// We must free the array of pointers allocated by backtrace_symbols(),
// but not the strings themselves.
free(trace_symbol);
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/log.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_LOG_H_
#define SHERPA_ONNX_CSRC_LOG_H_
#include <stdio.h>
#include <mutex> // NOLINT
#include <sstream>
#include <string>
namespace sherpa_onnx {
#if SHERPA_ONNX_ENABLE_CHECK
#if defined(NDEBUG)
constexpr bool kDisableDebug = true;
#else
constexpr bool kDisableDebug = false;
#endif
enum class LogLevel {
kTrace = 0,
kDebug = 1,
kInfo = 2,
kWarning = 3,
kError = 4,
kFatal = 5, // print message and abort the program
};
// They are used in SHERPA_ONNX_LOG(xxx), so their names
// do not follow the google c++ code style
//
// You can use them in the following way:
//
// SHERPA_ONNX_LOG(TRACE) << "some message";
// SHERPA_ONNX_LOG(DEBUG) << "some message";
#ifndef _MSC_VER
constexpr LogLevel TRACE = LogLevel::kTrace;
constexpr LogLevel DEBUG = LogLevel::kDebug;
constexpr LogLevel INFO = LogLevel::kInfo;
constexpr LogLevel WARNING = LogLevel::kWarning;
constexpr LogLevel ERROR = LogLevel::kError;
constexpr LogLevel FATAL = LogLevel::kFatal;
#else
#define TRACE LogLevel::kTrace
#define DEBUG LogLevel::kDebug
#define INFO LogLevel::kInfo
#define WARNING LogLevel::kWarning
#define ERROR LogLevel::kError
#define FATAL LogLevel::kFatal
#endif
std::string GetStackTrace();
/* Return the current log level.
If the current log level is TRACE, then all logged messages are printed out.
If the current log level is DEBUG, log messages with "TRACE" level are not
shown and all other levels are printed out.
Similarly, if the current log level is INFO, log message with "TRACE" and
"DEBUG" are not shown and all other levels are printed out.
If it is FATAL, then only FATAL messages are shown.
*/
inline LogLevel GetCurrentLogLevel() {
static LogLevel log_level = INFO;
static std::once_flag init_flag;
std::call_once(init_flag, []() {
const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL");
if (env_log_level == nullptr) return;
std::string s = env_log_level;
if (s == "TRACE")
log_level = TRACE;
else if (s == "DEBUG")
log_level = DEBUG;
else if (s == "INFO")
log_level = INFO;
else if (s == "WARNING")
log_level = WARNING;
else if (s == "ERROR")
log_level = ERROR;
else if (s == "FATAL")
log_level = FATAL;
else
fprintf(stderr,
"Unknown SHERPA_ONNX_LOG_LEVEL: %s"
"\nSupported values are: "
"TRACE, DEBUG, INFO, WARNING, ERROR, FATAL",
s.c_str());
});
return log_level;
}
inline bool EnableAbort() {
static std::once_flag init_flag;
static bool enable_abort = false;
std::call_once(init_flag, []() {
enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr);
});
return enable_abort;
}
class Logger {
public:
Logger(const char *filename, const char *func_name, uint32_t line_num,
LogLevel level)
: filename_(filename),
func_name_(func_name),
line_num_(line_num),
level_(level) {
cur_level_ = GetCurrentLogLevel();
switch (level) {
case TRACE:
if (cur_level_ <= TRACE) fprintf(stderr, "[T] ");
break;
case DEBUG:
if (cur_level_ <= DEBUG) fprintf(stderr, "[D] ");
break;
case INFO:
if (cur_level_ <= INFO) fprintf(stderr, "[I] ");
break;
case WARNING:
if (cur_level_ <= WARNING) fprintf(stderr, "[W] ");
break;
case ERROR:
if (cur_level_ <= ERROR) fprintf(stderr, "[E] ");
break;
case FATAL:
if (cur_level_ <= FATAL) fprintf(stderr, "[F] ");
break;
}
if (cur_level_ <= level_) {
fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name);
}
}
~Logger() noexcept(false) {
static constexpr const char *kErrMsg = R"(
Some bad things happened. Please read the above error messages and stack
trace. If you are using Python, the following command may be helpful:
gdb --args python /path/to/your/code.py
(You can use `gdb` to debug the code. Please consider compiling
a debug version of sherpa_onnx.).
If you are unable to fix it, please open an issue at:
https://github.com/csukuangfj/kaldi-native-fbank/issues/new
)";
if (level_ == FATAL) {
fprintf(stderr, "\n");
std::string stack_trace = GetStackTrace();
if (!stack_trace.empty()) {
fprintf(stderr, "\n\n%s\n", stack_trace.c_str());
}
fflush(nullptr);
#ifndef __ANDROID_API__
if (EnableAbort()) {
// NOTE: abort() will terminate the program immediately without
// printing the Python stack backtrace.
abort();
}
throw std::runtime_error(kErrMsg);
#else
abort();
#endif
}
}
const Logger &operator<<(bool b) const {
if (cur_level_ <= level_) {
fprintf(stderr, b ? "true" : "false");
}
return *this;
}
const Logger &operator<<(int8_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
return *this;
}
const Logger &operator<<(const char *s) const {
if (cur_level_ <= level_) fprintf(stderr, "%s", s);
return *this;
}
const Logger &operator<<(int32_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
return *this;
}
const Logger &operator<<(uint32_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%u", i);
return *this;
}
const Logger &operator<<(uint64_t i) const {
if (cur_level_ <= level_)
fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT
return *this;
}
const Logger &operator<<(int64_t i) const {
if (cur_level_ <= level_)
fprintf(stderr, "%lli", (long long int)i); // NOLINT
return *this;
}
const Logger &operator<<(float f) const {
if (cur_level_ <= level_) fprintf(stderr, "%f", f);
return *this;
}
const Logger &operator<<(double d) const {
if (cur_level_ <= level_) fprintf(stderr, "%f", d);
return *this;
}
template <typename T>
const Logger &operator<<(const T &t) const {
// require T overloads operator<<
std::ostringstream os;
os << t;
return *this << os.str().c_str();
}
// specialization to fix compile error: `stringstream << nullptr` is ambiguous
const Logger &operator<<(const std::nullptr_t &null) const {
if (cur_level_ <= level_) *this << "(null)";
return *this;
}
private:
const char *filename_;
const char *func_name_;
uint32_t line_num_;
LogLevel level_;
LogLevel cur_level_;
};
#endif // SHERPA_ONNX_ENABLE_CHECK
class Voidifier {
public:
#if SHERPA_ONNX_ENABLE_CHECK
void operator&(const Logger &) const {}
#endif
};
#if !defined(SHERPA_ONNX_ENABLE_CHECK)
template <typename T>
const Voidifier &operator<<(const Voidifier &v, T &&) {
return v;
}
#endif
} // namespace sherpa_onnx
#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "")
#ifdef SHERPA_ONNX_ENABLE_CHECK
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \
defined(__PRETTY_FUNCTION__)
// for clang and GCC
#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__
#else
// for other compilers
#define SHERPA_ONNX_FUNC __func__
#endif
#define SHERPA_ONNX_CHECK(x) \
(x) ? (void)0 \
: ::sherpa_onnx::Voidifier() & \
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
::sherpa_onnx::FATAL) \
<< "Check failed: " << #x << " "
// WARNING: x and y may be evaluated multiple times, but this happens only
// when the check fails. Since the program aborts if it fails, we don't think
// the extra evaluation of x and y matters.
//
// CAUTION: we recommend the following use case:
//
// auto x = Foo();
// auto y = Bar();
// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message";
//
// And please avoid
//
// SHERPA_ONNX_CHECK_EQ(Foo(), Bar());
//
// if `Foo()` or `Bar()` causes some side effects, e.g., changing some
// local static variables or global variables.
#define _SHERPA_ONNX_CHECK_OP(x, y, op) \
((x)op(y)) ? (void)0 \
: ::sherpa_onnx::Voidifier() & \
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
::sherpa_onnx::FATAL) \
<< "Check failed: " << #x << " " << #op << " " << #y \
<< " (" << (x) << " vs. " << (y) << ") "
#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==)
#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=)
#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <)
#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=)
#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >)
#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=)
#define SHERPA_ONNX_LOG(x) \
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_onnx::x)
// ------------------------------------------------------------
// For debug check
// ------------------------------------------------------------
// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank,
// the following macros are in fact empty and does nothing.
#define SHERPA_ONNX_DCHECK(x) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x)
#define SHERPA_ONNX_DCHECK_EQ(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y)
#define SHERPA_ONNX_DCHECK_NE(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y)
#define SHERPA_ONNX_DCHECK_LT(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y)
#define SHERPA_ONNX_DCHECK_LE(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y)
#define SHERPA_ONNX_DCHECK_GT(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y)
#define SHERPA_ONNX_DCHECK_GE(x, y) \
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y)
#define SHERPA_ONNX_DLOG(x) \
::sherpa_onnx::kDisableDebug \
? (void)0 \
: ::sherpa_onnx::Voidifier() & SHERPA_ONNX_LOG(x)
#else
#define SHERPA_ONNX_CHECK(x) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_LOG(x) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK(x) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DLOG(x) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_onnx::Voidifier()
#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_onnx::Voidifier()
#endif // SHERPA_ONNX_CHECK_NE
#endif // SHERPA_ONNX_CSRC_LOG_H_
... ...
... ... @@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
return os.str();
}
... ... @@ -47,7 +49,8 @@ class OnlineRecognizer::Impl {
explicit Impl(const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.tokens) {
sym_(config.tokens),
endpoint_(config_.endpoint_config) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
}
... ... @@ -64,7 +67,7 @@ class OnlineRecognizer::Impl {
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) {
void DecodeStreams(OnlineStream **ss, int32_t n) const {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
... ... @@ -111,18 +114,44 @@ class OnlineRecognizer::Impl {
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) {
OnlineRecognizerResult GetResult(OnlineStream *s) const {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_);
}
bool IsEndpoint(OnlineStream *s) const {
if (!config_.enable_endpoint) return false;
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
// subsampling factor is 4
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const {
// reset result and neural network model state,
// but keep the feature extractor state
// reset result
s->SetResult(decoder_->GetEmptyResult());
// reset neural network model state
s->SetStates(model_->GetEncoderInitStates());
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_;
Endpoint endpoint_;
};
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
... ... @@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
return impl_->GetResult(s);
}
bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
return impl_->IsEndpoint(s);
}
void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
} // namespace sherpa_onnx
... ...
... ... @@ -8,6 +8,7 @@
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
... ... @@ -22,13 +23,21 @@ struct OnlineRecognizerConfig {
FeatureExtractorConfig feat_config;
OnlineTransducerModelConfig model_config;
std::string tokens;
EndpointConfig endpoint_config;
bool enable_endpoint;
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const std::string &tokens)
: feat_config(feat_config), model_config(model_config), tokens(tokens) {}
const std::string &tokens,
const EndpointConfig &endpoint_config,
bool enable_endpoint)
: feat_config(feat_config),
model_config(model_config),
tokens(tokens),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}
std::string ToString() const;
};
... ... @@ -48,7 +57,7 @@ class OnlineRecognizer {
bool IsReady(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) {
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};
DecodeStreams(ss, 1);
}
... ... @@ -58,9 +67,18 @@ class OnlineRecognizer {
* @param ss Pointer array containing streams to be decoded.
* @param n Number of streams in `ss`.
*/
void DecodeStreams(OnlineStream **ss, int32_t n);
void DecodeStreams(OnlineStream **ss, int32_t n) const;
OnlineRecognizerResult GetResult(OnlineStream *s);
OnlineRecognizerResult GetResult(OnlineStream *s) const;
// Return true if we detect an endpoint for this stream.
// Note: If this function returns true, you usually want to
// invoke Reset(s).
bool IsEndpoint(OnlineStream *s) const;
// Clear the state of this stream. If IsEndpoint(s) returns true,
// after calling this function, IsEndpoint(s) will return false
void Reset(OnlineStream *s) const;
private:
class Impl;
... ...
... ... @@ -55,7 +55,8 @@ class OnlineStream {
int32_t FeatureDim() const;
// Return a reference to the number of processed frames so far.
// Return a reference to the number of processed frames so far
// before subsampling..
// Initially, it is 0. It is always less than NumFramesReady().
//
// The returned reference is valid as long as this object is alive.
... ...
... ... @@ -14,6 +14,9 @@ namespace sherpa_onnx {
struct OnlineTransducerDecoderResult {
/// The decoded token IDs so far
std::vector<int64_t> tokens;
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
};
class OnlineTransducerDecoder {
... ...
... ... @@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (y != 0) {
emitted = true;
(*result)[i].tokens.push_back(y);
(*result)[i].num_trailing_blanks = 0;
} else {
++(*result)[i].num_trailing_blanks;
}
}
if (emitted) {
... ...
// sherpa-onnx/csrc/parse-options.cc
/**
* Copyright 2009-2011 Karel Vesely; Microsoft Corporation;
* Saarland University (Author: Arnab Ghoshal);
* Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey);
* Frantisek Skala; Arnab Ghoshal
* Copyright 2013 Tanel Alumae
*/
// This file is copied and modified from kaldi/src/util/parse-options.cu
#include "sherpa-onnx/csrc/parse-options.h"
#include <ctype.h>
#include <algorithm>
#include <cctype>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <limits>
#include <type_traits>
#include <unordered_map>
#include "sherpa-onnx/csrc/log.h"
#ifdef _MSC_VER
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
_strtoi64(cur_cstr, end_cstr, 10);
#else
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
#endif
namespace sherpa_onnx {
/// Converts a string into an integer via strtoll and returns false if there was
/// any kind of problem (i.e. the string was not an integer or contained extra
/// non-whitespace junk, or the integer was too large to fit into the type it is
/// being converted into). Only sets *out if everything was OK and it returns
/// true.
template <class Int>
bool ConvertStringToInteger(const std::string &str, Int *out) {
// copied from kaldi/src/util/text-util.h
static_assert(std::is_integral<Int>::value, "");
const char *this_str = str.c_str();
char *end = nullptr;
errno = 0;
int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
if (end != this_str) {
while (isspace(*end)) ++end;
}
if (end == this_str || *end != '\0' || errno != 0) return false;
Int iInt = static_cast<Int>(i);
if (static_cast<int64_t>(iInt) != i ||
(i < 0 && !std::numeric_limits<Int>::is_signed)) {
return false;
}
*out = iInt;
return true;
}
// copied from kaldi/src/util/text-util.cc
template <class T>
class NumberIstream {
public:
explicit NumberIstream(std::istream &i) : in_(i) {}
NumberIstream &operator>>(T &x) {
if (!in_.good()) return *this;
in_ >> x;
if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
return ParseOnFail(&x);
}
private:
std::istream &in_;
bool RemainderIsOnlySpaces() {
if (in_.tellg() != std::istream::pos_type(-1)) {
std::string rem;
in_ >> rem;
if (rem.find_first_not_of(' ') != std::string::npos) {
// there is not only spaces
return false;
}
}
in_.clear();
return true;
}
NumberIstream &ParseOnFail(T *x) {
std::string str;
in_.clear();
in_.seekg(0);
// If the stream is broken even before trying
// to read from it or if there are many tokens,
// it's pointless to try.
if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
in_.setstate(std::ios_base::failbit);
return *this;
}
std::unordered_map<std::string, T> inf_nan_map;
// we'll keep just uppercase values.
inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
// MSVC
inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
if (inf_nan_map.find(str) != inf_nan_map.end()) {
*x = inf_nan_map[str];
} else {
in_.setstate(std::ios_base::failbit);
}
return *this;
}
};
/// ConvertStringToReal converts a string into either float or double
/// and returns false if there was any kind of problem (i.e. the string
/// was not a floating point number or contained extra non-whitespace junk).
/// Be careful- this function will successfully read inf's or nan's.
template <typename T>
bool ConvertStringToReal(const std::string &str, T *out) {
std::istringstream iss(str);
NumberIstream<T> i(iss);
i >> *out;
if (iss.fail()) {
// Number conversion failed.
return false;
}
return true;
}
ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
: print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) {
if (po != nullptr && po->other_parser_ != nullptr) {
// we get here if this constructor is used twice, recursively.
other_parser_ = po->other_parser_;
} else {
other_parser_ = po;
}
if (po != nullptr && po->prefix_ != "") {
prefix_ = po->prefix_ + std::string(".") + prefix;
} else {
prefix_ = prefix;
}
}
void ParseOptions::Register(const std::string &name, bool *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
void ParseOptions::Register(const std::string &name, int32_t *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
void ParseOptions::Register(const std::string &name, uint32_t *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
void ParseOptions::Register(const std::string &name, float *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
void ParseOptions::Register(const std::string &name, double *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
void ParseOptions::Register(const std::string &name, std::string *ptr,
const std::string &doc) {
RegisterTmpl(name, ptr, doc);
}
// old-style, used for registering application-specific parameters
template <typename T>
void ParseOptions::RegisterTmpl(const std::string &name, T *ptr,
const std::string &doc) {
if (other_parser_ == nullptr) {
this->RegisterCommon(name, ptr, doc, false);
} else {
SHERPA_ONNX_CHECK(prefix_ != "")
<< "prefix: " << prefix_ << "\n"
<< "Cannot use empty prefix when registering with prefix.";
std::string new_name = prefix_ + '.' + name; // name becomes prefix.name
other_parser_->Register(new_name, ptr, doc);
}
}
// does the common part of the job of registering a parameter
template <typename T>
void ParseOptions::RegisterCommon(const std::string &name, T *ptr,
const std::string &doc, bool is_standard) {
SHERPA_ONNX_CHECK(ptr != nullptr);
std::string idx = name;
NormalizeArgName(&idx);
if (doc_map_.find(idx) != doc_map_.end()) {
SHERPA_ONNX_LOG(WARNING)
<< "Registering option twice, ignoring second time: " << name;
} else {
this->RegisterSpecific(name, idx, ptr, doc, is_standard);
}
}
// used to register standard parameters (those that are present in all of the
// applications)
template <typename T>
void ParseOptions::RegisterStandard(const std::string &name, T *ptr,
const std::string &doc) {
this->RegisterCommon(name, ptr, doc, true);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, bool *b,
const std::string &doc, bool is_standard) {
bool_map_[idx] = b;
doc_map_[idx] =
DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"),
is_standard);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, int32_t *i,
const std::string &doc, bool is_standard) {
int_map_[idx] = i;
std::ostringstream ss;
ss << doc << " (int, default = " << *i << ")";
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, uint32_t *u,
const std::string &doc, bool is_standard) {
uint_map_[idx] = u;
std::ostringstream ss;
ss << doc << " (uint, default = " << *u << ")";
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, float *f,
const std::string &doc, bool is_standard) {
float_map_[idx] = f;
std::ostringstream ss;
ss << doc << " (float, default = " << *f << ")";
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, double *f,
const std::string &doc, bool is_standard) {
double_map_[idx] = f;
std::ostringstream ss;
ss << doc << " (double, default = " << *f << ")";
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
}
void ParseOptions::RegisterSpecific(const std::string &name,
const std::string &idx, std::string *s,
const std::string &doc, bool is_standard) {
string_map_[idx] = s;
doc_map_[idx] =
DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard);
}
void ParseOptions::DisableOption(const std::string &name) {
if (argv_ != nullptr) {
SHERPA_ONNX_LOG(FATAL)
<< "DisableOption must not be called after calling Read().";
}
if (doc_map_.erase(name) == 0) {
SHERPA_ONNX_LOG(FATAL) << "Option " << name
<< " was not registered so cannot be disabled: ";
}
bool_map_.erase(name);
int_map_.erase(name);
uint_map_.erase(name);
float_map_.erase(name);
double_map_.erase(name);
string_map_.erase(name);
}
int ParseOptions::NumArgs() const { return positional_args_.size(); }
std::string ParseOptions::GetArg(int i) const {
if (i < 1 || i > static_cast<int>(positional_args_.size())) {
SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i;
}
return positional_args_[i - 1];
}
// We currently do not support any other options.
enum ShellType { kBash = 0 };
// This can be changed in the code if it ever does need to be changed (as it's
// unlikely that one compilation of this tool-set would use both shells).
static ShellType kShellType = kBash;
// Returns true if we need to escape a string before putting it into
// a shell (mainly thinking of bash shell, but should work for others)
// This is for the convenience of the user so command-lines that are
// printed out by ParseOptions::Read (with --print-args=true) are
// paste-able into the shell and will run. If you use a different type of
// shell, it might be necessary to change this function.
// But it's mostly a cosmetic issue as it basically affects how
// the program echoes its command-line arguments to the screen.
static bool MustBeQuoted(const std::string &str, ShellType st) {
// Only Bash is supported (for the moment).
SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
const char *c = str.c_str();
if (*c == '\0') {
return true; // Must quote empty string
} else {
const char *ok_chars[2];
// These seem not to be interpreted as long as there are no other "bad"
// characters involved (e.g. "," would be interpreted as part of something
// like a{b,c}, but not on its own.
ok_chars[kBash] = "[]~#^_-+=:.,/";
// Just want to make sure that a space character doesn't get automatically
// inserted here via an automated style-checking script, like it did before.
SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' '));
for (; *c != '\0'; ++c) {
// For non-alphanumeric characters we have a list of characters which
// are OK. All others are forbidden (this is easier since the shell
// interprets most non-alphanumeric characters).
if (!isalnum(*c)) {
const char *d;
for (d = ok_chars[st]; *d != '\0'; ++d) {
if (*c == *d) break;
}
// If not alphanumeric or one of the "ok_chars", it must be escaped.
if (*d == '\0') return true;
}
}
return false; // The string was OK. No quoting or escaping.
}
}
// Returns a quoted and escaped version of "str"
// which has previously been determined to need escaping.
// Our aim is to print out the command line in such a way that if it's
// pasted into a shell of ShellType "st" (only bash for now), it
// will get passed to the program in the same way.
static std::string QuoteAndEscape(const std::string &str, ShellType st) {
// Only Bash is supported (for the moment).
SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
// For now we use the following rules:
// In the normal case, we quote with single-quote "'", and to escape
// a single-quote we use the string: '\'' (interpreted as closing the
// single-quote, putting an escaped single-quote from the shell, and
// then reopening the single quote).
char quote_char = '\'';
const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b
// If the string contains single-quotes that would need escaping this
// way, and we determine that the string could be safely double-quoted
// without requiring any escaping, then we double-quote the string.
// This is the case if the characters "`$\ do not appear in the string.
// e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html
const char *c_str = str.c_str();
if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) {
quote_char = '"';
escape_str = "\\\""; // should never be accessed.
}
char buf[2];
buf[1] = '\0';
buf[0] = quote_char;
std::string ans = buf;
const char *c = str.c_str();
for (; *c != '\0'; ++c) {
if (*c == quote_char) {
ans += escape_str;
} else {
buf[0] = *c;
ans += buf;
}
}
buf[0] = quote_char;
ans += buf;
return ans;
}
// static function
std::string ParseOptions::Escape(const std::string &str) {
return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str;
}
int ParseOptions::Read(int argc, const char *const argv[]) {
argc_ = argc;
argv_ = argv;
std::string key, value;
int i;
// first pass: look for config parameter, look for priority
for (i = 1; i < argc; ++i) {
if (std::strncmp(argv[i], "--", 2) == 0) {
if (std::strcmp(argv[i], "--") == 0) {
// a lone "--" marks the end of named options
break;
}
bool has_equal_sign;
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
if (key.compare("config") == 0) {
ReadConfigFile(value);
} else if (key.compare("help") == 0) {
PrintUsage();
exit(0);
}
}
}
bool double_dash_seen = false;
// second pass: add the command line options
for (i = 1; i < argc; ++i) {
if (std::strncmp(argv[i], "--", 2) == 0) {
if (std::strcmp(argv[i], "--") == 0) {
// A lone "--" marks the end of named options.
// Skip that option and break the processing of named options
i += 1;
double_dash_seen = true;
break;
}
bool has_equal_sign;
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
if (!SetOption(key, value, has_equal_sign)) {
PrintUsage(true);
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i];
}
} else {
break;
}
}
// process remaining arguments as positional
for (; i < argc; ++i) {
if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) {
double_dash_seen = true;
} else {
positional_args_.push_back(std::string(argv[i]));
}
}
// if the user did not suppress this with --print-args = false....
if (print_args_) {
std::ostringstream strm;
for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
strm << '\n';
SHERPA_ONNX_LOG(INFO) << strm.str();
}
return i;
}
void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
std::ostringstream os;
os << '\n' << usage_ << '\n';
// first we print application-specific options
bool app_specific_header_printed = false;
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
if (it->second.is_standard_ == false) { // application-specific option
if (app_specific_header_printed == false) { // header was not yet printed
os << "Options:" << '\n';
app_specific_header_printed = true;
}
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
<< it->second.use_msg_ << '\n';
}
}
if (app_specific_header_printed == true) {
os << '\n';
}
// then the standard options
os << "Standard options:" << '\n';
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
if (it->second.is_standard_ == true) { // we have standard option
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
<< it->second.use_msg_ << '\n';
}
}
os << '\n';
if (print_command_line) {
std::ostringstream strm;
strm << "Command line was: ";
for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " ";
strm << '\n';
os << strm.str();
}
SHERPA_ONNX_LOG(INFO) << os.str();
}
void ParseOptions::PrintConfig(std::ostream &os) const {
os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n';
std::string key;
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
key = it->first;
os << it->second.name_ << " = ";
if (bool_map_.end() != bool_map_.find(key)) {
os << (*bool_map_.at(key) ? "true" : "false");
} else if (int_map_.end() != int_map_.find(key)) {
os << (*int_map_.at(key));
} else if (uint_map_.end() != uint_map_.find(key)) {
os << (*uint_map_.at(key));
} else if (float_map_.end() != float_map_.find(key)) {
os << (*float_map_.at(key));
} else if (double_map_.end() != double_map_.find(key)) {
os << (*double_map_.at(key));
} else if (string_map_.end() != string_map_.find(key)) {
os << "'" << *string_map_.at(key) << "'";
} else {
SHERPA_ONNX_LOG(FATAL)
<< "PrintConfig: unrecognized option " << key << "[code error]";
}
os << '\n';
}
os << '\n';
}
void ParseOptions::ReadConfigFile(const std::string &filename) {
std::ifstream is(filename.c_str(), std::ifstream::in);
if (!is.good()) {
SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename;
}
std::string line, key, value;
int32_t line_number = 0;
while (std::getline(is, line)) {
++line_number;
// trim out the comments
size_t pos;
if ((pos = line.find_first_of('#')) != std::string::npos) {
line.erase(pos);
}
// skip empty lines
Trim(&line);
if (line.length() == 0) continue;
if (line.substr(0, 2) != "--") {
SHERPA_ONNX_LOG(FATAL)
<< "Reading config file " << filename << ": line " << line_number
<< " does not look like a line "
<< "from a Kaldi command-line program's config file: should "
<< "be of the form --x=y. Note: config files intended to "
<< "be sourced by shell scripts lack the '--'.";
}
// parse option
bool has_equal_sign;
SplitLongArg(line, &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
if (!SetOption(key, value, has_equal_sign)) {
PrintUsage(true);
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file "
<< filename << ": line " << line_number;
}
}
}
void ParseOptions::SplitLongArg(const std::string &in, std::string *key,
std::string *value,
bool *has_equal_sign) const {
SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition.
size_t pos = in.find_first_of('=', 0);
if (pos == std::string::npos) { // we allow --option for bools
// defaults to empty. We handle this differently in different cases.
*key = in.substr(2, in.size() - 2); // 2 because starts with --.
*value = "";
*has_equal_sign = false;
} else if (pos == 2) { // we also don't allow empty keys: --=value
PrintUsage(true);
SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in;
} else { // normal case: --option=value
*key = in.substr(2, pos - 2); // 2 because starts with --.
*value = in.substr(pos + 1);
*has_equal_sign = true;
}
}
void ParseOptions::NormalizeArgName(std::string *str) const {
std::string out;
std::string::iterator it;
for (it = str->begin(); it != str->end(); ++it) {
if (*it == '_') {
out += '-'; // convert _ to -
} else {
out += std::tolower(*it);
}
}
*str = out;
SHERPA_ONNX_CHECK_GT(str->length(), 0);
}
void ParseOptions::Trim(std::string *str) const {
const char *white_chars = " \t\n\r\f\v";
std::string::size_type pos = str->find_last_not_of(white_chars);
if (pos != std::string::npos) {
str->erase(pos + 1);
pos = str->find_first_not_of(white_chars);
if (pos != std::string::npos) str->erase(0, pos);
} else {
str->erase(str->begin(), str->end());
}
}
bool ParseOptions::SetOption(const std::string &key, const std::string &value,
bool has_equal_sign) {
if (bool_map_.end() != bool_map_.find(key)) {
if (has_equal_sign && value == "") {
SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "=";
}
*(bool_map_[key]) = ToBool(value);
} else if (int_map_.end() != int_map_.find(key)) {
*(int_map_[key]) = ToInt(value);
} else if (uint_map_.end() != uint_map_.find(key)) {
*(uint_map_[key]) = ToUint(value);
} else if (float_map_.end() != float_map_.find(key)) {
*(float_map_[key]) = ToFloat(value);
} else if (double_map_.end() != double_map_.find(key)) {
*(double_map_[key]) = ToDouble(value);
} else if (string_map_.end() != string_map_.find(key)) {
if (!has_equal_sign) {
SHERPA_ONNX_LOG(FATAL)
<< "Invalid option --" << key << " (option format is --x=y).";
}
*(string_map_[key]) = value;
} else {
return false;
}
return true;
}
bool ParseOptions::ToBool(std::string str) const {
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
// allow "" as a valid option for "true", so that --x is the same as --x=true
if ((str.compare("true") == 0) || (str.compare("t") == 0) ||
(str.compare("1") == 0) || (str.compare("") == 0)) {
return true;
}
if ((str.compare("false") == 0) || (str.compare("f") == 0) ||
(str.compare("0") == 0)) {
return false;
}
// if it is neither true nor false:
PrintUsage(true);
SHERPA_ONNX_LOG(FATAL)
<< "Invalid format for boolean argument [expected true or false]: "
<< str;
return false; // never reached
}
int32_t ParseOptions::ToInt(const std::string &str) const {
int32_t ret = 0;
if (!ConvertStringToInteger(str, &ret))
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
return ret;
}
uint32_t ParseOptions::ToUint(const std::string &str) const {
uint32_t ret = 0;
if (!ConvertStringToInteger(str, &ret))
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
return ret;
}
float ParseOptions::ToFloat(const std::string &str) const {
float ret;
if (!ConvertStringToReal(str, &ret))
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
return ret;
}
double ParseOptions::ToDouble(const std::string &str) const {
double ret;
if (!ConvertStringToReal(str, &ret))
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
return ret;
}
// instantiate templates
template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr,
const std::string &doc);
template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr,
const std::string &doc);
template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr,
const std::string &doc);
template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr,
const std::string &doc);
template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr,
const std::string &doc);
template void ParseOptions::RegisterTmpl(const std::string &name,
std::string *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name,
int32_t *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name,
uint32_t *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name,
float *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name,
double *ptr,
const std::string &doc);
template void ParseOptions::RegisterStandard(const std::string &name,
std::string *ptr,
const std::string &doc);
template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr,
const std::string &doc,
bool is_standard);
template void ParseOptions::RegisterCommon(const std::string &name,
int32_t *ptr, const std::string &doc,
bool is_standard);
template void ParseOptions::RegisterCommon(const std::string &name,
uint32_t *ptr,
const std::string &doc,
bool is_standard);
template void ParseOptions::RegisterCommon(const std::string &name, float *ptr,
const std::string &doc,
bool is_standard);
template void ParseOptions::RegisterCommon(const std::string &name, double *ptr,
const std::string &doc,
bool is_standard);
template void ParseOptions::RegisterCommon(const std::string &name,
std::string *ptr,
const std::string &doc,
bool is_standard);
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/parse-options.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
//
// This file is copied and modified from kaldi/src/util/parse-options.h
#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
class ParseOptions {
public:
explicit ParseOptions(const char *usage)
: print_args_(true),
help_(false),
usage_(usage),
argc_(0),
argv_(nullptr),
prefix_(""),
other_parser_(nullptr) {
#if !defined(_MSC_VER) && !defined(__CYGWIN__)
// This is just a convenient place to set the stderr to line
// buffering mode, since it's called at program start.
// This helps ensure different programs' output is not mixed up.
setlinebuf(stderr);
#endif
RegisterStandard("config", &config_,
"Configuration file to read (this "
"option may be repeated)");
RegisterStandard("print-args", &print_args_,
"Print the command line arguments (to stderr)");
RegisterStandard("help", &help_, "Print out usage message");
}
/**
This is a constructor for the special case where some options are
registered with a prefix to avoid conflicts. The object thus created will
only be used temporarily to register an options class with the original
options parser (which is passed as the *other pointer) using the given
prefix. It should not be used for any other purpose, and the prefix must
not be the empty string. It seems to be the least bad way of implementing
options with prefixes at this point.
Example of usage is:
ParseOptions po; // original ParseOptions object
ParseOptions po_mfcc("mfcc", &po); // object with prefix.
MfccOptions mfcc_opts;
mfcc_opts.Register(&po_mfcc);
The options will now get registered as, e.g., --mfcc.frame-shift=10.0
instead of just --frame-shift=10.0
*/
ParseOptions(const std::string &prefix, ParseOptions *other);
ParseOptions(const ParseOptions &) = delete;
ParseOptions &operator=(const ParseOptions &) = delete;
~ParseOptions() = default;
void Register(const std::string &name, bool *ptr, const std::string &doc);
void Register(const std::string &name, int32_t *ptr, const std::string &doc);
void Register(const std::string &name, uint32_t *ptr, const std::string &doc);
void Register(const std::string &name, float *ptr, const std::string &doc);
void Register(const std::string &name, double *ptr, const std::string &doc);
void Register(const std::string &name, std::string *ptr,
const std::string &doc);
/// If called after registering an option and before calling
/// Read(), disables that option from being used. Will crash
/// at runtime if that option had not been registered.
void DisableOption(const std::string &name);
/// This one is used for registering standard parameters of all the programs
template <typename T>
void RegisterStandard(const std::string &name, T *ptr,
const std::string &doc);
/**
Parses the command line options and fills the ParseOptions-registered
variables. This must be called after all the variables were registered!!!
Initially the variables have implicit values,
then the config file values are set-up,
finally the command line values given.
Returns the first position in argv that was not used.
[typically not useful: use NumParams() and GetParam(). ]
*/
int Read(int argc, const char *const *argv);
/// Prints the usage documentation [provided in the constructor].
void PrintUsage(bool print_command_line = false) const;
/// Prints the actual configuration of all the registered variables
void PrintConfig(std::ostream &os) const;
/// Reads the options values from a config file. Must be called after
/// registering all options. This is usually used internally after the
/// standard --config option is used, but it may also be called from a
/// program.
void ReadConfigFile(const std::string &filename);
/// Number of positional parameters (c.f. argc-1).
int NumArgs() const;
/// Returns one of the positional parameters; 1-based indexing for argc/argv
/// compatibility. Will crash if param is not >=1 and <=NumArgs().
///
/// Note: Index is 1 based.
std::string GetArg(int param) const;
std::string GetOptArg(int param) const {
return (param <= NumArgs() ? GetArg(param) : "");
}
/// The following function will return a possibly quoted and escaped
/// version of "str", according to the current shell. Currently
/// this is just hardwired to bash. It's useful for debug output.
static std::string Escape(const std::string &str);
private:
/// Template to register various variable types,
/// used for program-specific parameters
template <typename T>
void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc);
// Following functions do just the datatype-specific part of the job
/// Register boolean variable
void RegisterSpecific(const std::string &name, const std::string &idx,
bool *b, const std::string &doc, bool is_standard);
/// Register int32_t variable
void RegisterSpecific(const std::string &name, const std::string &idx,
int32_t *i, const std::string &doc, bool is_standard);
/// Register unsigned int32_t variable
void RegisterSpecific(const std::string &name, const std::string &idx,
uint32_t *u, const std::string &doc, bool is_standard);
/// Register float variable
void RegisterSpecific(const std::string &name, const std::string &idx,
float *f, const std::string &doc, bool is_standard);
/// Register double variable [useful as we change BaseFloat type].
void RegisterSpecific(const std::string &name, const std::string &idx,
double *f, const std::string &doc, bool is_standard);
/// Register string variable
void RegisterSpecific(const std::string &name, const std::string &idx,
std::string *s, const std::string &doc,
bool is_standard);
/// Does the actual job for both kinds of parameters
/// Does the common part of the job for all datatypes,
/// then calls RegisterSpecific
template <typename T>
void RegisterCommon(const std::string &name, T *ptr, const std::string &doc,
bool is_standard);
/// Set option with name "key" to "value"; will crash if can't do it.
/// "has_equal_sign" is used to allow --x for a boolean option x,
/// and --y=, for a string option y.
bool SetOption(const std::string &key, const std::string &value,
bool has_equal_sign);
bool ToBool(std::string str) const;
int32_t ToInt(const std::string &str) const;
uint32_t ToUint(const std::string &str) const;
float ToFloat(const std::string &str) const;
double ToDouble(const std::string &str) const;
// maps for option variables
std::unordered_map<std::string, bool *> bool_map_;
std::unordered_map<std::string, int32_t *> int_map_;
std::unordered_map<std::string, uint32_t *> uint_map_;
std::unordered_map<std::string, float *> float_map_;
std::unordered_map<std::string, double *> double_map_;
std::unordered_map<std::string, std::string *> string_map_;
/**
Structure for options' documentation
*/
struct DocInfo {
DocInfo() = default;
DocInfo(const std::string &name, const std::string &usemsg)
: name_(name), use_msg_(usemsg), is_standard_(false) {}
DocInfo(const std::string &name, const std::string &usemsg,
bool is_standard)
: name_(name), use_msg_(usemsg), is_standard_(is_standard) {}
std::string name_;
std::string use_msg_;
bool is_standard_;
};
using DocMapType = std::unordered_map<std::string, DocInfo>;
DocMapType doc_map_; ///< map for the documentation
bool print_args_; ///< variable for the implicit --print-args parameter
bool help_; ///< variable for the implicit --help parameter
std::string config_; ///< variable for the implicit --config parameter
std::vector<std::string> positional_args_;
const char *usage_;
int argc_;
const char *const *argv_;
/// These members are not normally used. They are only used when the object
/// is constructed with a prefix
std::string prefix_;
ParseOptions *other_parser_;
protected:
/// SplitLongArg parses an argument of the form --a=b, --a=, or --a,
/// and sets "has_equal_sign" to true if an equals-sign was parsed..
/// this is needed in order to correctly allow --x for a boolean option
/// x, and --y= for a string option y, and to disallow --x= and --y.
void SplitLongArg(const std::string &in, std::string *key, std::string *value,
bool *has_equal_sign) const;
void NormalizeArgName(std::string *str) const;
/// Removes the beginning and trailing whitespaces from a string
void Trim(std::string *str) const;
};
/// This template is provided for convenience in reading config classes from
/// files; this is not the standard way to read configuration options, but may
/// occasionally be needed. This function assumes the config has a function
/// "void Register(ParseOptions *opts)" which it can call to register the
/// ParseOptions object.
template <class C>
void ReadConfigFromFile(const std::string &config_filename, C *c) {
std::ostringstream usage_str;
usage_str << "Parsing config from "
<< "from '" << config_filename << "'";
ParseOptions po(usage_str.str().c_str());
c->Register(&po);
po.ReadConfigFile(config_filename);
}
/// This variant of the template ReadConfigFromFile is for if you need to read
/// two config classes from the same file.
template <class C1, class C2>
void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) {
std::ostringstream usage_str;
usage_str << "Parsing config from "
<< "from '" << conf << "'";
ParseOptions po(usage_str.str().c_str());
c1->Register(&po);
c2->Register(&po);
po.ReadConfigFile(conf);
}
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
... ...
// sherpa-onnx/csrc/sherpa-onnx-alsa.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <cctype> // std::tolower
#include <cstdint>
#include "sherpa-onnx/csrc/alsa.h"
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
bool stop = false;
static void Handler(int sig) {
stop = true;
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
}
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-alsa \
/path/to/tokens.txt \
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
device_name \
[num_threads]
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
The device name specifies which microphone to use in case there are several
on you system. You can use
arecord -l
to find all available microphones on your computer. For instance, if it outputs
**** List of CAPTURE Hardware Devices ****
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
Subdevices: 1/1
Subdevice #0: subdevice #0
and if you want to select card 3 and the device 0 on that card, please use:
hw:3,0
as the device_name.
)usage";
fprintf(stderr, "%s\n", usage);
fprintf(stderr, "argc, %d\n", argc);
return 0;
}
signal(SIGINT, Handler);
sherpa_onnx::OnlineRecognizerConfig config;
config.tokens = argv[1];
config.model_config.debug = false;
config.model_config.encoder_filename = argv[2];
config.model_config.decoder_filename = argv[3];
config.model_config.joiner_filename = argv[4];
const char *device_name = argv[5];
config.model_config.num_threads = 2;
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;
config.endpoint_config.rule2.min_trailing_silence = 1.2;
config.endpoint_config.rule3.min_utterance_length = 300;
fprintf(stderr, "%s\n", config.ToString().c_str());
sherpa_onnx::OnlineRecognizer recognizer(config);
int32_t expected_sample_rate = config.feat_config.sampling_rate;
sherpa_onnx::Alsa alsa(device_name);
fprintf(stderr, "Use recording device: %s\n", device_name);
if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
expected_sample_rate);
exit(-1);
}
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
std::string last_text;
auto stream = recognizer.CreateStream();
sherpa_onnx::Display display;
int32_t segment_index = 0;
while (!stop) {
const std::vector<float> samples = alsa.Read(chunk);
stream->AcceptWaveform(expected_sample_rate, samples.data(),
samples.size());
while (recognizer.IsReady(stream.get())) {
recognizer.DecodeStream(stream.get());
}
auto text = recognizer.GetResult(stream.get()).text;
bool is_endpoint = recognizer.IsEndpoint(stream.get());
if (!text.empty() && last_text != text) {
last_text = text;
std::transform(text.begin(), text.end(), text.begin(),
[](auto c) { return std::tolower(c); });
display.Print(segment_index, text);
}
if (!text.empty() && is_endpoint) {
++segment_index;
recognizer.Reset(stream.get());
}
}
return 0;
}
... ...
... ... @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
features.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)
... ...
// sherpa-onnx/csrc/endpoint.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/endpoint.h"
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/endpoint.h"
namespace sherpa_onnx {
static constexpr const char *kEndpointRuleInitDoc = R"doc(
Constructor for EndpointRule.
Args:
must_contain_nonsilence:
If True, for this endpointing rule to apply there must be nonsilence in the
best-path traceback. For decoding, a non-blank token is considered as
non-silence.
min_trailing_silence:
This endpointing rule requires duration of trailing silence (in seconds)
to be ``>=`` this value.
min_utterance_length:
This endpointing rule requires utterance-length (in seconds) to
be ``>=`` this value.
)doc";
static constexpr const char *kEndpointConfigInitDoc = R"doc(
If any rule in EndpointConfig is activated, it is said that an endpointing
is detected.
Args:
rule1:
By default, it times out after 2.4 seconds of silence, even if
we decoded nothing.
rule2:
By default, it times out after 1.2 seconds of silence after decoding
something.
rule3:
By default, it times out after the utterance is 20 seconds long, regardless of
anything else.
)doc";
static void PybindEndpointRule(py::module *m) {
using PyClass = EndpointRule;
py::class_<PyClass>(*m, "EndpointRule")
.def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
kEndpointRuleInitDoc)
.def("__str__", &PyClass::ToString)
.def_readwrite("must_contain_nonsilence",
&PyClass::must_contain_nonsilence)
.def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
.def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
}
static void PybindEndpointConfig(py::module *m) {
using PyClass = EndpointConfig;
py::class_<PyClass>(*m, "EndpointConfig")
.def(
py::init(
[](float rule1_min_trailing_silence,
float rule2_min_trailing_silence,
float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
EndpointRule rule1(false, rule1_min_trailing_silence, 0);
EndpointRule rule2(true, rule2_min_trailing_silence, 0);
EndpointRule rule3(false, 0, rule3_min_utterance_length);
return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
}),
py::arg("rule1_min_trailing_silence"),
py::arg("rule2_min_trailing_silence"),
py::arg("rule3_min_utterance_length"))
.def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();
ans->rule1 = rule1;
ans->rule2 = rule2;
ans->rule3 = rule3;
return ans;
}),
py::arg("rule1") = EndpointRule(false, 2.4, 0),
py::arg("rule2") = EndpointRule(true, 1.2, 0),
py::arg("rule3") = EndpointRule(false, 0, 20),
kEndpointConfigInitDoc)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def_readwrite("rule1", &PyClass::rule1)
.def_readwrite("rule2", &PyClass::rule2)
.def_readwrite("rule3", &PyClass::rule3);
}
void PybindEndpoint(py::module *m) {
PybindEndpointRule(m);
PybindEndpointConfig(m);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/endpoint.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindEndpoint(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
... ...
... ... @@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
const OnlineTransducerModelConfig &, const std::string &,
const EndpointConfig &, bool>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def("__str__", &PyClass::ToString);
}
... ... @@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
})
.def("get_result", &PyClass::GetResult);
.def("get_result", &PyClass::GetResult)
.def("is_endpoint", &PyClass::IsEndpoint)
.def("reset", &PyClass::Reset);
}
} // namespace sherpa_onnx
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
... ... @@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFeatures(&m);
PybindOnlineTransducerModelConfig(&m);
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
}
... ...
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
OnlineStream,
... ...
... ... @@ -2,12 +2,13 @@ from pathlib import Path
from typing import List
from _sherpa_onnx import (
OnlineStream,
OnlineTransducerModelConfig,
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
def _assert_file_exists(f: str):
... ... @@ -26,6 +27,10 @@ class OnlineRecognizer(object):
num_threads: int = 4,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20,
):
"""
Please refer to
... ... @@ -52,6 +57,22 @@ class OnlineRecognizer(object):
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
... ... @@ -72,10 +93,18 @@ class OnlineRecognizer(object):
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
tokens=tokens,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -93,4 +122,10 @@ class OnlineRecognizer(object):
return self.recognizer.is_ready(s)
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text
return self.recognizer.get_result(s).text.strip()
def is_endpoint(self, s: OnlineStream) -> bool:
return self.recognizer.is_endpoint(s)
def reset(self, s: OnlineStream) -> bool:
return self.recognizer.reset(s)
... ...