Fangjun Kuang
Committed by GitHub

Support BPE models with byte fallback. (#2531)

@@ -171,6 +171,12 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { @@ -171,6 +171,12 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
171 void SymbolTable::Init(std::istream &is) { 171 void SymbolTable::Init(std::istream &is) {
172 sym2id_ = ReadTokens(is, &id2sym_); 172 sym2id_ = ReadTokens(is, &id2sym_);
173 is_bbpe_ = IsByteBPE(sym2id_); 173 is_bbpe_ = IsByteBPE(sym2id_);
  174 +
  175 + if (sym2id_.count("<0x00>") && sym2id_.count("<0xFF>") &&
  176 + ((sym2id_.at("<0xFF>") - sym2id_.at("<0x00>")) == 255)) {
  177 + is_bpe_with_byte_fallback_ = true;
  178 + id_for_0x00_ = sym2id_.at("<0x00>");
  179 + }
174 } 180 }
175 181
176 std::string SymbolTable::ToString() const { 182 std::string SymbolTable::ToString() const {
@@ -197,13 +203,13 @@ const std::string SymbolTable::operator[](int32_t id) const { @@ -197,13 +203,13 @@ const std::string SymbolTable::operator[](int32_t id) const {
197 // id 0 is blank, id 1 is sos/eos, id 2 is unk 203 // id 0 is blank, id 1 is sos/eos, id 2 is unk
198 // 204 //
199 // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> 205 // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
200 - if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && 206 + if (is_bpe_with_byte_fallback_ && sym.size() == 6 && sym[0] == '<' &&
201 sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { 207 sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
202 std::ostringstream os; 208 std::ostringstream os;
203 - os << std::hex << std::uppercase << (id - 3); 209 + os << std::hex << std::uppercase << (id - id_for_0x00_);
204 210
205 if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { 211 if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
206 - uint8_t i = id - 3; 212 + uint8_t i = id - id_for_0x00_;
207 sym = std::string(&i, &i + 1); 213 sym = std::string(&i, &i + 1);
208 } 214 }
209 } 215 }
@@ -66,6 +66,15 @@ class SymbolTable { @@ -66,6 +66,15 @@ class SymbolTable {
66 private: 66 private:
67 std::unordered_map<std::string, int32_t> sym2id_; 67 std::unordered_map<std::string, int32_t> sym2id_;
68 std::unordered_map<int32_t, std::string> id2sym_; 68 std::unordered_map<int32_t, std::string> id2sym_;
  69 +
  70 + // see https://github.com/k2-fsa/sherpa-onnx/issues/2524
  71 + bool is_bpe_with_byte_fallback_ = false;
  72 +
  73 + // used only when is_bpe_with_byte_fallback_ is true. It is the ID
  74 + // of <0x00> in tokens.txt
  75 + int32_t id_for_0x00_ = 0;
  76 +
  77 + // true for byte BPE. false for non byte BPE.
69 bool is_bbpe_ = false; 78 bool is_bbpe_ = false;
70 }; 79 };
71 80