继续操作前请注册或者登录。
Fangjun Kuang
Committed by GitHub

Refactor the JNI interface to make it more modular and maintainable (#802)

正在显示 117 个修改的文件 包含 2970 行增加2784 行删除
name: apk-asr
on:
push:
tags:
- '*'
workflow_dispatch:
concurrency:
group: apk-asr-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
jobs:
apk_asr:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
runs-on: ${{ matrix.os }}
name: apk for asr ${{ matrix.index }}/${{ matrix.total }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
total: ["1"]
index: ["0"]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# https://github.com/actions/setup-java
- uses: actions/setup-java@v4
with:
distribution: 'temurin' # See 'Supported distributions' for available options
java-version: '21'
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-android
- name: Display NDK HOME
shell: bash
run: |
echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
ls -lh ${ANDROID_NDK_LATEST_HOME}
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip jinja2
- name: Setup build tool version variable
shell: bash
run: |
echo "---"
ls -lh /usr/local/lib/android/
echo "---"
ls -lh /usr/local/lib/android/sdk
echo "---"
ls -lh /usr/local/lib/android/sdk/build-tools
echo "---"
BUILD_TOOL_VERSION=$(ls /usr/local/lib/android/sdk/build-tools/ | tail -n 1)
echo "BUILD_TOOL_VERSION=$BUILD_TOOL_VERSION" >> $GITHUB_ENV
echo "Last build tool version is: $BUILD_TOOL_VERSION"
- name: Generate build script
shell: bash
run: |
cd scripts/apk
total=${{ matrix.total }}
index=${{ matrix.index }}
./generate-asr-apk-script.py --total $total --index $index
chmod +x build-apk-asr.sh
mv -v ./build-apk-asr.sh ../..
- name: build APK
shell: bash
run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-asr.sh
- name: Display APK
shell: bash
run: |
ls -lh ./apks/
du -h -d1 .
# https://github.com/marketplace/actions/sign-android-release
- uses: r0adkll/sign-android-release@v1
name: Sign app APK
with:
releaseDirectory: ./apks
signingKeyBase64: ${{ secrets.ANDROID_SIGNING_KEY }}
alias: ${{ secrets.ANDROID_SIGNING_KEY_ALIAS }}
keyStorePassword: ${{ secrets.ANDROID_SIGNING_KEY_STORE_PASSWORD }}
env:
BUILD_TOOLS_VERSION: ${{ env.BUILD_TOOL_VERSION }}
- name: Display APK after signing
shell: bash
run: |
ls -lh ./apks/
du -h -d1 .
- name: Rename APK after signing
shell: bash
run: |
cd apks
rm -fv signingKey.jks
rm -fv *.apk.idsig
rm -fv *-aligned.apk
all_apks=$(ls -1 *-signed.apk)
echo "----"
echo $all_apks
echo "----"
for apk in ${all_apks[@]}; do
n=$(echo $apk | sed -e s/-signed//)
mv -v $apk $n
done
cd ..
ls -lh ./apks/
du -h -d1 .
- name: Display APK after rename
shell: bash
run: |
ls -lh ./apks/
du -h -d1 .
- name: Publish to huggingface
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
mkdir -p asr
cp -v ../apks/*.apk ./asr/
git status
git lfs track "*.apk"
git add .
git commit -m "add more apks"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-apk main
... ...
... ... @@ -95,3 +95,4 @@ sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
spoken-language-identification-test-wavs
my-release-key*
vits-zh-hf-fanchen-C
sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
... ...
... ... @@ -16,6 +16,7 @@
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:label="ASR: Next-gen Kaldi"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
\ No newline at end of file
... ...
... ... @@ -12,16 +12,19 @@ import android.widget.Button
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.*
import kotlin.concurrent.thread
private const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
// To enable microphone in android emulator, use
//
// adb emu avd hostmicon
class MainActivity : AppCompatActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
private lateinit var model: SherpaOnnx
private lateinit var recognizer: OnlineRecognizer
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
... ... @@ -87,7 +90,6 @@ class MainActivity : AppCompatActivity() {
audioRecord!!.startRecording()
recordButton.setText(R.string.stop)
isRecording = true
model.reset(true)
textView.text = ""
lastText = ""
idx = 0
... ... @@ -108,6 +110,7 @@ class MainActivity : AppCompatActivity() {
private fun processSamples() {
Log.i(TAG, "processing samples")
val stream = recognizer.createStream()
val interval = 0.1 // i.e., 100 ms
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
... ... @@ -117,29 +120,41 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
while (model.isReady()) {
model.decode()
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
val isEndpoint = model.isEndpoint()
val text = model.text
val isEndpoint = recognizer.isEndpoint(stream)
var text = recognizer.getResult(stream).text
// For streaming parformer, we need to manually add some
// paddings so that it has enough right context to
// recognize the last word of this segment
if (isEndpoint && recognizer.config.modelConfig.paraformer.encoder.isNotBlank()) {
val tailPaddings = FloatArray((0.8 * sampleRateInHz).toInt())
stream.acceptWaveform(tailPaddings, sampleRate = sampleRateInHz)
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
text = recognizer.getResult(stream).text
}
var textToDisplay = lastText;
var textToDisplay = lastText
if(text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = "${idx}: ${text}"
if (text.isNotBlank()) {
textToDisplay = if (lastText.isBlank()) {
"${idx}: $text"
} else {
textToDisplay = "${lastText}\n${idx}: ${text}"
"${lastText}\n${idx}: $text"
}
}
if (isEndpoint) {
model.reset()
recognizer.reset(stream)
if (text.isNotBlank()) {
lastText = "${lastText}\n${idx}: ${text}"
textToDisplay = lastText;
lastText = "${lastText}\n${idx}: $text"
textToDisplay = lastText
idx += 1
}
}
... ... @@ -149,6 +164,7 @@ class MainActivity : AppCompatActivity() {
}
}
}
stream.release()
}
private fun initMicrophone(): Boolean {
... ... @@ -180,7 +196,7 @@ class MainActivity : AppCompatActivity() {
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 0
println("Select model type ${type}")
Log.i(TAG, "Select model type $type")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
... ... @@ -189,7 +205,7 @@ class MainActivity : AppCompatActivity() {
enableEndpoint = true,
)
model = SherpaOnnx(
recognizer = OnlineRecognizer(
assetManager = application.assets,
config = config,
)
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/WaveReader.kt
\ No newline at end of file
... ...
... ... @@ -16,6 +16,7 @@
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:label="2pass ASR: Next-gen Kaldi"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
... ... @@ -29,4 +30,4 @@
</activity>
</application>
</manifest>
\ No newline at end of file
</manifest>
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
\ No newline at end of file
... ...
... ... @@ -17,11 +17,13 @@ import kotlin.concurrent.thread
private const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
// adb emu avd hostmicon
// to enable microphone inside the emulator
class MainActivity : AppCompatActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
private lateinit var onlineRecognizer: SherpaOnnx
private lateinit var offlineRecognizer: SherpaOnnxOffline
private lateinit var onlineRecognizer: OnlineRecognizer
private lateinit var offlineRecognizer: OfflineRecognizer
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
... ... @@ -93,7 +95,6 @@ class MainActivity : AppCompatActivity() {
audioRecord!!.startRecording()
recordButton.setText(R.string.stop)
isRecording = true
onlineRecognizer.reset(true)
samplesBuffer.clear()
textView.text = ""
lastText = ""
... ... @@ -115,6 +116,7 @@ class MainActivity : AppCompatActivity() {
private fun processSamples() {
Log.i(TAG, "processing samples")
val stream = onlineRecognizer.createStream()
val interval = 0.1 // i.e., 100 ms
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
... ... @@ -126,29 +128,29 @@ class MainActivity : AppCompatActivity() {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
samplesBuffer.add(samples)
onlineRecognizer.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (onlineRecognizer.isReady()) {
onlineRecognizer.decode()
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (onlineRecognizer.isReady(stream)) {
onlineRecognizer.decode(stream)
}
val isEndpoint = onlineRecognizer.isEndpoint()
val isEndpoint = onlineRecognizer.isEndpoint(stream)
var textToDisplay = lastText
var text = onlineRecognizer.text
var text = onlineRecognizer.getResult(stream).text
if (text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = if (lastText.isBlank()) {
// textView.text = "${idx}: ${text}"
textToDisplay = "${idx}: ${text}"
"${idx}: $text"
} else {
textToDisplay = "${lastText}\n${idx}: ${text}"
"${lastText}\n${idx}: $text"
}
}
if (isEndpoint) {
onlineRecognizer.reset()
onlineRecognizer.reset(stream)
if (text.isNotBlank()) {
text = runSecondPass()
lastText = "${lastText}\n${idx}: ${text}"
lastText = "${lastText}\n${idx}: $text"
idx += 1
} else {
samplesBuffer.clear()
... ... @@ -160,6 +162,7 @@ class MainActivity : AppCompatActivity() {
}
}
}
stream.release()
}
private fun initMicrophone(): Boolean {
... ... @@ -190,8 +193,8 @@ class MainActivity : AppCompatActivity() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val firstType = 1
println("Select model type ${firstType} for the first pass")
val firstType = 9
Log.i(TAG, "Select model type $firstType for the first pass")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = firstType)!!,
... ... @@ -199,7 +202,7 @@ class MainActivity : AppCompatActivity() {
enableEndpoint = true,
)
onlineRecognizer = SherpaOnnx(
onlineRecognizer = OnlineRecognizer(
assetManager = application.assets,
config = config,
)
... ... @@ -209,15 +212,15 @@ class MainActivity : AppCompatActivity() {
// Please change getOfflineModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val secondType = 1
println("Select model type ${secondType} for the second pass")
val secondType = 0
Log.i(TAG, "Select model type $secondType for the second pass")
val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getOfflineModelConfig(type = secondType)!!,
)
offlineRecognizer = SherpaOnnxOffline(
offlineRecognizer = OfflineRecognizer(
assetManager = application.assets,
config = config,
)
... ... @@ -244,8 +247,15 @@ class MainActivity : AppCompatActivity() {
val n = maxOf(0, samples.size - 8000)
samplesBuffer.clear()
samplesBuffer.add(samples.sliceArray(n..samples.size-1))
samplesBuffer.add(samples.sliceArray(n until samples.size))
return offlineRecognizer.decode(samples.sliceArray(0..n), sampleRateInHz)
val stream = offlineRecognizer.createStream()
stream.acceptWaveform(samples.sliceArray(0..n), sampleRateInHz)
offlineRecognizer.decode(stream)
val result = offlineRecognizer.getResult(stream)
stream.release()
return result.text
}
}
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class EndpointRule(
var mustContainNonSilence: Boolean,
var minTrailingSilence: Float,
var minUtteranceLength: Float,
)
data class EndpointConfig(
var rule1: EndpointRule = EndpointRule(false, 2.0f, 0.0f),
var rule2: EndpointRule = EndpointRule(true, 1.2f, 0.0f),
var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
)
data class OnlineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OnlineParaformerModelConfig(
var encoder: String = "",
var decoder: String = "",
)
data class OnlineZipformer2CtcModelConfig(
var model: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
)
data class OnlineLMConfig(
var model: String = "",
var scale: Float = 0.5f,
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)
data class OfflineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OfflineParaformerModelConfig(
var model: String = "",
)
data class OfflineWhisperModelConfig(
var encoder: String = "",
var decoder: String = "",
var language: String = "en", // Used with multilingual model
var task: String = "transcribe", // transcribe or translate
var tailPaddings: Int = 1000, // Padding added at the end of the samples
)
data class OfflineModelConfig(
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
var tokens: String,
)
data class OfflineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OfflineModelConfig,
// var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)
class SherpaOnnx(
assetManager: AssetManager? = null,
var config: OnlineRecognizerConfig,
) {
private val ptr: Long
init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
}
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)
val text: String
get() = getText(ptr)
val tokens: Array<String>
get() = getTokens(ptr)
private external fun delete(ptr: Long)
private external fun new(
assetManager: AssetManager,
config: OnlineRecognizerConfig,
): Long
private external fun newFromFile(
config: OnlineRecognizerConfig,
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
private external fun getTokens(ptr: Long): Array<String>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
class SherpaOnnxOffline(
assetManager: AssetManager? = null,
var config: OfflineRecognizerConfig,
) {
private val ptr: Long
init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
}
fun decode(samples: FloatArray, sampleRate: Int) = decode(ptr, samples, sampleRate)
private external fun delete(ptr: Long)
private external fun new(
assetManager: AssetManager,
config: OfflineRecognizerConfig,
): Long
private external fun newFromFile(
config: OfflineRecognizerConfig,
): Long
private external fun decode(ptr: Long, samples: FloatArray, sampleRate: Int): String
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23
encoder/joiner int8, decoder float32
1 - csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-20m-2023-02-17-english
encoder/joiner int8, decoder fp32
*/
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
1 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
}
return null
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own LM model. (It should be straightforward to train a new NN LM model
by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineLMConfig(
model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
scale = 0.5f,
)
}
}
return OnlineLMConfig()
}
// for English models, use a small value for rule2.minTrailingSilence, e.g., 0.8
fun getEndpointConfig(): EndpointConfig {
return EndpointConfig(
rule1 = EndpointRule(false, 2.4f, 0.0f),
rule2 = EndpointRule(true, 0.8f, 0.0f),
rule3 = EndpointRule(false, 0.0f, 20.0f)
)
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese
int8
1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english
encoder int8, decoder/joiner float32
2 - sherpa-onnx-whisper-tiny.en
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
encoder int8, decoder int8
3 - sherpa-onnx-whisper-base.en
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
encoder int8, decoder int8
4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese
encoder/joiner int8, decoder fp32
*/
fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"
return OfflineModelConfig(
paraformer = OfflineParaformerModelConfig(
model = "$modelDir/model.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
1 -> {
val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx",
decoder = "$modelDir/decoder-epoch-30-avg-4.onnx",
joiner = "$modelDir/joiner-epoch-30-avg-4.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
2 -> {
val modelDir = "sherpa-onnx-whisper-tiny.en"
return OfflineModelConfig(
whisper = OfflineWhisperModelConfig(
encoder = "$modelDir/tiny.en-encoder.int8.onnx",
decoder = "$modelDir/tiny.en-decoder.int8.onnx",
),
tokens = "$modelDir/tiny.en-tokens.txt",
modelType = "whisper",
)
}
3 -> {
val modelDir = "sherpa-onnx-whisper-base.en"
return OfflineModelConfig(
whisper = OfflineWhisperModelConfig(
encoder = "$modelDir/base.en-encoder.int8.onnx",
decoder = "$modelDir/base.en-decoder.int8.onnx",
),
tokens = "$modelDir/base.en-tokens.txt",
modelType = "whisper",
)
}
4 -> {
val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",
decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",
joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
5 -> {
val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-20-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
)
}
}
return null
}
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
class WaveReader {
companion object {
// Read a mono wave file asset
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromAsset(
assetManager: AssetManager,
filename: String,
): Array<Any>
// Read a mono wave file from disk
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromFile(
filename: String,
): Array<Any>
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
../../../../../../../../../../../../sherpa-onnx/kotlin-api/AudioTagging.kt
\ No newline at end of file
... ...
... ... @@ -46,7 +46,6 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.Tagger
import kotlin.concurrent.thread
... ...
... ... @@ -13,13 +13,14 @@ import androidx.compose.material3.Surface
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
// adb emu avd hostmicon
// to enable mic inside the emulator
class MainActivity : ComponentActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
override fun onCreate(savedInstanceState: Bundle?) {
... ...
../../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
package com.k2fsa.sherpa.onnx.audio.tagging
import android.content.res.AssetManager
import android.util.Log
import com.k2fsa.sherpa.onnx.AudioTagging
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
object Tagger {
... ... @@ -17,7 +19,7 @@ object Tagger {
return
}
Log.i(TAG, "Initializing audio tagger")
Log.i("sherpa-onnx", "Initializing audio tagger")
val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
_tagger = AudioTagging(assetManager, config)
}
... ...
... ... @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
import androidx.wear.compose.material.MaterialTheme
import androidx.wear.compose.material.Text
import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
import kotlin.concurrent.thread
... ...
... ... @@ -17,11 +17,14 @@ import androidx.activity.compose.setContent
import androidx.compose.runtime.Composable
import androidx.core.app.ActivityCompat
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
// adb emu avd hostmicon
// to enable mic inside the emulator
class MainActivity : ComponentActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
override fun onCreate(savedInstanceState: Bundle?) {
... ...
... ... @@ -15,7 +15,8 @@
android:theme="@style/Theme.SherpaOnnx"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:name=".kws.MainActivity"
android:label="Keyword-spotter"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/KeywordSpotter.kt
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
package com.k2fsa.sherpa.onnx.kws
import android.Manifest
import android.content.pm.PackageManager
... ... @@ -14,7 +14,13 @@ import android.widget.TextView
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.*
import com.k2fsa.sherpa.onnx.KeywordSpotter
import com.k2fsa.sherpa.onnx.KeywordSpotterConfig
import com.k2fsa.sherpa.onnx.OnlineStream
import com.k2fsa.sherpa.onnx.R
import com.k2fsa.sherpa.onnx.getFeatureConfig
import com.k2fsa.sherpa.onnx.getKeywordsFile
import com.k2fsa.sherpa.onnx.getKwsModelConfig
import kotlin.concurrent.thread
private const val TAG = "sherpa-onnx"
... ... @@ -23,7 +29,8 @@ private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
class MainActivity : AppCompatActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
private lateinit var model: SherpaOnnxKws
private lateinit var kws: KeywordSpotter
private lateinit var stream: OnlineStream
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
... ... @@ -87,15 +94,18 @@ class MainActivity : AppCompatActivity() {
Log.i(TAG, keywords)
keywords = keywords.replace("\n", "/")
keywords = keywords.trim()
// If keywords is an empty string, it just resets the decoding stream
// always returns true in this case.
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
// Return false if errors occured when adding keywords, true otherwise.
val status = model.reset(keywords)
if (!status) {
Log.i(TAG, "Failed to reset with keywords.")
Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show();
// Return false if errors occurred when adding keywords, true otherwise.
stream.release()
stream = kws.createStream(keywords)
if (stream.ptr == 0L) {
Log.i(TAG, "Failed to create stream with keywords: $keywords")
Toast.makeText(this, "Failed to set keywords to $keywords.", Toast.LENGTH_LONG)
.show()
return
}
... ... @@ -122,6 +132,7 @@ class MainActivity : AppCompatActivity() {
audioRecord!!.release()
audioRecord = null
recordButton.setText(R.string.start)
stream.release()
Log.i(TAG, "Stopped recording")
}
}
... ... @@ -137,22 +148,22 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
while (model.isReady()) {
model.decode()
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (kws.isReady(stream)) {
kws.decode(stream)
}
val text = model.keyword
val text = kws.getResult(stream).keyword
var textToDisplay = lastText;
var textToDisplay = lastText
if(text.isNotBlank()) {
if (text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = "${idx}: ${text}"
textToDisplay = "$idx: $text"
} else {
textToDisplay = "${idx}: ${text}\n${lastText}"
textToDisplay = "$idx: $text\n$lastText"
}
lastText = "${idx}: ${text}\n${lastText}"
lastText = "$idx: $text\n$lastText"
idx += 1
}
... ... @@ -188,20 +199,21 @@ class MainActivity : AppCompatActivity() {
}
private fun initModel() {
// Please change getModelConfig() to add new models
// Please change getKwsModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
// for a list of available models
val type = 0
Log.i(TAG, "Select model type ${type}")
Log.i(TAG, "Select model type $type")
val config = KeywordSpotterConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
keywordsFile = getKeywordsFile(type = type)!!,
modelConfig = getKwsModelConfig(type = type)!!,
keywordsFile = getKeywordsFile(type = type),
)
model = SherpaOnnxKws(
kws = KeywordSpotter(
assetManager = application.assets,
config = config,
)
stream = kws.createStream()
}
}
}
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
\ No newline at end of file
... ...
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
class WaveReader {
companion object {
// Read a mono wave file asset
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromAsset(
assetManager: AssetManager,
filename: String,
): Array<Any>
// Read a mono wave file from disk
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromFile(
filename: String,
): Array<Any>
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
<resources>
<string name="app_name">KWS with Next-gen Kaldi</string>
<string name="app_name">Keyword spotting</string>
<string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
\n
\n\n\n
The source code and pre-trained models are publicly available.
Please see https://github.com/k2-fsa/sherpa-onnx for details.
</string>
<string name="keyword_hint">Input your keywords here, one keyword perline.</string>
<string name="keyword_hint">Input your keywords here, one keyword per line.\nTwo example keywords are given below:\n\nn ǐ h ǎo @你好\nd àn g ē d àn g ē @蛋哥蛋哥</string>
<string name="start">Start</string>
<string name="stop">Stop</string>
</resources>
... ...
... ... @@ -2,7 +2,7 @@ package com.k2fsa.sherpa.onnx.speaker.identification
import androidx.compose.ui.graphics.vector.ImageVector
data class BarItem (
data class BarItem(
val title: String,
// see https://www.composables.com/icons
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
sealed class NavRoutes(val route: String) {
object Home: NavRoutes("home")
object Register: NavRoutes("register")
object View: NavRoutes("view")
object Help: NavRoutes("help")
object Home : NavRoutes("home")
object Register : NavRoutes("register")
object View : NavRoutes("view")
object Help : NavRoutes("help")
}
\ No newline at end of file
... ...
../../../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
\ No newline at end of file
... ...
../../../../../../../../../../../../sherpa-onnx/kotlin-api/Speaker.kt
\ No newline at end of file
... ...
@file:OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@file:OptIn(ExperimentalMaterial3Api::class)
package com.k2fsa.sherpa.onnx.slid
... ... @@ -9,11 +9,9 @@ import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.ui.Modifier
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
... ... @@ -31,6 +29,7 @@ import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
... ... @@ -63,13 +62,13 @@ fun Home() {
}
private var audioRecord: AudioRecord? = null
private val sampleRateInHz = 16000
private const val sampleRateInHz = 16000
@Composable
fun MyApp(padding: PaddingValues) {
val activity = LocalContext.current as Activity
var isStarted by remember { mutableStateOf(false) }
var result by remember { mutableStateOf<String>("") }
var result by remember { mutableStateOf("") }
val onButtonClick: () -> Unit = {
isStarted = !isStarted
... ... @@ -114,12 +113,12 @@ fun MyApp(padding: PaddingValues) {
}
Log.i(TAG, "Stop recording")
Log.i(TAG, "Start recognition")
val samples = Flatten(sampleList)
val samples = flatten(sampleList)
val stream = Slid.slid.createStream()
stream.acceptWaveform(samples, sampleRateInHz)
val lang = Slid.slid.compute(stream)
result = Slid.localeMap.get(lang) ?: lang
result = Slid.localeMap[lang] ?: lang
stream.release()
}
... ... @@ -152,7 +151,7 @@ fun MyApp(padding: PaddingValues) {
}
}
fun Flatten(sampleList: ArrayList<FloatArray>): FloatArray {
fun flatten(sampleList: ArrayList<FloatArray>): FloatArray {
var totalSamples = 0
for (a in sampleList) {
totalSamples += a.size
... ...
... ... @@ -10,12 +10,9 @@ import androidx.activity.compose.setContent
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.SpokenLanguageIdentification
import com.k2fsa.sherpa.onnx.slid.ui.theme.SherpaOnnxSpokenLanguageIdentificationTheme
const val TAG = "sherpa-onnx"
... ... @@ -32,6 +29,7 @@ class MainActivity : ComponentActivity() {
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
Slid.initSlid(this.assets)
}
@Suppress("DEPRECATION")
@Deprecated("Deprecated in Java")
override fun onRequestPermissionsResult(
... ...
../../../../../../../../../../SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt
\ No newline at end of file
../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
\ No newline at end of file
... ...
../../../../../../../../../../../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
\ No newline at end of file
... ...
... ... @@ -15,10 +15,10 @@ object Slid {
get() {
return _slid!!
}
val localeMap : Map<String, String>
get() {
return _localeMap
}
val localeMap: Map<String, String>
get() {
return _localeMap
}
fun initSlid(assetManager: AssetManager? = null, numThreads: Int = 1) {
synchronized(this) {
... ... @@ -31,7 +31,7 @@ object Slid {
}
if (_localeMap.isEmpty()) {
val allLang = Locale.getISOLanguages();
val allLang = Locale.getISOLanguages()
for (lang in allLang) {
val locale = Locale(lang)
_localeMap[lang] = locale.displayName
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.media.*
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.media.MediaPlayer
import android.net.Uri
import android.os.Bundle
import android.util.Log
... ... @@ -212,7 +216,7 @@ class MainActivity : AppCompatActivity() {
}
if (dictDir != null) {
val newDir = copyDataDir( modelDir!!)
val newDir = copyDataDir(modelDir!!)
modelDir = newDir + "/" + modelDir
dictDir = modelDir + "/" + "dict"
ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst"
... ... @@ -220,7 +224,9 @@ class MainActivity : AppCompatActivity() {
}
val config = getOfflineTtsConfig(
modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "",
modelDir = modelDir!!,
modelName = modelName!!,
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",
ruleFsts = ruleFsts ?: "",
... ... @@ -232,11 +238,11 @@ class MainActivity : AppCompatActivity() {
private fun copyDataDir(dataDir: String): String {
println("data dir is $dataDir")
Log.i(TAG, "data dir is $dataDir")
copyAssets(dataDir)
val newDataDir = application.getExternalFilesDir(null)!!.absolutePath
println("newDataDir: $newDataDir")
Log.i(TAG, "newDataDir: $newDataDir")
return newDataDir
}
... ... @@ -256,7 +262,7 @@ class MainActivity : AppCompatActivity() {
}
}
} catch (ex: IOException) {
Log.e(TAG, "Failed to copy $path. ${ex.toString()}")
Log.e(TAG, "Failed to copy $path. $ex")
}
}
... ... @@ -276,7 +282,7 @@ class MainActivity : AppCompatActivity() {
ostream.flush()
ostream.close()
} catch (ex: Exception) {
Log.e(TAG, "Failed to copy $filename, ${ex.toString()}")
Log.e(TAG, "Failed to copy $filename, $ex")
}
}
}
... ...
... ... @@ -49,10 +49,10 @@ class OfflineTts(
private var ptr: Long
init {
if (assetManager != null) {
ptr = newFromAsset(assetManager, config)
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
newFromFile(config)
}
}
... ... @@ -65,7 +65,7 @@ class OfflineTts(
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
var objArray = generateImpl(ptr, text = text, sid = sid, speed = speed)
val objArray = generateImpl(ptr, text = text, sid = sid, speed = speed)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
... ... @@ -78,7 +78,13 @@ class OfflineTts(
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
): GeneratedAudio {
var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
val objArray = generateWithCallbackImpl(
ptr,
text = text,
sid = sid,
speed = speed,
callback = callback
)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
... ... @@ -87,10 +93,10 @@ class OfflineTts(
fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
if (assetManager != null) {
ptr = newFromAsset(assetManager, config)
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
newFromFile(config)
}
}
}
... ... @@ -103,9 +109,14 @@ class OfflineTts(
}
protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineTtsConfig,
... ... @@ -123,14 +134,14 @@ class OfflineTts(
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
external fun generateImpl(
private external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): Array<Any>
external fun generateWithCallbackImpl(
private external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
... ... @@ -156,7 +167,7 @@ fun getOfflineTtsConfig(
dictDir: String,
ruleFsts: String,
ruleFars: String
): OfflineTtsConfig? {
): OfflineTtsConfig {
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = OfflineTtsVitsModelConfig(
... ...
package com.k2fsa.sherpa.onnx.tts.engine
import android.content.Intent
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.speech.tts.TextToSpeech
import androidx.appcompat.app.AppCompatActivity
class CheckVoiceData : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
val intent = Intent().apply {
putStringArrayListExtra(TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES, arrayListOf(TtsEngine.lang))
putStringArrayListExtra(
TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES,
arrayListOf(TtsEngine.lang)
)
putStringArrayListExtra(TextToSpeech.Engine.EXTRA_UNAVAILABLE_VOICES, arrayListOf())
}
setResult(TextToSpeech.Engine.CHECK_VOICE_DATA_PASS, intent)
... ...
... ... @@ -2,7 +2,6 @@ package com.k2fsa.sherpa.onnx.tts.engine
import android.app.Activity
import android.content.Intent
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.speech.tts.TextToSpeech
... ... @@ -12,120 +11,168 @@ fun getSampleText(lang: String): String {
"ara" -> {
text = "هذا هو محرك تحويل النص إلى كلام باستخدام الجيل القادم من كالدي"
}
"ben" -> {
text = "এটি একটি টেক্সট-টু-স্পীচ ইঞ্জিন যা পরবর্তী প্রজন্মের কালডি ব্যবহার করে"
}
"bul" -> {
text = "Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение"
text =
"Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение"
}
"cat" -> {
text = "Aquest és un motor de text a veu que utilitza Kaldi de nova generació"
}
"ces" -> {
text = "Toto je převodník textu na řeč využívající novou generaci kaldi"
}
"dan" -> {
text = "Dette er en tekst til tale-motor, der bruger næste generation af kaldi"
}
"deu" -> {
text = "Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet"
text =
"Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet"
}
"ell" -> {
text = "Αυτή είναι μια μηχανή κειμένου σε ομιλία που χρησιμοποιεί kaldi επόμενης γενιάς"
}
"eng" -> {
text = "This is a text-to-speech engine using next generation Kaldi"
}
"est" -> {
text = "See on teksti kõneks muutmise mootor, mis kasutab järgmise põlvkonna Kaldi"
}
"fin" -> {
text = "Tämä on tekstistä puheeksi -moottori, joka käyttää seuraavan sukupolven kaldia"
}
"fra" -> {
text = "Il s'agit d'un moteur de synthèse vocale utilisant Kaldi de nouvelle génération"
}
"gle" -> {
text = "Is inneall téacs-go-hurlabhra é seo a úsáideann Kaldi den chéad ghlúin eile"
}
"hrv" -> {
text = "Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije"
text =
"Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije"
}
"hun" -> {
text = "Ez egy szövegfelolvasó motor a következő generációs kaldi használatával"
}
"isl" -> {
text = "Þetta er texta í tal vél sem notar næstu kynslóð kaldi"
}
"ita" -> {
text = "Questo è un motore di sintesi vocale che utilizza kaldi di nuova generazione"
}
"kat" -> {
text = "ეს არის ტექსტიდან მეტყველების ძრავა შემდეგი თაობის კალდის გამოყენებით"
}
"kaz" -> {
text = "Бұл келесі буын kaldi көмегімен мәтіннен сөйлеуге арналған қозғалтқыш"
}
"mlt" -> {
text = "Din hija magna text-to-speech li tuża Kaldi tal-ġenerazzjoni li jmiss"
}
"lav" -> {
text = "Šis ir teksta pārvēršanas runā dzinējs, kas izmanto nākamās paaudzes Kaldi"
}
"lit" -> {
text = "Tai teksto į kalbą variklis, kuriame naudojamas naujos kartos Kaldi"
}
"ltz" -> {
text = "Dëst ass en Text-zu-Speech-Motor mat der nächster Generatioun Kaldi"
}
"nep" -> {
text = "यो अर्को पुस्ता काल्डी प्रयोग गरेर स्पीच इन्जिनको पाठ हो"
}
"nld" -> {
text = "Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie"
text =
"Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie"
}
"nor" -> {
text = "Dette er en tekst til tale-motor som bruker neste generasjons kaldi"
}
"pol" -> {
text = "Jest to silnik syntezatora mowy wykorzystujący Kaldi nowej generacji"
}
"por" -> {
text = "Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração"
text =
"Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração"
}
"ron" -> {
text = "Acesta este un motor text to speech care folosește generația următoare de kadi"
}
"rus" -> {
text = "Это движок преобразования текста в речь, использующий Kaldi следующего поколения."
text =
"Это движок преобразования текста в речь, использующий Kaldi следующего поколения."
}
"slk" -> {
text = "Toto je nástroj na prevod textu na reč využívajúci kaldi novej generácie"
}
"slv" -> {
text = "To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije"
text =
"To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije"
}
"spa" -> {
text = "Este es un motor de texto a voz que utiliza kaldi de próxima generación."
}
"srp" -> {
text = "Ово је механизам за претварање текста у говор који користи калди следеће генерације"
text =
"Ово је механизам за претварање текста у говор који користи калди следеће генерације"
}
"swa" -> {
text = "Haya ni maandishi kwa injini ya hotuba kwa kutumia kizazi kijacho kaldi"
}
"swe" -> {
text = "Detta är en text till tal-motor som använder nästa generations kaldi"
}
"tur" -> {
text = "Bu, yeni nesil kaldi'yi kullanan bir metinden konuşmaya motorudur"
}
"ukr" -> {
text = "Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління"
text =
"Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління"
}
"vie" -> {
text = "Đây là công cụ chuyển văn bản thành giọng nói sử dụng kaldi thế hệ tiếp theo"
}
"zho", "cmn" -> {
text = "使用新一代卡尔迪的语音合成引擎"
}
... ... @@ -137,13 +184,13 @@ class GetSampleText : Activity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
var result = TextToSpeech.LANG_AVAILABLE
var text: String = getSampleText(TtsEngine.lang ?: "")
val text: String = getSampleText(TtsEngine.lang ?: "")
if (text.isEmpty()) {
result = TextToSpeech.LANG_NOT_SUPPORTED
}
val intent = Intent().apply{
if(result == TextToSpeech.LANG_AVAILABLE) {
val intent = Intent().apply {
if (result == TextToSpeech.LANG_AVAILABLE) {
putExtra(TextToSpeech.Engine.EXTRA_SAMPLE_TEXT, text)
} else {
putExtra("sampleText", text)
... ...
... ... @@ -26,20 +26,16 @@ import androidx.compose.material3.Scaffold
import androidx.compose.material3.Slider
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.material3.TopAppBar
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
import java.io.File
import java.lang.NumberFormatException
const val TAG = "sherpa-onnx-tts-engine"
... ... @@ -76,7 +72,7 @@ class MainActivity : ComponentActivity() {
val testTextContent = getSampleText(TtsEngine.lang ?: "")
var testText by remember { mutableStateOf(testTextContent) }
val numSpeakers = TtsEngine.tts!!.numSpeakers()
if (numSpeakers > 1) {
OutlinedTextField(
... ... @@ -88,7 +84,7 @@ class MainActivity : ComponentActivity() {
try {
TtsEngine.speakerId = it.toString().toInt()
} catch (ex: NumberFormatException) {
Log.i(TAG, "Invalid input: ${it}")
Log.i(TAG, "Invalid input: $it")
TtsEngine.speakerId = 0
}
}
... ... @@ -119,7 +115,7 @@ class MainActivity : ComponentActivity() {
Button(
modifier = Modifier.padding(20.dp),
onClick = {
Log.i(TAG, "Clicked, text: ${testText}")
Log.i(TAG, "Clicked, text: $testText")
if (testText.isBlank() || testText.isEmpty()) {
Toast.makeText(
applicationContext,
... ... @@ -136,7 +132,7 @@ class MainActivity : ComponentActivity() {
val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.size > 0 && audio.save(filename)
audio.samples.isNotEmpty() && audio.save(filename)
if (ok) {
stopMediaPlayer()
... ...
... ... @@ -4,8 +4,10 @@ import android.content.Context
import android.content.res.AssetManager
import android.util.Log
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import com.k2fsa.sherpa.onnx.*
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableIntStateOf
import com.k2fsa.sherpa.onnx.OfflineTts
import com.k2fsa.sherpa.onnx.getOfflineTtsConfig
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
... ... @@ -21,8 +23,8 @@ object TtsEngine {
var lang: String? = null
val speedState: MutableState<Float> = mutableStateOf(1.0F)
val speakerIdState: MutableState<Int> = mutableStateOf(0)
val speedState: MutableState<Float> = mutableFloatStateOf(1.0F)
val speakerIdState: MutableState<Int> = mutableIntStateOf(0)
var speed: Float
get() = speedState.value
... ... @@ -113,15 +115,15 @@ object TtsEngine {
if (dataDir != null) {
val newDir = copyDataDir(context, modelDir!!)
modelDir = newDir + "/" + modelDir
dataDir = newDir + "/" + dataDir
modelDir = "$newDir/$modelDir"
dataDir = "$newDir/$dataDir"
assets = null
}
if (dictDir != null) {
val newDir = copyDataDir(context, modelDir!!)
modelDir = newDir + "/" + modelDir
dictDir = modelDir + "/" + "dict"
modelDir = "$newDir/$modelDir"
dictDir = "$modelDir/dict"
ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst"
assets = null
}
... ... @@ -132,18 +134,18 @@ object TtsEngine {
dictDir = dictDir ?: "",
ruleFsts = ruleFsts ?: "",
ruleFars = ruleFars ?: ""
)!!
)
tts = OfflineTts(assetManager = assets, config = config)
}
private fun copyDataDir(context: Context, dataDir: String): String {
println("data dir is $dataDir")
Log.i(TAG, "data dir is $dataDir")
copyAssets(context, dataDir)
val newDataDir = context.getExternalFilesDir(null)!!.absolutePath
println("newDataDir: $newDataDir")
Log.i(TAG, "newDataDir: $newDataDir")
return newDataDir
}
... ... @@ -158,12 +160,12 @@ object TtsEngine {
val dir = File(fullPath)
dir.mkdirs()
for (asset in assets.iterator()) {
val p: String = if (path == "") "" else path + "/"
val p: String = if (path == "") "" else "$path/"
copyAssets(context, p + asset)
}
}
} catch (ex: IOException) {
Log.e(TAG, "Failed to copy $path. ${ex.toString()}")
Log.e(TAG, "Failed to copy $path. $ex")
}
}
... ... @@ -183,7 +185,7 @@ object TtsEngine {
ostream.flush()
ostream.close()
} catch (ex: Exception) {
Log.e(TAG, "Failed to copy $filename, ${ex.toString()}")
Log.e(TAG, "Failed to copy $filename, $ex")
}
}
}
... ...
... ... @@ -6,7 +6,6 @@ import android.speech.tts.SynthesisRequest
import android.speech.tts.TextToSpeech
import android.speech.tts.TextToSpeechService
import android.util.Log
import com.k2fsa.sherpa.onnx.*
/*
https://developer.android.com/reference/java/util/Locale#getISO3Language()
... ...
package com.k2fsa.sherpa.onnx.tts.engine
import android.app.Application
import android.os.FileUtils.ProgressListener
import android.speech.tts.TextToSpeech
import android.speech.tts.TextToSpeech.OnInitListener
import android.speech.tts.UtteranceProgressListener
... ... @@ -27,7 +26,7 @@ class TtsViewModel : ViewModel() {
private val onInitListener = object : OnInitListener {
override fun onInit(status: Int) {
when (status) {
TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeded")
TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeeded")
TextToSpeech.ERROR -> Log.i(TAG, "Init tts failed")
else -> Log.i(TAG, "Unknown status $status")
}
... ...
... ... @@ -15,7 +15,7 @@
android:theme="@style/Theme.SherpaOnnxVad"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:name="com.k2fsa.sherpa.onnx.vad.MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
... ...
package com.k2fsa.sherpa.onnx
package com.k2fsa.sherpa.onnx.vad
import android.Manifest
import android.content.pm.PackageManager
... ... @@ -11,6 +11,9 @@ import android.view.View
import android.widget.Button
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.R
import com.k2fsa.sherpa.onnx.Vad
import com.k2fsa.sherpa.onnx.getVadModelConfig
import kotlin.concurrent.thread
... ... @@ -116,7 +119,7 @@ class MainActivity : AppCompatActivity() {
private fun initVadModel() {
val type = 0
println("Select VAD model type ${type}")
Log.i(TAG, "Select VAD model type ${type}")
val config = getVadModelConfig(type)
vad = Vad(
... ... @@ -171,4 +174,4 @@ class MainActivity : AppCompatActivity() {
}
}
}
}
\ No newline at end of file
}
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt
\ No newline at end of file
... ...
... ... @@ -4,7 +4,7 @@
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
tools:context="com.k2fsa.sherpa.onnx.vad.MainActivity">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
... ... @@ -40,4 +40,4 @@
</androidx.constraintlayout.widget.ConstraintLayout>
\ No newline at end of file
</androidx.constraintlayout.widget.ConstraintLayout>
... ...
... ... @@ -15,7 +15,7 @@
android:theme="@style/Theme.SherpaOnnxVadAsr"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:name=".vad.asr.MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
package com.k2fsa.sherpa.onnx.vad.asr
import android.Manifest
import android.content.pm.PackageManager
... ... @@ -13,6 +13,13 @@ import android.widget.Button
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.OfflineRecognizer
import com.k2fsa.sherpa.onnx.OfflineRecognizerConfig
import com.k2fsa.sherpa.onnx.R
import com.k2fsa.sherpa.onnx.Vad
import com.k2fsa.sherpa.onnx.getFeatureConfig
import com.k2fsa.sherpa.onnx.getOfflineModelConfig
import com.k2fsa.sherpa.onnx.getVadModelConfig
import kotlin.concurrent.thread
... ... @@ -40,7 +47,7 @@ class MainActivity : AppCompatActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
// Non-streaming ASR
private lateinit var offlineRecognizer: SherpaOnnxOffline
private lateinit var offlineRecognizer: OfflineRecognizer
private var idx: Int = 0
private var lastText: String = ""
... ... @@ -122,7 +129,7 @@ class MainActivity : AppCompatActivity() {
private fun initVadModel() {
val type = 0
println("Select VAD model type ${type}")
Log.i(TAG, "Select VAD model type ${type}")
val config = getVadModelConfig(type)
vad = Vad(
... ... @@ -194,20 +201,25 @@ class MainActivity : AppCompatActivity() {
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val secondType = 0
println("Select model type ${secondType} for the second pass")
Log.i(TAG, "Select model type ${secondType} for the second pass")
val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getOfflineModelConfig(type = secondType)!!,
)
offlineRecognizer = SherpaOnnxOffline(
offlineRecognizer = OfflineRecognizer(
assetManager = application.assets,
config = config,
)
}
private fun runSecondPass(samples: FloatArray): String {
return offlineRecognizer.decode(samples, sampleRateInHz)
val stream = offlineRecognizer.createStream()
stream.acceptWaveform(samples, sampleRateInHz)
offlineRecognizer.decode(stream)
val result = offlineRecognizer.getResult(stream)
stream.release()
return result.text
}
}
\ No newline at end of file
}
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
\ No newline at end of file
... ...
../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
\ No newline at end of file
... ...
../../../../../../../../../SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
\ No newline at end of file
../../../../../../../../../SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt
\ No newline at end of file
../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt
\ No newline at end of file
... ...
... ... @@ -4,7 +4,7 @@
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
tools:context=".vad.asr.MainActivity">
<LinearLayout
android:layout_width="match_parent"
... ...
<resources>
<string name="app_name">VAD-ASR</string>
<string name="app_name">VAD+ASR</string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n
\n\n\n
... ...
... ... @@ -59,7 +59,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi
if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
-DBUILD_ESPEAK_NG_EXE=OFF \
... ...
... ... @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi
if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
-DBUILD_ESPEAK_NG_EXE=OFF \
... ...
... ... @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi
if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
-DBUILD_ESPEAK_NG_EXE=OFF \
... ...
... ... @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi
if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
-DBUILD_ESPEAK_NG_EXE=OFF \
... ...
../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/AudioTagging.kt
\ No newline at end of file
... ...
../sherpa-onnx/kotlin-api/FeatureConfig.kt
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
fun callback(samples: FloatArray): Unit {
println("callback got called with ${samples.size} samples");
}
fun main() {
testSpokenLanguageIdentifcation()
testAudioTagging()
testSpeakerRecognition()
testTts()
testAsr("transducer")
testAsr("zipformer2-ctc")
}
fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(assetManager=null, config=config)
val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)
for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
}
fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
zipformer=OfflineZipformerAudioTaggingModelConfig(
model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
),
numThreads=1,
debug=true,
provider="cpu",
),
labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
topK=5,
)
val tagger = AudioTagging(assetManager=null, config=config)
val testFiles = arrayOf(
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
)
println("----------")
for (waveFilename in testFiles) {
val stream = tagger.createStream()
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
stream.acceptWaveform(samples, sampleRate = sampleRate)
val events = tagger.compute(stream)
stream.release()
println(waveFilename)
println(events)
println("----------")
}
tagger.release()
}
fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
var objArray = WaveReader.readWaveFromFile(
filename = filename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
val stream = extractor.createStream()
stream.acceptWaveform(sampleRate = sampleRate, samples=samples)
stream.inputFinished()
check(extractor.isReady(stream))
val embedding = extractor.compute(stream)
stream.release()
return embedding
}
fun testSpeakerRecognition() {
val config = SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
)
val extractor = SpeakerEmbeddingExtractor(config = config)
val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")
val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")
val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")
var manager = SpeakerEmbeddingManager(extractor.dim())
var ok = manager.add(name = "speaker1", embedding=embedding1a)
check(ok)
manager.add(name = "speaker2", embedding=embedding2a)
check(ok)
var name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "speaker1")
manager.release()
manager = SpeakerEmbeddingManager(extractor.dim())
val embeddingList = mutableListOf(embedding1a, embedding1b)
ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())
check(ok)
name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "s1")
name = manager.search(embedding=embedding2a, threshold=0.5f)
check(name.length == 0)
manager.release()
}
fun testTts() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
vits=OfflineTtsVitsModelConfig(
model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx",
tokens="./vits-piper-en_US-amy-low/tokens.txt",
dataDir="./vits-piper-en_US-amy-low/espeak-ng-data",
),
numThreads=1,
debug=true,
)
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
audio.save(filename="test-en.wav")
}
fun testAsr(type: String) {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)
var waveFilename: String
var modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
var endpointConfig = EndpointConfig()
var lmConfig = OnlineLMConfig()
var config = OnlineRecognizerConfig(
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
var model = SherpaOnnx(
config = config,
)
var objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
model.acceptWaveform(samples, sampleRate = sampleRate)
while (model.isReady()) {
model.decode()
}
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
model.inputFinished()
while (model.isReady()) {
model.decode()
}
println("results: ${model.text}")
}
../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
\ No newline at end of file
... ...
../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/OfflineStream.kt
\ No newline at end of file
... ...
../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
\ No newline at end of file
... ...
../sherpa-onnx/kotlin-api/OnlineStream.kt
\ No newline at end of file
... ...
../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
\ No newline at end of file
../android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
\ No newline at end of file
../android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/Speaker.kt
\ No newline at end of file
... ...
../android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
\ No newline at end of file
... ...
../android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/Vad.kt
\ No newline at end of file
... ...
../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt
\ No newline at end of file
../sherpa-onnx/kotlin-api/WaveReader.kt
\ No newline at end of file
... ...
... ... @@ -44,9 +44,23 @@ function testSpeakerEmbeddingExtractor() {
if [ ! -f ./speaker2_a_cn_16k.wav ]; then
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
fi
out_filename=test_speaker_id.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_speaker_id.kt \
OnlineStream.kt \
Speaker.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
function testAsr() {
function testOnlineAsr() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git lfs install
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
... ... @@ -57,6 +71,20 @@ function testAsr() {
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
out_filename=test_online_asr.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_online_asr.kt \
FeatureConfig.kt \
OnlineRecognizer.kt \
OnlineStream.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
function testTts() {
... ... @@ -65,16 +93,42 @@ function testTts() {
tar xf vits-piper-en_US-amy-low.tar.bz2
rm vits-piper-en_US-amy-low.tar.bz2
fi
out_filename=test_tts.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_tts.kt \
Tts.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
function testAudioTagging() {
if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
fi
out_filename=test_audio_tagging.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_audio_tagging.kt \
AudioTagging.kt \
OfflineStream.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
function testSpokenLanguageIdentification() {
if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
... ... @@ -87,50 +141,44 @@ function testSpokenLanguageIdentification() {
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
fi
}
function test() {
testSpokenLanguageIdentification
testAudioTagging
testSpeakerEmbeddingExtractor
testAsr
testTts
}
out_filename=test_language_id.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_language_id.kt \
SpokenLanguageIdentification.kt \
OfflineStream.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt
test
kotlinc-jvm -include-runtime -d main.jar \
AudioTagging.kt \
Main.kt \
OfflineStream.kt \
SherpaOnnx.kt \
Speaker.kt \
SpokenLanguageIdentification.kt \
Tts.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt
ls -lh main.jar
java -Djava.library.path=../build/lib -jar main.jar
function testTwoPass() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
fi
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
function testOfflineAsr() {
if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
fi
kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt
ls -lh 2pass.jar
java -Djava.library.path=../build/lib -jar 2pass.jar
out_filename=test_offline_asr.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_offline_asr.kt \
FeatureConfig.kt \
OfflineRecognizer.kt \
OfflineStream.kt \
WaveReader.kt \
faked-asset-manager.kt
ls -lh $out_filename
java -Djava.library.path=../build/lib -jar $out_filename
}
testTwoPass
testSpeakerEmbeddingExtractor
testOnlineAsr
testTts
testAudioTagging
testSpokenLanguageIdentification
testOfflineAsr
... ...
package com.k2fsa.sherpa.onnx
fun main() {
test2Pass()
}
fun test2Pass() {
val firstPass = createFirstPass()
val secondPass = createSecondPass()
val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"
var objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
firstPass.acceptWaveform(samples, sampleRate = sampleRate)
while (firstPass.isReady()) {
firstPass.decode()
}
var text = firstPass.text
println("First pass text: $text")
text = secondPass.decode(samples, sampleRate)
println("Second pass text: $text")
}
fun createFirstPass(): SherpaOnnx {
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getModelConfig(type = 1)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
)
return SherpaOnnx(config = config)
}
fun createSecondPass(): SherpaOnnxOffline {
val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getOfflineModelConfig(type = 2)!!,
)
return SherpaOnnxOffline(config = config)
}
package com.k2fsa.sherpa.onnx
fun main() {
testAudioTagging()
}
fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
zipformer=OfflineZipformerAudioTaggingModelConfig(
model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
),
numThreads=1,
debug=true,
provider="cpu",
),
labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
topK=5,
)
val tagger = AudioTagging(config=config)
val testFiles = arrayOf(
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
)
println("----------")
for (waveFilename in testFiles) {
val stream = tagger.createStream()
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
stream.acceptWaveform(samples, sampleRate = sampleRate)
val events = tagger.compute(stream)
stream.release()
println(waveFilename)
println(events)
println("----------")
}
tagger.release()
}
... ...
package com.k2fsa.sherpa.onnx
fun main() {
testSpokenLanguageIdentifcation()
}
fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(config=config)
val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)
for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
slid.release()
}
... ...
package com.k2fsa.sherpa.onnx
fun main() {
val recognizer = createOfflineRecognizer()
val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = recognizer.createStream()
stream.acceptWaveform(samples, sampleRate=sampleRate)
recognizer.decode(stream)
val result = recognizer.getResult(stream)
println(result)
stream.release()
recognizer.release()
}
fun createOfflineRecognizer(): OfflineRecognizer {
val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getOfflineModelConfig(type = 2)!!,
)
return OfflineRecognizer(config = config)
}
... ...
package com.k2fsa.sherpa.onnx
fun main() {
testOnlineAsr("transducer")
testOnlineAsr("zipformer2-ctc")
}
fun testOnlineAsr(type: String) {
val featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)
val waveFilename: String
val modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
val endpointConfig = EndpointConfig()
val lmConfig = OnlineLMConfig()
val config = OnlineRecognizerConfig(
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
val recognizer = OnlineRecognizer(
config = config,
)
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = recognizer.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
val tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
stream.acceptWaveform(tailPaddings, sampleRate = sampleRate)
stream.inputFinished()
while (recognizer.isReady(stream)) {
recognizer.decode(stream)
}
println("results: ${recognizer.getResult(stream).text}")
stream.release()
recognizer.release()
}
... ...
package com.k2fsa.sherpa.onnx
fun main() {
testSpeakerRecognition()
}
fun testSpeakerRecognition() {
val config = SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
)
val extractor = SpeakerEmbeddingExtractor(config = config)
val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")
val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")
val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")
var manager = SpeakerEmbeddingManager(extractor.dim())
var ok = manager.add(name = "speaker1", embedding=embedding1a)
check(ok)
manager.add(name = "speaker2", embedding=embedding2a)
check(ok)
var name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "speaker1")
manager.release()
manager = SpeakerEmbeddingManager(extractor.dim())
val embeddingList = mutableListOf(embedding1a, embedding1b)
ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())
check(ok)
name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "s1")
name = manager.search(embedding=embedding2a, threshold=0.5f)
check(name.length == 0)
manager.release()
extractor.release()
println("Speaker ID test done!")
}
fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
var objArray = WaveReader.readWaveFromFile(
filename = filename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
val stream = extractor.createStream()
stream.acceptWaveform(sampleRate = sampleRate, samples=samples)
stream.inputFinished()
check(extractor.isReady(stream))
val embedding = extractor.compute(stream)
stream.release()
return embedding
}
... ...
package com.k2fsa.sherpa.onnx
fun main() {
testTts()
}
fun testTts() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
vits=OfflineTtsVitsModelConfig(
model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx",
tokens="./vits-piper-en_US-amy-low/tokens.txt",
dataDir="./vits-piper-en_US-amy-low/espeak-ng-data",
),
numThreads=1,
debug=true,
)
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
audio.save(filename="test-en.wav")
tts.release()
println("Saved to test-en.wav")
}
fun callback(samples: FloatArray): Unit {
println("callback got called with ${samples.size} samples");
}
... ...
#!/usr/bin/env bash
#
# Auto generated! Please DO NOT EDIT!
# Please set the environment variable ANDROID_NDK
# before running this script
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
log "Building streaming ASR APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
export SHERPA_ONNX_ENABLE_TTS=OFF
log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
log "====================armv7-eabi================"
./build-android-armv7-eabi.sh
log "====================x86-64===================="
./build-android-x86-64.sh
log "====================x86===================="
./build-android-x86.sh
mkdir -p apks
{% for model in model_list %}
pushd ./android/SherpaOnnx/app/src/main/assets/
model_name={{ model.model_name }}
type={{ model.idx }}
lang={{ model.lang }}
short_name={{ model.short_name }}
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/${model_name}.tar.bz2
tar xvf ${model_name}.tar.bz2
{{ model.cmd }}
rm -rf *.tar.bz2
ls -lh $model_name
popd
# Now we are at the project root directory
git checkout .
pushd android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx
sed -i.bak s/"type = 0/type = $type/" ./MainActivity.kt
git diff
popd
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build ASR apk for $arch"
log "------------------------------------------------------------"
src_arch=$arch
if [ $arch == "armeabi-v7a" ]; then
src_arch=armv7-eabi
elif [ $arch == "x86_64" ]; then
src_arch=x86-64
fi
ls -lh ./build-android-$src_arch/install/lib/*.so
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnx/app/src/main/jniLibs/$arch/
pushd ./android/SherpaOnnx
sed -i.bak s/2048/9012/g ./gradle.properties
git diff ./gradle.properties
./gradlew assembleRelease
popd
mv android/SherpaOnnx/app/build/outputs/apk/release/app-release-unsigned.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-asr-$lang-$short_name.apk
ls -lh apks
rm -v ./android/SherpaOnnx/app/src/main/jniLibs/$arch/*.so
done
rm -rf ./android/SherpaOnnx/app/src/main/assets/$model_name
{% endfor %}
git checkout .
ls -lh apks/
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=OFF
mkdir -p apks
{% for model in model_list %}
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=OFF
mkdir -p apks
{% for model in model_list %}
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=OFF
mkdir -p apks
{% for model in model_list %}
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=OFF
mkdir -p apks
{% for model in model_list %}
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=ON
mkdir -p apks
{% for tts_model in tts_model_list %}
... ...
... ... @@ -29,6 +29,8 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh
export SHERPA_ONNX_ENABLE_TTS=ON
mkdir -p apks
{% for tts_model in tts_model_list %}
... ...
#!/usr/bin/env python3
import argparse
from dataclasses import dataclass
from typing import List, Optional
import jinja2
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--total",
type=int,
default=1,
help="Number of runners",
)
parser.add_argument(
"--index",
type=int,
default=0,
help="Index of the current runner",
)
return parser.parse_args()
@dataclass
class Model:
# We will download
# https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/{model_name}.tar.bz2
model_name: str
# The type of the model, e..g, 0, 1, 2. It is hardcoded in the kotlin code
idx: int
# e.g., zh, en, zh_en
lang: str
# e.g., whisper, paraformer, zipformer
short_name: str = ""
# cmd is used to remove extra file from the model directory
cmd: str = ""
def get_models():
models = [
Model(
model_name="sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20",
idx=8,
lang="bilingual_zh_en",
short_name="zipformer",
cmd="""
pushd $model_name
rm -v decoder-epoch-99-avg-1.int8.onnx
rm -v encoder-epoch-99-avg-1.onnx
rm -v joiner-epoch-99-avg-1.onnx
rm -v *.sh
rm -v .gitattributes
rm -v *state*
rm -rfv test_wavs
ls -lh
popd
""",
),
]
return models
def main():
args = get_args()
index = args.index
total = args.total
assert 0 <= index < total, (index, total)
all_model_list = get_models()
num_models = len(all_model_list)
num_per_runner = num_models // total
if num_per_runner <= 0:
raise ValueError(f"num_models: {num_models}, num_runners: {total}")
start = index * num_per_runner
end = start + num_per_runner
remaining = num_models - args.total * num_per_runner
print(f"{index}/{total}: {start}-{end}/{num_models}")
d = dict()
d["model_list"] = all_model_list[start:end]
if index < remaining:
s = args.total * num_per_runner + index
d["model_list"].append(all_model_list[s])
print(f"{s}/{num_models}")
filename_list = [
"./build-apk-asr.sh",
]
for filename in filename_list:
environment = jinja2.Environment()
with open(f"{filename}.in") as f:
s = f.read()
template = environment.from_string(s)
s = template.render(**d)
with open(filename, "w") as f:
print(s, file=f)
if __name__ == "__main__":
main()
... ...
... ... @@ -82,7 +82,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
for (const auto &f : required_files) {
if (!FileExists(dict_dir + "/" + f)) {
SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(),
SHERPA_ONNX_LOGE("'%s/%s' does not exist.", dict_dir.c_str(),
f.c_str());
return false;
}
... ...
... ... @@ -12,8 +12,15 @@ endif()
set(sources
audio-tagging.cc
jni.cc
keyword-spotter.cc
offline-recognizer.cc
offline-stream.cc
online-recognizer.cc
online-stream.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
spoken-language-identification.cc
voice-activity-detector.cc
)
if(SHERPA_ONNX_ENABLE_TTS)
... ...
... ... @@ -6,6 +6,8 @@
#define SHERPA_ONNX_JNI_COMMON_H_
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
... ...
... ... @@ -4,1530 +4,43 @@
// 2022 Pingfeng Luo
// 2023 Zhaoming
// TODO(fangjun): Add documentation to functions/methods in this file
// and also show how to use them with kotlin, possibly with java.
#include <fstream>
#include <functional>
#include <strstream>
#include <utility>
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
class SherpaOnnx {
public:
#if __ANDROID_API__ >= 9
SherpaOnnx(AAssetManager *mgr, const OnlineRecognizerConfig &config)
: recognizer_(mgr, config), stream_(recognizer_.CreateStream()) {}
#endif
explicit SherpaOnnx(const OnlineRecognizerConfig &config)
: recognizer_(config), stream_(recognizer_.CreateStream()) {}
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
std::string GetText() const {
auto result = recognizer_.GetResult(stream_.get());
return result.text;
}
const std::vector<std::string> GetTokens() const {
auto result = recognizer_.GetResult(stream_.get());
return result.tokens;
}
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
// If keywords is an empty string, it just recreates the decoding stream
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
void Reset(bool recreate, const std::string &keywords = {}) {
if (keywords.empty()) {
if (recreate) {
stream_ = recognizer_.CreateStream();
} else {
recognizer_.Reset(stream_.get());
}
} else {
auto stream = recognizer_.CreateStream(keywords);
// Set new keywords failed, the stream_ will not be updated.
if (stream != nullptr) {
stream_ = std::move(stream);
} else {
SHERPA_ONNX_LOGE("Failed to set keywords: %s", keywords.c_str());
}
}
}
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
private:
OnlineRecognizer recognizer_;
std::unique_ptr<OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
class SherpaOnnxOffline {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxOffline(AAssetManager *mgr, const OfflineRecognizerConfig &config)
: recognizer_(mgr, config) {}
#endif
explicit SherpaOnnxOffline(const OfflineRecognizerConfig &config)
: recognizer_(config) {}
std::string Decode(int32_t sample_rate, const float *samples, int32_t n) {
auto stream = recognizer_.CreateStream();
stream->AcceptWaveform(sample_rate, samples, n);
recognizer_.DecodeStream(stream.get());
return stream->GetResult().text;
}
private:
OfflineRecognizer recognizer_;
};
class SherpaOnnxVad {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxVad(AAssetManager *mgr, const VadModelConfig &config)
: vad_(mgr, config) {}
#endif
explicit SherpaOnnxVad(const VadModelConfig &config) : vad_(config) {}
void AcceptWaveform(const float *samples, int32_t n) {
vad_.AcceptWaveform(samples, n);
}
bool Empty() const { return vad_.Empty(); }
void Pop() { vad_.Pop(); }
void Clear() { vad_.Clear(); }
const SpeechSegment &Front() const { return vad_.Front(); }
bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }
void Reset() { vad_.Reset(); }
private:
VoiceActivityDetector vad_;
};
class SherpaOnnxKws {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
: keyword_spotter_(mgr, config),
stream_(keyword_spotter_.CreateStream()) {}
#endif
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
// If keywords is an empty string, it just recreates the decoding stream
// always returns true in this case.
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
// Return false if errors occurred when adding keywords, true otherwise.
bool Reset(const std::string &keywords = {}) {
if (keywords.empty()) {
stream_ = keyword_spotter_.CreateStream();
return true;
} else {
auto stream = keyword_spotter_.CreateStream(keywords);
// Set new keywords failed, the stream_ will not be updated.
if (stream == nullptr) {
return false;
} else {
stream_ = std::move(stream);
return true;
}
}
}
std::string GetKeyword() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.keyword;
}
std::vector<std::string> GetTokens() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.tokens;
}
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
private:
KeywordSpotter keyword_spotter_;
std::unique_ptr<OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
class SherpaOnnxSpeakerEmbeddingExtractorStream {
public:
explicit SherpaOnnxSpeakerEmbeddingExtractorStream(
std::unique_ptr<OnlineStream> stream)
: stream_(std::move(stream)) {}
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const { stream_->InputFinished(); }
OnlineStream *Get() const { return stream_.get(); }
private:
std::unique_ptr<OnlineStream> stream_;
};
class SherpaOnnxSpeakerEmbeddingExtractor {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxSpeakerEmbeddingExtractor(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: extractor_(mgr, config) {}
#endif
explicit SherpaOnnxSpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: extractor_(config) {}
int32_t Dim() const { return extractor_.Dim(); }
bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.IsReady(stream->Get());
}
SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const {
return new SherpaOnnxSpeakerEmbeddingExtractorStream(
extractor_.CreateStream());
}
std::vector<float> Compute(
const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.Compute(stream->Get());
}
private:
SpeakerEmbeddingExtractor extractor_;
};
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
JNIEnv *env, jobject config) {
SpeakerEmbeddingExtractorConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- enable endpoint ----------
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
ans.enable_endpoint = env->GetBooleanField(config, fid);
//---------- endpoint_config ----------
fid = env->GetFieldID(cls, "endpointConfig",
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
jobject endpoint_config = env->GetObjectField(config, fid);
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
fid = env->GetFieldID(endpoint_config_cls, "rule1",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule1 = env->GetObjectField(endpoint_config, fid);
jclass rule_class = env->GetObjectClass(rule1);
fid = env->GetFieldID(endpoint_config_cls, "rule2",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule2 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(endpoint_config_cls, "rule3",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule3 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
ans.endpoint_config.rule1.must_contain_nonsilence =
env->GetBooleanField(rule1, fid);
ans.endpoint_config.rule2.must_contain_nonsilence =
env->GetBooleanField(rule2, fid);
ans.endpoint_config.rule3.must_contain_nonsilence =
env->GetBooleanField(rule3, fid);
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
ans.endpoint_config.rule1.min_trailing_silence =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_trailing_silence =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_trailing_silence =
env->GetFloatField(rule3, fid);
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
ans.endpoint_config.rule1.min_utterance_length =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_utterance_length =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_utterance_length =
env->GetFloatField(rule3, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
jobject lm_model_config = env->GetObjectField(config, fid);
jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(lm_model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.lm_config.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
return ans;
}
static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
OfflineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.model = p;
env->ReleaseStringUTFChars(s, p);
// whisper
fid = env->GetFieldID(model_config_cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");
jobject whisper_config = env->GetObjectField(model_config, fid);
jclass whisper_config_cls = env->GetObjectClass(whisper_config);
fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.language = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.task = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I");
ans.model_config.whisper.tail_paddings =
env->GetIntField(whisper_config, fid);
return ans;
}
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
KeywordSpotterConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.keywords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "keywordsScore", "F");
ans.keywords_score = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
ans.keywords_threshold = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
ans.num_trailing_blanks = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
VadModelConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// silero_vad
fid = env->GetFieldID(cls, "sileroVadModelConfig",
"Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");
jobject silero_vad_config = env->GetObjectField(config, fid);
jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);
fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");
auto s = (jstring)env->GetObjectField(silero_vad_config, fid);
auto p = env->GetStringUTFChars(s, nullptr);
ans.silero_vad.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");
ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");
ans.silero_vad.min_silence_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");
ans.silero_vad.min_speech_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");
ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);
fid = env->GetFieldID(cls, "sampleRate", "I");
ans.sample_rate = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(
ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr)
->CreateStream();
return (jlong)stream;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
return extractor->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
std::vector<float> embedding = extractor->Compute(stream);
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
embedding.data());
return embedding_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
return extractor->Dim();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
stream->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
stream->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_new(
JNIEnv *env, jobject /*obj*/, jint dim) {
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
delete manager;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jstring name,
jfloatArray embedding) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, p);
env->ReleaseStringUTFChars(name, p_name);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jobjectArray embedding_arr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
int num_embeddings = env->GetArrayLength(embedding_arr);
if (num_embeddings == 0) {
return false;
}
std::vector<std::vector<float>> embedding_list;
embedding_list.reserve(num_embeddings);
for (int32_t i = 0; i != num_embeddings; ++i) {
jfloatArray embedding =
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
embedding_list.push_back({p, p + n});
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, embedding_list);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Remove(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jfloatArray embedding,
jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
std::string name = manager->Search(p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return env->NewStringUTF(name.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jfloatArray embedding, jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Verify(p_name, p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Contains(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
return manager->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (auto &s : all_speakers) {
jstring js = env->NewStringUTF(s.c_str());
env->SetObjectArrayElement(obj_arr, i, js);
++i;
}
return obj_arr;
}
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
jobject NewFloat(JNIEnv *env, float value) {
jclass cls = env->FindClass("java/lang/Float");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxVad(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxVad(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
return model->Empty();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Pop();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Clear();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {
const auto &front =
reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr)->Front();
jfloatArray samples_arr = env->NewFloatArray(front.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),
front.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));
env->SetObjectArrayElement(obj_arr, 1, samples_arr);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
return model->IsSpeechDetected();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxOffline(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxOffline(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate,
jstring keywords) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
model->Reset(recreate, p_keywords);
env->ReleaseStringUTFChars(keywords, p_keywords);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsEndpoint();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto text = model->Decode(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens();
int32_t size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
// Convert the C string to a jstring
jstring jstr = env->NewStringUTF(cstr);
// Set the array element
env->SetObjectArrayElement(result, i, jstr);
}
return result;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(config);
return (jlong)model;
}
#include <fstream>
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
}
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
return model->IsReady();
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
model->Decode();
jobject NewFloat(JNIEnv *env, float value) {
jclass cls = env->FindClass("java/lang/Float");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(sample_rate, p, n);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
std::string keywords_str = p_keywords;
bool status =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
env->ReleaseStringUTFChars(keywords, p_keywords);
return status;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
auto tokens =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
int32_t size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
// Convert the C string to a jstring
jstring jstr = env->NewStringUTF(cstr);
// Set the array element
env->SetObjectArrayElement(result, i, jstr);
}
return result;
return ok;
}
static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is,
... ... @@ -1593,81 +106,7 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(
return obj_arr;
}
// ******warpper for OnlineRecognizer*******
// wav reader for java interface
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env,
jclass /*cls*/,
jstring filename) {
auto data =
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(
env, nullptr, nullptr, filename);
return data;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createOnlineRecognizer(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
sherpa_onnx::OnlineRecognizerConfig config =
sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto p_recognizer = new sherpa_onnx::OnlineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)p_recognizer;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_deleteOnlineRecognizer(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
std::unique_ptr<sherpa_onnx::OnlineStream> s =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr)->CreateStream();
sherpa_onnx::OnlineStream *p_stream = s.release();
return reinterpret_cast<jlong>(p_stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {
sherpa_onnx::OnlineRecognizer *model =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
return model->IsReady(s);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStream(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {
sherpa_onnx::OnlineRecognizer *model =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
model->DecodeStream(s);
}
#if 0
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env,
... ... @@ -1687,92 +126,4 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env,
model->DecodeStreams(p_ss.data(), n);
env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {
sherpa_onnx::OnlineRecognizer *model =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
return env->NewStringUTF(result.text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {
sherpa_onnx::OnlineRecognizer *model =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
return model->IsEndpoint(s);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reSet(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {
sherpa_onnx::OnlineRecognizer *model =
reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
model->Reset(s);
}
// *********for OnlineStream *********
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint sample_rate,
jfloatArray waveform) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
jfloat *p = env->GetFloatArrayElements(waveform, nullptr);
jsize n = env->GetArrayLength(waveform);
s->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(waveform, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong s_ptr) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
s->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_deleteStream(
JNIEnv *env, jobject /*obj*/, jlong s_ptr) {
delete reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_numFramesReady(
JNIEnv *env, jobject /*obj*/, jlong s_ptr) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
return s->NumFramesReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_isLastFrame(
JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint frame) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
return s->IsLastFrame(frame);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_reSet(
JNIEnv *env, jobject /*obj*/, jlong s_ptr) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
s->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_featureDim(
JNIEnv *env, jobject /*obj*/, jlong s_ptr) {
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
return s->FeatureDim();
}
#endif
... ...
// sherpa-onnx/jni/keyword-spotter.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
KeywordSpotterConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.keywords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "keywordsScore", "F");
ans.keywords_score = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
ans.keywords_threshold = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
ans.num_trailing_blanks = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto kws = new sherpa_onnx::KeywordSpotter(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)kws;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto kws = new sherpa_onnx::KeywordSpotter(config);
return (jlong)kws;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
kws->DecodeStream(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
const char *p = env->GetStringUTFChars(keywords, nullptr);
std::unique_ptr<sherpa_onnx::OnlineStream> stream;
if (strlen(p) == 0) {
stream = kws->CreateStream();
} else {
stream = kws->CreateStream(p);
}
env->ReleaseStringUTFChars(keywords, p);
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OnlineStream *ans = stream.release();
return (jlong)ans;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
return kws->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_KeywordSpotter_getResult(JNIEnv *env,
jobject /*obj*/, jlong ptr,
jlong stream_ptr) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
sherpa_onnx::KeywordResult result = kws->GetResult(stream);
// [0]: keyword, jstring
// [1]: tokens, array of jstring
// [2]: timestamps, array of float
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
3, env->FindClass("java/lang/Object"), nullptr);
jstring keyword = env->NewStringUTF(result.keyword.c_str());
env->SetObjectArrayElement(obj_arr, 0, keyword);
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (const auto &t : result.tokens) {
jstring jtext = env->NewStringUTF(t.c_str());
env->SetObjectArrayElement(tokens_arr, i, jtext);
i += 1;
}
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
result.timestamps.data());
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
return obj_arr;
}
... ...
// sherpa-onnx/jni/offline-recognizer.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
OfflineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.model = p;
env->ReleaseStringUTFChars(s, p);
// whisper
fid = env->GetFieldID(model_config_cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");
jobject whisper_config = env->GetObjectField(model_config, fid);
jclass whisper_config_cls = env->GetObjectClass(whisper_config);
fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.language = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.task = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I");
ans.model_config.whisper.tail_paddings =
env->GetIntField(whisper_config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OfflineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto model = new sherpa_onnx::OfflineRecognizer(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
std::unique_ptr<sherpa_onnx::OfflineStream> s = recognizer->CreateStream();
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OfflineStream *p = s.release();
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
recognizer->DecodeStream(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
jobject /*obj*/,
jlong streamPtr) {
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
sherpa_onnx::OfflineRecognitionResult result = stream->GetResult();
// [0]: text, jstring
// [1]: tokens, array of jstring
// [2]: timestamps, array of float
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
3, env->FindClass("java/lang/Object"), nullptr);
jstring text = env->NewStringUTF(result.text.c_str());
env->SetObjectArrayElement(obj_arr, 0, text);
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (const auto &t : result.tokens) {
jstring jtext = env->NewStringUTF(t.c_str());
env->SetObjectArrayElement(tokens_arr, i, jtext);
i += 1;
}
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
result.timestamps.data());
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
return obj_arr;
}
... ...
// sherpa-onnx/jni/online-recognizer.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- enable endpoint ----------
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
ans.enable_endpoint = env->GetBooleanField(config, fid);
//---------- endpoint_config ----------
fid = env->GetFieldID(cls, "endpointConfig",
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
jobject endpoint_config = env->GetObjectField(config, fid);
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
fid = env->GetFieldID(endpoint_config_cls, "rule1",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule1 = env->GetObjectField(endpoint_config, fid);
jclass rule_class = env->GetObjectClass(rule1);
fid = env->GetFieldID(endpoint_config_cls, "rule2",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule2 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(endpoint_config_cls, "rule3",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule3 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
ans.endpoint_config.rule1.must_contain_nonsilence =
env->GetBooleanField(rule1, fid);
ans.endpoint_config.rule2.must_contain_nonsilence =
env->GetBooleanField(rule2, fid);
ans.endpoint_config.rule3.must_contain_nonsilence =
env->GetBooleanField(rule3, fid);
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
ans.endpoint_config.rule1.min_trailing_silence =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_trailing_silence =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_trailing_silence =
env->GetFloatField(rule3, fid);
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
ans.endpoint_config.rule1.min_utterance_length =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_utterance_length =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_utterance_length =
env->GetFloatField(rule3, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
jobject lm_model_config = env->GetObjectField(config, fid);
jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(lm_model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.lm_config.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto recognizer = new sherpa_onnx::OnlineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)recognizer;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto recognizer = new sherpa_onnx::OnlineRecognizer(config);
return (jlong)recognizer;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
recognizer->Reset(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
return recognizer->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
return recognizer->IsEndpoint(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
recognizer->DecodeStream(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring hotwords) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
const char *p = env->GetStringUTFChars(hotwords, nullptr);
std::unique_ptr<sherpa_onnx::OnlineStream> stream;
if (strlen(p) == 0) {
stream = recognizer->CreateStream();
} else {
stream = recognizer->CreateStream(p);
}
env->ReleaseStringUTFChars(hotwords, p);
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OnlineStream *ans = stream.release();
return (jlong)ans;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
sherpa_onnx::OnlineRecognizerResult result = recognizer->GetResult(stream);
// [0]: text, jstring
// [1]: tokens, array of jstring
// [2]: timestamps, array of float
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
3, env->FindClass("java/lang/Object"), nullptr);
jstring text = env->NewStringUTF(result.text.c_str());
env->SetObjectArrayElement(obj_arr, 0, text);
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (const auto &t : result.tokens) {
jstring jtext = env->NewStringUTF(t.c_str());
env->SetObjectArrayElement(tokens_arr, i, jtext);
i += 1;
}
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
result.timestamps.data());
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
return obj_arr;
}
... ...
// sherpa-onnx/jni/online-stream.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/jni/common.h"
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
stream->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
stream->InputFinished();
}
... ...
// sherpa-onnx/jni/speaker-embedding-extractor.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
JNIEnv *env, jobject config) {
SpeakerEmbeddingExtractorConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
std::unique_ptr<sherpa_onnx::OnlineStream> s =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr)
->CreateStream();
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OnlineStream_delete() from
// ./online-stream.cc
sherpa_onnx::OnlineStream *p = s.release();
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
return extractor->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
std::vector<float> embedding = extractor->Compute(stream);
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
embedding.data());
return embedding_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
return extractor->Dim();
}
... ...
// sherpa-onnx/jni/speaker-embedding-manager.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_create(JNIEnv *env,
jobject /*obj*/,
jint dim) {
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
delete manager;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jstring name,
jfloatArray embedding) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, p);
env->ReleaseStringUTFChars(name, p_name);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jobjectArray embedding_arr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
int num_embeddings = env->GetArrayLength(embedding_arr);
if (num_embeddings == 0) {
return false;
}
std::vector<std::vector<float>> embedding_list;
embedding_list.reserve(num_embeddings);
for (int32_t i = 0; i != num_embeddings; ++i) {
jfloatArray embedding =
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
embedding_list.push_back({p, p + n});
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, embedding_list);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Remove(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jfloatArray embedding,
jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
std::string name = manager->Search(p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return env->NewStringUTF(name.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jfloatArray embedding, jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Verify(p_name, p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Contains(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
return manager->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (auto &s : all_speakers) {
jstring js = env->NewStringUTF(s.c_str());
env->SetObjectArrayElement(obj_arr, i, js);
++i;
}
return obj_arr;
}
... ...
// sherpa-onnx/csrc/voice-activity-detector.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
VadModelConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// silero_vad
fid = env->GetFieldID(cls, "sileroVadModelConfig",
"Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");
jobject silero_vad_config = env->GetObjectField(config, fid);
jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);
fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");
auto s = (jstring)env->GetObjectField(silero_vad_config, fid);
auto p = env->GetStringUTFChars(s, nullptr);
ans.silero_vad.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");
ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");
ans.silero_vad.min_silence_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");
ans.silero_vad.min_speech_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");
ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);
fid = env->GetFieldID(cls, "sampleRate", "I");
ans.sample_rate = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::VoiceActivityDetector(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto model = new sherpa_onnx::VoiceActivityDetector(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
return model->Empty();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
model->Pop();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
model->Clear();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {
const auto &front =
reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr)->Front();
jfloatArray samples_arr = env->NewFloatArray(front.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),
front.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));
env->SetObjectArrayElement(obj_arr, 1, samples_arr);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
return model->IsSpeechDetected();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
model->Reset();
}
... ...
... ... @@ -2,8 +2,6 @@ package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
const val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig(
var model: String = "",
)
... ...
package com.k2fsa.sherpa.onnx
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
... ...
... ... @@ -3,26 +3,6 @@ package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OnlineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
data class KeywordSpotterConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
... ... @@ -33,17 +13,24 @@ data class KeywordSpotterConfig(
var numTrailingBlanks: Int = 2,
)
class SherpaOnnxKws(
data class KeywordSpotterResult(
val keyword: String,
val tokens: Array<String>,
val timestamps: FloatArray,
// TODO(fangjun): Add more fields
)
class KeywordSpotter(
assetManager: AssetManager? = null,
var config: KeywordSpotterConfig,
val config: KeywordSpotterConfig,
) {
private val ptr: Long
init {
if (assetManager != null) {
ptr = new(assetManager, config)
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
newFromFile(config)
}
}
... ... @@ -51,20 +38,28 @@ class SherpaOnnxKws(
delete(ptr)
}
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun release() = finalize()
fun inputFinished() = inputFinished(ptr)
fun decode() = decode(ptr)
fun isReady(): Boolean = isReady(ptr)
fun reset(keywords: String): Boolean = reset(ptr, keywords)
fun createStream(keywords: String = ""): OnlineStream {
val p = createStream(ptr, keywords)
return OnlineStream(p)
}
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
fun getResult(stream: OnlineStream): KeywordSpotterResult {
val objArray = getResult(ptr, stream.ptr)
val keyword: String
get() = getKeyword(ptr)
val keyword = objArray[0] as String
val tokens = objArray[1] as Array<String>
val timestamps = objArray[2] as FloatArray
return KeywordSpotterResult(keyword = keyword, tokens = tokens, timestamps = timestamps)
}
private external fun delete(ptr: Long)
private external fun new(
private external fun newFromAsset(
assetManager: AssetManager,
config: KeywordSpotterConfig,
): Long
... ... @@ -73,12 +68,10 @@ class SherpaOnnxKws(
config: KeywordSpotterConfig,
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getKeyword(ptr: Long): String
private external fun reset(ptr: Long, keywords: String): Boolean
private external fun decode(ptr: Long)
private external fun isReady(ptr: Long): Boolean
private external fun createStream(ptr: Long, keywords: String): Long
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
private external fun decode(ptr: Long, streamPtr: Long)
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
companion object {
init {
... ... @@ -87,10 +80,6 @@ class SherpaOnnxKws(
}
}
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
... ... @@ -108,7 +97,7 @@ by following the code)
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
*/
fun getModelConfig(type: Int): OnlineModelConfig? {
fun getKwsModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
... ... @@ -137,15 +126,15 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
}
}
return null;
return null
}
/*
* Get the default keywords for each model.
* Caution: The types and modelDir should be the same as those in getModelConfig
* function above.
*/
fun getKeywordsFile(type: Int) : String {
*/
fun getKeywordsFile(type: Int): String {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
... ... @@ -158,5 +147,5 @@ fun getKeywordsFile(type: Int) : String {
}
}
return "";
return ""
}
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineRecognizerResult(
val text: String,
val tokens: Array<String>,
val timestamps: FloatArray,
)
data class OfflineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OfflineParaformerModelConfig(
var model: String = "",
)
data class OfflineWhisperModelConfig(
var encoder: String = "",
var decoder: String = "",
var language: String = "en", // Used with multilingual model
var task: String = "transcribe", // transcribe or translate
var tailPaddings: Int = 1000, // Padding added at the end of the samples
)
data class OfflineModelConfig(
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
var tokens: String,
)
data class OfflineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OfflineModelConfig,
// var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)
class OfflineRecognizer(
assetManager: AssetManager? = null,
config: OfflineRecognizerConfig,
) {
private val ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
}
fun release() = finalize()
fun createStream(): OfflineStream {
val p = createStream(ptr)
return OfflineStream(p)
}
fun getResult(stream: OfflineStream): OfflineRecognizerResult {
val objArray = getResult(stream.ptr)
val text = objArray[0] as String
val tokens = objArray[1] as Array<String>
val timestamps = objArray[2] as FloatArray
return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
}
fun decode(stream: OfflineStream) = decode(ptr, stream.ptr)
private external fun delete(ptr: Long)
private external fun createStream(ptr: Long): Long
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineRecognizerConfig,
): Long
private external fun newFromFile(
config: OfflineRecognizerConfig,
): Long
private external fun decode(ptr: Long, streamPtr: Long)
private external fun getResult(streamPtr: Long): Array<Any>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese
int8
1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english
encoder int8, decoder/joiner float32
2 - sherpa-onnx-whisper-tiny.en
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
encoder int8, decoder int8
3 - sherpa-onnx-whisper-base.en
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
encoder int8, decoder int8
4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese
encoder/joiner int8, decoder fp32
*/
fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"
return OfflineModelConfig(
paraformer = OfflineParaformerModelConfig(
model = "$modelDir/model.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
1 -> {
val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx",
decoder = "$modelDir/decoder-epoch-30-avg-4.onnx",
joiner = "$modelDir/joiner-epoch-30-avg-4.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
2 -> {
val modelDir = "sherpa-onnx-whisper-tiny.en"
return OfflineModelConfig(
whisper = OfflineWhisperModelConfig(
encoder = "$modelDir/tiny.en-encoder.int8.onnx",
decoder = "$modelDir/tiny.en-decoder.int8.onnx",
),
tokens = "$modelDir/tiny.en-tokens.txt",
modelType = "whisper",
)
}
3 -> {
val modelDir = "sherpa-onnx-whisper-base.en"
return OfflineModelConfig(
whisper = OfflineWhisperModelConfig(
encoder = "$modelDir/base.en-encoder.int8.onnx",
decoder = "$modelDir/base.en-decoder.int8.onnx",
),
tokens = "$modelDir/base.en-tokens.txt",
modelType = "whisper",
)
}
4 -> {
val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",
decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",
joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
5 -> {
val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-20-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
)
}
}
return null
}
... ...
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
... ... @@ -46,15 +45,11 @@ data class OnlineLMConfig(
var scale: Float = 0.5f,
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig,
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
... ... @@ -63,17 +58,24 @@ data class OnlineRecognizerConfig(
var hotwordsScore: Float = 1.5f,
)
class SherpaOnnx(
data class OnlineRecognizerResult(
val text: String,
val tokens: Array<String>,
val timestamps: FloatArray,
// TODO(fangjun): Add more fields
)
class OnlineRecognizer(
assetManager: AssetManager? = null,
var config: OnlineRecognizerConfig,
val config: OnlineRecognizerConfig,
) {
private val ptr: Long
init {
if (assetManager != null) {
ptr = new(assetManager, config)
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
newFromFile(config)
}
}
... ... @@ -81,24 +83,30 @@ class SherpaOnnx(
delete(ptr)
}
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun release() = finalize()
fun createStream(hotwords: String = ""): OnlineStream {
val p = createStream(ptr, hotwords)
return OnlineStream(p)
}
fun inputFinished() = inputFinished(ptr)
fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)
fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
fun isEndpoint(stream: OnlineStream) = isEndpoint(ptr, stream.ptr)
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
fun getResult(stream: OnlineStream): OnlineRecognizerResult {
val objArray = getResult(ptr, stream.ptr)
val text: String
get() = getText(ptr)
val text = objArray[0] as String
val tokens = objArray[1] as Array<String>
val timestamps = objArray[2] as FloatArray
val tokens: Array<String>
get() = getTokens(ptr)
return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
}
private external fun delete(ptr: Long)
private external fun new(
private external fun newFromAsset(
assetManager: AssetManager,
config: OnlineRecognizerConfig,
): Long
... ... @@ -107,14 +115,12 @@ class SherpaOnnx(
config: OnlineRecognizerConfig,
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
private external fun getTokens(ptr: Long): Array<String>
private external fun createStream(ptr: Long, hotwords: String): Long
private external fun reset(ptr: Long, streamPtr: Long)
private external fun decode(ptr: Long, streamPtr: Long)
private external fun isEndpoint(ptr: Long, streamPtr: Long): Boolean
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
companion object {
init {
... ... @@ -123,9 +129,6 @@ class SherpaOnnx(
}
}
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
/*
Please see
... ... @@ -277,14 +280,40 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
9 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
10 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
}
return null;
return null
}
/*
... ... @@ -310,7 +339,7 @@ fun getOnlineLMConfig(type: Int): OnlineLMConfig {
)
}
}
return OnlineLMConfig();
return OnlineLMConfig()
}
fun getEndpointConfig(): EndpointConfig {
... ... @@ -320,3 +349,4 @@ fun getEndpointConfig(): EndpointConfig {
rule3 = EndpointRule(false, 0.0f, 20.0f)
)
}
... ...
package com.k2fsa.sherpa.onnx
class OnlineStream(var ptr: Long = 0) {
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun delete(ptr: Long)
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
... ...
... ... @@ -3,7 +3,6 @@ package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx"
data class SpeakerEmbeddingExtractorConfig(
val model: String,
var numThreads: Int = 1,
... ... @@ -11,33 +10,6 @@ data class SpeakerEmbeddingExtractorConfig(
var provider: String = "cpu",
)
class SpeakerEmbeddingExtractorStream(var ptr: Long) {
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
protected fun finalize() {
delete(ptr)
ptr = 0
}
private external fun myTest(ptr: Long, v: Array<FloatArray>)
fun release() = finalize()
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun delete(ptr: Long)
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
class SpeakerEmbeddingExtractor(
assetManager: AssetManager? = null,
config: SpeakerEmbeddingExtractorConfig,
... ... @@ -46,29 +18,31 @@ class SpeakerEmbeddingExtractor(
init {
ptr = if (assetManager != null) {
new(assetManager, config)
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
ptr = 0
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun createStream(): SpeakerEmbeddingExtractorStream {
fun createStream(): OnlineStream {
val p = createStream(ptr)
return SpeakerEmbeddingExtractorStream(p)
return OnlineStream(p)
}
fun isReady(stream: SpeakerEmbeddingExtractorStream) = isReady(ptr, stream.ptr)
fun compute(stream: SpeakerEmbeddingExtractorStream) = compute(ptr, stream.ptr)
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
fun compute(stream: OnlineStream) = compute(ptr, stream.ptr)
fun dim() = dim(ptr)
private external fun new(
private external fun newFromAsset(
assetManager: AssetManager,
config: SpeakerEmbeddingExtractorConfig,
): Long
... ... @@ -98,12 +72,14 @@ class SpeakerEmbeddingManager(val dim: Int) {
private var ptr: Long
init {
ptr = new(dim)
ptr = create(dim)
}
protected fun finalize() {
delete(ptr)
ptr = 0
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
... ... @@ -119,7 +95,7 @@ class SpeakerEmbeddingManager(val dim: Int) {
fun allSpeakerNames() = allSpeakerNames(ptr)
private external fun new(dim: Int): Long
private external fun create(dim: Int): Long
private external fun delete(ptr: Long): Unit
private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean
private external fun addList(ptr: Long, name: String, embedding: Array<FloatArray>): Boolean
... ... @@ -170,7 +146,7 @@ object SpeakerRecognition {
if (_extractor != null) {
return
}
Log.i(TAG, "Initializing speaker embedding extractor")
Log.i("sherpa-onnx", "Initializing speaker embedding extractor")
_extractor = SpeakerEmbeddingExtractor(
assetManager = assetManager,
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx"
data class SpokenLanguageIdentificationWhisperConfig (
data class SpokenLanguageIdentificationWhisperConfig(
var encoder: String,
var decoder: String,
var tailPaddings: Int = -1,
)
data class SpokenLanguageIdentificationConfig (
data class SpokenLanguageIdentificationConfig(
var whisper: SpokenLanguageIdentificationWhisperConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
class SpokenLanguageIdentification (
class SpokenLanguageIdentification(
assetManager: AssetManager? = null,
config: SpokenLanguageIdentificationConfig,
) {
... ... @@ -46,7 +43,7 @@ class SpokenLanguageIdentification (
return OfflineStream(p)
}
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)
private external fun newFromAsset(
assetManager: AssetManager,
... ... @@ -69,10 +66,14 @@ class SpokenLanguageIdentification (
}
}
}
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper
// to download more models
fun getSpokenLanguageIdentificationConfig(type: Int, numThreads: Int=1): SpokenLanguageIdentificationConfig? {
fun getSpokenLanguageIdentificationConfig(
type: Int,
numThreads: Int = 1
): SpokenLanguageIdentificationConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-whisper-tiny"
... ...
... ... @@ -27,7 +27,7 @@ class Vad(
init {
if (assetManager != null) {
ptr = new(assetManager, config)
ptr = newFromAsset(assetManager, config)
} else {
ptr = newFromFile(config)
}
... ... @@ -54,7 +54,7 @@ class Vad(
private external fun delete(ptr: Long)
private external fun new(
private external fun newFromAsset(
assetManager: AssetManager,
config: VadModelConfig,
): Long
... ...