Fangjun Kuang
Committed by GitHub

Add JNI (#57)

Makefile
*.jar
... ...
package android.content.res
// a dummy class for testing only
class AssetManager
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
fun main() {
var featConfig = FeatureConfig(
sampleRate=16000.0f,
featureDim=80,
)
var modelConfig = OnlineTransducerModelConfig(
encoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner="./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
numThreads=4,
debug=false,
)
var endpointConfig = EndpointConfig()
var config = OnlineRecognizerConfig(
modelConfig=modelConfig,
featConfig=featConfig,
endpointConfig=endpointConfig,
tokens="./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
enableEndpoint=true,
)
var model = SherpaOnnx(
assetManager = AssetManager(),
config = config,
)
var samples = WaveReader.readWave(
assetManager = AssetManager(),
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
)
model.decodeSamples(samples!!)
var tail_paddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tail_paddings)
model.inputFinished()
println(model.text)
}
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class EndpointRule(
var mustContainNonSilence: Boolean,
var minTrailingSilence: Float,
var minUtteranceLength: Float,
)
data class EndpointConfig(
var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
)
data class OnlineTransducerModelConfig(
var encoder: String,
var decoder: String,
var joiner: String,
var numThreads: Int = 4,
var debug: Boolean = false,
)
data class FeatureConfig(
var sampleRate: Float = 16000.0f,
var featureDim: Int = 80,
)
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var tokens: String,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean,
)
class SherpaOnnx(
assetManager: AssetManager,
var config: OnlineRecognizerConfig
) {
private val ptr: Long
init {
ptr = new(assetManager, config)
}
protected fun finalize() {
delete(ptr)
}
fun decodeSamples(samples: FloatArray) =
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
fun inputFinished() = inputFinished(ptr)
fun reset() = reset(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
val text: String
get() = getText(ptr)
private external fun delete(ptr: Long)
private external fun new(
assetManager: AssetManager,
config: OnlineRecognizerConfig,
): Long
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
fun getFeatureConfig(): FeatureConfig {
val featConfig = FeatureConfig()
featConfig.sampleRate = 16000.0f
featConfig.featureDim = 80
return featConfig
}
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
class WaveReader {
companion object {
// Read a mono wave file.
// No resampling is made.
external fun readWave(
assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
): FloatArray?
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
... ...
#!/usr/bin/env bash
set -e
mkdir -p build
cd build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..
make -j4
ls -lh lib
cd ..
export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
cd .github/scripts/
git lfs install
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt AssetManager.kt
ls -lh main.jar
java -Djava.library.path=../../build/lib -jar main.jar
... ...
name: jni
on:
push:
branches:
- master
paths:
- '.github/workflows/jni.yaml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'sherpa-onnx/jni/*'
- '.github/scripts/test-jni.sh'
pull_request:
branches:
- master
paths:
- '.github/workflows/jni.yaml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'sherpa-onnx/jni/*'
- '.github/scripts/test-jni.sh'
concurrency:
group: jni-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
jni:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Display kotlin version
shell: bash
run: |
kotlinc -version
- name: Display java version
shell: bash
run: |
java -version
echo "JAVA_HOME is: ${JAVA_HOME}"
- name: Run JNI test
shell: bash
run: |
.github/scripts/test-jni.sh
... ...
... ... @@ -16,6 +16,7 @@ option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
... ... @@ -44,6 +45,11 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
if(DEFINED ANDROID_ABI)
message(STATUS "Set SHERPA_ONNX_ENABLE_JNI to ON for Android")
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
endif()
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
... ... @@ -51,6 +57,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_EXTENSIONS OFF)
... ...
... ... @@ -2,3 +2,7 @@ add_subdirectory(csrc)
if(SHERPA_ONNX_ENABLE_PYTHON)
add_subdirectory(python)
endif()
if(SHERPA_ONNX_ENABLE_JNI)
add_subdirectory(jni)
endif()
... ...
include_directories(${CMAKE_SOURCE_DIR})
if(NOT DEFINED ANDROID_ABI)
if(NOT DEFINED ENV{JAVA_HOME})
message(FATAL_ERROR "Please set the environment variable JAVA_HOME")
endif()
include_directories($ENV{JAVA_HOME}/include)
include_directories($ENV{JAVA_HOME}/include/linux)
include_directories($ENV{JAVA_HOME}/include/darwin)
endif()
add_library(sherpa-onnx-jni jni.cc)
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)
... ...
// sherpa-onnx/jni/jni.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
// 2022 Pingfeng Luo
// TODO(fangjun): Add documentation to functions/methods in this file
// and also show how to use them with kotlin, possibly with java.
// If you use ndk, you can find "jni.h" inside
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT
#include <strstream>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#else
#include <fstream>
#endif
#if __ANDROID_API__ >= 8
#include <android/log.h>
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
} while (0)
#else
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
#endif
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#define SHERPA_ONNX_EXTERN_C extern "C"
namespace sherpa_onnx {
class SherpaOnnx {
public:
SherpaOnnx(
#if __ANDROID_API__ >= 9
AAssetManager *mgr,
#endif
const sherpa_onnx::OnlineRecognizerConfig &config)
: recognizer_(
#if __ANDROID_API__ >= 9
mgr,
#endif
config),
stream_(recognizer_.CreateStream()),
tail_padding_(16000 * 0.32, 0) {
}
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
Decode();
}
void InputFinished() const {
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
stream_->InputFinished();
Decode();
}
const std::string GetText() const {
auto result = recognizer_.GetResult(stream_.get());
return result.text;
}
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); }
private:
void Decode() const {
while (recognizer_.IsReady(stream_.get())) {
recognizer_.DecodeStream(stream_.get());
}
}
private:
sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
std::vector<float> tail_padding_;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- enable endpoint ----------
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
ans.enable_endpoint = env->GetBooleanField(config, fid);
//---------- endpoint_config ----------
fid = env->GetFieldID(cls, "endpointConfig",
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
jobject endpoint_config = env->GetObjectField(config, fid);
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
fid = env->GetFieldID(endpoint_config_cls, "rule1",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule1 = env->GetObjectField(endpoint_config, fid);
jclass rule_class = env->GetObjectClass(rule1);
fid = env->GetFieldID(endpoint_config_cls, "rule2",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule2 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(endpoint_config_cls, "rule3",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule3 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
ans.endpoint_config.rule1.must_contain_nonsilence =
env->GetBooleanField(rule1, fid);
ans.endpoint_config.rule2.must_contain_nonsilence =
env->GetBooleanField(rule2, fid);
ans.endpoint_config.rule3.must_contain_nonsilence =
env->GetBooleanField(rule3, fid);
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
ans.endpoint_config.rule1.min_trailing_silence =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_trailing_silence =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_trailing_silence =
env->GetFloatField(rule3, fid);
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
ans.endpoint_config.rule1.min_utterance_length =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_utterance_length =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_utterance_length =
env->GetFloatField(rule3, fid);
//---------- tokens ----------
fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.tokens = p;
env->ReleaseStringUTFChars(s, p);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.decoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
SHERPA_ONNX_LOGE("freed!");
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsEndpoint();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jfloat sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->DecodeSamples(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
jfloat expected_sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return nullptr;
}
AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
size_t asset_length = AAsset_getLength(asset);
std::vector<char> buffer(asset_length);
AAsset_read(asset, buffer.data(), asset_length);
std::istrstream is(buffer.data(), asset_length);
#else
std::ifstream is(p_filename, std::ios::binary);
#endif
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
#if __ANDROID_API__ >= 9
AAsset_close(asset);
#endif
env->ReleaseStringUTFChars(filename, p_filename);
if (!is_ok) {
return nullptr;
}
jfloatArray ans = env->NewFloatArray(samples.size());
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
return ans;
}
... ...