Fangjun Kuang
Committed by GitHub

Fix reading hotwords file for android (#354)

@@ -12,6 +12,13 @@ @@ -12,6 +12,13 @@
12 #include <utility> 12 #include <utility>
13 #include <vector> 13 #include <vector>
14 14
  15 +#if __ANDROID_API__ >= 9
  16 +#include <strstream>
  17 +
  18 +#include "android/asset_manager.h"
  19 +#include "android/asset_manager_jni.h"
  20 +#endif
  21 +
15 #include "sherpa-onnx/csrc/file-utils.h" 22 #include "sherpa-onnx/csrc/file-utils.h"
16 #include "sherpa-onnx/csrc/macros.h" 23 #include "sherpa-onnx/csrc/macros.h"
17 #include "sherpa-onnx/csrc/online-lm.h" 24 #include "sherpa-onnx/csrc/online-lm.h"
@@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
62 model_(OnlineTransducerModel::Create(config.model_config)), 69 model_(OnlineTransducerModel::Create(config.model_config)),
63 sym_(config.model_config.tokens), 70 sym_(config.model_config.tokens),
64 endpoint_(config_.endpoint_config) { 71 endpoint_(config_.endpoint_config) {
65 - if (!config_.hotwords_file.empty()) {  
66 - InitHotwords();  
67 - }  
68 if (sym_.contains("<unk>")) { 72 if (sym_.contains("<unk>")) {
69 unk_id_ = sym_["<unk>"]; 73 unk_id_ = sym_["<unk>"];
70 } 74 }
71 75
72 if (config.decoding_method == "modified_beam_search") { 76 if (config.decoding_method == "modified_beam_search") {
  77 + if (!config_.hotwords_file.empty()) {
  78 + InitHotwords();
  79 + }
  80 +
73 if (!config_.lm_config.model.empty()) { 81 if (!config_.lm_config.model.empty()) {
74 lm_ = OnlineLM::Create(config.lm_config); 82 lm_ = OnlineLM::Create(config.lm_config);
75 } 83 }
@@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
99 } 107 }
100 108
101 if (config.decoding_method == "modified_beam_search") { 109 if (config.decoding_method == "modified_beam_search") {
  110 +#if 0
  111 + // TODO(fangjun): Implement it
  112 + if (!config_.lm_config.model.empty()) {
  113 + lm_ = OnlineLM::Create(mgr, config.lm_config);
  114 + }
  115 +#endif
  116 +
  117 + if (!config_.hotwords_file.empty()) {
  118 + InitHotwords(mgr);
  119 + }
  120 +
102 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 121 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
103 model_.get(), lm_.get(), config_.max_active_paths, 122 model_.get(), lm_.get(), config_.max_active_paths,
104 config_.lm_config.scale, unk_id_); 123 config_.lm_config.scale, unk_id_);
@@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
268 s->Reset(); 287 s->Reset();
269 } 288 }
270 289
  290 + private:
271 void InitHotwords() { 291 void InitHotwords() {
272 // each line in hotwords_file contains space-separated words 292 // each line in hotwords_file contains space-separated words
273 293
@@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
286 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 306 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
287 } 307 }
288 308
289 - private: 309 +#if __ANDROID_API__ >= 9
  310 + void InitHotwords(AAssetManager *mgr) {
  311 + // each line in hotwords_file contains space-separated words
  312 +
  313 + auto buf = ReadFile(mgr, config_.hotwords_file);
  314 +
  315 + std::istrstream is(buf.data(), buf.size());
  316 +
  317 + if (!is) {
  318 + SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
  319 + config_.hotwords_file.c_str());
  320 + exit(-1);
  321 + }
  322 +
  323 + if (!EncodeHotwords(is, sym_, &hotwords_)) {
  324 + SHERPA_ONNX_LOGE("Encode hotwords failed.");
  325 + exit(-1);
  326 + }
  327 + hotwords_graph_ =
  328 + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
  329 + }
  330 +#endif
  331 +
290 void InitOnlineStream(OnlineStream *stream) const { 332 void InitOnlineStream(OnlineStream *stream) const {
291 auto r = decoder_->GetEmptyResult(); 333 auto r = decoder_->GetEmptyResult();
292 334