Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Nickolay V. Shmyrev
2025-07-26 18:12:28 +0300
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-07-26 23:12:28 +0800
Commit
10e845a8bad2acbc44ab7de5a52a6ef9e07bcf95
10e845a8
1 parent
c1445749
Implement max_symbols_per_frame for GigaAM2 accurate decoding since model uses c…
…har tokens instead of BPE. (#2423)
隐藏空白字符变更
内嵌
并排对比
正在显示
1 个修改的文件
包含
29 行增加
和
24 行删除
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
查看文件 @
10e845a
...
...
@@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(
int32_t
vocab_size
=
model
->
VocabSize
();
int32_t
blank_id
=
vocab_size
-
1
;
int32_t
max_symbols_per_frame
=
10
;
auto
decoder_input_pair
=
BuildDecoderInput
(
blank_id
,
model
->
Allocator
());
...
...
@@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne(
memory_info
,
const_cast
<
float
*>
(
p
)
+
t
*
num_cols
,
num_cols
,
encoder_shape
.
data
(),
encoder_shape
.
size
());
Ort
::
Value
logit
=
model
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_output_pair
.
first
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty
>
0
)
{
p_logit
[
blank_id
]
-=
blank_penalty
;
for
(
int32_t
q
=
0
;
q
!=
max_symbols_per_frame
;
++
q
)
{
Ort
::
Value
logit
=
model
->
RunJoiner
(
View
(
&
cur_encoder_out
),
View
(
&
decoder_output_pair
.
first
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty
>
0
)
{
p_logit
[
blank_id
]
-=
blank_penalty
;
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
blank_id
)
{
ans
.
tokens
.
push_back
(
y
);
ans
.
timestamps
.
push_back
(
t
);
decoder_input_pair
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
std
::
move
(
decoder_output_pair
.
second
));
}
else
{
break
;
}
// if (y != blank_id)
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
blank_id
)
{
ans
.
tokens
.
push_back
(
y
);
ans
.
timestamps
.
push_back
(
t
);
decoder_input_pair
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
std
::
move
(
decoder_output_pair
.
second
));
}
// if (y != blank_id)
}
// for (int32_t i = 0; i != num_rows; ++i)
return
ans
;
...
...
@@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t
dim1
=
static_cast
<
int32_t
>
(
shape
[
1
]);
int32_t
dim2
=
static_cast
<
int32_t
>
(
shape
[
2
]);
const
int
64_t
*
p_length
=
encoder_out_length
.
GetTensorData
<
int64
_t
>
();
const
int
32_t
*
p_length
=
encoder_out_length
.
GetTensorData
<
int32
_t
>
();
const
float
*
p
=
encoder_out
.
GetTensorData
<
float
>
();
std
::
vector
<
OfflineTransducerDecoderResult
>
ans
(
batch_size
);
...
...
请
注册
或
登录
后发表评论