ivan provalov
Committed by GitHub

Fixing Whisper Model Token Normalization (#1904)

@@ -134,3 +134,5 @@ us_gold.json @@ -134,3 +134,5 @@ us_gold.json
134 us_silver.json 134 us_silver.json
135 kokoro-multi-lang-v1_0 135 kokoro-multi-lang-v1_0
136 sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 136 sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
  137 +cmake-build-debug
  138 +README-DEV.txt
@@ -23,28 +23,6 @@ @@ -23,28 +23,6 @@
23 23
24 namespace sherpa_onnx { 24 namespace sherpa_onnx {
25 25
26 -static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,  
27 - const SymbolTable &sym_table) {  
28 - OfflineRecognitionResult r;  
29 - r.tokens.reserve(src.tokens.size());  
30 -  
31 - std::string text;  
32 - for (auto i : src.tokens) {  
33 - if (!sym_table.Contains(i)) {  
34 - continue;  
35 - }  
36 -  
37 - const auto &s = sym_table[i];  
38 - text += s;  
39 - r.tokens.push_back(s);  
40 - }  
41 -  
42 - r.text = text;  
43 - r.lang = src.lang;  
44 -  
45 - return r;  
46 -}  
47 -  
48 class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { 26 class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
49 public: 27 public:
50 explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) 28 explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
@@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
156 std::move(cross_kv.second)); 134 std::move(cross_kv.second));
157 135
158 auto r = Convert(results[0], symbol_table_); 136 auto r = Convert(results[0], symbol_table_);
159 - r.text = ApplyInverseTextNormalization(std::move(r.text));  
160 s->SetResult(r); 137 s->SetResult(r);
161 } catch (const Ort::Exception &ex) { 138 } catch (const Ort::Exception &ex) {
162 SHERPA_ONNX_LOGE( 139 SHERPA_ONNX_LOGE(
@@ -170,6 +147,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -170,6 +147,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
170 } 147 }
171 148
172 private: 149 private:
  150 + OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
  151 + const SymbolTable &sym_table) const {
  152 + OfflineRecognitionResult r;
  153 + r.tokens.reserve(src.tokens.size());
  154 +
  155 + std::string text;
  156 + for (auto i : src.tokens) {
  157 + if (!sym_table.Contains(i)) {
  158 + continue;
  159 + }
  160 +
  161 + std::string s = sym_table[i];
  162 + s = ApplyInverseTextNormalization(s);
  163 +
  164 + text += s;
  165 + r.tokens.push_back(s);
  166 + }
  167 +
  168 + r.text = text;
  169 + r.lang = src.lang;
  170 +
  171 + return r;
  172 + }
  173 +
  174 + private:
173 OfflineRecognizerConfig config_; 175 OfflineRecognizerConfig config_;
174 SymbolTable symbol_table_; 176 SymbolTable symbol_table_;
175 std::unique_ptr<OfflineWhisperModel> model_; 177 std::unique_ptr<OfflineWhisperModel> model_;
@@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { @@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
55 EXPECT_EQ(s.size() + 4, v.size()); 55 EXPECT_EQ(s.size() + 4, v.size());
56 } 56 }
57 57
  58 +
  59 +// Tests for sanitizeUtf8
  60 +TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
  61 + std::string input = "Valid UTF-8 🌍";
  62 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
  63 +}
  64 +
  65 +TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) {
  66 + std::string input = "Invalid \xFF UTF-8";
  67 + std::string expected = "Invalid UTF-8";
  68 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  69 +}
  70 +
  71 +TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) {
  72 + std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence
  73 + std::string expected = "Broken ";
  74 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  75 +}
  76 +
  77 +TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) {
  78 + std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences
  79 + std::string expected = "Test ";
  80 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  81 +}
  82 +
  83 +TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) {
  84 + std::string input = "\x20\xC4"; // Space followed by an invalid byte
  85 + std::string expected = " "; // 0xC4 removed
  86 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  87 +}
  88 +
  89 +TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) {
  90 + std::string input = "Edge 🏆💯";
  91 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
  92 +}
  93 +
  94 +TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) {
  95 + std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF";
  96 + std::string expected = "Mix € 😁 "; // Invalid bytes removed
  97 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  98 +}
  99 +
  100 +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) {
  101 + std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4)
  102 + std::string expected = " "; // Space remains, 0xC4 is removed
  103 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  104 +}
  105 +
  106 +TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) {
  107 + std::string input = "Hello \xc4 world"; // Invalid `0xC4`
  108 + std::string expected = "Hello world"; // `0xC4` should be removed
  109 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  110 +}
  111 +
  112 +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) {
  113 + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
  114 + std::string expected = " "; // `0xc4` should be removed, space remains
  115 + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
  116 +}
  117 +
  118 +TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) {
  119 + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
  120 + std::string output = RemoveInvalidUtf8Sequences(input);
  121 +
  122 + std::cout << "Processed string: ";
  123 + for (unsigned char c : output) {
  124 + printf("\\x%02x ", c);
  125 + }
  126 + std::cout << std::endl;
  127 +
  128 + EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space
  129 +}
  130 +
58 } // namespace sherpa_onnx 131 } // namespace sherpa_onnx