Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
彭震东
2023-04-09 23:28:10 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-04-09 23:28:10 +0800
Commit
d781fcdeefbea58834167add758ebf38c75e9061
d781fcde
1 parent
80060c27
Use log probs for paraformer (#120)
* Use log probs for paraformer * Fix
显示空白字符变更
内嵌
并排对比
正在显示
3 个修改的文件
包含
15 行增加
和
11 行删除
sherpa-onnx/csrc/offline-paraformer-decoder.h
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
sherpa-onnx/csrc/offline-paraformer-decoder.h
查看文件 @
d781fcd
...
...
@@ -23,8 +23,7 @@ class OfflineParaformerDecoder {
/** Run beam search given the output from the paraformer model.
*
* @param log_probs A 3-D tensor of shape (N, T, vocab_size)
* @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t.
* log_probs[i].argmax(axis=-1) equals to token_num[i]
* @param token_num A 1-D tensor of shape (N). token_num equals to T.
*
* @return Return a vector of size `N` containing the decoded results.
*/
...
...
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
查看文件 @
d781fcd
...
...
@@ -4,28 +4,33 @@
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
#include <algorithm>
#include <vector>
namespace
sherpa_onnx
{
std
::
vector
<
OfflineParaformerDecoderResult
>
OfflineParaformerGreedySearchDecoder
::
Decode
(
Ort
::
Value
/*log_probs*/
,
Ort
::
Value
token_num
)
{
std
::
vector
<
int64_t
>
shape
=
token_num
.
GetTensorTypeAndShapeInfo
().
GetShape
();
OfflineParaformerGreedySearchDecoder
::
Decode
(
Ort
::
Value
log_probs
,
Ort
::
Value
/*token_num*/
)
{
std
::
vector
<
int64_t
>
shape
=
log_probs
.
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
batch_size
=
shape
[
0
];
int32_t
num_tokens
=
shape
[
1
];
int32_t
vocab_size
=
shape
[
2
];
std
::
vector
<
OfflineParaformerDecoderResult
>
results
(
batch_size
);
const
int64_t
*
p
=
token_num
.
GetTensorData
<
int64_t
>
();
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
const
float
*
p
=
log_probs
.
GetTensorData
<
float
>
()
+
i
*
num_tokens
*
vocab_size
;
for
(
int32_t
k
=
0
;
k
!=
num_tokens
;
++
k
)
{
if
(
p
[
k
]
==
eos_id_
)
break
;
auto
max_idx
=
static_cast
<
int64_t
>
(
std
::
distance
(
p
,
std
::
max_element
(
p
,
p
+
vocab_size
)));
if
(
max_idx
==
eos_id_
)
break
;
results
[
i
].
tokens
.
push_back
(
p
[
k
]);
}
results
[
i
].
tokens
.
push_back
(
max_idx
);
p
+=
num_tokens
;
p
+=
vocab_size
;
}
}
return
results
;
...
...
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
查看文件 @
d781fcd
...
...
@@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
:
eos_id_
(
eos_id
)
{}
std
::
vector
<
OfflineParaformerDecoderResult
>
Decode
(
Ort
::
Value
/*log_probs*/
,
Ort
::
Value
token_num
)
override
;
Ort
::
Value
log_probs
,
Ort
::
Value
/*token_num*/
)
override
;
private
:
int32_t
eos_id_
;
...
...
请
注册
或
登录
后发表评论