Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-03-02 21:40:13 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-03-02 21:40:13 +0800
Commit
d5e7b51af5f45328cfbbc3c3e26a2d12280fd10c
d5e7b51a
1 parent
dfcbc8d4
Support RKNN for Zipformer CTC models. (#1948)
显示空白字符变更
内嵌
并排对比
正在显示
17 个修改的文件
包含
815 行增加
和
110 行删除
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/online-ctc-decoder.h
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
sherpa-onnx/csrc/online-ctc-fst-decoder.h
sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h
sherpa-onnx/csrc/online-model-config.cc
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
sherpa-onnx/csrc/online-recognizer-impl.cc
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h
sherpa-onnx/csrc/rknn/utils.cc
sherpa-onnx/csrc/rknn/utils.h
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
d5e7b51
...
...
@@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN)
list
(
APPEND sources
./rknn/online-stream-rknn.cc
./rknn/online-transducer-greedy-search-decoder-rknn.cc
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/utils.cc
)
endif
()
...
...
sherpa-onnx/csrc/online-ctc-decoder.h
查看文件 @
d5e7b51
...
...
@@ -43,12 +43,14 @@ class OnlineCtcDecoder {
/** Run streaming CTC decoding given the output from the encoder model.
*
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
* lob_probs.
* @param log_probs A 3-D tensor of shape
* (batch_size, num_frames, vocab_size) containing
* lob_probs in row major.
*
* @param results Input & Output parameters..
*/
virtual
void
Decode
(
Ort
::
Value
log_probs
,
virtual
void
Decode
(
const
float
*
log_probs
,
int32_t
batch_size
,
int32_t
num_frames
,
int32_t
vocab_size
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
OnlineStream
**
ss
=
nullptr
,
int32_t
n
=
0
)
=
0
;
...
...
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
查看文件 @
d5e7b51
...
...
@@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
processed_frames
+=
num_rows
;
}
void
OnlineCtcFstDecoder
::
Decode
(
Ort
::
Value
log_probs
,
void
OnlineCtcFstDecoder
::
Decode
(
const
float
*
log_probs
,
int32_t
batch_size
,
int32_t
num_frames
,
int32_t
vocab_size
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
OnlineStream
**
ss
,
int32_t
n
)
{
std
::
vector
<
int64_t
>
log_probs_shape
=
log_probs
.
GetTensorTypeAndShapeInfo
().
GetShape
();
if
(
log_probs_shape
[
0
]
!=
results
->
size
())
{
if
(
batch_size
!=
results
->
size
())
{
SHERPA_ONNX_LOGE
(
"Size mismatch! log_probs.size(0) %d, results.size(0): %d"
,
static_cast
<
int32_t
>
(
log_probs_shape
[
0
]),
static_cast
<
int32_t
>
(
results
->
size
()));
batch_size
,
static_cast
<
int32_t
>
(
results
->
size
()));
exit
(
-
1
);
}
if
(
log_probs_shape
[
0
]
!=
n
)
{
SHERPA_ONNX_LOGE
(
"Size mismatch! log_probs.size(0) %d, n: %d"
,
static_cast
<
int32_t
>
(
log_probs_shape
[
0
]),
n
);
if
(
batch_size
!=
n
)
{
SHERPA_ONNX_LOGE
(
"Size mismatch! log_probs.size(0) %d, n: %d"
,
batch_size
,
n
);
exit
(
-
1
);
}
int32_t
batch_size
=
static_cast
<
int32_t
>
(
log_probs_shape
[
0
]);
int32_t
num_frames
=
static_cast
<
int32_t
>
(
log_probs_shape
[
1
]);
int32_t
vocab_size
=
static_cast
<
int32_t
>
(
log_probs_shape
[
2
]);
const
float
*
p
=
log_probs
.
GetTensorData
<
float
>
();
const
float
*
p
=
log_probs
;
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
DecodeOne
(
p
+
i
*
num_frames
*
vocab_size
,
num_frames
,
vocab_size
,
...
...
sherpa-onnx/csrc/online-ctc-fst-decoder.h
查看文件 @
d5e7b51
...
...
@@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
OnlineCtcFstDecoder
(
const
OnlineCtcFstDecoderConfig
&
config
,
int32_t
blank_id
);
void
Decode
(
Ort
::
Value
log_probs
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
void
Decode
(
const
float
*
log_probs
,
int32_t
batch_size
,
int32_t
num_frames
,
int32_t
vocab_size
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
OnlineStream
**
ss
=
nullptr
,
int32_t
n
=
0
)
override
;
std
::
unique_ptr
<
kaldi_decoder
::
FasterDecoder
>
CreateFasterDecoder
()
...
...
sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
查看文件 @
d5e7b51
...
...
@@ -13,23 +13,16 @@
namespace
sherpa_onnx
{
void
OnlineCtcGreedySearchDecoder
::
Decode
(
Ort
::
Value
log_probs
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
const
float
*
log_probs
,
int32_t
batch_size
,
int32_t
num_frames
,
int32_t
vocab_size
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
OnlineStream
**
/*ss=nullptr*/
,
int32_t
/*n = 0*/
)
{
std
::
vector
<
int64_t
>
log_probs_shape
=
log_probs
.
GetTensorTypeAndShapeInfo
().
GetShape
();
if
(
log_probs_shape
[
0
]
!=
results
->
size
())
{
if
(
batch_size
!=
results
->
size
())
{
SHERPA_ONNX_LOGE
(
"Size mismatch! log_probs.size(0) %d, results.size(0): %d"
,
static_cast
<
int32_t
>
(
log_probs_shape
[
0
]),
static_cast
<
int32_t
>
(
results
->
size
()));
batch_size
,
static_cast
<
int32_t
>
(
results
->
size
()));
exit
(
-
1
);
}
int32_t
batch_size
=
static_cast
<
int32_t
>
(
log_probs_shape
[
0
]);
int32_t
num_frames
=
static_cast
<
int32_t
>
(
log_probs_shape
[
1
]);
int32_t
vocab_size
=
static_cast
<
int32_t
>
(
log_probs_shape
[
2
]);
const
float
*
p
=
log_probs
.
GetTensorData
<
float
>
();
const
float
*
p
=
log_probs
;
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
auto
&
r
=
(
*
results
)[
b
];
...
...
sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h
查看文件 @
d5e7b51
...
...
@@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
explicit
OnlineCtcGreedySearchDecoder
(
int32_t
blank_id
)
:
blank_id_
(
blank_id
)
{}
void
Decode
(
Ort
::
Value
log_probs
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
void
Decode
(
const
float
*
log_probs
,
int32_t
batch_size
,
int32_t
num_frames
,
int32_t
vocab_size
,
std
::
vector
<
OnlineCtcDecoderResult
>
*
results
,
OnlineStream
**
ss
=
nullptr
,
int32_t
n
=
0
)
override
;
private
:
...
...
sherpa-onnx/csrc/online-model-config.cc
查看文件 @
d5e7b51
...
...
@@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const {
transducer
.
decoder
.
c_str
(),
transducer
.
joiner
.
c_str
());
return
false
;
}
if
(
!
zipformer2_ctc
.
model
.
empty
()
&&
EndsWith
(
zipformer2_ctc
.
model
,
".rknn"
))
{
SHERPA_ONNX_LOGE
(
"--provider is %s, which is not rknn, but you pass rknn model "
"filename for zipformer2_ctc: '%s'"
,
provider_config
.
provider
.
c_str
(),
zipformer2_ctc
.
model
.
c_str
());
return
false
;
}
}
if
(
provider_config
.
provider
==
"rknn"
)
{
...
...
@@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const {
transducer
.
joiner
.
c_str
());
return
false
;
}
if
(
!
zipformer2_ctc
.
model
.
empty
()
&&
EndsWith
(
zipformer2_ctc
.
model
,
".onnx"
))
{
SHERPA_ONNX_LOGE
(
"--provider rknn, but you pass onnx model filename for "
"zipformer2_ctc: '%s'"
,
zipformer2_ctc
.
model
.
c_str
());
return
false
;
}
}
if
(
!
tokens_buf
.
empty
()
&&
FileExists
(
tokens
))
{
...
...
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
查看文件 @
d5e7b51
...
...
@@ -24,11 +24,10 @@
namespace
sherpa_onnx
{
static
OnlineRecognizerResult
Convert
(
const
OnlineCtcDecoderResult
&
src
,
OnlineRecognizerResult
ConvertCtc
(
const
OnlineCtcDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
)
{
OnlineRecognizerResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
...
...
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
next_states
=
model_
->
UnStackStates
(
std
::
move
(
out_states
));
decoder_
->
Decode
(
std
::
move
(
out
[
0
]),
&
results
,
ss
,
n
);
std
::
vector
<
int64_t
>
log_probs_shape
=
out
[
0
].
GetTensorTypeAndShapeInfo
().
GetShape
();
decoder_
->
Decode
(
out
[
0
].
GetTensorData
<
float
>
(),
log_probs_shape
[
0
],
log_probs_shape
[
1
],
log_probs_shape
[
2
],
&
results
,
ss
,
n
);
for
(
int32_t
k
=
0
;
k
!=
n
;
++
k
)
{
ss
[
k
]
->
SetCtcResult
(
results
[
k
]);
...
...
@@ -196,7 +198,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
auto
r
=
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
auto
r
=
ConvertCtc
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
r
.
text
=
ApplyInverseTextNormalization
(
r
.
text
);
return
r
;
...
...
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std
::
vector
<
OnlineCtcDecoderResult
>
results
(
1
);
results
[
0
]
=
std
::
move
(
s
->
GetCtcResult
());
decoder_
->
Decode
(
std
::
move
(
out
[
0
]),
&
results
,
&
s
,
1
);
std
::
vector
<
int64_t
>
log_probs_shape
=
out
[
0
].
GetTensorTypeAndShapeInfo
().
GetShape
();
decoder_
->
Decode
(
out
[
0
].
GetTensorData
<
float
>
(),
log_probs_shape
[
0
],
log_probs_shape
[
1
],
log_probs_shape
[
2
],
&
results
,
&
s
,
1
);
s
->
SetCtcResult
(
results
[
0
]);
}
...
...
sherpa-onnx/csrc/online-recognizer-impl.cc
查看文件 @
d5e7b51
...
...
@@ -27,6 +27,7 @@
#include "sherpa-onnx/csrc/text-utils.h"
#if SHERPA_ONNX_ENABLE_RKNN
#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h"
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
#endif
...
...
@@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
if
(
config
.
model_config
.
provider_config
.
provider
==
"rknn"
)
{
#if SHERPA_ONNX_ENABLE_RKNN
// Currently, only zipformer v1 is suported for rknn
if
(
config
.
model_config
.
transducer
.
encoder
.
empty
())
{
if
(
config
.
model_config
.
transducer
.
encoder
.
empty
()
&&
config
.
model_config
.
zipformer2_ctc
.
model
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Only Zipformer transducers are currently supported by rknn. "
"Fallback to CPU"
);
}
else
{
"Only Zipformer transducers and CTC models are currently supported "
"by rknn. Fallback to CPU"
);
}
else
if
(
!
config
.
model_config
.
transducer
.
encoder
.
empty
())
{
return
std
::
make_unique
<
OnlineRecognizerTransducerRknnImpl
>
(
config
);
}
else
if
(
!
config
.
model_config
.
zipformer2_ctc
.
model
.
empty
())
{
return
std
::
make_unique
<
OnlineRecognizerCtcRknnImpl
>
(
config
);
}
#else
SHERPA_ONNX_LOGE
(
...
...
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
0 → 100644
查看文件 @
d5e7b51
// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
#include <algorithm>
#include <ios>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace
sherpa_onnx
{
// defined in ../online-recognizer-ctc-impl.h
OnlineRecognizerResult
ConvertCtc
(
const
OnlineCtcDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
);
class
OnlineRecognizerCtcRknnImpl
:
public
OnlineRecognizerImpl
{
public
:
explicit
OnlineRecognizerCtcRknnImpl
(
const
OnlineRecognizerConfig
&
config
)
:
OnlineRecognizerImpl
(
config
),
config_
(
config
),
model_
(
std
::
make_unique
<
OnlineZipformerCtcModelRknn
>
(
config
.
model_config
)),
endpoint_
(
config_
.
endpoint_config
)
{
if
(
!
config
.
model_config
.
tokens_buf
.
empty
())
{
sym_
=
SymbolTable
(
config
.
model_config
.
tokens_buf
,
false
);
}
else
{
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_
=
SymbolTable
(
config
.
model_config
.
tokens
,
true
);
}
InitDecoder
();
}
template
<
typename
Manager
>
explicit
OnlineRecognizerCtcRknnImpl
(
Manager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
OnlineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
model_
(
std
::
make_unique
<
OnlineZipformerCtcModelRknn
>
(
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
InitDecoder
();
}
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
override
{
auto
stream
=
std
::
make_unique
<
OnlineStreamRknn
>
(
config_
.
feat_config
);
stream
->
SetZipformerEncoderStates
(
model_
->
GetInitStates
());
stream
->
SetFasterDecoder
(
decoder_
->
CreateFasterDecoder
());
return
stream
;
}
bool
IsReady
(
OnlineStream
*
s
)
const
override
{
return
s
->
GetNumProcessedFrames
()
+
model_
->
ChunkSize
()
<
s
->
NumFramesReady
();
}
void
DecodeStreams
(
OnlineStream
**
ss
,
int32_t
n
)
const
override
{
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
DecodeStream
(
reinterpret_cast
<
OnlineStreamRknn
*>
(
ss
[
i
]));
}
}
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
)
const
override
{
OnlineCtcDecoderResult
decoder_result
=
s
->
GetCtcResult
();
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
auto
r
=
ConvertCtc
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
r
.
text
=
ApplyInverseTextNormalization
(
r
.
text
);
return
r
;
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
if
(
!
config_
.
enable_endpoint
)
{
return
false
;
}
int32_t
num_processed_frames
=
s
->
GetNumProcessedFrames
();
// frame shift is 10 milliseconds
float
frame_shift_in_seconds
=
0
.
01
;
// subsampling factor is 4
int32_t
trailing_silence_frames
=
s
->
GetCtcResult
().
num_trailing_blanks
*
4
;
return
endpoint_
.
IsEndpoint
(
num_processed_frames
,
trailing_silence_frames
,
frame_shift_in_seconds
);
}
void
Reset
(
OnlineStream
*
s
)
const
override
{
// segment is incremented only when the last
// result is not empty
const
auto
&
r
=
s
->
GetCtcResult
();
if
(
!
r
.
tokens
.
empty
())
{
s
->
GetCurrentSegment
()
+=
1
;
}
// clear result
s
->
SetCtcResult
({});
// clear states
reinterpret_cast
<
OnlineStreamRknn
*>
(
s
)
->
SetZipformerEncoderStates
(
model_
->
GetInitStates
());
s
->
GetFasterDecoderProcessedFrames
()
=
0
;
// Note: We only update counters. The underlying audio samples
// are not discarded.
s
->
Reset
();
}
private
:
void
InitDecoder
()
{
if
(
!
sym_
.
Contains
(
"<blk>"
)
&&
!
sym_
.
Contains
(
"<eps>"
)
&&
!
sym_
.
Contains
(
"<blank>"
))
{
SHERPA_ONNX_LOGE
(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID."
);
exit
(
-
1
);
}
int32_t
blank_id
=
0
;
if
(
sym_
.
Contains
(
"<blk>"
))
{
blank_id
=
sym_
[
"<blk>"
];
}
else
if
(
sym_
.
Contains
(
"<eps>"
))
{
// for tdnn models of the yesno recipe from icefall
blank_id
=
sym_
[
"<eps>"
];
}
else
if
(
sym_
.
Contains
(
"<blank>"
))
{
// for WeNet CTC models
blank_id
=
sym_
[
"<blank>"
];
}
if
(
!
config_
.
ctc_fst_decoder_config
.
graph
.
empty
())
{
decoder_
=
std
::
make_unique
<
OnlineCtcFstDecoder
>
(
config_
.
ctc_fst_decoder_config
,
blank_id
);
}
else
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineCtcGreedySearchDecoder
>
(
blank_id
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s for streaming CTC models"
,
config_
.
decoding_method
.
c_str
());
exit
(
-
1
);
}
}
void
DecodeStream
(
OnlineStreamRknn
*
s
)
const
{
int32_t
chunk_size
=
model_
->
ChunkSize
();
int32_t
chunk_shift
=
model_
->
ChunkShift
();
int32_t
feat_dim
=
s
->
FeatureDim
();
const
auto
num_processed_frames
=
s
->
GetNumProcessedFrames
();
std
::
vector
<
float
>
features
=
s
->
GetFrames
(
num_processed_frames
,
chunk_size
);
s
->
GetNumProcessedFrames
()
+=
chunk_shift
;
auto
&
states
=
s
->
GetZipformerEncoderStates
();
auto
p
=
model_
->
Run
(
features
,
std
::
move
(
states
));
states
=
std
::
move
(
p
.
second
);
std
::
vector
<
OnlineCtcDecoderResult
>
results
(
1
);
results
[
0
]
=
std
::
move
(
s
->
GetCtcResult
());
auto
attr
=
model_
->
GetOutAttr
();
decoder_
->
Decode
(
p
.
first
.
data
(),
attr
.
dims
[
0
],
attr
.
dims
[
1
],
attr
.
dims
[
2
],
&
results
,
reinterpret_cast
<
OnlineStream
**>
(
&
s
),
1
);
s
->
SetCtcResult
(
results
[
0
]);
}
private
:
OnlineRecognizerConfig
config_
;
std
::
unique_ptr
<
OnlineZipformerCtcModelRknn
>
model_
;
std
::
unique_ptr
<
OnlineCtcDecoder
>
decoder_
;
SymbolTable
sym_
;
Endpoint
endpoint_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
...
...
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
0 → 100644
查看文件 @
d5e7b51
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
class
OnlineZipformerCtcModelRknn
::
Impl
{
public
:
~
Impl
()
{
auto
ret
=
rknn_destroy
(
ctx_
);
if
(
ret
!=
RKNN_SUCC
)
{
SHERPA_ONNX_LOGE
(
"Failed to destroy the context"
);
}
}
explicit
Impl
(
const
OnlineModelConfig
&
config
)
:
config_
(
config
)
{
{
auto
buf
=
ReadFile
(
config
.
zipformer2_ctc
.
model
);
Init
(
buf
.
data
(),
buf
.
size
());
}
int32_t
ret
=
RKNN_SUCC
;
switch
(
config_
.
num_threads
)
{
case
1
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_AUTO
);
break
;
case
0
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_0
);
break
;
case
-
1
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_1
);
break
;
case
-
2
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_2
);
break
;
case
-
3
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_0_1
);
break
;
case
-
4
:
ret
=
rknn_set_core_mask
(
ctx_
,
RKNN_NPU_CORE_0_1_2
);
break
;
default
:
SHERPA_ONNX_LOGE
(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d"
,
config_
.
num_threads
);
break
;
}
if
(
ret
!=
RKNN_SUCC
)
{
SHERPA_ONNX_LOGE
(
"Failed to select npu core to run the model (You can ignore it if "
"you "
"are not using RK3588."
);
}
}
// TODO(fangjun): Support Android
std
::
vector
<
std
::
vector
<
uint8_t
>>
GetInitStates
()
const
{
// input_attrs_[0] is for the feature
// input_attrs_[1:] is for states
// so we use -1 here
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
(
input_attrs_
.
size
()
-
1
);
int32_t
i
=
-
1
;
for
(
auto
&
attr
:
input_attrs_
)
{
i
+=
1
;
if
(
i
==
0
)
{
// skip processing the attr for features.
continue
;
}
if
(
attr
.
type
==
RKNN_TENSOR_FLOAT16
)
{
states
[
i
-
1
].
resize
(
attr
.
n_elems
*
sizeof
(
float
));
}
else
if
(
attr
.
type
==
RKNN_TENSOR_INT64
)
{
states
[
i
-
1
].
resize
(
attr
.
n_elems
*
sizeof
(
int64_t
));
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported tensor type: %d, %s"
,
attr
.
type
,
get_type_string
(
attr
.
type
));
SHERPA_ONNX_EXIT
(
-
1
);
}
}
return
states
;
}
std
::
pair
<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
Run
(
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
const
{
std
::
vector
<
rknn_input
>
inputs
(
input_attrs_
.
size
());
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
inputs
.
size
());
++
i
)
{
auto
&
input
=
inputs
[
i
];
auto
&
attr
=
input_attrs_
[
i
];
input
.
index
=
attr
.
index
;
if
(
attr
.
type
==
RKNN_TENSOR_FLOAT16
)
{
input
.
type
=
RKNN_TENSOR_FLOAT32
;
}
else
if
(
attr
.
type
==
RKNN_TENSOR_INT64
)
{
input
.
type
=
RKNN_TENSOR_INT64
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported tensor type %d, %s"
,
attr
.
type
,
get_type_string
(
attr
.
type
));
SHERPA_ONNX_EXIT
(
-
1
);
}
input
.
fmt
=
attr
.
fmt
;
if
(
i
==
0
)
{
input
.
buf
=
reinterpret_cast
<
void
*>
(
features
.
data
());
input
.
size
=
features
.
size
()
*
sizeof
(
float
);
}
else
{
input
.
buf
=
reinterpret_cast
<
void
*>
(
states
[
i
-
1
].
data
());
input
.
size
=
states
[
i
-
1
].
size
();
}
}
std
::
vector
<
float
>
out
(
output_attrs_
[
0
].
n_elems
);
// Note(fangjun): We can reuse the memory from input argument `states`
// auto next_states = GetInitStates();
auto
&
next_states
=
states
;
std
::
vector
<
rknn_output
>
outputs
(
output_attrs_
.
size
());
for
(
int32_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
auto
&
output
=
outputs
[
i
];
auto
&
attr
=
output_attrs_
[
i
];
output
.
index
=
attr
.
index
;
output
.
is_prealloc
=
1
;
if
(
attr
.
type
==
RKNN_TENSOR_FLOAT16
)
{
output
.
want_float
=
1
;
}
else
if
(
attr
.
type
==
RKNN_TENSOR_INT64
)
{
output
.
want_float
=
0
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported tensor type %d, %s"
,
attr
.
type
,
get_type_string
(
attr
.
type
));
SHERPA_ONNX_EXIT
(
-
1
);
}
if
(
i
==
0
)
{
output
.
size
=
out
.
size
()
*
sizeof
(
float
);
output
.
buf
=
reinterpret_cast
<
void
*>
(
out
.
data
());
}
else
{
output
.
size
=
next_states
[
i
-
1
].
size
();
output
.
buf
=
reinterpret_cast
<
void
*>
(
next_states
[
i
-
1
].
data
());
}
}
auto
ret
=
rknn_inputs_set
(
ctx_
,
inputs
.
size
(),
inputs
.
data
());
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to set inputs"
);
ret
=
rknn_run
(
ctx_
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to run the model"
);
ret
=
rknn_outputs_get
(
ctx_
,
outputs
.
size
(),
outputs
.
data
(),
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get model output"
);
for
(
int32_t
i
=
0
;
i
<
next_states
.
size
();
++
i
)
{
const
auto
&
attr
=
input_attrs_
[
i
+
1
];
if
(
attr
.
n_dims
==
4
)
{
// TODO(fangjun): The transpose is copied from
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
// I don't understand why we need to do that.
std
::
vector
<
uint8_t
>
dst
(
next_states
[
i
].
size
());
int32_t
n
=
attr
.
dims
[
0
];
int32_t
h
=
attr
.
dims
[
1
];
int32_t
w
=
attr
.
dims
[
2
];
int32_t
c
=
attr
.
dims
[
3
];
ConvertNCHWtoNHWC
(
reinterpret_cast
<
const
float
*>
(
next_states
[
i
].
data
()),
n
,
c
,
h
,
w
,
reinterpret_cast
<
float
*>
(
dst
.
data
()));
next_states
[
i
]
=
std
::
move
(
dst
);
}
}
return
{
std
::
move
(
out
),
std
::
move
(
next_states
)};
}
int32_t
ChunkSize
()
const
{
return
T_
;
}
int32_t
ChunkShift
()
const
{
return
decode_chunk_len_
;
}
int32_t
VocabSize
()
const
{
return
vocab_size_
;
}
rknn_tensor_attr
GetOutAttr
()
const
{
return
output_attrs_
[
0
];
}
private
:
void
Init
(
void
*
model_data
,
size_t
model_data_length
)
{
auto
ret
=
rknn_init
(
&
ctx_
,
model_data
,
model_data_length
,
0
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to init model '%s'"
,
config_
.
zipformer2_ctc
.
model
.
c_str
());
if
(
config_
.
debug
)
{
rknn_sdk_version
v
;
ret
=
rknn_query
(
ctx_
,
RKNN_QUERY_SDK_VERSION
,
&
v
,
sizeof
(
v
));
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get rknn sdk version"
);
SHERPA_ONNX_LOGE
(
"sdk api version: %s, driver version: %s"
,
v
.
api_version
,
v
.
drv_version
);
}
rknn_input_output_num
io_num
;
ret
=
rknn_query
(
ctx_
,
RKNN_QUERY_IN_OUT_NUM
,
&
io_num
,
sizeof
(
io_num
));
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get I/O information for the model"
);
if
(
config_
.
debug
)
{
SHERPA_ONNX_LOGE
(
"model: %d inputs, %d outputs"
,
static_cast
<
int32_t
>
(
io_num
.
n_input
),
static_cast
<
int32_t
>
(
io_num
.
n_output
));
}
input_attrs_
.
resize
(
io_num
.
n_input
);
output_attrs_
.
resize
(
io_num
.
n_output
);
int32_t
i
=
0
;
for
(
auto
&
attr
:
input_attrs_
)
{
memset
(
&
attr
,
0
,
sizeof
(
attr
));
attr
.
index
=
i
;
ret
=
rknn_query
(
ctx_
,
RKNN_QUERY_INPUT_ATTR
,
&
attr
,
sizeof
(
attr
));
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get attr for model input %d"
,
i
);
i
+=
1
;
}
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
std
::
string
sep
;
for
(
auto
&
attr
:
input_attrs_
)
{
os
<<
sep
<<
ToString
(
attr
);
sep
=
"
\n
"
;
}
SHERPA_ONNX_LOGE
(
"
\n
----------Model inputs info----------
\n
%s"
,
os
.
str
().
c_str
());
}
i
=
0
;
for
(
auto
&
attr
:
output_attrs_
)
{
memset
(
&
attr
,
0
,
sizeof
(
attr
));
attr
.
index
=
i
;
ret
=
rknn_query
(
ctx_
,
RKNN_QUERY_OUTPUT_ATTR
,
&
attr
,
sizeof
(
attr
));
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get attr for model output %d"
,
i
);
i
+=
1
;
}
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
std
::
string
sep
;
for
(
auto
&
attr
:
output_attrs_
)
{
os
<<
sep
<<
ToString
(
attr
);
sep
=
"
\n
"
;
}
SHERPA_ONNX_LOGE
(
"
\n
----------Model outputs info----------
\n
%s"
,
os
.
str
().
c_str
());
}
rknn_custom_string
custom_string
;
ret
=
rknn_query
(
ctx_
,
RKNN_QUERY_CUSTOM_STRING
,
&
custom_string
,
sizeof
(
custom_string
));
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to read custom string from the model"
);
if
(
config_
.
debug
)
{
SHERPA_ONNX_LOGE
(
"customs string: %s"
,
custom_string
.
string
);
}
auto
meta
=
Parse
(
custom_string
);
if
(
config_
.
debug
)
{
for
(
const
auto
&
p
:
meta
)
{
SHERPA_ONNX_LOGE
(
"%s: %s"
,
p
.
first
.
c_str
(),
p
.
second
.
c_str
());
}
}
if
(
meta
.
count
(
"T"
))
{
T_
=
atoi
(
meta
.
at
(
"T"
).
c_str
());
}
if
(
meta
.
count
(
"decode_chunk_len"
))
{
decode_chunk_len_
=
atoi
(
meta
.
at
(
"decode_chunk_len"
).
c_str
());
}
vocab_size_
=
output_attrs_
[
0
].
dims
[
2
];
if
(
config_
.
debug
)
{
#if __OHOS__
SHERPA_ONNX_LOGE
(
"T: %{public}d"
,
T_
);
SHERPA_ONNX_LOGE
(
"decode_chunk_len_: %{public}d"
,
decode_chunk_len_
);
SHERPA_ONNX_LOGE
(
"vocab_size: %{public}d"
,
vocab_size
);
#else
SHERPA_ONNX_LOGE
(
"T: %d"
,
T_
);
SHERPA_ONNX_LOGE
(
"decode_chunk_len_: %d"
,
decode_chunk_len_
);
SHERPA_ONNX_LOGE
(
"vocab_size: %d"
,
vocab_size_
);
#endif
}
if
(
T_
==
0
)
{
SHERPA_ONNX_LOGE
(
"Invalid T. Please use the script from icefall to export your model"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
if
(
decode_chunk_len_
==
0
)
{
SHERPA_ONNX_LOGE
(
"Invalid decode_chunk_len. Please use the script from icefall to "
"export your model"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
}
private
:
OnlineModelConfig
config_
;
rknn_context
ctx_
=
0
;
std
::
vector
<
rknn_tensor_attr
>
input_attrs_
;
std
::
vector
<
rknn_tensor_attr
>
output_attrs_
;
int32_t
T_
=
0
;
int32_t
decode_chunk_len_
=
0
;
int32_t
vocab_size_
=
0
;
};
OnlineZipformerCtcModelRknn
::~
OnlineZipformerCtcModelRknn
()
=
default
;
OnlineZipformerCtcModelRknn
::
OnlineZipformerCtcModelRknn
(
const
OnlineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
template
<
typename
Manager
>
OnlineZipformerCtcModelRknn
::
OnlineZipformerCtcModelRknn
(
Manager
*
mgr
,
const
OnlineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
OnlineZipformerCtcModelRknn
>
(
mgr
,
config
))
{}
std
::
vector
<
std
::
vector
<
uint8_t
>>
OnlineZipformerCtcModelRknn
::
GetInitStates
()
const
{
return
impl_
->
GetInitStates
();
}
std
::
pair
<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
OnlineZipformerCtcModelRknn
::
Run
(
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
const
{
return
impl_
->
Run
(
std
::
move
(
features
),
std
::
move
(
states
));
}
int32_t
OnlineZipformerCtcModelRknn
::
ChunkSize
()
const
{
return
impl_
->
ChunkSize
();
}
int32_t
OnlineZipformerCtcModelRknn
::
ChunkShift
()
const
{
return
impl_
->
ChunkShift
();
}
int32_t
OnlineZipformerCtcModelRknn
::
VocabSize
()
const
{
return
impl_
->
VocabSize
();
}
rknn_tensor_attr
OnlineZipformerCtcModelRknn
::
GetOutAttr
()
const
{
return
impl_
->
GetOutAttr
();
}
#if __ANDROID_API__ >= 9
template
OnlineZipformerCtcModelRknn
::
OnlineZipformerCtcModelRknn
(
AAssetManager
*
mgr
,
const
OnlineModelConfig
&
config
);
#endif
#if __OHOS__
template
OnlineZipformerCtcModelRknn
::
OnlineZipformerCtcModelRknn
(
NativeResourceManager
*
mgr
,
const
OnlineModelConfig
&
config
);
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
0 → 100644
查看文件 @
d5e7b51
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
#include <memory>
#include <utility>
#include <vector>
#include "rknn_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
namespace
sherpa_onnx
{
class
OnlineZipformerCtcModelRknn
{
public
:
~
OnlineZipformerCtcModelRknn
();
explicit
OnlineZipformerCtcModelRknn
(
const
OnlineModelConfig
&
config
);
template
<
typename
Manager
>
OnlineZipformerCtcModelRknn
(
Manager
*
mgr
,
const
OnlineModelConfig
&
config
);
std
::
vector
<
std
::
vector
<
uint8_t
>>
GetInitStates
()
const
;
std
::
pair
<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
Run
(
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
const
;
int32_t
ChunkSize
()
const
;
int32_t
ChunkShift
()
const
;
int32_t
VocabSize
()
const
;
rknn_tensor_attr
GetOutAttr
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
...
...
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
查看文件 @
d5e7b51
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
5
Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
...
...
@@ -22,68 +22,11 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
// chw -> hwc
static
void
Transpose
(
const
float
*
src
,
int32_t
n
,
int32_t
channel
,
int32_t
height
,
int32_t
width
,
float
*
dst
)
{
for
(
int32_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int32_t
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int32_t
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int32_t
c
=
0
;
c
<
channel
;
++
c
)
{
// dst[h, w, c] = src[c, h, w]
dst
[
i
*
height
*
width
*
channel
+
h
*
width
*
channel
+
w
*
channel
+
c
]
=
src
[
i
*
height
*
width
*
channel
+
c
*
height
*
width
+
h
*
width
+
w
];
}
}
}
}
}
static
std
::
string
ToString
(
const
rknn_tensor_attr
&
attr
)
{
std
::
ostringstream
os
;
os
<<
"{"
;
os
<<
attr
.
index
;
os
<<
", name: "
<<
attr
.
name
;
os
<<
", shape: ("
;
std
::
string
sep
;
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
attr
.
n_dims
);
++
i
)
{
os
<<
sep
<<
attr
.
dims
[
i
];
sep
=
","
;
}
os
<<
")"
;
os
<<
", n_elems: "
<<
attr
.
n_elems
;
os
<<
", size: "
<<
attr
.
size
;
os
<<
", fmt: "
<<
get_format_string
(
attr
.
fmt
);
os
<<
", type: "
<<
get_type_string
(
attr
.
type
);
os
<<
", pass_through: "
<<
(
attr
.
pass_through
?
"true"
:
"false"
);
os
<<
"}"
;
return
os
.
str
();
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Parse
(
const
rknn_custom_string
&
custom_string
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ans
;
std
::
vector
<
std
::
string
>
fields
;
SplitStringToVector
(
custom_string
.
string
,
";"
,
false
,
&
fields
);
std
::
vector
<
std
::
string
>
tmp
;
for
(
const
auto
&
f
:
fields
)
{
SplitStringToVector
(
f
,
"="
,
false
,
&
tmp
);
if
(
tmp
.
size
()
!=
2
)
{
SHERPA_ONNX_LOGE
(
"Invalid custom string %s for %s"
,
custom_string
.
string
,
f
.
c_str
());
SHERPA_ONNX_EXIT
(
-
1
);
}
ans
[
std
::
move
(
tmp
[
0
])]
=
std
::
move
(
tmp
[
1
]);
}
return
ans
;
}
class
OnlineZipformerTransducerModelRknn
::
Impl
{
public
:
~
Impl
()
{
...
...
@@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
for
(
int32_t
i
=
0
;
i
<
next_states
.
size
();
++
i
)
{
const
auto
&
attr
=
encoder_input_attrs_
[
i
+
1
];
if
(
attr
.
n_dims
==
4
)
{
// TODO(fangjun): The
transpose
is copied from
// TODO(fangjun): The
ConvertNCHWtoNHWC
is copied from
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
// I don't understand why we need to do that.
std
::
vector
<
uint8_t
>
dst
(
next_states
[
i
].
size
());
...
...
@@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
int32_t
h
=
attr
.
dims
[
1
];
int32_t
w
=
attr
.
dims
[
2
];
int32_t
c
=
attr
.
dims
[
3
];
Transpose
(
reinterpret_cast
<
const
float
*>
(
next_states
[
i
].
data
()),
n
,
c
,
h
,
w
,
reinterpret_cast
<
float
*>
(
dst
.
data
()));
ConvertNCHWtoNHWC
(
reinterpret_cast
<
const
float
*>
(
next_states
[
i
].
data
()),
n
,
c
,
h
,
w
,
reinterpret_cast
<
float
*>
(
dst
.
data
()));
next_states
[
i
]
=
std
::
move
(
dst
);
}
}
...
...
@@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
#if __OHOS__
SHERPA_ONNX_LOGE
(
"T: %{public}d"
,
T_
);
SHERPA_ONNX_LOGE
(
"decode_chunk_len_: %{public}d"
,
decode_chunk_len_
);
SHERPA_ONNX_LOGE
(
"context_size: %{public}d"
,
context_size_
);
#else
SHERPA_ONNX_LOGE
(
"T: %d"
,
T_
);
SHERPA_ONNX_LOGE
(
"decode_chunk_len_: %d"
,
decode_chunk_len_
);
SHERPA_ONNX_LOGE
(
"context_size: %d"
,
context_size_
);
#endif
}
}
...
...
@@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl {
SHERPA_ONNX_EXIT
(
-
1
);
}
context_size_
=
decoder_input_attrs_
[
0
].
dims
[
1
];
if
(
config_
.
debug
)
{
SHERPA_ONNX_LOGE
(
"context_size: %d"
,
context_size_
);
}
i
=
0
;
for
(
auto
&
attr
:
decoder_output_attrs_
)
{
memset
(
&
attr
,
0
,
sizeof
(
attr
));
...
...
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h
查看文件 @
d5e7b51
...
...
@@ -14,8 +14,11 @@
namespace
sherpa_onnx
{
// this is for zipformer v1, i.e., the folder
// pruned_transducer_statelss7_streaming from icefall
// this is for zipformer v1 and v2, i.e., the folder
// pruned_transducer_statelss7_streaming
// and
// zipformer
// from icefall
class
OnlineZipformerTransducerModelRknn
{
public
:
~
OnlineZipformerTransducerModelRknn
();
...
...
sherpa-onnx/csrc/rknn/utils.cc
0 → 100644
查看文件 @
d5e7b51
// sherpa-onnx/csrc/utils.cc
//
// Copyright 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/utils.h"
#include <sstream>
#include <unordered_map>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
void
ConvertNCHWtoNHWC
(
const
float
*
src
,
int32_t
n
,
int32_t
channel
,
int32_t
height
,
int32_t
width
,
float
*
dst
)
{
for
(
int32_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int32_t
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int32_t
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int32_t
c
=
0
;
c
<
channel
;
++
c
)
{
// dst[h, w, c] = src[c, h, w]
dst
[
i
*
height
*
width
*
channel
+
h
*
width
*
channel
+
w
*
channel
+
c
]
=
src
[
i
*
height
*
width
*
channel
+
c
*
height
*
width
+
h
*
width
+
w
];
}
}
}
}
}
std
::
string
ToString
(
const
rknn_tensor_attr
&
attr
)
{
std
::
ostringstream
os
;
os
<<
"{"
;
os
<<
attr
.
index
;
os
<<
", name: "
<<
attr
.
name
;
os
<<
", shape: ("
;
std
::
string
sep
;
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
attr
.
n_dims
);
++
i
)
{
os
<<
sep
<<
attr
.
dims
[
i
];
sep
=
","
;
}
os
<<
")"
;
os
<<
", n_elems: "
<<
attr
.
n_elems
;
os
<<
", size: "
<<
attr
.
size
;
os
<<
", fmt: "
<<
get_format_string
(
attr
.
fmt
);
os
<<
", type: "
<<
get_type_string
(
attr
.
type
);
os
<<
", pass_through: "
<<
(
attr
.
pass_through
?
"true"
:
"false"
);
os
<<
"}"
;
return
os
.
str
();
}
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Parse
(
const
rknn_custom_string
&
custom_string
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ans
;
std
::
vector
<
std
::
string
>
fields
;
SplitStringToVector
(
custom_string
.
string
,
";"
,
false
,
&
fields
);
std
::
vector
<
std
::
string
>
tmp
;
for
(
const
auto
&
f
:
fields
)
{
SplitStringToVector
(
f
,
"="
,
false
,
&
tmp
);
if
(
tmp
.
size
()
!=
2
)
{
SHERPA_ONNX_LOGE
(
"Invalid custom string %s for %s"
,
custom_string
.
string
,
f
.
c_str
());
SHERPA_ONNX_EXIT
(
-
1
);
}
ans
[
std
::
move
(
tmp
[
0
])]
=
std
::
move
(
tmp
[
1
]);
}
return
ans
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/rknn/utils.h
0 → 100644
查看文件 @
d5e7b51
// sherpa-onnx/csrc/utils.h
//
// Copyright 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_
#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_
#include <string>
#include <unordered_map>
#include "rknn_api.h" // NOLINT
namespace
sherpa_onnx
{
void
ConvertNCHWtoNHWC
(
const
float
*
src
,
int32_t
n
,
int32_t
channel
,
int32_t
height
,
int32_t
width
,
float
*
dst
);
std
::
string
ToString
(
const
rknn_tensor_attr
&
attr
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Parse
(
const
rknn_custom_string
&
custom_string
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
d5e7b51
...
...
@@ -83,6 +83,7 @@ for a list of pre-trained models to download.
po
.
Read
(
argc
,
argv
);
if
(
po
.
NumArgs
()
<
1
)
{
po
.
PrintUsage
();
fprintf
(
stderr
,
"Error! Please provide at lease 1 wav file
\n
"
);
exit
(
EXIT_FAILURE
);
}
...
...
请
注册
或
登录
后发表评论