Committed by
GitHub
Test int8 models (#107)
* Test int8 models * Fix displaying help messages * small fixes * Fix jni test
正在显示
7 个修改的文件
包含
296 行增加
和
175 行删除
| @@ -35,7 +35,7 @@ fun main() { | @@ -35,7 +35,7 @@ fun main() { | ||
| 35 | 35 | ||
| 36 | var objArray = WaveReader.readWave( | 36 | var objArray = WaveReader.readWave( |
| 37 | assetManager = AssetManager(), | 37 | assetManager = AssetManager(), |
| 38 | - filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", | 38 | + filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", |
| 39 | ) | 39 | ) |
| 40 | var samples : FloatArray = objArray[0] as FloatArray | 40 | var samples : FloatArray = objArray[0] as FloatArray |
| 41 | var sampleRate : Int = objArray[1] as Int | 41 | var sampleRate : Int = objArray[1] as Int |
| @@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url" | @@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 26 | pushd $repo | 26 | pushd $repo |
| 27 | git lfs pull --include "*.onnx" | 27 | git lfs pull --include "*.onnx" |
| 28 | +ls -lh *.onnx | ||
| 28 | popd | 29 | popd |
| 29 | 30 | ||
| 30 | time $EXE \ | 31 | time $EXE \ |
| @@ -37,6 +38,16 @@ time $EXE \ | @@ -37,6 +38,16 @@ time $EXE \ | ||
| 37 | $repo/test_wavs/1.wav \ | 38 | $repo/test_wavs/1.wav \ |
| 38 | $repo/test_wavs/8k.wav | 39 | $repo/test_wavs/8k.wav |
| 39 | 40 | ||
| 41 | +time $EXE \ | ||
| 42 | + --tokens=$repo/tokens.txt \ | ||
| 43 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 44 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 45 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 46 | + --num-threads=2 \ | ||
| 47 | + $repo/test_wavs/0.wav \ | ||
| 48 | + $repo/test_wavs/1.wav \ | ||
| 49 | + $repo/test_wavs/8k.wav | ||
| 50 | + | ||
| 40 | rm -rf $repo | 51 | rm -rf $repo |
| 41 | 52 | ||
| 42 | log "------------------------------------------------------------" | 53 | log "------------------------------------------------------------" |
| @@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url" | @@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 51 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 62 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 52 | pushd $repo | 63 | pushd $repo |
| 53 | git lfs pull --include "*.onnx" | 64 | git lfs pull --include "*.onnx" |
| 65 | +ls -lh *.onnx | ||
| 54 | popd | 66 | popd |
| 55 | 67 | ||
| 56 | time $EXE \ | 68 | time $EXE \ |
| @@ -63,6 +75,16 @@ time $EXE \ | @@ -63,6 +75,16 @@ time $EXE \ | ||
| 63 | $repo/test_wavs/1.wav \ | 75 | $repo/test_wavs/1.wav \ |
| 64 | $repo/test_wavs/8k.wav | 76 | $repo/test_wavs/8k.wav |
| 65 | 77 | ||
| 78 | +time $EXE \ | ||
| 79 | + --tokens=$repo/tokens.txt \ | ||
| 80 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 81 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 82 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 83 | + --num-threads=2 \ | ||
| 84 | + $repo/test_wavs/0.wav \ | ||
| 85 | + $repo/test_wavs/1.wav \ | ||
| 86 | + $repo/test_wavs/8k.wav | ||
| 87 | + | ||
| 66 | rm -rf $repo | 88 | rm -rf $repo |
| 67 | 89 | ||
| 68 | log "------------------------------------------------------------" | 90 | log "------------------------------------------------------------" |
| @@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url" | @@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 77 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 99 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 78 | pushd $repo | 100 | pushd $repo |
| 79 | git lfs pull --include "*.onnx" | 101 | git lfs pull --include "*.onnx" |
| 102 | +ls -lh *.onnx | ||
| 80 | popd | 103 | popd |
| 81 | 104 | ||
| 82 | time $EXE \ | 105 | time $EXE \ |
| @@ -89,4 +112,14 @@ time $EXE \ | @@ -89,4 +112,14 @@ time $EXE \ | ||
| 89 | $repo/test_wavs/2.wav \ | 112 | $repo/test_wavs/2.wav \ |
| 90 | $repo/test_wavs/8k.wav | 113 | $repo/test_wavs/8k.wav |
| 91 | 114 | ||
| 115 | +time $EXE \ | ||
| 116 | + --tokens=$repo/tokens.txt \ | ||
| 117 | + --paraformer=$repo/model.int8.onnx \ | ||
| 118 | + --num-threads=2 \ | ||
| 119 | + --decoding-method=greedy_search \ | ||
| 120 | + $repo/test_wavs/0.wav \ | ||
| 121 | + $repo/test_wavs/1.wav \ | ||
| 122 | + $repo/test_wavs/2.wav \ | ||
| 123 | + $repo/test_wavs/8k.wav | ||
| 124 | + | ||
| 92 | rm -rf $repo | 125 | rm -rf $repo |
| @@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url" | @@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 26 | pushd $repo | 26 | pushd $repo |
| 27 | git lfs pull --include "*.onnx" | 27 | git lfs pull --include "*.onnx" |
| 28 | +ls -lh *.onnx | ||
| 28 | popd | 29 | popd |
| 29 | 30 | ||
| 30 | waves=( | 31 | waves=( |
| 31 | -$repo/test_wavs/1089-134686-0001.wav | ||
| 32 | -$repo/test_wavs/1221-135766-0001.wav | ||
| 33 | -$repo/test_wavs/1221-135766-0002.wav | 32 | +$repo/test_wavs/0.wav |
| 33 | +$repo/test_wavs/1.wav | ||
| 34 | +$repo/test_wavs/8k.wav | ||
| 34 | ) | 35 | ) |
| 35 | 36 | ||
| 36 | for wave in ${waves[@]}; do | 37 | for wave in ${waves[@]}; do |
| @@ -43,6 +44,16 @@ for wave in ${waves[@]}; do | @@ -43,6 +44,16 @@ for wave in ${waves[@]}; do | ||
| 43 | 2 | 44 | 2 |
| 44 | done | 45 | done |
| 45 | 46 | ||
| 47 | +for wave in ${waves[@]}; do | ||
| 48 | + time $EXE \ | ||
| 49 | + $repo/tokens.txt \ | ||
| 50 | + $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 51 | + $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 52 | + $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 53 | + $wave \ | ||
| 54 | + 2 | ||
| 55 | +done | ||
| 56 | + | ||
| 46 | rm -rf $repo | 57 | rm -rf $repo |
| 47 | 58 | ||
| 48 | log "------------------------------------------------------------" | 59 | log "------------------------------------------------------------" |
| @@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url" | @@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 57 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 68 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 58 | pushd $repo | 69 | pushd $repo |
| 59 | git lfs pull --include "*.onnx" | 70 | git lfs pull --include "*.onnx" |
| 71 | +ls -lh *.onnx | ||
| 60 | popd | 72 | popd |
| 61 | 73 | ||
| 62 | waves=( | 74 | waves=( |
| 63 | $repo/test_wavs/0.wav | 75 | $repo/test_wavs/0.wav |
| 64 | $repo/test_wavs/1.wav | 76 | $repo/test_wavs/1.wav |
| 65 | -$repo/test_wavs/2.wav | 77 | +$repo/test_wavs/8k.wav |
| 66 | ) | 78 | ) |
| 67 | 79 | ||
| 68 | for wave in ${waves[@]}; do | 80 | for wave in ${waves[@]}; do |
| @@ -75,6 +87,16 @@ for wave in ${waves[@]}; do | @@ -75,6 +87,16 @@ for wave in ${waves[@]}; do | ||
| 75 | 2 | 87 | 2 |
| 76 | done | 88 | done |
| 77 | 89 | ||
| 90 | +for wave in ${waves[@]}; do | ||
| 91 | + time $EXE \ | ||
| 92 | + $repo/tokens.txt \ | ||
| 93 | + $repo/encoder-epoch-11-avg-1.int8.onnx \ | ||
| 94 | + $repo/decoder-epoch-11-avg-1.int8.onnx \ | ||
| 95 | + $repo/joiner-epoch-11-avg-1.int8.onnx \ | ||
| 96 | + $wave \ | ||
| 97 | + 2 | ||
| 98 | +done | ||
| 99 | + | ||
| 78 | rm -rf $repo | 100 | rm -rf $repo |
| 79 | 101 | ||
| 80 | log "------------------------------------------------------------" | 102 | log "------------------------------------------------------------" |
| @@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url" | @@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 89 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 111 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 90 | pushd $repo | 112 | pushd $repo |
| 91 | git lfs pull --include "*.onnx" | 113 | git lfs pull --include "*.onnx" |
| 114 | +ls -lh *.onnx | ||
| 92 | popd | 115 | popd |
| 93 | 116 | ||
| 94 | waves=( | 117 | waves=( |
| 95 | -$repo/test_wavs/1089-134686-0001.wav | ||
| 96 | -$repo/test_wavs/1221-135766-0001.wav | ||
| 97 | -$repo/test_wavs/1221-135766-0002.wav | 118 | +$repo/test_wavs/0.wav |
| 119 | +$repo/test_wavs/1.wav | ||
| 120 | +$repo/test_wavs/8k.wav | ||
| 98 | ) | 121 | ) |
| 99 | 122 | ||
| 100 | for wave in ${waves[@]}; do | 123 | for wave in ${waves[@]}; do |
| @@ -107,10 +130,22 @@ for wave in ${waves[@]}; do | @@ -107,10 +130,22 @@ for wave in ${waves[@]}; do | ||
| 107 | 2 | 130 | 2 |
| 108 | done | 131 | done |
| 109 | 132 | ||
| 133 | +# test int8 | ||
| 134 | +# | ||
| 135 | +for wave in ${waves[@]}; do | ||
| 136 | + time $EXE \ | ||
| 137 | + $repo/tokens.txt \ | ||
| 138 | + $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 139 | + $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 140 | + $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 141 | + $wave \ | ||
| 142 | + 2 | ||
| 143 | +done | ||
| 144 | + | ||
| 110 | rm -rf $repo | 145 | rm -rf $repo |
| 111 | 146 | ||
| 112 | log "------------------------------------------------------------" | 147 | log "------------------------------------------------------------" |
| 113 | -log "Run streaming Zipformer transducer (Bilingual, Chinse + English)" | 148 | +log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" |
| 114 | log "------------------------------------------------------------" | 149 | log "------------------------------------------------------------" |
| 115 | 150 | ||
| 116 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 | 151 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 |
| @@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url" | @@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 121 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 156 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 122 | pushd $repo | 157 | pushd $repo |
| 123 | git lfs pull --include "*.onnx" | 158 | git lfs pull --include "*.onnx" |
| 159 | +ls -lh *.onnx | ||
| 124 | popd | 160 | popd |
| 125 | 161 | ||
| 126 | waves=( | 162 | waves=( |
| @@ -128,7 +164,7 @@ $repo/test_wavs/0.wav | @@ -128,7 +164,7 @@ $repo/test_wavs/0.wav | ||
| 128 | $repo/test_wavs/1.wav | 164 | $repo/test_wavs/1.wav |
| 129 | $repo/test_wavs/2.wav | 165 | $repo/test_wavs/2.wav |
| 130 | $repo/test_wavs/3.wav | 166 | $repo/test_wavs/3.wav |
| 131 | -$repo/test_wavs/4.wav | 167 | +$repo/test_wavs/8k.wav |
| 132 | ) | 168 | ) |
| 133 | 169 | ||
| 134 | for wave in ${waves[@]}; do | 170 | for wave in ${waves[@]}; do |
| @@ -141,6 +177,16 @@ for wave in ${waves[@]}; do | @@ -141,6 +177,16 @@ for wave in ${waves[@]}; do | ||
| 141 | 2 | 177 | 2 |
| 142 | done | 178 | done |
| 143 | 179 | ||
| 180 | +for wave in ${waves[@]}; do | ||
| 181 | + time $EXE \ | ||
| 182 | + $repo/tokens.txt \ | ||
| 183 | + $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 184 | + $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 185 | + $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 186 | + $wave \ | ||
| 187 | + 2 | ||
| 188 | +done | ||
| 189 | + | ||
| 144 | # Decode a URL | 190 | # Decode a URL |
| 145 | if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | 191 | if [ $EXE == "sherpa-onnx-ffmpeg" ]; then |
| 146 | time $EXE \ | 192 | time $EXE \ |
| @@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | @@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | ||
| 152 | 2 | 198 | 2 |
| 153 | fi | 199 | fi |
| 154 | 200 | ||
| 201 | +if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | ||
| 202 | + time $EXE \ | ||
| 203 | + $repo/tokens.txt \ | ||
| 204 | + $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 205 | + $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 206 | + $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 207 | + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ | ||
| 208 | + 2 | ||
| 209 | +fi | ||
| 210 | + | ||
| 155 | rm -rf $repo | 211 | rm -rf $repo |
| @@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh | @@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh | ||
| 46 | run-sherpa-onnx-offline-transducer.sh | 46 | run-sherpa-onnx-offline-transducer.sh |
| 47 | sherpa-onnx-paraformer-zh-2023-03-28 | 47 | sherpa-onnx-paraformer-zh-2023-03-28 |
| 48 | run-offline-websocket-server-paraformer.sh | 48 | run-offline-websocket-server-paraformer.sh |
| 49 | +run-*int8.sh | ||
| 50 | +a.sh | ||
| 51 | +run-offline-websocket-client-*.sh | ||
| 52 | +run-sherpa-onnx-*.sh | ||
| 53 | +sherpa-onnx-zipformer-en-2023-03-30 |
| @@ -18,139 +18,13 @@ | @@ -18,139 +18,13 @@ | ||
| 18 | #include <cstring> | 18 | #include <cstring> |
| 19 | #include <fstream> | 19 | #include <fstream> |
| 20 | #include <iomanip> | 20 | #include <iomanip> |
| 21 | -#include <limits> | ||
| 22 | -#include <type_traits> | ||
| 23 | -#include <unordered_map> | ||
| 24 | 21 | ||
| 25 | #include "sherpa-onnx/csrc/log.h" | 22 | #include "sherpa-onnx/csrc/log.h" |
| 26 | - | ||
| 27 | -#ifdef _MSC_VER | ||
| 28 | -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ | ||
| 29 | - _strtoi64(cur_cstr, end_cstr, 10); | ||
| 30 | -#else | ||
| 31 | -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); | ||
| 32 | -#endif | 23 | +#include "sherpa-onnx/csrc/macros.h" |
| 24 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 33 | 25 | ||
| 34 | namespace sherpa_onnx { | 26 | namespace sherpa_onnx { |
| 35 | 27 | ||
| 36 | -/// Converts a string into an integer via strtoll and returns false if there was | ||
| 37 | -/// any kind of problem (i.e. the string was not an integer or contained extra | ||
| 38 | -/// non-whitespace junk, or the integer was too large to fit into the type it is | ||
| 39 | -/// being converted into). Only sets *out if everything was OK and it returns | ||
| 40 | -/// true. | ||
| 41 | -template <class Int> | ||
| 42 | -bool ConvertStringToInteger(const std::string &str, Int *out) { | ||
| 43 | - // copied from kaldi/src/util/text-util.h | ||
| 44 | - static_assert(std::is_integral<Int>::value, ""); | ||
| 45 | - const char *this_str = str.c_str(); | ||
| 46 | - char *end = nullptr; | ||
| 47 | - errno = 0; | ||
| 48 | - int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); | ||
| 49 | - if (end != this_str) { | ||
| 50 | - while (isspace(*end)) ++end; | ||
| 51 | - } | ||
| 52 | - if (end == this_str || *end != '\0' || errno != 0) return false; | ||
| 53 | - Int iInt = static_cast<Int>(i); | ||
| 54 | - if (static_cast<int64_t>(iInt) != i || | ||
| 55 | - (i < 0 && !std::numeric_limits<Int>::is_signed)) { | ||
| 56 | - return false; | ||
| 57 | - } | ||
| 58 | - *out = iInt; | ||
| 59 | - return true; | ||
| 60 | -} | ||
| 61 | - | ||
| 62 | -// copied from kaldi/src/util/text-util.cc | ||
| 63 | -template <class T> | ||
| 64 | -class NumberIstream { | ||
| 65 | - public: | ||
| 66 | - explicit NumberIstream(std::istream &i) : in_(i) {} | ||
| 67 | - | ||
| 68 | - NumberIstream &operator>>(T &x) { | ||
| 69 | - if (!in_.good()) return *this; | ||
| 70 | - in_ >> x; | ||
| 71 | - if (!in_.fail() && RemainderIsOnlySpaces()) return *this; | ||
| 72 | - return ParseOnFail(&x); | ||
| 73 | - } | ||
| 74 | - | ||
| 75 | - private: | ||
| 76 | - std::istream &in_; | ||
| 77 | - | ||
| 78 | - bool RemainderIsOnlySpaces() { | ||
| 79 | - if (in_.tellg() != std::istream::pos_type(-1)) { | ||
| 80 | - std::string rem; | ||
| 81 | - in_ >> rem; | ||
| 82 | - | ||
| 83 | - if (rem.find_first_not_of(' ') != std::string::npos) { | ||
| 84 | - // there is not only spaces | ||
| 85 | - return false; | ||
| 86 | - } | ||
| 87 | - } | ||
| 88 | - | ||
| 89 | - in_.clear(); | ||
| 90 | - return true; | ||
| 91 | - } | ||
| 92 | - | ||
| 93 | - NumberIstream &ParseOnFail(T *x) { | ||
| 94 | - std::string str; | ||
| 95 | - in_.clear(); | ||
| 96 | - in_.seekg(0); | ||
| 97 | - // If the stream is broken even before trying | ||
| 98 | - // to read from it or if there are many tokens, | ||
| 99 | - // it's pointless to try. | ||
| 100 | - if (!(in_ >> str) || !RemainderIsOnlySpaces()) { | ||
| 101 | - in_.setstate(std::ios_base::failbit); | ||
| 102 | - return *this; | ||
| 103 | - } | ||
| 104 | - | ||
| 105 | - std::unordered_map<std::string, T> inf_nan_map; | ||
| 106 | - // we'll keep just uppercase values. | ||
| 107 | - inf_nan_map["INF"] = std::numeric_limits<T>::infinity(); | ||
| 108 | - inf_nan_map["+INF"] = std::numeric_limits<T>::infinity(); | ||
| 109 | - inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity(); | ||
| 110 | - inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 111 | - inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 112 | - inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity(); | ||
| 113 | - inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 114 | - inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 115 | - inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 116 | - // MSVC | ||
| 117 | - inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity(); | ||
| 118 | - inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity(); | ||
| 119 | - inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 120 | - inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 121 | - | ||
| 122 | - std::transform(str.begin(), str.end(), str.begin(), ::toupper); | ||
| 123 | - | ||
| 124 | - if (inf_nan_map.find(str) != inf_nan_map.end()) { | ||
| 125 | - *x = inf_nan_map[str]; | ||
| 126 | - } else { | ||
| 127 | - in_.setstate(std::ios_base::failbit); | ||
| 128 | - } | ||
| 129 | - | ||
| 130 | - return *this; | ||
| 131 | - } | ||
| 132 | -}; | ||
| 133 | - | ||
| 134 | -/// ConvertStringToReal converts a string into either float or double | ||
| 135 | -/// and returns false if there was any kind of problem (i.e. the string | ||
| 136 | -/// was not a floating point number or contained extra non-whitespace junk). | ||
| 137 | -/// Be careful- this function will successfully read inf's or nan's. | ||
| 138 | -template <typename T> | ||
| 139 | -bool ConvertStringToReal(const std::string &str, T *out) { | ||
| 140 | - std::istringstream iss(str); | ||
| 141 | - | ||
| 142 | - NumberIstream<T> i(iss); | ||
| 143 | - | ||
| 144 | - i >> *out; | ||
| 145 | - | ||
| 146 | - if (iss.fail()) { | ||
| 147 | - // Number conversion failed. | ||
| 148 | - return false; | ||
| 149 | - } | ||
| 150 | - | ||
| 151 | - return true; | ||
| 152 | -} | ||
| 153 | - | ||
| 154 | ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) | 28 | ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) |
| 155 | : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { | 29 | : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { |
| 156 | if (po != nullptr && po->other_parser_ != nullptr) { | 30 | if (po != nullptr && po->other_parser_ != nullptr) { |
| @@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr, | @@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr, | ||
| 219 | std::string idx = name; | 93 | std::string idx = name; |
| 220 | NormalizeArgName(&idx); | 94 | NormalizeArgName(&idx); |
| 221 | if (doc_map_.find(idx) != doc_map_.end()) { | 95 | if (doc_map_.find(idx) != doc_map_.end()) { |
| 222 | - SHERPA_ONNX_LOG(WARNING) | ||
| 223 | - << "Registering option twice, ignoring second time: " << name; | 96 | + SHERPA_ONNX_LOGE("Registering option twice, ignoring second time: %s", |
| 97 | + name.c_str()); | ||
| 224 | } else { | 98 | } else { |
| 225 | this->RegisterSpecific(name, idx, ptr, doc, is_standard); | 99 | this->RegisterSpecific(name, idx, ptr, doc, is_standard); |
| 226 | } | 100 | } |
| @@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name, | @@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name, | ||
| 289 | 163 | ||
| 290 | void ParseOptions::DisableOption(const std::string &name) { | 164 | void ParseOptions::DisableOption(const std::string &name) { |
| 291 | if (argv_ != nullptr) { | 165 | if (argv_ != nullptr) { |
| 292 | - SHERPA_ONNX_LOG(FATAL) | ||
| 293 | - << "DisableOption must not be called after calling Read()."; | 166 | + SHERPA_ONNX_LOGE("DisableOption must not be called after calling Read()."); |
| 167 | + exit(-1); | ||
| 294 | } | 168 | } |
| 295 | if (doc_map_.erase(name) == 0) { | 169 | if (doc_map_.erase(name) == 0) { |
| 296 | - SHERPA_ONNX_LOG(FATAL) << "Option " << name | ||
| 297 | - << " was not registered so cannot be disabled: "; | 170 | + SHERPA_ONNX_LOGE("Option %s was not registered so cannot be disabled: ", |
| 171 | + name.c_str()); | ||
| 172 | + exit(-1); | ||
| 298 | } | 173 | } |
| 299 | bool_map_.erase(name); | 174 | bool_map_.erase(name); |
| 300 | int_map_.erase(name); | 175 | int_map_.erase(name); |
| @@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); } | @@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); } | ||
| 308 | 183 | ||
| 309 | std::string ParseOptions::GetArg(int i) const { | 184 | std::string ParseOptions::GetArg(int i) const { |
| 310 | if (i < 1 || i > static_cast<int>(positional_args_.size())) { | 185 | if (i < 1 || i > static_cast<int>(positional_args_.size())) { |
| 311 | - SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; | 186 | + SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i); |
| 187 | + exit(-1); | ||
| 312 | } | 188 | } |
| 313 | 189 | ||
| 314 | return positional_args_[i - 1]; | 190 | return positional_args_[i - 1]; |
| @@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) { | @@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) { | ||
| 460 | Trim(&value); | 336 | Trim(&value); |
| 461 | if (!SetOption(key, value, has_equal_sign)) { | 337 | if (!SetOption(key, value, has_equal_sign)) { |
| 462 | PrintUsage(true); | 338 | PrintUsage(true); |
| 463 | - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i]; | 339 | + SHERPA_ONNX_LOGE("Invalid option %s", argv[i]); |
| 340 | + exit(-1); | ||
| 464 | } | 341 | } |
| 465 | } else { | 342 | } else { |
| 466 | break; | 343 | break; |
| @@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) { | @@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) { | ||
| 481 | std::ostringstream strm; | 358 | std::ostringstream strm; |
| 482 | for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; | 359 | for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; |
| 483 | strm << '\n'; | 360 | strm << '\n'; |
| 484 | - SHERPA_ONNX_LOG(INFO) << strm.str(); | 361 | + SHERPA_ONNX_LOGE("%s", strm.str().c_str()); |
| 485 | } | 362 | } |
| 486 | return i; | 363 | return i; |
| 487 | } | 364 | } |
| @@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { | @@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { | ||
| 522 | os << strm.str(); | 399 | os << strm.str(); |
| 523 | } | 400 | } |
| 524 | 401 | ||
| 525 | - SHERPA_ONNX_LOG(INFO) << os.str(); | 402 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 526 | } | 403 | } |
| 527 | 404 | ||
| 528 | void ParseOptions::PrintConfig(std::ostream &os) const { | 405 | void ParseOptions::PrintConfig(std::ostream &os) const { |
| @@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const { | @@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const { | ||
| 544 | } else if (string_map_.end() != string_map_.find(key)) { | 421 | } else if (string_map_.end() != string_map_.find(key)) { |
| 545 | os << "'" << *string_map_.at(key) << "'"; | 422 | os << "'" << *string_map_.at(key) << "'"; |
| 546 | } else { | 423 | } else { |
| 547 | - SHERPA_ONNX_LOG(FATAL) | ||
| 548 | - << "PrintConfig: unrecognized option " << key << "[code error]"; | 424 | + SHERPA_ONNX_LOGE("PrintConfig: unrecognized option %s [code error]", |
| 425 | + key.c_str()); | ||
| 426 | + exit(-1); | ||
| 549 | } | 427 | } |
| 550 | os << '\n'; | 428 | os << '\n'; |
| 551 | } | 429 | } |
| @@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const { | @@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const { | ||
| 555 | void ParseOptions::ReadConfigFile(const std::string &filename) { | 433 | void ParseOptions::ReadConfigFile(const std::string &filename) { |
| 556 | std::ifstream is(filename.c_str(), std::ifstream::in); | 434 | std::ifstream is(filename.c_str(), std::ifstream::in); |
| 557 | if (!is.good()) { | 435 | if (!is.good()) { |
| 558 | - SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename; | 436 | + SHERPA_ONNX_LOGE("Cannot open config file: %s", filename.c_str()); |
| 437 | + exit(-1); | ||
| 559 | } | 438 | } |
| 560 | 439 | ||
| 561 | std::string line, key, value; | 440 | std::string line, key, value; |
| @@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { | @@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { | ||
| 572 | if (line.length() == 0) continue; | 451 | if (line.length() == 0) continue; |
| 573 | 452 | ||
| 574 | if (line.substr(0, 2) != "--") { | 453 | if (line.substr(0, 2) != "--") { |
| 575 | - SHERPA_ONNX_LOG(FATAL) | ||
| 576 | - << "Reading config file " << filename << ": line " << line_number | ||
| 577 | - << " does not look like a line " | ||
| 578 | - << "from a Kaldi command-line program's config file: should " | ||
| 579 | - << "be of the form --x=y. Note: config files intended to " | ||
| 580 | - << "be sourced by shell scripts lack the '--'."; | 454 | + SHERPA_ONNX_LOGE( |
| 455 | + "Reading config file %s: line %d does not look like a line " | ||
| 456 | + "from a sherpa-onnx command-line program's config file: should " | ||
| 457 | + "be of the form --x=y. Note: config files intended to " | ||
| 458 | + "be sourced by shell scripts lack the '--'.", | ||
| 459 | + filename.c_str(), line_number); | ||
| 460 | + exit(-1); | ||
| 581 | } | 461 | } |
| 582 | 462 | ||
| 583 | // parse option | 463 | // parse option |
| @@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { | @@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { | ||
| 587 | Trim(&value); | 467 | Trim(&value); |
| 588 | if (!SetOption(key, value, has_equal_sign)) { | 468 | if (!SetOption(key, value, has_equal_sign)) { |
| 589 | PrintUsage(true); | 469 | PrintUsage(true); |
| 590 | - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file " | ||
| 591 | - << filename << ": line " << line_number; | 470 | + SHERPA_ONNX_LOGE("Invalid option %s in config file %s: line %d", |
| 471 | + line.c_str(), filename.c_str(), line_number); | ||
| 472 | + exit(-1); | ||
| 592 | } | 473 | } |
| 593 | } | 474 | } |
| 594 | } | 475 | } |
| @@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key, | @@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key, | ||
| 605 | *has_equal_sign = false; | 486 | *has_equal_sign = false; |
| 606 | } else if (pos == 2) { // we also don't allow empty keys: --=value | 487 | } else if (pos == 2) { // we also don't allow empty keys: --=value |
| 607 | PrintUsage(true); | 488 | PrintUsage(true); |
| 608 | - SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in; | 489 | + SHERPA_ONNX_LOGE("Invalid option (no key): %s", in.c_str()); |
| 490 | + exit(-1); | ||
| 609 | } else { // normal case: --option=value | 491 | } else { // normal case: --option=value |
| 610 | *key = in.substr(2, pos - 2); // 2 because starts with --. | 492 | *key = in.substr(2, pos - 2); // 2 because starts with --. |
| 611 | *value = in.substr(pos + 1); | 493 | *value = in.substr(pos + 1); |
| @@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, | @@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, | ||
| 646 | bool has_equal_sign) { | 528 | bool has_equal_sign) { |
| 647 | if (bool_map_.end() != bool_map_.find(key)) { | 529 | if (bool_map_.end() != bool_map_.find(key)) { |
| 648 | if (has_equal_sign && value == "") { | 530 | if (has_equal_sign && value == "") { |
| 649 | - SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "="; | 531 | + SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str()); |
| 532 | + exit(-1); | ||
| 650 | } | 533 | } |
| 651 | *(bool_map_[key]) = ToBool(value); | 534 | *(bool_map_[key]) = ToBool(value); |
| 652 | } else if (int_map_.end() != int_map_.find(key)) { | 535 | } else if (int_map_.end() != int_map_.find(key)) { |
| @@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, | @@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, | ||
| 659 | *(double_map_[key]) = ToDouble(value); | 542 | *(double_map_[key]) = ToDouble(value); |
| 660 | } else if (string_map_.end() != string_map_.find(key)) { | 543 | } else if (string_map_.end() != string_map_.find(key)) { |
| 661 | if (!has_equal_sign) { | 544 | if (!has_equal_sign) { |
| 662 | - SHERPA_ONNX_LOG(FATAL) | ||
| 663 | - << "Invalid option --" << key << " (option format is --x=y)."; | 545 | + SHERPA_ONNX_LOGE("Invalid option --%s (option format is --x=y).", |
| 546 | + key.c_str()); | ||
| 547 | + exit(-1); | ||
| 664 | } | 548 | } |
| 665 | *(string_map_[key]) = value; | 549 | *(string_map_[key]) = value; |
| 666 | } else { | 550 | } else { |
| @@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const { | @@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const { | ||
| 683 | } | 567 | } |
| 684 | // if it is neither true nor false: | 568 | // if it is neither true nor false: |
| 685 | PrintUsage(true); | 569 | PrintUsage(true); |
| 686 | - SHERPA_ONNX_LOG(FATAL) | ||
| 687 | - << "Invalid format for boolean argument [expected true or false]: " | ||
| 688 | - << str; | 570 | + SHERPA_ONNX_LOGE( |
| 571 | + "Invalid format for boolean argument [expected true or false]: %s", | ||
| 572 | + str.c_str()); | ||
| 573 | + exit(-1); | ||
| 689 | return false; // never reached | 574 | return false; // never reached |
| 690 | } | 575 | } |
| 691 | 576 | ||
| 692 | int32_t ParseOptions::ToInt(const std::string &str) const { | 577 | int32_t ParseOptions::ToInt(const std::string &str) const { |
| 693 | int32_t ret = 0; | 578 | int32_t ret = 0; |
| 694 | - if (!ConvertStringToInteger(str, &ret)) | ||
| 695 | - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; | 579 | + if (!ConvertStringToInteger(str, &ret)) { |
| 580 | + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); | ||
| 581 | + exit(-1); | ||
| 582 | + } | ||
| 696 | return ret; | 583 | return ret; |
| 697 | } | 584 | } |
| 698 | 585 | ||
| 699 | uint32_t ParseOptions::ToUint(const std::string &str) const { | 586 | uint32_t ParseOptions::ToUint(const std::string &str) const { |
| 700 | uint32_t ret = 0; | 587 | uint32_t ret = 0; |
| 701 | - if (!ConvertStringToInteger(str, &ret)) | ||
| 702 | - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; | 588 | + if (!ConvertStringToInteger(str, &ret)) { |
| 589 | + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); | ||
| 590 | + exit(-1); | ||
| 591 | + } | ||
| 703 | return ret; | 592 | return ret; |
| 704 | } | 593 | } |
| 705 | 594 | ||
| 706 | float ParseOptions::ToFloat(const std::string &str) const { | 595 | float ParseOptions::ToFloat(const std::string &str) const { |
| 707 | float ret; | 596 | float ret; |
| 708 | - if (!ConvertStringToReal(str, &ret)) | ||
| 709 | - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; | 597 | + if (!ConvertStringToReal(str, &ret)) { |
| 598 | + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); | ||
| 599 | + exit(-1); | ||
| 600 | + } | ||
| 710 | return ret; | 601 | return ret; |
| 711 | } | 602 | } |
| 712 | 603 | ||
| 713 | double ParseOptions::ToDouble(const std::string &str) const { | 604 | double ParseOptions::ToDouble(const std::string &str) const { |
| 714 | double ret; | 605 | double ret; |
| 715 | - if (!ConvertStringToReal(str, &ret)) | ||
| 716 | - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; | 606 | + if (!ConvertStringToReal(str, &ret)) { |
| 607 | + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); | ||
| 608 | + exit(-1); | ||
| 609 | + } | ||
| 717 | return ret; | 610 | return ret; |
| 718 | } | 611 | } |
| 719 | 612 |
| @@ -7,7 +7,11 @@ | @@ -7,7 +7,11 @@ | ||
| 7 | 7 | ||
| 8 | #include <assert.h> | 8 | #include <assert.h> |
| 9 | 9 | ||
| 10 | +#include <algorithm> | ||
| 11 | +#include <limits> | ||
| 12 | +#include <sstream> | ||
| 10 | #include <string> | 13 | #include <string> |
| 14 | +#include <unordered_map> | ||
| 11 | #include <vector> | 15 | #include <vector> |
| 12 | 16 | ||
| 13 | // This file is copied/modified from | 17 | // This file is copied/modified from |
| @@ -15,6 +19,102 @@ | @@ -15,6 +19,102 @@ | ||
| 15 | 19 | ||
| 16 | namespace sherpa_onnx { | 20 | namespace sherpa_onnx { |
| 17 | 21 | ||
| 22 | +// copied from kaldi/src/util/text-util.cc | ||
| 23 | +template <class T> | ||
| 24 | +class NumberIstream { | ||
| 25 | + public: | ||
| 26 | + explicit NumberIstream(std::istream &i) : in_(i) {} | ||
| 27 | + | ||
| 28 | + NumberIstream &operator>>(T &x) { | ||
| 29 | + if (!in_.good()) return *this; | ||
| 30 | + in_ >> x; | ||
| 31 | + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; | ||
| 32 | + return ParseOnFail(&x); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + private: | ||
| 36 | + std::istream &in_; | ||
| 37 | + | ||
| 38 | + bool RemainderIsOnlySpaces() { | ||
| 39 | + if (in_.tellg() != std::istream::pos_type(-1)) { | ||
| 40 | + std::string rem; | ||
| 41 | + in_ >> rem; | ||
| 42 | + | ||
| 43 | + if (rem.find_first_not_of(' ') != std::string::npos) { | ||
| 44 | + // there is not only spaces | ||
| 45 | + return false; | ||
| 46 | + } | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + in_.clear(); | ||
| 50 | + return true; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + NumberIstream &ParseOnFail(T *x) { | ||
| 54 | + std::string str; | ||
| 55 | + in_.clear(); | ||
| 56 | + in_.seekg(0); | ||
| 57 | + // If the stream is broken even before trying | ||
| 58 | + // to read from it or if there are many tokens, | ||
| 59 | + // it's pointless to try. | ||
| 60 | + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { | ||
| 61 | + in_.setstate(std::ios_base::failbit); | ||
| 62 | + return *this; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + std::unordered_map<std::string, T> inf_nan_map; | ||
| 66 | + // we'll keep just uppercase values. | ||
| 67 | + inf_nan_map["INF"] = std::numeric_limits<T>::infinity(); | ||
| 68 | + inf_nan_map["+INF"] = std::numeric_limits<T>::infinity(); | ||
| 69 | + inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity(); | ||
| 70 | + inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 71 | + inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity(); | ||
| 72 | + inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity(); | ||
| 73 | + inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 74 | + inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 75 | + inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 76 | + // MSVC | ||
| 77 | + inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity(); | ||
| 78 | + inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity(); | ||
| 79 | + inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN(); | ||
| 80 | + inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN(); | ||
| 81 | + | ||
| 82 | + std::transform(str.begin(), str.end(), str.begin(), ::toupper); | ||
| 83 | + | ||
| 84 | + if (inf_nan_map.find(str) != inf_nan_map.end()) { | ||
| 85 | + *x = inf_nan_map[str]; | ||
| 86 | + } else { | ||
| 87 | + in_.setstate(std::ios_base::failbit); | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + return *this; | ||
| 91 | + } | ||
| 92 | +}; | ||
| 93 | + | ||
| 94 | +/// ConvertStringToReal converts a string into either float or double | ||
| 95 | +/// and returns false if there was any kind of problem (i.e. the string | ||
| 96 | +/// was not a floating point number or contained extra non-whitespace junk). | ||
| 97 | +/// Be careful- this function will successfully read inf's or nan's. | ||
| 98 | +template <typename T> | ||
| 99 | +bool ConvertStringToReal(const std::string &str, T *out) { | ||
| 100 | + std::istringstream iss(str); | ||
| 101 | + | ||
| 102 | + NumberIstream<T> i(iss); | ||
| 103 | + | ||
| 104 | + i >> *out; | ||
| 105 | + | ||
| 106 | + if (iss.fail()) { | ||
| 107 | + // Number conversion failed. | ||
| 108 | + return false; | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + return true; | ||
| 112 | +} | ||
| 113 | + | ||
| 114 | +template bool ConvertStringToReal<float>(const std::string &str, float *out); | ||
| 115 | + | ||
| 116 | +template bool ConvertStringToReal<double>(const std::string &str, double *out); | ||
| 117 | + | ||
| 18 | void SplitStringToVector(const std::string &full, const char *delim, | 118 | void SplitStringToVector(const std::string &full, const char *delim, |
| 19 | bool omit_empty_strings, | 119 | bool omit_empty_strings, |
| 20 | std::vector<std::string> *out) { | 120 | std::vector<std::string> *out) { |
| @@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim, | @@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 43 | out->resize(split.size()); | 143 | out->resize(split.size()); |
| 44 | for (size_t i = 0; i < split.size(); ++i) { | 144 | for (size_t i = 0; i < split.size(); ++i) { |
| 45 | // assume atof never fails | 145 | // assume atof never fails |
| 46 | - (*out)[i] = atof(split[i].c_str()); | 146 | + F f = 0; |
| 147 | + if (!ConvertStringToReal(split[i], &f)) return false; | ||
| 148 | + (*out)[i] = f; | ||
| 47 | } | 149 | } |
| 48 | return true; | 150 | return true; |
| 49 | } | 151 | } |
| @@ -6,7 +6,9 @@ | @@ -6,7 +6,9 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | 6 | #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ |
| 7 | #include <stdlib.h> | 7 | #include <stdlib.h> |
| 8 | 8 | ||
| 9 | +#include <limits> | ||
| 9 | #include <string> | 10 | #include <string> |
| 11 | +#include <type_traits> | ||
| 10 | #include <vector> | 12 | #include <vector> |
| 11 | 13 | ||
| 12 | #ifdef _MSC_VER | 14 | #ifdef _MSC_VER |
| @@ -21,6 +23,32 @@ | @@ -21,6 +23,32 @@ | ||
| 21 | 23 | ||
| 22 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 23 | 25 | ||
| 26 | +/// Converts a string into an integer via strtoll and returns false if there was | ||
| 27 | +/// any kind of problem (i.e. the string was not an integer or contained extra | ||
| 28 | +/// non-whitespace junk, or the integer was too large to fit into the type it is | ||
| 29 | +/// being converted into). Only sets *out if everything was OK and it returns | ||
| 30 | +/// true. | ||
| 31 | +template <class Int> | ||
| 32 | +bool ConvertStringToInteger(const std::string &str, Int *out) { | ||
| 33 | + // copied from kaldi/src/util/text-util.h | ||
| 34 | + static_assert(std::is_integral<Int>::value, ""); | ||
| 35 | + const char *this_str = str.c_str(); | ||
| 36 | + char *end = nullptr; | ||
| 37 | + errno = 0; | ||
| 38 | + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); | ||
| 39 | + if (end != this_str) { | ||
| 40 | + while (isspace(*end)) ++end; | ||
| 41 | + } | ||
| 42 | + if (end == this_str || *end != '\0' || errno != 0) return false; | ||
| 43 | + Int iInt = static_cast<Int>(i); | ||
| 44 | + if (static_cast<int64_t>(iInt) != i || | ||
| 45 | + (i < 0 && !std::numeric_limits<Int>::is_signed)) { | ||
| 46 | + return false; | ||
| 47 | + } | ||
| 48 | + *out = iInt; | ||
| 49 | + return true; | ||
| 50 | +} | ||
| 51 | + | ||
| 24 | /// Split a string using any of the single character delimiters. | 52 | /// Split a string using any of the single character delimiters. |
| 25 | /// If omit_empty_strings == true, the output will contain any | 53 | /// If omit_empty_strings == true, the output will contain any |
| 26 | /// nonempty strings after splitting on any of the | 54 | /// nonempty strings after splitting on any of the |
| @@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim, | @@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 86 | bool omit_empty_strings, // typically false | 114 | bool omit_empty_strings, // typically false |
| 87 | std::vector<F> *out); | 115 | std::vector<F> *out); |
| 88 | 116 | ||
| 117 | +// This is defined for F = float and double. | ||
| 118 | +template <typename T> | ||
| 119 | +bool ConvertStringToReal(const std::string &str, T *out); | ||
| 120 | + | ||
| 89 | } // namespace sherpa_onnx | 121 | } // namespace sherpa_onnx |
| 90 | 122 | ||
| 91 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | 123 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ |
-
请 注册 或 登录 后发表评论