Fangjun Kuang
Committed by GitHub

Support getting word IDs for CTC HLG decoding. (#978)

... ... @@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;
/// The decoded word IDs
/// Note: tokens.size() is usually not equal to words.size()
/// words is empty for greedy search decoding.
/// it is not empty when an HLG graph or an HLG graph is used.
std::vector<int32_t> words;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
///
/// tokens.size() == timestamps.size()
std::vector<int32_t> timestamps;
};
... ...
... ... @@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
// -1 here since the input labels are incremented during graph
// construction
r.tokens.push_back(arc.ilabel - 1);
if (arc.olabel != 0) {
r.words.push_back(arc.olabel);
}
r.timestamps.push_back(t);
prev = arc.ilabel;
... ...
... ... @@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode(
if (timestamps.size() == results[i].tokens.size()) {
results[i].timestamps = std::move(timestamps);
} else {
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
static_cast<int32_t>(results[i].tokens.size()),
static_cast<int32_t>(timestamps.size()));
}
}
}
... ...
... ... @@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
r.timestamps.push_back(time);
}
r.words = std::move(src.words);
return r;
}
... ...
... ... @@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const {
}
sep = ", ";
}
os << "], ";
sep = "";
os << "\""
<< "words"
<< "\""
<< ": ";
os << "[";
for (int32_t w : words) {
os << sep << w;
sep = ", ";
}
os << "]";
os << "}";
... ...
... ... @@ -30,6 +30,8 @@ struct OfflineRecognitionResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
std::vector<int32_t> words;
std::string AsJsonString() const;
};
... ...
... ... @@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;
/// The decoded word IDs
/// Note: tokens.size() is usually not equal to words.size()
/// words is empty for greedy search decoding.
/// it is not empty when an HLG graph or an HLG graph is used.
std::vector<int32_t> words;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
///
/// tokens.size() == timestamps.size()
std::vector<int32_t> timestamps;
int32_t num_trailing_blanks = 0;
... ...
... ... @@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
bool ok = decoder->GetBestPath(&fst_out);
if (ok) {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out_unused;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
&osymbols_out_unused, nullptr);
std::vector<int32_t> osymbols_out;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
nullptr);
std::vector<int64_t> tokens;
tokens.reserve(isymbols_out.size());
... ... @@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
}
result->tokens = std::move(tokens);
result->words = std::move(osymbols_out);
result->timestamps = std::move(timestamps);
// no need to set frame_offset
}
... ...
... ... @@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
}
r.segment = segment;
r.words = std::move(src.words);
r.start_time = frames_since_start * frame_shift_ms / 1000.;
return r;
... ...
... ... @@ -22,14 +22,16 @@ namespace sherpa_onnx {
template <typename T>
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
if (precision != 0) {
oss << std::fixed << std::setprecision(precision);
}
oss << "[";
std::string sep = "";
for (const auto &item : vec) {
oss << sep << item;
sep = ", ";
}
oss << " ]";
oss << "]";
return oss.str();
}
... ... @@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string> &vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
oss << "[";
std::string sep = "";
for (const auto &item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
oss << " ]";
oss << "]";
return oss.str();
}
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"text\": "
<< "\"" << text << "\""
<< ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"words\": " << VecToString(words, 0) << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
... ...
... ... @@ -47,6 +47,8 @@ struct OnlineRecognizerResult {
/// log-domain scores from "hot-phrase" contextual boosting
std::vector<float> context_scores;
std::vector<int32_t> words;
/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
... ...
... ... @@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
})
.def_property_readonly("tokens",
[](const PyClass &self) { return self.tokens; })
.def_property_readonly("words",
[](const PyClass &self) { return self.words; })
.def_property_readonly(
"timestamps", [](const PyClass &self) { return self.timestamps; });
}
... ...
... ... @@ -41,6 +41,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
.def_property_readonly(
"segment", [](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"words",
[](PyClass &self) -> std::vector<int32_t> { return self.words; })
.def_property_readonly(
"is_final", [](PyClass &self) -> bool { return self.is_final; })
.def("__str__", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>())
... ...