Committed by
GitHub
Support decoding with byte-level BPE (bbpe) models. (#1633)
正在显示
11 个修改的文件
包含
270 行增加
和
10 行删除
scripts/bbpe/.gitignore
0 → 100644
| 1 | +bbpe.cc |
scripts/bbpe/generate_bbpe_table.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +# | ||
| 4 | +# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28 | ||
| 5 | +# and | ||
| 6 | +# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py | ||
| 7 | +# | ||
| 8 | +# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall | ||
| 9 | + | ||
| 10 | +import re | ||
| 11 | + | ||
| 12 | +BPE_UNK = chr(8263) | ||
| 13 | +PRINTABLE_BASE_CHARS = ( | ||
| 14 | + list(range(256, 287 + 1)) | ||
| 15 | + + list(range(32, 126 + 1)) | ||
| 16 | + + list(range(288, 305 + 1)) | ||
| 17 | + + list(range(308, 318 + 1)) | ||
| 18 | + + list(range(321, 328 + 1)) | ||
| 19 | + + list(range(330, 382 + 1)) | ||
| 20 | + + list(range(384, 422 + 1)) | ||
| 21 | +) | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)} | ||
| 25 | +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} | ||
| 26 | +BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def main(): | ||
| 30 | + s = "" | ||
| 31 | + s += "// sherpa-onnx/csrc/bbpe.cc\n" | ||
| 32 | + s += "//\n" | ||
| 33 | + s += "// Copyright (c) 2024 Xiaomi Corporation\n" | ||
| 34 | + s += "\n" | ||
| 35 | + s += "// Auto-generated! DO NOT EDIT\n" | ||
| 36 | + s += "\n" | ||
| 37 | + s += '#include "sherpa-onnx/csrc/bbpe.h"\n' | ||
| 38 | + s += "\n" | ||
| 39 | + s += "#include <cstdint>\n" | ||
| 40 | + s += "#include <string>\n" | ||
| 41 | + s += "#include <unordered_map>\n" | ||
| 42 | + s += "\n" | ||
| 43 | + s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n" | ||
| 44 | + s += " static const std::unordered_map<std::string, uint8_t> table = {\n" | ||
| 45 | + | ||
| 46 | + s += " " | ||
| 47 | + for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()): | ||
| 48 | + s += "{" | ||
| 49 | + if k in ["\\", '"']: | ||
| 50 | + s += f'"\{k}", {v}' | ||
| 51 | + else: | ||
| 52 | + s += f'"{k}", {v}' | ||
| 53 | + s += "}, " | ||
| 54 | + if i > 0 and i % 7 == 0: | ||
| 55 | + s += "\n" | ||
| 56 | + s += " " | ||
| 57 | + s += "};\n" | ||
| 58 | + s += "\n" | ||
| 59 | + s += " return table\n;" | ||
| 60 | + s += "}\n" | ||
| 61 | + | ||
| 62 | + with open("bbpe.cc", "w", encoding="utf-8") as f: | ||
| 63 | + f.write(s) | ||
| 64 | + | ||
| 65 | + | ||
| 66 | +if __name__ == "__main__": | ||
| 67 | + main() |
| @@ -12,6 +12,7 @@ endif() | @@ -12,6 +12,7 @@ endif() | ||
| 12 | 12 | ||
| 13 | set(sources | 13 | set(sources |
| 14 | base64-decode.cc | 14 | base64-decode.cc |
| 15 | + bbpe.cc | ||
| 15 | cat.cc | 16 | cat.cc |
| 16 | circular-buffer.cc | 17 | circular-buffer.cc |
| 17 | context-graph.cc | 18 | context-graph.cc |
| @@ -78,11 +79,11 @@ set(sources | @@ -78,11 +79,11 @@ set(sources | ||
| 78 | online-stream.cc | 79 | online-stream.cc |
| 79 | online-transducer-decoder.cc | 80 | online-transducer-decoder.cc |
| 80 | online-transducer-greedy-search-decoder.cc | 81 | online-transducer-greedy-search-decoder.cc |
| 82 | + online-transducer-greedy-search-nemo-decoder.cc | ||
| 81 | online-transducer-model-config.cc | 83 | online-transducer-model-config.cc |
| 82 | online-transducer-model.cc | 84 | online-transducer-model.cc |
| 83 | online-transducer-modified-beam-search-decoder.cc | 85 | online-transducer-modified-beam-search-decoder.cc |
| 84 | online-transducer-nemo-model.cc | 86 | online-transducer-nemo-model.cc |
| 85 | - online-transducer-greedy-search-nemo-decoder.cc | ||
| 86 | online-wenet-ctc-model-config.cc | 87 | online-wenet-ctc-model-config.cc |
| 87 | online-wenet-ctc-model.cc | 88 | online-wenet-ctc-model.cc |
| 88 | online-zipformer-transducer-model.cc | 89 | online-zipformer-transducer-model.cc |
sherpa-onnx/csrc/bbpe.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/bbpe.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +// Auto-generated! DO NOT EDIT | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/bbpe.h" | ||
| 8 | + | ||
| 9 | +#include <cstdint> | ||
| 10 | +#include <string> | ||
| 11 | +#include <unordered_map> | ||
| 12 | + | ||
| 13 | +const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() { | ||
| 14 | + static const std::unordered_map<std::string, uint8_t> table = { | ||
| 15 | + {"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5}, | ||
| 16 | + {"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11}, | ||
| 17 | + {"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17}, | ||
| 18 | + {"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23}, | ||
| 19 | + {"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29}, | ||
| 20 | + {"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35}, | ||
| 21 | + {"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41}, | ||
| 22 | + {"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47}, | ||
| 23 | + {"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53}, | ||
| 24 | + {"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59}, | ||
| 25 | + {"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65}, | ||
| 26 | + {"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71}, | ||
| 27 | + {"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77}, | ||
| 28 | + {"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83}, | ||
| 29 | + {"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89}, | ||
| 30 | + {"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95}, | ||
| 31 | + {"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101}, | ||
| 32 | + {"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107}, | ||
| 33 | + {"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113}, | ||
| 34 | + {"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119}, | ||
| 35 | + {"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125}, | ||
| 36 | + {"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131}, | ||
| 37 | + {"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137}, | ||
| 38 | + {"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143}, | ||
| 39 | + {"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149}, | ||
| 40 | + {"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155}, | ||
| 41 | + {"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161}, | ||
| 42 | + {"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167}, | ||
| 43 | + {"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173}, | ||
| 44 | + {"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179}, | ||
| 45 | + {"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185}, | ||
| 46 | + {"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191}, | ||
| 47 | + {"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197}, | ||
| 48 | + {"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203}, | ||
| 49 | + {"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209}, | ||
| 50 | + {"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215}, | ||
| 51 | + {"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221}, | ||
| 52 | + {"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227}, | ||
| 53 | + {"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233}, | ||
| 54 | + {"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239}, | ||
| 55 | + {"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245}, | ||
| 56 | + {"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251}, | ||
| 57 | + {"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"⁇", 32}, | ||
| 58 | + }; | ||
| 59 | + | ||
| 60 | + return table; | ||
| 61 | +} |
sherpa-onnx/csrc/bbpe.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/bbpe.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_BBPE_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_BBPE_H_ | ||
| 7 | +#include <cstdint> | ||
| 8 | +#include <string> | ||
| 9 | +#include <unordered_map> | ||
| 10 | + | ||
| 11 | +// It is equivalent to the map BCHAR_TO_BYTE | ||
| 12 | +// from | ||
| 13 | +// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280 | ||
| 14 | +const std::unordered_map<std::string, uint8_t> &GetByteBpeTable(); | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_CSRC_BBPE_H_ |
| @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 41 | text.append(sym); | 41 | text.append(sym); |
| 42 | 42 | ||
| 43 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { | 43 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { |
| 44 | - // for byte bpe models | 44 | + // for bpe models with byte_fallback |
| 45 | // (but don't rewrite printable characters 0x20..0x7e, | 45 | // (but don't rewrite printable characters 0x20..0x7e, |
| 46 | // which collide with standard BPE units) | 46 | // which collide with standard BPE units) |
| 47 | std::ostringstream os; | 47 | std::ostringstream os; |
| @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 52 | 52 | ||
| 53 | r.tokens.push_back(std::move(sym)); | 53 | r.tokens.push_back(std::move(sym)); |
| 54 | } | 54 | } |
| 55 | + | ||
| 56 | + if (sym_table.IsByteBpe()) { | ||
| 57 | + text = sym_table.DecodeByteBpe(text); | ||
| 58 | + } | ||
| 59 | + | ||
| 55 | r.text = std::move(text); | 60 | r.text = std::move(text); |
| 56 | 61 | ||
| 57 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | 62 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; |
| @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert( | @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert( | ||
| 43 | text.append(sym); | 43 | text.append(sym); |
| 44 | 44 | ||
| 45 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { | 45 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { |
| 46 | - // for byte bpe models, | 46 | + // for bpe models with byte_fallback, |
| 47 | // (but don't rewrite printable characters 0x20..0x7e, | 47 | // (but don't rewrite printable characters 0x20..0x7e, |
| 48 | // which collide with standard BPE units) | 48 | // which collide with standard BPE units) |
| 49 | std::ostringstream os; | 49 | std::ostringstream os; |
| @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert( | @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert( | ||
| 54 | 54 | ||
| 55 | r.tokens.push_back(std::move(sym)); | 55 | r.tokens.push_back(std::move(sym)); |
| 56 | } | 56 | } |
| 57 | + if (sym_table.IsByteBpe()) { | ||
| 58 | + text = sym_table.DecodeByteBpe(text); | ||
| 59 | + } | ||
| 60 | + | ||
| 57 | r.text = std::move(text); | 61 | r.text = std::move(text); |
| 58 | 62 | ||
| 59 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | 63 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; |
| @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 34 | r.tokens.reserve(src.tokens.size()); | 34 | r.tokens.reserve(src.tokens.size()); |
| 35 | r.timestamps.reserve(src.tokens.size()); | 35 | r.timestamps.reserve(src.tokens.size()); |
| 36 | 36 | ||
| 37 | + std::string text; | ||
| 37 | for (auto i : src.tokens) { | 38 | for (auto i : src.tokens) { |
| 38 | auto sym = sym_table[i]; | 39 | auto sym = sym_table[i]; |
| 39 | 40 | ||
| 40 | - r.text.append(sym); | 41 | + text.append(sym); |
| 41 | 42 | ||
| 42 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { | 43 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { |
| 43 | - // for byte bpe models | 44 | + // for bpe models with byte_fallback |
| 44 | // (but don't rewrite printable characters 0x20..0x7e, | 45 | // (but don't rewrite printable characters 0x20..0x7e, |
| 45 | // which collide with standard BPE units) | 46 | // which collide with standard BPE units) |
| 46 | std::ostringstream os; | 47 | std::ostringstream os; |
| @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 52 | r.tokens.push_back(std::move(sym)); | 53 | r.tokens.push_back(std::move(sym)); |
| 53 | } | 54 | } |
| 54 | 55 | ||
| 56 | + if (sym_table.IsByteBpe()) { | ||
| 57 | + text = sym_table.DecodeByteBpe(text); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + r.text = std::move(text); | ||
| 61 | + | ||
| 55 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | 62 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; |
| 56 | for (auto t : src.timestamps) { | 63 | for (auto t : src.timestamps) { |
| 57 | float time = frame_shift_s * t; | 64 | float time = frame_shift_s * t; |
| @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 38 | r.tokens.reserve(src.tokens.size()); | 38 | r.tokens.reserve(src.tokens.size()); |
| 39 | r.timestamps.reserve(src.tokens.size()); | 39 | r.timestamps.reserve(src.tokens.size()); |
| 40 | 40 | ||
| 41 | + std::string text; | ||
| 41 | for (auto i : src.tokens) { | 42 | for (auto i : src.tokens) { |
| 42 | auto sym = sym_table[i]; | 43 | auto sym = sym_table[i]; |
| 43 | 44 | ||
| 44 | - r.text.append(sym); | 45 | + text.append(sym); |
| 45 | 46 | ||
| 46 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { | 47 | if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { |
| 47 | - // for byte bpe models | 48 | + // for bpe models with byte_fallback |
| 48 | // (but don't rewrite printable characters 0x20..0x7e, | 49 | // (but don't rewrite printable characters 0x20..0x7e, |
| 49 | // which collide with standard BPE units) | 50 | // which collide with standard BPE units) |
| 50 | std::ostringstream os; | 51 | std::ostringstream os; |
| @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 56 | r.tokens.push_back(std::move(sym)); | 57 | r.tokens.push_back(std::move(sym)); |
| 57 | } | 58 | } |
| 58 | 59 | ||
| 60 | + if (sym_table.IsByteBpe()) { | ||
| 61 | + text = sym_table.DecodeByteBpe(text); | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + r.text = std::move(text); | ||
| 65 | + | ||
| 59 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | 66 | float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; |
| 60 | for (auto t : src.timestamps) { | 67 | for (auto t : src.timestamps) { |
| 61 | float time = frame_shift_s * t; | 68 | float time = frame_shift_s * t; |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | #include "sherpa-onnx/csrc/symbol-table.h" | 5 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 6 | 6 | ||
| 7 | #include <cassert> | 7 | #include <cassert> |
| 8 | +#include <cctype> | ||
| 8 | #include <fstream> | 9 | #include <fstream> |
| 9 | #include <sstream> | 10 | #include <sstream> |
| 10 | #include <string> | 11 | #include <string> |
| @@ -22,8 +23,10 @@ | @@ -22,8 +23,10 @@ | ||
| 22 | #endif | 23 | #endif |
| 23 | 24 | ||
| 24 | #include "sherpa-onnx/csrc/base64-decode.h" | 25 | #include "sherpa-onnx/csrc/base64-decode.h" |
| 26 | +#include "sherpa-onnx/csrc/bbpe.h" | ||
| 25 | #include "sherpa-onnx/csrc/lexicon.h" | 27 | #include "sherpa-onnx/csrc/lexicon.h" |
| 26 | #include "sherpa-onnx/csrc/onnx-utils.h" | 28 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 29 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 27 | 30 | ||
| 28 | namespace sherpa_onnx { | 31 | namespace sherpa_onnx { |
| 29 | 32 | ||
| @@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) { | @@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) { | ||
| 47 | TrimRight(s, t); | 50 | TrimRight(s, t); |
| 48 | TrimLeft(s, t); | 51 | TrimLeft(s, t); |
| 49 | } | 52 | } |
| 53 | + | ||
| 54 | +bool IsByteBPE(const char *s, int32_t n) { | ||
| 55 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(s); | ||
| 56 | + if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 57 | + return IsByteBPE(s + 3, n - 3); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + for (int32_t i = 0; i != n; ++i) { | ||
| 61 | + if (p[i] > 0xc6) { | ||
| 62 | + return false; | ||
| 63 | + } | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + return true; | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +bool IsByteBPE(const std::unordered_map<std::string, int32_t> &sym2id) { | ||
| 70 | + uint8_t max_v = 0; | ||
| 71 | + for (const auto &p : sym2id) { | ||
| 72 | + const auto &s = p.first; | ||
| 73 | + if (!IsByteBPE(s.c_str(), s.size())) { | ||
| 74 | + return false; | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + uint8_t m = 0; | ||
| 78 | + if (s.size() >= 3) { | ||
| 79 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(s.c_str()); | ||
| 80 | + | ||
| 81 | + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 82 | + if (s.size() > 3) { | ||
| 83 | + m = *std::max_element( | ||
| 84 | + reinterpret_cast<const uint8_t *>(s.data()) + 3, | ||
| 85 | + reinterpret_cast<const uint8_t *>(s.data()) + s.size()); | ||
| 86 | + } else { | ||
| 87 | + m = 0; | ||
| 88 | + } | ||
| 89 | + } else { | ||
| 90 | + m = *std::max_element( | ||
| 91 | + reinterpret_cast<const uint8_t *>(s.data()), | ||
| 92 | + reinterpret_cast<const uint8_t *>(s.data()) + s.size()); | ||
| 93 | + } | ||
| 94 | + } else { | ||
| 95 | + m = *std::max_element( | ||
| 96 | + reinterpret_cast<const uint8_t *>(s.data()), | ||
| 97 | + reinterpret_cast<const uint8_t *>(s.data()) + s.size()); | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + max_v = (m > max_v) ? m : max_v; | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + return static_cast<uint8_t>(max_v) == 0xc6; | ||
| 104 | +} | ||
| 105 | + | ||
| 50 | } // namespace | 106 | } // namespace |
| 51 | 107 | ||
| 52 | std::unordered_map<std::string, int32_t> ReadTokens( | 108 | std::unordered_map<std::string, int32_t> ReadTokens( |
| @@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { | @@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { | ||
| 111 | Init(is); | 167 | Init(is); |
| 112 | } | 168 | } |
| 113 | 169 | ||
| 114 | -void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); } | 170 | +void SymbolTable::Init(std::istream &is) { |
| 171 | + sym2id_ = ReadTokens(is, &id2sym_); | ||
| 172 | + is_bbpe_ = IsByteBPE(sym2id_); | ||
| 173 | +} | ||
| 115 | 174 | ||
| 116 | std::string SymbolTable::ToString() const { | 175 | std::string SymbolTable::ToString() const { |
| 117 | std::ostringstream os; | 176 | std::ostringstream os; |
| @@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const { | @@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const { | ||
| 124 | 183 | ||
| 125 | const std::string SymbolTable::operator[](int32_t id) const { | 184 | const std::string SymbolTable::operator[](int32_t id) const { |
| 126 | std::string sym = id2sym_.at(id); | 185 | std::string sym = id2sym_.at(id); |
| 127 | - if (sym.size() >= 3) { | 186 | + if (sym.size() >= 3 && !is_bbpe_) { |
| 128 | // For BPE-based models, we replace ▁ with a space | 187 | // For BPE-based models, we replace ▁ with a space |
| 129 | // Unicode 9601, hex 0x2581, utf8 0xe29681 | 188 | // Unicode 9601, hex 0x2581, utf8 0xe29681 |
| 130 | const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); | 189 | const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); |
| @@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const { | @@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const { | ||
| 133 | } | 192 | } |
| 134 | } | 193 | } |
| 135 | 194 | ||
| 136 | - // for byte-level BPE | 195 | + // for BPE with byte_fallback |
| 137 | // id 0 is blank, id 1 is sos/eos, id 2 is unk | 196 | // id 0 is blank, id 1 is sos/eos, id 2 is unk |
| 138 | // | 197 | // |
| 139 | // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> | 198 | // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> |
| @@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() { | @@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() { | ||
| 172 | } | 231 | } |
| 173 | } | 232 | } |
| 174 | 233 | ||
| 234 | +std::string SymbolTable::DecodeByteBpe(const std::string &text) const { | ||
| 235 | + if (!is_bbpe_) { | ||
| 236 | + return text; | ||
| 237 | + } | ||
| 238 | + auto v = SplitUtf8(text); | ||
| 239 | + | ||
| 240 | + const auto &bbpe_table = GetByteBpeTable(); | ||
| 241 | + std::string ans; | ||
| 242 | + for (const auto &s : v) { | ||
| 243 | + if (s == "▁") { | ||
| 244 | + if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) { | ||
| 245 | + ans.push_back(' '); | ||
| 246 | + } | ||
| 247 | + } else if (bbpe_table.count(s)) { | ||
| 248 | + ans.push_back(bbpe_table.at(s)); | ||
| 249 | + } else if (std::isprint(s[0])) { | ||
| 250 | + ans.append(s); | ||
| 251 | + } else { | ||
| 252 | + // Should not happen | ||
| 253 | + SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str()); | ||
| 254 | + } | ||
| 255 | + } | ||
| 256 | + | ||
| 257 | + // TODO(fangjun): Filter invalid utf-8 sequences | ||
| 258 | + return ans; | ||
| 259 | +} | ||
| 260 | + | ||
| 175 | #if __ANDROID_API__ >= 9 | 261 | #if __ANDROID_API__ >= 9 |
| 176 | template SymbolTable::SymbolTable(AAssetManager *mgr, | 262 | template SymbolTable::SymbolTable(AAssetManager *mgr, |
| 177 | const std::string &filename); | 263 | const std::string &filename); |
| @@ -56,12 +56,17 @@ class SymbolTable { | @@ -56,12 +56,17 @@ class SymbolTable { | ||
| 56 | 56 | ||
| 57 | int32_t NumSymbols() const { return id2sym_.size(); } | 57 | int32_t NumSymbols() const { return id2sym_.size(); } |
| 58 | 58 | ||
| 59 | + std::string DecodeByteBpe(const std::string &text) const; | ||
| 60 | + | ||
| 61 | + bool IsByteBpe() const { return is_bbpe_; } | ||
| 62 | + | ||
| 59 | private: | 63 | private: |
| 60 | void Init(std::istream &is); | 64 | void Init(std::istream &is); |
| 61 | 65 | ||
| 62 | private: | 66 | private: |
| 63 | std::unordered_map<std::string, int32_t> sym2id_; | 67 | std::unordered_map<std::string, int32_t> sym2id_; |
| 64 | std::unordered_map<int32_t, std::string> id2sym_; | 68 | std::unordered_map<int32_t, std::string> id2sym_; |
| 69 | + bool is_bbpe_ = false; | ||
| 65 | }; | 70 | }; |
| 66 | 71 | ||
| 67 | std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); | 72 | std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); |
-
请 注册 或 登录 后发表评论