正在显示
23 个修改的文件
包含
2190 行增加
和
21 行删除
| @@ -4,9 +4,11 @@ build | @@ -4,9 +4,11 @@ build | ||
| 4 | onnxruntime-* | 4 | onnxruntime-* |
| 5 | icefall-* | 5 | icefall-* |
| 6 | run.sh | 6 | run.sh |
| 7 | -sherpa-onnx-* | ||
| 8 | __pycache__ | 7 | __pycache__ |
| 9 | dist/ | 8 | dist/ |
| 10 | sherpa_onnx.egg-info/ | 9 | sherpa_onnx.egg-info/ |
| 11 | .DS_Store | 10 | .DS_Store |
| 12 | build-aarch64-linux-gnu | 11 | build-aarch64-linux-gnu |
| 12 | +sherpa-onnx-streaming-zipformer-* | ||
| 13 | +sherpa-onnx-lstm-en-* | ||
| 14 | +sherpa-onnx-lstm-zh-* |
| @@ -13,6 +13,7 @@ endif() | @@ -13,6 +13,7 @@ endif() | ||
| 13 | 13 | ||
| 14 | option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) | 14 | option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) |
| 15 | option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) | 15 | option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) |
| 16 | +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) | ||
| 16 | option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) | 17 | option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) |
| 17 | 18 | ||
| 18 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 19 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| @@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") | @@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") | ||
| 46 | message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") | 47 | message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") |
| 47 | message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") | 48 | message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") |
| 48 | message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") | 49 | message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") |
| 50 | +message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") | ||
| 51 | +message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") | ||
| 49 | 52 | ||
| 50 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") | 53 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") |
| 51 | set(CMAKE_CXX_EXTENSIONS OFF) | 54 | set(CMAKE_CXX_EXTENSIONS OFF) |
| @@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA) | @@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA) | ||
| 56 | add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1) | 59 | add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1) |
| 57 | endif() | 60 | endif() |
| 58 | 61 | ||
| 62 | +check_include_file_cxx(cxxabi.h SHERPA_ONNX_HAVE_CXXABI_H) | ||
| 63 | +check_include_file_cxx(execinfo.h SHERPA_ONNX_HAVE_EXECINFO_H) | ||
| 64 | + | ||
| 59 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | 65 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) |
| 60 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) | 66 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) |
| 61 | 67 |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +# Real-time speech recognition from a microphone with sherpa-onnx Python API | ||
| 4 | +# with endpoint detection. | ||
| 5 | +# | ||
| 6 | +# Please refer to | ||
| 7 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 8 | +# to download pre-trained models | ||
| 9 | + | ||
| 10 | +import sys | ||
| 11 | + | ||
| 12 | +try: | ||
| 13 | + import sounddevice as sd | ||
| 14 | +except ImportError as e: | ||
| 15 | + print("Please install sounddevice first. You can use") | ||
| 16 | + print() | ||
| 17 | + print(" pip install sounddevice") | ||
| 18 | + print() | ||
| 19 | + print("to install it") | ||
| 20 | + sys.exit(-1) | ||
| 21 | + | ||
| 22 | +import sherpa_onnx | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def create_recognizer(): | ||
| 26 | + # Please replace the model files if needed. | ||
| 27 | + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 28 | + # for download links. | ||
| 29 | + recognizer = sherpa_onnx.OnlineRecognizer( | ||
| 30 | + tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt", | ||
| 31 | + encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx", | ||
| 32 | + decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx", | ||
| 33 | + joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx", | ||
| 34 | + num_threads=4, | ||
| 35 | + sample_rate=16000, | ||
| 36 | + feature_dim=80, | ||
| 37 | + enable_endpoint_detection=True, | ||
| 38 | + rule1_min_trailing_silence=2.4, | ||
| 39 | + rule2_min_trailing_silence=1.2, | ||
| 40 | + rule3_min_utterance_length=300, # it essentially disables this rule | ||
| 41 | + ) | ||
| 42 | + return recognizer | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +def main(): | ||
| 46 | + print("Started! Please speak") | ||
| 47 | + recognizer = create_recognizer() | ||
| 48 | + sample_rate = 16000 | ||
| 49 | + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | ||
| 50 | + last_result = "" | ||
| 51 | + stream = recognizer.create_stream() | ||
| 52 | + | ||
| 53 | + last_result = "" | ||
| 54 | + segment_id = 0 | ||
| 55 | + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: | ||
| 56 | + while True: | ||
| 57 | + samples, _ = s.read(samples_per_read) # a blocking read | ||
| 58 | + samples = samples.reshape(-1) | ||
| 59 | + stream.accept_waveform(sample_rate, samples) | ||
| 60 | + while recognizer.is_ready(stream): | ||
| 61 | + recognizer.decode_stream(stream) | ||
| 62 | + | ||
| 63 | + is_endpoint = recognizer.is_endpoint(stream) | ||
| 64 | + | ||
| 65 | + result = recognizer.get_result(stream) | ||
| 66 | + | ||
| 67 | + if result and (last_result != result): | ||
| 68 | + last_result = result | ||
| 69 | + print(f"{segment_id}: {result}") | ||
| 70 | + | ||
| 71 | + if result and is_endpoint: | ||
| 72 | + segment_id += 1 | ||
| 73 | + recognizer.reset(stream) | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +if __name__ == "__main__": | ||
| 77 | + devices = sd.query_devices() | ||
| 78 | + print(devices) | ||
| 79 | + default_input_device_idx = sd.default.device[0] | ||
| 80 | + print(f'Use default device: {devices[default_input_device_idx]["name"]}') | ||
| 81 | + | ||
| 82 | + try: | ||
| 83 | + main() | ||
| 84 | + except KeyboardInterrupt: | ||
| 85 | + print("\nCaught Ctrl + C. Exiting") |
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 1 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | 2 | ||
| 3 | -add_library(sherpa-onnx-core | 3 | +set(sources |
| 4 | cat.cc | 4 | cat.cc |
| 5 | + endpoint.cc | ||
| 5 | features.cc | 6 | features.cc |
| 6 | online-lstm-transducer-model.cc | 7 | online-lstm-transducer-model.cc |
| 7 | online-recognizer.cc | 8 | online-recognizer.cc |
| @@ -11,6 +12,7 @@ add_library(sherpa-onnx-core | @@ -11,6 +12,7 @@ add_library(sherpa-onnx-core | ||
| 11 | online-transducer-model.cc | 12 | online-transducer-model.cc |
| 12 | online-zipformer-transducer-model.cc | 13 | online-zipformer-transducer-model.cc |
| 13 | onnx-utils.cc | 14 | onnx-utils.cc |
| 15 | + parse-options.cc | ||
| 14 | resample.cc | 16 | resample.cc |
| 15 | symbol-table.cc | 17 | symbol-table.cc |
| 16 | text-utils.cc | 18 | text-utils.cc |
| @@ -18,11 +20,29 @@ add_library(sherpa-onnx-core | @@ -18,11 +20,29 @@ add_library(sherpa-onnx-core | ||
| 18 | wave-reader.cc | 20 | wave-reader.cc |
| 19 | ) | 21 | ) |
| 20 | 22 | ||
| 23 | +if(SHERPA_ONNX_ENABLE_CHECK) | ||
| 24 | + list(APPEND sources log.cc) | ||
| 25 | +endif() | ||
| 26 | + | ||
| 27 | +add_library(sherpa-onnx-core ${sources}) | ||
| 28 | + | ||
| 21 | target_link_libraries(sherpa-onnx-core | 29 | target_link_libraries(sherpa-onnx-core |
| 22 | onnxruntime | 30 | onnxruntime |
| 23 | kaldi-native-fbank-core | 31 | kaldi-native-fbank-core |
| 24 | ) | 32 | ) |
| 25 | 33 | ||
| 34 | +if(SHERPA_ONNX_ENABLE_CHECK) | ||
| 35 | + target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) | ||
| 36 | + | ||
| 37 | + if(SHERPA_ONNX_HAVE_EXECINFO_H) | ||
| 38 | + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_EXECINFO_H=1) | ||
| 39 | + endif() | ||
| 40 | + | ||
| 41 | + if(SHERPA_ONNX_HAVE_CXXABI_H) | ||
| 42 | + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_CXXABI_H=1) | ||
| 43 | + endif() | ||
| 44 | +endif() | ||
| 45 | + | ||
| 26 | add_executable(sherpa-onnx sherpa-onnx.cc) | 46 | add_executable(sherpa-onnx sherpa-onnx.cc) |
| 27 | 47 | ||
| 28 | target_link_libraries(sherpa-onnx sherpa-onnx-core) | 48 | target_link_libraries(sherpa-onnx sherpa-onnx-core) |
sherpa-onnx/csrc/endpoint.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/endpoint.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022 (authors: Pingfeng Luo) | ||
| 4 | +// 2022-2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/endpoint.h" | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/log.h" | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +static bool RuleActivated(const EndpointRule &rule, | ||
| 16 | + const std::string &rule_name, float trailing_silence, | ||
| 17 | + float utterance_length) { | ||
| 18 | + bool contain_nonsilence = utterance_length > trailing_silence; | ||
| 19 | + bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) && | ||
| 20 | + trailing_silence >= rule.min_trailing_silence && | ||
| 21 | + utterance_length >= rule.min_utterance_length; | ||
| 22 | + if (ans) { | ||
| 23 | + SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: " | ||
| 24 | + << (contain_nonsilence ? "true" : "false") << ',' | ||
| 25 | + << trailing_silence << ',' << utterance_length; | ||
| 26 | + } | ||
| 27 | + return ans; | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule, | ||
| 31 | + const std::string &rule_name) { | ||
| 32 | + po->Register( | ||
| 33 | + rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence, | ||
| 34 | + "If True, for this endpointing " + rule_name + | ||
| 35 | + " to apply there must be nonsilence in the best-path traceback. " | ||
| 36 | + "For decoding, a non-blank token is considered as non-silence"); | ||
| 37 | + po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence, | ||
| 38 | + "This endpointing " + rule_name + | ||
| 39 | + " requires duration of trailing silence in seconds) to " | ||
| 40 | + "be >= this value."); | ||
| 41 | + po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length, | ||
| 42 | + "This endpointing " + rule_name + | ||
| 43 | + " requires utterance-length (in seconds) to be >= this " | ||
| 44 | + "value."); | ||
| 45 | +} | ||
| 46 | + | ||
| 47 | +std::string EndpointRule::ToString() const { | ||
| 48 | + std::ostringstream os; | ||
| 49 | + | ||
| 50 | + os << "EndpointRule("; | ||
| 51 | + os << "must_contain_nonsilence=" | ||
| 52 | + << (must_contain_nonsilence ? "True" : "False") << ", "; | ||
| 53 | + os << "min_trailing_silence=" << min_trailing_silence << ", "; | ||
| 54 | + os << "min_utterance_length=" << min_utterance_length << ")"; | ||
| 55 | + | ||
| 56 | + return os.str(); | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +void EndpointConfig::Register(ParseOptions *po) { | ||
| 60 | + RegisterEndpointRule(po, &rule1, "rule1"); | ||
| 61 | + RegisterEndpointRule(po, &rule2, "rule2"); | ||
| 62 | + RegisterEndpointRule(po, &rule3, "rule3"); | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +std::string EndpointConfig::ToString() const { | ||
| 66 | + std::ostringstream os; | ||
| 67 | + | ||
| 68 | + os << "EndpointConfig("; | ||
| 69 | + os << "rule1=" << rule1.ToString() << ", "; | ||
| 70 | + os << "rule2=" << rule2.ToString() << ", "; | ||
| 71 | + os << "rule3=" << rule3.ToString() << ")"; | ||
| 72 | + | ||
| 73 | + return os.str(); | ||
| 74 | +} | ||
| 75 | + | ||
| 76 | +bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames, | ||
| 77 | + float frame_shift_in_seconds) const { | ||
| 78 | + float utterance_length = num_frames_decoded * frame_shift_in_seconds; | ||
| 79 | + float trailing_silence = trailing_silence_frames * frame_shift_in_seconds; | ||
| 80 | + if (RuleActivated(config_.rule1, "rule1", trailing_silence, | ||
| 81 | + utterance_length) || | ||
| 82 | + RuleActivated(config_.rule2, "rule2", trailing_silence, | ||
| 83 | + utterance_length) || | ||
| 84 | + RuleActivated(config_.rule3, "rule3", trailing_silence, | ||
| 85 | + utterance_length)) { | ||
| 86 | + return true; | ||
| 87 | + } | ||
| 88 | + return false; | ||
| 89 | +} | ||
| 90 | + | ||
| 91 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/endpoint.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/endpoint.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022 (authors: Pingfeng Luo) | ||
| 4 | +// 2022-2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ENDPOINT_H_ | ||
| 8 | + | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct EndpointRule { | ||
| 15 | + // If True, for this endpointing rule to apply there must | ||
| 16 | + // be nonsilence in the best-path traceback. | ||
| 17 | + // For decoding, a non-blank token is considered as non-silence | ||
| 18 | + bool must_contain_nonsilence = true; | ||
| 19 | + // This endpointing rule requires duration of trailing silence | ||
| 20 | + // (in seconds) to be >= this value. | ||
| 21 | + float min_trailing_silence = 2.0; | ||
| 22 | + // This endpointing rule requires utterance-length (in seconds) | ||
| 23 | + // to be >= this value. | ||
| 24 | + float min_utterance_length = 0.0f; | ||
| 25 | + | ||
| 26 | + EndpointRule() = default; | ||
| 27 | + | ||
| 28 | + EndpointRule(bool must_contain_nonsilence, float min_trailing_silence, | ||
| 29 | + float min_utterance_length) | ||
| 30 | + : must_contain_nonsilence(must_contain_nonsilence), | ||
| 31 | + min_trailing_silence(min_trailing_silence), | ||
| 32 | + min_utterance_length(min_utterance_length) {} | ||
| 33 | + | ||
| 34 | + std::string ToString() const; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +class ParseOptions; | ||
| 38 | + | ||
| 39 | +struct EndpointConfig { | ||
| 40 | + // For default setting, | ||
| 41 | + // rule1 times out after 2.4 seconds of silence, even if we decoded nothing. | ||
| 42 | + // rule2 times out after 1.2 seconds of silence after decoding something. | ||
| 43 | + // rule3 times out after the utterance is 20 seconds long, regardless of | ||
| 44 | + // anything else. | ||
| 45 | + EndpointRule rule1; | ||
| 46 | + EndpointRule rule2; | ||
| 47 | + EndpointRule rule3; | ||
| 48 | + | ||
| 49 | + void Register(ParseOptions *po); | ||
| 50 | + | ||
| 51 | + EndpointConfig() | ||
| 52 | + : rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {} | ||
| 53 | + | ||
| 54 | + EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2, | ||
| 55 | + const EndpointRule &rule3) | ||
| 56 | + : rule1(rule1), rule2(rule2), rule3(rule3) {} | ||
| 57 | + | ||
| 58 | + std::string ToString() const; | ||
| 59 | +}; | ||
| 60 | + | ||
| 61 | +class Endpoint { | ||
| 62 | + public: | ||
| 63 | + explicit Endpoint(const EndpointConfig &config) : config_(config) {} | ||
| 64 | + | ||
| 65 | + /// This function returns true if this set of endpointing rules thinks we | ||
| 66 | + /// should terminate decoding. | ||
| 67 | + bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames, | ||
| 68 | + float frame_shift_in_seconds) const; | ||
| 69 | + | ||
| 70 | + private: | ||
| 71 | + EndpointConfig config_; | ||
| 72 | +}; | ||
| 73 | + | ||
| 74 | +} // namespace sherpa_onnx | ||
| 75 | + | ||
| 76 | +#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_ |
sherpa-onnx/csrc/log.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/log.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/log.h" | ||
| 6 | + | ||
| 7 | +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 8 | +#include <execinfo.h> // To get stack trace in error messages. | ||
| 9 | +#ifdef SHERPA_ONNX_HAVE_CXXABI_H | ||
| 10 | +#include <cxxabi.h> // For name demangling. | ||
| 11 | +// Useful to decode the stack trace, but only used if we have execinfo.h | ||
| 12 | +#endif // SHERPA_ONNX_HAVE_CXXABI_H | ||
| 13 | +#endif // SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 14 | + | ||
| 15 | +#include <stdlib.h> | ||
| 16 | + | ||
| 17 | +#include <ctime> | ||
| 18 | +#include <iomanip> | ||
| 19 | +#include <string> | ||
| 20 | + | ||
| 21 | +namespace sherpa_onnx { | ||
| 22 | + | ||
| 23 | +std::string GetDateTimeStr() { | ||
| 24 | + std::ostringstream os; | ||
| 25 | + std::time_t t = std::time(nullptr); | ||
| 26 | + std::tm tm = *std::localtime(&t); | ||
| 27 | + os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss | ||
| 28 | + return os.str(); | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin, | ||
| 32 | + std::size_t *end) { | ||
| 33 | + // Find the first '_' with leading ' ' or '('. | ||
| 34 | + *begin = std::string::npos; | ||
| 35 | + for (std::size_t i = 1; i < trace_name.size(); ++i) { | ||
| 36 | + if (trace_name[i] != '_') { | ||
| 37 | + continue; | ||
| 38 | + } | ||
| 39 | + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') { | ||
| 40 | + *begin = i; | ||
| 41 | + break; | ||
| 42 | + } | ||
| 43 | + } | ||
| 44 | + if (*begin == std::string::npos) { | ||
| 45 | + return false; | ||
| 46 | + } | ||
| 47 | + *end = trace_name.find_first_of(" +", *begin); | ||
| 48 | + return *end != std::string::npos; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 52 | +static std::string Demangle(const std::string &trace_name) { | ||
| 53 | +#ifndef SHERPA_ONNX_HAVE_CXXABI_H | ||
| 54 | + return trace_name; | ||
| 55 | +#else // SHERPA_ONNX_HAVE_CXXABI_H | ||
| 56 | + // Try demangle the symbol. We are trying to support the following formats | ||
| 57 | + // produced by different platforms: | ||
| 58 | + // | ||
| 59 | + // Linux: | ||
| 60 | + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] | ||
| 61 | + // | ||
| 62 | + // Mac: | ||
| 63 | + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813 | ||
| 64 | + // | ||
| 65 | + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and | ||
| 66 | + // demangle it info a readable name like kaldi::UnitTextError. | ||
| 67 | + std::size_t begin, end; | ||
| 68 | + if (!LocateSymbolRange(trace_name, &begin, &end)) { | ||
| 69 | + return trace_name; | ||
| 70 | + } | ||
| 71 | + std::string symbol = trace_name.substr(begin, end - begin); | ||
| 72 | + int status; | ||
| 73 | + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status); | ||
| 74 | + if (status == 0 && demangled_name != nullptr) { | ||
| 75 | + symbol = demangled_name; | ||
| 76 | + free(demangled_name); | ||
| 77 | + } | ||
| 78 | + return trace_name.substr(0, begin) + symbol + | ||
| 79 | + trace_name.substr(end, std::string::npos); | ||
| 80 | +#endif // SHERPA_ONNX_HAVE_CXXABI_H | ||
| 81 | +} | ||
| 82 | +#endif // SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 83 | + | ||
| 84 | +std::string GetStackTrace() { | ||
| 85 | + std::string ans; | ||
| 86 | +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 87 | + constexpr const std::size_t kMaxTraceSize = 50; | ||
| 88 | + constexpr const std::size_t kMaxTracePrint = 50; // Must be even. | ||
| 89 | + // Buffer for the trace. | ||
| 90 | + void *trace[kMaxTraceSize]; | ||
| 91 | + // Get the trace. | ||
| 92 | + std::size_t size = backtrace(trace, kMaxTraceSize); | ||
| 93 | + // Get the trace symbols. | ||
| 94 | + char **trace_symbol = backtrace_symbols(trace, size); | ||
| 95 | + if (trace_symbol == nullptr) return ans; | ||
| 96 | + | ||
| 97 | + // Compose a human-readable backtrace string. | ||
| 98 | + ans += "[ Stack-Trace: ]\n"; | ||
| 99 | + if (size <= kMaxTracePrint) { | ||
| 100 | + for (std::size_t i = 0; i < size; ++i) { | ||
| 101 | + ans += Demangle(trace_symbol[i]) + "\n"; | ||
| 102 | + } | ||
| 103 | + } else { // Print out first+last (e.g.) 5. | ||
| 104 | + for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) { | ||
| 105 | + ans += Demangle(trace_symbol[i]) + "\n"; | ||
| 106 | + } | ||
| 107 | + ans += ".\n.\n.\n"; | ||
| 108 | + for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) { | ||
| 109 | + ans += Demangle(trace_symbol[i]) + "\n"; | ||
| 110 | + } | ||
| 111 | + if (size == kMaxTraceSize) | ||
| 112 | + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + // We must free the array of pointers allocated by backtrace_symbols(), | ||
| 116 | + // but not the strings themselves. | ||
| 117 | + free(trace_symbol); | ||
| 118 | +#endif // SHERPA_ONNX_HAVE_EXECINFO_H | ||
| 119 | + return ans; | ||
| 120 | +} | ||
| 121 | + | ||
| 122 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/log.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/log.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_LOG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_LOG_H_ | ||
| 7 | + | ||
| 8 | +#include <stdio.h> | ||
| 9 | + | ||
| 10 | +#include <mutex> // NOLINT | ||
| 11 | +#include <sstream> | ||
| 12 | +#include <string> | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +#if SHERPA_ONNX_ENABLE_CHECK | ||
| 17 | + | ||
| 18 | +#if defined(NDEBUG) | ||
| 19 | +constexpr bool kDisableDebug = true; | ||
| 20 | +#else | ||
| 21 | +constexpr bool kDisableDebug = false; | ||
| 22 | +#endif | ||
| 23 | + | ||
| 24 | +enum class LogLevel { | ||
| 25 | + kTrace = 0, | ||
| 26 | + kDebug = 1, | ||
| 27 | + kInfo = 2, | ||
| 28 | + kWarning = 3, | ||
| 29 | + kError = 4, | ||
| 30 | + kFatal = 5, // print message and abort the program | ||
| 31 | +}; | ||
| 32 | + | ||
| 33 | +// They are used in SHERPA_ONNX_LOG(xxx), so their names | ||
| 34 | +// do not follow the google c++ code style | ||
| 35 | +// | ||
| 36 | +// You can use them in the following way: | ||
| 37 | +// | ||
| 38 | +// SHERPA_ONNX_LOG(TRACE) << "some message"; | ||
| 39 | +// SHERPA_ONNX_LOG(DEBUG) << "some message"; | ||
| 40 | +#ifndef _MSC_VER | ||
| 41 | +constexpr LogLevel TRACE = LogLevel::kTrace; | ||
| 42 | +constexpr LogLevel DEBUG = LogLevel::kDebug; | ||
| 43 | +constexpr LogLevel INFO = LogLevel::kInfo; | ||
| 44 | +constexpr LogLevel WARNING = LogLevel::kWarning; | ||
| 45 | +constexpr LogLevel ERROR = LogLevel::kError; | ||
| 46 | +constexpr LogLevel FATAL = LogLevel::kFatal; | ||
| 47 | +#else | ||
| 48 | +#define TRACE LogLevel::kTrace | ||
| 49 | +#define DEBUG LogLevel::kDebug | ||
| 50 | +#define INFO LogLevel::kInfo | ||
| 51 | +#define WARNING LogLevel::kWarning | ||
| 52 | +#define ERROR LogLevel::kError | ||
| 53 | +#define FATAL LogLevel::kFatal | ||
| 54 | +#endif | ||
| 55 | + | ||
| 56 | +std::string GetStackTrace(); | ||
| 57 | + | ||
| 58 | +/* Return the current log level. | ||
| 59 | + | ||
| 60 | + | ||
| 61 | + If the current log level is TRACE, then all logged messages are printed out. | ||
| 62 | + | ||
| 63 | + If the current log level is DEBUG, log messages with "TRACE" level are not | ||
| 64 | + shown and all other levels are printed out. | ||
| 65 | + | ||
| 66 | + Similarly, if the current log level is INFO, log message with "TRACE" and | ||
| 67 | + "DEBUG" are not shown and all other levels are printed out. | ||
| 68 | + | ||
| 69 | + If it is FATAL, then only FATAL messages are shown. | ||
| 70 | + */ | ||
| 71 | +inline LogLevel GetCurrentLogLevel() { | ||
| 72 | + static LogLevel log_level = INFO; | ||
| 73 | + static std::once_flag init_flag; | ||
| 74 | + std::call_once(init_flag, []() { | ||
| 75 | + const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL"); | ||
| 76 | + if (env_log_level == nullptr) return; | ||
| 77 | + | ||
| 78 | + std::string s = env_log_level; | ||
| 79 | + if (s == "TRACE") | ||
| 80 | + log_level = TRACE; | ||
| 81 | + else if (s == "DEBUG") | ||
| 82 | + log_level = DEBUG; | ||
| 83 | + else if (s == "INFO") | ||
| 84 | + log_level = INFO; | ||
| 85 | + else if (s == "WARNING") | ||
| 86 | + log_level = WARNING; | ||
| 87 | + else if (s == "ERROR") | ||
| 88 | + log_level = ERROR; | ||
| 89 | + else if (s == "FATAL") | ||
| 90 | + log_level = FATAL; | ||
| 91 | + else | ||
| 92 | + fprintf(stderr, | ||
| 93 | + "Unknown SHERPA_ONNX_LOG_LEVEL: %s" | ||
| 94 | + "\nSupported values are: " | ||
| 95 | + "TRACE, DEBUG, INFO, WARNING, ERROR, FATAL", | ||
| 96 | + s.c_str()); | ||
| 97 | + }); | ||
| 98 | + return log_level; | ||
| 99 | +} | ||
| 100 | + | ||
| 101 | +inline bool EnableAbort() { | ||
| 102 | + static std::once_flag init_flag; | ||
| 103 | + static bool enable_abort = false; | ||
| 104 | + std::call_once(init_flag, []() { | ||
| 105 | + enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr); | ||
| 106 | + }); | ||
| 107 | + return enable_abort; | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +class Logger { | ||
| 111 | + public: | ||
| 112 | + Logger(const char *filename, const char *func_name, uint32_t line_num, | ||
| 113 | + LogLevel level) | ||
| 114 | + : filename_(filename), | ||
| 115 | + func_name_(func_name), | ||
| 116 | + line_num_(line_num), | ||
| 117 | + level_(level) { | ||
| 118 | + cur_level_ = GetCurrentLogLevel(); | ||
| 119 | + switch (level) { | ||
| 120 | + case TRACE: | ||
| 121 | + if (cur_level_ <= TRACE) fprintf(stderr, "[T] "); | ||
| 122 | + break; | ||
| 123 | + case DEBUG: | ||
| 124 | + if (cur_level_ <= DEBUG) fprintf(stderr, "[D] "); | ||
| 125 | + break; | ||
| 126 | + case INFO: | ||
| 127 | + if (cur_level_ <= INFO) fprintf(stderr, "[I] "); | ||
| 128 | + break; | ||
| 129 | + case WARNING: | ||
| 130 | + if (cur_level_ <= WARNING) fprintf(stderr, "[W] "); | ||
| 131 | + break; | ||
| 132 | + case ERROR: | ||
| 133 | + if (cur_level_ <= ERROR) fprintf(stderr, "[E] "); | ||
| 134 | + break; | ||
| 135 | + case FATAL: | ||
| 136 | + if (cur_level_ <= FATAL) fprintf(stderr, "[F] "); | ||
| 137 | + break; | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + if (cur_level_ <= level_) { | ||
| 141 | + fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name); | ||
| 142 | + } | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + ~Logger() noexcept(false) { | ||
| 146 | + static constexpr const char *kErrMsg = R"( | ||
| 147 | + Some bad things happened. Please read the above error messages and stack | ||
| 148 | + trace. If you are using Python, the following command may be helpful: | ||
| 149 | + | ||
| 150 | + gdb --args python /path/to/your/code.py | ||
| 151 | + | ||
| 152 | + (You can use `gdb` to debug the code. Please consider compiling | ||
| 153 | + a debug version of sherpa_onnx.). | ||
| 154 | + | ||
| 155 | + If you are unable to fix it, please open an issue at: | ||
| 156 | + | ||
| 157 | + https://github.com/csukuangfj/kaldi-native-fbank/issues/new | ||
| 158 | + )"; | ||
| 159 | + if (level_ == FATAL) { | ||
| 160 | + fprintf(stderr, "\n"); | ||
| 161 | + std::string stack_trace = GetStackTrace(); | ||
| 162 | + if (!stack_trace.empty()) { | ||
| 163 | + fprintf(stderr, "\n\n%s\n", stack_trace.c_str()); | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + fflush(nullptr); | ||
| 167 | + | ||
| 168 | +#ifndef __ANDROID_API__ | ||
| 169 | + if (EnableAbort()) { | ||
| 170 | + // NOTE: abort() will terminate the program immediately without | ||
| 171 | + // printing the Python stack backtrace. | ||
| 172 | + abort(); | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + throw std::runtime_error(kErrMsg); | ||
| 176 | +#else | ||
| 177 | + abort(); | ||
| 178 | +#endif | ||
| 179 | + } | ||
| 180 | + } | ||
| 181 | + | ||
| 182 | + const Logger &operator<<(bool b) const { | ||
| 183 | + if (cur_level_ <= level_) { | ||
| 184 | + fprintf(stderr, b ? "true" : "false"); | ||
| 185 | + } | ||
| 186 | + return *this; | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + const Logger &operator<<(int8_t i) const { | ||
| 190 | + if (cur_level_ <= level_) fprintf(stderr, "%d", i); | ||
| 191 | + return *this; | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + const Logger &operator<<(const char *s) const { | ||
| 195 | + if (cur_level_ <= level_) fprintf(stderr, "%s", s); | ||
| 196 | + return *this; | ||
| 197 | + } | ||
| 198 | + | ||
| 199 | + const Logger &operator<<(int32_t i) const { | ||
| 200 | + if (cur_level_ <= level_) fprintf(stderr, "%d", i); | ||
| 201 | + return *this; | ||
| 202 | + } | ||
| 203 | + | ||
| 204 | + const Logger &operator<<(uint32_t i) const { | ||
| 205 | + if (cur_level_ <= level_) fprintf(stderr, "%u", i); | ||
| 206 | + return *this; | ||
| 207 | + } | ||
| 208 | + | ||
| 209 | + const Logger &operator<<(uint64_t i) const { | ||
| 210 | + if (cur_level_ <= level_) | ||
| 211 | + fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT | ||
| 212 | + return *this; | ||
| 213 | + } | ||
| 214 | + | ||
| 215 | + const Logger &operator<<(int64_t i) const { | ||
| 216 | + if (cur_level_ <= level_) | ||
| 217 | + fprintf(stderr, "%lli", (long long int)i); // NOLINT | ||
| 218 | + return *this; | ||
| 219 | + } | ||
| 220 | + | ||
| 221 | + const Logger &operator<<(float f) const { | ||
| 222 | + if (cur_level_ <= level_) fprintf(stderr, "%f", f); | ||
| 223 | + return *this; | ||
| 224 | + } | ||
| 225 | + | ||
| 226 | + const Logger &operator<<(double d) const { | ||
| 227 | + if (cur_level_ <= level_) fprintf(stderr, "%f", d); | ||
| 228 | + return *this; | ||
| 229 | + } | ||
| 230 | + | ||
| 231 | + template <typename T> | ||
| 232 | + const Logger &operator<<(const T &t) const { | ||
| 233 | + // require T overloads operator<< | ||
| 234 | + std::ostringstream os; | ||
| 235 | + os << t; | ||
| 236 | + return *this << os.str().c_str(); | ||
| 237 | + } | ||
| 238 | + | ||
| 239 | + // specialization to fix compile error: `stringstream << nullptr` is ambiguous | ||
| 240 | + const Logger &operator<<(const std::nullptr_t &null) const { | ||
| 241 | + if (cur_level_ <= level_) *this << "(null)"; | ||
| 242 | + return *this; | ||
| 243 | + } | ||
| 244 | + | ||
| 245 | + private: | ||
| 246 | + const char *filename_; | ||
| 247 | + const char *func_name_; | ||
| 248 | + uint32_t line_num_; | ||
| 249 | + LogLevel level_; | ||
| 250 | + LogLevel cur_level_; | ||
| 251 | +}; | ||
| 252 | +#endif // SHERPA_ONNX_ENABLE_CHECK | ||
| 253 | + | ||
| 254 | +class Voidifier { | ||
| 255 | + public: | ||
| 256 | +#if SHERPA_ONNX_ENABLE_CHECK | ||
| 257 | + void operator&(const Logger &) const {} | ||
| 258 | +#endif | ||
| 259 | +}; | ||
| 260 | +#if !defined(SHERPA_ONNX_ENABLE_CHECK) | ||
| 261 | +template <typename T> | ||
| 262 | +const Voidifier &operator<<(const Voidifier &v, T &&) { | ||
| 263 | + return v; | ||
| 264 | +} | ||
| 265 | +#endif | ||
| 266 | + | ||
| 267 | +} // namespace sherpa_onnx | ||
| 268 | + | ||
| 269 | +#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "") | ||
| 270 | + | ||
| 271 | +#ifdef SHERPA_ONNX_ENABLE_CHECK | ||
| 272 | + | ||
| 273 | +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \ | ||
| 274 | + defined(__PRETTY_FUNCTION__) | ||
| 275 | +// for clang and GCC | ||
| 276 | +#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__ | ||
| 277 | +#else | ||
| 278 | +// for other compilers | ||
| 279 | +#define SHERPA_ONNX_FUNC __func__ | ||
| 280 | +#endif | ||
| 281 | + | ||
| 282 | +#define SHERPA_ONNX_CHECK(x) \ | ||
| 283 | + (x) ? (void)0 \ | ||
| 284 | + : ::sherpa_onnx::Voidifier() & \ | ||
| 285 | + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ | ||
| 286 | + ::sherpa_onnx::FATAL) \ | ||
| 287 | + << "Check failed: " << #x << " " | ||
| 288 | + | ||
| 289 | +// WARNING: x and y may be evaluated multiple times, but this happens only | ||
| 290 | +// when the check fails. Since the program aborts if it fails, we don't think | ||
| 291 | +// the extra evaluation of x and y matters. | ||
| 292 | +// | ||
| 293 | +// CAUTION: we recommend the following use case: | ||
| 294 | +// | ||
| 295 | +// auto x = Foo(); | ||
| 296 | +// auto y = Bar(); | ||
| 297 | +// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message"; | ||
| 298 | +// | ||
| 299 | +// And please avoid | ||
| 300 | +// | ||
| 301 | +// SHERPA_ONNX_CHECK_EQ(Foo(), Bar()); | ||
| 302 | +// | ||
| 303 | +// if `Foo()` or `Bar()` causes some side effects, e.g., changing some | ||
| 304 | +// local static variables or global variables. | ||
| 305 | +#define _SHERPA_ONNX_CHECK_OP(x, y, op) \ | ||
| 306 | + ((x)op(y)) ? (void)0 \ | ||
| 307 | + : ::sherpa_onnx::Voidifier() & \ | ||
| 308 | + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ | ||
| 309 | + ::sherpa_onnx::FATAL) \ | ||
| 310 | + << "Check failed: " << #x << " " << #op << " " << #y \ | ||
| 311 | + << " (" << (x) << " vs. " << (y) << ") " | ||
| 312 | + | ||
| 313 | +#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==) | ||
| 314 | +#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=) | ||
| 315 | +#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <) | ||
| 316 | +#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=) | ||
| 317 | +#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >) | ||
| 318 | +#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=) | ||
| 319 | + | ||
| 320 | +#define SHERPA_ONNX_LOG(x) \ | ||
| 321 | + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_onnx::x) | ||
| 322 | + | ||
| 323 | +// ------------------------------------------------------------ | ||
| 324 | +// For debug check | ||
| 325 | +// ------------------------------------------------------------ | ||
| 326 | +// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank, | ||
| 327 | +// the following macros are in fact empty and does nothing. | ||
| 328 | + | ||
| 329 | +#define SHERPA_ONNX_DCHECK(x) \ | ||
| 330 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x) | ||
| 331 | + | ||
| 332 | +#define SHERPA_ONNX_DCHECK_EQ(x, y) \ | ||
| 333 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y) | ||
| 334 | + | ||
| 335 | +#define SHERPA_ONNX_DCHECK_NE(x, y) \ | ||
| 336 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y) | ||
| 337 | + | ||
| 338 | +#define SHERPA_ONNX_DCHECK_LT(x, y) \ | ||
| 339 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y) | ||
| 340 | + | ||
| 341 | +#define SHERPA_ONNX_DCHECK_LE(x, y) \ | ||
| 342 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y) | ||
| 343 | + | ||
| 344 | +#define SHERPA_ONNX_DCHECK_GT(x, y) \ | ||
| 345 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y) | ||
| 346 | + | ||
| 347 | +#define SHERPA_ONNX_DCHECK_GE(x, y) \ | ||
| 348 | + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y) | ||
| 349 | + | ||
| 350 | +#define SHERPA_ONNX_DLOG(x) \ | ||
| 351 | + ::sherpa_onnx::kDisableDebug \ | ||
| 352 | + ? (void)0 \ | ||
| 353 | + : ::sherpa_onnx::Voidifier() & SHERPA_ONNX_LOG(x) | ||
| 354 | + | ||
| 355 | +#else | ||
| 356 | + | ||
| 357 | +#define SHERPA_ONNX_CHECK(x) ::sherpa_onnx::Voidifier() | ||
| 358 | +#define SHERPA_ONNX_LOG(x) ::sherpa_onnx::Voidifier() | ||
| 359 | + | ||
| 360 | +#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_onnx::Voidifier() | ||
| 361 | +#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_onnx::Voidifier() | ||
| 362 | +#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_onnx::Voidifier() | ||
| 363 | +#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_onnx::Voidifier() | ||
| 364 | +#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_onnx::Voidifier() | ||
| 365 | +#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_onnx::Voidifier() | ||
| 366 | + | ||
| 367 | +#define SHERPA_ONNX_DCHECK(x) ::sherpa_onnx::Voidifier() | ||
| 368 | +#define SHERPA_ONNX_DLOG(x) ::sherpa_onnx::Voidifier() | ||
| 369 | +#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_onnx::Voidifier() | ||
| 370 | +#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_onnx::Voidifier() | ||
| 371 | +#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_onnx::Voidifier() | ||
| 372 | +#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_onnx::Voidifier() | ||
| 373 | +#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_onnx::Voidifier() | ||
| 374 | +#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_onnx::Voidifier() | ||
| 375 | + | ||
| 376 | +#endif // SHERPA_ONNX_CHECK_NE | ||
| 377 | + | ||
| 378 | +#endif // SHERPA_ONNX_CSRC_LOG_H_ |
| @@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 37 | os << "OnlineRecognizerConfig("; | 37 | os << "OnlineRecognizerConfig("; |
| 38 | os << "feat_config=" << feat_config.ToString() << ", "; | 38 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 39 | os << "model_config=" << model_config.ToString() << ", "; | 39 | os << "model_config=" << model_config.ToString() << ", "; |
| 40 | - os << "tokens=\"" << tokens << "\")"; | 40 | + os << "tokens=\"" << tokens << "\", "; |
| 41 | + os << "endpoint_config=" << endpoint_config.ToString() << ", "; | ||
| 42 | + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; | ||
| 41 | 43 | ||
| 42 | return os.str(); | 44 | return os.str(); |
| 43 | } | 45 | } |
| @@ -47,7 +49,8 @@ class OnlineRecognizer::Impl { | @@ -47,7 +49,8 @@ class OnlineRecognizer::Impl { | ||
| 47 | explicit Impl(const OnlineRecognizerConfig &config) | 49 | explicit Impl(const OnlineRecognizerConfig &config) |
| 48 | : config_(config), | 50 | : config_(config), |
| 49 | model_(OnlineTransducerModel::Create(config.model_config)), | 51 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 50 | - sym_(config.tokens) { | 52 | + sym_(config.tokens), |
| 53 | + endpoint_(config_.endpoint_config) { | ||
| 51 | decoder_ = | 54 | decoder_ = |
| 52 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 55 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 53 | } | 56 | } |
| @@ -64,7 +67,7 @@ class OnlineRecognizer::Impl { | @@ -64,7 +67,7 @@ class OnlineRecognizer::Impl { | ||
| 64 | s->NumFramesReady(); | 67 | s->NumFramesReady(); |
| 65 | } | 68 | } |
| 66 | 69 | ||
| 67 | - void DecodeStreams(OnlineStream **ss, int32_t n) { | 70 | + void DecodeStreams(OnlineStream **ss, int32_t n) const { |
| 68 | int32_t chunk_size = model_->ChunkSize(); | 71 | int32_t chunk_size = model_->ChunkSize(); |
| 69 | int32_t chunk_shift = model_->ChunkShift(); | 72 | int32_t chunk_shift = model_->ChunkShift(); |
| 70 | 73 | ||
| @@ -111,18 +114,44 @@ class OnlineRecognizer::Impl { | @@ -111,18 +114,44 @@ class OnlineRecognizer::Impl { | ||
| 111 | } | 114 | } |
| 112 | } | 115 | } |
| 113 | 116 | ||
| 114 | - OnlineRecognizerResult GetResult(OnlineStream *s) { | 117 | + OnlineRecognizerResult GetResult(OnlineStream *s) const { |
| 115 | OnlineTransducerDecoderResult decoder_result = s->GetResult(); | 118 | OnlineTransducerDecoderResult decoder_result = s->GetResult(); |
| 116 | decoder_->StripLeadingBlanks(&decoder_result); | 119 | decoder_->StripLeadingBlanks(&decoder_result); |
| 117 | 120 | ||
| 118 | return Convert(decoder_result, sym_); | 121 | return Convert(decoder_result, sym_); |
| 119 | } | 122 | } |
| 120 | 123 | ||
| 124 | + bool IsEndpoint(OnlineStream *s) const { | ||
| 125 | + if (!config_.enable_endpoint) return false; | ||
| 126 | + int32_t num_processed_frames = s->GetNumProcessedFrames(); | ||
| 127 | + | ||
| 128 | + // frame shift is 10 milliseconds | ||
| 129 | + float frame_shift_in_seconds = 0.01; | ||
| 130 | + | ||
| 131 | + // subsampling factor is 4 | ||
| 132 | + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; | ||
| 133 | + | ||
| 134 | + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | ||
| 135 | + frame_shift_in_seconds); | ||
| 136 | + } | ||
| 137 | + | ||
| 138 | + void Reset(OnlineStream *s) const { | ||
| 139 | + // reset result and neural network model state, | ||
| 140 | + // but keep the feature extractor state | ||
| 141 | + | ||
| 142 | + // reset result | ||
| 143 | + s->SetResult(decoder_->GetEmptyResult()); | ||
| 144 | + | ||
| 145 | + // reset neural network model state | ||
| 146 | + s->SetStates(model_->GetEncoderInitStates()); | ||
| 147 | + } | ||
| 148 | + | ||
| 121 | private: | 149 | private: |
| 122 | OnlineRecognizerConfig config_; | 150 | OnlineRecognizerConfig config_; |
| 123 | std::unique_ptr<OnlineTransducerModel> model_; | 151 | std::unique_ptr<OnlineTransducerModel> model_; |
| 124 | std::unique_ptr<OnlineTransducerDecoder> decoder_; | 152 | std::unique_ptr<OnlineTransducerDecoder> decoder_; |
| 125 | SymbolTable sym_; | 153 | SymbolTable sym_; |
| 154 | + Endpoint endpoint_; | ||
| 126 | }; | 155 | }; |
| 127 | 156 | ||
| 128 | OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) | 157 | OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) |
| @@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { | @@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { | ||
| 137 | return impl_->IsReady(s); | 166 | return impl_->IsReady(s); |
| 138 | } | 167 | } |
| 139 | 168 | ||
| 140 | -void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) { | 169 | +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { |
| 141 | impl_->DecodeStreams(ss, n); | 170 | impl_->DecodeStreams(ss, n); |
| 142 | } | 171 | } |
| 143 | 172 | ||
| 144 | -OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) { | 173 | +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const { |
| 145 | return impl_->GetResult(s); | 174 | return impl_->GetResult(s); |
| 146 | } | 175 | } |
| 147 | 176 | ||
| 177 | +bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const { | ||
| 178 | + return impl_->IsEndpoint(s); | ||
| 179 | +} | ||
| 180 | + | ||
| 181 | +void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); } | ||
| 182 | + | ||
| 148 | } // namespace sherpa_onnx | 183 | } // namespace sherpa_onnx |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <string> | 9 | #include <string> |
| 10 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/endpoint.h" | ||
| 11 | #include "sherpa-onnx/csrc/features.h" | 12 | #include "sherpa-onnx/csrc/features.h" |
| 12 | #include "sherpa-onnx/csrc/online-stream.h" | 13 | #include "sherpa-onnx/csrc/online-stream.h" |
| 13 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| @@ -22,13 +23,21 @@ struct OnlineRecognizerConfig { | @@ -22,13 +23,21 @@ struct OnlineRecognizerConfig { | ||
| 22 | FeatureExtractorConfig feat_config; | 23 | FeatureExtractorConfig feat_config; |
| 23 | OnlineTransducerModelConfig model_config; | 24 | OnlineTransducerModelConfig model_config; |
| 24 | std::string tokens; | 25 | std::string tokens; |
| 26 | + EndpointConfig endpoint_config; | ||
| 27 | + bool enable_endpoint; | ||
| 25 | 28 | ||
| 26 | OnlineRecognizerConfig() = default; | 29 | OnlineRecognizerConfig() = default; |
| 27 | 30 | ||
| 28 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, | 31 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, |
| 29 | const OnlineTransducerModelConfig &model_config, | 32 | const OnlineTransducerModelConfig &model_config, |
| 30 | - const std::string &tokens) | ||
| 31 | - : feat_config(feat_config), model_config(model_config), tokens(tokens) {} | 33 | + const std::string &tokens, |
| 34 | + const EndpointConfig &endpoint_config, | ||
| 35 | + bool enable_endpoint) | ||
| 36 | + : feat_config(feat_config), | ||
| 37 | + model_config(model_config), | ||
| 38 | + tokens(tokens), | ||
| 39 | + endpoint_config(endpoint_config), | ||
| 40 | + enable_endpoint(enable_endpoint) {} | ||
| 32 | 41 | ||
| 33 | std::string ToString() const; | 42 | std::string ToString() const; |
| 34 | }; | 43 | }; |
| @@ -48,7 +57,7 @@ class OnlineRecognizer { | @@ -48,7 +57,7 @@ class OnlineRecognizer { | ||
| 48 | bool IsReady(OnlineStream *s) const; | 57 | bool IsReady(OnlineStream *s) const; |
| 49 | 58 | ||
| 50 | /** Decode a single stream. */ | 59 | /** Decode a single stream. */ |
| 51 | - void DecodeStream(OnlineStream *s) { | 60 | + void DecodeStream(OnlineStream *s) const { |
| 52 | OnlineStream *ss[1] = {s}; | 61 | OnlineStream *ss[1] = {s}; |
| 53 | DecodeStreams(ss, 1); | 62 | DecodeStreams(ss, 1); |
| 54 | } | 63 | } |
| @@ -58,9 +67,18 @@ class OnlineRecognizer { | @@ -58,9 +67,18 @@ class OnlineRecognizer { | ||
| 58 | * @param ss Pointer array containing streams to be decoded. | 67 | * @param ss Pointer array containing streams to be decoded. |
| 59 | * @param n Number of streams in `ss`. | 68 | * @param n Number of streams in `ss`. |
| 60 | */ | 69 | */ |
| 61 | - void DecodeStreams(OnlineStream **ss, int32_t n); | 70 | + void DecodeStreams(OnlineStream **ss, int32_t n) const; |
| 62 | 71 | ||
| 63 | - OnlineRecognizerResult GetResult(OnlineStream *s); | 72 | + OnlineRecognizerResult GetResult(OnlineStream *s) const; |
| 73 | + | ||
| 74 | + // Return true if we detect an endpoint for this stream. | ||
| 75 | + // Note: If this function returns true, you usually want to | ||
| 76 | + // invoke Reset(s). | ||
| 77 | + bool IsEndpoint(OnlineStream *s) const; | ||
| 78 | + | ||
| 79 | + // Clear the state of this stream. If IsEndpoint(s) returns true, | ||
| 80 | + // after calling this function, IsEndpoint(s) will return false | ||
| 81 | + void Reset(OnlineStream *s) const; | ||
| 64 | 82 | ||
| 65 | private: | 83 | private: |
| 66 | class Impl; | 84 | class Impl; |
| @@ -55,7 +55,8 @@ class OnlineStream { | @@ -55,7 +55,8 @@ class OnlineStream { | ||
| 55 | 55 | ||
| 56 | int32_t FeatureDim() const; | 56 | int32_t FeatureDim() const; |
| 57 | 57 | ||
| 58 | - // Return a reference to the number of processed frames so far. | 58 | + // Return a reference to the number of processed frames so far |
| 59 | + // before subsampling.. | ||
| 59 | // Initially, it is 0. It is always less than NumFramesReady(). | 60 | // Initially, it is 0. It is always less than NumFramesReady(). |
| 60 | // | 61 | // |
| 61 | // The returned reference is valid as long as this object is alive. | 62 | // The returned reference is valid as long as this object is alive. |
| @@ -14,6 +14,9 @@ namespace sherpa_onnx { | @@ -14,6 +14,9 @@ namespace sherpa_onnx { | ||
| 14 | struct OnlineTransducerDecoderResult { | 14 | struct OnlineTransducerDecoderResult { |
| 15 | /// The decoded token IDs so far | 15 | /// The decoded token IDs so far |
| 16 | std::vector<int64_t> tokens; | 16 | std::vector<int64_t> tokens; |
| 17 | + | ||
| 18 | + /// number of trailing blank frames decoded so far | ||
| 19 | + int32_t num_trailing_blanks = 0; | ||
| 17 | }; | 20 | }; |
| 18 | 21 | ||
| 19 | class OnlineTransducerDecoder { | 22 | class OnlineTransducerDecoder { |
| @@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 113 | if (y != 0) { | 113 | if (y != 0) { |
| 114 | emitted = true; | 114 | emitted = true; |
| 115 | (*result)[i].tokens.push_back(y); | 115 | (*result)[i].tokens.push_back(y); |
| 116 | + (*result)[i].num_trailing_blanks = 0; | ||
| 117 | + } else { | ||
| 118 | + ++(*result)[i].num_trailing_blanks; | ||
| 116 | } | 119 | } |
| 117 | } | 120 | } |
| 118 | if (emitted) { | 121 | if (emitted) { |
sherpa-onnx/csrc/parse-options.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/parse-options.cc | ||
| 2 | +/** | ||
| 3 | + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation; | ||
| 4 | + * Saarland University (Author: Arnab Ghoshal); | ||
| 5 | + * Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); | ||
| 6 | + * Frantisek Skala; Arnab Ghoshal | ||
| 7 | + * Copyright 2013 Tanel Alumae | ||
| 8 | + */ | ||
| 9 | + | ||
| 10 | +// This file is copied and modified from kaldi/src/util/parse-options.cu | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 13 | + | ||
| 14 | +#include <ctype.h> | ||
| 15 | + | ||
| 16 | +#include <algorithm> | ||
| 17 | +#include <cctype> | ||
| 18 | +#include <cstring> | ||
| 19 | +#include <fstream> | ||
| 20 | +#include <iomanip> | ||
| 21 | +#include <limits> | ||
| 22 | +#include <type_traits> | ||
| 23 | +#include <unordered_map> | ||
| 24 | + | ||
| 25 | +#include "sherpa-onnx/csrc/log.h" | ||
| 26 | + | ||
| 27 | +#ifdef _MSC_VER | ||
| 28 | +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ | ||
| 29 | + _strtoi64(cur_cstr, end_cstr, 10); | ||
| 30 | +#else | ||
| 31 | +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); | ||
| 32 | +#endif | ||
| 33 | + | ||
| 34 | +namespace sherpa_onnx { | ||
| 35 | + | ||
| 36 | +/// Converts a string into an integer via strtoll and returns false if there was | ||
| 37 | +/// any kind of problem (i.e. the string was not an integer or contained extra | ||
| 38 | +/// non-whitespace junk, or the integer was too large to fit into the type it is | ||
| 39 | +/// being converted into). Only sets *out if everything was OK and it returns | ||
| 40 | +/// true. | ||
| 41 | +template <class Int> | ||
| 42 | +bool ConvertStringToInteger(const std::string &str, Int *out) { | ||
| 43 | + // copied from kaldi/src/util/text-util.h | ||
| 44 | + static_assert(std::is_integral<Int>::value, ""); | ||
| 45 | + const char *this_str = str.c_str(); | ||
| 46 | + char *end = nullptr; | ||
| 47 | + errno = 0; | ||
| 48 | + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); | ||
| 49 | + if (end != this_str) { | ||
| 50 | + while (isspace(*end)) ++end; | ||
| 51 | + } | ||
| 52 | + if (end == this_str || *end != '\0' || errno != 0) return false; | ||
| 53 | + Int iInt = static_cast<Int>(i); | ||
| 54 | + if (static_cast<int64_t>(iInt) != i || | ||
| 55 | + (i < 0 && !std::numeric_limits<Int>::is_signed)) { | ||
| 56 | + return false; | ||
| 57 | + } | ||
| 58 | + *out = iInt; | ||
| 59 | + return true; | ||
| 60 | +} | ||
| 61 | + | ||
| 62 | +// copied from kaldi/src/util/text-util.cc | ||
| 63 | +template <class T> | ||
| 64 | +class NumberIstream { | ||
| 65 | + public: | ||
| 66 | + explicit NumberIstream(std::istream &i) : in_(i) {} | ||
| 67 | + | ||
| 68 | + NumberIstream &operator>>(T &x) { | ||
| 69 | + if (!in_.good()) return *this; | ||
| 70 | + in_ >> x; | ||
| 71 | + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; | ||
| 72 | + return ParseOnFail(&x); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + private: | ||
| 76 | + std::istream &in_; | ||
| 77 | + | ||
| 78 | + bool RemainderIsOnlySpaces() { | ||
| 79 | + if (in_.tellg() != std::istream::pos_type(-1)) { | ||
| 80 | + std::string rem; | ||
| 81 | + in_ >> rem; | ||
| 82 | + | ||
| 83 | + if (rem.find_first_not_of(' ') != std::string::npos) { | ||
| 84 | + // there is not only spaces | ||
| 85 | + return false; | ||
| 86 | + } | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + in_.clear(); | ||
| 90 | + return true; | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + NumberIstream &ParseOnFail(T *x) { | ||
| 94 | + std::string str; | ||
| 95 | + in_.clear(); | ||
| 96 | + in_.seekg(0); | ||
| 97 | + // If the stream is broken even before trying | ||
| 98 | + // to read from it or if there are many tokens, | ||
| 99 | + // it's pointless to try. | ||
| 100 | + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { | ||
| 101 | + in_.setstate(std::ios_base::failbit); | ||
| 102 | + return *this; | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + std::unordered_map<std::string, T> inf_nan_map; | ||
| 106 | + // we'll keep just uppercase values. | ||
| 107 | + inf_nan_map["INF"] = std::numeric_limits<T>::infinity(); | ||
| 108 | + inf_nan_map["+INF"] = std::numeric_limits<T>::infinity(); | ||
| 109 | + inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity(); | ||
| 110 | + inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 111 | + inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 112 | + inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity(); | ||
| 113 | + inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 114 | + inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 115 | + inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 116 | + // MSVC | ||
| 117 | + inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity(); | ||
| 118 | + inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity(); | ||
| 119 | + inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 120 | + inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 121 | + | ||
| 122 | + std::transform(str.begin(), str.end(), str.begin(), ::toupper); | ||
| 123 | + | ||
| 124 | + if (inf_nan_map.find(str) != inf_nan_map.end()) { | ||
| 125 | + *x = inf_nan_map[str]; | ||
| 126 | + } else { | ||
| 127 | + in_.setstate(std::ios_base::failbit); | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + return *this; | ||
| 131 | + } | ||
| 132 | +}; | ||
| 133 | + | ||
| 134 | +/// ConvertStringToReal converts a string into either float or double | ||
| 135 | +/// and returns false if there was any kind of problem (i.e. the string | ||
| 136 | +/// was not a floating point number or contained extra non-whitespace junk). | ||
| 137 | +/// Be careful- this function will successfully read inf's or nan's. | ||
| 138 | +template <typename T> | ||
| 139 | +bool ConvertStringToReal(const std::string &str, T *out) { | ||
| 140 | + std::istringstream iss(str); | ||
| 141 | + | ||
| 142 | + NumberIstream<T> i(iss); | ||
| 143 | + | ||
| 144 | + i >> *out; | ||
| 145 | + | ||
| 146 | + if (iss.fail()) { | ||
| 147 | + // Number conversion failed. | ||
| 148 | + return false; | ||
| 149 | + } | ||
| 150 | + | ||
| 151 | + return true; | ||
| 152 | +} | ||
| 153 | + | ||
| 154 | +ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) | ||
| 155 | + : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { | ||
| 156 | + if (po != nullptr && po->other_parser_ != nullptr) { | ||
| 157 | + // we get here if this constructor is used twice, recursively. | ||
| 158 | + other_parser_ = po->other_parser_; | ||
| 159 | + } else { | ||
| 160 | + other_parser_ = po; | ||
| 161 | + } | ||
| 162 | + if (po != nullptr && po->prefix_ != "") { | ||
| 163 | + prefix_ = po->prefix_ + std::string(".") + prefix; | ||
| 164 | + } else { | ||
| 165 | + prefix_ = prefix; | ||
| 166 | + } | ||
| 167 | +} | ||
| 168 | + | ||
| 169 | +void ParseOptions::Register(const std::string &name, bool *ptr, | ||
| 170 | + const std::string &doc) { | ||
| 171 | + RegisterTmpl(name, ptr, doc); | ||
| 172 | +} | ||
| 173 | + | ||
| 174 | +void ParseOptions::Register(const std::string &name, int32_t *ptr, | ||
| 175 | + const std::string &doc) { | ||
| 176 | + RegisterTmpl(name, ptr, doc); | ||
| 177 | +} | ||
| 178 | + | ||
| 179 | +void ParseOptions::Register(const std::string &name, uint32_t *ptr, | ||
| 180 | + const std::string &doc) { | ||
| 181 | + RegisterTmpl(name, ptr, doc); | ||
| 182 | +} | ||
| 183 | + | ||
| 184 | +void ParseOptions::Register(const std::string &name, float *ptr, | ||
| 185 | + const std::string &doc) { | ||
| 186 | + RegisterTmpl(name, ptr, doc); | ||
| 187 | +} | ||
| 188 | + | ||
| 189 | +void ParseOptions::Register(const std::string &name, double *ptr, | ||
| 190 | + const std::string &doc) { | ||
| 191 | + RegisterTmpl(name, ptr, doc); | ||
| 192 | +} | ||
| 193 | + | ||
| 194 | +void ParseOptions::Register(const std::string &name, std::string *ptr, | ||
| 195 | + const std::string &doc) { | ||
| 196 | + RegisterTmpl(name, ptr, doc); | ||
| 197 | +} | ||
| 198 | + | ||
| 199 | +// old-style, used for registering application-specific parameters | ||
| 200 | +template <typename T> | ||
| 201 | +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr, | ||
| 202 | + const std::string &doc) { | ||
| 203 | + if (other_parser_ == nullptr) { | ||
| 204 | + this->RegisterCommon(name, ptr, doc, false); | ||
| 205 | + } else { | ||
| 206 | + SHERPA_ONNX_CHECK(prefix_ != "") | ||
| 207 | + << "prefix: " << prefix_ << "\n" | ||
| 208 | + << "Cannot use empty prefix when registering with prefix."; | ||
| 209 | + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name | ||
| 210 | + other_parser_->Register(new_name, ptr, doc); | ||
| 211 | + } | ||
| 212 | +} | ||
| 213 | + | ||
| 214 | +// does the common part of the job of registering a parameter | ||
| 215 | +template <typename T> | ||
| 216 | +void ParseOptions::RegisterCommon(const std::string &name, T *ptr, | ||
| 217 | + const std::string &doc, bool is_standard) { | ||
| 218 | + SHERPA_ONNX_CHECK(ptr != nullptr); | ||
| 219 | + std::string idx = name; | ||
| 220 | + NormalizeArgName(&idx); | ||
| 221 | + if (doc_map_.find(idx) != doc_map_.end()) { | ||
| 222 | + SHERPA_ONNX_LOG(WARNING) | ||
| 223 | + << "Registering option twice, ignoring second time: " << name; | ||
| 224 | + } else { | ||
| 225 | + this->RegisterSpecific(name, idx, ptr, doc, is_standard); | ||
| 226 | + } | ||
| 227 | +} | ||
| 228 | + | ||
| 229 | +// used to register standard parameters (those that are present in all of the | ||
| 230 | +// applications) | ||
| 231 | +template <typename T> | ||
| 232 | +void ParseOptions::RegisterStandard(const std::string &name, T *ptr, | ||
| 233 | + const std::string &doc) { | ||
| 234 | + this->RegisterCommon(name, ptr, doc, true); | ||
| 235 | +} | ||
| 236 | + | ||
| 237 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 238 | + const std::string &idx, bool *b, | ||
| 239 | + const std::string &doc, bool is_standard) { | ||
| 240 | + bool_map_[idx] = b; | ||
| 241 | + doc_map_[idx] = | ||
| 242 | + DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"), | ||
| 243 | + is_standard); | ||
| 244 | +} | ||
| 245 | + | ||
| 246 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 247 | + const std::string &idx, int32_t *i, | ||
| 248 | + const std::string &doc, bool is_standard) { | ||
| 249 | + int_map_[idx] = i; | ||
| 250 | + std::ostringstream ss; | ||
| 251 | + ss << doc << " (int, default = " << *i << ")"; | ||
| 252 | + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); | ||
| 253 | +} | ||
| 254 | + | ||
| 255 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 256 | + const std::string &idx, uint32_t *u, | ||
| 257 | + const std::string &doc, bool is_standard) { | ||
| 258 | + uint_map_[idx] = u; | ||
| 259 | + std::ostringstream ss; | ||
| 260 | + ss << doc << " (uint, default = " << *u << ")"; | ||
| 261 | + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); | ||
| 262 | +} | ||
| 263 | + | ||
| 264 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 265 | + const std::string &idx, float *f, | ||
| 266 | + const std::string &doc, bool is_standard) { | ||
| 267 | + float_map_[idx] = f; | ||
| 268 | + std::ostringstream ss; | ||
| 269 | + ss << doc << " (float, default = " << *f << ")"; | ||
| 270 | + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); | ||
| 271 | +} | ||
| 272 | + | ||
| 273 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 274 | + const std::string &idx, double *f, | ||
| 275 | + const std::string &doc, bool is_standard) { | ||
| 276 | + double_map_[idx] = f; | ||
| 277 | + std::ostringstream ss; | ||
| 278 | + ss << doc << " (double, default = " << *f << ")"; | ||
| 279 | + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); | ||
| 280 | +} | ||
| 281 | + | ||
| 282 | +void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 283 | + const std::string &idx, std::string *s, | ||
| 284 | + const std::string &doc, bool is_standard) { | ||
| 285 | + string_map_[idx] = s; | ||
| 286 | + doc_map_[idx] = | ||
| 287 | + DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard); | ||
| 288 | +} | ||
| 289 | + | ||
| 290 | +void ParseOptions::DisableOption(const std::string &name) { | ||
| 291 | + if (argv_ != nullptr) { | ||
| 292 | + SHERPA_ONNX_LOG(FATAL) | ||
| 293 | + << "DisableOption must not be called after calling Read()."; | ||
| 294 | + } | ||
| 295 | + if (doc_map_.erase(name) == 0) { | ||
| 296 | + SHERPA_ONNX_LOG(FATAL) << "Option " << name | ||
| 297 | + << " was not registered so cannot be disabled: "; | ||
| 298 | + } | ||
| 299 | + bool_map_.erase(name); | ||
| 300 | + int_map_.erase(name); | ||
| 301 | + uint_map_.erase(name); | ||
| 302 | + float_map_.erase(name); | ||
| 303 | + double_map_.erase(name); | ||
| 304 | + string_map_.erase(name); | ||
| 305 | +} | ||
| 306 | + | ||
| 307 | +int ParseOptions::NumArgs() const { return positional_args_.size(); } | ||
| 308 | + | ||
| 309 | +std::string ParseOptions::GetArg(int i) const { | ||
| 310 | + if (i < 1 || i > static_cast<int>(positional_args_.size())) { | ||
| 311 | + SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; | ||
| 312 | + } | ||
| 313 | + | ||
| 314 | + return positional_args_[i - 1]; | ||
| 315 | +} | ||
| 316 | + | ||
| 317 | +// We currently do not support any other options. | ||
| 318 | +enum ShellType { kBash = 0 }; | ||
| 319 | + | ||
| 320 | +// This can be changed in the code if it ever does need to be changed (as it's | ||
| 321 | +// unlikely that one compilation of this tool-set would use both shells). | ||
| 322 | +static ShellType kShellType = kBash; | ||
| 323 | + | ||
| 324 | +// Returns true if we need to escape a string before putting it into | ||
| 325 | +// a shell (mainly thinking of bash shell, but should work for others) | ||
| 326 | +// This is for the convenience of the user so command-lines that are | ||
| 327 | +// printed out by ParseOptions::Read (with --print-args=true) are | ||
| 328 | +// paste-able into the shell and will run. If you use a different type of | ||
| 329 | +// shell, it might be necessary to change this function. | ||
| 330 | +// But it's mostly a cosmetic issue as it basically affects how | ||
| 331 | +// the program echoes its command-line arguments to the screen. | ||
| 332 | +static bool MustBeQuoted(const std::string &str, ShellType st) { | ||
| 333 | + // Only Bash is supported (for the moment). | ||
| 334 | + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; | ||
| 335 | + | ||
| 336 | + const char *c = str.c_str(); | ||
| 337 | + if (*c == '\0') { | ||
| 338 | + return true; // Must quote empty string | ||
| 339 | + } else { | ||
| 340 | + const char *ok_chars[2]; | ||
| 341 | + | ||
| 342 | + // These seem not to be interpreted as long as there are no other "bad" | ||
| 343 | + // characters involved (e.g. "," would be interpreted as part of something | ||
| 344 | + // like a{b,c}, but not on its own. | ||
| 345 | + ok_chars[kBash] = "[]~#^_-+=:.,/"; | ||
| 346 | + | ||
| 347 | + // Just want to make sure that a space character doesn't get automatically | ||
| 348 | + // inserted here via an automated style-checking script, like it did before. | ||
| 349 | + SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' ')); | ||
| 350 | + | ||
| 351 | + for (; *c != '\0'; ++c) { | ||
| 352 | + // For non-alphanumeric characters we have a list of characters which | ||
| 353 | + // are OK. All others are forbidden (this is easier since the shell | ||
| 354 | + // interprets most non-alphanumeric characters). | ||
| 355 | + if (!isalnum(*c)) { | ||
| 356 | + const char *d; | ||
| 357 | + for (d = ok_chars[st]; *d != '\0'; ++d) { | ||
| 358 | + if (*c == *d) break; | ||
| 359 | + } | ||
| 360 | + // If not alphanumeric or one of the "ok_chars", it must be escaped. | ||
| 361 | + if (*d == '\0') return true; | ||
| 362 | + } | ||
| 363 | + } | ||
| 364 | + return false; // The string was OK. No quoting or escaping. | ||
| 365 | + } | ||
| 366 | +} | ||
| 367 | + | ||
| 368 | +// Returns a quoted and escaped version of "str" | ||
| 369 | +// which has previously been determined to need escaping. | ||
| 370 | +// Our aim is to print out the command line in such a way that if it's | ||
| 371 | +// pasted into a shell of ShellType "st" (only bash for now), it | ||
| 372 | +// will get passed to the program in the same way. | ||
| 373 | +static std::string QuoteAndEscape(const std::string &str, ShellType st) { | ||
| 374 | + // Only Bash is supported (for the moment). | ||
| 375 | + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; | ||
| 376 | + | ||
| 377 | + // For now we use the following rules: | ||
| 378 | + // In the normal case, we quote with single-quote "'", and to escape | ||
| 379 | + // a single-quote we use the string: '\'' (interpreted as closing the | ||
| 380 | + // single-quote, putting an escaped single-quote from the shell, and | ||
| 381 | + // then reopening the single quote). | ||
| 382 | + char quote_char = '\''; | ||
| 383 | + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b | ||
| 384 | + | ||
| 385 | + // If the string contains single-quotes that would need escaping this | ||
| 386 | + // way, and we determine that the string could be safely double-quoted | ||
| 387 | + // without requiring any escaping, then we double-quote the string. | ||
| 388 | + // This is the case if the characters "`$\ do not appear in the string. | ||
| 389 | + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html | ||
| 390 | + const char *c_str = str.c_str(); | ||
| 391 | + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) { | ||
| 392 | + quote_char = '"'; | ||
| 393 | + escape_str = "\\\""; // should never be accessed. | ||
| 394 | + } | ||
| 395 | + | ||
| 396 | + char buf[2]; | ||
| 397 | + buf[1] = '\0'; | ||
| 398 | + | ||
| 399 | + buf[0] = quote_char; | ||
| 400 | + std::string ans = buf; | ||
| 401 | + const char *c = str.c_str(); | ||
| 402 | + for (; *c != '\0'; ++c) { | ||
| 403 | + if (*c == quote_char) { | ||
| 404 | + ans += escape_str; | ||
| 405 | + } else { | ||
| 406 | + buf[0] = *c; | ||
| 407 | + ans += buf; | ||
| 408 | + } | ||
| 409 | + } | ||
| 410 | + buf[0] = quote_char; | ||
| 411 | + ans += buf; | ||
| 412 | + return ans; | ||
| 413 | +} | ||
| 414 | + | ||
| 415 | +// static function | ||
| 416 | +std::string ParseOptions::Escape(const std::string &str) { | ||
| 417 | + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str; | ||
| 418 | +} | ||
| 419 | + | ||
| 420 | +int ParseOptions::Read(int argc, const char *const argv[]) { | ||
| 421 | + argc_ = argc; | ||
| 422 | + argv_ = argv; | ||
| 423 | + std::string key, value; | ||
| 424 | + int i; | ||
| 425 | + | ||
| 426 | + // first pass: look for config parameter, look for priority | ||
| 427 | + for (i = 1; i < argc; ++i) { | ||
| 428 | + if (std::strncmp(argv[i], "--", 2) == 0) { | ||
| 429 | + if (std::strcmp(argv[i], "--") == 0) { | ||
| 430 | + // a lone "--" marks the end of named options | ||
| 431 | + break; | ||
| 432 | + } | ||
| 433 | + bool has_equal_sign; | ||
| 434 | + SplitLongArg(argv[i], &key, &value, &has_equal_sign); | ||
| 435 | + NormalizeArgName(&key); | ||
| 436 | + Trim(&value); | ||
| 437 | + if (key.compare("config") == 0) { | ||
| 438 | + ReadConfigFile(value); | ||
| 439 | + } else if (key.compare("help") == 0) { | ||
| 440 | + PrintUsage(); | ||
| 441 | + exit(0); | ||
| 442 | + } | ||
| 443 | + } | ||
| 444 | + } | ||
| 445 | + | ||
| 446 | + bool double_dash_seen = false; | ||
| 447 | + // second pass: add the command line options | ||
| 448 | + for (i = 1; i < argc; ++i) { | ||
| 449 | + if (std::strncmp(argv[i], "--", 2) == 0) { | ||
| 450 | + if (std::strcmp(argv[i], "--") == 0) { | ||
| 451 | + // A lone "--" marks the end of named options. | ||
| 452 | + // Skip that option and break the processing of named options | ||
| 453 | + i += 1; | ||
| 454 | + double_dash_seen = true; | ||
| 455 | + break; | ||
| 456 | + } | ||
| 457 | + bool has_equal_sign; | ||
| 458 | + SplitLongArg(argv[i], &key, &value, &has_equal_sign); | ||
| 459 | + NormalizeArgName(&key); | ||
| 460 | + Trim(&value); | ||
| 461 | + if (!SetOption(key, value, has_equal_sign)) { | ||
| 462 | + PrintUsage(true); | ||
| 463 | + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i]; | ||
| 464 | + } | ||
| 465 | + } else { | ||
| 466 | + break; | ||
| 467 | + } | ||
| 468 | + } | ||
| 469 | + | ||
| 470 | + // process remaining arguments as positional | ||
| 471 | + for (; i < argc; ++i) { | ||
| 472 | + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) { | ||
| 473 | + double_dash_seen = true; | ||
| 474 | + } else { | ||
| 475 | + positional_args_.push_back(std::string(argv[i])); | ||
| 476 | + } | ||
| 477 | + } | ||
| 478 | + | ||
| 479 | + // if the user did not suppress this with --print-args = false.... | ||
| 480 | + if (print_args_) { | ||
| 481 | + std::ostringstream strm; | ||
| 482 | + for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; | ||
| 483 | + strm << '\n'; | ||
| 484 | + SHERPA_ONNX_LOG(INFO) << strm.str(); | ||
| 485 | + } | ||
| 486 | + return i; | ||
| 487 | +} | ||
| 488 | + | ||
| 489 | +void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { | ||
| 490 | + std::ostringstream os; | ||
| 491 | + os << '\n' << usage_ << '\n'; | ||
| 492 | + // first we print application-specific options | ||
| 493 | + bool app_specific_header_printed = false; | ||
| 494 | + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { | ||
| 495 | + if (it->second.is_standard_ == false) { // application-specific option | ||
| 496 | + if (app_specific_header_printed == false) { // header was not yet printed | ||
| 497 | + os << "Options:" << '\n'; | ||
| 498 | + app_specific_header_printed = true; | ||
| 499 | + } | ||
| 500 | + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " | ||
| 501 | + << it->second.use_msg_ << '\n'; | ||
| 502 | + } | ||
| 503 | + } | ||
| 504 | + if (app_specific_header_printed == true) { | ||
| 505 | + os << '\n'; | ||
| 506 | + } | ||
| 507 | + | ||
| 508 | + // then the standard options | ||
| 509 | + os << "Standard options:" << '\n'; | ||
| 510 | + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { | ||
| 511 | + if (it->second.is_standard_ == true) { // we have standard option | ||
| 512 | + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " | ||
| 513 | + << it->second.use_msg_ << '\n'; | ||
| 514 | + } | ||
| 515 | + } | ||
| 516 | + os << '\n'; | ||
| 517 | + if (print_command_line) { | ||
| 518 | + std::ostringstream strm; | ||
| 519 | + strm << "Command line was: "; | ||
| 520 | + for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " "; | ||
| 521 | + strm << '\n'; | ||
| 522 | + os << strm.str(); | ||
| 523 | + } | ||
| 524 | + | ||
| 525 | + SHERPA_ONNX_LOG(INFO) << os.str(); | ||
| 526 | +} | ||
| 527 | + | ||
| 528 | +void ParseOptions::PrintConfig(std::ostream &os) const { | ||
| 529 | + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n'; | ||
| 530 | + std::string key; | ||
| 531 | + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { | ||
| 532 | + key = it->first; | ||
| 533 | + os << it->second.name_ << " = "; | ||
| 534 | + if (bool_map_.end() != bool_map_.find(key)) { | ||
| 535 | + os << (*bool_map_.at(key) ? "true" : "false"); | ||
| 536 | + } else if (int_map_.end() != int_map_.find(key)) { | ||
| 537 | + os << (*int_map_.at(key)); | ||
| 538 | + } else if (uint_map_.end() != uint_map_.find(key)) { | ||
| 539 | + os << (*uint_map_.at(key)); | ||
| 540 | + } else if (float_map_.end() != float_map_.find(key)) { | ||
| 541 | + os << (*float_map_.at(key)); | ||
| 542 | + } else if (double_map_.end() != double_map_.find(key)) { | ||
| 543 | + os << (*double_map_.at(key)); | ||
| 544 | + } else if (string_map_.end() != string_map_.find(key)) { | ||
| 545 | + os << "'" << *string_map_.at(key) << "'"; | ||
| 546 | + } else { | ||
| 547 | + SHERPA_ONNX_LOG(FATAL) | ||
| 548 | + << "PrintConfig: unrecognized option " << key << "[code error]"; | ||
| 549 | + } | ||
| 550 | + os << '\n'; | ||
| 551 | + } | ||
| 552 | + os << '\n'; | ||
| 553 | +} | ||
| 554 | + | ||
| 555 | +void ParseOptions::ReadConfigFile(const std::string &filename) { | ||
| 556 | + std::ifstream is(filename.c_str(), std::ifstream::in); | ||
| 557 | + if (!is.good()) { | ||
| 558 | + SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename; | ||
| 559 | + } | ||
| 560 | + | ||
| 561 | + std::string line, key, value; | ||
| 562 | + int32_t line_number = 0; | ||
| 563 | + while (std::getline(is, line)) { | ||
| 564 | + ++line_number; | ||
| 565 | + // trim out the comments | ||
| 566 | + size_t pos; | ||
| 567 | + if ((pos = line.find_first_of('#')) != std::string::npos) { | ||
| 568 | + line.erase(pos); | ||
| 569 | + } | ||
| 570 | + // skip empty lines | ||
| 571 | + Trim(&line); | ||
| 572 | + if (line.length() == 0) continue; | ||
| 573 | + | ||
| 574 | + if (line.substr(0, 2) != "--") { | ||
| 575 | + SHERPA_ONNX_LOG(FATAL) | ||
| 576 | + << "Reading config file " << filename << ": line " << line_number | ||
| 577 | + << " does not look like a line " | ||
| 578 | + << "from a Kaldi command-line program's config file: should " | ||
| 579 | + << "be of the form --x=y. Note: config files intended to " | ||
| 580 | + << "be sourced by shell scripts lack the '--'."; | ||
| 581 | + } | ||
| 582 | + | ||
| 583 | + // parse option | ||
| 584 | + bool has_equal_sign; | ||
| 585 | + SplitLongArg(line, &key, &value, &has_equal_sign); | ||
| 586 | + NormalizeArgName(&key); | ||
| 587 | + Trim(&value); | ||
| 588 | + if (!SetOption(key, value, has_equal_sign)) { | ||
| 589 | + PrintUsage(true); | ||
| 590 | + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file " | ||
| 591 | + << filename << ": line " << line_number; | ||
| 592 | + } | ||
| 593 | + } | ||
| 594 | +} | ||
| 595 | + | ||
| 596 | +void ParseOptions::SplitLongArg(const std::string &in, std::string *key, | ||
| 597 | + std::string *value, | ||
| 598 | + bool *has_equal_sign) const { | ||
| 599 | + SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition. | ||
| 600 | + size_t pos = in.find_first_of('=', 0); | ||
| 601 | + if (pos == std::string::npos) { // we allow --option for bools | ||
| 602 | + // defaults to empty. We handle this differently in different cases. | ||
| 603 | + *key = in.substr(2, in.size() - 2); // 2 because starts with --. | ||
| 604 | + *value = ""; | ||
| 605 | + *has_equal_sign = false; | ||
| 606 | + } else if (pos == 2) { // we also don't allow empty keys: --=value | ||
| 607 | + PrintUsage(true); | ||
| 608 | + SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in; | ||
| 609 | + } else { // normal case: --option=value | ||
| 610 | + *key = in.substr(2, pos - 2); // 2 because starts with --. | ||
| 611 | + *value = in.substr(pos + 1); | ||
| 612 | + *has_equal_sign = true; | ||
| 613 | + } | ||
| 614 | +} | ||
| 615 | + | ||
| 616 | +void ParseOptions::NormalizeArgName(std::string *str) const { | ||
| 617 | + std::string out; | ||
| 618 | + std::string::iterator it; | ||
| 619 | + | ||
| 620 | + for (it = str->begin(); it != str->end(); ++it) { | ||
| 621 | + if (*it == '_') { | ||
| 622 | + out += '-'; // convert _ to - | ||
| 623 | + } else { | ||
| 624 | + out += std::tolower(*it); | ||
| 625 | + } | ||
| 626 | + } | ||
| 627 | + *str = out; | ||
| 628 | + | ||
| 629 | + SHERPA_ONNX_CHECK_GT(str->length(), 0); | ||
| 630 | +} | ||
| 631 | + | ||
| 632 | +void ParseOptions::Trim(std::string *str) const { | ||
| 633 | + const char *white_chars = " \t\n\r\f\v"; | ||
| 634 | + | ||
| 635 | + std::string::size_type pos = str->find_last_not_of(white_chars); | ||
| 636 | + if (pos != std::string::npos) { | ||
| 637 | + str->erase(pos + 1); | ||
| 638 | + pos = str->find_first_not_of(white_chars); | ||
| 639 | + if (pos != std::string::npos) str->erase(0, pos); | ||
| 640 | + } else { | ||
| 641 | + str->erase(str->begin(), str->end()); | ||
| 642 | + } | ||
| 643 | +} | ||
| 644 | + | ||
| 645 | +bool ParseOptions::SetOption(const std::string &key, const std::string &value, | ||
| 646 | + bool has_equal_sign) { | ||
| 647 | + if (bool_map_.end() != bool_map_.find(key)) { | ||
| 648 | + if (has_equal_sign && value == "") { | ||
| 649 | + SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "="; | ||
| 650 | + } | ||
| 651 | + *(bool_map_[key]) = ToBool(value); | ||
| 652 | + } else if (int_map_.end() != int_map_.find(key)) { | ||
| 653 | + *(int_map_[key]) = ToInt(value); | ||
| 654 | + } else if (uint_map_.end() != uint_map_.find(key)) { | ||
| 655 | + *(uint_map_[key]) = ToUint(value); | ||
| 656 | + } else if (float_map_.end() != float_map_.find(key)) { | ||
| 657 | + *(float_map_[key]) = ToFloat(value); | ||
| 658 | + } else if (double_map_.end() != double_map_.find(key)) { | ||
| 659 | + *(double_map_[key]) = ToDouble(value); | ||
| 660 | + } else if (string_map_.end() != string_map_.find(key)) { | ||
| 661 | + if (!has_equal_sign) { | ||
| 662 | + SHERPA_ONNX_LOG(FATAL) | ||
| 663 | + << "Invalid option --" << key << " (option format is --x=y)."; | ||
| 664 | + } | ||
| 665 | + *(string_map_[key]) = value; | ||
| 666 | + } else { | ||
| 667 | + return false; | ||
| 668 | + } | ||
| 669 | + return true; | ||
| 670 | +} | ||
| 671 | + | ||
| 672 | +bool ParseOptions::ToBool(std::string str) const { | ||
| 673 | + std::transform(str.begin(), str.end(), str.begin(), ::tolower); | ||
| 674 | + | ||
| 675 | + // allow "" as a valid option for "true", so that --x is the same as --x=true | ||
| 676 | + if ((str.compare("true") == 0) || (str.compare("t") == 0) || | ||
| 677 | + (str.compare("1") == 0) || (str.compare("") == 0)) { | ||
| 678 | + return true; | ||
| 679 | + } | ||
| 680 | + if ((str.compare("false") == 0) || (str.compare("f") == 0) || | ||
| 681 | + (str.compare("0") == 0)) { | ||
| 682 | + return false; | ||
| 683 | + } | ||
| 684 | + // if it is neither true nor false: | ||
| 685 | + PrintUsage(true); | ||
| 686 | + SHERPA_ONNX_LOG(FATAL) | ||
| 687 | + << "Invalid format for boolean argument [expected true or false]: " | ||
| 688 | + << str; | ||
| 689 | + return false; // never reached | ||
| 690 | +} | ||
| 691 | + | ||
| 692 | +int32_t ParseOptions::ToInt(const std::string &str) const { | ||
| 693 | + int32_t ret = 0; | ||
| 694 | + if (!ConvertStringToInteger(str, &ret)) | ||
| 695 | + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; | ||
| 696 | + return ret; | ||
| 697 | +} | ||
| 698 | + | ||
| 699 | +uint32_t ParseOptions::ToUint(const std::string &str) const { | ||
| 700 | + uint32_t ret = 0; | ||
| 701 | + if (!ConvertStringToInteger(str, &ret)) | ||
| 702 | + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; | ||
| 703 | + return ret; | ||
| 704 | +} | ||
| 705 | + | ||
| 706 | +float ParseOptions::ToFloat(const std::string &str) const { | ||
| 707 | + float ret; | ||
| 708 | + if (!ConvertStringToReal(str, &ret)) | ||
| 709 | + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; | ||
| 710 | + return ret; | ||
| 711 | +} | ||
| 712 | + | ||
| 713 | +double ParseOptions::ToDouble(const std::string &str) const { | ||
| 714 | + double ret; | ||
| 715 | + if (!ConvertStringToReal(str, &ret)) | ||
| 716 | + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; | ||
| 717 | + return ret; | ||
| 718 | +} | ||
| 719 | + | ||
| 720 | +// instantiate templates | ||
| 721 | +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, | ||
| 722 | + const std::string &doc); | ||
| 723 | +template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr, | ||
| 724 | + const std::string &doc); | ||
| 725 | +template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr, | ||
| 726 | + const std::string &doc); | ||
| 727 | +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, | ||
| 728 | + const std::string &doc); | ||
| 729 | +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr, | ||
| 730 | + const std::string &doc); | ||
| 731 | +template void ParseOptions::RegisterTmpl(const std::string &name, | ||
| 732 | + std::string *ptr, | ||
| 733 | + const std::string &doc); | ||
| 734 | + | ||
| 735 | +template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr, | ||
| 736 | + const std::string &doc); | ||
| 737 | +template void ParseOptions::RegisterStandard(const std::string &name, | ||
| 738 | + int32_t *ptr, | ||
| 739 | + const std::string &doc); | ||
| 740 | +template void ParseOptions::RegisterStandard(const std::string &name, | ||
| 741 | + uint32_t *ptr, | ||
| 742 | + const std::string &doc); | ||
| 743 | +template void ParseOptions::RegisterStandard(const std::string &name, | ||
| 744 | + float *ptr, | ||
| 745 | + const std::string &doc); | ||
| 746 | +template void ParseOptions::RegisterStandard(const std::string &name, | ||
| 747 | + double *ptr, | ||
| 748 | + const std::string &doc); | ||
| 749 | +template void ParseOptions::RegisterStandard(const std::string &name, | ||
| 750 | + std::string *ptr, | ||
| 751 | + const std::string &doc); | ||
| 752 | + | ||
| 753 | +template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr, | ||
| 754 | + const std::string &doc, | ||
| 755 | + bool is_standard); | ||
| 756 | +template void ParseOptions::RegisterCommon(const std::string &name, | ||
| 757 | + int32_t *ptr, const std::string &doc, | ||
| 758 | + bool is_standard); | ||
| 759 | +template void ParseOptions::RegisterCommon(const std::string &name, | ||
| 760 | + uint32_t *ptr, | ||
| 761 | + const std::string &doc, | ||
| 762 | + bool is_standard); | ||
| 763 | +template void ParseOptions::RegisterCommon(const std::string &name, float *ptr, | ||
| 764 | + const std::string &doc, | ||
| 765 | + bool is_standard); | ||
| 766 | +template void ParseOptions::RegisterCommon(const std::string &name, double *ptr, | ||
| 767 | + const std::string &doc, | ||
| 768 | + bool is_standard); | ||
| 769 | +template void ParseOptions::RegisterCommon(const std::string &name, | ||
| 770 | + std::string *ptr, | ||
| 771 | + const std::string &doc, | ||
| 772 | + bool is_standard); | ||
| 773 | + | ||
| 774 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/parse-options.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/parse-options.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +// | ||
| 5 | +// This file is copied and modified from kaldi/src/util/parse-options.h | ||
| 6 | + | ||
| 7 | +#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ | ||
| 8 | +#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ | ||
| 9 | + | ||
| 10 | +#include <sstream> | ||
| 11 | +#include <string> | ||
| 12 | +#include <unordered_map> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class ParseOptions { | ||
| 18 | + public: | ||
| 19 | + explicit ParseOptions(const char *usage) | ||
| 20 | + : print_args_(true), | ||
| 21 | + help_(false), | ||
| 22 | + usage_(usage), | ||
| 23 | + argc_(0), | ||
| 24 | + argv_(nullptr), | ||
| 25 | + prefix_(""), | ||
| 26 | + other_parser_(nullptr) { | ||
| 27 | +#if !defined(_MSC_VER) && !defined(__CYGWIN__) | ||
| 28 | + // This is just a convenient place to set the stderr to line | ||
| 29 | + // buffering mode, since it's called at program start. | ||
| 30 | + // This helps ensure different programs' output is not mixed up. | ||
| 31 | + setlinebuf(stderr); | ||
| 32 | +#endif | ||
| 33 | + RegisterStandard("config", &config_, | ||
| 34 | + "Configuration file to read (this " | ||
| 35 | + "option may be repeated)"); | ||
| 36 | + RegisterStandard("print-args", &print_args_, | ||
| 37 | + "Print the command line arguments (to stderr)"); | ||
| 38 | + RegisterStandard("help", &help_, "Print out usage message"); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + /** | ||
| 42 | + This is a constructor for the special case where some options are | ||
| 43 | + registered with a prefix to avoid conflicts. The object thus created will | ||
| 44 | + only be used temporarily to register an options class with the original | ||
| 45 | + options parser (which is passed as the *other pointer) using the given | ||
| 46 | + prefix. It should not be used for any other purpose, and the prefix must | ||
| 47 | + not be the empty string. It seems to be the least bad way of implementing | ||
| 48 | + options with prefixes at this point. | ||
| 49 | + Example of usage is: | ||
| 50 | + ParseOptions po; // original ParseOptions object | ||
| 51 | + ParseOptions po_mfcc("mfcc", &po); // object with prefix. | ||
| 52 | + MfccOptions mfcc_opts; | ||
| 53 | + mfcc_opts.Register(&po_mfcc); | ||
| 54 | + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 | ||
| 55 | + instead of just --frame-shift=10.0 | ||
| 56 | + */ | ||
| 57 | + ParseOptions(const std::string &prefix, ParseOptions *other); | ||
| 58 | + | ||
| 59 | + ParseOptions(const ParseOptions &) = delete; | ||
| 60 | + ParseOptions &operator=(const ParseOptions &) = delete; | ||
| 61 | + ~ParseOptions() = default; | ||
| 62 | + | ||
| 63 | + void Register(const std::string &name, bool *ptr, const std::string &doc); | ||
| 64 | + void Register(const std::string &name, int32_t *ptr, const std::string &doc); | ||
| 65 | + void Register(const std::string &name, uint32_t *ptr, const std::string &doc); | ||
| 66 | + void Register(const std::string &name, float *ptr, const std::string &doc); | ||
| 67 | + void Register(const std::string &name, double *ptr, const std::string &doc); | ||
| 68 | + void Register(const std::string &name, std::string *ptr, | ||
| 69 | + const std::string &doc); | ||
| 70 | + | ||
| 71 | + /// If called after registering an option and before calling | ||
| 72 | + /// Read(), disables that option from being used. Will crash | ||
| 73 | + /// at runtime if that option had not been registered. | ||
| 74 | + void DisableOption(const std::string &name); | ||
| 75 | + | ||
| 76 | + /// This one is used for registering standard parameters of all the programs | ||
| 77 | + template <typename T> | ||
| 78 | + void RegisterStandard(const std::string &name, T *ptr, | ||
| 79 | + const std::string &doc); | ||
| 80 | + | ||
| 81 | + /** | ||
| 82 | + Parses the command line options and fills the ParseOptions-registered | ||
| 83 | + variables. This must be called after all the variables were registered!!! | ||
| 84 | + | ||
| 85 | + Initially the variables have implicit values, | ||
| 86 | + then the config file values are set-up, | ||
| 87 | + finally the command line values given. | ||
| 88 | + Returns the first position in argv that was not used. | ||
| 89 | + [typically not useful: use NumParams() and GetParam(). ] | ||
| 90 | + */ | ||
| 91 | + int Read(int argc, const char *const *argv); | ||
| 92 | + | ||
| 93 | + /// Prints the usage documentation [provided in the constructor]. | ||
| 94 | + void PrintUsage(bool print_command_line = false) const; | ||
| 95 | + | ||
| 96 | + /// Prints the actual configuration of all the registered variables | ||
| 97 | + void PrintConfig(std::ostream &os) const; | ||
| 98 | + | ||
| 99 | + /// Reads the options values from a config file. Must be called after | ||
| 100 | + /// registering all options. This is usually used internally after the | ||
| 101 | + /// standard --config option is used, but it may also be called from a | ||
| 102 | + /// program. | ||
| 103 | + void ReadConfigFile(const std::string &filename); | ||
| 104 | + | ||
| 105 | + /// Number of positional parameters (c.f. argc-1). | ||
| 106 | + int NumArgs() const; | ||
| 107 | + | ||
| 108 | + /// Returns one of the positional parameters; 1-based indexing for argc/argv | ||
| 109 | + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). | ||
| 110 | + /// | ||
| 111 | + /// Note: Index is 1 based. | ||
| 112 | + std::string GetArg(int param) const; | ||
| 113 | + | ||
| 114 | + std::string GetOptArg(int param) const { | ||
| 115 | + return (param <= NumArgs() ? GetArg(param) : ""); | ||
| 116 | + } | ||
| 117 | + | ||
| 118 | + /// The following function will return a possibly quoted and escaped | ||
| 119 | + /// version of "str", according to the current shell. Currently | ||
| 120 | + /// this is just hardwired to bash. It's useful for debug output. | ||
| 121 | + static std::string Escape(const std::string &str); | ||
| 122 | + | ||
| 123 | + private: | ||
| 124 | + /// Template to register various variable types, | ||
| 125 | + /// used for program-specific parameters | ||
| 126 | + template <typename T> | ||
| 127 | + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); | ||
| 128 | + | ||
| 129 | + // Following functions do just the datatype-specific part of the job | ||
| 130 | + /// Register boolean variable | ||
| 131 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 132 | + bool *b, const std::string &doc, bool is_standard); | ||
| 133 | + /// Register int32_t variable | ||
| 134 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 135 | + int32_t *i, const std::string &doc, bool is_standard); | ||
| 136 | + /// Register unsigned int32_t variable | ||
| 137 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 138 | + uint32_t *u, const std::string &doc, bool is_standard); | ||
| 139 | + /// Register float variable | ||
| 140 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 141 | + float *f, const std::string &doc, bool is_standard); | ||
| 142 | + /// Register double variable [useful as we change BaseFloat type]. | ||
| 143 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 144 | + double *f, const std::string &doc, bool is_standard); | ||
| 145 | + /// Register string variable | ||
| 146 | + void RegisterSpecific(const std::string &name, const std::string &idx, | ||
| 147 | + std::string *s, const std::string &doc, | ||
| 148 | + bool is_standard); | ||
| 149 | + | ||
| 150 | + /// Does the actual job for both kinds of parameters | ||
| 151 | + /// Does the common part of the job for all datatypes, | ||
| 152 | + /// then calls RegisterSpecific | ||
| 153 | + template <typename T> | ||
| 154 | + void RegisterCommon(const std::string &name, T *ptr, const std::string &doc, | ||
| 155 | + bool is_standard); | ||
| 156 | + | ||
| 157 | + /// Set option with name "key" to "value"; will crash if can't do it. | ||
| 158 | + /// "has_equal_sign" is used to allow --x for a boolean option x, | ||
| 159 | + /// and --y=, for a string option y. | ||
| 160 | + bool SetOption(const std::string &key, const std::string &value, | ||
| 161 | + bool has_equal_sign); | ||
| 162 | + | ||
| 163 | + bool ToBool(std::string str) const; | ||
| 164 | + int32_t ToInt(const std::string &str) const; | ||
| 165 | + uint32_t ToUint(const std::string &str) const; | ||
| 166 | + float ToFloat(const std::string &str) const; | ||
| 167 | + double ToDouble(const std::string &str) const; | ||
| 168 | + | ||
| 169 | + // maps for option variables | ||
| 170 | + std::unordered_map<std::string, bool *> bool_map_; | ||
| 171 | + std::unordered_map<std::string, int32_t *> int_map_; | ||
| 172 | + std::unordered_map<std::string, uint32_t *> uint_map_; | ||
| 173 | + std::unordered_map<std::string, float *> float_map_; | ||
| 174 | + std::unordered_map<std::string, double *> double_map_; | ||
| 175 | + std::unordered_map<std::string, std::string *> string_map_; | ||
| 176 | + | ||
| 177 | + /** | ||
| 178 | + Structure for options' documentation | ||
| 179 | + */ | ||
| 180 | + struct DocInfo { | ||
| 181 | + DocInfo() = default; | ||
| 182 | + DocInfo(const std::string &name, const std::string &usemsg) | ||
| 183 | + : name_(name), use_msg_(usemsg), is_standard_(false) {} | ||
| 184 | + DocInfo(const std::string &name, const std::string &usemsg, | ||
| 185 | + bool is_standard) | ||
| 186 | + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} | ||
| 187 | + | ||
| 188 | + std::string name_; | ||
| 189 | + std::string use_msg_; | ||
| 190 | + bool is_standard_; | ||
| 191 | + }; | ||
| 192 | + using DocMapType = std::unordered_map<std::string, DocInfo>; | ||
| 193 | + DocMapType doc_map_; ///< map for the documentation | ||
| 194 | + | ||
| 195 | + bool print_args_; ///< variable for the implicit --print-args parameter | ||
| 196 | + bool help_; ///< variable for the implicit --help parameter | ||
| 197 | + std::string config_; ///< variable for the implicit --config parameter | ||
| 198 | + std::vector<std::string> positional_args_; | ||
| 199 | + const char *usage_; | ||
| 200 | + int argc_; | ||
| 201 | + const char *const *argv_; | ||
| 202 | + | ||
| 203 | + /// These members are not normally used. They are only used when the object | ||
| 204 | + /// is constructed with a prefix | ||
| 205 | + std::string prefix_; | ||
| 206 | + ParseOptions *other_parser_; | ||
| 207 | + | ||
| 208 | + protected: | ||
| 209 | + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, | ||
| 210 | + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. | ||
| 211 | + /// this is needed in order to correctly allow --x for a boolean option | ||
| 212 | + /// x, and --y= for a string option y, and to disallow --x= and --y. | ||
| 213 | + void SplitLongArg(const std::string &in, std::string *key, std::string *value, | ||
| 214 | + bool *has_equal_sign) const; | ||
| 215 | + | ||
| 216 | + void NormalizeArgName(std::string *str) const; | ||
| 217 | + | ||
| 218 | + /// Removes the beginning and trailing whitespaces from a string | ||
| 219 | + void Trim(std::string *str) const; | ||
| 220 | +}; | ||
| 221 | + | ||
| 222 | +/// This template is provided for convenience in reading config classes from | ||
| 223 | +/// files; this is not the standard way to read configuration options, but may | ||
| 224 | +/// occasionally be needed. This function assumes the config has a function | ||
| 225 | +/// "void Register(ParseOptions *opts)" which it can call to register the | ||
| 226 | +/// ParseOptions object. | ||
| 227 | +template <class C> | ||
| 228 | +void ReadConfigFromFile(const std::string &config_filename, C *c) { | ||
| 229 | + std::ostringstream usage_str; | ||
| 230 | + usage_str << "Parsing config from " | ||
| 231 | + << "from '" << config_filename << "'"; | ||
| 232 | + ParseOptions po(usage_str.str().c_str()); | ||
| 233 | + c->Register(&po); | ||
| 234 | + po.ReadConfigFile(config_filename); | ||
| 235 | +} | ||
| 236 | + | ||
| 237 | +/// This variant of the template ReadConfigFromFile is for if you need to read | ||
| 238 | +/// two config classes from the same file. | ||
| 239 | +template <class C1, class C2> | ||
| 240 | +void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) { | ||
| 241 | + std::ostringstream usage_str; | ||
| 242 | + usage_str << "Parsing config from " | ||
| 243 | + << "from '" << conf << "'"; | ||
| 244 | + ParseOptions po(usage_str.str().c_str()); | ||
| 245 | + c1->Register(&po); | ||
| 246 | + c2->Register(&po); | ||
| 247 | + po.ReadConfigFile(conf); | ||
| 248 | +} | ||
| 249 | + | ||
| 250 | +} // namespace sherpa_onnx | ||
| 251 | + | ||
| 252 | +#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ |
sherpa-onnx/csrc/sherpa-onnx-alsa.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-alsa.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#include <signal.h> | ||
| 5 | +#include <stdio.h> | ||
| 6 | +#include <stdlib.h> | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <cctype> // std::tolower | ||
| 10 | +#include <cstdint> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/alsa.h" | ||
| 13 | +#include "sherpa-onnx/csrc/display.h" | ||
| 14 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 15 | + | ||
| 16 | +bool stop = false; | ||
| 17 | + | ||
| 18 | +static void Handler(int sig) { | ||
| 19 | + stop = true; | ||
| 20 | + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +int main(int32_t argc, char *argv[]) { | ||
| 24 | + if (argc < 6 || argc > 7) { | ||
| 25 | + const char *usage = R"usage( | ||
| 26 | +Usage: | ||
| 27 | + ./bin/sherpa-onnx-alsa \ | ||
| 28 | + /path/to/tokens.txt \ | ||
| 29 | + /path/to/encoder.onnx \ | ||
| 30 | + /path/to/decoder.onnx \ | ||
| 31 | + /path/to/joiner.onnx \ | ||
| 32 | + device_name \ | ||
| 33 | + [num_threads] | ||
| 34 | + | ||
| 35 | +Please refer to | ||
| 36 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 37 | +for a list of pre-trained models to download. | ||
| 38 | + | ||
| 39 | +The device name specifies which microphone to use in case there are several | ||
| 40 | +on you system. You can use | ||
| 41 | + | ||
| 42 | + arecord -l | ||
| 43 | + | ||
| 44 | +to find all available microphones on your computer. For instance, if it outputs | ||
| 45 | + | ||
| 46 | +**** List of CAPTURE Hardware Devices **** | ||
| 47 | +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] | ||
| 48 | + Subdevices: 1/1 | ||
| 49 | + Subdevice #0: subdevice #0 | ||
| 50 | + | ||
| 51 | +and if you want to select card 3 and the device 0 on that card, please use: | ||
| 52 | + | ||
| 53 | + hw:3,0 | ||
| 54 | + | ||
| 55 | +as the device_name. | ||
| 56 | +)usage"; | ||
| 57 | + | ||
| 58 | + fprintf(stderr, "%s\n", usage); | ||
| 59 | + fprintf(stderr, "argc, %d\n", argc); | ||
| 60 | + | ||
| 61 | + return 0; | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + signal(SIGINT, Handler); | ||
| 65 | + | ||
| 66 | + sherpa_onnx::OnlineRecognizerConfig config; | ||
| 67 | + | ||
| 68 | + config.tokens = argv[1]; | ||
| 69 | + | ||
| 70 | + config.model_config.debug = false; | ||
| 71 | + config.model_config.encoder_filename = argv[2]; | ||
| 72 | + config.model_config.decoder_filename = argv[3]; | ||
| 73 | + config.model_config.joiner_filename = argv[4]; | ||
| 74 | + | ||
| 75 | + const char *device_name = argv[5]; | ||
| 76 | + | ||
| 77 | + config.model_config.num_threads = 2; | ||
| 78 | + if (argc == 7 && atoi(argv[6]) > 0) { | ||
| 79 | + config.model_config.num_threads = atoi(argv[6]); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + config.enable_endpoint = true; | ||
| 83 | + | ||
| 84 | + config.endpoint_config.rule1.min_trailing_silence = 2.4; | ||
| 85 | + config.endpoint_config.rule2.min_trailing_silence = 1.2; | ||
| 86 | + config.endpoint_config.rule3.min_utterance_length = 300; | ||
| 87 | + | ||
| 88 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 89 | + | ||
| 90 | + sherpa_onnx::OnlineRecognizer recognizer(config); | ||
| 91 | + | ||
| 92 | + int32_t expected_sample_rate = config.feat_config.sampling_rate; | ||
| 93 | + | ||
| 94 | + sherpa_onnx::Alsa alsa(device_name); | ||
| 95 | + fprintf(stderr, "Use recording device: %s\n", device_name); | ||
| 96 | + | ||
| 97 | + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { | ||
| 98 | + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), | ||
| 99 | + expected_sample_rate); | ||
| 100 | + exit(-1); | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); | ||
| 104 | + | ||
| 105 | + std::string last_text; | ||
| 106 | + | ||
| 107 | + auto stream = recognizer.CreateStream(); | ||
| 108 | + | ||
| 109 | + sherpa_onnx::Display display; | ||
| 110 | + | ||
| 111 | + int32_t segment_index = 0; | ||
| 112 | + while (!stop) { | ||
| 113 | + const std::vector<float> samples = alsa.Read(chunk); | ||
| 114 | + | ||
| 115 | + stream->AcceptWaveform(expected_sample_rate, samples.data(), | ||
| 116 | + samples.size()); | ||
| 117 | + | ||
| 118 | + while (recognizer.IsReady(stream.get())) { | ||
| 119 | + recognizer.DecodeStream(stream.get()); | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + auto text = recognizer.GetResult(stream.get()).text; | ||
| 123 | + | ||
| 124 | + bool is_endpoint = recognizer.IsEndpoint(stream.get()); | ||
| 125 | + | ||
| 126 | + if (!text.empty() && last_text != text) { | ||
| 127 | + last_text = text; | ||
| 128 | + | ||
| 129 | + std::transform(text.begin(), text.end(), text.begin(), | ||
| 130 | + [](auto c) { return std::tolower(c); }); | ||
| 131 | + | ||
| 132 | + display.Print(segment_index, text); | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + if (!text.empty() && is_endpoint) { | ||
| 136 | + ++segment_index; | ||
| 137 | + recognizer.Reset(stream.get()); | ||
| 138 | + } | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + return 0; | ||
| 142 | +} |
| @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx | @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 4 | features.cc | 4 | features.cc |
| 5 | online-transducer-model-config.cc | 5 | online-transducer-model-config.cc |
| 6 | sherpa-onnx.cc | 6 | sherpa-onnx.cc |
| 7 | + endpoint.cc | ||
| 7 | online-stream.cc | 8 | online-stream.cc |
| 8 | online-recognizer.cc | 9 | online-recognizer.cc |
| 9 | ) | 10 | ) |
sherpa-onnx/python/csrc/endpoint.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/endpoint.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/endpoint.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/endpoint.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +static constexpr const char *kEndpointRuleInitDoc = R"doc( | ||
| 15 | +Constructor for EndpointRule. | ||
| 16 | + | ||
| 17 | +Args: | ||
| 18 | + must_contain_nonsilence: | ||
| 19 | + If True, for this endpointing rule to apply there must be nonsilence in the | ||
| 20 | + best-path traceback. For decoding, a non-blank token is considered as | ||
| 21 | + non-silence. | ||
| 22 | + min_trailing_silence: | ||
| 23 | + This endpointing rule requires duration of trailing silence (in seconds) | ||
| 24 | + to be ``>=`` this value. | ||
| 25 | + min_utterance_length: | ||
| 26 | + This endpointing rule requires utterance-length (in seconds) to | ||
| 27 | + be ``>=`` this value. | ||
| 28 | +)doc"; | ||
| 29 | + | ||
| 30 | +static constexpr const char *kEndpointConfigInitDoc = R"doc( | ||
| 31 | +If any rule in EndpointConfig is activated, it is said that an endpointing | ||
| 32 | +is detected. | ||
| 33 | + | ||
| 34 | +Args: | ||
| 35 | + rule1: | ||
| 36 | + By default, it times out after 2.4 seconds of silence, even if | ||
| 37 | + we decoded nothing. | ||
| 38 | + rule2: | ||
| 39 | + By default, it times out after 1.2 seconds of silence after decoding | ||
| 40 | + something. | ||
| 41 | + rule3: | ||
| 42 | + By default, it times out after the utterance is 20 seconds long, regardless of | ||
| 43 | + anything else. | ||
| 44 | +)doc"; | ||
| 45 | + | ||
| 46 | +static void PybindEndpointRule(py::module *m) { | ||
| 47 | + using PyClass = EndpointRule; | ||
| 48 | + py::class_<PyClass>(*m, "EndpointRule") | ||
| 49 | + .def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"), | ||
| 50 | + py::arg("min_trailing_silence"), py::arg("min_utterance_length"), | ||
| 51 | + kEndpointRuleInitDoc) | ||
| 52 | + .def("__str__", &PyClass::ToString) | ||
| 53 | + .def_readwrite("must_contain_nonsilence", | ||
| 54 | + &PyClass::must_contain_nonsilence) | ||
| 55 | + .def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence) | ||
| 56 | + .def_readwrite("min_utterance_length", &PyClass::min_utterance_length); | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +static void PybindEndpointConfig(py::module *m) { | ||
| 60 | + using PyClass = EndpointConfig; | ||
| 61 | + py::class_<PyClass>(*m, "EndpointConfig") | ||
| 62 | + .def( | ||
| 63 | + py::init( | ||
| 64 | + [](float rule1_min_trailing_silence, | ||
| 65 | + float rule2_min_trailing_silence, | ||
| 66 | + float rule3_min_utterance_length) -> std::unique_ptr<PyClass> { | ||
| 67 | + EndpointRule rule1(false, rule1_min_trailing_silence, 0); | ||
| 68 | + EndpointRule rule2(true, rule2_min_trailing_silence, 0); | ||
| 69 | + EndpointRule rule3(false, 0, rule3_min_utterance_length); | ||
| 70 | + | ||
| 71 | + return std::make_unique<EndpointConfig>(rule1, rule2, rule3); | ||
| 72 | + }), | ||
| 73 | + py::arg("rule1_min_trailing_silence"), | ||
| 74 | + py::arg("rule2_min_trailing_silence"), | ||
| 75 | + py::arg("rule3_min_utterance_length")) | ||
| 76 | + .def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2, | ||
| 77 | + const EndpointRule &rule3) -> std::unique_ptr<PyClass> { | ||
| 78 | + auto ans = std::make_unique<PyClass>(); | ||
| 79 | + ans->rule1 = rule1; | ||
| 80 | + ans->rule2 = rule2; | ||
| 81 | + ans->rule3 = rule3; | ||
| 82 | + return ans; | ||
| 83 | + }), | ||
| 84 | + py::arg("rule1") = EndpointRule(false, 2.4, 0), | ||
| 85 | + py::arg("rule2") = EndpointRule(true, 1.2, 0), | ||
| 86 | + py::arg("rule3") = EndpointRule(false, 0, 20), | ||
| 87 | + kEndpointConfigInitDoc) | ||
| 88 | + .def("__str__", | ||
| 89 | + [](const PyClass &self) -> std::string { return self.ToString(); }) | ||
| 90 | + .def_readwrite("rule1", &PyClass::rule1) | ||
| 91 | + .def_readwrite("rule2", &PyClass::rule2) | ||
| 92 | + .def_readwrite("rule3", &PyClass::rule3); | ||
| 93 | +} | ||
| 94 | + | ||
| 95 | +void PybindEndpoint(py::module *m) { | ||
| 96 | + PybindEndpointRule(m); | ||
| 97 | + PybindEndpointConfig(m); | ||
| 98 | +} | ||
| 99 | + | ||
| 100 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/endpoint.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/endpoint.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindEndpoint(py::module *m); | ||
| 13 | + | ||
| 14 | +} // namespace sherpa_onnx | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ |
| @@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 21 | using PyClass = OnlineRecognizerConfig; | 21 | using PyClass = OnlineRecognizerConfig; |
| 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 23 | .def(py::init<const FeatureExtractorConfig &, | 23 | .def(py::init<const FeatureExtractorConfig &, |
| 24 | - const OnlineTransducerModelConfig &, const std::string &>(), | ||
| 25 | - py::arg("feat_config"), py::arg("model_config"), py::arg("tokens")) | 24 | + const OnlineTransducerModelConfig &, const std::string &, |
| 25 | + const EndpointConfig &, bool>(), | ||
| 26 | + py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"), | ||
| 27 | + py::arg("endpoint_config"), py::arg("enable_endpoint")) | ||
| 26 | .def_readwrite("feat_config", &PyClass::feat_config) | 28 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 27 | .def_readwrite("model_config", &PyClass::model_config) | 29 | .def_readwrite("model_config", &PyClass::model_config) |
| 28 | .def_readwrite("tokens", &PyClass::tokens) | 30 | .def_readwrite("tokens", &PyClass::tokens) |
| 31 | + .def_readwrite("endpoint_config", &PyClass::endpoint_config) | ||
| 32 | + .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | ||
| 29 | .def("__str__", &PyClass::ToString); | 33 | .def("__str__", &PyClass::ToString); |
| 30 | } | 34 | } |
| 31 | 35 | ||
| @@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) { | @@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) { | ||
| 43 | [](PyClass &self, std::vector<OnlineStream *> ss) { | 47 | [](PyClass &self, std::vector<OnlineStream *> ss) { |
| 44 | self.DecodeStreams(ss.data(), ss.size()); | 48 | self.DecodeStreams(ss.data(), ss.size()); |
| 45 | }) | 49 | }) |
| 46 | - .def("get_result", &PyClass::GetResult); | 50 | + .def("get_result", &PyClass::GetResult) |
| 51 | + .def("is_endpoint", &PyClass::IsEndpoint) | ||
| 52 | + .def("reset", &PyClass::Reset); | ||
| 47 | } | 53 | } |
| 48 | 54 | ||
| 49 | } // namespace sherpa_onnx | 55 | } // namespace sherpa_onnx |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/sherpa-onnx.h" | 5 | #include "sherpa-onnx/python/csrc/sherpa-onnx.h" |
| 6 | 6 | ||
| 7 | +#include "sherpa-onnx/python/csrc/endpoint.h" | ||
| 7 | #include "sherpa-onnx/python/csrc/features.h" | 8 | #include "sherpa-onnx/python/csrc/features.h" |
| 8 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 9 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| 9 | #include "sherpa-onnx/python/csrc/online-stream.h" | 10 | #include "sherpa-onnx/python/csrc/online-stream.h" |
| @@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 16 | PybindFeatures(&m); | 17 | PybindFeatures(&m); |
| 17 | PybindOnlineTransducerModelConfig(&m); | 18 | PybindOnlineTransducerModelConfig(&m); |
| 18 | PybindOnlineStream(&m); | 19 | PybindOnlineStream(&m); |
| 20 | + PybindEndpoint(&m); | ||
| 19 | PybindOnlineRecognizer(&m); | 21 | PybindOnlineRecognizer(&m); |
| 20 | } | 22 | } |
| 21 | 23 |
| @@ -2,12 +2,13 @@ from pathlib import Path | @@ -2,12 +2,13 @@ from pathlib import Path | ||
| 2 | from typing import List | 2 | from typing import List |
| 3 | 3 | ||
| 4 | from _sherpa_onnx import ( | 4 | from _sherpa_onnx import ( |
| 5 | - OnlineStream, | ||
| 6 | - OnlineTransducerModelConfig, | 5 | + EndpointConfig, |
| 7 | FeatureExtractorConfig, | 6 | FeatureExtractorConfig, |
| 7 | + OnlineRecognizer as _Recognizer, | ||
| 8 | OnlineRecognizerConfig, | 8 | OnlineRecognizerConfig, |
| 9 | + OnlineStream, | ||
| 10 | + OnlineTransducerModelConfig, | ||
| 9 | ) | 11 | ) |
| 10 | -from _sherpa_onnx import OnlineRecognizer as _Recognizer | ||
| 11 | 12 | ||
| 12 | 13 | ||
| 13 | def _assert_file_exists(f: str): | 14 | def _assert_file_exists(f: str): |
| @@ -26,6 +27,10 @@ class OnlineRecognizer(object): | @@ -26,6 +27,10 @@ class OnlineRecognizer(object): | ||
| 26 | num_threads: int = 4, | 27 | num_threads: int = 4, |
| 27 | sample_rate: float = 16000, | 28 | sample_rate: float = 16000, |
| 28 | feature_dim: int = 80, | 29 | feature_dim: int = 80, |
| 30 | + enable_endpoint_detection: bool = False, | ||
| 31 | + rule1_min_trailing_silence: int = 2.4, | ||
| 32 | + rule2_min_trailing_silence: int = 1.2, | ||
| 33 | + rule3_min_utterance_length: int = 20, | ||
| 29 | ): | 34 | ): |
| 30 | """ | 35 | """ |
| 31 | Please refer to | 36 | Please refer to |
| @@ -52,6 +57,22 @@ class OnlineRecognizer(object): | @@ -52,6 +57,22 @@ class OnlineRecognizer(object): | ||
| 52 | Sample rate of the training data used to train the model. | 57 | Sample rate of the training data used to train the model. |
| 53 | feature_dim: | 58 | feature_dim: |
| 54 | Dimension of the feature used to train the model. | 59 | Dimension of the feature used to train the model. |
| 60 | + enable_endpoint_detection: | ||
| 61 | + True to enable endpoint detection. False to disable endpoint | ||
| 62 | + detection. | ||
| 63 | + rule1_min_trailing_silence: | ||
| 64 | + Used only when enable_endpoint_detection is True. If the duration | ||
| 65 | + of trailing silence in seconds is larger than this value, we assume | ||
| 66 | + an endpoint is detected. | ||
| 67 | + rule2_min_trailing_silence: | ||
| 68 | + Used only when enable_endpoint_detection is True. If we have decoded | ||
| 69 | + something that is nonsilence and if the duration of trailing silence | ||
| 70 | + in seconds is larger than this value, we assume an endpoint is | ||
| 71 | + detected. | ||
| 72 | + rule3_min_utterance_length: | ||
| 73 | + Used only when enable_endpoint_detection is True. If the utterance | ||
| 74 | + length in seconds is larger than this value, we assume an endpoint | ||
| 75 | + is detected. | ||
| 55 | """ | 76 | """ |
| 56 | _assert_file_exists(tokens) | 77 | _assert_file_exists(tokens) |
| 57 | _assert_file_exists(encoder) | 78 | _assert_file_exists(encoder) |
| @@ -72,10 +93,18 @@ class OnlineRecognizer(object): | @@ -72,10 +93,18 @@ class OnlineRecognizer(object): | ||
| 72 | feature_dim=feature_dim, | 93 | feature_dim=feature_dim, |
| 73 | ) | 94 | ) |
| 74 | 95 | ||
| 96 | + endpoint_config = EndpointConfig( | ||
| 97 | + rule1_min_trailing_silence=rule1_min_trailing_silence, | ||
| 98 | + rule2_min_trailing_silence=rule2_min_trailing_silence, | ||
| 99 | + rule3_min_utterance_length=rule3_min_utterance_length, | ||
| 100 | + ) | ||
| 101 | + | ||
| 75 | recognizer_config = OnlineRecognizerConfig( | 102 | recognizer_config = OnlineRecognizerConfig( |
| 76 | feat_config=feat_config, | 103 | feat_config=feat_config, |
| 77 | model_config=model_config, | 104 | model_config=model_config, |
| 78 | tokens=tokens, | 105 | tokens=tokens, |
| 106 | + endpoint_config=endpoint_config, | ||
| 107 | + enable_endpoint=enable_endpoint_detection, | ||
| 79 | ) | 108 | ) |
| 80 | 109 | ||
| 81 | self.recognizer = _Recognizer(recognizer_config) | 110 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -93,4 +122,10 @@ class OnlineRecognizer(object): | @@ -93,4 +122,10 @@ class OnlineRecognizer(object): | ||
| 93 | return self.recognizer.is_ready(s) | 122 | return self.recognizer.is_ready(s) |
| 94 | 123 | ||
| 95 | def get_result(self, s: OnlineStream) -> str: | 124 | def get_result(self, s: OnlineStream) -> str: |
| 96 | - return self.recognizer.get_result(s).text | 125 | + return self.recognizer.get_result(s).text.strip() |
| 126 | + | ||
| 127 | + def is_endpoint(self, s: OnlineStream) -> bool: | ||
| 128 | + return self.recognizer.is_endpoint(s) | ||
| 129 | + | ||
| 130 | + def reset(self, s: OnlineStream) -> bool: | ||
| 131 | + return self.recognizer.reset(s) |
-
请 注册 或 登录 后发表评论