Fangjun Kuang
Committed by GitHub

Add C++ runtime for MeloTTS (#1138)

正在显示 51 个修改的文件 包含 693 行增加156 行删除
... ... @@ -63,10 +63,16 @@ jobs:
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts
rm -rf ./
cp -v ../scripts/melo-tts/*.onnx .
cp -v ../scripts/melo-tts/lexicon.txt .
cp -v ../scripts/melo-tts/tokens.txt .
cp -v ../scripts/melo-tts/README.md .
curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/new_heteronym.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
... ... @@ -77,6 +83,10 @@ jobs:
git lfs track "*.onnx"
git add .
ls -lh
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
... ...
... ... @@ -39,10 +39,14 @@ jobs:
cd build
cmake \
-A x64 \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_SHARED_LIBS=ON \
-DBUILD_SHARED_LIBS=ON \
-D SHERPA_ONNX_ENABLE_JNI=ON \
-D CMAKE_INSTALL_PREFIX=./install \
-DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
-DBUILD_ESPEAK_NG_EXE=OFF \
-DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF \
-DSHERPA_ONNX_ENABLE_BINARY=ON \
..
- name: Build sherpa-onnx for windows
... ...
## 1.10.16
* Support zh-en TTS model from MeloTTS.
## 1.10.15
* Downgrade onnxruntime from v1.18.1 to v1.17.1
... ...
... ... @@ -11,7 +11,7 @@ project(sherpa-onnx)
# ./nodejs-addon-examples
# ./dart-api-examples/
# ./CHANGELOG.md
set(SHERPA_ONNX_VERSION "1.10.15")
set(SHERPA_ONNX_VERSION "1.10.16")
# Disable warning about
#
... ...
... ... @@ -10,7 +10,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -11,7 +11,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -8,7 +8,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -9,7 +9,7 @@ environment:
sdk: ^3.4.0
dependencies:
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -5,7 +5,7 @@ description: >
publish_to: 'none'
version: 1.10.14
version: 1.10.16
topics:
- speech-recognition
... ... @@ -30,7 +30,7 @@ dependencies:
record: ^5.1.0
url_launcher: ^6.2.6
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
# sherpa_onnx:
# path: ../../flutter/sherpa_onnx
... ...
... ... @@ -5,7 +5,7 @@ description: >
publish_to: 'none' # Remove this line if you wish to publish to pub.dev
version: 1.0.0
version: 1.10.16
environment:
sdk: '>=3.4.0 <4.0.0'
... ... @@ -17,7 +17,7 @@ dependencies:
cupertino_icons: ^1.0.6
path_provider: ^2.1.3
path: ^1.9.0
sherpa_onnx: ^1.10.15
sherpa_onnx: ^1.10.16
url_launcher: ^6.2.6
audioplayers: ^5.0.0
... ...
... ... @@ -17,7 +17,7 @@ topics:
- voice-activity-detection
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
version: 1.10.15
version: 1.10.16
homepage: https://github.com/k2-fsa/sherpa-onnx
... ... @@ -30,19 +30,19 @@ dependencies:
flutter:
sdk: flutter
sherpa_onnx_android: ^1.10.15
sherpa_onnx_android: ^1.10.16
# path: ../sherpa_onnx_android
sherpa_onnx_macos: ^1.10.15
sherpa_onnx_macos: ^1.10.16
# path: ../sherpa_onnx_macos
sherpa_onnx_linux: ^1.10.15
sherpa_onnx_linux: ^1.10.16
# path: ../sherpa_onnx_linux
#
sherpa_onnx_windows: ^1.10.15
sherpa_onnx_windows: ^1.10.16
# path: ../sherpa_onnx_windows
sherpa_onnx_ios: ^1.10.15
sherpa_onnx_ios: ^1.10.16
# sherpa_onnx_ios:
# path: ../sherpa_onnx_ios
... ...
... ... @@ -7,7 +7,7 @@
# https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
Pod::Spec.new do |s|
s.name = 'sherpa_onnx_ios'
s.version = '1.10.15'
s.version = '1.10.16'
s.summary = 'A new Flutter FFI plugin project.'
s.description = <<-DESC
A new Flutter FFI plugin project.
... ...
... ... @@ -4,7 +4,7 @@
#
Pod::Spec.new do |s|
s.name = 'sherpa_onnx_macos'
s.version = '1.10.15'
s.version = '1.10.16'
s.summary = 'sherpa-onnx Flutter FFI plugin project.'
s.description = <<-DESC
sherpa-onnx Flutter FFI plugin project.
... ...
{
"dependencies": {
"sherpa-onnx-node": "^1.10.15"
"sherpa-onnx-node": "^1.10.16"
}
}
... ...
... ... @@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt
git diff
popd
if [[ $model_dir == vits-melo-tts-zh_en ]]; then
lang=zh_en
fi
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build tts apk for $arch"
... ...
... ... @@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
git diff
popd
if [[ $model_dir == vits-melo-tts-zh_en ]]; then
lang=zh_en
fi
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build tts apk for $arch"
... ...
... ... @@ -313,6 +313,11 @@ def get_vits_models() -> List[TtsModel]:
lang="zh",
),
TtsModel(
model_dir="vits-melo-tts-zh_en",
model_name="model.onnx",
lang="zh",
),
TtsModel(
model_dir="vits-zh-hf-fanchen-C",
model_name="vits-zh-hf-fanchen-C.onnx",
lang="zh",
... ... @@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]:
),
]
rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"]
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
for m in chinese_models:
s = [f"{m.model_dir}/{r}" for r in rule_fsts]
if "vits-zh-hf" in m.model_dir or "sherpa-onnx-vits-zh-ll" == m.model_dir:
if (
"vits-zh-hf" in m.model_dir
or "sherpa-onnx-vits-zh-ll" == m.model_dir
or "melo-tts" in m.model_dir
):
s = s[:-1]
m.dict_dir = m.model_dir + "/dict"
else:
m.rule_fars = f"{m.model_dir}/rule.far"
m.rule_fsts = ",".join(s)
if "vits-zh-hf" not in m.model_dir and "zh-ll" not in m.model_dir:
m.rule_fars = f"{m.model_dir}/rule.far"
all_models = chinese_models + [
TtsModel(
model_dir="vits-cantonese-hf-xiaomaiiwn",
... ...
... ... @@ -17,7 +17,7 @@ topics:
- voice-activity-detection
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
version: 1.10.15
version: 1.10.16
homepage: https://github.com/k2-fsa/sherpa-onnx
... ...
... ... @@ -6,9 +6,6 @@ from typing import List, Optional
import jinja2
# pip install iso639-lang
from iso639 import Lang
def get_args():
parser = argparse.ArgumentParser()
... ... @@ -37,13 +34,6 @@ class TtsModel:
data_dir: Optional[str] = None
dict_dir: Optional[str] = None
is_char: bool = False
lang_iso_639_3: str = ""
def convert_lang_to_iso_639_3(models: List[TtsModel]):
for m in models:
if m.lang_iso_639_3 == "":
m.lang_iso_639_3 = Lang(m.lang).pt3
def get_coqui_models() -> List[TtsModel]:
... ... @@ -313,6 +303,11 @@ def get_vits_models() -> List[TtsModel]:
lang="zh",
),
TtsModel(
model_dir="vits-melo-tts-zh_en",
model_name="model.onnx",
lang="zh_en",
),
TtsModel(
model_dir="vits-zh-hf-fanchen-C",
model_name="vits-zh-hf-fanchen-C.onnx",
lang="zh",
... ... @@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]:
model_name="vits-zh-hf-fanchen-unity.onnx",
lang="zh",
),
TtsModel(
model_dir="sherpa-onnx-vits-zh-ll",
model_name="model.onnx",
lang="zh",
),
]
rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"]
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
for m in chinese_models:
s = [f"{m.model_dir}/{r}" for r in rule_fsts]
if "vits-zh-hf" in m.model_dir:
if (
"vits-zh-hf" in m.model_dir
or "sherpa-onnx-vits-zh-ll" == m.model_dir
or "melo-tts" in m.model_dir
):
s = s[:-1]
m.dict_dir = m.model_dir + "/dict"
else:
m.rule_fars = f"{m.model_dir}/rule.far"
m.rule_fsts = ",".join(s)
if "vits-zh-hf" not in m.model_dir:
m.rule_fars = f"{m.model_dir}/rule.far"
all_models = chinese_models + [
TtsModel(
model_dir="vits-cantonese-hf-xiaomaiiwn",
model_name="vits-cantonese-hf-xiaomaiiwn.onnx",
lang="cantonese",
lang_iso_639_3="yue",
rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst",
),
# English (US)
... ... @@ -374,7 +376,6 @@ def main():
all_model_list += get_piper_models()
all_model_list += get_mimic3_models()
all_model_list += get_coqui_models()
convert_lang_to_iso_639_3(all_model_list)
num_models = len(all_model_list)
... ...
# Introduction
Models in this directory are converted from
https://github.com/myshell-ai/MeloTTS
Note there is only a single female speaker in the model.
... ...
... ... @@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from melo.text.english import eng_dict, refine_syllables
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
from melo.text.symbols import language_tone_start_map
for k, v in pinyin_to_symbol_map.items():
if isinstance(v, list):
... ... @@ -82,6 +81,7 @@ def generate_tokens(symbol_list):
def generate_lexicon():
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
eng_dict["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]]
with open("lexicon.txt", "w", encoding="utf-8") as f:
for word in eng_dict:
phones, tones = refine_syllables(eng_dict[word])
... ... @@ -237,9 +237,11 @@ def main():
meta_data = {
"model_type": "melo-vits",
"comment": "melo",
"version": 2,
"language": "Chinese + English",
"add_blank": int(model.hps.data.add_blank),
"n_speakers": 1,
"jieba": 1,
"sample_rate": model.hps.data.sampling_rate,
"bert_dim": 1024,
"ja_bert_dim": 768,
... ...
... ... @@ -12,7 +12,7 @@ function install() {
cd MeloTTS
pip install -r ./requirements.txt
pip install soundfile onnx onnxruntime
pip install soundfile onnx==1.15.0 onnxruntime==1.16.3
python3 -m unidic download
popd
... ...
... ... @@ -135,28 +135,11 @@ class OnnxModel:
def main():
lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
text = "永远相信,美好的事情即将发生。"
text = "这是一个使用 next generation kaldi 的 text to speech 中英文例子. Thank you! 你觉得如何呢? are you ok? Fantastic! How about you?"
s = jieba.cut(text, HMM=True)
phones, tones = lexicon.convert(s)
en_text = "how are you ?".split()
phones_en, tones_en = lexicon.convert(en_text)
phones += [0]
tones += [0]
phones += phones_en
tones += tones_en
text = "多音字测试, 银行,行不行?长沙长大"
s = jieba.cut(text, HMM=True)
phones2, tones2 = lexicon.convert(s)
phones += phones2
tones += tones2
model = OnnxModel("./model.onnx")
if model.add_blank:
... ...
... ... @@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
void SherpaOnnxOfflineRecognizerSetConfig(
const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config){
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);
recognizer->impl->SetConfig(recognizer_config);
recognizer->impl->SetConfig(recognizer_config);
}
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
... ... @@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
pText[text.size()] = 0;
r->text = pText;
//lang
// lang
const auto &lang = result.lang;
char *c_lang = new char[lang.size() + 1];
std::copy(lang.begin(), lang.end(), c_lang);
... ... @@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
}
delete[] r->matches;
delete r;
};
}
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
... ...
... ... @@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;
// Pointer to continuous memory which holds timestamps
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;
... ... @@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
*/
const char *json;
//return recognized language
// return recognized language
const char *lang;
} SherpaOnnxOfflineRecognizerResult;
/// Get the result of the offline stream.
... ...
... ... @@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
jieba-lexicon.cc
lexicon.cc
melo-tts-lexicon.cc
offline-tts-character-frontend.cc
offline-tts-frontend.cc
offline-tts-impl.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
... ...
... ... @@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) {
std::vector<std::string> words;
std::vector<cppjieba::Word> jiebawords;
std::string s = "他来到了网易杭研大厦";
std::string s = "他来到了网易杭研大厦。How are you?";
std::cout << s << std::endl;
std::cout << "[demo] Cut With HMM" << std::endl;
jieba.Cut(s, words, true);
... ...
... ... @@ -17,6 +17,7 @@ namespace sherpa_onnx {
// implemented in ./lexicon.cc
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
... ... @@ -53,8 +54,7 @@ class JiebaLexicon::Impl {
}
}
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
const std::string &text) const {
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const {
// see
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
std::regex punct_re{":|、|;"};
... ... @@ -87,7 +87,7 @@ class JiebaLexicon::Impl {
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
int32_t blank = token2id_.at(" ");
... ... @@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon,
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
}
... ...
... ... @@ -10,11 +10,6 @@
#include <unordered_map>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
... ... @@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
#if __ANDROID_API__ >= 9
JiebaLexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;
... ...
... ... @@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
}
#endif
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
switch (language_) {
case Language::kChinese:
... ... @@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
return {};
}
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
... ... @@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
fprintf(stderr, "\n");
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
int32_t blank = -1;
... ... @@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
return ans;
}
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
... ... @@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
int32_t blank = token2id_.at(" ");
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
for (const auto &w : words) {
... ...
... ... @@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend {
const std::string &language, bool debug = false);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese(
const std::string &text) const;
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
std::vector<TokenIDs> ConvertTextToTokenIdsChinese(
const std::string &text) const;
void InitLanguage(const std::string &lang);
... ...
// sherpa-onnx/csrc/melo-tts-lexicon.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include <fstream>
#include <regex> // NOLINT
#include <utility>
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
// implemented in ./lexicon.cc
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
class MeloTtsLexicon::Impl {
public:
Impl(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
}
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
std::regex punct_re{":|、|;"};
std::string s = std::regex_replace(text, punct_re, ",");
std::regex punct_re2("。");
s = std::regex_replace(s, punct_re2, ".");
std::regex punct_re3("?");
s = std::regex_replace(s, punct_re3, "?");
std::regex punct_re4("!");
s = std::regex_replace(s, punct_re4, "!");
std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
std::vector<TokenIDs> ans;
TokenIDs this_sentence;
int32_t blank = token2id_.at("_");
for (const auto &w : words) {
auto ids = ConvertWordToIds(w);
if (ids.tokens.empty()) {
SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
continue;
}
this_sentence.tokens.insert(this_sentence.tokens.end(),
ids.tokens.begin(), ids.tokens.end());
this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
ids.tones.end());
if (w == "." || w == "!" || w == "?" || w == ",") {
ans.push_back(std::move(this_sentence));
this_sentence = {};
}
} // for (const auto &w : words)
if (!this_sentence.tokens.empty()) {
ans.push_back(std::move(this_sentence));
}
return ans;
}
private:
TokenIDs ConvertWordToIds(const std::string &w) const {
if (word2ids_.count(w)) {
return word2ids_.at(w);
}
if (token2id_.count(w)) {
return {{token2id_.at(w)}, {0}};
}
TokenIDs ans;
std::vector<std::string> words = SplitUtf8(w);
for (const auto &word : words) {
if (word2ids_.count(word)) {
auto ids = ConvertWordToIds(word);
ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
ids.tokens.end());
ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
}
}
return ans;
}
void InitTokens(std::istream &is) {
token2id_ = ReadTokens(is);
token2id_[" "] = token2id_["_"];
std::vector<std::pair<std::string, std::string>> puncts = {
{",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}};
for (const auto &p : puncts) {
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
token2id_[p.second] = token2id_[p.first];
}
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
token2id_[p.first] = token2id_[p.second];
}
}
if (!token2id_.count("、") && token2id_.count(",")) {
token2id_["、"] = token2id_[","];
}
}
void InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::vector<std::string> phone_list;
std::vector<int64_t> tone_list;
std::string line;
std::string phone;
int32_t line_num = 0;
while (std::getline(is, line)) {
++line_num;
std::istringstream iss(line);
token_list.clear();
phone_list.clear();
tone_list.clear();
iss >> word;
ToLowerCase(&word);
if (word2ids_.count(word)) {
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
continue;
}
while (iss >> phone) {
token_list.push_back(std::move(phone));
}
if ((token_list.size() & 1) != 0) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
int32_t num_phones = token_list.size() / 2;
phone_list.reserve(num_phones);
tone_list.reserve(num_phones);
for (int32_t i = 0; i != num_phones; ++i) {
phone_list.push_back(std::move(token_list[i]));
tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
if (tone_list.back() < 0 || tone_list.back() > 50) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
}
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
if (ids.empty()) {
continue;
}
if (ids.size() != num_phones) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
std::vector<int64_t> ids64{ids.begin(), ids.end()};
word2ids_.insert(
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}
word2ids_["呣"] = word2ids_["母"];
word2ids_["嗯"] = word2ids_["恩"];
}
private:
// lexicon.txt is saved in word2ids_
std::unordered_map<std::string, TokenIDs> word2ids_;
// tokens.txt is saved in token2id_
std::unordered_map<std::string, int32_t> token2id_;
OfflineTtsVitsModelMetaData meta_data_;
std::unique_ptr<cppjieba::Jieba> jieba_;
bool debug_ = false;
};
MeloTtsLexicon::~MeloTtsLexicon() = default;
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/melo-tts-lexicon.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
namespace sherpa_onnx {
class MeloTtsLexicon : public OfflineTtsFrontend {
public:
~MeloTtsLexicon() override;
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
... ...
... ... @@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
#endif
std::vector<std::vector<int64_t>>
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
const std::string &_text, const std::string & /*voice = ""*/) const {
// see
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
... ... @@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
std::u32string s = conv.from_bytes(text);
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
if (add_blank) {
... ...
... ... @@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
* If a frontend does not support splitting the text into
* sentences, the resulting vector contains only one subvector.
*/
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:
... ...
// sherpa-onnx/csrc/offline-tts-frontend.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include <sstream>
#include <string>
namespace sherpa_onnx {
std::string TokenIDs::ToString() const {
std::ostringstream os;
os << "TokenIDs(";
os << "tokens=[";
std::string sep;
for (auto i : tokens) {
os << sep << i;
sep = ", ";
}
os << "], ";
os << "tones=[";
sep = {};
for (auto i : tones) {
os << sep << i;
sep = ", ";
}
os << "]";
os << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
... ... @@ -8,8 +8,28 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
struct TokenIDs {
TokenIDs() = default;
/*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT
: tokens{tokens} {}
TokenIDs(const std::vector<int64_t> &tokens,
const std::vector<int64_t> &tones)
: tokens{tokens}, tones{tones} {}
std::string ToString() const;
std::vector<int64_t> tokens;
// Used only in MeloTTS
std::vector<int64_t> tones;
};
class OfflineTtsFrontend {
public:
virtual ~OfflineTtsFrontend() = default;
... ... @@ -26,7 +46,7 @@ class OfflineTtsFrontend {
* If a frontend does not support splitting the text into sentences,
* the resulting vector contains only one subvector.
*/
virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
virtual std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const = 0;
};
... ...
... ... @@ -22,6 +22,7 @@
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
... ... @@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
std::vector<std::vector<int64_t>> x =
std::vector<TokenIDs> token_ids =
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
if (x.empty() || (x.size() == 1 && x[0].empty())) {
if (token_ids.empty() ||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}
std::vector<std::vector<int64_t>> x;
std::vector<std::vector<int64_t>> tones;
x.reserve(token_ids.size());
for (auto &i : token_ids) {
x.push_back(std::move(i.tokens));
}
if (!token_ids[0].tones.empty()) {
tones.reserve(token_ids.size());
for (auto &i : token_ids) {
tones.push_back(std::move(i.tones));
}
}
// TODO(fangjun): add blank inside the frontend, not here
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
meta_data.frontend != "characters") {
for (auto &k : x) {
k = AddBlank(k);
}
for (auto &k : tones) {
k = AddBlank(k);
}
}
int32_t x_size = static_cast<int32_t>(x.size());
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
auto ans = Process(x, sid, speed);
auto ans = Process(x, tones, sid, speed);
if (callback) {
callback(ans.samples.data(), ans.samples.size(), 1.0);
}
... ... @@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
// the input text is too long, we process sentences within it in batches
// to avoid OOM. Batch size is config_.max_num_sentences
std::vector<std::vector<int64_t>> batch;
std::vector<std::vector<int64_t>> batch_x;
std::vector<std::vector<int64_t>> batch_tones;
int32_t batch_size = config_.max_num_sentences;
batch.reserve(config_.max_num_sentences);
batch_x.reserve(config_.max_num_sentences);
batch_tones.reserve(config_.max_num_sentences);
int32_t num_batches = x_size / batch_size;
if (config_.model.debug) {
... ... @@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
int32_t k = 0;
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch.clear();
batch_x.clear();
batch_tones.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
}
auto audio = Process(batch, sid, speed);
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
... ... @@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
batch.clear();
batch_x.clear();
batch_tones.clear();
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
++k;
}
if (!batch.empty()) {
auto audio = Process(batch, sid, speed);
if (!batch_x.empty()) {
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
... ... @@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
if (meta_data.frontend == "characters") {
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
config_.model.vits.tokens, meta_data);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
meta_data.is_melo_tts) {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
... ... @@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
const std::vector<std::vector<int64_t>> &tones,
int32_t sid, float speed) const {
int32_t num_tokens = 0;
for (const auto &k : tokens) {
... ... @@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
x.insert(x.end(), k.begin(), k.end());
}
std::vector<int64_t> tone_list;
if (!tones.empty()) {
tone_list.reserve(num_tokens);
for (const auto &k : tones) {
tone_list.insert(tone_list.end(), k.begin(), k.end());
}
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
... ... @@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
Ort::Value tones_tensor{nullptr};
if (!tones.empty()) {
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
tone_list.size(), x_shape.data(),
x_shape.size());
}
Ort::Value audio{nullptr};
if (tones.empty()) {
audio = model_->Run(std::move(x_tensor), sid, speed);
} else {
audio =
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
}
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
... ...
... ... @@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData {
bool is_piper = false;
bool is_coqui = false;
bool is_icefall = false;
bool is_melo_tts = false;
// for Chinese TTS models from
// https://github.com/Plachtaa/VITS-fast-fine-tuning
... ... @@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData {
int32_t use_eos_bos = 0;
int32_t pad_id = 0;
// for melo tts
int32_t speaker_id = 0;
int32_t version = 0;
std::string punctuations;
std::string language;
std::string voice;
... ...
... ... @@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl {
return RunVits(std::move(x), sid, speed);
}
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
if (x_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
static_cast<int32_t>(x_shape[0]));
exit(-1);
}
int64_t len = x_shape[1];
int64_t len_shape = 1;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
int64_t scale_shape = 1;
float noise_scale = config_.vits.noise_scale;
float length_scale = config_.vits.length_scale;
float noise_scale_w = config_.vits.noise_scale_w;
if (speed != 1 && speed > 0) {
length_scale = 1. / speed;
}
Ort::Value noise_scale_tensor =
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
memory_info, &length_scale, 1, &scale_shape, 1);
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
memory_info, &noise_scale_w, 1, &scale_shape, 1);
Ort::Value sid_tensor =
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
std::vector<Ort::Value> inputs;
inputs.reserve(7);
inputs.push_back(std::move(x));
inputs.push_back(std::move(x_length));
inputs.push_back(std::move(tones));
inputs.push_back(std::move(sid_tensor));
inputs.push_back(std::move(noise_scale_tensor));
inputs.push_back(std::move(length_scale_tensor));
inputs.push_back(std::move(noise_scale_w_tensor));
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
private:
... ... @@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl {
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id",
0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0);
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
"punctuation", "");
... ... @@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl {
if (comment.find("icefall") != std::string::npos) {
meta_data_.is_icefall = true;
}
if (comment.find("melo") != std::string::npos) {
meta_data_.is_melo_tts = true;
int32_t expected_version = 2;
if (meta_data_.version < expected_version) {
SHERPA_ONNX_LOGE(
"Please download the latest MeloTTS model and retry. Current "
"version: %d. Expected version: %d",
meta_data_.version, expected_version);
exit(-1);
}
// NOTE(fangjun):
// version 0 is the first version
// version 2: add jieba=1 to the metadata
}
}
Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
... ... @@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
return impl_->Run(std::move(x), sid, speed);
}
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
int64_t sid /*= 0*/,
float speed /*= 1.0*/) {
return impl_->Run(std::move(x), std::move(tones), sid, speed);
}
const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
return impl_->GetMetaData();
}
... ...
... ... @@ -40,6 +40,10 @@ class OfflineTtsVitsModel {
*/
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
// This is for MeloTTS
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
float speed = 1.0);
const OfflineTtsVitsModelMetaData &GetMetaData() const;
private:
... ...
... ... @@ -5,8 +5,8 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
... ... @@ -36,7 +36,6 @@ class OfflineWhisperDecoder {
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
};
} // namespace sherpa_onnx
... ...
... ... @@ -12,7 +12,8 @@
namespace sherpa_onnx {
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
void OfflineWhisperGreedySearchDecoder::SetConfig(
const OfflineWhisperModelConfig &config) {
config_ = config;
}
... ... @@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
ans[0].lang = "";
}
ans[0].tokens = std::move(predicted_tokens);
... ...
... ... @@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) {
}
}
template <typename T /*= float*/>
void Print1D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
const T *d = v->GetTensorData<T>();
std::ostringstream os;
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
fprintf(stderr, "%.3f ", d[i]);
os << *d << " ";
}
fprintf(stderr, "\n");
os << "\n";
fprintf(stderr, "%s\n", os.str().c_str());
}
template void Print1D<int64_t>(Ort::Value *v);
template void Print1D<float>(Ort::Value *v);
template <typename T /*= float*/>
void Print2D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
... ...
... ... @@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
Ort::Value View(Ort::Value *v);
// Print a 1-D tensor to stderr
template <typename T = float>
void Print1D(Ort::Value *v);
// Print a 2-D tensor to stderr
... ...
... ... @@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
}
#endif
std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string &voice /*= ""*/) const {
piper::eSpeakPhonemeConfig config;
... ... @@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
piper::phonemize_eSpeak(text, config, phonemes);
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> phoneme_ids;
... ...
... ... @@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
const OfflineTtsVitsModelMetaData &meta_data);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:
... ...
... ... @@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api.ReleaseStatus(status);
}
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
const std::string &provider_str,
static Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr) {
Provider p = StringToProvider(provider_str);
... ... @@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
case Provider::kTRT: {
if (provider_config == nullptr) {
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
"Must be extended for offline and others");
SHERPA_ONNX_LOGE(
"Tensorrt support for Online models ony,"
"Must be extended for offline and others");
exit(1);
}
auto trt_config = provider_config->trt_config;
... ... @@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::to_string(trt_config.trt_max_partition_iterations);
auto trt_min_subgraph_size =
std::to_string(trt_config.trt_min_subgraph_size);
auto trt_fp16_enable =
std::to_string(trt_config.trt_fp16_enable);
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
auto trt_detailed_build_log =
std::to_string(trt_config.trt_detailed_build_log);
auto trt_engine_cache_enable =
std::to_string(trt_config.trt_engine_cache_enable);
auto trt_timing_cache_enable =
std::to_string(trt_config.trt_timing_cache_enable);
auto trt_dump_subgraphs =
std::to_string(trt_config.trt_dump_subgraphs);
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
std::vector<TrtPairs> trt_options = {
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
};
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations",
trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
... ... @@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
if (provider_config != nullptr) {
options.device_id = provider_config->device;
options.cudnn_conv_algo_search =
OrtCudnnConvAlgoSearch(provider_config->cuda_config
.cudnn_conv_algo_search);
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
provider_config->cuda_config.cudnn_conv_algo_search);
} else {
options.device_id = 0;
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
... ... @@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type) {
const std::string &model_type) {
/*
Transducer models : Only encoder will run with tensorrt,
decoder and joiner will run with cuda
*/
if(config.provider_config.provider == "trt" &&
if (config.provider_config.provider == "trt" &&
(model_type == "decoder" || model_type == "joiner")) {
return GetSessionOptionsImpl(config.num_threads,
"cuda", &config.provider_config);
return GetSessionOptionsImpl(config.num_threads, "cuda",
&config.provider_config);
}
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
... ...
... ... @@ -5,6 +5,8 @@
#ifndef SHERPA_ONNX_CSRC_SESSION_H_
#define SHERPA_ONNX_CSRC_SESSION_H_
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
... ... @@ -25,7 +27,7 @@ namespace sherpa_onnx {
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type);
const std::string &model_type);
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
... ...
... ... @@ -6,6 +6,7 @@
#include <algorithm>
#include <unordered_map>
#include <utility>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/macros.h"
... ...
... ... @@ -11,7 +11,7 @@
namespace sherpa_onnx {
TEST(UTF8, Case1) {
std::string hello = "你好, 早上好!世界. hello!。Hallo";
std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?";
std::vector<std::string> ss = SplitUtf8(hello);
for (const auto &s : ss) {
std::cout << s << "\n";
... ...