Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-10-12 11:59:19 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-10-12 11:59:19 +0800
Commit
32da5ecf022d856141b4cbed5ed444d94b441d79
32da5ecf
1 parent
98b67ad8
Add script to convert vits models (#355)
显示空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
307 行增加
和
0 行删除
.github/workflows/export-vits-ljspeech-to-onnx.yaml
scripts/vits/.gitignore
scripts/vits/__init__.py
scripts/vits/export-onnx-ljs.py
.github/workflows/export-vits-ljspeech-to-onnx.yaml
0 → 100644
查看文件 @
32da5ec
name
:
export-vits-ljspeech-to-onnx
on
:
push
:
branches
:
-
master
paths
:
-
'
scripts/vits/**'
-
'
.github/workflows/export-vits-ljspeech-to-onnx.yaml'
pull_request
:
paths
:
-
'
scripts/vits/**'
-
'
.github/workflows/export-vits-ljspeech-to-onnx.yaml'
workflow_dispatch
:
concurrency
:
group
:
export-vits-ljspeech-${{ github.ref }}
cancel-in-progress
:
true
jobs
:
export-vits-ljspeech-onnx
:
if
:
github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name
:
vits ljspeech
runs-on
:
${{ matrix.os }}
strategy
:
fail-fast
:
false
matrix
:
os
:
[
ubuntu-latest
]
torch
:
[
"
1.13.0"
]
steps
:
-
uses
:
actions/checkout@v4
-
name
:
Install dependencies
shell
:
bash
run
:
|
python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy
python3 -m pip install onnxruntime onnx soundfile
python3 -m pip install scipy cython unidecode phonemizer
# required by phonemizer
# See https://bootphon.github.io/phonemizer/install.html
# To fix the following error: RuntimeError: espeak not installed on your system
#
sudo apt-get install festival espeak-ng mbrola
-
name
:
export vits ljspeech
shell
:
bash
run
:
|
cd scripts/vits
echo "Downloading vits"
git clone https://github.com/jaywalnut310/vits
pushd vits/monotonic_align
python3 setup.py build
ls -lh build/
ls -lh build/lib*/
ls -lh build/lib*/*/
cp build/lib*/monotonic_align/core*.so .
sed -i.bak s/.monotonic_align.core/.core/g ./__init__.py
git diff
popd
export PYTHONPATH=$PWD/vits:$PYTHONPATH
echo "Download models"
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/test.py
python3 ./export-onnx-ljs.py --config vits/configs/ljs_base.json --checkpoint ./pretrained_ljs.pth
python3 ./test.py
ls -lh *.wav
-
uses
:
actions/upload-artifact@v3
with
:
name
:
test-0.wav
path
:
scripts/vits/test-0.wav
-
uses
:
actions/upload-artifact@v3
with
:
name
:
test-1.wav
path
:
scripts/vits/test-1.wav
-
uses
:
actions/upload-artifact@v3
with
:
name
:
test-2.wav
path
:
scripts/vits/test-2.wav
...
...
scripts/vits/.gitignore
0 → 100644
查看文件 @
32da5ec
tokens-ljs.txt
...
...
scripts/vits/__init__.py
0 → 100644
查看文件 @
32da5ec
scripts/vits/export-onnx-ljs.py
0 → 100755
查看文件 @
32da5ec
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script converts vits models trained using the LJ Speech dataset.
Usage:
(1) Download vits
cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits
(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-ljs/tree/main
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
(3) Run this file
./export-onnx-ljs.py
\
--config ~/open-source//vits/configs/ljs_base.json
\
--checkpoint ~/open-source/icefall-models/vits-ljs/pretrained_ljs.pth
It will generate the following two files:
$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 36M Oct 10 20:48 vits-ljs.int8.onnx
-rw-r--r-- 1 fangjun staff 109M Oct 10 20:48 vits-ljs.onnx
"""
import
sys
# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys
.
path
.
insert
(
0
,
"/Users/fangjun/open-source/vits"
)
# noqa
import
argparse
from
pathlib
import
Path
from
typing
import
Dict
,
Any
import
commons
import
onnx
import
torch
import
utils
from
models
import
SynthesizerTrn
from
onnxruntime.quantization
import
QuantType
,
quantize_dynamic
from
text
import
text_to_sequence
from
text.symbols
import
symbols
from
text.symbols
import
_punctuation
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
type
=
str
,
required
=
True
,
help
=
"""Path to ljs_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/ljs_base.json
"""
,
)
parser
.
add_argument
(
"--checkpoint"
,
type
=
str
,
required
=
True
,
help
=
"""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
"""
,
)
return
parser
.
parse_args
()
class
OnnxModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
:
SynthesizerTrn
):
super
()
.
__init__
()
self
.
model
=
model
def
forward
(
self
,
x
,
x_lengths
,
noise_scale
=
1
,
length_scale
=
1
,
noise_scale_w
=
1.0
,
sid
=
None
,
max_len
=
None
,
):
return
self
.
model
.
infer
(
x
=
x
,
x_lengths
=
x_lengths
,
sid
=
sid
,
noise_scale
=
noise_scale
,
length_scale
=
length_scale
,
noise_scale_w
=
noise_scale_w
,
max_len
=
max_len
,
)[
0
]
def
get_text
(
text
,
hps
):
text_norm
=
text_to_sequence
(
text
,
hps
.
data
.
text_cleaners
)
if
hps
.
data
.
add_blank
:
text_norm
=
commons
.
intersperse
(
text_norm
,
0
)
text_norm
=
torch
.
LongTensor
(
text_norm
)
return
text_norm
def
check_args
(
args
):
assert
Path
(
args
.
config
)
.
is_file
(),
args
.
config
assert
Path
(
args
.
checkpoint
)
.
is_file
(),
args
.
checkpoint
def
add_meta_data
(
filename
:
str
,
meta_data
:
Dict
[
str
,
Any
]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model
=
onnx
.
load
(
filename
)
for
key
,
value
in
meta_data
.
items
():
meta
=
model
.
metadata_props
.
add
()
meta
.
key
=
key
meta
.
value
=
str
(
value
)
onnx
.
save
(
model
,
filename
)
def
generate_tokens
():
with
open
(
"tokens-ljs.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
i
,
s
in
enumerate
(
symbols
):
f
.
write
(
f
"{s} {i}
\n
"
)
print
(
"Generated tokens-ljs.txt"
)
@torch.no_grad
()
def
main
():
args
=
get_args
()
check_args
(
args
)
generate_tokens
()
hps
=
utils
.
get_hparams_from_file
(
args
.
config
)
net_g
=
SynthesizerTrn
(
len
(
symbols
),
hps
.
data
.
filter_length
//
2
+
1
,
hps
.
train
.
segment_size
//
hps
.
data
.
hop_length
,
**
hps
.
model
,
)
_
=
net_g
.
eval
()
_
=
utils
.
load_checkpoint
(
args
.
checkpoint
,
net_g
,
None
)
x
=
get_text
(
"Liliana is the most beautiful assistant"
,
hps
)
x
=
x
.
unsqueeze
(
0
)
x_length
=
torch
.
tensor
([
x
.
shape
[
1
]],
dtype
=
torch
.
int64
)
noise_scale
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
length_scale
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
noise_scale_w
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
model
=
OnnxModel
(
net_g
)
opset_version
=
13
filename
=
"vits-ljs.onnx"
torch
.
onnx
.
export
(
model
,
(
x
,
x_length
,
noise_scale
,
length_scale
,
noise_scale_w
),
filename
,
opset_version
=
opset_version
,
input_names
=
[
"x"
,
"x_length"
,
"noise_scale"
,
"length_scale"
,
"noise_scale_w"
],
output_names
=
[
"y"
],
dynamic_axes
=
{
"x"
:
{
0
:
"N"
,
1
:
"L"
},
# n_audio is also known as batch_size
"x_length"
:
{
0
:
"N"
},
"y"
:
{
0
:
"N"
,
2
:
"L"
},
},
)
meta_data
=
{
"model_type"
:
"vits"
,
"comment"
:
"ljspeech"
,
"language"
:
"English"
,
"add_blank"
:
int
(
hps
.
data
.
add_blank
),
"sample_rate"
:
hps
.
data
.
sampling_rate
,
"punctuation"
:
" "
.
join
(
list
(
_punctuation
)),
}
print
(
"meta_data"
,
meta_data
)
add_meta_data
(
filename
=
filename
,
meta_data
=
meta_data
)
print
(
"Generate int8 quantization models"
)
filename_int8
=
"vits-ljs.int8.onnx"
quantize_dynamic
(
model_input
=
filename
,
model_output
=
filename_int8
,
weight_type
=
QuantType
.
QUInt8
,
)
print
(
f
"Saved to {filename} and {filename_int8}"
)
if
__name__
==
"__main__"
:
main
()
...
...
请
注册
或
登录
后发表评论