Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-19 15:04:24 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-19 15:04:24 +0800
Commit
8acc059b3f5d88fc804408ce91b6a32b564d56b1
8acc059b
1 parent
ebc3b47f
Support batch greedy search decoding (#30)
显示空白字符变更
内嵌
并排对比
正在显示
5 个修改的文件
包含
177 行增加
和
64 行删除
sherpa-onnx/csrc/online-lstm-transducer-model.cc
sherpa-onnx/csrc/online-lstm-transducer-model.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-model.h
sherpa-onnx/csrc/online-lstm-transducer-model.cc
查看文件 @
8acc059
...
...
@@ -3,6 +3,7 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
...
...
@@ -10,6 +11,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
...
...
@@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
}
}
Ort
::
Value
OnlineLstmTransducerModel
::
StackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
{
fprintf
(
stderr
,
"implement me: %s:%d!
\n
"
,
__func__
,
static_cast
<
int
>
(
__LINE__
));
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
int64_t
a
;
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
1
,
1
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
a
,
0
,
&
a
,
0
);
return
x
;
std
::
vector
<
Ort
::
Value
>
OnlineLstmTransducerModel
::
StackStates
(
const
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
&
states
)
const
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
states
.
size
());
std
::
array
<
int64_t
,
3
>
h_shape
{
num_encoder_layers_
,
batch_size
,
d_model_
};
Ort
::
Value
h
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
h_shape
.
data
(),
h_shape
.
size
());
std
::
array
<
int64_t
,
3
>
c_shape
{
num_encoder_layers_
,
batch_size
,
rnn_hidden_size_
};
Ort
::
Value
c
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
c_shape
.
data
(),
c_shape
.
size
());
float
*
dst_h
=
h
.
GetTensorMutableData
<
float
>
();
float
*
dst_c
=
c
.
GetTensorMutableData
<
float
>
();
for
(
int32_t
layer
=
0
;
layer
!=
num_encoder_layers_
;
++
layer
)
{
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
const
float
*
src_h
=
states
[
i
][
0
].
GetTensorData
<
float
>
()
+
layer
*
d_model_
;
const
float
*
src_c
=
states
[
i
][
1
].
GetTensorData
<
float
>
()
+
layer
*
rnn_hidden_size_
;
std
::
copy
(
src_h
,
src_h
+
d_model_
,
dst_h
);
std
::
copy
(
src_c
,
src_c
+
rnn_hidden_size_
,
dst_c
);
dst_h
+=
d_model_
;
dst_c
+=
rnn_hidden_size_
;
}
}
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
reserve
(
2
);
ans
.
push_back
(
std
::
move
(
h
));
ans
.
push_back
(
std
::
move
(
c
));
return
ans
;
}
std
::
vector
<
Ort
::
Value
>
OnlineLstmTransducerModel
::
UnStackStates
(
Ort
::
Value
states
)
const
{
fprintf
(
stderr
,
"implement me: %s:%d!
\n
"
,
__func__
,
static_cast
<
int
>
(
__LINE__
));
return
{};
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
OnlineLstmTransducerModel
::
UnStackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
{
int32_t
batch_size
=
states
[
0
].
GetTensorTypeAndShapeInfo
().
GetShape
()[
1
];
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
ans
(
batch_size
);
// allocate space
std
::
array
<
int64_t
,
3
>
h_shape
{
num_encoder_layers_
,
1
,
d_model_
};
std
::
array
<
int64_t
,
3
>
c_shape
{
num_encoder_layers_
,
1
,
rnn_hidden_size_
};
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
Ort
::
Value
h
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
h_shape
.
data
(),
h_shape
.
size
());
Ort
::
Value
c
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
c_shape
.
data
(),
c_shape
.
size
());
ans
[
i
].
push_back
(
std
::
move
(
h
));
ans
[
i
].
push_back
(
std
::
move
(
c
));
}
for
(
int32_t
layer
=
0
;
layer
!=
num_encoder_layers_
;
++
layer
)
{
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
const
float
*
src_h
=
states
[
0
].
GetTensorData
<
float
>
()
+
layer
*
batch_size
*
d_model_
+
i
*
d_model_
;
const
float
*
src_c
=
states
[
1
].
GetTensorData
<
float
>
()
+
layer
*
batch_size
*
rnn_hidden_size_
+
i
*
rnn_hidden_size_
;
float
*
dst_h
=
ans
[
i
][
0
].
GetTensorMutableData
<
float
>
()
+
layer
*
d_model_
;
float
*
dst_c
=
ans
[
i
][
1
].
GetTensorMutableData
<
float
>
()
+
layer
*
rnn_hidden_size_
;
std
::
copy
(
src_h
,
src_h
+
d_model_
,
dst_h
);
std
::
copy
(
src_c
,
src_c
+
rnn_hidden_size_
,
dst_c
);
}
}
return
ans
;
}
std
::
vector
<
Ort
::
Value
>
OnlineLstmTransducerModel
::
GetEncoderInitStates
()
{
...
...
@@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
}
Ort
::
Value
OnlineLstmTransducerModel
::
BuildDecoderInput
(
const
std
::
vector
<
int64_t
>
&
hyp
)
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
shape
{
1
,
context_size_
};
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
results
.
size
());
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
context_size_
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
allocator_
,
shape
.
data
(),
shape
.
size
());
int64_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int64_t
>
();
for
(
const
auto
&
r
:
results
)
{
const
int64_t
*
begin
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
()
-
context_size_
;
const
int64_t
*
end
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
();
std
::
copy
(
begin
,
end
,
p
);
p
+=
context_size_
;
}
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
int64_t
*>
(
hyp
.
data
()
+
hyp
.
size
()
-
context_size_
),
context_size_
,
shape
.
data
(),
shape
.
size
());
return
decoder_input
;
}
Ort
::
Value
OnlineLstmTransducerModel
::
RunDecoder
(
Ort
::
Value
decoder_input
)
{
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.h
查看文件 @
8acc059
...
...
@@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
public
:
explicit
OnlineLstmTransducerModel
(
const
OnlineTransducerModelConfig
&
config
);
Ort
::
Value
StackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
override
;
std
::
vector
<
Ort
::
Value
>
StackStates
(
const
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
&
states
)
const
override
;
std
::
vector
<
Ort
::
Value
>
UnStackStates
(
Ort
::
Value
states
)
const
override
;
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
UnStackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
override
;
std
::
vector
<
Ort
::
Value
>
GetEncoderInitStates
()
override
;
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunEncoder
(
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
&
states
)
override
;
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
int64_t
>
&
hyp
)
override
;
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
override
;
Ort
::
Value
RunDecoder
(
Ort
::
Value
decoder_input
)
override
;
...
...
@@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
int32_t
ChunkShift
()
const
override
{
return
decode_chunk_len_
;
}
int32_t
VocabSize
()
const
override
{
return
vocab_size_
;
}
OrtAllocator
*
Allocator
()
override
{
return
allocator_
;
}
private
:
void
InitEncoder
(
const
std
::
string
&
encoder_filename
);
...
...
@@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
private
:
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
encoder_sess_
;
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
8acc059
...
...
@@ -6,6 +6,7 @@
#include <assert.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <utility>
...
...
@@ -64,39 +65,50 @@ class OnlineRecognizer::Impl {
}
void
DecodeStreams
(
OnlineStream
**
ss
,
int32_t
n
)
{
if
(
n
!=
1
)
{
fprintf
(
stderr
,
"only n == 1 is implemented
\n
"
);
exit
(
-
1
);
}
OnlineStream
*
s
=
ss
[
0
];
assert
(
IsReady
(
s
));
int32_t
chunk_size
=
model_
->
ChunkSize
();
int32_t
chunk_shift
=
model_
->
ChunkShift
();
int32_t
feature_dim
=
s
->
FeatureDim
();
int32_t
feature_dim
=
s
s
[
0
]
->
FeatureDim
();
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
chunk_size
,
feature_dim
};
std
::
vector
<
OnlineTransducerDecoderResult
>
results
(
n
);
std
::
vector
<
float
>
features_vec
(
n
*
chunk_size
*
feature_dim
);
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
states_vec
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
std
::
vector
<
float
>
features
=
ss
[
i
]
->
GetFrames
(
ss
[
i
]
->
GetNumProcessedFrames
(),
chunk_size
);
ss
[
i
]
->
GetNumProcessedFrames
()
+=
chunk_shift
;
std
::
copy
(
features
.
begin
(),
features
.
end
(),
features_vec
.
data
()
+
i
*
chunk_size
*
feature_dim
);
results
[
i
]
=
std
::
move
(
ss
[
i
]
->
GetResult
());
states_vec
[
i
]
=
std
::
move
(
ss
[
i
]
->
GetStates
());
}
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
vector
<
float
>
features
=
s
->
GetFrames
(
s
->
GetNumProcessedFrames
(),
chunk_size
);
s
->
GetNumProcessedFrames
()
+=
chunk_shift
;
std
::
array
<
int64_t
,
3
>
x_shape
{
n
,
chunk_size
,
feature_dim
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features
.
data
(),
features
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features_vec
.
data
(),
features_vec
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
auto
pair
=
model_
->
RunEncoder
(
std
::
move
(
x
),
s
->
GetStates
()
);
auto
states
=
model_
->
StackStates
(
states_vec
);
s
->
SetStates
(
std
::
move
(
pair
.
second
));
std
::
vector
<
OnlineTransducerDecoderResult
>
results
=
{
s
->
GetResult
()};
auto
pair
=
model_
->
RunEncoder
(
std
::
move
(
x
),
states
);
decoder_
->
Decode
(
std
::
move
(
pair
.
first
),
&
results
);
s
->
SetResult
(
results
[
0
]);
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
next_states
=
model_
->
UnStackStates
(
pair
.
second
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
ss
[
i
]
->
SetResult
(
results
[
i
]);
ss
[
i
]
->
SetStates
(
std
::
move
(
next_states
[
i
]));
}
}
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
)
{
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
8acc059
...
...
@@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
encoder_out_dim
,
shape
.
data
(),
shape
.
size
());
}
static
Ort
::
Value
Repeat
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
cur_encoder_out
,
int32_t
n
)
{
if
(
n
==
1
)
{
return
std
::
move
(
*
cur_encoder_out
);
}
std
::
vector
<
int64_t
>
cur_encoder_out_shape
=
cur_encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
std
::
array
<
int64_t
,
2
>
ans_shape
{
n
,
cur_encoder_out_shape
[
1
]};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
ans_shape
.
data
(),
ans_shape
.
size
());
const
float
*
src
=
cur_encoder_out
->
GetTensorData
<
float
>
();
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
std
::
copy
(
src
,
src
+
cur_encoder_out_shape
[
1
],
dst
);
dst
+=
cur_encoder_out_shape
[
1
];
}
return
ans
;
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder
::
GetEmptyResult
()
const
{
int32_t
context_size
=
model_
->
ContextSize
();
...
...
@@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode(
exit
(
-
1
);
}
if
(
result
->
size
()
!=
1
)
{
fprintf
(
stderr
,
"only batch size == 1 is implemented. Given: %d"
,
static_cast
<
int32_t
>
(
result
->
size
()));
exit
(
-
1
);
}
auto
&
hyp
=
(
*
result
)[
0
].
tokens
;
int32_t
num_frames
=
encoder_out_shape
[
1
];
int32_t
batch_size
=
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]);
int32_t
num_frames
=
static_cast
<
int32_t
>
(
encoder_out_shape
[
1
]);
int32_t
vocab_size
=
model_
->
VocabSize
();
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
hyp
);
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
*
result
);
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
GetFrame
(
&
encoder_out
,
t
);
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
batch_size
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensorData
<
float
>
();
bool
emitted
=
false
;
for
(
int32_t
i
=
0
;
i
<
batch_size
;
++
i
,
p_logit
+=
vocab_size
)
{
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
0
)
{
hyp
.
push_back
(
y
);
decoder_input
=
model_
->
BuildDecoderInput
(
hyp
);
emitted
=
true
;
(
*
result
)[
i
].
tokens
.
push_back
(
y
);
}
}
if
(
emitted
)
{
decoder_input
=
model_
->
BuildDecoderInput
(
*
result
);
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
}
}
...
...
sherpa-onnx/csrc/online-transducer-model.h
查看文件 @
8acc059
...
...
@@ -13,6 +13,8 @@
namespace
sherpa_onnx
{
class
OnlineTransducerDecoderResult
;
class
OnlineTransducerModel
{
public
:
virtual
~
OnlineTransducerModel
()
=
default
;
...
...
@@ -27,8 +29,8 @@ class OnlineTransducerModel {
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual
Ort
::
Value
StackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
=
0
;
virtual
std
::
vector
<
Ort
::
Value
>
StackStates
(
const
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
&
states
)
const
=
0
;
/** Unstack a batch state into a list of individual states.
*
...
...
@@ -37,7 +39,8 @@ class OnlineTransducerModel {
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual
std
::
vector
<
Ort
::
Value
>
UnStackStates
(
Ort
::
Value
states
)
const
=
0
;
virtual
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
UnStackStates
(
const
std
::
vector
<
Ort
::
Value
>
&
states
)
const
=
0
;
/** Get the initial encoder states.
*
...
...
@@ -58,7 +61,8 @@ class OnlineTransducerModel {
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
&
states
)
=
0
;
// NOLINT
virtual
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
int64_t
>
&
hyp
)
=
0
;
virtual
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
=
0
;
/** Run the decoder network.
*
...
...
@@ -111,6 +115,7 @@ class OnlineTransducerModel {
virtual
int32_t
VocabSize
()
const
=
0
;
virtual
int32_t
SubsamplingFactor
()
const
{
return
4
;
}
virtual
OrtAllocator
*
Allocator
()
=
0
;
};
}
// namespace sherpa_onnx
...
...
请
注册
或
登录
后发表评论