Committed by
GitHub
Add Java and Kotlin API for punctuation models (#818)
正在显示
19 个修改的文件
包含
515 行增加
和
0 行删除
| @@ -106,6 +106,14 @@ jobs: | @@ -106,6 +106,14 @@ jobs: | ||
| 106 | make -j4 | 106 | make -j4 |
| 107 | ls -lh lib | 107 | ls -lh lib |
| 108 | 108 | ||
| 109 | + - name: Run java test (add punctuations) | ||
| 110 | + shell: bash | ||
| 111 | + run: | | ||
| 112 | + cd ./java-api-examples | ||
| 113 | + ./run-add-punctuation-zh-en.sh | ||
| 114 | + # Delete model files to save space | ||
| 115 | + rm -rf sherpa-onnx-punct-* | ||
| 116 | + | ||
| 109 | - name: Run java test (Spoken language identification) | 117 | - name: Run java test (Spoken language identification) |
| 110 | shell: bash | 118 | shell: bash |
| 111 | run: | | 119 | run: | |
java-api-examples/AddPunctuation.java
0 → 100644
| 1 | +// Copyright 2024 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +// This file shows how to use a punctuation model to add punctuations to text. | ||
| 4 | +// | ||
| 5 | +// The model supports both English and Chinese. | ||
| 6 | +import com.k2fsa.sherpa.onnx.*; | ||
| 7 | + | ||
| 8 | +public class AddPunctuation { | ||
| 9 | + public static void main(String[] args) { | ||
| 10 | + // please download the model from | ||
| 11 | + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models | ||
| 12 | + String model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"; | ||
| 13 | + OfflinePunctuationModelConfig modelConfig = | ||
| 14 | + OfflinePunctuationModelConfig.builder() | ||
| 15 | + .setCtTransformer(model) | ||
| 16 | + .setNumThreads(1) | ||
| 17 | + .setDebug(true) | ||
| 18 | + .build(); | ||
| 19 | + OfflinePunctuationConfig config = | ||
| 20 | + OfflinePunctuationConfig.builder().setModel(modelConfig).build(); | ||
| 21 | + | ||
| 22 | + OfflinePunctuation punct = new OfflinePunctuation(config); | ||
| 23 | + | ||
| 24 | + String[] sentences = | ||
| 25 | + new String[] { | ||
| 26 | + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", | ||
| 27 | + "我们都是木头人不会说话不会动", | ||
| 28 | + "The African blogosphere is rapidly expanding bringing more voices online in the form of" | ||
| 29 | + + " commentaries opinions analyses rants and poetry", | ||
| 30 | + }; | ||
| 31 | + | ||
| 32 | + System.out.println("---"); | ||
| 33 | + for (String text : sentences) { | ||
| 34 | + String out = punct.addPunctuation(text); | ||
| 35 | + System.out.printf("Input: %s\n", text); | ||
| 36 | + System.out.printf("Output: %s\n", out); | ||
| 37 | + System.out.println("---"); | ||
| 38 | + } | ||
| 39 | + } | ||
| 40 | +} |
| @@ -35,3 +35,11 @@ This directory contains examples for the JAVA API of sherpa-onnx. | @@ -35,3 +35,11 @@ This directory contains examples for the JAVA API of sherpa-onnx. | ||
| 35 | ```bash | 35 | ```bash |
| 36 | ./run-spoken-language-identification-whisper.sh | 36 | ./run-spoken-language-identification-whisper.sh |
| 37 | ``` | 37 | ``` |
| 38 | + | ||
| 39 | +## Add puncutations to text | ||
| 40 | + | ||
| 41 | +The punctuation model supports both English and Chinese. | ||
| 42 | + | ||
| 43 | +```bash | ||
| 44 | +./run-add-punctuation-zh-en.sh | ||
| 45 | +``` |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then | ||
| 6 | + mkdir -p ../build | ||
| 7 | + pushd ../build | ||
| 8 | + cmake \ | ||
| 9 | + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | ||
| 10 | + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | ||
| 11 | + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | ||
| 12 | + -DBUILD_SHARED_LIBS=ON \ | ||
| 13 | + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ | ||
| 14 | + -DSHERPA_ONNX_ENABLE_JNI=ON \ | ||
| 15 | + .. | ||
| 16 | + | ||
| 17 | + make -j4 | ||
| 18 | + ls -lh lib | ||
| 19 | + popd | ||
| 20 | +fi | ||
| 21 | + | ||
| 22 | +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then | ||
| 23 | + pushd ../sherpa-onnx/java-api | ||
| 24 | + make | ||
| 25 | + popd | ||
| 26 | +fi | ||
| 27 | + | ||
| 28 | +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then | ||
| 29 | + cmake \ | ||
| 30 | + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | ||
| 31 | + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | ||
| 32 | + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | ||
| 33 | + -DBUILD_SHARED_LIBS=ON \ | ||
| 34 | + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ | ||
| 35 | + -DSHERPA_ONNX_ENABLE_JNI=ON \ | ||
| 36 | + .. | ||
| 37 | + | ||
| 38 | + make -j4 | ||
| 39 | + ls -lh lib | ||
| 40 | +fi | ||
| 41 | + | ||
| 42 | +if [ ! -f ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx ]; then | ||
| 43 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 44 | + tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 45 | + rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 46 | +fi | ||
| 47 | + | ||
| 48 | +java \ | ||
| 49 | + -Djava.library.path=$PWD/../build/lib \ | ||
| 50 | + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ | ||
| 51 | + ./AddPunctuation.java |
kotlin-api-examples/OfflinePunctuation.kt
0 → 120000
| 1 | +../sherpa-onnx/kotlin-api/OfflinePunctuation.kt |
| @@ -197,9 +197,29 @@ function testOfflineAsr() { | @@ -197,9 +197,29 @@ function testOfflineAsr() { | ||
| 197 | java -Djava.library.path=../build/lib -jar $out_filename | 197 | java -Djava.library.path=../build/lib -jar $out_filename |
| 198 | } | 198 | } |
| 199 | 199 | ||
| 200 | +function testPunctuation() { | ||
| 201 | + if [ ! -f ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx ]; then | ||
| 202 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 203 | + tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 204 | + rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | ||
| 205 | + fi | ||
| 206 | + | ||
| 207 | + out_filename=test_punctuation.jar | ||
| 208 | + kotlinc-jvm -include-runtime -d $out_filename \ | ||
| 209 | + ./test_punctuation.kt \ | ||
| 210 | + ./OfflinePunctuation.kt \ | ||
| 211 | + faked-asset-manager.kt \ | ||
| 212 | + faked-log.kt | ||
| 213 | + | ||
| 214 | + ls -lh $out_filename | ||
| 215 | + | ||
| 216 | + java -Djava.library.path=../build/lib -jar $out_filename | ||
| 217 | +} | ||
| 218 | + | ||
| 200 | testSpeakerEmbeddingExtractor | 219 | testSpeakerEmbeddingExtractor |
| 201 | testOnlineAsr | 220 | testOnlineAsr |
| 202 | testTts | 221 | testTts |
| 203 | testAudioTagging | 222 | testAudioTagging |
| 204 | testSpokenLanguageIdentification | 223 | testSpokenLanguageIdentification |
| 205 | testOfflineAsr | 224 | testOfflineAsr |
| 225 | +testPunctuation |
kotlin-api-examples/test_punctuation.kt
0 → 100644
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +fun main() { | ||
| 4 | + testPunctuation() | ||
| 5 | +} | ||
| 6 | + | ||
| 7 | +fun testPunctuation() { | ||
| 8 | + val config = OfflinePunctuationConfig( | ||
| 9 | + model=OfflinePunctuationModelConfig( | ||
| 10 | + ctTransformer="./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx", | ||
| 11 | + numThreads=1, | ||
| 12 | + debug=true, | ||
| 13 | + provider="cpu", | ||
| 14 | + ) | ||
| 15 | + ) | ||
| 16 | + val punct = OfflinePunctuation(config = config) | ||
| 17 | + val sentences = arrayOf( | ||
| 18 | + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", | ||
| 19 | + "我们都是木头人不会说话不会动", | ||
| 20 | + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", | ||
| 21 | + ) | ||
| 22 | + println("---") | ||
| 23 | + for (text in sentences) { | ||
| 24 | + val out = punct.addPunctuation(text) | ||
| 25 | + println("Input: $text") | ||
| 26 | + println("Output: $out") | ||
| 27 | + println("---") | ||
| 28 | + } | ||
| 29 | + println(sentences) | ||
| 30 | + | ||
| 31 | +} |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 17 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/math.h" | 18 | #include "sherpa-onnx/csrc/math.h" |
| 14 | #include "sherpa-onnx/csrc/offline-ct-transformer-model.h" | 19 | #include "sherpa-onnx/csrc/offline-ct-transformer-model.h" |
| @@ -24,6 +29,12 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -24,6 +29,12 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 24 | const OfflinePunctuationConfig &config) | 29 | const OfflinePunctuationConfig &config) |
| 25 | : config_(config), model_(config.model) {} | 30 | : config_(config), model_(config.model) {} |
| 26 | 31 | ||
| 32 | +#if __ANDROID_API__ >= 9 | ||
| 33 | + OfflinePunctuationCtTransformerImpl(AAssetManager *mgr, | ||
| 34 | + const OfflinePunctuationConfig &config) | ||
| 35 | + : config_(config), model_(mgr, config.model) {} | ||
| 36 | +#endif | ||
| 37 | + | ||
| 27 | std::string AddPunctuation(const std::string &text) const override { | 38 | std::string AddPunctuation(const std::string &text) const override { |
| 28 | if (text.empty()) { | 39 | if (text.empty()) { |
| 29 | return {}; | 40 | return {}; |
| @@ -4,6 +4,11 @@ | @@ -4,6 +4,11 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-punctuation-impl.h" | 5 | #include "sherpa-onnx/csrc/offline-punctuation-impl.h" |
| 6 | 6 | ||
| 7 | +#if __ANDROID_API__ >= 9 | ||
| 8 | +#include "android/asset_manager.h" | ||
| 9 | +#include "android/asset_manager_jni.h" | ||
| 10 | +#endif | ||
| 11 | + | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | #include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h" | 13 | #include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h" |
| 9 | 14 | ||
| @@ -19,4 +24,16 @@ std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create( | @@ -19,4 +24,16 @@ std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create( | ||
| 19 | return nullptr; | 24 | return nullptr; |
| 20 | } | 25 | } |
| 21 | 26 | ||
| 27 | +#if __ANDROID_API__ >= 9 | ||
| 28 | +std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create( | ||
| 29 | + AAssetManager *mgr, const OfflinePunctuationConfig &config) { | ||
| 30 | + if (!config.model.ct_transformer.empty()) { | ||
| 31 | + return std::make_unique<OfflinePunctuationCtTransformerImpl>(mgr, config); | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer"); | ||
| 35 | + return nullptr; | ||
| 36 | +} | ||
| 37 | +#endif | ||
| 38 | + | ||
| 22 | } // namespace sherpa_onnx | 39 | } // namespace sherpa_onnx |
| @@ -7,6 +7,10 @@ | @@ -7,6 +7,10 @@ | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | +#if __ANDROID_API__ >= 9 | ||
| 11 | +#include "android/asset_manager.h" | ||
| 12 | +#include "android/asset_manager_jni.h" | ||
| 13 | +#endif | ||
| 10 | 14 | ||
| 11 | #include "sherpa-onnx/csrc/offline-punctuation.h" | 15 | #include "sherpa-onnx/csrc/offline-punctuation.h" |
| 12 | 16 | ||
| @@ -19,6 +23,11 @@ class OfflinePunctuationImpl { | @@ -19,6 +23,11 @@ class OfflinePunctuationImpl { | ||
| 19 | static std::unique_ptr<OfflinePunctuationImpl> Create( | 23 | static std::unique_ptr<OfflinePunctuationImpl> Create( |
| 20 | const OfflinePunctuationConfig &config); | 24 | const OfflinePunctuationConfig &config); |
| 21 | 25 | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + static std::unique_ptr<OfflinePunctuationImpl> Create( | ||
| 28 | + AAssetManager *mgr, const OfflinePunctuationConfig &config); | ||
| 29 | +#endif | ||
| 30 | + | ||
| 22 | virtual std::string AddPunctuation(const std::string &text) const = 0; | 31 | virtual std::string AddPunctuation(const std::string &text) const = 0; |
| 23 | }; | 32 | }; |
| 24 | 33 |
| @@ -4,6 +4,11 @@ | @@ -4,6 +4,11 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-punctuation.h" | 5 | #include "sherpa-onnx/csrc/offline-punctuation.h" |
| 6 | 6 | ||
| 7 | +#if __ANDROID_API__ >= 9 | ||
| 8 | +#include "android/asset_manager.h" | ||
| 9 | +#include "android/asset_manager_jni.h" | ||
| 10 | +#endif | ||
| 11 | + | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | #include "sherpa-onnx/csrc/offline-punctuation-impl.h" | 13 | #include "sherpa-onnx/csrc/offline-punctuation-impl.h" |
| 9 | 14 | ||
| @@ -33,6 +38,12 @@ std::string OfflinePunctuationConfig::ToString() const { | @@ -33,6 +38,12 @@ std::string OfflinePunctuationConfig::ToString() const { | ||
| 33 | OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config) | 38 | OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config) |
| 34 | : impl_(OfflinePunctuationImpl::Create(config)) {} | 39 | : impl_(OfflinePunctuationImpl::Create(config)) {} |
| 35 | 40 | ||
| 41 | +#if __ANDROID_API__ >= 9 | ||
| 42 | +OfflinePunctuation::OfflinePunctuation(AAssetManager *mgr, | ||
| 43 | + const OfflinePunctuationConfig &config) | ||
| 44 | + : impl_(OfflinePunctuationImpl::Create(mgr, config)) {} | ||
| 45 | +#endif | ||
| 46 | + | ||
| 36 | OfflinePunctuation::~OfflinePunctuation() = default; | 47 | OfflinePunctuation::~OfflinePunctuation() = default; |
| 37 | 48 | ||
| 38 | std::string OfflinePunctuation::AddPunctuation(const std::string &text) const { | 49 | std::string OfflinePunctuation::AddPunctuation(const std::string &text) const { |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | #include "sherpa-onnx/csrc/offline-punctuation-model-config.h" | 16 | #include "sherpa-onnx/csrc/offline-punctuation-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/parse-options.h" | 17 | #include "sherpa-onnx/csrc/parse-options.h" |
| 13 | 18 | ||
| @@ -33,6 +38,11 @@ class OfflinePunctuation { | @@ -33,6 +38,11 @@ class OfflinePunctuation { | ||
| 33 | public: | 38 | public: |
| 34 | explicit OfflinePunctuation(const OfflinePunctuationConfig &config); | 39 | explicit OfflinePunctuation(const OfflinePunctuationConfig &config); |
| 35 | 40 | ||
| 41 | +#if __ANDROID_API__ >= 9 | ||
| 42 | + OfflinePunctuation(AAssetManager *mgr, | ||
| 43 | + const OfflinePunctuationConfig &config); | ||
| 44 | +#endif | ||
| 45 | + | ||
| 36 | ~OfflinePunctuation(); | 46 | ~OfflinePunctuation(); |
| 37 | 47 | ||
| 38 | // Add punctuation to the input text and return it. | 48 | // Add punctuation to the input text and return it. |
| @@ -40,6 +40,10 @@ java_files += SpokenLanguageIdentificationWhisperConfig.java | @@ -40,6 +40,10 @@ java_files += SpokenLanguageIdentificationWhisperConfig.java | ||
| 40 | java_files += SpokenLanguageIdentificationConfig.java | 40 | java_files += SpokenLanguageIdentificationConfig.java |
| 41 | java_files += SpokenLanguageIdentification.java | 41 | java_files += SpokenLanguageIdentification.java |
| 42 | 42 | ||
| 43 | +java_files += OfflinePunctuationModelConfig.java | ||
| 44 | +java_files += OfflinePunctuationConfig.java | ||
| 45 | +java_files += OfflinePunctuation.java | ||
| 46 | + | ||
| 43 | class_files := $(java_files:%.java=%.class) | 47 | class_files := $(java_files:%.java=%.class) |
| 44 | 48 | ||
| 45 | java_files := $(addprefix src/$(package_dir)/,$(java_files)) | 49 | java_files := $(addprefix src/$(package_dir)/,$(java_files)) |
| 1 | +// Copyright 2024 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +package com.k2fsa.sherpa.onnx; | ||
| 4 | + | ||
| 5 | +public class OfflinePunctuation { | ||
| 6 | + static { | ||
| 7 | + System.loadLibrary("sherpa-onnx-jni"); | ||
| 8 | + } | ||
| 9 | + | ||
| 10 | + private long ptr = 0; // this is the asr engine ptrss | ||
| 11 | + | ||
| 12 | + public OfflinePunctuation(OfflinePunctuationConfig config) { | ||
| 13 | + ptr = newFromFile(config); | ||
| 14 | + } | ||
| 15 | + | ||
| 16 | + public String addPunctuation(String text) { | ||
| 17 | + return addPunctuation(ptr, text); | ||
| 18 | + } | ||
| 19 | + | ||
| 20 | + @Override | ||
| 21 | + protected void finalize() throws Throwable { | ||
| 22 | + release(); | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + // You'd better call it manually if it is not used anymore | ||
| 26 | + public void release() { | ||
| 27 | + if (this.ptr == 0) { | ||
| 28 | + return; | ||
| 29 | + } | ||
| 30 | + delete(this.ptr); | ||
| 31 | + this.ptr = 0; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + private native void delete(long ptr); | ||
| 35 | + | ||
| 36 | + private native long newFromFile(OfflinePunctuationConfig config); | ||
| 37 | + | ||
| 38 | + private native String addPunctuation(long ptr, String text); | ||
| 39 | +} |
| 1 | +// Copyright 2024 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +package com.k2fsa.sherpa.onnx; | ||
| 4 | + | ||
| 5 | +public class OfflinePunctuationConfig { | ||
| 6 | + private final OfflinePunctuationModelConfig model; | ||
| 7 | + | ||
| 8 | + private OfflinePunctuationConfig(Builder builder) { | ||
| 9 | + this.model = builder.model; | ||
| 10 | + } | ||
| 11 | + | ||
| 12 | + public static Builder builder() { | ||
| 13 | + return new Builder(); | ||
| 14 | + } | ||
| 15 | + | ||
| 16 | + public OfflinePunctuationModelConfig getModel() { | ||
| 17 | + return model; | ||
| 18 | + } | ||
| 19 | + | ||
| 20 | + | ||
| 21 | + public static class Builder { | ||
| 22 | + private OfflinePunctuationModelConfig model = OfflinePunctuationModelConfig.builder().build(); | ||
| 23 | + | ||
| 24 | + public OfflinePunctuationConfig build() { | ||
| 25 | + return new OfflinePunctuationConfig(this); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + public Builder setModel(OfflinePunctuationModelConfig model) { | ||
| 29 | + this.model = model; | ||
| 30 | + return this; | ||
| 31 | + } | ||
| 32 | + } | ||
| 33 | +} |
| 1 | +// Copyright 2024 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +package com.k2fsa.sherpa.onnx; | ||
| 4 | + | ||
| 5 | +public class OfflinePunctuationModelConfig { | ||
| 6 | + private final String ctTransformer; | ||
| 7 | + private final int numThreads; | ||
| 8 | + private final boolean debug; | ||
| 9 | + private final String provider; | ||
| 10 | + | ||
| 11 | + private OfflinePunctuationModelConfig(Builder builder) { | ||
| 12 | + this.ctTransformer = builder.ctTransformer; | ||
| 13 | + this.numThreads = builder.numThreads; | ||
| 14 | + this.debug = builder.debug; | ||
| 15 | + this.provider = builder.provider; | ||
| 16 | + } | ||
| 17 | + | ||
| 18 | + public static Builder builder() { | ||
| 19 | + return new Builder(); | ||
| 20 | + } | ||
| 21 | + | ||
| 22 | + public String getCtTransformer() { | ||
| 23 | + return ctTransformer; | ||
| 24 | + } | ||
| 25 | + | ||
| 26 | + public static class Builder { | ||
| 27 | + private String ctTransformer = ""; | ||
| 28 | + private int numThreads = 1; | ||
| 29 | + private boolean debug = true; | ||
| 30 | + private String provider = "cpu"; | ||
| 31 | + | ||
| 32 | + public OfflinePunctuationModelConfig build() { | ||
| 33 | + return new OfflinePunctuationModelConfig(this); | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + public Builder setCtTransformer(String ctTransformer) { | ||
| 37 | + this.ctTransformer = ctTransformer; | ||
| 38 | + return this; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + public Builder setNumThreads(int numThreads) { | ||
| 42 | + this.numThreads = numThreads; | ||
| 43 | + return this; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + public Builder setDebug(boolean debug) { | ||
| 47 | + this.debug = debug; | ||
| 48 | + return this; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + public Builder setProvider(String provider) { | ||
| 52 | + this.provider = provider; | ||
| 53 | + return this; | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | +} |
| @@ -13,6 +13,7 @@ set(sources | @@ -13,6 +13,7 @@ set(sources | ||
| 13 | audio-tagging.cc | 13 | audio-tagging.cc |
| 14 | jni.cc | 14 | jni.cc |
| 15 | keyword-spotter.cc | 15 | keyword-spotter.cc |
| 16 | + offline-punctuation.cc | ||
| 16 | offline-recognizer.cc | 17 | offline-recognizer.cc |
| 17 | offline-stream.cc | 18 | offline-stream.cc |
| 18 | online-recognizer.cc | 19 | online-recognizer.cc |
sherpa-onnx/jni/offline-punctuation.cc
0 → 100644
| 1 | +// sherpa-onnx/jni/offline-punctuation.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-punctuation.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | +#include "sherpa-onnx/jni/common.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +static OfflinePunctuationConfig GetOfflinePunctuationConfig(JNIEnv *env, | ||
| 13 | + jobject config) { | ||
| 14 | + OfflinePunctuationConfig ans; | ||
| 15 | + | ||
| 16 | + jclass cls = env->GetObjectClass(config); | ||
| 17 | + jfieldID fid; | ||
| 18 | + | ||
| 19 | + fid = env->GetFieldID( | ||
| 20 | + cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflinePunctuationModelConfig;"); | ||
| 21 | + jobject model_config = env->GetObjectField(config, fid); | ||
| 22 | + jclass model_config_cls = env->GetObjectClass(model_config); | ||
| 23 | + | ||
| 24 | + fid = | ||
| 25 | + env->GetFieldID(model_config_cls, "ctTransformer", "Ljava/lang/String;"); | ||
| 26 | + jstring s = (jstring)env->GetObjectField(model_config, fid); | ||
| 27 | + const char *p = env->GetStringUTFChars(s, nullptr); | ||
| 28 | + ans.model.ct_transformer = p; | ||
| 29 | + env->ReleaseStringUTFChars(s, p); | ||
| 30 | + | ||
| 31 | + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | ||
| 32 | + ans.model.num_threads = env->GetIntField(model_config, fid); | ||
| 33 | + | ||
| 34 | + fid = env->GetFieldID(model_config_cls, "debug", "Z"); | ||
| 35 | + ans.model.debug = env->GetBooleanField(model_config, fid); | ||
| 36 | + | ||
| 37 | + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | ||
| 38 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 39 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 40 | + ans.model.provider = p; | ||
| 41 | + env->ReleaseStringUTFChars(s, p); | ||
| 42 | + | ||
| 43 | + return ans; | ||
| 44 | +} | ||
| 45 | + | ||
| 46 | +} // namespace sherpa_onnx | ||
| 47 | + | ||
| 48 | +SHERPA_ONNX_EXTERN_C | ||
| 49 | +JNIEXPORT jlong JNICALL | ||
| 50 | +Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset( | ||
| 51 | + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | ||
| 52 | +#if __ANDROID_API__ >= 9 | ||
| 53 | + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); | ||
| 54 | + if (!mgr) { | ||
| 55 | + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); | ||
| 56 | + } | ||
| 57 | +#endif | ||
| 58 | + auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config); | ||
| 59 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 60 | + auto model = new sherpa_onnx::OfflinePunctuation( | ||
| 61 | +#if __ANDROID_API__ >= 9 | ||
| 62 | + mgr, | ||
| 63 | +#endif | ||
| 64 | + config); | ||
| 65 | + | ||
| 66 | + return (jlong)model; | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +SHERPA_ONNX_EXTERN_C | ||
| 70 | +JNIEXPORT jlong JNICALL | ||
| 71 | +Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromFile(JNIEnv *env, | ||
| 72 | + jobject /*obj*/, | ||
| 73 | + jobject _config) { | ||
| 74 | + auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config); | ||
| 75 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 76 | + | ||
| 77 | + if (!config.Validate()) { | ||
| 78 | + SHERPA_ONNX_LOGE("Errors found in config!"); | ||
| 79 | + return 0; | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + auto model = new sherpa_onnx::OfflinePunctuation(config); | ||
| 83 | + | ||
| 84 | + return (jlong)model; | ||
| 85 | +} | ||
| 86 | + | ||
| 87 | +SHERPA_ONNX_EXTERN_C | ||
| 88 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_delete( | ||
| 89 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 90 | + delete reinterpret_cast<sherpa_onnx::OfflinePunctuation *>(ptr); | ||
| 91 | +} | ||
| 92 | + | ||
| 93 | +SHERPA_ONNX_EXTERN_C | ||
| 94 | +JNIEXPORT jstring JNICALL | ||
| 95 | +Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_addPunctuation(JNIEnv *env, | ||
| 96 | + jobject /*obj*/, | ||
| 97 | + jlong ptr, | ||
| 98 | + jstring text) { | ||
| 99 | + auto punct = reinterpret_cast<const sherpa_onnx::OfflinePunctuation *>(ptr); | ||
| 100 | + | ||
| 101 | + const char *ptext = env->GetStringUTFChars(text, nullptr); | ||
| 102 | + | ||
| 103 | + std::string result = punct->AddPunctuation(ptext); | ||
| 104 | + | ||
| 105 | + env->ReleaseStringUTFChars(text, ptext); | ||
| 106 | + | ||
| 107 | + return env->NewStringUTF(result.c_str()); | ||
| 108 | +} |
sherpa-onnx/kotlin-api/OfflinePunctuation.kt
0 → 100644
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +import android.content.res.AssetManager | ||
| 4 | + | ||
| 5 | +data class OfflinePunctuationModelConfig( | ||
| 6 | + var ctTransformer: String, | ||
| 7 | + var numThreads: Int = 1, | ||
| 8 | + var debug: Boolean = false, | ||
| 9 | + var provider: String = "cpu", | ||
| 10 | +) | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +data class OfflinePunctuationConfig( | ||
| 14 | + var model: OfflinePunctuationModelConfig, | ||
| 15 | +) | ||
| 16 | + | ||
| 17 | +class OfflinePunctuation( | ||
| 18 | + assetManager: AssetManager? = null, | ||
| 19 | + config: OfflinePunctuationConfig, | ||
| 20 | +) { | ||
| 21 | + private val ptr: Long | ||
| 22 | + | ||
| 23 | + init { | ||
| 24 | + ptr = if (assetManager != null) { | ||
| 25 | + newFromAsset(assetManager, config) | ||
| 26 | + } else { | ||
| 27 | + newFromFile(config) | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + protected fun finalize() { | ||
| 32 | + delete(ptr) | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + fun release() = finalize() | ||
| 36 | + | ||
| 37 | + fun addPunctuation(text: String) = addPunctuation(ptr, text) | ||
| 38 | + | ||
| 39 | + private external fun delete(ptr: Long) | ||
| 40 | + | ||
| 41 | + private external fun addPunctuation(ptr: Long, text: String): String | ||
| 42 | + | ||
| 43 | + private external fun newFromAsset( | ||
| 44 | + assetManager: AssetManager, | ||
| 45 | + config: OfflinePunctuationConfig, | ||
| 46 | + ): Long | ||
| 47 | + | ||
| 48 | + private external fun newFromFile( | ||
| 49 | + config: OfflinePunctuationConfig, | ||
| 50 | + ): Long | ||
| 51 | + | ||
| 52 | + companion object { | ||
| 53 | + init { | ||
| 54 | + System.loadLibrary("sherpa-onnx-jni") | ||
| 55 | + } | ||
| 56 | + } | ||
| 57 | +} |
-
请 注册 或 登录 后发表评论