utils.cc
3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// sherpa-onnx/csrc/utils.cc
//
// Copyright 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/utils.h"
#include <iostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *ids,
std::vector<std::string> *phrases,
std::vector<float> *scores,
std::vector<float> *thresholds) {
SHERPA_ONNX_CHECK(ids != nullptr);
ids->clear();
std::vector<int32_t> tmp_ids;
std::vector<float> tmp_scores;
std::vector<float> tmp_thresholds;
std::vector<std::string> tmp_phrases;
std::string line;
std::string word;
bool has_scores = false;
bool has_thresholds = false;
bool has_phrases = false;
while (std::getline(is, line)) {
float score = 0;
float threshold = 0;
std::string phrase = "";
std::istringstream iss(line);
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
word = word.replace(0, 3, " ");
}
}
if (symbol_table.contains(word)) {
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
} else {
switch (word[0]) {
case ':': // boosting score for current keyword
score = std::stof(word.substr(1));
has_scores = true;
break;
case '#': // triggering threshold (probability) for current keyword
threshold = std::stof(word.substr(1));
has_thresholds = true;
break;
case '@': // the original keyword string
phrase = word.substr(1);
has_phrases = true;
break;
default:
SHERPA_ONNX_LOGE(
"Cannot find ID for token %s at line: %s. (Hint: words on "
"the same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
}
}
}
ids->push_back(std::move(tmp_ids));
tmp_scores.push_back(score);
tmp_phrases.push_back(phrase);
tmp_thresholds.push_back(threshold);
}
if (scores != nullptr) {
if (has_scores) {
scores->swap(tmp_scores);
} else {
scores->clear();
}
}
if (phrases != nullptr) {
if (has_phrases) {
*phrases = std::move(tmp_phrases);
} else {
phrases->clear();
}
}
if (thresholds != nullptr) {
if (has_thresholds) {
thresholds->swap(tmp_thresholds);
} else {
thresholds->clear();
}
}
return true;
}
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
}
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *keywords_id,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold) {
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
threshold);
}
} // namespace sherpa_onnx