Committed by
GitHub
Fix C# APIs (#183)
* Fix c# APIs * reformat
正在显示
4 个修改的文件
包含
84 行增加
和
56 行删除
| @@ -5,59 +5,54 @@ | @@ -5,59 +5,54 @@ | ||
| 5 | // This file shows how to use sherpa-onnx C API | 5 | // This file shows how to use sherpa-onnx C API |
| 6 | // to decode a file. | 6 | // to decode a file. |
| 7 | 7 | ||
| 8 | -#include "cargs.h" | ||
| 9 | #include <stdio.h> | 8 | #include <stdio.h> |
| 10 | #include <stdlib.h> | 9 | #include <stdlib.h> |
| 11 | #include <string.h> | 10 | #include <string.h> |
| 12 | 11 | ||
| 12 | +#include "cargs.h" | ||
| 13 | #include "sherpa-onnx/c-api/c-api.h" | 13 | #include "sherpa-onnx/c-api/c-api.h" |
| 14 | 14 | ||
| 15 | static struct cag_option options[] = { | 15 | static struct cag_option options[] = { |
| 16 | - { | ||
| 17 | - .identifier = 't', | ||
| 18 | - .access_letters = NULL, | ||
| 19 | - .access_name = "tokens", | ||
| 20 | - .value_name = "tokens", | ||
| 21 | - .description = "Tokens file" | ||
| 22 | - }, { | ||
| 23 | - .identifier = 'e', | ||
| 24 | - .access_letters = NULL, | ||
| 25 | - .access_name = "encoder", | ||
| 26 | - .value_name = "encoder", | ||
| 27 | - .description = "Encoder ONNX file" | ||
| 28 | - }, { | ||
| 29 | - .identifier = 'd', | ||
| 30 | - .access_letters = NULL, | ||
| 31 | - .access_name = "decoder", | ||
| 32 | - .value_name = "decoder", | ||
| 33 | - .description = "Decoder ONNX file" | ||
| 34 | - }, { | ||
| 35 | - .identifier = 'j', | ||
| 36 | - .access_letters = NULL, | ||
| 37 | - .access_name = "joiner", | ||
| 38 | - .value_name = "joiner", | ||
| 39 | - .description = "Joiner ONNX file" | ||
| 40 | - }, { | ||
| 41 | - .identifier = 'n', | ||
| 42 | - .access_letters = NULL, | ||
| 43 | - .access_name = "num-threads", | ||
| 44 | - .value_name = "num-threads", | ||
| 45 | - .description = "Number of threads" | ||
| 46 | - }, { | ||
| 47 | - .identifier = 'p', | ||
| 48 | - .access_letters = NULL, | ||
| 49 | - .access_name = "provider", | ||
| 50 | - .value_name = "provider", | ||
| 51 | - .description = "Provider: cpu (default), cuda, coreml" | ||
| 52 | - }, { | ||
| 53 | - .identifier = 'm', | ||
| 54 | - .access_letters = NULL, | ||
| 55 | - .access_name = "decoding-method", | ||
| 56 | - .value_name = "decoding-method", | ||
| 57 | - .description = | ||
| 58 | - "Decoding method: greedy_search (default), modified_beam_search" | ||
| 59 | - } | ||
| 60 | -}; | 16 | + {.identifier = 'h', |
| 17 | + .access_letters = "h", | ||
| 18 | + .access_name = "help", | ||
| 19 | + .description = "Show help"}, | ||
| 20 | + {.identifier = 't', | ||
| 21 | + .access_letters = NULL, | ||
| 22 | + .access_name = "tokens", | ||
| 23 | + .value_name = "tokens", | ||
| 24 | + .description = "Tokens file"}, | ||
| 25 | + {.identifier = 'e', | ||
| 26 | + .access_letters = NULL, | ||
| 27 | + .access_name = "encoder", | ||
| 28 | + .value_name = "encoder", | ||
| 29 | + .description = "Encoder ONNX file"}, | ||
| 30 | + {.identifier = 'd', | ||
| 31 | + .access_letters = NULL, | ||
| 32 | + .access_name = "decoder", | ||
| 33 | + .value_name = "decoder", | ||
| 34 | + .description = "Decoder ONNX file"}, | ||
| 35 | + {.identifier = 'j', | ||
| 36 | + .access_letters = NULL, | ||
| 37 | + .access_name = "joiner", | ||
| 38 | + .value_name = "joiner", | ||
| 39 | + .description = "Joiner ONNX file"}, | ||
| 40 | + {.identifier = 'n', | ||
| 41 | + .access_letters = NULL, | ||
| 42 | + .access_name = "num-threads", | ||
| 43 | + .value_name = "num-threads", | ||
| 44 | + .description = "Number of threads"}, | ||
| 45 | + {.identifier = 'p', | ||
| 46 | + .access_letters = NULL, | ||
| 47 | + .access_name = "provider", | ||
| 48 | + .value_name = "provider", | ||
| 49 | + .description = "Provider: cpu (default), cuda, coreml"}, | ||
| 50 | + {.identifier = 'm', | ||
| 51 | + .access_letters = NULL, | ||
| 52 | + .access_name = "decoding-method", | ||
| 53 | + .value_name = "decoding-method", | ||
| 54 | + .description = | ||
| 55 | + "Decoding method: greedy_search (default), modified_beam_search"}}; | ||
| 61 | 56 | ||
| 62 | const char *kUsage = | 57 | const char *kUsage = |
| 63 | "\n" | 58 | "\n" |
| @@ -67,6 +62,7 @@ const char *kUsage = | @@ -67,6 +62,7 @@ const char *kUsage = | ||
| 67 | " --encoder=/path/to/encoder.onnx \\\n" | 62 | " --encoder=/path/to/encoder.onnx \\\n" |
| 68 | " --decoder=/path/to/decoder.onnx \\\n" | 63 | " --decoder=/path/to/decoder.onnx \\\n" |
| 69 | " --joiner=/path/to/joiner.onnx \\\n" | 64 | " --joiner=/path/to/joiner.onnx \\\n" |
| 65 | + " --provider=cpu \\\n" | ||
| 70 | " /path/to/foo.wav\n" | 66 | " /path/to/foo.wav\n" |
| 71 | "\n\n" | 67 | "\n\n" |
| 72 | "Default num_threads is 1.\n" | 68 | "Default num_threads is 1.\n" |
| @@ -77,6 +73,11 @@ const char *kUsage = | @@ -77,6 +73,11 @@ const char *kUsage = | ||
| 77 | "for a list of pre-trained models to download.\n"; | 73 | "for a list of pre-trained models to download.\n"; |
| 78 | 74 | ||
| 79 | int32_t main(int32_t argc, char *argv[]) { | 75 | int32_t main(int32_t argc, char *argv[]) { |
| 76 | + if (argc < 6) { | ||
| 77 | + fprintf(stderr, "%s\n", kUsage); | ||
| 78 | + exit(0); | ||
| 79 | + } | ||
| 80 | + | ||
| 80 | SherpaOnnxOnlineRecognizerConfig config; | 81 | SherpaOnnxOnlineRecognizerConfig config; |
| 81 | 82 | ||
| 82 | config.model_config.debug = 0; | 83 | config.model_config.debug = 0; |
| @@ -105,19 +106,38 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -105,19 +106,38 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 105 | identifier = cag_option_get(&context); | 106 | identifier = cag_option_get(&context); |
| 106 | value = cag_option_get_value(&context); | 107 | value = cag_option_get_value(&context); |
| 107 | switch (identifier) { | 108 | switch (identifier) { |
| 108 | - case 't': config.model_config.tokens = value; break; | ||
| 109 | - case 'e': config.model_config.encoder = value; break; | ||
| 110 | - case 'd': config.model_config.decoder = value; break; | ||
| 111 | - case 'j': config.model_config.joiner = value; break; | ||
| 112 | - case 'n': config.model_config.num_threads = atoi(value); break; | ||
| 113 | - case 'p': config.model_config.provider = value; break; | ||
| 114 | - case 'm': config.decoding_method = value; break; | ||
| 115 | - default: | 109 | + case 't': |
| 110 | + config.model_config.tokens = value; | ||
| 111 | + break; | ||
| 112 | + case 'e': | ||
| 113 | + config.model_config.encoder = value; | ||
| 114 | + break; | ||
| 115 | + case 'd': | ||
| 116 | + config.model_config.decoder = value; | ||
| 117 | + break; | ||
| 118 | + case 'j': | ||
| 119 | + config.model_config.joiner = value; | ||
| 120 | + break; | ||
| 121 | + case 'n': | ||
| 122 | + config.model_config.num_threads = atoi(value); | ||
| 123 | + break; | ||
| 124 | + case 'p': | ||
| 125 | + config.model_config.provider = value; | ||
| 126 | + break; | ||
| 127 | + case 'm': | ||
| 128 | + config.decoding_method = value; | ||
| 129 | + break; | ||
| 130 | + case 'h': { | ||
| 131 | + fprintf(stderr, "%s\n", kUsage); | ||
| 132 | + exit(0); | ||
| 133 | + break; | ||
| 134 | + } | ||
| 135 | + default: | ||
| 116 | // do nothing as config already have valid default values | 136 | // do nothing as config already have valid default values |
| 117 | break; | 137 | break; |
| 118 | } | 138 | } |
| 119 | } | 139 | } |
| 120 | - | 140 | + |
| 121 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); | 141 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); |
| 122 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); | 142 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); |
| 123 | 143 |
| @@ -20,6 +20,9 @@ class OnlineDecodeFiles | @@ -20,6 +20,9 @@ class OnlineDecodeFiles | ||
| 20 | [Option(Required = true, HelpText = "Path to tokens.txt")] | 20 | [Option(Required = true, HelpText = "Path to tokens.txt")] |
| 21 | public string Tokens { get; set; } | 21 | public string Tokens { get; set; } |
| 22 | 22 | ||
| 23 | + [Option(Required = false, Default = "cpu", HelpText = "Provider, e.g., cpu, coreml")] | ||
| 24 | + public string Provider { get; set; } | ||
| 25 | + | ||
| 23 | [Option(Required = true, HelpText = "Path to encoder.onnx")] | 26 | [Option(Required = true, HelpText = "Path to encoder.onnx")] |
| 24 | public string Encoder { get; set; } | 27 | public string Encoder { get; set; } |
| 25 | 28 | ||
| @@ -124,6 +127,7 @@ to download pre-trained streaming models. | @@ -124,6 +127,7 @@ to download pre-trained streaming models. | ||
| 124 | config.TransducerModelConfig.Decoder = options.Decoder; | 127 | config.TransducerModelConfig.Decoder = options.Decoder; |
| 125 | config.TransducerModelConfig.Joiner = options.Joiner; | 128 | config.TransducerModelConfig.Joiner = options.Joiner; |
| 126 | config.TransducerModelConfig.Tokens = options.Tokens; | 129 | config.TransducerModelConfig.Tokens = options.Tokens; |
| 130 | + config.TransducerModelConfig.Provider = options.Provider; | ||
| 127 | config.TransducerModelConfig.NumThreads = options.NumThreads; | 131 | config.TransducerModelConfig.NumThreads = options.NumThreads; |
| 128 | config.TransducerModelConfig.Debug = options.Debug ? 1 : 0; | 132 | config.TransducerModelConfig.Debug = options.Debug ? 1 : 0; |
| 129 | 133 |
| @@ -23,6 +23,7 @@ namespace SherpaOnnx | @@ -23,6 +23,7 @@ namespace SherpaOnnx | ||
| 23 | Joiner = ""; | 23 | Joiner = ""; |
| 24 | Tokens = ""; | 24 | Tokens = ""; |
| 25 | NumThreads = 1; | 25 | NumThreads = 1; |
| 26 | + Provider = "cpu"; | ||
| 26 | Debug = 0; | 27 | Debug = 0; |
| 27 | } | 28 | } |
| 28 | [MarshalAs(UnmanagedType.LPStr)] | 29 | [MarshalAs(UnmanagedType.LPStr)] |
| @@ -40,6 +41,9 @@ namespace SherpaOnnx | @@ -40,6 +41,9 @@ namespace SherpaOnnx | ||
| 40 | /// Number of threads used to run the neural network model | 41 | /// Number of threads used to run the neural network model |
| 41 | public int NumThreads; | 42 | public int NumThreads; |
| 42 | 43 | ||
| 44 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 45 | + public string Provider; | ||
| 46 | + | ||
| 43 | /// true to print debug information of the model | 47 | /// true to print debug information of the model |
| 44 | public int Debug; | 48 | public int Debug; |
| 45 | } | 49 | } |
-
请 注册 或 登录 后发表评论