Wei Kang
Committed by GitHub

Support customize scores for hotwords (#926)

* Support customize scores for hotwords

* Skip blank lines
@@ -61,10 +61,9 @@ class ContextGraph { @@ -61,10 +61,9 @@ class ContextGraph {
61 } 61 }
62 62
63 ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, 63 ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
64 - float context_score, const std::vector<float> &scores = {},  
65 - const std::vector<std::string> &phrases = {})  
66 - : ContextGraph(token_ids, context_score, 0.0f, scores, phrases,  
67 - std::vector<float>()) {} 64 + float context_score, const std::vector<float> &scores = {})
  65 + : ContextGraph(token_ids, context_score, 0.0f, scores,
  66 + std::vector<std::string>(), std::vector<float>()) {}
68 67
69 std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep( 68 std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
70 const ContextState *state, int32_t token_id, 69 const ContextState *state, int32_t token_id,
@@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
145 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); 145 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
146 std::istringstream is(hws); 146 std::istringstream is(hws);
147 std::vector<std::vector<int32_t>> current; 147 std::vector<std::vector<int32_t>> current;
  148 + std::vector<float> current_scores;
148 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, 149 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
149 - bpe_encoder_.get(), &current)) { 150 + bpe_encoder_.get(), &current, &current_scores)) {
150 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", 151 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
151 hotwords.c_str()); 152 hotwords.c_str());
152 } 153 }
  154 +
  155 + int32_t num_default_hws = hotwords_.size();
  156 + int32_t num_hws = current.size();
  157 +
153 current.insert(current.end(), hotwords_.begin(), hotwords_.end()); 158 current.insert(current.end(), hotwords_.begin(), hotwords_.end());
154 159
155 - auto context_graph =  
156 - std::make_shared<ContextGraph>(current, config_.hotwords_score); 160 + if (!current_scores.empty() && !boost_scores_.empty()) {
  161 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  162 + boost_scores_.end());
  163 + } else if (!current_scores.empty() && boost_scores_.empty()) {
  164 + current_scores.insert(current_scores.end(), num_default_hws,
  165 + config_.hotwords_score);
  166 + } else if (current_scores.empty() && !boost_scores_.empty()) {
  167 + current_scores.insert(current_scores.end(), num_hws,
  168 + config_.hotwords_score);
  169 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  170 + boost_scores_.end());
  171 + } else {
  172 + // Do nothing.
  173 + }
  174 +
  175 + auto context_graph = std::make_shared<ContextGraph>(
  176 + current, config_.hotwords_score, current_scores);
157 return std::make_unique<OfflineStream>(config_.feat_config, context_graph); 177 return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
158 } 178 }
159 179
@@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
226 } 246 }
227 247
228 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, 248 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
229 - bpe_encoder_.get(), &hotwords_)) { 249 + bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
230 SHERPA_ONNX_LOGE( 250 SHERPA_ONNX_LOGE(
231 "Failed to encode some hotwords, skip them already, see logs above " 251 "Failed to encode some hotwords, skip them already, see logs above "
232 "for details."); 252 "for details.");
233 } 253 }
234 - hotwords_graph_ =  
235 - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 254 + hotwords_graph_ = std::make_shared<ContextGraph>(
  255 + hotwords_, config_.hotwords_score, boost_scores_);
236 } 256 }
237 257
238 #if __ANDROID_API__ >= 9 258 #if __ANDROID_API__ >= 9
@@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
250 } 270 }
251 271
252 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, 272 if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
253 - bpe_encoder_.get(), &hotwords_)) { 273 + bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
254 SHERPA_ONNX_LOGE( 274 SHERPA_ONNX_LOGE(
255 "Failed to encode some hotwords, skip them already, see logs above " 275 "Failed to encode some hotwords, skip them already, see logs above "
256 "for details."); 276 "for details.");
257 } 277 }
258 - hotwords_graph_ =  
259 - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 278 + hotwords_graph_ = std::make_shared<ContextGraph>(
  279 + hotwords_, config_.hotwords_score, boost_scores_);
260 } 280 }
261 #endif 281 #endif
262 282
@@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
264 OfflineRecognizerConfig config_; 284 OfflineRecognizerConfig config_;
265 SymbolTable symbol_table_; 285 SymbolTable symbol_table_;
266 std::vector<std::vector<int32_t>> hotwords_; 286 std::vector<std::vector<int32_t>> hotwords_;
  287 + std::vector<float> boost_scores_;
267 ContextGraphPtr hotwords_graph_; 288 ContextGraphPtr hotwords_graph_;
268 std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; 289 std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
269 std::unique_ptr<OfflineTransducerModel> model_; 290 std::unique_ptr<OfflineTransducerModel> model_;
@@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
182 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); 182 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
183 std::istringstream is(hws); 183 std::istringstream is(hws);
184 std::vector<std::vector<int32_t>> current; 184 std::vector<std::vector<int32_t>> current;
  185 + std::vector<float> current_scores;
185 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, 186 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
186 - bpe_encoder_.get(), &current)) { 187 + bpe_encoder_.get(), &current, &current_scores)) {
187 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", 188 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
188 hotwords.c_str()); 189 hotwords.c_str());
189 } 190 }
  191 +
  192 + int32_t num_default_hws = hotwords_.size();
  193 + int32_t num_hws = current.size();
  194 +
190 current.insert(current.end(), hotwords_.begin(), hotwords_.end()); 195 current.insert(current.end(), hotwords_.begin(), hotwords_.end());
191 - auto context_graph =  
192 - std::make_shared<ContextGraph>(current, config_.hotwords_score); 196 +
  197 + if (!current_scores.empty() && !boost_scores_.empty()) {
  198 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  199 + boost_scores_.end());
  200 + } else if (!current_scores.empty() && boost_scores_.empty()) {
  201 + current_scores.insert(current_scores.end(), num_default_hws,
  202 + config_.hotwords_score);
  203 + } else if (current_scores.empty() && !boost_scores_.empty()) {
  204 + current_scores.insert(current_scores.end(), num_hws,
  205 + config_.hotwords_score);
  206 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  207 + boost_scores_.end());
  208 + } else {
  209 + // Do nothing.
  210 + }
  211 +
  212 + auto context_graph = std::make_shared<ContextGraph>(
  213 + current, config_.hotwords_score, current_scores);
193 auto stream = 214 auto stream =
194 std::make_unique<OnlineStream>(config_.feat_config, context_graph); 215 std::make_unique<OnlineStream>(config_.feat_config, context_graph);
195 InitOnlineStream(stream.get()); 216 InitOnlineStream(stream.get());
@@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
376 } 397 }
377 398
378 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, 399 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
379 - bpe_encoder_.get(), &hotwords_)) { 400 + bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
380 SHERPA_ONNX_LOGE( 401 SHERPA_ONNX_LOGE(
381 "Failed to encode some hotwords, skip them already, see logs above " 402 "Failed to encode some hotwords, skip them already, see logs above "
382 "for details."); 403 "for details.");
383 } 404 }
384 - hotwords_graph_ =  
385 - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 405 + hotwords_graph_ = std::make_shared<ContextGraph>(
  406 + hotwords_, config_.hotwords_score, boost_scores_);
386 } 407 }
387 408
388 #if __ANDROID_API__ >= 9 409 #if __ANDROID_API__ >= 9
@@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
400 } 421 }
401 422
402 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, 423 if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
403 - bpe_encoder_.get(), &hotwords_)) { 424 + bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
404 SHERPA_ONNX_LOGE( 425 SHERPA_ONNX_LOGE(
405 "Failed to encode some hotwords, skip them already, see logs above " 426 "Failed to encode some hotwords, skip them already, see logs above "
406 "for details."); 427 "for details.");
407 } 428 }
408 - hotwords_graph_ =  
409 - std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 429 + hotwords_graph_ = std::make_shared<ContextGraph>(
  430 + hotwords_, config_.hotwords_score, boost_scores_);
410 } 431 }
411 #endif 432 #endif
412 433
@@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
428 private: 449 private:
429 OnlineRecognizerConfig config_; 450 OnlineRecognizerConfig config_;
430 std::vector<std::vector<int32_t>> hotwords_; 451 std::vector<std::vector<int32_t>> hotwords_;
  452 + std::vector<float> boost_scores_;
431 ContextGraphPtr hotwords_graph_; 453 ContextGraphPtr hotwords_graph_;
432 std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; 454 std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
433 std::unique_ptr<OnlineTransducerModel> model_; 455 std::unique_ptr<OnlineTransducerModel> model_;
@@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { @@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
35 35
36 auto sym_table = SymbolTable(tokens); 36 auto sym_table = SymbolTable(tokens);
37 37
38 - std::string text = "世界人民大团结\n中国 V S 美国"; 38 + std::string text =
  39 + "世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also
39 40
40 std::istringstream iss(text); 41 std::istringstream iss(text);
41 42
42 std::vector<std::vector<int32_t>> ids; 43 std::vector<std::vector<int32_t>> ids;
  44 + std::vector<float> scores;
43 45
44 - auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids); 46 + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
45 47
46 std::vector<std::vector<int32_t>> expected_ids( 48 std::vector<std::vector<int32_t>> expected_ids(
47 {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); 49 {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
48 EXPECT_EQ(ids, expected_ids); 50 EXPECT_EQ(ids, expected_ids);
  51 +
  52 + EXPECT_EQ(scores.size(), 0);
49 } 53 }
50 54
51 TEST(TEXT2TOKEN, TEST_bpe) { 55 TEST(TEXT2TOKEN, TEST_bpe) {
@@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) { @@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) {
68 auto sym_table = SymbolTable(tokens); 72 auto sym_table = SymbolTable(tokens);
69 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); 73 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
70 74
71 - std::string text = "HELLO WORLD\nI LOVE YOU"; 75 + std::string text = "HELLO WORLD\nI LOVE YOU :2.0";
72 76
73 std::istringstream iss(text); 77 std::istringstream iss(text);
74 78
75 std::vector<std::vector<int32_t>> ids; 79 std::vector<std::vector<int32_t>> ids;
  80 + std::vector<float> scores;
76 81
77 - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); 82 + auto r =
  83 + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
78 84
79 std::vector<std::vector<int32_t>> expected_ids( 85 std::vector<std::vector<int32_t>> expected_ids(
80 {{22, 58, 24, 425}, {19, 370, 47}}); 86 {{22, 58, 24, 425}, {19, 370, 47}});
81 EXPECT_EQ(ids, expected_ids); 87 EXPECT_EQ(ids, expected_ids);
  88 +
  89 + std::vector<float> expected_scores({0, 2.0});
  90 + EXPECT_EQ(scores, expected_scores);
82 } 91 }
83 92
84 TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { 93 TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
@@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { @@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
101 auto sym_table = SymbolTable(tokens); 110 auto sym_table = SymbolTable(tokens);
102 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); 111 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
103 112
104 - std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国"; 113 + std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5";
105 114
106 std::istringstream iss(text); 115 std::istringstream iss(text);
107 116
108 std::vector<std::vector<int32_t>> ids; 117 std::vector<std::vector<int32_t>> ids;
  118 + std::vector<float> scores;
109 119
110 - auto r =  
111 - EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids); 120 + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
  121 + &ids, &scores);
112 122
113 std::vector<std::vector<int32_t>> expected_ids( 123 std::vector<std::vector<int32_t>> expected_ids(
114 {{1368, 1392, 557, 680, 275, 178, 475}, 124 {{1368, 1392, 557, 680, 275, 178, 475},
115 {685, 736, 275, 178, 179, 921, 736}}); 125 {685, 736, 275, 178, 179, 921, 736}});
116 EXPECT_EQ(ids, expected_ids); 126 EXPECT_EQ(ids, expected_ids);
  127 +
  128 + std::vector<float> expected_scores({1.5, 0.5});
  129 + EXPECT_EQ(scores, expected_scores);
117 } 130 }
118 131
119 TEST(TEXT2TOKEN, TEST_bbpe) { 132 TEST(TEXT2TOKEN, TEST_bbpe) {
@@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) { @@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
136 auto sym_table = SymbolTable(tokens); 149 auto sym_table = SymbolTable(tokens);
137 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); 150 auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
138 151
139 - std::string text = "频繁\n李鞑靼"; 152 + std::string text = "频繁 :1.0\n李鞑靼";
140 153
141 std::istringstream iss(text); 154 std::istringstream iss(text);
142 155
143 std::vector<std::vector<int32_t>> ids; 156 std::vector<std::vector<int32_t>> ids;
  157 + std::vector<float> scores;
144 158
145 - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); 159 + auto r =
  160 + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
146 161
147 std::vector<std::vector<int32_t>> expected_ids( 162 std::vector<std::vector<int32_t>> expected_ids(
148 {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); 163 {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
149 EXPECT_EQ(ids, expected_ids); 164 EXPECT_EQ(ids, expected_ids);
  165 +
  166 + std::vector<float> expected_scores({1.0, 0});
  167 + EXPECT_EQ(scores, expected_scores);
150 } 168 }
151 169
152 } // namespace sherpa_onnx 170 } // namespace sherpa_onnx
@@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines, @@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines,
103 bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, 103 bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
104 const SymbolTable &symbol_table, 104 const SymbolTable &symbol_table,
105 const ssentencepiece::Ssentencepiece *bpe_encoder, 105 const ssentencepiece::Ssentencepiece *bpe_encoder,
106 - std::vector<std::vector<int32_t>> *hotwords) { 106 + std::vector<std::vector<int32_t>> *hotwords,
  107 + std::vector<float> *boost_scores) {
107 std::vector<std::string> lines; 108 std::vector<std::string> lines;
108 std::string line; 109 std::string line;
109 std::string word; 110 std::string word;
@@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, @@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
131 break; 132 break;
132 } 133 }
133 } 134 }
134 - phrase = oss.str().substr(1); 135 + phrase = oss.str();
  136 + if (phrase.empty()) {
  137 + continue;
  138 + } else {
  139 + phrase = phrase.substr(1);
  140 + }
135 std::istringstream piss(phrase); 141 std::istringstream piss(phrase);
136 oss.clear(); 142 oss.clear();
137 oss.str(""); 143 oss.str("");
@@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, @@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
177 } 183 }
178 lines.push_back(oss.str()); 184 lines.push_back(oss.str());
179 } 185 }
180 - return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr); 186 + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores,
  187 + nullptr);
181 } 188 }
182 189
183 bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, 190 bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
@@ -29,7 +29,8 @@ namespace sherpa_onnx { @@ -29,7 +29,8 @@ namespace sherpa_onnx {
29 bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, 29 bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
30 const SymbolTable &symbol_table, 30 const SymbolTable &symbol_table,
31 const ssentencepiece::Ssentencepiece *bpe_encoder_, 31 const ssentencepiece::Ssentencepiece *bpe_encoder_,
32 - std::vector<std::vector<int32_t>> *hotwords_id); 32 + std::vector<std::vector<int32_t>> *hotwords_id,
  33 + std::vector<float> *boost_scores);
33 34
34 /* Encode the keywords in an input stream to be tokens ids. 35 /* Encode the keywords in an input stream to be tokens ids.
35 * 36 *