Wei Kang
Committed by GitHub

Fix context graph (#292)

... ... @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) {
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
for (const auto &iter : queries) {
float total_scores = 0;
... ...
... ... @@ -19,7 +19,7 @@ void ContextGraph::Build(
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
is_end ? node->node_score + context_score_ : 0, is_end);
}
node = node->next[token].get();
}
... ... @@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
... ... @@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
return std::make_pair(score + node->output_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}
... ... @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
}
}
kv.second->output = output;
kv.second->output_score += output == nullptr ? 0 : output->output_score;
node_queue.push(kv.second.get());
}
}
... ...
... ... @@ -21,7 +21,7 @@ struct ContextState {
int32_t token;
float token_score;
float node_score;
float local_node_score;
float output_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
... ... @@ -29,11 +29,11 @@ struct ContextState {
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float local_node_score, bool is_end)
float output_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
local_node_score(local_node_score),
output_score(output_score),
is_end(is_end) {}
};
... ...