Fangjun Kuang
Committed by GitHub

Support removing invalid utf-8 sequences. (#1648)

@@ -545,6 +545,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -545,6 +545,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
545 pad-sequence-test.cc 545 pad-sequence-test.cc
546 slice-test.cc 546 slice-test.cc
547 stack-test.cc 547 stack-test.cc
  548 + text-utils-test.cc
548 text2token-test.cc 549 text2token-test.cc
549 transpose-test.cc 550 transpose-test.cc
550 unbind-test.cc 551 unbind-test.cc
@@ -488,6 +488,8 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( @@ -488,6 +488,8 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
488 488
489 std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( 489 std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
490 std::string text) const { 490 std::string text) const {
  491 + text = RemoveInvalidUtf8Sequences(text);
  492 +
491 if (!itn_list_.empty()) { 493 if (!itn_list_.empty()) {
492 for (const auto &tn : itn_list_) { 494 for (const auto &tn : itn_list_) {
493 text = tn->Normalize(text); 495 text = tn->Normalize(text);
@@ -194,6 +194,8 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, @@ -194,6 +194,8 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr,
194 194
195 std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( 195 std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
196 std::string text) const { 196 std::string text) const {
  197 + text = RemoveInvalidUtf8Sequences(text);
  198 +
197 if (!itn_list_.empty()) { 199 if (!itn_list_.empty()) {
198 for (const auto &tn : itn_list_) { 200 for (const auto &tn : itn_list_) {
199 text = tn->Normalize(text); 201 text = tn->Normalize(text);
  1 +// sherpa-onnx/csrc/text-utils-test.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/text-utils.h"
  6 +
  7 +#include "gtest/gtest.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +TEST(RemoveInvalidUtf8Sequences, Case1) {
  12 + std::vector<uint8_t> v = {
  13 + 0xe4, 0xbb, 0x8a, // 今
  14 + 0xe5, 0xa4, 0xa9, // 天
  15 + 'i', 's', ' ', 'M', 'o', 'd', 'a', 'y', ',', // is Monday,
  16 + ' ', 'w', 'i', 'e', ' ', 'h', 'e', 'i', 0xc3, // wie heißen Size
  17 + 0x9f, 'e', 'n', ' ', 'S', 'i', 'e', 0xf0, 0x9d, 0x84, 0x81};
  18 +
  19 + std::vector<uint8_t> v0 = v;
  20 + v0[1] = 0xc0; // make the first 3 bytes an invalid utf8 character
  21 + std::string s0{v0.begin(), v0.end()};
  22 + EXPECT_EQ(s0.size(), v0.size());
  23 +
  24 + auto s = RemoveInvalidUtf8Sequences(s0); // should remove 今
  25 +
  26 + v0 = v;
  27 + // v0[23] == 0xc3
  28 + // v0[24] == 0x9f
  29 +
  30 + v0[23] = 0xc1;
  31 +
  32 + s0 = {v0.begin(), v0.end()};
  33 + s = RemoveInvalidUtf8Sequences(s0); // should remove ß
  34 +
  35 + EXPECT_EQ(s.size() + 2, v.size());
  36 +
  37 + v0 = v;
  38 + // v0[31] = 0xf0;
  39 + // v0[32] = 0x9d;
  40 + // v0[33] = 0x84;
  41 + // v0[34] = 0x81;
  42 + v0[31] = 0xf5;
  43 +
  44 + s0 = {v0.begin(), v0.end()};
  45 + s = RemoveInvalidUtf8Sequences(s0);
  46 +
  47 + EXPECT_EQ(s.size() + 4, v.size());
  48 +}
  49 +
  50 +} // namespace sherpa_onnx
@@ -396,4 +396,110 @@ void ToLowerCase(std::string *in_out) { @@ -396,4 +396,110 @@ void ToLowerCase(std::string *in_out) {
396 [](unsigned char c) { return std::tolower(c); }); 396 [](unsigned char c) { return std::tolower(c); });
397 } 397 }
398 398
  399 +static inline bool InRange(uint8_t x, uint8_t low, uint8_t high) {
  400 + return low <= x && x <= high;
  401 +}
  402 +
  403 +/*
  404 +Please see
  405 +https://stackoverflow.com/questions/6555015/check-for-invalid-utf8
  406 +
  407 +
  408 +Table 3-7. Well-Formed UTF-8 Byte Sequences
  409 +
  410 +Code Points First Byte Second Byte Third Byte Fourth Byte
  411 +U+0000..U+007F 00..7F
  412 +U+0080..U+07FF C2..DF 80..BF
  413 +U+0800..U+0FFF E0 A0..BF 80..BF
  414 +U+1000..U+CFFF E1..EC 80..BF 80..BF
  415 +U+D000..U+D7FF ED 80..9F 80..BF
  416 +U+E000..U+FFFF EE..EF 80..BF 80..BF
  417 +U+10000..U+3FFFF F0 90..BF 80..BF 80..BF
  418 +U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF
  419 +U+100000..U+10FFFF F4 80..8F 80..BF 80..BF
  420 + */
  421 +std::string RemoveInvalidUtf8Sequences(const std::string &text,
  422 + bool show_debug_msg /*= false*/) {
  423 + int32_t n = static_cast<int32_t>(text.size());
  424 +
  425 + std::string ans;
  426 + ans.reserve(n);
  427 +
  428 + int32_t i = 0;
  429 + const uint8_t *p = reinterpret_cast<const uint8_t *>(text.data());
  430 + while (i < n) {
  431 + if (p[i] <= 0x7f) {
  432 + ans.append(text, i, 1);
  433 + i += 1;
  434 + continue;
  435 + }
  436 +
  437 + if (InRange(p[i], 0xc2, 0xdf) && i + 1 < n &&
  438 + InRange(p[i + 1], 0x80, 0xbf)) {
  439 + ans.append(text, i, 2);
  440 + i += 2;
  441 + continue;
  442 + }
  443 +
  444 + if (p[i] == 0xe0 && i + 2 < n && InRange(p[i + 1], 0xa0, 0xbf) &&
  445 + InRange(p[i + 2], 0x80, 0xbf)) {
  446 + ans.append(text, i, 3);
  447 + i += 3;
  448 + continue;
  449 + }
  450 +
  451 + if (InRange(p[i], 0xe1, 0xec) && i + 2 < n &&
  452 + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) {
  453 + ans.append(text, i, 3);
  454 + i += 3;
  455 + continue;
  456 + }
  457 +
  458 + if (p[i] == 0xed && i + 2 < n && InRange(p[i + 1], 0x80, 0x9f) &&
  459 + InRange(p[i + 2], 0x80, 0xbf)) {
  460 + ans.append(text, i, 3);
  461 + i += 3;
  462 + continue;
  463 + }
  464 +
  465 + if (InRange(p[i], 0xee, 0xef) && i + 2 < n &&
  466 + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) {
  467 + ans.append(text, i, 3);
  468 + i += 3;
  469 + continue;
  470 + }
  471 +
  472 + if (p[i] == 0xf0 && i + 3 < n && InRange(p[i + 1], 0x90, 0xbf) &&
  473 + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) {
  474 + ans.append(text, i, 4);
  475 + i += 4;
  476 + continue;
  477 + }
  478 +
  479 + if (InRange(p[i], 0xf1, 0xf3) && i + 3 < n &&
  480 + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf) &&
  481 + InRange(p[i + 3], 0x80, 0xbf)) {
  482 + ans.append(text, i, 4);
  483 + i += 4;
  484 + continue;
  485 + }
  486 +
  487 + if (p[i] == 0xf4 && i + 3 < n && InRange(p[i + 1], 0x80, 0x8f) &&
  488 + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) {
  489 + ans.append(text, i, 4);
  490 + i += 4;
  491 + continue;
  492 + }
  493 +
  494 + if (show_debug_msg) {
  495 + SHERPA_ONNX_LOGE("Ignore invalid utf8 sequence at pos: %d, value: %02x",
  496 + i, p[i]);
  497 + }
  498 +
  499 + i += 1;
  500 + }
  501 +
  502 + return ans;
  503 +}
  504 +
399 } // namespace sherpa_onnx 505 } // namespace sherpa_onnx
@@ -124,6 +124,9 @@ std::vector<std::string> SplitUtf8(const std::string &text); @@ -124,6 +124,9 @@ std::vector<std::string> SplitUtf8(const std::string &text);
124 std::string ToLowerCase(const std::string &s); 124 std::string ToLowerCase(const std::string &s);
125 void ToLowerCase(std::string *in_out); 125 void ToLowerCase(std::string *in_out);
126 126
  127 +std::string RemoveInvalidUtf8Sequences(const std::string &text,
  128 + bool show_debug_msg = false);
  129 +
127 } // namespace sherpa_onnx 130 } // namespace sherpa_onnx
128 131
129 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 132 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_