Fangjun Kuang
Committed by GitHub

Support BPE models with byte fallback. (#2531)

... ... @@ -171,6 +171,12 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
void SymbolTable::Init(std::istream &is) {
sym2id_ = ReadTokens(is, &id2sym_);
is_bbpe_ = IsByteBPE(sym2id_);
if (sym2id_.count("<0x00>") && sym2id_.count("<0xFF>") &&
((sym2id_.at("<0xFF>") - sym2id_.at("<0x00>")) == 255)) {
is_bpe_with_byte_fallback_ = true;
id_for_0x00_ = sym2id_.at("<0x00>");
}
}
std::string SymbolTable::ToString() const {
... ... @@ -197,13 +203,13 @@ const std::string SymbolTable::operator[](int32_t id) const {
// id 0 is blank, id 1 is sos/eos, id 2 is unk
//
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
if (is_bpe_with_byte_fallback_ && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
os << std::hex << std::uppercase << (id - 3);
os << std::hex << std::uppercase << (id - id_for_0x00_);
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
uint8_t i = id - 3;
uint8_t i = id - id_for_0x00_;
sym = std::string(&i, &i + 1);
}
}
... ...
... ... @@ -66,6 +66,15 @@ class SymbolTable {
private:
std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_;
// see https://github.com/k2-fsa/sherpa-onnx/issues/2524
bool is_bpe_with_byte_fallback_ = false;
// used only when is_bpe_with_byte_fallback_ is true. It is the ID
// of <0x00> in tokens.txt
int32_t id_for_0x00_ = 0;
// true for byte BPE. false for non byte BPE.
bool is_bbpe_ = false;
};
... ...