Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-07-15 10:47:19 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-07-15 10:47:19 +0800
Commit
04c2319c2c48a7a0a477c26833a89bdfbb853e5f
04c2319c
1 parent
de04b3b9
Export MeloTTS to ONNX (#1129)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
573 行增加
和
0 行删除
.github/workflows/export-melo-tts-to-onnx.yaml
scripts/melo-tts/export-onnx.py
scripts/melo-tts/run.sh
scripts/melo-tts/test.py
.github/workflows/export-melo-tts-to-onnx.yaml
0 → 100644
查看文件 @
04c2319
name
:
export-melo-tts-to-onnx
on
:
push
:
branches
:
-
export-melo-tts-onnx
workflow_dispatch
:
concurrency
:
group
:
export-melo-tts-to-onnx-${{ github.ref }}
cancel-in-progress
:
true
jobs
:
export-melo-tts-to-onnx
:
if
:
github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name
:
export melo-tts
runs-on
:
${{ matrix.os }}
strategy
:
fail-fast
:
false
matrix
:
os
:
[
ubuntu-latest
]
python-version
:
[
"
3.10"
]
steps
:
-
uses
:
actions/checkout@v4
-
name
:
Setup Python ${{ matrix.python-version }}
uses
:
actions/setup-python@v5
with
:
python-version
:
${{ matrix.python-version }}
-
name
:
Run
shell
:
bash
run
:
|
cd scripts/melo-tts
./run.sh
-
uses
:
actions/upload-artifact@v4
with
:
name
:
test.wav
path
:
scripts/melo-tts/test.wav
-
name
:
Publish to huggingface (aishell)
env
:
HF_TOKEN
:
${{ secrets.HF_TOKEN }}
uses
:
nick-fields/retry@v3
with
:
max_attempts
:
20
timeout_seconds
:
200
shell
:
bash
command
:
|
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://huggingface.co/csukuangfj/vits-melo-tts-zh_en huggingface
cd huggingface
git fetch
git pull
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts
cp -v ../scripts/melo-tts/*.onnx .
cp -v ../scripts/melo-tts/lexicon.txt .
cp -v ../scripts/melo-tts/tokens.txt .
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
tar xvf dict.tar.bz2
rm dict.tar.bz2
git lfs track "*.onnx"
git add .
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
cd ..
rm -rf huggingface/.git*
dst=vits-melo-tts-zh_en
mv huggingface $dst
tar cjvf $dst.tar.bz2 $dst
rm -rf $dst
-
name
:
Release
uses
:
svenstaro/upload-release-action@v2
with
:
file_glob
:
true
file
:
./*.tar.bz2
overwrite
:
true
repo_name
:
k2-fsa/sherpa-onnx
repo_token
:
${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag
:
tts-models
...
...
scripts/melo-tts/export-onnx.py
0 → 100755
查看文件 @
04c2319
#!/usr/bin/env python3
from
typing
import
Any
,
Dict
import
onnx
import
torch
from
melo.api
import
TTS
from
melo.text
import
language_id_map
,
language_tone_start_map
from
melo.text.chinese
import
pinyin_to_symbol_map
from
pypinyin
import
Style
,
lazy_pinyin
,
phrases_dict
,
pinyin_dict
for
k
,
v
in
pinyin_to_symbol_map
.
items
():
pinyin_to_symbol_map
[
k
]
=
v
.
split
()
def
get_initial_final_tone
(
word
:
str
):
initials
=
lazy_pinyin
(
word
,
neutral_tone_with_five
=
True
,
style
=
Style
.
INITIALS
)
finals
=
lazy_pinyin
(
word
,
neutral_tone_with_five
=
True
,
style
=
Style
.
FINALS_TONE3
)
ans_phone
=
[]
ans_tone
=
[]
for
c
,
v
in
zip
(
initials
,
finals
):
raw_pinyin
=
c
+
v
v_without_tone
=
v
[:
-
1
]
try
:
tone
=
v
[
-
1
]
except
:
print
(
"skip"
,
word
,
initials
,
finals
)
return
[],
[]
pinyin
=
c
+
v_without_tone
assert
tone
in
"12345"
if
c
:
v_rep_map
=
{
"uei"
:
"ui"
,
"iou"
:
"iu"
,
"uen"
:
"un"
,
}
if
v_without_tone
in
v_rep_map
.
keys
():
pinyin
=
c
+
v_rep_map
[
v_without_tone
]
else
:
pinyin_rep_map
=
{
"ing"
:
"ying"
,
"i"
:
"yi"
,
"in"
:
"yin"
,
"u"
:
"wu"
,
}
if
pinyin
in
pinyin_rep_map
.
keys
():
pinyin
=
pinyin_rep_map
[
pinyin
]
else
:
single_rep_map
=
{
"v"
:
"yu"
,
"e"
:
"e"
,
"i"
:
"y"
,
"u"
:
"w"
,
}
if
pinyin
[
0
]
in
single_rep_map
.
keys
():
pinyin
=
single_rep_map
[
pinyin
[
0
]]
+
pinyin
[
1
:]
# print(word, initials, finals, pinyin)
if
pinyin
not
in
pinyin_to_symbol_map
:
print
(
"skip"
,
pinyin
,
word
,
c
,
v
,
raw_pinyin
)
continue
phone
=
pinyin_to_symbol_map
[
pinyin
]
ans_phone
+=
phone
ans_tone
+=
[
tone
]
*
len
(
phone
)
return
ans_phone
,
ans_tone
def
generate_tokens
(
symbol_list
):
with
open
(
"tokens.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
i
,
s
in
enumerate
(
symbol_list
):
f
.
write
(
f
"{s} {i}
\n
"
)
def
generate_lexicon
():
word_dict
=
pinyin_dict
.
pinyin_dict
phrases
=
phrases_dict
.
phrases_dict
with
open
(
"lexicon.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
key
in
word_dict
:
if
not
(
0x4E00
<=
key
<=
0x9FA5
):
continue
w
=
chr
(
key
)
phone
,
tone
=
get_initial_final_tone
(
w
)
if
not
phone
:
continue
phone
=
" "
.
join
(
phone
)
tone
=
" "
.
join
(
tone
)
f
.
write
(
f
"{w} {phone} {tone}
\n
"
)
for
w
in
phrases
:
phone
,
tone
=
get_initial_final_tone
(
w
)
if
not
phone
:
continue
assert
len
(
phone
)
==
len
(
tone
),
(
len
(
phone
),
len
(
tone
),
phone
,
tone
)
phone
=
" "
.
join
(
phone
)
tone
=
" "
.
join
(
tone
)
f
.
write
(
f
"{w} {phone} {tone}
\n
"
)
def
add_meta_data
(
filename
:
str
,
meta_data
:
Dict
[
str
,
Any
]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model
=
onnx
.
load
(
filename
)
while
len
(
model
.
metadata_props
):
model
.
metadata_props
.
pop
()
for
key
,
value
in
meta_data
.
items
():
meta
=
model
.
metadata_props
.
add
()
meta
.
key
=
key
meta
.
value
=
str
(
value
)
onnx
.
save
(
model
,
filename
)
class
ModelWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
:
"SynthesizerTrn"
):
super
()
.
__init__
()
self
.
model
=
model
def
forward
(
self
,
x
,
x_lengths
,
tones
,
lang_id
,
bert
,
ja_bert
,
sid
,
noise_scale
,
length_scale
,
noise_scale_w
,
max_len
=
None
,
):
"""
Args:
x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
return
self
.
model
.
infer
(
x
=
x
,
x_lengths
=
x_lengths
,
sid
=
sid
,
tone
=
tones
,
language
=
lang_id
,
bert
=
bert
,
ja_bert
=
ja_bert
,
noise_scale
=
noise_scale
,
noise_scale_w
=
noise_scale_w
,
length_scale
=
length_scale
,
)[
0
]
def
main
():
generate_lexicon
()
language
=
"ZH"
model
=
TTS
(
language
=
language
,
device
=
"cpu"
)
generate_tokens
(
model
.
hps
[
"symbols"
])
torch_model
=
ModelWrapper
(
model
.
model
)
opset_version
=
13
x
=
torch
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
60
,),
dtype
=
torch
.
int64
)
print
(
x
.
shape
)
x_lengths
=
torch
.
tensor
([
x
.
size
(
0
)],
dtype
=
torch
.
int64
)
sid
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
)
tones
=
torch
.
zeros_like
(
x
)
lang_id
=
torch
.
ones_like
(
x
)
noise_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
)
length_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
)
noise_scale_w
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
)
bert
=
torch
.
zeros
(
1024
,
x
.
shape
[
0
],
dtype
=
torch
.
float32
)
ja_bert
=
torch
.
zeros
(
768
,
x
.
shape
[
0
],
dtype
=
torch
.
float32
)
x
=
x
.
unsqueeze
(
0
)
tones
=
tones
.
unsqueeze
(
0
)
lang_id
=
lang_id
.
unsqueeze
(
0
)
bert
=
bert
.
unsqueeze
(
0
)
ja_bert
=
ja_bert
.
unsqueeze
(
0
)
filename
=
"model.onnx"
torch
.
onnx
.
export
(
torch_model
,
(
x
,
x_lengths
,
tones
,
lang_id
,
bert
,
ja_bert
,
sid
,
noise_scale
,
length_scale
,
noise_scale_w
,
),
filename
,
opset_version
=
opset_version
,
input_names
=
[
"x"
,
"x_lengths"
,
"tones"
,
"lang_id"
,
"bert"
,
"ja_bert"
,
"sid"
,
"noise_scale"
,
"length_scale"
,
"noise_scale_w"
,
],
output_names
=
[
"y"
],
dynamic_axes
=
{
"x"
:
{
0
:
"N"
,
1
:
"L"
},
"x_lengths"
:
{
0
:
"N"
},
"tones"
:
{
0
:
"N"
,
1
:
"L"
},
"lang_id"
:
{
0
:
"N"
,
1
:
"L"
},
"bert"
:
{
0
:
"N"
,
2
:
"L"
},
"ja_bert"
:
{
0
:
"N"
,
2
:
"L"
},
"y"
:
{
0
:
"N"
,
1
:
"S"
,
2
:
"T"
},
},
)
meta_data
=
{
"model_type"
:
"melo-vits"
,
"comment"
:
"melo"
,
"language"
:
"Chinese + English"
,
"add_blank"
:
int
(
model
.
hps
.
data
.
add_blank
),
"n_speakers"
:
1
,
"sample_rate"
:
model
.
hps
.
data
.
sampling_rate
,
"bert_dim"
:
1024
,
"ja_bert_dim"
:
768
,
"speaker_id"
:
list
(
model
.
hps
.
data
.
spk2id
.
values
())[
0
],
"lang_id"
:
language_id_map
[
model
.
language
],
"tone_start"
:
language_tone_start_map
[
model
.
language
],
"url"
:
"https://github.com/myshell-ai/MeloTTS"
,
"license"
:
"MIT license"
,
"description"
:
"MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai"
,
}
add_meta_data
(
filename
,
meta_data
)
if
__name__
==
"__main__"
:
main
()
...
...
scripts/melo-tts/run.sh
0 → 100755
查看文件 @
04c2319
#!/usr/bin/env bash
set
-ex
function
install
()
{
pip install
torch
==
2.3.1+cpu
torchaudio
==
2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pushd
/tmp
git clone https://github.com/myshell-ai/MeloTTS
cd
MeloTTS
pip install -r ./requirements.txt
pip install soundfile onnx onnxruntime
python3 -m unidic download
popd
}
install
export
PYTHONPATH
=
/tmp/MeloTTS:
$PYTHONPATH
echo
"pwd:
$PWD
"
./export-onnx.py
ls -lh
head lexicon.txt
echo
"---"
tail lexicon.txt
echo
"---"
head tokens.txt
echo
"---"
tail tokens.txt
./test.py
ls -lh
...
...
scripts/melo-tts/test.py
0 → 100755
查看文件 @
04c2319
#!/usr/bin/env python3
from
typing
import
Iterable
,
List
,
Tuple
import
jieba
import
onnxruntime
as
ort
import
soundfile
as
sf
import
torch
class
Lexicon
:
def
__init__
(
self
,
lexion_filename
:
str
,
tokens_filename
:
str
):
tokens
=
dict
()
with
open
(
tokens_filename
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
s
,
i
=
line
.
split
()
tokens
[
s
]
=
int
(
i
)
lexicon
=
dict
()
with
open
(
lexion_filename
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
splits
=
line
.
split
()
word_or_phrase
=
splits
[
0
]
phone_tone_list
=
splits
[
1
:]
assert
len
(
phone_tone_list
)
&
1
==
0
,
len
(
phone_tone_list
)
phones
=
phone_tone_list
[:
len
(
phone_tone_list
)
//
2
]
phones
=
[
tokens
[
p
]
for
p
in
phones
]
tones
=
phone_tone_list
[
len
(
phone_tone_list
)
//
2
:]
tones
=
[
int
(
t
)
for
t
in
tones
]
lexicon
[
word_or_phrase
]
=
(
phones
,
tones
)
self
.
lexicon
=
lexicon
punctuation
=
[
"!"
,
"?"
,
"…"
,
","
,
"."
,
"'"
,
"-"
]
for
p
in
punctuation
:
i
=
tokens
[
p
]
tone
=
0
self
.
lexicon
[
p
]
=
([
i
],
[
tone
])
self
.
lexicon
[
" "
]
=
([
tokens
[
"_"
]],
[
0
])
def
_convert
(
self
,
text
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
phones
=
[]
tones
=
[]
if
text
==
","
:
text
=
","
elif
text
==
"。"
:
text
=
"."
elif
text
==
"!"
:
text
=
"!"
elif
text
==
"?"
:
text
=
"?"
if
text
not
in
self
.
lexicon
:
print
(
"t"
,
text
)
if
len
(
text
)
>
1
:
for
w
in
text
:
print
(
"w"
,
w
)
p
,
t
=
self
.
convert
(
w
)
if
p
:
phones
+=
p
tones
+=
t
return
phones
,
tones
phones
,
tones
=
self
.
lexicon
[
text
]
return
phones
,
tones
def
convert
(
self
,
text_list
:
Iterable
[
str
])
->
Tuple
[
List
[
int
],
List
[
int
]]:
phones
=
[]
tones
=
[]
for
text
in
text_list
:
print
(
text
)
p
,
t
=
self
.
_convert
(
text
)
phones
+=
p
tones
+=
t
return
phones
,
tones
class
OnnxModel
:
def
__init__
(
self
,
filename
):
session_opts
=
ort
.
SessionOptions
()
session_opts
.
inter_op_num_threads
=
1
session_opts
.
intra_op_num_threads
=
4
self
.
session_opts
=
session_opts
self
.
model
=
ort
.
InferenceSession
(
filename
,
sess_options
=
self
.
session_opts
,
providers
=
[
"CPUExecutionProvider"
],
)
meta
=
self
.
model
.
get_modelmeta
()
.
custom_metadata_map
self
.
bert_dim
=
int
(
meta
[
"bert_dim"
])
self
.
ja_bert_dim
=
int
(
meta
[
"ja_bert_dim"
])
self
.
add_blank
=
int
(
meta
[
"add_blank"
])
self
.
sample_rate
=
int
(
meta
[
"sample_rate"
])
self
.
speaker_id
=
int
(
meta
[
"speaker_id"
])
self
.
lang_id
=
int
(
meta
[
"lang_id"
])
self
.
sample_rate
=
int
(
meta
[
"sample_rate"
])
def
__call__
(
self
,
x
,
tones
,
lang
):
"""
Args:
x: 1-D int64 torch tensor
tones: 1-D int64 torch tensor
lang: 1-D int64 torch tensor
"""
x
=
x
.
unsqueeze
(
0
)
tones
=
tones
.
unsqueeze
(
0
)
lang
=
lang
.
unsqueeze
(
0
)
print
(
x
.
shape
,
tones
.
shape
,
lang
.
shape
)
bert
=
torch
.
zeros
(
1
,
self
.
bert_dim
,
x
.
shape
[
-
1
])
ja_bert
=
torch
.
zeros
(
1
,
self
.
ja_bert_dim
,
x
.
shape
[
-
1
])
sid
=
torch
.
tensor
([
self
.
speaker_id
],
dtype
=
torch
.
int64
)
noise_scale
=
torch
.
tensor
([
0.6
],
dtype
=
torch
.
float32
)
length_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
)
noise_scale_w
=
torch
.
tensor
([
0.8
],
dtype
=
torch
.
float32
)
x_lengths
=
torch
.
tensor
([
x
.
shape
[
-
1
]],
dtype
=
torch
.
int64
)
y
=
self
.
model
.
run
(
[
"y"
],
{
"x"
:
x
.
numpy
(),
"x_lengths"
:
x_lengths
.
numpy
(),
"tones"
:
tones
.
numpy
(),
"lang_id"
:
lang
.
numpy
(),
"bert"
:
bert
.
numpy
(),
"ja_bert"
:
ja_bert
.
numpy
(),
"sid"
:
sid
.
numpy
(),
"noise_scale"
:
noise_scale
.
numpy
(),
"noise_scale_w"
:
noise_scale_w
.
numpy
(),
"length_scale"
:
length_scale
.
numpy
(),
},
)[
0
][
0
][
0
]
return
y
def
main
():
lexicon
=
Lexicon
(
lexion_filename
=
"./lexicon.txt"
,
tokens_filename
=
"./tokens.txt"
)
text
=
"永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大"
s
=
jieba
.
cut
(
text
,
HMM
=
True
)
phones
,
tones
=
lexicon
.
convert
(
s
)
model
=
OnnxModel
(
"./model.onnx"
)
langs
=
[
model
.
lang_id
]
*
len
(
phones
)
if
model
.
add_blank
:
new_phones
=
[
0
]
*
(
2
*
len
(
phones
)
+
1
)
new_tones
=
[
0
]
*
(
2
*
len
(
tones
)
+
1
)
new_langs
=
[
0
]
*
(
2
*
len
(
langs
)
+
1
)
new_phones
[
1
::
2
]
=
phones
new_tones
[
1
::
2
]
=
tones
new_langs
[
1
::
2
]
=
langs
phones
=
new_phones
tones
=
new_tones
langs
=
new_langs
phones
=
torch
.
tensor
(
phones
,
dtype
=
torch
.
int64
)
tones
=
torch
.
tensor
(
tones
,
dtype
=
torch
.
int64
)
langs
=
torch
.
tensor
(
langs
,
dtype
=
torch
.
int64
)
print
(
phones
.
shape
,
tones
.
shape
,
langs
.
shape
)
y
=
model
(
x
=
phones
,
tones
=
tones
,
lang
=
langs
)
sf
.
write
(
"./test.wav"
,
y
,
model
.
sample_rate
)
if
__name__
==
"__main__"
:
main
()
...
...
请
注册
或
登录
后发表评论