Karel Vesely
Committed by GitHub

Track token scores (#571)

* add export of per-token scores (ys, lm, context)

- for best path of the modified-beam-search decoding of transducer

* refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult

* export per-token scores also for greedy-search (online-transducer)

- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing

* fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
build
*.zip
*.tgz
*.sw?
onnxruntime-*
icefall-*
run.sh
... ...
... ... @@ -29,9 +29,21 @@ struct Hypothesis {
std::vector<int32_t> timestamps;
// The acoustic probability for each token in ys.
// Only used for keyword spotting task.
// Used for keyword spotting task.
// For transducer mofified beam-search and greedy-search,
// this is filled with log_posterior scores.
std::vector<float> ys_probs;
// lm_probs[i] contains the lm score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if LM is used.
std::vector<float> lm_probs;
// context_scores[i] contains the context-graph score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if `ContextGraph` is used.
std::vector<float> context_scores;
// The total score of ys in log space.
// It contains only acoustic scores
double log_prob = 0;
... ...
... ... @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.push_back(time);
}
r.ys_probs = std::move(src.ys_probs);
r.lm_probs = std::move(src.lm_probs);
r.context_scores = std::move(src.context_scores);
r.segment = segment;
r.start_time = frames_since_start * frame_shift_ms / 1000.;
... ...
... ... @@ -18,56 +18,50 @@
namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
os << "\"segment\":" << segment << ", ";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"text\""
<< ": ";
os << "\"" << text << "\""
<< ", ";
os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
for (const auto& item : vec) {
oss << sep << item;
sep = ", ";
}
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";
sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
oss << " ]";
return oss.str();
}
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
os << "]";
os << "}";
oss << " ]";
return oss.str();
}
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2)
<< start_time << ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
}
... ...
... ... @@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
std::vector<float> ys_probs; //< log-prob scores from ASR model
std::vector<float> lm_probs; //< log-prob scores from language model
//
/// log-domain scores from "hot-phrase" contextual boosting
std::vector<float> context_scores;
/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
... ... @@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "ys_probs": [x, x, x],
* "lm_probs": [x, x, x],
* "context_scores": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
... ...
... ... @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = other.timestamps;
ys_probs = other.ys_probs;
lm_probs = other.lm_probs;
context_scores = other.context_scores;
return *this;
}
... ... @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = std::move(other.timestamps);
ys_probs = std::move(other.ys_probs);
lm_probs = std::move(other.lm_probs);
context_scores = std::move(other.context_scores);
return *this;
}
... ...
... ... @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;
std::vector<float> ys_probs;
std::vector<float> lm_probs;
std::vector<float> context_scores;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
... ...
... ... @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
r->tokens = std::vector<int64_t>(start, end);
}
void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
... ... @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
break;
}
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape =
... ... @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
... ... @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} else {
++r.num_trailing_blanks;
}
// export the per-token log scores
if (y != 0 && y != unk_id_) {
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
// now it contains normalized
// probability
r.ys_probs.push_back(p_logprob[y]);
}
}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
... ...
... ... @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
// export per-token scores
r->ys_probs = std::move(hyp.ys_probs);
r->lm_probs = std::move(hyp.lm_probs);
r->context_scores = std::move(hyp.context_scores);
r->num_trailing_blanks = hyp.num_trailing_blanks;
}
... ... @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
new_hyp.ys_probs.push_back(y_prob);
if (lm_) { // export only when LM is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
new_hyp.lm_probs.push_back(lm_prob);
}
// export only when `ContextGraph` is used
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
new_hyp.context_scores.push_back(context_score);
}
}
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
cur.push_back(std::move(hyps));
... ...
... ... @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
[](PyClass &self) -> float { return self.start_time; })
.def_property_readonly(
"timestamps",
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
.def_property_readonly(
"ys_probs",
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
.def_property_readonly(
"lm_probs",
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
.def_property_readonly(
"context_scores",
[](PyClass &self) -> std::vector<float> {
return self.context_scores;
})
.def_property_readonly(
"segment",
[](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"is_final",
[](PyClass &self) -> bool { return self.is_final; })
.def("as_json_string", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>());
}
static void PybindOnlineRecognizerConfig(py::module *m) {
... ...
... ... @@ -503,6 +503,9 @@ class OnlineRecognizer(object):
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text.strip()
def get_result_as_json_string(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).as_json_string()
def tokens(self, s: OnlineStream) -> List[str]:
return self.recognizer.get_result(s).tokens
... ... @@ -512,6 +515,15 @@ class OnlineRecognizer(object):
def start_time(self, s: OnlineStream) -> float:
return self.recognizer.get_result(s).start_time
def ys_probs(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).ys_probs
def lm_probs(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).lm_probs
def context_scores(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).context_scores
def is_endpoint(self, s: OnlineStream) -> bool:
return self.recognizer.is_endpoint(s)
... ...