Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-08-15 11:42:46 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-08-15 11:42:46 +0800
Commit
5c0f7f69df47df2f7e54c181ce2a5a922bba9f81
5c0f7f69
1 parent
6b1ddbd2
Support TDT transducer decoding (#2495)
隐藏空白字符变更
内嵌
并排对比
正在显示
6 个修改的文件
包含
122 行增加
和
6 行删除
scripts/nemo/parakeet-tdt-0.6b-v2/test_onnx.py
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
sherpa-onnx/csrc/offline-transducer-nemo-model.h
scripts/nemo/parakeet-tdt-0.6b-v2/test_onnx.py
查看文件 @
5c0f7f6
...
...
@@ -207,6 +207,7 @@ def main():
for
line
in
f
:
t
,
idx
=
line
.
split
()
id2token
[
int
(
idx
)]
=
t
vocab_size
=
len
(
id2token
)
start
=
time
.
time
()
fbank
=
create_fbank
()
...
...
@@ -242,12 +243,21 @@ def main():
encoder_out
=
model
.
run_encoder
(
features
)
# encoder_out:[batch_size, dim, T)
for
t
in
range
(
encoder_out
.
shape
[
2
]):
t
=
0
while
t
<
encoder_out
.
shape
[
2
]:
encoder_out_t
=
encoder_out
[:,
:,
t
:
t
+
1
]
logits
=
model
.
run_joiner
(
encoder_out_t
,
decoder_out
)
logits
=
torch
.
from_numpy
(
logits
)
logits
=
logits
.
squeeze
()
idx
=
torch
.
argmax
(
logits
,
dim
=-
1
)
.
item
()
token_logits
=
logits
[:
vocab_size
]
duration_logits
=
logits
[
vocab_size
:]
idx
=
torch
.
argmax
(
token_logits
,
dim
=-
1
)
.
item
()
skip
=
torch
.
argmax
(
duration_logits
,
dim
=-
1
)
.
item
()
if
skip
==
0
:
skip
=
1
if
idx
!=
blank
:
ans
.
append
(
idx
)
state0
=
state0_next
...
...
@@ -255,6 +265,7 @@ def main():
decoder_out
,
state0_next
,
state1_next
=
model
.
run_decoder
(
ans
[
-
1
],
state0
,
state1
)
t
+=
skip
end
=
time
.
time
()
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
查看文件 @
5c0f7f6
...
...
@@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
config_
.
model_config
))
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchNeMoDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
model_
.
get
(),
config_
.
blank_penalty
,
model_
->
IsTDT
()
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
查看文件 @
5c0f7f6
...
...
@@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne(
return
ans
;
}
static
OfflineTransducerDecoderResult
DecodeOneTDT
(
const
float
*
p
,
int32_t
num_rows
,
int32_t
num_cols
,
OfflineTransducerNeMoModel
*
model
,
float
blank_penalty
)
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
OfflineTransducerDecoderResult
ans
;
int32_t
vocab_size
=
model
->
VocabSize
();
int32_t
blank_id
=
vocab_size
-
1
;
auto
decoder_input_pair
=
BuildDecoderInput
(
blank_id
,
model
->
Allocator
());
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
model
->
GetDecoderInitStates
(
1
));
std
::
array
<
int64_t
,
3
>
encoder_shape
{
1
,
num_cols
,
1
};
int32_t
skip
=
0
;
for
(
int32_t
t
=
0
;
t
<
num_rows
;
t
+=
skip
)
{
Ort
::
Value
cur_encoder_out
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
float
*>
(
p
)
+
t
*
num_cols
,
num_cols
,
encoder_shape
.
data
(),
encoder_shape
.
size
());
Ort
::
Value
logit
=
model
->
RunJoiner
(
View
(
&
cur_encoder_out
),
View
(
&
decoder_output_pair
.
first
));
auto
shape
=
logit
.
GetTensorTypeAndShapeInfo
().
GetShape
();
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty
>
0
)
{
p_logit
[
blank_id
]
-=
blank_penalty
;
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
skip
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
,
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
,
static_cast
<
const
float
*>
(
p_logit
)
+
shape
.
back
())));
if
(
skip
==
0
)
{
skip
=
1
;
}
if
(
y
!=
blank_id
)
{
ans
.
tokens
.
push_back
(
y
);
ans
.
timestamps
.
push_back
(
t
);
decoder_input_pair
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
std
::
move
(
decoder_output_pair
.
second
));
}
}
// for (int32_t t = 0; t < num_rows; ++t) {
return
ans
;
}
std
::
vector
<
OfflineTransducerDecoderResult
>
OfflineTransducerGreedySearchNeMoDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
...
...
@@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
?
encoder_out_length
.
GetTensorData
<
int32_t
>
()[
i
]
:
encoder_out_length
.
GetTensorData
<
int64_t
>
()[
i
];
ans
[
i
]
=
DecodeOne
(
this_p
,
this_len
,
dim2
,
model_
,
blank_penalty_
);
if
(
is_tdt_
)
{
ans
[
i
]
=
DecodeOneTDT
(
this_p
,
this_len
,
dim2
,
model_
,
blank_penalty_
);
}
else
{
ans
[
i
]
=
DecodeOne
(
this_p
,
this_len
,
dim2
,
model_
,
blank_penalty_
);
}
}
return
ans
;
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
查看文件 @
5c0f7f6
...
...
@@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder
:
public
OfflineTransducerDecoder
{
public
:
OfflineTransducerGreedySearchNeMoDecoder
(
OfflineTransducerNeMoModel
*
model
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
float
blank_penalty
,
bool
is_tdt
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
),
is_tdt_
(
is_tdt
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
...
...
@@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder
private
:
OfflineTransducerNeMoModel
*
model_
;
// Not owned
float
blank_penalty_
;
bool
is_tdt_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
查看文件 @
5c0f7f6
...
...
@@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl {
std
::
string
FeatureNormalizationMethod
()
const
{
return
normalize_type_
;
}
bool
IsGigaAM
()
const
{
return
is_giga_am_
;
}
bool
IsTDT
()
const
{
return
is_tdt_
;
}
int32_t
FeatureDim
()
const
{
return
feat_dim_
;
}
...
...
@@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl {
if
(
normalize_type_
==
"NA"
)
{
normalize_type_
=
""
;
}
std
::
string
url
;
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY
(
url
,
"url"
);
if
(
url
.
find
(
"tdt"
)
!=
std
::
string
::
npos
)
{
is_tdt_
=
1
;
}
}
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
...
...
@@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl {
GetOutputNames
(
joiner_sess_
.
get
(),
&
joiner_output_names_
,
&
joiner_output_names_ptr_
);
auto
shape
=
joiner_sess_
->
GetOutputTypeInfo
(
0
)
.
GetTensorTypeAndShapeInfo
()
.
GetShape
();
int32_t
output_size
=
shape
.
back
();
if
(
is_tdt_
)
{
if
(
vocab_size_
==
output_size
)
{
SHERPA_ONNX_LOGE
(
"It is not a TDT model!"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
if
(
config_
.
debug
)
{
SHERPA_ONNX_LOGE
(
"TDT model. vocab_size: %d, num_durations: %d"
,
vocab_size_
,
output_size
-
vocab_size_
);
}
}
else
if
(
vocab_size_
!=
output_size
)
{
SHERPA_ONNX_LOGE
(
"vocab_size: %d != output_size: %d"
,
vocab_size_
,
output_size
);
SHERPA_ONNX_EXIT
(
-
1
);
}
}
private
:
...
...
@@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl {
int32_t
pred_rnn_layers_
=
-
1
;
int32_t
pred_hidden_
=
-
1
;
int32_t
is_giga_am_
=
0
;
int32_t
is_tdt_
=
0
;
// giga am uses 64
// parakeet-tdt-0.6b-v2 uses 128
...
...
@@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
bool
OfflineTransducerNeMoModel
::
IsGigaAM
()
const
{
return
impl_
->
IsGigaAM
();
}
bool
OfflineTransducerNeMoModel
::
IsTDT
()
const
{
return
impl_
->
IsTDT
();
}
int32_t
OfflineTransducerNeMoModel
::
FeatureDim
()
const
{
return
impl_
->
FeatureDim
();
}
...
...
sherpa-onnx/csrc/offline-transducer-nemo-model.h
查看文件 @
5c0f7f6
...
...
@@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel {
bool
IsGigaAM
()
const
;
// true if it is a Token-and-Duration Transducer model
// false otherwise
bool
IsTDT
()
const
;
int32_t
FeatureDim
()
const
;
private
:
...
...
请
注册
或
登录
后发表评论