正在显示
6 个修改的文件
包含
218 行增加
和
451 行删除
| @@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 92 | template <typename Manager> | 92 | template <typename Manager> |
| 93 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | 93 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( |
| 94 | Manager *mgr, const OnlineRecognizerConfig &config) { | 94 | Manager *mgr, const OnlineRecognizerConfig &config) { |
| 95 | + if (config.model_config.provider_config.provider == "rknn") { | ||
| 96 | +#if SHERPA_ONNX_ENABLE_RKNN | ||
| 97 | + // Currently, only zipformer v1 is suported for rknn | ||
| 98 | + if (config.model_config.transducer.encoder.empty() && | ||
| 99 | + config.model_config.zipformer2_ctc.model.empty()) { | ||
| 100 | + SHERPA_ONNX_LOGE( | ||
| 101 | + "Only Zipformer transducers and CTC models are currently supported " | ||
| 102 | + "by rknn. Fallback to CPU"); | ||
| 103 | + } else if (!config.model_config.transducer.encoder.empty()) { | ||
| 104 | + return std::make_unique<OnlineRecognizerTransducerRknnImpl>(mgr, config); | ||
| 105 | + } else if (!config.model_config.zipformer2_ctc.model.empty()) { | ||
| 106 | + return std::make_unique<OnlineRecognizerCtcRknnImpl>(mgr, config); | ||
| 107 | + } | ||
| 108 | +#else | ||
| 109 | + SHERPA_ONNX_LOGE( | ||
| 110 | + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " | ||
| 111 | + "want to use rknn. Fallback to CPU"); | ||
| 112 | +#endif | ||
| 113 | + } | ||
| 114 | + | ||
| 95 | if (!config.model_config.transducer.encoder.empty()) { | 115 | if (!config.model_config.transducer.encoder.empty()) { |
| 96 | Ort::Env env(ORT_LOGGING_LEVEL_ERROR); | 116 | Ort::Env env(ORT_LOGGING_LEVEL_ERROR); |
| 97 | 117 |
| @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 42 | Init(buf.data(), buf.size()); | 42 | Init(buf.data(), buf.size()); |
| 43 | } | 43 | } |
| 44 | 44 | ||
| 45 | - int32_t ret = RKNN_SUCC; | ||
| 46 | - switch (config_.num_threads) { | ||
| 47 | - case 1: | ||
| 48 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO); | ||
| 49 | - break; | ||
| 50 | - case 0: | ||
| 51 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0); | ||
| 52 | - break; | ||
| 53 | - case -1: | ||
| 54 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1); | ||
| 55 | - break; | ||
| 56 | - case -2: | ||
| 57 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2); | ||
| 58 | - break; | ||
| 59 | - case -3: | ||
| 60 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1); | ||
| 61 | - break; | ||
| 62 | - case -4: | ||
| 63 | - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2); | ||
| 64 | - break; | ||
| 65 | - default: | ||
| 66 | - SHERPA_ONNX_LOGE( | ||
| 67 | - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " | ||
| 68 | - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", | ||
| 69 | - config_.num_threads); | ||
| 70 | - break; | ||
| 71 | - } | ||
| 72 | - if (ret != RKNN_SUCC) { | ||
| 73 | - SHERPA_ONNX_LOGE( | ||
| 74 | - "Failed to select npu core to run the model (You can ignore it if " | ||
| 75 | - "you " | ||
| 76 | - "are not using RK3588."); | 45 | + SetCoreMask(ctx_, config_.num_threads); |
| 46 | + } | ||
| 47 | + | ||
| 48 | + template <typename Manager> | ||
| 49 | + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) { | ||
| 50 | + { | ||
| 51 | + auto buf = ReadFile(mgr, config.zipformer2_ctc.model); | ||
| 52 | + Init(buf.data(), buf.size()); | ||
| 77 | } | 53 | } |
| 54 | + | ||
| 55 | + SetCoreMask(ctx_, config_.num_threads); | ||
| 78 | } | 56 | } |
| 79 | 57 | ||
| 80 | // TODO(fangjun): Support Android | 58 | // TODO(fangjun): Support Android |
| @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 209 | 187 | ||
| 210 | private: | 188 | private: |
| 211 | void Init(void *model_data, size_t model_data_length) { | 189 | void Init(void *model_data, size_t model_data_length) { |
| 212 | - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); | ||
| 213 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'", | ||
| 214 | - config_.zipformer2_ctc.model.c_str()); | ||
| 215 | - | ||
| 216 | - if (config_.debug) { | ||
| 217 | - rknn_sdk_version v; | ||
| 218 | - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); | ||
| 219 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); | ||
| 220 | - | ||
| 221 | - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, | ||
| 222 | - v.drv_version); | ||
| 223 | - } | ||
| 224 | - | ||
| 225 | - rknn_input_output_num io_num; | ||
| 226 | - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); | ||
| 227 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); | ||
| 228 | - | ||
| 229 | - if (config_.debug) { | ||
| 230 | - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", | ||
| 231 | - static_cast<int32_t>(io_num.n_input), | ||
| 232 | - static_cast<int32_t>(io_num.n_output)); | ||
| 233 | - } | ||
| 234 | - | ||
| 235 | - input_attrs_.resize(io_num.n_input); | ||
| 236 | - output_attrs_.resize(io_num.n_output); | ||
| 237 | - | ||
| 238 | - int32_t i = 0; | ||
| 239 | - for (auto &attr : input_attrs_) { | ||
| 240 | - memset(&attr, 0, sizeof(attr)); | ||
| 241 | - attr.index = i; | ||
| 242 | - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 243 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); | ||
| 244 | - i += 1; | ||
| 245 | - } | ||
| 246 | - | ||
| 247 | - if (config_.debug) { | ||
| 248 | - std::ostringstream os; | ||
| 249 | - std::string sep; | ||
| 250 | - for (auto &attr : input_attrs_) { | ||
| 251 | - os << sep << ToString(attr); | ||
| 252 | - sep = "\n"; | ||
| 253 | - } | ||
| 254 | - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", | ||
| 255 | - os.str().c_str()); | ||
| 256 | - } | ||
| 257 | - | ||
| 258 | - i = 0; | ||
| 259 | - for (auto &attr : output_attrs_) { | ||
| 260 | - memset(&attr, 0, sizeof(attr)); | ||
| 261 | - attr.index = i; | ||
| 262 | - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 263 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); | ||
| 264 | - i += 1; | ||
| 265 | - } | 190 | + InitContext(model_data, model_data_length, config_.debug, &ctx_); |
| 266 | 191 | ||
| 267 | - if (config_.debug) { | ||
| 268 | - std::ostringstream os; | ||
| 269 | - std::string sep; | ||
| 270 | - for (auto &attr : output_attrs_) { | ||
| 271 | - os << sep << ToString(attr); | ||
| 272 | - sep = "\n"; | ||
| 273 | - } | ||
| 274 | - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", | ||
| 275 | - os.str().c_str()); | ||
| 276 | - } | 192 | + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); |
| 277 | 193 | ||
| 278 | - rknn_custom_string custom_string; | ||
| 279 | - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, | ||
| 280 | - sizeof(custom_string)); | ||
| 281 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); | ||
| 282 | - if (config_.debug) { | ||
| 283 | - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); | ||
| 284 | - } | ||
| 285 | - auto meta = Parse(custom_string); | 194 | + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); |
| 286 | 195 | ||
| 287 | - if (config_.debug) { | ||
| 288 | - for (const auto &p : meta) { | ||
| 289 | - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); | ||
| 290 | - } | ||
| 291 | - } | 196 | + auto meta = Parse(custom_string, config_.debug); |
| 292 | 197 | ||
| 293 | if (meta.count("T")) { | 198 | if (meta.count("T")) { |
| 294 | T_ = atoi(meta.at("T").c_str()); | 199 | T_ = atoi(meta.at("T").c_str()); |
| @@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 62 | InitJoiner(buf.data(), buf.size()); | 62 | InitJoiner(buf.data(), buf.size()); |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | - // Now select which core to run for RK3588 | ||
| 66 | - int32_t ret_encoder = RKNN_SUCC; | ||
| 67 | - int32_t ret_decoder = RKNN_SUCC; | ||
| 68 | - int32_t ret_joiner = RKNN_SUCC; | ||
| 69 | - switch (config_.num_threads) { | ||
| 70 | - case 1: | ||
| 71 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO); | ||
| 72 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO); | ||
| 73 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO); | ||
| 74 | - break; | ||
| 75 | - case 0: | ||
| 76 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0); | ||
| 77 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0); | ||
| 78 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0); | ||
| 79 | - break; | ||
| 80 | - case -1: | ||
| 81 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1); | ||
| 82 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1); | ||
| 83 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1); | ||
| 84 | - break; | ||
| 85 | - case -2: | ||
| 86 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2); | ||
| 87 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2); | ||
| 88 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2); | ||
| 89 | - break; | ||
| 90 | - case -3: | ||
| 91 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1); | ||
| 92 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1); | ||
| 93 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1); | ||
| 94 | - break; | ||
| 95 | - case -4: | ||
| 96 | - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2); | ||
| 97 | - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2); | ||
| 98 | - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2); | ||
| 99 | - break; | ||
| 100 | - default: | ||
| 101 | - SHERPA_ONNX_LOGE( | ||
| 102 | - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " | ||
| 103 | - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", | ||
| 104 | - config_.num_threads); | ||
| 105 | - break; | ||
| 106 | - } | ||
| 107 | - if (ret_encoder != RKNN_SUCC) { | ||
| 108 | - SHERPA_ONNX_LOGE( | ||
| 109 | - "Failed to select npu core to run encoder (You can ignore it if you " | ||
| 110 | - "are not using RK3588."); | 65 | + SetCoreMask(encoder_ctx_, config_.num_threads); |
| 66 | + SetCoreMask(decoder_ctx_, config_.num_threads); | ||
| 67 | + SetCoreMask(joiner_ctx_, config_.num_threads); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + template <typename Manager> | ||
| 71 | + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) { | ||
| 72 | + { | ||
| 73 | + auto buf = ReadFile(mgr, config.transducer.encoder); | ||
| 74 | + InitEncoder(buf.data(), buf.size()); | ||
| 111 | } | 75 | } |
| 112 | 76 | ||
| 113 | - if (ret_decoder != RKNN_SUCC) { | ||
| 114 | - SHERPA_ONNX_LOGE( | ||
| 115 | - "Failed to select npu core to run decoder (You can ignore it if you " | ||
| 116 | - "are not using RK3588."); | 77 | + { |
| 78 | + auto buf = ReadFile(mgr, config.transducer.decoder); | ||
| 79 | + InitDecoder(buf.data(), buf.size()); | ||
| 117 | } | 80 | } |
| 118 | 81 | ||
| 119 | - if (ret_decoder != RKNN_SUCC) { | ||
| 120 | - SHERPA_ONNX_LOGE( | ||
| 121 | - "Failed to select npu core to run joiner (You can ignore it if you " | ||
| 122 | - "are not using RK3588."); | 82 | + { |
| 83 | + auto buf = ReadFile(mgr, config.transducer.joiner); | ||
| 84 | + InitJoiner(buf.data(), buf.size()); | ||
| 123 | } | 85 | } |
| 86 | + | ||
| 87 | + SetCoreMask(encoder_ctx_, config_.num_threads); | ||
| 88 | + SetCoreMask(decoder_ctx_, config_.num_threads); | ||
| 89 | + SetCoreMask(joiner_ctx_, config_.num_threads); | ||
| 124 | } | 90 | } |
| 125 | 91 | ||
| 126 | // TODO(fangjun): Support Android | 92 | // TODO(fangjun): Support Android |
| @@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 325 | 291 | ||
| 326 | private: | 292 | private: |
| 327 | void InitEncoder(void *model_data, size_t model_data_length) { | 293 | void InitEncoder(void *model_data, size_t model_data_length) { |
| 328 | - auto ret = | ||
| 329 | - rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr); | ||
| 330 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'", | ||
| 331 | - config_.transducer.encoder.c_str()); | ||
| 332 | - | ||
| 333 | - if (config_.debug) { | ||
| 334 | - rknn_sdk_version v; | ||
| 335 | - ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); | ||
| 336 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); | ||
| 337 | - | ||
| 338 | - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, | ||
| 339 | - v.drv_version); | ||
| 340 | - } | ||
| 341 | - | ||
| 342 | - rknn_input_output_num io_num; | ||
| 343 | - ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, | ||
| 344 | - sizeof(io_num)); | ||
| 345 | - SHERPA_ONNX_RKNN_CHECK(ret, | ||
| 346 | - "Failed to get I/O information for the encoder"); | ||
| 347 | - | ||
| 348 | - if (config_.debug) { | ||
| 349 | - SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs", | ||
| 350 | - static_cast<int32_t>(io_num.n_input), | ||
| 351 | - static_cast<int32_t>(io_num.n_output)); | ||
| 352 | - } | ||
| 353 | - | ||
| 354 | - encoder_input_attrs_.resize(io_num.n_input); | ||
| 355 | - encoder_output_attrs_.resize(io_num.n_output); | ||
| 356 | - | ||
| 357 | - int32_t i = 0; | ||
| 358 | - for (auto &attr : encoder_input_attrs_) { | ||
| 359 | - memset(&attr, 0, sizeof(attr)); | ||
| 360 | - attr.index = i; | ||
| 361 | - ret = | ||
| 362 | - rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 363 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i); | ||
| 364 | - i += 1; | ||
| 365 | - } | ||
| 366 | - | ||
| 367 | - if (config_.debug) { | ||
| 368 | - std::ostringstream os; | ||
| 369 | - std::string sep; | ||
| 370 | - for (auto &attr : encoder_input_attrs_) { | ||
| 371 | - os << sep << ToString(attr); | ||
| 372 | - sep = "\n"; | ||
| 373 | - } | ||
| 374 | - SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s", | ||
| 375 | - os.str().c_str()); | ||
| 376 | - } | ||
| 377 | - | ||
| 378 | - i = 0; | ||
| 379 | - for (auto &attr : encoder_output_attrs_) { | ||
| 380 | - memset(&attr, 0, sizeof(attr)); | ||
| 381 | - attr.index = i; | ||
| 382 | - ret = | ||
| 383 | - rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 384 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d", | ||
| 385 | - i); | ||
| 386 | - i += 1; | ||
| 387 | - } | 294 | + InitContext(model_data, model_data_length, config_.debug, &encoder_ctx_); |
| 388 | 295 | ||
| 389 | - if (config_.debug) { | ||
| 390 | - std::ostringstream os; | ||
| 391 | - std::string sep; | ||
| 392 | - for (auto &attr : encoder_output_attrs_) { | ||
| 393 | - os << sep << ToString(attr); | ||
| 394 | - sep = "\n"; | ||
| 395 | - } | ||
| 396 | - SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s", | ||
| 397 | - os.str().c_str()); | ||
| 398 | - } | 296 | + InitInputOutputAttrs(encoder_ctx_, config_.debug, &encoder_input_attrs_, |
| 297 | + &encoder_output_attrs_); | ||
| 399 | 298 | ||
| 400 | - rknn_custom_string custom_string; | ||
| 401 | - ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, | ||
| 402 | - sizeof(custom_string)); | ||
| 403 | - SHERPA_ONNX_RKNN_CHECK( | ||
| 404 | - ret, "Failed to read custom string from the encoder model"); | ||
| 405 | - if (config_.debug) { | ||
| 406 | - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); | ||
| 407 | - } | ||
| 408 | - auto meta = Parse(custom_string); | 299 | + rknn_custom_string custom_string = |
| 300 | + GetCustomString(encoder_ctx_, config_.debug); | ||
| 409 | 301 | ||
| 410 | - if (config_.debug) { | ||
| 411 | - for (const auto &p : meta) { | ||
| 412 | - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); | ||
| 413 | - } | ||
| 414 | - } | 302 | + auto meta = Parse(custom_string, config_.debug); |
| 415 | 303 | ||
| 416 | if (meta.count("encoder_dims")) { | 304 | if (meta.count("encoder_dims")) { |
| 417 | SplitStringToIntegers(meta.at("encoder_dims"), ",", false, | 305 | SplitStringToIntegers(meta.at("encoder_dims"), ",", false, |
| @@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 479 | } | 367 | } |
| 480 | 368 | ||
| 481 | void InitDecoder(void *model_data, size_t model_data_length) { | 369 | void InitDecoder(void *model_data, size_t model_data_length) { |
| 482 | - auto ret = | ||
| 483 | - rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr); | ||
| 484 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'", | ||
| 485 | - config_.transducer.decoder.c_str()); | ||
| 486 | - | ||
| 487 | - rknn_input_output_num io_num; | ||
| 488 | - ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, | ||
| 489 | - sizeof(io_num)); | ||
| 490 | - SHERPA_ONNX_RKNN_CHECK(ret, | ||
| 491 | - "Failed to get I/O information for the decoder"); | ||
| 492 | - | ||
| 493 | - if (io_num.n_input != 1) { | ||
| 494 | - SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d", | ||
| 495 | - static_cast<int32_t>(io_num.n_input)); | ||
| 496 | - SHERPA_ONNX_EXIT(-1); | ||
| 497 | - } | 370 | + InitContext(model_data, model_data_length, config_.debug, &decoder_ctx_); |
| 498 | 371 | ||
| 499 | - if (io_num.n_output != 1) { | ||
| 500 | - SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d", | ||
| 501 | - static_cast<int32_t>(io_num.n_output)); | ||
| 502 | - SHERPA_ONNX_EXIT(-1); | ||
| 503 | - } | ||
| 504 | - | ||
| 505 | - if (config_.debug) { | ||
| 506 | - SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs", | ||
| 507 | - static_cast<int32_t>(io_num.n_input), | ||
| 508 | - static_cast<int32_t>(io_num.n_output)); | ||
| 509 | - } | ||
| 510 | - | ||
| 511 | - decoder_input_attrs_.resize(io_num.n_input); | ||
| 512 | - decoder_output_attrs_.resize(io_num.n_output); | ||
| 513 | - | ||
| 514 | - int32_t i = 0; | ||
| 515 | - for (auto &attr : decoder_input_attrs_) { | ||
| 516 | - memset(&attr, 0, sizeof(attr)); | ||
| 517 | - attr.index = i; | ||
| 518 | - ret = | ||
| 519 | - rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 520 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i); | ||
| 521 | - i += 1; | ||
| 522 | - } | ||
| 523 | - | ||
| 524 | - if (config_.debug) { | ||
| 525 | - std::ostringstream os; | ||
| 526 | - std::string sep; | ||
| 527 | - for (auto &attr : decoder_input_attrs_) { | ||
| 528 | - os << sep << ToString(attr); | ||
| 529 | - sep = "\n"; | ||
| 530 | - } | ||
| 531 | - SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s", | ||
| 532 | - os.str().c_str()); | ||
| 533 | - } | 372 | + InitInputOutputAttrs(decoder_ctx_, config_.debug, &decoder_input_attrs_, |
| 373 | + &decoder_output_attrs_); | ||
| 534 | 374 | ||
| 535 | if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) { | 375 | if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) { |
| 536 | SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s", | 376 | SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s", |
| @@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 543 | if (config_.debug) { | 383 | if (config_.debug) { |
| 544 | SHERPA_ONNX_LOGE("context_size: %d", context_size_); | 384 | SHERPA_ONNX_LOGE("context_size: %d", context_size_); |
| 545 | } | 385 | } |
| 546 | - | ||
| 547 | - i = 0; | ||
| 548 | - for (auto &attr : decoder_output_attrs_) { | ||
| 549 | - memset(&attr, 0, sizeof(attr)); | ||
| 550 | - attr.index = i; | ||
| 551 | - ret = | ||
| 552 | - rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 553 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d", | ||
| 554 | - i); | ||
| 555 | - i += 1; | ||
| 556 | - } | ||
| 557 | - | ||
| 558 | - if (config_.debug) { | ||
| 559 | - std::ostringstream os; | ||
| 560 | - std::string sep; | ||
| 561 | - for (auto &attr : decoder_output_attrs_) { | ||
| 562 | - os << sep << ToString(attr); | ||
| 563 | - sep = "\n"; | ||
| 564 | - } | ||
| 565 | - SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s", | ||
| 566 | - os.str().c_str()); | ||
| 567 | - } | ||
| 568 | } | 386 | } |
| 569 | 387 | ||
| 570 | void InitJoiner(void *model_data, size_t model_data_length) { | 388 | void InitJoiner(void *model_data, size_t model_data_length) { |
| 571 | - auto ret = | ||
| 572 | - rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr); | ||
| 573 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'", | ||
| 574 | - config_.transducer.joiner.c_str()); | 389 | + InitContext(model_data, model_data_length, config_.debug, &joiner_ctx_); |
| 575 | 390 | ||
| 576 | - rknn_input_output_num io_num; | ||
| 577 | - ret = | ||
| 578 | - rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); | ||
| 579 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner"); | ||
| 580 | - | ||
| 581 | - if (config_.debug) { | ||
| 582 | - SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs", | ||
| 583 | - static_cast<int32_t>(io_num.n_input), | ||
| 584 | - static_cast<int32_t>(io_num.n_output)); | ||
| 585 | - } | ||
| 586 | - | ||
| 587 | - joiner_input_attrs_.resize(io_num.n_input); | ||
| 588 | - joiner_output_attrs_.resize(io_num.n_output); | ||
| 589 | - | ||
| 590 | - int32_t i = 0; | ||
| 591 | - for (auto &attr : joiner_input_attrs_) { | ||
| 592 | - memset(&attr, 0, sizeof(attr)); | ||
| 593 | - attr.index = i; | ||
| 594 | - ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 595 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i); | ||
| 596 | - i += 1; | ||
| 597 | - } | ||
| 598 | - | ||
| 599 | - if (config_.debug) { | ||
| 600 | - std::ostringstream os; | ||
| 601 | - std::string sep; | ||
| 602 | - for (auto &attr : joiner_input_attrs_) { | ||
| 603 | - os << sep << ToString(attr); | ||
| 604 | - sep = "\n"; | ||
| 605 | - } | ||
| 606 | - SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s", | ||
| 607 | - os.str().c_str()); | ||
| 608 | - } | ||
| 609 | - | ||
| 610 | - i = 0; | ||
| 611 | - for (auto &attr : joiner_output_attrs_) { | ||
| 612 | - memset(&attr, 0, sizeof(attr)); | ||
| 613 | - attr.index = i; | ||
| 614 | - ret = | ||
| 615 | - rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 616 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i); | ||
| 617 | - i += 1; | ||
| 618 | - } | ||
| 619 | - | ||
| 620 | - if (config_.debug) { | ||
| 621 | - std::ostringstream os; | ||
| 622 | - std::string sep; | ||
| 623 | - for (auto &attr : joiner_output_attrs_) { | ||
| 624 | - os << sep << ToString(attr); | ||
| 625 | - sep = "\n"; | ||
| 626 | - } | ||
| 627 | - SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s", | ||
| 628 | - os.str().c_str()); | ||
| 629 | - } | 391 | + InitInputOutputAttrs(joiner_ctx_, config_.debug, &joiner_input_attrs_, |
| 392 | + &joiner_output_attrs_); | ||
| 630 | 393 | ||
| 631 | vocab_size_ = joiner_output_attrs_[0].dims[1]; | 394 | vocab_size_ = joiner_output_attrs_[0].dims[1]; |
| 632 | if (config_.debug) { | 395 | if (config_.debug) { |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" | 5 | #include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" |
| 6 | 6 | ||
| 7 | +#include <memory> | ||
| 7 | #include <string> | 8 | #include <string> |
| 8 | #include <utility> | 9 | #include <utility> |
| 9 | #include <vector> | 10 | #include <vector> |
| @@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl { | @@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl { | ||
| 39 | auto buf = ReadFile(config.silero_vad.model); | 40 | auto buf = ReadFile(config.silero_vad.model); |
| 40 | Init(buf.data(), buf.size()); | 41 | Init(buf.data(), buf.size()); |
| 41 | 42 | ||
| 43 | + SetCoreMask(ctx_, config_.num_threads); | ||
| 44 | + | ||
| 42 | if (sample_rate_ != 16000) { | 45 | if (sample_rate_ != 16000) { |
| 43 | SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", | 46 | SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", |
| 44 | config.sample_rate); | 47 | config.sample_rate); |
| @@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl { | @@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl { | ||
| 57 | auto buf = ReadFile(mgr, config.silero_vad.model); | 60 | auto buf = ReadFile(mgr, config.silero_vad.model); |
| 58 | Init(buf.data(), buf.size()); | 61 | Init(buf.data(), buf.size()); |
| 59 | 62 | ||
| 63 | + SetCoreMask(ctx_, config_.num_threads); | ||
| 64 | + | ||
| 60 | if (sample_rate_ != 16000) { | 65 | if (sample_rate_ != 16000) { |
| 61 | SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", | 66 | SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", |
| 62 | config.sample_rate); | 67 | config.sample_rate); |
| @@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl { | @@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl { | ||
| 172 | 177 | ||
| 173 | private: | 178 | private: |
| 174 | void Init(void *model_data, size_t model_data_length) { | 179 | void Init(void *model_data, size_t model_data_length) { |
| 175 | - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); | ||
| 176 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'", | ||
| 177 | - config_.silero_vad.model.c_str()); | ||
| 178 | - | ||
| 179 | - if (config_.debug) { | ||
| 180 | - rknn_sdk_version v; | ||
| 181 | - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); | ||
| 182 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); | ||
| 183 | - | ||
| 184 | - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, | ||
| 185 | - v.drv_version); | ||
| 186 | - } | 180 | + InitContext(model_data, model_data_length, config_.debug, &ctx_); |
| 187 | 181 | ||
| 188 | - rknn_input_output_num io_num; | ||
| 189 | - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); | ||
| 190 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); | 182 | + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); |
| 191 | 183 | ||
| 192 | - if (config_.debug) { | ||
| 193 | - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", | ||
| 194 | - static_cast<int32_t>(io_num.n_input), | ||
| 195 | - static_cast<int32_t>(io_num.n_output)); | ||
| 196 | - } | ||
| 197 | - | ||
| 198 | - input_attrs_.resize(io_num.n_input); | ||
| 199 | - output_attrs_.resize(io_num.n_output); | 184 | + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); |
| 200 | 185 | ||
| 201 | - int32_t i = 0; | ||
| 202 | - for (auto &attr : input_attrs_) { | ||
| 203 | - memset(&attr, 0, sizeof(attr)); | ||
| 204 | - attr.index = i; | ||
| 205 | - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 206 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); | ||
| 207 | - i += 1; | ||
| 208 | - } | ||
| 209 | - | ||
| 210 | - if (config_.debug) { | ||
| 211 | - std::ostringstream os; | ||
| 212 | - std::string sep; | ||
| 213 | - for (auto &attr : input_attrs_) { | ||
| 214 | - os << sep << ToString(attr); | ||
| 215 | - sep = "\n"; | ||
| 216 | - } | ||
| 217 | - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", | ||
| 218 | - os.str().c_str()); | ||
| 219 | - } | ||
| 220 | - | ||
| 221 | - i = 0; | ||
| 222 | - for (auto &attr : output_attrs_) { | ||
| 223 | - memset(&attr, 0, sizeof(attr)); | ||
| 224 | - attr.index = i; | ||
| 225 | - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 226 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); | ||
| 227 | - i += 1; | ||
| 228 | - } | ||
| 229 | - | ||
| 230 | - if (config_.debug) { | ||
| 231 | - std::ostringstream os; | ||
| 232 | - std::string sep; | ||
| 233 | - for (auto &attr : output_attrs_) { | ||
| 234 | - os << sep << ToString(attr); | ||
| 235 | - sep = "\n"; | ||
| 236 | - } | ||
| 237 | - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", | ||
| 238 | - os.str().c_str()); | ||
| 239 | - } | ||
| 240 | - | ||
| 241 | - rknn_custom_string custom_string; | ||
| 242 | - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, | ||
| 243 | - sizeof(custom_string)); | ||
| 244 | - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); | ||
| 245 | - if (config_.debug) { | ||
| 246 | - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); | ||
| 247 | - } | ||
| 248 | - auto meta = Parse(custom_string); | 186 | + auto meta = Parse(custom_string, config_.debug); |
| 249 | 187 | ||
| 250 | if (config_.silero_vad.window_size != 512) { | 188 | if (config_.silero_vad.window_size != 512) { |
| 251 | SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d", | 189 | SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d", |
| @@ -4,12 +4,15 @@ | @@ -4,12 +4,15 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/rknn/utils.h" | 5 | #include "sherpa-onnx/csrc/rknn/utils.h" |
| 6 | 6 | ||
| 7 | +#include <string.h> | ||
| 8 | + | ||
| 7 | #include <sstream> | 9 | #include <sstream> |
| 8 | #include <unordered_map> | 10 | #include <unordered_map> |
| 9 | #include <utility> | 11 | #include <utility> |
| 10 | #include <vector> | 12 | #include <vector> |
| 11 | 13 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| 15 | +#include "sherpa-onnx/csrc/rknn/macros.h" | ||
| 13 | #include "sherpa-onnx/csrc/text-utils.h" | 16 | #include "sherpa-onnx/csrc/text-utils.h" |
| 14 | 17 | ||
| 15 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| @@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) { | @@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) { | ||
| 52 | } | 55 | } |
| 53 | 56 | ||
| 54 | std::unordered_map<std::string, std::string> Parse( | 57 | std::unordered_map<std::string, std::string> Parse( |
| 55 | - const rknn_custom_string &custom_string) { | 58 | + const rknn_custom_string &custom_string, bool debug /*= false*/) { |
| 56 | std::unordered_map<std::string, std::string> ans; | 59 | std::unordered_map<std::string, std::string> ans; |
| 57 | std::vector<std::string> fields; | 60 | std::vector<std::string> fields; |
| 58 | SplitStringToVector(custom_string.string, ";", false, &fields); | 61 | SplitStringToVector(custom_string.string, ";", false, &fields); |
| @@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse( | @@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse( | ||
| 68 | ans[std::move(tmp[0])] = std::move(tmp[1]); | 71 | ans[std::move(tmp[0])] = std::move(tmp[1]); |
| 69 | } | 72 | } |
| 70 | 73 | ||
| 74 | + if (debug) { | ||
| 75 | + for (const auto &p : ans) { | ||
| 76 | + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + | ||
| 71 | return ans; | 80 | return ans; |
| 72 | } | 81 | } |
| 73 | 82 | ||
| 83 | +void InitContext(void *model_data, size_t model_data_length, bool debug, | ||
| 84 | + rknn_context *ctx) { | ||
| 85 | + auto ret = rknn_init(ctx, model_data, model_data_length, 0, nullptr); | ||
| 86 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init rknn"); | ||
| 87 | + | ||
| 88 | + if (debug) { | ||
| 89 | + rknn_sdk_version v; | ||
| 90 | + ret = rknn_query(*ctx, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); | ||
| 91 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); | ||
| 92 | + | ||
| 93 | + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, | ||
| 94 | + v.drv_version); | ||
| 95 | + } | ||
| 96 | +} | ||
| 97 | + | ||
| 98 | +void InitInputOutputAttrs(rknn_context ctx, bool debug, | ||
| 99 | + std::vector<rknn_tensor_attr> *input_attrs, | ||
| 100 | + std::vector<rknn_tensor_attr> *output_attrs) { | ||
| 101 | + rknn_input_output_num io_num; | ||
| 102 | + auto ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); | ||
| 103 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); | ||
| 104 | + | ||
| 105 | + if (debug) { | ||
| 106 | + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", | ||
| 107 | + static_cast<int32_t>(io_num.n_input), | ||
| 108 | + static_cast<int32_t>(io_num.n_output)); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + input_attrs->resize(io_num.n_input); | ||
| 112 | + output_attrs->resize(io_num.n_output); | ||
| 113 | + | ||
| 114 | + int32_t i = 0; | ||
| 115 | + for (auto &attr : *input_attrs) { | ||
| 116 | + memset(&attr, 0, sizeof(attr)); | ||
| 117 | + attr.index = i; | ||
| 118 | + ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); | ||
| 119 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); | ||
| 120 | + i += 1; | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + if (debug) { | ||
| 124 | + std::ostringstream os; | ||
| 125 | + std::string sep; | ||
| 126 | + for (auto &attr : *input_attrs) { | ||
| 127 | + os << sep << ToString(attr); | ||
| 128 | + sep = "\n"; | ||
| 129 | + } | ||
| 130 | + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", | ||
| 131 | + os.str().c_str()); | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + i = 0; | ||
| 135 | + for (auto &attr : *output_attrs) { | ||
| 136 | + memset(&attr, 0, sizeof(attr)); | ||
| 137 | + attr.index = i; | ||
| 138 | + ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); | ||
| 139 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); | ||
| 140 | + i += 1; | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + if (debug) { | ||
| 144 | + std::ostringstream os; | ||
| 145 | + std::string sep; | ||
| 146 | + for (auto &attr : *output_attrs) { | ||
| 147 | + os << sep << ToString(attr); | ||
| 148 | + sep = "\n"; | ||
| 149 | + } | ||
| 150 | + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", | ||
| 151 | + os.str().c_str()); | ||
| 152 | + } | ||
| 153 | +} | ||
| 154 | + | ||
| 155 | +rknn_custom_string GetCustomString(rknn_context ctx, bool debug) { | ||
| 156 | + rknn_custom_string custom_string; | ||
| 157 | + auto ret = rknn_query(ctx, RKNN_QUERY_CUSTOM_STRING, &custom_string, | ||
| 158 | + sizeof(custom_string)); | ||
| 159 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); | ||
| 160 | + if (debug) { | ||
| 161 | + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); | ||
| 162 | + } | ||
| 163 | + return custom_string; | ||
| 164 | +} | ||
| 165 | + | ||
| 166 | +void SetCoreMask(rknn_context ctx, int32_t num_threads) { | ||
| 167 | + int32_t ret = RKNN_SUCC; | ||
| 168 | + switch (num_threads) { | ||
| 169 | + case 1: | ||
| 170 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_AUTO); | ||
| 171 | + break; | ||
| 172 | + case 0: | ||
| 173 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0); | ||
| 174 | + break; | ||
| 175 | + case -1: | ||
| 176 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_1); | ||
| 177 | + break; | ||
| 178 | + case -2: | ||
| 179 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_2); | ||
| 180 | + break; | ||
| 181 | + case -3: | ||
| 182 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1); | ||
| 183 | + break; | ||
| 184 | + case -4: | ||
| 185 | + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1_2); | ||
| 186 | + break; | ||
| 187 | + default: | ||
| 188 | + SHERPA_ONNX_LOGE( | ||
| 189 | + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " | ||
| 190 | + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", | ||
| 191 | + num_threads); | ||
| 192 | + break; | ||
| 193 | + } | ||
| 194 | + if (ret != RKNN_SUCC) { | ||
| 195 | + SHERPA_ONNX_LOGE( | ||
| 196 | + "Failed to select npu core to run the model (You can ignore it if " | ||
| 197 | + "you are not using RK3588."); | ||
| 198 | + } | ||
| 199 | +} | ||
| 200 | + | ||
| 74 | } // namespace sherpa_onnx | 201 | } // namespace sherpa_onnx |
| @@ -7,17 +7,31 @@ | @@ -7,17 +7,31 @@ | ||
| 7 | 7 | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <unordered_map> | 9 | #include <unordered_map> |
| 10 | +#include <vector> | ||
| 10 | 11 | ||
| 11 | #include "rknn_api.h" // NOLINT | 12 | #include "rknn_api.h" // NOLINT |
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | + | ||
| 14 | void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, | 16 | void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, |
| 15 | int32_t height, int32_t width, float *dst); | 17 | int32_t height, int32_t width, float *dst); |
| 16 | 18 | ||
| 17 | std::string ToString(const rknn_tensor_attr &attr); | 19 | std::string ToString(const rknn_tensor_attr &attr); |
| 18 | 20 | ||
| 19 | std::unordered_map<std::string, std::string> Parse( | 21 | std::unordered_map<std::string, std::string> Parse( |
| 20 | - const rknn_custom_string &custom_string); | 22 | + const rknn_custom_string &custom_string, bool debug = false); |
| 23 | + | ||
| 24 | +void InitContext(void *model_data, size_t model_data_length, bool debug, | ||
| 25 | + rknn_context *ctx); | ||
| 26 | + | ||
| 27 | +void InitInputOutputAttrs(rknn_context ctx, bool debug, | ||
| 28 | + std::vector<rknn_tensor_attr> *input_attrs, | ||
| 29 | + std::vector<rknn_tensor_attr> *output_attrs); | ||
| 30 | + | ||
| 31 | +rknn_custom_string GetCustomString(rknn_context ctx, bool debug); | ||
| 32 | + | ||
| 33 | +void SetCoreMask(rknn_context ctx, int32_t num_threads); | ||
| 34 | + | ||
| 21 | } // namespace sherpa_onnx | 35 | } // namespace sherpa_onnx |
| 22 | 36 | ||
| 23 | #endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_ | 37 | #endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_ |
-
请 注册 或 登录 后发表评论