Fangjun Kuang
Committed by GitHub

Minor fixes for rknn (#1925)

... ... @@ -99,7 +99,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
int32_t n_text_ctx = model_->TextCtx();
std::vector<int32_t> predicted_tokens;
for (int32_t i = 0; i < n_text_ctx; ++i) {
for (int32_t i = 0; i < n_text_ctx / 2; ++i) {
if (max_token_id == model_->EOT()) {
break;
}
... ...
... ... @@ -7,6 +7,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ... @@ -65,6 +66,29 @@ bool OnlineModelConfig::Validate() const {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".rknn") ||
EndsWith(transducer.decoder, ".rknn") ||
EndsWith(transducer.joiner, ".rknn"))) {
SHERPA_ONNX_LOGE(
"--provider is %s, which is not rknn, but you pass rknn model "
"filenames. encoder: '%s', decoder: '%s', joiner: '%s'",
provider_config.provider.c_str(), transducer.encoder.c_str(),
transducer.decoder.c_str(), transducer.joiner.c_str());
return false;
}
}
if (provider_config.provider == "rknn") {
if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".onnx") ||
EndsWith(transducer.decoder, ".onnx") ||
EndsWith(transducer.joiner, ".onnx"))) {
SHERPA_ONNX_LOGE(
"--provider is rknn, but you pass onnx model "
"filenames. encoder: '%s', decoder: '%s', joiner: %'s'",
transducer.encoder.c_str(), transducer.decoder.c_str(),
transducer.joiner.c_str());
return false;
}
}
if (!tokens_buf.empty() && FileExists(tokens)) {
... ...
... ... @@ -463,8 +463,10 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
auto meta = Parse(custom_string);
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
if (meta.count("encoder_dims")) {
... ...
... ... @@ -90,6 +90,8 @@ as the device_name.
exit(-1);
}
fprintf(stderr, "Started! Please speak\n");
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
std::string last_text;
... ...
... ... @@ -158,8 +158,11 @@ for a list of pre-trained models to download.
const float rtf = s.elapsed_seconds / s.duration;
os << po.GetArg(i) << "\n";
os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
<< ", Real time factor (RTF): " << rtf << "\n";
os << "Number of threads: " << config.model_config.num_threads << ", "
<< std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
<< ", Audio duration (s): " << s.duration
<< ", Real time factor (RTF) = " << s.elapsed_seconds << "/"
<< s.duration << " = " << rtf << "\n";
const auto r = recognizer.GetResult(s.online_stream.get());
os << r.text << "\n";
os << r.AsJsonString() << "\n\n";
... ...
... ... @@ -699,4 +699,12 @@ std::string ToString(const std::wstring &s) {
return converter.to_bytes(s);
}
bool EndsWith(const std::string &haystack, const std::string &needle) {
if (needle.size() > haystack.size()) {
return false;
}
return std::equal(needle.rbegin(), needle.rend(), haystack.rbegin());
}
} // namespace sherpa_onnx
... ...
... ... @@ -145,6 +145,8 @@ std::wstring ToWideString(const std::string &s);
std::string ToString(const std::wstring &s);
bool EndsWith(const std::string &haystack, const std::string &needle);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
... ...