Committed by
GitHub
Fixing Whisper Model Token Normalization (#1904)
正在显示
3 个修改的文件
包含
100 行增加
和
23 行删除
| @@ -134,3 +134,5 @@ us_gold.json | @@ -134,3 +134,5 @@ us_gold.json | ||
| 134 | us_silver.json | 134 | us_silver.json |
| 135 | kokoro-multi-lang-v1_0 | 135 | kokoro-multi-lang-v1_0 |
| 136 | sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 | 136 | sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 |
| 137 | +cmake-build-debug | ||
| 138 | +README-DEV.txt |
| @@ -23,28 +23,6 @@ | @@ -23,28 +23,6 @@ | ||
| 23 | 23 | ||
| 24 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 25 | 25 | ||
| 26 | -static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 27 | - const SymbolTable &sym_table) { | ||
| 28 | - OfflineRecognitionResult r; | ||
| 29 | - r.tokens.reserve(src.tokens.size()); | ||
| 30 | - | ||
| 31 | - std::string text; | ||
| 32 | - for (auto i : src.tokens) { | ||
| 33 | - if (!sym_table.Contains(i)) { | ||
| 34 | - continue; | ||
| 35 | - } | ||
| 36 | - | ||
| 37 | - const auto &s = sym_table[i]; | ||
| 38 | - text += s; | ||
| 39 | - r.tokens.push_back(s); | ||
| 40 | - } | ||
| 41 | - | ||
| 42 | - r.text = text; | ||
| 43 | - r.lang = src.lang; | ||
| 44 | - | ||
| 45 | - return r; | ||
| 46 | -} | ||
| 47 | - | ||
| 48 | class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | 26 | class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { |
| 49 | public: | 27 | public: |
| 50 | explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) | 28 | explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) |
| @@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 156 | std::move(cross_kv.second)); | 134 | std::move(cross_kv.second)); |
| 157 | 135 | ||
| 158 | auto r = Convert(results[0], symbol_table_); | 136 | auto r = Convert(results[0], symbol_table_); |
| 159 | - r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 160 | s->SetResult(r); | 137 | s->SetResult(r); |
| 161 | } catch (const Ort::Exception &ex) { | 138 | } catch (const Ort::Exception &ex) { |
| 162 | SHERPA_ONNX_LOGE( | 139 | SHERPA_ONNX_LOGE( |
| @@ -170,6 +147,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -170,6 +147,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 170 | } | 147 | } |
| 171 | 148 | ||
| 172 | private: | 149 | private: |
| 150 | + OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 151 | + const SymbolTable &sym_table) const { | ||
| 152 | + OfflineRecognitionResult r; | ||
| 153 | + r.tokens.reserve(src.tokens.size()); | ||
| 154 | + | ||
| 155 | + std::string text; | ||
| 156 | + for (auto i : src.tokens) { | ||
| 157 | + if (!sym_table.Contains(i)) { | ||
| 158 | + continue; | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + std::string s = sym_table[i]; | ||
| 162 | + s = ApplyInverseTextNormalization(s); | ||
| 163 | + | ||
| 164 | + text += s; | ||
| 165 | + r.tokens.push_back(s); | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + r.text = text; | ||
| 169 | + r.lang = src.lang; | ||
| 170 | + | ||
| 171 | + return r; | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + private: | ||
| 173 | OfflineRecognizerConfig config_; | 175 | OfflineRecognizerConfig config_; |
| 174 | SymbolTable symbol_table_; | 176 | SymbolTable symbol_table_; |
| 175 | std::unique_ptr<OfflineWhisperModel> model_; | 177 | std::unique_ptr<OfflineWhisperModel> model_; |
| @@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { | @@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { | ||
| 55 | EXPECT_EQ(s.size() + 4, v.size()); | 55 | EXPECT_EQ(s.size() + 4, v.size()); |
| 56 | } | 56 | } |
| 57 | 57 | ||
| 58 | + | ||
| 59 | +// Tests for sanitizeUtf8 | ||
| 60 | +TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) { | ||
| 61 | + std::string input = "Valid UTF-8 🌍"; | ||
| 62 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input); | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) { | ||
| 66 | + std::string input = "Invalid \xFF UTF-8"; | ||
| 67 | + std::string expected = "Invalid UTF-8"; | ||
| 68 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 69 | +} | ||
| 70 | + | ||
| 71 | +TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) { | ||
| 72 | + std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence | ||
| 73 | + std::string expected = "Broken "; | ||
| 74 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 75 | +} | ||
| 76 | + | ||
| 77 | +TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) { | ||
| 78 | + std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences | ||
| 79 | + std::string expected = "Test "; | ||
| 80 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 81 | +} | ||
| 82 | + | ||
| 83 | +TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) { | ||
| 84 | + std::string input = "\x20\xC4"; // Space followed by an invalid byte | ||
| 85 | + std::string expected = " "; // 0xC4 removed | ||
| 86 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 87 | +} | ||
| 88 | + | ||
| 89 | +TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) { | ||
| 90 | + std::string input = "Edge 🏆💯"; | ||
| 91 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input); | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) { | ||
| 95 | + std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF"; | ||
| 96 | + std::string expected = "Mix € 😁 "; // Invalid bytes removed | ||
| 97 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 98 | +} | ||
| 99 | + | ||
| 100 | +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) { | ||
| 101 | + std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4) | ||
| 102 | + std::string expected = " "; // Space remains, 0xC4 is removed | ||
| 103 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 104 | +} | ||
| 105 | + | ||
| 106 | +TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) { | ||
| 107 | + std::string input = "Hello \xc4 world"; // Invalid `0xC4` | ||
| 108 | + std::string expected = "Hello world"; // `0xC4` should be removed | ||
| 109 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 110 | +} | ||
| 111 | + | ||
| 112 | +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) { | ||
| 113 | + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` | ||
| 114 | + std::string expected = " "; // `0xc4` should be removed, space remains | ||
| 115 | + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | ||
| 116 | +} | ||
| 117 | + | ||
| 118 | +TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) { | ||
| 119 | + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` | ||
| 120 | + std::string output = RemoveInvalidUtf8Sequences(input); | ||
| 121 | + | ||
| 122 | + std::cout << "Processed string: "; | ||
| 123 | + for (unsigned char c : output) { | ||
| 124 | + printf("\\x%02x ", c); | ||
| 125 | + } | ||
| 126 | + std::cout << std::endl; | ||
| 127 | + | ||
| 128 | + EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space | ||
| 129 | +} | ||
| 130 | + | ||
| 58 | } // namespace sherpa_onnx | 131 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论