Fangjun Kuang
Committed by GitHub

Fix C# APIs (#183)

* Fix c# APIs

* reformat
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.4.4") 4 +set(SHERPA_ONNX_VERSION "1.4.5")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -5,59 +5,54 @@ @@ -5,59 +5,54 @@
5 // This file shows how to use sherpa-onnx C API 5 // This file shows how to use sherpa-onnx C API
6 // to decode a file. 6 // to decode a file.
7 7
8 -#include "cargs.h"  
9 #include <stdio.h> 8 #include <stdio.h>
10 #include <stdlib.h> 9 #include <stdlib.h>
11 #include <string.h> 10 #include <string.h>
12 11
  12 +#include "cargs.h"
13 #include "sherpa-onnx/c-api/c-api.h" 13 #include "sherpa-onnx/c-api/c-api.h"
14 14
15 static struct cag_option options[] = { 15 static struct cag_option options[] = {
16 - {  
17 - .identifier = 't', 16 + {.identifier = 'h',
  17 + .access_letters = "h",
  18 + .access_name = "help",
  19 + .description = "Show help"},
  20 + {.identifier = 't',
18 .access_letters = NULL, 21 .access_letters = NULL,
19 .access_name = "tokens", 22 .access_name = "tokens",
20 .value_name = "tokens", 23 .value_name = "tokens",
21 - .description = "Tokens file"  
22 - }, {  
23 - .identifier = 'e', 24 + .description = "Tokens file"},
  25 + {.identifier = 'e',
24 .access_letters = NULL, 26 .access_letters = NULL,
25 .access_name = "encoder", 27 .access_name = "encoder",
26 .value_name = "encoder", 28 .value_name = "encoder",
27 - .description = "Encoder ONNX file"  
28 - }, {  
29 - .identifier = 'd', 29 + .description = "Encoder ONNX file"},
  30 + {.identifier = 'd',
30 .access_letters = NULL, 31 .access_letters = NULL,
31 .access_name = "decoder", 32 .access_name = "decoder",
32 .value_name = "decoder", 33 .value_name = "decoder",
33 - .description = "Decoder ONNX file"  
34 - }, {  
35 - .identifier = 'j', 34 + .description = "Decoder ONNX file"},
  35 + {.identifier = 'j',
36 .access_letters = NULL, 36 .access_letters = NULL,
37 .access_name = "joiner", 37 .access_name = "joiner",
38 .value_name = "joiner", 38 .value_name = "joiner",
39 - .description = "Joiner ONNX file"  
40 - }, {  
41 - .identifier = 'n', 39 + .description = "Joiner ONNX file"},
  40 + {.identifier = 'n',
42 .access_letters = NULL, 41 .access_letters = NULL,
43 .access_name = "num-threads", 42 .access_name = "num-threads",
44 .value_name = "num-threads", 43 .value_name = "num-threads",
45 - .description = "Number of threads"  
46 - }, {  
47 - .identifier = 'p', 44 + .description = "Number of threads"},
  45 + {.identifier = 'p',
48 .access_letters = NULL, 46 .access_letters = NULL,
49 .access_name = "provider", 47 .access_name = "provider",
50 .value_name = "provider", 48 .value_name = "provider",
51 - .description = "Provider: cpu (default), cuda, coreml"  
52 - }, {  
53 - .identifier = 'm', 49 + .description = "Provider: cpu (default), cuda, coreml"},
  50 + {.identifier = 'm',
54 .access_letters = NULL, 51 .access_letters = NULL,
55 .access_name = "decoding-method", 52 .access_name = "decoding-method",
56 .value_name = "decoding-method", 53 .value_name = "decoding-method",
57 .description = 54 .description =
58 - "Decoding method: greedy_search (default), modified_beam_search"  
59 - }  
60 -}; 55 + "Decoding method: greedy_search (default), modified_beam_search"}};
61 56
62 const char *kUsage = 57 const char *kUsage =
63 "\n" 58 "\n"
@@ -67,6 +62,7 @@ const char *kUsage = @@ -67,6 +62,7 @@ const char *kUsage =
67 " --encoder=/path/to/encoder.onnx \\\n" 62 " --encoder=/path/to/encoder.onnx \\\n"
68 " --decoder=/path/to/decoder.onnx \\\n" 63 " --decoder=/path/to/decoder.onnx \\\n"
69 " --joiner=/path/to/joiner.onnx \\\n" 64 " --joiner=/path/to/joiner.onnx \\\n"
  65 + " --provider=cpu \\\n"
70 " /path/to/foo.wav\n" 66 " /path/to/foo.wav\n"
71 "\n\n" 67 "\n\n"
72 "Default num_threads is 1.\n" 68 "Default num_threads is 1.\n"
@@ -77,6 +73,11 @@ const char *kUsage = @@ -77,6 +73,11 @@ const char *kUsage =
77 "for a list of pre-trained models to download.\n"; 73 "for a list of pre-trained models to download.\n";
78 74
79 int32_t main(int32_t argc, char *argv[]) { 75 int32_t main(int32_t argc, char *argv[]) {
  76 + if (argc < 6) {
  77 + fprintf(stderr, "%s\n", kUsage);
  78 + exit(0);
  79 + }
  80 +
80 SherpaOnnxOnlineRecognizerConfig config; 81 SherpaOnnxOnlineRecognizerConfig config;
81 82
82 config.model_config.debug = 0; 83 config.model_config.debug = 0;
@@ -105,13 +106,32 @@ int32_t main(int32_t argc, char *argv[]) { @@ -105,13 +106,32 @@ int32_t main(int32_t argc, char *argv[]) {
105 identifier = cag_option_get(&context); 106 identifier = cag_option_get(&context);
106 value = cag_option_get_value(&context); 107 value = cag_option_get_value(&context);
107 switch (identifier) { 108 switch (identifier) {
108 - case 't': config.model_config.tokens = value; break;  
109 - case 'e': config.model_config.encoder = value; break;  
110 - case 'd': config.model_config.decoder = value; break;  
111 - case 'j': config.model_config.joiner = value; break;  
112 - case 'n': config.model_config.num_threads = atoi(value); break;  
113 - case 'p': config.model_config.provider = value; break;  
114 - case 'm': config.decoding_method = value; break; 109 + case 't':
  110 + config.model_config.tokens = value;
  111 + break;
  112 + case 'e':
  113 + config.model_config.encoder = value;
  114 + break;
  115 + case 'd':
  116 + config.model_config.decoder = value;
  117 + break;
  118 + case 'j':
  119 + config.model_config.joiner = value;
  120 + break;
  121 + case 'n':
  122 + config.model_config.num_threads = atoi(value);
  123 + break;
  124 + case 'p':
  125 + config.model_config.provider = value;
  126 + break;
  127 + case 'm':
  128 + config.decoding_method = value;
  129 + break;
  130 + case 'h': {
  131 + fprintf(stderr, "%s\n", kUsage);
  132 + exit(0);
  133 + break;
  134 + }
115 default: 135 default:
116 // do nothing as config already have valid default values 136 // do nothing as config already have valid default values
117 break; 137 break;
@@ -20,6 +20,9 @@ class OnlineDecodeFiles @@ -20,6 +20,9 @@ class OnlineDecodeFiles
20 [Option(Required = true, HelpText = "Path to tokens.txt")] 20 [Option(Required = true, HelpText = "Path to tokens.txt")]
21 public string Tokens { get; set; } 21 public string Tokens { get; set; }
22 22
  23 + [Option(Required = false, Default = "cpu", HelpText = "Provider, e.g., cpu, coreml")]
  24 + public string Provider { get; set; }
  25 +
23 [Option(Required = true, HelpText = "Path to encoder.onnx")] 26 [Option(Required = true, HelpText = "Path to encoder.onnx")]
24 public string Encoder { get; set; } 27 public string Encoder { get; set; }
25 28
@@ -124,6 +127,7 @@ to download pre-trained streaming models. @@ -124,6 +127,7 @@ to download pre-trained streaming models.
124 config.TransducerModelConfig.Decoder = options.Decoder; 127 config.TransducerModelConfig.Decoder = options.Decoder;
125 config.TransducerModelConfig.Joiner = options.Joiner; 128 config.TransducerModelConfig.Joiner = options.Joiner;
126 config.TransducerModelConfig.Tokens = options.Tokens; 129 config.TransducerModelConfig.Tokens = options.Tokens;
  130 + config.TransducerModelConfig.Provider = options.Provider;
127 config.TransducerModelConfig.NumThreads = options.NumThreads; 131 config.TransducerModelConfig.NumThreads = options.NumThreads;
128 config.TransducerModelConfig.Debug = options.Debug ? 1 : 0; 132 config.TransducerModelConfig.Debug = options.Debug ? 1 : 0;
129 133
@@ -23,6 +23,7 @@ namespace SherpaOnnx @@ -23,6 +23,7 @@ namespace SherpaOnnx
23 Joiner = ""; 23 Joiner = "";
24 Tokens = ""; 24 Tokens = "";
25 NumThreads = 1; 25 NumThreads = 1;
  26 + Provider = "cpu";
26 Debug = 0; 27 Debug = 0;
27 } 28 }
28 [MarshalAs(UnmanagedType.LPStr)] 29 [MarshalAs(UnmanagedType.LPStr)]
@@ -40,6 +41,9 @@ namespace SherpaOnnx @@ -40,6 +41,9 @@ namespace SherpaOnnx
40 /// Number of threads used to run the neural network model 41 /// Number of threads used to run the neural network model
41 public int NumThreads; 42 public int NumThreads;
42 43
  44 + [MarshalAs(UnmanagedType.LPStr)]
  45 + public string Provider;
  46 +
43 /// true to print debug information of the model 47 /// true to print debug information of the model
44 public int Debug; 48 public int Debug;
45 } 49 }