Fangjun Kuang
Committed by GitHub

Add Java and Kotlin API for punctuation models (#818)

... ... @@ -106,6 +106,14 @@ jobs:
make -j4
ls -lh lib
- name: Run java test (add punctuations)
shell: bash
run: |
cd ./java-api-examples
./run-add-punctuation-zh-en.sh
# Delete model files to save space
rm -rf sherpa-onnx-punct-*
- name: Run java test (Spoken language identification)
shell: bash
run: |
... ...
// Copyright 2024 Xiaomi Corporation
// This file shows how to use a punctuation model to add punctuations to text.
//
// The model supports both English and Chinese.
import com.k2fsa.sherpa.onnx.*;
public class AddPunctuation {
public static void main(String[] args) {
// please download the model from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
String model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx";
OfflinePunctuationModelConfig modelConfig =
OfflinePunctuationModelConfig.builder()
.setCtTransformer(model)
.setNumThreads(1)
.setDebug(true)
.build();
OfflinePunctuationConfig config =
OfflinePunctuationConfig.builder().setModel(modelConfig).build();
OfflinePunctuation punct = new OfflinePunctuation(config);
String[] sentences =
new String[] {
"这是一个测试你好吗How are you我很好thank you are you ok谢谢你",
"我们都是木头人不会说话不会动",
"The African blogosphere is rapidly expanding bringing more voices online in the form of"
+ " commentaries opinions analyses rants and poetry",
};
System.out.println("---");
for (String text : sentences) {
String out = punct.addPunctuation(text);
System.out.printf("Input: %s\n", text);
System.out.printf("Output: %s\n", out);
System.out.println("---");
}
}
}
... ...
... ... @@ -35,3 +35,11 @@ This directory contains examples for the JAVA API of sherpa-onnx.
```bash
./run-spoken-language-identification-whisper.sh
```
## Add puncutations to text
The punctuation model supports both English and Chinese.
```bash
./run-add-punctuation-zh-en.sh
```
... ...
#!/usr/bin/env bash
set -ex
if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..
make -j4
ls -lh lib
popd
fi
if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi
if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..
make -j4
ls -lh lib
fi
if [ ! -f ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
fi
java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
./AddPunctuation.java
... ...
../sherpa-onnx/kotlin-api/OfflinePunctuation.kt
\ No newline at end of file
... ...
... ... @@ -197,9 +197,29 @@ function testOfflineAsr() {
java -Djava.library.path=../build/lib -jar $out_filename
}
function testPunctuation() {
if [ ! -f ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
fi
out_filename=test_punctuation.jar
kotlinc-jvm -include-runtime -d $out_filename \
./test_punctuation.kt \
./OfflinePunctuation.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
testSpeakerEmbeddingExtractor
testOnlineAsr
testTts
testAudioTagging
testSpokenLanguageIdentification
testOfflineAsr
testPunctuation
... ...
package com.k2fsa.sherpa.onnx
fun main() {
testPunctuation()
}
fun testPunctuation() {
val config = OfflinePunctuationConfig(
model=OfflinePunctuationModelConfig(
ctTransformer="./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx",
numThreads=1,
debug=true,
provider="cpu",
)
)
val punct = OfflinePunctuation(config = config)
val sentences = arrayOf(
"这是一个测试你好吗How are you我很好thank you are you ok谢谢你",
"我们都是木头人不会说话不会动",
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry",
)
println("---")
for (text in sentences) {
val out = punct.addPunctuation(text)
println("Input: $text")
println("Output: $out")
println("---")
}
println(sentences)
}
... ...
... ... @@ -9,6 +9,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
... ... @@ -24,6 +29,12 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
const OfflinePunctuationConfig &config)
: config_(config), model_(config.model) {}
#if __ANDROID_API__ >= 9
OfflinePunctuationCtTransformerImpl(AAssetManager *mgr,
const OfflinePunctuationConfig &config)
: config_(config), model_(mgr, config.model) {}
#endif
std::string AddPunctuation(const std::string &text) const override {
if (text.empty()) {
return {};
... ...
... ... @@ -4,6 +4,11 @@
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h"
... ... @@ -19,4 +24,16 @@ std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create(
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create(
AAssetManager *mgr, const OfflinePunctuationConfig &config) {
if (!config.model.ct_transformer.empty()) {
return std::make_unique<OfflinePunctuationCtTransformerImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer");
return nullptr;
}
#endif
} // namespace sherpa_onnx
... ...
... ... @@ -7,6 +7,10 @@
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-punctuation.h"
... ... @@ -19,6 +23,11 @@ class OfflinePunctuationImpl {
static std::unique_ptr<OfflinePunctuationImpl> Create(
const OfflinePunctuationConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflinePunctuationImpl> Create(
AAssetManager *mgr, const OfflinePunctuationConfig &config);
#endif
virtual std::string AddPunctuation(const std::string &text) const = 0;
};
... ...
... ... @@ -4,6 +4,11 @@
#include "sherpa-onnx/csrc/offline-punctuation.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
... ... @@ -33,6 +38,12 @@ std::string OfflinePunctuationConfig::ToString() const {
OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config)
: impl_(OfflinePunctuationImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
OfflinePunctuation::OfflinePunctuation(AAssetManager *mgr,
const OfflinePunctuationConfig &config)
: impl_(OfflinePunctuationImpl::Create(mgr, config)) {}
#endif
OfflinePunctuation::~OfflinePunctuation() = default;
std::string OfflinePunctuation::AddPunctuation(const std::string &text) const {
... ...
... ... @@ -8,6 +8,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
... ... @@ -33,6 +38,11 @@ class OfflinePunctuation {
public:
explicit OfflinePunctuation(const OfflinePunctuationConfig &config);
#if __ANDROID_API__ >= 9
OfflinePunctuation(AAssetManager *mgr,
const OfflinePunctuationConfig &config);
#endif
~OfflinePunctuation();
// Add punctuation to the input text and return it.
... ...
... ... @@ -40,6 +40,10 @@ java_files += SpokenLanguageIdentificationWhisperConfig.java
java_files += SpokenLanguageIdentificationConfig.java
java_files += SpokenLanguageIdentification.java
java_files += OfflinePunctuationModelConfig.java
java_files += OfflinePunctuationConfig.java
java_files += OfflinePunctuation.java
class_files := $(java_files:%.java=%.class)
java_files := $(addprefix src/$(package_dir)/,$(java_files))
... ...
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflinePunctuation {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0; // this is the asr engine ptrss
public OfflinePunctuation(OfflinePunctuationConfig config) {
ptr = newFromFile(config);
}
public String addPunctuation(String text) {
return addPunctuation(ptr, text);
}
@Override
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native void delete(long ptr);
private native long newFromFile(OfflinePunctuationConfig config);
private native String addPunctuation(long ptr, String text);
}
... ...
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflinePunctuationConfig {
private final OfflinePunctuationModelConfig model;
private OfflinePunctuationConfig(Builder builder) {
this.model = builder.model;
}
public static Builder builder() {
return new Builder();
}
public OfflinePunctuationModelConfig getModel() {
return model;
}
public static class Builder {
private OfflinePunctuationModelConfig model = OfflinePunctuationModelConfig.builder().build();
public OfflinePunctuationConfig build() {
return new OfflinePunctuationConfig(this);
}
public Builder setModel(OfflinePunctuationModelConfig model) {
this.model = model;
return this;
}
}
}
... ...
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflinePunctuationModelConfig {
private final String ctTransformer;
private final int numThreads;
private final boolean debug;
private final String provider;
private OfflinePunctuationModelConfig(Builder builder) {
this.ctTransformer = builder.ctTransformer;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
}
public static Builder builder() {
return new Builder();
}
public String getCtTransformer() {
return ctTransformer;
}
public static class Builder {
private String ctTransformer = "";
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
public OfflinePunctuationModelConfig build() {
return new OfflinePunctuationModelConfig(this);
}
public Builder setCtTransformer(String ctTransformer) {
this.ctTransformer = ctTransformer;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Builder setDebug(boolean debug) {
this.debug = debug;
return this;
}
public Builder setProvider(String provider) {
this.provider = provider;
return this;
}
}
}
... ...
... ... @@ -13,6 +13,7 @@ set(sources
audio-tagging.cc
jni.cc
keyword-spotter.cc
offline-punctuation.cc
offline-recognizer.cc
offline-stream.cc
online-recognizer.cc
... ...
// sherpa-onnx/jni/offline-punctuation.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflinePunctuationConfig GetOfflinePunctuationConfig(JNIEnv *env,
jobject config) {
OfflinePunctuationConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(
cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflinePunctuationModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid =
env->GetFieldID(model_config_cls, "ctTransformer", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.ct_transformer = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OfflinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto model = new sherpa_onnx::OfflinePunctuation(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflinePunctuation *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_addPunctuation(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring text) {
auto punct = reinterpret_cast<const sherpa_onnx::OfflinePunctuation *>(ptr);
const char *ptext = env->GetStringUTFChars(text, nullptr);
std::string result = punct->AddPunctuation(ptext);
env->ReleaseStringUTFChars(text, ptext);
return env->NewStringUTF(result.c_str());
}
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflinePunctuationModelConfig(
var ctTransformer: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OfflinePunctuationConfig(
var model: OfflinePunctuationModelConfig,
)
class OfflinePunctuation(
assetManager: AssetManager? = null,
config: OfflinePunctuationConfig,
) {
private val ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
}
fun release() = finalize()
fun addPunctuation(text: String) = addPunctuation(ptr, text)
private external fun delete(ptr: Long)
private external fun addPunctuation(ptr: Long, text: String): String
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflinePunctuationConfig,
): Long
private external fun newFromFile(
config: OfflinePunctuationConfig,
): Long
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
... ...