Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-19 11:42:15 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-19 11:42:15 +0800
Commit
d4b0c0590ad7f213c7ff1915789dd1f0d81e8d8a
d4b0c059
1 parent
0f6f58d1
Add online stream. (#28)
隐藏空白字符变更
内嵌
并排对比
正在显示
5 个修改的文件
包含
191 行增加
和
29 行删除
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-stream.h
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
d4b0c05
...
...
@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable
(
sherpa-onnx
features.cc
online-lstm-transducer-model.cc
online-stream.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-model.cc
...
...
sherpa-onnx/csrc/features.h
查看文件 @
d4b0c05
...
...
@@ -11,16 +11,12 @@
namespace
sherpa_onnx
{
struct
FeatureExtractorConfig
{
int32_
t
sampling_rate
=
16000
;
floa
t
sampling_rate
=
16000
;
int32_t
feature_dim
=
80
;
};
class
FeatureExtractor
{
public
:
/**
* @param sampling_rate Sampling rate of the data used to train the model.
* @param feature_dim Dimension of the features used to train the model.
*/
explicit
FeatureExtractor
(
const
FeatureExtractorConfig
&
config
=
{});
~
FeatureExtractor
();
...
...
@@ -32,16 +28,19 @@ class FeatureExtractor {
*/
void
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
);
// InputFinished() tells the class you won't be providing any
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
/**
* InputFinished() tells the class you won't be providing any
* more waveform. This will help flush out the last frame or two
* of features, in the case where snip-edges == false; it also
* affects the return value of IsLastFrame().
*/
void
InputFinished
();
int32_t
NumFramesReady
()
const
;
// Note: IsLastFrame() will only ever return true if you have called
// InputFinished() (and this frame is the last frame).
/** Note: IsLastFrame() will only ever return true if you have called
* InputFinished() (and this frame is the last frame).
*/
bool
IsLastFrame
(
int32_t
frame
)
const
;
/** Get n frames starting from the given frame index.
...
...
sherpa-onnx/csrc/online-stream.cc
0 → 100644
查看文件 @
d4b0c05
// sherpa-onnx/csrc/online-stream.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-stream.h"
#include <memory>
#include <vector>
#include "sherpa-onnx/csrc/features.h"
namespace
sherpa_onnx
{
class
OnlineStream
::
Impl
{
public
:
explicit
Impl
(
const
FeatureExtractorConfig
&
config
)
:
feat_extractor_
(
config
)
{}
void
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
feat_extractor_
.
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
void
InputFinished
()
{
feat_extractor_
.
InputFinished
();
}
int32_t
NumFramesReady
()
const
{
return
feat_extractor_
.
NumFramesReady
();
}
bool
IsLastFrame
(
int32_t
frame
)
const
{
return
feat_extractor_
.
IsLastFrame
(
frame
);
}
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
return
feat_extractor_
.
GetFrames
(
frame_index
,
n
);
}
void
Reset
()
{
feat_extractor_
.
Reset
();
}
int32_t
&
GetNumProcessedFrames
()
{
return
num_processed_frames_
;
}
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
)
{
result_
=
r
;
}
const
OnlineTransducerDecoderResult
&
GetResult
()
const
{
return
result_
;
}
int32_t
FeatureDim
()
const
{
return
feat_extractor_
.
FeatureDim
();
}
private
:
FeatureExtractor
feat_extractor_
;
int32_t
num_processed_frames_
=
0
;
// before subsampling
OnlineTransducerDecoderResult
result_
;
};
OnlineStream
::
OnlineStream
(
const
FeatureExtractorConfig
&
config
/*= {}*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
OnlineStream
::~
OnlineStream
()
=
default
;
void
OnlineStream
::
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
impl_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
void
OnlineStream
::
InputFinished
()
{
impl_
->
InputFinished
();
}
int32_t
OnlineStream
::
NumFramesReady
()
const
{
return
impl_
->
NumFramesReady
();
}
bool
OnlineStream
::
IsLastFrame
(
int32_t
frame
)
const
{
return
impl_
->
IsLastFrame
(
frame
);
}
std
::
vector
<
float
>
OnlineStream
::
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
return
impl_
->
GetFrames
(
frame_index
,
n
);
}
void
OnlineStream
::
Reset
()
{
impl_
->
Reset
();
}
int32_t
OnlineStream
::
FeatureDim
()
const
{
return
impl_
->
FeatureDim
();
}
int32_t
&
OnlineStream
::
GetNumProcessedFrames
()
{
return
impl_
->
GetNumProcessedFrames
();
}
void
OnlineStream
::
SetResult
(
const
OnlineTransducerDecoderResult
&
r
)
{
impl_
->
SetResult
(
r
);
}
const
OnlineTransducerDecoderResult
&
OnlineStream
::
GetResult
()
const
{
return
impl_
->
GetResult
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-stream.h
0 → 100644
查看文件 @
d4b0c05
// sherpa-onnx/csrc/online-stream.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
#define SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
#include <memory>
#include <vector>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
namespace
sherpa_onnx
{
class
OnlineStream
{
public
:
explicit
OnlineStream
(
const
FeatureExtractorConfig
&
config
=
{});
~
OnlineStream
();
/**
@param sampling_rate The sampling_rate of the input waveform. Should match
the one expected by the feature extractor.
@param waveform Pointer to a 1-D array of size n
@param n Number of entries in waveform
*/
void
AcceptWaveform
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
);
/**
* InputFinished() tells the class you won't be providing any
* more waveform. This will help flush out the last frame or two
* of features, in the case where snip-edges == false; it also
* affects the return value of IsLastFrame().
*/
void
InputFinished
();
int32_t
NumFramesReady
()
const
;
/** Note: IsLastFrame() will only ever return true if you have called
* InputFinished() (and this frame is the last frame).
*/
bool
IsLastFrame
(
int32_t
frame
)
const
;
/** Get n frames starting from the given frame index.
*
* @param frame_index The starting frame index
* @param n Number of frames to get.
* @return Return a 2-D tensor of shape (n, feature_dim).
* which is flattened into a 1-D vector (flattened in in row major)
*/
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
;
void
Reset
();
int32_t
FeatureDim
()
const
;
// Return a reference to the number of processed frames so far.
// Initially, it is 0. It is always less than NumFramesReady().
//
// The returned reference is valid as long as this object is alive.
int32_t
&
GetNumProcessedFrames
();
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
);
const
OnlineTransducerDecoderResult
&
GetResult
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
d4b0c05
...
...
@@ -8,8 +8,7 @@
#include <string>
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
...
...
@@ -64,7 +63,7 @@ for a list of pre-trained models to download.
std
::
vector
<
Ort
::
Value
>
states
=
model
->
GetEncoderInitStates
();
int32_
t
expected_sampling_rate
=
16000
;
floa
t
expected_sampling_rate
=
16000
;
bool
is_ok
=
false
;
std
::
vector
<
float
>
samples
=
...
...
@@ -75,7 +74,7 @@ for a list of pre-trained models to download.
return
-
1
;
}
float
duration
=
samples
.
size
()
/
static_cast
<
float
>
(
expected_sampling_rate
)
;
float
duration
=
samples
.
size
()
/
expected_sampling_rate
;
fprintf
(
stderr
,
"wav filename: %s
\n
"
,
wav_filename
.
c_str
());
fprintf
(
stderr
,
"wav duration (s): %.3f
\n
"
,
duration
);
...
...
@@ -83,32 +82,33 @@ for a list of pre-trained models to download.
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
fprintf
(
stderr
,
"Started
\n
"
);
sherpa_onnx
::
FeatureExtractor
feat_extractor
;
feat_extractor
.
AcceptWaveform
(
expected_sampling_rate
,
samples
.
data
(),
samples
.
size
());
sherpa_onnx
::
OnlineStream
stream
;
stream
.
AcceptWaveform
(
expected_sampling_rate
,
samples
.
data
(),
samples
.
size
());
std
::
vector
<
float
>
tail_paddings
(
static_cast
<
int
>
(
0.2
*
expected_sampling_rate
));
feat_extractor
.
AcceptWaveform
(
expected_sampling_rate
,
tail_paddings
.
data
(),
tail_paddings
.
size
());
feat_extractor
.
InputFinished
();
stream
.
AcceptWaveform
(
expected_sampling_rate
,
tail_paddings
.
data
(),
tail_paddings
.
size
());
stream
.
InputFinished
();
int32_t
num_frames
=
feat_extractor
.
NumFramesReady
();
int32_t
feature_dim
=
feat_extractor
.
FeatureDim
();
int32_t
num_frames
=
stream
.
NumFramesReady
();
int32_t
feature_dim
=
stream
.
FeatureDim
();
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
chunk_size
,
feature_dim
};
sherpa_onnx
::
OnlineTransducerGreedySearchDecoder
decoder
(
model
.
get
());
std
::
vector
<
sherpa_onnx
::
OnlineTransducerDecoderResult
>
result
=
{
decoder
.
GetEmptyResult
()};
for
(
int32_t
start
=
0
;
start
+
chunk_size
<
num_frames
;
start
+=
chunk_shift
)
{
std
::
vector
<
float
>
features
=
feat_extractor
.
GetFrames
(
start
,
chunk_size
);
while
(
stream
.
NumFramesReady
()
-
stream
.
GetNumProcessedFrames
()
>
chunk_size
)
{
std
::
vector
<
float
>
features
=
stream
.
GetFrames
(
stream
.
GetNumProcessedFrames
(),
chunk_size
);
stream
.
GetNumProcessedFrames
()
+=
chunk_shift
;
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features
.
data
(),
features
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
auto
pair
=
model
->
RunEncoder
(
std
::
move
(
x
),
states
);
states
=
std
::
move
(
pair
.
second
);
decoder
.
Decode
(
std
::
move
(
pair
.
first
),
&
result
);
...
...
@@ -116,8 +116,8 @@ for a list of pre-trained models to download.
decoder
.
StripLeadingBlanks
(
&
result
[
0
]);
const
auto
&
hyp
=
result
[
0
].
tokens
;
std
::
string
text
;
for
(
size_t
i
=
model
->
ContextSize
();
i
!=
hyp
.
size
();
++
i
)
{
text
+=
sym
[
hyp
[
i
]];
for
(
auto
t
:
hyp
)
{
text
+=
sym
[
t
];
}
fprintf
(
stderr
,
"Done!
\n
"
);
...
...
请
注册
或
登录
后发表评论