Fangjun Kuang
Committed by GitHub

Support decoding with byte-level BPE (bbpe) models. (#1633)

#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28
# and
# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py
#
# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall
import re
BPE_UNK = chr(8263)
PRINTABLE_BASE_CHARS = (
list(range(256, 287 + 1))
+ list(range(32, 126 + 1))
+ list(range(288, 305 + 1))
+ list(range(308, 318 + 1))
+ list(range(321, 328 + 1))
+ list(range(330, 382 + 1))
+ list(range(384, 422 + 1))
)
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space
def main():
s = ""
s += "// sherpa-onnx/csrc/bbpe.cc\n"
s += "//\n"
s += "// Copyright (c) 2024 Xiaomi Corporation\n"
s += "\n"
s += "// Auto-generated! DO NOT EDIT\n"
s += "\n"
s += '#include "sherpa-onnx/csrc/bbpe.h"\n'
s += "\n"
s += "#include <cstdint>\n"
s += "#include <string>\n"
s += "#include <unordered_map>\n"
s += "\n"
s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n"
s += " static const std::unordered_map<std::string, uint8_t> table = {\n"
s += " "
for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()):
s += "{"
if k in ["\\", '"']:
s += f'"\{k}", {v}'
else:
s += f'"{k}", {v}'
s += "}, "
if i > 0 and i % 7 == 0:
s += "\n"
s += " "
s += "};\n"
s += "\n"
s += " return table\n;"
s += "}\n"
with open("bbpe.cc", "w", encoding="utf-8") as f:
f.write(s)
if __name__ == "__main__":
main()
... ...
... ... @@ -12,6 +12,7 @@ endif()
set(sources
base64-decode.cc
bbpe.cc
cat.cc
circular-buffer.cc
context-graph.cc
... ... @@ -78,11 +79,11 @@ set(sources
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-greedy-search-nemo-decoder.cc
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-nemo-model.cc
online-transducer-greedy-search-nemo-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
... ...
// sherpa-onnx/csrc/bbpe.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Auto-generated! DO NOT EDIT
#include "sherpa-onnx/csrc/bbpe.h"
#include <cstdint>
#include <string>
#include <unordered_map>
const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {
static const std::unordered_map<std::string, uint8_t> table = {
{"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5},
{"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11},
{"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17},
{"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23},
{"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29},
{"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35},
{"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41},
{"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47},
{"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53},
{"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59},
{"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65},
{"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71},
{"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77},
{"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83},
{"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89},
{"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95},
{"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101},
{"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107},
{"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113},
{"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119},
{"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125},
{"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131},
{"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137},
{"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143},
{"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149},
{"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155},
{"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161},
{"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167},
{"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173},
{"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179},
{"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185},
{"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191},
{"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197},
{"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203},
{"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209},
{"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215},
{"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221},
{"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227},
{"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233},
{"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239},
{"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245},
{"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251},
{"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"⁇", 32},
};
return table;
}
... ...
// sherpa-onnx/csrc/bbpe.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_BBPE_H_
#define SHERPA_ONNX_CSRC_BBPE_H_
#include <cstdint>
#include <string>
#include <unordered_map>
// It is equivalent to the map BCHAR_TO_BYTE
// from
// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280
const std::unordered_map<std::string, uint8_t> &GetByteBpeTable();
#endif // SHERPA_ONNX_CSRC_BBPE_H_
... ...
... ... @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
text.append(sym);
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
... ... @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
r.tokens.push_back(std::move(sym));
}
if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
... ...
... ... @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert(
text.append(sym);
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models,
// for bpe models with byte_fallback,
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
... ... @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert(
r.tokens.push_back(std::move(sym));
}
if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
... ...
... ... @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
auto sym = sym_table[i];
r.text.append(sym);
text.append(sym);
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
... ... @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
r.tokens.push_back(std::move(sym));
}
if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
... ...
... ... @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
auto sym = sym_table[i];
r.text.append(sym);
text.append(sym);
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
... ... @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.tokens.push_back(std::move(sym));
}
if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
... ...
... ... @@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/symbol-table.h"
#include <cassert>
#include <cctype>
#include <fstream>
#include <sstream>
#include <string>
... ... @@ -22,8 +23,10 @@
#endif
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/bbpe.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
... ... @@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) {
TrimRight(s, t);
TrimLeft(s, t);
}
bool IsByteBPE(const char *s, int32_t n) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(s);
if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
return IsByteBPE(s + 3, n - 3);
}
for (int32_t i = 0; i != n; ++i) {
if (p[i] > 0xc6) {
return false;
}
}
return true;
}
bool IsByteBPE(const std::unordered_map<std::string, int32_t> &sym2id) {
uint8_t max_v = 0;
for (const auto &p : sym2id) {
const auto &s = p.first;
if (!IsByteBPE(s.c_str(), s.size())) {
return false;
}
uint8_t m = 0;
if (s.size() >= 3) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(s.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
if (s.size() > 3) {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()) + 3,
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
} else {
m = 0;
}
} else {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()),
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
}
} else {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()),
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
}
max_v = (m > max_v) ? m : max_v;
}
return static_cast<uint8_t>(max_v) == 0xc6;
}
} // namespace
std::unordered_map<std::string, int32_t> ReadTokens(
... ... @@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
Init(is);
}
void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); }
void SymbolTable::Init(std::istream &is) {
sym2id_ = ReadTokens(is, &id2sym_);
is_bbpe_ = IsByteBPE(sym2id_);
}
std::string SymbolTable::ToString() const {
std::ostringstream os;
... ... @@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const {
const std::string SymbolTable::operator[](int32_t id) const {
std::string sym = id2sym_.at(id);
if (sym.size() >= 3) {
if (sym.size() >= 3 && !is_bbpe_) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
... ... @@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const {
}
}
// for byte-level BPE
// for BPE with byte_fallback
// id 0 is blank, id 1 is sos/eos, id 2 is unk
//
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
... ... @@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() {
}
}
std::string SymbolTable::DecodeByteBpe(const std::string &text) const {
if (!is_bbpe_) {
return text;
}
auto v = SplitUtf8(text);
const auto &bbpe_table = GetByteBpeTable();
std::string ans;
for (const auto &s : v) {
if (s == "▁") {
if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) {
ans.push_back(' ');
}
} else if (bbpe_table.count(s)) {
ans.push_back(bbpe_table.at(s));
} else if (std::isprint(s[0])) {
ans.append(s);
} else {
// Should not happen
SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str());
}
}
// TODO(fangjun): Filter invalid utf-8 sequences
return ans;
}
#if __ANDROID_API__ >= 9
template SymbolTable::SymbolTable(AAssetManager *mgr,
const std::string &filename);
... ...
... ... @@ -56,12 +56,17 @@ class SymbolTable {
int32_t NumSymbols() const { return id2sym_.size(); }
std::string DecodeByteBpe(const std::string &text) const;
bool IsByteBpe() const { return is_bbpe_; }
private:
void Init(std::istream &is);
private:
std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_;
bool is_bbpe_ = false;
};
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
... ...