Fangjun Kuang
Committed by GitHub

add python API and examples for TTS (#364)

@@ -8,6 +8,24 @@ log() { @@ -8,6 +8,24 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "Offline TTS test"
  12 +
  13 +wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
  14 +wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
  15 +wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
  16 +
  17 +python3 ./python-api-examples/offline-tts.py \
  18 + --vits-model=./vits-ljs.onnx \
  19 + --vits-lexicon=./lexicon.txt \
  20 + --vits-tokens=./tokens.txt \
  21 + --output-filename=./tts.wav \
  22 + 'liliana, the most beautiful and lovely assistant of our team!'
  23 +
  24 +ls -lh ./tts.wav
  25 +file ./tts.wav
  26 +
  27 +rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt
  28 +
11 mkdir -p /tmp/icefall-models 29 mkdir -p /tmp/icefall-models
12 dir=/tmp/icefall-models 30 dir=/tmp/icefall-models
13 31
@@ -171,3 +189,5 @@ rm -rf $repo @@ -171,3 +189,5 @@ rm -rf $repo
171 git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data 189 git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
172 190
173 python3 sherpa-onnx/python/tests/test_text2token.py --verbose 191 python3 sherpa-onnx/python/tests/test_text2token.py --verbose
  192 +
  193 +rm -rf /tmp/sherpa-test-data
@@ -42,7 +42,7 @@ jobs: @@ -42,7 +42,7 @@ jobs:
42 python-version: "3.10" 42 python-version: "3.10"
43 43
44 steps: 44 steps:
45 - - uses: actions/checkout@v2 45 + - uses: actions/checkout@v3
46 with: 46 with:
47 fetch-depth: 0 47 fetch-depth: 0
48 48
@@ -54,7 +54,7 @@ jobs: @@ -54,7 +54,7 @@ jobs:
54 - name: Install Python dependencies 54 - name: Install Python dependencies
55 shell: bash 55 shell: bash
56 run: | 56 run: |
57 - python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 57 + python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile
58 58
59 - name: Install sherpa-onnx 59 - name: Install sherpa-onnx
60 shell: bash 60 shell: bash
@@ -65,3 +65,8 @@ jobs: @@ -65,3 +65,8 @@ jobs:
65 shell: bash 65 shell: bash
66 run: | 66 run: |
67 .github/scripts/test-python.sh 67 .github/scripts/test-python.sh
  68 +
  69 + - uses: actions/upload-artifact@v3
  70 + with:
  71 + name: tts-generated-test-files
  72 + path: tts.wav
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.21") 4 +set(SHERPA_ONNX_VERSION "1.8.0")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -137,6 +137,7 @@ class BuildExtension(build_ext): @@ -137,6 +137,7 @@ class BuildExtension(build_ext):
137 binaries += ["sherpa-onnx-offline-websocket-server"] 137 binaries += ["sherpa-onnx-offline-websocket-server"]
138 binaries += ["sherpa-onnx-online-websocket-client"] 138 binaries += ["sherpa-onnx-online-websocket-client"]
139 binaries += ["sherpa-onnx-vad-microphone"] 139 binaries += ["sherpa-onnx-vad-microphone"]
  140 + binaries += ["sherpa-onnx-offline-tts"]
140 141
141 if is_windows(): 142 if is_windows():
142 binaries += ["kaldi-native-fbank-core.dll"] 143 binaries += ["kaldi-native-fbank-core.dll"]
@@ -144,6 +145,9 @@ class BuildExtension(build_ext): @@ -144,6 +145,9 @@ class BuildExtension(build_ext):
144 binaries += ["sherpa-onnx-core.dll"] 145 binaries += ["sherpa-onnx-core.dll"]
145 binaries += ["sherpa-onnx-portaudio.dll"] 146 binaries += ["sherpa-onnx-portaudio.dll"]
146 binaries += ["onnxruntime.dll"] 147 binaries += ["onnxruntime.dll"]
  148 + binaries += ["kaldi-decoder-core.dll"]
  149 + binaries += ["sherpa-onnx-fst.dll"]
  150 + binaries += ["sherpa-onnx-kaldifst-core.dll"]
147 151
148 for f in binaries: 152 for f in binaries:
149 suffix = "" if "dll" in f else suffix 153 suffix = "" if "dll" in f else suffix
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +This file demonstrates how to use sherpa-onnx Python API to generate audio
  7 +from text, i.e., text-to-speech.
  8 +
  9 +Usage:
  10 +
  11 +1. Download a model
  12 +
  13 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
  14 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
  15 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
  16 +
  17 +python3 ./python-api-examples/offline-tts.py \
  18 + --vits-model=./vits-ljs.onnx \
  19 + --vits-lexicon=./lexicon.txt \
  20 + --vits-tokens=./tokens.txt \
  21 + --output-filename=./generated.wav \
  22 + 'liliana, the most beautiful and lovely assistant of our team!'
  23 +"""
  24 +
  25 +import argparse
  26 +
  27 +import sherpa_onnx
  28 +import soundfile as sf
  29 +
  30 +
  31 +def get_args():
  32 + parser = argparse.ArgumentParser(
  33 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  34 + )
  35 +
  36 + parser.add_argument(
  37 + "--vits-model",
  38 + type=str,
  39 + help="Path to vits model.onnx",
  40 + )
  41 +
  42 + parser.add_argument(
  43 + "--vits-lexicon",
  44 + type=str,
  45 + help="Path to lexicon.txt",
  46 + )
  47 +
  48 + parser.add_argument(
  49 + "--vits-tokens",
  50 + type=str,
  51 + help="Path to tokens.txt",
  52 + )
  53 +
  54 + parser.add_argument(
  55 + "--output-filename",
  56 + type=str,
  57 + default="./generated.wav",
  58 + help="Path to save generated wave",
  59 + )
  60 +
  61 + parser.add_argument(
  62 + "--debug",
  63 + type=bool,
  64 + default=False,
  65 + help="True to show debug messages",
  66 + )
  67 +
  68 + parser.add_argument(
  69 + "--provider",
  70 + type=str,
  71 + default="cpu",
  72 + help="valid values: cpu, cuda, coreml",
  73 + )
  74 +
  75 + parser.add_argument(
  76 + "--num-threads",
  77 + type=int,
  78 + default=1,
  79 + help="Number of threads for neural network computation",
  80 + )
  81 +
  82 + parser.add_argument(
  83 + "text",
  84 + type=str,
  85 + help="The input text to generate audio for",
  86 + )
  87 +
  88 + return parser.parse_args()
  89 +
  90 +
  91 +def main():
  92 + args = get_args()
  93 + print(args)
  94 +
  95 + tts_config = sherpa_onnx.OfflineTtsConfig(
  96 + model=sherpa_onnx.OfflineTtsModelConfig(
  97 + vits=sherpa_onnx.OfflineTtsVitsModelConfig(
  98 + model=args.vits_model,
  99 + lexicon=args.vits_lexicon,
  100 + tokens=args.vits_tokens,
  101 + ),
  102 + provider=args.provider,
  103 + debug=args.debug,
  104 + num_threads=args.num_threads,
  105 + )
  106 + )
  107 + tts = sherpa_onnx.OfflineTts(tts_config)
  108 + audio = tts.generate(args.text)
  109 + sf.write(
  110 + args.output_filename,
  111 + audio.samples,
  112 + samplerate=audio.sample_rate,
  113 + subtype="PCM_16",
  114 + )
  115 + print(f"Saved to {args.output_filename}")
  116 + print(f"The text is '{args.text}'")
  117 +
  118 +
  119 +if __name__ == "__main__":
  120 + main()
@@ -57,12 +57,16 @@ def get_binaries_to_install(): @@ -57,12 +57,16 @@ def get_binaries_to_install():
57 binaries += ["sherpa-onnx-offline-websocket-server"] 57 binaries += ["sherpa-onnx-offline-websocket-server"]
58 binaries += ["sherpa-onnx-online-websocket-client"] 58 binaries += ["sherpa-onnx-online-websocket-client"]
59 binaries += ["sherpa-onnx-vad-microphone"] 59 binaries += ["sherpa-onnx-vad-microphone"]
  60 + binaries += ["sherpa-onnx-offline-tts"]
60 if is_windows(): 61 if is_windows():
61 binaries += ["kaldi-native-fbank-core.dll"] 62 binaries += ["kaldi-native-fbank-core.dll"]
62 binaries += ["sherpa-onnx-c-api.dll"] 63 binaries += ["sherpa-onnx-c-api.dll"]
63 binaries += ["sherpa-onnx-core.dll"] 64 binaries += ["sherpa-onnx-core.dll"]
64 binaries += ["sherpa-onnx-portaudio.dll"] 65 binaries += ["sherpa-onnx-portaudio.dll"]
65 binaries += ["onnxruntime.dll"] 66 binaries += ["onnxruntime.dll"]
  67 + binaries += ["kaldi-decoder-core.dll"]
  68 + binaries += ["sherpa-onnx-fst.dll"]
  69 + binaries += ["sherpa-onnx-kaldifst-core.dll"]
66 70
67 exe = [] 71 exe = []
68 for f in binaries: 72 for f in binaries:
@@ -21,9 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -21,9 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
21 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) 21 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
22 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), 22 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens, 23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
24 - model_->Punctuations()) {  
25 - SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());  
26 - } 24 + model_->Punctuations()) {}
27 25
28 GeneratedAudio Generate(const std::string &text) const override { 26 GeneratedAudio Generate(const std::string &text) const override {
29 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); 27 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
@@ -14,6 +14,9 @@ pybind11_add_module(_sherpa_onnx @@ -14,6 +14,9 @@ pybind11_add_module(_sherpa_onnx
14 offline-stream.cc 14 offline-stream.cc
15 offline-tdnn-model-config.cc 15 offline-tdnn-model-config.cc
16 offline-transducer-model-config.cc 16 offline-transducer-model-config.cc
  17 + offline-tts-model-config.cc
  18 + offline-tts-vits-model-config.cc
  19 + offline-tts.cc
17 offline-whisper-model-config.cc 20 offline-whisper-model-config.cc
18 offline-zipformer-ctc-model-config.cc 21 offline-zipformer-ctc-model-config.cc
19 online-lm-config.cc 22 online-lm-config.cc
  1 +// sherpa-onnx/python/csrc/offline-tts-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  10 +#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineTtsModelConfig(py::module *m) {
  15 + PybindOfflineTtsVitsModelConfig(m);
  16 +
  17 + using PyClass = OfflineTtsModelConfig;
  18 +
  19 + py::class_<PyClass>(*m, "OfflineTtsModelConfig")
  20 + .def(py::init<>())
  21 + .def(py::init<const OfflineTtsVitsModelConfig &, int32_t, bool,
  22 + const std::string &>(),
  23 + py::arg("vits"), py::arg("num_threads") = 1,
  24 + py::arg("debug") = false, py::arg("provider") = "cpu")
  25 + .def_readwrite("vits", &PyClass::vits)
  26 + .def_readwrite("num_threads", &PyClass::num_threads)
  27 + .def_readwrite("debug", &PyClass::debug)
  28 + .def_readwrite("provider", &PyClass::provider)
  29 + .def("__str__", &PyClass::ToString);
  30 +}
  31 +
  32 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-tts-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTtsModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineTtsVitsModelConfig(py::module *m) {
  14 + using PyClass = OfflineTtsVitsModelConfig;
  15 +
  16 + py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
  17 + .def(py::init<>())
  18 + .def(py::init<const std::string &, const std::string &,
  19 + const std::string &>(),
  20 + py::arg("model"), py::arg("lexicon"), py::arg("tokens"))
  21 + .def_readwrite("model", &PyClass::model)
  22 + .def_readwrite("lexicon", &PyClass::lexicon)
  23 + .def_readwrite("tokens", &PyClass::tokens)
  24 + .def("__str__", &PyClass::ToString);
  25 +}
  26 +
  27 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-tts-vits-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTtsVitsModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-tts.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/python/csrc/offline-tts.h"
  5 +
  6 +#include "sherpa-onnx/csrc/offline-tts.h"
  7 +#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static void PybindGeneratedAudio(py::module *m) {
  12 + using PyClass = GeneratedAudio;
  13 + py::class_<PyClass>(*m, "GeneratedAudio")
  14 + .def(py::init<>())
  15 + .def_readwrite("samples", &PyClass::samples)
  16 + .def_readwrite("sample_rate", &PyClass::sample_rate)
  17 + .def("__str__", [](PyClass &self) {
  18 + std::ostringstream os;
  19 + os << "GeneratedAudio(sample_rate=" << self.sample_rate << ", ";
  20 + os << "num_samples=" << self.samples.size() << ")";
  21 + return os.str();
  22 + });
  23 +}
  24 +
  25 +static void PybindOfflineTtsConfig(py::module *m) {
  26 + PybindOfflineTtsModelConfig(m);
  27 +
  28 + using PyClass = OfflineTtsConfig;
  29 + py::class_<PyClass>(*m, "OfflineTtsConfig")
  30 + .def(py::init<>())
  31 + .def(py::init<const OfflineTtsModelConfig &>(), py::arg("model"))
  32 + .def_readwrite("model", &PyClass::model)
  33 + .def("__str__", &PyClass::ToString);
  34 +}
  35 +
  36 +void PybindOfflineTts(py::module *m) {
  37 + PybindOfflineTtsConfig(m);
  38 + PybindGeneratedAudio(m);
  39 +
  40 + using PyClass = OfflineTts;
  41 + py::class_<PyClass>(*m, "OfflineTts")
  42 + .def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
  43 + .def("generate", &PyClass::Generate);
  44 +}
  45 +
  46 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-tts.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTts(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #include "sherpa-onnx/python/csrc/offline-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-model-config.h"
14 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 14 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
15 #include "sherpa-onnx/python/csrc/offline-stream.h" 15 #include "sherpa-onnx/python/csrc/offline-stream.h"
  16 +#include "sherpa-onnx/python/csrc/offline-tts.h"
16 #include "sherpa-onnx/python/csrc/online-lm-config.h" 17 #include "sherpa-onnx/python/csrc/online-lm-config.h"
17 #include "sherpa-onnx/python/csrc/online-model-config.h" 18 #include "sherpa-onnx/python/csrc/online-model-config.h"
18 #include "sherpa-onnx/python/csrc/online-recognizer.h" 19 #include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -45,6 +46,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -45,6 +46,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
45 PybindVadModel(&m); 46 PybindVadModel(&m);
46 PybindCircularBuffer(&m); 47 PybindCircularBuffer(&m);
47 PybindVoiceActivityDetector(&m); 48 PybindVoiceActivityDetector(&m);
  49 +
  50 + PybindOfflineTts(&m);
48 } 51 }
49 52
50 } // namespace sherpa_onnx 53 } // namespace sherpa_onnx
@@ -2,6 +2,10 @@ from _sherpa_onnx import ( @@ -2,6 +2,10 @@ from _sherpa_onnx import (
2 CircularBuffer, 2 CircularBuffer,
3 Display, 3 Display,
4 OfflineStream, 4 OfflineStream,
  5 + OfflineTts,
  6 + OfflineTtsConfig,
  7 + OfflineTtsModelConfig,
  8 + OfflineTtsVitsModelConfig,
5 OnlineStream, 9 OnlineStream,
6 SileroVadModelConfig, 10 SileroVadModelConfig,
7 SpeechSegment, 11 SpeechSegment,