正在显示
2 个修改的文件
包含
10 行增加
和
14 行删除
| @@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 98 | int32_t dot_index = -1; | 98 | int32_t dot_index = -1; |
| 99 | int32_t comma_index = -1; | 99 | int32_t comma_index = -1; |
| 100 | 100 | ||
| 101 | - for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) { | 101 | + for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) { |
| 102 | int32_t punct_id = this_punctuations[m]; | 102 | int32_t punct_id = this_punctuations[m]; |
| 103 | 103 | ||
| 104 | if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { | 104 | if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { |
| @@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 126 | } | 126 | } |
| 127 | } else { | 127 | } else { |
| 128 | last = this_start + dot_index + 1; | 128 | last = this_start + dot_index + 1; |
| 129 | + } | ||
| 129 | 130 | ||
| 131 | + if (dot_index != 1) { | ||
| 130 | punctuations.insert(punctuations.end(), this_punctuations.begin(), | 132 | punctuations.insert(punctuations.end(), this_punctuations.begin(), |
| 131 | this_punctuations.begin() + (dot_index + 1)); | 133 | this_punctuations.begin() + (dot_index + 1)); |
| 132 | } | 134 | } |
| 133 | } // for (int32_t i = 0; i != num_segments; ++i) | 135 | } // for (int32_t i = 0; i != num_segments; ++i) |
| 134 | 136 | ||
| 135 | - if (punctuations.size() != token_ids.size() && | ||
| 136 | - punctuations.size() + 1 == token_ids.size()) { | ||
| 137 | - punctuations.push_back(meta_data.dot_id); | ||
| 138 | - } | ||
| 139 | - | ||
| 140 | - if (punctuations.size() != token_ids.size()) { | ||
| 141 | - SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened", | ||
| 142 | - text.c_str(), static_cast<int32_t>(punctuations.size()), | ||
| 143 | - static_cast<int32_t>(token_ids.size())); | ||
| 144 | - return text; | ||
| 145 | - } | ||
| 146 | - | ||
| 147 | std::string ans; | 137 | std::string ans; |
| 148 | 138 | ||
| 149 | for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) { | 139 | for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) { |
| 140 | + if (i > tokens.size()) { | ||
| 141 | + break; | ||
| 142 | + } | ||
| 150 | const std::string &w = tokens[i]; | 143 | const std::string &w = tokens[i]; |
| 151 | if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { | 144 | if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { |
| 152 | ans.push_back(' '); | 145 | ans.push_back(' '); |
| @@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 156 | ans.append(meta_data.id2punct[punctuations[i]]); | 149 | ans.append(meta_data.id2punct[punctuations[i]]); |
| 157 | } | 150 | } |
| 158 | } | 151 | } |
| 152 | + if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) { | ||
| 153 | + ans.push_back(meta_data.dot_id); | ||
| 154 | + } | ||
| 159 | 155 | ||
| 160 | return ans; | 156 | return ans; |
| 161 | } | 157 | } |
-
请 注册 或 登录 后发表评论