Committed by
GitHub
Support getting word IDs for CTC HLG decoding. (#978)
正在显示
13 个修改的文件
包含
59 行增加
和
12 行删除
| @@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult { | @@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult { | ||
| 15 | /// The decoded token IDs | 15 | /// The decoded token IDs |
| 16 | std::vector<int64_t> tokens; | 16 | std::vector<int64_t> tokens; |
| 17 | 17 | ||
| 18 | + /// The decoded word IDs | ||
| 19 | + /// Note: tokens.size() is usually not equal to words.size() | ||
| 20 | + /// words is empty for greedy search decoding. | ||
| 21 | + /// it is not empty when an HLG graph or an HLG graph is used. | ||
| 22 | + std::vector<int32_t> words; | ||
| 23 | + | ||
| 18 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. | 24 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. |
| 19 | /// Note: The index is after subsampling | 25 | /// Note: The index is after subsampling |
| 26 | + /// | ||
| 27 | + /// tokens.size() == timestamps.size() | ||
| 20 | std::vector<int32_t> timestamps; | 28 | std::vector<int32_t> timestamps; |
| 21 | }; | 29 | }; |
| 22 | 30 |
| @@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, | @@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, | ||
| 108 | // -1 here since the input labels are incremented during graph | 108 | // -1 here since the input labels are incremented during graph |
| 109 | // construction | 109 | // construction |
| 110 | r.tokens.push_back(arc.ilabel - 1); | 110 | r.tokens.push_back(arc.ilabel - 1); |
| 111 | + if (arc.olabel != 0) { | ||
| 112 | + r.words.push_back(arc.olabel); | ||
| 113 | + } | ||
| 111 | 114 | ||
| 112 | r.timestamps.push_back(t); | 115 | r.timestamps.push_back(t); |
| 113 | prev = arc.ilabel; | 116 | prev = arc.ilabel; |
| @@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode( | @@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode( | ||
| 64 | 64 | ||
| 65 | if (timestamps.size() == results[i].tokens.size()) { | 65 | if (timestamps.size() == results[i].tokens.size()) { |
| 66 | results[i].timestamps = std::move(timestamps); | 66 | results[i].timestamps = std::move(timestamps); |
| 67 | - } else { | ||
| 68 | - SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i, | ||
| 69 | - static_cast<int32_t>(results[i].tokens.size()), | ||
| 70 | - static_cast<int32_t>(timestamps.size())); | ||
| 71 | } | 67 | } |
| 72 | } | 68 | } |
| 73 | } | 69 | } |
| @@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 65 | r.timestamps.push_back(time); | 65 | r.timestamps.push_back(time); |
| 66 | } | 66 | } |
| 67 | 67 | ||
| 68 | + r.words = std::move(src.words); | ||
| 69 | + | ||
| 68 | return r; | 70 | return r; |
| 69 | } | 71 | } |
| 70 | 72 |
| @@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const { | @@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const { | ||
| 339 | } | 339 | } |
| 340 | sep = ", "; | 340 | sep = ", "; |
| 341 | } | 341 | } |
| 342 | + os << "], "; | ||
| 343 | + | ||
| 344 | + sep = ""; | ||
| 345 | + | ||
| 346 | + os << "\"" | ||
| 347 | + << "words" | ||
| 348 | + << "\"" | ||
| 349 | + << ": "; | ||
| 350 | + os << "["; | ||
| 351 | + for (int32_t w : words) { | ||
| 352 | + os << sep << w; | ||
| 353 | + sep = ", "; | ||
| 354 | + } | ||
| 355 | + | ||
| 342 | os << "]"; | 356 | os << "]"; |
| 343 | os << "}"; | 357 | os << "}"; |
| 344 | 358 |
| @@ -30,6 +30,8 @@ struct OfflineRecognitionResult { | @@ -30,6 +30,8 @@ struct OfflineRecognitionResult { | ||
| 30 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. | 30 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. |
| 31 | std::vector<float> timestamps; | 31 | std::vector<float> timestamps; |
| 32 | 32 | ||
| 33 | + std::vector<int32_t> words; | ||
| 34 | + | ||
| 33 | std::string AsJsonString() const; | 35 | std::string AsJsonString() const; |
| 34 | }; | 36 | }; |
| 35 | 37 |
| @@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult { | @@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult { | ||
| 22 | /// The decoded token IDs | 22 | /// The decoded token IDs |
| 23 | std::vector<int64_t> tokens; | 23 | std::vector<int64_t> tokens; |
| 24 | 24 | ||
| 25 | + /// The decoded word IDs | ||
| 26 | + /// Note: tokens.size() is usually not equal to words.size() | ||
| 27 | + /// words is empty for greedy search decoding. | ||
| 28 | + /// it is not empty when an HLG graph or an HLG graph is used. | ||
| 29 | + std::vector<int32_t> words; | ||
| 30 | + | ||
| 25 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. | 31 | /// timestamps[i] contains the output frame index where tokens[i] is decoded. |
| 26 | /// Note: The index is after subsampling | 32 | /// Note: The index is after subsampling |
| 33 | + /// | ||
| 34 | + /// tokens.size() == timestamps.size() | ||
| 27 | std::vector<int32_t> timestamps; | 35 | std::vector<int32_t> timestamps; |
| 28 | 36 | ||
| 29 | int32_t num_trailing_blanks = 0; | 37 | int32_t num_trailing_blanks = 0; |
| @@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | @@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | ||
| 51 | bool ok = decoder->GetBestPath(&fst_out); | 51 | bool ok = decoder->GetBestPath(&fst_out); |
| 52 | if (ok) { | 52 | if (ok) { |
| 53 | std::vector<int32_t> isymbols_out; | 53 | std::vector<int32_t> isymbols_out; |
| 54 | - std::vector<int32_t> osymbols_out_unused; | ||
| 55 | - ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, | ||
| 56 | - &osymbols_out_unused, nullptr); | 54 | + std::vector<int32_t> osymbols_out; |
| 55 | + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out, | ||
| 56 | + nullptr); | ||
| 57 | std::vector<int64_t> tokens; | 57 | std::vector<int64_t> tokens; |
| 58 | tokens.reserve(isymbols_out.size()); | 58 | tokens.reserve(isymbols_out.size()); |
| 59 | 59 | ||
| @@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | @@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, | ||
| 83 | } | 83 | } |
| 84 | 84 | ||
| 85 | result->tokens = std::move(tokens); | 85 | result->tokens = std::move(tokens); |
| 86 | + result->words = std::move(osymbols_out); | ||
| 86 | result->timestamps = std::move(timestamps); | 87 | result->timestamps = std::move(timestamps); |
| 87 | // no need to set frame_offset | 88 | // no need to set frame_offset |
| 88 | } | 89 | } |
| @@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | @@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 59 | } | 59 | } |
| 60 | 60 | ||
| 61 | r.segment = segment; | 61 | r.segment = segment; |
| 62 | + r.words = std::move(src.words); | ||
| 62 | r.start_time = frames_since_start * frame_shift_ms / 1000.; | 63 | r.start_time = frames_since_start * frame_shift_ms / 1000.; |
| 63 | 64 | ||
| 64 | return r; | 65 | return r; |
| @@ -22,14 +22,16 @@ namespace sherpa_onnx { | @@ -22,14 +22,16 @@ namespace sherpa_onnx { | ||
| 22 | template <typename T> | 22 | template <typename T> |
| 23 | std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) { | 23 | std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) { |
| 24 | std::ostringstream oss; | 24 | std::ostringstream oss; |
| 25 | + if (precision != 0) { | ||
| 25 | oss << std::fixed << std::setprecision(precision); | 26 | oss << std::fixed << std::setprecision(precision); |
| 26 | - oss << "[ "; | 27 | + } |
| 28 | + oss << "["; | ||
| 27 | std::string sep = ""; | 29 | std::string sep = ""; |
| 28 | for (const auto &item : vec) { | 30 | for (const auto &item : vec) { |
| 29 | oss << sep << item; | 31 | oss << sep << item; |
| 30 | sep = ", "; | 32 | sep = ", "; |
| 31 | } | 33 | } |
| 32 | - oss << " ]"; | 34 | + oss << "]"; |
| 33 | return oss.str(); | 35 | return oss.str(); |
| 34 | } | 36 | } |
| 35 | 37 | ||
| @@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string | @@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string | ||
| 38 | std::string VecToString<std::string>(const std::vector<std::string> &vec, | 40 | std::string VecToString<std::string>(const std::vector<std::string> &vec, |
| 39 | int32_t) { // ignore 2nd arg | 41 | int32_t) { // ignore 2nd arg |
| 40 | std::ostringstream oss; | 42 | std::ostringstream oss; |
| 41 | - oss << "[ "; | 43 | + oss << "["; |
| 42 | std::string sep = ""; | 44 | std::string sep = ""; |
| 43 | for (const auto &item : vec) { | 45 | for (const auto &item : vec) { |
| 44 | oss << sep << "\"" << item << "\""; | 46 | oss << sep << "\"" << item << "\""; |
| 45 | sep = ", "; | 47 | sep = ", "; |
| 46 | } | 48 | } |
| 47 | - oss << " ]"; | 49 | + oss << "]"; |
| 48 | return oss.str(); | 50 | return oss.str(); |
| 49 | } | 51 | } |
| 50 | 52 | ||
| 51 | std::string OnlineRecognizerResult::AsJsonString() const { | 53 | std::string OnlineRecognizerResult::AsJsonString() const { |
| 52 | std::ostringstream os; | 54 | std::ostringstream os; |
| 53 | os << "{ "; | 55 | os << "{ "; |
| 54 | - os << "\"text\": " << "\"" << text << "\"" << ", "; | 56 | + os << "\"text\": " |
| 57 | + << "\"" << text << "\"" | ||
| 58 | + << ", "; | ||
| 55 | os << "\"tokens\": " << VecToString(tokens) << ", "; | 59 | os << "\"tokens\": " << VecToString(tokens) << ", "; |
| 56 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; | 60 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; |
| 57 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; | 61 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; |
| 58 | os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; | 62 | os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; |
| 59 | os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; | 63 | os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; |
| 60 | os << "\"segment\": " << segment << ", "; | 64 | os << "\"segment\": " << segment << ", "; |
| 65 | + os << "\"words\": " << VecToString(words, 0) << ", "; | ||
| 61 | os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time | 66 | os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time |
| 62 | << ", "; | 67 | << ", "; |
| 63 | os << "\"is_final\": " << (is_final ? "true" : "false"); | 68 | os << "\"is_final\": " << (is_final ? "true" : "false"); |
| @@ -47,6 +47,8 @@ struct OnlineRecognizerResult { | @@ -47,6 +47,8 @@ struct OnlineRecognizerResult { | ||
| 47 | /// log-domain scores from "hot-phrase" contextual boosting | 47 | /// log-domain scores from "hot-phrase" contextual boosting |
| 48 | std::vector<float> context_scores; | 48 | std::vector<float> context_scores; |
| 49 | 49 | ||
| 50 | + std::vector<int32_t> words; | ||
| 51 | + | ||
| 50 | /// ID of this segment | 52 | /// ID of this segment |
| 51 | /// When an endpoint is detected, it is incremented | 53 | /// When an endpoint is detected, it is incremented |
| 52 | int32_t segment = 0; | 54 | int32_t segment = 0; |
| @@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | @@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | ||
| 34 | }) | 34 | }) |
| 35 | .def_property_readonly("tokens", | 35 | .def_property_readonly("tokens", |
| 36 | [](const PyClass &self) { return self.tokens; }) | 36 | [](const PyClass &self) { return self.tokens; }) |
| 37 | + .def_property_readonly("words", | ||
| 38 | + [](const PyClass &self) { return self.words; }) | ||
| 37 | .def_property_readonly( | 39 | .def_property_readonly( |
| 38 | "timestamps", [](const PyClass &self) { return self.timestamps; }); | 40 | "timestamps", [](const PyClass &self) { return self.timestamps; }); |
| 39 | } | 41 | } |
| @@ -41,6 +41,9 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -41,6 +41,9 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 41 | .def_property_readonly( | 41 | .def_property_readonly( |
| 42 | "segment", [](PyClass &self) -> int32_t { return self.segment; }) | 42 | "segment", [](PyClass &self) -> int32_t { return self.segment; }) |
| 43 | .def_property_readonly( | 43 | .def_property_readonly( |
| 44 | + "words", | ||
| 45 | + [](PyClass &self) -> std::vector<int32_t> { return self.words; }) | ||
| 46 | + .def_property_readonly( | ||
| 44 | "is_final", [](PyClass &self) -> bool { return self.is_final; }) | 47 | "is_final", [](PyClass &self) -> bool { return self.is_final; }) |
| 45 | .def("__str__", &PyClass::AsJsonString, | 48 | .def("__str__", &PyClass::AsJsonString, |
| 46 | py::call_guard<py::gil_scoped_release>()) | 49 | py::call_guard<py::gil_scoped_release>()) |
-
请 注册 或 登录 后发表评论