Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-19 09:57:56 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-19 09:57:56 +0800
Commit
710edaa6f9d22c04cf9d1372df0ed25ccd8554a6
710edaa6
1 parent
cb8f85ff
Refactor feature extractor (#26)
显示空白字符变更
内嵌
并排对比
正在显示
3 个修改的文件
包含
82 行增加
和
44 行删除
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/csrc/features.cc
查看文件 @
710edaa
...
...
@@ -6,52 +6,50 @@
#include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace
sherpa_onnx
{
FeatureExtractor
::
FeatureExtractor
()
{
class
FeatureExtractor
::
Impl
{
public
:
Impl
(
int32_t
sampling_rate
,
int32_t
feature_dim
)
{
opts_
.
frame_opts
.
dither
=
0
;
opts_
.
frame_opts
.
snip_edges
=
false
;
opts_
.
frame_opts
.
samp_freq
=
16000
;
opts_
.
frame_opts
.
samp_freq
=
sampling_rate
;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_
.
frame_opts
.
max_feature_vectors
=
100
*
100
;
opts_
.
mel_opts
.
num_bins
=
80
;
// feature dim
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
opts_
.
mel_opts
.
num_bins
=
feature_dim
;
FeatureExtractor
::
FeatureExtractor
(
const
knf
::
FbankOptions
&
opts
)
:
opts_
(
opts
)
{
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
}
void
FeatureExtractor
::
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
void
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
fbank_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
}
void
FeatureExtractor
::
InputFinished
()
{
void
InputFinished
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
fbank_
->
InputFinished
();
}
}
int32_t
FeatureExtractor
::
NumFramesReady
()
const
{
int32_t
NumFramesReady
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
fbank_
->
NumFramesReady
();
}
}
bool
FeatureExtractor
::
IsLastFrame
(
int32_t
frame
)
const
{
bool
IsLastFrame
(
int32_t
frame
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
fbank_
->
IsLastFrame
(
frame
);
}
}
std
::
vector
<
float
>
FeatureExtractor
::
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
if
(
frame_index
+
n
>
NumFramesReady
())
{
fprintf
(
stderr
,
"%d + %d > %d
\n
"
,
frame_index
,
n
,
NumFramesReady
());
exit
(
-
1
);
...
...
@@ -70,10 +68,46 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
}
return
features
;
}
void
Reset
()
{
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
int32_t
FeatureDim
()
const
{
return
opts_
.
mel_opts
.
num_bins
;
}
private
:
std
::
unique_ptr
<
knf
::
OnlineFbank
>
fbank_
;
knf
::
FbankOptions
opts_
;
mutable
std
::
mutex
mutex_
;
};
FeatureExtractor
::
FeatureExtractor
(
int32_t
sampling_rate
/*=16000*/
,
int32_t
feature_dim
/*=80*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
sampling_rate
,
feature_dim
))
{}
FeatureExtractor
::~
FeatureExtractor
()
=
default
;
void
FeatureExtractor
::
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
impl_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
void
FeatureExtractor
::
Reset
()
{
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
void
FeatureExtractor
::
InputFinished
()
{
impl_
->
InputFinished
();
}
int32_t
FeatureExtractor
::
NumFramesReady
()
const
{
return
impl_
->
NumFramesReady
();
}
bool
FeatureExtractor
::
IsLastFrame
(
int32_t
frame
)
const
{
return
impl_
->
IsLastFrame
(
frame
);
}
std
::
vector
<
float
>
FeatureExtractor
::
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
return
impl_
->
GetFrames
(
frame_index
,
n
);
}
void
FeatureExtractor
::
Reset
()
{
impl_
->
Reset
();
}
int32_t
FeatureExtractor
::
FeatureDim
()
const
{
return
impl_
->
FeatureDim
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/features.h
查看文件 @
710edaa
...
...
@@ -6,17 +6,19 @@
#define SHERPA_ONNX_CSRC_FEATURES_H_
#include <memory>
#include <mutex> // NOLINT
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace
sherpa_onnx
{
class
FeatureExtractor
{
public
:
FeatureExtractor
();
explicit
FeatureExtractor
(
const
knf
::
FbankOptions
&
fbank_opts
);
/**
* @param sampling_rate Sampling rate of the data used to train the model.
* @param feature_dim Dimension of the features used to train the model.
*/
explicit
FeatureExtractor
(
int32_t
sampling_rate
=
16000
,
int32_t
feature_dim
=
80
);
~
FeatureExtractor
();
/**
@param sampling_rate The sampling_rate of the input waveform. Should match
...
...
@@ -48,12 +50,13 @@ class FeatureExtractor {
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
;
void
Reset
();
int32_t
FeatureDim
()
const
{
return
opts_
.
mel_opts
.
num_bins
;
}
/// Return feature dim of this extractor
int32_t
FeatureDim
()
const
;
private
:
std
::
unique_ptr
<
knf
::
OnlineFbank
>
fbank_
;
knf
::
FbankOptions
opts_
;
mutable
std
::
mutex
mutex_
;
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
710edaa
...
...
@@ -2,8 +2,9 @@
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <iostream>
#include <string>
#include <vector>
...
...
@@ -30,14 +31,14 @@ Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage"
;
std
::
cerr
<<
usage
<<
"
\n
"
;
fprintf
(
stderr
,
"%s
\n
"
,
usage
)
;
return
0
;
}
std
::
string
tokens
=
argv
[
1
];
sherpa_onnx
::
OnlineTransducerModelConfig
config
;
config
.
debug
=
tru
e
;
config
.
debug
=
fals
e
;
config
.
encoder_filename
=
argv
[
2
];
config
.
decoder_filename
=
argv
[
3
];
config
.
joiner_filename
=
argv
[
4
];
...
...
@@ -47,7 +48,7 @@ for a list of pre-trained models to download.
if
(
argc
==
7
)
{
config
.
num_threads
=
atoi
(
argv
[
6
]);
}
std
::
cout
<<
config
.
ToString
().
c_str
()
<<
"
\n
"
;
fprintf
(
stderr
,
"%s
\n
"
,
config
.
ToString
().
c_str
())
;
auto
model
=
sherpa_onnx
::
OnlineTransducerModel
::
Create
(
config
);
...
...
@@ -72,17 +73,17 @@ for a list of pre-trained models to download.
sherpa_onnx
::
ReadWave
(
wav_filename
,
expected_sampling_rate
,
&
is_ok
);
if
(
!
is_ok
)
{
std
::
cerr
<<
"Failed to read "
<<
wav_filename
<<
"
\n
"
;
fprintf
(
stderr
,
"Failed to read %s
\n
"
,
wav_filename
.
c_str
())
;
return
-
1
;
}
const
float
duration
=
samples
.
size
()
/
expected_sampling_rate
;
float
duration
=
samples
.
size
()
/
static_cast
<
float
>
(
expected_sampling_rate
)
;
std
::
cout
<<
"wav filename: "
<<
wav_filename
<<
"
\n
"
;
std
::
cout
<<
"wav duration (s): "
<<
duration
<<
"
\n
"
;
fprintf
(
stderr
,
"wav filename: %s
\n
"
,
wav_filename
.
c_str
());
fprintf
(
stderr
,
"wav duration (s): %.3f
\n
"
,
duration
);
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
std
::
cout
<<
"Started!
\n
"
;
fprintf
(
stderr
,
"Started
\n
"
)
;
sherpa_onnx
::
FeatureExtractor
feat_extractor
;
feat_extractor
.
AcceptWaveform
(
expected_sampling_rate
,
samples
.
data
(),
...
...
@@ -115,10 +116,10 @@ for a list of pre-trained models to download.
text
+=
sym
[
hyp
[
i
]];
}
std
::
cout
<<
"Done!
\n
"
;
fprintf
(
stderr
,
"Done!
\n
"
)
;
std
::
cout
<<
"Recognition result for "
<<
wav_filename
<<
"
\n
"
<<
text
<<
"
\n
"
;
fprintf
(
stderr
,
"Recognition result for %s:
\n
%s
\n
"
,
wav_filename
.
c_str
(),
text
.
c_str
());
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
elapsed_seconds
=
...
...
@@ -126,7 +127,7 @@ for a list of pre-trained models to download.
.
count
()
/
1000.
;
std
::
cout
<<
"num threads: "
<<
config
.
num_threads
<<
"
\n
"
;
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
num_threads
)
;
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
float
rtf
=
elapsed_seconds
/
duration
;
...
...
请
注册
或
登录
后发表评论