Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Wei Kang
2023-08-28 19:39:22 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-08-28 19:39:22 +0800
Commit
2b0152d2a2f2b2dbb775520023f30ea435123b6f
2b0152d2
1 parent
49ec7e8f
Fix context graph (#292)
隐藏空白字符变更
内嵌
并排对比
正在显示
3 个修改的文件
包含
10 行增加
和
18 行删除
sherpa-onnx/csrc/context-graph-test.cc
sherpa-onnx/csrc/context-graph.cc
sherpa-onnx/csrc/context-graph.h
sherpa-onnx/csrc/context-graph-test.cc
查看文件 @
2b0152d
...
...
@@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) {
auto
context_graph
=
ContextGraph
(
contexts
,
1
);
auto
queries
=
std
::
map
<
std
::
string
,
float
>
{
{
"HEHERSHE"
,
14
},
{
"HERSHE"
,
12
},
{
"HISHE"
,
9
},
{
"SHED"
,
6
},
{
"HELL"
,
2
},
{
"HELLO"
,
7
},
{
"DHRHISQ"
,
4
},
{
"THEN"
,
2
}};
{
"HEHERSHE"
,
14
},
{
"HERSHE"
,
12
},
{
"HISHE"
,
9
},
{
"SHED"
,
6
},
{
"SHELF"
,
6
},
{
"HELL"
,
2
},
{
"HELLO"
,
7
},
{
"DHRHISQ"
,
4
},
{
"THEN"
,
2
}};
for
(
const
auto
&
iter
:
queries
)
{
float
total_scores
=
0
;
...
...
sherpa-onnx/csrc/context-graph.cc
查看文件 @
2b0152d
...
...
@@ -19,7 +19,7 @@ void ContextGraph::Build(
bool
is_end
=
j
==
token_ids
[
i
].
size
()
-
1
;
node
->
next
[
token
]
=
std
::
make_unique
<
ContextState
>
(
token
,
context_score_
,
node
->
node_score
+
context_score_
,
is_end
?
0
:
node
->
local_node_score
+
context_score_
,
is_end
);
is_end
?
node
->
node_score
+
context_score_
:
0
,
is_end
);
}
node
=
node
->
next
[
token
].
get
();
}
...
...
@@ -34,7 +34,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if
(
1
==
state
->
next
.
count
(
token
))
{
node
=
state
->
next
.
at
(
token
).
get
();
score
=
node
->
token_score
;
if
(
state
->
is_end
)
score
+=
state
->
node_score
;
}
else
{
node
=
state
->
fail
;
while
(
0
==
node
->
next
.
count
(
token
))
{
...
...
@@ -44,24 +43,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if
(
1
==
node
->
next
.
count
(
token
))
{
node
=
node
->
next
.
at
(
token
).
get
();
}
score
=
node
->
node_score
-
state
->
local_
node_score
;
score
=
node
->
node_score
-
state
->
node_score
;
}
SHERPA_ONNX_CHECK
(
nullptr
!=
node
);
float
matched_score
=
0
;
auto
output
=
node
->
output
;
while
(
nullptr
!=
output
)
{
matched_score
+=
output
->
node_score
;
output
=
output
->
output
;
}
return
std
::
make_pair
(
score
+
matched_score
,
node
);
return
std
::
make_pair
(
score
+
node
->
output_score
,
node
);
}
std
::
pair
<
float
,
const
ContextState
*>
ContextGraph
::
Finalize
(
const
ContextState
*
state
)
const
{
float
score
=
-
state
->
node_score
;
if
(
state
->
is_end
)
{
score
=
0
;
}
return
std
::
make_pair
(
score
,
root_
.
get
());
}
...
...
@@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const {
}
}
kv
.
second
->
output
=
output
;
kv
.
second
->
output_score
+=
output
==
nullptr
?
0
:
output
->
output_score
;
node_queue
.
push
(
kv
.
second
.
get
());
}
}
...
...
sherpa-onnx/csrc/context-graph.h
查看文件 @
2b0152d
...
...
@@ -21,7 +21,7 @@ struct ContextState {
int32_t
token
;
float
token_score
;
float
node_score
;
float
local_node
_score
;
float
output
_score
;
bool
is_end
;
std
::
unordered_map
<
int32_t
,
std
::
unique_ptr
<
ContextState
>>
next
;
const
ContextState
*
fail
=
nullptr
;
...
...
@@ -29,11 +29,11 @@ struct ContextState {
ContextState
()
=
default
;
ContextState
(
int32_t
token
,
float
token_score
,
float
node_score
,
float
local_node
_score
,
bool
is_end
)
float
output
_score
,
bool
is_end
)
:
token
(
token
),
token_score
(
token_score
),
node_score
(
node_score
),
local_node_score
(
local_node
_score
),
output_score
(
output
_score
),
is_end
(
is_end
)
{}
};
...
...
请
注册
或
登录
后发表评论