Committed by
GitHub
Fix reading hotwords file for android (#354)
正在显示
1 个修改的文件
包含
46 行增加
和
4 行删除
| @@ -12,6 +12,13 @@ | @@ -12,6 +12,13 @@ | ||
| 12 | #include <utility> | 12 | #include <utility> |
| 13 | #include <vector> | 13 | #include <vector> |
| 14 | 14 | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include <strstream> | ||
| 17 | + | ||
| 18 | +#include "android/asset_manager.h" | ||
| 19 | +#include "android/asset_manager_jni.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 15 | #include "sherpa-onnx/csrc/file-utils.h" | 22 | #include "sherpa-onnx/csrc/file-utils.h" |
| 16 | #include "sherpa-onnx/csrc/macros.h" | 23 | #include "sherpa-onnx/csrc/macros.h" |
| 17 | #include "sherpa-onnx/csrc/online-lm.h" | 24 | #include "sherpa-onnx/csrc/online-lm.h" |
| @@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 62 | model_(OnlineTransducerModel::Create(config.model_config)), | 69 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 63 | sym_(config.model_config.tokens), | 70 | sym_(config.model_config.tokens), |
| 64 | endpoint_(config_.endpoint_config) { | 71 | endpoint_(config_.endpoint_config) { |
| 65 | - if (!config_.hotwords_file.empty()) { | ||
| 66 | - InitHotwords(); | ||
| 67 | - } | ||
| 68 | if (sym_.contains("<unk>")) { | 72 | if (sym_.contains("<unk>")) { |
| 69 | unk_id_ = sym_["<unk>"]; | 73 | unk_id_ = sym_["<unk>"]; |
| 70 | } | 74 | } |
| 71 | 75 | ||
| 72 | if (config.decoding_method == "modified_beam_search") { | 76 | if (config.decoding_method == "modified_beam_search") { |
| 77 | + if (!config_.hotwords_file.empty()) { | ||
| 78 | + InitHotwords(); | ||
| 79 | + } | ||
| 80 | + | ||
| 73 | if (!config_.lm_config.model.empty()) { | 81 | if (!config_.lm_config.model.empty()) { |
| 74 | lm_ = OnlineLM::Create(config.lm_config); | 82 | lm_ = OnlineLM::Create(config.lm_config); |
| 75 | } | 83 | } |
| @@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 99 | } | 107 | } |
| 100 | 108 | ||
| 101 | if (config.decoding_method == "modified_beam_search") { | 109 | if (config.decoding_method == "modified_beam_search") { |
| 110 | +#if 0 | ||
| 111 | + // TODO(fangjun): Implement it | ||
| 112 | + if (!config_.lm_config.model.empty()) { | ||
| 113 | + lm_ = OnlineLM::Create(mgr, config.lm_config); | ||
| 114 | + } | ||
| 115 | +#endif | ||
| 116 | + | ||
| 117 | + if (!config_.hotwords_file.empty()) { | ||
| 118 | + InitHotwords(mgr); | ||
| 119 | + } | ||
| 120 | + | ||
| 102 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 121 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 103 | model_.get(), lm_.get(), config_.max_active_paths, | 122 | model_.get(), lm_.get(), config_.max_active_paths, |
| 104 | config_.lm_config.scale, unk_id_); | 123 | config_.lm_config.scale, unk_id_); |
| @@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 268 | s->Reset(); | 287 | s->Reset(); |
| 269 | } | 288 | } |
| 270 | 289 | ||
| 290 | + private: | ||
| 271 | void InitHotwords() { | 291 | void InitHotwords() { |
| 272 | // each line in hotwords_file contains space-separated words | 292 | // each line in hotwords_file contains space-separated words |
| 273 | 293 | ||
| @@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 286 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 306 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); |
| 287 | } | 307 | } |
| 288 | 308 | ||
| 289 | - private: | 309 | +#if __ANDROID_API__ >= 9 |
| 310 | + void InitHotwords(AAssetManager *mgr) { | ||
| 311 | + // each line in hotwords_file contains space-separated words | ||
| 312 | + | ||
| 313 | + auto buf = ReadFile(mgr, config_.hotwords_file); | ||
| 314 | + | ||
| 315 | + std::istrstream is(buf.data(), buf.size()); | ||
| 316 | + | ||
| 317 | + if (!is) { | ||
| 318 | + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | ||
| 319 | + config_.hotwords_file.c_str()); | ||
| 320 | + exit(-1); | ||
| 321 | + } | ||
| 322 | + | ||
| 323 | + if (!EncodeHotwords(is, sym_, &hotwords_)) { | ||
| 324 | + SHERPA_ONNX_LOGE("Encode hotwords failed."); | ||
| 325 | + exit(-1); | ||
| 326 | + } | ||
| 327 | + hotwords_graph_ = | ||
| 328 | + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | ||
| 329 | + } | ||
| 330 | +#endif | ||
| 331 | + | ||
| 290 | void InitOnlineStream(OnlineStream *stream) const { | 332 | void InitOnlineStream(OnlineStream *stream) const { |
| 291 | auto r = decoder_->GetEmptyResult(); | 333 | auto r = decoder_->GetEmptyResult(); |
| 292 | 334 |
-
请 注册 或 登录 后发表评论