正在显示
2 个修改的文件
包含
49 行增加
和
1 行删除
| @@ -17,6 +17,8 @@ | @@ -17,6 +17,8 @@ | ||
| 17 | #include "android/asset_manager_jni.h" | 17 | #include "android/asset_manager_jni.h" |
| 18 | #endif | 18 | #endif |
| 19 | 19 | ||
| 20 | +#include <regex> | ||
| 21 | + | ||
| 20 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 21 | #include "sherpa-onnx/csrc/onnx-utils.h" | 23 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 22 | #include "sherpa-onnx/csrc/text-utils.h" | 24 | #include "sherpa-onnx/csrc/text-utils.h" |
| @@ -147,7 +149,36 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIds( | @@ -147,7 +149,36 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIds( | ||
| 147 | 149 | ||
| 148 | std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( | 150 | std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( |
| 149 | const std::string &text) const { | 151 | const std::string &text) const { |
| 150 | - std::vector<std::string> words = SplitUtf8(text); | 152 | + std::vector<std::string> words; |
| 153 | + if (pattern_) { | ||
| 154 | + // Handle polyphones | ||
| 155 | + size_t pos = 0; | ||
| 156 | + auto begin = std::sregex_iterator(text.begin(), text.end(), *pattern_); | ||
| 157 | + auto end = std::sregex_iterator(); | ||
| 158 | + for (std::sregex_iterator i = begin; i != end; ++i) { | ||
| 159 | + std::smatch match = *i; | ||
| 160 | + if (pos < match.position()) { | ||
| 161 | + auto this_segment = text.substr(pos, match.position() - pos); | ||
| 162 | + auto this_segment_words = SplitUtf8(this_segment); | ||
| 163 | + words.insert(words.end(), this_segment_words.begin(), | ||
| 164 | + this_segment_words.end()); | ||
| 165 | + pos = match.position() + match.length(); | ||
| 166 | + } else if (pos == match.position()) { | ||
| 167 | + pos = match.position() + match.length(); | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + words.push_back(match.str()); | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + if (pos < text.size()) { | ||
| 174 | + auto this_segment = text.substr(pos, text.size() - pos); | ||
| 175 | + auto this_segment_words = SplitUtf8(this_segment); | ||
| 176 | + words.insert(words.end(), this_segment_words.begin(), | ||
| 177 | + this_segment_words.end()); | ||
| 178 | + } | ||
| 179 | + } else { | ||
| 180 | + words = SplitUtf8(text); | ||
| 181 | + } | ||
| 151 | 182 | ||
| 152 | if (debug_) { | 183 | if (debug_) { |
| 153 | fprintf(stderr, "Input text in string: %s\n", text.c_str()); | 184 | fprintf(stderr, "Input text in string: %s\n", text.c_str()); |
| @@ -272,6 +303,9 @@ void Lexicon::InitLexicon(std::istream &is) { | @@ -272,6 +303,9 @@ void Lexicon::InitLexicon(std::istream &is) { | ||
| 272 | std::string line; | 303 | std::string line; |
| 273 | std::string phone; | 304 | std::string phone; |
| 274 | 305 | ||
| 306 | + std::ostringstream os; | ||
| 307 | + std::string sep; | ||
| 308 | + | ||
| 275 | while (std::getline(is, line)) { | 309 | while (std::getline(is, line)) { |
| 276 | std::istringstream iss(line); | 310 | std::istringstream iss(line); |
| 277 | 311 | ||
| @@ -293,8 +327,18 @@ void Lexicon::InitLexicon(std::istream &is) { | @@ -293,8 +327,18 @@ void Lexicon::InitLexicon(std::istream &is) { | ||
| 293 | if (ids.empty()) { | 327 | if (ids.empty()) { |
| 294 | continue; | 328 | continue; |
| 295 | } | 329 | } |
| 330 | + if (language_ == Language::kChinese && word.size() > 3) { | ||
| 331 | + // this is not a single word; | ||
| 332 | + os << sep << word; | ||
| 333 | + sep = "|"; | ||
| 334 | + } | ||
| 335 | + | ||
| 296 | word2ids_.insert({std::move(word), std::move(ids)}); | 336 | word2ids_.insert({std::move(word), std::move(ids)}); |
| 297 | } | 337 | } |
| 338 | + | ||
| 339 | + if (!sep.empty()) { | ||
| 340 | + pattern_ = std::make_unique<std::regex>(os.str()); | ||
| 341 | + } | ||
| 298 | } | 342 | } |
| 299 | 343 | ||
| 300 | void Lexicon::InitPunctuations(const std::string &punctuations) { | 344 | void Lexicon::InitPunctuations(const std::string &punctuations) { |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | 7 | ||
| 8 | #include <cstdint> | 8 | #include <cstdint> |
| 9 | #include <iostream> | 9 | #include <iostream> |
| 10 | +#include <regex> | ||
| 10 | #include <string> | 11 | #include <string> |
| 11 | #include <unordered_map> | 12 | #include <unordered_map> |
| 12 | #include <unordered_set> | 13 | #include <unordered_set> |
| @@ -79,6 +80,9 @@ class Lexicon { | @@ -79,6 +80,9 @@ class Lexicon { | ||
| 79 | Language language_; | 80 | Language language_; |
| 80 | bool debug_; | 81 | bool debug_; |
| 81 | bool is_piper_; | 82 | bool is_piper_; |
| 83 | + | ||
| 84 | + // for Chinese polyphones | ||
| 85 | + std::unique_ptr<std::regex> pattern_; | ||
| 82 | }; | 86 | }; |
| 83 | 87 | ||
| 84 | } // namespace sherpa_onnx | 88 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论