Committed by
GitHub
Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
正在显示
6 个修改的文件
包含
103 行增加
和
35 行删除
| @@ -61,10 +61,9 @@ class ContextGraph { | @@ -61,10 +61,9 @@ class ContextGraph { | ||
| 61 | } | 61 | } |
| 62 | 62 | ||
| 63 | ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, | 63 | ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, |
| 64 | - float context_score, const std::vector<float> &scores = {}, | ||
| 65 | - const std::vector<std::string> &phrases = {}) | ||
| 66 | - : ContextGraph(token_ids, context_score, 0.0f, scores, phrases, | ||
| 67 | - std::vector<float>()) {} | 64 | + float context_score, const std::vector<float> &scores = {}) |
| 65 | + : ContextGraph(token_ids, context_score, 0.0f, scores, | ||
| 66 | + std::vector<std::string>(), std::vector<float>()) {} | ||
| 68 | 67 | ||
| 69 | std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep( | 68 | std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep( |
| 70 | const ContextState *state, int32_t token_id, | 69 | const ContextState *state, int32_t token_id, |
| @@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 145 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | 145 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); |
| 146 | std::istringstream is(hws); | 146 | std::istringstream is(hws); |
| 147 | std::vector<std::vector<int32_t>> current; | 147 | std::vector<std::vector<int32_t>> current; |
| 148 | + std::vector<float> current_scores; | ||
| 148 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | 149 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, |
| 149 | - bpe_encoder_.get(), ¤t)) { | 150 | + bpe_encoder_.get(), ¤t, ¤t_scores)) { |
| 150 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | 151 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", |
| 151 | hotwords.c_str()); | 152 | hotwords.c_str()); |
| 152 | } | 153 | } |
| 154 | + | ||
| 155 | + int32_t num_default_hws = hotwords_.size(); | ||
| 156 | + int32_t num_hws = current.size(); | ||
| 157 | + | ||
| 153 | current.insert(current.end(), hotwords_.begin(), hotwords_.end()); | 158 | current.insert(current.end(), hotwords_.begin(), hotwords_.end()); |
| 154 | 159 | ||
| 155 | - auto context_graph = | ||
| 156 | - std::make_shared<ContextGraph>(current, config_.hotwords_score); | 160 | + if (!current_scores.empty() && !boost_scores_.empty()) { |
| 161 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 162 | + boost_scores_.end()); | ||
| 163 | + } else if (!current_scores.empty() && boost_scores_.empty()) { | ||
| 164 | + current_scores.insert(current_scores.end(), num_default_hws, | ||
| 165 | + config_.hotwords_score); | ||
| 166 | + } else if (current_scores.empty() && !boost_scores_.empty()) { | ||
| 167 | + current_scores.insert(current_scores.end(), num_hws, | ||
| 168 | + config_.hotwords_score); | ||
| 169 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 170 | + boost_scores_.end()); | ||
| 171 | + } else { | ||
| 172 | + // Do nothing. | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + auto context_graph = std::make_shared<ContextGraph>( | ||
| 176 | + current, config_.hotwords_score, current_scores); | ||
| 157 | return std::make_unique<OfflineStream>(config_.feat_config, context_graph); | 177 | return std::make_unique<OfflineStream>(config_.feat_config, context_graph); |
| 158 | } | 178 | } |
| 159 | 179 | ||
| @@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 226 | } | 246 | } |
| 227 | 247 | ||
| 228 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | 248 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, |
| 229 | - bpe_encoder_.get(), &hotwords_)) { | 249 | + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { |
| 230 | SHERPA_ONNX_LOGE( | 250 | SHERPA_ONNX_LOGE( |
| 231 | "Failed to encode some hotwords, skip them already, see logs above " | 251 | "Failed to encode some hotwords, skip them already, see logs above " |
| 232 | "for details."); | 252 | "for details."); |
| 233 | } | 253 | } |
| 234 | - hotwords_graph_ = | ||
| 235 | - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 254 | + hotwords_graph_ = std::make_shared<ContextGraph>( |
| 255 | + hotwords_, config_.hotwords_score, boost_scores_); | ||
| 236 | } | 256 | } |
| 237 | 257 | ||
| 238 | #if __ANDROID_API__ >= 9 | 258 | #if __ANDROID_API__ >= 9 |
| @@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 250 | } | 270 | } |
| 251 | 271 | ||
| 252 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | 272 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, |
| 253 | - bpe_encoder_.get(), &hotwords_)) { | 273 | + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { |
| 254 | SHERPA_ONNX_LOGE( | 274 | SHERPA_ONNX_LOGE( |
| 255 | "Failed to encode some hotwords, skip them already, see logs above " | 275 | "Failed to encode some hotwords, skip them already, see logs above " |
| 256 | "for details."); | 276 | "for details."); |
| 257 | } | 277 | } |
| 258 | - hotwords_graph_ = | ||
| 259 | - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 278 | + hotwords_graph_ = std::make_shared<ContextGraph>( |
| 279 | + hotwords_, config_.hotwords_score, boost_scores_); | ||
| 260 | } | 280 | } |
| 261 | #endif | 281 | #endif |
| 262 | 282 | ||
| @@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 264 | OfflineRecognizerConfig config_; | 284 | OfflineRecognizerConfig config_; |
| 265 | SymbolTable symbol_table_; | 285 | SymbolTable symbol_table_; |
| 266 | std::vector<std::vector<int32_t>> hotwords_; | 286 | std::vector<std::vector<int32_t>> hotwords_; |
| 287 | + std::vector<float> boost_scores_; | ||
| 267 | ContextGraphPtr hotwords_graph_; | 288 | ContextGraphPtr hotwords_graph_; |
| 268 | std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; | 289 | std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; |
| 269 | std::unique_ptr<OfflineTransducerModel> model_; | 290 | std::unique_ptr<OfflineTransducerModel> model_; |
| @@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 182 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | 182 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); |
| 183 | std::istringstream is(hws); | 183 | std::istringstream is(hws); |
| 184 | std::vector<std::vector<int32_t>> current; | 184 | std::vector<std::vector<int32_t>> current; |
| 185 | + std::vector<float> current_scores; | ||
| 185 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, | 186 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 186 | - bpe_encoder_.get(), ¤t)) { | 187 | + bpe_encoder_.get(), ¤t, ¤t_scores)) { |
| 187 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | 188 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", |
| 188 | hotwords.c_str()); | 189 | hotwords.c_str()); |
| 189 | } | 190 | } |
| 191 | + | ||
| 192 | + int32_t num_default_hws = hotwords_.size(); | ||
| 193 | + int32_t num_hws = current.size(); | ||
| 194 | + | ||
| 190 | current.insert(current.end(), hotwords_.begin(), hotwords_.end()); | 195 | current.insert(current.end(), hotwords_.begin(), hotwords_.end()); |
| 191 | - auto context_graph = | ||
| 192 | - std::make_shared<ContextGraph>(current, config_.hotwords_score); | 196 | + |
| 197 | + if (!current_scores.empty() && !boost_scores_.empty()) { | ||
| 198 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 199 | + boost_scores_.end()); | ||
| 200 | + } else if (!current_scores.empty() && boost_scores_.empty()) { | ||
| 201 | + current_scores.insert(current_scores.end(), num_default_hws, | ||
| 202 | + config_.hotwords_score); | ||
| 203 | + } else if (current_scores.empty() && !boost_scores_.empty()) { | ||
| 204 | + current_scores.insert(current_scores.end(), num_hws, | ||
| 205 | + config_.hotwords_score); | ||
| 206 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 207 | + boost_scores_.end()); | ||
| 208 | + } else { | ||
| 209 | + // Do nothing. | ||
| 210 | + } | ||
| 211 | + | ||
| 212 | + auto context_graph = std::make_shared<ContextGraph>( | ||
| 213 | + current, config_.hotwords_score, current_scores); | ||
| 193 | auto stream = | 214 | auto stream = |
| 194 | std::make_unique<OnlineStream>(config_.feat_config, context_graph); | 215 | std::make_unique<OnlineStream>(config_.feat_config, context_graph); |
| 195 | InitOnlineStream(stream.get()); | 216 | InitOnlineStream(stream.get()); |
| @@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 376 | } | 397 | } |
| 377 | 398 | ||
| 378 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, | 399 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 379 | - bpe_encoder_.get(), &hotwords_)) { | 400 | + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { |
| 380 | SHERPA_ONNX_LOGE( | 401 | SHERPA_ONNX_LOGE( |
| 381 | "Failed to encode some hotwords, skip them already, see logs above " | 402 | "Failed to encode some hotwords, skip them already, see logs above " |
| 382 | "for details."); | 403 | "for details."); |
| 383 | } | 404 | } |
| 384 | - hotwords_graph_ = | ||
| 385 | - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 405 | + hotwords_graph_ = std::make_shared<ContextGraph>( |
| 406 | + hotwords_, config_.hotwords_score, boost_scores_); | ||
| 386 | } | 407 | } |
| 387 | 408 | ||
| 388 | #if __ANDROID_API__ >= 9 | 409 | #if __ANDROID_API__ >= 9 |
| @@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 400 | } | 421 | } |
| 401 | 422 | ||
| 402 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, | 423 | if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 403 | - bpe_encoder_.get(), &hotwords_)) { | 424 | + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { |
| 404 | SHERPA_ONNX_LOGE( | 425 | SHERPA_ONNX_LOGE( |
| 405 | "Failed to encode some hotwords, skip them already, see logs above " | 426 | "Failed to encode some hotwords, skip them already, see logs above " |
| 406 | "for details."); | 427 | "for details."); |
| 407 | } | 428 | } |
| 408 | - hotwords_graph_ = | ||
| 409 | - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 429 | + hotwords_graph_ = std::make_shared<ContextGraph>( |
| 430 | + hotwords_, config_.hotwords_score, boost_scores_); | ||
| 410 | } | 431 | } |
| 411 | #endif | 432 | #endif |
| 412 | 433 | ||
| @@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 428 | private: | 449 | private: |
| 429 | OnlineRecognizerConfig config_; | 450 | OnlineRecognizerConfig config_; |
| 430 | std::vector<std::vector<int32_t>> hotwords_; | 451 | std::vector<std::vector<int32_t>> hotwords_; |
| 452 | + std::vector<float> boost_scores_; | ||
| 431 | ContextGraphPtr hotwords_graph_; | 453 | ContextGraphPtr hotwords_graph_; |
| 432 | std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; | 454 | std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; |
| 433 | std::unique_ptr<OnlineTransducerModel> model_; | 455 | std::unique_ptr<OnlineTransducerModel> model_; |
| @@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { | @@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { | ||
| 35 | 35 | ||
| 36 | auto sym_table = SymbolTable(tokens); | 36 | auto sym_table = SymbolTable(tokens); |
| 37 | 37 | ||
| 38 | - std::string text = "世界人民大团结\n中国 V S 美国"; | 38 | + std::string text = |
| 39 | + "世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also | ||
| 39 | 40 | ||
| 40 | std::istringstream iss(text); | 41 | std::istringstream iss(text); |
| 41 | 42 | ||
| 42 | std::vector<std::vector<int32_t>> ids; | 43 | std::vector<std::vector<int32_t>> ids; |
| 44 | + std::vector<float> scores; | ||
| 43 | 45 | ||
| 44 | - auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids); | 46 | + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores); |
| 45 | 47 | ||
| 46 | std::vector<std::vector<int32_t>> expected_ids( | 48 | std::vector<std::vector<int32_t>> expected_ids( |
| 47 | {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); | 49 | {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); |
| 48 | EXPECT_EQ(ids, expected_ids); | 50 | EXPECT_EQ(ids, expected_ids); |
| 51 | + | ||
| 52 | + EXPECT_EQ(scores.size(), 0); | ||
| 49 | } | 53 | } |
| 50 | 54 | ||
| 51 | TEST(TEXT2TOKEN, TEST_bpe) { | 55 | TEST(TEXT2TOKEN, TEST_bpe) { |
| @@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) { | @@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) { | ||
| 68 | auto sym_table = SymbolTable(tokens); | 72 | auto sym_table = SymbolTable(tokens); |
| 69 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | 73 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); |
| 70 | 74 | ||
| 71 | - std::string text = "HELLO WORLD\nI LOVE YOU"; | 75 | + std::string text = "HELLO WORLD\nI LOVE YOU :2.0"; |
| 72 | 76 | ||
| 73 | std::istringstream iss(text); | 77 | std::istringstream iss(text); |
| 74 | 78 | ||
| 75 | std::vector<std::vector<int32_t>> ids; | 79 | std::vector<std::vector<int32_t>> ids; |
| 80 | + std::vector<float> scores; | ||
| 76 | 81 | ||
| 77 | - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); | 82 | + auto r = |
| 83 | + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); | ||
| 78 | 84 | ||
| 79 | std::vector<std::vector<int32_t>> expected_ids( | 85 | std::vector<std::vector<int32_t>> expected_ids( |
| 80 | {{22, 58, 24, 425}, {19, 370, 47}}); | 86 | {{22, 58, 24, 425}, {19, 370, 47}}); |
| 81 | EXPECT_EQ(ids, expected_ids); | 87 | EXPECT_EQ(ids, expected_ids); |
| 88 | + | ||
| 89 | + std::vector<float> expected_scores({0, 2.0}); | ||
| 90 | + EXPECT_EQ(scores, expected_scores); | ||
| 82 | } | 91 | } |
| 83 | 92 | ||
| 84 | TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { | 93 | TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { |
| @@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { | @@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { | ||
| 101 | auto sym_table = SymbolTable(tokens); | 110 | auto sym_table = SymbolTable(tokens); |
| 102 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | 111 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); |
| 103 | 112 | ||
| 104 | - std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国"; | 113 | + std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5"; |
| 105 | 114 | ||
| 106 | std::istringstream iss(text); | 115 | std::istringstream iss(text); |
| 107 | 116 | ||
| 108 | std::vector<std::vector<int32_t>> ids; | 117 | std::vector<std::vector<int32_t>> ids; |
| 118 | + std::vector<float> scores; | ||
| 109 | 119 | ||
| 110 | - auto r = | ||
| 111 | - EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids); | 120 | + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), |
| 121 | + &ids, &scores); | ||
| 112 | 122 | ||
| 113 | std::vector<std::vector<int32_t>> expected_ids( | 123 | std::vector<std::vector<int32_t>> expected_ids( |
| 114 | {{1368, 1392, 557, 680, 275, 178, 475}, | 124 | {{1368, 1392, 557, 680, 275, 178, 475}, |
| 115 | {685, 736, 275, 178, 179, 921, 736}}); | 125 | {685, 736, 275, 178, 179, 921, 736}}); |
| 116 | EXPECT_EQ(ids, expected_ids); | 126 | EXPECT_EQ(ids, expected_ids); |
| 127 | + | ||
| 128 | + std::vector<float> expected_scores({1.5, 0.5}); | ||
| 129 | + EXPECT_EQ(scores, expected_scores); | ||
| 117 | } | 130 | } |
| 118 | 131 | ||
| 119 | TEST(TEXT2TOKEN, TEST_bbpe) { | 132 | TEST(TEXT2TOKEN, TEST_bbpe) { |
| @@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) { | @@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) { | ||
| 136 | auto sym_table = SymbolTable(tokens); | 149 | auto sym_table = SymbolTable(tokens); |
| 137 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | 150 | auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); |
| 138 | 151 | ||
| 139 | - std::string text = "频繁\n李鞑靼"; | 152 | + std::string text = "频繁 :1.0\n李鞑靼"; |
| 140 | 153 | ||
| 141 | std::istringstream iss(text); | 154 | std::istringstream iss(text); |
| 142 | 155 | ||
| 143 | std::vector<std::vector<int32_t>> ids; | 156 | std::vector<std::vector<int32_t>> ids; |
| 157 | + std::vector<float> scores; | ||
| 144 | 158 | ||
| 145 | - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); | 159 | + auto r = |
| 160 | + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); | ||
| 146 | 161 | ||
| 147 | std::vector<std::vector<int32_t>> expected_ids( | 162 | std::vector<std::vector<int32_t>> expected_ids( |
| 148 | {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); | 163 | {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); |
| 149 | EXPECT_EQ(ids, expected_ids); | 164 | EXPECT_EQ(ids, expected_ids); |
| 165 | + | ||
| 166 | + std::vector<float> expected_scores({1.0, 0}); | ||
| 167 | + EXPECT_EQ(scores, expected_scores); | ||
| 150 | } | 168 | } |
| 151 | 169 | ||
| 152 | } // namespace sherpa_onnx | 170 | } // namespace sherpa_onnx |
| @@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines, | @@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines, | ||
| 103 | bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | 103 | bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, |
| 104 | const SymbolTable &symbol_table, | 104 | const SymbolTable &symbol_table, |
| 105 | const ssentencepiece::Ssentencepiece *bpe_encoder, | 105 | const ssentencepiece::Ssentencepiece *bpe_encoder, |
| 106 | - std::vector<std::vector<int32_t>> *hotwords) { | 106 | + std::vector<std::vector<int32_t>> *hotwords, |
| 107 | + std::vector<float> *boost_scores) { | ||
| 107 | std::vector<std::string> lines; | 108 | std::vector<std::string> lines; |
| 108 | std::string line; | 109 | std::string line; |
| 109 | std::string word; | 110 | std::string word; |
| @@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | @@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | ||
| 131 | break; | 132 | break; |
| 132 | } | 133 | } |
| 133 | } | 134 | } |
| 134 | - phrase = oss.str().substr(1); | 135 | + phrase = oss.str(); |
| 136 | + if (phrase.empty()) { | ||
| 137 | + continue; | ||
| 138 | + } else { | ||
| 139 | + phrase = phrase.substr(1); | ||
| 140 | + } | ||
| 135 | std::istringstream piss(phrase); | 141 | std::istringstream piss(phrase); |
| 136 | oss.clear(); | 142 | oss.clear(); |
| 137 | oss.str(""); | 143 | oss.str(""); |
| @@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | @@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | ||
| 177 | } | 183 | } |
| 178 | lines.push_back(oss.str()); | 184 | lines.push_back(oss.str()); |
| 179 | } | 185 | } |
| 180 | - return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr); | 186 | + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores, |
| 187 | + nullptr); | ||
| 181 | } | 188 | } |
| 182 | 189 | ||
| 183 | bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | 190 | bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, |
| @@ -29,7 +29,8 @@ namespace sherpa_onnx { | @@ -29,7 +29,8 @@ namespace sherpa_onnx { | ||
| 29 | bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, | 29 | bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, |
| 30 | const SymbolTable &symbol_table, | 30 | const SymbolTable &symbol_table, |
| 31 | const ssentencepiece::Ssentencepiece *bpe_encoder_, | 31 | const ssentencepiece::Ssentencepiece *bpe_encoder_, |
| 32 | - std::vector<std::vector<int32_t>> *hotwords_id); | 32 | + std::vector<std::vector<int32_t>> *hotwords_id, |
| 33 | + std::vector<float> *boost_scores); | ||
| 33 | 34 | ||
| 34 | /* Encode the keywords in an input stream to be tokens ids. | 35 | /* Encode the keywords in an input stream to be tokens ids. |
| 35 | * | 36 | * |
-
请 注册 或 登录 后发表评论