Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
PF Luo
2023-04-26 11:41:04 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-04-26 11:41:04 +0800
Commit
aa7108729bd07e773f03ea1a49eb8bb27367ff8b
aa710872
1 parent
86017f98
share GetHypsRowSplits interface and fix getting Topk not taking logprob (#131)
显示空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
51 行增加
和
37 行删除
sherpa-onnx/csrc/hypothesis.cc
sherpa-onnx/csrc/hypothesis.h
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/hypothesis.cc
查看文件 @
aa71087
...
...
@@ -66,4 +66,19 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
return
{
all_hyps
.
begin
(),
all_hyps
.
begin
()
+
k
};
}
const
std
::
vector
<
int32_t
>
GetHypsRowSplits
(
const
std
::
vector
<
Hypotheses
>
&
hyps
)
{
std
::
vector
<
int32_t
>
row_splits
;
row_splits
.
reserve
(
hyps
.
size
()
+
1
);
row_splits
.
push_back
(
0
);
int32_t
s
=
0
;
for
(
const
auto
&
h
:
hyps
)
{
s
+=
h
.
Size
();
row_splits
.
push_back
(
s
);
}
return
row_splits
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/hypothesis.h
查看文件 @
aa71087
...
...
@@ -121,6 +121,9 @@ class Hypotheses {
Map
hyps_dict_
;
};
const
std
::
vector
<
int32_t
>
GetHypsRowSplits
(
const
std
::
vector
<
Hypotheses
>
&
hyps
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
查看文件 @
aa71087
...
...
@@ -15,21 +15,6 @@
namespace
sherpa_onnx
{
static
std
::
vector
<
int32_t
>
GetHypsRowSplits
(
const
std
::
vector
<
Hypotheses
>
&
hyps
)
{
std
::
vector
<
int32_t
>
row_splits
;
row_splits
.
reserve
(
hyps
.
size
()
+
1
);
row_splits
.
push_back
(
0
);
int32_t
s
=
0
;
for
(
const
auto
&
h
:
hyps
)
{
s
+=
h
.
Size
();
row_splits
.
push_back
(
s
);
}
return
row_splits
;
}
std
::
vector
<
OfflineTransducerDecoderResult
>
OfflineTransducerModifiedBeamSearchDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
)
{
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
查看文件 @
aa71087
...
...
@@ -14,7 +14,7 @@
namespace
sherpa_onnx
{
static
void
UseCachedDecoderOut
(
const
std
::
vector
<
int32_t
>
&
hyps_
num_split
,
const
std
::
vector
<
int32_t
>
&
hyps_
row_splits
,
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
,
int32_t
context_size
,
Ort
::
Value
*
decoder_out
)
{
std
::
vector
<
int64_t
>
shape
=
...
...
@@ -24,7 +24,7 @@ static void UseCachedDecoderOut(
int32_t
batch_size
=
static_cast
<
int32_t
>
(
results
.
size
());
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
int32_t
num_hyps
=
hyps_
num_split
[
i
+
1
]
-
hyps_num_split
[
i
];
int32_t
num_hyps
=
hyps_
row_splits
[
i
+
1
]
-
hyps_row_splits
[
i
];
if
(
num_hyps
>
1
||
!
results
[
i
].
decoder_out
)
{
dst
+=
num_hyps
*
shape
[
1
];
continue
;
...
...
@@ -86,17 +86,14 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
// Due to merging paths with identical token sequences,
// not all utterances have "num_active_paths" paths.
int32_t
hyps_num_acc
=
0
;
std
::
vector
<
int32_t
>
hyps_num_split
;
hyps_num_split
.
push_back
(
0
);
auto
hyps_row_splits
=
GetHypsRowSplits
(
cur
);
int32_t
num_hyps
=
hyps_row_splits
.
back
();
// total num hyps for all utterance
prev
.
clear
();
for
(
auto
&
hyps
:
cur
)
{
for
(
auto
&
h
:
hyps
)
{
prev
.
push_back
(
std
::
move
(
h
.
second
));
hyps_num_acc
++
;
}
hyps_num_split
.
push_back
(
hyps_num_acc
);
}
cur
.
clear
();
cur
.
reserve
(
batch_size
);
...
...
@@ -104,30 +101,44 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
prev
);
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
if
(
t
==
0
)
{
UseCachedDecoderOut
(
hyps_
num_split
,
*
result
,
model_
->
ContextSize
(),
UseCachedDecoderOut
(
hyps_
row_splits
,
*
result
,
model_
->
ContextSize
(),
&
decoder_out
);
}
Ort
::
Value
cur_encoder_out
=
GetEncoderOutFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
hyps_
num_split
);
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
hyps_
row_splits
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
LogSoftmax
(
p_logit
,
vocab_size
,
num_hyps
);
// now p_logit contains log_softmax output, we rename it to p_logprob
// to match what it actually contains
float
*
p_logprob
=
p_logit
;
for
(
int32_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
// add log_prob of each hypothesis to p_logprob before taking top_k
for
(
int32_t
i
=
0
;
i
!=
num_hyps
;
++
i
)
{
float
log_prob
=
prev
[
i
].
log_prob
;
for
(
int32_t
k
=
0
;
k
!=
vocab_size
;
++
k
,
++
p_logprob
)
{
*
p_logprob
+=
log_prob
;
}
}
p_logprob
=
p_logit
;
// we changed p_logprob in the above for loop
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
int32_t
frame_offset
=
(
*
result
)[
b
].
frame_offset
;
int32_t
start
=
hyps_num_split
[
b
];
int32_t
end
=
hyps_num_split
[
b
+
1
];
LogSoftmax
(
p_logit
,
vocab_size
,
(
end
-
start
));
int32_t
start
=
hyps_row_splits
[
b
];
int32_t
end
=
hyps_row_splits
[
b
+
1
];
auto
topk
=
TopkIndex
(
p_log
it
,
vocab_size
*
(
end
-
start
),
max_active_paths_
);
TopkIndex
(
p_log
prob
,
vocab_size
*
(
end
-
start
),
max_active_paths_
);
Hypotheses
hyps
;
for
(
auto
i
:
topk
)
{
int32_t
hyp_index
=
i
/
vocab_size
+
start
;
int32_t
new_token
=
i
%
vocab_size
;
for
(
auto
k
:
topk
)
{
int32_t
hyp_index
=
k
/
vocab_size
+
start
;
int32_t
new_token
=
k
%
vocab_size
;
Hypothesis
new_hyp
=
prev
[
hyp_index
];
if
(
new_token
!=
0
)
{
...
...
@@ -137,12 +148,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
else
{
++
new_hyp
.
num_trailing_blanks
;
}
new_hyp
.
log_prob
+=
p_logit
[
i
];
new_hyp
.
log_prob
=
p_logprob
[
k
];
hyps
.
Add
(
std
::
move
(
new_hyp
));
}
}
// for (auto k : topk)
cur
.
push_back
(
std
::
move
(
hyps
));
p_logit
+=
vocab_size
*
(
end
-
start
);
}
p_logprob
+=
(
end
-
start
)
*
vocab_size
;
}
// for (int32_t b = 0; b != batch_size; ++b)
}
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
...
...
请
注册
或
登录
后发表评论