Fangjun Kuang
Committed by GitHub

Minor fixes for rknn (#1925)

@@ -99,7 +99,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -99,7 +99,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
99 int32_t n_text_ctx = model_->TextCtx(); 99 int32_t n_text_ctx = model_->TextCtx();
100 100
101 std::vector<int32_t> predicted_tokens; 101 std::vector<int32_t> predicted_tokens;
102 - for (int32_t i = 0; i < n_text_ctx; ++i) { 102 + for (int32_t i = 0; i < n_text_ctx / 2; ++i) {
103 if (max_token_id == model_->EOT()) { 103 if (max_token_id == model_->EOT()) {
104 break; 104 break;
105 } 105 }
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 7
8 #include "sherpa-onnx/csrc/file-utils.h" 8 #include "sherpa-onnx/csrc/file-utils.h"
9 #include "sherpa-onnx/csrc/macros.h" 9 #include "sherpa-onnx/csrc/macros.h"
  10 +#include "sherpa-onnx/csrc/text-utils.h"
10 11
11 namespace sherpa_onnx { 12 namespace sherpa_onnx {
12 13
@@ -65,6 +66,29 @@ bool OnlineModelConfig::Validate() const { @@ -65,6 +66,29 @@ bool OnlineModelConfig::Validate() const {
65 SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); 66 SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
66 return false; 67 return false;
67 } 68 }
  69 + if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".rknn") ||
  70 + EndsWith(transducer.decoder, ".rknn") ||
  71 + EndsWith(transducer.joiner, ".rknn"))) {
  72 + SHERPA_ONNX_LOGE(
  73 + "--provider is %s, which is not rknn, but you pass rknn model "
  74 + "filenames. encoder: '%s', decoder: '%s', joiner: '%s'",
  75 + provider_config.provider.c_str(), transducer.encoder.c_str(),
  76 + transducer.decoder.c_str(), transducer.joiner.c_str());
  77 + return false;
  78 + }
  79 + }
  80 +
  81 + if (provider_config.provider == "rknn") {
  82 + if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".onnx") ||
  83 + EndsWith(transducer.decoder, ".onnx") ||
  84 + EndsWith(transducer.joiner, ".onnx"))) {
  85 + SHERPA_ONNX_LOGE(
  86 + "--provider is rknn, but you pass onnx model "
  87 + "filenames. encoder: '%s', decoder: '%s', joiner: %'s'",
  88 + transducer.encoder.c_str(), transducer.decoder.c_str(),
  89 + transducer.joiner.c_str());
  90 + return false;
  91 + }
68 } 92 }
69 93
70 if (!tokens_buf.empty() && FileExists(tokens)) { 94 if (!tokens_buf.empty() && FileExists(tokens)) {
@@ -463,8 +463,10 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -463,8 +463,10 @@ class OnlineZipformerTransducerModelRknn::Impl {
463 } 463 }
464 auto meta = Parse(custom_string); 464 auto meta = Parse(custom_string);
465 465
466 - for (const auto &p : meta) {  
467 - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); 466 + if (config_.debug) {
  467 + for (const auto &p : meta) {
  468 + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
  469 + }
468 } 470 }
469 471
470 if (meta.count("encoder_dims")) { 472 if (meta.count("encoder_dims")) {
@@ -90,6 +90,8 @@ as the device_name. @@ -90,6 +90,8 @@ as the device_name.
90 exit(-1); 90 exit(-1);
91 } 91 }
92 92
  93 + fprintf(stderr, "Started! Please speak\n");
  94 +
93 int32_t chunk = 0.1 * alsa.GetActualSampleRate(); 95 int32_t chunk = 0.1 * alsa.GetActualSampleRate();
94 96
95 std::string last_text; 97 std::string last_text;
@@ -158,8 +158,11 @@ for a list of pre-trained models to download. @@ -158,8 +158,11 @@ for a list of pre-trained models to download.
158 const float rtf = s.elapsed_seconds / s.duration; 158 const float rtf = s.elapsed_seconds / s.duration;
159 159
160 os << po.GetArg(i) << "\n"; 160 os << po.GetArg(i) << "\n";
161 - os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds  
162 - << ", Real time factor (RTF): " << rtf << "\n"; 161 + os << "Number of threads: " << config.model_config.num_threads << ", "
  162 + << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
  163 + << ", Audio duration (s): " << s.duration
  164 + << ", Real time factor (RTF) = " << s.elapsed_seconds << "/"
  165 + << s.duration << " = " << rtf << "\n";
163 const auto r = recognizer.GetResult(s.online_stream.get()); 166 const auto r = recognizer.GetResult(s.online_stream.get());
164 os << r.text << "\n"; 167 os << r.text << "\n";
165 os << r.AsJsonString() << "\n\n"; 168 os << r.AsJsonString() << "\n\n";
@@ -699,4 +699,12 @@ std::string ToString(const std::wstring &s) { @@ -699,4 +699,12 @@ std::string ToString(const std::wstring &s) {
699 return converter.to_bytes(s); 699 return converter.to_bytes(s);
700 } 700 }
701 701
  702 +bool EndsWith(const std::string &haystack, const std::string &needle) {
  703 + if (needle.size() > haystack.size()) {
  704 + return false;
  705 + }
  706 +
  707 + return std::equal(needle.rbegin(), needle.rend(), haystack.rbegin());
  708 +}
  709 +
702 } // namespace sherpa_onnx 710 } // namespace sherpa_onnx
@@ -145,6 +145,8 @@ std::wstring ToWideString(const std::string &s); @@ -145,6 +145,8 @@ std::wstring ToWideString(const std::string &s);
145 145
146 std::string ToString(const std::wstring &s); 146 std::string ToString(const std::wstring &s);
147 147
  148 +bool EndsWith(const std::string &haystack, const std::string &needle);
  149 +
148 } // namespace sherpa_onnx 150 } // namespace sherpa_onnx
149 151
150 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 152 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_