Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-05-30 15:31:10 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-05-30 15:31:10 +0800
Commit
082f230dfb06d22762df50932fae7145e7eba6dd
082f230d
1 parent
3f472a99
Fix nemo streaming transducer greedy search (#944)
显示空白字符变更
内嵌
并排对比
正在显示
18 个修改的文件
包含
274 行增加
和
244 行删除
.github/scripts/test-online-transducer.sh
.github/workflows/aarch64-linux-gnu-shared.yaml
.github/workflows/aarch64-linux-gnu-static.yaml
.github/workflows/android.yaml
.github/workflows/arm-linux-gnueabihf.yaml
.github/workflows/build-xcframework.yaml
.github/workflows/riscv64-linux.yaml
.github/workflows/windows-x64.yaml
.github/workflows/windows-x86.yaml
sherpa-onnx/csrc/online-recognizer-impl.cc
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-stream.h
sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
sherpa-onnx/csrc/online-transducer-nemo-model.cc
sherpa-onnx/csrc/online-transducer-nemo-model.h
.github/scripts/test-online-transducer.sh
查看文件 @
082f230
...
...
@@ -16,6 +16,45 @@ echo "PATH: $PATH"
which
$EXE
log
"------------------------------------------------------------"
log
"Run NeMo transducer (English)"
log
"------------------------------------------------------------"
repo_url
=
https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
curl -SL -O
$repo_url
tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
repo
=
sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms
log
"Start testing
${
repo_url
}
"
waves
=(
$repo
/test_wavs/0.wav
$repo
/test_wavs/1.wav
$repo
/test_wavs/8k.wav
)
for
wave
in
${
waves
[@]
}
;
do
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--num-threads
=
2
\
$wave
done
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--num-threads
=
2
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/8k.wav
rm -rf
$repo
log
"------------------------------------------------------------"
log
"Run LSTM transducer (English)"
log
"------------------------------------------------------------"
...
...
.github/workflows/aarch64-linux-gnu-shared.yaml
查看文件 @
082f230
...
...
@@ -196,7 +196,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p aarch64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
...
...
.github/workflows/aarch64-linux-gnu-static.yaml
查看文件 @
082f230
...
...
@@ -187,7 +187,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p aarch64
cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
...
...
.github/workflows/android.yaml
查看文件 @
082f230
...
...
@@ -124,7 +124,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
cp -v ../sherpa-onnx-*-android.tar.bz2 ./
...
...
.github/workflows/arm-linux-gnueabihf.yaml
查看文件 @
082f230
...
...
@@ -209,7 +209,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p arm32
cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
...
...
.github/workflows/build-xcframework.yaml
查看文件 @
082f230
...
...
@@ -138,7 +138,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
cp -v ../sherpa-onnx-*.tar.bz2 ./
...
...
.github/workflows/riscv64-linux.yaml
查看文件 @
082f230
...
...
@@ -242,7 +242,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p riscv64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
...
...
.github/workflows/windows-x64.yaml
查看文件 @
082f230
...
...
@@ -219,7 +219,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p win64
cp -v ../sherpa-onnx-*.tar.bz2 ./win64
...
...
.github/workflows/windows-x86.yaml
查看文件 @
082f230
...
...
@@ -221,7 +221,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p win32
cp -v ../sherpa-onnx-*.tar.bz2 ./win32
...
...
sherpa-onnx/csrc/online-recognizer-impl.cc
查看文件 @
082f230
...
...
@@ -14,19 +14,18 @@ namespace sherpa_onnx {
std
::
unique_ptr
<
OnlineRecognizerImpl
>
OnlineRecognizerImpl
::
Create
(
const
OnlineRecognizerConfig
&
config
)
{
if
(
!
config
.
model_config
.
transducer
.
encoder
.
empty
())
{
Ort
::
Env
env
(
ORT_LOGGING_LEVEL_WARNING
);
auto
decoder_model
=
ReadFile
(
config
.
model_config
.
transducer
.
decoder
);
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
decoder_model
.
data
(),
decoder_model
.
size
(),
Ort
::
SessionOptions
{});
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
decoder_model
.
data
(),
decoder_model
.
size
(),
Ort
::
SessionOptions
{});
size_t
node_count
=
sess
->
GetOutputCount
();
if
(
node_count
==
1
)
{
return
std
::
make_unique
<
OnlineRecognizerTransducerImpl
>
(
config
);
}
else
{
SHERPA_ONNX_LOGE
(
"Running streaming Nemo transducer model"
);
return
std
::
make_unique
<
OnlineRecognizerTransducerNeMoImpl
>
(
config
);
}
}
...
...
@@ -52,7 +51,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
Ort
::
Env
env
(
ORT_LOGGING_LEVEL_WARNING
);
auto
decoder_model
=
ReadFile
(
mgr
,
config
.
model_config
.
transducer
.
decoder
);
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
decoder_model
.
data
(),
decoder_model
.
size
(),
Ort
::
SessionOptions
{});
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
decoder_model
.
data
(),
decoder_model
.
size
(),
Ort
::
SessionOptions
{});
size_t
node_count
=
sess
->
GetOutputCount
();
...
...
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
082f230
...
...
@@ -35,18 +35,15 @@
namespace
sherpa_onnx
{
static
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
)
{
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
)
{
OnlineRecognizerResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
r
.
timestamps
.
reserve
(
src
.
tokens
.
size
());
for
(
auto
i
:
src
.
tokens
)
{
if
(
i
==
-
1
)
continue
;
auto
sym
=
sym_table
[
i
];
r
.
text
.
append
(
sym
);
...
...
sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
查看文件 @
082f230
...
...
@@ -6,6 +6,7 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#include <algorithm>
#include <fstream>
#include <ios>
#include <memory>
...
...
@@ -32,13 +33,10 @@
namespace
sherpa_onnx
{
// defined in ./online-recognizer-transducer-impl.h
// static may or may not be here? TODDOs
static
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
);
float
frame_shift_ms
,
int32_t
subsampling_factor
,
int32_t
segment
,
int32_t
frames_since_start
);
class
OnlineRecognizerTransducerNeMoImpl
:
public
OnlineRecognizerImpl
{
public
:
...
...
@@ -47,8 +45,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
:
config_
(
config
),
symbol_table_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
),
model_
(
std
::
make_unique
<
OnlineTransducerNeMoModel
>
(
config
.
model_config
))
{
model_
(
std
::
make_unique
<
OnlineTransducerNeMoModel
>
(
config
.
model_config
))
{
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchNeMoDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
...
...
@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
override
{
auto
stream
=
std
::
make_unique
<
OnlineStream
>
(
config_
.
feat_config
);
stream
->
SetStates
(
model_
->
GetInitStates
());
InitOnlineStream
(
stream
.
get
());
return
stream
;
}
...
...
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
}
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
)
const
override
{
OnlineTransducerDecoderResult
decoder_result
=
s
->
GetResult
();
decoder_
->
StripLeadingBlanks
(
&
decoder_result
);
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
8
;
return
Convert
(
decoder_result
,
symbol_table_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
int32_t
subsampling_factor
=
model_
->
SubsamplingFactor
();
return
Convert
(
s
->
GetResult
(),
symbol_table_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// frame shift is 10 milliseconds
float
frame_shift_in_seconds
=
0
.
01
;
// subsampling factor is 8
int32_t
trailing_silence_frames
=
s
->
GetResult
().
num_trailing_blanks
*
8
;
int32_t
trailing_silence_frames
=
s
->
GetResult
().
num_trailing_blanks
*
model_
->
SubsamplingFactor
();
return
endpoint_
.
IsEndpoint
(
num_processed_frames
,
trailing_silence_frames
,
frame_shift_in_seconds
);
...
...
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// segment is incremented only when the last
// result is not empty
const
auto
&
r
=
s
->
GetResult
();
if
(
!
r
.
tokens
.
empty
()
&&
r
.
tokens
.
back
()
!=
0
)
{
if
(
!
r
.
tokens
.
empty
())
{
s
->
GetCurrentSegment
()
+=
1
;
}
}
// we keep the decoder_out
decoder_
->
UpdateDecoderOut
(
&
s
->
GetResult
());
Ort
::
Value
decoder_out
=
std
::
move
(
s
->
GetResult
().
decoder_out
);
s
->
SetResult
({});
auto
r
=
decoder_
->
GetEmptyResult
(
);
s
->
SetStates
(
model_
->
GetEncoderInitStates
()
);
s
->
SetResult
(
r
);
s
->
GetResult
().
decoder_out
=
std
::
move
(
decoder_out
);
s
->
SetNeMoDecoderStates
(
model_
->
GetDecoderInitStates
());
// Note: We only update counters. The underlying audio samples
// are not discarded.
...
...
@@ -151,7 +143,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
int32_t
feature_dim
=
ss
[
0
]
->
FeatureDim
();
std
::
vector
<
OnlineTransducerDecoderResult
>
result
(
n
);
std
::
vector
<
float
>
features_vec
(
n
*
chunk_size
*
feature_dim
);
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
encoder_states
(
n
);
...
...
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std
::
copy
(
features
.
begin
(),
features
.
end
(),
features_vec
.
data
()
+
i
*
chunk_size
*
feature_dim
);
result
[
i
]
=
std
::
move
(
ss
[
i
]
->
GetResult
());
encoder_states
[
i
]
=
std
::
move
(
ss
[
i
]
->
GetStates
());
}
auto
memory_info
=
...
...
@@ -180,8 +169,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
// Batch size is 1
auto
states
=
std
::
move
(
encoder_states
[
0
]);
auto
states
=
model_
->
StackStates
(
std
::
move
(
encoder_states
));
int32_t
num_states
=
states
.
size
();
// num_states = 3
auto
t
=
model_
->
RunEncoder
(
std
::
move
(
x
),
std
::
move
(
states
));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
...
...
@@ -194,28 +182,22 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
out_states
.
push_back
(
std
::
move
(
t
[
k
]));
}
Ort
::
Value
encoder_out
=
Transpose12
(
model_
->
Allocator
(),
&
t
[
0
]);
// defined in online-transducer-greedy-search-nemo-decoder.h
// get intial states of decoder.
std
::
vector
<
Ort
::
Value
>
&
decoder_states
=
ss
[
0
]
->
GetNeMoDecoderStates
();
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states
=
decoder_
->
Decode
(
std
::
move
(
encoder_out
),
std
::
move
(
decoder_states
),
&
result
,
ss
,
n
);
auto
unstacked_states
=
model_
->
UnStackStates
(
std
::
move
(
out_states
));
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
ss
[
i
]
->
SetStates
(
std
::
move
(
unstacked_states
[
i
]));
}
ss
[
0
]
->
SetResult
(
resul
t
[
0
]);
Ort
::
Value
encoder_out
=
Transpose12
(
model_
->
Allocator
(),
&
t
[
0
]);
ss
[
0
]
->
SetStates
(
std
::
move
(
out_states
)
);
decoder_
->
Decode
(
std
::
move
(
encoder_out
),
ss
,
n
);
}
void
InitOnlineStream
(
OnlineStream
*
stream
)
const
{
auto
r
=
decoder_
->
GetEmptyResult
();
// set encoder states
stream
->
SetStates
(
model_
->
GetEncoderInitStates
());
stream
->
SetResult
(
r
);
stream
->
SetNeMoDecoderStates
(
model_
->
GetDecoderInitStates
(
1
));
// set decoder states
stream
->
SetNeMoDecoderStates
(
model_
->
GetDecoderInitStates
());
}
private
:
...
...
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
symbol_table_
.
NumSymbols
(),
vocab_size
);
exit
(
-
1
);
}
}
private
:
...
...
@@ -259,7 +240,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std
::
unique_ptr
<
OnlineTransducerNeMoModel
>
model_
;
std
::
unique_ptr
<
OnlineTransducerGreedySearchNeMoDecoder
>
decoder_
;
Endpoint
endpoint_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-stream.cc
查看文件 @
082f230
...
...
@@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return
impl_
->
GetStates
();
}
void
OnlineStream
::
SetNeMoDecoderStates
(
std
::
vector
<
Ort
::
Value
>
decoder_states
)
{
void
OnlineStream
::
SetNeMoDecoderStates
(
std
::
vector
<
Ort
::
Value
>
decoder_states
)
{
return
impl_
->
SetNeMoDecoderStates
(
std
::
move
(
decoder_states
));
}
...
...
sherpa-onnx/csrc/online-stream.h
查看文件 @
082f230
sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
查看文件 @
082f230
...
...
@@ -10,96 +10,57 @@
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
static
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
BuildDecoderInput
(
int32_t
token
,
OrtAllocator
*
allocator
)
{
static
Ort
::
Value
BuildDecoderInput
(
int32_t
token
,
OrtAllocator
*
allocator
)
{
std
::
array
<
int64_t
,
2
>
shape
{
1
,
1
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
std
::
array
<
int64_t
,
1
>
length_shape
{
1
};
Ort
::
Value
decoder_input_length
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
allocator
,
length_shape
.
data
(),
length_shape
.
size
());
int32_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int32_t
>
();
int32_t
*
p_length
=
decoder_input_length
.
GetTensorMutableData
<
int32_t
>
();
p
[
0
]
=
token
;
p_length
[
0
]
=
1
;
return
{
std
::
move
(
decoder_input
),
std
::
move
(
decoder_input_length
)};
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchNeMoDecoder
::
GetEmptyResult
()
const
{
int32_t
context_size
=
8
;
int32_t
blank_id
=
0
;
// always 0
OnlineTransducerDecoderResult
r
;
r
.
tokens
.
resize
(
context_size
,
-
1
);
r
.
tokens
.
back
()
=
blank_id
;
return
r
;
}
static
void
UpdateCachedDecoderOut
(
OrtAllocator
*
allocator
,
const
Ort
::
Value
*
decoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
{
std
::
vector
<
int64_t
>
shape
=
decoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
v_shape
{
1
,
shape
[
1
]};
const
float
*
src
=
decoder_out
->
GetTensorData
<
float
>
();
for
(
auto
&
r
:
*
result
)
{
if
(
!
r
.
decoder_out
)
{
r
.
decoder_out
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
v_shape
.
data
(),
v_shape
.
size
());
}
float
*
dst
=
r
.
decoder_out
.
GetTensorMutableData
<
float
>
();
std
::
copy
(
src
,
src
+
shape
[
1
],
dst
);
src
+=
shape
[
1
];
}
return
decoder_input
;
}
std
::
vector
<
Ort
::
Value
>
DecodeOne
(
const
float
*
encoder_out
,
int32_t
num_rows
,
int32_t
num_cols
,
OnlineTransducerNeMoModel
*
model
,
float
blank_penalty
,
std
::
vector
<
Ort
::
Value
>&
decoder_states
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
{
static
void
DecodeOne
(
const
float
*
encoder_out
,
int32_t
num_rows
,
int32_t
num_cols
,
OnlineTransducerNeMoModel
*
model
,
float
blank_penalty
,
OnlineStream
*
s
)
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
// OnlineTransducerDecoderResult result;
int32_t
vocab_size
=
model
->
VocabSize
();
int32_t
blank_id
=
vocab_size
-
1
;
auto
&
r
=
(
*
result
)[
0
];
auto
&
r
=
s
->
GetResult
();
Ort
::
Value
decoder_out
{
nullptr
};
auto
decoder_input_pair
=
BuildDecoderInput
(
blank_id
,
model
->
Allocator
());
// decoder_input_pair[0]: decoder_input
// decoder_input_pair[1]: decoder_input_length (discarded)
auto
decoder_input
=
BuildDecoderInput
(
r
.
tokens
.
empty
()
?
blank_id
:
r
.
tokens
.
back
(),
model
->
Allocator
());
std
::
vector
<
Ort
::
Value
>
&
last_decoder_states
=
s
->
GetNeMoDecoderStates
();
std
::
vector
<
Ort
::
Value
>
tmp_decoder_states
;
tmp_decoder_states
.
reserve
(
last_decoder_states
.
size
());
for
(
auto
&
v
:
last_decoder_states
)
{
tmp_decoder_states
.
push_back
(
View
(
&
v
));
}
// decoder_output_pair.second returns the next decoder state
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_states
));
// here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN
model
->
RunDecoder
(
std
::
move
(
decoder_input
),
std
::
move
(
tmp_decoder_states
));
std
::
array
<
int64_t
,
3
>
encoder_shape
{
1
,
num_cols
,
1
};
decoder_states
=
std
::
move
(
decoder_output_pair
.
second
)
;
bool
emitted
=
false
;
// TODO: Inside this loop, I need to framewise decoding.
for
(
int32_t
t
=
0
;
t
!=
num_rows
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
float
*>
(
encoder_out
)
+
t
*
num_cols
,
num_cols
,
...
...
@@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
SHERPA_ONNX_LOGE
(
"y=%d"
,
y
);
if
(
y
!=
blank_id
)
{
emitted
=
true
;
r
.
tokens
.
push_back
(
y
);
r
.
timestamps
.
push_back
(
t
+
r
.
frame_offset
);
r
.
num_trailing_blanks
=
0
;
decoder_input
_pair
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
decoder_input
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
// last decoder state becomes the current state for the first chunk
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_states
));
// Update the decoder states for the next chunk
decoder_states
=
std
::
move
(
decoder_output_pair
.
second
);
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input
),
std
::
move
(
decoder_output_pair
.
second
));
}
else
{
++
r
.
num_trailing_blanks
;
}
}
decoder_out
=
std
::
move
(
decoder_output_pair
.
first
);
// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);
// Update frame_offset
for
(
auto
&
r
:
*
result
)
{
r
.
frame_offset
+=
num_rows
;
if
(
emitted
)
{
s
->
SetNeMoDecoderStates
(
std
::
move
(
decoder_output_pair
.
second
));
}
r
eturn
std
::
move
(
decoder_states
)
;
r
.
frame_offset
+=
num_rows
;
}
std
::
vector
<
Ort
::
Value
>
OnlineTransducerGreedySearchNeMoDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
Ort
::
Value
>
decoder_states
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
,
OnlineStream
**
/*ss = nullptr*/
,
int32_t
/*n= 0*/
)
{
void
OnlineTransducerGreedySearchNeMoDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
OnlineStream
**
ss
,
int32_t
n
)
const
{
auto
shape
=
encoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
batch_size
=
static_cast
<
int32_t
>
(
shape
[
0
]);
// bs = 1
if
(
shape
[
0
]
!=
result
->
size
())
{
SHERPA_ONNX_LOGE
(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d"
,
static_cast
<
int32_t
>
(
shape
[
0
]),
static_cast
<
int32_t
>
(
result
->
size
()));
if
(
batch_size
!=
n
)
{
SHERPA_ONNX_LOGE
(
"Size mismatch! encoder_out.size(0) %d, n: %d"
,
static_cast
<
int32_t
>
(
shape
[
0
]),
n
);
exit
(
-
1
);
}
int32_t
batch_size
=
static_cast
<
int32_t
>
(
shape
[
0
]);
// bs = 1
int32_t
dim1
=
static_cast
<
int32_t
>
(
shape
[
1
]);
// 2
int32_t
dim2
=
static_cast
<
int32_t
>
(
shape
[
2
]);
// 512
// Define and initialize encoder_out_length
Ort
::
MemoryInfo
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
int64_t
length_value
=
1
;
std
::
vector
<
int64_t
>
length_shape
=
{
1
};
int32_t
dim1
=
static_cast
<
int32_t
>
(
shape
[
1
]);
// T
int32_t
dim2
=
static_cast
<
int32_t
>
(
shape
[
2
]);
// encoder_out_dim
Ort
::
Value
encoder_out_length
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
memory_info
,
&
length_value
,
1
,
length_shape
.
data
(),
length_shape
.
size
()
);
const
int64_t
*
p_length
=
encoder_out_length
.
GetTensorData
<
int64_t
>
();
const
float
*
p
=
encoder_out
.
GetTensorData
<
float
>
();
// std::vector<OnlineTransducerDecoderResult> ans(batch_size);
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
const
float
*
this_p
=
p
+
dim1
*
dim2
*
i
;
int32_t
this_len
=
p_length
[
i
];
// outputs the decoder state from last chunk.
auto
last_decoder_states
=
DecodeOne
(
this_p
,
this_len
,
dim2
,
model_
,
blank_penalty_
,
decoder_states
,
result
);
// ans[i] = decode_result_pair.first;
decoder_states
=
std
::
move
(
last_decoder_states
);
DecodeOne
(
this_p
,
dim1
,
dim2
,
model_
,
blank_penalty_
,
ss
[
i
]);
}
return
decoder_states
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
查看文件 @
082f230
...
...
@@ -7,27 +7,22 @@
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
namespace
sherpa_onnx
{
class
OnlineStream
;
class
OnlineTransducerGreedySearchNeMoDecoder
{
public
:
OnlineTransducerGreedySearchNeMoDecoder
(
OnlineTransducerNeMoModel
*
model
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
;
void
UpdateDecoderOut
(
OnlineTransducerDecoderResult
*
result
)
{}
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
/*r*/
)
const
{}
std
::
vector
<
Ort
::
Value
>
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
Ort
::
Value
>
decoder_states
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
,
OnlineStream
**
ss
=
nullptr
,
int32_t
n
=
0
);
// @param n number of elements in ss
void
Decode
(
Ort
::
Value
encoder_out
,
OnlineStream
**
ss
,
int32_t
n
)
const
;
private
:
OnlineTransducerNeMoModel
*
model_
;
// Not owned
...
...
@@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder {
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
...
...
sherpa-onnx/csrc/online-transducer-nemo-model.cc
查看文件 @
082f230
...
...
@@ -102,8 +102,8 @@ class OnlineTransducerNeMoModel::Impl {
std
::
move
(
features
),
View
(
&
length
),
std
::
move
(
cache_last_channel
),
std
::
move
(
cache_last_time
),
std
::
move
(
cache_last_channel_len
)};
auto
out
=
encoder_sess_
->
Run
({},
encoder_input_names_ptr_
.
data
(),
inputs
.
data
(),
inputs
.
size
(),
auto
out
=
encoder_sess_
->
Run
(
{},
encoder_input_names_ptr_
.
data
(),
inputs
.
data
(),
inputs
.
size
(),
encoder_output_names_ptr_
.
data
(),
encoder_output_names_ptr_
.
size
());
// out[0]: logit
// out[1] logit_length
...
...
@@ -127,16 +127,18 @@ class OnlineTransducerNeMoModel::Impl {
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunDecoder
(
Ort
::
Value
targets
,
std
::
vector
<
Ort
::
Value
>
states
)
{
Ort
::
MemoryInfo
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
Ort
::
MemoryInfo
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
auto
shape
=
targets
.
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
batch_size
=
static_cast
<
int32_t
>
(
shape
[
0
]);
// Create the tensor with a single int32_t value of 1
int32_t
length_value
=
1
;
std
::
vector
<
int64_t
>
length_shape
=
{
1
};
std
::
vector
<
int64_t
>
length_shape
=
{
batch_size
};
std
::
vector
<
int32_t
>
length_value
(
batch_size
,
1
);
Ort
::
Value
targets_length
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
memory_info
,
&
length_value
,
1
,
length_shape
.
data
(),
length_shape
.
size
()
);
memory_info
,
length_value
.
data
(),
batch_size
,
length_shape
.
data
(),
length_shape
.
size
());
std
::
vector
<
Ort
::
Value
>
decoder_inputs
;
decoder_inputs
.
reserve
(
2
+
states
.
size
());
...
...
@@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl {
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
{
std
::
array
<
Ort
::
Value
,
2
>
joiner_input
=
{
std
::
move
(
encoder_out
),
std
::
move
(
decoder_out
)};
auto
logit
=
joiner_sess_
->
Run
({},
joiner_input_names_ptr_
.
data
(),
joiner_input
.
data
(),
joiner_input
.
size
(),
joiner_output_names_ptr_
.
data
(),
auto
logit
=
joiner_sess_
->
Run
({},
joiner_input_names_ptr_
.
data
(),
joiner_input
.
data
(),
joiner_input
.
size
(),
joiner_output_names_ptr_
.
data
(),
joiner_output_names_ptr_
.
size
());
return
std
::
move
(
logit
[
0
]);
}
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
(
int32_t
batch_size
)
const
{
std
::
array
<
int64_t
,
3
>
s0_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
Ort
::
Value
s0
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s0_shape
.
data
(),
s0_shape
.
size
());
Fill
<
float
>
(
&
s0
,
0
);
std
::
array
<
int64_t
,
3
>
s1_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
Ort
::
Value
s1
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s1_shape
.
data
(),
s1_shape
.
size
());
Fill
<
float
>
(
&
s1
,
0
);
std
::
vector
<
Ort
::
Value
>
states
;
}
states
.
reserve
(
2
);
states
.
push_back
(
std
::
move
(
s0
));
states
.
push_back
(
std
::
move
(
s1
));
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
()
{
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
reserve
(
2
);
ans
.
push_back
(
View
(
&
lstm0_
));
ans
.
push_back
(
View
(
&
lstm1_
));
return
state
s
;
return
an
s
;
}
int32_t
ChunkSize
()
const
{
return
window_size_
;
}
...
...
@@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl {
// - cache_last_channel
// - cache_last_time_
// - cache_last_channel_len
std
::
vector
<
Ort
::
Value
>
GetInitStates
()
{
std
::
vector
<
Ort
::
Value
>
Get
Encoder
InitStates
()
{
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
reserve
(
3
);
ans
.
push_back
(
View
(
&
cache_last_channel_
));
...
...
@@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl {
return
ans
;
}
private
:
std
::
vector
<
Ort
::
Value
>
StackStates
(
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
states
)
const
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
states
.
size
());
if
(
batch_size
==
1
)
{
return
std
::
move
(
states
[
0
]);
}
std
::
vector
<
Ort
::
Value
>
ans
;
// stack cache_last_channel
std
::
vector
<
const
Ort
::
Value
*>
buf
(
batch_size
);
// there are 3 states to be stacked
for
(
int32_t
i
=
0
;
i
!=
3
;
++
i
)
{
buf
.
clear
();
buf
.
reserve
(
batch_size
);
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
assert
(
states
[
b
].
size
()
==
3
);
buf
.
push_back
(
&
states
[
b
][
i
]);
}
Ort
::
Value
c
{
nullptr
};
if
(
i
==
2
)
{
c
=
Cat
<
int64_t
>
(
allocator_
,
buf
,
0
);
}
else
{
c
=
Cat
(
allocator_
,
buf
,
0
);
}
ans
.
push_back
(
std
::
move
(
c
));
}
return
ans
;
}
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
UnStackStates
(
std
::
vector
<
Ort
::
Value
>
states
)
const
{
assert
(
states
.
size
()
==
3
);
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
ans
;
auto
shape
=
states
[
0
].
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
batch_size
=
shape
[
0
];
ans
.
resize
(
batch_size
);
if
(
batch_size
==
1
)
{
ans
[
0
]
=
std
::
move
(
states
);
return
ans
;
}
for
(
int32_t
i
=
0
;
i
!=
3
;
++
i
)
{
std
::
vector
<
Ort
::
Value
>
v
;
if
(
i
==
2
)
{
v
=
Unbind
<
int64_t
>
(
allocator_
,
&
states
[
i
],
0
);
}
else
{
v
=
Unbind
(
allocator_
,
&
states
[
i
],
0
);
}
assert
(
v
.
size
()
==
batch_size
);
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
ans
[
b
].
push_back
(
std
::
move
(
v
[
b
]));
}
}
return
ans
;
}
private
:
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
...
...
@@ -276,10 +332,10 @@ private:
normalize_type_
=
""
;
}
InitStates
();
Init
Encoder
States
();
}
void
InitStates
()
{
void
Init
Encoder
States
()
{
std
::
array
<
int64_t
,
4
>
cache_last_channel_shape
{
1
,
cache_last_channel_dim1_
,
cache_last_channel_dim2_
,
cache_last_channel_dim3_
};
...
...
@@ -314,6 +370,24 @@ private:
GetOutputNames
(
decoder_sess_
.
get
(),
&
decoder_output_names_
,
&
decoder_output_names_ptr_
);
InitDecoderStates
();
}
void
InitDecoderStates
()
{
int32_t
batch_size
=
1
;
std
::
array
<
int64_t
,
3
>
s0_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
lstm0_
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s0_shape
.
data
(),
s0_shape
.
size
());
Fill
<
float
>
(
&
lstm0_
,
0
);
std
::
array
<
int64_t
,
3
>
s1_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
lstm1_
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s1_shape
.
data
(),
s1_shape
.
size
());
Fill
<
float
>
(
&
lstm1_
,
0
);
}
void
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
)
{
...
...
@@ -363,6 +437,7 @@ private:
int32_t
pred_rnn_layers_
=
-
1
;
int32_t
pred_hidden_
=
-
1
;
// encoder states
int32_t
cache_last_channel_dim1_
;
int32_t
cache_last_channel_dim2_
;
int32_t
cache_last_channel_dim3_
;
...
...
@@ -370,9 +445,14 @@ private:
int32_t
cache_last_time_dim2_
;
int32_t
cache_last_time_dim3_
;
// init encoder states
Ort
::
Value
cache_last_channel_
{
nullptr
};
Ort
::
Value
cache_last_time_
{
nullptr
};
Ort
::
Value
cache_last_channel_len_
{
nullptr
};
// init decoder states
Ort
::
Value
lstm0_
{
nullptr
};
Ort
::
Value
lstm1_
{
nullptr
};
};
OnlineTransducerNeMoModel
::
OnlineTransducerNeMoModel
(
...
...
@@ -387,9 +467,8 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
OnlineTransducerNeMoModel
::~
OnlineTransducerNeMoModel
()
=
default
;
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
RunEncoder
(
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
states
)
const
{
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
RunEncoder
(
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
states
)
const
{
return
impl_
->
RunEncoder
(
std
::
move
(
features
),
std
::
move
(
states
));
}
...
...
@@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
return
impl_
->
RunDecoder
(
std
::
move
(
targets
),
std
::
move
(
states
));
}
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
GetDecoderInitStates
(
int32_t
batch_size
)
const
{
return
impl_
->
GetDecoderInitStates
(
batch_size
);
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
GetDecoderInitStates
()
const
{
return
impl_
->
GetDecoderInitStates
();
}
Ort
::
Value
OnlineTransducerNeMoModel
::
RunJoiner
(
Ort
::
Value
encoder_out
,
...
...
@@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
return
impl_
->
RunJoiner
(
std
::
move
(
encoder_out
),
std
::
move
(
decoder_out
));
}
int32_t
OnlineTransducerNeMoModel
::
ChunkSize
()
const
{
return
impl_
->
ChunkSize
();
}
}
int32_t
OnlineTransducerNeMoModel
::
ChunkShift
()
const
{
return
impl_
->
ChunkShift
();
}
}
int32_t
OnlineTransducerNeMoModel
::
SubsamplingFactor
()
const
{
return
impl_
->
SubsamplingFactor
();
...
...
@@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
return
impl_
->
FeatureNormalizationMethod
();
}
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
GetInitStates
()
const
{
return
impl_
->
GetInitStates
();
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
GetEncoderInitStates
()
const
{
return
impl_
->
GetEncoderInitStates
();
}
std
::
vector
<
Ort
::
Value
>
OnlineTransducerNeMoModel
::
StackStates
(
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
states
)
const
{
return
impl_
->
StackStates
(
std
::
move
(
states
));
}
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
OnlineTransducerNeMoModel
::
UnStackStates
(
std
::
vector
<
Ort
::
Value
>
states
)
const
{
return
impl_
->
UnStackStates
(
std
::
move
(
states
));
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-nemo-model.h
查看文件 @
082f230
...
...
@@ -38,15 +38,24 @@ class OnlineTransducerNeMoModel {
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std
::
vector
<
Ort
::
Value
>
GetInitStates
()
const
;
std
::
vector
<
Ort
::
Value
>
GetEncoderInitStates
()
const
;
// stack encoder states
std
::
vector
<
Ort
::
Value
>
StackStates
(
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
states
)
const
;
// unstack encoder states
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
UnStackStates
(
std
::
vector
<
Ort
::
Value
>
states
)
const
;
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states It is from GetInitStates() or returned from this method.
* @param states It is from GetEncoderInitStates() or returned from this
* method.
*
* @return Return a tuple containing:
* - ans[0]: encoder_out, a tensor of shape (N,
T', encoder_out_dim
)
* - ans[0]: encoder_out, a tensor of shape (N,
encoder_out_dim, T'
)
* - ans[1:]: contains next states
*/
std
::
vector
<
Ort
::
Value
>
RunEncoder
(
...
...
@@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel {
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunDecoder
(
Ort
::
Value
targets
,
std
::
vector
<
Ort
::
Value
>
states
)
const
;
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
(
int32_t
batch_size
)
const
;
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
()
const
;
/** Run the joint network.
*
...
...
@@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel {
* @param decoder_out Output of the decoder network.
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
*/
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
const
;
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
const
;
/** We send this number of feature frames to the encoder at a time. */
int32_t
ChunkSize
()
const
;
...
...
@@ -117,7 +124,7 @@ class OnlineTransducerNeMoModel {
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
};
}
// namespace sherpa_onnx
...
...
请
注册
或
登录
后发表评论