正在显示
2 个修改的文件
包含
14 行增加
和
11 行删除
| @@ -76,6 +76,14 @@ jobs: | @@ -76,6 +76,14 @@ jobs: | ||
| 76 | otool -L build/bin/sherpa-onnx | 76 | otool -L build/bin/sherpa-onnx |
| 77 | otool -l build/bin/sherpa-onnx | 77 | otool -l build/bin/sherpa-onnx |
| 78 | 78 | ||
| 79 | + - name: Test offline punctuation | ||
| 80 | + shell: bash | ||
| 81 | + run: | | ||
| 82 | + export PATH=$PWD/build/bin:$PATH | ||
| 83 | + export EXE=sherpa-onnx-offline-punctuation | ||
| 84 | + | ||
| 85 | + .github/scripts/test-offline-punctuation.sh | ||
| 86 | + | ||
| 79 | - name: Test offline transducer | 87 | - name: Test offline transducer |
| 80 | shell: bash | 88 | shell: bash |
| 81 | run: | | 89 | run: | |
| @@ -92,13 +100,7 @@ jobs: | @@ -92,13 +100,7 @@ jobs: | ||
| 92 | 100 | ||
| 93 | .github/scripts/test-online-ctc.sh | 101 | .github/scripts/test-online-ctc.sh |
| 94 | 102 | ||
| 95 | - - name: Test offline punctuation | ||
| 96 | - shell: bash | ||
| 97 | - run: | | ||
| 98 | - export PATH=$PWD/build/bin:$PATH | ||
| 99 | - export EXE=sherpa-onnx-offline-punctuation | ||
| 100 | 103 | ||
| 101 | - .github/scripts/test-offline-punctuation.sh | ||
| 102 | 104 | ||
| 103 | - name: Test C API | 105 | - name: Test C API |
| 104 | shell: bash | 106 | shell: bash |
| @@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 69 | std::vector<int32_t> punctuations; | 69 | std::vector<int32_t> punctuations; |
| 70 | int32_t last = -1; | 70 | int32_t last = -1; |
| 71 | for (int32_t i = 0; i != num_segments; ++i) { | 71 | for (int32_t i = 0; i != num_segments; ++i) { |
| 72 | - int32_t this_start = i * segment_size; // inclusive | ||
| 73 | - int32_t this_end = this_start + segment_size; // exclusive | 72 | + int32_t this_start = i * segment_size; // included |
| 73 | + int32_t this_end = this_start + segment_size; // not included | ||
| 74 | if (this_end > static_cast<int32_t>(token_ids.size())) { | 74 | if (this_end > static_cast<int32_t>(token_ids.size())) { |
| 75 | this_end = token_ids.size(); | 75 | this_end = token_ids.size(); |
| 76 | } | 76 | } |
| @@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 113 | int32_t dot_index = -1; | 113 | int32_t dot_index = -1; |
| 114 | int32_t comma_index = -1; | 114 | int32_t comma_index = -1; |
| 115 | 115 | ||
| 116 | - for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) { | 116 | + for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2; |
| 117 | + m >= 1; --m) { | ||
| 117 | int32_t punct_id = this_punctuations[m]; | 118 | int32_t punct_id = this_punctuations[m]; |
| 118 | 119 | ||
| 119 | if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { | 120 | if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { |
| @@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 137 | } | 138 | } |
| 138 | 139 | ||
| 139 | if (i == num_segments - 1) { | 140 | if (i == num_segments - 1) { |
| 140 | - dot_index = token_ids.size() - 1; | 141 | + dot_index = static_cast<int32_t>(this_punctuations.size()) - 1; |
| 141 | } | 142 | } |
| 142 | } else { | 143 | } else { |
| 143 | last = this_start + dot_index + 1; | 144 | last = this_start + dot_index + 1; |
| 144 | } | 145 | } |
| 145 | 146 | ||
| 146 | - if (dot_index != 1) { | 147 | + if (dot_index != -1) { |
| 147 | punctuations.insert(punctuations.end(), this_punctuations.begin(), | 148 | punctuations.insert(punctuations.end(), this_punctuations.begin(), |
| 148 | this_punctuations.begin() + (dot_index + 1)); | 149 | this_punctuations.begin() + (dot_index + 1)); |
| 149 | } | 150 | } |
-
请 注册 或 登录 后发表评论