Fangjun Kuang
Committed by GitHub

Refactor rknn code (#2079)

@@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
92 template <typename Manager> 92 template <typename Manager>
93 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( 93 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
94 Manager *mgr, const OnlineRecognizerConfig &config) { 94 Manager *mgr, const OnlineRecognizerConfig &config) {
  95 + if (config.model_config.provider_config.provider == "rknn") {
  96 +#if SHERPA_ONNX_ENABLE_RKNN
  97 + // Currently, only zipformer v1 is suported for rknn
  98 + if (config.model_config.transducer.encoder.empty() &&
  99 + config.model_config.zipformer2_ctc.model.empty()) {
  100 + SHERPA_ONNX_LOGE(
  101 + "Only Zipformer transducers and CTC models are currently supported "
  102 + "by rknn. Fallback to CPU");
  103 + } else if (!config.model_config.transducer.encoder.empty()) {
  104 + return std::make_unique<OnlineRecognizerTransducerRknnImpl>(mgr, config);
  105 + } else if (!config.model_config.zipformer2_ctc.model.empty()) {
  106 + return std::make_unique<OnlineRecognizerCtcRknnImpl>(mgr, config);
  107 + }
  108 +#else
  109 + SHERPA_ONNX_LOGE(
  110 + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
  111 + "want to use rknn. Fallback to CPU");
  112 +#endif
  113 + }
  114 +
95 if (!config.model_config.transducer.encoder.empty()) { 115 if (!config.model_config.transducer.encoder.empty()) {
96 Ort::Env env(ORT_LOGGING_LEVEL_ERROR); 116 Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
97 117
@@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl {
42 Init(buf.data(), buf.size()); 42 Init(buf.data(), buf.size());
43 } 43 }
44 44
45 - int32_t ret = RKNN_SUCC;  
46 - switch (config_.num_threads) {  
47 - case 1:  
48 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);  
49 - break;  
50 - case 0:  
51 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);  
52 - break;  
53 - case -1:  
54 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);  
55 - break;  
56 - case -2:  
57 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);  
58 - break;  
59 - case -3:  
60 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);  
61 - break;  
62 - case -4:  
63 - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);  
64 - break;  
65 - default:  
66 - SHERPA_ONNX_LOGE(  
67 - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "  
68 - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",  
69 - config_.num_threads);  
70 - break;  
71 - }  
72 - if (ret != RKNN_SUCC) {  
73 - SHERPA_ONNX_LOGE(  
74 - "Failed to select npu core to run the model (You can ignore it if "  
75 - "you "  
76 - "are not using RK3588."); 45 + SetCoreMask(ctx_, config_.num_threads);
  46 + }
  47 +
  48 + template <typename Manager>
  49 + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
  50 + {
  51 + auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
  52 + Init(buf.data(), buf.size());
77 } 53 }
  54 +
  55 + SetCoreMask(ctx_, config_.num_threads);
78 } 56 }
79 57
80 // TODO(fangjun): Support Android 58 // TODO(fangjun): Support Android
@@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl {
209 187
210 private: 188 private:
211 void Init(void *model_data, size_t model_data_length) { 189 void Init(void *model_data, size_t model_data_length) {
212 - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);  
213 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",  
214 - config_.zipformer2_ctc.model.c_str());  
215 -  
216 - if (config_.debug) {  
217 - rknn_sdk_version v;  
218 - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));  
219 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");  
220 -  
221 - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,  
222 - v.drv_version);  
223 - }  
224 -  
225 - rknn_input_output_num io_num;  
226 - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));  
227 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");  
228 -  
229 - if (config_.debug) {  
230 - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",  
231 - static_cast<int32_t>(io_num.n_input),  
232 - static_cast<int32_t>(io_num.n_output));  
233 - }  
234 -  
235 - input_attrs_.resize(io_num.n_input);  
236 - output_attrs_.resize(io_num.n_output);  
237 -  
238 - int32_t i = 0;  
239 - for (auto &attr : input_attrs_) {  
240 - memset(&attr, 0, sizeof(attr));  
241 - attr.index = i;  
242 - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));  
243 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);  
244 - i += 1;  
245 - }  
246 -  
247 - if (config_.debug) {  
248 - std::ostringstream os;  
249 - std::string sep;  
250 - for (auto &attr : input_attrs_) {  
251 - os << sep << ToString(attr);  
252 - sep = "\n";  
253 - }  
254 - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",  
255 - os.str().c_str());  
256 - }  
257 -  
258 - i = 0;  
259 - for (auto &attr : output_attrs_) {  
260 - memset(&attr, 0, sizeof(attr));  
261 - attr.index = i;  
262 - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));  
263 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);  
264 - i += 1;  
265 - } 190 + InitContext(model_data, model_data_length, config_.debug, &ctx_);
266 191
267 - if (config_.debug) {  
268 - std::ostringstream os;  
269 - std::string sep;  
270 - for (auto &attr : output_attrs_) {  
271 - os << sep << ToString(attr);  
272 - sep = "\n";  
273 - }  
274 - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",  
275 - os.str().c_str());  
276 - } 192 + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
277 193
278 - rknn_custom_string custom_string;  
279 - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,  
280 - sizeof(custom_string));  
281 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");  
282 - if (config_.debug) {  
283 - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);  
284 - }  
285 - auto meta = Parse(custom_string); 194 + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
286 195
287 - if (config_.debug) {  
288 - for (const auto &p : meta) {  
289 - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());  
290 - }  
291 - } 196 + auto meta = Parse(custom_string, config_.debug);
292 197
293 if (meta.count("T")) { 198 if (meta.count("T")) {
294 T_ = atoi(meta.at("T").c_str()); 199 T_ = atoi(meta.at("T").c_str());
@@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl {
62 InitJoiner(buf.data(), buf.size()); 62 InitJoiner(buf.data(), buf.size());
63 } 63 }
64 64
65 - // Now select which core to run for RK3588  
66 - int32_t ret_encoder = RKNN_SUCC;  
67 - int32_t ret_decoder = RKNN_SUCC;  
68 - int32_t ret_joiner = RKNN_SUCC;  
69 - switch (config_.num_threads) {  
70 - case 1:  
71 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO);  
72 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO);  
73 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO);  
74 - break;  
75 - case 0:  
76 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0);  
77 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0);  
78 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0);  
79 - break;  
80 - case -1:  
81 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1);  
82 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1);  
83 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1);  
84 - break;  
85 - case -2:  
86 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2);  
87 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2);  
88 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2);  
89 - break;  
90 - case -3:  
91 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1);  
92 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1);  
93 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1);  
94 - break;  
95 - case -4:  
96 - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2);  
97 - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2);  
98 - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2);  
99 - break;  
100 - default:  
101 - SHERPA_ONNX_LOGE(  
102 - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "  
103 - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",  
104 - config_.num_threads);  
105 - break;  
106 - }  
107 - if (ret_encoder != RKNN_SUCC) {  
108 - SHERPA_ONNX_LOGE(  
109 - "Failed to select npu core to run encoder (You can ignore it if you "  
110 - "are not using RK3588."); 65 + SetCoreMask(encoder_ctx_, config_.num_threads);
  66 + SetCoreMask(decoder_ctx_, config_.num_threads);
  67 + SetCoreMask(joiner_ctx_, config_.num_threads);
  68 + }
  69 +
  70 + template <typename Manager>
  71 + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
  72 + {
  73 + auto buf = ReadFile(mgr, config.transducer.encoder);
  74 + InitEncoder(buf.data(), buf.size());
111 } 75 }
112 76
113 - if (ret_decoder != RKNN_SUCC) {  
114 - SHERPA_ONNX_LOGE(  
115 - "Failed to select npu core to run decoder (You can ignore it if you "  
116 - "are not using RK3588."); 77 + {
  78 + auto buf = ReadFile(mgr, config.transducer.decoder);
  79 + InitDecoder(buf.data(), buf.size());
117 } 80 }
118 81
119 - if (ret_decoder != RKNN_SUCC) {  
120 - SHERPA_ONNX_LOGE(  
121 - "Failed to select npu core to run joiner (You can ignore it if you "  
122 - "are not using RK3588."); 82 + {
  83 + auto buf = ReadFile(mgr, config.transducer.joiner);
  84 + InitJoiner(buf.data(), buf.size());
123 } 85 }
  86 +
  87 + SetCoreMask(encoder_ctx_, config_.num_threads);
  88 + SetCoreMask(decoder_ctx_, config_.num_threads);
  89 + SetCoreMask(joiner_ctx_, config_.num_threads);
124 } 90 }
125 91
126 // TODO(fangjun): Support Android 92 // TODO(fangjun): Support Android
@@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl {
325 291
326 private: 292 private:
327 void InitEncoder(void *model_data, size_t model_data_length) { 293 void InitEncoder(void *model_data, size_t model_data_length) {
328 - auto ret =  
329 - rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr);  
330 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'",  
331 - config_.transducer.encoder.c_str());  
332 -  
333 - if (config_.debug) {  
334 - rknn_sdk_version v;  
335 - ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));  
336 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");  
337 -  
338 - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,  
339 - v.drv_version);  
340 - }  
341 -  
342 - rknn_input_output_num io_num;  
343 - ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,  
344 - sizeof(io_num));  
345 - SHERPA_ONNX_RKNN_CHECK(ret,  
346 - "Failed to get I/O information for the encoder");  
347 -  
348 - if (config_.debug) {  
349 - SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs",  
350 - static_cast<int32_t>(io_num.n_input),  
351 - static_cast<int32_t>(io_num.n_output));  
352 - }  
353 -  
354 - encoder_input_attrs_.resize(io_num.n_input);  
355 - encoder_output_attrs_.resize(io_num.n_output);  
356 -  
357 - int32_t i = 0;  
358 - for (auto &attr : encoder_input_attrs_) {  
359 - memset(&attr, 0, sizeof(attr));  
360 - attr.index = i;  
361 - ret =  
362 - rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));  
363 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i);  
364 - i += 1;  
365 - }  
366 -  
367 - if (config_.debug) {  
368 - std::ostringstream os;  
369 - std::string sep;  
370 - for (auto &attr : encoder_input_attrs_) {  
371 - os << sep << ToString(attr);  
372 - sep = "\n";  
373 - }  
374 - SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s",  
375 - os.str().c_str());  
376 - }  
377 -  
378 - i = 0;  
379 - for (auto &attr : encoder_output_attrs_) {  
380 - memset(&attr, 0, sizeof(attr));  
381 - attr.index = i;  
382 - ret =  
383 - rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));  
384 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d",  
385 - i);  
386 - i += 1;  
387 - } 294 + InitContext(model_data, model_data_length, config_.debug, &encoder_ctx_);
388 295
389 - if (config_.debug) {  
390 - std::ostringstream os;  
391 - std::string sep;  
392 - for (auto &attr : encoder_output_attrs_) {  
393 - os << sep << ToString(attr);  
394 - sep = "\n";  
395 - }  
396 - SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s",  
397 - os.str().c_str());  
398 - } 296 + InitInputOutputAttrs(encoder_ctx_, config_.debug, &encoder_input_attrs_,
  297 + &encoder_output_attrs_);
399 298
400 - rknn_custom_string custom_string;  
401 - ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,  
402 - sizeof(custom_string));  
403 - SHERPA_ONNX_RKNN_CHECK(  
404 - ret, "Failed to read custom string from the encoder model");  
405 - if (config_.debug) {  
406 - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);  
407 - }  
408 - auto meta = Parse(custom_string); 299 + rknn_custom_string custom_string =
  300 + GetCustomString(encoder_ctx_, config_.debug);
409 301
410 - if (config_.debug) {  
411 - for (const auto &p : meta) {  
412 - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());  
413 - }  
414 - } 302 + auto meta = Parse(custom_string, config_.debug);
415 303
416 if (meta.count("encoder_dims")) { 304 if (meta.count("encoder_dims")) {
417 SplitStringToIntegers(meta.at("encoder_dims"), ",", false, 305 SplitStringToIntegers(meta.at("encoder_dims"), ",", false,
@@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl {
479 } 367 }
480 368
481 void InitDecoder(void *model_data, size_t model_data_length) { 369 void InitDecoder(void *model_data, size_t model_data_length) {
482 - auto ret =  
483 - rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr);  
484 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'",  
485 - config_.transducer.decoder.c_str());  
486 -  
487 - rknn_input_output_num io_num;  
488 - ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,  
489 - sizeof(io_num));  
490 - SHERPA_ONNX_RKNN_CHECK(ret,  
491 - "Failed to get I/O information for the decoder");  
492 -  
493 - if (io_num.n_input != 1) {  
494 - SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d",  
495 - static_cast<int32_t>(io_num.n_input));  
496 - SHERPA_ONNX_EXIT(-1);  
497 - } 370 + InitContext(model_data, model_data_length, config_.debug, &decoder_ctx_);
498 371
499 - if (io_num.n_output != 1) {  
500 - SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d",  
501 - static_cast<int32_t>(io_num.n_output));  
502 - SHERPA_ONNX_EXIT(-1);  
503 - }  
504 -  
505 - if (config_.debug) {  
506 - SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs",  
507 - static_cast<int32_t>(io_num.n_input),  
508 - static_cast<int32_t>(io_num.n_output));  
509 - }  
510 -  
511 - decoder_input_attrs_.resize(io_num.n_input);  
512 - decoder_output_attrs_.resize(io_num.n_output);  
513 -  
514 - int32_t i = 0;  
515 - for (auto &attr : decoder_input_attrs_) {  
516 - memset(&attr, 0, sizeof(attr));  
517 - attr.index = i;  
518 - ret =  
519 - rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));  
520 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i);  
521 - i += 1;  
522 - }  
523 -  
524 - if (config_.debug) {  
525 - std::ostringstream os;  
526 - std::string sep;  
527 - for (auto &attr : decoder_input_attrs_) {  
528 - os << sep << ToString(attr);  
529 - sep = "\n";  
530 - }  
531 - SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s",  
532 - os.str().c_str());  
533 - } 372 + InitInputOutputAttrs(decoder_ctx_, config_.debug, &decoder_input_attrs_,
  373 + &decoder_output_attrs_);
534 374
535 if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) { 375 if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) {
536 SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s", 376 SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s",
@@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl {
543 if (config_.debug) { 383 if (config_.debug) {
544 SHERPA_ONNX_LOGE("context_size: %d", context_size_); 384 SHERPA_ONNX_LOGE("context_size: %d", context_size_);
545 } 385 }
546 -  
547 - i = 0;  
548 - for (auto &attr : decoder_output_attrs_) {  
549 - memset(&attr, 0, sizeof(attr));  
550 - attr.index = i;  
551 - ret =  
552 - rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));  
553 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d",  
554 - i);  
555 - i += 1;  
556 - }  
557 -  
558 - if (config_.debug) {  
559 - std::ostringstream os;  
560 - std::string sep;  
561 - for (auto &attr : decoder_output_attrs_) {  
562 - os << sep << ToString(attr);  
563 - sep = "\n";  
564 - }  
565 - SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s",  
566 - os.str().c_str());  
567 - }  
568 } 386 }
569 387
570 void InitJoiner(void *model_data, size_t model_data_length) { 388 void InitJoiner(void *model_data, size_t model_data_length) {
571 - auto ret =  
572 - rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr);  
573 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'",  
574 - config_.transducer.joiner.c_str()); 389 + InitContext(model_data, model_data_length, config_.debug, &joiner_ctx_);
575 390
576 - rknn_input_output_num io_num;  
577 - ret =  
578 - rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));  
579 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner");  
580 -  
581 - if (config_.debug) {  
582 - SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs",  
583 - static_cast<int32_t>(io_num.n_input),  
584 - static_cast<int32_t>(io_num.n_output));  
585 - }  
586 -  
587 - joiner_input_attrs_.resize(io_num.n_input);  
588 - joiner_output_attrs_.resize(io_num.n_output);  
589 -  
590 - int32_t i = 0;  
591 - for (auto &attr : joiner_input_attrs_) {  
592 - memset(&attr, 0, sizeof(attr));  
593 - attr.index = i;  
594 - ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));  
595 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i);  
596 - i += 1;  
597 - }  
598 -  
599 - if (config_.debug) {  
600 - std::ostringstream os;  
601 - std::string sep;  
602 - for (auto &attr : joiner_input_attrs_) {  
603 - os << sep << ToString(attr);  
604 - sep = "\n";  
605 - }  
606 - SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s",  
607 - os.str().c_str());  
608 - }  
609 -  
610 - i = 0;  
611 - for (auto &attr : joiner_output_attrs_) {  
612 - memset(&attr, 0, sizeof(attr));  
613 - attr.index = i;  
614 - ret =  
615 - rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));  
616 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i);  
617 - i += 1;  
618 - }  
619 -  
620 - if (config_.debug) {  
621 - std::ostringstream os;  
622 - std::string sep;  
623 - for (auto &attr : joiner_output_attrs_) {  
624 - os << sep << ToString(attr);  
625 - sep = "\n";  
626 - }  
627 - SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s",  
628 - os.str().c_str());  
629 - } 391 + InitInputOutputAttrs(joiner_ctx_, config_.debug, &joiner_input_attrs_,
  392 + &joiner_output_attrs_);
630 393
631 vocab_size_ = joiner_output_attrs_[0].dims[1]; 394 vocab_size_ = joiner_output_attrs_[0].dims[1];
632 if (config_.debug) { 395 if (config_.debug) {
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" 5 #include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
6 6
  7 +#include <memory>
7 #include <string> 8 #include <string>
8 #include <utility> 9 #include <utility>
9 #include <vector> 10 #include <vector>
@@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl { @@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl {
39 auto buf = ReadFile(config.silero_vad.model); 40 auto buf = ReadFile(config.silero_vad.model);
40 Init(buf.data(), buf.size()); 41 Init(buf.data(), buf.size());
41 42
  43 + SetCoreMask(ctx_, config_.num_threads);
  44 +
42 if (sample_rate_ != 16000) { 45 if (sample_rate_ != 16000) {
43 SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", 46 SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
44 config.sample_rate); 47 config.sample_rate);
@@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl { @@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl {
57 auto buf = ReadFile(mgr, config.silero_vad.model); 60 auto buf = ReadFile(mgr, config.silero_vad.model);
58 Init(buf.data(), buf.size()); 61 Init(buf.data(), buf.size());
59 62
  63 + SetCoreMask(ctx_, config_.num_threads);
  64 +
60 if (sample_rate_ != 16000) { 65 if (sample_rate_ != 16000) {
61 SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", 66 SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
62 config.sample_rate); 67 config.sample_rate);
@@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl { @@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl {
172 177
173 private: 178 private:
174 void Init(void *model_data, size_t model_data_length) { 179 void Init(void *model_data, size_t model_data_length) {
175 - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);  
176 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'",  
177 - config_.silero_vad.model.c_str());  
178 -  
179 - if (config_.debug) {  
180 - rknn_sdk_version v;  
181 - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));  
182 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");  
183 -  
184 - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,  
185 - v.drv_version);  
186 - } 180 + InitContext(model_data, model_data_length, config_.debug, &ctx_);
187 181
188 - rknn_input_output_num io_num;  
189 - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));  
190 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); 182 + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
191 183
192 - if (config_.debug) {  
193 - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",  
194 - static_cast<int32_t>(io_num.n_input),  
195 - static_cast<int32_t>(io_num.n_output));  
196 - }  
197 -  
198 - input_attrs_.resize(io_num.n_input);  
199 - output_attrs_.resize(io_num.n_output); 184 + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
200 185
201 - int32_t i = 0;  
202 - for (auto &attr : input_attrs_) {  
203 - memset(&attr, 0, sizeof(attr));  
204 - attr.index = i;  
205 - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));  
206 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);  
207 - i += 1;  
208 - }  
209 -  
210 - if (config_.debug) {  
211 - std::ostringstream os;  
212 - std::string sep;  
213 - for (auto &attr : input_attrs_) {  
214 - os << sep << ToString(attr);  
215 - sep = "\n";  
216 - }  
217 - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",  
218 - os.str().c_str());  
219 - }  
220 -  
221 - i = 0;  
222 - for (auto &attr : output_attrs_) {  
223 - memset(&attr, 0, sizeof(attr));  
224 - attr.index = i;  
225 - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));  
226 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);  
227 - i += 1;  
228 - }  
229 -  
230 - if (config_.debug) {  
231 - std::ostringstream os;  
232 - std::string sep;  
233 - for (auto &attr : output_attrs_) {  
234 - os << sep << ToString(attr);  
235 - sep = "\n";  
236 - }  
237 - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",  
238 - os.str().c_str());  
239 - }  
240 -  
241 - rknn_custom_string custom_string;  
242 - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,  
243 - sizeof(custom_string));  
244 - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");  
245 - if (config_.debug) {  
246 - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);  
247 - }  
248 - auto meta = Parse(custom_string); 186 + auto meta = Parse(custom_string, config_.debug);
249 187
250 if (config_.silero_vad.window_size != 512) { 188 if (config_.silero_vad.window_size != 512) {
251 SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d", 189 SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d",
@@ -4,12 +4,15 @@ @@ -4,12 +4,15 @@
4 4
5 #include "sherpa-onnx/csrc/rknn/utils.h" 5 #include "sherpa-onnx/csrc/rknn/utils.h"
6 6
  7 +#include <string.h>
  8 +
7 #include <sstream> 9 #include <sstream>
8 #include <unordered_map> 10 #include <unordered_map>
9 #include <utility> 11 #include <utility>
10 #include <vector> 12 #include <vector>
11 13
12 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
  15 +#include "sherpa-onnx/csrc/rknn/macros.h"
13 #include "sherpa-onnx/csrc/text-utils.h" 16 #include "sherpa-onnx/csrc/text-utils.h"
14 17
15 namespace sherpa_onnx { 18 namespace sherpa_onnx {
@@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) { @@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) {
52 } 55 }
53 56
54 std::unordered_map<std::string, std::string> Parse( 57 std::unordered_map<std::string, std::string> Parse(
55 - const rknn_custom_string &custom_string) { 58 + const rknn_custom_string &custom_string, bool debug /*= false*/) {
56 std::unordered_map<std::string, std::string> ans; 59 std::unordered_map<std::string, std::string> ans;
57 std::vector<std::string> fields; 60 std::vector<std::string> fields;
58 SplitStringToVector(custom_string.string, ";", false, &fields); 61 SplitStringToVector(custom_string.string, ";", false, &fields);
@@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse( @@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse(
68 ans[std::move(tmp[0])] = std::move(tmp[1]); 71 ans[std::move(tmp[0])] = std::move(tmp[1]);
69 } 72 }
70 73
  74 + if (debug) {
  75 + for (const auto &p : ans) {
  76 + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
  77 + }
  78 + }
  79 +
71 return ans; 80 return ans;
72 } 81 }
73 82
  83 +void InitContext(void *model_data, size_t model_data_length, bool debug,
  84 + rknn_context *ctx) {
  85 + auto ret = rknn_init(ctx, model_data, model_data_length, 0, nullptr);
  86 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init rknn");
  87 +
  88 + if (debug) {
  89 + rknn_sdk_version v;
  90 + ret = rknn_query(*ctx, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
  91 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
  92 +
  93 + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
  94 + v.drv_version);
  95 + }
  96 +}
  97 +
  98 +void InitInputOutputAttrs(rknn_context ctx, bool debug,
  99 + std::vector<rknn_tensor_attr> *input_attrs,
  100 + std::vector<rknn_tensor_attr> *output_attrs) {
  101 + rknn_input_output_num io_num;
  102 + auto ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
  103 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
  104 +
  105 + if (debug) {
  106 + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
  107 + static_cast<int32_t>(io_num.n_input),
  108 + static_cast<int32_t>(io_num.n_output));
  109 + }
  110 +
  111 + input_attrs->resize(io_num.n_input);
  112 + output_attrs->resize(io_num.n_output);
  113 +
  114 + int32_t i = 0;
  115 + for (auto &attr : *input_attrs) {
  116 + memset(&attr, 0, sizeof(attr));
  117 + attr.index = i;
  118 + ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
  119 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
  120 + i += 1;
  121 + }
  122 +
  123 + if (debug) {
  124 + std::ostringstream os;
  125 + std::string sep;
  126 + for (auto &attr : *input_attrs) {
  127 + os << sep << ToString(attr);
  128 + sep = "\n";
  129 + }
  130 + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
  131 + os.str().c_str());
  132 + }
  133 +
  134 + i = 0;
  135 + for (auto &attr : *output_attrs) {
  136 + memset(&attr, 0, sizeof(attr));
  137 + attr.index = i;
  138 + ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
  139 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
  140 + i += 1;
  141 + }
  142 +
  143 + if (debug) {
  144 + std::ostringstream os;
  145 + std::string sep;
  146 + for (auto &attr : *output_attrs) {
  147 + os << sep << ToString(attr);
  148 + sep = "\n";
  149 + }
  150 + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
  151 + os.str().c_str());
  152 + }
  153 +}
  154 +
  155 +rknn_custom_string GetCustomString(rknn_context ctx, bool debug) {
  156 + rknn_custom_string custom_string;
  157 + auto ret = rknn_query(ctx, RKNN_QUERY_CUSTOM_STRING, &custom_string,
  158 + sizeof(custom_string));
  159 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
  160 + if (debug) {
  161 + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
  162 + }
  163 + return custom_string;
  164 +}
  165 +
  166 +void SetCoreMask(rknn_context ctx, int32_t num_threads) {
  167 + int32_t ret = RKNN_SUCC;
  168 + switch (num_threads) {
  169 + case 1:
  170 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_AUTO);
  171 + break;
  172 + case 0:
  173 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0);
  174 + break;
  175 + case -1:
  176 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_1);
  177 + break;
  178 + case -2:
  179 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_2);
  180 + break;
  181 + case -3:
  182 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1);
  183 + break;
  184 + case -4:
  185 + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1_2);
  186 + break;
  187 + default:
  188 + SHERPA_ONNX_LOGE(
  189 + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
  190 + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
  191 + num_threads);
  192 + break;
  193 + }
  194 + if (ret != RKNN_SUCC) {
  195 + SHERPA_ONNX_LOGE(
  196 + "Failed to select npu core to run the model (You can ignore it if "
  197 + "you are not using RK3588.");
  198 + }
  199 +}
  200 +
74 } // namespace sherpa_onnx 201 } // namespace sherpa_onnx
@@ -7,17 +7,31 @@ @@ -7,17 +7,31 @@
7 7
8 #include <string> 8 #include <string>
9 #include <unordered_map> 9 #include <unordered_map>
  10 +#include <vector>
10 11
11 #include "rknn_api.h" // NOLINT 12 #include "rknn_api.h" // NOLINT
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
  15 +
14 void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, 16 void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
15 int32_t height, int32_t width, float *dst); 17 int32_t height, int32_t width, float *dst);
16 18
17 std::string ToString(const rknn_tensor_attr &attr); 19 std::string ToString(const rknn_tensor_attr &attr);
18 20
19 std::unordered_map<std::string, std::string> Parse( 21 std::unordered_map<std::string, std::string> Parse(
20 - const rknn_custom_string &custom_string); 22 + const rknn_custom_string &custom_string, bool debug = false);
  23 +
  24 +void InitContext(void *model_data, size_t model_data_length, bool debug,
  25 + rknn_context *ctx);
  26 +
  27 +void InitInputOutputAttrs(rknn_context ctx, bool debug,
  28 + std::vector<rknn_tensor_attr> *input_attrs,
  29 + std::vector<rknn_tensor_attr> *output_attrs);
  30 +
  31 +rknn_custom_string GetCustomString(rknn_context ctx, bool debug);
  32 +
  33 +void SetCoreMask(rknn_context ctx, int32_t num_threads);
  34 +
21 } // namespace sherpa_onnx 35 } // namespace sherpa_onnx
22 36
23 #endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_ 37 #endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_