Karel Vesely
Committed by GitHub

Track token scores (#571)

* add export of per-token scores (ys, lm, context)

- for best path of the modified-beam-search decoding of transducer

* refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult

* export per-token scores also for greedy-search (online-transducer)

- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing

* fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
1 build 1 build
2 *.zip 2 *.zip
3 *.tgz 3 *.tgz
  4 +*.sw?
4 onnxruntime-* 5 onnxruntime-*
5 icefall-* 6 icefall-*
6 run.sh 7 run.sh
@@ -29,9 +29,21 @@ struct Hypothesis { @@ -29,9 +29,21 @@ struct Hypothesis {
29 std::vector<int32_t> timestamps; 29 std::vector<int32_t> timestamps;
30 30
31 // The acoustic probability for each token in ys. 31 // The acoustic probability for each token in ys.
32 - // Only used for keyword spotting task. 32 + // Used for keyword spotting task.
  33 + // For transducer mofified beam-search and greedy-search,
  34 + // this is filled with log_posterior scores.
33 std::vector<float> ys_probs; 35 std::vector<float> ys_probs;
34 36
  37 + // lm_probs[i] contains the lm score for each token in ys.
  38 + // Used only in transducer mofified beam-search.
  39 + // Elements filled only if LM is used.
  40 + std::vector<float> lm_probs;
  41 +
  42 + // context_scores[i] contains the context-graph score for each token in ys.
  43 + // Used only in transducer mofified beam-search.
  44 + // Elements filled only if `ContextGraph` is used.
  45 + std::vector<float> context_scores;
  46 +
35 // The total score of ys in log space. 47 // The total score of ys in log space.
36 // It contains only acoustic scores 48 // It contains only acoustic scores
37 double log_prob = 0; 49 double log_prob = 0;
@@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
69 r.timestamps.push_back(time); 69 r.timestamps.push_back(time);
70 } 70 }
71 71
  72 + r.ys_probs = std::move(src.ys_probs);
  73 + r.lm_probs = std::move(src.lm_probs);
  74 + r.context_scores = std::move(src.context_scores);
  75 +
72 r.segment = segment; 76 r.segment = segment;
73 r.start_time = frames_since_start * frame_shift_ms / 1000.; 77 r.start_time = frames_since_start * frame_shift_ms / 1000.;
74 78
@@ -18,56 +18,50 @@ @@ -18,56 +18,50 @@
18 18
19 namespace sherpa_onnx { 19 namespace sherpa_onnx {
20 20
21 -std::string OnlineRecognizerResult::AsJsonString() const {  
22 - std::ostringstream os;  
23 - os << "{";  
24 - os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";  
25 - os << "\"segment\":" << segment << ", ";  
26 - os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time  
27 - << ", ";  
28 -  
29 - os << "\"text\""  
30 - << ": ";  
31 - os << "\"" << text << "\""  
32 - << ", ";  
33 -  
34 - os << "\""  
35 - << "timestamps"  
36 - << "\""  
37 - << ": ";  
38 - os << "[";  
39 - 21 +/// Helper for `OnlineRecognizerResult::AsJsonString()`
  22 +template<typename T>
  23 +std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
  24 + std::ostringstream oss;
  25 + oss << std::fixed << std::setprecision(precision);
  26 + oss << "[ ";
40 std::string sep = ""; 27 std::string sep = "";
41 - for (auto t : timestamps) {  
42 - os << sep << std::fixed << std::setprecision(2) << t; 28 + for (const auto& item : vec) {
  29 + oss << sep << item;
43 sep = ", "; 30 sep = ", ";
44 } 31 }
45 - os << "], ";  
46 -  
47 - os << "\""  
48 - << "tokens"  
49 - << "\""  
50 - << ":";  
51 - os << "[";  
52 -  
53 - sep = "";  
54 - auto oldFlags = os.flags();  
55 - for (const auto &t : tokens) {  
56 - if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {  
57 - const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());  
58 - os << sep << "\""  
59 - << "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])  
60 - << ">"  
61 - << "\"";  
62 - os.flags(oldFlags);  
63 - } else {  
64 - os << sep << "\"" << t << "\"";  
65 - } 32 + oss << " ]";
  33 + return oss.str();
  34 +}
  35 +
  36 +/// Helper for `OnlineRecognizerResult::AsJsonString()`
  37 +template<> // explicit specialization for T = std::string
  38 +std::string VecToString<std::string>(const std::vector<std::string>& vec,
  39 + int32_t) { // ignore 2nd arg
  40 + std::ostringstream oss;
  41 + oss << "[ ";
  42 + std::string sep = "";
  43 + for (const auto& item : vec) {
  44 + oss << sep << "\"" << item << "\"";
66 sep = ", "; 45 sep = ", ";
67 } 46 }
68 - os << "]";  
69 - os << "}"; 47 + oss << " ]";
  48 + return oss.str();
  49 +}
70 50
  51 +std::string OnlineRecognizerResult::AsJsonString() const {
  52 + std::ostringstream os;
  53 + os << "{ ";
  54 + os << "\"text\": " << "\"" << text << "\"" << ", ";
  55 + os << "\"tokens\": " << VecToString(tokens) << ", ";
  56 + os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
  57 + os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
  58 + os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
  59 + os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
  60 + os << "\"segment\": " << segment << ", ";
  61 + os << "\"start_time\": " << std::fixed << std::setprecision(2)
  62 + << start_time << ", ";
  63 + os << "\"is_final\": " << (is_final ? "true" : "false");
  64 + os << "}";
71 return os.str(); 65 return os.str();
72 } 66 }
73 67
@@ -40,6 +40,12 @@ struct OnlineRecognizerResult { @@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
40 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 40 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
41 std::vector<float> timestamps; 41 std::vector<float> timestamps;
42 42
  43 + std::vector<float> ys_probs; //< log-prob scores from ASR model
  44 + std::vector<float> lm_probs; //< log-prob scores from language model
  45 + //
  46 + /// log-domain scores from "hot-phrase" contextual boosting
  47 + std::vector<float> context_scores;
  48 +
43 /// ID of this segment 49 /// ID of this segment
44 /// When an endpoint is detected, it is incremented 50 /// When an endpoint is detected, it is incremented
45 int32_t segment = 0; 51 int32_t segment = 0;
@@ -58,6 +64,9 @@ struct OnlineRecognizerResult { @@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
58 * "text": "The recognition result", 64 * "text": "The recognition result",
59 * "tokens": [x, x, x], 65 * "tokens": [x, x, x],
60 * "timestamps": [x, x, x], 66 * "timestamps": [x, x, x],
  67 + * "ys_probs": [x, x, x],
  68 + * "lm_probs": [x, x, x],
  69 + * "context_scores": [x, x, x],
61 * "segment": x, 70 * "segment": x,
62 * "start_time": x, 71 * "start_time": x,
63 * "is_final": true|false 72 * "is_final": true|false
@@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
37 frame_offset = other.frame_offset; 37 frame_offset = other.frame_offset;
38 timestamps = other.timestamps; 38 timestamps = other.timestamps;
39 39
  40 + ys_probs = other.ys_probs;
  41 + lm_probs = other.lm_probs;
  42 + context_scores = other.context_scores;
  43 +
40 return *this; 44 return *this;
41 } 45 }
42 46
@@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
60 frame_offset = other.frame_offset; 64 frame_offset = other.frame_offset;
61 timestamps = std::move(other.timestamps); 65 timestamps = std::move(other.timestamps);
62 66
  67 + ys_probs = std::move(other.ys_probs);
  68 + lm_probs = std::move(other.lm_probs);
  69 + context_scores = std::move(other.context_scores);
  70 +
63 return *this; 71 return *this;
64 } 72 }
65 73
@@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult { @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
26 /// timestamps[i] contains the output frame index where tokens[i] is decoded. 26 /// timestamps[i] contains the output frame index where tokens[i] is decoded.
27 std::vector<int32_t> timestamps; 27 std::vector<int32_t> timestamps;
28 28
  29 + std::vector<float> ys_probs;
  30 + std::vector<float> lm_probs;
  31 + std::vector<float> context_scores;
  32 +
29 // Cache decoder_out for endpointing 33 // Cache decoder_out for endpointing
30 Ort::Value decoder_out; 34 Ort::Value decoder_out;
31 35
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
71 r->tokens = std::vector<int64_t>(start, end); 71 r->tokens = std::vector<int64_t>(start, end);
72 } 72 }
73 73
  74 +
74 void OnlineTransducerGreedySearchDecoder::Decode( 75 void OnlineTransducerGreedySearchDecoder::Decode(
75 Ort::Value encoder_out, 76 Ort::Value encoder_out,
76 std::vector<OnlineTransducerDecoderResult> *result) { 77 std::vector<OnlineTransducerDecoderResult> *result) {
  78 +
77 std::vector<int64_t> encoder_out_shape = 79 std::vector<int64_t> encoder_out_shape =
78 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 80 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
79 81
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
97 break; 99 break;
98 } 100 }
99 } 101 }
  102 +
100 if (is_batch_decoder_out_cached) { 103 if (is_batch_decoder_out_cached) {
101 auto &r = result->front(); 104 auto &r = result->front();
102 std::vector<int64_t> decoder_out_shape = 105 std::vector<int64_t> decoder_out_shape =
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
124 if (blank_penalty_ > 0.0) { 127 if (blank_penalty_ > 0.0) {
125 p_logit[0] -= blank_penalty_; // assuming blank id is 0 128 p_logit[0] -= blank_penalty_; // assuming blank id is 0
126 } 129 }
  130 +
127 auto y = static_cast<int32_t>(std::distance( 131 auto y = static_cast<int32_t>(std::distance(
128 static_cast<const float *>(p_logit), 132 static_cast<const float *>(p_logit),
129 std::max_element(static_cast<const float *>(p_logit), 133 std::max_element(static_cast<const float *>(p_logit),
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
138 } else { 142 } else {
139 ++r.num_trailing_blanks; 143 ++r.num_trailing_blanks;
140 } 144 }
  145 +
  146 + // export the per-token log scores
  147 + if (y != 0 && y != unk_id_) {
  148 + LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
  149 + // save time by doing it only for
  150 + // emitted symbols
  151 + const float *p_logprob = p_logit; // rename p_logit as p_logprob,
  152 + // now it contains normalized
  153 + // probability
  154 + r.ys_probs.push_back(p_logprob[y]);
  155 + }
141 } 156 }
142 if (emitted) { 157 if (emitted) {
143 Ort::Value decoder_input = model_->BuildDecoderInput(*result); 158 Ort::Value decoder_input = model_->BuildDecoderInput(*result);
@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
59 std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); 59 std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
60 r->tokens = std::move(tokens); 60 r->tokens = std::move(tokens);
61 r->timestamps = std::move(hyp.timestamps); 61 r->timestamps = std::move(hyp.timestamps);
  62 +
  63 + // export per-token scores
  64 + r->ys_probs = std::move(hyp.ys_probs);
  65 + r->lm_probs = std::move(hyp.lm_probs);
  66 + r->context_scores = std::move(hyp.context_scores);
  67 +
62 r->num_trailing_blanks = hyp.num_trailing_blanks; 68 r->num_trailing_blanks = hyp.num_trailing_blanks;
63 } 69 }
64 70
@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
180 new_hyp.log_prob = p_logprob[k] + context_score - 186 new_hyp.log_prob = p_logprob[k] + context_score -
181 prev_lm_log_prob; // log_prob only includes the 187 prev_lm_log_prob; // log_prob only includes the
182 // score of the transducer 188 // score of the transducer
  189 + // export the per-token log scores
  190 + if (new_token != 0 && new_token != unk_id_) {
  191 + const Hypothesis& prev_i = prev[hyp_index];
  192 + // subtract 'prev[i]' path scores, which were added before
  193 + // getting topk tokens
  194 + float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
  195 + new_hyp.ys_probs.push_back(y_prob);
  196 +
  197 + if (lm_) { // export only when LM is used
  198 + float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
  199 + if (lm_scale_ != 0.0) {
  200 + lm_prob /= lm_scale_; // remove lm-scale
  201 + }
  202 + new_hyp.lm_probs.push_back(lm_prob);
  203 + }
  204 +
  205 + // export only when `ContextGraph` is used
  206 + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
  207 + new_hyp.context_scores.push_back(context_score);
  208 + }
  209 + }
  210 +
183 hyps.Add(std::move(new_hyp)); 211 hyps.Add(std::move(new_hyp));
184 } // for (auto k : topk) 212 } // for (auto k : topk)
185 cur.push_back(std::move(hyps)); 213 cur.push_back(std::move(hyps));
@@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
28 [](PyClass &self) -> float { return self.start_time; }) 28 [](PyClass &self) -> float { return self.start_time; })
29 .def_property_readonly( 29 .def_property_readonly(
30 "timestamps", 30 "timestamps",
31 - [](PyClass &self) -> std::vector<float> { return self.timestamps; }); 31 + [](PyClass &self) -> std::vector<float> { return self.timestamps; })
  32 + .def_property_readonly(
  33 + "ys_probs",
  34 + [](PyClass &self) -> std::vector<float> { return self.ys_probs; })
  35 + .def_property_readonly(
  36 + "lm_probs",
  37 + [](PyClass &self) -> std::vector<float> { return self.lm_probs; })
  38 + .def_property_readonly(
  39 + "context_scores",
  40 + [](PyClass &self) -> std::vector<float> {
  41 + return self.context_scores;
  42 + })
  43 + .def_property_readonly(
  44 + "segment",
  45 + [](PyClass &self) -> int32_t { return self.segment; })
  46 + .def_property_readonly(
  47 + "is_final",
  48 + [](PyClass &self) -> bool { return self.is_final; })
  49 + .def("as_json_string", &PyClass::AsJsonString,
  50 + py::call_guard<py::gil_scoped_release>());
32 } 51 }
33 52
34 static void PybindOnlineRecognizerConfig(py::module *m) { 53 static void PybindOnlineRecognizerConfig(py::module *m) {
@@ -503,6 +503,9 @@ class OnlineRecognizer(object): @@ -503,6 +503,9 @@ class OnlineRecognizer(object):
503 def get_result(self, s: OnlineStream) -> str: 503 def get_result(self, s: OnlineStream) -> str:
504 return self.recognizer.get_result(s).text.strip() 504 return self.recognizer.get_result(s).text.strip()
505 505
  506 + def get_result_as_json_string(self, s: OnlineStream) -> str:
  507 + return self.recognizer.get_result(s).as_json_string()
  508 +
506 def tokens(self, s: OnlineStream) -> List[str]: 509 def tokens(self, s: OnlineStream) -> List[str]:
507 return self.recognizer.get_result(s).tokens 510 return self.recognizer.get_result(s).tokens
508 511
@@ -512,6 +515,15 @@ class OnlineRecognizer(object): @@ -512,6 +515,15 @@ class OnlineRecognizer(object):
512 def start_time(self, s: OnlineStream) -> float: 515 def start_time(self, s: OnlineStream) -> float:
513 return self.recognizer.get_result(s).start_time 516 return self.recognizer.get_result(s).start_time
514 517
  518 + def ys_probs(self, s: OnlineStream) -> List[float]:
  519 + return self.recognizer.get_result(s).ys_probs
  520 +
  521 + def lm_probs(self, s: OnlineStream) -> List[float]:
  522 + return self.recognizer.get_result(s).lm_probs
  523 +
  524 + def context_scores(self, s: OnlineStream) -> List[float]:
  525 + return self.recognizer.get_result(s).context_scores
  526 +
515 def is_endpoint(self, s: OnlineStream) -> bool: 527 def is_endpoint(self, s: OnlineStream) -> bool:
516 return self.recognizer.is_endpoint(s) 528 return self.recognizer.is_endpoint(s)
517 529