Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-19 10:39:07 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-19 10:39:07 +0800
Commit
0f6f58d1d355fe7b8bf6db37669175c2136c371c
0f6f58d1
1 parent
710edaa6
Add online transducer decoder (#27)
显示空白字符变更
内嵌
并排对比
正在显示
12 个修改的文件
包含
176 行增加
和
71 行删除
cmake/kaldi-native-fbank.cmake
cmake/onnxruntime.cmake
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/decode.h
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/online-transducer-decoder.h
sherpa-onnx/csrc/decode.cc → sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/onnx-utils.cc
sherpa-onnx/csrc/onnx-utils.h
sherpa-onnx/csrc/sherpa-onnx.cc
cmake/kaldi-native-fbank.cmake
查看文件 @
0f6f58d
function
(
download_kaldi_native_fbank
)
include
(
FetchContent
)
set
(
kaldi_native_fbank_URL
"https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.11.tar.gz"
)
set
(
kaldi_native_fbank_HASH
"SHA256=e69ae25ef6f30566ef31ca949dd1b0b8ec3a827caeba93a61d82bb848dac5d69"
)
set
(
kaldi_native_fbank_URL
"https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz"
)
set
(
kaldi_native_fbank_HASH
"SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44"
)
set
(
KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL
""
FORCE
)
set
(
KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL
""
FORCE
)
...
...
@@ -11,10 +11,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set
(
possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.11.tar.gz
${
PROJECT_SOURCE_DIR
}
/kaldi-native-fbank-1.11.tar.gz
${
PROJECT_BINARY_DIR
}
/kaldi-native-fbank-1.11.tar.gz
/tmp/kaldi-native-fbank-1.11.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz
${
PROJECT_SOURCE_DIR
}
/kaldi-native-fbank-1.12.tar.gz
${
PROJECT_BINARY_DIR
}
/kaldi-native-fbank-1.12.tar.gz
/tmp/kaldi-native-fbank-1.12.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz
)
foreach
(
f IN LISTS possible_file_locations
)
...
...
cmake/onnxruntime.cmake
查看文件 @
0f6f58d
...
...
@@ -9,6 +9,7 @@ function(download_onnxruntime)
${
PROJECT_SOURCE_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
${
PROJECT_BINARY_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
/tmp/onnxruntime-linux-x64-1.14.0.tgz
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
)
set
(
onnxruntime_URL
"https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz"
)
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
0f6f58d
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
add_executable
(
sherpa-onnx
decode.cc
features.cc
online-lstm-transducer-model.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-model.cc
onnx-utils.cc
...
...
sherpa-onnx/csrc/decode.h
已删除
100644 → 0
查看文件 @
710edaa
// sherpa/csrc/decode.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
#define SHERPA_ONNX_CSRC_DECODE_H_
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace
sherpa_onnx
{
/** Greedy search for non-streaming ASR.
*
* @TODO(fangjun) Support batch size > 1
*
* @param model The RnntModel
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
*/
void
GreedySearch
(
OnlineTransducerModel
*
model
,
Ort
::
Value
encoder_out
,
std
::
vector
<
int64_t
>
*
hyp
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_DECODE_H_
sherpa-onnx/csrc/features.cc
查看文件 @
0f6f58d
...
...
@@ -15,16 +15,16 @@ namespace sherpa_onnx {
class
FeatureExtractor
::
Impl
{
public
:
Impl
(
int32_t
sampling_rate
,
int32_t
feature_dim
)
{
explicit
Impl
(
const
FeatureExtractorConfig
&
config
)
{
opts_
.
frame_opts
.
dither
=
0
;
opts_
.
frame_opts
.
snip_edges
=
false
;
opts_
.
frame_opts
.
samp_freq
=
sampling_rate
;
opts_
.
frame_opts
.
samp_freq
=
config
.
sampling_rate
;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_
.
frame_opts
.
max_feature_vectors
=
100
*
100
;
opts_
.
mel_opts
.
num_bins
=
feature_dim
;
opts_
.
mel_opts
.
num_bins
=
config
.
feature_dim
;
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
...
...
@@ -80,9 +80,8 @@ class FeatureExtractor::Impl {
mutable
std
::
mutex
mutex_
;
};
FeatureExtractor
::
FeatureExtractor
(
int32_t
sampling_rate
/*=16000*/
,
int32_t
feature_dim
/*=80*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
sampling_rate
,
feature_dim
))
{}
FeatureExtractor
::
FeatureExtractor
(
const
FeatureExtractorConfig
&
config
/*={}*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
FeatureExtractor
::~
FeatureExtractor
()
=
default
;
...
...
sherpa-onnx/csrc/features.h
查看文件 @
0f6f58d
...
...
@@ -10,14 +10,18 @@
namespace
sherpa_onnx
{
struct
FeatureExtractorConfig
{
int32_t
sampling_rate
=
16000
;
int32_t
feature_dim
=
80
;
};
class
FeatureExtractor
{
public
:
/**
* @param sampling_rate Sampling rate of the data used to train the model.
* @param feature_dim Dimension of the features used to train the model.
*/
explicit
FeatureExtractor
(
int32_t
sampling_rate
=
16000
,
int32_t
feature_dim
=
80
);
explicit
FeatureExtractor
(
const
FeatureExtractorConfig
&
config
=
{});
~
FeatureExtractor
();
/**
...
...
sherpa-onnx/csrc/online-transducer-decoder.h
0 → 100644
查看文件 @
0f6f58d
// sherpa/csrc/online-transducer-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace
sherpa_onnx
{
struct
OnlineTransducerDecoderResult
{
/// The decoded token IDs so far
std
::
vector
<
int64_t
>
tokens
;
};
class
OnlineTransducerDecoder
{
public
:
virtual
~
OnlineTransducerDecoder
()
=
default
;
/* Return an empty result.
*
* To simplify the decoding code, we add `context_size` blanks
* to the beginning of the decoding result, which will be
* stripped by calling `StripPrecedingBlanks()`.
*/
virtual
OnlineTransducerDecoderResult
GetEmptyResult
()
=
0
;
/** Strip blanks added by `GetEmptyResult()`.
*
* @param r It is changed in-place.
*/
virtual
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
/*r*/
)
{}
/** Run transducer beam search given the output from the encoder model.
*
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
* @param result It is modified in-place.
*
* @note There is no need to pass encoder_out_length here since for the
* online decoding case, each utterance has the same number of frames
* and there are no paddings.
*/
virtual
void
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
...
...
sherpa-onnx/csrc/
decode
.cc → sherpa-onnx/csrc/
online-transducer-greedy-search-decoder
.cc
查看文件 @
0f6f58d
// sherpa/csrc/
decode
.cc
// sherpa/csrc/
online-transducer-greedy-search-decoder
.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/
decode
.h"
#include "sherpa-onnx/csrc/
online-transducer-greedy-search-decoder
.h"
#include <assert.h>
...
...
@@ -10,19 +10,9 @@
#include <utility>
#include <vector>
namespace
sherpa_onnx
{
static
Ort
::
Value
Clone
(
Ort
::
Value
*
v
)
{
auto
type_and_shape
=
v
->
GetTensorTypeAndShapeInfo
();
std
::
vector
<
int64_t
>
shape
=
type_and_shape
.
GetShape
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
#include "sherpa-onnx/csrc/onnx-utils.h"
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
v
->
GetTensorMutableData
<
float
>
(),
type_and_shape
.
GetElementCount
(),
shape
.
data
(),
shape
.
size
());
}
namespace
sherpa_onnx
{
static
Ort
::
Value
GetFrame
(
Ort
::
Value
*
encoder_out
,
int32_t
t
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
...
...
@@ -42,26 +32,58 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
encoder_out_dim
,
shape
.
data
(),
shape
.
size
());
}
void
GreedySearch
(
OnlineTransducerModel
*
model
,
Ort
::
Value
encoder_out
,
std
::
vector
<
int64_t
>
*
hyp
)
{
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder
::
GetEmptyResult
()
{
int32_t
context_size
=
model_
->
ContextSize
();
int32_t
blank_id
=
0
;
// always 0
OnlineTransducerDecoderResult
r
;
r
.
tokens
.
resize
(
context_size
,
blank_id
);
return
r
;
}
void
OnlineTransducerGreedySearchDecoder
::
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
{
int32_t
context_size
=
model_
->
ContextSize
();
auto
start
=
r
->
tokens
.
begin
()
+
context_size
;
auto
end
=
r
->
tokens
.
end
();
r
->
tokens
=
std
::
vector
<
int64_t
>
(
start
,
end
);
}
void
OnlineTransducerGreedySearchDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
encoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
if
(
encoder_out_shape
[
0
]
>
1
)
{
fprintf
(
stderr
,
"Only batch_size=1 is implemented. Given: %d
\n
"
,
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]));
if
(
encoder_out_shape
[
0
]
!=
result
->
size
())
{
fprintf
(
stderr
,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d
\n
"
,
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]),
static_cast
<
int32_t
>
(
result
->
size
()));
exit
(
-
1
);
}
if
(
result
->
size
()
!=
1
)
{
fprintf
(
stderr
,
"only batch size == 1 is implemented. Given: %d"
,
static_cast
<
int32_t
>
(
result
->
size
()));
exit
(
-
1
);
}
auto
&
hyp
=
(
*
result
)[
0
].
tokens
;
int32_t
num_frames
=
encoder_out_shape
[
1
];
int32_t
vocab_size
=
model
->
VocabSize
();
int32_t
vocab_size
=
model
_
->
VocabSize
();
Ort
::
Value
decoder_input
=
model
->
BuildDecoderInput
(
*
hyp
);
Ort
::
Value
decoder_out
=
model
->
RunDecoder
(
std
::
move
(
decoder_input
));
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
hyp
);
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
GetFrame
(
&
encoder_out
,
t
);
Ort
::
Value
logit
=
model
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
&
decoder_out
));
model
_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensorData
<
float
>
();
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
...
...
@@ -69,9 +91,9 @@ void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
0
)
{
hyp
->
push_back
(
y
);
decoder_input
=
model
->
BuildDecoderInput
(
*
hyp
);
decoder_out
=
model
->
RunDecoder
(
std
::
move
(
decoder_input
));
hyp
.
push_back
(
y
);
decoder_input
=
model_
->
BuildDecoderInput
(
hyp
);
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
}
}
}
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
0 → 100644
查看文件 @
0f6f58d
// sherpa/csrc/online-transducer-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace
sherpa_onnx
{
class
OnlineTransducerGreedySearchDecoder
:
public
OnlineTransducerDecoder
{
public
:
explicit
OnlineTransducerGreedySearchDecoder
(
OnlineTransducerModel
*
model
)
:
model_
(
model
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
override
;
void
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
override
;
private
:
OnlineTransducerModel
*
model_
;
// Not owned
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
...
...
sherpa-onnx/csrc/onnx-utils.cc
查看文件 @
0f6f58d
...
...
@@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
Ort
::
Value
Clone
(
Ort
::
Value
*
v
)
{
auto
type_and_shape
=
v
->
GetTensorTypeAndShapeInfo
();
std
::
vector
<
int64_t
>
shape
=
type_and_shape
.
GetShape
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
v
->
GetTensorMutableData
<
float
>
(),
type_and_shape
.
GetElementCount
(),
shape
.
data
(),
shape
.
size
());
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/onnx-utils.h
查看文件 @
0f6f58d
...
...
@@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
void
PrintModelMetadata
(
std
::
ostream
&
os
,
const
Ort
::
ModelMetadata
&
meta_data
);
// NOLINT
// Return a shallow copy of v
Ort
::
Value
Clone
(
Ort
::
Value
*
v
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
0f6f58d
...
...
@@ -9,8 +9,8 @@
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/decode.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
...
...
@@ -64,8 +64,6 @@ for a list of pre-trained models to download.
std
::
vector
<
Ort
::
Value
>
states
=
model
->
GetEncoderInitStates
();
std
::
vector
<
int64_t
>
hyp
(
model
->
ContextSize
(),
0
);
int32_t
expected_sampling_rate
=
16000
;
bool
is_ok
=
false
;
...
...
@@ -100,6 +98,10 @@ for a list of pre-trained models to download.
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
chunk_size
,
feature_dim
};
sherpa_onnx
::
OnlineTransducerGreedySearchDecoder
decoder
(
model
.
get
());
std
::
vector
<
sherpa_onnx
::
OnlineTransducerDecoderResult
>
result
=
{
decoder
.
GetEmptyResult
()};
for
(
int32_t
start
=
0
;
start
+
chunk_size
<
num_frames
;
start
+=
chunk_shift
)
{
std
::
vector
<
float
>
features
=
feat_extractor
.
GetFrames
(
start
,
chunk_size
);
...
...
@@ -109,8 +111,10 @@ for a list of pre-trained models to download.
x_shape
.
data
(),
x_shape
.
size
());
auto
pair
=
model
->
RunEncoder
(
std
::
move
(
x
),
states
);
states
=
std
::
move
(
pair
.
second
);
sherpa_onnx
::
GreedySearch
(
model
.
get
(),
std
::
move
(
pair
.
first
),
&
hyp
);
decoder
.
Decode
(
std
::
move
(
pair
.
first
),
&
result
);
}
decoder
.
StripLeadingBlanks
(
&
result
[
0
]);
const
auto
&
hyp
=
result
[
0
].
tokens
;
std
::
string
text
;
for
(
size_t
i
=
model
->
ContextSize
();
i
!=
hyp
.
size
();
++
i
)
{
text
+=
sym
[
hyp
[
i
]];
...
...
请
注册
或
登录
后发表评论