context-graph-test.cc
2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// sherpa-onnx/csrc/context-graph-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <map>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
total_scores += res.first;
EXPECT_EQ(total_scores, iter.second);
}
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int32_t> char_dist(0, 25);
std::uniform_int_distribution<int32_t> len_dist(3, 8);
for (int32_t num = 10; num <= 10000; num *= 10) {
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < num; ++i) {
std::vector<int32_t> tmp;
int32_t word_len = len_dist(mt);
for (int32_t j = 0; j < word_len; ++j) {
tmp.push_back(char_dist(mt));
}
contexts.push_back(std::move(tmp));
}
auto start = std::chrono::high_resolution_clock::now();
auto context_graph = ContextGraph(contexts, 1);
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
static_cast<int32_t>(duration.count()));
}
}
} // namespace sherpa_onnx