Fangjun Kuang
Committed by GitHub

Support getting word IDs for CTC HLG decoding. (#978)

@@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult { @@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult {
15 /// The decoded token IDs 15 /// The decoded token IDs
16 std::vector<int64_t> tokens; 16 std::vector<int64_t> tokens;
17 17
  18 + /// The decoded word IDs
  19 + /// Note: tokens.size() is usually not equal to words.size()
  20 + /// words is empty for greedy search decoding.
  21 + /// it is not empty when an HLG graph or an HLG graph is used.
  22 + std::vector<int32_t> words;
  23 +
18 /// timestamps[i] contains the output frame index where tokens[i] is decoded. 24 /// timestamps[i] contains the output frame index where tokens[i] is decoded.
19 /// Note: The index is after subsampling 25 /// Note: The index is after subsampling
  26 + ///
  27 + /// tokens.size() == timestamps.size()
20 std::vector<int32_t> timestamps; 28 std::vector<int32_t> timestamps;
21 }; 29 };
22 30
@@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, @@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
108 // -1 here since the input labels are incremented during graph 108 // -1 here since the input labels are incremented during graph
109 // construction 109 // construction
110 r.tokens.push_back(arc.ilabel - 1); 110 r.tokens.push_back(arc.ilabel - 1);
  111 + if (arc.olabel != 0) {
  112 + r.words.push_back(arc.olabel);
  113 + }
111 114
112 r.timestamps.push_back(t); 115 r.timestamps.push_back(t);
113 prev = arc.ilabel; 116 prev = arc.ilabel;
@@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode( @@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode(
64 64
65 if (timestamps.size() == results[i].tokens.size()) { 65 if (timestamps.size() == results[i].tokens.size()) {
66 results[i].timestamps = std::move(timestamps); 66 results[i].timestamps = std::move(timestamps);
67 - } else {  
68 - SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,  
69 - static_cast<int32_t>(results[i].tokens.size()),  
70 - static_cast<int32_t>(timestamps.size()));  
71 } 67 }
72 } 68 }
73 } 69 }
@@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
65 r.timestamps.push_back(time); 65 r.timestamps.push_back(time);
66 } 66 }
67 67
  68 + r.words = std::move(src.words);
  69 +
68 return r; 70 return r;
69 } 71 }
70 72
@@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const { @@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const {
339 } 339 }
340 sep = ", "; 340 sep = ", ";
341 } 341 }
  342 + os << "], ";
  343 +
  344 + sep = "";
  345 +
  346 + os << "\""
  347 + << "words"
  348 + << "\""
  349 + << ": ";
  350 + os << "[";
  351 + for (int32_t w : words) {
  352 + os << sep << w;
  353 + sep = ", ";
  354 + }
  355 +
342 os << "]"; 356 os << "]";
343 os << "}"; 357 os << "}";
344 358
@@ -30,6 +30,8 @@ struct OfflineRecognitionResult { @@ -30,6 +30,8 @@ struct OfflineRecognitionResult {
30 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 30 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
31 std::vector<float> timestamps; 31 std::vector<float> timestamps;
32 32
  33 + std::vector<int32_t> words;
  34 +
33 std::string AsJsonString() const; 35 std::string AsJsonString() const;
34 }; 36 };
35 37
@@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult { @@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult {
22 /// The decoded token IDs 22 /// The decoded token IDs
23 std::vector<int64_t> tokens; 23 std::vector<int64_t> tokens;
24 24
  25 + /// The decoded word IDs
  26 + /// Note: tokens.size() is usually not equal to words.size()
  27 + /// words is empty for greedy search decoding.
  28 + /// it is not empty when an HLG graph or an HLG graph is used.
  29 + std::vector<int32_t> words;
  30 +
25 /// timestamps[i] contains the output frame index where tokens[i] is decoded. 31 /// timestamps[i] contains the output frame index where tokens[i] is decoded.
26 /// Note: The index is after subsampling 32 /// Note: The index is after subsampling
  33 + ///
  34 + /// tokens.size() == timestamps.size()
27 std::vector<int32_t> timestamps; 35 std::vector<int32_t> timestamps;
28 36
29 int32_t num_trailing_blanks = 0; 37 int32_t num_trailing_blanks = 0;
@@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, @@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
51 bool ok = decoder->GetBestPath(&fst_out); 51 bool ok = decoder->GetBestPath(&fst_out);
52 if (ok) { 52 if (ok) {
53 std::vector<int32_t> isymbols_out; 53 std::vector<int32_t> isymbols_out;
54 - std::vector<int32_t> osymbols_out_unused;  
55 - ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,  
56 - &osymbols_out_unused, nullptr); 54 + std::vector<int32_t> osymbols_out;
  55 + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
  56 + nullptr);
57 std::vector<int64_t> tokens; 57 std::vector<int64_t> tokens;
58 tokens.reserve(isymbols_out.size()); 58 tokens.reserve(isymbols_out.size());
59 59
@@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, @@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
83 } 83 }
84 84
85 result->tokens = std::move(tokens); 85 result->tokens = std::move(tokens);
  86 + result->words = std::move(osymbols_out);
86 result->timestamps = std::move(timestamps); 87 result->timestamps = std::move(timestamps);
87 // no need to set frame_offset 88 // no need to set frame_offset
88 } 89 }
@@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, @@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
59 } 59 }
60 60
61 r.segment = segment; 61 r.segment = segment;
  62 + r.words = std::move(src.words);
62 r.start_time = frames_since_start * frame_shift_ms / 1000.; 63 r.start_time = frames_since_start * frame_shift_ms / 1000.;
63 64
64 return r; 65 return r;
@@ -22,14 +22,16 @@ namespace sherpa_onnx { @@ -22,14 +22,16 @@ namespace sherpa_onnx {
22 template <typename T> 22 template <typename T>
23 std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) { 23 std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
24 std::ostringstream oss; 24 std::ostringstream oss;
25 - oss << std::fixed << std::setprecision(precision);  
26 - oss << "[ "; 25 + if (precision != 0) {
  26 + oss << std::fixed << std::setprecision(precision);
  27 + }
  28 + oss << "[";
27 std::string sep = ""; 29 std::string sep = "";
28 for (const auto &item : vec) { 30 for (const auto &item : vec) {
29 oss << sep << item; 31 oss << sep << item;
30 sep = ", "; 32 sep = ", ";
31 } 33 }
32 - oss << " ]"; 34 + oss << "]";
33 return oss.str(); 35 return oss.str();
34 } 36 }
35 37
@@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string @@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string
38 std::string VecToString<std::string>(const std::vector<std::string> &vec, 40 std::string VecToString<std::string>(const std::vector<std::string> &vec,
39 int32_t) { // ignore 2nd arg 41 int32_t) { // ignore 2nd arg
40 std::ostringstream oss; 42 std::ostringstream oss;
41 - oss << "[ "; 43 + oss << "[";
42 std::string sep = ""; 44 std::string sep = "";
43 for (const auto &item : vec) { 45 for (const auto &item : vec) {
44 oss << sep << "\"" << item << "\""; 46 oss << sep << "\"" << item << "\"";
45 sep = ", "; 47 sep = ", ";
46 } 48 }
47 - oss << " ]"; 49 + oss << "]";
48 return oss.str(); 50 return oss.str();
49 } 51 }
50 52
51 std::string OnlineRecognizerResult::AsJsonString() const { 53 std::string OnlineRecognizerResult::AsJsonString() const {
52 std::ostringstream os; 54 std::ostringstream os;
53 os << "{ "; 55 os << "{ ";
54 - os << "\"text\": " << "\"" << text << "\"" << ", "; 56 + os << "\"text\": "
  57 + << "\"" << text << "\""
  58 + << ", ";
55 os << "\"tokens\": " << VecToString(tokens) << ", "; 59 os << "\"tokens\": " << VecToString(tokens) << ", ";
56 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; 60 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
57 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; 61 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
58 os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; 62 os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
59 os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; 63 os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
60 os << "\"segment\": " << segment << ", "; 64 os << "\"segment\": " << segment << ", ";
  65 + os << "\"words\": " << VecToString(words, 0) << ", ";
61 os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time 66 os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
62 << ", "; 67 << ", ";
63 os << "\"is_final\": " << (is_final ? "true" : "false"); 68 os << "\"is_final\": " << (is_final ? "true" : "false");
@@ -47,6 +47,8 @@ struct OnlineRecognizerResult { @@ -47,6 +47,8 @@ struct OnlineRecognizerResult {
47 /// log-domain scores from "hot-phrase" contextual boosting 47 /// log-domain scores from "hot-phrase" contextual boosting
48 std::vector<float> context_scores; 48 std::vector<float> context_scores;
49 49
  50 + std::vector<int32_t> words;
  51 +
50 /// ID of this segment 52 /// ID of this segment
51 /// When an endpoint is detected, it is incremented 53 /// When an endpoint is detected, it is incremented
52 int32_t segment = 0; 54 int32_t segment = 0;
@@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT @@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
34 }) 34 })
35 .def_property_readonly("tokens", 35 .def_property_readonly("tokens",
36 [](const PyClass &self) { return self.tokens; }) 36 [](const PyClass &self) { return self.tokens; })
  37 + .def_property_readonly("words",
  38 + [](const PyClass &self) { return self.words; })
37 .def_property_readonly( 39 .def_property_readonly(
38 "timestamps", [](const PyClass &self) { return self.timestamps; }); 40 "timestamps", [](const PyClass &self) { return self.timestamps; });
39 } 41 }
@@ -41,6 +41,9 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -41,6 +41,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
41 .def_property_readonly( 41 .def_property_readonly(
42 "segment", [](PyClass &self) -> int32_t { return self.segment; }) 42 "segment", [](PyClass &self) -> int32_t { return self.segment; })
43 .def_property_readonly( 43 .def_property_readonly(
  44 + "words",
  45 + [](PyClass &self) -> std::vector<int32_t> { return self.words; })
  46 + .def_property_readonly(
44 "is_final", [](PyClass &self) -> bool { return self.is_final; }) 47 "is_final", [](PyClass &self) -> bool { return self.is_final; })
45 .def("__str__", &PyClass::AsJsonString, 48 .def("__str__", &PyClass::AsJsonString,
46 py::call_guard<py::gil_scoped_release>()) 49 py::call_guard<py::gil_scoped_release>())