Committed by
GitHub
Added provider option to sherpa-onnx and decode-file-c-api (#162)
正在显示
9 个修改的文件
包含
268 行增加
和
174 行删除
| @@ -36,22 +36,22 @@ $repo/test_wavs/8k.wav | @@ -36,22 +36,22 @@ $repo/test_wavs/8k.wav | ||
| 36 | 36 | ||
| 37 | for wave in ${waves[@]}; do | 37 | for wave in ${waves[@]}; do |
| 38 | time $EXE \ | 38 | time $EXE \ |
| 39 | - $repo/tokens.txt \ | ||
| 40 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 41 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 42 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 43 | - $wave \ | ||
| 44 | - 2 | 39 | + --tokens=$repo/tokens.txt \ |
| 40 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 41 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 42 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 43 | + --num-threads=2 \ | ||
| 44 | + $wave | ||
| 45 | done | 45 | done |
| 46 | 46 | ||
| 47 | for wave in ${waves[@]}; do | 47 | for wave in ${waves[@]}; do |
| 48 | time $EXE \ | 48 | time $EXE \ |
| 49 | - $repo/tokens.txt \ | ||
| 50 | - $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 51 | - $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 52 | - $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 53 | - $wave \ | ||
| 54 | - 2 | 49 | + --tokens=$repo/tokens.txt \ |
| 50 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 51 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 52 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 53 | + --num-threads=2 \ | ||
| 54 | + $wave | ||
| 55 | done | 55 | done |
| 56 | 56 | ||
| 57 | rm -rf $repo | 57 | rm -rf $repo |
| @@ -79,22 +79,22 @@ $repo/test_wavs/8k.wav | @@ -79,22 +79,22 @@ $repo/test_wavs/8k.wav | ||
| 79 | 79 | ||
| 80 | for wave in ${waves[@]}; do | 80 | for wave in ${waves[@]}; do |
| 81 | time $EXE \ | 81 | time $EXE \ |
| 82 | - $repo/tokens.txt \ | ||
| 83 | - $repo/encoder-epoch-11-avg-1.onnx \ | ||
| 84 | - $repo/decoder-epoch-11-avg-1.onnx \ | ||
| 85 | - $repo/joiner-epoch-11-avg-1.onnx \ | ||
| 86 | - $wave \ | ||
| 87 | - 2 | 82 | + --tokens=$repo/tokens.txt \ |
| 83 | + --encoder=$repo/encoder-epoch-11-avg-1.onnx \ | ||
| 84 | + --decoder=$repo/decoder-epoch-11-avg-1.onnx \ | ||
| 85 | + --joiner=$repo/joiner-epoch-11-avg-1.onnx \ | ||
| 86 | + --num-threads=2 \ | ||
| 87 | + $wave | ||
| 88 | done | 88 | done |
| 89 | 89 | ||
| 90 | for wave in ${waves[@]}; do | 90 | for wave in ${waves[@]}; do |
| 91 | time $EXE \ | 91 | time $EXE \ |
| 92 | - $repo/tokens.txt \ | ||
| 93 | - $repo/encoder-epoch-11-avg-1.int8.onnx \ | ||
| 94 | - $repo/decoder-epoch-11-avg-1.int8.onnx \ | ||
| 95 | - $repo/joiner-epoch-11-avg-1.int8.onnx \ | ||
| 96 | - $wave \ | ||
| 97 | - 2 | 92 | + --tokens=$repo/tokens.txt \ |
| 93 | + --encoder=$repo/encoder-epoch-11-avg-1.int8.onnx \ | ||
| 94 | + --decoder=$repo/decoder-epoch-11-avg-1.int8.onnx \ | ||
| 95 | + --joiner=$repo/joiner-epoch-11-avg-1.int8.onnx \ | ||
| 96 | + --num-threads=2 \ | ||
| 97 | + $wave | ||
| 98 | done | 98 | done |
| 99 | 99 | ||
| 100 | rm -rf $repo | 100 | rm -rf $repo |
| @@ -122,24 +122,24 @@ $repo/test_wavs/8k.wav | @@ -122,24 +122,24 @@ $repo/test_wavs/8k.wav | ||
| 122 | 122 | ||
| 123 | for wave in ${waves[@]}; do | 123 | for wave in ${waves[@]}; do |
| 124 | time $EXE \ | 124 | time $EXE \ |
| 125 | - $repo/tokens.txt \ | ||
| 126 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 127 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 128 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 129 | - $wave \ | ||
| 130 | - 2 | 125 | + --tokens=$repo/tokens.txt \ |
| 126 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 127 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 128 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 129 | + --num-threads=2 \ | ||
| 130 | + $wave | ||
| 131 | done | 131 | done |
| 132 | 132 | ||
| 133 | # test int8 | 133 | # test int8 |
| 134 | # | 134 | # |
| 135 | for wave in ${waves[@]}; do | 135 | for wave in ${waves[@]}; do |
| 136 | time $EXE \ | 136 | time $EXE \ |
| 137 | - $repo/tokens.txt \ | ||
| 138 | - $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 139 | - $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 140 | - $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 141 | - $wave \ | ||
| 142 | - 2 | 137 | + --tokens=$repo/tokens.txt \ |
| 138 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 139 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 140 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 141 | + --num-threads=2 \ | ||
| 142 | + $wave | ||
| 143 | done | 143 | done |
| 144 | 144 | ||
| 145 | rm -rf $repo | 145 | rm -rf $repo |
| @@ -169,22 +169,22 @@ $repo/test_wavs/8k.wav | @@ -169,22 +169,22 @@ $repo/test_wavs/8k.wav | ||
| 169 | 169 | ||
| 170 | for wave in ${waves[@]}; do | 170 | for wave in ${waves[@]}; do |
| 171 | time $EXE \ | 171 | time $EXE \ |
| 172 | - $repo/tokens.txt \ | ||
| 173 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 174 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 175 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 176 | - $wave \ | ||
| 177 | - 2 | 172 | + --tokens=$repo/tokens.txt \ |
| 173 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 174 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 175 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 176 | + --num-threads=2 \ | ||
| 177 | + $wave | ||
| 178 | done | 178 | done |
| 179 | 179 | ||
| 180 | for wave in ${waves[@]}; do | 180 | for wave in ${waves[@]}; do |
| 181 | time $EXE \ | 181 | time $EXE \ |
| 182 | - $repo/tokens.txt \ | ||
| 183 | - $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 184 | - $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 185 | - $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 186 | - $wave \ | ||
| 187 | - 2 | 182 | + --tokens=$repo/tokens.txt \ |
| 183 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 184 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 185 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 186 | + --num-threads=2 \ | ||
| 187 | + $wave | ||
| 188 | done | 188 | done |
| 189 | 189 | ||
| 190 | # Decode a URL | 190 | # Decode a URL |
| @@ -233,22 +233,22 @@ $repo/test_wavs/2.wav | @@ -233,22 +233,22 @@ $repo/test_wavs/2.wav | ||
| 233 | 233 | ||
| 234 | for wave in ${waves[@]}; do | 234 | for wave in ${waves[@]}; do |
| 235 | time $EXE \ | 235 | time $EXE \ |
| 236 | - $repo/tokens.txt \ | ||
| 237 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 238 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 239 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 240 | - $wave \ | ||
| 241 | - 2 | 236 | + --tokens=$repo/tokens.txt \ |
| 237 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 238 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 239 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 240 | + --num-threads=2 \ | ||
| 241 | + $wave | ||
| 242 | done | 242 | done |
| 243 | 243 | ||
| 244 | for wave in ${waves[@]}; do | 244 | for wave in ${waves[@]}; do |
| 245 | time $EXE \ | 245 | time $EXE \ |
| 246 | - $repo/tokens.txt \ | ||
| 247 | - $repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 248 | - $repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 249 | - $repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 250 | - $wave \ | ||
| 251 | - 2 | 246 | + --tokens=$repo/tokens.txt \ |
| 247 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 248 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 249 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 250 | + --num-threads=2 \ | ||
| 251 | + $wave | ||
| 252 | done | 252 | done |
| 253 | 253 | ||
| 254 | rm -rf $repo | 254 | rm -rf $repo |
| 1 | +include(cargs) | ||
| 2 | + | ||
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 3 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | add_executable(decode-file-c-api decode-file-c-api.c) | 4 | add_executable(decode-file-c-api decode-file-c-api.c) |
| 3 | -target_link_libraries(decode-file-c-api sherpa-onnx-c-api) | 5 | +target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs) |
| @@ -5,50 +5,85 @@ | @@ -5,50 +5,85 @@ | ||
| 5 | // This file shows how to use sherpa-onnx C API | 5 | // This file shows how to use sherpa-onnx C API |
| 6 | // to decode a file. | 6 | // to decode a file. |
| 7 | 7 | ||
| 8 | +#include "cargs.h" | ||
| 8 | #include <stdio.h> | 9 | #include <stdio.h> |
| 9 | #include <stdlib.h> | 10 | #include <stdlib.h> |
| 10 | #include <string.h> | 11 | #include <string.h> |
| 11 | 12 | ||
| 12 | #include "sherpa-onnx/c-api/c-api.h" | 13 | #include "sherpa-onnx/c-api/c-api.h" |
| 13 | 14 | ||
| 15 | +static struct cag_option options[] = { | ||
| 16 | + { | ||
| 17 | + .identifier = 't', | ||
| 18 | + .access_letters = NULL, | ||
| 19 | + .access_name = "tokens", | ||
| 20 | + .value_name = "tokens", | ||
| 21 | + .description = "Tokens file" | ||
| 22 | + }, { | ||
| 23 | + .identifier = 'e', | ||
| 24 | + .access_letters = NULL, | ||
| 25 | + .access_name = "encoder", | ||
| 26 | + .value_name = "encoder", | ||
| 27 | + .description = "Encoder ONNX file" | ||
| 28 | + }, { | ||
| 29 | + .identifier = 'd', | ||
| 30 | + .access_letters = NULL, | ||
| 31 | + .access_name = "decoder", | ||
| 32 | + .value_name = "decoder", | ||
| 33 | + .description = "Decoder ONNX file" | ||
| 34 | + }, { | ||
| 35 | + .identifier = 'j', | ||
| 36 | + .access_letters = NULL, | ||
| 37 | + .access_name = "joiner", | ||
| 38 | + .value_name = "joiner", | ||
| 39 | + .description = "Joiner ONNX file" | ||
| 40 | + }, { | ||
| 41 | + .identifier = 'n', | ||
| 42 | + .access_letters = NULL, | ||
| 43 | + .access_name = "num-threads", | ||
| 44 | + .value_name = "num-threads", | ||
| 45 | + .description = "Number of threads" | ||
| 46 | + }, { | ||
| 47 | + .identifier = 'p', | ||
| 48 | + .access_letters = NULL, | ||
| 49 | + .access_name = "provider", | ||
| 50 | + .value_name = "provider", | ||
| 51 | + .description = "Provider: cpu (default), cuda, coreml" | ||
| 52 | + }, { | ||
| 53 | + .identifier = 'm', | ||
| 54 | + .access_letters = NULL, | ||
| 55 | + .access_name = "decoding-method", | ||
| 56 | + .value_name = "decoding-method", | ||
| 57 | + .description = | ||
| 58 | + "Decoding method: greedy_search (default), modified_beam_search" | ||
| 59 | + } | ||
| 60 | +}; | ||
| 61 | + | ||
| 14 | const char *kUsage = | 62 | const char *kUsage = |
| 15 | "\n" | 63 | "\n" |
| 16 | "Usage:\n " | 64 | "Usage:\n " |
| 17 | " ./bin/decode-file-c-api \\\n" | 65 | " ./bin/decode-file-c-api \\\n" |
| 18 | - " /path/to/tokens.txt \\\n" | ||
| 19 | - " /path/to/encoder.onnx \\\n" | ||
| 20 | - " /path/to/decoder.onnx \\\n" | ||
| 21 | - " /path/to/joiner.onnx \\\n" | ||
| 22 | - " /path/to/foo.wav [num_threads [decoding_method]]\n" | 66 | + " --tokens=/path/to/tokens.txt \\\n" |
| 67 | + " --encoder=/path/to/encoder.onnx \\\n" | ||
| 68 | + " --decoder=/path/to/decoder.onnx \\\n" | ||
| 69 | + " --joiner=/path/to/joiner.onnx \\\n" | ||
| 70 | + " /path/to/foo.wav\n" | ||
| 23 | "\n\n" | 71 | "\n\n" |
| 24 | "Default num_threads is 1.\n" | 72 | "Default num_threads is 1.\n" |
| 25 | "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" | 73 | "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" |
| 74 | + "Valid provider: cpu (default), cuda, coreml\n\n" | ||
| 26 | "Please refer to \n" | 75 | "Please refer to \n" |
| 27 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" | 76 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" |
| 28 | "for a list of pre-trained models to download.\n"; | 77 | "for a list of pre-trained models to download.\n"; |
| 29 | 78 | ||
| 30 | int32_t main(int32_t argc, char *argv[]) { | 79 | int32_t main(int32_t argc, char *argv[]) { |
| 31 | - if (argc < 6 || argc > 8) { | ||
| 32 | - fprintf(stderr, "%s\n", kUsage); | ||
| 33 | - return -1; | ||
| 34 | - } | ||
| 35 | SherpaOnnxOnlineRecognizerConfig config; | 80 | SherpaOnnxOnlineRecognizerConfig config; |
| 36 | - config.model_config.tokens = argv[1]; | ||
| 37 | - config.model_config.encoder = argv[2]; | ||
| 38 | - config.model_config.decoder = argv[3]; | ||
| 39 | - config.model_config.joiner = argv[4]; | ||
| 40 | - | ||
| 41 | - int32_t num_threads = 1; | ||
| 42 | - if (argc == 7 && atoi(argv[6]) > 0) { | ||
| 43 | - num_threads = atoi(argv[6]); | ||
| 44 | - } | ||
| 45 | - config.model_config.num_threads = num_threads; | 81 | + |
| 46 | config.model_config.debug = 0; | 82 | config.model_config.debug = 0; |
| 83 | + config.model_config.num_threads = 1; | ||
| 84 | + config.model_config.provider = "cpu"; | ||
| 47 | 85 | ||
| 48 | config.decoding_method = "greedy_search"; | 86 | config.decoding_method = "greedy_search"; |
| 49 | - if (argc == 8) { | ||
| 50 | - config.decoding_method = argv[7]; | ||
| 51 | - } | ||
| 52 | 87 | ||
| 53 | config.max_active_paths = 4; | 88 | config.max_active_paths = 4; |
| 54 | 89 | ||
| @@ -60,13 +95,36 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -60,13 +95,36 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 60 | config.rule2_min_trailing_silence = 1.2; | 95 | config.rule2_min_trailing_silence = 1.2; |
| 61 | config.rule3_min_utterance_length = 300; | 96 | config.rule3_min_utterance_length = 300; |
| 62 | 97 | ||
| 98 | + cag_option_context context; | ||
| 99 | + char identifier; | ||
| 100 | + const char *value; | ||
| 101 | + | ||
| 102 | + cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv); | ||
| 103 | + | ||
| 104 | + while (cag_option_fetch(&context)) { | ||
| 105 | + identifier = cag_option_get(&context); | ||
| 106 | + value = cag_option_get_value(&context); | ||
| 107 | + switch (identifier) { | ||
| 108 | + case 't': config.model_config.tokens = value; break; | ||
| 109 | + case 'e': config.model_config.encoder = value; break; | ||
| 110 | + case 'd': config.model_config.decoder = value; break; | ||
| 111 | + case 'j': config.model_config.joiner = value; break; | ||
| 112 | + case 'n': config.model_config.num_threads = atoi(value); break; | ||
| 113 | + case 'p': config.model_config.provider = value; break; | ||
| 114 | + case 'm': config.decoding_method = value; break; | ||
| 115 | + default: | ||
| 116 | + // do nothing as config already have valid default values | ||
| 117 | + break; | ||
| 118 | + } | ||
| 119 | + } | ||
| 120 | + | ||
| 63 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); | 121 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); |
| 64 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); | 122 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); |
| 65 | 123 | ||
| 66 | SherpaOnnxDisplay *display = CreateDisplay(50); | 124 | SherpaOnnxDisplay *display = CreateDisplay(50); |
| 67 | int32_t segment_id = 0; | 125 | int32_t segment_id = 0; |
| 68 | 126 | ||
| 69 | - const char *wav_filename = argv[5]; | 127 | + const char *wav_filename = argv[context.index]; |
| 70 | FILE *fp = fopen(wav_filename, "rb"); | 128 | FILE *fp = fopen(wav_filename, "rb"); |
| 71 | if (!fp) { | 129 | if (!fp) { |
| 72 | fprintf(stderr, "Failed to open %s\n", wav_filename); | 130 | fprintf(stderr, "Failed to open %s\n", wav_filename); |
cmake/cargs.cmake
0 → 100644
| 1 | +function(download_cargs) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(cargs_URL "https://github.com/likle/cargs/archive/refs/tags/v1.0.3.tar.gz") | ||
| 5 | + set(cargs_HASH "SHA256=ddba25bd35e9c6c75bc706c126001b8ce8e084d40ef37050e6aa6963e836eb8b") | ||
| 6 | + | ||
| 7 | + # If you don't have access to the Internet, | ||
| 8 | + # please pre-download asio | ||
| 9 | + set(possible_file_locations | ||
| 10 | + $ENV{HOME}/Downloads/cargs-v1-0-3.tar.gz | ||
| 11 | + ${PROJECT_SOURCE_DIR}/cargs-v1-0-3.tar.gz | ||
| 12 | + ${PROJECT_BINARY_DIR}/cargs-v1-0-3.tar.gz | ||
| 13 | + /tmp/cargs-v1-0-3.tar.gz | ||
| 14 | + /star-fj/fangjun/download/github/cargs-v1-0-3.tar.gz | ||
| 15 | + ) | ||
| 16 | + | ||
| 17 | + foreach(f IN LISTS possible_file_locations) | ||
| 18 | + if(EXISTS ${f}) | ||
| 19 | + set(cargs_URL "${f}") | ||
| 20 | + file(TO_CMAKE_PATH "${cargs_URL}" cargs_URL) | ||
| 21 | + break() | ||
| 22 | + endif() | ||
| 23 | + endforeach() | ||
| 24 | + | ||
| 25 | + FetchContent_Declare(cargs URL ${cargs_URL} URL_HASH ${cargs_HASH}) | ||
| 26 | + | ||
| 27 | + FetchContent_GetProperties(cargs) | ||
| 28 | + if(NOT cargs_POPULATED) | ||
| 29 | + message(STATUS "Downloading cargs ${cargs_URL}") | ||
| 30 | + FetchContent_Populate(cargs) | ||
| 31 | + endif() | ||
| 32 | + message(STATUS "cargs is downloaded to ${cargs_SOURCE_DIR}") | ||
| 33 | + add_subdirectory(${cargs_SOURCE_DIR} ${cargs_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 34 | +endfunction() | ||
| 35 | + | ||
| 36 | +download_cargs() |
| @@ -41,6 +41,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -41,6 +41,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 41 | recognizer_config.model_config.joiner_filename = config->model_config.joiner; | 41 | recognizer_config.model_config.joiner_filename = config->model_config.joiner; |
| 42 | recognizer_config.model_config.tokens = config->model_config.tokens; | 42 | recognizer_config.model_config.tokens = config->model_config.tokens; |
| 43 | recognizer_config.model_config.num_threads = config->model_config.num_threads; | 43 | recognizer_config.model_config.num_threads = config->model_config.num_threads; |
| 44 | + recognizer_config.model_config.provider = config->model_config.provider; | ||
| 44 | recognizer_config.model_config.debug = config->model_config.debug; | 45 | recognizer_config.model_config.debug = config->model_config.debug; |
| 45 | 46 | ||
| 46 | recognizer_config.decoding_method = config->decoding_method; | 47 | recognizer_config.decoding_method = config->decoding_method; |
| @@ -52,6 +52,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | @@ -52,6 +52,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | ||
| 52 | const char *joiner; | 52 | const char *joiner; |
| 53 | const char *tokens; | 53 | const char *tokens; |
| 54 | int32_t num_threads; | 54 | int32_t num_threads; |
| 55 | + const char *provider; | ||
| 55 | int32_t debug; // true to print debug information of the model | 56 | int32_t debug; // true to print debug information of the model |
| 56 | } SherpaOnnxOnlineTransducerModelConfig; | 57 | } SherpaOnnxOnlineTransducerModelConfig; |
| 57 | 58 |
| @@ -17,6 +17,8 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { | @@ -17,6 +17,8 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 17 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 17 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 18 | po->Register("num_threads", &num_threads, | 18 | po->Register("num_threads", &num_threads, |
| 19 | "Number of threads to run the neural network"); | 19 | "Number of threads to run the neural network"); |
| 20 | + po->Register("provider", &provider, | ||
| 21 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 20 | 22 | ||
| 21 | po->Register("debug", &debug, | 23 | po->Register("debug", &debug, |
| 22 | "true to print model information while loading it."); | 24 | "true to print model information while loading it."); |
| @@ -60,6 +62,7 @@ std::string OnlineTransducerModelConfig::ToString() const { | @@ -60,6 +62,7 @@ std::string OnlineTransducerModelConfig::ToString() const { | ||
| 60 | os << "joiner_filename=\"" << joiner_filename << "\", "; | 62 | os << "joiner_filename=\"" << joiner_filename << "\", "; |
| 61 | os << "tokens=\"" << tokens << "\", "; | 63 | os << "tokens=\"" << tokens << "\", "; |
| 62 | os << "num_threads=" << num_threads << ", "; | 64 | os << "num_threads=" << num_threads << ", "; |
| 65 | + os << "provider=\"" << provider << "\", "; | ||
| 63 | os << "debug=" << (debug ? "True" : "False") << ")"; | 66 | os << "debug=" << (debug ? "True" : "False") << ")"; |
| 64 | 67 | ||
| 65 | return os.str(); | 68 | return os.str(); |
| @@ -69,17 +69,17 @@ for a list of pre-trained models to download. | @@ -69,17 +69,17 @@ for a list of pre-trained models to download. | ||
| 69 | fprintf(stderr, "Creating recognizer ...\n"); | 69 | fprintf(stderr, "Creating recognizer ...\n"); |
| 70 | sherpa_onnx::OfflineRecognizer recognizer(config); | 70 | sherpa_onnx::OfflineRecognizer recognizer(config); |
| 71 | 71 | ||
| 72 | - auto begin = std::chrono::steady_clock::now(); | 72 | + const auto begin = std::chrono::steady_clock::now(); |
| 73 | fprintf(stderr, "Started\n"); | 73 | fprintf(stderr, "Started\n"); |
| 74 | 74 | ||
| 75 | std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; | 75 | std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; |
| 76 | std::vector<sherpa_onnx::OfflineStream *> ss_pointers; | 76 | std::vector<sherpa_onnx::OfflineStream *> ss_pointers; |
| 77 | float duration = 0; | 77 | float duration = 0; |
| 78 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { | 78 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { |
| 79 | - std::string wav_filename = po.GetArg(i); | 79 | + const std::string wav_filename = po.GetArg(i); |
| 80 | int32_t sampling_rate = -1; | 80 | int32_t sampling_rate = -1; |
| 81 | bool is_ok = false; | 81 | bool is_ok = false; |
| 82 | - std::vector<float> samples = | 82 | + const std::vector<float> samples = |
| 83 | sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | 83 | sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); |
| 84 | if (!is_ok) { | 84 | if (!is_ok) { |
| 85 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | 85 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| @@ -96,7 +96,7 @@ for a list of pre-trained models to download. | @@ -96,7 +96,7 @@ for a list of pre-trained models to download. | ||
| 96 | 96 | ||
| 97 | recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); | 97 | recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); |
| 98 | 98 | ||
| 99 | - auto end = std::chrono::steady_clock::now(); | 99 | + const auto end = std::chrono::steady_clock::now(); |
| 100 | 100 | ||
| 101 | fprintf(stderr, "Done!\n\n"); | 101 | fprintf(stderr, "Done!\n\n"); |
| 102 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { | 102 | for (int32_t i = 1; i <= po.NumArgs(); ++i) { |
| @@ -11,22 +11,28 @@ | @@ -11,22 +11,28 @@ | ||
| 11 | #include "sherpa-onnx/csrc/online-recognizer.h" | 11 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 12 | #include "sherpa-onnx/csrc/online-stream.h" | 12 | #include "sherpa-onnx/csrc/online-stream.h" |
| 13 | #include "sherpa-onnx/csrc/symbol-table.h" | 13 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 14 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | #include "sherpa-onnx/csrc/wave-reader.h" | 15 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 15 | 16 | ||
| 16 | -// TODO(fangjun): Use ParseOptions as we are getting more args | ||
| 17 | int main(int32_t argc, char *argv[]) { | 17 | int main(int32_t argc, char *argv[]) { |
| 18 | - if (argc < 6 || argc > 9) { | ||
| 19 | - const char *usage = R"usage( | 18 | + const char *kUsageMessage = R"usage( |
| 20 | Usage: | 19 | Usage: |
| 20 | + | ||
| 21 | ./bin/sherpa-onnx \ | 21 | ./bin/sherpa-onnx \ |
| 22 | - /path/to/tokens.txt \ | ||
| 23 | - /path/to/encoder.onnx \ | ||
| 24 | - /path/to/decoder.onnx \ | ||
| 25 | - /path/to/joiner.onnx \ | ||
| 26 | - /path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]] | 22 | + --tokens=/path/to/tokens.txt \ |
| 23 | + --encoder=/path/to/encoder.onnx \ | ||
| 24 | + --decoder=/path/to/decoder.onnx \ | ||
| 25 | + --joiner=/path/to/joiner.onnx \ | ||
| 26 | + --provider=cpu \ | ||
| 27 | + --num-threads=2 \ | ||
| 28 | + --decoding-method=greedy_search \ | ||
| 29 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 30 | + | ||
| 31 | +Note: It supports decoding multiple files in batches | ||
| 27 | 32 | ||
| 28 | Default value for num_threads is 2. | 33 | Default value for num_threads is 2. |
| 29 | Valid values for decoding_method: greedy_search (default), modified_beam_search. | 34 | Valid values for decoding_method: greedy_search (default), modified_beam_search. |
| 35 | +Valid values for provider: cpu (default), cuda, coreml. | ||
| 30 | foo.wav should be of single channel, 16-bit PCM encoded wave file; its | 36 | foo.wav should be of single channel, 16-bit PCM encoded wave file; its |
| 31 | sampling rate can be arbitrary and does not need to be 16kHz. | 37 | sampling rate can be arbitrary and does not need to be 16kHz. |
| 32 | 38 | ||
| @@ -34,33 +40,17 @@ Please refer to | @@ -34,33 +40,17 @@ Please refer to | ||
| 34 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 40 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 35 | for a list of pre-trained models to download. | 41 | for a list of pre-trained models to download. |
| 36 | )usage"; | 42 | )usage"; |
| 37 | - fprintf(stderr, "%s\n", usage); | ||
| 38 | - | ||
| 39 | - return 0; | ||
| 40 | - } | ||
| 41 | 43 | ||
| 44 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 42 | sherpa_onnx::OnlineRecognizerConfig config; | 45 | sherpa_onnx::OnlineRecognizerConfig config; |
| 43 | 46 | ||
| 44 | - config.model_config.tokens = argv[1]; | 47 | + config.Register(&po); |
| 45 | 48 | ||
| 46 | - config.model_config.debug = false; | ||
| 47 | - config.model_config.encoder_filename = argv[2]; | ||
| 48 | - config.model_config.decoder_filename = argv[3]; | ||
| 49 | - config.model_config.joiner_filename = argv[4]; | ||
| 50 | - | ||
| 51 | - std::string wav_filename = argv[5]; | ||
| 52 | - | ||
| 53 | - config.model_config.num_threads = 2; | ||
| 54 | - if (argc == 7 && atoi(argv[6]) > 0) { | ||
| 55 | - config.model_config.num_threads = atoi(argv[6]); | ||
| 56 | - } | ||
| 57 | - if (argc == 8) { | ||
| 58 | - config.decoding_method = argv[7]; | 49 | + po.Read(argc, argv); |
| 50 | + if (po.NumArgs() < 1) { | ||
| 51 | + po.PrintUsage(); | ||
| 52 | + exit(EXIT_FAILURE); | ||
| 59 | } | 53 | } |
| 60 | - if (argc == 9) { | ||
| 61 | - config.lm_config.model = argv[8]; | ||
| 62 | - } | ||
| 63 | - config.max_active_paths = 4; | ||
| 64 | 54 | ||
| 65 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 55 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 66 | 56 | ||
| @@ -71,63 +61,66 @@ for a list of pre-trained models to download. | @@ -71,63 +61,66 @@ for a list of pre-trained models to download. | ||
| 71 | 61 | ||
| 72 | sherpa_onnx::OnlineRecognizer recognizer(config); | 62 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 73 | 63 | ||
| 74 | - int32_t sampling_rate = -1; | ||
| 75 | - | ||
| 76 | - bool is_ok = false; | ||
| 77 | - std::vector<float> samples = | ||
| 78 | - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 79 | - | ||
| 80 | - if (!is_ok) { | ||
| 81 | - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 82 | - return -1; | ||
| 83 | - } | ||
| 84 | - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 85 | - | ||
| 86 | - float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 87 | - | ||
| 88 | - fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | ||
| 89 | - fprintf(stderr, "wav duration (s): %.3f\n", duration); | ||
| 90 | - | ||
| 91 | - auto begin = std::chrono::steady_clock::now(); | ||
| 92 | - fprintf(stderr, "Started\n"); | ||
| 93 | - | ||
| 94 | - auto s = recognizer.CreateStream(); | ||
| 95 | - s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 96 | - | ||
| 97 | - std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate)); | ||
| 98 | - // Note: We can call AcceptWaveform() multiple times. | ||
| 99 | - s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); | ||
| 100 | - | ||
| 101 | - // Call InputFinished() to indicate that no audio samples are available | ||
| 102 | - s->InputFinished(); | ||
| 103 | - | ||
| 104 | - while (recognizer.IsReady(s.get())) { | ||
| 105 | - recognizer.DecodeStream(s.get()); | 64 | + float duration = 0; |
| 65 | + for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||
| 66 | + const std::string wav_filename = po.GetArg(i); | ||
| 67 | + int32_t sampling_rate = -1; | ||
| 68 | + | ||
| 69 | + bool is_ok = false; | ||
| 70 | + const std::vector<float> samples = | ||
| 71 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 72 | + | ||
| 73 | + if (!is_ok) { | ||
| 74 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 75 | + return -1; | ||
| 76 | + } | ||
| 77 | + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 78 | + | ||
| 79 | + const float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 80 | + | ||
| 81 | + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | ||
| 82 | + fprintf(stderr, "wav duration (s): %.3f\n", duration); | ||
| 83 | + | ||
| 84 | + fprintf(stderr, "Started\n"); | ||
| 85 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 86 | + | ||
| 87 | + auto s = recognizer.CreateStream(); | ||
| 88 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 89 | + | ||
| 90 | + std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate)); | ||
| 91 | + // Note: We can call AcceptWaveform() multiple times. | ||
| 92 | + s->AcceptWaveform( | ||
| 93 | + sampling_rate, tail_paddings.data(), tail_paddings.size()); | ||
| 94 | + | ||
| 95 | + // Call InputFinished() to indicate that no audio samples are available | ||
| 96 | + s->InputFinished(); | ||
| 97 | + | ||
| 98 | + while (recognizer.IsReady(s.get())) { | ||
| 99 | + recognizer.DecodeStream(s.get()); | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + const std::string text = recognizer.GetResult(s.get()).AsJsonString(); | ||
| 103 | + | ||
| 104 | + const auto end = std::chrono::steady_clock::now(); | ||
| 105 | + const float elapsed_seconds = | ||
| 106 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 107 | + .count() / 1000.; | ||
| 108 | + | ||
| 109 | + fprintf(stderr, "Done!\n"); | ||
| 110 | + fprintf(stderr, | ||
| 111 | + "Recognition result for %s:\n%s\n", | ||
| 112 | + wav_filename.c_str(), text.c_str()); | ||
| 113 | + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | ||
| 114 | + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||
| 115 | + if (config.decoding_method == "modified_beam_search") { | ||
| 116 | + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 120 | + const float rtf = elapsed_seconds / duration; | ||
| 121 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 122 | + elapsed_seconds, duration, rtf); | ||
| 106 | } | 123 | } |
| 107 | 124 | ||
| 108 | - std::string text = recognizer.GetResult(s.get()).AsJsonString(); | ||
| 109 | - | ||
| 110 | - fprintf(stderr, "Done!\n"); | ||
| 111 | - | ||
| 112 | - fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), | ||
| 113 | - text.c_str()); | ||
| 114 | - | ||
| 115 | - auto end = std::chrono::steady_clock::now(); | ||
| 116 | - float elapsed_seconds = | ||
| 117 | - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 118 | - .count() / | ||
| 119 | - 1000.; | ||
| 120 | - | ||
| 121 | - fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | ||
| 122 | - fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||
| 123 | - if (config.decoding_method == "modified_beam_search") { | ||
| 124 | - fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | ||
| 125 | - } | ||
| 126 | - | ||
| 127 | - fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 128 | - float rtf = elapsed_seconds / duration; | ||
| 129 | - fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 130 | - elapsed_seconds, duration, rtf); | ||
| 131 | - | ||
| 132 | return 0; | 125 | return 0; |
| 133 | } | 126 | } |
-
请 注册 或 登录 后发表评论