Committed by
GitHub
support using xnnpack as execution provider (#612)
正在显示
6 个修改的文件
包含
29 行增加
和
11 行删除
| @@ -16,7 +16,7 @@ endif() | @@ -16,7 +16,7 @@ endif() | ||
| 16 | 16 | ||
| 17 | set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip") | 17 | set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip") |
| 18 | set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip") | 18 | set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-static_lib-1.18.0.zip") |
| 19 | -set(onnxruntime_HASH "SHA256=6791d695d17118db6815364c975a9d7ea9a8909754516ed1b089fe015c20912e") | 19 | +set(onnxruntime_HASH "SHA256=77ecc51d8caf0953755db6edcdec2fc03bce3f6d379bedd635be50bb95f88da5") |
| 20 | 20 | ||
| 21 | # If you don't have access to the Internet, | 21 | # If you don't have access to the Internet, |
| 22 | # please download onnxruntime to one of the following locations. | 22 | # please download onnxruntime to one of the following locations. |
| @@ -16,7 +16,7 @@ endif() | @@ -16,7 +16,7 @@ endif() | ||
| 16 | 16 | ||
| 17 | set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip") | 17 | set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip") |
| 18 | set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip") | 18 | set(onnxruntime_URL2 "https://hub.nuaa.cf/csukuangfj/onnxruntime-libs/releases/download/v1.18.0/onnxruntime-linux-riscv64-1.18.0.zip") |
| 19 | -set(onnxruntime_HASH "SHA256=87ef36dbba28ee332069e7e511dcb409913bdeeed231b45172fe200d71c690a2") | 19 | +set(onnxruntime_HASH "SHA256=81a11b54d1d71f4b3161b00cba8576a07594abd218aa5c0d82382960ada06092") |
| 20 | 20 | ||
| 21 | # If you don't have access to the Internet, | 21 | # If you don't have access to the Internet, |
| 22 | # please download onnxruntime to one of the following locations. | 22 | # please download onnxruntime to one of the following locations. |
| @@ -20,6 +20,8 @@ Provider StringToProvider(std::string s) { | @@ -20,6 +20,8 @@ Provider StringToProvider(std::string s) { | ||
| 20 | return Provider::kCUDA; | 20 | return Provider::kCUDA; |
| 21 | } else if (s == "coreml") { | 21 | } else if (s == "coreml") { |
| 22 | return Provider::kCoreML; | 22 | return Provider::kCoreML; |
| 23 | + } else if (s == "xnnpack") { | ||
| 24 | + return Provider::kXnnpack; | ||
| 23 | } else { | 25 | } else { |
| 24 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); | 26 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); |
| 25 | return Provider::kCPU; | 27 | return Provider::kCPU; |
| @@ -13,9 +13,10 @@ namespace sherpa_onnx { | @@ -13,9 +13,10 @@ namespace sherpa_onnx { | ||
| 13 | // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java | 13 | // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java |
| 14 | // for a list of available providers | 14 | // for a list of available providers |
| 15 | enum class Provider { | 15 | enum class Provider { |
| 16 | - kCPU = 0, // CPUExecutionProvider | ||
| 17 | - kCUDA = 1, // CUDAExecutionProvider | ||
| 18 | - kCoreML = 2, // CoreMLExecutionProvider | 16 | + kCPU = 0, // CPUExecutionProvider |
| 17 | + kCUDA = 1, // CUDAExecutionProvider | ||
| 18 | + kCoreML = 2, // CoreMLExecutionProvider | ||
| 19 | + kXnnpack = 3, // XnnpackExecutionProvider | ||
| 19 | }; | 20 | }; |
| 20 | 21 | ||
| 21 | /** | 22 | /** |
| @@ -25,6 +25,12 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -25,6 +25,12 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 25 | sess_opts.SetIntraOpNumThreads(num_threads); | 25 | sess_opts.SetIntraOpNumThreads(num_threads); |
| 26 | sess_opts.SetInterOpNumThreads(num_threads); | 26 | sess_opts.SetInterOpNumThreads(num_threads); |
| 27 | 27 | ||
| 28 | + std::vector<std::string> available_providers = Ort::GetAvailableProviders(); | ||
| 29 | + std::ostringstream os; | ||
| 30 | + for (const auto &ep : available_providers) { | ||
| 31 | + os << ep << ", "; | ||
| 32 | + } | ||
| 33 | + | ||
| 28 | // Other possible options | 34 | // Other possible options |
| 29 | // sess_opts.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); | 35 | // sess_opts.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); |
| 30 | // sess_opts.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); | 36 | // sess_opts.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); |
| @@ -33,9 +39,17 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -33,9 +39,17 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 33 | switch (p) { | 39 | switch (p) { |
| 34 | case Provider::kCPU: | 40 | case Provider::kCPU: |
| 35 | break; // nothing to do for the CPU provider | 41 | break; // nothing to do for the CPU provider |
| 42 | + case Provider::kXnnpack: { | ||
| 43 | + if (std::find(available_providers.begin(), available_providers.end(), | ||
| 44 | + "XnnpackExecutionProvider") != available_providers.end()) { | ||
| 45 | + sess_opts.AppendExecutionProvider("XNNPACK"); | ||
| 46 | + } else { | ||
| 47 | + SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", | ||
| 48 | + os.str().c_str()); | ||
| 49 | + } | ||
| 50 | + break; | ||
| 51 | + } | ||
| 36 | case Provider::kCUDA: { | 52 | case Provider::kCUDA: { |
| 37 | - std::vector<std::string> available_providers = | ||
| 38 | - Ort::GetAvailableProviders(); | ||
| 39 | if (std::find(available_providers.begin(), available_providers.end(), | 53 | if (std::find(available_providers.begin(), available_providers.end(), |
| 40 | "CUDAExecutionProvider") != available_providers.end()) { | 54 | "CUDAExecutionProvider") != available_providers.end()) { |
| 41 | // The CUDA provider is available, proceed with setting the options | 55 | // The CUDA provider is available, proceed with setting the options |
| @@ -47,8 +61,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -47,8 +61,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 47 | sess_opts.AppendExecutionProvider_CUDA(options); | 61 | sess_opts.AppendExecutionProvider_CUDA(options); |
| 48 | } else { | 62 | } else { |
| 49 | SHERPA_ONNX_LOGE( | 63 | SHERPA_ONNX_LOGE( |
| 50 | - "Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to " | ||
| 51 | - "cpu!"); | 64 | + "Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Available " |
| 65 | + "providers: %s. Fallback to cpu!", | ||
| 66 | + os.str().c_str()); | ||
| 52 | } | 67 | } |
| 53 | break; | 68 | break; |
| 54 | } | 69 | } |
| @@ -43,8 +43,8 @@ This program shows how to use VAD in sherpa-onnx. | @@ -43,8 +43,8 @@ This program shows how to use VAD in sherpa-onnx. | ||
| 43 | 43 | ||
| 44 | ./bin/sherpa-onnx-vad-microphone \ | 44 | ./bin/sherpa-onnx-vad-microphone \ |
| 45 | --silero-vad-model=/path/to/silero_vad.onnx \ | 45 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| 46 | - --provider=cpu \ | ||
| 47 | - --num-threads=1 | 46 | + --vad-provider=cpu \ |
| 47 | + --vad-num-threads=1 | ||
| 48 | 48 | ||
| 49 | Please download silero_vad.onnx from | 49 | Please download silero_vad.onnx from |
| 50 | https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx | 50 | https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx |
-
请 注册 或 登录 后发表评论