Fangjun Kuang
Committed by GitHub

Fix punctuation (#976)

... ... @@ -76,6 +76,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline punctuation
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-punctuation
.github/scripts/test-offline-punctuation.sh
- name: Test offline transducer
shell: bash
run: |
... ... @@ -92,13 +100,7 @@ jobs:
.github/scripts/test-online-ctc.sh
- name: Test offline punctuation
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-punctuation
.github/scripts/test-offline-punctuation.sh
- name: Test C API
shell: bash
... ...
... ... @@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
std::vector<int32_t> punctuations;
int32_t last = -1;
for (int32_t i = 0; i != num_segments; ++i) {
int32_t this_start = i * segment_size; // inclusive
int32_t this_end = this_start + segment_size; // exclusive
int32_t this_start = i * segment_size; // included
int32_t this_end = this_start + segment_size; // not included
if (this_end > static_cast<int32_t>(token_ids.size())) {
this_end = token_ids.size();
}
... ... @@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
int32_t dot_index = -1;
int32_t comma_index = -1;
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
m >= 1; --m) {
int32_t punct_id = this_punctuations[m];
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
... ... @@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
}
if (i == num_segments - 1) {
dot_index = token_ids.size() - 1;
dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
}
} else {
last = this_start + dot_index + 1;
}
if (dot_index != 1) {
if (dot_index != -1) {
punctuations.insert(punctuations.end(), this_punctuations.begin(),
this_punctuations.begin() + (dot_index + 1));
}
... ...