Fangjun Kuang
Committed by GitHub

Refactor rknn code (#2079)

... ... @@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
template <typename Manager>
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
Manager *mgr, const OnlineRecognizerConfig &config) {
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
// Currently, only zipformer v1 is suported for rknn
if (config.model_config.transducer.encoder.empty() &&
config.model_config.zipformer2_ctc.model.empty()) {
SHERPA_ONNX_LOGE(
"Only Zipformer transducers and CTC models are currently supported "
"by rknn. Fallback to CPU");
} else if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(mgr, config);
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcRknnImpl>(mgr, config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn. Fallback to CPU");
#endif
}
if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
... ...
... ... @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl {
Init(buf.data(), buf.size());
}
int32_t ret = RKNN_SUCC;
switch (config_.num_threads) {
case 1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
break;
case -1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
break;
case -2:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
break;
case -3:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
break;
case -4:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
config_.num_threads);
break;
}
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run the model (You can ignore it if "
"you "
"are not using RK3588.");
SetCoreMask(ctx_, config_.num_threads);
}
template <typename Manager>
Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
SetCoreMask(ctx_, config_.num_threads);
}
// TODO(fangjun): Support Android
... ... @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl {
private:
void Init(void *model_data, size_t model_data_length) {
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
config_.zipformer2_ctc.model.c_str());
if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs_.resize(io_num.n_input);
output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
InitContext(model_data, model_data_length, config_.debug, &ctx_);
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
rknn_custom_string custom_string;
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
auto meta = Parse(custom_string, config_.debug);
if (meta.count("T")) {
T_ = atoi(meta.at("T").c_str());
... ...
... ... @@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl {
InitJoiner(buf.data(), buf.size());
}
// Now select which core to run for RK3588
int32_t ret_encoder = RKNN_SUCC;
int32_t ret_decoder = RKNN_SUCC;
int32_t ret_joiner = RKNN_SUCC;
switch (config_.num_threads) {
case 1:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0);
break;
case -1:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1);
break;
case -2:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2);
break;
case -3:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1);
break;
case -4:
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2);
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2);
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
config_.num_threads);
break;
}
if (ret_encoder != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run encoder (You can ignore it if you "
"are not using RK3588.");
SetCoreMask(encoder_ctx_, config_.num_threads);
SetCoreMask(decoder_ctx_, config_.num_threads);
SetCoreMask(joiner_ctx_, config_.num_threads);
}
template <typename Manager>
Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(mgr, config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}
if (ret_decoder != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run decoder (You can ignore it if you "
"are not using RK3588.");
{
auto buf = ReadFile(mgr, config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}
if (ret_decoder != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run joiner (You can ignore it if you "
"are not using RK3588.");
{
auto buf = ReadFile(mgr, config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
SetCoreMask(encoder_ctx_, config_.num_threads);
SetCoreMask(decoder_ctx_, config_.num_threads);
SetCoreMask(joiner_ctx_, config_.num_threads);
}
// TODO(fangjun): Support Android
... ... @@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl {
private:
void InitEncoder(void *model_data, size_t model_data_length) {
auto ret =
rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'",
config_.transducer.encoder.c_str());
if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
rknn_input_output_num io_num;
ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,
sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret,
"Failed to get I/O information for the encoder");
if (config_.debug) {
SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
encoder_input_attrs_.resize(io_num.n_input);
encoder_output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : encoder_input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret =
rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : encoder_input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : encoder_output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret =
rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d",
i);
i += 1;
}
InitContext(model_data, model_data_length, config_.debug, &encoder_ctx_);
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : encoder_output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s",
os.str().c_str());
}
InitInputOutputAttrs(encoder_ctx_, config_.debug, &encoder_input_attrs_,
&encoder_output_attrs_);
rknn_custom_string custom_string;
ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(
ret, "Failed to read custom string from the encoder model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
rknn_custom_string custom_string =
GetCustomString(encoder_ctx_, config_.debug);
if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
auto meta = Parse(custom_string, config_.debug);
if (meta.count("encoder_dims")) {
SplitStringToIntegers(meta.at("encoder_dims"), ",", false,
... ... @@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
void InitDecoder(void *model_data, size_t model_data_length) {
auto ret =
rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'",
config_.transducer.decoder.c_str());
rknn_input_output_num io_num;
ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,
sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret,
"Failed to get I/O information for the decoder");
if (io_num.n_input != 1) {
SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d",
static_cast<int32_t>(io_num.n_input));
SHERPA_ONNX_EXIT(-1);
}
InitContext(model_data, model_data_length, config_.debug, &decoder_ctx_);
if (io_num.n_output != 1) {
SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d",
static_cast<int32_t>(io_num.n_output));
SHERPA_ONNX_EXIT(-1);
}
if (config_.debug) {
SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
decoder_input_attrs_.resize(io_num.n_input);
decoder_output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : decoder_input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret =
rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : decoder_input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s",
os.str().c_str());
}
InitInputOutputAttrs(decoder_ctx_, config_.debug, &decoder_input_attrs_,
&decoder_output_attrs_);
if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) {
SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s",
... ... @@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl {
if (config_.debug) {
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
}
i = 0;
for (auto &attr : decoder_output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret =
rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d",
i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : decoder_output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s",
os.str().c_str());
}
}
void InitJoiner(void *model_data, size_t model_data_length) {
auto ret =
rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'",
config_.transducer.joiner.c_str());
InitContext(model_data, model_data_length, config_.debug, &joiner_ctx_);
rknn_input_output_num io_num;
ret =
rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner");
if (config_.debug) {
SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
joiner_input_attrs_.resize(io_num.n_input);
joiner_output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : joiner_input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : joiner_input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : joiner_output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret =
rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : joiner_output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s",
os.str().c_str());
}
InitInputOutputAttrs(joiner_ctx_, config_.debug, &joiner_input_attrs_,
&joiner_output_attrs_);
vocab_size_ = joiner_output_attrs_[0].dims[1];
if (config_.debug) {
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
... ... @@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl {
auto buf = ReadFile(config.silero_vad.model);
Init(buf.data(), buf.size());
SetCoreMask(ctx_, config_.num_threads);
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
... ... @@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl {
auto buf = ReadFile(mgr, config.silero_vad.model);
Init(buf.data(), buf.size());
SetCoreMask(ctx_, config_.num_threads);
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
... ... @@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl {
private:
void Init(void *model_data, size_t model_data_length) {
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'",
config_.silero_vad.model.c_str());
if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
InitContext(model_data, model_data_length, config_.debug, &ctx_);
rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
if (config_.debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs_.resize(io_num.n_input);
output_attrs_.resize(io_num.n_output);
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
int32_t i = 0;
for (auto &attr : input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
rknn_custom_string custom_string;
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
auto meta = Parse(custom_string, config_.debug);
if (config_.silero_vad.window_size != 512) {
SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d",
... ...
... ... @@ -4,12 +4,15 @@
#include "sherpa-onnx/csrc/rknn/utils.h"
#include <string.h>
#include <sstream>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ... @@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) {
}
std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string) {
const rknn_custom_string &custom_string, bool debug /*= false*/) {
std::unordered_map<std::string, std::string> ans;
std::vector<std::string> fields;
SplitStringToVector(custom_string.string, ";", false, &fields);
... ... @@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse(
ans[std::move(tmp[0])] = std::move(tmp[1]);
}
if (debug) {
for (const auto &p : ans) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
return ans;
}
void InitContext(void *model_data, size_t model_data_length, bool debug,
rknn_context *ctx) {
auto ret = rknn_init(ctx, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init rknn");
if (debug) {
rknn_sdk_version v;
ret = rknn_query(*ctx, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
}
void InitInputOutputAttrs(rknn_context ctx, bool debug,
std::vector<rknn_tensor_attr> *input_attrs,
std::vector<rknn_tensor_attr> *output_attrs) {
rknn_input_output_num io_num;
auto ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
if (debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs->resize(io_num.n_input);
output_attrs->resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : *input_attrs) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : *input_attrs) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : *output_attrs) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
if (debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : *output_attrs) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
}
rknn_custom_string GetCustomString(rknn_context ctx, bool debug) {
rknn_custom_string custom_string;
auto ret = rknn_query(ctx, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
return custom_string;
}
void SetCoreMask(rknn_context ctx, int32_t num_threads) {
int32_t ret = RKNN_SUCC;
switch (num_threads) {
case 1:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0);
break;
case -1:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_1);
break;
case -2:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_2);
break;
case -3:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1);
break;
case -4:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
num_threads);
break;
}
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run the model (You can ignore it if "
"you are not using RK3588.");
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -7,17 +7,31 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "rknn_api.h" // NOLINT
namespace sherpa_onnx {
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
int32_t height, int32_t width, float *dst);
std::string ToString(const rknn_tensor_attr &attr);
std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string);
const rknn_custom_string &custom_string, bool debug = false);
void InitContext(void *model_data, size_t model_data_length, bool debug,
rknn_context *ctx);
void InitInputOutputAttrs(rknn_context ctx, bool debug,
std::vector<rknn_tensor_attr> *input_attrs,
std::vector<rknn_tensor_attr> *output_attrs);
rknn_custom_string GetCustomString(rknn_context ctx, bool debug);
void SetCoreMask(rknn_context ctx, int32_t num_threads);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
... ...