正在显示
17 个修改的文件
包含
839 行增加
和
0 行删除
| @@ -86,6 +86,15 @@ set(sources | @@ -86,6 +86,15 @@ set(sources | ||
| 86 | wave-reader.cc | 86 | wave-reader.cc |
| 87 | ) | 87 | ) |
| 88 | 88 | ||
| 89 | +list(APPEND sources | ||
| 90 | + lexicon.cc | ||
| 91 | + offline-tts-impl.cc | ||
| 92 | + offline-tts-model-config.cc | ||
| 93 | + offline-tts-vits-model-config.cc | ||
| 94 | + offline-tts-vits-model.cc | ||
| 95 | + offline-tts.cc | ||
| 96 | +) | ||
| 97 | + | ||
| 89 | if(SHERPA_ONNX_ENABLE_CHECK) | 98 | if(SHERPA_ONNX_ENABLE_CHECK) |
| 90 | list(APPEND sources log.cc) | 99 | list(APPEND sources log.cc) |
| 91 | endif() | 100 | endif() |
| @@ -135,23 +144,31 @@ endif() | @@ -135,23 +144,31 @@ endif() | ||
| 135 | add_executable(sherpa-onnx sherpa-onnx.cc) | 144 | add_executable(sherpa-onnx sherpa-onnx.cc) |
| 136 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) | 145 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) |
| 137 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 146 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 147 | +add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | ||
| 138 | 148 | ||
| 139 | 149 | ||
| 140 | target_link_libraries(sherpa-onnx sherpa-onnx-core) | 150 | target_link_libraries(sherpa-onnx sherpa-onnx-core) |
| 141 | target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) | 151 | target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) |
| 142 | target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core) | 152 | target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core) |
| 153 | +target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core) | ||
| 143 | if(NOT WIN32) | 154 | if(NOT WIN32) |
| 144 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | 155 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") |
| 145 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") | 156 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") |
| 146 | 157 | ||
| 147 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | 158 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") |
| 148 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") | 159 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") |
| 160 | + | ||
| 161 | + target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | ||
| 149 | target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") | 162 | target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") |
| 150 | 163 | ||
| 164 | + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | ||
| 165 | + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") | ||
| 166 | + | ||
| 151 | if(SHERPA_ONNX_ENABLE_PYTHON) | 167 | if(SHERPA_ONNX_ENABLE_PYTHON) |
| 152 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") | 168 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") |
| 153 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") | 169 | target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") |
| 154 | target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") | 170 | target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") |
| 171 | + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") | ||
| 155 | endif() | 172 | endif() |
| 156 | endif() | 173 | endif() |
| 157 | 174 | ||
| @@ -170,6 +187,7 @@ install( | @@ -170,6 +187,7 @@ install( | ||
| 170 | sherpa-onnx | 187 | sherpa-onnx |
| 171 | sherpa-onnx-offline | 188 | sherpa-onnx-offline |
| 172 | sherpa-onnx-offline-parallel | 189 | sherpa-onnx-offline-parallel |
| 190 | + sherpa-onnx-offline-tts | ||
| 173 | DESTINATION | 191 | DESTINATION |
| 174 | bin | 192 | bin |
| 175 | ) | 193 | ) |
sherpa-onnx/csrc/lexicon.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/lexicon.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/lexicon.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cctype> | ||
| 9 | +#include <fstream> | ||
| 10 | +#include <sstream> | ||
| 11 | +#include <utility> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +static void ToLowerCase(std::string *in_out) { | ||
| 19 | + std::transform(in_out->begin(), in_out->end(), in_out->begin(), | ||
| 20 | + [](unsigned char c) { return std::tolower(c); }); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +// Note: We don't use SymbolTable here since tokens may contain a blank | ||
| 24 | +// in the first column | ||
| 25 | +static std::unordered_map<std::string, int32_t> ReadTokens( | ||
| 26 | + const std::string &tokens) { | ||
| 27 | + std::unordered_map<std::string, int32_t> token2id; | ||
| 28 | + | ||
| 29 | + std::ifstream is(tokens); | ||
| 30 | + std::string line; | ||
| 31 | + | ||
| 32 | + std::string sym; | ||
| 33 | + int32_t id; | ||
| 34 | + while (std::getline(is, line)) { | ||
| 35 | + std::istringstream iss(line); | ||
| 36 | + iss >> sym; | ||
| 37 | + if (iss.eof()) { | ||
| 38 | + id = atoi(sym.c_str()); | ||
| 39 | + sym = " "; | ||
| 40 | + } else { | ||
| 41 | + iss >> id; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + if (!iss.eof()) { | ||
| 45 | + SHERPA_ONNX_LOGE("Error: %s", line.c_str()); | ||
| 46 | + exit(-1); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | +#if 0 | ||
| 50 | + if (token2id.count(sym)) { | ||
| 51 | + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
| 52 | + sym.c_str(), line.c_str(), token2id.at(sym)); | ||
| 53 | + exit(-1); | ||
| 54 | + } | ||
| 55 | +#endif | ||
| 56 | + token2id.insert({sym, id}); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + return token2id; | ||
| 60 | +} | ||
| 61 | + | ||
| 62 | +static std::vector<int32_t> ConvertTokensToIds( | ||
| 63 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 64 | + const std::vector<std::string> &tokens) { | ||
| 65 | + std::vector<int32_t> ids; | ||
| 66 | + ids.reserve(tokens.size()); | ||
| 67 | + for (const auto &s : tokens) { | ||
| 68 | + if (!token2id.count(s)) { | ||
| 69 | + return {}; | ||
| 70 | + } | ||
| 71 | + int32_t id = token2id.at(s); | ||
| 72 | + ids.push_back(id); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + return ids; | ||
| 76 | +} | ||
| 77 | + | ||
| 78 | +Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | ||
| 79 | + const std::string &punctuations) { | ||
| 80 | + token2id_ = ReadTokens(tokens); | ||
| 81 | + std::ifstream is(lexicon); | ||
| 82 | + | ||
| 83 | + std::string word; | ||
| 84 | + std::vector<std::string> token_list; | ||
| 85 | + std::string line; | ||
| 86 | + std::string phone; | ||
| 87 | + | ||
| 88 | + while (std::getline(is, line)) { | ||
| 89 | + std::istringstream iss(line); | ||
| 90 | + | ||
| 91 | + token_list.clear(); | ||
| 92 | + | ||
| 93 | + iss >> word; | ||
| 94 | + ToLowerCase(&word); | ||
| 95 | + | ||
| 96 | + if (word2ids_.count(word)) { | ||
| 97 | + SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str()); | ||
| 98 | + return; | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + while (iss >> phone) { | ||
| 102 | + token_list.push_back(std::move(phone)); | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list); | ||
| 106 | + if (ids.empty()) { | ||
| 107 | + continue; | ||
| 108 | + } | ||
| 109 | + word2ids_.insert({std::move(word), std::move(ids)}); | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | + // process punctuations | ||
| 113 | + std::vector<std::string> punctuation_list; | ||
| 114 | + SplitStringToVector(punctuations, " ", false, &punctuation_list); | ||
| 115 | + for (auto &s : punctuation_list) { | ||
| 116 | + punctuations_.insert(std::move(s)); | ||
| 117 | + } | ||
| 118 | +} | ||
| 119 | + | ||
| 120 | +std::vector<int64_t> Lexicon::ConvertTextToTokenIds( | ||
| 121 | + const std::string &_text) const { | ||
| 122 | + std::string text(_text); | ||
| 123 | + ToLowerCase(&text); | ||
| 124 | + | ||
| 125 | + std::vector<std::string> words; | ||
| 126 | + SplitStringToVector(text, " ", false, &words); | ||
| 127 | + | ||
| 128 | + std::vector<int64_t> ans; | ||
| 129 | + for (auto w : words) { | ||
| 130 | + std::vector<int64_t> prefix; | ||
| 131 | + while (!w.empty() && punctuations_.count(std::string(1, w[0]))) { | ||
| 132 | + // if w begins with a punctuation | ||
| 133 | + prefix.push_back(token2id_.at(std::string(1, w[0]))); | ||
| 134 | + w = std::string(w.begin() + 1, w.end()); | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + std::vector<int64_t> suffix; | ||
| 138 | + while (!w.empty() && punctuations_.count(std::string(1, w.back()))) { | ||
| 139 | + suffix.push_back(token2id_.at(std::string(1, w.back()))); | ||
| 140 | + w = std::string(w.begin(), w.end() - 1); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + if (!word2ids_.count(w)) { | ||
| 144 | + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); | ||
| 145 | + continue; | ||
| 146 | + } | ||
| 147 | + | ||
| 148 | + const auto &token_ids = word2ids_.at(w); | ||
| 149 | + ans.insert(ans.end(), prefix.begin(), prefix.end()); | ||
| 150 | + ans.insert(ans.end(), token_ids.begin(), token_ids.end()); | ||
| 151 | + ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + return ans; | ||
| 155 | +} | ||
| 156 | + | ||
| 157 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/lexicon.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/lexicon.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_LEXICON_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_LEXICON_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <string> | ||
| 10 | +#include <unordered_map> | ||
| 11 | +#include <unordered_set> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class Lexicon { | ||
| 17 | + public: | ||
| 18 | + Lexicon(const std::string &lexicon, const std::string &tokens, | ||
| 19 | + const std::string &punctuations); | ||
| 20 | + | ||
| 21 | + std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; | ||
| 22 | + | ||
| 23 | + private: | ||
| 24 | + std::unordered_map<std::string, std::vector<int32_t>> word2ids_; | ||
| 25 | + std::unordered_set<std::string> punctuations_; | ||
| 26 | + std::unordered_map<std::string, int32_t> token2id_; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_LEXICON_H_ |
sherpa-onnx/csrc/offline-tts-impl.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-impl.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-tts-vits-impl.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
| 14 | + const OfflineTtsConfig &config) { | ||
| 15 | + // TODO(fangjun): Support other types | ||
| 16 | + return std::make_unique<OfflineTtsVitsImpl>(config); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tts-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/offline-tts.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineTtsImpl { | ||
| 16 | + public: | ||
| 17 | + virtual ~OfflineTtsImpl() = default; | ||
| 18 | + | ||
| 19 | + static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); | ||
| 20 | + | ||
| 21 | + virtual GeneratedAudio Generate(const std::string &text) const = 0; | ||
| 22 | +}; | ||
| 23 | + | ||
| 24 | +} // namespace sherpa_onnx | ||
| 25 | + | ||
| 26 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ |
sherpa-onnx/csrc/offline-tts-model-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +void OfflineTtsModelConfig::Register(ParseOptions *po) { | ||
| 12 | + vits.Register(po); | ||
| 13 | + | ||
| 14 | + po->Register("num-threads", &num_threads, | ||
| 15 | + "Number of threads to run the neural network"); | ||
| 16 | + | ||
| 17 | + po->Register("debug", &debug, | ||
| 18 | + "true to print model information while loading it."); | ||
| 19 | + | ||
| 20 | + po->Register("provider", &provider, | ||
| 21 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +bool OfflineTtsModelConfig::Validate() const { | ||
| 25 | + if (num_threads < 1) { | ||
| 26 | + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 27 | + return false; | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + return vits.Validate(); | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +std::string OfflineTtsModelConfig::ToString() const { | ||
| 34 | + std::ostringstream os; | ||
| 35 | + | ||
| 36 | + os << "OfflineTtsModelConfig("; | ||
| 37 | + os << "vits=" << vits.ToString() << ", "; | ||
| 38 | + os << "num_threads=" << num_threads << ", "; | ||
| 39 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 40 | + os << "provider=\"" << provider << "\")"; | ||
| 41 | + | ||
| 42 | + return os.str(); | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tts-model-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct OfflineTtsModelConfig { | ||
| 16 | + OfflineTtsVitsModelConfig vits; | ||
| 17 | + | ||
| 18 | + int32_t num_threads = 1; | ||
| 19 | + bool debug = false; | ||
| 20 | + std::string provider = "cpu"; | ||
| 21 | + | ||
| 22 | + OfflineTtsModelConfig() = default; | ||
| 23 | + | ||
| 24 | + OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, | ||
| 25 | + int32_t num_threads, bool debug, | ||
| 26 | + const std::string &provider) | ||
| 27 | + : vits(vits), | ||
| 28 | + num_threads(num_threads), | ||
| 29 | + debug(debug), | ||
| 30 | + provider(provider) {} | ||
| 31 | + | ||
| 32 | + void Register(ParseOptions *po); | ||
| 33 | + bool Validate() const; | ||
| 34 | + | ||
| 35 | + std::string ToString() const; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-tts-vits-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-vits-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/lexicon.h" | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-tts-impl.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-tts-vits-model.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 20 | + public: | ||
| 21 | + explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) | ||
| 22 | + : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), | ||
| 23 | + lexicon_(config.model.vits.lexicon, config.model.vits.tokens, | ||
| 24 | + model_->Punctuations()) { | ||
| 25 | + SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str()); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + GeneratedAudio Generate(const std::string &text) const override { | ||
| 29 | + std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); | ||
| 30 | + if (x.empty()) { | ||
| 31 | + SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); | ||
| 32 | + return {}; | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + if (model_->AddBlank()) { | ||
| 36 | + std::vector<int64_t> buffer(x.size() * 2 + 1); | ||
| 37 | + int32_t i = 1; | ||
| 38 | + for (auto k : x) { | ||
| 39 | + buffer[i] = k; | ||
| 40 | + i += 2; | ||
| 41 | + } | ||
| 42 | + x = std::move(buffer); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + auto memory_info = | ||
| 46 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 47 | + | ||
| 48 | + std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())}; | ||
| 49 | + Ort::Value x_tensor = Ort::Value::CreateTensor( | ||
| 50 | + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); | ||
| 51 | + | ||
| 52 | + Ort::Value audio = model_->Run(std::move(x_tensor)); | ||
| 53 | + | ||
| 54 | + std::vector<int64_t> audio_shape = | ||
| 55 | + audio.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 56 | + | ||
| 57 | + int64_t total = 1; | ||
| 58 | + // The output shape may be (1, 1, total) or (1, total) or (total,) | ||
| 59 | + for (auto i : audio_shape) { | ||
| 60 | + total *= i; | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + const float *p = audio.GetTensorData<float>(); | ||
| 64 | + | ||
| 65 | + GeneratedAudio ans; | ||
| 66 | + ans.sample_rate = model_->SampleRate(); | ||
| 67 | + ans.samples = std::vector<float>(p, p + total); | ||
| 68 | + return ans; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + private: | ||
| 72 | + std::unique_ptr<OfflineTtsVitsModel> model_; | ||
| 73 | + Lexicon lexicon_; | ||
| 74 | +}; | ||
| 75 | + | ||
| 76 | +} // namespace sherpa_onnx | ||
| 77 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-tts-vits-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("vits-model", &model, "Path to VITS model"); | ||
| 14 | + po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); | ||
| 15 | + po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +bool OfflineTtsVitsModelConfig::Validate() const { | ||
| 19 | + if (model.empty()) { | ||
| 20 | + SHERPA_ONNX_LOGE("Please provide --vits-model"); | ||
| 21 | + return false; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + if (!FileExists(model)) { | ||
| 25 | + SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str()); | ||
| 26 | + return false; | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + if (lexicon.empty()) { | ||
| 30 | + SHERPA_ONNX_LOGE("Please provide --vits-lexicon"); | ||
| 31 | + return false; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + if (!FileExists(lexicon)) { | ||
| 35 | + SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str()); | ||
| 36 | + return false; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + if (tokens.empty()) { | ||
| 40 | + SHERPA_ONNX_LOGE("Please provide --vits-tokens"); | ||
| 41 | + return false; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + if (!FileExists(tokens)) { | ||
| 45 | + SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str()); | ||
| 46 | + return false; | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + return true; | ||
| 50 | +} | ||
| 51 | + | ||
| 52 | +std::string OfflineTtsVitsModelConfig::ToString() const { | ||
| 53 | + std::ostringstream os; | ||
| 54 | + | ||
| 55 | + os << "OfflineTtsVitsModelConfig("; | ||
| 56 | + os << "model=\"" << model << "\", "; | ||
| 57 | + os << "lexicon=\"" << lexicon << "\", "; | ||
| 58 | + os << "tokens=\"" << tokens << "\")"; | ||
| 59 | + | ||
| 60 | + return os.str(); | ||
| 61 | +} | ||
| 62 | + | ||
| 63 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-tts-vits-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineTtsVitsModelConfig { | ||
| 15 | + std::string model; | ||
| 16 | + std::string lexicon; | ||
| 17 | + std::string tokens; | ||
| 18 | + | ||
| 19 | + OfflineTtsVitsModelConfig() = default; | ||
| 20 | + | ||
| 21 | + OfflineTtsVitsModelConfig(const std::string &model, | ||
| 22 | + const std::string &lexicon, | ||
| 23 | + const std::string &tokens) | ||
| 24 | + : model(model), lexicon(lexicon), tokens(tokens) {} | ||
| 25 | + | ||
| 26 | + void Register(ParseOptions *po); | ||
| 27 | + bool Validate() const; | ||
| 28 | + | ||
| 29 | + std::string ToString() const; | ||
| 30 | +}; | ||
| 31 | + | ||
| 32 | +} // namespace sherpa_onnx | ||
| 33 | + | ||
| 34 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-tts-vits-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-vits-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-vits-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | +#include "sherpa-onnx/csrc/session.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OfflineTtsVitsModel::Impl { | ||
| 19 | + public: | ||
| 20 | + explicit Impl(const OfflineTtsModelConfig &config) | ||
| 21 | + : config_(config), | ||
| 22 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 23 | + sess_opts_(GetSessionOptions(config)), | ||
| 24 | + allocator_{} { | ||
| 25 | + auto buf = ReadFile(config.vits.model); | ||
| 26 | + Init(buf.data(), buf.size()); | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + Ort::Value Run(Ort::Value x) { | ||
| 30 | + auto memory_info = | ||
| 31 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 32 | + | ||
| 33 | + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 34 | + if (x_shape[0] != 1) { | ||
| 35 | + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", | ||
| 36 | + static_cast<int32_t>(x_shape[0])); | ||
| 37 | + exit(-1); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + int64_t len = x_shape[1]; | ||
| 41 | + int64_t len_shape = 1; | ||
| 42 | + | ||
| 43 | + Ort::Value x_length = | ||
| 44 | + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); | ||
| 45 | + | ||
| 46 | + int64_t scale_shape = 1; | ||
| 47 | + float noise_scale = 1; | ||
| 48 | + float length_scale = 1; | ||
| 49 | + float noise_scale_w = 1; | ||
| 50 | + | ||
| 51 | + Ort::Value noise_scale_tensor = | ||
| 52 | + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); | ||
| 53 | + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( | ||
| 54 | + memory_info, &length_scale, 1, &scale_shape, 1); | ||
| 55 | + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( | ||
| 56 | + memory_info, &noise_scale_w, 1, &scale_shape, 1); | ||
| 57 | + | ||
| 58 | + std::array<Ort::Value, 5> inputs = { | ||
| 59 | + std::move(x), std::move(x_length), std::move(noise_scale_tensor), | ||
| 60 | + std::move(length_scale_tensor), std::move(noise_scale_w_tensor)}; | ||
| 61 | + | ||
| 62 | + auto out = | ||
| 63 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 64 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 65 | + | ||
| 66 | + return std::move(out[0]); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + int32_t SampleRate() const { return sample_rate_; } | ||
| 70 | + | ||
| 71 | + bool AddBlank() const { return add_blank_; } | ||
| 72 | + | ||
| 73 | + std::string Punctuations() const { return punctuations_; } | ||
| 74 | + | ||
| 75 | + private: | ||
| 76 | + void Init(void *model_data, size_t model_data_length) { | ||
| 77 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 78 | + sess_opts_); | ||
| 79 | + | ||
| 80 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 81 | + | ||
| 82 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 83 | + | ||
| 84 | + // get meta data | ||
| 85 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 86 | + if (config_.debug) { | ||
| 87 | + std::ostringstream os; | ||
| 88 | + os << "---vits model---\n"; | ||
| 89 | + PrintModelMetadata(os, meta_data); | ||
| 90 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 94 | + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); | ||
| 95 | + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); | ||
| 96 | + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + private: | ||
| 100 | + OfflineTtsModelConfig config_; | ||
| 101 | + Ort::Env env_; | ||
| 102 | + Ort::SessionOptions sess_opts_; | ||
| 103 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 104 | + | ||
| 105 | + std::unique_ptr<Ort::Session> sess_; | ||
| 106 | + | ||
| 107 | + std::vector<std::string> input_names_; | ||
| 108 | + std::vector<const char *> input_names_ptr_; | ||
| 109 | + | ||
| 110 | + std::vector<std::string> output_names_; | ||
| 111 | + std::vector<const char *> output_names_ptr_; | ||
| 112 | + | ||
| 113 | + int32_t sample_rate_; | ||
| 114 | + int32_t add_blank_; | ||
| 115 | + std::string punctuations_; | ||
| 116 | +}; | ||
| 117 | + | ||
| 118 | +OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) | ||
| 119 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 120 | + | ||
| 121 | +OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; | ||
| 122 | + | ||
| 123 | +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) { | ||
| 124 | + return impl_->Run(std::move(x)); | ||
| 125 | +} | ||
| 126 | + | ||
| 127 | +int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } | ||
| 128 | + | ||
| 129 | +bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); } | ||
| 130 | + | ||
| 131 | +std::string OfflineTtsVitsModel::Punctuations() const { | ||
| 132 | + return impl_->Punctuations(); | ||
| 133 | +} | ||
| 134 | + | ||
| 135 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tts-vits-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-vits-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineTtsVitsModel { | ||
| 17 | + public: | ||
| 18 | + ~OfflineTtsVitsModel(); | ||
| 19 | + | ||
| 20 | + explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); | ||
| 21 | + | ||
| 22 | + /** Run the model. | ||
| 23 | + * | ||
| 24 | + * @param x A int64 tensor of shape (1, num_tokens) | ||
| 25 | + * @return Return a float32 tensor containing audio samples. You can flatten | ||
| 26 | + * it to a 1-D tensor. | ||
| 27 | + */ | ||
| 28 | + Ort::Value Run(Ort::Value x); | ||
| 29 | + | ||
| 30 | + // Sample rate of the generated audio | ||
| 31 | + int32_t SampleRate() const; | ||
| 32 | + | ||
| 33 | + // true to insert a blank between each token | ||
| 34 | + bool AddBlank() const; | ||
| 35 | + | ||
| 36 | + std::string Punctuations() const; | ||
| 37 | + | ||
| 38 | + private: | ||
| 39 | + class Impl; | ||
| 40 | + std::unique_ptr<Impl> impl_; | ||
| 41 | +}; | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx | ||
| 44 | + | ||
| 45 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ |
sherpa-onnx/csrc/offline-tts.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-tts-impl.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); } | ||
| 14 | + | ||
| 15 | +bool OfflineTtsConfig::Validate() const { return model.Validate(); } | ||
| 16 | + | ||
| 17 | +std::string OfflineTtsConfig::ToString() const { | ||
| 18 | + std::ostringstream os; | ||
| 19 | + | ||
| 20 | + os << "OfflineTtsConfig("; | ||
| 21 | + os << "model=" << model.ToString() << ")"; | ||
| 22 | + | ||
| 23 | + return os.str(); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +OfflineTts::OfflineTts(const OfflineTtsConfig &config) | ||
| 27 | + : impl_(OfflineTtsImpl::Create(config)) {} | ||
| 28 | + | ||
| 29 | +OfflineTts::~OfflineTts() = default; | ||
| 30 | + | ||
| 31 | +GeneratedAudio OfflineTts::Generate(const std::string &text) const { | ||
| 32 | + return impl_->Generate(text); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tts.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ | ||
| 6 | + | ||
| 7 | +#include <cstdint> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct OfflineTtsConfig { | ||
| 18 | + OfflineTtsModelConfig model; | ||
| 19 | + | ||
| 20 | + OfflineTtsConfig() = default; | ||
| 21 | + explicit OfflineTtsConfig(const OfflineTtsModelConfig &model) | ||
| 22 | + : model(model) {} | ||
| 23 | + | ||
| 24 | + void Register(ParseOptions *po); | ||
| 25 | + bool Validate() const; | ||
| 26 | + | ||
| 27 | + std::string ToString() const; | ||
| 28 | +}; | ||
| 29 | + | ||
| 30 | +struct GeneratedAudio { | ||
| 31 | + std::vector<float> samples; | ||
| 32 | + int32_t sample_rate; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +class OfflineTtsImpl; | ||
| 36 | + | ||
| 37 | +class OfflineTts { | ||
| 38 | + public: | ||
| 39 | + ~OfflineTts(); | ||
| 40 | + explicit OfflineTts(const OfflineTtsConfig &config); | ||
| 41 | + // @param text A string containing words separated by spaces | ||
| 42 | + GeneratedAudio Generate(const std::string &text) const; | ||
| 43 | + | ||
| 44 | + private: | ||
| 45 | + std::unique_ptr<OfflineTtsImpl> impl_; | ||
| 46 | +}; | ||
| 47 | + | ||
| 48 | +} // namespace sherpa_onnx | ||
| 49 | + | ||
| 50 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ |
| @@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { | @@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { | ||
| 87 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 87 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 88 | } | 88 | } |
| 89 | 89 | ||
| 90 | +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { | ||
| 91 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 92 | +} | ||
| 93 | + | ||
| 90 | } // namespace sherpa_onnx | 94 | } // namespace sherpa_onnx |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include "onnxruntime_cxx_api.h" // NOLINT | 8 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 9 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 9 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 11 | #include "sherpa-onnx/csrc/online-lm-config.h" | 12 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 12 | #include "sherpa-onnx/csrc/online-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/vad-model-config.h" | 14 | #include "sherpa-onnx/csrc/vad-model-config.h" |
| @@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); | @@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); | ||
| 23 | Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | 24 | Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); |
| 24 | 25 | ||
| 25 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); | 26 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); |
| 27 | + | ||
| 28 | +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | ||
| 26 | } // namespace sherpa_onnx | 29 | } // namespace sherpa_onnx |
| 27 | 30 | ||
| 28 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ | 31 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ |
sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <fstream> | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/offline-tts.h" | ||
| 8 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 9 | + | ||
| 10 | +int main(int32_t argc, char *argv[]) { | ||
| 11 | + const char *kUsageMessage = R"usage( | ||
| 12 | +Offline text-to-speech with sherpa-onnx | ||
| 13 | + | ||
| 14 | +./bin/sherpa-onnx-offline-tts \ | ||
| 15 | + --vits-model /path/to/model.onnx \ | ||
| 16 | + --vits-lexicon /path/to/lexicon.txt \ | ||
| 17 | + --vits-tokens /path/to/tokens.txt | ||
| 18 | + 'some text within single quotes' | ||
| 19 | + | ||
| 20 | +It will generate a file test.wav. | ||
| 21 | +)usage"; | ||
| 22 | + | ||
| 23 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 24 | + sherpa_onnx::OfflineTtsConfig config; | ||
| 25 | + config.Register(&po); | ||
| 26 | + po.Read(argc, argv); | ||
| 27 | + | ||
| 28 | + if (po.NumArgs() == 0) { | ||
| 29 | + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); | ||
| 30 | + po.PrintUsage(); | ||
| 31 | + exit(EXIT_FAILURE); | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + if (po.NumArgs() > 1) { | ||
| 35 | + fprintf(stderr, | ||
| 36 | + "Error: Accept only one positional argument. Please use single " | ||
| 37 | + "quotes to wrap your text\n"); | ||
| 38 | + po.PrintUsage(); | ||
| 39 | + exit(EXIT_FAILURE); | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + if (!config.Validate()) { | ||
| 43 | + fprintf(stderr, "Errors in config!\n"); | ||
| 44 | + exit(EXIT_FAILURE); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + sherpa_onnx::OfflineTts tts(config); | ||
| 48 | + auto audio = tts.Generate(po.GetArg(1)); | ||
| 49 | + | ||
| 50 | + std::ofstream os("t.pcm", std::ios::binary); | ||
| 51 | + os.write(reinterpret_cast<const char *>(audio.samples.data()), | ||
| 52 | + sizeof(float) * audio.samples.size()); | ||
| 53 | + | ||
| 54 | + // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav | ||
| 55 | + | ||
| 56 | + return 0; | ||
| 57 | +} |
-
请 注册 或 登录 后发表评论