Committed by
GitHub
Add jni interface and kotlin API examples for TTS. (#381)
正在显示
15 个修改的文件
包含
444 行增加
和
27 行删除
kotlin-api-examples/.gitignore
0 → 100644
| @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx | @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx | ||
| 3 | import android.content.res.AssetManager | 3 | import android.content.res.AssetManager |
| 4 | 4 | ||
| 5 | fun main() { | 5 | fun main() { |
| 6 | + testTts() | ||
| 7 | + testAsr() | ||
| 8 | +} | ||
| 9 | + | ||
| 10 | +fun testTts() { | ||
| 11 | + var config = OfflineTtsConfig( | ||
| 12 | + model=OfflineTtsModelConfig( | ||
| 13 | + vits=OfflineTtsVitsModelConfig( | ||
| 14 | + model="./vits-zh-aishell3/vits-aishell3.onnx", | ||
| 15 | + lexicon="./vits-zh-aishell3/lexicon.txt", | ||
| 16 | + tokens="./vits-zh-aishell3/tokens.txt", | ||
| 17 | + ), | ||
| 18 | + numThreads=1, | ||
| 19 | + debug=true, | ||
| 20 | + ) | ||
| 21 | + ) | ||
| 22 | + val tts = OfflineTts(config=config) | ||
| 23 | + val audio = tts.generate(text="林美丽最美丽!", sid=99, speed=1.2f) | ||
| 24 | + audio.save(filename="99.wav") | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +fun testAsr() { | ||
| 6 | var featConfig = FeatureConfig( | 28 | var featConfig = FeatureConfig( |
| 7 | sampleRate = 16000, | 29 | sampleRate = 16000, |
| 8 | featureDim = 80, | 30 | featureDim = 80, |
kotlin-api-examples/Tts.kt
0 → 100644
| 1 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 2 | +package com.k2fsa.sherpa.onnx | ||
| 3 | + | ||
| 4 | +import android.content.res.AssetManager | ||
| 5 | + | ||
| 6 | +data class OfflineTtsVitsModelConfig( | ||
| 7 | + var model: String, | ||
| 8 | + var lexicon: String, | ||
| 9 | + var tokens: String, | ||
| 10 | + var noiseScale: Float = 0.667f, | ||
| 11 | + var noiseScaleW: Float = 0.8f, | ||
| 12 | + var lengthScale: Float = 1.0f, | ||
| 13 | +) | ||
| 14 | + | ||
| 15 | +data class OfflineTtsModelConfig( | ||
| 16 | + var vits: OfflineTtsVitsModelConfig, | ||
| 17 | + var numThreads: Int = 1, | ||
| 18 | + var debug: Boolean = false, | ||
| 19 | + var provider: String = "cpu", | ||
| 20 | +) | ||
| 21 | + | ||
| 22 | +data class OfflineTtsConfig( | ||
| 23 | + var model: OfflineTtsModelConfig, | ||
| 24 | +) | ||
| 25 | + | ||
| 26 | +class GeneratedAudio( | ||
| 27 | + val samples : FloatArray, | ||
| 28 | + val sampleRate: Int, | ||
| 29 | +) { | ||
| 30 | + fun save(filename: String) = saveImpl(filename=filename, samples=samples, sampleRate=sampleRate) | ||
| 31 | + | ||
| 32 | + private external fun saveImpl( | ||
| 33 | + filename: String, | ||
| 34 | + samples: FloatArray, | ||
| 35 | + sampleRate: Int | ||
| 36 | + ): Boolean | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +class OfflineTts( | ||
| 40 | + assetManager: AssetManager? = null, | ||
| 41 | + var config: OfflineTtsConfig, | ||
| 42 | +) { | ||
| 43 | + private var ptr: Long | ||
| 44 | + | ||
| 45 | + init { | ||
| 46 | + if (assetManager != null) { | ||
| 47 | + ptr = new(assetManager, config) | ||
| 48 | + } else { | ||
| 49 | + ptr = newFromFile(config) | ||
| 50 | + } | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + fun generate( | ||
| 54 | + text: String, | ||
| 55 | + sid: Int = 0, | ||
| 56 | + speed: Float = 1.0f | ||
| 57 | + ): GeneratedAudio { | ||
| 58 | + var objArray = generateImpl(ptr, text=text, sid=sid, speed=speed) | ||
| 59 | + return GeneratedAudio(samples=objArray[0] as FloatArray, | ||
| 60 | + sampleRate=objArray[1] as Int) | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + fun allocate(assetManager: AssetManager? = null) { | ||
| 64 | + if (ptr == 0L) { | ||
| 65 | + if (assetManager != null) { | ||
| 66 | + ptr = new(assetManager, config) | ||
| 67 | + } else { | ||
| 68 | + ptr = newFromFile(config) | ||
| 69 | + } | ||
| 70 | + } | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + fun free() { | ||
| 74 | + if (ptr != 0L) { | ||
| 75 | + delete(ptr) | ||
| 76 | + ptr = 0 | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + protected fun finalize() { | ||
| 81 | + delete(ptr) | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + private external fun new( | ||
| 85 | + assetManager: AssetManager, | ||
| 86 | + config: OfflineTtsConfig, | ||
| 87 | + ): Long | ||
| 88 | + | ||
| 89 | + private external fun newFromFile( | ||
| 90 | + config: OfflineTtsConfig, | ||
| 91 | + ): Long | ||
| 92 | + | ||
| 93 | + private external fun delete(ptr: Long) | ||
| 94 | + | ||
| 95 | + // The returned array has two entries: | ||
| 96 | + // - the first entry is an 1-D float array containing audio samples. | ||
| 97 | + // Each sample is normalized to the range [-1, 1] | ||
| 98 | + // - the second entry is the sample rate | ||
| 99 | + external fun generateImpl( | ||
| 100 | + ptr: Long, | ||
| 101 | + text: String, | ||
| 102 | + sid: Int = 0, | ||
| 103 | + speed: Float = 1.0f | ||
| 104 | + ): Array<Any> | ||
| 105 | + | ||
| 106 | + companion object { | ||
| 107 | + init { | ||
| 108 | + System.loadLibrary("sherpa-onnx-jni") | ||
| 109 | + } | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | +} |
| @@ -6,11 +6,13 @@ | @@ -6,11 +6,13 @@ | ||
| 6 | 6 | ||
| 7 | set -e | 7 | set -e |
| 8 | 8 | ||
| 9 | + | ||
| 9 | cd .. | 10 | cd .. |
| 10 | mkdir -p build | 11 | mkdir -p build |
| 11 | cd build | 12 | cd build |
| 12 | 13 | ||
| 13 | -cmake \ | 14 | +if [ ! -f ../build/lib/libsherpa-onnx-jni.dylib ]; then |
| 15 | + cmake \ | ||
| 14 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | 16 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ |
| 15 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | 17 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ |
| 16 | -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | 18 | -DSHERPA_ONNX_ENABLE_CHECK=OFF \ |
| @@ -19,8 +21,9 @@ cmake \ | @@ -19,8 +21,9 @@ cmake \ | ||
| 19 | -DSHERPA_ONNX_ENABLE_JNI=ON \ | 21 | -DSHERPA_ONNX_ENABLE_JNI=ON \ |
| 20 | .. | 22 | .. |
| 21 | 23 | ||
| 22 | -make -j4 | ||
| 23 | -ls -lh lib | 24 | + make -j4 |
| 25 | + ls -lh lib | ||
| 26 | +fi | ||
| 24 | 27 | ||
| 25 | export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH | 28 | export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH |
| 26 | 29 | ||
| @@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then | @@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then | ||
| 31 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 | 34 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 |
| 32 | fi | 35 | fi |
| 33 | 36 | ||
| 34 | -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt | 37 | +kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt |
| 35 | 38 | ||
| 36 | ls -lh main.jar | 39 | ls -lh main.jar |
| 37 | 40 |
| @@ -10,7 +10,15 @@ | @@ -10,7 +10,15 @@ | ||
| 10 | #include <sstream> | 10 | #include <sstream> |
| 11 | #include <utility> | 11 | #include <utility> |
| 12 | 12 | ||
| 13 | +#if __ANDROID_API__ >= 9 | ||
| 14 | +#include <strstream> | ||
| 15 | + | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 13 | #include "sherpa-onnx/csrc/macros.h" | 20 | #include "sherpa-onnx/csrc/macros.h" |
| 21 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | #include "sherpa-onnx/csrc/text-utils.h" | 22 | #include "sherpa-onnx/csrc/text-utils.h" |
| 15 | 23 | ||
| 16 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| @@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) { | @@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) { | ||
| 22 | 30 | ||
| 23 | // Note: We don't use SymbolTable here since tokens may contain a blank | 31 | // Note: We don't use SymbolTable here since tokens may contain a blank |
| 24 | // in the first column | 32 | // in the first column |
| 25 | -static std::unordered_map<std::string, int32_t> ReadTokens( | ||
| 26 | - const std::string &tokens) { | 33 | +static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) { |
| 27 | std::unordered_map<std::string, int32_t> token2id; | 34 | std::unordered_map<std::string, int32_t> token2id; |
| 28 | 35 | ||
| 29 | - std::ifstream is(tokens); | ||
| 30 | std::string line; | 36 | std::string line; |
| 31 | 37 | ||
| 32 | std::string sym; | 38 | std::string sym; |
| @@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | @@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | ||
| 80 | bool debug /*= false*/) | 86 | bool debug /*= false*/) |
| 81 | : debug_(debug) { | 87 | : debug_(debug) { |
| 82 | InitLanguage(language); | 88 | InitLanguage(language); |
| 83 | - InitTokens(tokens); | ||
| 84 | - InitLexicon(lexicon); | 89 | + |
| 90 | + { | ||
| 91 | + std::ifstream is(tokens); | ||
| 92 | + InitTokens(is); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + { | ||
| 96 | + std::ifstream is(lexicon); | ||
| 97 | + InitLexicon(is); | ||
| 98 | + } | ||
| 99 | + | ||
| 85 | InitPunctuations(punctuations); | 100 | InitPunctuations(punctuations); |
| 86 | } | 101 | } |
| 87 | 102 | ||
| 103 | +#if __ANDROID_API__ >= 9 | ||
| 104 | +Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, | ||
| 105 | + const std::string &tokens, const std::string &punctuations, | ||
| 106 | + const std::string &language, bool debug /*= false*/) | ||
| 107 | + : debug_(debug) { | ||
| 108 | + InitLanguage(language); | ||
| 109 | + | ||
| 110 | + { | ||
| 111 | + auto buf = ReadFile(mgr, tokens); | ||
| 112 | + std::istrstream is(buf.data(), buf.size()); | ||
| 113 | + InitTokens(is); | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + { | ||
| 117 | + auto buf = ReadFile(mgr, lexicon); | ||
| 118 | + std::istrstream is(buf.data(), buf.size()); | ||
| 119 | + InitLexicon(is); | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + InitPunctuations(punctuations); | ||
| 123 | +} | ||
| 124 | +#endif | ||
| 125 | + | ||
| 88 | std::vector<int64_t> Lexicon::ConvertTextToTokenIds( | 126 | std::vector<int64_t> Lexicon::ConvertTextToTokenIds( |
| 89 | const std::string &text) const { | 127 | const std::string &text) const { |
| 90 | switch (language_) { | 128 | switch (language_) { |
| @@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 192 | return ans; | 230 | return ans; |
| 193 | } | 231 | } |
| 194 | 232 | ||
| 195 | -void Lexicon::InitTokens(const std::string &tokens) { | ||
| 196 | - token2id_ = ReadTokens(tokens); | ||
| 197 | -} | 233 | +void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } |
| 198 | 234 | ||
| 199 | void Lexicon::InitLanguage(const std::string &_lang) { | 235 | void Lexicon::InitLanguage(const std::string &_lang) { |
| 200 | std::string lang(_lang); | 236 | std::string lang(_lang); |
| @@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) { | @@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) { | ||
| 209 | } | 245 | } |
| 210 | } | 246 | } |
| 211 | 247 | ||
| 212 | -void Lexicon::InitLexicon(const std::string &lexicon) { | ||
| 213 | - std::ifstream is(lexicon); | ||
| 214 | - | 248 | +void Lexicon::InitLexicon(std::istream &is) { |
| 215 | std::string word; | 249 | std::string word; |
| 216 | std::vector<std::string> token_list; | 250 | std::vector<std::string> token_list; |
| 217 | std::string line; | 251 | std::string line; |
| @@ -6,11 +6,17 @@ | @@ -6,11 +6,17 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_LEXICON_H_ | 6 | #define SHERPA_ONNX_CSRC_LEXICON_H_ |
| 7 | 7 | ||
| 8 | #include <cstdint> | 8 | #include <cstdint> |
| 9 | +#include <iostream> | ||
| 9 | #include <string> | 10 | #include <string> |
| 10 | #include <unordered_map> | 11 | #include <unordered_map> |
| 11 | #include <unordered_set> | 12 | #include <unordered_set> |
| 12 | #include <vector> | 13 | #include <vector> |
| 13 | 14 | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 14 | namespace sherpa_onnx { | 20 | namespace sherpa_onnx { |
| 15 | 21 | ||
| 16 | // TODO(fangjun): Refactor it to an abstract class | 22 | // TODO(fangjun): Refactor it to an abstract class |
| @@ -20,6 +26,12 @@ class Lexicon { | @@ -20,6 +26,12 @@ class Lexicon { | ||
| 20 | const std::string &punctuations, const std::string &language, | 26 | const std::string &punctuations, const std::string &language, |
| 21 | bool debug = false); | 27 | bool debug = false); |
| 22 | 28 | ||
| 29 | +#if __ANDROID_API__ >= 9 | ||
| 30 | + Lexicon(AAssetManager *mgr, const std::string &lexicon, | ||
| 31 | + const std::string &tokens, const std::string &punctuations, | ||
| 32 | + const std::string &language, bool debug = false); | ||
| 33 | +#endif | ||
| 34 | + | ||
| 23 | std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; | 35 | std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; |
| 24 | 36 | ||
| 25 | private: | 37 | private: |
| @@ -30,8 +42,8 @@ class Lexicon { | @@ -30,8 +42,8 @@ class Lexicon { | ||
| 30 | const std::string &text) const; | 42 | const std::string &text) const; |
| 31 | 43 | ||
| 32 | void InitLanguage(const std::string &lang); | 44 | void InitLanguage(const std::string &lang); |
| 33 | - void InitTokens(const std::string &tokens); | ||
| 34 | - void InitLexicon(const std::string &lexicon); | 45 | + void InitTokens(std::istream &is); |
| 46 | + void InitLexicon(std::istream &is); | ||
| 35 | void InitPunctuations(const std::string &punctuations); | 47 | void InitPunctuations(const std::string &punctuations); |
| 36 | 48 | ||
| 37 | private: | 49 | private: |
| @@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | @@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
| 16 | return std::make_unique<OfflineTtsVitsImpl>(config); | 16 | return std::make_unique<OfflineTtsVitsImpl>(config); |
| 17 | } | 17 | } |
| 18 | 18 | ||
| 19 | +#if __ANDROID_API__ >= 9 | ||
| 20 | +std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
| 21 | + AAssetManager *mgr, const OfflineTtsConfig &config) { | ||
| 22 | + // TODO(fangjun): Support other types | ||
| 23 | + return std::make_unique<OfflineTtsVitsImpl>(mgr, config); | ||
| 24 | +} | ||
| 25 | +#endif | ||
| 26 | + | ||
| 19 | } // namespace sherpa_onnx | 27 | } // namespace sherpa_onnx |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <string> | 9 | #include <string> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | #include "sherpa-onnx/csrc/offline-tts.h" | 16 | #include "sherpa-onnx/csrc/offline-tts.h" |
| 12 | 17 | ||
| 13 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| @@ -18,6 +23,11 @@ class OfflineTtsImpl { | @@ -18,6 +23,11 @@ class OfflineTtsImpl { | ||
| 18 | 23 | ||
| 19 | static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); | 24 | static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); |
| 20 | 25 | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + static std::unique_ptr<OfflineTtsImpl> Create(AAssetManager *mgr, | ||
| 28 | + const OfflineTtsConfig &config); | ||
| 29 | +#endif | ||
| 30 | + | ||
| 21 | virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0, | 31 | virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0, |
| 22 | float speed = 1.0) const = 0; | 32 | float speed = 1.0) const = 0; |
| 23 | }; | 33 | }; |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 12 | #include "sherpa-onnx/csrc/lexicon.h" | 17 | #include "sherpa-onnx/csrc/lexicon.h" |
| 13 | #include "sherpa-onnx/csrc/macros.h" | 18 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/offline-tts-impl.h" | 19 | #include "sherpa-onnx/csrc/offline-tts-impl.h" |
| @@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 24 | model_->Punctuations(), model_->Language(), | 29 | model_->Punctuations(), model_->Language(), |
| 25 | config.model.debug) {} | 30 | config.model.debug) {} |
| 26 | 31 | ||
| 32 | +#if __ANDROID_API__ >= 9 | ||
| 33 | + OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) | ||
| 34 | + : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)), | ||
| 35 | + lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens, | ||
| 36 | + model_->Punctuations(), model_->Language(), | ||
| 37 | + config.model.debug) {} | ||
| 38 | +#endif | ||
| 39 | + | ||
| 27 | GeneratedAudio Generate(const std::string &text, int64_t sid = 0, | 40 | GeneratedAudio Generate(const std::string &text, int64_t sid = 0, |
| 28 | float speed = 1.0) const override { | 41 | float speed = 1.0) const override { |
| 29 | int32_t num_speakers = model_->NumSpeakers(); | 42 | int32_t num_speakers = model_->NumSpeakers(); |
| @@ -26,6 +26,17 @@ class OfflineTtsVitsModel::Impl { | @@ -26,6 +26,17 @@ class OfflineTtsVitsModel::Impl { | ||
| 26 | Init(buf.data(), buf.size()); | 26 | Init(buf.data(), buf.size()); |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | +#if __ANDROID_API__ >= 9 | ||
| 30 | + Impl(AAssetManager *mgr, const OfflineTtsModelConfig &config) | ||
| 31 | + : config_(config), | ||
| 32 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 33 | + sess_opts_(GetSessionOptions(config)), | ||
| 34 | + allocator_{} { | ||
| 35 | + auto buf = ReadFile(mgr, config.vits.model); | ||
| 36 | + Init(buf.data(), buf.size()); | ||
| 37 | + } | ||
| 38 | +#endif | ||
| 39 | + | ||
| 29 | Ort::Value Run(Ort::Value x, int64_t sid, float speed) { | 40 | Ort::Value Run(Ort::Value x, int64_t sid, float speed) { |
| 30 | auto memory_info = | 41 | auto memory_info = |
| 31 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 42 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| @@ -141,6 +152,12 @@ class OfflineTtsVitsModel::Impl { | @@ -141,6 +152,12 @@ class OfflineTtsVitsModel::Impl { | ||
| 141 | OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) | 152 | OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) |
| 142 | : impl_(std::make_unique<Impl>(config)) {} | 153 | : impl_(std::make_unique<Impl>(config)) {} |
| 143 | 154 | ||
| 155 | +#if __ANDROID_API__ >= 9 | ||
| 156 | +OfflineTtsVitsModel::OfflineTtsVitsModel(AAssetManager *mgr, | ||
| 157 | + const OfflineTtsModelConfig &config) | ||
| 158 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 159 | +#endif | ||
| 160 | + | ||
| 144 | OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; | 161 | OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; |
| 145 | 162 | ||
| 146 | Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, | 163 | Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <string> | 9 | #include <string> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 16 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" | 17 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" |
| 13 | 18 | ||
| @@ -18,6 +23,9 @@ class OfflineTtsVitsModel { | @@ -18,6 +23,9 @@ class OfflineTtsVitsModel { | ||
| 18 | ~OfflineTtsVitsModel(); | 23 | ~OfflineTtsVitsModel(); |
| 19 | 24 | ||
| 20 | explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); | 25 | explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); |
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + OfflineTtsVitsModel(AAssetManager *mgr, const OfflineTtsModelConfig &config); | ||
| 28 | +#endif | ||
| 21 | 29 | ||
| 22 | /** Run the model. | 30 | /** Run the model. |
| 23 | * | 31 | * |
| @@ -26,6 +26,11 @@ std::string OfflineTtsConfig::ToString() const { | @@ -26,6 +26,11 @@ std::string OfflineTtsConfig::ToString() const { | ||
| 26 | OfflineTts::OfflineTts(const OfflineTtsConfig &config) | 26 | OfflineTts::OfflineTts(const OfflineTtsConfig &config) |
| 27 | : impl_(OfflineTtsImpl::Create(config)) {} | 27 | : impl_(OfflineTtsImpl::Create(config)) {} |
| 28 | 28 | ||
| 29 | +#if __ANDROID_API__ >= 9 | ||
| 30 | +OfflineTts::OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config) | ||
| 31 | + : impl_(OfflineTtsImpl::Create(mgr, config)) {} | ||
| 32 | +#endif | ||
| 33 | + | ||
| 29 | OfflineTts::~OfflineTts() = default; | 34 | OfflineTts::~OfflineTts() = default; |
| 30 | 35 | ||
| 31 | GeneratedAudio OfflineTts::Generate(const std::string &text, int64_t sid /*=0*/, | 36 | GeneratedAudio OfflineTts::Generate(const std::string &text, int64_t sid /*=0*/, |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 12 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" | 17 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/parse-options.h" | 18 | #include "sherpa-onnx/csrc/parse-options.h" |
| 14 | 19 | ||
| @@ -38,6 +43,11 @@ class OfflineTts { | @@ -38,6 +43,11 @@ class OfflineTts { | ||
| 38 | public: | 43 | public: |
| 39 | ~OfflineTts(); | 44 | ~OfflineTts(); |
| 40 | explicit OfflineTts(const OfflineTtsConfig &config); | 45 | explicit OfflineTts(const OfflineTtsConfig &config); |
| 46 | + | ||
| 47 | +#if __ANDROID_API__ >= 9 | ||
| 48 | + OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config); | ||
| 49 | +#endif | ||
| 50 | + | ||
| 41 | // @param text A string containing words separated by spaces | 51 | // @param text A string containing words separated by spaces |
| 42 | // @param sid Speaker ID. Used only for multi-speaker models, e.g., models | 52 | // @param sid Speaker ID. Used only for multi-speaker models, e.g., models |
| 43 | // trained using the VCTK dataset. It is not used for | 53 | // trained using the VCTK dataset. It is not used for |
| @@ -7,12 +7,13 @@ | @@ -7,12 +7,13 @@ | ||
| 7 | #include <cassert> | 7 | #include <cassert> |
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | #include <sstream> | 9 | #include <sstream> |
| 10 | -#include <strstream> | ||
| 11 | 10 | ||
| 12 | #include "sherpa-onnx/csrc/base64-decode.h" | 11 | #include "sherpa-onnx/csrc/base64-decode.h" |
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 12 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | 13 | ||
| 15 | #if __ANDROID_API__ >= 9 | 14 | #if __ANDROID_API__ >= 9 |
| 15 | +#include <strstream> | ||
| 16 | + | ||
| 16 | #include "android/asset_manager.h" | 17 | #include "android/asset_manager.h" |
| 17 | #include "android/asset_manager_jni.h" | 18 | #include "android/asset_manager_jni.h" |
| 18 | #endif | 19 | #endif |
| @@ -21,10 +21,12 @@ | @@ -21,10 +21,12 @@ | ||
| 21 | 21 | ||
| 22 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 23 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 23 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 24 | +#include "sherpa-onnx/csrc/offline-tts.h" | ||
| 24 | #include "sherpa-onnx/csrc/online-recognizer.h" | 25 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 25 | #include "sherpa-onnx/csrc/onnx-utils.h" | 26 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 26 | #include "sherpa-onnx/csrc/voice-activity-detector.h" | 27 | #include "sherpa-onnx/csrc/voice-activity-detector.h" |
| 27 | #include "sherpa-onnx/csrc/wave-reader.h" | 28 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 29 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 28 | 30 | ||
| 29 | #define SHERPA_ONNX_EXTERN_C extern "C" | 31 | #define SHERPA_ONNX_EXTERN_C extern "C" |
| 30 | 32 | ||
| @@ -124,7 +126,7 @@ class SherpaOnnxVad { | @@ -124,7 +126,7 @@ class SherpaOnnxVad { | ||
| 124 | 126 | ||
| 125 | void Pop() { vad_.Pop(); } | 127 | void Pop() { vad_.Pop(); } |
| 126 | 128 | ||
| 127 | - void Clear() { vad_.Clear();} | 129 | + void Clear() { vad_.Clear(); } |
| 128 | 130 | ||
| 129 | const SpeechSegment &Front() const { return vad_.Front(); } | 131 | const SpeechSegment &Front() const { return vad_.Front(); } |
| 130 | 132 | ||
| @@ -491,9 +493,173 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { | @@ -491,9 +493,173 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { | ||
| 491 | return ans; | 493 | return ans; |
| 492 | } | 494 | } |
| 493 | 495 | ||
| 496 | +class SherpaOnnxOfflineTts { | ||
| 497 | + public: | ||
| 498 | +#if __ANDROID_API__ >= 9 | ||
| 499 | + SherpaOnnxOfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config) | ||
| 500 | + : tts_(mgr, config) {} | ||
| 501 | +#endif | ||
| 502 | + explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config) | ||
| 503 | + : tts_(config) {} | ||
| 504 | + | ||
| 505 | + GeneratedAudio Generate(const std::string &text, int64_t sid = 0, | ||
| 506 | + float speed = 1.0) const { | ||
| 507 | + return tts_.Generate(text, sid, speed); | ||
| 508 | + } | ||
| 509 | + | ||
| 510 | + private: | ||
| 511 | + OfflineTts tts_; | ||
| 512 | +}; | ||
| 513 | + | ||
| 514 | +static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | ||
| 515 | + OfflineTtsConfig ans; | ||
| 516 | + | ||
| 517 | + jclass cls = env->GetObjectClass(config); | ||
| 518 | + jfieldID fid; | ||
| 519 | + | ||
| 520 | + fid = env->GetFieldID(cls, "model", | ||
| 521 | + "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); | ||
| 522 | + jobject model = env->GetObjectField(config, fid); | ||
| 523 | + jclass model_config_cls = env->GetObjectClass(model); | ||
| 524 | + | ||
| 525 | + fid = env->GetFieldID(model_config_cls, "vits", | ||
| 526 | + "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); | ||
| 527 | + jobject vits = env->GetObjectField(model, fid); | ||
| 528 | + jclass vits_cls = env->GetObjectClass(vits); | ||
| 529 | + | ||
| 530 | + fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;"); | ||
| 531 | + jstring s = (jstring)env->GetObjectField(vits, fid); | ||
| 532 | + const char *p = env->GetStringUTFChars(s, nullptr); | ||
| 533 | + ans.model.vits.model = p; | ||
| 534 | + env->ReleaseStringUTFChars(s, p); | ||
| 535 | + | ||
| 536 | + fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;"); | ||
| 537 | + s = (jstring)env->GetObjectField(vits, fid); | ||
| 538 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 539 | + ans.model.vits.lexicon = p; | ||
| 540 | + env->ReleaseStringUTFChars(s, p); | ||
| 541 | + | ||
| 542 | + fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;"); | ||
| 543 | + s = (jstring)env->GetObjectField(vits, fid); | ||
| 544 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 545 | + ans.model.vits.tokens = p; | ||
| 546 | + env->ReleaseStringUTFChars(s, p); | ||
| 547 | + | ||
| 548 | + fid = env->GetFieldID(vits_cls, "noiseScale", "F"); | ||
| 549 | + ans.model.vits.noise_scale = env->GetFloatField(vits, fid); | ||
| 550 | + | ||
| 551 | + fid = env->GetFieldID(vits_cls, "noiseScaleW", "F"); | ||
| 552 | + ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid); | ||
| 553 | + | ||
| 554 | + fid = env->GetFieldID(vits_cls, "lengthScale", "F"); | ||
| 555 | + ans.model.vits.length_scale = env->GetFloatField(vits, fid); | ||
| 556 | + | ||
| 557 | + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | ||
| 558 | + ans.model.num_threads = env->GetIntField(model, fid); | ||
| 559 | + | ||
| 560 | + fid = env->GetFieldID(model_config_cls, "debug", "Z"); | ||
| 561 | + ans.model.debug = env->GetBooleanField(model, fid); | ||
| 562 | + | ||
| 563 | + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | ||
| 564 | + s = (jstring)env->GetObjectField(model, fid); | ||
| 565 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 566 | + ans.model.provider = p; | ||
| 567 | + env->ReleaseStringUTFChars(s, p); | ||
| 568 | + | ||
| 569 | + return ans; | ||
| 570 | +} | ||
| 571 | + | ||
| 494 | } // namespace sherpa_onnx | 572 | } // namespace sherpa_onnx |
| 495 | 573 | ||
| 496 | SHERPA_ONNX_EXTERN_C | 574 | SHERPA_ONNX_EXTERN_C |
| 575 | +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new( | ||
| 576 | + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | ||
| 577 | +#if __ANDROID_API__ >= 9 | ||
| 578 | + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); | ||
| 579 | + if (!mgr) { | ||
| 580 | + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); | ||
| 581 | + } | ||
| 582 | +#endif | ||
| 583 | + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); | ||
| 584 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 585 | + auto tts = new sherpa_onnx::SherpaOnnxOfflineTts( | ||
| 586 | +#if __ANDROID_API__ >= 9 | ||
| 587 | + mgr, | ||
| 588 | +#endif | ||
| 589 | + config); | ||
| 590 | + | ||
| 591 | + return (jlong)tts; | ||
| 592 | +} | ||
| 593 | + | ||
| 594 | +SHERPA_ONNX_EXTERN_C | ||
| 595 | +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( | ||
| 596 | + JNIEnv *env, jobject /*obj*/, jobject _config) { | ||
| 597 | + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); | ||
| 598 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 599 | + auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config); | ||
| 600 | + | ||
| 601 | + return (jlong)tts; | ||
| 602 | +} | ||
| 603 | + | ||
| 604 | +SHERPA_ONNX_EXTERN_C | ||
| 605 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( | ||
| 606 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 607 | + delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr); | ||
| 608 | +} | ||
| 609 | + | ||
| 610 | +// see | ||
| 611 | +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables | ||
| 612 | +static jobject NewInteger(JNIEnv *env, int32_t value) { | ||
| 613 | + jclass cls = env->FindClass("java/lang/Integer"); | ||
| 614 | + jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V"); | ||
| 615 | + return env->NewObject(cls, constructor, value); | ||
| 616 | +} | ||
| 617 | + | ||
| 618 | +SHERPA_ONNX_EXTERN_C | ||
| 619 | +JNIEXPORT jobjectArray JNICALL | ||
| 620 | +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, | ||
| 621 | + jlong ptr, jstring text, | ||
| 622 | + jint sid, jfloat speed) { | ||
| 623 | + const char *p_text = env->GetStringUTFChars(text, nullptr); | ||
| 624 | + SHERPA_ONNX_LOGE("string is: %s", p_text); | ||
| 625 | + | ||
| 626 | + auto audio = | ||
| 627 | + reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate( | ||
| 628 | + p_text, sid, speed); | ||
| 629 | + | ||
| 630 | + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); | ||
| 631 | + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), | ||
| 632 | + audio.samples.data()); | ||
| 633 | + | ||
| 634 | + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( | ||
| 635 | + 2, env->FindClass("java/lang/Object"), nullptr); | ||
| 636 | + | ||
| 637 | + env->SetObjectArrayElement(obj_arr, 0, samples_arr); | ||
| 638 | + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); | ||
| 639 | + | ||
| 640 | + env->ReleaseStringUTFChars(text, p_text); | ||
| 641 | + | ||
| 642 | + return obj_arr; | ||
| 643 | +} | ||
| 644 | + | ||
| 645 | +SHERPA_ONNX_EXTERN_C | ||
| 646 | +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( | ||
| 647 | + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, | ||
| 648 | + jint sample_rate) { | ||
| 649 | + const char *p_filename = env->GetStringUTFChars(filename, nullptr); | ||
| 650 | + | ||
| 651 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 652 | + jsize n = env->GetArrayLength(samples); | ||
| 653 | + | ||
| 654 | + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); | ||
| 655 | + | ||
| 656 | + env->ReleaseStringUTFChars(filename, p_filename); | ||
| 657 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 658 | + | ||
| 659 | + return ok; | ||
| 660 | +} | ||
| 661 | + | ||
| 662 | +SHERPA_ONNX_EXTERN_C | ||
| 497 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( | 663 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( |
| 498 | JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | 664 | JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { |
| 499 | #if __ANDROID_API__ >= 9 | 665 | #if __ANDROID_API__ >= 9 |
| @@ -513,6 +679,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( | @@ -513,6 +679,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( | ||
| 513 | return (jlong)model; | 679 | return (jlong)model; |
| 514 | } | 680 | } |
| 515 | 681 | ||
| 682 | +SHERPA_ONNX_EXTERN_C | ||
| 516 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile( | 683 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile( |
| 517 | JNIEnv *env, jobject /*obj*/, jobject _config) { | 684 | JNIEnv *env, jobject /*obj*/, jobject _config) { |
| 518 | auto config = sherpa_onnx::GetVadModelConfig(env, _config); | 685 | auto config = sherpa_onnx::GetVadModelConfig(env, _config); |
| @@ -566,14 +733,6 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env, | @@ -566,14 +733,6 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env, | ||
| 566 | model->Clear(); | 733 | model->Clear(); |
| 567 | } | 734 | } |
| 568 | 735 | ||
| 569 | -// see | ||
| 570 | -// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables | ||
| 571 | -static jobject NewInteger(JNIEnv *env, int32_t value) { | ||
| 572 | - jclass cls = env->FindClass("java/lang/Integer"); | ||
| 573 | - jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V"); | ||
| 574 | - return env->NewObject(cls, constructor, value); | ||
| 575 | -} | ||
| 576 | - | ||
| 577 | SHERPA_ONNX_EXTERN_C | 736 | SHERPA_ONNX_EXTERN_C |
| 578 | JNIEXPORT jobjectArray JNICALL | 737 | JNIEXPORT jobjectArray JNICALL |
| 579 | Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { | 738 | Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { |
-
请 注册 或 登录 后发表评论