Fangjun Kuang
Committed by GitHub

support reading rule FST for Android TTS (#410)

@@ -34,6 +34,11 @@ jobs: @@ -34,6 +34,11 @@ jobs:
34 with: 34 with:
35 fetch-depth: 0 35 fetch-depth: 0
36 36
  37 + - name: ccache
  38 + uses: hendrikmuhs/ccache-action@v1.2
  39 + with:
  40 + key: ${{ matrix.os }}-android
  41 +
37 - name: Display NDK HOME 42 - name: Display NDK HOME
38 shell: bash 43 shell: bash
39 run: | 44 run: |
@@ -61,6 +66,10 @@ jobs: @@ -61,6 +66,10 @@ jobs:
61 - name: build APK 66 - name: build APK
62 shell: bash 67 shell: bash
63 run: | 68 run: |
  69 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  70 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  71 + cmake --version
  72 +
64 export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME 73 export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
65 ./build-apk-tts.sh 74 ./build-apk-tts.sh
66 75
@@ -70,12 +79,14 @@ jobs: @@ -70,12 +79,14 @@ jobs:
70 ls -lh ./apks/ 79 ls -lh ./apks/
71 du -h -d1 . 80 du -h -d1 .
72 81
73 - # - uses: actions/upload-artifact@v3  
74 - # with:  
75 - # name: tts-apk  
76 - # path: ./apks/*.apk 82 + - uses: actions/upload-artifact@v3
  83 + if: false
  84 + with:
  85 + name: tts-apk
  86 + path: ./apks/*.apk
77 87
78 - name: Publish to huggingface 88 - name: Publish to huggingface
  89 + if: true
79 env: 90 env:
80 HF_TOKEN: ${{ secrets.HF_TOKEN }} 91 HF_TOKEN: ${{ secrets.HF_TOKEN }}
81 uses: nick-fields/retry@v2 92 uses: nick-fields/retry@v2
@@ -92,7 +103,9 @@ jobs: @@ -92,7 +103,9 @@ jobs:
92 103
93 git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface 104 git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
94 cd huggingface 105 cd huggingface
  106 + git fetch
95 git pull 107 git pull
  108 + git merge -m "merge remote" --ff origin main
96 109
97 mkdir -p tts 110 mkdir -p tts
98 cp -v ../apks/*.apk ./tts/ 111 cp -v ../apks/*.apk ./tts/
@@ -28,6 +28,12 @@ jobs: @@ -28,6 +28,12 @@ jobs:
28 - uses: actions/checkout@v4 28 - uses: actions/checkout@v4
29 with: 29 with:
30 fetch-depth: 0 30 fetch-depth: 0
  31 +
  32 + - name: ccache
  33 + uses: hendrikmuhs/ccache-action@v1.2
  34 + with:
  35 + key: ${{ matrix.os }}-android
  36 +
31 - name: Display NDK HOME 37 - name: Display NDK HOME
32 shell: bash 38 shell: bash
33 run: | 39 run: |
@@ -37,6 +43,10 @@ jobs: @@ -37,6 +43,10 @@ jobs:
37 - name: build APK 43 - name: build APK
38 shell: bash 44 shell: bash
39 run: | 45 run: |
  46 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  47 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  48 + cmake --version
  49 +
40 export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME 50 export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
41 ./build-apk-vad.sh 51 ./build-apk-vad.sh
42 ./build-apk-two-pass.sh 52 ./build-apk-two-pass.sh
@@ -101,12 +101,14 @@ class MainActivity : AppCompatActivity() { @@ -101,12 +101,14 @@ class MainActivity : AppCompatActivity() {
101 fun initTts() { 101 fun initTts() {
102 var modelDir :String? 102 var modelDir :String?
103 var modelName :String? 103 var modelName :String?
  104 + var ruleFsts: String?
104 105
105 // The purpose of such a design is to make the CI test easier 106 // The purpose of such a design is to make the CI test easier
106 // Please see 107 // Please see
107 // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py 108 // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py
108 modelDir = null 109 modelDir = null
109 modelName = null 110 modelName = null
  111 + ruleFsts = null
110 112
111 // Example 1: 113 // Example 1:
112 // modelDir = "vits-vctk" 114 // modelDir = "vits-vctk"
@@ -116,7 +118,12 @@ class MainActivity : AppCompatActivity() { @@ -116,7 +118,12 @@ class MainActivity : AppCompatActivity() {
116 // modelDir = "vits-piper-en_US-lessac-medium" 118 // modelDir = "vits-piper-en_US-lessac-medium"
117 // modelName = "en_US-lessac-medium.onnx" 119 // modelName = "en_US-lessac-medium.onnx"
118 120
119 - val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!)!! 121 + // Example 3:
  122 + // modelDir = "vits-zh-aishell3"
  123 + // modelName = "vits-aishell3.onnx"
  124 + // ruleFsts = "vits-zh-aishell3/rule.fst"
  125 +
  126 + val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!, ruleFsts = ruleFsts ?: "")!!
120 tts = OfflineTts(assetManager = application.assets, config = config) 127 tts = OfflineTts(assetManager = application.assets, config = config)
121 } 128 }
122 } 129 }
@@ -21,6 +21,7 @@ data class OfflineTtsModelConfig( @@ -21,6 +21,7 @@ data class OfflineTtsModelConfig(
21 21
22 data class OfflineTtsConfig( 22 data class OfflineTtsConfig(
23 var model: OfflineTtsModelConfig, 23 var model: OfflineTtsModelConfig,
  24 + var ruleFsts: String = "",
24 ) 25 )
25 26
26 class GeneratedAudio( 27 class GeneratedAudio(
@@ -116,7 +117,7 @@ class OfflineTts( @@ -116,7 +117,7 @@ class OfflineTts(
116 // please refer to 117 // please refer to
117 // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html 118 // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
118 // to download models 119 // to download models
119 -fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? { 120 +fun getOfflineTtsConfig(modelDir: String, modelName: String, ruleFsts: String): OfflineTtsConfig? {
120 return OfflineTtsConfig( 121 return OfflineTtsConfig(
121 model = OfflineTtsModelConfig( 122 model = OfflineTtsModelConfig(
122 vits = OfflineTtsVitsModelConfig( 123 vits = OfflineTtsVitsModelConfig(
@@ -125,8 +126,9 @@ fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? @@ -125,8 +126,9 @@ fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig?
125 tokens = "$modelDir/tokens.txt" 126 tokens = "$modelDir/tokens.txt"
126 ), 127 ),
127 numThreads = 2, 128 numThreads = 2,
128 - debug = false, 129 + debug = true,
129 provider = "cpu", 130 provider = "cpu",
130 - ) 131 + ),
  132 + ruleFsts=ruleFsts,
131 ) 133 )
132 } 134 }
1 function(download_kaldifst) 1 function(download_kaldifst)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.8.tar.gz")  
5 - set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.8.tar.gz")  
6 - set(kaldifst_HASH "SHA256=94613923568ef9a240ba1059b8b9dfe3082daad794934635d99e66248a6687b5") 4 + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.9.tar.gz")
  5 + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.9.tar.gz")
  6 + set(kaldifst_HASH "SHA256=8c653021491dca54c38ab659565edfab391418a79ae87099257863cd5664dd39")
7 7
8 # If you don't have access to the Internet, 8 # If you don't have access to the Internet,
9 # please pre-download kaldifst 9 # please pre-download kaldifst
10 set(possible_file_locations 10 set(possible_file_locations
11 - $ENV{HOME}/Downloads/kaldifst-1.7.8.tar.gz  
12 - ${PROJECT_SOURCE_DIR}/kaldifst-1.7.8.tar.gz  
13 - ${PROJECT_BINARY_DIR}/kaldifst-1.7.8.tar.gz  
14 - /tmp/kaldifst-1.7.8.tar.gz  
15 - /star-fj/fangjun/download/github/kaldifst-1.7.8.tar.gz 11 + $ENV{HOME}/Downloads/kaldifst-1.7.9.tar.gz
  12 + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.9.tar.gz
  13 + ${PROJECT_BINARY_DIR}/kaldifst-1.7.9.tar.gz
  14 + /tmp/kaldifst-1.7.9.tar.gz
  15 + /star-fj/fangjun/download/github/kaldifst-1.7.9.tar.gz
16 ) 16 )
17 17
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
@@ -8,7 +8,7 @@ @@ -8,7 +8,7 @@
8 # Inside the $ANDROID_NDK directory, you can find a binary ndk-build 8 # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
9 # and some other files like the file "build/cmake/android.toolchain.cmake" 9 # and some other files like the file "build/cmake/android.toolchain.cmake"
10 10
11 -set -e 11 +set -ex
12 12
13 log() { 13 log() {
14 # This function is from espnet 14 # This function is from espnet
@@ -43,6 +43,7 @@ wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/$model_name @@ -43,6 +43,7 @@ wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/$model_name
43 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt 43 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt
44 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt 44 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt
45 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true 45 wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true
  46 +wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/rule.fst 2>/dev/null || true
46 47
47 popd 48 popd
48 # Now we are at the project root directory 49 # Now we are at the project root directory
@@ -51,6 +52,11 @@ git checkout . @@ -51,6 +52,11 @@ git checkout .
51 pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx 52 pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx
52 sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt 53 sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt
53 sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt 54 sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
  55 +{% if tts_model.rule_fsts %}
  56 + rule_fsts={{ tts_model.rule_fsts }}
  57 + sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./MainActivity.kt
  58 +{% endif %}
  59 +
54 git diff 60 git diff
55 popd 61 popd
56 62
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 2
  3 +import argparse
3 from dataclasses import dataclass 4 from dataclasses import dataclass
  5 +from typing import List, Optional
4 6
5 import jinja2 7 import jinja2
6 -from typing import List  
7 -import argparse  
8 8
9 9
10 def get_args(): 10 def get_args():
@@ -29,12 +29,65 @@ class TtsModel: @@ -29,12 +29,65 @@ class TtsModel:
29 model_dir: str 29 model_dir: str
30 model_name: str 30 model_name: str
31 lang: str # en, zh, fr, de, etc. 31 lang: str # en, zh, fr, de, etc.
  32 + rule_fsts: Optional[List[str]] = (None,)
32 33
33 34
34 def get_all_models() -> List[TtsModel]: 35 def get_all_models() -> List[TtsModel]:
35 return [ 36 return [
  37 + # Chinese
  38 + TtsModel(
  39 + model_dir="vits-zh-aishell3",
  40 + model_name="vits-aishell3.onnx",
  41 + lang="zh",
  42 + rule_fsts="vits-zh-aishell3/rule.fst",
  43 + ),
  44 + TtsModel(
  45 + model_dir="vits-zh-hf-doom",
  46 + model_name="doom.onnx",
  47 + lang="zh",
  48 + rule_fsts="vits-zh-hf-doom/rule.fst",
  49 + ),
  50 + TtsModel(
  51 + model_dir="vits-zh-hf-echo",
  52 + model_name="echo.onnx",
  53 + lang="zh",
  54 + rule_fsts="vits-zh-hf-echo/rule.fst",
  55 + ),
  56 + TtsModel(
  57 + model_dir="vits-zh-hf-zenyatta",
  58 + model_name="zenyatta.onnx",
  59 + lang="zh",
  60 + rule_fsts="vits-zh-hf-zenyatta/rule.fst",
  61 + ),
  62 + TtsModel(
  63 + model_dir="vits-zh-hf-abyssinvoker",
  64 + model_name="abyssinvoker.onnx",
  65 + lang="zh",
  66 + rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
  67 + ),
  68 + TtsModel(
  69 + model_dir="vits-zh-hf-keqing",
  70 + model_name="keqing.onnx",
  71 + lang="zh",
  72 + rule_fsts="vits-zh-hf-keqing/rule.fst",
  73 + ),
  74 + TtsModel(
  75 + model_dir="vits-zh-hf-eula",
  76 + model_name="eula.onnx",
  77 + lang="zh",
  78 + rule_fsts="vits-zh-hf-eula/rule.fst",
  79 + ),
  80 + TtsModel(
  81 + model_dir="vits-zh-hf-bronya",
  82 + model_name="bronya.onnx",
  83 + lang="zh",
  84 + rule_fsts="vits-zh-hf-bronya/rule.fst",
  85 + ),
36 TtsModel( 86 TtsModel(
37 - model_dir="vits-zh-aishell3", model_name="vits-aishell3.onnx", lang="zh" 87 + model_dir="vits-zh-hf-theresa",
  88 + model_name="theresa.onnx",
  89 + lang="zh",
  90 + rule_fsts="vits-zh-hf-theresa/rule.fst",
38 ), 91 ),
39 # English (US) 92 # English (US)
40 # fmt: off 93 # fmt: off
@@ -196,8 +196,14 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -196,8 +196,14 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
196 196
197 std::vector<int64_t> ans; 197 std::vector<int64_t> ans;
198 198
199 - auto sil = token2id_.at("sil");  
200 - auto eos = token2id_.at("eos"); 199 + int32_t sil = -1;
  200 + int32_t eos = -1;
  201 + if (token2id_.count("sil")) {
  202 + sil = token2id_.at("sil");
  203 + eos = token2id_.at("eos");
  204 + } else {
  205 + sil = 0;
  206 + }
201 207
202 ans.push_back(sil); 208 ans.push_back(sil);
203 209
@@ -216,7 +222,9 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -216,7 +222,9 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
216 ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 222 ans.insert(ans.end(), token_ids.begin(), token_ids.end());
217 } 223 }
218 ans.push_back(sil); 224 ans.push_back(sil);
  225 + if (eos != -1) {
219 ans.push_back(eos); 226 ans.push_back(eos);
  227 + }
220 return ans; 228 return ans;
221 } 229 }
222 230
@@ -10,15 +10,17 @@ @@ -10,15 +10,17 @@
10 #include <vector> 10 #include <vector>
11 11
12 #if __ANDROID_API__ >= 9 12 #if __ANDROID_API__ >= 9
  13 +#include <strstream>
  14 +
13 #include "android/asset_manager.h" 15 #include "android/asset_manager.h"
14 #include "android/asset_manager_jni.h" 16 #include "android/asset_manager_jni.h"
15 #endif 17 #endif
16 -  
17 #include "kaldifst/csrc/text-normalizer.h" 18 #include "kaldifst/csrc/text-normalizer.h"
18 #include "sherpa-onnx/csrc/lexicon.h" 19 #include "sherpa-onnx/csrc/lexicon.h"
19 #include "sherpa-onnx/csrc/macros.h" 20 #include "sherpa-onnx/csrc/macros.h"
20 #include "sherpa-onnx/csrc/offline-tts-impl.h" 21 #include "sherpa-onnx/csrc/offline-tts-impl.h"
21 #include "sherpa-onnx/csrc/offline-tts-vits-model.h" 22 #include "sherpa-onnx/csrc/offline-tts-vits-model.h"
  23 +#include "sherpa-onnx/csrc/onnx-utils.h"
22 #include "sherpa-onnx/csrc/text-utils.h" 24 #include "sherpa-onnx/csrc/text-utils.h"
23 25
24 namespace sherpa_onnx { 26 namespace sherpa_onnx {
@@ -52,7 +54,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -52,7 +54,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
52 model_->Punctuations(), model_->Language(), config.model.debug, 54 model_->Punctuations(), model_->Language(), config.model.debug,
53 model_->IsPiper()) { 55 model_->IsPiper()) {
54 if (!config.rule_fsts.empty()) { 56 if (!config.rule_fsts.empty()) {
55 - SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android"); 57 + std::vector<std::string> files;
  58 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  59 + tn_list_.reserve(files.size());
  60 + for (const auto &f : files) {
  61 + if (config.model.debug) {
  62 + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
  63 + }
  64 + auto buf = ReadFile(mgr, f);
  65 + std::istrstream is(buf.data(), buf.size());
  66 + tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
  67 + }
56 } 68 }
57 } 69 }
58 #endif 70 #endif
@@ -566,6 +566,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { @@ -566,6 +566,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
566 ans.model.provider = p; 566 ans.model.provider = p;
567 env->ReleaseStringUTFChars(s, p); 567 env->ReleaseStringUTFChars(s, p);
568 568
  569 + // for ruleFsts
  570 + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
  571 + s = (jstring)env->GetObjectField(config, fid);
  572 + p = env->GetStringUTFChars(s, nullptr);
  573 + ans.rule_fsts = p;
  574 + env->ReleaseStringUTFChars(s, p);
  575 +
569 return ans; 576 return ans;
570 } 577 }
571 578