Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-03-03 13:15:25 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-03-03 13:15:25 +0800
Commit
c9d6859df7922279dbfeae6dfb36354c41919175
c9d6859d
1 parent
d5e7b51a
Add transducer modified_beam_search for RKNN. (#1949)
隐藏空白字符变更
内嵌
并排对比
正在显示
10 个修改的文件
包含
290 行增加
和
25 行删除
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/hypothesis.h
sherpa-onnx/csrc/math.h
sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h
sherpa-onnx/csrc/rknn/online-stream-rknn.h
sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h
sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h
sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc
sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h
sherpa-onnx/csrc/rknn/utils.cc
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
c9d6859
...
...
@@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
list
(
APPEND sources
./rknn/online-stream-rknn.cc
./rknn/online-transducer-greedy-search-decoder-rknn.cc
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/utils.cc
...
...
sherpa-onnx/csrc/hypothesis.h
查看文件 @
c9d6859
...
...
@@ -142,7 +142,6 @@ class Hypotheses {
void
Clear
()
{
hyps_dict_
.
clear
();
}
private
:
// Return a list of hyps contained in this object.
std
::
vector
<
Hypothesis
>
Vec
()
const
{
std
::
vector
<
Hypothesis
>
ans
;
...
...
sherpa-onnx/csrc/math.h
查看文件 @
c9d6859
...
...
@@ -119,5 +119,17 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
return
{
vec_index
.
begin
(),
vec_index
.
begin
()
+
k_num
};
}
template
<
class
T
>
std
::
vector
<
int32_t
>
TopkIndex
(
const
std
::
vector
<
std
::
vector
<
T
>>
&
vec
,
int32_t
topk
)
{
std
::
vector
<
T
>
flatten
;
flatten
.
reserve
(
vec
.
size
()
*
vec
[
0
].
size
());
for
(
const
auto
&
v
:
vec
)
{
flatten
.
insert
(
flatten
.
end
(),
v
.
begin
(),
v
.
end
());
}
return
TopkIndex
(
flatten
.
data
(),
flatten
.
size
(),
topk
);
}
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_MATH_H_
...
...
sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h
查看文件 @
c9d6859
...
...
@@ -16,7 +16,9 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
#include "sherpa-onnx/csrc/symbol-table.h"
...
...
@@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
unk_id_
=
sym_
[
"<unk>"
];
}
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoderRknn
>
(
model_
.
get
(),
unk_id_
);
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoderRknn
>
(
model_
.
get
(),
unk_id_
);
}
else
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoderRknn
>
(
model_
.
get
(),
config
.
max_active_paths
,
unk_id_
);
}
else
{
SHERPA_ONNX_LOGE
(
"Invalid decoding method: '%s'. Support only greedy_search and "
"modified_beam_search."
,
config
.
decoding_method
.
c_str
());
SHERPA_ONNX_EXIT
(
-
1
);
}
}
template
<
typename
Manager
>
...
...
@@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
Endpoint
endpoint_
;
int32_t
unk_id_
=
-
1
;
std
::
unique_ptr
<
OnlineZipformerTransducerModelRknn
>
model_
;
std
::
unique_ptr
<
OnlineTransducer
GreedySearch
DecoderRknn
>
decoder_
;
std
::
unique_ptr
<
OnlineTransducerDecoderRknn
>
decoder_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/rknn/online-stream-rknn.h
查看文件 @
c9d6859
...
...
@@ -8,7 +8,7 @@
#include "rknn_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-
greedy-search-
decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
namespace
sherpa_onnx
{
...
...
sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h
0 → 100644
查看文件 @
c9d6859
// sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
#include <vector>
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
struct
OnlineTransducerDecoderResultRknn
{
/// Number of frames after subsampling we have decoded so far
int32_t
frame_offset
=
0
;
/// The decoded token IDs so far
std
::
vector
<
int64_t
>
tokens
;
/// number of trailing blank frames decoded so far
int32_t
num_trailing_blanks
=
0
;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std
::
vector
<
int32_t
>
timestamps
;
// used only by greedy_search
std
::
vector
<
float
>
previous_decoder_out
;
// used only in modified beam_search
Hypotheses
hyps
;
// used only by modified_beam_search
std
::
vector
<
std
::
vector
<
float
>>
previous_decoder_out2
;
};
class
OnlineTransducerDecoderRknn
{
public
:
virtual
~
OnlineTransducerDecoderRknn
()
=
default
;
/* Return an empty result.
*
* To simplify the decoding code, we add `context_size` blanks
* to the beginning of the decoding result, which will be
* stripped by calling `StripPrecedingBlanks()`.
*/
virtual
OnlineTransducerDecoderResultRknn
GetEmptyResult
()
const
=
0
;
/** Strip blanks added by `GetEmptyResult()`.
*
* @param r It is changed in-place.
*/
virtual
void
StripLeadingBlanks
(
OnlineTransducerDecoderResultRknn
*
/*r*/
)
const
{}
virtual
void
Decode
(
std
::
vector
<
float
>
encoder_out
,
OnlineTransducerDecoderResultRknn
*
result
)
const
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_
...
...
sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h
查看文件 @
c9d6859
...
...
@@ -7,39 +7,26 @@
#include <vector>
#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
namespace
sherpa_onnx
{
struct
OnlineTransducerDecoderResultRknn
{
/// Number of frames after subsampling we have decoded so far
int32_t
frame_offset
=
0
;
/// The decoded token IDs so far
std
::
vector
<
int64_t
>
tokens
;
/// number of trailing blank frames decoded so far
int32_t
num_trailing_blanks
=
0
;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std
::
vector
<
int32_t
>
timestamps
;
std
::
vector
<
float
>
previous_decoder_out
;
};
class
OnlineTransducerGreedySearchDecoderRknn
{
class
OnlineTransducerGreedySearchDecoderRknn
:
public
OnlineTransducerDecoderRknn
{
public
:
explicit
OnlineTransducerGreedySearchDecoderRknn
(
OnlineZipformerTransducerModelRknn
*
model
,
int32_t
unk_id
=
2
,
float
blank_penalty
=
0
.
0
)
:
model_
(
model
),
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
OnlineTransducerDecoderResultRknn
GetEmptyResult
()
const
;
OnlineTransducerDecoderResultRknn
GetEmptyResult
()
const
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResultRknn
*
r
)
const
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResultRknn
*
r
)
const
override
;
void
Decode
(
std
::
vector
<
float
>
encoder_out
,
OnlineTransducerDecoderResultRknn
*
result
)
const
;
OnlineTransducerDecoderResultRknn
*
result
)
const
override
;
private
:
OnlineZipformerTransducerModelRknn
*
model_
;
// Not owned
...
...
sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc
0 → 100644
查看文件 @
c9d6859
// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
namespace
sherpa_onnx
{
OnlineTransducerDecoderResultRknn
OnlineTransducerModifiedBeamSearchDecoderRknn
::
GetEmptyResult
()
const
{
int32_t
context_size
=
model_
->
ContextSize
();
int32_t
blank_id
=
0
;
// always 0
OnlineTransducerDecoderResultRknn
r
;
std
::
vector
<
int64_t
>
blanks
(
context_size
,
-
1
);
blanks
.
back
()
=
blank_id
;
Hypotheses
blank_hyp
({{
blanks
,
0
}});
r
.
hyps
=
std
::
move
(
blank_hyp
);
r
.
tokens
=
std
::
move
(
blanks
);
return
r
;
}
void
OnlineTransducerModifiedBeamSearchDecoderRknn
::
StripLeadingBlanks
(
OnlineTransducerDecoderResultRknn
*
r
)
const
{
int32_t
context_size
=
model_
->
ContextSize
();
auto
hyp
=
r
->
hyps
.
GetMostProbable
(
true
);
std
::
vector
<
int64_t
>
tokens
(
hyp
.
ys
.
begin
()
+
context_size
,
hyp
.
ys
.
end
());
r
->
tokens
=
std
::
move
(
tokens
);
r
->
timestamps
=
std
::
move
(
hyp
.
timestamps
);
r
->
num_trailing_blanks
=
hyp
.
num_trailing_blanks
;
}
static
std
::
vector
<
std
::
vector
<
float
>>
GetDecoderOut
(
OnlineZipformerTransducerModelRknn
*
model
,
const
Hypotheses
&
hyp_vec
)
{
std
::
vector
<
std
::
vector
<
float
>>
ans
;
ans
.
reserve
(
hyp_vec
.
Size
());
int32_t
context_size
=
model
->
ContextSize
();
for
(
const
auto
&
p
:
hyp_vec
)
{
const
auto
&
hyp
=
p
.
second
;
auto
start
=
hyp
.
ys
.
begin
()
+
(
hyp
.
ys
.
size
()
-
context_size
);
auto
end
=
hyp
.
ys
.
end
();
auto
tokens
=
std
::
vector
<
int64_t
>
(
start
,
end
);
auto
decoder_out
=
model
->
RunDecoder
(
std
::
move
(
tokens
));
ans
.
push_back
(
std
::
move
(
decoder_out
));
}
return
ans
;
}
static
std
::
vector
<
std
::
vector
<
float
>>
GetJoinerOutLogSoftmax
(
OnlineZipformerTransducerModelRknn
*
model
,
const
float
*
p_encoder_out
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
decoder_out
)
{
std
::
vector
<
std
::
vector
<
float
>>
ans
;
ans
.
reserve
(
decoder_out
.
size
());
for
(
const
auto
&
d
:
decoder_out
)
{
auto
joiner_out
=
model
->
RunJoiner
(
p_encoder_out
,
d
.
data
());
LogSoftmax
(
joiner_out
.
data
(),
joiner_out
.
size
());
ans
.
push_back
(
std
::
move
(
joiner_out
));
}
return
ans
;
}
void
OnlineTransducerModifiedBeamSearchDecoderRknn
::
Decode
(
std
::
vector
<
float
>
encoder_out
,
OnlineTransducerDecoderResultRknn
*
result
)
const
{
auto
&
r
=
result
[
0
];
auto
attr
=
model_
->
GetEncoderOutAttr
();
int32_t
num_frames
=
attr
.
dims
[
1
];
int32_t
encoder_out_dim
=
attr
.
dims
[
2
];
int32_t
vocab_size
=
model_
->
VocabSize
();
int32_t
context_size
=
model_
->
ContextSize
();
Hypotheses
cur
=
std
::
move
(
result
->
hyps
);
std
::
vector
<
Hypothesis
>
prev
;
auto
decoder_out
=
std
::
move
(
result
->
previous_decoder_out2
);
if
(
decoder_out
.
empty
())
{
decoder_out
=
GetDecoderOut
(
model_
,
cur
);
}
const
float
*
p_encoder_out
=
encoder_out
.
data
();
int32_t
frame_offset
=
result
->
frame_offset
;
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
prev
=
cur
.
Vec
();
cur
.
Clear
();
auto
log_probs
=
GetJoinerOutLogSoftmax
(
model_
,
p_encoder_out
,
decoder_out
);
p_encoder_out
+=
encoder_out_dim
;
for
(
int32_t
i
=
0
;
i
!=
prev
.
size
();
++
i
)
{
auto
log_prob
=
prev
[
i
].
log_prob
;
for
(
auto
&
p
:
log_probs
[
i
])
{
p
+=
log_prob
;
}
}
auto
topk
=
TopkIndex
(
log_probs
,
max_active_paths_
);
for
(
auto
k
:
topk
)
{
int32_t
hyp_index
=
k
/
vocab_size
;
int32_t
new_token
=
k
%
vocab_size
;
Hypothesis
new_hyp
=
prev
[
hyp_index
];
new_hyp
.
log_prob
=
log_probs
[
hyp_index
][
new_token
];
// blank is hardcoded to 0
// also, it treats unk as blank
if
(
new_token
!=
0
&&
new_token
!=
unk_id_
)
{
new_hyp
.
ys
.
push_back
(
new_token
);
new_hyp
.
timestamps
.
push_back
(
t
+
frame_offset
);
new_hyp
.
num_trailing_blanks
=
0
;
}
else
{
++
new_hyp
.
num_trailing_blanks
;
}
cur
.
Add
(
std
::
move
(
new_hyp
));
}
decoder_out
=
GetDecoderOut
(
model_
,
cur
);
}
result
->
hyps
=
std
::
move
(
cur
);
result
->
frame_offset
+=
num_frames
;
result
->
previous_decoder_out2
=
std
::
move
(
decoder_out
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h
0 → 100644
查看文件 @
c9d6859
// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
#include <vector>
#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
namespace
sherpa_onnx
{
class
OnlineTransducerModifiedBeamSearchDecoderRknn
:
public
OnlineTransducerDecoderRknn
{
public
:
explicit
OnlineTransducerModifiedBeamSearchDecoderRknn
(
OnlineZipformerTransducerModelRknn
*
model
,
int32_t
max_active_paths
,
int32_t
unk_id
=
2
,
float
blank_penalty
=
0
.
0
)
:
model_
(
model
),
max_active_paths_
(
max_active_paths
),
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
OnlineTransducerDecoderResultRknn
GetEmptyResult
()
const
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResultRknn
*
r
)
const
override
;
void
Decode
(
std
::
vector
<
float
>
encoder_out
,
OnlineTransducerDecoderResultRknn
*
result
)
const
override
;
private
:
OnlineZipformerTransducerModelRknn
*
model_
;
// Not owned
int32_t
max_active_paths_
;
int32_t
unk_id_
;
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_
...
...
sherpa-onnx/csrc/rknn/utils.cc
查看文件 @
c9d6859
...
...
@@ -6,6 +6,7 @@
#include <sstream>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
...
...
请
注册
或
登录
后发表评论