Tts.kt
4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineTtsVitsModelConfig(
var model: String,
var lexicon: String = "",
var tokens: String,
var dataDir: String = "",
var dictDir: String = "",
var noiseScale: Float = 0.667f,
var noiseScaleW: Float = 0.8f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OfflineTtsConfig(
var model: OfflineTtsModelConfig,
var ruleFsts: String = "",
var ruleFars: String = "",
var maxNumSentences: Int = 1,
)
class GeneratedAudio(
val samples: FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) =
saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}
class OfflineTts(
assetManager: AssetManager? = null,
var config: OfflineTtsConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
fun sampleRate() = getSampleRate(ptr)
fun numSpeakers() = getNumSpeakers(ptr)
fun generate(
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
val objArray = generateImpl(ptr, text = text, sid = sid, speed = speed)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
)
}
fun generateWithCallback(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
val objArray = generateWithCallbackImpl(
ptr,
text = text,
sid = sid,
speed = speed,
callback = callback
)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
)
}
fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
}
fun free() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineTtsConfig,
): Long
private external fun newFromFile(
config: OfflineTtsConfig,
): Long
private external fun delete(ptr: Long)
private external fun getSampleRate(ptr: Long): Int
private external fun getNumSpeakers(ptr: Long): Int
// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
private external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): Array<Any>
private external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): Array<Any>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models
fun getOfflineTtsConfig(
modelDir: String,
modelName: String,
lexicon: String,
dataDir: String,
dictDir: String,
ruleFsts: String,
ruleFars: String
): OfflineTtsConfig {
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = OfflineTtsVitsModelConfig(
model = "$modelDir/$modelName",
lexicon = "$modelDir/$lexicon",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
dictDir = dictDir,
),
numThreads = 2,
debug = true,
provider = "cpu",
),
ruleFsts = ruleFsts,
ruleFars = ruleFars,
)
}