context-graph.cc
2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// sherpa-onnx/csrc/context-graph.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <cassert>
#include <queue>
#include <utility>
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
}
node = node->next[token].get();
}
}
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
node = node->fail;
if (-1 == node->token) break; // root
}
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
return std::make_pair(score + node->output_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
return std::make_pair(score, root_.get());
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
kv.second->fail = root_.get();
node_queue.push(kv.second.get());
}
while (!node_queue.empty()) {
auto current_node = node_queue.front();
node_queue.pop();
for (auto &kv : current_node->next) {
auto fail = current_node->fail;
if (1 == fail->next.count(kv.first)) {
fail = fail->next.at(kv.first).get();
} else {
fail = fail->fail;
while (0 == fail->next.count(kv.first)) {
fail = fail->fail;
if (-1 == fail->token) break;
}
if (1 == fail->next.count(kv.first))
fail = fail->next.at(kv.first).get();
}
kv.second->fail = fail;
// fill the output arc
auto output = fail;
while (!output->is_end) {
output = output->fail;
if (-1 == output->token) {
output = nullptr;
break;
}
}
kv.second->output = output;
kv.second->output_score += output == nullptr ? 0 : output->output_score;
node_queue.push(kv.second.get());
}
}
}
} // namespace sherpa_onnx