正在显示
3 个修改的文件
包含
10 行增加
和
18 行删除
| @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) { | @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) { | ||
| 22 | auto context_graph = ContextGraph(contexts, 1); | 22 | auto context_graph = ContextGraph(contexts, 1); |
| 23 | 23 | ||
| 24 | auto queries = std::map<std::string, float>{ | 24 | auto queries = std::map<std::string, float>{ |
| 25 | - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, | ||
| 26 | - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; | 25 | + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, |
| 26 | + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, | ||
| 27 | + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; | ||
| 27 | 28 | ||
| 28 | for (const auto &iter : queries) { | 29 | for (const auto &iter : queries) { |
| 29 | float total_scores = 0; | 30 | float total_scores = 0; |
| @@ -19,7 +19,7 @@ void ContextGraph::Build( | @@ -19,7 +19,7 @@ void ContextGraph::Build( | ||
| 19 | bool is_end = j == token_ids[i].size() - 1; | 19 | bool is_end = j == token_ids[i].size() - 1; |
| 20 | node->next[token] = std::make_unique<ContextState>( | 20 | node->next[token] = std::make_unique<ContextState>( |
| 21 | token, context_score_, node->node_score + context_score_, | 21 | token, context_score_, node->node_score + context_score_, |
| 22 | - is_end ? 0 : node->local_node_score + context_score_, is_end); | 22 | + is_end ? node->node_score + context_score_ : 0, is_end); |
| 23 | } | 23 | } |
| 24 | node = node->next[token].get(); | 24 | node = node->next[token].get(); |
| 25 | } | 25 | } |
| @@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | @@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
| 34 | if (1 == state->next.count(token)) { | 34 | if (1 == state->next.count(token)) { |
| 35 | node = state->next.at(token).get(); | 35 | node = state->next.at(token).get(); |
| 36 | score = node->token_score; | 36 | score = node->token_score; |
| 37 | - if (state->is_end) score += state->node_score; | ||
| 38 | } else { | 37 | } else { |
| 39 | node = state->fail; | 38 | node = state->fail; |
| 40 | while (0 == node->next.count(token)) { | 39 | while (0 == node->next.count(token)) { |
| @@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | @@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
| 44 | if (1 == node->next.count(token)) { | 43 | if (1 == node->next.count(token)) { |
| 45 | node = node->next.at(token).get(); | 44 | node = node->next.at(token).get(); |
| 46 | } | 45 | } |
| 47 | - score = node->node_score - state->local_node_score; | 46 | + score = node->node_score - state->node_score; |
| 48 | } | 47 | } |
| 49 | SHERPA_ONNX_CHECK(nullptr != node); | 48 | SHERPA_ONNX_CHECK(nullptr != node); |
| 50 | - float matched_score = 0; | ||
| 51 | - auto output = node->output; | ||
| 52 | - while (nullptr != output) { | ||
| 53 | - matched_score += output->node_score; | ||
| 54 | - output = output->output; | ||
| 55 | - } | ||
| 56 | - return std::make_pair(score + matched_score, node); | 49 | + return std::make_pair(score + node->output_score, node); |
| 57 | } | 50 | } |
| 58 | 51 | ||
| 59 | std::pair<float, const ContextState *> ContextGraph::Finalize( | 52 | std::pair<float, const ContextState *> ContextGraph::Finalize( |
| 60 | const ContextState *state) const { | 53 | const ContextState *state) const { |
| 61 | float score = -state->node_score; | 54 | float score = -state->node_score; |
| 62 | - if (state->is_end) { | ||
| 63 | - score = 0; | ||
| 64 | - } | ||
| 65 | return std::make_pair(score, root_.get()); | 55 | return std::make_pair(score, root_.get()); |
| 66 | } | 56 | } |
| 67 | 57 | ||
| @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const { | @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const { | ||
| 98 | } | 88 | } |
| 99 | } | 89 | } |
| 100 | kv.second->output = output; | 90 | kv.second->output = output; |
| 91 | + kv.second->output_score += output == nullptr ? 0 : output->output_score; | ||
| 101 | node_queue.push(kv.second.get()); | 92 | node_queue.push(kv.second.get()); |
| 102 | } | 93 | } |
| 103 | } | 94 | } |
| @@ -21,7 +21,7 @@ struct ContextState { | @@ -21,7 +21,7 @@ struct ContextState { | ||
| 21 | int32_t token; | 21 | int32_t token; |
| 22 | float token_score; | 22 | float token_score; |
| 23 | float node_score; | 23 | float node_score; |
| 24 | - float local_node_score; | 24 | + float output_score; |
| 25 | bool is_end; | 25 | bool is_end; |
| 26 | std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; | 26 | std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; |
| 27 | const ContextState *fail = nullptr; | 27 | const ContextState *fail = nullptr; |
| @@ -29,11 +29,11 @@ struct ContextState { | @@ -29,11 +29,11 @@ struct ContextState { | ||
| 29 | 29 | ||
| 30 | ContextState() = default; | 30 | ContextState() = default; |
| 31 | ContextState(int32_t token, float token_score, float node_score, | 31 | ContextState(int32_t token, float token_score, float node_score, |
| 32 | - float local_node_score, bool is_end) | 32 | + float output_score, bool is_end) |
| 33 | : token(token), | 33 | : token(token), |
| 34 | token_score(token_score), | 34 | token_score(token_score), |
| 35 | node_score(node_score), | 35 | node_score(node_score), |
| 36 | - local_node_score(local_node_score), | 36 | + output_score(output_score), |
| 37 | is_end(is_end) {} | 37 | is_end(is_end) {} |
| 38 | }; | 38 | }; |
| 39 | 39 |
-
请 注册 或 登录 后发表评论