Jingzhao Ou
Committed by GitHub

Added provider option to sherpa-onnx and decode-file-c-api (#162)

@@ -36,22 +36,22 @@ $repo/test_wavs/8k.wav @@ -36,22 +36,22 @@ $repo/test_wavs/8k.wav
36 36
37 for wave in ${waves[@]}; do 37 for wave in ${waves[@]}; do
38 time $EXE \ 38 time $EXE \
39 - $repo/tokens.txt \  
40 - $repo/encoder-epoch-99-avg-1.onnx \  
41 - $repo/decoder-epoch-99-avg-1.onnx \  
42 - $repo/joiner-epoch-99-avg-1.onnx \  
43 - $wave \  
44 - 2 39 + --tokens=$repo/tokens.txt \
  40 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  41 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  42 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  43 + --num-threads=2 \
  44 + $wave
45 done 45 done
46 46
47 for wave in ${waves[@]}; do 47 for wave in ${waves[@]}; do
48 time $EXE \ 48 time $EXE \
49 - $repo/tokens.txt \  
50 - $repo/encoder-epoch-99-avg-1.int8.onnx \  
51 - $repo/decoder-epoch-99-avg-1.int8.onnx \  
52 - $repo/joiner-epoch-99-avg-1.int8.onnx \  
53 - $wave \  
54 - 2 49 + --tokens=$repo/tokens.txt \
  50 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  51 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  52 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  53 + --num-threads=2 \
  54 + $wave
55 done 55 done
56 56
57 rm -rf $repo 57 rm -rf $repo
@@ -79,22 +79,22 @@ $repo/test_wavs/8k.wav @@ -79,22 +79,22 @@ $repo/test_wavs/8k.wav
79 79
80 for wave in ${waves[@]}; do 80 for wave in ${waves[@]}; do
81 time $EXE \ 81 time $EXE \
82 - $repo/tokens.txt \  
83 - $repo/encoder-epoch-11-avg-1.onnx \  
84 - $repo/decoder-epoch-11-avg-1.onnx \  
85 - $repo/joiner-epoch-11-avg-1.onnx \  
86 - $wave \  
87 - 2 82 + --tokens=$repo/tokens.txt \
  83 + --encoder=$repo/encoder-epoch-11-avg-1.onnx \
  84 + --decoder=$repo/decoder-epoch-11-avg-1.onnx \
  85 + --joiner=$repo/joiner-epoch-11-avg-1.onnx \
  86 + --num-threads=2 \
  87 + $wave
88 done 88 done
89 89
90 for wave in ${waves[@]}; do 90 for wave in ${waves[@]}; do
91 time $EXE \ 91 time $EXE \
92 - $repo/tokens.txt \  
93 - $repo/encoder-epoch-11-avg-1.int8.onnx \  
94 - $repo/decoder-epoch-11-avg-1.int8.onnx \  
95 - $repo/joiner-epoch-11-avg-1.int8.onnx \  
96 - $wave \  
97 - 2 92 + --tokens=$repo/tokens.txt \
  93 + --encoder=$repo/encoder-epoch-11-avg-1.int8.onnx \
  94 + --decoder=$repo/decoder-epoch-11-avg-1.int8.onnx \
  95 + --joiner=$repo/joiner-epoch-11-avg-1.int8.onnx \
  96 + --num-threads=2 \
  97 + $wave
98 done 98 done
99 99
100 rm -rf $repo 100 rm -rf $repo
@@ -122,24 +122,24 @@ $repo/test_wavs/8k.wav @@ -122,24 +122,24 @@ $repo/test_wavs/8k.wav
122 122
123 for wave in ${waves[@]}; do 123 for wave in ${waves[@]}; do
124 time $EXE \ 124 time $EXE \
125 - $repo/tokens.txt \  
126 - $repo/encoder-epoch-99-avg-1.onnx \  
127 - $repo/decoder-epoch-99-avg-1.onnx \  
128 - $repo/joiner-epoch-99-avg-1.onnx \  
129 - $wave \  
130 - 2 125 + --tokens=$repo/tokens.txt \
  126 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  127 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  128 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  129 + --num-threads=2 \
  130 + $wave
131 done 131 done
132 132
133 # test int8 133 # test int8
134 # 134 #
135 for wave in ${waves[@]}; do 135 for wave in ${waves[@]}; do
136 time $EXE \ 136 time $EXE \
137 - $repo/tokens.txt \  
138 - $repo/encoder-epoch-99-avg-1.int8.onnx \  
139 - $repo/decoder-epoch-99-avg-1.int8.onnx \  
140 - $repo/joiner-epoch-99-avg-1.int8.onnx \  
141 - $wave \  
142 - 2 137 + --tokens=$repo/tokens.txt \
  138 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  139 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  140 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  141 + --num-threads=2 \
  142 + $wave
143 done 143 done
144 144
145 rm -rf $repo 145 rm -rf $repo
@@ -169,22 +169,22 @@ $repo/test_wavs/8k.wav @@ -169,22 +169,22 @@ $repo/test_wavs/8k.wav
169 169
170 for wave in ${waves[@]}; do 170 for wave in ${waves[@]}; do
171 time $EXE \ 171 time $EXE \
172 - $repo/tokens.txt \  
173 - $repo/encoder-epoch-99-avg-1.onnx \  
174 - $repo/decoder-epoch-99-avg-1.onnx \  
175 - $repo/joiner-epoch-99-avg-1.onnx \  
176 - $wave \  
177 - 2 172 + --tokens=$repo/tokens.txt \
  173 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  174 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  175 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  176 + --num-threads=2 \
  177 + $wave
178 done 178 done
179 179
180 for wave in ${waves[@]}; do 180 for wave in ${waves[@]}; do
181 time $EXE \ 181 time $EXE \
182 - $repo/tokens.txt \  
183 - $repo/encoder-epoch-99-avg-1.int8.onnx \  
184 - $repo/decoder-epoch-99-avg-1.int8.onnx \  
185 - $repo/joiner-epoch-99-avg-1.int8.onnx \  
186 - $wave \  
187 - 2 182 + --tokens=$repo/tokens.txt \
  183 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  184 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  185 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  186 + --num-threads=2 \
  187 + $wave
188 done 188 done
189 189
190 # Decode a URL 190 # Decode a URL
@@ -233,22 +233,22 @@ $repo/test_wavs/2.wav @@ -233,22 +233,22 @@ $repo/test_wavs/2.wav
233 233
234 for wave in ${waves[@]}; do 234 for wave in ${waves[@]}; do
235 time $EXE \ 235 time $EXE \
236 - $repo/tokens.txt \  
237 - $repo/encoder-epoch-99-avg-1.onnx \  
238 - $repo/decoder-epoch-99-avg-1.onnx \  
239 - $repo/joiner-epoch-99-avg-1.onnx \  
240 - $wave \  
241 - 2 236 + --tokens=$repo/tokens.txt \
  237 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  238 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  239 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  240 + --num-threads=2 \
  241 + $wave
242 done 242 done
243 243
244 for wave in ${waves[@]}; do 244 for wave in ${waves[@]}; do
245 time $EXE \ 245 time $EXE \
246 - $repo/tokens.txt \  
247 - $repo/encoder-epoch-99-avg-1.int8.onnx \  
248 - $repo/decoder-epoch-99-avg-1.int8.onnx \  
249 - $repo/joiner-epoch-99-avg-1.int8.onnx \  
250 - $wave \  
251 - 2 246 + --tokens=$repo/tokens.txt \
  247 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  248 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  249 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  250 + --num-threads=2 \
  251 + $wave
252 done 252 done
253 253
254 rm -rf $repo 254 rm -rf $repo
  1 +include(cargs)
  2 +
1 include_directories(${CMAKE_SOURCE_DIR}) 3 include_directories(${CMAKE_SOURCE_DIR})
2 add_executable(decode-file-c-api decode-file-c-api.c) 4 add_executable(decode-file-c-api decode-file-c-api.c)
3 -target_link_libraries(decode-file-c-api sherpa-onnx-c-api) 5 +target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
@@ -5,50 +5,85 @@ @@ -5,50 +5,85 @@
5 // This file shows how to use sherpa-onnx C API 5 // This file shows how to use sherpa-onnx C API
6 // to decode a file. 6 // to decode a file.
7 7
  8 +#include "cargs.h"
8 #include <stdio.h> 9 #include <stdio.h>
9 #include <stdlib.h> 10 #include <stdlib.h>
10 #include <string.h> 11 #include <string.h>
11 12
12 #include "sherpa-onnx/c-api/c-api.h" 13 #include "sherpa-onnx/c-api/c-api.h"
13 14
  15 +static struct cag_option options[] = {
  16 + {
  17 + .identifier = 't',
  18 + .access_letters = NULL,
  19 + .access_name = "tokens",
  20 + .value_name = "tokens",
  21 + .description = "Tokens file"
  22 + }, {
  23 + .identifier = 'e',
  24 + .access_letters = NULL,
  25 + .access_name = "encoder",
  26 + .value_name = "encoder",
  27 + .description = "Encoder ONNX file"
  28 + }, {
  29 + .identifier = 'd',
  30 + .access_letters = NULL,
  31 + .access_name = "decoder",
  32 + .value_name = "decoder",
  33 + .description = "Decoder ONNX file"
  34 + }, {
  35 + .identifier = 'j',
  36 + .access_letters = NULL,
  37 + .access_name = "joiner",
  38 + .value_name = "joiner",
  39 + .description = "Joiner ONNX file"
  40 + }, {
  41 + .identifier = 'n',
  42 + .access_letters = NULL,
  43 + .access_name = "num-threads",
  44 + .value_name = "num-threads",
  45 + .description = "Number of threads"
  46 + }, {
  47 + .identifier = 'p',
  48 + .access_letters = NULL,
  49 + .access_name = "provider",
  50 + .value_name = "provider",
  51 + .description = "Provider: cpu (default), cuda, coreml"
  52 + }, {
  53 + .identifier = 'm',
  54 + .access_letters = NULL,
  55 + .access_name = "decoding-method",
  56 + .value_name = "decoding-method",
  57 + .description =
  58 + "Decoding method: greedy_search (default), modified_beam_search"
  59 + }
  60 +};
  61 +
14 const char *kUsage = 62 const char *kUsage =
15 "\n" 63 "\n"
16 "Usage:\n " 64 "Usage:\n "
17 " ./bin/decode-file-c-api \\\n" 65 " ./bin/decode-file-c-api \\\n"
18 - " /path/to/tokens.txt \\\n"  
19 - " /path/to/encoder.onnx \\\n"  
20 - " /path/to/decoder.onnx \\\n"  
21 - " /path/to/joiner.onnx \\\n"  
22 - " /path/to/foo.wav [num_threads [decoding_method]]\n" 66 + " --tokens=/path/to/tokens.txt \\\n"
  67 + " --encoder=/path/to/encoder.onnx \\\n"
  68 + " --decoder=/path/to/decoder.onnx \\\n"
  69 + " --joiner=/path/to/joiner.onnx \\\n"
  70 + " /path/to/foo.wav\n"
23 "\n\n" 71 "\n\n"
24 "Default num_threads is 1.\n" 72 "Default num_threads is 1.\n"
25 "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" 73 "Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
  74 + "Valid provider: cpu (default), cuda, coreml\n\n"
26 "Please refer to \n" 75 "Please refer to \n"
27 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" 76 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
28 "for a list of pre-trained models to download.\n"; 77 "for a list of pre-trained models to download.\n";
29 78
30 int32_t main(int32_t argc, char *argv[]) { 79 int32_t main(int32_t argc, char *argv[]) {
31 - if (argc < 6 || argc > 8) {  
32 - fprintf(stderr, "%s\n", kUsage);  
33 - return -1;  
34 - }  
35 SherpaOnnxOnlineRecognizerConfig config; 80 SherpaOnnxOnlineRecognizerConfig config;
36 - config.model_config.tokens = argv[1];  
37 - config.model_config.encoder = argv[2];  
38 - config.model_config.decoder = argv[3];  
39 - config.model_config.joiner = argv[4];  
40 -  
41 - int32_t num_threads = 1;  
42 - if (argc == 7 && atoi(argv[6]) > 0) {  
43 - num_threads = atoi(argv[6]);  
44 - }  
45 - config.model_config.num_threads = num_threads; 81 +
46 config.model_config.debug = 0; 82 config.model_config.debug = 0;
  83 + config.model_config.num_threads = 1;
  84 + config.model_config.provider = "cpu";
47 85
48 config.decoding_method = "greedy_search"; 86 config.decoding_method = "greedy_search";
49 - if (argc == 8) {  
50 - config.decoding_method = argv[7];  
51 - }  
52 87
53 config.max_active_paths = 4; 88 config.max_active_paths = 4;
54 89
@@ -60,13 +95,36 @@ int32_t main(int32_t argc, char *argv[]) { @@ -60,13 +95,36 @@ int32_t main(int32_t argc, char *argv[]) {
60 config.rule2_min_trailing_silence = 1.2; 95 config.rule2_min_trailing_silence = 1.2;
61 config.rule3_min_utterance_length = 300; 96 config.rule3_min_utterance_length = 300;
62 97
  98 + cag_option_context context;
  99 + char identifier;
  100 + const char *value;
  101 +
  102 + cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv);
  103 +
  104 + while (cag_option_fetch(&context)) {
  105 + identifier = cag_option_get(&context);
  106 + value = cag_option_get_value(&context);
  107 + switch (identifier) {
  108 + case 't': config.model_config.tokens = value; break;
  109 + case 'e': config.model_config.encoder = value; break;
  110 + case 'd': config.model_config.decoder = value; break;
  111 + case 'j': config.model_config.joiner = value; break;
  112 + case 'n': config.model_config.num_threads = atoi(value); break;
  113 + case 'p': config.model_config.provider = value; break;
  114 + case 'm': config.decoding_method = value; break;
  115 + default:
  116 + // do nothing as config already have valid default values
  117 + break;
  118 + }
  119 + }
  120 +
63 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); 121 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
64 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); 122 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
65 123
66 SherpaOnnxDisplay *display = CreateDisplay(50); 124 SherpaOnnxDisplay *display = CreateDisplay(50);
67 int32_t segment_id = 0; 125 int32_t segment_id = 0;
68 126
69 - const char *wav_filename = argv[5]; 127 + const char *wav_filename = argv[context.index];
70 FILE *fp = fopen(wav_filename, "rb"); 128 FILE *fp = fopen(wav_filename, "rb");
71 if (!fp) { 129 if (!fp) {
72 fprintf(stderr, "Failed to open %s\n", wav_filename); 130 fprintf(stderr, "Failed to open %s\n", wav_filename);
  1 +function(download_cargs)
  2 + include(FetchContent)
  3 +
  4 + set(cargs_URL "https://github.com/likle/cargs/archive/refs/tags/v1.0.3.tar.gz")
  5 + set(cargs_HASH "SHA256=ddba25bd35e9c6c75bc706c126001b8ce8e084d40ef37050e6aa6963e836eb8b")
  6 +
  7 + # If you don't have access to the Internet,
  8 + # please pre-download asio
  9 + set(possible_file_locations
  10 + $ENV{HOME}/Downloads/cargs-v1-0-3.tar.gz
  11 + ${PROJECT_SOURCE_DIR}/cargs-v1-0-3.tar.gz
  12 + ${PROJECT_BINARY_DIR}/cargs-v1-0-3.tar.gz
  13 + /tmp/cargs-v1-0-3.tar.gz
  14 + /star-fj/fangjun/download/github/cargs-v1-0-3.tar.gz
  15 + )
  16 +
  17 + foreach(f IN LISTS possible_file_locations)
  18 + if(EXISTS ${f})
  19 + set(cargs_URL "${f}")
  20 + file(TO_CMAKE_PATH "${cargs_URL}" cargs_URL)
  21 + break()
  22 + endif()
  23 + endforeach()
  24 +
  25 + FetchContent_Declare(cargs URL ${cargs_URL} URL_HASH ${cargs_HASH})
  26 +
  27 + FetchContent_GetProperties(cargs)
  28 + if(NOT cargs_POPULATED)
  29 + message(STATUS "Downloading cargs ${cargs_URL}")
  30 + FetchContent_Populate(cargs)
  31 + endif()
  32 + message(STATUS "cargs is downloaded to ${cargs_SOURCE_DIR}")
  33 + add_subdirectory(${cargs_SOURCE_DIR} ${cargs_BINARY_DIR} EXCLUDE_FROM_ALL)
  34 +endfunction()
  35 +
  36 +download_cargs()
@@ -41,6 +41,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -41,6 +41,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
41 recognizer_config.model_config.joiner_filename = config->model_config.joiner; 41 recognizer_config.model_config.joiner_filename = config->model_config.joiner;
42 recognizer_config.model_config.tokens = config->model_config.tokens; 42 recognizer_config.model_config.tokens = config->model_config.tokens;
43 recognizer_config.model_config.num_threads = config->model_config.num_threads; 43 recognizer_config.model_config.num_threads = config->model_config.num_threads;
  44 + recognizer_config.model_config.provider = config->model_config.provider;
44 recognizer_config.model_config.debug = config->model_config.debug; 45 recognizer_config.model_config.debug = config->model_config.debug;
45 46
46 recognizer_config.decoding_method = config->decoding_method; 47 recognizer_config.decoding_method = config->decoding_method;
@@ -52,6 +52,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { @@ -52,6 +52,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
52 const char *joiner; 52 const char *joiner;
53 const char *tokens; 53 const char *tokens;
54 int32_t num_threads; 54 int32_t num_threads;
  55 + const char *provider;
55 int32_t debug; // true to print debug information of the model 56 int32_t debug; // true to print debug information of the model
56 } SherpaOnnxOnlineTransducerModelConfig; 57 } SherpaOnnxOnlineTransducerModelConfig;
57 58
@@ -17,6 +17,8 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { @@ -17,6 +17,8 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) {
17 po->Register("tokens", &tokens, "Path to tokens.txt"); 17 po->Register("tokens", &tokens, "Path to tokens.txt");
18 po->Register("num_threads", &num_threads, 18 po->Register("num_threads", &num_threads,
19 "Number of threads to run the neural network"); 19 "Number of threads to run the neural network");
  20 + po->Register("provider", &provider,
  21 + "Specify a provider to use: cpu, cuda, coreml");
20 22
21 po->Register("debug", &debug, 23 po->Register("debug", &debug,
22 "true to print model information while loading it."); 24 "true to print model information while loading it.");
@@ -60,6 +62,7 @@ std::string OnlineTransducerModelConfig::ToString() const { @@ -60,6 +62,7 @@ std::string OnlineTransducerModelConfig::ToString() const {
60 os << "joiner_filename=\"" << joiner_filename << "\", "; 62 os << "joiner_filename=\"" << joiner_filename << "\", ";
61 os << "tokens=\"" << tokens << "\", "; 63 os << "tokens=\"" << tokens << "\", ";
62 os << "num_threads=" << num_threads << ", "; 64 os << "num_threads=" << num_threads << ", ";
  65 + os << "provider=\"" << provider << "\", ";
63 os << "debug=" << (debug ? "True" : "False") << ")"; 66 os << "debug=" << (debug ? "True" : "False") << ")";
64 67
65 return os.str(); 68 return os.str();
@@ -69,17 +69,17 @@ for a list of pre-trained models to download. @@ -69,17 +69,17 @@ for a list of pre-trained models to download.
69 fprintf(stderr, "Creating recognizer ...\n"); 69 fprintf(stderr, "Creating recognizer ...\n");
70 sherpa_onnx::OfflineRecognizer recognizer(config); 70 sherpa_onnx::OfflineRecognizer recognizer(config);
71 71
72 - auto begin = std::chrono::steady_clock::now(); 72 + const auto begin = std::chrono::steady_clock::now();
73 fprintf(stderr, "Started\n"); 73 fprintf(stderr, "Started\n");
74 74
75 std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; 75 std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
76 std::vector<sherpa_onnx::OfflineStream *> ss_pointers; 76 std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
77 float duration = 0; 77 float duration = 0;
78 for (int32_t i = 1; i <= po.NumArgs(); ++i) { 78 for (int32_t i = 1; i <= po.NumArgs(); ++i) {
79 - std::string wav_filename = po.GetArg(i); 79 + const std::string wav_filename = po.GetArg(i);
80 int32_t sampling_rate = -1; 80 int32_t sampling_rate = -1;
81 bool is_ok = false; 81 bool is_ok = false;
82 - std::vector<float> samples = 82 + const std::vector<float> samples =
83 sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); 83 sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
84 if (!is_ok) { 84 if (!is_ok) {
85 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); 85 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
@@ -96,7 +96,7 @@ for a list of pre-trained models to download. @@ -96,7 +96,7 @@ for a list of pre-trained models to download.
96 96
97 recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); 97 recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size());
98 98
99 - auto end = std::chrono::steady_clock::now(); 99 + const auto end = std::chrono::steady_clock::now();
100 100
101 fprintf(stderr, "Done!\n\n"); 101 fprintf(stderr, "Done!\n\n");
102 for (int32_t i = 1; i <= po.NumArgs(); ++i) { 102 for (int32_t i = 1; i <= po.NumArgs(); ++i) {
@@ -11,22 +11,28 @@ @@ -11,22 +11,28 @@
11 #include "sherpa-onnx/csrc/online-recognizer.h" 11 #include "sherpa-onnx/csrc/online-recognizer.h"
12 #include "sherpa-onnx/csrc/online-stream.h" 12 #include "sherpa-onnx/csrc/online-stream.h"
13 #include "sherpa-onnx/csrc/symbol-table.h" 13 #include "sherpa-onnx/csrc/symbol-table.h"
  14 +#include "sherpa-onnx/csrc/parse-options.h"
14 #include "sherpa-onnx/csrc/wave-reader.h" 15 #include "sherpa-onnx/csrc/wave-reader.h"
15 16
16 -// TODO(fangjun): Use ParseOptions as we are getting more args  
17 int main(int32_t argc, char *argv[]) { 17 int main(int32_t argc, char *argv[]) {
18 - if (argc < 6 || argc > 9) {  
19 - const char *usage = R"usage( 18 + const char *kUsageMessage = R"usage(
20 Usage: 19 Usage:
  20 +
21 ./bin/sherpa-onnx \ 21 ./bin/sherpa-onnx \
22 - /path/to/tokens.txt \  
23 - /path/to/encoder.onnx \  
24 - /path/to/decoder.onnx \  
25 - /path/to/joiner.onnx \  
26 - /path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]] 22 + --tokens=/path/to/tokens.txt \
  23 + --encoder=/path/to/encoder.onnx \
  24 + --decoder=/path/to/decoder.onnx \
  25 + --joiner=/path/to/joiner.onnx \
  26 + --provider=cpu \
  27 + --num-threads=2 \
  28 + --decoding-method=greedy_search \
  29 + /path/to/foo.wav [bar.wav foobar.wav ...]
  30 +
  31 +Note: It supports decoding multiple files in batches
27 32
28 Default value for num_threads is 2. 33 Default value for num_threads is 2.
29 Valid values for decoding_method: greedy_search (default), modified_beam_search. 34 Valid values for decoding_method: greedy_search (default), modified_beam_search.
  35 +Valid values for provider: cpu (default), cuda, coreml.
30 foo.wav should be of single channel, 16-bit PCM encoded wave file; its 36 foo.wav should be of single channel, 16-bit PCM encoded wave file; its
31 sampling rate can be arbitrary and does not need to be 16kHz. 37 sampling rate can be arbitrary and does not need to be 16kHz.
32 38
@@ -34,33 +40,17 @@ Please refer to @@ -34,33 +40,17 @@ Please refer to
34 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 40 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
35 for a list of pre-trained models to download. 41 for a list of pre-trained models to download.
36 )usage"; 42 )usage";
37 - fprintf(stderr, "%s\n", usage);  
38 -  
39 - return 0;  
40 - }  
41 43
  44 + sherpa_onnx::ParseOptions po(kUsageMessage);
42 sherpa_onnx::OnlineRecognizerConfig config; 45 sherpa_onnx::OnlineRecognizerConfig config;
43 46
44 - config.model_config.tokens = argv[1]; 47 + config.Register(&po);
45 48
46 - config.model_config.debug = false;  
47 - config.model_config.encoder_filename = argv[2];  
48 - config.model_config.decoder_filename = argv[3];  
49 - config.model_config.joiner_filename = argv[4];  
50 -  
51 - std::string wav_filename = argv[5];  
52 -  
53 - config.model_config.num_threads = 2;  
54 - if (argc == 7 && atoi(argv[6]) > 0) {  
55 - config.model_config.num_threads = atoi(argv[6]);  
56 - }  
57 - if (argc == 8) {  
58 - config.decoding_method = argv[7]; 49 + po.Read(argc, argv);
  50 + if (po.NumArgs() < 1) {
  51 + po.PrintUsage();
  52 + exit(EXIT_FAILURE);
59 } 53 }
60 - if (argc == 9) {  
61 - config.lm_config.model = argv[8];  
62 - }  
63 - config.max_active_paths = 4;  
64 54
65 fprintf(stderr, "%s\n", config.ToString().c_str()); 55 fprintf(stderr, "%s\n", config.ToString().c_str());
66 56
@@ -71,63 +61,66 @@ for a list of pre-trained models to download. @@ -71,63 +61,66 @@ for a list of pre-trained models to download.
71 61
72 sherpa_onnx::OnlineRecognizer recognizer(config); 62 sherpa_onnx::OnlineRecognizer recognizer(config);
73 63
74 - int32_t sampling_rate = -1;  
75 -  
76 - bool is_ok = false;  
77 - std::vector<float> samples =  
78 - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);  
79 -  
80 - if (!is_ok) {  
81 - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());  
82 - return -1;  
83 - }  
84 - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);  
85 -  
86 - float duration = samples.size() / static_cast<float>(sampling_rate);  
87 -  
88 - fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());  
89 - fprintf(stderr, "wav duration (s): %.3f\n", duration);  
90 -  
91 - auto begin = std::chrono::steady_clock::now();  
92 - fprintf(stderr, "Started\n");  
93 -  
94 - auto s = recognizer.CreateStream();  
95 - s->AcceptWaveform(sampling_rate, samples.data(), samples.size());  
96 -  
97 - std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));  
98 - // Note: We can call AcceptWaveform() multiple times.  
99 - s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());  
100 -  
101 - // Call InputFinished() to indicate that no audio samples are available  
102 - s->InputFinished();  
103 -  
104 - while (recognizer.IsReady(s.get())) {  
105 - recognizer.DecodeStream(s.get()); 64 + float duration = 0;
  65 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  66 + const std::string wav_filename = po.GetArg(i);
  67 + int32_t sampling_rate = -1;
  68 +
  69 + bool is_ok = false;
  70 + const std::vector<float> samples =
  71 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  72 +
  73 + if (!is_ok) {
  74 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  75 + return -1;
  76 + }
  77 + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
  78 +
  79 + const float duration = samples.size() / static_cast<float>(sampling_rate);
  80 +
  81 + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
  82 + fprintf(stderr, "wav duration (s): %.3f\n", duration);
  83 +
  84 + fprintf(stderr, "Started\n");
  85 + const auto begin = std::chrono::steady_clock::now();
  86 +
  87 + auto s = recognizer.CreateStream();
  88 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  89 +
  90 + std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));
  91 + // Note: We can call AcceptWaveform() multiple times.
  92 + s->AcceptWaveform(
  93 + sampling_rate, tail_paddings.data(), tail_paddings.size());
  94 +
  95 + // Call InputFinished() to indicate that no audio samples are available
  96 + s->InputFinished();
  97 +
  98 + while (recognizer.IsReady(s.get())) {
  99 + recognizer.DecodeStream(s.get());
  100 + }
  101 +
  102 + const std::string text = recognizer.GetResult(s.get()).AsJsonString();
  103 +
  104 + const auto end = std::chrono::steady_clock::now();
  105 + const float elapsed_seconds =
  106 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  107 + .count() / 1000.;
  108 +
  109 + fprintf(stderr, "Done!\n");
  110 + fprintf(stderr,
  111 + "Recognition result for %s:\n%s\n",
  112 + wav_filename.c_str(), text.c_str());
  113 + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
  114 + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
  115 + if (config.decoding_method == "modified_beam_search") {
  116 + fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
  117 + }
  118 +
  119 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  120 + const float rtf = elapsed_seconds / duration;
  121 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  122 + elapsed_seconds, duration, rtf);
106 } 123 }
107 124
108 - std::string text = recognizer.GetResult(s.get()).AsJsonString();  
109 -  
110 - fprintf(stderr, "Done!\n");  
111 -  
112 - fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),  
113 - text.c_str());  
114 -  
115 - auto end = std::chrono::steady_clock::now();  
116 - float elapsed_seconds =  
117 - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)  
118 - .count() /  
119 - 1000.;  
120 -  
121 - fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);  
122 - fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());  
123 - if (config.decoding_method == "modified_beam_search") {  
124 - fprintf(stderr, "max active paths: %d\n", config.max_active_paths);  
125 - }  
126 -  
127 - fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);  
128 - float rtf = elapsed_seconds / duration;  
129 - fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",  
130 - elapsed_seconds, duration, rtf);  
131 -  
132 return 0; 125 return 0;
133 } 126 }