Fangjun Kuang
Committed by GitHub

Fix a punctuation bug (#764)

1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.9.18") 4 +set(SHERPA_ONNX_VERSION "1.9.19")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
98 int32_t dot_index = -1; 98 int32_t dot_index = -1;
99 int32_t comma_index = -1; 99 int32_t comma_index = -1;
100 100
101 - for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) { 101 + for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
102 int32_t punct_id = this_punctuations[m]; 102 int32_t punct_id = this_punctuations[m];
103 103
104 if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { 104 if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
@@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
126 } 126 }
127 } else { 127 } else {
128 last = this_start + dot_index + 1; 128 last = this_start + dot_index + 1;
  129 + }
129 130
  131 + if (dot_index != 1) {
130 punctuations.insert(punctuations.end(), this_punctuations.begin(), 132 punctuations.insert(punctuations.end(), this_punctuations.begin(),
131 this_punctuations.begin() + (dot_index + 1)); 133 this_punctuations.begin() + (dot_index + 1));
132 } 134 }
133 } // for (int32_t i = 0; i != num_segments; ++i) 135 } // for (int32_t i = 0; i != num_segments; ++i)
134 136
135 - if (punctuations.size() != token_ids.size() &&  
136 - punctuations.size() + 1 == token_ids.size()) {  
137 - punctuations.push_back(meta_data.dot_id);  
138 - }  
139 -  
140 - if (punctuations.size() != token_ids.size()) {  
141 - SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",  
142 - text.c_str(), static_cast<int32_t>(punctuations.size()),  
143 - static_cast<int32_t>(token_ids.size()));  
144 - return text;  
145 - }  
146 -  
147 std::string ans; 137 std::string ans;
148 138
149 for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) { 139 for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
  140 + if (i > tokens.size()) {
  141 + break;
  142 + }
150 const std::string &w = tokens[i]; 143 const std::string &w = tokens[i];
151 if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { 144 if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
152 ans.push_back(' '); 145 ans.push_back(' ');
@@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
156 ans.append(meta_data.id2punct[punctuations[i]]); 149 ans.append(meta_data.id2punct[punctuations[i]]);
157 } 150 }
158 } 151 }
  152 + if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
  153 + ans.push_back(meta_data.dot_id);
  154 + }
159 155
160 return ans; 156 return ans;
161 } 157 }