Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-05-08 19:07:49 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-05-08 19:07:49 +0800
Commit
68b25abf2712a4e6b5fa5f846fd9c23f72f5f860
68b25abf
1 parent
a9f936e9
Export NeMo FastConformer Hybrid Transducer Large Streaming to ONNX (#844)
隐藏空白字符变更
内嵌
并排对比
正在显示
9 个修改的文件
包含
611 行增加
和
1 行删除
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh
scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py
scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py
scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml
查看文件 @
68b25ab
name
:
export-nemo-
speaker-verification
-to-onnx
name
:
export-nemo-
fast-conformer-ctc
-to-onnx
on
:
workflow_dispatch
:
...
...
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml
0 → 100644
查看文件 @
68b25ab
name
:
export-nemo-fast-conformer-transducer-to-onnx
on
:
workflow_dispatch
:
concurrency
:
group
:
export-nemo-fast-conformer-hybrid-transducer-to-onnx-${{ github.ref }}
cancel-in-progress
:
true
jobs
:
export-nemo-fast-conformer-hybrid-transducer-to-onnx
:
if
:
github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name
:
NeMo transducer
runs-on
:
${{ matrix.os }}
strategy
:
fail-fast
:
false
matrix
:
os
:
[
macos-latest
]
python-version
:
[
"
3.10"
]
steps
:
-
uses
:
actions/checkout@v4
-
name
:
Setup Python ${{ matrix.python-version }}
uses
:
actions/setup-python@v5
with
:
python-version
:
${{ matrix.python-version }}
-
name
:
Install NeMo
shell
:
bash
run
:
|
BRANCH='main'
pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]
pip install onnxruntime
pip install kaldi-native-fbank
pip install soundfile librosa
-
name
:
Run
shell
:
bash
run
:
|
cd scripts/nemo/fast-conformer-hybrid-transducer-ctc
./run-transducer.sh
mv -v sherpa-onnx-nemo* ../../..
-
name
:
Download test waves
shell
:
bash
run
:
|
mkdir test_wavs
pushd test_wavs
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt
popd
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms
-
name
:
Release
uses
:
svenstaro/upload-release-action@v2
with
:
file_glob
:
true
file
:
./*.tar.bz2
overwrite
:
true
repo_name
:
k2-fsa/sherpa-onnx
repo_token
:
${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag
:
asr-models
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py
查看文件 @
68b25ab
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import
argparse
from
typing
import
Dict
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py
0 → 100755
查看文件 @
68b25ab
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import
argparse
from
typing
import
Dict
import
nemo.collections.asr
as
nemo_asr
import
onnx
import
torch
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
required
=
True
,
choices
=
[
"80"
,
"480"
,
"1040"
],
)
return
parser
.
parse_args
()
def
add_meta_data
(
filename
:
str
,
meta_data
:
Dict
[
str
,
str
]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model
=
onnx
.
load
(
filename
)
while
len
(
model
.
metadata_props
):
model
.
metadata_props
.
pop
()
for
key
,
value
in
meta_data
.
items
():
meta
=
model
.
metadata_props
.
add
()
meta
.
key
=
key
meta
.
value
=
str
(
value
)
onnx
.
save
(
model
,
filename
)
@torch.no_grad
()
def
main
():
args
=
get_args
()
model_name
=
f
"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms"
asr_model
=
nemo_asr
.
models
.
ASRModel
.
from_pretrained
(
model_name
=
model_name
)
with
open
(
"./tokens.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
i
,
s
in
enumerate
(
asr_model
.
joint
.
vocabulary
):
f
.
write
(
f
"{s} {i}
\n
"
)
f
.
write
(
f
"<blk> {i+1}
\n
"
)
print
(
"Saved to tokens.txt"
)
decoder_type
=
"rnnt"
asr_model
.
change_decoding_strategy
(
decoder_type
=
decoder_type
)
asr_model
.
eval
()
assert
asr_model
.
encoder
.
streaming_cfg
is
not
None
if
isinstance
(
asr_model
.
encoder
.
streaming_cfg
.
chunk_size
,
list
):
chunk_size
=
asr_model
.
encoder
.
streaming_cfg
.
chunk_size
[
1
]
else
:
chunk_size
=
asr_model
.
encoder
.
streaming_cfg
.
chunk_size
if
isinstance
(
asr_model
.
encoder
.
streaming_cfg
.
pre_encode_cache_size
,
list
):
pre_encode_cache_size
=
asr_model
.
encoder
.
streaming_cfg
.
pre_encode_cache_size
[
1
]
else
:
pre_encode_cache_size
=
asr_model
.
encoder
.
streaming_cfg
.
pre_encode_cache_size
window_size
=
chunk_size
+
pre_encode_cache_size
print
(
"chunk_size"
,
chunk_size
)
print
(
"pre_encode_cache_size"
,
pre_encode_cache_size
)
print
(
"window_size"
,
window_size
)
chunk_shift
=
chunk_size
# cache_last_channel: (batch_size, dim1, dim2, dim3)
cache_last_channel_dim1
=
len
(
asr_model
.
encoder
.
layers
)
cache_last_channel_dim2
=
asr_model
.
encoder
.
streaming_cfg
.
last_channel_cache_size
cache_last_channel_dim3
=
asr_model
.
encoder
.
d_model
# cache_last_time: (batch_size, dim1, dim2, dim3)
cache_last_time_dim1
=
len
(
asr_model
.
encoder
.
layers
)
cache_last_time_dim2
=
asr_model
.
encoder
.
d_model
cache_last_time_dim3
=
asr_model
.
encoder
.
conv_context_size
[
0
]
asr_model
.
set_export_config
({
"decoder_type"
:
"rnnt"
,
"cache_support"
:
True
})
# asr_model.export("model.onnx")
asr_model
.
encoder
.
export
(
"encoder.onnx"
)
asr_model
.
decoder
.
export
(
"decoder.onnx"
)
asr_model
.
joint
.
export
(
"joiner.onnx"
)
# model.onnx is a suffix.
# It will generate two files:
# encoder-model.onnx
# decoder_joint-model.onnx
meta_data
=
{
"vocab_size"
:
asr_model
.
decoder
.
vocab_size
,
"window_size"
:
window_size
,
"chunk_shift"
:
chunk_shift
,
"normalize_type"
:
"None"
,
"cache_last_channel_dim1"
:
cache_last_channel_dim1
,
"cache_last_channel_dim2"
:
cache_last_channel_dim2
,
"cache_last_channel_dim3"
:
cache_last_channel_dim3
,
"cache_last_time_dim1"
:
cache_last_time_dim1
,
"cache_last_time_dim2"
:
cache_last_time_dim2
,
"cache_last_time_dim3"
:
cache_last_time_dim3
,
"pred_rnn_layers"
:
asr_model
.
decoder
.
pred_rnn_layers
,
"pred_hidden"
:
asr_model
.
decoder
.
pred_hidden
,
"subsampling_factor"
:
8
,
"model_type"
:
"EncDecHybridRNNTCTCBPEModel"
,
"version"
:
"1"
,
"model_author"
:
"NeMo"
,
"url"
:
f
"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}"
,
"comment"
:
"Only the transducer branch is exported"
,
}
add_meta_data
(
"encoder.onnx"
,
meta_data
)
print
(
meta_data
)
if
__name__
==
"__main__"
:
main
()
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh
查看文件 @
68b25ab
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set
-ex
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh
0 → 100755
查看文件 @
68b25ab
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set
-ex
if
[
! -e ./0.wav
]
;
then
# curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
fi
ms
=(
80
480
1040
)
for
m
in
${
ms
[@]
}
;
do
./export-onnx-transducer.py --model
$m
d
=
sherpa-onnx-nemo-streaming-fast-conformer-transducer-
${
m
}
ms
if
[
! -f
$d
/encoder.onnx
]
;
then
mkdir -p
$d
mv -v encoder.onnx
$d
/
mv -v decoder.onnx
$d
/
mv -v joiner.onnx
$d
/
mv -v tokens.txt
$d
/
ls -lh
$d
fi
done
# Now test the exported models
for
m
in
${
ms
[@]
}
;
do
d
=
sherpa-onnx-nemo-streaming-fast-conformer-transducer-
${
m
}
ms
python3 ./test-onnx-transducer.py
\
--encoder
$d
/encoder.onnx
\
--decoder
$d
/decoder.onnx
\
--joiner
$d
/joiner.onnx
\
--tokens
$d
/tokens.txt
\
--wav ./0.wav
done
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py
0 → 100755
查看文件 @
68b25ab
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import
onnxruntime
def
show
(
filename
):
session_opts
=
onnxruntime
.
SessionOptions
()
session_opts
.
log_severity_level
=
3
sess
=
onnxruntime
.
InferenceSession
(
filename
,
session_opts
)
for
i
in
sess
.
get_inputs
():
print
(
i
)
print
(
"-----"
)
for
i
in
sess
.
get_outputs
():
print
(
i
)
def
main
():
print
(
"=========encoder=========="
)
show
(
"./encoder.onnx"
)
print
(
"=========decoder=========="
)
show
(
"./decoder.onnx"
)
print
(
"=========joiner=========="
)
show
(
"./joiner.onnx"
)
if
__name__
==
"__main__"
:
main
()
"""
=========encoder==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
NodeArg(name='cache_last_channel', type='tensor(float)', shape=['cache_last_channel_dynamic_axes_1', 17, 'cache_last_channel_dynamic_axes_2', 512])
NodeArg(name='cache_last_time', type='tensor(float)', shape=['cache_last_time_dynamic_axes_1', 17, 512, 'cache_last_time_dynamic_axes_2'])
NodeArg(name='cache_last_channel_len', type='tensor(int64)', shape=['cache_last_channel_len_dynamic_axes_1'])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2'])
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1'])
NodeArg(name='cache_last_channel_next', type='tensor(float)', shape=['cache_last_channel_next_dynamic_axes_1', 17, 'cache_last_channel_next_dynamic_axes_2', 512])
NodeArg(name='cache_last_time_next', type='tensor(float)', shape=['cache_last_time_next_dynamic_axes_1', 17, 512, 'cache_last_time_next_dynamic_axes_2'])
NodeArg(name='cache_last_channel_next_len', type='tensor(int64)', shape=['cache_last_channel_next_len_dynamic_axes_1'])
=========decoder==========
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2'])
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1'])
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640])
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2'])
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1'])
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640])
NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640])
=========joiner==========
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2'])
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2'])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025])
"""
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py
查看文件 @
68b25ab
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import
argparse
from
pathlib
import
Path
...
...
scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py
0 → 100755
查看文件 @
68b25ab
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import
argparse
from
pathlib
import
Path
import
kaldi_native_fbank
as
knf
import
librosa
import
numpy
as
np
import
onnxruntime
as
ort
import
soundfile
as
sf
import
torch
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--encoder"
,
type
=
str
,
required
=
True
,
help
=
"Path to encoder.onnx"
)
parser
.
add_argument
(
"--decoder"
,
type
=
str
,
required
=
True
,
help
=
"Path to decoder.onnx"
)
parser
.
add_argument
(
"--joiner"
,
type
=
str
,
required
=
True
,
help
=
"Path to joiner.onnx"
)
parser
.
add_argument
(
"--tokens"
,
type
=
str
,
required
=
True
,
help
=
"Path to tokens.txt"
)
parser
.
add_argument
(
"--wav"
,
type
=
str
,
required
=
True
,
help
=
"Path to test.wav"
)
return
parser
.
parse_args
()
def
create_fbank
():
opts
=
knf
.
FbankOptions
()
opts
.
frame_opts
.
dither
=
0
opts
.
frame_opts
.
remove_dc_offset
=
False
opts
.
frame_opts
.
window_type
=
"hann"
opts
.
mel_opts
.
low_freq
=
0
opts
.
mel_opts
.
num_bins
=
80
opts
.
mel_opts
.
is_librosa
=
True
fbank
=
knf
.
OnlineFbank
(
opts
)
return
fbank
def
compute_features
(
audio
,
fbank
):
assert
len
(
audio
.
shape
)
==
1
,
audio
.
shape
fbank
.
accept_waveform
(
16000
,
audio
)
ans
=
[]
processed
=
0
while
processed
<
fbank
.
num_frames_ready
:
ans
.
append
(
np
.
array
(
fbank
.
get_frame
(
processed
)))
processed
+=
1
ans
=
np
.
stack
(
ans
)
return
ans
class
OnnxModel
:
def
__init__
(
self
,
encoder
:
str
,
decoder
:
str
,
joiner
:
str
,
):
self
.
init_encoder
(
encoder
)
self
.
init_decoder
(
decoder
)
self
.
init_joiner
(
joiner
)
def
init_encoder
(
self
,
encoder
):
session_opts
=
ort
.
SessionOptions
()
session_opts
.
inter_op_num_threads
=
1
session_opts
.
intra_op_num_threads
=
1
self
.
encoder
=
ort
.
InferenceSession
(
encoder
,
sess_options
=
session_opts
,
providers
=
[
"CPUExecutionProvider"
],
)
meta
=
self
.
encoder
.
get_modelmeta
()
.
custom_metadata_map
print
(
meta
)
self
.
window_size
=
int
(
meta
[
"window_size"
])
self
.
chunk_shift
=
int
(
meta
[
"chunk_shift"
])
self
.
cache_last_channel_dim1
=
int
(
meta
[
"cache_last_channel_dim1"
])
self
.
cache_last_channel_dim2
=
int
(
meta
[
"cache_last_channel_dim2"
])
self
.
cache_last_channel_dim3
=
int
(
meta
[
"cache_last_channel_dim3"
])
self
.
cache_last_time_dim1
=
int
(
meta
[
"cache_last_time_dim1"
])
self
.
cache_last_time_dim2
=
int
(
meta
[
"cache_last_time_dim2"
])
self
.
cache_last_time_dim3
=
int
(
meta
[
"cache_last_time_dim3"
])
self
.
pred_rnn_layers
=
int
(
meta
[
"pred_rnn_layers"
])
self
.
pred_hidden
=
int
(
meta
[
"pred_hidden"
])
self
.
init_cache_state
()
def
init_decoder
(
self
,
decoder
):
session_opts
=
ort
.
SessionOptions
()
session_opts
.
inter_op_num_threads
=
1
session_opts
.
intra_op_num_threads
=
1
self
.
decoder
=
ort
.
InferenceSession
(
decoder
,
sess_options
=
session_opts
,
providers
=
[
"CPUExecutionProvider"
],
)
def
init_joiner
(
self
,
joiner
):
session_opts
=
ort
.
SessionOptions
()
session_opts
.
inter_op_num_threads
=
1
session_opts
.
intra_op_num_threads
=
1
self
.
joiner
=
ort
.
InferenceSession
(
joiner
,
sess_options
=
session_opts
,
providers
=
[
"CPUExecutionProvider"
],
)
def
get_decoder_state
(
self
):
batch_size
=
1
state0
=
torch
.
zeros
(
self
.
pred_rnn_layers
,
batch_size
,
self
.
pred_hidden
)
.
numpy
()
state1
=
torch
.
zeros
(
self
.
pred_rnn_layers
,
batch_size
,
self
.
pred_hidden
)
.
numpy
()
return
state0
,
state1
def
init_cache_state
(
self
):
self
.
cache_last_channel
=
torch
.
zeros
(
1
,
self
.
cache_last_channel_dim1
,
self
.
cache_last_channel_dim2
,
self
.
cache_last_channel_dim3
,
dtype
=
torch
.
float32
,
)
.
numpy
()
self
.
cache_last_time
=
torch
.
zeros
(
1
,
self
.
cache_last_time_dim1
,
self
.
cache_last_time_dim2
,
self
.
cache_last_time_dim3
,
dtype
=
torch
.
float32
,
)
.
numpy
()
self
.
cache_last_channel_len
=
torch
.
ones
([
1
],
dtype
=
torch
.
int64
)
.
numpy
()
def
run_encoder
(
self
,
x
:
np
.
ndarray
):
# x: (T, C)
x
=
torch
.
from_numpy
(
x
)
x
=
x
.
t
()
.
unsqueeze
(
0
)
# x: [1, C, T]
x_lens
=
torch
.
tensor
([
x
.
shape
[
-
1
]],
dtype
=
torch
.
int64
)
(
encoder_out
,
out_len
,
cache_last_channel_next
,
cache_last_time_next
,
cache_last_channel_len_next
,
)
=
self
.
encoder
.
run
(
[
self
.
encoder
.
get_outputs
()[
0
]
.
name
,
self
.
encoder
.
get_outputs
()[
1
]
.
name
,
self
.
encoder
.
get_outputs
()[
2
]
.
name
,
self
.
encoder
.
get_outputs
()[
3
]
.
name
,
self
.
encoder
.
get_outputs
()[
4
]
.
name
,
],
{
self
.
encoder
.
get_inputs
()[
0
]
.
name
:
x
.
numpy
(),
self
.
encoder
.
get_inputs
()[
1
]
.
name
:
x_lens
.
numpy
(),
self
.
encoder
.
get_inputs
()[
2
]
.
name
:
self
.
cache_last_channel
,
self
.
encoder
.
get_inputs
()[
3
]
.
name
:
self
.
cache_last_time
,
self
.
encoder
.
get_inputs
()[
4
]
.
name
:
self
.
cache_last_channel_len
,
},
)
self
.
cache_last_channel
=
cache_last_channel_next
self
.
cache_last_time
=
cache_last_time_next
self
.
cache_last_channel_len
=
cache_last_channel_len_next
# [batch_size, dim, T]
return
encoder_out
def
run_decoder
(
self
,
token
:
int
,
state0
:
np
.
ndarray
,
state1
:
np
.
ndarray
,
):
target
=
torch
.
tensor
([[
token
]],
dtype
=
torch
.
int32
)
.
numpy
()
target_len
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int32
)
.
numpy
()
(
decoder_out
,
decoder_out_length
,
state0_next
,
state1_next
,
)
=
self
.
decoder
.
run
(
[
self
.
decoder
.
get_outputs
()[
0
]
.
name
,
self
.
decoder
.
get_outputs
()[
1
]
.
name
,
self
.
decoder
.
get_outputs
()[
2
]
.
name
,
self
.
decoder
.
get_outputs
()[
3
]
.
name
,
],
{
self
.
decoder
.
get_inputs
()[
0
]
.
name
:
target
,
self
.
decoder
.
get_inputs
()[
1
]
.
name
:
target_len
,
self
.
decoder
.
get_inputs
()[
2
]
.
name
:
state0
,
self
.
decoder
.
get_inputs
()[
3
]
.
name
:
state1
,
},
)
return
decoder_out
,
state0_next
,
state1_next
def
run_joiner
(
self
,
encoder_out
:
np
.
ndarray
,
decoder_out
:
np
.
ndarray
,
):
# encoder_out: [batch_size, dim, 1]
# decoder_out: [batch_size, dim, 1]
logit
=
self
.
joiner
.
run
(
[
self
.
joiner
.
get_outputs
()[
0
]
.
name
,
],
{
self
.
joiner
.
get_inputs
()[
0
]
.
name
:
encoder_out
,
self
.
joiner
.
get_inputs
()[
1
]
.
name
:
decoder_out
,
},
)[
0
]
# logit: [batch_size, 1, 1, vocab_size]
return
logit
def
main
():
args
=
get_args
()
assert
Path
(
args
.
encoder
)
.
is_file
(),
args
.
encoder
assert
Path
(
args
.
decoder
)
.
is_file
(),
args
.
decoder
assert
Path
(
args
.
joiner
)
.
is_file
(),
args
.
joiner
assert
Path
(
args
.
tokens
)
.
is_file
(),
args
.
tokens
assert
Path
(
args
.
wav
)
.
is_file
(),
args
.
wav
print
(
vars
(
args
))
model
=
OnnxModel
(
args
.
encoder
,
args
.
decoder
,
args
.
joiner
)
id2token
=
dict
()
with
open
(
args
.
tokens
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
t
,
idx
=
line
.
split
()
id2token
[
int
(
idx
)]
=
t
fbank
=
create_fbank
()
audio
,
sample_rate
=
sf
.
read
(
args
.
wav
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
if
sample_rate
!=
16000
:
audio
=
librosa
.
resample
(
audio
,
orig_sr
=
sample_rate
,
target_sr
=
16000
,
)
sample_rate
=
16000
tail_padding
=
np
.
zeros
(
sample_rate
*
2
)
audio
=
np
.
concatenate
([
audio
,
tail_padding
])
window_size
=
model
.
window_size
chunk_shift
=
model
.
chunk_shift
blank
=
len
(
id2token
)
-
1
ans
=
[
blank
]
state0
,
state1
=
model
.
get_decoder_state
()
decoder_out
,
state0_next
,
state1_next
=
model
.
run_decoder
(
ans
[
-
1
],
state0
,
state1
)
features
=
compute_features
(
audio
,
fbank
)
num_chunks
=
(
features
.
shape
[
0
]
-
window_size
)
//
chunk_shift
+
1
for
i
in
range
(
num_chunks
):
start
=
i
*
chunk_shift
end
=
start
+
window_size
chunk
=
features
[
start
:
end
,
:]
encoder_out
=
model
.
run_encoder
(
chunk
)
# encoder_out:[batch_size, dim, T)
for
t
in
range
(
encoder_out
.
shape
[
2
]):
encoder_out_t
=
encoder_out
[:,
:,
t
:
t
+
1
]
logits
=
model
.
run_joiner
(
encoder_out_t
,
decoder_out
)
logits
=
torch
.
from_numpy
(
logits
)
logits
=
logits
.
squeeze
()
idx
=
torch
.
argmax
(
logits
,
dim
=-
1
)
.
item
()
if
idx
!=
blank
:
ans
.
append
(
idx
)
state0
=
state0_next
state1
=
state1_next
decoder_out
,
state0_next
,
state1_next
=
model
.
run_decoder
(
ans
[
-
1
],
state0
,
state1
)
ans
=
ans
[
1
:]
# remove the first blank
tokens
=
[
id2token
[
i
]
for
i
in
ans
]
underline
=
"▁"
# underline = b"\xe2\x96\x81".decode()
text
=
""
.
join
(
tokens
)
.
replace
(
underline
,
" "
)
.
strip
()
print
(
args
.
wav
)
print
(
text
)
main
()
...
...
请
注册
或
登录
后发表评论