Fangjun Kuang
Committed by GitHub

Test int8 models (#107)

* Test int8 models

* Fix displaying help messages

* small fixes

* Fix jni test
@@ -35,7 +35,7 @@ fun main() { @@ -35,7 +35,7 @@ fun main() {
35 35
36 var objArray = WaveReader.readWave( 36 var objArray = WaveReader.readWave(
37 assetManager = AssetManager(), 37 assetManager = AssetManager(),
38 - filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", 38 + filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
39 ) 39 )
40 var samples : FloatArray = objArray[0] as FloatArray 40 var samples : FloatArray = objArray[0] as FloatArray
41 var sampleRate : Int = objArray[1] as Int 41 var sampleRate : Int = objArray[1] as Int
@@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url" @@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url"
25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
26 pushd $repo 26 pushd $repo
27 git lfs pull --include "*.onnx" 27 git lfs pull --include "*.onnx"
  28 +ls -lh *.onnx
28 popd 29 popd
29 30
30 time $EXE \ 31 time $EXE \
@@ -37,6 +38,16 @@ time $EXE \ @@ -37,6 +38,16 @@ time $EXE \
37 $repo/test_wavs/1.wav \ 38 $repo/test_wavs/1.wav \
38 $repo/test_wavs/8k.wav 39 $repo/test_wavs/8k.wav
39 40
  41 +time $EXE \
  42 + --tokens=$repo/tokens.txt \
  43 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  44 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  45 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  46 + --num-threads=2 \
  47 + $repo/test_wavs/0.wav \
  48 + $repo/test_wavs/1.wav \
  49 + $repo/test_wavs/8k.wav
  50 +
40 rm -rf $repo 51 rm -rf $repo
41 52
42 log "------------------------------------------------------------" 53 log "------------------------------------------------------------"
@@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url" @@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url"
51 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 62 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
52 pushd $repo 63 pushd $repo
53 git lfs pull --include "*.onnx" 64 git lfs pull --include "*.onnx"
  65 +ls -lh *.onnx
54 popd 66 popd
55 67
56 time $EXE \ 68 time $EXE \
@@ -63,6 +75,16 @@ time $EXE \ @@ -63,6 +75,16 @@ time $EXE \
63 $repo/test_wavs/1.wav \ 75 $repo/test_wavs/1.wav \
64 $repo/test_wavs/8k.wav 76 $repo/test_wavs/8k.wav
65 77
  78 +time $EXE \
  79 + --tokens=$repo/tokens.txt \
  80 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  81 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  82 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  83 + --num-threads=2 \
  84 + $repo/test_wavs/0.wav \
  85 + $repo/test_wavs/1.wav \
  86 + $repo/test_wavs/8k.wav
  87 +
66 rm -rf $repo 88 rm -rf $repo
67 89
68 log "------------------------------------------------------------" 90 log "------------------------------------------------------------"
@@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url" @@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url"
77 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 99 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
78 pushd $repo 100 pushd $repo
79 git lfs pull --include "*.onnx" 101 git lfs pull --include "*.onnx"
  102 +ls -lh *.onnx
80 popd 103 popd
81 104
82 time $EXE \ 105 time $EXE \
@@ -89,4 +112,14 @@ time $EXE \ @@ -89,4 +112,14 @@ time $EXE \
89 $repo/test_wavs/2.wav \ 112 $repo/test_wavs/2.wav \
90 $repo/test_wavs/8k.wav 113 $repo/test_wavs/8k.wav
91 114
  115 +time $EXE \
  116 + --tokens=$repo/tokens.txt \
  117 + --paraformer=$repo/model.int8.onnx \
  118 + --num-threads=2 \
  119 + --decoding-method=greedy_search \
  120 + $repo/test_wavs/0.wav \
  121 + $repo/test_wavs/1.wav \
  122 + $repo/test_wavs/2.wav \
  123 + $repo/test_wavs/8k.wav
  124 +
92 rm -rf $repo 125 rm -rf $repo
@@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url" @@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url"
25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
26 pushd $repo 26 pushd $repo
27 git lfs pull --include "*.onnx" 27 git lfs pull --include "*.onnx"
  28 +ls -lh *.onnx
28 popd 29 popd
29 30
30 waves=( 31 waves=(
31 -$repo/test_wavs/1089-134686-0001.wav  
32 -$repo/test_wavs/1221-135766-0001.wav  
33 -$repo/test_wavs/1221-135766-0002.wav 32 +$repo/test_wavs/0.wav
  33 +$repo/test_wavs/1.wav
  34 +$repo/test_wavs/8k.wav
34 ) 35 )
35 36
36 for wave in ${waves[@]}; do 37 for wave in ${waves[@]}; do
@@ -43,6 +44,16 @@ for wave in ${waves[@]}; do @@ -43,6 +44,16 @@ for wave in ${waves[@]}; do
43 2 44 2
44 done 45 done
45 46
  47 +for wave in ${waves[@]}; do
  48 + time $EXE \
  49 + $repo/tokens.txt \
  50 + $repo/encoder-epoch-99-avg-1.int8.onnx \
  51 + $repo/decoder-epoch-99-avg-1.int8.onnx \
  52 + $repo/joiner-epoch-99-avg-1.int8.onnx \
  53 + $wave \
  54 + 2
  55 +done
  56 +
46 rm -rf $repo 57 rm -rf $repo
47 58
48 log "------------------------------------------------------------" 59 log "------------------------------------------------------------"
@@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url" @@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url"
57 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 68 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
58 pushd $repo 69 pushd $repo
59 git lfs pull --include "*.onnx" 70 git lfs pull --include "*.onnx"
  71 +ls -lh *.onnx
60 popd 72 popd
61 73
62 waves=( 74 waves=(
63 $repo/test_wavs/0.wav 75 $repo/test_wavs/0.wav
64 $repo/test_wavs/1.wav 76 $repo/test_wavs/1.wav
65 -$repo/test_wavs/2.wav 77 +$repo/test_wavs/8k.wav
66 ) 78 )
67 79
68 for wave in ${waves[@]}; do 80 for wave in ${waves[@]}; do
@@ -75,6 +87,16 @@ for wave in ${waves[@]}; do @@ -75,6 +87,16 @@ for wave in ${waves[@]}; do
75 2 87 2
76 done 88 done
77 89
  90 +for wave in ${waves[@]}; do
  91 + time $EXE \
  92 + $repo/tokens.txt \
  93 + $repo/encoder-epoch-11-avg-1.int8.onnx \
  94 + $repo/decoder-epoch-11-avg-1.int8.onnx \
  95 + $repo/joiner-epoch-11-avg-1.int8.onnx \
  96 + $wave \
  97 + 2
  98 +done
  99 +
78 rm -rf $repo 100 rm -rf $repo
79 101
80 log "------------------------------------------------------------" 102 log "------------------------------------------------------------"
@@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url" @@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url"
89 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 111 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
90 pushd $repo 112 pushd $repo
91 git lfs pull --include "*.onnx" 113 git lfs pull --include "*.onnx"
  114 +ls -lh *.onnx
92 popd 115 popd
93 116
94 waves=( 117 waves=(
95 -$repo/test_wavs/1089-134686-0001.wav  
96 -$repo/test_wavs/1221-135766-0001.wav  
97 -$repo/test_wavs/1221-135766-0002.wav 118 +$repo/test_wavs/0.wav
  119 +$repo/test_wavs/1.wav
  120 +$repo/test_wavs/8k.wav
98 ) 121 )
99 122
100 for wave in ${waves[@]}; do 123 for wave in ${waves[@]}; do
@@ -107,10 +130,22 @@ for wave in ${waves[@]}; do @@ -107,10 +130,22 @@ for wave in ${waves[@]}; do
107 2 130 2
108 done 131 done
109 132
  133 +# test int8
  134 +#
  135 +for wave in ${waves[@]}; do
  136 + time $EXE \
  137 + $repo/tokens.txt \
  138 + $repo/encoder-epoch-99-avg-1.int8.onnx \
  139 + $repo/decoder-epoch-99-avg-1.int8.onnx \
  140 + $repo/joiner-epoch-99-avg-1.int8.onnx \
  141 + $wave \
  142 + 2
  143 +done
  144 +
110 rm -rf $repo 145 rm -rf $repo
111 146
112 log "------------------------------------------------------------" 147 log "------------------------------------------------------------"
113 -log "Run streaming Zipformer transducer (Bilingual, Chinse + English)" 148 +log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"
114 log "------------------------------------------------------------" 149 log "------------------------------------------------------------"
115 150
116 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 151 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
@@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url" @@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url"
121 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 156 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
122 pushd $repo 157 pushd $repo
123 git lfs pull --include "*.onnx" 158 git lfs pull --include "*.onnx"
  159 +ls -lh *.onnx
124 popd 160 popd
125 161
126 waves=( 162 waves=(
@@ -128,7 +164,7 @@ $repo/test_wavs/0.wav @@ -128,7 +164,7 @@ $repo/test_wavs/0.wav
128 $repo/test_wavs/1.wav 164 $repo/test_wavs/1.wav
129 $repo/test_wavs/2.wav 165 $repo/test_wavs/2.wav
130 $repo/test_wavs/3.wav 166 $repo/test_wavs/3.wav
131 -$repo/test_wavs/4.wav 167 +$repo/test_wavs/8k.wav
132 ) 168 )
133 169
134 for wave in ${waves[@]}; do 170 for wave in ${waves[@]}; do
@@ -141,6 +177,16 @@ for wave in ${waves[@]}; do @@ -141,6 +177,16 @@ for wave in ${waves[@]}; do
141 2 177 2
142 done 178 done
143 179
  180 +for wave in ${waves[@]}; do
  181 + time $EXE \
  182 + $repo/tokens.txt \
  183 + $repo/encoder-epoch-99-avg-1.int8.onnx \
  184 + $repo/decoder-epoch-99-avg-1.int8.onnx \
  185 + $repo/joiner-epoch-99-avg-1.int8.onnx \
  186 + $wave \
  187 + 2
  188 +done
  189 +
144 # Decode a URL 190 # Decode a URL
145 if [ $EXE == "sherpa-onnx-ffmpeg" ]; then 191 if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
146 time $EXE \ 192 time $EXE \
@@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then @@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
152 2 198 2
153 fi 199 fi
154 200
  201 +if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
  202 + time $EXE \
  203 + $repo/tokens.txt \
  204 + $repo/encoder-epoch-99-avg-1.int8.onnx \
  205 + $repo/decoder-epoch-99-avg-1.int8.onnx \
  206 + $repo/joiner-epoch-99-avg-1.int8.onnx \
  207 + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \
  208 + 2
  209 +fi
  210 +
155 rm -rf $repo 211 rm -rf $repo
@@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh @@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh
46 run-sherpa-onnx-offline-transducer.sh 46 run-sherpa-onnx-offline-transducer.sh
47 sherpa-onnx-paraformer-zh-2023-03-28 47 sherpa-onnx-paraformer-zh-2023-03-28
48 run-offline-websocket-server-paraformer.sh 48 run-offline-websocket-server-paraformer.sh
  49 +run-*int8.sh
  50 +a.sh
  51 +run-offline-websocket-client-*.sh
  52 +run-sherpa-onnx-*.sh
  53 +sherpa-onnx-zipformer-en-2023-03-30
@@ -18,139 +18,13 @@ @@ -18,139 +18,13 @@
18 #include <cstring> 18 #include <cstring>
19 #include <fstream> 19 #include <fstream>
20 #include <iomanip> 20 #include <iomanip>
21 -#include <limits>  
22 -#include <type_traits>  
23 -#include <unordered_map>  
24 21
25 #include "sherpa-onnx/csrc/log.h" 22 #include "sherpa-onnx/csrc/log.h"
26 -  
27 -#ifdef _MSC_VER  
28 -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \  
29 - _strtoi64(cur_cstr, end_cstr, 10);  
30 -#else  
31 -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);  
32 -#endif 23 +#include "sherpa-onnx/csrc/macros.h"
  24 +#include "sherpa-onnx/csrc/text-utils.h"
33 25
34 namespace sherpa_onnx { 26 namespace sherpa_onnx {
35 27
36 -/// Converts a string into an integer via strtoll and returns false if there was  
37 -/// any kind of problem (i.e. the string was not an integer or contained extra  
38 -/// non-whitespace junk, or the integer was too large to fit into the type it is  
39 -/// being converted into). Only sets *out if everything was OK and it returns  
40 -/// true.  
41 -template <class Int>  
42 -bool ConvertStringToInteger(const std::string &str, Int *out) {  
43 - // copied from kaldi/src/util/text-util.h  
44 - static_assert(std::is_integral<Int>::value, "");  
45 - const char *this_str = str.c_str();  
46 - char *end = nullptr;  
47 - errno = 0;  
48 - int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);  
49 - if (end != this_str) {  
50 - while (isspace(*end)) ++end;  
51 - }  
52 - if (end == this_str || *end != '\0' || errno != 0) return false;  
53 - Int iInt = static_cast<Int>(i);  
54 - if (static_cast<int64_t>(iInt) != i ||  
55 - (i < 0 && !std::numeric_limits<Int>::is_signed)) {  
56 - return false;  
57 - }  
58 - *out = iInt;  
59 - return true;  
60 -}  
61 -  
62 -// copied from kaldi/src/util/text-util.cc  
63 -template <class T>  
64 -class NumberIstream {  
65 - public:  
66 - explicit NumberIstream(std::istream &i) : in_(i) {}  
67 -  
68 - NumberIstream &operator>>(T &x) {  
69 - if (!in_.good()) return *this;  
70 - in_ >> x;  
71 - if (!in_.fail() && RemainderIsOnlySpaces()) return *this;  
72 - return ParseOnFail(&x);  
73 - }  
74 -  
75 - private:  
76 - std::istream &in_;  
77 -  
78 - bool RemainderIsOnlySpaces() {  
79 - if (in_.tellg() != std::istream::pos_type(-1)) {  
80 - std::string rem;  
81 - in_ >> rem;  
82 -  
83 - if (rem.find_first_not_of(' ') != std::string::npos) {  
84 - // there is not only spaces  
85 - return false;  
86 - }  
87 - }  
88 -  
89 - in_.clear();  
90 - return true;  
91 - }  
92 -  
93 - NumberIstream &ParseOnFail(T *x) {  
94 - std::string str;  
95 - in_.clear();  
96 - in_.seekg(0);  
97 - // If the stream is broken even before trying  
98 - // to read from it or if there are many tokens,  
99 - // it's pointless to try.  
100 - if (!(in_ >> str) || !RemainderIsOnlySpaces()) {  
101 - in_.setstate(std::ios_base::failbit);  
102 - return *this;  
103 - }  
104 -  
105 - std::unordered_map<std::string, T> inf_nan_map;  
106 - // we'll keep just uppercase values.  
107 - inf_nan_map["INF"] = std::numeric_limits<T>::infinity();  
108 - inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();  
109 - inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();  
110 - inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();  
111 - inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();  
112 - inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();  
113 - inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();  
114 - inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();  
115 - inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();  
116 - // MSVC  
117 - inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();  
118 - inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();  
119 - inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();  
120 - inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();  
121 -  
122 - std::transform(str.begin(), str.end(), str.begin(), ::toupper);  
123 -  
124 - if (inf_nan_map.find(str) != inf_nan_map.end()) {  
125 - *x = inf_nan_map[str];  
126 - } else {  
127 - in_.setstate(std::ios_base::failbit);  
128 - }  
129 -  
130 - return *this;  
131 - }  
132 -};  
133 -  
134 -/// ConvertStringToReal converts a string into either float or double  
135 -/// and returns false if there was any kind of problem (i.e. the string  
136 -/// was not a floating point number or contained extra non-whitespace junk).  
137 -/// Be careful- this function will successfully read inf's or nan's.  
138 -template <typename T>  
139 -bool ConvertStringToReal(const std::string &str, T *out) {  
140 - std::istringstream iss(str);  
141 -  
142 - NumberIstream<T> i(iss);  
143 -  
144 - i >> *out;  
145 -  
146 - if (iss.fail()) {  
147 - // Number conversion failed.  
148 - return false;  
149 - }  
150 -  
151 - return true;  
152 -}  
153 -  
154 ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) 28 ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
155 : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { 29 : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) {
156 if (po != nullptr && po->other_parser_ != nullptr) { 30 if (po != nullptr && po->other_parser_ != nullptr) {
@@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr, @@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr,
219 std::string idx = name; 93 std::string idx = name;
220 NormalizeArgName(&idx); 94 NormalizeArgName(&idx);
221 if (doc_map_.find(idx) != doc_map_.end()) { 95 if (doc_map_.find(idx) != doc_map_.end()) {
222 - SHERPA_ONNX_LOG(WARNING)  
223 - << "Registering option twice, ignoring second time: " << name; 96 + SHERPA_ONNX_LOGE("Registering option twice, ignoring second time: %s",
  97 + name.c_str());
224 } else { 98 } else {
225 this->RegisterSpecific(name, idx, ptr, doc, is_standard); 99 this->RegisterSpecific(name, idx, ptr, doc, is_standard);
226 } 100 }
@@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name, @@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name,
289 163
290 void ParseOptions::DisableOption(const std::string &name) { 164 void ParseOptions::DisableOption(const std::string &name) {
291 if (argv_ != nullptr) { 165 if (argv_ != nullptr) {
292 - SHERPA_ONNX_LOG(FATAL)  
293 - << "DisableOption must not be called after calling Read()."; 166 + SHERPA_ONNX_LOGE("DisableOption must not be called after calling Read().");
  167 + exit(-1);
294 } 168 }
295 if (doc_map_.erase(name) == 0) { 169 if (doc_map_.erase(name) == 0) {
296 - SHERPA_ONNX_LOG(FATAL) << "Option " << name  
297 - << " was not registered so cannot be disabled: "; 170 + SHERPA_ONNX_LOGE("Option %s was not registered so cannot be disabled: ",
  171 + name.c_str());
  172 + exit(-1);
298 } 173 }
299 bool_map_.erase(name); 174 bool_map_.erase(name);
300 int_map_.erase(name); 175 int_map_.erase(name);
@@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); } @@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); }
308 183
309 std::string ParseOptions::GetArg(int i) const { 184 std::string ParseOptions::GetArg(int i) const {
310 if (i < 1 || i > static_cast<int>(positional_args_.size())) { 185 if (i < 1 || i > static_cast<int>(positional_args_.size())) {
311 - SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; 186 + SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i);
  187 + exit(-1);
312 } 188 }
313 189
314 return positional_args_[i - 1]; 190 return positional_args_[i - 1];
@@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) { @@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
460 Trim(&value); 336 Trim(&value);
461 if (!SetOption(key, value, has_equal_sign)) { 337 if (!SetOption(key, value, has_equal_sign)) {
462 PrintUsage(true); 338 PrintUsage(true);
463 - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i]; 339 + SHERPA_ONNX_LOGE("Invalid option %s", argv[i]);
  340 + exit(-1);
464 } 341 }
465 } else { 342 } else {
466 break; 343 break;
@@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) { @@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
481 std::ostringstream strm; 358 std::ostringstream strm;
482 for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; 359 for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
483 strm << '\n'; 360 strm << '\n';
484 - SHERPA_ONNX_LOG(INFO) << strm.str(); 361 + SHERPA_ONNX_LOGE("%s", strm.str().c_str());
485 } 362 }
486 return i; 363 return i;
487 } 364 }
@@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { @@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
522 os << strm.str(); 399 os << strm.str();
523 } 400 }
524 401
525 - SHERPA_ONNX_LOG(INFO) << os.str(); 402 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
526 } 403 }
527 404
528 void ParseOptions::PrintConfig(std::ostream &os) const { 405 void ParseOptions::PrintConfig(std::ostream &os) const {
@@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const { @@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const {
544 } else if (string_map_.end() != string_map_.find(key)) { 421 } else if (string_map_.end() != string_map_.find(key)) {
545 os << "'" << *string_map_.at(key) << "'"; 422 os << "'" << *string_map_.at(key) << "'";
546 } else { 423 } else {
547 - SHERPA_ONNX_LOG(FATAL)  
548 - << "PrintConfig: unrecognized option " << key << "[code error]"; 424 + SHERPA_ONNX_LOGE("PrintConfig: unrecognized option %s [code error]",
  425 + key.c_str());
  426 + exit(-1);
549 } 427 }
550 os << '\n'; 428 os << '\n';
551 } 429 }
@@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const { @@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const {
555 void ParseOptions::ReadConfigFile(const std::string &filename) { 433 void ParseOptions::ReadConfigFile(const std::string &filename) {
556 std::ifstream is(filename.c_str(), std::ifstream::in); 434 std::ifstream is(filename.c_str(), std::ifstream::in);
557 if (!is.good()) { 435 if (!is.good()) {
558 - SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename; 436 + SHERPA_ONNX_LOGE("Cannot open config file: %s", filename.c_str());
  437 + exit(-1);
559 } 438 }
560 439
561 std::string line, key, value; 440 std::string line, key, value;
@@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { @@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
572 if (line.length() == 0) continue; 451 if (line.length() == 0) continue;
573 452
574 if (line.substr(0, 2) != "--") { 453 if (line.substr(0, 2) != "--") {
575 - SHERPA_ONNX_LOG(FATAL)  
576 - << "Reading config file " << filename << ": line " << line_number  
577 - << " does not look like a line "  
578 - << "from a Kaldi command-line program's config file: should "  
579 - << "be of the form --x=y. Note: config files intended to "  
580 - << "be sourced by shell scripts lack the '--'."; 454 + SHERPA_ONNX_LOGE(
  455 + "Reading config file %s: line %d does not look like a line "
  456 + "from a sherpa-onnx command-line program's config file: should "
  457 + "be of the form --x=y. Note: config files intended to "
  458 + "be sourced by shell scripts lack the '--'.",
  459 + filename.c_str(), line_number);
  460 + exit(-1);
581 } 461 }
582 462
583 // parse option 463 // parse option
@@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { @@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
587 Trim(&value); 467 Trim(&value);
588 if (!SetOption(key, value, has_equal_sign)) { 468 if (!SetOption(key, value, has_equal_sign)) {
589 PrintUsage(true); 469 PrintUsage(true);
590 - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file "  
591 - << filename << ": line " << line_number; 470 + SHERPA_ONNX_LOGE("Invalid option %s in config file %s: line %d",
  471 + line.c_str(), filename.c_str(), line_number);
  472 + exit(-1);
592 } 473 }
593 } 474 }
594 } 475 }
@@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key, @@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key,
605 *has_equal_sign = false; 486 *has_equal_sign = false;
606 } else if (pos == 2) { // we also don't allow empty keys: --=value 487 } else if (pos == 2) { // we also don't allow empty keys: --=value
607 PrintUsage(true); 488 PrintUsage(true);
608 - SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in; 489 + SHERPA_ONNX_LOGE("Invalid option (no key): %s", in.c_str());
  490 + exit(-1);
609 } else { // normal case: --option=value 491 } else { // normal case: --option=value
610 *key = in.substr(2, pos - 2); // 2 because starts with --. 492 *key = in.substr(2, pos - 2); // 2 because starts with --.
611 *value = in.substr(pos + 1); 493 *value = in.substr(pos + 1);
@@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, @@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value,
646 bool has_equal_sign) { 528 bool has_equal_sign) {
647 if (bool_map_.end() != bool_map_.find(key)) { 529 if (bool_map_.end() != bool_map_.find(key)) {
648 if (has_equal_sign && value == "") { 530 if (has_equal_sign && value == "") {
649 - SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "="; 531 + SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str());
  532 + exit(-1);
650 } 533 }
651 *(bool_map_[key]) = ToBool(value); 534 *(bool_map_[key]) = ToBool(value);
652 } else if (int_map_.end() != int_map_.find(key)) { 535 } else if (int_map_.end() != int_map_.find(key)) {
@@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, @@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value,
659 *(double_map_[key]) = ToDouble(value); 542 *(double_map_[key]) = ToDouble(value);
660 } else if (string_map_.end() != string_map_.find(key)) { 543 } else if (string_map_.end() != string_map_.find(key)) {
661 if (!has_equal_sign) { 544 if (!has_equal_sign) {
662 - SHERPA_ONNX_LOG(FATAL)  
663 - << "Invalid option --" << key << " (option format is --x=y)."; 545 + SHERPA_ONNX_LOGE("Invalid option --%s (option format is --x=y).",
  546 + key.c_str());
  547 + exit(-1);
664 } 548 }
665 *(string_map_[key]) = value; 549 *(string_map_[key]) = value;
666 } else { 550 } else {
@@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const { @@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const {
683 } 567 }
684 // if it is neither true nor false: 568 // if it is neither true nor false:
685 PrintUsage(true); 569 PrintUsage(true);
686 - SHERPA_ONNX_LOG(FATAL)  
687 - << "Invalid format for boolean argument [expected true or false]: "  
688 - << str; 570 + SHERPA_ONNX_LOGE(
  571 + "Invalid format for boolean argument [expected true or false]: %s",
  572 + str.c_str());
  573 + exit(-1);
689 return false; // never reached 574 return false; // never reached
690 } 575 }
691 576
692 int32_t ParseOptions::ToInt(const std::string &str) const { 577 int32_t ParseOptions::ToInt(const std::string &str) const {
693 int32_t ret = 0; 578 int32_t ret = 0;
694 - if (!ConvertStringToInteger(str, &ret))  
695 - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; 579 + if (!ConvertStringToInteger(str, &ret)) {
  580 + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str());
  581 + exit(-1);
  582 + }
696 return ret; 583 return ret;
697 } 584 }
698 585
699 uint32_t ParseOptions::ToUint(const std::string &str) const { 586 uint32_t ParseOptions::ToUint(const std::string &str) const {
700 uint32_t ret = 0; 587 uint32_t ret = 0;
701 - if (!ConvertStringToInteger(str, &ret))  
702 - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; 588 + if (!ConvertStringToInteger(str, &ret)) {
  589 + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str());
  590 + exit(-1);
  591 + }
703 return ret; 592 return ret;
704 } 593 }
705 594
706 float ParseOptions::ToFloat(const std::string &str) const { 595 float ParseOptions::ToFloat(const std::string &str) const {
707 float ret; 596 float ret;
708 - if (!ConvertStringToReal(str, &ret))  
709 - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; 597 + if (!ConvertStringToReal(str, &ret)) {
  598 + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
  599 + exit(-1);
  600 + }
710 return ret; 601 return ret;
711 } 602 }
712 603
713 double ParseOptions::ToDouble(const std::string &str) const { 604 double ParseOptions::ToDouble(const std::string &str) const {
714 double ret; 605 double ret;
715 - if (!ConvertStringToReal(str, &ret))  
716 - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; 606 + if (!ConvertStringToReal(str, &ret)) {
  607 + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
  608 + exit(-1);
  609 + }
717 return ret; 610 return ret;
718 } 611 }
719 612
@@ -7,7 +7,11 @@ @@ -7,7 +7,11 @@
7 7
8 #include <assert.h> 8 #include <assert.h>
9 9
  10 +#include <algorithm>
  11 +#include <limits>
  12 +#include <sstream>
10 #include <string> 13 #include <string>
  14 +#include <unordered_map>
11 #include <vector> 15 #include <vector>
12 16
13 // This file is copied/modified from 17 // This file is copied/modified from
@@ -15,6 +19,102 @@ @@ -15,6 +19,102 @@
15 19
16 namespace sherpa_onnx { 20 namespace sherpa_onnx {
17 21
  22 +// copied from kaldi/src/util/text-util.cc
  23 +template <class T>
  24 +class NumberIstream {
  25 + public:
  26 + explicit NumberIstream(std::istream &i) : in_(i) {}
  27 +
  28 + NumberIstream &operator>>(T &x) {
  29 + if (!in_.good()) return *this;
  30 + in_ >> x;
  31 + if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
  32 + return ParseOnFail(&x);
  33 + }
  34 +
  35 + private:
  36 + std::istream &in_;
  37 +
  38 + bool RemainderIsOnlySpaces() {
  39 + if (in_.tellg() != std::istream::pos_type(-1)) {
  40 + std::string rem;
  41 + in_ >> rem;
  42 +
  43 + if (rem.find_first_not_of(' ') != std::string::npos) {
  44 + // there is not only spaces
  45 + return false;
  46 + }
  47 + }
  48 +
  49 + in_.clear();
  50 + return true;
  51 + }
  52 +
  53 + NumberIstream &ParseOnFail(T *x) {
  54 + std::string str;
  55 + in_.clear();
  56 + in_.seekg(0);
  57 + // If the stream is broken even before trying
  58 + // to read from it or if there are many tokens,
  59 + // it's pointless to try.
  60 + if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
  61 + in_.setstate(std::ios_base::failbit);
  62 + return *this;
  63 + }
  64 +
  65 + std::unordered_map<std::string, T> inf_nan_map;
  66 + // we'll keep just uppercase values.
  67 + inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
  68 + inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
  69 + inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
  70 + inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
  71 + inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
  72 + inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
  73 + inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
  74 + inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
  75 + inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
  76 + // MSVC
  77 + inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
  78 + inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
  79 + inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
  80 + inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
  81 +
  82 + std::transform(str.begin(), str.end(), str.begin(), ::toupper);
  83 +
  84 + if (inf_nan_map.find(str) != inf_nan_map.end()) {
  85 + *x = inf_nan_map[str];
  86 + } else {
  87 + in_.setstate(std::ios_base::failbit);
  88 + }
  89 +
  90 + return *this;
  91 + }
  92 +};
  93 +
  94 +/// ConvertStringToReal converts a string into either float or double
  95 +/// and returns false if there was any kind of problem (i.e. the string
  96 +/// was not a floating point number or contained extra non-whitespace junk).
  97 +/// Be careful- this function will successfully read inf's or nan's.
  98 +template <typename T>
  99 +bool ConvertStringToReal(const std::string &str, T *out) {
  100 + std::istringstream iss(str);
  101 +
  102 + NumberIstream<T> i(iss);
  103 +
  104 + i >> *out;
  105 +
  106 + if (iss.fail()) {
  107 + // Number conversion failed.
  108 + return false;
  109 + }
  110 +
  111 + return true;
  112 +}
  113 +
  114 +template bool ConvertStringToReal<float>(const std::string &str, float *out);
  115 +
  116 +template bool ConvertStringToReal<double>(const std::string &str, double *out);
  117 +
18 void SplitStringToVector(const std::string &full, const char *delim, 118 void SplitStringToVector(const std::string &full, const char *delim,
19 bool omit_empty_strings, 119 bool omit_empty_strings,
20 std::vector<std::string> *out) { 120 std::vector<std::string> *out) {
@@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim, @@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
43 out->resize(split.size()); 143 out->resize(split.size());
44 for (size_t i = 0; i < split.size(); ++i) { 144 for (size_t i = 0; i < split.size(); ++i) {
45 // assume atof never fails 145 // assume atof never fails
46 - (*out)[i] = atof(split[i].c_str()); 146 + F f = 0;
  147 + if (!ConvertStringToReal(split[i], &f)) return false;
  148 + (*out)[i] = f;
47 } 149 }
48 return true; 150 return true;
49 } 151 }
@@ -6,7 +6,9 @@ @@ -6,7 +6,9 @@
6 #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 6 #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
7 #include <stdlib.h> 7 #include <stdlib.h>
8 8
  9 +#include <limits>
9 #include <string> 10 #include <string>
  11 +#include <type_traits>
10 #include <vector> 12 #include <vector>
11 13
12 #ifdef _MSC_VER 14 #ifdef _MSC_VER
@@ -21,6 +23,32 @@ @@ -21,6 +23,32 @@
21 23
22 namespace sherpa_onnx { 24 namespace sherpa_onnx {
23 25
  26 +/// Converts a string into an integer via strtoll and returns false if there was
  27 +/// any kind of problem (i.e. the string was not an integer or contained extra
  28 +/// non-whitespace junk, or the integer was too large to fit into the type it is
  29 +/// being converted into). Only sets *out if everything was OK and it returns
  30 +/// true.
  31 +template <class Int>
  32 +bool ConvertStringToInteger(const std::string &str, Int *out) {
  33 + // copied from kaldi/src/util/text-util.h
  34 + static_assert(std::is_integral<Int>::value, "");
  35 + const char *this_str = str.c_str();
  36 + char *end = nullptr;
  37 + errno = 0;
  38 + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
  39 + if (end != this_str) {
  40 + while (isspace(*end)) ++end;
  41 + }
  42 + if (end == this_str || *end != '\0' || errno != 0) return false;
  43 + Int iInt = static_cast<Int>(i);
  44 + if (static_cast<int64_t>(iInt) != i ||
  45 + (i < 0 && !std::numeric_limits<Int>::is_signed)) {
  46 + return false;
  47 + }
  48 + *out = iInt;
  49 + return true;
  50 +}
  51 +
24 /// Split a string using any of the single character delimiters. 52 /// Split a string using any of the single character delimiters.
25 /// If omit_empty_strings == true, the output will contain any 53 /// If omit_empty_strings == true, the output will contain any
26 /// nonempty strings after splitting on any of the 54 /// nonempty strings after splitting on any of the
@@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim, @@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
86 bool omit_empty_strings, // typically false 114 bool omit_empty_strings, // typically false
87 std::vector<F> *out); 115 std::vector<F> *out);
88 116
  117 +// This is defined for F = float and double.
  118 +template <typename T>
  119 +bool ConvertStringToReal(const std::string &str, T *out);
  120 +
89 } // namespace sherpa_onnx 121 } // namespace sherpa_onnx
90 122
91 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 123 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_