Committed by
GitHub
Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
正在显示
11 个修改的文件
包含
152 行增加
和
46 行删除
| @@ -29,9 +29,21 @@ struct Hypothesis { | @@ -29,9 +29,21 @@ struct Hypothesis { | ||
| 29 | std::vector<int32_t> timestamps; | 29 | std::vector<int32_t> timestamps; |
| 30 | 30 | ||
| 31 | // The acoustic probability for each token in ys. | 31 | // The acoustic probability for each token in ys. |
| 32 | - // Only used for keyword spotting task. | 32 | + // Used for keyword spotting task. |
| 33 | + // For transducer mofified beam-search and greedy-search, | ||
| 34 | + // this is filled with log_posterior scores. | ||
| 33 | std::vector<float> ys_probs; | 35 | std::vector<float> ys_probs; |
| 34 | 36 | ||
| 37 | + // lm_probs[i] contains the lm score for each token in ys. | ||
| 38 | + // Used only in transducer mofified beam-search. | ||
| 39 | + // Elements filled only if LM is used. | ||
| 40 | + std::vector<float> lm_probs; | ||
| 41 | + | ||
| 42 | + // context_scores[i] contains the context-graph score for each token in ys. | ||
| 43 | + // Used only in transducer mofified beam-search. | ||
| 44 | + // Elements filled only if `ContextGraph` is used. | ||
| 45 | + std::vector<float> context_scores; | ||
| 46 | + | ||
| 35 | // The total score of ys in log space. | 47 | // The total score of ys in log space. |
| 36 | // It contains only acoustic scores | 48 | // It contains only acoustic scores |
| 37 | double log_prob = 0; | 49 | double log_prob = 0; |
| @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 69 | r.timestamps.push_back(time); | 69 | r.timestamps.push_back(time); |
| 70 | } | 70 | } |
| 71 | 71 | ||
| 72 | + r.ys_probs = std::move(src.ys_probs); | ||
| 73 | + r.lm_probs = std::move(src.lm_probs); | ||
| 74 | + r.context_scores = std::move(src.context_scores); | ||
| 75 | + | ||
| 72 | r.segment = segment; | 76 | r.segment = segment; |
| 73 | r.start_time = frames_since_start * frame_shift_ms / 1000.; | 77 | r.start_time = frames_since_start * frame_shift_ms / 1000.; |
| 74 | 78 |
| @@ -18,56 +18,50 @@ | @@ -18,56 +18,50 @@ | ||
| 18 | 18 | ||
| 19 | namespace sherpa_onnx { | 19 | namespace sherpa_onnx { |
| 20 | 20 | ||
| 21 | -std::string OnlineRecognizerResult::AsJsonString() const { | ||
| 22 | - std::ostringstream os; | ||
| 23 | - os << "{"; | ||
| 24 | - os << "\"is_final\":" << (is_final ? "true" : "false") << ", "; | ||
| 25 | - os << "\"segment\":" << segment << ", "; | ||
| 26 | - os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time | ||
| 27 | - << ", "; | ||
| 28 | - | ||
| 29 | - os << "\"text\"" | ||
| 30 | - << ": "; | ||
| 31 | - os << "\"" << text << "\"" | ||
| 32 | - << ", "; | ||
| 33 | - | ||
| 34 | - os << "\"" | ||
| 35 | - << "timestamps" | ||
| 36 | - << "\"" | ||
| 37 | - << ": "; | ||
| 38 | - os << "["; | ||
| 39 | - | 21 | +/// Helper for `OnlineRecognizerResult::AsJsonString()` |
| 22 | +template<typename T> | ||
| 23 | +std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { | ||
| 24 | + std::ostringstream oss; | ||
| 25 | + oss << std::fixed << std::setprecision(precision); | ||
| 26 | + oss << "[ "; | ||
| 40 | std::string sep = ""; | 27 | std::string sep = ""; |
| 41 | - for (auto t : timestamps) { | ||
| 42 | - os << sep << std::fixed << std::setprecision(2) << t; | 28 | + for (const auto& item : vec) { |
| 29 | + oss << sep << item; | ||
| 43 | sep = ", "; | 30 | sep = ", "; |
| 44 | } | 31 | } |
| 45 | - os << "], "; | ||
| 46 | - | ||
| 47 | - os << "\"" | ||
| 48 | - << "tokens" | ||
| 49 | - << "\"" | ||
| 50 | - << ":"; | ||
| 51 | - os << "["; | ||
| 52 | - | ||
| 53 | - sep = ""; | ||
| 54 | - auto oldFlags = os.flags(); | ||
| 55 | - for (const auto &t : tokens) { | ||
| 56 | - if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) { | ||
| 57 | - const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str()); | ||
| 58 | - os << sep << "\"" | ||
| 59 | - << "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0]) | ||
| 60 | - << ">" | ||
| 61 | - << "\""; | ||
| 62 | - os.flags(oldFlags); | ||
| 63 | - } else { | ||
| 64 | - os << sep << "\"" << t << "\""; | ||
| 65 | - } | 32 | + oss << " ]"; |
| 33 | + return oss.str(); | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +/// Helper for `OnlineRecognizerResult::AsJsonString()` | ||
| 37 | +template<> // explicit specialization for T = std::string | ||
| 38 | +std::string VecToString<std::string>(const std::vector<std::string>& vec, | ||
| 39 | + int32_t) { // ignore 2nd arg | ||
| 40 | + std::ostringstream oss; | ||
| 41 | + oss << "[ "; | ||
| 42 | + std::string sep = ""; | ||
| 43 | + for (const auto& item : vec) { | ||
| 44 | + oss << sep << "\"" << item << "\""; | ||
| 66 | sep = ", "; | 45 | sep = ", "; |
| 67 | } | 46 | } |
| 68 | - os << "]"; | ||
| 69 | - os << "}"; | 47 | + oss << " ]"; |
| 48 | + return oss.str(); | ||
| 49 | +} | ||
| 70 | 50 | ||
| 51 | +std::string OnlineRecognizerResult::AsJsonString() const { | ||
| 52 | + std::ostringstream os; | ||
| 53 | + os << "{ "; | ||
| 54 | + os << "\"text\": " << "\"" << text << "\"" << ", "; | ||
| 55 | + os << "\"tokens\": " << VecToString(tokens) << ", "; | ||
| 56 | + os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; | ||
| 57 | + os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; | ||
| 58 | + os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; | ||
| 59 | + os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; | ||
| 60 | + os << "\"segment\": " << segment << ", "; | ||
| 61 | + os << "\"start_time\": " << std::fixed << std::setprecision(2) | ||
| 62 | + << start_time << ", "; | ||
| 63 | + os << "\"is_final\": " << (is_final ? "true" : "false"); | ||
| 64 | + os << "}"; | ||
| 71 | return os.str(); | 65 | return os.str(); |
| 72 | } | 66 | } |
| 73 | 67 |
| @@ -40,6 +40,12 @@ struct OnlineRecognizerResult { | @@ -40,6 +40,12 @@ struct OnlineRecognizerResult { | ||
| 40 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. | 40 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. |
| 41 | std::vector<float> timestamps; | 41 | std::vector<float> timestamps; |
| 42 | 42 | ||
| 43 | + std::vector<float> ys_probs; //< log-prob scores from ASR model | ||
| 44 | + std::vector<float> lm_probs; //< log-prob scores from language model | ||
| 45 | + // | ||
| 46 | + /// log-domain scores from "hot-phrase" contextual boosting | ||
| 47 | + std::vector<float> context_scores; | ||
| 48 | + | ||
| 43 | /// ID of this segment | 49 | /// ID of this segment |
| 44 | /// When an endpoint is detected, it is incremented | 50 | /// When an endpoint is detected, it is incremented |
| 45 | int32_t segment = 0; | 51 | int32_t segment = 0; |
| @@ -58,6 +64,9 @@ struct OnlineRecognizerResult { | @@ -58,6 +64,9 @@ struct OnlineRecognizerResult { | ||
| 58 | * "text": "The recognition result", | 64 | * "text": "The recognition result", |
| 59 | * "tokens": [x, x, x], | 65 | * "tokens": [x, x, x], |
| 60 | * "timestamps": [x, x, x], | 66 | * "timestamps": [x, x, x], |
| 67 | + * "ys_probs": [x, x, x], | ||
| 68 | + * "lm_probs": [x, x, x], | ||
| 69 | + * "context_scores": [x, x, x], | ||
| 61 | * "segment": x, | 70 | * "segment": x, |
| 62 | * "start_time": x, | 71 | * "start_time": x, |
| 63 | * "is_final": true|false | 72 | * "is_final": true|false |
| @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 37 | frame_offset = other.frame_offset; | 37 | frame_offset = other.frame_offset; |
| 38 | timestamps = other.timestamps; | 38 | timestamps = other.timestamps; |
| 39 | 39 | ||
| 40 | + ys_probs = other.ys_probs; | ||
| 41 | + lm_probs = other.lm_probs; | ||
| 42 | + context_scores = other.context_scores; | ||
| 43 | + | ||
| 40 | return *this; | 44 | return *this; |
| 41 | } | 45 | } |
| 42 | 46 | ||
| @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 60 | frame_offset = other.frame_offset; | 64 | frame_offset = other.frame_offset; |
| 61 | timestamps = std::move(other.timestamps); | 65 | timestamps = std::move(other.timestamps); |
| 62 | 66 | ||
| 67 | + ys_probs = std::move(other.ys_probs); | ||
| 68 | + lm_probs = std::move(other.lm_probs); | ||
| 69 | + context_scores = std::move(other.context_scores); | ||
| 70 | + | ||
| 63 | return *this; | 71 | return *this; |
| 64 | } | 72 | } |
| 65 | 73 |
| @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult { | @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult { | ||
| 26 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. | 26 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. |
| 27 | std::vector<int32_t> timestamps; | 27 | std::vector<int32_t> timestamps; |
| 28 | 28 | ||
| 29 | + std::vector<float> ys_probs; | ||
| 30 | + std::vector<float> lm_probs; | ||
| 31 | + std::vector<float> context_scores; | ||
| 32 | + | ||
| 29 | // Cache decoder_out for endpointing | 33 | // Cache decoder_out for endpointing |
| 30 | Ort::Value decoder_out; | 34 | Ort::Value decoder_out; |
| 31 | 35 |
| @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | ||
| 71 | r->tokens = std::vector<int64_t>(start, end); | 71 | r->tokens = std::vector<int64_t>(start, end); |
| 72 | } | 72 | } |
| 73 | 73 | ||
| 74 | + | ||
| 74 | void OnlineTransducerGreedySearchDecoder::Decode( | 75 | void OnlineTransducerGreedySearchDecoder::Decode( |
| 75 | Ort::Value encoder_out, | 76 | Ort::Value encoder_out, |
| 76 | std::vector<OnlineTransducerDecoderResult> *result) { | 77 | std::vector<OnlineTransducerDecoderResult> *result) { |
| 78 | + | ||
| 77 | std::vector<int64_t> encoder_out_shape = | 79 | std::vector<int64_t> encoder_out_shape = |
| 78 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 80 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 79 | 81 | ||
| @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 97 | break; | 99 | break; |
| 98 | } | 100 | } |
| 99 | } | 101 | } |
| 102 | + | ||
| 100 | if (is_batch_decoder_out_cached) { | 103 | if (is_batch_decoder_out_cached) { |
| 101 | auto &r = result->front(); | 104 | auto &r = result->front(); |
| 102 | std::vector<int64_t> decoder_out_shape = | 105 | std::vector<int64_t> decoder_out_shape = |
| @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 124 | if (blank_penalty_ > 0.0) { | 127 | if (blank_penalty_ > 0.0) { |
| 125 | p_logit[0] -= blank_penalty_; // assuming blank id is 0 | 128 | p_logit[0] -= blank_penalty_; // assuming blank id is 0 |
| 126 | } | 129 | } |
| 130 | + | ||
| 127 | auto y = static_cast<int32_t>(std::distance( | 131 | auto y = static_cast<int32_t>(std::distance( |
| 128 | static_cast<const float *>(p_logit), | 132 | static_cast<const float *>(p_logit), |
| 129 | std::max_element(static_cast<const float *>(p_logit), | 133 | std::max_element(static_cast<const float *>(p_logit), |
| @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 138 | } else { | 142 | } else { |
| 139 | ++r.num_trailing_blanks; | 143 | ++r.num_trailing_blanks; |
| 140 | } | 144 | } |
| 145 | + | ||
| 146 | + // export the per-token log scores | ||
| 147 | + if (y != 0 && y != unk_id_) { | ||
| 148 | + LogSoftmax(p_logit, vocab_size); // renormalize probabilities, | ||
| 149 | + // save time by doing it only for | ||
| 150 | + // emitted symbols | ||
| 151 | + const float *p_logprob = p_logit; // rename p_logit as p_logprob, | ||
| 152 | + // now it contains normalized | ||
| 153 | + // probability | ||
| 154 | + r.ys_probs.push_back(p_logprob[y]); | ||
| 155 | + } | ||
| 141 | } | 156 | } |
| 142 | if (emitted) { | 157 | if (emitted) { |
| 143 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); | 158 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | ||
| 59 | std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); | 59 | std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); |
| 60 | r->tokens = std::move(tokens); | 60 | r->tokens = std::move(tokens); |
| 61 | r->timestamps = std::move(hyp.timestamps); | 61 | r->timestamps = std::move(hyp.timestamps); |
| 62 | + | ||
| 63 | + // export per-token scores | ||
| 64 | + r->ys_probs = std::move(hyp.ys_probs); | ||
| 65 | + r->lm_probs = std::move(hyp.lm_probs); | ||
| 66 | + r->context_scores = std::move(hyp.context_scores); | ||
| 67 | + | ||
| 62 | r->num_trailing_blanks = hyp.num_trailing_blanks; | 68 | r->num_trailing_blanks = hyp.num_trailing_blanks; |
| 63 | } | 69 | } |
| 64 | 70 | ||
| @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 180 | new_hyp.log_prob = p_logprob[k] + context_score - | 186 | new_hyp.log_prob = p_logprob[k] + context_score - |
| 181 | prev_lm_log_prob; // log_prob only includes the | 187 | prev_lm_log_prob; // log_prob only includes the |
| 182 | // score of the transducer | 188 | // score of the transducer |
| 189 | + // export the per-token log scores | ||
| 190 | + if (new_token != 0 && new_token != unk_id_) { | ||
| 191 | + const Hypothesis& prev_i = prev[hyp_index]; | ||
| 192 | + // subtract 'prev[i]' path scores, which were added before | ||
| 193 | + // getting topk tokens | ||
| 194 | + float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; | ||
| 195 | + new_hyp.ys_probs.push_back(y_prob); | ||
| 196 | + | ||
| 197 | + if (lm_) { // export only when LM is used | ||
| 198 | + float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; | ||
| 199 | + if (lm_scale_ != 0.0) { | ||
| 200 | + lm_prob /= lm_scale_; // remove lm-scale | ||
| 201 | + } | ||
| 202 | + new_hyp.lm_probs.push_back(lm_prob); | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + // export only when `ContextGraph` is used | ||
| 206 | + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { | ||
| 207 | + new_hyp.context_scores.push_back(context_score); | ||
| 208 | + } | ||
| 209 | + } | ||
| 210 | + | ||
| 183 | hyps.Add(std::move(new_hyp)); | 211 | hyps.Add(std::move(new_hyp)); |
| 184 | } // for (auto k : topk) | 212 | } // for (auto k : topk) |
| 185 | cur.push_back(std::move(hyps)); | 213 | cur.push_back(std::move(hyps)); |
| @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 28 | [](PyClass &self) -> float { return self.start_time; }) | 28 | [](PyClass &self) -> float { return self.start_time; }) |
| 29 | .def_property_readonly( | 29 | .def_property_readonly( |
| 30 | "timestamps", | 30 | "timestamps", |
| 31 | - [](PyClass &self) -> std::vector<float> { return self.timestamps; }); | 31 | + [](PyClass &self) -> std::vector<float> { return self.timestamps; }) |
| 32 | + .def_property_readonly( | ||
| 33 | + "ys_probs", | ||
| 34 | + [](PyClass &self) -> std::vector<float> { return self.ys_probs; }) | ||
| 35 | + .def_property_readonly( | ||
| 36 | + "lm_probs", | ||
| 37 | + [](PyClass &self) -> std::vector<float> { return self.lm_probs; }) | ||
| 38 | + .def_property_readonly( | ||
| 39 | + "context_scores", | ||
| 40 | + [](PyClass &self) -> std::vector<float> { | ||
| 41 | + return self.context_scores; | ||
| 42 | + }) | ||
| 43 | + .def_property_readonly( | ||
| 44 | + "segment", | ||
| 45 | + [](PyClass &self) -> int32_t { return self.segment; }) | ||
| 46 | + .def_property_readonly( | ||
| 47 | + "is_final", | ||
| 48 | + [](PyClass &self) -> bool { return self.is_final; }) | ||
| 49 | + .def("as_json_string", &PyClass::AsJsonString, | ||
| 50 | + py::call_guard<py::gil_scoped_release>()); | ||
| 32 | } | 51 | } |
| 33 | 52 | ||
| 34 | static void PybindOnlineRecognizerConfig(py::module *m) { | 53 | static void PybindOnlineRecognizerConfig(py::module *m) { |
| @@ -503,6 +503,9 @@ class OnlineRecognizer(object): | @@ -503,6 +503,9 @@ class OnlineRecognizer(object): | ||
| 503 | def get_result(self, s: OnlineStream) -> str: | 503 | def get_result(self, s: OnlineStream) -> str: |
| 504 | return self.recognizer.get_result(s).text.strip() | 504 | return self.recognizer.get_result(s).text.strip() |
| 505 | 505 | ||
| 506 | + def get_result_as_json_string(self, s: OnlineStream) -> str: | ||
| 507 | + return self.recognizer.get_result(s).as_json_string() | ||
| 508 | + | ||
| 506 | def tokens(self, s: OnlineStream) -> List[str]: | 509 | def tokens(self, s: OnlineStream) -> List[str]: |
| 507 | return self.recognizer.get_result(s).tokens | 510 | return self.recognizer.get_result(s).tokens |
| 508 | 511 | ||
| @@ -512,6 +515,15 @@ class OnlineRecognizer(object): | @@ -512,6 +515,15 @@ class OnlineRecognizer(object): | ||
| 512 | def start_time(self, s: OnlineStream) -> float: | 515 | def start_time(self, s: OnlineStream) -> float: |
| 513 | return self.recognizer.get_result(s).start_time | 516 | return self.recognizer.get_result(s).start_time |
| 514 | 517 | ||
| 518 | + def ys_probs(self, s: OnlineStream) -> List[float]: | ||
| 519 | + return self.recognizer.get_result(s).ys_probs | ||
| 520 | + | ||
| 521 | + def lm_probs(self, s: OnlineStream) -> List[float]: | ||
| 522 | + return self.recognizer.get_result(s).lm_probs | ||
| 523 | + | ||
| 524 | + def context_scores(self, s: OnlineStream) -> List[float]: | ||
| 525 | + return self.recognizer.get_result(s).context_scores | ||
| 526 | + | ||
| 515 | def is_endpoint(self, s: OnlineStream) -> bool: | 527 | def is_endpoint(self, s: OnlineStream) -> bool: |
| 516 | return self.recognizer.is_endpoint(s) | 528 | return self.recognizer.is_endpoint(s) |
| 517 | 529 |
-
请 注册 或 登录 后发表评论