Fangjun Kuang
Committed by GitHub

Add jni interface and kotlin API examples for TTS. (#381)

  1 +hs_err*
  2 +main.jar
  3 +vits-zh-aishell3
@@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 4
5 fun main() { 5 fun main() {
  6 + testTts()
  7 + testAsr()
  8 +}
  9 +
  10 +fun testTts() {
  11 + var config = OfflineTtsConfig(
  12 + model=OfflineTtsModelConfig(
  13 + vits=OfflineTtsVitsModelConfig(
  14 + model="./vits-zh-aishell3/vits-aishell3.onnx",
  15 + lexicon="./vits-zh-aishell3/lexicon.txt",
  16 + tokens="./vits-zh-aishell3/tokens.txt",
  17 + ),
  18 + numThreads=1,
  19 + debug=true,
  20 + )
  21 + )
  22 + val tts = OfflineTts(config=config)
  23 + val audio = tts.generate(text="林美丽最美丽!", sid=99, speed=1.2f)
  24 + audio.save(filename="99.wav")
  25 +}
  26 +
  27 +fun testAsr() {
6 var featConfig = FeatureConfig( 28 var featConfig = FeatureConfig(
7 sampleRate = 16000, 29 sampleRate = 16000,
8 featureDim = 80, 30 featureDim = 80,
  1 +// Copyright (c) 2023 Xiaomi Corporation
  2 +package com.k2fsa.sherpa.onnx
  3 +
  4 +import android.content.res.AssetManager
  5 +
  6 +data class OfflineTtsVitsModelConfig(
  7 + var model: String,
  8 + var lexicon: String,
  9 + var tokens: String,
  10 + var noiseScale: Float = 0.667f,
  11 + var noiseScaleW: Float = 0.8f,
  12 + var lengthScale: Float = 1.0f,
  13 +)
  14 +
  15 +data class OfflineTtsModelConfig(
  16 + var vits: OfflineTtsVitsModelConfig,
  17 + var numThreads: Int = 1,
  18 + var debug: Boolean = false,
  19 + var provider: String = "cpu",
  20 +)
  21 +
  22 +data class OfflineTtsConfig(
  23 + var model: OfflineTtsModelConfig,
  24 +)
  25 +
  26 +class GeneratedAudio(
  27 + val samples : FloatArray,
  28 + val sampleRate: Int,
  29 +) {
  30 + fun save(filename: String) = saveImpl(filename=filename, samples=samples, sampleRate=sampleRate)
  31 +
  32 + private external fun saveImpl(
  33 + filename: String,
  34 + samples: FloatArray,
  35 + sampleRate: Int
  36 + ): Boolean
  37 +}
  38 +
  39 +class OfflineTts(
  40 + assetManager: AssetManager? = null,
  41 + var config: OfflineTtsConfig,
  42 +) {
  43 + private var ptr: Long
  44 +
  45 + init {
  46 + if (assetManager != null) {
  47 + ptr = new(assetManager, config)
  48 + } else {
  49 + ptr = newFromFile(config)
  50 + }
  51 + }
  52 +
  53 + fun generate(
  54 + text: String,
  55 + sid: Int = 0,
  56 + speed: Float = 1.0f
  57 + ): GeneratedAudio {
  58 + var objArray = generateImpl(ptr, text=text, sid=sid, speed=speed)
  59 + return GeneratedAudio(samples=objArray[0] as FloatArray,
  60 + sampleRate=objArray[1] as Int)
  61 + }
  62 +
  63 + fun allocate(assetManager: AssetManager? = null) {
  64 + if (ptr == 0L) {
  65 + if (assetManager != null) {
  66 + ptr = new(assetManager, config)
  67 + } else {
  68 + ptr = newFromFile(config)
  69 + }
  70 + }
  71 + }
  72 +
  73 + fun free() {
  74 + if (ptr != 0L) {
  75 + delete(ptr)
  76 + ptr = 0
  77 + }
  78 + }
  79 +
  80 + protected fun finalize() {
  81 + delete(ptr)
  82 + }
  83 +
  84 + private external fun new(
  85 + assetManager: AssetManager,
  86 + config: OfflineTtsConfig,
  87 + ): Long
  88 +
  89 + private external fun newFromFile(
  90 + config: OfflineTtsConfig,
  91 + ): Long
  92 +
  93 + private external fun delete(ptr: Long)
  94 +
  95 + // The returned array has two entries:
  96 + // - the first entry is an 1-D float array containing audio samples.
  97 + // Each sample is normalized to the range [-1, 1]
  98 + // - the second entry is the sample rate
  99 + external fun generateImpl(
  100 + ptr: Long,
  101 + text: String,
  102 + sid: Int = 0,
  103 + speed: Float = 1.0f
  104 + ): Array<Any>
  105 +
  106 + companion object {
  107 + init {
  108 + System.loadLibrary("sherpa-onnx-jni")
  109 + }
  110 + }
  111 +
  112 +}
@@ -6,21 +6,24 @@ @@ -6,21 +6,24 @@
6 6
7 set -e 7 set -e
8 8
  9 +
9 cd .. 10 cd ..
10 mkdir -p build 11 mkdir -p build
11 cd build 12 cd build
12 13
13 -cmake \  
14 - -DSHERPA_ONNX_ENABLE_PYTHON=OFF \  
15 - -DSHERPA_ONNX_ENABLE_TESTS=OFF \  
16 - -DSHERPA_ONNX_ENABLE_CHECK=OFF \  
17 - -DBUILD_SHARED_LIBS=ON \  
18 - -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \  
19 - -DSHERPA_ONNX_ENABLE_JNI=ON \  
20 - ..  
21 -  
22 -make -j4  
23 -ls -lh lib 14 +if [ ! -f ../build/lib/libsherpa-onnx-jni.dylib ]; then
  15 + cmake \
  16 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  17 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  18 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  19 + -DBUILD_SHARED_LIBS=ON \
  20 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  21 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  22 + ..
  23 +
  24 + make -j4
  25 + ls -lh lib
  26 +fi
24 27
25 export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH 28 export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
26 29
@@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then @@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
31 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 34 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
32 fi 35 fi
33 36
34 -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt 37 +kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt
35 38
36 ls -lh main.jar 39 ls -lh main.jar
37 40
@@ -10,7 +10,15 @@ @@ -10,7 +10,15 @@
10 #include <sstream> 10 #include <sstream>
11 #include <utility> 11 #include <utility>
12 12
  13 +#if __ANDROID_API__ >= 9
  14 +#include <strstream>
  15 +
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
13 #include "sherpa-onnx/csrc/macros.h" 20 #include "sherpa-onnx/csrc/macros.h"
  21 +#include "sherpa-onnx/csrc/onnx-utils.h"
14 #include "sherpa-onnx/csrc/text-utils.h" 22 #include "sherpa-onnx/csrc/text-utils.h"
15 23
16 namespace sherpa_onnx { 24 namespace sherpa_onnx {
@@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) { @@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) {
22 30
23 // Note: We don't use SymbolTable here since tokens may contain a blank 31 // Note: We don't use SymbolTable here since tokens may contain a blank
24 // in the first column 32 // in the first column
25 -static std::unordered_map<std::string, int32_t> ReadTokens(  
26 - const std::string &tokens) { 33 +static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
27 std::unordered_map<std::string, int32_t> token2id; 34 std::unordered_map<std::string, int32_t> token2id;
28 35
29 - std::ifstream is(tokens);  
30 std::string line; 36 std::string line;
31 37
32 std::string sym; 38 std::string sym;
@@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, @@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
80 bool debug /*= false*/) 86 bool debug /*= false*/)
81 : debug_(debug) { 87 : debug_(debug) {
82 InitLanguage(language); 88 InitLanguage(language);
83 - InitTokens(tokens);  
84 - InitLexicon(lexicon); 89 +
  90 + {
  91 + std::ifstream is(tokens);
  92 + InitTokens(is);
  93 + }
  94 +
  95 + {
  96 + std::ifstream is(lexicon);
  97 + InitLexicon(is);
  98 + }
  99 +
85 InitPunctuations(punctuations); 100 InitPunctuations(punctuations);
86 } 101 }
87 102
  103 +#if __ANDROID_API__ >= 9
  104 +Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
  105 + const std::string &tokens, const std::string &punctuations,
  106 + const std::string &language, bool debug /*= false*/)
  107 + : debug_(debug) {
  108 + InitLanguage(language);
  109 +
  110 + {
  111 + auto buf = ReadFile(mgr, tokens);
  112 + std::istrstream is(buf.data(), buf.size());
  113 + InitTokens(is);
  114 + }
  115 +
  116 + {
  117 + auto buf = ReadFile(mgr, lexicon);
  118 + std::istrstream is(buf.data(), buf.size());
  119 + InitLexicon(is);
  120 + }
  121 +
  122 + InitPunctuations(punctuations);
  123 +}
  124 +#endif
  125 +
88 std::vector<int64_t> Lexicon::ConvertTextToTokenIds( 126 std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
89 const std::string &text) const { 127 const std::string &text) const {
90 switch (language_) { 128 switch (language_) {
@@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
192 return ans; 230 return ans;
193 } 231 }
194 232
195 -void Lexicon::InitTokens(const std::string &tokens) {  
196 - token2id_ = ReadTokens(tokens);  
197 -} 233 +void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
198 234
199 void Lexicon::InitLanguage(const std::string &_lang) { 235 void Lexicon::InitLanguage(const std::string &_lang) {
200 std::string lang(_lang); 236 std::string lang(_lang);
@@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) { @@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) {
209 } 245 }
210 } 246 }
211 247
212 -void Lexicon::InitLexicon(const std::string &lexicon) {  
213 - std::ifstream is(lexicon);  
214 - 248 +void Lexicon::InitLexicon(std::istream &is) {
215 std::string word; 249 std::string word;
216 std::vector<std::string> token_list; 250 std::vector<std::string> token_list;
217 std::string line; 251 std::string line;
@@ -6,11 +6,17 @@ @@ -6,11 +6,17 @@
6 #define SHERPA_ONNX_CSRC_LEXICON_H_ 6 #define SHERPA_ONNX_CSRC_LEXICON_H_
7 7
8 #include <cstdint> 8 #include <cstdint>
  9 +#include <iostream>
9 #include <string> 10 #include <string>
10 #include <unordered_map> 11 #include <unordered_map>
11 #include <unordered_set> 12 #include <unordered_set>
12 #include <vector> 13 #include <vector>
13 14
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
14 namespace sherpa_onnx { 20 namespace sherpa_onnx {
15 21
16 // TODO(fangjun): Refactor it to an abstract class 22 // TODO(fangjun): Refactor it to an abstract class
@@ -20,6 +26,12 @@ class Lexicon { @@ -20,6 +26,12 @@ class Lexicon {
20 const std::string &punctuations, const std::string &language, 26 const std::string &punctuations, const std::string &language,
21 bool debug = false); 27 bool debug = false);
22 28
  29 +#if __ANDROID_API__ >= 9
  30 + Lexicon(AAssetManager *mgr, const std::string &lexicon,
  31 + const std::string &tokens, const std::string &punctuations,
  32 + const std::string &language, bool debug = false);
  33 +#endif
  34 +
23 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; 35 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
24 36
25 private: 37 private:
@@ -30,8 +42,8 @@ class Lexicon { @@ -30,8 +42,8 @@ class Lexicon {
30 const std::string &text) const; 42 const std::string &text) const;
31 43
32 void InitLanguage(const std::string &lang); 44 void InitLanguage(const std::string &lang);
33 - void InitTokens(const std::string &tokens);  
34 - void InitLexicon(const std::string &lexicon); 45 + void InitTokens(std::istream &is);
  46 + void InitLexicon(std::istream &is);
35 void InitPunctuations(const std::string &punctuations); 47 void InitPunctuations(const std::string &punctuations);
36 48
37 private: 49 private:
@@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( @@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
16 return std::make_unique<OfflineTtsVitsImpl>(config); 16 return std::make_unique<OfflineTtsVitsImpl>(config);
17 } 17 }
18 18
  19 +#if __ANDROID_API__ >= 9
  20 +std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
  21 + AAssetManager *mgr, const OfflineTtsConfig &config) {
  22 + // TODO(fangjun): Support other types
  23 + return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
  24 +}
  25 +#endif
  26 +
19 } // namespace sherpa_onnx 27 } // namespace sherpa_onnx
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <memory> 8 #include <memory>
9 #include <string> 9 #include <string>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "sherpa-onnx/csrc/offline-tts.h" 16 #include "sherpa-onnx/csrc/offline-tts.h"
12 17
13 namespace sherpa_onnx { 18 namespace sherpa_onnx {
@@ -18,6 +23,11 @@ class OfflineTtsImpl { @@ -18,6 +23,11 @@ class OfflineTtsImpl {
18 23
19 static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); 24 static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
20 25
  26 +#if __ANDROID_API__ >= 9
  27 + static std::unique_ptr<OfflineTtsImpl> Create(AAssetManager *mgr,
  28 + const OfflineTtsConfig &config);
  29 +#endif
  30 +
21 virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0, 31 virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
22 float speed = 1.0) const = 0; 32 float speed = 1.0) const = 0;
23 }; 33 };
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
12 #include "sherpa-onnx/csrc/lexicon.h" 17 #include "sherpa-onnx/csrc/lexicon.h"
13 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/offline-tts-impl.h" 19 #include "sherpa-onnx/csrc/offline-tts-impl.h"
@@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
24 model_->Punctuations(), model_->Language(), 29 model_->Punctuations(), model_->Language(),
25 config.model.debug) {} 30 config.model.debug) {}
26 31
  32 +#if __ANDROID_API__ >= 9
  33 + OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
  34 + : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
  35 + lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
  36 + model_->Punctuations(), model_->Language(),
  37 + config.model.debug) {}
  38 +#endif
  39 +
27 GeneratedAudio Generate(const std::string &text, int64_t sid = 0, 40 GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
28 float speed = 1.0) const override { 41 float speed = 1.0) const override {
29 int32_t num_speakers = model_->NumSpeakers(); 42 int32_t num_speakers = model_->NumSpeakers();
@@ -26,6 +26,17 @@ class OfflineTtsVitsModel::Impl { @@ -26,6 +26,17 @@ class OfflineTtsVitsModel::Impl {
26 Init(buf.data(), buf.size()); 26 Init(buf.data(), buf.size());
27 } 27 }
28 28
  29 +#if __ANDROID_API__ >= 9
  30 + Impl(AAssetManager *mgr, const OfflineTtsModelConfig &config)
  31 + : config_(config),
  32 + env_(ORT_LOGGING_LEVEL_WARNING),
  33 + sess_opts_(GetSessionOptions(config)),
  34 + allocator_{} {
  35 + auto buf = ReadFile(mgr, config.vits.model);
  36 + Init(buf.data(), buf.size());
  37 + }
  38 +#endif
  39 +
29 Ort::Value Run(Ort::Value x, int64_t sid, float speed) { 40 Ort::Value Run(Ort::Value x, int64_t sid, float speed) {
30 auto memory_info = 41 auto memory_info =
31 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 42 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -141,6 +152,12 @@ class OfflineTtsVitsModel::Impl { @@ -141,6 +152,12 @@ class OfflineTtsVitsModel::Impl {
141 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) 152 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
142 : impl_(std::make_unique<Impl>(config)) {} 153 : impl_(std::make_unique<Impl>(config)) {}
143 154
  155 +#if __ANDROID_API__ >= 9
  156 +OfflineTtsVitsModel::OfflineTtsVitsModel(AAssetManager *mgr,
  157 + const OfflineTtsModelConfig &config)
  158 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  159 +#endif
  160 +
144 OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; 161 OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
145 162
146 Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, 163 Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <memory> 8 #include <memory>
9 #include <string> 9 #include <string>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/offline-tts-model-config.h" 17 #include "sherpa-onnx/csrc/offline-tts-model-config.h"
13 18
@@ -18,6 +23,9 @@ class OfflineTtsVitsModel { @@ -18,6 +23,9 @@ class OfflineTtsVitsModel {
18 ~OfflineTtsVitsModel(); 23 ~OfflineTtsVitsModel();
19 24
20 explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); 25 explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config);
  26 +#if __ANDROID_API__ >= 9
  27 + OfflineTtsVitsModel(AAssetManager *mgr, const OfflineTtsModelConfig &config);
  28 +#endif
21 29
22 /** Run the model. 30 /** Run the model.
23 * 31 *
@@ -26,6 +26,11 @@ std::string OfflineTtsConfig::ToString() const { @@ -26,6 +26,11 @@ std::string OfflineTtsConfig::ToString() const {
26 OfflineTts::OfflineTts(const OfflineTtsConfig &config) 26 OfflineTts::OfflineTts(const OfflineTtsConfig &config)
27 : impl_(OfflineTtsImpl::Create(config)) {} 27 : impl_(OfflineTtsImpl::Create(config)) {}
28 28
  29 +#if __ANDROID_API__ >= 9
  30 +OfflineTts::OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config)
  31 + : impl_(OfflineTtsImpl::Create(mgr, config)) {}
  32 +#endif
  33 +
29 OfflineTts::~OfflineTts() = default; 34 OfflineTts::~OfflineTts() = default;
30 35
31 GeneratedAudio OfflineTts::Generate(const std::string &text, int64_t sid /*=0*/, 36 GeneratedAudio OfflineTts::Generate(const std::string &text, int64_t sid /*=0*/,
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
12 #include "sherpa-onnx/csrc/offline-tts-model-config.h" 17 #include "sherpa-onnx/csrc/offline-tts-model-config.h"
13 #include "sherpa-onnx/csrc/parse-options.h" 18 #include "sherpa-onnx/csrc/parse-options.h"
14 19
@@ -38,6 +43,11 @@ class OfflineTts { @@ -38,6 +43,11 @@ class OfflineTts {
38 public: 43 public:
39 ~OfflineTts(); 44 ~OfflineTts();
40 explicit OfflineTts(const OfflineTtsConfig &config); 45 explicit OfflineTts(const OfflineTtsConfig &config);
  46 +
  47 +#if __ANDROID_API__ >= 9
  48 + OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config);
  49 +#endif
  50 +
41 // @param text A string containing words separated by spaces 51 // @param text A string containing words separated by spaces
42 // @param sid Speaker ID. Used only for multi-speaker models, e.g., models 52 // @param sid Speaker ID. Used only for multi-speaker models, e.g., models
43 // trained using the VCTK dataset. It is not used for 53 // trained using the VCTK dataset. It is not used for
@@ -7,12 +7,13 @@ @@ -7,12 +7,13 @@
7 #include <cassert> 7 #include <cassert>
8 #include <fstream> 8 #include <fstream>
9 #include <sstream> 9 #include <sstream>
10 -#include <strstream>  
11 10
12 #include "sherpa-onnx/csrc/base64-decode.h" 11 #include "sherpa-onnx/csrc/base64-decode.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 12 #include "sherpa-onnx/csrc/onnx-utils.h"
14 13
15 #if __ANDROID_API__ >= 9 14 #if __ANDROID_API__ >= 9
  15 +#include <strstream>
  16 +
16 #include "android/asset_manager.h" 17 #include "android/asset_manager.h"
17 #include "android/asset_manager_jni.h" 18 #include "android/asset_manager_jni.h"
18 #endif 19 #endif
@@ -21,10 +21,12 @@ @@ -21,10 +21,12 @@
21 21
22 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
23 #include "sherpa-onnx/csrc/offline-recognizer.h" 23 #include "sherpa-onnx/csrc/offline-recognizer.h"
  24 +#include "sherpa-onnx/csrc/offline-tts.h"
24 #include "sherpa-onnx/csrc/online-recognizer.h" 25 #include "sherpa-onnx/csrc/online-recognizer.h"
25 #include "sherpa-onnx/csrc/onnx-utils.h" 26 #include "sherpa-onnx/csrc/onnx-utils.h"
26 #include "sherpa-onnx/csrc/voice-activity-detector.h" 27 #include "sherpa-onnx/csrc/voice-activity-detector.h"
27 #include "sherpa-onnx/csrc/wave-reader.h" 28 #include "sherpa-onnx/csrc/wave-reader.h"
  29 +#include "sherpa-onnx/csrc/wave-writer.h"
28 30
29 #define SHERPA_ONNX_EXTERN_C extern "C" 31 #define SHERPA_ONNX_EXTERN_C extern "C"
30 32
@@ -124,7 +126,7 @@ class SherpaOnnxVad { @@ -124,7 +126,7 @@ class SherpaOnnxVad {
124 126
125 void Pop() { vad_.Pop(); } 127 void Pop() { vad_.Pop(); }
126 128
127 - void Clear() { vad_.Clear();} 129 + void Clear() { vad_.Clear(); }
128 130
129 const SpeechSegment &Front() const { return vad_.Front(); } 131 const SpeechSegment &Front() const { return vad_.Front(); }
130 132
@@ -491,9 +493,173 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { @@ -491,9 +493,173 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
491 return ans; 493 return ans;
492 } 494 }
493 495
  496 +class SherpaOnnxOfflineTts {
  497 + public:
  498 +#if __ANDROID_API__ >= 9
  499 + SherpaOnnxOfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config)
  500 + : tts_(mgr, config) {}
  501 +#endif
  502 + explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
  503 + : tts_(config) {}
  504 +
  505 + GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
  506 + float speed = 1.0) const {
  507 + return tts_.Generate(text, sid, speed);
  508 + }
  509 +
  510 + private:
  511 + OfflineTts tts_;
  512 +};
  513 +
  514 +static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
  515 + OfflineTtsConfig ans;
  516 +
  517 + jclass cls = env->GetObjectClass(config);
  518 + jfieldID fid;
  519 +
  520 + fid = env->GetFieldID(cls, "model",
  521 + "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
  522 + jobject model = env->GetObjectField(config, fid);
  523 + jclass model_config_cls = env->GetObjectClass(model);
  524 +
  525 + fid = env->GetFieldID(model_config_cls, "vits",
  526 + "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
  527 + jobject vits = env->GetObjectField(model, fid);
  528 + jclass vits_cls = env->GetObjectClass(vits);
  529 +
  530 + fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;");
  531 + jstring s = (jstring)env->GetObjectField(vits, fid);
  532 + const char *p = env->GetStringUTFChars(s, nullptr);
  533 + ans.model.vits.model = p;
  534 + env->ReleaseStringUTFChars(s, p);
  535 +
  536 + fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;");
  537 + s = (jstring)env->GetObjectField(vits, fid);
  538 + p = env->GetStringUTFChars(s, nullptr);
  539 + ans.model.vits.lexicon = p;
  540 + env->ReleaseStringUTFChars(s, p);
  541 +
  542 + fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;");
  543 + s = (jstring)env->GetObjectField(vits, fid);
  544 + p = env->GetStringUTFChars(s, nullptr);
  545 + ans.model.vits.tokens = p;
  546 + env->ReleaseStringUTFChars(s, p);
  547 +
  548 + fid = env->GetFieldID(vits_cls, "noiseScale", "F");
  549 + ans.model.vits.noise_scale = env->GetFloatField(vits, fid);
  550 +
  551 + fid = env->GetFieldID(vits_cls, "noiseScaleW", "F");
  552 + ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid);
  553 +
  554 + fid = env->GetFieldID(vits_cls, "lengthScale", "F");
  555 + ans.model.vits.length_scale = env->GetFloatField(vits, fid);
  556 +
  557 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  558 + ans.model.num_threads = env->GetIntField(model, fid);
  559 +
  560 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  561 + ans.model.debug = env->GetBooleanField(model, fid);
  562 +
  563 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  564 + s = (jstring)env->GetObjectField(model, fid);
  565 + p = env->GetStringUTFChars(s, nullptr);
  566 + ans.model.provider = p;
  567 + env->ReleaseStringUTFChars(s, p);
  568 +
  569 + return ans;
  570 +}
  571 +
494 } // namespace sherpa_onnx 572 } // namespace sherpa_onnx
495 573
496 SHERPA_ONNX_EXTERN_C 574 SHERPA_ONNX_EXTERN_C
  575 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
  576 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  577 +#if __ANDROID_API__ >= 9
  578 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  579 + if (!mgr) {
  580 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  581 + }
  582 +#endif
  583 + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
  584 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  585 + auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(
  586 +#if __ANDROID_API__ >= 9
  587 + mgr,
  588 +#endif
  589 + config);
  590 +
  591 + return (jlong)tts;
  592 +}
  593 +
  594 +SHERPA_ONNX_EXTERN_C
  595 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
  596 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  597 + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
  598 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  599 + auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config);
  600 +
  601 + return (jlong)tts;
  602 +}
  603 +
  604 +SHERPA_ONNX_EXTERN_C
  605 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
  606 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  607 + delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
  608 +}
  609 +
  610 +// see
  611 +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
  612 +static jobject NewInteger(JNIEnv *env, int32_t value) {
  613 + jclass cls = env->FindClass("java/lang/Integer");
  614 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
  615 + return env->NewObject(cls, constructor, value);
  616 +}
  617 +
  618 +SHERPA_ONNX_EXTERN_C
  619 +JNIEXPORT jobjectArray JNICALL
  620 +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
  621 + jlong ptr, jstring text,
  622 + jint sid, jfloat speed) {
  623 + const char *p_text = env->GetStringUTFChars(text, nullptr);
  624 + SHERPA_ONNX_LOGE("string is: %s", p_text);
  625 +
  626 + auto audio =
  627 + reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
  628 + p_text, sid, speed);
  629 +
  630 + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
  631 + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
  632 + audio.samples.data());
  633 +
  634 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  635 + 2, env->FindClass("java/lang/Object"), nullptr);
  636 +
  637 + env->SetObjectArrayElement(obj_arr, 0, samples_arr);
  638 + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
  639 +
  640 + env->ReleaseStringUTFChars(text, p_text);
  641 +
  642 + return obj_arr;
  643 +}
  644 +
  645 +SHERPA_ONNX_EXTERN_C
  646 +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
  647 + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
  648 + jint sample_rate) {
  649 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
  650 +
  651 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  652 + jsize n = env->GetArrayLength(samples);
  653 +
  654 + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
  655 +
  656 + env->ReleaseStringUTFChars(filename, p_filename);
  657 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  658 +
  659 + return ok;
  660 +}
  661 +
  662 +SHERPA_ONNX_EXTERN_C
497 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( 663 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new(
498 JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { 664 JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
499 #if __ANDROID_API__ >= 9 665 #if __ANDROID_API__ >= 9
@@ -513,6 +679,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( @@ -513,6 +679,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new(
513 return (jlong)model; 679 return (jlong)model;
514 } 680 }
515 681
  682 +SHERPA_ONNX_EXTERN_C
516 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile( 683 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
517 JNIEnv *env, jobject /*obj*/, jobject _config) { 684 JNIEnv *env, jobject /*obj*/, jobject _config) {
518 auto config = sherpa_onnx::GetVadModelConfig(env, _config); 685 auto config = sherpa_onnx::GetVadModelConfig(env, _config);
@@ -560,20 +727,12 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env, @@ -560,20 +727,12 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
560 727
561 SHERPA_ONNX_EXTERN_C 728 SHERPA_ONNX_EXTERN_C
562 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env, 729 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
563 - jobject /*obj*/,  
564 - jlong ptr) { 730 + jobject /*obj*/,
  731 + jlong ptr) {
565 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr); 732 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
566 model->Clear(); 733 model->Clear();
567 } 734 }
568 735
569 -// see  
570 -// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables  
571 -static jobject NewInteger(JNIEnv *env, int32_t value) {  
572 - jclass cls = env->FindClass("java/lang/Integer");  
573 - jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");  
574 - return env->NewObject(cls, constructor, value);  
575 -}  
576 -  
577 SHERPA_ONNX_EXTERN_C 736 SHERPA_ONNX_EXTERN_C
578 JNIEXPORT jobjectArray JNICALL 737 JNIEXPORT jobjectArray JNICALL
579 Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { 738 Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {