Fangjun Kuang
Committed by GitHub

Fix punctuation (#976)

@@ -76,6 +76,14 @@ jobs: @@ -76,6 +76,14 @@ jobs:
76 otool -L build/bin/sherpa-onnx 76 otool -L build/bin/sherpa-onnx
77 otool -l build/bin/sherpa-onnx 77 otool -l build/bin/sherpa-onnx
78 78
  79 + - name: Test offline punctuation
  80 + shell: bash
  81 + run: |
  82 + export PATH=$PWD/build/bin:$PATH
  83 + export EXE=sherpa-onnx-offline-punctuation
  84 +
  85 + .github/scripts/test-offline-punctuation.sh
  86 +
79 - name: Test offline transducer 87 - name: Test offline transducer
80 shell: bash 88 shell: bash
81 run: | 89 run: |
@@ -92,13 +100,7 @@ jobs: @@ -92,13 +100,7 @@ jobs:
92 100
93 .github/scripts/test-online-ctc.sh 101 .github/scripts/test-online-ctc.sh
94 102
95 - - name: Test offline punctuation  
96 - shell: bash  
97 - run: |  
98 - export PATH=$PWD/build/bin:$PATH  
99 - export EXE=sherpa-onnx-offline-punctuation  
100 103
101 - .github/scripts/test-offline-punctuation.sh  
102 104
103 - name: Test C API 105 - name: Test C API
104 shell: bash 106 shell: bash
@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
69 std::vector<int32_t> punctuations; 69 std::vector<int32_t> punctuations;
70 int32_t last = -1; 70 int32_t last = -1;
71 for (int32_t i = 0; i != num_segments; ++i) { 71 for (int32_t i = 0; i != num_segments; ++i) {
72 - int32_t this_start = i * segment_size; // inclusive  
73 - int32_t this_end = this_start + segment_size; // exclusive 72 + int32_t this_start = i * segment_size; // included
  73 + int32_t this_end = this_start + segment_size; // not included
74 if (this_end > static_cast<int32_t>(token_ids.size())) { 74 if (this_end > static_cast<int32_t>(token_ids.size())) {
75 this_end = token_ids.size(); 75 this_end = token_ids.size();
76 } 76 }
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
113 int32_t dot_index = -1; 113 int32_t dot_index = -1;
114 int32_t comma_index = -1; 114 int32_t comma_index = -1;
115 115
116 - for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) { 116 + for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
  117 + m >= 1; --m) {
117 int32_t punct_id = this_punctuations[m]; 118 int32_t punct_id = this_punctuations[m];
118 119
119 if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { 120 if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { @@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
137 } 138 }
138 139
139 if (i == num_segments - 1) { 140 if (i == num_segments - 1) {
140 - dot_index = token_ids.size() - 1; 141 + dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
141 } 142 }
142 } else { 143 } else {
143 last = this_start + dot_index + 1; 144 last = this_start + dot_index + 1;
144 } 145 }
145 146
146 - if (dot_index != 1) { 147 + if (dot_index != -1) {
147 punctuations.insert(punctuations.end(), this_punctuations.begin(), 148 punctuations.insert(punctuations.end(), this_punctuations.begin(),
148 this_punctuations.begin() + (dot_index + 1)); 149 this_punctuations.begin() + (dot_index + 1));
149 } 150 }