Fangjun Kuang
Committed by GitHub

Support decoding with byte-level BPE (bbpe) models. (#1633)

  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +#
  4 +# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28
  5 +# and
  6 +# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py
  7 +#
  8 +# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall
  9 +
  10 +import re
  11 +
  12 +BPE_UNK = chr(8263)
  13 +PRINTABLE_BASE_CHARS = (
  14 + list(range(256, 287 + 1))
  15 + + list(range(32, 126 + 1))
  16 + + list(range(288, 305 + 1))
  17 + + list(range(308, 318 + 1))
  18 + + list(range(321, 328 + 1))
  19 + + list(range(330, 382 + 1))
  20 + + list(range(384, 422 + 1))
  21 +)
  22 +
  23 +
  24 +BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
  25 +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
  26 +BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space
  27 +
  28 +
  29 +def main():
  30 + s = ""
  31 + s += "// sherpa-onnx/csrc/bbpe.cc\n"
  32 + s += "//\n"
  33 + s += "// Copyright (c) 2024 Xiaomi Corporation\n"
  34 + s += "\n"
  35 + s += "// Auto-generated! DO NOT EDIT\n"
  36 + s += "\n"
  37 + s += '#include "sherpa-onnx/csrc/bbpe.h"\n'
  38 + s += "\n"
  39 + s += "#include <cstdint>\n"
  40 + s += "#include <string>\n"
  41 + s += "#include <unordered_map>\n"
  42 + s += "\n"
  43 + s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n"
  44 + s += " static const std::unordered_map<std::string, uint8_t> table = {\n"
  45 +
  46 + s += " "
  47 + for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()):
  48 + s += "{"
  49 + if k in ["\\", '"']:
  50 + s += f'"\{k}", {v}'
  51 + else:
  52 + s += f'"{k}", {v}'
  53 + s += "}, "
  54 + if i > 0 and i % 7 == 0:
  55 + s += "\n"
  56 + s += " "
  57 + s += "};\n"
  58 + s += "\n"
  59 + s += " return table\n;"
  60 + s += "}\n"
  61 +
  62 + with open("bbpe.cc", "w", encoding="utf-8") as f:
  63 + f.write(s)
  64 +
  65 +
  66 +if __name__ == "__main__":
  67 + main()
@@ -12,6 +12,7 @@ endif() @@ -12,6 +12,7 @@ endif()
12 12
13 set(sources 13 set(sources
14 base64-decode.cc 14 base64-decode.cc
  15 + bbpe.cc
15 cat.cc 16 cat.cc
16 circular-buffer.cc 17 circular-buffer.cc
17 context-graph.cc 18 context-graph.cc
@@ -78,11 +79,11 @@ set(sources @@ -78,11 +79,11 @@ set(sources
78 online-stream.cc 79 online-stream.cc
79 online-transducer-decoder.cc 80 online-transducer-decoder.cc
80 online-transducer-greedy-search-decoder.cc 81 online-transducer-greedy-search-decoder.cc
  82 + online-transducer-greedy-search-nemo-decoder.cc
81 online-transducer-model-config.cc 83 online-transducer-model-config.cc
82 online-transducer-model.cc 84 online-transducer-model.cc
83 online-transducer-modified-beam-search-decoder.cc 85 online-transducer-modified-beam-search-decoder.cc
84 online-transducer-nemo-model.cc 86 online-transducer-nemo-model.cc
85 - online-transducer-greedy-search-nemo-decoder.cc  
86 online-wenet-ctc-model-config.cc 87 online-wenet-ctc-model-config.cc
87 online-wenet-ctc-model.cc 88 online-wenet-ctc-model.cc
88 online-zipformer-transducer-model.cc 89 online-zipformer-transducer-model.cc
  1 +// sherpa-onnx/csrc/bbpe.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +// Auto-generated! DO NOT EDIT
  6 +
  7 +#include "sherpa-onnx/csrc/bbpe.h"
  8 +
  9 +#include <cstdint>
  10 +#include <string>
  11 +#include <unordered_map>
  12 +
  13 +const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {
  14 + static const std::unordered_map<std::string, uint8_t> table = {
  15 + {"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5},
  16 + {"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11},
  17 + {"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17},
  18 + {"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23},
  19 + {"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29},
  20 + {"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35},
  21 + {"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41},
  22 + {"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47},
  23 + {"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53},
  24 + {"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59},
  25 + {"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65},
  26 + {"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71},
  27 + {"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77},
  28 + {"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83},
  29 + {"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89},
  30 + {"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95},
  31 + {"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101},
  32 + {"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107},
  33 + {"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113},
  34 + {"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119},
  35 + {"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125},
  36 + {"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131},
  37 + {"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137},
  38 + {"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143},
  39 + {"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149},
  40 + {"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155},
  41 + {"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161},
  42 + {"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167},
  43 + {"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173},
  44 + {"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179},
  45 + {"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185},
  46 + {"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191},
  47 + {"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197},
  48 + {"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203},
  49 + {"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209},
  50 + {"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215},
  51 + {"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221},
  52 + {"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227},
  53 + {"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233},
  54 + {"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239},
  55 + {"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245},
  56 + {"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251},
  57 + {"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"⁇", 32},
  58 + };
  59 +
  60 + return table;
  61 +}
  1 +// sherpa-onnx/csrc/bbpe.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_BBPE_H_
  6 +#define SHERPA_ONNX_CSRC_BBPE_H_
  7 +#include <cstdint>
  8 +#include <string>
  9 +#include <unordered_map>
  10 +
  11 +// It is equivalent to the map BCHAR_TO_BYTE
  12 +// from
  13 +// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280
  14 +const std::unordered_map<std::string, uint8_t> &GetByteBpeTable();
  15 +
  16 +#endif // SHERPA_ONNX_CSRC_BBPE_H_
@@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
41 text.append(sym); 41 text.append(sym);
42 42
43 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { 43 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
44 - // for byte bpe models 44 + // for bpe models with byte_fallback
45 // (but don't rewrite printable characters 0x20..0x7e, 45 // (but don't rewrite printable characters 0x20..0x7e,
46 // which collide with standard BPE units) 46 // which collide with standard BPE units)
47 std::ostringstream os; 47 std::ostringstream os;
@@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
52 52
53 r.tokens.push_back(std::move(sym)); 53 r.tokens.push_back(std::move(sym));
54 } 54 }
  55 +
  56 + if (sym_table.IsByteBpe()) {
  57 + text = sym_table.DecodeByteBpe(text);
  58 + }
  59 +
55 r.text = std::move(text); 60 r.text = std::move(text);
56 61
57 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; 62 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
@@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert( @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert(
43 text.append(sym); 43 text.append(sym);
44 44
45 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { 45 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
46 - // for byte bpe models, 46 + // for bpe models with byte_fallback,
47 // (but don't rewrite printable characters 0x20..0x7e, 47 // (but don't rewrite printable characters 0x20..0x7e,
48 // which collide with standard BPE units) 48 // which collide with standard BPE units)
49 std::ostringstream os; 49 std::ostringstream os;
@@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert( @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert(
54 54
55 r.tokens.push_back(std::move(sym)); 55 r.tokens.push_back(std::move(sym));
56 } 56 }
  57 + if (sym_table.IsByteBpe()) {
  58 + text = sym_table.DecodeByteBpe(text);
  59 + }
  60 +
57 r.text = std::move(text); 61 r.text = std::move(text);
58 62
59 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; 63 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
@@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
34 r.tokens.reserve(src.tokens.size()); 34 r.tokens.reserve(src.tokens.size());
35 r.timestamps.reserve(src.tokens.size()); 35 r.timestamps.reserve(src.tokens.size());
36 36
  37 + std::string text;
37 for (auto i : src.tokens) { 38 for (auto i : src.tokens) {
38 auto sym = sym_table[i]; 39 auto sym = sym_table[i];
39 40
40 - r.text.append(sym); 41 + text.append(sym);
41 42
42 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { 43 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
43 - // for byte bpe models 44 + // for bpe models with byte_fallback
44 // (but don't rewrite printable characters 0x20..0x7e, 45 // (but don't rewrite printable characters 0x20..0x7e,
45 // which collide with standard BPE units) 46 // which collide with standard BPE units)
46 std::ostringstream os; 47 std::ostringstream os;
@@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
52 r.tokens.push_back(std::move(sym)); 53 r.tokens.push_back(std::move(sym));
53 } 54 }
54 55
  56 + if (sym_table.IsByteBpe()) {
  57 + text = sym_table.DecodeByteBpe(text);
  58 + }
  59 +
  60 + r.text = std::move(text);
  61 +
55 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; 62 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
56 for (auto t : src.timestamps) { 63 for (auto t : src.timestamps) {
57 float time = frame_shift_s * t; 64 float time = frame_shift_s * t;
@@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
38 r.tokens.reserve(src.tokens.size()); 38 r.tokens.reserve(src.tokens.size());
39 r.timestamps.reserve(src.tokens.size()); 39 r.timestamps.reserve(src.tokens.size());
40 40
  41 + std::string text;
41 for (auto i : src.tokens) { 42 for (auto i : src.tokens) {
42 auto sym = sym_table[i]; 43 auto sym = sym_table[i];
43 44
44 - r.text.append(sym); 45 + text.append(sym);
45 46
46 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { 47 if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
47 - // for byte bpe models 48 + // for bpe models with byte_fallback
48 // (but don't rewrite printable characters 0x20..0x7e, 49 // (but don't rewrite printable characters 0x20..0x7e,
49 // which collide with standard BPE units) 50 // which collide with standard BPE units)
50 std::ostringstream os; 51 std::ostringstream os;
@@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
56 r.tokens.push_back(std::move(sym)); 57 r.tokens.push_back(std::move(sym));
57 } 58 }
58 59
  60 + if (sym_table.IsByteBpe()) {
  61 + text = sym_table.DecodeByteBpe(text);
  62 + }
  63 +
  64 + r.text = std::move(text);
  65 +
59 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; 66 float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
60 for (auto t : src.timestamps) { 67 for (auto t : src.timestamps) {
61 float time = frame_shift_s * t; 68 float time = frame_shift_s * t;
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #include "sherpa-onnx/csrc/symbol-table.h" 5 #include "sherpa-onnx/csrc/symbol-table.h"
6 6
7 #include <cassert> 7 #include <cassert>
  8 +#include <cctype>
8 #include <fstream> 9 #include <fstream>
9 #include <sstream> 10 #include <sstream>
10 #include <string> 11 #include <string>
@@ -22,8 +23,10 @@ @@ -22,8 +23,10 @@
22 #endif 23 #endif
23 24
24 #include "sherpa-onnx/csrc/base64-decode.h" 25 #include "sherpa-onnx/csrc/base64-decode.h"
  26 +#include "sherpa-onnx/csrc/bbpe.h"
25 #include "sherpa-onnx/csrc/lexicon.h" 27 #include "sherpa-onnx/csrc/lexicon.h"
26 #include "sherpa-onnx/csrc/onnx-utils.h" 28 #include "sherpa-onnx/csrc/onnx-utils.h"
  29 +#include "sherpa-onnx/csrc/text-utils.h"
27 30
28 namespace sherpa_onnx { 31 namespace sherpa_onnx {
29 32
@@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) { @@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) {
47 TrimRight(s, t); 50 TrimRight(s, t);
48 TrimLeft(s, t); 51 TrimLeft(s, t);
49 } 52 }
  53 +
  54 +bool IsByteBPE(const char *s, int32_t n) {
  55 + const uint8_t *p = reinterpret_cast<const uint8_t *>(s);
  56 + if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
  57 + return IsByteBPE(s + 3, n - 3);
  58 + }
  59 +
  60 + for (int32_t i = 0; i != n; ++i) {
  61 + if (p[i] > 0xc6) {
  62 + return false;
  63 + }
  64 + }
  65 +
  66 + return true;
  67 +}
  68 +
  69 +bool IsByteBPE(const std::unordered_map<std::string, int32_t> &sym2id) {
  70 + uint8_t max_v = 0;
  71 + for (const auto &p : sym2id) {
  72 + const auto &s = p.first;
  73 + if (!IsByteBPE(s.c_str(), s.size())) {
  74 + return false;
  75 + }
  76 +
  77 + uint8_t m = 0;
  78 + if (s.size() >= 3) {
  79 + const uint8_t *p = reinterpret_cast<const uint8_t *>(s.c_str());
  80 +
  81 + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
  82 + if (s.size() > 3) {
  83 + m = *std::max_element(
  84 + reinterpret_cast<const uint8_t *>(s.data()) + 3,
  85 + reinterpret_cast<const uint8_t *>(s.data()) + s.size());
  86 + } else {
  87 + m = 0;
  88 + }
  89 + } else {
  90 + m = *std::max_element(
  91 + reinterpret_cast<const uint8_t *>(s.data()),
  92 + reinterpret_cast<const uint8_t *>(s.data()) + s.size());
  93 + }
  94 + } else {
  95 + m = *std::max_element(
  96 + reinterpret_cast<const uint8_t *>(s.data()),
  97 + reinterpret_cast<const uint8_t *>(s.data()) + s.size());
  98 + }
  99 +
  100 + max_v = (m > max_v) ? m : max_v;
  101 + }
  102 +
  103 + return static_cast<uint8_t>(max_v) == 0xc6;
  104 +}
  105 +
50 } // namespace 106 } // namespace
51 107
52 std::unordered_map<std::string, int32_t> ReadTokens( 108 std::unordered_map<std::string, int32_t> ReadTokens(
@@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { @@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
111 Init(is); 167 Init(is);
112 } 168 }
113 169
114 -void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); } 170 +void SymbolTable::Init(std::istream &is) {
  171 + sym2id_ = ReadTokens(is, &id2sym_);
  172 + is_bbpe_ = IsByteBPE(sym2id_);
  173 +}
115 174
116 std::string SymbolTable::ToString() const { 175 std::string SymbolTable::ToString() const {
117 std::ostringstream os; 176 std::ostringstream os;
@@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const { @@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const {
124 183
125 const std::string SymbolTable::operator[](int32_t id) const { 184 const std::string SymbolTable::operator[](int32_t id) const {
126 std::string sym = id2sym_.at(id); 185 std::string sym = id2sym_.at(id);
127 - if (sym.size() >= 3) { 186 + if (sym.size() >= 3 && !is_bbpe_) {
128 // For BPE-based models, we replace ▁ with a space 187 // For BPE-based models, we replace ▁ with a space
129 // Unicode 9601, hex 0x2581, utf8 0xe29681 188 // Unicode 9601, hex 0x2581, utf8 0xe29681
130 const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); 189 const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
@@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const { @@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const {
133 } 192 }
134 } 193 }
135 194
136 - // for byte-level BPE 195 + // for BPE with byte_fallback
137 // id 0 is blank, id 1 is sos/eos, id 2 is unk 196 // id 0 is blank, id 1 is sos/eos, id 2 is unk
138 // 197 //
139 // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> 198 // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
@@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() { @@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() {
172 } 231 }
173 } 232 }
174 233
  234 +std::string SymbolTable::DecodeByteBpe(const std::string &text) const {
  235 + if (!is_bbpe_) {
  236 + return text;
  237 + }
  238 + auto v = SplitUtf8(text);
  239 +
  240 + const auto &bbpe_table = GetByteBpeTable();
  241 + std::string ans;
  242 + for (const auto &s : v) {
  243 + if (s == "▁") {
  244 + if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) {
  245 + ans.push_back(' ');
  246 + }
  247 + } else if (bbpe_table.count(s)) {
  248 + ans.push_back(bbpe_table.at(s));
  249 + } else if (std::isprint(s[0])) {
  250 + ans.append(s);
  251 + } else {
  252 + // Should not happen
  253 + SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str());
  254 + }
  255 + }
  256 +
  257 + // TODO(fangjun): Filter invalid utf-8 sequences
  258 + return ans;
  259 +}
  260 +
175 #if __ANDROID_API__ >= 9 261 #if __ANDROID_API__ >= 9
176 template SymbolTable::SymbolTable(AAssetManager *mgr, 262 template SymbolTable::SymbolTable(AAssetManager *mgr,
177 const std::string &filename); 263 const std::string &filename);
@@ -56,12 +56,17 @@ class SymbolTable { @@ -56,12 +56,17 @@ class SymbolTable {
56 56
57 int32_t NumSymbols() const { return id2sym_.size(); } 57 int32_t NumSymbols() const { return id2sym_.size(); }
58 58
  59 + std::string DecodeByteBpe(const std::string &text) const;
  60 +
  61 + bool IsByteBpe() const { return is_bbpe_; }
  62 +
59 private: 63 private:
60 void Init(std::istream &is); 64 void Init(std::istream &is);
61 65
62 private: 66 private:
63 std::unordered_map<std::string, int32_t> sym2id_; 67 std::unordered_map<std::string, int32_t> sym2id_;
64 std::unordered_map<int32_t, std::string> id2sym_; 68 std::unordered_map<int32_t, std::string> id2sym_;
  69 + bool is_bbpe_ = false;
65 }; 70 };
66 71
67 std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); 72 std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);