Wei Kang
Committed by GitHub

Fix context graph (#292)

@@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) { @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) {
22 auto context_graph = ContextGraph(contexts, 1); 22 auto context_graph = ContextGraph(contexts, 1);
23 23
24 auto queries = std::map<std::string, float>{ 24 auto queries = std::map<std::string, float>{
25 - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},  
26 - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; 25 + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
  26 + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
  27 + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
27 28
28 for (const auto &iter : queries) { 29 for (const auto &iter : queries) {
29 float total_scores = 0; 30 float total_scores = 0;
@@ -19,7 +19,7 @@ void ContextGraph::Build( @@ -19,7 +19,7 @@ void ContextGraph::Build(
19 bool is_end = j == token_ids[i].size() - 1; 19 bool is_end = j == token_ids[i].size() - 1;
20 node->next[token] = std::make_unique<ContextState>( 20 node->next[token] = std::make_unique<ContextState>(
21 token, context_score_, node->node_score + context_score_, 21 token, context_score_, node->node_score + context_score_,
22 - is_end ? 0 : node->local_node_score + context_score_, is_end); 22 + is_end ? node->node_score + context_score_ : 0, is_end);
23 } 23 }
24 node = node->next[token].get(); 24 node = node->next[token].get();
25 } 25 }
@@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( @@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
34 if (1 == state->next.count(token)) { 34 if (1 == state->next.count(token)) {
35 node = state->next.at(token).get(); 35 node = state->next.at(token).get();
36 score = node->token_score; 36 score = node->token_score;
37 - if (state->is_end) score += state->node_score;  
38 } else { 37 } else {
39 node = state->fail; 38 node = state->fail;
40 while (0 == node->next.count(token)) { 39 while (0 == node->next.count(token)) {
@@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( @@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
44 if (1 == node->next.count(token)) { 43 if (1 == node->next.count(token)) {
45 node = node->next.at(token).get(); 44 node = node->next.at(token).get();
46 } 45 }
47 - score = node->node_score - state->local_node_score; 46 + score = node->node_score - state->node_score;
48 } 47 }
49 SHERPA_ONNX_CHECK(nullptr != node); 48 SHERPA_ONNX_CHECK(nullptr != node);
50 - float matched_score = 0;  
51 - auto output = node->output;  
52 - while (nullptr != output) {  
53 - matched_score += output->node_score;  
54 - output = output->output;  
55 - }  
56 - return std::make_pair(score + matched_score, node); 49 + return std::make_pair(score + node->output_score, node);
57 } 50 }
58 51
59 std::pair<float, const ContextState *> ContextGraph::Finalize( 52 std::pair<float, const ContextState *> ContextGraph::Finalize(
60 const ContextState *state) const { 53 const ContextState *state) const {
61 float score = -state->node_score; 54 float score = -state->node_score;
62 - if (state->is_end) {  
63 - score = 0;  
64 - }  
65 return std::make_pair(score, root_.get()); 55 return std::make_pair(score, root_.get());
66 } 56 }
67 57
@@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const { @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
98 } 88 }
99 } 89 }
100 kv.second->output = output; 90 kv.second->output = output;
  91 + kv.second->output_score += output == nullptr ? 0 : output->output_score;
101 node_queue.push(kv.second.get()); 92 node_queue.push(kv.second.get());
102 } 93 }
103 } 94 }
@@ -21,7 +21,7 @@ struct ContextState { @@ -21,7 +21,7 @@ struct ContextState {
21 int32_t token; 21 int32_t token;
22 float token_score; 22 float token_score;
23 float node_score; 23 float node_score;
24 - float local_node_score; 24 + float output_score;
25 bool is_end; 25 bool is_end;
26 std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; 26 std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
27 const ContextState *fail = nullptr; 27 const ContextState *fail = nullptr;
@@ -29,11 +29,11 @@ struct ContextState { @@ -29,11 +29,11 @@ struct ContextState {
29 29
30 ContextState() = default; 30 ContextState() = default;
31 ContextState(int32_t token, float token_score, float node_score, 31 ContextState(int32_t token, float token_score, float node_score,
32 - float local_node_score, bool is_end) 32 + float output_score, bool is_end)
33 : token(token), 33 : token(token),
34 token_score(token_score), 34 token_score(token_score),
35 node_score(node_score), 35 node_score(node_score),
36 - local_node_score(local_node_score), 36 + output_score(output_score),
37 is_end(is_end) {} 37 is_end(is_end) {}
38 }; 38 };
39 39