Fangjun Kuang
Committed by GitHub

Support Chinese vits models (#368)

@@ -9,6 +9,10 @@ log() { @@ -9,6 +9,10 @@ log() {
9 } 9 }
10 10
11 log "Offline TTS test" 11 log "Offline TTS test"
  12 +# test waves are saved in ./tts
  13 +mkdir ./tts
  14 +
  15 +log "vits-ljs test"
12 16
13 wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx 17 wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
14 wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt 18 wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
@@ -18,14 +22,48 @@ python3 ./python-api-examples/offline-tts.py \ @@ -18,14 +22,48 @@ python3 ./python-api-examples/offline-tts.py \
18 --vits-model=./vits-ljs.onnx \ 22 --vits-model=./vits-ljs.onnx \
19 --vits-lexicon=./lexicon.txt \ 23 --vits-lexicon=./lexicon.txt \
20 --vits-tokens=./tokens.txt \ 24 --vits-tokens=./tokens.txt \
21 - --output-filename=./tts.wav \ 25 + --output-filename=./tts/vits-ljs.wav \
22 'liliana, the most beautiful and lovely assistant of our team!' 26 'liliana, the most beautiful and lovely assistant of our team!'
23 27
24 -ls -lh ./tts.wav  
25 -file ./tts.wav 28 +ls -lh ./tts
26 29
27 rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt 30 rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt
28 31
  32 +log "vits-vctk test"
  33 +wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vits-vctk.onnx
  34 +wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/lexicon.txt
  35 +wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/tokens.txt
  36 +
  37 +for sid in 0 10 90; do
  38 + python3 ./python-api-examples/offline-tts.py \
  39 + --vits-model=./vits-vctk.onnx \
  40 + --vits-lexicon=./lexicon.txt \
  41 + --vits-tokens=./tokens.txt \
  42 + --sid=$sid \
  43 + --output-filename=./tts/vits-vctk-${sid}.wav \
  44 + 'liliana, the most beautiful and lovely assistant of our team!'
  45 +done
  46 +
  47 +rm -v vits-vctk.onnx ./lexicon.txt ./tokens.txt
  48 +
  49 +log "vits-zh-aishell3"
  50 +
  51 +wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/vits-aishell3.onnx
  52 +wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/lexicon.txt
  53 +wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/tokens.txt
  54 +
  55 +for sid in 0 10 90; do
  56 + python3 ./python-api-examples/offline-tts.py \
  57 + --vits-model=./vits-aishell3.onnx \
  58 + --vits-lexicon=./lexicon.txt \
  59 + --vits-tokens=./tokens.txt \
  60 + --sid=$sid \
  61 + --output-filename=./tts/vits-aishell3-${sid}.wav \
  62 + '林美丽最美丽'
  63 +done
  64 +
  65 +rm -v vits-aishell3.onnx ./lexicon.txt ./tokens.txt
  66 +
29 mkdir -p /tmp/icefall-models 67 mkdir -p /tmp/icefall-models
30 dir=/tmp/icefall-models 68 dir=/tmp/icefall-models
31 69
@@ -69,4 +69,4 @@ jobs: @@ -69,4 +69,4 @@ jobs:
69 - uses: actions/upload-artifact@v3 69 - uses: actions/upload-artifact@v3
70 with: 70 with:
71 name: tts-generated-test-files 71 name: tts-generated-test-files
72 - path: tts.wav 72 + path: tts
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.8.1") 4 +set(SHERPA_ONNX_VERSION "1.8.2")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -175,6 +175,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) @@ -175,6 +175,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
175 include(asio) 175 include(asio)
176 endif() 176 endif()
177 177
  178 +include(utfcpp)
  179 +
178 add_subdirectory(sherpa-onnx) 180 add_subdirectory(sherpa-onnx)
179 181
180 if(SHERPA_ONNX_ENABLE_C_API) 182 if(SHERPA_ONNX_ENABLE_C_API)
@@ -6,7 +6,7 @@ function(download_kaldi_decoder) @@ -6,7 +6,7 @@ function(download_kaldi_decoder)
6 set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d") 6 set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
7 7
8 set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 8 set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
9 - set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 + set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
10 set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 10 set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
11 11
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz  
16 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz  
17 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz  
18 - /tmp/kaldi-native-fbank-1.18.1.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
  16 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
  17 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
  18 + /tmp/kaldi-native-fbank-1.18.5.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
  1 +function(download_utfcpp)
  2 + include(FetchContent)
  3 +
  4 + set(utfcpp_URL "https://github.com/nemtrif/utfcpp/archive/refs/tags/v3.2.5.tar.gz")
  5 + set(utfcpp_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/utfcpp-3.2.5.tar.gz")
  6 + set(utfcpp_HASH "SHA256=14fd1b3c466814cb4c40771b7f207b61d2c7a0aa6a5e620ca05c00df27f25afd")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download utfcpp
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/utfcpp-3.2.5.tar.gz
  12 + ${PROJECT_SOURCE_DIR}/utfcpp-3.2.5.tar.gz
  13 + ${PROJECT_BINARY_DIR}/utfcpp-3.2.5.tar.gz
  14 + /tmp/utfcpp-3.2.5.tar.gz
  15 + /star-fj/fangjun/download/github/utfcpp-3.2.5.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(utfcpp_URL "${f}")
  21 + file(TO_CMAKE_PATH "${utfcpp_URL}" utfcpp_URL)
  22 + message(STATUS "Found local downloaded utfcpp: ${utfcpp_URL}")
  23 + set(utfcpp_URL2)
  24 + break()
  25 + endif()
  26 + endforeach()
  27 +
  28 + FetchContent_Declare(utfcpp
  29 + URL
  30 + ${utfcpp_URL}
  31 + ${utfcpp_URL2}
  32 + URL_HASH ${utfcpp_HASH}
  33 + )
  34 +
  35 + FetchContent_GetProperties(utfcpp)
  36 + if(NOT utfcpp_POPULATED)
  37 + message(STATUS "Downloading utfcpp from ${utfcpp_URL}")
  38 + FetchContent_Populate(utfcpp)
  39 + endif()
  40 + message(STATUS "utfcpp is downloaded to ${utfcpp_SOURCE_DIR}")
  41 + # add_subdirectory(${utfcpp_SOURCE_DIR} ${utfcpp_BINARY_DIR} EXCLUDE_FROM_ALL)
  42 + include_directories(${utfcpp_SOURCE_DIR})
  43 +endfunction()
  44 +
  45 +download_utfcpp()
@@ -20,9 +20,14 @@ python3 ./python-api-examples/offline-tts.py \ @@ -20,9 +20,14 @@ python3 ./python-api-examples/offline-tts.py \
20 --vits-tokens=./tokens.txt \ 20 --vits-tokens=./tokens.txt \
21 --output-filename=./generated.wav \ 21 --output-filename=./generated.wav \
22 'liliana, the most beautiful and lovely assistant of our team!' 22 'liliana, the most beautiful and lovely assistant of our team!'
  23 +
  24 +Please see
  25 +https://k2-fsa.github.io/sherpa/onnx/tts/index.html
  26 +for details.
23 """ 27 """
24 28
25 import argparse 29 import argparse
  30 +import time
26 31
27 import sherpa_onnx 32 import sherpa_onnx
28 import soundfile as sf 33 import soundfile as sf
@@ -115,7 +120,14 @@ def main(): @@ -115,7 +120,14 @@ def main():
115 ) 120 )
116 ) 121 )
117 tts = sherpa_onnx.OfflineTts(tts_config) 122 tts = sherpa_onnx.OfflineTts(tts_config)
  123 +
  124 + start = time.time()
118 audio = tts.generate(args.text, sid=args.sid) 125 audio = tts.generate(args.text, sid=args.sid)
  126 + end = time.time()
  127 + elapsed_seconds = end - start
  128 + audio_duration = len(audio.samples) / audio.sample_rate
  129 + real_time_factor = elapsed_seconds / audio_duration
  130 +
119 sf.write( 131 sf.write(
120 args.output_filename, 132 args.output_filename,
121 audio.samples, 133 audio.samples,
@@ -124,6 +136,9 @@ def main(): @@ -124,6 +136,9 @@ def main():
124 ) 136 )
125 print(f"Saved to {args.output_filename}") 137 print(f"Saved to {args.output_filename}")
126 print(f"The text is '{args.text}'") 138 print(f"The text is '{args.text}'")
  139 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  140 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  141 + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
127 142
128 143
129 if __name__ == "__main__": 144 if __name__ == "__main__":
@@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
331 stack-test.cc 331 stack-test.cc
332 transpose-test.cc 332 transpose-test.cc
333 unbind-test.cc 333 unbind-test.cc
  334 + utfcpp-test.cc
334 ) 335 )
335 336
336 function(sherpa_onnx_add_test source) 337 function(sherpa_onnx_add_test source)
@@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds( @@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
76 } 76 }
77 77
78 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, 78 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
79 - const std::string &punctuations) { 79 + const std::string &punctuations, const std::string &language) {
  80 + InitLanguage(language);
  81 + InitTokens(tokens);
  82 + InitLexicon(lexicon);
  83 + InitPunctuations(punctuations);
  84 +}
  85 +
  86 +std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
  87 + const std::string &text) const {
  88 + switch (language_) {
  89 + case Language::kEnglish:
  90 + return ConvertTextToTokenIdsEnglish(text);
  91 + case Language::kChinese:
  92 + return ConvertTextToTokenIdsChinese(text);
  93 + default:
  94 + SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
  95 + exit(-1);
  96 + }
  97 +
  98 + return {};
  99 +}
  100 +
  101 +std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
  102 + const std::string &text) const {
  103 + std::vector<std::string> words = SplitUtf8(text);
  104 +
  105 + std::vector<int64_t> ans;
  106 +
  107 + ans.push_back(token2id_.at("sil"));
  108 +
  109 + for (const auto &w : words) {
  110 + if (!word2ids_.count(w)) {
  111 + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
  112 + continue;
  113 + }
  114 +
  115 + const auto &token_ids = word2ids_.at(w);
  116 + ans.insert(ans.end(), token_ids.begin(), token_ids.end());
  117 + }
  118 + ans.push_back(token2id_.at("sil"));
  119 + ans.push_back(token2id_.at("eos"));
  120 + return ans;
  121 +}
  122 +
  123 +std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
  124 + const std::string &_text) const {
  125 + std::string text(_text);
  126 + ToLowerCase(&text);
  127 +
  128 + std::vector<std::string> words = SplitUtf8(text);
  129 +
  130 + std::vector<int64_t> ans;
  131 + for (const auto &w : words) {
  132 + if (punctuations_.count(w)) {
  133 + ans.push_back(token2id_.at(w));
  134 + continue;
  135 + }
  136 +
  137 + if (!word2ids_.count(w)) {
  138 + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
  139 + continue;
  140 + }
  141 +
  142 + const auto &token_ids = word2ids_.at(w);
  143 + ans.insert(ans.end(), token_ids.begin(), token_ids.end());
  144 + if (blank_ != -1) {
  145 + ans.push_back(blank_);
  146 + }
  147 + }
  148 +
  149 + if (blank_ != -1 && !ans.empty()) {
  150 + // remove the last blank
  151 + ans.resize(ans.size() - 1);
  152 + }
  153 +
  154 + return ans;
  155 +}
  156 +
  157 +void Lexicon::InitTokens(const std::string &tokens) {
80 token2id_ = ReadTokens(tokens); 158 token2id_ = ReadTokens(tokens);
81 - blank_ = token2id_.at(" "); 159 + if (token2id_.count(" ")) {
  160 + blank_ = token2id_.at(" ");
  161 + }
  162 +}
  163 +
  164 +void Lexicon::InitLanguage(const std::string &_lang) {
  165 + std::string lang(_lang);
  166 + ToLowerCase(&lang);
  167 + if (lang == "english") {
  168 + language_ = Language::kEnglish;
  169 + } else if (lang == "chinese") {
  170 + language_ = Language::kChinese;
  171 + } else {
  172 + SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
  173 + exit(-1);
  174 + }
  175 +}
  176 +
  177 +void Lexicon::InitLexicon(const std::string &lexicon) {
82 std::ifstream is(lexicon); 178 std::ifstream is(lexicon);
83 179
84 std::string word; 180 std::string word;
@@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, @@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
109 } 205 }
110 word2ids_.insert({std::move(word), std::move(ids)}); 206 word2ids_.insert({std::move(word), std::move(ids)});
111 } 207 }
  208 +}
112 209
113 - // process punctuations 210 +void Lexicon::InitPunctuations(const std::string &punctuations) {
114 std::vector<std::string> punctuation_list; 211 std::vector<std::string> punctuation_list;
115 SplitStringToVector(punctuations, " ", false, &punctuation_list); 212 SplitStringToVector(punctuations, " ", false, &punctuation_list);
116 for (auto &s : punctuation_list) { 213 for (auto &s : punctuation_list) {
@@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, @@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
118 } 215 }
119 } 216 }
120 217
121 -std::vector<int64_t> Lexicon::ConvertTextToTokenIds(  
122 - const std::string &_text) const {  
123 - std::string text(_text);  
124 - ToLowerCase(&text);  
125 -  
126 - std::vector<std::string> words;  
127 - SplitStringToVector(text, " ", false, &words);  
128 -  
129 - std::vector<int64_t> ans;  
130 - for (auto w : words) {  
131 - std::vector<int64_t> prefix;  
132 - while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {  
133 - // if w begins with a punctuation  
134 - prefix.push_back(token2id_.at(std::string(1, w[0])));  
135 - w = std::string(w.begin() + 1, w.end());  
136 - }  
137 -  
138 - std::vector<int64_t> suffix;  
139 - while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {  
140 - suffix.push_back(token2id_.at(std::string(1, w.back())));  
141 - w = std::string(w.begin(), w.end() - 1);  
142 - }  
143 -  
144 - if (!word2ids_.count(w)) {  
145 - SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());  
146 - continue;  
147 - }  
148 -  
149 - const auto &token_ids = word2ids_.at(w);  
150 - ans.insert(ans.end(), prefix.begin(), prefix.end());  
151 - ans.insert(ans.end(), token_ids.begin(), token_ids.end());  
152 - ans.insert(ans.end(), suffix.rbegin(), suffix.rend());  
153 - ans.push_back(blank_);  
154 - }  
155 -  
156 - if (!ans.empty()) {  
157 - ans.resize(ans.size() - 1);  
158 - }  
159 -  
160 - return ans;  
161 -}  
162 -  
163 } // namespace sherpa_onnx 218 } // namespace sherpa_onnx
@@ -13,18 +13,40 @@ @@ -13,18 +13,40 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +// TODO(fangjun): Refactor it to an abstract class
16 class Lexicon { 17 class Lexicon {
17 public: 18 public:
18 Lexicon(const std::string &lexicon, const std::string &tokens, 19 Lexicon(const std::string &lexicon, const std::string &tokens,
19 - const std::string &punctuations); 20 + const std::string &punctuations, const std::string &language);
20 21
21 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; 22 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
22 23
23 private: 24 private:
  25 + std::vector<int64_t> ConvertTextToTokenIdsEnglish(
  26 + const std::string &text) const;
  27 +
  28 + std::vector<int64_t> ConvertTextToTokenIdsChinese(
  29 + const std::string &text) const;
  30 +
  31 + void InitLanguage(const std::string &lang);
  32 + void InitTokens(const std::string &tokens);
  33 + void InitLexicon(const std::string &lexicon);
  34 + void InitPunctuations(const std::string &punctuations);
  35 +
  36 + private:
  37 + enum class Language {
  38 + kEnglish,
  39 + kChinese,
  40 + kUnknown,
  41 + };
  42 +
  43 + private:
24 std::unordered_map<std::string, std::vector<int32_t>> word2ids_; 44 std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
25 std::unordered_set<std::string> punctuations_; 45 std::unordered_set<std::string> punctuations_;
26 std::unordered_map<std::string, int32_t> token2id_; 46 std::unordered_map<std::string, int32_t> token2id_;
27 - int32_t blank_; // ID for the blank token 47 + int32_t blank_ = -1; // ID for the blank token
  48 + Language language_;
  49 + //
28 }; 50 };
29 51
30 } // namespace sherpa_onnx 52 } // namespace sherpa_onnx
@@ -21,7 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -21,7 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
21 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) 21 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
22 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), 22 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens, 23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
24 - model_->Punctuations()) {} 24 + model_->Punctuations(), model_->Language()) {}
25 25
26 GeneratedAudio Generate(const std::string &text, 26 GeneratedAudio Generate(const std::string &text,
27 int64_t sid = 0) const override { 27 int64_t sid = 0) const override {
@@ -84,6 +84,7 @@ class OfflineTtsVitsModel::Impl { @@ -84,6 +84,7 @@ class OfflineTtsVitsModel::Impl {
84 bool AddBlank() const { return add_blank_; } 84 bool AddBlank() const { return add_blank_; }
85 85
86 std::string Punctuations() const { return punctuations_; } 86 std::string Punctuations() const { return punctuations_; }
  87 + std::string Language() const { return language_; }
87 88
88 private: 89 private:
89 void Init(void *model_data, size_t model_data_length) { 90 void Init(void *model_data, size_t model_data_length) {
@@ -108,6 +109,7 @@ class OfflineTtsVitsModel::Impl { @@ -108,6 +109,7 @@ class OfflineTtsVitsModel::Impl {
108 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); 109 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
109 SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); 110 SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
110 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); 111 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
  112 + SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
111 } 113 }
112 114
113 private: 115 private:
@@ -128,6 +130,7 @@ class OfflineTtsVitsModel::Impl { @@ -128,6 +130,7 @@ class OfflineTtsVitsModel::Impl {
128 int32_t add_blank_; 130 int32_t add_blank_;
129 int32_t n_speakers_; 131 int32_t n_speakers_;
130 std::string punctuations_; 132 std::string punctuations_;
  133 + std::string language_;
131 }; 134 };
132 135
133 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) 136 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
@@ -147,4 +150,6 @@ std::string OfflineTtsVitsModel::Punctuations() const { @@ -147,4 +150,6 @@ std::string OfflineTtsVitsModel::Punctuations() const {
147 return impl_->Punctuations(); 150 return impl_->Punctuations();
148 } 151 }
149 152
  153 +std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
  154 +
150 } // namespace sherpa_onnx 155 } // namespace sherpa_onnx
@@ -38,6 +38,7 @@ class OfflineTtsVitsModel { @@ -38,6 +38,7 @@ class OfflineTtsVitsModel {
38 bool AddBlank() const; 38 bool AddBlank() const;
39 39
40 std::string Punctuations() const; 40 std::string Punctuations() const;
  41 + std::string Language() const;
41 42
42 private: 43 private:
43 class Impl; 44 class Impl;
@@ -8,12 +8,16 @@ @@ -8,12 +8,16 @@
8 #include <assert.h> 8 #include <assert.h>
9 9
10 #include <algorithm> 10 #include <algorithm>
  11 +#include <cctype>
11 #include <limits> 12 #include <limits>
12 #include <sstream> 13 #include <sstream>
13 #include <string> 14 #include <string>
14 #include <unordered_map> 15 #include <unordered_map>
  16 +#include <utility>
15 #include <vector> 17 #include <vector>
16 18
  19 +#include "source/utf8.h"
  20 +
17 // This file is copied/modified from 21 // This file is copied/modified from
18 // https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc 22 // https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
19 23
@@ -158,4 +162,57 @@ template bool SplitStringToFloats(const std::string &full, const char *delim, @@ -158,4 +162,57 @@ template bool SplitStringToFloats(const std::string &full, const char *delim,
158 bool omit_empty_strings, 162 bool omit_empty_strings,
159 std::vector<double> *out); 163 std::vector<double> *out);
160 164
  165 +std::vector<std::string> SplitUtf8(const std::string &text) {
  166 + char *begin = const_cast<char *>(text.c_str());
  167 + char *end = begin + text.size();
  168 +
  169 + std::vector<std::string> ans;
  170 + std::string buf;
  171 +
  172 + while (begin < end) {
  173 + uint32_t code = utf8::next(begin, end);
  174 +
  175 + // 1. is punctuation
  176 + if (std::ispunct(code)) {
  177 + if (!buf.empty()) {
  178 + ans.push_back(std::move(buf));
  179 + }
  180 +
  181 + char s[5] = {0};
  182 + utf8::append(code, s);
  183 + ans.push_back(s);
  184 + continue;
  185 + }
  186 +
  187 + // 2. is space
  188 + if (std::isspace(code)) {
  189 + if (!buf.empty()) {
  190 + ans.push_back(std::move(buf));
  191 + }
  192 + continue;
  193 + }
  194 +
  195 + // 3. is alpha
  196 + if (std::isalpha(code)) {
  197 + buf.push_back(code);
  198 + continue;
  199 + }
  200 +
  201 + if (!buf.empty()) {
  202 + ans.push_back(std::move(buf));
  203 + }
  204 +
  205 + // for others
  206 +
  207 + char s[5] = {0};
  208 + utf8::append(code, s);
  209 + ans.push_back(s);
  210 + }
  211 +
  212 + if (!buf.empty()) {
  213 + ans.push_back(std::move(buf));
  214 + }
  215 +
  216 + return ans;
  217 +}
161 } // namespace sherpa_onnx 218 } // namespace sherpa_onnx
@@ -119,6 +119,8 @@ bool SplitStringToFloats(const std::string &full, const char *delim, @@ -119,6 +119,8 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
119 template <typename T> 119 template <typename T>
120 bool ConvertStringToReal(const std::string &str, T *out); 120 bool ConvertStringToReal(const std::string &str, T *out);
121 121
  122 +std::vector<std::string> SplitUtf8(const std::string &text);
  123 +
122 } // namespace sherpa_onnx 124 } // namespace sherpa_onnx
123 125
124 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 126 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
  1 +// sherpa-onnx/csrc/utfcpp-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include <cctype>
  6 +#include <string>
  7 +
  8 +#include "gtest/gtest.h"
  9 +#include "sherpa-onnx/csrc/text-utils.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +TEST(UTF8, Case1) {
  14 + std::string hello = "你好, 早上好!世界. hello!。Hallo";
  15 + std::vector<std::string> ss = SplitUtf8(hello);
  16 + for (const auto &s : ss) {
  17 + std::cout << s << "\n";
  18 + }
  19 +}
  20 +
  21 +} // namespace sherpa_onnx