Committed by
GitHub
Support BPE models with byte fallback. (#2531)
正在显示
2 个修改的文件
包含
18 行增加
和
3 行删除
| @@ -171,6 +171,12 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { | @@ -171,6 +171,12 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { | ||
| 171 | void SymbolTable::Init(std::istream &is) { | 171 | void SymbolTable::Init(std::istream &is) { |
| 172 | sym2id_ = ReadTokens(is, &id2sym_); | 172 | sym2id_ = ReadTokens(is, &id2sym_); |
| 173 | is_bbpe_ = IsByteBPE(sym2id_); | 173 | is_bbpe_ = IsByteBPE(sym2id_); |
| 174 | + | ||
| 175 | + if (sym2id_.count("<0x00>") && sym2id_.count("<0xFF>") && | ||
| 176 | + ((sym2id_.at("<0xFF>") - sym2id_.at("<0x00>")) == 255)) { | ||
| 177 | + is_bpe_with_byte_fallback_ = true; | ||
| 178 | + id_for_0x00_ = sym2id_.at("<0x00>"); | ||
| 179 | + } | ||
| 174 | } | 180 | } |
| 175 | 181 | ||
| 176 | std::string SymbolTable::ToString() const { | 182 | std::string SymbolTable::ToString() const { |
| @@ -197,13 +203,13 @@ const std::string SymbolTable::operator[](int32_t id) const { | @@ -197,13 +203,13 @@ const std::string SymbolTable::operator[](int32_t id) const { | ||
| 197 | // id 0 is blank, id 1 is sos/eos, id 2 is unk | 203 | // id 0 is blank, id 1 is sos/eos, id 2 is unk |
| 198 | // | 204 | // |
| 199 | // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> | 205 | // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> |
| 200 | - if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && | 206 | + if (is_bpe_with_byte_fallback_ && sym.size() == 6 && sym[0] == '<' && |
| 201 | sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { | 207 | sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { |
| 202 | std::ostringstream os; | 208 | std::ostringstream os; |
| 203 | - os << std::hex << std::uppercase << (id - 3); | 209 | + os << std::hex << std::uppercase << (id - id_for_0x00_); |
| 204 | 210 | ||
| 205 | if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { | 211 | if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { |
| 206 | - uint8_t i = id - 3; | 212 | + uint8_t i = id - id_for_0x00_; |
| 207 | sym = std::string(&i, &i + 1); | 213 | sym = std::string(&i, &i + 1); |
| 208 | } | 214 | } |
| 209 | } | 215 | } |
| @@ -66,6 +66,15 @@ class SymbolTable { | @@ -66,6 +66,15 @@ class SymbolTable { | ||
| 66 | private: | 66 | private: |
| 67 | std::unordered_map<std::string, int32_t> sym2id_; | 67 | std::unordered_map<std::string, int32_t> sym2id_; |
| 68 | std::unordered_map<int32_t, std::string> id2sym_; | 68 | std::unordered_map<int32_t, std::string> id2sym_; |
| 69 | + | ||
| 70 | + // see https://github.com/k2-fsa/sherpa-onnx/issues/2524 | ||
| 71 | + bool is_bpe_with_byte_fallback_ = false; | ||
| 72 | + | ||
| 73 | + // used only when is_bpe_with_byte_fallback_ is true. It is the ID | ||
| 74 | + // of <0x00> in tokens.txt | ||
| 75 | + int32_t id_for_0x00_ = 0; | ||
| 76 | + | ||
| 77 | + // true for byte BPE. false for non byte BPE. | ||
| 69 | bool is_bbpe_ = false; | 78 | bool is_bbpe_ = false; |
| 70 | }; | 79 | }; |
| 71 | 80 |
-
请 注册 或 登录 后发表评论