Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-19 12:45:38 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-19 12:45:38 +0800
Commit
ebc3b47fb8a934491b5fbebb21acb90042f2315e
ebc3b47f
1 parent
d4b0c059
add online-recognizer (#29)
隐藏空白字符变更
内嵌
并排对比
正在显示
11 个修改的文件
包含
268 行增加
和
62 行删除
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-stream.h
sherpa-onnx/csrc/online-transducer-decoder.h
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
ebc3b47
...
...
@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable
(
sherpa-onnx
features.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
...
...
sherpa-onnx/csrc/features.cc
查看文件 @
ebc3b47
...
...
@@ -7,12 +7,23 @@
#include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <sstream>
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace
sherpa_onnx
{
std
::
string
FeatureExtractorConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"FeatureExtractorConfig("
;
os
<<
"sampling_rate="
<<
sampling_rate
<<
", "
;
os
<<
"feature_dim="
<<
feature_dim
<<
")"
;
return
os
.
str
();
}
class
FeatureExtractor
::
Impl
{
public
:
explicit
Impl
(
const
FeatureExtractorConfig
&
config
)
{
...
...
sherpa-onnx/csrc/features.h
查看文件 @
ebc3b47
...
...
@@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_FEATURES_H_
#include <memory>
#include <string>
#include <vector>
namespace
sherpa_onnx
{
...
...
@@ -13,6 +14,8 @@ namespace sherpa_onnx {
struct
FeatureExtractorConfig
{
float
sampling_rate
=
16000
;
int32_t
feature_dim
=
80
;
std
::
string
ToString
()
const
;
};
class
FeatureExtractor
{
...
...
sherpa-onnx/csrc/online-recognizer.cc
0 → 100644
查看文件 @
ebc3b47
// sherpa-onnx/csrc/online-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-recognizer.h"
#include <assert.h>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace
sherpa_onnx
{
static
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
)
{
std
::
string
text
;
for
(
auto
t
:
src
.
tokens
)
{
text
+=
sym_table
[
t
];
}
OnlineRecognizerResult
ans
;
ans
.
text
=
std
::
move
(
text
);
return
ans
;
}
std
::
string
OnlineRecognizerConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OnlineRecognizerConfig("
;
os
<<
"feat_config="
<<
feat_config
.
ToString
()
<<
", "
;
os
<<
"model_config="
<<
model_config
.
ToString
()
<<
", "
;
os
<<
"tokens=
\"
"
<<
tokens
<<
"
\"
)"
;
return
os
.
str
();
}
class
OnlineRecognizer
::
Impl
{
public
:
explicit
Impl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
tokens
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
}
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
{
auto
stream
=
std
::
make_unique
<
OnlineStream
>
(
config_
.
feat_config
);
stream
->
SetResult
(
decoder_
->
GetEmptyResult
());
stream
->
SetStates
(
model_
->
GetEncoderInitStates
());
return
stream
;
}
bool
IsReady
(
OnlineStream
*
s
)
const
{
return
s
->
GetNumProcessedFrames
()
+
model_
->
ChunkSize
()
<
s
->
NumFramesReady
();
}
void
DecodeStreams
(
OnlineStream
**
ss
,
int32_t
n
)
{
if
(
n
!=
1
)
{
fprintf
(
stderr
,
"only n == 1 is implemented
\n
"
);
exit
(
-
1
);
}
OnlineStream
*
s
=
ss
[
0
];
assert
(
IsReady
(
s
));
int32_t
chunk_size
=
model_
->
ChunkSize
();
int32_t
chunk_shift
=
model_
->
ChunkShift
();
int32_t
feature_dim
=
s
->
FeatureDim
();
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
chunk_size
,
feature_dim
};
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
vector
<
float
>
features
=
s
->
GetFrames
(
s
->
GetNumProcessedFrames
(),
chunk_size
);
s
->
GetNumProcessedFrames
()
+=
chunk_shift
;
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features
.
data
(),
features
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
auto
pair
=
model_
->
RunEncoder
(
std
::
move
(
x
),
s
->
GetStates
());
s
->
SetStates
(
std
::
move
(
pair
.
second
));
std
::
vector
<
OnlineTransducerDecoderResult
>
results
=
{
s
->
GetResult
()};
decoder_
->
Decode
(
std
::
move
(
pair
.
first
),
&
results
);
s
->
SetResult
(
results
[
0
]);
}
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
)
{
OnlineTransducerDecoderResult
decoder_result
=
s
->
GetResult
();
decoder_
->
StripLeadingBlanks
(
&
decoder_result
);
return
Convert
(
decoder_result
,
sym_
);
}
private
:
OnlineRecognizerConfig
config_
;
std
::
unique_ptr
<
OnlineTransducerModel
>
model_
;
std
::
unique_ptr
<
OnlineTransducerDecoder
>
decoder_
;
SymbolTable
sym_
;
};
OnlineRecognizer
::
OnlineRecognizer
(
const
OnlineRecognizerConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
OnlineRecognizer
::~
OnlineRecognizer
()
=
default
;
std
::
unique_ptr
<
OnlineStream
>
OnlineRecognizer
::
CreateStream
()
const
{
return
impl_
->
CreateStream
();
}
bool
OnlineRecognizer
::
IsReady
(
OnlineStream
*
s
)
const
{
return
impl_
->
IsReady
(
s
);
}
void
OnlineRecognizer
::
DecodeStreams
(
OnlineStream
**
ss
,
int32_t
n
)
{
impl_
->
DecodeStreams
(
ss
,
n
);
}
OnlineRecognizerResult
OnlineRecognizer
::
GetResult
(
OnlineStream
*
s
)
{
return
impl_
->
GetResult
(
s
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-recognizer.h
0 → 100644
查看文件 @
ebc3b47
// sherpa-onnx/csrc/online-recognizer.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace
sherpa_onnx
{
struct
OnlineRecognizerResult
{
std
::
string
text
;
};
struct
OnlineRecognizerConfig
{
FeatureExtractorConfig
feat_config
;
OnlineTransducerModelConfig
model_config
;
std
::
string
tokens
;
std
::
string
ToString
()
const
;
};
class
OnlineRecognizer
{
public
:
explicit
OnlineRecognizer
(
const
OnlineRecognizerConfig
&
config
);
~
OnlineRecognizer
();
/// Create a stream for decoding.
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
*/
bool
IsReady
(
OnlineStream
*
s
)
const
;
/** Decode a single stream. */
void
DecodeStream
(
OnlineStream
*
s
)
{
OnlineStream
*
ss
[
1
]
=
{
s
};
DecodeStreams
(
ss
,
1
);
}
/** Decode multiple streams in parallel
*
* @param ss Pointer array containing streams to be decoded.
* @param n Number of streams in `ss`.
*/
void
DecodeStreams
(
OnlineStream
**
ss
,
int32_t
n
);
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
);
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
...
...
sherpa-onnx/csrc/online-stream.cc
查看文件 @
ebc3b47
...
...
@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/online-stream.h"
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/features.h"
...
...
@@ -41,10 +42,17 @@ class OnlineStream::Impl {
int32_t
FeatureDim
()
const
{
return
feat_extractor_
.
FeatureDim
();
}
void
SetStates
(
std
::
vector
<
Ort
::
Value
>
states
)
{
states_
=
std
::
move
(
states
);
}
std
::
vector
<
Ort
::
Value
>
&
GetStates
()
{
return
states_
;
}
private
:
FeatureExtractor
feat_extractor_
;
int32_t
num_processed_frames_
=
0
;
// before subsampling
OnlineTransducerDecoderResult
result_
;
std
::
vector
<
Ort
::
Value
>
states_
;
};
OnlineStream
::
OnlineStream
(
const
FeatureExtractorConfig
&
config
/*= {}*/
)
...
...
@@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
return
impl_
->
GetResult
();
}
void
OnlineStream
::
SetStates
(
std
::
vector
<
Ort
::
Value
>
states
)
{
impl_
->
SetStates
(
std
::
move
(
states
));
}
std
::
vector
<
Ort
::
Value
>
&
OnlineStream
::
GetStates
()
{
return
impl_
->
GetStates
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-stream.h
查看文件 @
ebc3b47
...
...
@@ -8,6 +8,7 @@
#include <memory>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
...
...
@@ -63,6 +64,9 @@ class OnlineStream {
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
);
const
OnlineTransducerDecoderResult
&
GetResult
()
const
;
void
SetStates
(
std
::
vector
<
Ort
::
Value
>
states
);
std
::
vector
<
Ort
::
Value
>
&
GetStates
();
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
...
...
sherpa-onnx/csrc/online-transducer-decoder.h
查看文件 @
ebc3b47
...
...
@@ -26,13 +26,14 @@ class OnlineTransducerDecoder {
* to the beginning of the decoding result, which will be
* stripped by calling `StripPrecedingBlanks()`.
*/
virtual
OnlineTransducerDecoderResult
GetEmptyResult
()
=
0
;
virtual
OnlineTransducerDecoderResult
GetEmptyResult
()
const
=
0
;
/** Strip blanks added by `GetEmptyResult()`.
*
* @param r It is changed in-place.
*/
virtual
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
/*r*/
)
{}
virtual
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
/*r*/
)
const
{
}
/** Run transducer beam search given the output from the encoder model.
*
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
ebc3b47
...
...
@@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder
::
GetEmptyResult
()
{
OnlineTransducerGreedySearchDecoder
::
GetEmptyResult
()
const
{
int32_t
context_size
=
model_
->
ContextSize
();
int32_t
blank_id
=
0
;
// always 0
OnlineTransducerDecoderResult
r
;
...
...
@@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
}
void
OnlineTransducerGreedySearchDecoder
::
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
{
OnlineTransducerDecoderResult
*
r
)
const
{
int32_t
context_size
=
model_
->
ContextSize
();
auto
start
=
r
->
tokens
.
begin
()
+
context_size
;
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
查看文件 @
ebc3b47
...
...
@@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
explicit
OnlineTransducerGreedySearchDecoder
(
OnlineTransducerModel
*
model
)
:
model_
(
model
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
override
;
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
const
override
;
void
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
override
;
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
ebc3b47
...
...
@@ -8,6 +8,7 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
...
...
@@ -35,35 +36,26 @@ for a list of pre-trained models to download.
return
0
;
}
std
::
string
tokens
=
argv
[
1
];
sherpa_onnx
::
OnlineTransducerModelConfig
config
;
config
.
debug
=
false
;
config
.
encoder_filename
=
argv
[
2
];
config
.
decoder_filename
=
argv
[
3
];
config
.
joiner_filename
=
argv
[
4
];
sherpa_onnx
::
OnlineRecognizerConfig
config
;
config
.
tokens
=
argv
[
1
];
config
.
model_config
.
debug
=
false
;
config
.
model_config
.
encoder_filename
=
argv
[
2
];
config
.
model_config
.
decoder_filename
=
argv
[
3
];
config
.
model_config
.
joiner_filename
=
argv
[
4
];
std
::
string
wav_filename
=
argv
[
5
];
config
.
num_threads
=
2
;
config
.
model_config
.
num_threads
=
2
;
if
(
argc
==
7
)
{
config
.
num_threads
=
atoi
(
argv
[
6
]);
config
.
model_config
.
num_threads
=
atoi
(
argv
[
6
]);
}
fprintf
(
stderr
,
"%s
\n
"
,
config
.
ToString
().
c_str
());
auto
model
=
sherpa_onnx
::
OnlineTransducerModel
::
Create
(
config
);
sherpa_onnx
::
SymbolTable
sym
(
tokens
);
Ort
::
AllocatorWithDefaultOptions
allocator
;
int32_t
chunk_size
=
model
->
ChunkSize
();
int32_t
chunk_shift
=
model
->
ChunkShift
();
sherpa_onnx
::
OnlineRecognizer
recognizer
(
config
);
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
vector
<
Ort
::
Value
>
states
=
model
->
GetEncoderInitStates
();
float
expected_sampling_rate
=
16000
;
float
expected_sampling_rate
=
config
.
feat_config
.
sampling_rate
;
bool
is_ok
=
false
;
std
::
vector
<
float
>
samples
=
...
...
@@ -82,44 +74,21 @@ for a list of pre-trained models to download.
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
fprintf
(
stderr
,
"Started
\n
"
);
sherpa_onnx
::
OnlineStream
stream
;
stream
.
AcceptWaveform
(
expected_sampling_rate
,
samples
.
data
(),
samples
.
size
());
auto
s
=
recognizer
.
CreateStream
();
s
->
AcceptWaveform
(
expected_sampling_rate
,
samples
.
data
(),
samples
.
size
());
std
::
vector
<
float
>
tail_paddings
(
static_cast
<
int
>
(
0.2
*
expected_sampling_rate
));
stream
.
AcceptWaveform
(
expected_sampling_rate
,
tail_paddings
.
data
(),
tail_paddings
.
size
());
stream
.
InputFinished
();
int32_t
num_frames
=
stream
.
NumFramesReady
();
int32_t
feature_dim
=
stream
.
FeatureDim
();
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
chunk_size
,
feature_dim
};
sherpa_onnx
::
OnlineTransducerGreedySearchDecoder
decoder
(
model
.
get
());
std
::
vector
<
sherpa_onnx
::
OnlineTransducerDecoderResult
>
result
=
{
decoder
.
GetEmptyResult
()};
while
(
stream
.
NumFramesReady
()
-
stream
.
GetNumProcessedFrames
()
>
chunk_size
)
{
std
::
vector
<
float
>
features
=
stream
.
GetFrames
(
stream
.
GetNumProcessedFrames
(),
chunk_size
);
stream
.
GetNumProcessedFrames
()
+=
chunk_shift
;
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features
.
data
(),
features
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
auto
pair
=
model
->
RunEncoder
(
std
::
move
(
x
),
states
);
states
=
std
::
move
(
pair
.
second
);
decoder
.
Decode
(
std
::
move
(
pair
.
first
),
&
result
);
}
decoder
.
StripLeadingBlanks
(
&
result
[
0
]);
const
auto
&
hyp
=
result
[
0
].
tokens
;
std
::
string
text
;
for
(
auto
t
:
hyp
)
{
text
+=
sym
[
t
];
s
->
AcceptWaveform
(
expected_sampling_rate
,
tail_paddings
.
data
(),
tail_paddings
.
size
());
s
->
InputFinished
();
while
(
recognizer
.
IsReady
(
s
.
get
()))
{
recognizer
.
DecodeStream
(
s
.
get
());
}
std
::
string
text
=
recognizer
.
GetResult
(
s
.
get
()).
text
;
fprintf
(
stderr
,
"Done!
\n
"
);
fprintf
(
stderr
,
"Recognition result for %s:
\n
%s
\n
"
,
wav_filename
.
c_str
(),
...
...
@@ -131,7 +100,7 @@ for a list of pre-trained models to download.
.
count
()
/
1000.
;
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
num_threads
);
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
model_config
.
num_threads
);
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
float
rtf
=
elapsed_seconds
/
duration
;
...
...
请
注册
或
登录
后发表评论