offline-ctc-fst-decoder.cc
3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
#include <string>
#include <utility>
#include "fst/fstlib.h"
#include "kaldi-decoder/csrc/decodable-ctc.h"
#include "kaldi-decoder/csrc/eigen.h"
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "sherpa-onnx/csrc/fst-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
/**
* @param decoder
* @param p Pointer to a 2-d array of shape (num_frames, vocab_size)
* @param num_frames Number of rows in the 2-d array.
* @param vocab_size Number of columns in the 2-d array.
* @return Return the decoded result.
*/
static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
const float *p, int32_t num_frames,
int32_t vocab_size) {
OfflineCtcDecoderResult r;
kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size);
decoder->Decode(&decodable);
if (!decoder->ReachedFinal()) {
SHERPA_ONNX_LOGE("Not reached final!");
return r;
}
fst::VectorFst<fst::LatticeArc> decoded; // linear FST.
decoder->GetBestPath(&decoded);
if (decoded.NumStates() == 0) {
SHERPA_ONNX_LOGE("Empty best path!");
return r;
}
auto cur_state = decoded.Start();
int32_t blank_id = 0;
for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) {
fst::ArcIterator<fst::Fst<fst::LatticeArc>> iter(decoded, cur_state);
const auto &arc = iter.Value();
cur_state = arc.nextstate;
if (arc.ilabel == prev) {
continue;
}
// 0 is epsilon here
if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) {
prev = arc.ilabel;
continue;
}
// -1 here since the input labels are incremented during graph
// construction
r.tokens.push_back(arc.ilabel - 1);
if (arc.olabel != 0) {
r.words.push_back(arc.olabel);
}
r.timestamps.push_back(t);
prev = arc.ilabel;
}
return r;
}
OfflineCtcFstDecoder::OfflineCtcFstDecoder(
const OfflineCtcFstDecoderConfig &config)
: config_(config), fst_(ReadGraph(config_.graph)) {}
std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
assert(static_cast<int32_t>(shape.size()) == 3);
int32_t batch_size = shape[0];
int32_t T = shape[1];
int32_t vocab_size = shape[2];
std::vector<int64_t> length_shape =
log_probs_length.GetTensorTypeAndShapeInfo().GetShape();
assert(static_cast<int32_t>(length_shape.size()) == 1);
assert(shape[0] == length_shape[0]);
kaldi_decoder::FasterDecoderOptions opts;
opts.max_active = config_.max_active;
kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts);
const float *start = log_probs.GetTensorData<float>();
std::vector<OfflineCtcDecoderResult> ans;
ans.reserve(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
const float *p = start + i * T * vocab_size;
int32_t num_frames = log_probs_length.GetTensorData<int64_t>()[i];
auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size);
ans.push_back(std::move(r));
}
return ans;
}
} // namespace sherpa_onnx