Fangjun Kuang
Committed by GitHub

Refactor the JNI interface to make it more modular and maintainable (#802)

正在显示 117 个修改的文件 包含 2949 行增加2763 行删除
  1 +name: apk-asr
  2 +
  3 +on:
  4 + push:
  5 + tags:
  6 + - '*'
  7 +
  8 + workflow_dispatch:
  9 +
  10 +concurrency:
  11 + group: apk-asr-${{ github.ref }}
  12 + cancel-in-progress: true
  13 +
  14 +permissions:
  15 + contents: write
  16 +
  17 +jobs:
  18 + apk_asr:
  19 + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
  20 + runs-on: ${{ matrix.os }}
  21 + name: apk for asr ${{ matrix.index }}/${{ matrix.total }}
  22 + strategy:
  23 + fail-fast: false
  24 + matrix:
  25 + os: [ubuntu-latest]
  26 + total: ["1"]
  27 + index: ["0"]
  28 +
  29 + steps:
  30 + - uses: actions/checkout@v4
  31 + with:
  32 + fetch-depth: 0
  33 +
  34 + # https://github.com/actions/setup-java
  35 + - uses: actions/setup-java@v4
  36 + with:
  37 + distribution: 'temurin' # See 'Supported distributions' for available options
  38 + java-version: '21'
  39 +
  40 + - name: ccache
  41 + uses: hendrikmuhs/ccache-action@v1.2
  42 + with:
  43 + key: ${{ matrix.os }}-android
  44 +
  45 + - name: Display NDK HOME
  46 + shell: bash
  47 + run: |
  48 + echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
  49 + ls -lh ${ANDROID_NDK_LATEST_HOME}
  50 +
  51 + - name: Install Python dependencies
  52 + shell: bash
  53 + run: |
  54 + python3 -m pip install --upgrade pip jinja2
  55 +
  56 + - name: Setup build tool version variable
  57 + shell: bash
  58 + run: |
  59 + echo "---"
  60 + ls -lh /usr/local/lib/android/
  61 + echo "---"
  62 +
  63 + ls -lh /usr/local/lib/android/sdk
  64 + echo "---"
  65 +
  66 + ls -lh /usr/local/lib/android/sdk/build-tools
  67 + echo "---"
  68 +
  69 + BUILD_TOOL_VERSION=$(ls /usr/local/lib/android/sdk/build-tools/ | tail -n 1)
  70 + echo "BUILD_TOOL_VERSION=$BUILD_TOOL_VERSION" >> $GITHUB_ENV
  71 + echo "Last build tool version is: $BUILD_TOOL_VERSION"
  72 +
  73 + - name: Generate build script
  74 + shell: bash
  75 + run: |
  76 + cd scripts/apk
  77 +
  78 + total=${{ matrix.total }}
  79 + index=${{ matrix.index }}
  80 +
  81 + ./generate-asr-apk-script.py --total $total --index $index
  82 +
  83 + chmod +x build-apk-asr.sh
  84 + mv -v ./build-apk-asr.sh ../..
  85 +
  86 + - name: build APK
  87 + shell: bash
  88 + run: |
  89 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  90 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  91 + cmake --version
  92 +
  93 + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
  94 + ./build-apk-asr.sh
  95 +
  96 + - name: Display APK
  97 + shell: bash
  98 + run: |
  99 + ls -lh ./apks/
  100 + du -h -d1 .
  101 +
  102 + # https://github.com/marketplace/actions/sign-android-release
  103 + - uses: r0adkll/sign-android-release@v1
  104 + name: Sign app APK
  105 + with:
  106 + releaseDirectory: ./apks
  107 + signingKeyBase64: ${{ secrets.ANDROID_SIGNING_KEY }}
  108 + alias: ${{ secrets.ANDROID_SIGNING_KEY_ALIAS }}
  109 + keyStorePassword: ${{ secrets.ANDROID_SIGNING_KEY_STORE_PASSWORD }}
  110 + env:
  111 + BUILD_TOOLS_VERSION: ${{ env.BUILD_TOOL_VERSION }}
  112 +
  113 + - name: Display APK after signing
  114 + shell: bash
  115 + run: |
  116 + ls -lh ./apks/
  117 + du -h -d1 .
  118 +
  119 + - name: Rename APK after signing
  120 + shell: bash
  121 + run: |
  122 + cd apks
  123 + rm -fv signingKey.jks
  124 + rm -fv *.apk.idsig
  125 + rm -fv *-aligned.apk
  126 +
  127 + all_apks=$(ls -1 *-signed.apk)
  128 + echo "----"
  129 + echo $all_apks
  130 + echo "----"
  131 + for apk in ${all_apks[@]}; do
  132 + n=$(echo $apk | sed -e s/-signed//)
  133 + mv -v $apk $n
  134 + done
  135 +
  136 + cd ..
  137 +
  138 + ls -lh ./apks/
  139 + du -h -d1 .
  140 +
  141 + - name: Display APK after rename
  142 + shell: bash
  143 + run: |
  144 + ls -lh ./apks/
  145 + du -h -d1 .
  146 +
  147 + - name: Publish to huggingface
  148 + env:
  149 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  150 + uses: nick-fields/retry@v3
  151 + with:
  152 + max_attempts: 20
  153 + timeout_seconds: 200
  154 + shell: bash
  155 + command: |
  156 + git config --global user.email "csukuangfj@gmail.com"
  157 + git config --global user.name "Fangjun Kuang"
  158 +
  159 + rm -rf huggingface
  160 + export GIT_LFS_SKIP_SMUDGE=1
  161 +
  162 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
  163 + cd huggingface
  164 + git fetch
  165 + git pull
  166 + git merge -m "merge remote" --ff origin main
  167 +
  168 + mkdir -p asr
  169 + cp -v ../apks/*.apk ./asr/
  170 + git status
  171 + git lfs track "*.apk"
  172 + git add .
  173 + git commit -m "add more apks"
  174 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-apk main
@@ -95,3 +95,4 @@ sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 @@ -95,3 +95,4 @@ sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
95 spoken-language-identification-test-wavs 95 spoken-language-identification-test-wavs
96 my-release-key* 96 my-release-key*
97 vits-zh-hf-fanchen-C 97 vits-zh-hf-fanchen-C
  98 +sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 tools:targetApi="31"> 16 tools:targetApi="31">
17 <activity 17 <activity
18 android:name=".MainActivity" 18 android:name=".MainActivity"
  19 + android:label="ASR: Next-gen Kaldi"
19 android:exported="true"> 20 android:exported="true">
20 <intent-filter> 21 <intent-filter>
21 <action android:name="android.intent.action.MAIN" /> 22 <action android:name="android.intent.action.MAIN" />
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
@@ -12,16 +12,19 @@ import android.widget.Button @@ -12,16 +12,19 @@ import android.widget.Button
12 import android.widget.TextView 12 import android.widget.TextView
13 import androidx.appcompat.app.AppCompatActivity 13 import androidx.appcompat.app.AppCompatActivity
14 import androidx.core.app.ActivityCompat 14 import androidx.core.app.ActivityCompat
15 -import com.k2fsa.sherpa.onnx.*  
16 import kotlin.concurrent.thread 15 import kotlin.concurrent.thread
17 16
18 private const val TAG = "sherpa-onnx" 17 private const val TAG = "sherpa-onnx"
19 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 18 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
20 19
  20 +// To enable microphone in android emulator, use
  21 +//
  22 +// adb emu avd hostmicon
  23 +
21 class MainActivity : AppCompatActivity() { 24 class MainActivity : AppCompatActivity() {
22 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 25 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
23 26
24 - private lateinit var model: SherpaOnnx 27 + private lateinit var recognizer: OnlineRecognizer
25 private var audioRecord: AudioRecord? = null 28 private var audioRecord: AudioRecord? = null
26 private lateinit var recordButton: Button 29 private lateinit var recordButton: Button
27 private lateinit var textView: TextView 30 private lateinit var textView: TextView
@@ -87,7 +90,6 @@ class MainActivity : AppCompatActivity() { @@ -87,7 +90,6 @@ class MainActivity : AppCompatActivity() {
87 audioRecord!!.startRecording() 90 audioRecord!!.startRecording()
88 recordButton.setText(R.string.stop) 91 recordButton.setText(R.string.stop)
89 isRecording = true 92 isRecording = true
90 - model.reset(true)  
91 textView.text = "" 93 textView.text = ""
92 lastText = "" 94 lastText = ""
93 idx = 0 95 idx = 0
@@ -108,6 +110,7 @@ class MainActivity : AppCompatActivity() { @@ -108,6 +110,7 @@ class MainActivity : AppCompatActivity() {
108 110
109 private fun processSamples() { 111 private fun processSamples() {
110 Log.i(TAG, "processing samples") 112 Log.i(TAG, "processing samples")
  113 + val stream = recognizer.createStream()
111 114
112 val interval = 0.1 // i.e., 100 ms 115 val interval = 0.1 // i.e., 100 ms
113 val bufferSize = (interval * sampleRateInHz).toInt() // in samples 116 val bufferSize = (interval * sampleRateInHz).toInt() // in samples
@@ -117,29 +120,41 @@ class MainActivity : AppCompatActivity() { @@ -117,29 +120,41 @@ class MainActivity : AppCompatActivity() {
117 val ret = audioRecord?.read(buffer, 0, buffer.size) 120 val ret = audioRecord?.read(buffer, 0, buffer.size)
118 if (ret != null && ret > 0) { 121 if (ret != null && ret > 0) {
119 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 122 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
120 - model.acceptWaveform(samples, sampleRate=sampleRateInHz)  
121 - while (model.isReady()) {  
122 - model.decode() 123 + stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
  124 + while (recognizer.isReady(stream)) {
  125 + recognizer.decode(stream)
123 } 126 }
124 127
125 - val isEndpoint = model.isEndpoint()  
126 - val text = model.text 128 + val isEndpoint = recognizer.isEndpoint(stream)
  129 + var text = recognizer.getResult(stream).text
127 130
128 - var textToDisplay = lastText; 131 + // For streaming parformer, we need to manually add some
  132 + // paddings so that it has enough right context to
  133 + // recognize the last word of this segment
  134 + if (isEndpoint && recognizer.config.modelConfig.paraformer.encoder.isNotBlank()) {
  135 + val tailPaddings = FloatArray((0.8 * sampleRateInHz).toInt())
  136 + stream.acceptWaveform(tailPaddings, sampleRate = sampleRateInHz)
  137 + while (recognizer.isReady(stream)) {
  138 + recognizer.decode(stream)
  139 + }
  140 + text = recognizer.getResult(stream).text
  141 + }
129 142
130 - if(text.isNotBlank()) {  
131 - if (lastText.isBlank()) {  
132 - textToDisplay = "${idx}: ${text}" 143 + var textToDisplay = lastText
  144 +
  145 + if (text.isNotBlank()) {
  146 + textToDisplay = if (lastText.isBlank()) {
  147 + "${idx}: $text"
133 } else { 148 } else {
134 - textToDisplay = "${lastText}\n${idx}: ${text}" 149 + "${lastText}\n${idx}: $text"
135 } 150 }
136 } 151 }
137 152
138 if (isEndpoint) { 153 if (isEndpoint) {
139 - model.reset() 154 + recognizer.reset(stream)
140 if (text.isNotBlank()) { 155 if (text.isNotBlank()) {
141 - lastText = "${lastText}\n${idx}: ${text}"  
142 - textToDisplay = lastText; 156 + lastText = "${lastText}\n${idx}: $text"
  157 + textToDisplay = lastText
143 idx += 1 158 idx += 1
144 } 159 }
145 } 160 }
@@ -149,6 +164,7 @@ class MainActivity : AppCompatActivity() { @@ -149,6 +164,7 @@ class MainActivity : AppCompatActivity() {
149 } 164 }
150 } 165 }
151 } 166 }
  167 + stream.release()
152 } 168 }
153 169
154 private fun initMicrophone(): Boolean { 170 private fun initMicrophone(): Boolean {
@@ -180,7 +196,7 @@ class MainActivity : AppCompatActivity() { @@ -180,7 +196,7 @@ class MainActivity : AppCompatActivity() {
180 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 196 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
181 // for a list of available models 197 // for a list of available models
182 val type = 0 198 val type = 0
183 - println("Select model type ${type}") 199 + Log.i(TAG, "Select model type $type")
184 val config = OnlineRecognizerConfig( 200 val config = OnlineRecognizerConfig(
185 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 201 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
186 modelConfig = getModelConfig(type = type)!!, 202 modelConfig = getModelConfig(type = type)!!,
@@ -189,7 +205,7 @@ class MainActivity : AppCompatActivity() { @@ -189,7 +205,7 @@ class MainActivity : AppCompatActivity() {
189 enableEndpoint = true, 205 enableEndpoint = true,
190 ) 206 )
191 207
192 - model = SherpaOnnx( 208 + recognizer = OnlineRecognizer(
193 assetManager = application.assets, 209 assetManager = application.assets,
194 config = config, 210 config = config,
195 ) 211 )
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/WaveReader.kt
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 tools:targetApi="31"> 16 tools:targetApi="31">
17 <activity 17 <activity
18 android:name=".MainActivity" 18 android:name=".MainActivity"
  19 + android:label="2pass ASR: Next-gen Kaldi"
19 android:exported="true"> 20 android:exported="true">
20 <intent-filter> 21 <intent-filter>
21 <action android:name="android.intent.action.MAIN" /> 22 <action android:name="android.intent.action.MAIN" />
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
@@ -17,11 +17,13 @@ import kotlin.concurrent.thread @@ -17,11 +17,13 @@ import kotlin.concurrent.thread
17 private const val TAG = "sherpa-onnx" 17 private const val TAG = "sherpa-onnx"
18 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 18 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
19 19
  20 +// adb emu avd hostmicon
  21 +// to enable microphone inside the emulator
20 class MainActivity : AppCompatActivity() { 22 class MainActivity : AppCompatActivity() {
21 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 23 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
22 24
23 - private lateinit var onlineRecognizer: SherpaOnnx  
24 - private lateinit var offlineRecognizer: SherpaOnnxOffline 25 + private lateinit var onlineRecognizer: OnlineRecognizer
  26 + private lateinit var offlineRecognizer: OfflineRecognizer
25 private var audioRecord: AudioRecord? = null 27 private var audioRecord: AudioRecord? = null
26 private lateinit var recordButton: Button 28 private lateinit var recordButton: Button
27 private lateinit var textView: TextView 29 private lateinit var textView: TextView
@@ -93,7 +95,6 @@ class MainActivity : AppCompatActivity() { @@ -93,7 +95,6 @@ class MainActivity : AppCompatActivity() {
93 audioRecord!!.startRecording() 95 audioRecord!!.startRecording()
94 recordButton.setText(R.string.stop) 96 recordButton.setText(R.string.stop)
95 isRecording = true 97 isRecording = true
96 - onlineRecognizer.reset(true)  
97 samplesBuffer.clear() 98 samplesBuffer.clear()
98 textView.text = "" 99 textView.text = ""
99 lastText = "" 100 lastText = ""
@@ -115,6 +116,7 @@ class MainActivity : AppCompatActivity() { @@ -115,6 +116,7 @@ class MainActivity : AppCompatActivity() {
115 116
116 private fun processSamples() { 117 private fun processSamples() {
117 Log.i(TAG, "processing samples") 118 Log.i(TAG, "processing samples")
  119 + val stream = onlineRecognizer.createStream()
118 120
119 val interval = 0.1 // i.e., 100 ms 121 val interval = 0.1 // i.e., 100 ms
120 val bufferSize = (interval * sampleRateInHz).toInt() // in samples 122 val bufferSize = (interval * sampleRateInHz).toInt() // in samples
@@ -126,29 +128,29 @@ class MainActivity : AppCompatActivity() { @@ -126,29 +128,29 @@ class MainActivity : AppCompatActivity() {
126 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 128 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
127 samplesBuffer.add(samples) 129 samplesBuffer.add(samples)
128 130
129 - onlineRecognizer.acceptWaveform(samples, sampleRate = sampleRateInHz)  
130 - while (onlineRecognizer.isReady()) {  
131 - onlineRecognizer.decode() 131 + stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
  132 + while (onlineRecognizer.isReady(stream)) {
  133 + onlineRecognizer.decode(stream)
132 } 134 }
133 - val isEndpoint = onlineRecognizer.isEndpoint() 135 + val isEndpoint = onlineRecognizer.isEndpoint(stream)
134 var textToDisplay = lastText 136 var textToDisplay = lastText
135 137
136 - var text = onlineRecognizer.text 138 + var text = onlineRecognizer.getResult(stream).text
137 if (text.isNotBlank()) { 139 if (text.isNotBlank()) {
138 - if (lastText.isBlank()) { 140 + textToDisplay = if (lastText.isBlank()) {
139 // textView.text = "${idx}: ${text}" 141 // textView.text = "${idx}: ${text}"
140 - textToDisplay = "${idx}: ${text}" 142 + "${idx}: $text"
141 } else { 143 } else {
142 - textToDisplay = "${lastText}\n${idx}: ${text}" 144 + "${lastText}\n${idx}: $text"
143 } 145 }
144 } 146 }
145 147
146 if (isEndpoint) { 148 if (isEndpoint) {
147 - onlineRecognizer.reset() 149 + onlineRecognizer.reset(stream)
148 150
149 if (text.isNotBlank()) { 151 if (text.isNotBlank()) {
150 text = runSecondPass() 152 text = runSecondPass()
151 - lastText = "${lastText}\n${idx}: ${text}" 153 + lastText = "${lastText}\n${idx}: $text"
152 idx += 1 154 idx += 1
153 } else { 155 } else {
154 samplesBuffer.clear() 156 samplesBuffer.clear()
@@ -160,6 +162,7 @@ class MainActivity : AppCompatActivity() { @@ -160,6 +162,7 @@ class MainActivity : AppCompatActivity() {
160 } 162 }
161 } 163 }
162 } 164 }
  165 + stream.release()
163 } 166 }
164 167
165 private fun initMicrophone(): Boolean { 168 private fun initMicrophone(): Boolean {
@@ -190,8 +193,8 @@ class MainActivity : AppCompatActivity() { @@ -190,8 +193,8 @@ class MainActivity : AppCompatActivity() {
190 // Please change getModelConfig() to add new models 193 // Please change getModelConfig() to add new models
191 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 194 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
192 // for a list of available models 195 // for a list of available models
193 - val firstType = 1  
194 - println("Select model type ${firstType} for the first pass") 196 + val firstType = 9
  197 + Log.i(TAG, "Select model type $firstType for the first pass")
195 val config = OnlineRecognizerConfig( 198 val config = OnlineRecognizerConfig(
196 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 199 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
197 modelConfig = getModelConfig(type = firstType)!!, 200 modelConfig = getModelConfig(type = firstType)!!,
@@ -199,7 +202,7 @@ class MainActivity : AppCompatActivity() { @@ -199,7 +202,7 @@ class MainActivity : AppCompatActivity() {
199 enableEndpoint = true, 202 enableEndpoint = true,
200 ) 203 )
201 204
202 - onlineRecognizer = SherpaOnnx( 205 + onlineRecognizer = OnlineRecognizer(
203 assetManager = application.assets, 206 assetManager = application.assets,
204 config = config, 207 config = config,
205 ) 208 )
@@ -209,15 +212,15 @@ class MainActivity : AppCompatActivity() { @@ -209,15 +212,15 @@ class MainActivity : AppCompatActivity() {
209 // Please change getOfflineModelConfig() to add new models 212 // Please change getOfflineModelConfig() to add new models
210 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 213 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
211 // for a list of available models 214 // for a list of available models
212 - val secondType = 1  
213 - println("Select model type ${secondType} for the second pass") 215 + val secondType = 0
  216 + Log.i(TAG, "Select model type $secondType for the second pass")
214 217
215 val config = OfflineRecognizerConfig( 218 val config = OfflineRecognizerConfig(
216 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 219 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
217 modelConfig = getOfflineModelConfig(type = secondType)!!, 220 modelConfig = getOfflineModelConfig(type = secondType)!!,
218 ) 221 )
219 222
220 - offlineRecognizer = SherpaOnnxOffline( 223 + offlineRecognizer = OfflineRecognizer(
221 assetManager = application.assets, 224 assetManager = application.assets,
222 config = config, 225 config = config,
223 ) 226 )
@@ -244,8 +247,15 @@ class MainActivity : AppCompatActivity() { @@ -244,8 +247,15 @@ class MainActivity : AppCompatActivity() {
244 val n = maxOf(0, samples.size - 8000) 247 val n = maxOf(0, samples.size - 8000)
245 248
246 samplesBuffer.clear() 249 samplesBuffer.clear()
247 - samplesBuffer.add(samples.sliceArray(n..samples.size-1)) 250 + samplesBuffer.add(samples.sliceArray(n until samples.size))
248 251
249 - return offlineRecognizer.decode(samples.sliceArray(0..n), sampleRateInHz) 252 + val stream = offlineRecognizer.createStream()
  253 + stream.acceptWaveform(samples.sliceArray(0..n), sampleRateInHz)
  254 + offlineRecognizer.decode(stream)
  255 + val result = offlineRecognizer.getResult(stream)
  256 +
  257 + stream.release()
  258 +
  259 + return result.text
250 } 260 }
251 } 261 }
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
1 -package com.k2fsa.sherpa.onnx  
2 -  
3 -import android.content.res.AssetManager  
4 -  
5 -data class EndpointRule(  
6 - var mustContainNonSilence: Boolean,  
7 - var minTrailingSilence: Float,  
8 - var minUtteranceLength: Float,  
9 -)  
10 -  
11 -data class EndpointConfig(  
12 - var rule1: EndpointRule = EndpointRule(false, 2.0f, 0.0f),  
13 - var rule2: EndpointRule = EndpointRule(true, 1.2f, 0.0f),  
14 - var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)  
15 -)  
16 -  
17 -data class OnlineTransducerModelConfig(  
18 - var encoder: String = "",  
19 - var decoder: String = "",  
20 - var joiner: String = "",  
21 -)  
22 -  
23 -data class OnlineParaformerModelConfig(  
24 - var encoder: String = "",  
25 - var decoder: String = "",  
26 -)  
27 -  
28 -data class OnlineZipformer2CtcModelConfig(  
29 - var model: String = "",  
30 -)  
31 -  
32 -data class OnlineModelConfig(  
33 - var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),  
34 - var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),  
35 - var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),  
36 - var tokens: String,  
37 - var numThreads: Int = 1,  
38 - var debug: Boolean = false,  
39 - var provider: String = "cpu",  
40 - var modelType: String = "",  
41 -)  
42 -  
43 -data class OnlineLMConfig(  
44 - var model: String = "",  
45 - var scale: Float = 0.5f,  
46 -)  
47 -  
48 -data class FeatureConfig(  
49 - var sampleRate: Int = 16000,  
50 - var featureDim: Int = 80,  
51 -)  
52 -  
53 -data class OnlineRecognizerConfig(  
54 - var featConfig: FeatureConfig = FeatureConfig(),  
55 - var modelConfig: OnlineModelConfig,  
56 - var lmConfig: OnlineLMConfig = OnlineLMConfig(),  
57 - var endpointConfig: EndpointConfig = EndpointConfig(),  
58 - var enableEndpoint: Boolean = true,  
59 - var decodingMethod: String = "greedy_search",  
60 - var maxActivePaths: Int = 4,  
61 - var hotwordsFile: String = "",  
62 - var hotwordsScore: Float = 1.5f,  
63 -)  
64 -  
65 -data class OfflineTransducerModelConfig(  
66 - var encoder: String = "",  
67 - var decoder: String = "",  
68 - var joiner: String = "",  
69 -)  
70 -  
71 -data class OfflineParaformerModelConfig(  
72 - var model: String = "",  
73 -)  
74 -  
75 -data class OfflineWhisperModelConfig(  
76 - var encoder: String = "",  
77 - var decoder: String = "",  
78 - var language: String = "en", // Used with multilingual model  
79 - var task: String = "transcribe", // transcribe or translate  
80 - var tailPaddings: Int = 1000, // Padding added at the end of the samples  
81 -)  
82 -  
83 -data class OfflineModelConfig(  
84 - var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),  
85 - var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),  
86 - var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),  
87 - var numThreads: Int = 1,  
88 - var debug: Boolean = false,  
89 - var provider: String = "cpu",  
90 - var modelType: String = "",  
91 - var tokens: String,  
92 -)  
93 -  
94 -data class OfflineRecognizerConfig(  
95 - var featConfig: FeatureConfig = FeatureConfig(),  
96 - var modelConfig: OfflineModelConfig,  
97 - // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it  
98 - var decodingMethod: String = "greedy_search",  
99 - var maxActivePaths: Int = 4,  
100 - var hotwordsFile: String = "",  
101 - var hotwordsScore: Float = 1.5f,  
102 -)  
103 -  
104 -class SherpaOnnx(  
105 - assetManager: AssetManager? = null,  
106 - var config: OnlineRecognizerConfig,  
107 -) {  
108 - private val ptr: Long  
109 -  
110 - init {  
111 - if (assetManager != null) {  
112 - ptr = new(assetManager, config)  
113 - } else {  
114 - ptr = newFromFile(config)  
115 - }  
116 - }  
117 -  
118 - protected fun finalize() {  
119 - delete(ptr)  
120 - }  
121 -  
122 - fun acceptWaveform(samples: FloatArray, sampleRate: Int) =  
123 - acceptWaveform(ptr, samples, sampleRate)  
124 -  
125 - fun inputFinished() = inputFinished(ptr)  
126 - fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)  
127 - fun decode() = decode(ptr)  
128 - fun isEndpoint(): Boolean = isEndpoint(ptr)  
129 - fun isReady(): Boolean = isReady(ptr)  
130 -  
131 - val text: String  
132 - get() = getText(ptr)  
133 -  
134 - val tokens: Array<String>  
135 - get() = getTokens(ptr)  
136 -  
137 - private external fun delete(ptr: Long)  
138 -  
139 - private external fun new(  
140 - assetManager: AssetManager,  
141 - config: OnlineRecognizerConfig,  
142 - ): Long  
143 -  
144 - private external fun newFromFile(  
145 - config: OnlineRecognizerConfig,  
146 - ): Long  
147 -  
148 - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)  
149 - private external fun inputFinished(ptr: Long)  
150 - private external fun getText(ptr: Long): String  
151 - private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)  
152 - private external fun decode(ptr: Long)  
153 - private external fun isEndpoint(ptr: Long): Boolean  
154 - private external fun isReady(ptr: Long): Boolean  
155 - private external fun getTokens(ptr: Long): Array<String>  
156 -  
157 - companion object {  
158 - init {  
159 - System.loadLibrary("sherpa-onnx-jni")  
160 - }  
161 - }  
162 -}  
163 -  
164 -class SherpaOnnxOffline(  
165 - assetManager: AssetManager? = null,  
166 - var config: OfflineRecognizerConfig,  
167 -) {  
168 - private val ptr: Long  
169 -  
170 - init {  
171 - if (assetManager != null) {  
172 - ptr = new(assetManager, config)  
173 - } else {  
174 - ptr = newFromFile(config)  
175 - }  
176 - }  
177 -  
178 - protected fun finalize() {  
179 - delete(ptr)  
180 - }  
181 -  
182 - fun decode(samples: FloatArray, sampleRate: Int) = decode(ptr, samples, sampleRate)  
183 -  
184 - private external fun delete(ptr: Long)  
185 -  
186 - private external fun new(  
187 - assetManager: AssetManager,  
188 - config: OfflineRecognizerConfig,  
189 - ): Long  
190 -  
191 - private external fun newFromFile(  
192 - config: OfflineRecognizerConfig,  
193 - ): Long  
194 -  
195 - private external fun decode(ptr: Long, samples: FloatArray, sampleRate: Int): String  
196 -  
197 - companion object {  
198 - init {  
199 - System.loadLibrary("sherpa-onnx-jni")  
200 - }  
201 - }  
202 -}  
203 -  
204 -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {  
205 - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)  
206 -}  
207 -  
208 -/*  
209 -Please see  
210 -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html  
211 -for a list of pre-trained models.  
212 -  
213 -We only add a few here. Please change the following code  
214 -to add your own. (It should be straightforward to add a new model  
215 -by following the code)  
216 -  
217 -@param type  
218 -0 - csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23 (Chinese)  
219 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23  
220 - encoder/joiner int8, decoder float32  
221 -  
222 -1 - csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 (English)  
223 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-20m-2023-02-17-english  
224 - encoder/joiner int8, decoder fp32  
225 -  
226 - */  
227 -fun getModelConfig(type: Int): OnlineModelConfig? {  
228 - when (type) {  
229 - 0 -> {  
230 - val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23"  
231 - return OnlineModelConfig(  
232 - transducer = OnlineTransducerModelConfig(  
233 - encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",  
234 - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",  
235 - joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",  
236 - ),  
237 - tokens = "$modelDir/tokens.txt",  
238 - modelType = "zipformer",  
239 - )  
240 - }  
241 -  
242 - 1 -> {  
243 - val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"  
244 - return OnlineModelConfig(  
245 - transducer = OnlineTransducerModelConfig(  
246 - encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",  
247 - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",  
248 - joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",  
249 - ),  
250 - tokens = "$modelDir/tokens.txt",  
251 - modelType = "zipformer",  
252 - )  
253 - }  
254 - }  
255 - return null  
256 -}  
257 -  
258 -/*  
259 -Please see  
260 -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html  
261 -for a list of pre-trained models.  
262 -  
263 -We only add a few here. Please change the following code  
264 -to add your own LM model. (It should be straightforward to train a new NN LM model  
265 -by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)  
266 -  
267 -@param type  
268 -0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)  
269 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english  
270 - */  
271 -fun getOnlineLMConfig(type: Int): OnlineLMConfig {  
272 - when (type) {  
273 - 0 -> {  
274 - val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"  
275 - return OnlineLMConfig(  
276 - model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",  
277 - scale = 0.5f,  
278 - )  
279 - }  
280 - }  
281 - return OnlineLMConfig()  
282 -}  
283 -  
284 -// for English models, use a small value for rule2.minTrailingSilence, e.g., 0.8  
285 -fun getEndpointConfig(): EndpointConfig {  
286 - return EndpointConfig(  
287 - rule1 = EndpointRule(false, 2.4f, 0.0f),  
288 - rule2 = EndpointRule(true, 0.8f, 0.0f),  
289 - rule3 = EndpointRule(false, 0.0f, 20.0f)  
290 - )  
291 -}  
292 -  
293 -/*  
294 -Please see  
295 -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html  
296 -for a list of pre-trained models.  
297 -  
298 -We only add a few here. Please change the following code  
299 -to add your own. (It should be straightforward to add a new model  
300 -by following the code)  
301 -  
302 -@param type  
303 -  
304 -0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)  
305 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese  
306 - int8  
307 -  
308 -1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English)  
309 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english  
310 - encoder int8, decoder/joiner float32  
311 -  
312 -2 - sherpa-onnx-whisper-tiny.en  
313 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en  
314 - encoder int8, decoder int8  
315 -  
316 -3 - sherpa-onnx-whisper-base.en  
317 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en  
318 - encoder int8, decoder int8  
319 -  
320 -4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)  
321 - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese  
322 - encoder/joiner int8, decoder fp32  
323 -  
324 - */  
325 -fun getOfflineModelConfig(type: Int): OfflineModelConfig? {  
326 - when (type) {  
327 - 0 -> {  
328 - val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"  
329 - return OfflineModelConfig(  
330 - paraformer = OfflineParaformerModelConfig(  
331 - model = "$modelDir/model.int8.onnx",  
332 - ),  
333 - tokens = "$modelDir/tokens.txt",  
334 - modelType = "paraformer",  
335 - )  
336 - }  
337 -  
338 - 1 -> {  
339 - val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04"  
340 - return OfflineModelConfig(  
341 - transducer = OfflineTransducerModelConfig(  
342 - encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx",  
343 - decoder = "$modelDir/decoder-epoch-30-avg-4.onnx",  
344 - joiner = "$modelDir/joiner-epoch-30-avg-4.onnx",  
345 - ),  
346 - tokens = "$modelDir/tokens.txt",  
347 - modelType = "zipformer",  
348 - )  
349 - }  
350 -  
351 - 2 -> {  
352 - val modelDir = "sherpa-onnx-whisper-tiny.en"  
353 - return OfflineModelConfig(  
354 - whisper = OfflineWhisperModelConfig(  
355 - encoder = "$modelDir/tiny.en-encoder.int8.onnx",  
356 - decoder = "$modelDir/tiny.en-decoder.int8.onnx",  
357 - ),  
358 - tokens = "$modelDir/tiny.en-tokens.txt",  
359 - modelType = "whisper",  
360 - )  
361 - }  
362 -  
363 - 3 -> {  
364 - val modelDir = "sherpa-onnx-whisper-base.en"  
365 - return OfflineModelConfig(  
366 - whisper = OfflineWhisperModelConfig(  
367 - encoder = "$modelDir/base.en-encoder.int8.onnx",  
368 - decoder = "$modelDir/base.en-decoder.int8.onnx",  
369 - ),  
370 - tokens = "$modelDir/base.en-tokens.txt",  
371 - modelType = "whisper",  
372 - )  
373 - }  
374 -  
375 -  
376 - 4 -> {  
377 - val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"  
378 - return OfflineModelConfig(  
379 - transducer = OfflineTransducerModelConfig(  
380 - encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",  
381 - decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",  
382 - joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",  
383 - ),  
384 - tokens = "$modelDir/tokens.txt",  
385 - modelType = "zipformer",  
386 - )  
387 - }  
388 -  
389 - 5 -> {  
390 - val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2"  
391 - return OfflineModelConfig(  
392 - transducer = OfflineTransducerModelConfig(  
393 - encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx",  
394 - decoder = "$modelDir/decoder-epoch-20-avg-1.onnx",  
395 - joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx",  
396 - ),  
397 - tokens = "$modelDir/tokens.txt",  
398 - modelType = "zipformer2",  
399 - )  
400 - }  
401 -  
402 - }  
403 - return null  
404 -}  
1 -package com.k2fsa.sherpa.onnx  
2 -  
3 -import android.content.res.AssetManager  
4 -  
5 -class WaveReader {  
6 - companion object {  
7 - // Read a mono wave file asset  
8 - // The returned array has two entries:  
9 - // - the first entry contains an 1-D float array  
10 - // - the second entry is the sample rate  
11 - external fun readWaveFromAsset(  
12 - assetManager: AssetManager,  
13 - filename: String,  
14 - ): Array<Any>  
15 -  
16 - // Read a mono wave file from disk  
17 - // The returned array has two entries:  
18 - // - the first entry contains an 1-D float array  
19 - // - the second entry is the sample rate  
20 - external fun readWaveFromFile(  
21 - filename: String,  
22 - ): Array<Any>  
23 -  
24 - init {  
25 - System.loadLibrary("sherpa-onnx-jni")  
26 - }  
27 - }  
28 -}  
  1 +../../../../../../../../../../../../sherpa-onnx/kotlin-api/AudioTagging.kt
@@ -46,7 +46,6 @@ import androidx.compose.ui.unit.dp @@ -46,7 +46,6 @@ import androidx.compose.ui.unit.dp
46 import androidx.compose.ui.unit.sp 46 import androidx.compose.ui.unit.sp
47 import androidx.core.app.ActivityCompat 47 import androidx.core.app.ActivityCompat
48 import com.k2fsa.sherpa.onnx.AudioEvent 48 import com.k2fsa.sherpa.onnx.AudioEvent
49 -import com.k2fsa.sherpa.onnx.Tagger  
50 import kotlin.concurrent.thread 49 import kotlin.concurrent.thread
51 50
52 51
@@ -13,13 +13,14 @@ import androidx.compose.material3.Surface @@ -13,13 +13,14 @@ import androidx.compose.material3.Surface
13 import androidx.compose.runtime.Composable 13 import androidx.compose.runtime.Composable
14 import androidx.compose.ui.Modifier 14 import androidx.compose.ui.Modifier
15 import androidx.core.app.ActivityCompat 15 import androidx.core.app.ActivityCompat
16 -import com.k2fsa.sherpa.onnx.Tagger  
17 import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme 16 import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
18 17
19 const val TAG = "sherpa-onnx" 18 const val TAG = "sherpa-onnx"
20 19
21 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 20 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
22 21
  22 +// adb emu avd hostmicon
  23 +// to enable mic inside the emulator
23 class MainActivity : ComponentActivity() { 24 class MainActivity : ComponentActivity() {
24 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 25 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
25 override fun onCreate(savedInstanceState: Bundle?) { 26 override fun onCreate(savedInstanceState: Bundle?) {
  1 +../../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
1 -package com.k2fsa.sherpa.onnx 1 +package com.k2fsa.sherpa.onnx.audio.tagging
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 import android.util.Log 4 import android.util.Log
  5 +import com.k2fsa.sherpa.onnx.AudioTagging
  6 +import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
5 7
6 8
7 object Tagger { 9 object Tagger {
@@ -17,7 +19,7 @@ object Tagger { @@ -17,7 +19,7 @@ object Tagger {
17 return 19 return
18 } 20 }
19 21
20 - Log.i(TAG, "Initializing audio tagger") 22 + Log.i("sherpa-onnx", "Initializing audio tagger")
21 val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!! 23 val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
22 _tagger = AudioTagging(assetManager, config) 24 _tagger = AudioTagging(assetManager, config)
23 } 25 }
@@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
33 import androidx.wear.compose.material.MaterialTheme 33 import androidx.wear.compose.material.MaterialTheme
34 import androidx.wear.compose.material.Text 34 import androidx.wear.compose.material.Text
35 import com.k2fsa.sherpa.onnx.AudioEvent 35 import com.k2fsa.sherpa.onnx.AudioEvent
36 -import com.k2fsa.sherpa.onnx.Tagger 36 +import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
37 import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme 37 import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
38 import kotlin.concurrent.thread 38 import kotlin.concurrent.thread
39 39
@@ -17,11 +17,14 @@ import androidx.activity.compose.setContent @@ -17,11 +17,14 @@ import androidx.activity.compose.setContent
17 import androidx.compose.runtime.Composable 17 import androidx.compose.runtime.Composable
18 import androidx.core.app.ActivityCompat 18 import androidx.core.app.ActivityCompat
19 import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen 19 import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
20 -import com.k2fsa.sherpa.onnx.Tagger 20 +import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
21 21
22 const val TAG = "sherpa-onnx" 22 const val TAG = "sherpa-onnx"
23 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 23 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
24 24
  25 +// adb emu avd hostmicon
  26 +// to enable mic inside the emulator
  27 +
25 class MainActivity : ComponentActivity() { 28 class MainActivity : ComponentActivity() {
26 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 29 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
27 override fun onCreate(savedInstanceState: Bundle?) { 30 override fun onCreate(savedInstanceState: Bundle?) {
@@ -15,7 +15,8 @@ @@ -15,7 +15,8 @@
15 android:theme="@style/Theme.SherpaOnnx" 15 android:theme="@style/Theme.SherpaOnnx"
16 tools:targetApi="31"> 16 tools:targetApi="31">
17 <activity 17 <activity
18 - android:name=".MainActivity" 18 + android:name=".kws.MainActivity"
  19 + android:label="Keyword-spotter"
19 android:exported="true"> 20 android:exported="true">
20 <intent-filter> 21 <intent-filter>
21 <action android:name="android.intent.action.MAIN" /> 22 <action android:name="android.intent.action.MAIN" />
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/KeywordSpotter.kt
1 -package com.k2fsa.sherpa.onnx 1 +package com.k2fsa.sherpa.onnx.kws
2 2
3 import android.Manifest 3 import android.Manifest
4 import android.content.pm.PackageManager 4 import android.content.pm.PackageManager
@@ -14,7 +14,13 @@ import android.widget.TextView @@ -14,7 +14,13 @@ import android.widget.TextView
14 import android.widget.Toast 14 import android.widget.Toast
15 import androidx.appcompat.app.AppCompatActivity 15 import androidx.appcompat.app.AppCompatActivity
16 import androidx.core.app.ActivityCompat 16 import androidx.core.app.ActivityCompat
17 -import com.k2fsa.sherpa.onnx.* 17 +import com.k2fsa.sherpa.onnx.KeywordSpotter
  18 +import com.k2fsa.sherpa.onnx.KeywordSpotterConfig
  19 +import com.k2fsa.sherpa.onnx.OnlineStream
  20 +import com.k2fsa.sherpa.onnx.R
  21 +import com.k2fsa.sherpa.onnx.getFeatureConfig
  22 +import com.k2fsa.sherpa.onnx.getKeywordsFile
  23 +import com.k2fsa.sherpa.onnx.getKwsModelConfig
18 import kotlin.concurrent.thread 24 import kotlin.concurrent.thread
19 25
20 private const val TAG = "sherpa-onnx" 26 private const val TAG = "sherpa-onnx"
@@ -23,7 +29,8 @@ private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 @@ -23,7 +29,8 @@ private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
23 class MainActivity : AppCompatActivity() { 29 class MainActivity : AppCompatActivity() {
24 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 30 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
25 31
26 - private lateinit var model: SherpaOnnxKws 32 + private lateinit var kws: KeywordSpotter
  33 + private lateinit var stream: OnlineStream
27 private var audioRecord: AudioRecord? = null 34 private var audioRecord: AudioRecord? = null
28 private lateinit var recordButton: Button 35 private lateinit var recordButton: Button
29 private lateinit var textView: TextView 36 private lateinit var textView: TextView
@@ -87,15 +94,18 @@ class MainActivity : AppCompatActivity() { @@ -87,15 +94,18 @@ class MainActivity : AppCompatActivity() {
87 94
88 Log.i(TAG, keywords) 95 Log.i(TAG, keywords)
89 keywords = keywords.replace("\n", "/") 96 keywords = keywords.replace("\n", "/")
  97 + keywords = keywords.trim()
90 // If keywords is an empty string, it just resets the decoding stream 98 // If keywords is an empty string, it just resets the decoding stream
91 // always returns true in this case. 99 // always returns true in this case.
92 // If keywords is not empty, it will create a new decoding stream with 100 // If keywords is not empty, it will create a new decoding stream with
93 // the given keywords appended to the default keywords. 101 // the given keywords appended to the default keywords.
94 - // Return false if errors occured when adding keywords, true otherwise.  
95 - val status = model.reset(keywords)  
96 - if (!status) {  
97 - Log.i(TAG, "Failed to reset with keywords.")  
98 - Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show(); 102 + // Return false if errors occurred when adding keywords, true otherwise.
  103 + stream.release()
  104 + stream = kws.createStream(keywords)
  105 + if (stream.ptr == 0L) {
  106 + Log.i(TAG, "Failed to create stream with keywords: $keywords")
  107 + Toast.makeText(this, "Failed to set keywords to $keywords.", Toast.LENGTH_LONG)
  108 + .show()
99 return 109 return
100 } 110 }
101 111
@@ -122,6 +132,7 @@ class MainActivity : AppCompatActivity() { @@ -122,6 +132,7 @@ class MainActivity : AppCompatActivity() {
122 audioRecord!!.release() 132 audioRecord!!.release()
123 audioRecord = null 133 audioRecord = null
124 recordButton.setText(R.string.start) 134 recordButton.setText(R.string.start)
  135 + stream.release()
125 Log.i(TAG, "Stopped recording") 136 Log.i(TAG, "Stopped recording")
126 } 137 }
127 } 138 }
@@ -137,22 +148,22 @@ class MainActivity : AppCompatActivity() { @@ -137,22 +148,22 @@ class MainActivity : AppCompatActivity() {
137 val ret = audioRecord?.read(buffer, 0, buffer.size) 148 val ret = audioRecord?.read(buffer, 0, buffer.size)
138 if (ret != null && ret > 0) { 149 if (ret != null && ret > 0) {
139 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 150 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
140 - model.acceptWaveform(samples, sampleRate=sampleRateInHz)  
141 - while (model.isReady()) {  
142 - model.decode() 151 + stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
  152 + while (kws.isReady(stream)) {
  153 + kws.decode(stream)
143 } 154 }
144 155
145 - val text = model.keyword 156 + val text = kws.getResult(stream).keyword
146 157
147 - var textToDisplay = lastText; 158 + var textToDisplay = lastText
148 159
149 - if(text.isNotBlank()) { 160 + if (text.isNotBlank()) {
150 if (lastText.isBlank()) { 161 if (lastText.isBlank()) {
151 - textToDisplay = "${idx}: ${text}" 162 + textToDisplay = "$idx: $text"
152 } else { 163 } else {
153 - textToDisplay = "${idx}: ${text}\n${lastText}" 164 + textToDisplay = "$idx: $text\n$lastText"
154 } 165 }
155 - lastText = "${idx}: ${text}\n${lastText}" 166 + lastText = "$idx: $text\n$lastText"
156 idx += 1 167 idx += 1
157 } 168 }
158 169
@@ -188,20 +199,21 @@ class MainActivity : AppCompatActivity() { @@ -188,20 +199,21 @@ class MainActivity : AppCompatActivity() {
188 } 199 }
189 200
190 private fun initModel() { 201 private fun initModel() {
191 - // Please change getModelConfig() to add new models 202 + // Please change getKwsModelConfig() to add new models
192 // See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html 203 // See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
193 // for a list of available models 204 // for a list of available models
194 val type = 0 205 val type = 0
195 - Log.i(TAG, "Select model type ${type}") 206 + Log.i(TAG, "Select model type $type")
196 val config = KeywordSpotterConfig( 207 val config = KeywordSpotterConfig(
197 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 208 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
198 - modelConfig = getModelConfig(type = type)!!,  
199 - keywordsFile = getKeywordsFile(type = type)!!, 209 + modelConfig = getKwsModelConfig(type = type)!!,
  210 + keywordsFile = getKeywordsFile(type = type),
200 ) 211 )
201 212
202 - model = SherpaOnnxKws( 213 + kws = KeywordSpotter(
203 assetManager = application.assets, 214 assetManager = application.assets,
204 config = config, 215 config = config,
205 ) 216 )
  217 + stream = kws.createStream()
206 } 218 }
207 } 219 }
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt
1 -// Copyright (c) 2023 Xiaomi Corporation  
2 -package com.k2fsa.sherpa.onnx  
3 -  
4 -import android.content.res.AssetManager  
5 -  
6 -class WaveReader {  
7 - companion object {  
8 - // Read a mono wave file asset  
9 - // The returned array has two entries:  
10 - // - the first entry contains an 1-D float array  
11 - // - the second entry is the sample rate  
12 - external fun readWaveFromAsset(  
13 - assetManager: AssetManager,  
14 - filename: String,  
15 - ): Array<Any>  
16 -  
17 - // Read a mono wave file from disk  
18 - // The returned array has two entries:  
19 - // - the first entry contains an 1-D float array  
20 - // - the second entry is the sample rate  
21 - external fun readWaveFromFile(  
22 - filename: String,  
23 - ): Array<Any>  
24 -  
25 - init {  
26 - System.loadLibrary("sherpa-onnx-jni")  
27 - }  
28 - }  
29 -}  
1 <resources> 1 <resources>
2 - <string name="app_name">KWS with Next-gen Kaldi</string> 2 + <string name="app_name">Keyword spotting</string>
3 <string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi. 3 <string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
4 \n 4 \n
5 \n\n\n 5 \n\n\n
6 The source code and pre-trained models are publicly available. 6 The source code and pre-trained models are publicly available.
7 Please see https://github.com/k2-fsa/sherpa-onnx for details. 7 Please see https://github.com/k2-fsa/sherpa-onnx for details.
8 </string> 8 </string>
9 - <string name="keyword_hint">Input your keywords here, one keyword perline.</string> 9 + <string name="keyword_hint">Input your keywords here, one keyword per line.\nTwo example keywords are given below:\n\nn ǐ h ǎo @你好\nd àn g ē d àn g ē @蛋哥蛋哥</string>
10 <string name="start">Start</string> 10 <string name="start">Start</string>
11 <string name="stop">Stop</string> 11 <string name="stop">Stop</string>
12 </resources> 12 </resources>
@@ -2,7 +2,7 @@ package com.k2fsa.sherpa.onnx.speaker.identification @@ -2,7 +2,7 @@ package com.k2fsa.sherpa.onnx.speaker.identification
2 2
3 import androidx.compose.ui.graphics.vector.ImageVector 3 import androidx.compose.ui.graphics.vector.ImageVector
4 4
5 -data class BarItem ( 5 +data class BarItem(
6 val title: String, 6 val title: String,
7 7
8 // see https://www.composables.com/icons 8 // see https://www.composables.com/icons
1 package com.k2fsa.sherpa.onnx.speaker.identification 1 package com.k2fsa.sherpa.onnx.speaker.identification
2 2
3 sealed class NavRoutes(val route: String) { 3 sealed class NavRoutes(val route: String) {
4 - object Home: NavRoutes("home")  
5 - object Register: NavRoutes("register")  
6 - object View: NavRoutes("view")  
7 - object Help: NavRoutes("help") 4 + object Home : NavRoutes("home")
  5 + object Register : NavRoutes("register")
  6 + object View : NavRoutes("view")
  7 + object Help : NavRoutes("help")
8 } 8 }
1 -@file:OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) 1 +@file:OptIn(ExperimentalMaterial3Api::class)
2 2
3 package com.k2fsa.sherpa.onnx.slid 3 package com.k2fsa.sherpa.onnx.slid
4 4
@@ -9,11 +9,9 @@ import android.media.AudioFormat @@ -9,11 +9,9 @@ import android.media.AudioFormat
9 import android.media.AudioRecord 9 import android.media.AudioRecord
10 import android.media.MediaRecorder 10 import android.media.MediaRecorder
11 import android.util.Log 11 import android.util.Log
12 -import androidx.compose.foundation.ExperimentalFoundationApi  
13 import androidx.compose.foundation.layout.Box 12 import androidx.compose.foundation.layout.Box
14 import androidx.compose.foundation.layout.Column 13 import androidx.compose.foundation.layout.Column
15 import androidx.compose.foundation.layout.PaddingValues 14 import androidx.compose.foundation.layout.PaddingValues
16 -import androidx.compose.ui.Modifier  
17 import androidx.compose.foundation.layout.Spacer 15 import androidx.compose.foundation.layout.Spacer
18 import androidx.compose.foundation.layout.fillMaxSize 16 import androidx.compose.foundation.layout.fillMaxSize
19 import androidx.compose.foundation.layout.height 17 import androidx.compose.foundation.layout.height
@@ -31,6 +29,7 @@ import androidx.compose.runtime.mutableStateOf @@ -31,6 +29,7 @@ import androidx.compose.runtime.mutableStateOf
31 import androidx.compose.runtime.remember 29 import androidx.compose.runtime.remember
32 import androidx.compose.runtime.setValue 30 import androidx.compose.runtime.setValue
33 import androidx.compose.ui.Alignment 31 import androidx.compose.ui.Alignment
  32 +import androidx.compose.ui.Modifier
34 import androidx.compose.ui.platform.LocalContext 33 import androidx.compose.ui.platform.LocalContext
35 import androidx.compose.ui.text.font.FontWeight 34 import androidx.compose.ui.text.font.FontWeight
36 import androidx.compose.ui.unit.dp 35 import androidx.compose.ui.unit.dp
@@ -63,13 +62,13 @@ fun Home() { @@ -63,13 +62,13 @@ fun Home() {
63 } 62 }
64 63
65 private var audioRecord: AudioRecord? = null 64 private var audioRecord: AudioRecord? = null
66 -private val sampleRateInHz = 16000 65 +private const val sampleRateInHz = 16000
67 66
68 @Composable 67 @Composable
69 fun MyApp(padding: PaddingValues) { 68 fun MyApp(padding: PaddingValues) {
70 val activity = LocalContext.current as Activity 69 val activity = LocalContext.current as Activity
71 var isStarted by remember { mutableStateOf(false) } 70 var isStarted by remember { mutableStateOf(false) }
72 - var result by remember { mutableStateOf<String>("") } 71 + var result by remember { mutableStateOf("") }
73 72
74 val onButtonClick: () -> Unit = { 73 val onButtonClick: () -> Unit = {
75 isStarted = !isStarted 74 isStarted = !isStarted
@@ -114,12 +113,12 @@ fun MyApp(padding: PaddingValues) { @@ -114,12 +113,12 @@ fun MyApp(padding: PaddingValues) {
114 } 113 }
115 Log.i(TAG, "Stop recording") 114 Log.i(TAG, "Stop recording")
116 Log.i(TAG, "Start recognition") 115 Log.i(TAG, "Start recognition")
117 - val samples = Flatten(sampleList) 116 + val samples = flatten(sampleList)
118 val stream = Slid.slid.createStream() 117 val stream = Slid.slid.createStream()
119 stream.acceptWaveform(samples, sampleRateInHz) 118 stream.acceptWaveform(samples, sampleRateInHz)
120 val lang = Slid.slid.compute(stream) 119 val lang = Slid.slid.compute(stream)
121 120
122 - result = Slid.localeMap.get(lang) ?: lang 121 + result = Slid.localeMap[lang] ?: lang
123 122
124 stream.release() 123 stream.release()
125 } 124 }
@@ -152,7 +151,7 @@ fun MyApp(padding: PaddingValues) { @@ -152,7 +151,7 @@ fun MyApp(padding: PaddingValues) {
152 } 151 }
153 } 152 }
154 153
155 -fun Flatten(sampleList: ArrayList<FloatArray>): FloatArray { 154 +fun flatten(sampleList: ArrayList<FloatArray>): FloatArray {
156 var totalSamples = 0 155 var totalSamples = 0
157 for (a in sampleList) { 156 for (a in sampleList) {
158 totalSamples += a.size 157 totalSamples += a.size
@@ -10,12 +10,9 @@ import androidx.activity.compose.setContent @@ -10,12 +10,9 @@ import androidx.activity.compose.setContent
10 import androidx.compose.foundation.layout.fillMaxSize 10 import androidx.compose.foundation.layout.fillMaxSize
11 import androidx.compose.material3.MaterialTheme 11 import androidx.compose.material3.MaterialTheme
12 import androidx.compose.material3.Surface 12 import androidx.compose.material3.Surface
13 -import androidx.compose.material3.Text  
14 import androidx.compose.runtime.Composable 13 import androidx.compose.runtime.Composable
15 import androidx.compose.ui.Modifier 14 import androidx.compose.ui.Modifier
16 -import androidx.compose.ui.tooling.preview.Preview  
17 import androidx.core.app.ActivityCompat 15 import androidx.core.app.ActivityCompat
18 -import com.k2fsa.sherpa.onnx.SpokenLanguageIdentification  
19 import com.k2fsa.sherpa.onnx.slid.ui.theme.SherpaOnnxSpokenLanguageIdentificationTheme 16 import com.k2fsa.sherpa.onnx.slid.ui.theme.SherpaOnnxSpokenLanguageIdentificationTheme
20 17
21 const val TAG = "sherpa-onnx" 18 const val TAG = "sherpa-onnx"
@@ -32,6 +29,7 @@ class MainActivity : ComponentActivity() { @@ -32,6 +29,7 @@ class MainActivity : ComponentActivity() {
32 ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) 29 ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
33 Slid.initSlid(this.assets) 30 Slid.initSlid(this.assets)
34 } 31 }
  32 +
35 @Suppress("DEPRECATION") 33 @Suppress("DEPRECATION")
36 @Deprecated("Deprecated in Java") 34 @Deprecated("Deprecated in Java")
37 override fun onRequestPermissionsResult( 35 override fun onRequestPermissionsResult(
1 -../../../../../../../../../../SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt  
  1 +../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
  1 +../../../../../../../../../../../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
@@ -15,7 +15,7 @@ object Slid { @@ -15,7 +15,7 @@ object Slid {
15 get() { 15 get() {
16 return _slid!! 16 return _slid!!
17 } 17 }
18 - val localeMap : Map<String, String> 18 + val localeMap: Map<String, String>
19 get() { 19 get() {
20 return _localeMap 20 return _localeMap
21 } 21 }
@@ -31,7 +31,7 @@ object Slid { @@ -31,7 +31,7 @@ object Slid {
31 } 31 }
32 32
33 if (_localeMap.isEmpty()) { 33 if (_localeMap.isEmpty()) {
34 - val allLang = Locale.getISOLanguages(); 34 + val allLang = Locale.getISOLanguages()
35 for (lang in allLang) { 35 for (lang in allLang) {
36 val locale = Locale(lang) 36 val locale = Locale(lang)
37 _localeMap[lang] = locale.displayName 37 _localeMap[lang] = locale.displayName
1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 -import android.media.* 4 +import android.media.AudioAttributes
  5 +import android.media.AudioFormat
  6 +import android.media.AudioManager
  7 +import android.media.AudioTrack
  8 +import android.media.MediaPlayer
5 import android.net.Uri 9 import android.net.Uri
6 import android.os.Bundle 10 import android.os.Bundle
7 import android.util.Log 11 import android.util.Log
@@ -212,7 +216,7 @@ class MainActivity : AppCompatActivity() { @@ -212,7 +216,7 @@ class MainActivity : AppCompatActivity() {
212 } 216 }
213 217
214 if (dictDir != null) { 218 if (dictDir != null) {
215 - val newDir = copyDataDir( modelDir!!) 219 + val newDir = copyDataDir(modelDir!!)
216 modelDir = newDir + "/" + modelDir 220 modelDir = newDir + "/" + modelDir
217 dictDir = modelDir + "/" + "dict" 221 dictDir = modelDir + "/" + "dict"
218 ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" 222 ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst"
@@ -220,7 +224,9 @@ class MainActivity : AppCompatActivity() { @@ -220,7 +224,9 @@ class MainActivity : AppCompatActivity() {
220 } 224 }
221 225
222 val config = getOfflineTtsConfig( 226 val config = getOfflineTtsConfig(
223 - modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "", 227 + modelDir = modelDir!!,
  228 + modelName = modelName!!,
  229 + lexicon = lexicon ?: "",
224 dataDir = dataDir ?: "", 230 dataDir = dataDir ?: "",
225 dictDir = dictDir ?: "", 231 dictDir = dictDir ?: "",
226 ruleFsts = ruleFsts ?: "", 232 ruleFsts = ruleFsts ?: "",
@@ -232,11 +238,11 @@ class MainActivity : AppCompatActivity() { @@ -232,11 +238,11 @@ class MainActivity : AppCompatActivity() {
232 238
233 239
234 private fun copyDataDir(dataDir: String): String { 240 private fun copyDataDir(dataDir: String): String {
235 - println("data dir is $dataDir") 241 + Log.i(TAG, "data dir is $dataDir")
236 copyAssets(dataDir) 242 copyAssets(dataDir)
237 243
238 val newDataDir = application.getExternalFilesDir(null)!!.absolutePath 244 val newDataDir = application.getExternalFilesDir(null)!!.absolutePath
239 - println("newDataDir: $newDataDir") 245 + Log.i(TAG, "newDataDir: $newDataDir")
240 return newDataDir 246 return newDataDir
241 } 247 }
242 248
@@ -256,7 +262,7 @@ class MainActivity : AppCompatActivity() { @@ -256,7 +262,7 @@ class MainActivity : AppCompatActivity() {
256 } 262 }
257 } 263 }
258 } catch (ex: IOException) { 264 } catch (ex: IOException) {
259 - Log.e(TAG, "Failed to copy $path. ${ex.toString()}") 265 + Log.e(TAG, "Failed to copy $path. $ex")
260 } 266 }
261 } 267 }
262 268
@@ -276,7 +282,7 @@ class MainActivity : AppCompatActivity() { @@ -276,7 +282,7 @@ class MainActivity : AppCompatActivity() {
276 ostream.flush() 282 ostream.flush()
277 ostream.close() 283 ostream.close()
278 } catch (ex: Exception) { 284 } catch (ex: Exception) {
279 - Log.e(TAG, "Failed to copy $filename, ${ex.toString()}") 285 + Log.e(TAG, "Failed to copy $filename, $ex")
280 } 286 }
281 } 287 }
282 } 288 }
@@ -49,10 +49,10 @@ class OfflineTts( @@ -49,10 +49,10 @@ class OfflineTts(
49 private var ptr: Long 49 private var ptr: Long
50 50
51 init { 51 init {
52 - if (assetManager != null) {  
53 - ptr = newFromAsset(assetManager, config) 52 + ptr = if (assetManager != null) {
  53 + newFromAsset(assetManager, config)
54 } else { 54 } else {
55 - ptr = newFromFile(config) 55 + newFromFile(config)
56 } 56 }
57 } 57 }
58 58
@@ -65,7 +65,7 @@ class OfflineTts( @@ -65,7 +65,7 @@ class OfflineTts(
65 sid: Int = 0, 65 sid: Int = 0,
66 speed: Float = 1.0f 66 speed: Float = 1.0f
67 ): GeneratedAudio { 67 ): GeneratedAudio {
68 - var objArray = generateImpl(ptr, text = text, sid = sid, speed = speed) 68 + val objArray = generateImpl(ptr, text = text, sid = sid, speed = speed)
69 return GeneratedAudio( 69 return GeneratedAudio(
70 samples = objArray[0] as FloatArray, 70 samples = objArray[0] as FloatArray,
71 sampleRate = objArray[1] as Int 71 sampleRate = objArray[1] as Int
@@ -78,7 +78,13 @@ class OfflineTts( @@ -78,7 +78,13 @@ class OfflineTts(
78 speed: Float = 1.0f, 78 speed: Float = 1.0f,
79 callback: (samples: FloatArray) -> Unit 79 callback: (samples: FloatArray) -> Unit
80 ): GeneratedAudio { 80 ): GeneratedAudio {
81 - var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback) 81 + val objArray = generateWithCallbackImpl(
  82 + ptr,
  83 + text = text,
  84 + sid = sid,
  85 + speed = speed,
  86 + callback = callback
  87 + )
82 return GeneratedAudio( 88 return GeneratedAudio(
83 samples = objArray[0] as FloatArray, 89 samples = objArray[0] as FloatArray,
84 sampleRate = objArray[1] as Int 90 sampleRate = objArray[1] as Int
@@ -87,10 +93,10 @@ class OfflineTts( @@ -87,10 +93,10 @@ class OfflineTts(
87 93
88 fun allocate(assetManager: AssetManager? = null) { 94 fun allocate(assetManager: AssetManager? = null) {
89 if (ptr == 0L) { 95 if (ptr == 0L) {
90 - if (assetManager != null) {  
91 - ptr = newFromAsset(assetManager, config) 96 + ptr = if (assetManager != null) {
  97 + newFromAsset(assetManager, config)
92 } else { 98 } else {
93 - ptr = newFromFile(config) 99 + newFromFile(config)
94 } 100 }
95 } 101 }
96 } 102 }
@@ -103,8 +109,13 @@ class OfflineTts( @@ -103,8 +109,13 @@ class OfflineTts(
103 } 109 }
104 110
105 protected fun finalize() { 111 protected fun finalize() {
  112 + if (ptr != 0L) {
106 delete(ptr) 113 delete(ptr)
  114 + ptr = 0
107 } 115 }
  116 + }
  117 +
  118 + fun release() = finalize()
108 119
109 private external fun newFromAsset( 120 private external fun newFromAsset(
110 assetManager: AssetManager, 121 assetManager: AssetManager,
@@ -123,14 +134,14 @@ class OfflineTts( @@ -123,14 +134,14 @@ class OfflineTts(
123 // - the first entry is an 1-D float array containing audio samples. 134 // - the first entry is an 1-D float array containing audio samples.
124 // Each sample is normalized to the range [-1, 1] 135 // Each sample is normalized to the range [-1, 1]
125 // - the second entry is the sample rate 136 // - the second entry is the sample rate
126 - external fun generateImpl( 137 + private external fun generateImpl(
127 ptr: Long, 138 ptr: Long,
128 text: String, 139 text: String,
129 sid: Int = 0, 140 sid: Int = 0,
130 speed: Float = 1.0f 141 speed: Float = 1.0f
131 ): Array<Any> 142 ): Array<Any>
132 143
133 - external fun generateWithCallbackImpl( 144 + private external fun generateWithCallbackImpl(
134 ptr: Long, 145 ptr: Long,
135 text: String, 146 text: String,
136 sid: Int = 0, 147 sid: Int = 0,
@@ -156,7 +167,7 @@ fun getOfflineTtsConfig( @@ -156,7 +167,7 @@ fun getOfflineTtsConfig(
156 dictDir: String, 167 dictDir: String,
157 ruleFsts: String, 168 ruleFsts: String,
158 ruleFars: String 169 ruleFars: String
159 -): OfflineTtsConfig? { 170 +): OfflineTtsConfig {
160 return OfflineTtsConfig( 171 return OfflineTtsConfig(
161 model = OfflineTtsModelConfig( 172 model = OfflineTtsModelConfig(
162 vits = OfflineTtsVitsModelConfig( 173 vits = OfflineTtsVitsModelConfig(
1 package com.k2fsa.sherpa.onnx.tts.engine 1 package com.k2fsa.sherpa.onnx.tts.engine
2 2
3 import android.content.Intent 3 import android.content.Intent
4 -import androidx.appcompat.app.AppCompatActivity  
5 import android.os.Bundle 4 import android.os.Bundle
6 import android.speech.tts.TextToSpeech 5 import android.speech.tts.TextToSpeech
  6 +import androidx.appcompat.app.AppCompatActivity
7 7
8 class CheckVoiceData : AppCompatActivity() { 8 class CheckVoiceData : AppCompatActivity() {
9 override fun onCreate(savedInstanceState: Bundle?) { 9 override fun onCreate(savedInstanceState: Bundle?) {
10 super.onCreate(savedInstanceState) 10 super.onCreate(savedInstanceState)
11 val intent = Intent().apply { 11 val intent = Intent().apply {
12 - putStringArrayListExtra(TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES, arrayListOf(TtsEngine.lang)) 12 + putStringArrayListExtra(
  13 + TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES,
  14 + arrayListOf(TtsEngine.lang)
  15 + )
13 putStringArrayListExtra(TextToSpeech.Engine.EXTRA_UNAVAILABLE_VOICES, arrayListOf()) 16 putStringArrayListExtra(TextToSpeech.Engine.EXTRA_UNAVAILABLE_VOICES, arrayListOf())
14 } 17 }
15 setResult(TextToSpeech.Engine.CHECK_VOICE_DATA_PASS, intent) 18 setResult(TextToSpeech.Engine.CHECK_VOICE_DATA_PASS, intent)
@@ -2,7 +2,6 @@ package com.k2fsa.sherpa.onnx.tts.engine @@ -2,7 +2,6 @@ package com.k2fsa.sherpa.onnx.tts.engine
2 2
3 import android.app.Activity 3 import android.app.Activity
4 import android.content.Intent 4 import android.content.Intent
5 -import androidx.appcompat.app.AppCompatActivity  
6 import android.os.Bundle 5 import android.os.Bundle
7 import android.speech.tts.TextToSpeech 6 import android.speech.tts.TextToSpeech
8 7
@@ -12,120 +11,168 @@ fun getSampleText(lang: String): String { @@ -12,120 +11,168 @@ fun getSampleText(lang: String): String {
12 "ara" -> { 11 "ara" -> {
13 text = "هذا هو محرك تحويل النص إلى كلام باستخدام الجيل القادم من كالدي" 12 text = "هذا هو محرك تحويل النص إلى كلام باستخدام الجيل القادم من كالدي"
14 } 13 }
  14 +
15 "ben" -> { 15 "ben" -> {
16 text = "এটি একটি টেক্সট-টু-স্পীচ ইঞ্জিন যা পরবর্তী প্রজন্মের কালডি ব্যবহার করে" 16 text = "এটি একটি টেক্সট-টু-স্পীচ ইঞ্জিন যা পরবর্তী প্রজন্মের কালডি ব্যবহার করে"
17 } 17 }
  18 +
18 "bul" -> { 19 "bul" -> {
19 - text = "Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение" 20 + text =
  21 + "Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение"
20 } 22 }
  23 +
21 "cat" -> { 24 "cat" -> {
22 text = "Aquest és un motor de text a veu que utilitza Kaldi de nova generació" 25 text = "Aquest és un motor de text a veu que utilitza Kaldi de nova generació"
23 } 26 }
  27 +
24 "ces" -> { 28 "ces" -> {
25 text = "Toto je převodník textu na řeč využívající novou generaci kaldi" 29 text = "Toto je převodník textu na řeč využívající novou generaci kaldi"
26 } 30 }
  31 +
27 "dan" -> { 32 "dan" -> {
28 text = "Dette er en tekst til tale-motor, der bruger næste generation af kaldi" 33 text = "Dette er en tekst til tale-motor, der bruger næste generation af kaldi"
29 } 34 }
  35 +
30 "deu" -> { 36 "deu" -> {
31 - text = "Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet" 37 + text =
  38 + "Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet"
32 } 39 }
  40 +
33 "ell" -> { 41 "ell" -> {
34 text = "Αυτή είναι μια μηχανή κειμένου σε ομιλία που χρησιμοποιεί kaldi επόμενης γενιάς" 42 text = "Αυτή είναι μια μηχανή κειμένου σε ομιλία που χρησιμοποιεί kaldi επόμενης γενιάς"
35 } 43 }
  44 +
36 "eng" -> { 45 "eng" -> {
37 text = "This is a text-to-speech engine using next generation Kaldi" 46 text = "This is a text-to-speech engine using next generation Kaldi"
38 } 47 }
  48 +
39 "est" -> { 49 "est" -> {
40 text = "See on teksti kõneks muutmise mootor, mis kasutab järgmise põlvkonna Kaldi" 50 text = "See on teksti kõneks muutmise mootor, mis kasutab järgmise põlvkonna Kaldi"
41 } 51 }
  52 +
42 "fin" -> { 53 "fin" -> {
43 text = "Tämä on tekstistä puheeksi -moottori, joka käyttää seuraavan sukupolven kaldia" 54 text = "Tämä on tekstistä puheeksi -moottori, joka käyttää seuraavan sukupolven kaldia"
44 } 55 }
  56 +
45 "fra" -> { 57 "fra" -> {
46 text = "Il s'agit d'un moteur de synthèse vocale utilisant Kaldi de nouvelle génération" 58 text = "Il s'agit d'un moteur de synthèse vocale utilisant Kaldi de nouvelle génération"
47 } 59 }
  60 +
48 "gle" -> { 61 "gle" -> {
49 text = "Is inneall téacs-go-hurlabhra é seo a úsáideann Kaldi den chéad ghlúin eile" 62 text = "Is inneall téacs-go-hurlabhra é seo a úsáideann Kaldi den chéad ghlúin eile"
50 } 63 }
  64 +
51 "hrv" -> { 65 "hrv" -> {
52 - text = "Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije" 66 + text =
  67 + "Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije"
53 } 68 }
  69 +
54 "hun" -> { 70 "hun" -> {
55 text = "Ez egy szövegfelolvasó motor a következő generációs kaldi használatával" 71 text = "Ez egy szövegfelolvasó motor a következő generációs kaldi használatával"
56 } 72 }
  73 +
57 "isl" -> { 74 "isl" -> {
58 text = "Þetta er texta í tal vél sem notar næstu kynslóð kaldi" 75 text = "Þetta er texta í tal vél sem notar næstu kynslóð kaldi"
59 } 76 }
  77 +
60 "ita" -> { 78 "ita" -> {
61 text = "Questo è un motore di sintesi vocale che utilizza kaldi di nuova generazione" 79 text = "Questo è un motore di sintesi vocale che utilizza kaldi di nuova generazione"
62 } 80 }
  81 +
63 "kat" -> { 82 "kat" -> {
64 text = "ეს არის ტექსტიდან მეტყველების ძრავა შემდეგი თაობის კალდის გამოყენებით" 83 text = "ეს არის ტექსტიდან მეტყველების ძრავა შემდეგი თაობის კალდის გამოყენებით"
65 } 84 }
  85 +
66 "kaz" -> { 86 "kaz" -> {
67 text = "Бұл келесі буын kaldi көмегімен мәтіннен сөйлеуге арналған қозғалтқыш" 87 text = "Бұл келесі буын kaldi көмегімен мәтіннен сөйлеуге арналған қозғалтқыш"
68 } 88 }
  89 +
69 "mlt" -> { 90 "mlt" -> {
70 text = "Din hija magna text-to-speech li tuża Kaldi tal-ġenerazzjoni li jmiss" 91 text = "Din hija magna text-to-speech li tuża Kaldi tal-ġenerazzjoni li jmiss"
71 } 92 }
  93 +
72 "lav" -> { 94 "lav" -> {
73 text = "Šis ir teksta pārvēršanas runā dzinējs, kas izmanto nākamās paaudzes Kaldi" 95 text = "Šis ir teksta pārvēršanas runā dzinējs, kas izmanto nākamās paaudzes Kaldi"
74 } 96 }
  97 +
75 "lit" -> { 98 "lit" -> {
76 text = "Tai teksto į kalbą variklis, kuriame naudojamas naujos kartos Kaldi" 99 text = "Tai teksto į kalbą variklis, kuriame naudojamas naujos kartos Kaldi"
77 } 100 }
  101 +
78 "ltz" -> { 102 "ltz" -> {
79 text = "Dëst ass en Text-zu-Speech-Motor mat der nächster Generatioun Kaldi" 103 text = "Dëst ass en Text-zu-Speech-Motor mat der nächster Generatioun Kaldi"
80 } 104 }
  105 +
81 "nep" -> { 106 "nep" -> {
82 text = "यो अर्को पुस्ता काल्डी प्रयोग गरेर स्पीच इन्जिनको पाठ हो" 107 text = "यो अर्को पुस्ता काल्डी प्रयोग गरेर स्पीच इन्जिनको पाठ हो"
83 } 108 }
  109 +
84 "nld" -> { 110 "nld" -> {
85 - text = "Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie" 111 + text =
  112 + "Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie"
86 } 113 }
  114 +
87 "nor" -> { 115 "nor" -> {
88 text = "Dette er en tekst til tale-motor som bruker neste generasjons kaldi" 116 text = "Dette er en tekst til tale-motor som bruker neste generasjons kaldi"
89 } 117 }
  118 +
90 "pol" -> { 119 "pol" -> {
91 text = "Jest to silnik syntezatora mowy wykorzystujący Kaldi nowej generacji" 120 text = "Jest to silnik syntezatora mowy wykorzystujący Kaldi nowej generacji"
92 } 121 }
  122 +
93 "por" -> { 123 "por" -> {
94 - text = "Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração" 124 + text =
  125 + "Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração"
95 } 126 }
  127 +
96 "ron" -> { 128 "ron" -> {
97 text = "Acesta este un motor text to speech care folosește generația următoare de kadi" 129 text = "Acesta este un motor text to speech care folosește generația următoare de kadi"
98 } 130 }
  131 +
99 "rus" -> { 132 "rus" -> {
100 - text = "Это движок преобразования текста в речь, использующий Kaldi следующего поколения." 133 + text =
  134 + "Это движок преобразования текста в речь, использующий Kaldi следующего поколения."
101 } 135 }
  136 +
102 "slk" -> { 137 "slk" -> {
103 text = "Toto je nástroj na prevod textu na reč využívajúci kaldi novej generácie" 138 text = "Toto je nástroj na prevod textu na reč využívajúci kaldi novej generácie"
104 } 139 }
  140 +
105 "slv" -> { 141 "slv" -> {
106 - text = "To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije" 142 + text =
  143 + "To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije"
107 } 144 }
  145 +
108 "spa" -> { 146 "spa" -> {
109 text = "Este es un motor de texto a voz que utiliza kaldi de próxima generación." 147 text = "Este es un motor de texto a voz que utiliza kaldi de próxima generación."
110 } 148 }
  149 +
111 "srp" -> { 150 "srp" -> {
112 - text = "Ово је механизам за претварање текста у говор који користи калди следеће генерације" 151 + text =
  152 + "Ово је механизам за претварање текста у говор који користи калди следеће генерације"
113 } 153 }
  154 +
114 "swa" -> { 155 "swa" -> {
115 text = "Haya ni maandishi kwa injini ya hotuba kwa kutumia kizazi kijacho kaldi" 156 text = "Haya ni maandishi kwa injini ya hotuba kwa kutumia kizazi kijacho kaldi"
116 } 157 }
  158 +
117 "swe" -> { 159 "swe" -> {
118 text = "Detta är en text till tal-motor som använder nästa generations kaldi" 160 text = "Detta är en text till tal-motor som använder nästa generations kaldi"
119 } 161 }
  162 +
120 "tur" -> { 163 "tur" -> {
121 text = "Bu, yeni nesil kaldi'yi kullanan bir metinden konuşmaya motorudur" 164 text = "Bu, yeni nesil kaldi'yi kullanan bir metinden konuşmaya motorudur"
122 } 165 }
  166 +
123 "ukr" -> { 167 "ukr" -> {
124 - text = "Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління" 168 + text =
  169 + "Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління"
125 } 170 }
  171 +
126 "vie" -> { 172 "vie" -> {
127 text = "Đây là công cụ chuyển văn bản thành giọng nói sử dụng kaldi thế hệ tiếp theo" 173 text = "Đây là công cụ chuyển văn bản thành giọng nói sử dụng kaldi thế hệ tiếp theo"
128 } 174 }
  175 +
129 "zho", "cmn" -> { 176 "zho", "cmn" -> {
130 text = "使用新一代卡尔迪的语音合成引擎" 177 text = "使用新一代卡尔迪的语音合成引擎"
131 } 178 }
@@ -137,13 +184,13 @@ class GetSampleText : Activity() { @@ -137,13 +184,13 @@ class GetSampleText : Activity() {
137 override fun onCreate(savedInstanceState: Bundle?) { 184 override fun onCreate(savedInstanceState: Bundle?) {
138 super.onCreate(savedInstanceState) 185 super.onCreate(savedInstanceState)
139 var result = TextToSpeech.LANG_AVAILABLE 186 var result = TextToSpeech.LANG_AVAILABLE
140 - var text: String = getSampleText(TtsEngine.lang ?: "") 187 + val text: String = getSampleText(TtsEngine.lang ?: "")
141 if (text.isEmpty()) { 188 if (text.isEmpty()) {
142 result = TextToSpeech.LANG_NOT_SUPPORTED 189 result = TextToSpeech.LANG_NOT_SUPPORTED
143 } 190 }
144 191
145 - val intent = Intent().apply{  
146 - if(result == TextToSpeech.LANG_AVAILABLE) { 192 + val intent = Intent().apply {
  193 + if (result == TextToSpeech.LANG_AVAILABLE) {
147 putExtra(TextToSpeech.Engine.EXTRA_SAMPLE_TEXT, text) 194 putExtra(TextToSpeech.Engine.EXTRA_SAMPLE_TEXT, text)
148 } else { 195 } else {
149 putExtra("sampleText", text) 196 putExtra("sampleText", text)
@@ -26,20 +26,16 @@ import androidx.compose.material3.Scaffold @@ -26,20 +26,16 @@ import androidx.compose.material3.Scaffold
26 import androidx.compose.material3.Slider 26 import androidx.compose.material3.Slider
27 import androidx.compose.material3.Surface 27 import androidx.compose.material3.Surface
28 import androidx.compose.material3.Text 28 import androidx.compose.material3.Text
29 -import androidx.compose.material3.TextField  
30 import androidx.compose.material3.TopAppBar 29 import androidx.compose.material3.TopAppBar
31 -import androidx.compose.runtime.Composable  
32 import androidx.compose.runtime.getValue 30 import androidx.compose.runtime.getValue
33 import androidx.compose.runtime.mutableStateOf 31 import androidx.compose.runtime.mutableStateOf
34 import androidx.compose.runtime.remember 32 import androidx.compose.runtime.remember
35 import androidx.compose.runtime.setValue 33 import androidx.compose.runtime.setValue
36 import androidx.compose.ui.Modifier 34 import androidx.compose.ui.Modifier
37 import androidx.compose.ui.text.input.KeyboardType 35 import androidx.compose.ui.text.input.KeyboardType
38 -import androidx.compose.ui.tooling.preview.Preview  
39 import androidx.compose.ui.unit.dp 36 import androidx.compose.ui.unit.dp
40 import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme 37 import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
41 import java.io.File 38 import java.io.File
42 -import java.lang.NumberFormatException  
43 39
44 const val TAG = "sherpa-onnx-tts-engine" 40 const val TAG = "sherpa-onnx-tts-engine"
45 41
@@ -88,7 +84,7 @@ class MainActivity : ComponentActivity() { @@ -88,7 +84,7 @@ class MainActivity : ComponentActivity() {
88 try { 84 try {
89 TtsEngine.speakerId = it.toString().toInt() 85 TtsEngine.speakerId = it.toString().toInt()
90 } catch (ex: NumberFormatException) { 86 } catch (ex: NumberFormatException) {
91 - Log.i(TAG, "Invalid input: ${it}") 87 + Log.i(TAG, "Invalid input: $it")
92 TtsEngine.speakerId = 0 88 TtsEngine.speakerId = 0
93 } 89 }
94 } 90 }
@@ -119,7 +115,7 @@ class MainActivity : ComponentActivity() { @@ -119,7 +115,7 @@ class MainActivity : ComponentActivity() {
119 Button( 115 Button(
120 modifier = Modifier.padding(20.dp), 116 modifier = Modifier.padding(20.dp),
121 onClick = { 117 onClick = {
122 - Log.i(TAG, "Clicked, text: ${testText}") 118 + Log.i(TAG, "Clicked, text: $testText")
123 if (testText.isBlank() || testText.isEmpty()) { 119 if (testText.isBlank() || testText.isEmpty()) {
124 Toast.makeText( 120 Toast.makeText(
125 applicationContext, 121 applicationContext,
@@ -136,7 +132,7 @@ class MainActivity : ComponentActivity() { @@ -136,7 +132,7 @@ class MainActivity : ComponentActivity() {
136 val filename = 132 val filename =
137 application.filesDir.absolutePath + "/generated.wav" 133 application.filesDir.absolutePath + "/generated.wav"
138 val ok = 134 val ok =
139 - audio.samples.size > 0 && audio.save(filename) 135 + audio.samples.isNotEmpty() && audio.save(filename)
140 136
141 if (ok) { 137 if (ok) {
142 stopMediaPlayer() 138 stopMediaPlayer()
@@ -4,8 +4,10 @@ import android.content.Context @@ -4,8 +4,10 @@ import android.content.Context
4 import android.content.res.AssetManager 4 import android.content.res.AssetManager
5 import android.util.Log 5 import android.util.Log
6 import androidx.compose.runtime.MutableState 6 import androidx.compose.runtime.MutableState
7 -import androidx.compose.runtime.mutableStateOf  
8 -import com.k2fsa.sherpa.onnx.* 7 +import androidx.compose.runtime.mutableFloatStateOf
  8 +import androidx.compose.runtime.mutableIntStateOf
  9 +import com.k2fsa.sherpa.onnx.OfflineTts
  10 +import com.k2fsa.sherpa.onnx.getOfflineTtsConfig
9 import java.io.File 11 import java.io.File
10 import java.io.FileOutputStream 12 import java.io.FileOutputStream
11 import java.io.IOException 13 import java.io.IOException
@@ -21,8 +23,8 @@ object TtsEngine { @@ -21,8 +23,8 @@ object TtsEngine {
21 var lang: String? = null 23 var lang: String? = null
22 24
23 25
24 - val speedState: MutableState<Float> = mutableStateOf(1.0F)  
25 - val speakerIdState: MutableState<Int> = mutableStateOf(0) 26 + val speedState: MutableState<Float> = mutableFloatStateOf(1.0F)
  27 + val speakerIdState: MutableState<Int> = mutableIntStateOf(0)
26 28
27 var speed: Float 29 var speed: Float
28 get() = speedState.value 30 get() = speedState.value
@@ -113,15 +115,15 @@ object TtsEngine { @@ -113,15 +115,15 @@ object TtsEngine {
113 115
114 if (dataDir != null) { 116 if (dataDir != null) {
115 val newDir = copyDataDir(context, modelDir!!) 117 val newDir = copyDataDir(context, modelDir!!)
116 - modelDir = newDir + "/" + modelDir  
117 - dataDir = newDir + "/" + dataDir 118 + modelDir = "$newDir/$modelDir"
  119 + dataDir = "$newDir/$dataDir"
118 assets = null 120 assets = null
119 } 121 }
120 122
121 if (dictDir != null) { 123 if (dictDir != null) {
122 val newDir = copyDataDir(context, modelDir!!) 124 val newDir = copyDataDir(context, modelDir!!)
123 - modelDir = newDir + "/" + modelDir  
124 - dictDir = modelDir + "/" + "dict" 125 + modelDir = "$newDir/$modelDir"
  126 + dictDir = "$modelDir/dict"
125 ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" 127 ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst"
126 assets = null 128 assets = null
127 } 129 }
@@ -132,18 +134,18 @@ object TtsEngine { @@ -132,18 +134,18 @@ object TtsEngine {
132 dictDir = dictDir ?: "", 134 dictDir = dictDir ?: "",
133 ruleFsts = ruleFsts ?: "", 135 ruleFsts = ruleFsts ?: "",
134 ruleFars = ruleFars ?: "" 136 ruleFars = ruleFars ?: ""
135 - )!! 137 + )
136 138
137 tts = OfflineTts(assetManager = assets, config = config) 139 tts = OfflineTts(assetManager = assets, config = config)
138 } 140 }
139 141
140 142
141 private fun copyDataDir(context: Context, dataDir: String): String { 143 private fun copyDataDir(context: Context, dataDir: String): String {
142 - println("data dir is $dataDir") 144 + Log.i(TAG, "data dir is $dataDir")
143 copyAssets(context, dataDir) 145 copyAssets(context, dataDir)
144 146
145 val newDataDir = context.getExternalFilesDir(null)!!.absolutePath 147 val newDataDir = context.getExternalFilesDir(null)!!.absolutePath
146 - println("newDataDir: $newDataDir") 148 + Log.i(TAG, "newDataDir: $newDataDir")
147 return newDataDir 149 return newDataDir
148 } 150 }
149 151
@@ -158,12 +160,12 @@ object TtsEngine { @@ -158,12 +160,12 @@ object TtsEngine {
158 val dir = File(fullPath) 160 val dir = File(fullPath)
159 dir.mkdirs() 161 dir.mkdirs()
160 for (asset in assets.iterator()) { 162 for (asset in assets.iterator()) {
161 - val p: String = if (path == "") "" else path + "/" 163 + val p: String = if (path == "") "" else "$path/"
162 copyAssets(context, p + asset) 164 copyAssets(context, p + asset)
163 } 165 }
164 } 166 }
165 } catch (ex: IOException) { 167 } catch (ex: IOException) {
166 - Log.e(TAG, "Failed to copy $path. ${ex.toString()}") 168 + Log.e(TAG, "Failed to copy $path. $ex")
167 } 169 }
168 } 170 }
169 171
@@ -183,7 +185,7 @@ object TtsEngine { @@ -183,7 +185,7 @@ object TtsEngine {
183 ostream.flush() 185 ostream.flush()
184 ostream.close() 186 ostream.close()
185 } catch (ex: Exception) { 187 } catch (ex: Exception) {
186 - Log.e(TAG, "Failed to copy $filename, ${ex.toString()}") 188 + Log.e(TAG, "Failed to copy $filename, $ex")
187 } 189 }
188 } 190 }
189 } 191 }
@@ -6,7 +6,6 @@ import android.speech.tts.SynthesisRequest @@ -6,7 +6,6 @@ import android.speech.tts.SynthesisRequest
6 import android.speech.tts.TextToSpeech 6 import android.speech.tts.TextToSpeech
7 import android.speech.tts.TextToSpeechService 7 import android.speech.tts.TextToSpeechService
8 import android.util.Log 8 import android.util.Log
9 -import com.k2fsa.sherpa.onnx.*  
10 9
11 /* 10 /*
12 https://developer.android.com/reference/java/util/Locale#getISO3Language() 11 https://developer.android.com/reference/java/util/Locale#getISO3Language()
1 package com.k2fsa.sherpa.onnx.tts.engine 1 package com.k2fsa.sherpa.onnx.tts.engine
2 2
3 import android.app.Application 3 import android.app.Application
4 -import android.os.FileUtils.ProgressListener  
5 import android.speech.tts.TextToSpeech 4 import android.speech.tts.TextToSpeech
6 import android.speech.tts.TextToSpeech.OnInitListener 5 import android.speech.tts.TextToSpeech.OnInitListener
7 import android.speech.tts.UtteranceProgressListener 6 import android.speech.tts.UtteranceProgressListener
@@ -27,7 +26,7 @@ class TtsViewModel : ViewModel() { @@ -27,7 +26,7 @@ class TtsViewModel : ViewModel() {
27 private val onInitListener = object : OnInitListener { 26 private val onInitListener = object : OnInitListener {
28 override fun onInit(status: Int) { 27 override fun onInit(status: Int) {
29 when (status) { 28 when (status) {
30 - TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeded") 29 + TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeeded")
31 TextToSpeech.ERROR -> Log.i(TAG, "Init tts failed") 30 TextToSpeech.ERROR -> Log.i(TAG, "Init tts failed")
32 else -> Log.i(TAG, "Unknown status $status") 31 else -> Log.i(TAG, "Unknown status $status")
33 } 32 }
@@ -15,7 +15,7 @@ @@ -15,7 +15,7 @@
15 android:theme="@style/Theme.SherpaOnnxVad" 15 android:theme="@style/Theme.SherpaOnnxVad"
16 tools:targetApi="31"> 16 tools:targetApi="31">
17 <activity 17 <activity
18 - android:name=".MainActivity" 18 + android:name="com.k2fsa.sherpa.onnx.vad.MainActivity"
19 android:exported="true"> 19 android:exported="true">
20 <intent-filter> 20 <intent-filter>
21 <action android:name="android.intent.action.MAIN" /> 21 <action android:name="android.intent.action.MAIN" />
1 -package com.k2fsa.sherpa.onnx 1 +package com.k2fsa.sherpa.onnx.vad
2 2
3 import android.Manifest 3 import android.Manifest
4 import android.content.pm.PackageManager 4 import android.content.pm.PackageManager
@@ -11,6 +11,9 @@ import android.view.View @@ -11,6 +11,9 @@ import android.view.View
11 import android.widget.Button 11 import android.widget.Button
12 import androidx.appcompat.app.AppCompatActivity 12 import androidx.appcompat.app.AppCompatActivity
13 import androidx.core.app.ActivityCompat 13 import androidx.core.app.ActivityCompat
  14 +import com.k2fsa.sherpa.onnx.R
  15 +import com.k2fsa.sherpa.onnx.Vad
  16 +import com.k2fsa.sherpa.onnx.getVadModelConfig
14 import kotlin.concurrent.thread 17 import kotlin.concurrent.thread
15 18
16 19
@@ -116,7 +119,7 @@ class MainActivity : AppCompatActivity() { @@ -116,7 +119,7 @@ class MainActivity : AppCompatActivity() {
116 119
117 private fun initVadModel() { 120 private fun initVadModel() {
118 val type = 0 121 val type = 0
119 - println("Select VAD model type ${type}") 122 + Log.i(TAG, "Select VAD model type ${type}")
120 val config = getVadModelConfig(type) 123 val config = getVadModelConfig(type)
121 124
122 vad = Vad( 125 vad = Vad(
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 xmlns:tools="http://schemas.android.com/tools" 4 xmlns:tools="http://schemas.android.com/tools"
5 android:layout_width="match_parent" 5 android:layout_width="match_parent"
6 android:layout_height="match_parent" 6 android:layout_height="match_parent"
7 - tools:context=".MainActivity"> 7 + tools:context="com.k2fsa.sherpa.onnx.vad.MainActivity">
8 <LinearLayout 8 <LinearLayout
9 android:layout_width="match_parent" 9 android:layout_width="match_parent"
10 android:layout_height="match_parent" 10 android:layout_height="match_parent"
@@ -15,7 +15,7 @@ @@ -15,7 +15,7 @@
15 android:theme="@style/Theme.SherpaOnnxVadAsr" 15 android:theme="@style/Theme.SherpaOnnxVadAsr"
16 tools:targetApi="31"> 16 tools:targetApi="31">
17 <activity 17 <activity
18 - android:name=".MainActivity" 18 + android:name=".vad.asr.MainActivity"
19 android:exported="true"> 19 android:exported="true">
20 <intent-filter> 20 <intent-filter>
21 <action android:name="android.intent.action.MAIN" /> 21 <action android:name="android.intent.action.MAIN" />
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
1 -package com.k2fsa.sherpa.onnx 1 +package com.k2fsa.sherpa.onnx.vad.asr
2 2
3 import android.Manifest 3 import android.Manifest
4 import android.content.pm.PackageManager 4 import android.content.pm.PackageManager
@@ -13,6 +13,13 @@ import android.widget.Button @@ -13,6 +13,13 @@ import android.widget.Button
13 import android.widget.TextView 13 import android.widget.TextView
14 import androidx.appcompat.app.AppCompatActivity 14 import androidx.appcompat.app.AppCompatActivity
15 import androidx.core.app.ActivityCompat 15 import androidx.core.app.ActivityCompat
  16 +import com.k2fsa.sherpa.onnx.OfflineRecognizer
  17 +import com.k2fsa.sherpa.onnx.OfflineRecognizerConfig
  18 +import com.k2fsa.sherpa.onnx.R
  19 +import com.k2fsa.sherpa.onnx.Vad
  20 +import com.k2fsa.sherpa.onnx.getFeatureConfig
  21 +import com.k2fsa.sherpa.onnx.getOfflineModelConfig
  22 +import com.k2fsa.sherpa.onnx.getVadModelConfig
16 import kotlin.concurrent.thread 23 import kotlin.concurrent.thread
17 24
18 25
@@ -40,7 +47,7 @@ class MainActivity : AppCompatActivity() { @@ -40,7 +47,7 @@ class MainActivity : AppCompatActivity() {
40 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) 47 private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
41 48
42 // Non-streaming ASR 49 // Non-streaming ASR
43 - private lateinit var offlineRecognizer: SherpaOnnxOffline 50 + private lateinit var offlineRecognizer: OfflineRecognizer
44 51
45 private var idx: Int = 0 52 private var idx: Int = 0
46 private var lastText: String = "" 53 private var lastText: String = ""
@@ -122,7 +129,7 @@ class MainActivity : AppCompatActivity() { @@ -122,7 +129,7 @@ class MainActivity : AppCompatActivity() {
122 129
123 private fun initVadModel() { 130 private fun initVadModel() {
124 val type = 0 131 val type = 0
125 - println("Select VAD model type ${type}") 132 + Log.i(TAG, "Select VAD model type ${type}")
126 val config = getVadModelConfig(type) 133 val config = getVadModelConfig(type)
127 134
128 vad = Vad( 135 vad = Vad(
@@ -194,20 +201,25 @@ class MainActivity : AppCompatActivity() { @@ -194,20 +201,25 @@ class MainActivity : AppCompatActivity() {
194 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 201 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
195 // for a list of available models 202 // for a list of available models
196 val secondType = 0 203 val secondType = 0
197 - println("Select model type ${secondType} for the second pass") 204 + Log.i(TAG, "Select model type ${secondType} for the second pass")
198 205
199 val config = OfflineRecognizerConfig( 206 val config = OfflineRecognizerConfig(
200 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 207 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
201 modelConfig = getOfflineModelConfig(type = secondType)!!, 208 modelConfig = getOfflineModelConfig(type = secondType)!!,
202 ) 209 )
203 210
204 - offlineRecognizer = SherpaOnnxOffline( 211 + offlineRecognizer = OfflineRecognizer(
205 assetManager = application.assets, 212 assetManager = application.assets,
206 config = config, 213 config = config,
207 ) 214 )
208 } 215 }
209 216
210 private fun runSecondPass(samples: FloatArray): String { 217 private fun runSecondPass(samples: FloatArray): String {
211 - return offlineRecognizer.decode(samples, sampleRateInHz) 218 + val stream = offlineRecognizer.createStream()
  219 + stream.acceptWaveform(samples, sampleRateInHz)
  220 + offlineRecognizer.decode(stream)
  221 + val result = offlineRecognizer.getResult(stream)
  222 + stream.release()
  223 + return result.text
212 } 224 }
213 } 225 }
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
1 -../../../../../../../../../SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt  
1 -../../../../../../../../../SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt  
  1 +../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 xmlns:tools="http://schemas.android.com/tools" 4 xmlns:tools="http://schemas.android.com/tools"
5 android:layout_width="match_parent" 5 android:layout_width="match_parent"
6 android:layout_height="match_parent" 6 android:layout_height="match_parent"
7 - tools:context=".MainActivity"> 7 + tools:context=".vad.asr.MainActivity">
8 8
9 <LinearLayout 9 <LinearLayout
10 android:layout_width="match_parent" 10 android:layout_width="match_parent"
1 <resources> 1 <resources>
2 - <string name="app_name">VAD-ASR</string> 2 + <string name="app_name">VAD+ASR</string>
3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. 3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
4 \n 4 \n
5 \n\n\n 5 \n\n\n
@@ -59,7 +59,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ @@ -59,7 +59,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
59 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" 59 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
60 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" 60 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
61 61
  62 +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
  63 + SHERPA_ONNX_ENABLE_TTS=ON
  64 +fi
  65 +
  66 +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
  67 + SHERPA_ONNX_ENABLE_BINARY=OFF
  68 +fi
  69 +
62 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ 70 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
  71 + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
  72 + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
63 -DBUILD_PIPER_PHONMIZE_EXE=OFF \ 73 -DBUILD_PIPER_PHONMIZE_EXE=OFF \
64 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ 74 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
65 -DBUILD_ESPEAK_NG_EXE=OFF \ 75 -DBUILD_ESPEAK_NG_EXE=OFF \
@@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" 60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" 61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
62 62
  63 +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
  64 + SHERPA_ONNX_ENABLE_TTS=ON
  65 +fi
  66 +
  67 +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
  68 + SHERPA_ONNX_ENABLE_BINARY=OFF
  69 +fi
  70 +
63 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ 71 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
  72 + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
  73 + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
64 -DBUILD_PIPER_PHONMIZE_EXE=OFF \ 74 -DBUILD_PIPER_PHONMIZE_EXE=OFF \
65 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ 75 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
66 -DBUILD_ESPEAK_NG_EXE=OFF \ 76 -DBUILD_ESPEAK_NG_EXE=OFF \
@@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" 60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" 61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
62 62
  63 +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
  64 + SHERPA_ONNX_ENABLE_TTS=ON
  65 +fi
  66 +
  67 +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
  68 + SHERPA_ONNX_ENABLE_BINARY=OFF
  69 +fi
  70 +
63 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ 71 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
  72 + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
  73 + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
64 -DBUILD_PIPER_PHONMIZE_EXE=OFF \ 74 -DBUILD_PIPER_PHONMIZE_EXE=OFF \
65 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ 75 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
66 -DBUILD_ESPEAK_NG_EXE=OFF \ 76 -DBUILD_ESPEAK_NG_EXE=OFF \
@@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/
60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" 60 echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR"
61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" 61 echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR"
62 62
  63 +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
  64 + SHERPA_ONNX_ENABLE_TTS=ON
  65 +fi
  66 +
  67 +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
  68 + SHERPA_ONNX_ENABLE_BINARY=OFF
  69 +fi
  70 +
63 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ 71 cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
  72 + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
  73 + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
64 -DBUILD_PIPER_PHONMIZE_EXE=OFF \ 74 -DBUILD_PIPER_PHONMIZE_EXE=OFF \
65 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ 75 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
66 -DBUILD_ESPEAK_NG_EXE=OFF \ 76 -DBUILD_ESPEAK_NG_EXE=OFF \
1 -../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt  
  1 +../sherpa-onnx/kotlin-api/AudioTagging.kt
  1 +../sherpa-onnx/kotlin-api/FeatureConfig.kt
1 -package com.k2fsa.sherpa.onnx  
2 -  
3 -import android.content.res.AssetManager  
4 -  
5 -fun callback(samples: FloatArray): Unit {  
6 - println("callback got called with ${samples.size} samples");  
7 -}  
8 -  
9 -fun main() {  
10 - testSpokenLanguageIdentifcation()  
11 - testAudioTagging()  
12 - testSpeakerRecognition()  
13 - testTts()  
14 - testAsr("transducer")  
15 - testAsr("zipformer2-ctc")  
16 -}  
17 -  
18 -fun testSpokenLanguageIdentifcation() {  
19 - val config = SpokenLanguageIdentificationConfig(  
20 - whisper = SpokenLanguageIdentificationWhisperConfig(  
21 - encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",  
22 - decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",  
23 - tailPaddings = 33,  
24 - ),  
25 - numThreads=1,  
26 - debug=true,  
27 - provider="cpu",  
28 - )  
29 - val slid = SpokenLanguageIdentification(assetManager=null, config=config)  
30 -  
31 - val testFiles = arrayOf(  
32 - "./spoken-language-identification-test-wavs/ar-arabic.wav",  
33 - "./spoken-language-identification-test-wavs/bg-bulgarian.wav",  
34 - "./spoken-language-identification-test-wavs/de-german.wav",  
35 - )  
36 -  
37 - for (waveFilename in testFiles) {  
38 - val objArray = WaveReader.readWaveFromFile(  
39 - filename = waveFilename,  
40 - )  
41 - val samples: FloatArray = objArray[0] as FloatArray  
42 - val sampleRate: Int = objArray[1] as Int  
43 -  
44 - val stream = slid.createStream()  
45 - stream.acceptWaveform(samples, sampleRate = sampleRate)  
46 - val lang = slid.compute(stream)  
47 - stream.release()  
48 - println(waveFilename)  
49 - println(lang)  
50 - }  
51 -}  
52 -  
53 -fun testAudioTagging() {  
54 - val config = AudioTaggingConfig(  
55 - model=AudioTaggingModelConfig(  
56 - zipformer=OfflineZipformerAudioTaggingModelConfig(  
57 - model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",  
58 - ),  
59 - numThreads=1,  
60 - debug=true,  
61 - provider="cpu",  
62 - ),  
63 - labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",  
64 - topK=5,  
65 - )  
66 - val tagger = AudioTagging(assetManager=null, config=config)  
67 -  
68 - val testFiles = arrayOf(  
69 - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",  
70 - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",  
71 - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",  
72 - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",  
73 - )  
74 - println("----------")  
75 - for (waveFilename in testFiles) {  
76 - val stream = tagger.createStream()  
77 -  
78 - val objArray = WaveReader.readWaveFromFile(  
79 - filename = waveFilename,  
80 - )  
81 - val samples: FloatArray = objArray[0] as FloatArray  
82 - val sampleRate: Int = objArray[1] as Int  
83 -  
84 - stream.acceptWaveform(samples, sampleRate = sampleRate)  
85 - val events = tagger.compute(stream)  
86 - stream.release()  
87 -  
88 - println(waveFilename)  
89 - println(events)  
90 - println("----------")  
91 - }  
92 -  
93 - tagger.release()  
94 -}  
95 -  
96 -fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {  
97 - var objArray = WaveReader.readWaveFromFile(  
98 - filename = filename,  
99 - )  
100 - var samples: FloatArray = objArray[0] as FloatArray  
101 - var sampleRate: Int = objArray[1] as Int  
102 -  
103 - val stream = extractor.createStream()  
104 - stream.acceptWaveform(sampleRate = sampleRate, samples=samples)  
105 - stream.inputFinished()  
106 - check(extractor.isReady(stream))  
107 -  
108 - val embedding = extractor.compute(stream)  
109 -  
110 - stream.release()  
111 -  
112 - return embedding  
113 -}  
114 -  
115 -fun testSpeakerRecognition() {  
116 - val config = SpeakerEmbeddingExtractorConfig(  
117 - model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",  
118 - )  
119 - val extractor = SpeakerEmbeddingExtractor(config = config)  
120 -  
121 - val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")  
122 - val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")  
123 - val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")  
124 -  
125 - var manager = SpeakerEmbeddingManager(extractor.dim())  
126 - var ok = manager.add(name = "speaker1", embedding=embedding1a)  
127 - check(ok)  
128 -  
129 - manager.add(name = "speaker2", embedding=embedding2a)  
130 - check(ok)  
131 -  
132 - var name = manager.search(embedding=embedding1b, threshold=0.5f)  
133 - check(name == "speaker1")  
134 -  
135 - manager.release()  
136 -  
137 - manager = SpeakerEmbeddingManager(extractor.dim())  
138 - val embeddingList = mutableListOf(embedding1a, embedding1b)  
139 - ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())  
140 - check(ok)  
141 -  
142 - name = manager.search(embedding=embedding1b, threshold=0.5f)  
143 - check(name == "s1")  
144 -  
145 - name = manager.search(embedding=embedding2a, threshold=0.5f)  
146 - check(name.length == 0)  
147 -  
148 - manager.release()  
149 -}  
150 -  
151 -fun testTts() {  
152 - // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models  
153 - // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2  
154 - var config = OfflineTtsConfig(  
155 - model=OfflineTtsModelConfig(  
156 - vits=OfflineTtsVitsModelConfig(  
157 - model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx",  
158 - tokens="./vits-piper-en_US-amy-low/tokens.txt",  
159 - dataDir="./vits-piper-en_US-amy-low/espeak-ng-data",  
160 - ),  
161 - numThreads=1,  
162 - debug=true,  
163 - )  
164 - )  
165 - val tts = OfflineTts(config=config)  
166 - val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)  
167 - audio.save(filename="test-en.wav")  
168 -}  
169 -  
170 -fun testAsr(type: String) {  
171 - var featConfig = FeatureConfig(  
172 - sampleRate = 16000,  
173 - featureDim = 80,  
174 - )  
175 -  
176 - var waveFilename: String  
177 - var modelConfig: OnlineModelConfig = when (type) {  
178 - "transducer" -> {  
179 - waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"  
180 - // please refer to  
181 - // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html  
182 - // to dowload pre-trained models  
183 - OnlineModelConfig(  
184 - transducer = OnlineTransducerModelConfig(  
185 - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",  
186 - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",  
187 - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",  
188 - ),  
189 - tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",  
190 - numThreads = 1,  
191 - debug = false,  
192 - )  
193 - }  
194 - "zipformer2-ctc" -> {  
195 - waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"  
196 - OnlineModelConfig(  
197 - zipformer2Ctc = OnlineZipformer2CtcModelConfig(  
198 - model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",  
199 - ),  
200 - tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",  
201 - numThreads = 1,  
202 - debug = false,  
203 - )  
204 - }  
205 - else -> throw IllegalArgumentException(type)  
206 - }  
207 -  
208 - var endpointConfig = EndpointConfig()  
209 -  
210 - var lmConfig = OnlineLMConfig()  
211 -  
212 - var config = OnlineRecognizerConfig(  
213 - modelConfig = modelConfig,  
214 - lmConfig = lmConfig,  
215 - featConfig = featConfig,  
216 - endpointConfig = endpointConfig,  
217 - enableEndpoint = true,  
218 - decodingMethod = "greedy_search",  
219 - maxActivePaths = 4,  
220 - )  
221 -  
222 - var model = SherpaOnnx(  
223 - config = config,  
224 - )  
225 -  
226 - var objArray = WaveReader.readWaveFromFile(  
227 - filename = waveFilename,  
228 - )  
229 - var samples: FloatArray = objArray[0] as FloatArray  
230 - var sampleRate: Int = objArray[1] as Int  
231 -  
232 - model.acceptWaveform(samples, sampleRate = sampleRate)  
233 - while (model.isReady()) {  
234 - model.decode()  
235 - }  
236 -  
237 - var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds  
238 - model.acceptWaveform(tailPaddings, sampleRate = sampleRate)  
239 - model.inputFinished()  
240 - while (model.isReady()) {  
241 - model.decode()  
242 - }  
243 -  
244 - println("results: ${model.text}")  
245 -}  
  1 +../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
1 -../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt  
  1 +../sherpa-onnx/kotlin-api/OfflineStream.kt
  1 +../sherpa-onnx/kotlin-api/OnlineRecognizer.kt
  1 +../sherpa-onnx/kotlin-api/OnlineStream.kt
1 -../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt  
1 -../android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt  
1 -../android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt  
  1 +../sherpa-onnx/kotlin-api/Speaker.kt
1 -../android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt  
  1 +../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
1 -../android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt  
  1 +../sherpa-onnx/kotlin-api/Vad.kt
1 -../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt  
  1 +../sherpa-onnx/kotlin-api/WaveReader.kt
@@ -44,9 +44,23 @@ function testSpeakerEmbeddingExtractor() { @@ -44,9 +44,23 @@ function testSpeakerEmbeddingExtractor() {
44 if [ ! -f ./speaker2_a_cn_16k.wav ]; then 44 if [ ! -f ./speaker2_a_cn_16k.wav ]; then
45 curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav 45 curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
46 fi 46 fi
  47 +
  48 + out_filename=test_speaker_id.jar
  49 + kotlinc-jvm -include-runtime -d $out_filename \
  50 + test_speaker_id.kt \
  51 + OnlineStream.kt \
  52 + Speaker.kt \
  53 + WaveReader.kt \
  54 + faked-asset-manager.kt \
  55 + faked-log.kt
  56 +
  57 + ls -lh $out_filename
  58 +
  59 + java -Djava.library.path=../build/lib -jar $out_filename
47 } 60 }
48 61
49 -function testAsr() { 62 +
  63 +function testOnlineAsr() {
50 if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then 64 if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
51 git lfs install 65 git lfs install
52 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 66 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
@@ -57,6 +71,20 @@ function testAsr() { @@ -57,6 +71,20 @@ function testAsr() {
57 tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 71 tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
58 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 72 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
59 fi 73 fi
  74 +
  75 + out_filename=test_online_asr.jar
  76 + kotlinc-jvm -include-runtime -d $out_filename \
  77 + test_online_asr.kt \
  78 + FeatureConfig.kt \
  79 + OnlineRecognizer.kt \
  80 + OnlineStream.kt \
  81 + WaveReader.kt \
  82 + faked-asset-manager.kt \
  83 + faked-log.kt
  84 +
  85 + ls -lh $out_filename
  86 +
  87 + java -Djava.library.path=../build/lib -jar $out_filename
60 } 88 }
61 89
62 function testTts() { 90 function testTts() {
@@ -65,16 +93,42 @@ function testTts() { @@ -65,16 +93,42 @@ function testTts() {
65 tar xf vits-piper-en_US-amy-low.tar.bz2 93 tar xf vits-piper-en_US-amy-low.tar.bz2
66 rm vits-piper-en_US-amy-low.tar.bz2 94 rm vits-piper-en_US-amy-low.tar.bz2
67 fi 95 fi
  96 +
  97 + out_filename=test_tts.jar
  98 + kotlinc-jvm -include-runtime -d $out_filename \
  99 + test_tts.kt \
  100 + Tts.kt \
  101 + faked-asset-manager.kt \
  102 + faked-log.kt
  103 +
  104 + ls -lh $out_filename
  105 +
  106 + java -Djava.library.path=../build/lib -jar $out_filename
68 } 107 }
69 108
  109 +
70 function testAudioTagging() { 110 function testAudioTagging() {
71 if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then 111 if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then
72 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 112 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
73 tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 113 tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
74 rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 114 rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
75 fi 115 fi
  116 +
  117 + out_filename=test_audio_tagging.jar
  118 + kotlinc-jvm -include-runtime -d $out_filename \
  119 + test_audio_tagging.kt \
  120 + AudioTagging.kt \
  121 + OfflineStream.kt \
  122 + WaveReader.kt \
  123 + faked-asset-manager.kt \
  124 + faked-log.kt
  125 +
  126 + ls -lh $out_filename
  127 +
  128 + java -Djava.library.path=../build/lib -jar $out_filename
76 } 129 }
77 130
  131 +
78 function testSpokenLanguageIdentification() { 132 function testSpokenLanguageIdentification() {
79 if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then 133 if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
80 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 134 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
@@ -87,50 +141,44 @@ function testSpokenLanguageIdentification() { @@ -87,50 +141,44 @@ function testSpokenLanguageIdentification() {
87 tar xvf spoken-language-identification-test-wavs.tar.bz2 141 tar xvf spoken-language-identification-test-wavs.tar.bz2
88 rm spoken-language-identification-test-wavs.tar.bz2 142 rm spoken-language-identification-test-wavs.tar.bz2
89 fi 143 fi
90 -}  
91 -  
92 -function test() {  
93 - testSpokenLanguageIdentification  
94 - testAudioTagging  
95 - testSpeakerEmbeddingExtractor  
96 - testAsr  
97 - testTts  
98 -}  
99 -  
100 -test  
101 144
102 -kotlinc-jvm -include-runtime -d main.jar \  
103 - AudioTagging.kt \  
104 - Main.kt \  
105 - OfflineStream.kt \  
106 - SherpaOnnx.kt \  
107 - Speaker.kt \ 145 + out_filename=test_language_id.jar
  146 + kotlinc-jvm -include-runtime -d $out_filename \
  147 + test_language_id.kt \
108 SpokenLanguageIdentification.kt \ 148 SpokenLanguageIdentification.kt \
109 - Tts.kt \ 149 + OfflineStream.kt \
110 WaveReader.kt \ 150 WaveReader.kt \
111 faked-asset-manager.kt \ 151 faked-asset-manager.kt \
112 faked-log.kt 152 faked-log.kt
113 153
114 -ls -lh main.jar  
115 -  
116 -java -Djava.library.path=../build/lib -jar main.jar 154 + ls -lh $out_filename
117 155
118 -function testTwoPass() {  
119 - if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then  
120 - curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
121 - tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
122 - rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
123 - fi 156 + java -Djava.library.path=../build/lib -jar $out_filename
  157 +}
124 158
  159 +function testOfflineAsr() {
125 if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then 160 if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
126 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 161 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
127 tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 162 tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
128 rm sherpa-onnx-whisper-tiny.en.tar.bz2 163 rm sherpa-onnx-whisper-tiny.en.tar.bz2
129 fi 164 fi
130 165
131 - kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt  
132 - ls -lh 2pass.jar  
133 - java -Djava.library.path=../build/lib -jar 2pass.jar 166 + out_filename=test_offline_asr.jar
  167 + kotlinc-jvm -include-runtime -d $out_filename \
  168 + test_offline_asr.kt \
  169 + FeatureConfig.kt \
  170 + OfflineRecognizer.kt \
  171 + OfflineStream.kt \
  172 + WaveReader.kt \
  173 + faked-asset-manager.kt
  174 +
  175 + ls -lh $out_filename
  176 + java -Djava.library.path=../build/lib -jar $out_filename
134 } 177 }
135 178
136 -testTwoPass 179 +testSpeakerEmbeddingExtractor
  180 +testOnlineAsr
  181 +testTts
  182 +testAudioTagging
  183 +testSpokenLanguageIdentification
  184 +testOfflineAsr
1 -package com.k2fsa.sherpa.onnx  
2 -  
3 -fun main() {  
4 - test2Pass()  
5 -}  
6 -  
7 -fun test2Pass() {  
8 - val firstPass = createFirstPass()  
9 - val secondPass = createSecondPass()  
10 -  
11 - val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"  
12 -  
13 - var objArray = WaveReader.readWaveFromFile(  
14 - filename = waveFilename,  
15 - )  
16 - var samples: FloatArray = objArray[0] as FloatArray  
17 - var sampleRate: Int = objArray[1] as Int  
18 -  
19 - firstPass.acceptWaveform(samples, sampleRate = sampleRate)  
20 - while (firstPass.isReady()) {  
21 - firstPass.decode()  
22 - }  
23 -  
24 - var text = firstPass.text  
25 - println("First pass text: $text")  
26 -  
27 - text = secondPass.decode(samples, sampleRate)  
28 - println("Second pass text: $text")  
29 -}  
30 -  
31 -fun createFirstPass(): SherpaOnnx {  
32 - val config = OnlineRecognizerConfig(  
33 - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),  
34 - modelConfig = getModelConfig(type = 1)!!,  
35 - endpointConfig = getEndpointConfig(),  
36 - enableEndpoint = true,  
37 - )  
38 -  
39 - return SherpaOnnx(config = config)  
40 -}  
41 -  
42 -fun createSecondPass(): SherpaOnnxOffline {  
43 - val config = OfflineRecognizerConfig(  
44 - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),  
45 - modelConfig = getOfflineModelConfig(type = 2)!!,  
46 - )  
47 -  
48 - return SherpaOnnxOffline(config = config)  
49 -}  
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testAudioTagging()
  5 +}
  6 +
  7 +fun testAudioTagging() {
  8 + val config = AudioTaggingConfig(
  9 + model=AudioTaggingModelConfig(
  10 + zipformer=OfflineZipformerAudioTaggingModelConfig(
  11 + model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
  12 + ),
  13 + numThreads=1,
  14 + debug=true,
  15 + provider="cpu",
  16 + ),
  17 + labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
  18 + topK=5,
  19 + )
  20 + val tagger = AudioTagging(config=config)
  21 +
  22 + val testFiles = arrayOf(
  23 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
  24 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
  25 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
  26 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
  27 + )
  28 + println("----------")
  29 + for (waveFilename in testFiles) {
  30 + val stream = tagger.createStream()
  31 +
  32 + val objArray = WaveReader.readWaveFromFile(
  33 + filename = waveFilename,
  34 + )
  35 + val samples: FloatArray = objArray[0] as FloatArray
  36 + val sampleRate: Int = objArray[1] as Int
  37 +
  38 + stream.acceptWaveform(samples, sampleRate = sampleRate)
  39 + val events = tagger.compute(stream)
  40 + stream.release()
  41 +
  42 + println(waveFilename)
  43 + println(events)
  44 + println("----------")
  45 + }
  46 +
  47 + tagger.release()
  48 +}
  49 +
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testSpokenLanguageIdentifcation()
  5 +}
  6 +
  7 +fun testSpokenLanguageIdentifcation() {
  8 + val config = SpokenLanguageIdentificationConfig(
  9 + whisper = SpokenLanguageIdentificationWhisperConfig(
  10 + encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
  11 + decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
  12 + tailPaddings = 33,
  13 + ),
  14 + numThreads=1,
  15 + debug=true,
  16 + provider="cpu",
  17 + )
  18 + val slid = SpokenLanguageIdentification(config=config)
  19 +
  20 + val testFiles = arrayOf(
  21 + "./spoken-language-identification-test-wavs/ar-arabic.wav",
  22 + "./spoken-language-identification-test-wavs/bg-bulgarian.wav",
  23 + "./spoken-language-identification-test-wavs/de-german.wav",
  24 + )
  25 +
  26 + for (waveFilename in testFiles) {
  27 + val objArray = WaveReader.readWaveFromFile(
  28 + filename = waveFilename,
  29 + )
  30 + val samples: FloatArray = objArray[0] as FloatArray
  31 + val sampleRate: Int = objArray[1] as Int
  32 +
  33 + val stream = slid.createStream()
  34 + stream.acceptWaveform(samples, sampleRate = sampleRate)
  35 + val lang = slid.compute(stream)
  36 + stream.release()
  37 + println(waveFilename)
  38 + println(lang)
  39 + }
  40 +
  41 + slid.release()
  42 +}
  43 +
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + val recognizer = createOfflineRecognizer()
  5 +
  6 + val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"
  7 +
  8 + val objArray = WaveReader.readWaveFromFile(
  9 + filename = waveFilename,
  10 + )
  11 + val samples: FloatArray = objArray[0] as FloatArray
  12 + val sampleRate: Int = objArray[1] as Int
  13 +
  14 + val stream = recognizer.createStream()
  15 + stream.acceptWaveform(samples, sampleRate=sampleRate)
  16 + recognizer.decode(stream)
  17 +
  18 + val result = recognizer.getResult(stream)
  19 + println(result)
  20 +
  21 + stream.release()
  22 + recognizer.release()
  23 +}
  24 +
  25 +fun createOfflineRecognizer(): OfflineRecognizer {
  26 + val config = OfflineRecognizerConfig(
  27 + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
  28 + modelConfig = getOfflineModelConfig(type = 2)!!,
  29 + )
  30 +
  31 + return OfflineRecognizer(config = config)
  32 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testOnlineAsr("transducer")
  5 + testOnlineAsr("zipformer2-ctc")
  6 +}
  7 +
  8 +fun testOnlineAsr(type: String) {
  9 + val featConfig = FeatureConfig(
  10 + sampleRate = 16000,
  11 + featureDim = 80,
  12 + )
  13 +
  14 + val waveFilename: String
  15 + val modelConfig: OnlineModelConfig = when (type) {
  16 + "transducer" -> {
  17 + waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
  18 + // please refer to
  19 + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  20 + // to dowload pre-trained models
  21 + OnlineModelConfig(
  22 + transducer = OnlineTransducerModelConfig(
  23 + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
  24 + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
  25 + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
  26 + ),
  27 + tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
  28 + numThreads = 1,
  29 + debug = false,
  30 + )
  31 + }
  32 + "zipformer2-ctc" -> {
  33 + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
  34 + OnlineModelConfig(
  35 + zipformer2Ctc = OnlineZipformer2CtcModelConfig(
  36 + model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
  37 + ),
  38 + tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
  39 + numThreads = 1,
  40 + debug = false,
  41 + )
  42 + }
  43 + else -> throw IllegalArgumentException(type)
  44 + }
  45 +
  46 + val endpointConfig = EndpointConfig()
  47 +
  48 + val lmConfig = OnlineLMConfig()
  49 +
  50 + val config = OnlineRecognizerConfig(
  51 + modelConfig = modelConfig,
  52 + lmConfig = lmConfig,
  53 + featConfig = featConfig,
  54 + endpointConfig = endpointConfig,
  55 + enableEndpoint = true,
  56 + decodingMethod = "greedy_search",
  57 + maxActivePaths = 4,
  58 + )
  59 +
  60 + val recognizer = OnlineRecognizer(
  61 + config = config,
  62 + )
  63 +
  64 + val objArray = WaveReader.readWaveFromFile(
  65 + filename = waveFilename,
  66 + )
  67 + val samples: FloatArray = objArray[0] as FloatArray
  68 + val sampleRate: Int = objArray[1] as Int
  69 +
  70 + val stream = recognizer.createStream()
  71 + stream.acceptWaveform(samples, sampleRate = sampleRate)
  72 + while (recognizer.isReady(stream)) {
  73 + recognizer.decode(stream)
  74 + }
  75 +
  76 + val tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
  77 + stream.acceptWaveform(tailPaddings, sampleRate = sampleRate)
  78 + stream.inputFinished()
  79 + while (recognizer.isReady(stream)) {
  80 + recognizer.decode(stream)
  81 + }
  82 +
  83 + println("results: ${recognizer.getResult(stream).text}")
  84 +
  85 + stream.release()
  86 + recognizer.release()
  87 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testSpeakerRecognition()
  5 +}
  6 +
  7 +fun testSpeakerRecognition() {
  8 + val config = SpeakerEmbeddingExtractorConfig(
  9 + model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
  10 + )
  11 + val extractor = SpeakerEmbeddingExtractor(config = config)
  12 +
  13 + val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")
  14 + val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")
  15 + val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")
  16 +
  17 + var manager = SpeakerEmbeddingManager(extractor.dim())
  18 + var ok = manager.add(name = "speaker1", embedding=embedding1a)
  19 + check(ok)
  20 +
  21 + manager.add(name = "speaker2", embedding=embedding2a)
  22 + check(ok)
  23 +
  24 + var name = manager.search(embedding=embedding1b, threshold=0.5f)
  25 + check(name == "speaker1")
  26 +
  27 + manager.release()
  28 +
  29 + manager = SpeakerEmbeddingManager(extractor.dim())
  30 + val embeddingList = mutableListOf(embedding1a, embedding1b)
  31 + ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())
  32 + check(ok)
  33 +
  34 + name = manager.search(embedding=embedding1b, threshold=0.5f)
  35 + check(name == "s1")
  36 +
  37 + name = manager.search(embedding=embedding2a, threshold=0.5f)
  38 + check(name.length == 0)
  39 +
  40 + manager.release()
  41 + extractor.release()
  42 + println("Speaker ID test done!")
  43 +}
  44 +
  45 +fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
  46 + var objArray = WaveReader.readWaveFromFile(
  47 + filename = filename,
  48 + )
  49 + var samples: FloatArray = objArray[0] as FloatArray
  50 + var sampleRate: Int = objArray[1] as Int
  51 +
  52 + val stream = extractor.createStream()
  53 + stream.acceptWaveform(sampleRate = sampleRate, samples=samples)
  54 + stream.inputFinished()
  55 + check(extractor.isReady(stream))
  56 +
  57 + val embedding = extractor.compute(stream)
  58 +
  59 + stream.release()
  60 +
  61 + return embedding
  62 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testTts()
  5 +}
  6 +
  7 +fun testTts() {
  8 + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
  9 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
  10 + var config = OfflineTtsConfig(
  11 + model=OfflineTtsModelConfig(
  12 + vits=OfflineTtsVitsModelConfig(
  13 + model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx",
  14 + tokens="./vits-piper-en_US-amy-low/tokens.txt",
  15 + dataDir="./vits-piper-en_US-amy-low/espeak-ng-data",
  16 + ),
  17 + numThreads=1,
  18 + debug=true,
  19 + )
  20 + )
  21 + val tts = OfflineTts(config=config)
  22 + val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
  23 + audio.save(filename="test-en.wav")
  24 + tts.release()
  25 + println("Saved to test-en.wav")
  26 +}
  27 +
  28 +fun callback(samples: FloatArray): Unit {
  29 + println("callback got called with ${samples.size} samples");
  30 +}
  1 +#!/usr/bin/env bash
  2 +#
  3 +# Auto generated! Please DO NOT EDIT!
  4 +
  5 +# Please set the environment variable ANDROID_NDK
  6 +# before running this script
  7 +
  8 +# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
  9 +# and some other files like the file "build/cmake/android.toolchain.cmake"
  10 +
  11 +set -ex
  12 +
  13 +log() {
  14 + # This function is from espnet
  15 + local fname=${BASH_SOURCE[1]##*/}
  16 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  17 +}
  18 +
  19 +SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
  20 +
  21 +log "Building streaming ASR APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
  22 +
  23 +export SHERPA_ONNX_ENABLE_TTS=OFF
  24 +
  25 +log "====================arm64-v8a================="
  26 +./build-android-arm64-v8a.sh
  27 +log "====================armv7-eabi================"
  28 +./build-android-armv7-eabi.sh
  29 +log "====================x86-64===================="
  30 +./build-android-x86-64.sh
  31 +log "====================x86===================="
  32 +./build-android-x86.sh
  33 +
  34 +mkdir -p apks
  35 +
  36 +{% for model in model_list %}
  37 +pushd ./android/SherpaOnnx/app/src/main/assets/
  38 +model_name={{ model.model_name }}
  39 +type={{ model.idx }}
  40 +lang={{ model.lang }}
  41 +short_name={{ model.short_name }}
  42 +
  43 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/${model_name}.tar.bz2
  44 +tar xvf ${model_name}.tar.bz2
  45 +
  46 +{{ model.cmd }}
  47 +
  48 +rm -rf *.tar.bz2
  49 +ls -lh $model_name
  50 +
  51 +popd
  52 +# Now we are at the project root directory
  53 +
  54 +git checkout .
  55 +pushd android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx
  56 +sed -i.bak s/"type = 0/type = $type/" ./MainActivity.kt
  57 +git diff
  58 +popd
  59 +
  60 +for arch in arm64-v8a armeabi-v7a x86_64 x86; do
  61 + log "------------------------------------------------------------"
  62 + log "build ASR apk for $arch"
  63 + log "------------------------------------------------------------"
  64 + src_arch=$arch
  65 + if [ $arch == "armeabi-v7a" ]; then
  66 + src_arch=armv7-eabi
  67 + elif [ $arch == "x86_64" ]; then
  68 + src_arch=x86-64
  69 + fi
  70 +
  71 + ls -lh ./build-android-$src_arch/install/lib/*.so
  72 +
  73 + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnx/app/src/main/jniLibs/$arch/
  74 +
  75 + pushd ./android/SherpaOnnx
  76 + sed -i.bak s/2048/9012/g ./gradle.properties
  77 + git diff ./gradle.properties
  78 + ./gradlew assembleRelease
  79 + popd
  80 +
  81 + mv android/SherpaOnnx/app/build/outputs/apk/release/app-release-unsigned.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-asr-$lang-$short_name.apk
  82 + ls -lh apks
  83 + rm -v ./android/SherpaOnnx/app/src/main/jniLibs/$arch/*.so
  84 +done
  85 +
  86 +rm -rf ./android/SherpaOnnx/app/src/main/assets/$model_name
  87 +{% endfor %}
  88 +
  89 +git checkout .
  90 +
  91 +ls -lh apks/
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=OFF
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for model in model_list %} 36 {% for model in model_list %}
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=OFF
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for model in model_list %} 36 {% for model in model_list %}
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=OFF
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for model in model_list %} 36 {% for model in model_list %}
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=OFF
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for model in model_list %} 36 {% for model in model_list %}
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=ON
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for tts_model in tts_model_list %} 36 {% for tts_model in tts_model_list %}
@@ -29,6 +29,8 @@ log "====================x86-64====================" @@ -29,6 +29,8 @@ log "====================x86-64===================="
29 log "====================x86====================" 29 log "====================x86===================="
30 ./build-android-x86.sh 30 ./build-android-x86.sh
31 31
  32 +export SHERPA_ONNX_ENABLE_TTS=ON
  33 +
32 mkdir -p apks 34 mkdir -p apks
33 35
34 {% for tts_model in tts_model_list %} 36 {% for tts_model in tts_model_list %}
  1 +#!/usr/bin/env python3
  2 +
  3 +import argparse
  4 +from dataclasses import dataclass
  5 +from typing import List, Optional
  6 +
  7 +import jinja2
  8 +
  9 +
  10 +def get_args():
  11 + parser = argparse.ArgumentParser()
  12 + parser.add_argument(
  13 + "--total",
  14 + type=int,
  15 + default=1,
  16 + help="Number of runners",
  17 + )
  18 + parser.add_argument(
  19 + "--index",
  20 + type=int,
  21 + default=0,
  22 + help="Index of the current runner",
  23 + )
  24 + return parser.parse_args()
  25 +
  26 +
  27 +@dataclass
  28 +class Model:
  29 + # We will download
  30 + # https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/{model_name}.tar.bz2
  31 + model_name: str
  32 +
  33 + # The type of the model, e..g, 0, 1, 2. It is hardcoded in the kotlin code
  34 + idx: int
  35 +
  36 + # e.g., zh, en, zh_en
  37 + lang: str
  38 +
  39 + # e.g., whisper, paraformer, zipformer
  40 + short_name: str = ""
  41 +
  42 + # cmd is used to remove extra file from the model directory
  43 + cmd: str = ""
  44 +
  45 +
  46 +def get_models():
  47 + models = [
  48 + Model(
  49 + model_name="sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20",
  50 + idx=8,
  51 + lang="bilingual_zh_en",
  52 + short_name="zipformer",
  53 + cmd="""
  54 + pushd $model_name
  55 + rm -v decoder-epoch-99-avg-1.int8.onnx
  56 + rm -v encoder-epoch-99-avg-1.onnx
  57 + rm -v joiner-epoch-99-avg-1.onnx
  58 +
  59 + rm -v *.sh
  60 + rm -v .gitattributes
  61 + rm -v *state*
  62 + rm -rfv test_wavs
  63 +
  64 + ls -lh
  65 +
  66 + popd
  67 + """,
  68 + ),
  69 + ]
  70 +
  71 + return models
  72 +
  73 +
  74 +def main():
  75 + args = get_args()
  76 + index = args.index
  77 + total = args.total
  78 + assert 0 <= index < total, (index, total)
  79 +
  80 + all_model_list = get_models()
  81 +
  82 + num_models = len(all_model_list)
  83 +
  84 + num_per_runner = num_models // total
  85 + if num_per_runner <= 0:
  86 + raise ValueError(f"num_models: {num_models}, num_runners: {total}")
  87 +
  88 + start = index * num_per_runner
  89 + end = start + num_per_runner
  90 +
  91 + remaining = num_models - args.total * num_per_runner
  92 +
  93 + print(f"{index}/{total}: {start}-{end}/{num_models}")
  94 +
  95 + d = dict()
  96 + d["model_list"] = all_model_list[start:end]
  97 + if index < remaining:
  98 + s = args.total * num_per_runner + index
  99 + d["model_list"].append(all_model_list[s])
  100 + print(f"{s}/{num_models}")
  101 +
  102 + filename_list = [
  103 + "./build-apk-asr.sh",
  104 + ]
  105 + for filename in filename_list:
  106 + environment = jinja2.Environment()
  107 + with open(f"{filename}.in") as f:
  108 + s = f.read()
  109 + template = environment.from_string(s)
  110 +
  111 + s = template.render(**d)
  112 + with open(filename, "w") as f:
  113 + print(s, file=f)
  114 +
  115 +
  116 +if __name__ == "__main__":
  117 + main()
@@ -82,7 +82,7 @@ bool OfflineTtsVitsModelConfig::Validate() const { @@ -82,7 +82,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
82 82
83 for (const auto &f : required_files) { 83 for (const auto &f : required_files) {
84 if (!FileExists(dict_dir + "/" + f)) { 84 if (!FileExists(dict_dir + "/" + f)) {
85 - SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(), 85 + SHERPA_ONNX_LOGE("'%s/%s' does not exist.", dict_dir.c_str(),
86 f.c_str()); 86 f.c_str());
87 return false; 87 return false;
88 } 88 }
@@ -12,8 +12,15 @@ endif() @@ -12,8 +12,15 @@ endif()
12 set(sources 12 set(sources
13 audio-tagging.cc 13 audio-tagging.cc
14 jni.cc 14 jni.cc
  15 + keyword-spotter.cc
  16 + offline-recognizer.cc
15 offline-stream.cc 17 offline-stream.cc
  18 + online-recognizer.cc
  19 + online-stream.cc
  20 + speaker-embedding-extractor.cc
  21 + speaker-embedding-manager.cc
16 spoken-language-identification.cc 22 spoken-language-identification.cc
  23 + voice-activity-detector.cc
17 ) 24 )
18 25
19 if(SHERPA_ONNX_ENABLE_TTS) 26 if(SHERPA_ONNX_ENABLE_TTS)
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 #define SHERPA_ONNX_JNI_COMMON_H_ 6 #define SHERPA_ONNX_JNI_COMMON_H_
7 7
8 #if __ANDROID_API__ >= 9 8 #if __ANDROID_API__ >= 9
  9 +#include <strstream>
  10 +
9 #include "android/asset_manager.h" 11 #include "android/asset_manager.h"
10 #include "android/asset_manager_jni.h" 12 #include "android/asset_manager_jni.h"
11 #endif 13 #endif
@@ -4,1530 +4,43 @@ @@ -4,1530 +4,43 @@
4 // 2022 Pingfeng Luo 4 // 2022 Pingfeng Luo
5 // 2023 Zhaoming 5 // 2023 Zhaoming
6 6
7 -// TODO(fangjun): Add documentation to functions/methods in this file  
8 -// and also show how to use them with kotlin, possibly with java.  
9 -  
10 -#include <fstream>  
11 -#include <functional>  
12 -#include <strstream>  
13 -#include <utility>  
14 -  
15 -#include "sherpa-onnx/csrc/keyword-spotter.h"  
16 -#include "sherpa-onnx/csrc/macros.h"  
17 -#include "sherpa-onnx/csrc/offline-recognizer.h"  
18 -#include "sherpa-onnx/csrc/online-recognizer.h"  
19 -#include "sherpa-onnx/csrc/onnx-utils.h"  
20 -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"  
21 -#include "sherpa-onnx/csrc/speaker-embedding-manager.h"  
22 -#include "sherpa-onnx/csrc/voice-activity-detector.h"  
23 -#include "sherpa-onnx/csrc/wave-reader.h"  
24 -#include "sherpa-onnx/csrc/wave-writer.h"  
25 -#include "sherpa-onnx/jni/common.h"  
26 -  
27 -namespace sherpa_onnx {  
28 -  
29 -class SherpaOnnx {  
30 - public:  
31 -#if __ANDROID_API__ >= 9  
32 - SherpaOnnx(AAssetManager *mgr, const OnlineRecognizerConfig &config)  
33 - : recognizer_(mgr, config), stream_(recognizer_.CreateStream()) {}  
34 -#endif  
35 -  
36 - explicit SherpaOnnx(const OnlineRecognizerConfig &config)  
37 - : recognizer_(config), stream_(recognizer_.CreateStream()) {}  
38 -  
39 - void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {  
40 - if (input_sample_rate_ == -1) {  
41 - input_sample_rate_ = sample_rate;  
42 - }  
43 -  
44 - stream_->AcceptWaveform(sample_rate, samples, n);  
45 - }  
46 -  
47 - void InputFinished() const {  
48 - std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);  
49 - stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),  
50 - tail_padding.size());  
51 - stream_->InputFinished();  
52 - }  
53 -  
54 - std::string GetText() const {  
55 - auto result = recognizer_.GetResult(stream_.get());  
56 - return result.text;  
57 - }  
58 -  
59 - const std::vector<std::string> GetTokens() const {  
60 - auto result = recognizer_.GetResult(stream_.get());  
61 - return result.tokens;  
62 - }  
63 -  
64 - bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }  
65 -  
66 - bool IsReady() const { return recognizer_.IsReady(stream_.get()); }  
67 -  
68 - // If keywords is an empty string, it just recreates the decoding stream  
69 - // If keywords is not empty, it will create a new decoding stream with  
70 - // the given keywords appended to the default keywords.  
71 - void Reset(bool recreate, const std::string &keywords = {}) {  
72 - if (keywords.empty()) {  
73 - if (recreate) {  
74 - stream_ = recognizer_.CreateStream();  
75 - } else {  
76 - recognizer_.Reset(stream_.get());  
77 - }  
78 - } else {  
79 - auto stream = recognizer_.CreateStream(keywords);  
80 - // Set new keywords failed, the stream_ will not be updated.  
81 - if (stream != nullptr) {  
82 - stream_ = std::move(stream);  
83 - } else {  
84 - SHERPA_ONNX_LOGE("Failed to set keywords: %s", keywords.c_str());  
85 - }  
86 - }  
87 - }  
88 -  
89 - void Decode() const { recognizer_.DecodeStream(stream_.get()); }  
90 -  
91 - private:  
92 - OnlineRecognizer recognizer_;  
93 - std::unique_ptr<OnlineStream> stream_;  
94 - int32_t input_sample_rate_ = -1;  
95 -};  
96 -  
97 -class SherpaOnnxOffline {  
98 - public:  
99 -#if __ANDROID_API__ >= 9  
100 - SherpaOnnxOffline(AAssetManager *mgr, const OfflineRecognizerConfig &config)  
101 - : recognizer_(mgr, config) {}  
102 -#endif  
103 -  
104 - explicit SherpaOnnxOffline(const OfflineRecognizerConfig &config)  
105 - : recognizer_(config) {}  
106 -  
107 - std::string Decode(int32_t sample_rate, const float *samples, int32_t n) {  
108 - auto stream = recognizer_.CreateStream();  
109 - stream->AcceptWaveform(sample_rate, samples, n);  
110 -  
111 - recognizer_.DecodeStream(stream.get());  
112 - return stream->GetResult().text;  
113 - }  
114 -  
115 - private:  
116 - OfflineRecognizer recognizer_;  
117 -};  
118 -  
119 -class SherpaOnnxVad {  
120 - public:  
121 -#if __ANDROID_API__ >= 9  
122 - SherpaOnnxVad(AAssetManager *mgr, const VadModelConfig &config)  
123 - : vad_(mgr, config) {}  
124 -#endif  
125 -  
126 - explicit SherpaOnnxVad(const VadModelConfig &config) : vad_(config) {}  
127 -  
128 - void AcceptWaveform(const float *samples, int32_t n) {  
129 - vad_.AcceptWaveform(samples, n);  
130 - }  
131 -  
132 - bool Empty() const { return vad_.Empty(); }  
133 -  
134 - void Pop() { vad_.Pop(); }  
135 -  
136 - void Clear() { vad_.Clear(); }  
137 -  
138 - const SpeechSegment &Front() const { return vad_.Front(); }  
139 -  
140 - bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }  
141 -  
142 - void Reset() { vad_.Reset(); }  
143 -  
144 - private:  
145 - VoiceActivityDetector vad_;  
146 -};  
147 -  
148 -class SherpaOnnxKws {  
149 - public:  
150 -#if __ANDROID_API__ >= 9  
151 - SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)  
152 - : keyword_spotter_(mgr, config),  
153 - stream_(keyword_spotter_.CreateStream()) {}  
154 -#endif  
155 -  
156 - explicit SherpaOnnxKws(const KeywordSpotterConfig &config)  
157 - : keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}  
158 -  
159 - void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {  
160 - if (input_sample_rate_ == -1) {  
161 - input_sample_rate_ = sample_rate;  
162 - }  
163 -  
164 - stream_->AcceptWaveform(sample_rate, samples, n);  
165 - }  
166 -  
167 - void InputFinished() const {  
168 - std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);  
169 - stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),  
170 - tail_padding.size());  
171 - stream_->InputFinished();  
172 - }  
173 -  
174 - // If keywords is an empty string, it just recreates the decoding stream  
175 - // always returns true in this case.  
176 - // If keywords is not empty, it will create a new decoding stream with  
177 - // the given keywords appended to the default keywords.  
178 - // Return false if errors occurred when adding keywords, true otherwise.  
179 - bool Reset(const std::string &keywords = {}) {  
180 - if (keywords.empty()) {  
181 - stream_ = keyword_spotter_.CreateStream();  
182 - return true;  
183 - } else {  
184 - auto stream = keyword_spotter_.CreateStream(keywords);  
185 - // Set new keywords failed, the stream_ will not be updated.  
186 - if (stream == nullptr) {  
187 - return false;  
188 - } else {  
189 - stream_ = std::move(stream);  
190 - return true;  
191 - }  
192 - }  
193 - }  
194 -  
195 - std::string GetKeyword() const {  
196 - auto result = keyword_spotter_.GetResult(stream_.get());  
197 - return result.keyword;  
198 - }  
199 -  
200 - std::vector<std::string> GetTokens() const {  
201 - auto result = keyword_spotter_.GetResult(stream_.get());  
202 - return result.tokens;  
203 - }  
204 -  
205 - bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }  
206 -  
207 - void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }  
208 -  
209 - private:  
210 - KeywordSpotter keyword_spotter_;  
211 - std::unique_ptr<OnlineStream> stream_;  
212 - int32_t input_sample_rate_ = -1;  
213 -};  
214 -  
215 -class SherpaOnnxSpeakerEmbeddingExtractorStream {  
216 - public:  
217 - explicit SherpaOnnxSpeakerEmbeddingExtractorStream(  
218 - std::unique_ptr<OnlineStream> stream)  
219 - : stream_(std::move(stream)) {}  
220 -  
221 - void AcceptWaveform(int32_t sample_rate, const float *samples,  
222 - int32_t n) const {  
223 - stream_->AcceptWaveform(sample_rate, samples, n);  
224 - }  
225 -  
226 - void InputFinished() const { stream_->InputFinished(); }  
227 -  
228 - OnlineStream *Get() const { return stream_.get(); }  
229 -  
230 - private:  
231 - std::unique_ptr<OnlineStream> stream_;  
232 -};  
233 -  
234 -class SherpaOnnxSpeakerEmbeddingExtractor {  
235 - public:  
236 -#if __ANDROID_API__ >= 9  
237 - SherpaOnnxSpeakerEmbeddingExtractor(  
238 - AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)  
239 - : extractor_(mgr, config) {}  
240 -#endif  
241 -  
242 - explicit SherpaOnnxSpeakerEmbeddingExtractor(  
243 - const SpeakerEmbeddingExtractorConfig &config)  
244 - : extractor_(config) {}  
245 -  
246 - int32_t Dim() const { return extractor_.Dim(); }  
247 -  
248 - bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {  
249 - return extractor_.IsReady(stream->Get());  
250 - }  
251 -  
252 - SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const {  
253 - return new SherpaOnnxSpeakerEmbeddingExtractorStream(  
254 - extractor_.CreateStream());  
255 - }  
256 -  
257 - std::vector<float> Compute(  
258 - const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {  
259 - return extractor_.Compute(stream->Get());  
260 - }  
261 -  
262 - private:  
263 - SpeakerEmbeddingExtractor extractor_;  
264 -};  
265 -  
266 -static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(  
267 - JNIEnv *env, jobject config) {  
268 - SpeakerEmbeddingExtractorConfig ans;  
269 -  
270 - jclass cls = env->GetObjectClass(config);  
271 -  
272 - jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");  
273 - jstring s = (jstring)env->GetObjectField(config, fid);  
274 - const char *p = env->GetStringUTFChars(s, nullptr);  
275 -  
276 - ans.model = p;  
277 - env->ReleaseStringUTFChars(s, p);  
278 -  
279 - fid = env->GetFieldID(cls, "numThreads", "I");  
280 - ans.num_threads = env->GetIntField(config, fid);  
281 -  
282 - fid = env->GetFieldID(cls, "debug", "Z");  
283 - ans.debug = env->GetBooleanField(config, fid);  
284 -  
285 - fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");  
286 - s = (jstring)env->GetObjectField(config, fid);  
287 - p = env->GetStringUTFChars(s, nullptr);  
288 - ans.provider = p;  
289 - env->ReleaseStringUTFChars(s, p);  
290 -  
291 - return ans;  
292 -}  
293 -  
294 -static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {  
295 - OnlineRecognizerConfig ans;  
296 -  
297 - jclass cls = env->GetObjectClass(config);  
298 - jfieldID fid;  
299 -  
300 - // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html  
301 - // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html  
302 -  
303 - //---------- decoding ----------  
304 - fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");  
305 - jstring s = (jstring)env->GetObjectField(config, fid);  
306 - const char *p = env->GetStringUTFChars(s, nullptr);  
307 - ans.decoding_method = p;  
308 - env->ReleaseStringUTFChars(s, p);  
309 -  
310 - fid = env->GetFieldID(cls, "maxActivePaths", "I");  
311 - ans.max_active_paths = env->GetIntField(config, fid);  
312 -  
313 - fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");  
314 - s = (jstring)env->GetObjectField(config, fid);  
315 - p = env->GetStringUTFChars(s, nullptr);  
316 - ans.hotwords_file = p;  
317 - env->ReleaseStringUTFChars(s, p);  
318 -  
319 - fid = env->GetFieldID(cls, "hotwordsScore", "F");  
320 - ans.hotwords_score = env->GetFloatField(config, fid);  
321 -  
322 - //---------- feat config ----------  
323 - fid = env->GetFieldID(cls, "featConfig",  
324 - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");  
325 - jobject feat_config = env->GetObjectField(config, fid);  
326 - jclass feat_config_cls = env->GetObjectClass(feat_config);  
327 -  
328 - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");  
329 - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);  
330 -  
331 - fid = env->GetFieldID(feat_config_cls, "featureDim", "I");  
332 - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);  
333 -  
334 - //---------- enable endpoint ----------  
335 - fid = env->GetFieldID(cls, "enableEndpoint", "Z");  
336 - ans.enable_endpoint = env->GetBooleanField(config, fid);  
337 -  
338 - //---------- endpoint_config ----------  
339 -  
340 - fid = env->GetFieldID(cls, "endpointConfig",  
341 - "Lcom/k2fsa/sherpa/onnx/EndpointConfig;");  
342 - jobject endpoint_config = env->GetObjectField(config, fid);  
343 - jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);  
344 -  
345 - fid = env->GetFieldID(endpoint_config_cls, "rule1",  
346 - "Lcom/k2fsa/sherpa/onnx/EndpointRule;");  
347 - jobject rule1 = env->GetObjectField(endpoint_config, fid);  
348 - jclass rule_class = env->GetObjectClass(rule1);  
349 -  
350 - fid = env->GetFieldID(endpoint_config_cls, "rule2",  
351 - "Lcom/k2fsa/sherpa/onnx/EndpointRule;");  
352 - jobject rule2 = env->GetObjectField(endpoint_config, fid);  
353 -  
354 - fid = env->GetFieldID(endpoint_config_cls, "rule3",  
355 - "Lcom/k2fsa/sherpa/onnx/EndpointRule;");  
356 - jobject rule3 = env->GetObjectField(endpoint_config, fid);  
357 -  
358 - fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");  
359 - ans.endpoint_config.rule1.must_contain_nonsilence =  
360 - env->GetBooleanField(rule1, fid);  
361 - ans.endpoint_config.rule2.must_contain_nonsilence =  
362 - env->GetBooleanField(rule2, fid);  
363 - ans.endpoint_config.rule3.must_contain_nonsilence =  
364 - env->GetBooleanField(rule3, fid);  
365 -  
366 - fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");  
367 - ans.endpoint_config.rule1.min_trailing_silence =  
368 - env->GetFloatField(rule1, fid);  
369 - ans.endpoint_config.rule2.min_trailing_silence =  
370 - env->GetFloatField(rule2, fid);  
371 - ans.endpoint_config.rule3.min_trailing_silence =  
372 - env->GetFloatField(rule3, fid);  
373 -  
374 - fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");  
375 - ans.endpoint_config.rule1.min_utterance_length =  
376 - env->GetFloatField(rule1, fid);  
377 - ans.endpoint_config.rule2.min_utterance_length =  
378 - env->GetFloatField(rule2, fid);  
379 - ans.endpoint_config.rule3.min_utterance_length =  
380 - env->GetFloatField(rule3, fid);  
381 -  
382 - //---------- model config ----------  
383 - fid = env->GetFieldID(cls, "modelConfig",  
384 - "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");  
385 - jobject model_config = env->GetObjectField(config, fid);  
386 - jclass model_config_cls = env->GetObjectClass(model_config);  
387 -  
388 - // transducer  
389 - fid = env->GetFieldID(model_config_cls, "transducer",  
390 - "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");  
391 - jobject transducer_config = env->GetObjectField(model_config, fid);  
392 - jclass transducer_config_cls = env->GetObjectClass(transducer_config);  
393 -  
394 - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");  
395 - s = (jstring)env->GetObjectField(transducer_config, fid);  
396 - p = env->GetStringUTFChars(s, nullptr);  
397 - ans.model_config.transducer.encoder = p;  
398 - env->ReleaseStringUTFChars(s, p);  
399 -  
400 - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");  
401 - s = (jstring)env->GetObjectField(transducer_config, fid);  
402 - p = env->GetStringUTFChars(s, nullptr);  
403 - ans.model_config.transducer.decoder = p;  
404 - env->ReleaseStringUTFChars(s, p);  
405 -  
406 - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");  
407 - s = (jstring)env->GetObjectField(transducer_config, fid);  
408 - p = env->GetStringUTFChars(s, nullptr);  
409 - ans.model_config.transducer.joiner = p;  
410 - env->ReleaseStringUTFChars(s, p);  
411 -  
412 - // paraformer  
413 - fid = env->GetFieldID(model_config_cls, "paraformer",  
414 - "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");  
415 - jobject paraformer_config = env->GetObjectField(model_config, fid);  
416 - jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);  
417 -  
418 - fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");  
419 - s = (jstring)env->GetObjectField(paraformer_config, fid);  
420 - p = env->GetStringUTFChars(s, nullptr);  
421 - ans.model_config.paraformer.encoder = p;  
422 - env->ReleaseStringUTFChars(s, p);  
423 -  
424 - fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");  
425 - s = (jstring)env->GetObjectField(paraformer_config, fid);  
426 - p = env->GetStringUTFChars(s, nullptr);  
427 - ans.model_config.paraformer.decoder = p;  
428 - env->ReleaseStringUTFChars(s, p);  
429 -  
430 - // streaming zipformer2 CTC  
431 - fid =  
432 - env->GetFieldID(model_config_cls, "zipformer2Ctc",  
433 - "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");  
434 - jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);  
435 - jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);  
436 -  
437 - fid =  
438 - env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");  
439 - s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);  
440 - p = env->GetStringUTFChars(s, nullptr);  
441 - ans.model_config.zipformer2_ctc.model = p;  
442 - env->ReleaseStringUTFChars(s, p);  
443 -  
444 - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");  
445 - s = (jstring)env->GetObjectField(model_config, fid);  
446 - p = env->GetStringUTFChars(s, nullptr);  
447 - ans.model_config.tokens = p;  
448 - env->ReleaseStringUTFChars(s, p);  
449 -  
450 - fid = env->GetFieldID(model_config_cls, "numThreads", "I");  
451 - ans.model_config.num_threads = env->GetIntField(model_config, fid);  
452 -  
453 - fid = env->GetFieldID(model_config_cls, "debug", "Z");  
454 - ans.model_config.debug = env->GetBooleanField(model_config, fid);  
455 -  
456 - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");  
457 - s = (jstring)env->GetObjectField(model_config, fid);  
458 - p = env->GetStringUTFChars(s, nullptr);  
459 - ans.model_config.provider = p;  
460 - env->ReleaseStringUTFChars(s, p);  
461 -  
462 - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");  
463 - s = (jstring)env->GetObjectField(model_config, fid);  
464 - p = env->GetStringUTFChars(s, nullptr);  
465 - ans.model_config.model_type = p;  
466 - env->ReleaseStringUTFChars(s, p);  
467 -  
468 - //---------- rnn lm model config ----------  
469 - fid = env->GetFieldID(cls, "lmConfig",  
470 - "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");  
471 - jobject lm_model_config = env->GetObjectField(config, fid);  
472 - jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);  
473 -  
474 - fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");  
475 - s = (jstring)env->GetObjectField(lm_model_config, fid);  
476 - p = env->GetStringUTFChars(s, nullptr);  
477 - ans.lm_config.model = p;  
478 - env->ReleaseStringUTFChars(s, p);  
479 -  
480 - fid = env->GetFieldID(lm_model_config_cls, "scale", "F");  
481 - ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);  
482 -  
483 - return ans;  
484 -}  
485 -  
486 -static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {  
487 - OfflineRecognizerConfig ans;  
488 -  
489 - jclass cls = env->GetObjectClass(config);  
490 - jfieldID fid;  
491 -  
492 - //---------- decoding ----------  
493 - fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");  
494 - jstring s = (jstring)env->GetObjectField(config, fid);  
495 - const char *p = env->GetStringUTFChars(s, nullptr);  
496 - ans.decoding_method = p;  
497 - env->ReleaseStringUTFChars(s, p);  
498 -  
499 - fid = env->GetFieldID(cls, "maxActivePaths", "I");  
500 - ans.max_active_paths = env->GetIntField(config, fid);  
501 -  
502 - fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");  
503 - s = (jstring)env->GetObjectField(config, fid);  
504 - p = env->GetStringUTFChars(s, nullptr);  
505 - ans.hotwords_file = p;  
506 - env->ReleaseStringUTFChars(s, p);  
507 -  
508 - fid = env->GetFieldID(cls, "hotwordsScore", "F");  
509 - ans.hotwords_score = env->GetFloatField(config, fid);  
510 -  
511 - //---------- feat config ----------  
512 - fid = env->GetFieldID(cls, "featConfig",  
513 - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");  
514 - jobject feat_config = env->GetObjectField(config, fid);  
515 - jclass feat_config_cls = env->GetObjectClass(feat_config);  
516 -  
517 - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");  
518 - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);  
519 -  
520 - fid = env->GetFieldID(feat_config_cls, "featureDim", "I");  
521 - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);  
522 -  
523 - //---------- model config ----------  
524 - fid = env->GetFieldID(cls, "modelConfig",  
525 - "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");  
526 - jobject model_config = env->GetObjectField(config, fid);  
527 - jclass model_config_cls = env->GetObjectClass(model_config);  
528 -  
529 - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");  
530 - s = (jstring)env->GetObjectField(model_config, fid);  
531 - p = env->GetStringUTFChars(s, nullptr);  
532 - ans.model_config.tokens = p;  
533 - env->ReleaseStringUTFChars(s, p);  
534 -  
535 - fid = env->GetFieldID(model_config_cls, "numThreads", "I");  
536 - ans.model_config.num_threads = env->GetIntField(model_config, fid);  
537 -  
538 - fid = env->GetFieldID(model_config_cls, "debug", "Z");  
539 - ans.model_config.debug = env->GetBooleanField(model_config, fid);  
540 -  
541 - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");  
542 - s = (jstring)env->GetObjectField(model_config, fid);  
543 - p = env->GetStringUTFChars(s, nullptr);  
544 - ans.model_config.provider = p;  
545 - env->ReleaseStringUTFChars(s, p);  
546 -  
547 - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");  
548 - s = (jstring)env->GetObjectField(model_config, fid);  
549 - p = env->GetStringUTFChars(s, nullptr);  
550 - ans.model_config.model_type = p;  
551 - env->ReleaseStringUTFChars(s, p);  
552 -  
553 - // transducer  
554 - fid = env->GetFieldID(model_config_cls, "transducer",  
555 - "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");  
556 - jobject transducer_config = env->GetObjectField(model_config, fid);  
557 - jclass transducer_config_cls = env->GetObjectClass(transducer_config);  
558 -  
559 - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");  
560 - s = (jstring)env->GetObjectField(transducer_config, fid);  
561 - p = env->GetStringUTFChars(s, nullptr);  
562 - ans.model_config.transducer.encoder_filename = p;  
563 - env->ReleaseStringUTFChars(s, p);  
564 -  
565 - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");  
566 - s = (jstring)env->GetObjectField(transducer_config, fid);  
567 - p = env->GetStringUTFChars(s, nullptr);  
568 - ans.model_config.transducer.decoder_filename = p;  
569 - env->ReleaseStringUTFChars(s, p);  
570 -  
571 - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");  
572 - s = (jstring)env->GetObjectField(transducer_config, fid);  
573 - p = env->GetStringUTFChars(s, nullptr);  
574 - ans.model_config.transducer.joiner_filename = p;  
575 - env->ReleaseStringUTFChars(s, p);  
576 -  
577 - // paraformer  
578 - fid = env->GetFieldID(model_config_cls, "paraformer",  
579 - "Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");  
580 - jobject paraformer_config = env->GetObjectField(model_config, fid);  
581 - jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);  
582 -  
583 - fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");  
584 -  
585 - s = (jstring)env->GetObjectField(paraformer_config, fid);  
586 - p = env->GetStringUTFChars(s, nullptr);  
587 - ans.model_config.paraformer.model = p;  
588 - env->ReleaseStringUTFChars(s, p);  
589 -  
590 - // whisper  
591 - fid = env->GetFieldID(model_config_cls, "whisper",  
592 - "Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");  
593 - jobject whisper_config = env->GetObjectField(model_config, fid);  
594 - jclass whisper_config_cls = env->GetObjectClass(whisper_config);  
595 -  
596 - fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");  
597 - s = (jstring)env->GetObjectField(whisper_config, fid);  
598 - p = env->GetStringUTFChars(s, nullptr);  
599 - ans.model_config.whisper.encoder = p;  
600 - env->ReleaseStringUTFChars(s, p);  
601 -  
602 - fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");  
603 - s = (jstring)env->GetObjectField(whisper_config, fid);  
604 - p = env->GetStringUTFChars(s, nullptr);  
605 - ans.model_config.whisper.decoder = p;  
606 - env->ReleaseStringUTFChars(s, p);  
607 -  
608 - fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;");  
609 - s = (jstring)env->GetObjectField(whisper_config, fid);  
610 - p = env->GetStringUTFChars(s, nullptr);  
611 - ans.model_config.whisper.language = p;  
612 - env->ReleaseStringUTFChars(s, p);  
613 -  
614 - fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;");  
615 - s = (jstring)env->GetObjectField(whisper_config, fid);  
616 - p = env->GetStringUTFChars(s, nullptr);  
617 - ans.model_config.whisper.task = p;  
618 - env->ReleaseStringUTFChars(s, p);  
619 -  
620 - fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I");  
621 - ans.model_config.whisper.tail_paddings =  
622 - env->GetIntField(whisper_config, fid);  
623 -  
624 - return ans;  
625 -}  
626 -  
627 -static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {  
628 - KeywordSpotterConfig ans;  
629 -  
630 - jclass cls = env->GetObjectClass(config);  
631 - jfieldID fid;  
632 -  
633 - // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html  
634 - // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html  
635 -  
636 - //---------- decoding ----------  
637 - fid = env->GetFieldID(cls, "maxActivePaths", "I");  
638 - ans.max_active_paths = env->GetIntField(config, fid);  
639 -  
640 - fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");  
641 - jstring s = (jstring)env->GetObjectField(config, fid);  
642 - const char *p = env->GetStringUTFChars(s, nullptr);  
643 - ans.keywords_file = p;  
644 - env->ReleaseStringUTFChars(s, p);  
645 -  
646 - fid = env->GetFieldID(cls, "keywordsScore", "F");  
647 - ans.keywords_score = env->GetFloatField(config, fid);  
648 -  
649 - fid = env->GetFieldID(cls, "keywordsThreshold", "F");  
650 - ans.keywords_threshold = env->GetFloatField(config, fid);  
651 -  
652 - fid = env->GetFieldID(cls, "numTrailingBlanks", "I");  
653 - ans.num_trailing_blanks = env->GetIntField(config, fid);  
654 -  
655 - //---------- feat config ----------  
656 - fid = env->GetFieldID(cls, "featConfig",  
657 - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");  
658 - jobject feat_config = env->GetObjectField(config, fid);  
659 - jclass feat_config_cls = env->GetObjectClass(feat_config);  
660 -  
661 - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");  
662 - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);  
663 -  
664 - fid = env->GetFieldID(feat_config_cls, "featureDim", "I");  
665 - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);  
666 -  
667 - //---------- model config ----------  
668 - fid = env->GetFieldID(cls, "modelConfig",  
669 - "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");  
670 - jobject model_config = env->GetObjectField(config, fid);  
671 - jclass model_config_cls = env->GetObjectClass(model_config);  
672 -  
673 - // transducer  
674 - fid = env->GetFieldID(model_config_cls, "transducer",  
675 - "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");  
676 - jobject transducer_config = env->GetObjectField(model_config, fid);  
677 - jclass transducer_config_cls = env->GetObjectClass(transducer_config);  
678 -  
679 - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");  
680 - s = (jstring)env->GetObjectField(transducer_config, fid);  
681 - p = env->GetStringUTFChars(s, nullptr);  
682 - ans.model_config.transducer.encoder = p;  
683 - env->ReleaseStringUTFChars(s, p);  
684 -  
685 - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");  
686 - s = (jstring)env->GetObjectField(transducer_config, fid);  
687 - p = env->GetStringUTFChars(s, nullptr);  
688 - ans.model_config.transducer.decoder = p;  
689 - env->ReleaseStringUTFChars(s, p);  
690 -  
691 - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");  
692 - s = (jstring)env->GetObjectField(transducer_config, fid);  
693 - p = env->GetStringUTFChars(s, nullptr);  
694 - ans.model_config.transducer.joiner = p;  
695 - env->ReleaseStringUTFChars(s, p);  
696 -  
697 - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");  
698 - s = (jstring)env->GetObjectField(model_config, fid);  
699 - p = env->GetStringUTFChars(s, nullptr);  
700 - ans.model_config.tokens = p;  
701 - env->ReleaseStringUTFChars(s, p);  
702 -  
703 - fid = env->GetFieldID(model_config_cls, "numThreads", "I");  
704 - ans.model_config.num_threads = env->GetIntField(model_config, fid);  
705 -  
706 - fid = env->GetFieldID(model_config_cls, "debug", "Z");  
707 - ans.model_config.debug = env->GetBooleanField(model_config, fid);  
708 -  
709 - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");  
710 - s = (jstring)env->GetObjectField(model_config, fid);  
711 - p = env->GetStringUTFChars(s, nullptr);  
712 - ans.model_config.provider = p;  
713 - env->ReleaseStringUTFChars(s, p);  
714 -  
715 - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");  
716 - s = (jstring)env->GetObjectField(model_config, fid);  
717 - p = env->GetStringUTFChars(s, nullptr);  
718 - ans.model_config.model_type = p;  
719 - env->ReleaseStringUTFChars(s, p);  
720 -  
721 - return ans;  
722 -}  
723 -  
724 -static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {  
725 - VadModelConfig ans;  
726 -  
727 - jclass cls = env->GetObjectClass(config);  
728 - jfieldID fid;  
729 -  
730 - // silero_vad  
731 - fid = env->GetFieldID(cls, "sileroVadModelConfig",  
732 - "Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");  
733 - jobject silero_vad_config = env->GetObjectField(config, fid);  
734 - jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);  
735 -  
736 - fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");  
737 - auto s = (jstring)env->GetObjectField(silero_vad_config, fid);  
738 - auto p = env->GetStringUTFChars(s, nullptr);  
739 - ans.silero_vad.model = p;  
740 - env->ReleaseStringUTFChars(s, p);  
741 -  
742 - fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");  
743 - ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);  
744 -  
745 - fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");  
746 - ans.silero_vad.min_silence_duration =  
747 - env->GetFloatField(silero_vad_config, fid);  
748 -  
749 - fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");  
750 - ans.silero_vad.min_speech_duration =  
751 - env->GetFloatField(silero_vad_config, fid);  
752 -  
753 - fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");  
754 - ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);  
755 -  
756 - fid = env->GetFieldID(cls, "sampleRate", "I");  
757 - ans.sample_rate = env->GetIntField(config, fid);  
758 -  
759 - fid = env->GetFieldID(cls, "numThreads", "I");  
760 - ans.num_threads = env->GetIntField(config, fid);  
761 -  
762 - fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");  
763 - s = (jstring)env->GetObjectField(config, fid);  
764 - p = env->GetStringUTFChars(s, nullptr);  
765 - ans.provider = p;  
766 - env->ReleaseStringUTFChars(s, p);  
767 -  
768 - fid = env->GetFieldID(cls, "debug", "Z");  
769 - ans.debug = env->GetBooleanField(config, fid);  
770 -  
771 - return ans;  
772 -}  
773 -  
774 -} // namespace sherpa_onnx  
775 -  
776 -SHERPA_ONNX_EXTERN_C  
777 -JNIEXPORT jlong JNICALL  
778 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env,  
779 - jobject /*obj*/,  
780 - jobject asset_manager,  
781 - jobject _config) {  
782 -#if __ANDROID_API__ >= 9  
783 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
784 - if (!mgr) {  
785 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
786 - }  
787 -#endif  
788 - auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);  
789 - SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());  
790 -  
791 - auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(  
792 -#if __ANDROID_API__ >= 9  
793 - mgr,  
794 -#endif  
795 - config);  
796 -  
797 - return (jlong)extractor;  
798 -}  
799 -  
800 -SHERPA_ONNX_EXTERN_C  
801 -JNIEXPORT jlong JNICALL  
802 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(  
803 - JNIEnv *env, jobject /*obj*/, jobject _config) {  
804 - auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);  
805 - SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());  
806 -  
807 - if (!config.Validate()) {  
808 - SHERPA_ONNX_LOGE("Errors found in config!");  
809 - }  
810 -  
811 - auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config);  
812 -  
813 - return (jlong)extractor;  
814 -}  
815 -  
816 -SHERPA_ONNX_EXTERN_C  
817 -JNIEXPORT void JNICALL  
818 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,  
819 - jobject /*obj*/,  
820 - jlong ptr) {  
821 - delete reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(  
822 - ptr);  
823 -}  
824 -  
825 -SHERPA_ONNX_EXTERN_C  
826 -JNIEXPORT jlong JNICALL  
827 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(  
828 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
829 - auto stream =  
830 - reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr)  
831 - ->CreateStream();  
832 -  
833 - return (jlong)stream;  
834 -}  
835 -  
836 -SHERPA_ONNX_EXTERN_C  
837 -JNIEXPORT jboolean JNICALL  
838 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,  
839 - jobject /*obj*/,  
840 - jlong ptr,  
841 - jlong stream_ptr) {  
842 - auto extractor =  
843 - reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);  
844 - auto stream = reinterpret_cast<  
845 - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);  
846 - return extractor->IsReady(stream);  
847 -}  
848 -  
849 -SHERPA_ONNX_EXTERN_C  
850 -JNIEXPORT jfloatArray JNICALL  
851 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,  
852 - jobject /*obj*/,  
853 - jlong ptr,  
854 - jlong stream_ptr) {  
855 - auto extractor =  
856 - reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);  
857 - auto stream = reinterpret_cast<  
858 - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);  
859 -  
860 - std::vector<float> embedding = extractor->Compute(stream);  
861 - jfloatArray embedding_arr = env->NewFloatArray(embedding.size());  
862 - env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),  
863 - embedding.data());  
864 - return embedding_arr;  
865 -}  
866 -  
867 -SHERPA_ONNX_EXTERN_C  
868 -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(  
869 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
870 - auto extractor =  
871 - reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);  
872 - return extractor->Dim();  
873 -}  
874 -  
875 -SHERPA_ONNX_EXTERN_C  
876 -JNIEXPORT void JNICALL  
877 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete(  
878 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
879 - delete reinterpret_cast<  
880 - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);  
881 -}  
882 -  
883 -SHERPA_ONNX_EXTERN_C  
884 -JNIEXPORT void JNICALL  
885 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform(  
886 - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,  
887 - jint sample_rate) {  
888 - auto stream = reinterpret_cast<  
889 - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);  
890 -  
891 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
892 - jsize n = env->GetArrayLength(samples);  
893 - stream->AcceptWaveform(sample_rate, p, n);  
894 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
895 -}  
896 -  
897 -SHERPA_ONNX_EXTERN_C  
898 -JNIEXPORT void JNICALL  
899 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished(  
900 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
901 - auto stream = reinterpret_cast<  
902 - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);  
903 - stream->InputFinished();  
904 -}  
905 -  
906 -SHERPA_ONNX_EXTERN_C  
907 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_new(  
908 - JNIEnv *env, jobject /*obj*/, jint dim) {  
909 - auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);  
910 - return (jlong)p;  
911 -}  
912 -  
913 -SHERPA_ONNX_EXTERN_C  
914 -JNIEXPORT void JNICALL  
915 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,  
916 - jobject /*obj*/,  
917 - jlong ptr) {  
918 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
919 - delete manager;  
920 -}  
921 -  
922 -SHERPA_ONNX_EXTERN_C  
923 -JNIEXPORT jboolean JNICALL  
924 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,  
925 - jobject /*obj*/,  
926 - jlong ptr, jstring name,  
927 - jfloatArray embedding) {  
928 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
929 -  
930 - jfloat *p = env->GetFloatArrayElements(embedding, nullptr);  
931 - jsize n = env->GetArrayLength(embedding);  
932 -  
933 - if (n != manager->Dim()) {  
934 - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),  
935 - static_cast<int32_t>(n));  
936 - exit(-1);  
937 - }  
938 -  
939 - const char *p_name = env->GetStringUTFChars(name, nullptr);  
940 -  
941 - jboolean ok = manager->Add(p_name, p);  
942 - env->ReleaseStringUTFChars(name, p_name);  
943 - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);  
944 -  
945 - return ok;  
946 -}  
947 -  
948 -SHERPA_ONNX_EXTERN_C  
949 -JNIEXPORT jboolean JNICALL  
950 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(  
951 - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,  
952 - jobjectArray embedding_arr) {  
953 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
954 -  
955 - int num_embeddings = env->GetArrayLength(embedding_arr);  
956 - if (num_embeddings == 0) {  
957 - return false;  
958 - }  
959 -  
960 - std::vector<std::vector<float>> embedding_list;  
961 - embedding_list.reserve(num_embeddings);  
962 - for (int32_t i = 0; i != num_embeddings; ++i) {  
963 - jfloatArray embedding =  
964 - (jfloatArray)env->GetObjectArrayElement(embedding_arr, i);  
965 -  
966 - jfloat *p = env->GetFloatArrayElements(embedding, nullptr);  
967 - jsize n = env->GetArrayLength(embedding);  
968 -  
969 - if (n != manager->Dim()) {  
970 - SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),  
971 - static_cast<int32_t>(n));  
972 - exit(-1);  
973 - }  
974 -  
975 - embedding_list.push_back({p, p + n});  
976 - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);  
977 - }  
978 -  
979 - const char *p_name = env->GetStringUTFChars(name, nullptr);  
980 -  
981 - jboolean ok = manager->Add(p_name, embedding_list);  
982 -  
983 - env->ReleaseStringUTFChars(name, p_name);  
984 -  
985 - return ok;  
986 -}  
987 -  
988 -SHERPA_ONNX_EXTERN_C  
989 -JNIEXPORT jboolean JNICALL  
990 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,  
991 - jobject /*obj*/,  
992 - jlong ptr,  
993 - jstring name) {  
994 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
995 -  
996 - const char *p_name = env->GetStringUTFChars(name, nullptr);  
997 -  
998 - jboolean ok = manager->Remove(p_name);  
999 -  
1000 - env->ReleaseStringUTFChars(name, p_name);  
1001 -  
1002 - return ok;  
1003 -}  
1004 -  
1005 -SHERPA_ONNX_EXTERN_C  
1006 -JNIEXPORT jstring JNICALL  
1007 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,  
1008 - jobject /*obj*/,  
1009 - jlong ptr,  
1010 - jfloatArray embedding,  
1011 - jfloat threshold) {  
1012 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
1013 -  
1014 - jfloat *p = env->GetFloatArrayElements(embedding, nullptr);  
1015 - jsize n = env->GetArrayLength(embedding);  
1016 -  
1017 - if (n != manager->Dim()) {  
1018 - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),  
1019 - static_cast<int32_t>(n));  
1020 - exit(-1);  
1021 - }  
1022 -  
1023 - std::string name = manager->Search(p, threshold);  
1024 -  
1025 - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);  
1026 -  
1027 - return env->NewStringUTF(name.c_str());  
1028 -}  
1029 -  
1030 -SHERPA_ONNX_EXTERN_C  
1031 -JNIEXPORT jboolean JNICALL  
1032 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(  
1033 - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,  
1034 - jfloatArray embedding, jfloat threshold) {  
1035 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
1036 -  
1037 - jfloat *p = env->GetFloatArrayElements(embedding, nullptr);  
1038 - jsize n = env->GetArrayLength(embedding);  
1039 -  
1040 - if (n != manager->Dim()) {  
1041 - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),  
1042 - static_cast<int32_t>(n));  
1043 - exit(-1);  
1044 - }  
1045 -  
1046 - const char *p_name = env->GetStringUTFChars(name, nullptr);  
1047 -  
1048 - jboolean ok = manager->Verify(p_name, p, threshold);  
1049 -  
1050 - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);  
1051 -  
1052 - env->ReleaseStringUTFChars(name, p_name);  
1053 -  
1054 - return ok;  
1055 -}  
1056 -  
1057 -SHERPA_ONNX_EXTERN_C  
1058 -JNIEXPORT jboolean JNICALL  
1059 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,  
1060 - jobject /*obj*/,  
1061 - jlong ptr,  
1062 - jstring name) {  
1063 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
1064 -  
1065 - const char *p_name = env->GetStringUTFChars(name, nullptr);  
1066 -  
1067 - jboolean ok = manager->Contains(p_name);  
1068 -  
1069 - env->ReleaseStringUTFChars(name, p_name);  
1070 -  
1071 - return ok;  
1072 -}  
1073 -  
1074 -SHERPA_ONNX_EXTERN_C  
1075 -JNIEXPORT jint JNICALL  
1076 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,  
1077 - jobject /*obj*/,  
1078 - jlong ptr) {  
1079 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
1080 - return manager->NumSpeakers();  
1081 -}  
1082 -  
1083 -SHERPA_ONNX_EXTERN_C  
1084 -JNIEXPORT jobjectArray JNICALL  
1085 -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(  
1086 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1087 - auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);  
1088 - std::vector<std::string> all_speakers = manager->GetAllSpeakers();  
1089 -  
1090 - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(  
1091 - all_speakers.size(), env->FindClass("java/lang/String"), nullptr);  
1092 -  
1093 - int32_t i = 0;  
1094 - for (auto &s : all_speakers) {  
1095 - jstring js = env->NewStringUTF(s.c_str());  
1096 - env->SetObjectArrayElement(obj_arr, i, js);  
1097 -  
1098 - ++i;  
1099 - }  
1100 -  
1101 - return obj_arr;  
1102 -}  
1103 -  
1104 -// see  
1105 -// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables  
1106 -jobject NewInteger(JNIEnv *env, int32_t value) {  
1107 - jclass cls = env->FindClass("java/lang/Integer");  
1108 - jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");  
1109 - return env->NewObject(cls, constructor, value);  
1110 -}  
1111 -  
1112 -jobject NewFloat(JNIEnv *env, float value) {  
1113 - jclass cls = env->FindClass("java/lang/Float");  
1114 - jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");  
1115 - return env->NewObject(cls, constructor, value);  
1116 -}  
1117 -  
1118 -SHERPA_ONNX_EXTERN_C  
1119 -JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(  
1120 - JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,  
1121 - jint sample_rate) {  
1122 - const char *p_filename = env->GetStringUTFChars(filename, nullptr);  
1123 -  
1124 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
1125 - jsize n = env->GetArrayLength(samples);  
1126 -  
1127 - bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);  
1128 -  
1129 - env->ReleaseStringUTFChars(filename, p_filename);  
1130 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
1131 -  
1132 - return ok;  
1133 -}  
1134 -  
1135 -SHERPA_ONNX_EXTERN_C  
1136 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new(  
1137 - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {  
1138 -#if __ANDROID_API__ >= 9  
1139 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
1140 - if (!mgr) {  
1141 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
1142 - }  
1143 -#endif  
1144 - auto config = sherpa_onnx::GetVadModelConfig(env, _config);  
1145 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1146 - auto model = new sherpa_onnx::SherpaOnnxVad(  
1147 -#if __ANDROID_API__ >= 9  
1148 - mgr,  
1149 -#endif  
1150 - config);  
1151 -  
1152 - return (jlong)model;  
1153 -}  
1154 -  
1155 -SHERPA_ONNX_EXTERN_C  
1156 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(  
1157 - JNIEnv *env, jobject /*obj*/, jobject _config) {  
1158 - auto config = sherpa_onnx::GetVadModelConfig(env, _config);  
1159 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1160 - auto model = new sherpa_onnx::SherpaOnnxVad(config);  
1161 -  
1162 - return (jlong)model;  
1163 -}  
1164 -  
1165 -SHERPA_ONNX_EXTERN_C  
1166 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,  
1167 - jobject /*obj*/,  
1168 - jlong ptr) {  
1169 - delete reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1170 -}  
1171 -  
1172 -SHERPA_ONNX_EXTERN_C  
1173 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(  
1174 - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {  
1175 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1176 -  
1177 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
1178 - jsize n = env->GetArrayLength(samples);  
1179 -  
1180 - model->AcceptWaveform(p, n);  
1181 -  
1182 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
1183 -}  
1184 -  
1185 -SHERPA_ONNX_EXTERN_C  
1186 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,  
1187 - jobject /*obj*/,  
1188 - jlong ptr) {  
1189 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1190 - return model->Empty();  
1191 -}  
1192 -  
1193 -SHERPA_ONNX_EXTERN_C  
1194 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,  
1195 - jobject /*obj*/,  
1196 - jlong ptr) {  
1197 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1198 - model->Pop();  
1199 -}  
1200 -  
1201 -SHERPA_ONNX_EXTERN_C  
1202 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,  
1203 - jobject /*obj*/,  
1204 - jlong ptr) {  
1205 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1206 - model->Clear();  
1207 -}  
1208 -  
1209 -SHERPA_ONNX_EXTERN_C  
1210 -JNIEXPORT jobjectArray JNICALL  
1211 -Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1212 - const auto &front =  
1213 - reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr)->Front();  
1214 -  
1215 - jfloatArray samples_arr = env->NewFloatArray(front.samples.size());  
1216 - env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),  
1217 - front.samples.data());  
1218 -  
1219 - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(  
1220 - 2, env->FindClass("java/lang/Object"), nullptr);  
1221 -  
1222 - env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));  
1223 - env->SetObjectArrayElement(obj_arr, 1, samples_arr);  
1224 -  
1225 - return obj_arr;  
1226 -}  
1227 -  
1228 -SHERPA_ONNX_EXTERN_C  
1229 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(  
1230 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1231 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1232 - return model->IsSpeechDetected();  
1233 -}  
1234 -  
1235 -SHERPA_ONNX_EXTERN_C  
1236 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,  
1237 - jobject /*obj*/,  
1238 - jlong ptr) {  
1239 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);  
1240 - model->Reset();  
1241 -}  
1242 -  
1243 -SHERPA_ONNX_EXTERN_C  
1244 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(  
1245 - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {  
1246 -#if __ANDROID_API__ >= 9  
1247 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
1248 - if (!mgr) {  
1249 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
1250 - }  
1251 -#endif  
1252 - auto config = sherpa_onnx::GetConfig(env, _config);  
1253 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1254 - auto model = new sherpa_onnx::SherpaOnnx(  
1255 -#if __ANDROID_API__ >= 9  
1256 - mgr,  
1257 -#endif  
1258 - config);  
1259 -  
1260 - return (jlong)model;  
1261 -}  
1262 -  
1263 -SHERPA_ONNX_EXTERN_C  
1264 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_newFromFile(  
1265 - JNIEnv *env, jobject /*obj*/, jobject _config) {  
1266 - auto config = sherpa_onnx::GetConfig(env, _config);  
1267 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1268 - auto model = new sherpa_onnx::SherpaOnnx(config);  
1269 -  
1270 - return (jlong)model;  
1271 -}  
1272 -  
1273 -SHERPA_ONNX_EXTERN_C  
1274 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(  
1275 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1276 - delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1277 -}  
1278 -  
1279 -SHERPA_ONNX_EXTERN_C  
1280 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_new(  
1281 - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {  
1282 -#if __ANDROID_API__ >= 9  
1283 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
1284 - if (!mgr) {  
1285 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
1286 - }  
1287 -#endif  
1288 - auto config = sherpa_onnx::GetOfflineConfig(env, _config);  
1289 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1290 - auto model = new sherpa_onnx::SherpaOnnxOffline(  
1291 -#if __ANDROID_API__ >= 9  
1292 - mgr,  
1293 -#endif  
1294 - config);  
1295 -  
1296 - return (jlong)model;  
1297 -}  
1298 -  
1299 -SHERPA_ONNX_EXTERN_C  
1300 -JNIEXPORT jlong JNICALL  
1301 -Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_newFromFile(JNIEnv *env,  
1302 - jobject /*obj*/,  
1303 - jobject _config) {  
1304 - auto config = sherpa_onnx::GetOfflineConfig(env, _config);  
1305 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1306 - auto model = new sherpa_onnx::SherpaOnnxOffline(config);  
1307 -  
1308 - return (jlong)model;  
1309 -}  
1310 -  
1311 -SHERPA_ONNX_EXTERN_C  
1312 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(  
1313 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1314 - delete reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);  
1315 -}  
1316 -  
1317 -SHERPA_ONNX_EXTERN_C  
1318 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(  
1319 - JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate,  
1320 - jstring keywords) {  
1321 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1322 - const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);  
1323 - model->Reset(recreate, p_keywords);  
1324 - env->ReleaseStringUTFChars(keywords, p_keywords);  
1325 -}  
1326 -  
1327 -SHERPA_ONNX_EXTERN_C  
1328 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(  
1329 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1330 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1331 - return model->IsReady();  
1332 -}  
1333 -  
1334 -SHERPA_ONNX_EXTERN_C  
1335 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(  
1336 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1337 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1338 - return model->IsEndpoint();  
1339 -}  
1340 -  
1341 -SHERPA_ONNX_EXTERN_C  
1342 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(  
1343 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1344 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1345 - model->Decode();  
1346 -}  
1347 -  
1348 -SHERPA_ONNX_EXTERN_C  
1349 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(  
1350 - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,  
1351 - jint sample_rate) {  
1352 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);  
1353 -  
1354 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
1355 - jsize n = env->GetArrayLength(samples);  
1356 -  
1357 - model->AcceptWaveform(sample_rate, p, n);  
1358 -  
1359 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
1360 -}  
1361 -  
1362 -SHERPA_ONNX_EXTERN_C  
1363 -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_decode(  
1364 - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,  
1365 - jint sample_rate) {  
1366 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);  
1367 -  
1368 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
1369 - jsize n = env->GetArrayLength(samples);  
1370 -  
1371 - auto text = model->Decode(sample_rate, p, n);  
1372 -  
1373 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
1374 -  
1375 - return env->NewStringUTF(text.c_str());  
1376 -}  
1377 -  
1378 -SHERPA_ONNX_EXTERN_C  
1379 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(  
1380 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1381 - reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();  
1382 -}  
1383 -  
1384 -SHERPA_ONNX_EXTERN_C  
1385 -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(  
1386 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1387 - // see  
1388 - // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni  
1389 - auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();  
1390 - return env->NewStringUTF(text.c_str());  
1391 -}  
1392 -  
1393 -SHERPA_ONNX_EXTERN_C  
1394 -JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(  
1395 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1396 - auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens();  
1397 - int32_t size = tokens.size();  
1398 - jclass stringClass = env->FindClass("java/lang/String");  
1399 -  
1400 - // convert C++ list into jni string array  
1401 - jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);  
1402 - for (int32_t i = 0; i < size; i++) {  
1403 - // Convert the C++ string to a C string  
1404 - const char *cstr = tokens[i].c_str();  
1405 -  
1406 - // Convert the C string to a jstring  
1407 - jstring jstr = env->NewStringUTF(cstr);  
1408 -  
1409 - // Set the array element  
1410 - env->SetObjectArrayElement(result, i, jstr);  
1411 - }  
1412 -  
1413 - return result;  
1414 -}  
1415 -  
1416 -SHERPA_ONNX_EXTERN_C  
1417 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(  
1418 - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {  
1419 -#if __ANDROID_API__ >= 9  
1420 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
1421 - if (!mgr) {  
1422 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
1423 - }  
1424 -#endif  
1425 - auto config = sherpa_onnx::GetKwsConfig(env, _config);  
1426 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1427 - auto model = new sherpa_onnx::SherpaOnnxKws(  
1428 -#if __ANDROID_API__ >= 9  
1429 - mgr,  
1430 -#endif  
1431 - config);  
1432 -  
1433 - return (jlong)model;  
1434 -}  
1435 -  
1436 -SHERPA_ONNX_EXTERN_C  
1437 -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(  
1438 - JNIEnv *env, jobject /*obj*/, jobject _config) {  
1439 - auto config = sherpa_onnx::GetKwsConfig(env, _config);  
1440 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1441 - auto model = new sherpa_onnx::SherpaOnnxKws(config);  
1442 -  
1443 - return (jlong)model;  
1444 -} 7 +#include <fstream>
1445 8
1446 -SHERPA_ONNX_EXTERN_C  
1447 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(  
1448 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1449 - delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);  
1450 -} 9 +#include "sherpa-onnx/csrc/macros.h"
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +#include "sherpa-onnx/csrc/wave-reader.h"
  12 +#include "sherpa-onnx/csrc/wave-writer.h"
  13 +#include "sherpa-onnx/jni/common.h"
1451 14
1452 -SHERPA_ONNX_EXTERN_C  
1453 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(  
1454 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1455 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);  
1456 - return model->IsReady(); 15 +// see
  16 +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
  17 +jobject NewInteger(JNIEnv *env, int32_t value) {
  18 + jclass cls = env->FindClass("java/lang/Integer");
  19 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
  20 + return env->NewObject(cls, constructor, value);
1457 } 21 }
1458 22
1459 -SHERPA_ONNX_EXTERN_C  
1460 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(  
1461 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1462 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);  
1463 - model->Decode(); 23 +jobject NewFloat(JNIEnv *env, float value) {
  24 + jclass cls = env->FindClass("java/lang/Float");
  25 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
  26 + return env->NewObject(cls, constructor, value);
1464 } 27 }
1465 28
1466 SHERPA_ONNX_EXTERN_C 29 SHERPA_ONNX_EXTERN_C
1467 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(  
1468 - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, 30 +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
  31 + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
1469 jint sample_rate) { 32 jint sample_rate) {
1470 - auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr); 33 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
1471 34
1472 jfloat *p = env->GetFloatArrayElements(samples, nullptr); 35 jfloat *p = env->GetFloatArrayElements(samples, nullptr);
1473 jsize n = env->GetArrayLength(samples); 36 jsize n = env->GetArrayLength(samples);
1474 37
1475 - model->AcceptWaveform(sample_rate, p, n); 38 + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
1476 39
  40 + env->ReleaseStringUTFChars(filename, p_filename);
1477 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); 41 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
1478 -}  
1479 -  
1480 -SHERPA_ONNX_EXTERN_C  
1481 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(  
1482 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1483 - reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();  
1484 -}  
1485 -  
1486 -SHERPA_ONNX_EXTERN_C  
1487 -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(  
1488 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1489 - // see  
1490 - // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni  
1491 - auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();  
1492 - return env->NewStringUTF(text.c_str());  
1493 -}  
1494 -  
1495 -SHERPA_ONNX_EXTERN_C  
1496 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(  
1497 - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {  
1498 - const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);  
1499 -  
1500 - std::string keywords_str = p_keywords;  
1501 -  
1502 - bool status =  
1503 - reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);  
1504 - env->ReleaseStringUTFChars(keywords, p_keywords);  
1505 - return status;  
1506 -}  
1507 42
1508 -SHERPA_ONNX_EXTERN_C  
1509 -JNIEXPORT jobjectArray JNICALL  
1510 -Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,  
1511 - jlong ptr) {  
1512 - auto tokens =  
1513 - reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();  
1514 - int32_t size = tokens.size();  
1515 - jclass stringClass = env->FindClass("java/lang/String");  
1516 -  
1517 - // convert C++ list into jni string array  
1518 - jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);  
1519 - for (int32_t i = 0; i < size; i++) {  
1520 - // Convert the C++ string to a C string  
1521 - const char *cstr = tokens[i].c_str();  
1522 -  
1523 - // Convert the C string to a jstring  
1524 - jstring jstr = env->NewStringUTF(cstr);  
1525 -  
1526 - // Set the array element  
1527 - env->SetObjectArrayElement(result, i, jstr);  
1528 - }  
1529 -  
1530 - return result; 43 + return ok;
1531 } 44 }
1532 45
1533 static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, 46 static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is,
@@ -1593,81 +106,7 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset( @@ -1593,81 +106,7 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(
1593 return obj_arr; 106 return obj_arr;
1594 } 107 }
1595 108
1596 -// ******warpper for OnlineRecognizer*******  
1597 -  
1598 -// wav reader for java interface  
1599 -SHERPA_ONNX_EXTERN_C  
1600 -JNIEXPORT jobjectArray JNICALL  
1601 -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env,  
1602 - jclass /*cls*/,  
1603 - jstring filename) {  
1604 - auto data =  
1605 - Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(  
1606 - env, nullptr, nullptr, filename);  
1607 - return data;  
1608 -}  
1609 -  
1610 -SHERPA_ONNX_EXTERN_C  
1611 -JNIEXPORT jlong JNICALL  
1612 -  
1613 -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createOnlineRecognizer(  
1614 -  
1615 - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {  
1616 -#if __ANDROID_API__ >= 9  
1617 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);  
1618 - if (!mgr) {  
1619 - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);  
1620 - }  
1621 -#endif  
1622 - sherpa_onnx::OnlineRecognizerConfig config =  
1623 - sherpa_onnx::GetConfig(env, _config);  
1624 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
1625 - auto p_recognizer = new sherpa_onnx::OnlineRecognizer(  
1626 -#if __ANDROID_API__ >= 9  
1627 - mgr,  
1628 -#endif  
1629 - config);  
1630 - return (jlong)p_recognizer;  
1631 -}  
1632 -  
1633 -SHERPA_ONNX_EXTERN_C  
1634 -JNIEXPORT void JNICALL  
1635 -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_deleteOnlineRecognizer(  
1636 - JNIEnv *env, jobject /*obj*/, jlong ptr) {  
1637 - delete reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1638 -}  
1639 -  
1640 -SHERPA_ONNX_EXTERN_C  
1641 -JNIEXPORT jlong JNICALL  
1642 -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env,  
1643 - jobject /*obj*/,  
1644 - jlong ptr) {  
1645 - std::unique_ptr<sherpa_onnx::OnlineStream> s =  
1646 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr)->CreateStream();  
1647 - sherpa_onnx::OnlineStream *p_stream = s.release();  
1648 - return reinterpret_cast<jlong>(p_stream);  
1649 -}  
1650 -  
1651 -SHERPA_ONNX_EXTERN_C  
1652 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady(  
1653 - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {  
1654 - sherpa_onnx::OnlineRecognizer *model =  
1655 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1656 - sherpa_onnx::OnlineStream *s =  
1657 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1658 - return model->IsReady(s);  
1659 -}  
1660 -  
1661 -SHERPA_ONNX_EXTERN_C  
1662 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStream(  
1663 - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {  
1664 - sherpa_onnx::OnlineRecognizer *model =  
1665 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1666 - sherpa_onnx::OnlineStream *s =  
1667 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1668 - model->DecodeStream(s);  
1669 -}  
1670 - 109 +#if 0
1671 SHERPA_ONNX_EXTERN_C 110 SHERPA_ONNX_EXTERN_C
1672 JNIEXPORT void JNICALL 111 JNIEXPORT void JNICALL
1673 Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env, 112 Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env,
@@ -1687,92 +126,4 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env, @@ -1687,92 +126,4 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env,
1687 model->DecodeStreams(p_ss.data(), n); 126 model->DecodeStreams(p_ss.data(), n);
1688 env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT); 127 env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT);
1689 } 128 }
1690 -  
1691 -SHERPA_ONNX_EXTERN_C  
1692 -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(  
1693 - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {  
1694 - sherpa_onnx::OnlineRecognizer *model =  
1695 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1696 - sherpa_onnx::OnlineStream *s =  
1697 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1698 - sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);  
1699 - return env->NewStringUTF(result.text.c_str());  
1700 -}  
1701 -  
1702 -SHERPA_ONNX_EXTERN_C  
1703 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint(  
1704 - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {  
1705 - sherpa_onnx::OnlineRecognizer *model =  
1706 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1707 - sherpa_onnx::OnlineStream *s =  
1708 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1709 - return model->IsEndpoint(s);  
1710 -}  
1711 -  
1712 -SHERPA_ONNX_EXTERN_C  
1713 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reSet(  
1714 - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) {  
1715 - sherpa_onnx::OnlineRecognizer *model =  
1716 - reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);  
1717 - sherpa_onnx::OnlineStream *s =  
1718 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1719 - model->Reset(s);  
1720 -}  
1721 -  
1722 -// *********for OnlineStream *********  
1723 -SHERPA_ONNX_EXTERN_C  
1724 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform(  
1725 - JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint sample_rate,  
1726 - jfloatArray waveform) {  
1727 - sherpa_onnx::OnlineStream *s =  
1728 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1729 - jfloat *p = env->GetFloatArrayElements(waveform, nullptr);  
1730 - jsize n = env->GetArrayLength(waveform);  
1731 - s->AcceptWaveform(sample_rate, p, n);  
1732 - env->ReleaseFloatArrayElements(waveform, p, JNI_ABORT);  
1733 -}  
1734 -  
1735 -SHERPA_ONNX_EXTERN_C  
1736 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished(  
1737 - JNIEnv *env, jobject /*obj*/, jlong s_ptr) {  
1738 - sherpa_onnx::OnlineStream *s =  
1739 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1740 - s->InputFinished();  
1741 -}  
1742 -  
1743 -SHERPA_ONNX_EXTERN_C  
1744 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_deleteStream(  
1745 - JNIEnv *env, jobject /*obj*/, jlong s_ptr) {  
1746 - delete reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1747 -}  
1748 -  
1749 -SHERPA_ONNX_EXTERN_C  
1750 -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_numFramesReady(  
1751 - JNIEnv *env, jobject /*obj*/, jlong s_ptr) {  
1752 - sherpa_onnx::OnlineStream *s =  
1753 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1754 - return s->NumFramesReady();  
1755 -}  
1756 -  
1757 -SHERPA_ONNX_EXTERN_C  
1758 -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_isLastFrame(  
1759 - JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint frame) {  
1760 - sherpa_onnx::OnlineStream *s =  
1761 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1762 - return s->IsLastFrame(frame);  
1763 -}  
1764 -SHERPA_ONNX_EXTERN_C  
1765 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_reSet(  
1766 - JNIEnv *env, jobject /*obj*/, jlong s_ptr) {  
1767 - sherpa_onnx::OnlineStream *s =  
1768 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1769 - s->Reset();  
1770 -}  
1771 -  
1772 -SHERPA_ONNX_EXTERN_C  
1773 -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_featureDim(  
1774 - JNIEnv *env, jobject /*obj*/, jlong s_ptr) {  
1775 - sherpa_onnx::OnlineStream *s =  
1776 - reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);  
1777 - return s->FeatureDim();  
1778 -} 129 +#endif
  1 +// sherpa-onnx/jni/keyword-spotter.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
  13 + KeywordSpotterConfig ans;
  14 +
  15 + jclass cls = env->GetObjectClass(config);
  16 + jfieldID fid;
  17 +
  18 + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
  19 + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
  20 +
  21 + //---------- decoding ----------
  22 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  23 + ans.max_active_paths = env->GetIntField(config, fid);
  24 +
  25 + fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
  26 + jstring s = (jstring)env->GetObjectField(config, fid);
  27 + const char *p = env->GetStringUTFChars(s, nullptr);
  28 + ans.keywords_file = p;
  29 + env->ReleaseStringUTFChars(s, p);
  30 +
  31 + fid = env->GetFieldID(cls, "keywordsScore", "F");
  32 + ans.keywords_score = env->GetFloatField(config, fid);
  33 +
  34 + fid = env->GetFieldID(cls, "keywordsThreshold", "F");
  35 + ans.keywords_threshold = env->GetFloatField(config, fid);
  36 +
  37 + fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
  38 + ans.num_trailing_blanks = env->GetIntField(config, fid);
  39 +
  40 + //---------- feat config ----------
  41 + fid = env->GetFieldID(cls, "featConfig",
  42 + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
  43 + jobject feat_config = env->GetObjectField(config, fid);
  44 + jclass feat_config_cls = env->GetObjectClass(feat_config);
  45 +
  46 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  47 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
  48 +
  49 + fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
  50 + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
  51 +
  52 + //---------- model config ----------
  53 + fid = env->GetFieldID(cls, "modelConfig",
  54 + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
  55 + jobject model_config = env->GetObjectField(config, fid);
  56 + jclass model_config_cls = env->GetObjectClass(model_config);
  57 +
  58 + // transducer
  59 + fid = env->GetFieldID(model_config_cls, "transducer",
  60 + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
  61 + jobject transducer_config = env->GetObjectField(model_config, fid);
  62 + jclass transducer_config_cls = env->GetObjectClass(transducer_config);
  63 +
  64 + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
  65 + s = (jstring)env->GetObjectField(transducer_config, fid);
  66 + p = env->GetStringUTFChars(s, nullptr);
  67 + ans.model_config.transducer.encoder = p;
  68 + env->ReleaseStringUTFChars(s, p);
  69 +
  70 + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
  71 + s = (jstring)env->GetObjectField(transducer_config, fid);
  72 + p = env->GetStringUTFChars(s, nullptr);
  73 + ans.model_config.transducer.decoder = p;
  74 + env->ReleaseStringUTFChars(s, p);
  75 +
  76 + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
  77 + s = (jstring)env->GetObjectField(transducer_config, fid);
  78 + p = env->GetStringUTFChars(s, nullptr);
  79 + ans.model_config.transducer.joiner = p;
  80 + env->ReleaseStringUTFChars(s, p);
  81 +
  82 + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
  83 + s = (jstring)env->GetObjectField(model_config, fid);
  84 + p = env->GetStringUTFChars(s, nullptr);
  85 + ans.model_config.tokens = p;
  86 + env->ReleaseStringUTFChars(s, p);
  87 +
  88 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  89 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
  90 +
  91 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  92 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
  93 +
  94 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  95 + s = (jstring)env->GetObjectField(model_config, fid);
  96 + p = env->GetStringUTFChars(s, nullptr);
  97 + ans.model_config.provider = p;
  98 + env->ReleaseStringUTFChars(s, p);
  99 +
  100 + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
  101 + s = (jstring)env->GetObjectField(model_config, fid);
  102 + p = env->GetStringUTFChars(s, nullptr);
  103 + ans.model_config.model_type = p;
  104 + env->ReleaseStringUTFChars(s, p);
  105 +
  106 + return ans;
  107 +}
  108 +
  109 +} // namespace sherpa_onnx
  110 +
  111 +SHERPA_ONNX_EXTERN_C
  112 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
  113 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  114 +#if __ANDROID_API__ >= 9
  115 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  116 + if (!mgr) {
  117 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  118 + }
  119 +#endif
  120 + auto config = sherpa_onnx::GetKwsConfig(env, _config);
  121 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  122 + auto kws = new sherpa_onnx::KeywordSpotter(
  123 +#if __ANDROID_API__ >= 9
  124 + mgr,
  125 +#endif
  126 + config);
  127 +
  128 + return (jlong)kws;
  129 +}
  130 +
  131 +SHERPA_ONNX_EXTERN_C
  132 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromFile(
  133 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  134 + auto config = sherpa_onnx::GetKwsConfig(env, _config);
  135 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  136 +
  137 + if (!config.Validate()) {
  138 + SHERPA_ONNX_LOGE("Errors found in config!");
  139 + return 0;
  140 + }
  141 +
  142 + auto kws = new sherpa_onnx::KeywordSpotter(config);
  143 +
  144 + return (jlong)kws;
  145 +}
  146 +
  147 +SHERPA_ONNX_EXTERN_C
  148 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_delete(
  149 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  150 + delete reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  151 +}
  152 +
  153 +SHERPA_ONNX_EXTERN_C
  154 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
  155 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  156 + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  157 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  158 +
  159 + kws->DecodeStream(stream);
  160 +}
  161 +
  162 +SHERPA_ONNX_EXTERN_C
  163 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
  164 + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
  165 + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  166 +
  167 + const char *p = env->GetStringUTFChars(keywords, nullptr);
  168 + std::unique_ptr<sherpa_onnx::OnlineStream> stream;
  169 +
  170 + if (strlen(p) == 0) {
  171 + stream = kws->CreateStream();
  172 + } else {
  173 + stream = kws->CreateStream(p);
  174 + }
  175 +
  176 + env->ReleaseStringUTFChars(keywords, p);
  177 +
  178 + // The user is responsible to free the returned pointer.
  179 + //
  180 + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
  181 + // ./offline-stream.cc
  182 + sherpa_onnx::OnlineStream *ans = stream.release();
  183 + return (jlong)ans;
  184 +}
  185 +
  186 +SHERPA_ONNX_EXTERN_C
  187 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_isReady(
  188 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  189 + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  190 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  191 +
  192 + return kws->IsReady(stream);
  193 +}
  194 +
  195 +SHERPA_ONNX_EXTERN_C
  196 +JNIEXPORT jobjectArray JNICALL
  197 +Java_com_k2fsa_sherpa_onnx_KeywordSpotter_getResult(JNIEnv *env,
  198 + jobject /*obj*/, jlong ptr,
  199 + jlong stream_ptr) {
  200 + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  201 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  202 +
  203 + sherpa_onnx::KeywordResult result = kws->GetResult(stream);
  204 +
  205 + // [0]: keyword, jstring
  206 + // [1]: tokens, array of jstring
  207 + // [2]: timestamps, array of float
  208 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  209 + 3, env->FindClass("java/lang/Object"), nullptr);
  210 +
  211 + jstring keyword = env->NewStringUTF(result.keyword.c_str());
  212 + env->SetObjectArrayElement(obj_arr, 0, keyword);
  213 +
  214 + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
  215 + result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
  216 +
  217 + int32_t i = 0;
  218 + for (const auto &t : result.tokens) {
  219 + jstring jtext = env->NewStringUTF(t.c_str());
  220 + env->SetObjectArrayElement(tokens_arr, i, jtext);
  221 + i += 1;
  222 + }
  223 +
  224 + env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
  225 +
  226 + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
  227 + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
  228 + result.timestamps.data());
  229 +
  230 + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
  231 +
  232 + return obj_arr;
  233 +}
  1 +// sherpa-onnx/jni/offline-recognizer.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
  13 + OfflineRecognizerConfig ans;
  14 +
  15 + jclass cls = env->GetObjectClass(config);
  16 + jfieldID fid;
  17 +
  18 + //---------- decoding ----------
  19 + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
  20 + jstring s = (jstring)env->GetObjectField(config, fid);
  21 + const char *p = env->GetStringUTFChars(s, nullptr);
  22 + ans.decoding_method = p;
  23 + env->ReleaseStringUTFChars(s, p);
  24 +
  25 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  26 + ans.max_active_paths = env->GetIntField(config, fid);
  27 +
  28 + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
  29 + s = (jstring)env->GetObjectField(config, fid);
  30 + p = env->GetStringUTFChars(s, nullptr);
  31 + ans.hotwords_file = p;
  32 + env->ReleaseStringUTFChars(s, p);
  33 +
  34 + fid = env->GetFieldID(cls, "hotwordsScore", "F");
  35 + ans.hotwords_score = env->GetFloatField(config, fid);
  36 +
  37 + //---------- feat config ----------
  38 + fid = env->GetFieldID(cls, "featConfig",
  39 + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
  40 + jobject feat_config = env->GetObjectField(config, fid);
  41 + jclass feat_config_cls = env->GetObjectClass(feat_config);
  42 +
  43 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  44 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
  45 +
  46 + fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
  47 + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
  48 +
  49 + //---------- model config ----------
  50 + fid = env->GetFieldID(cls, "modelConfig",
  51 + "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
  52 + jobject model_config = env->GetObjectField(config, fid);
  53 + jclass model_config_cls = env->GetObjectClass(model_config);
  54 +
  55 + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
  56 + s = (jstring)env->GetObjectField(model_config, fid);
  57 + p = env->GetStringUTFChars(s, nullptr);
  58 + ans.model_config.tokens = p;
  59 + env->ReleaseStringUTFChars(s, p);
  60 +
  61 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  62 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
  63 +
  64 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  65 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
  66 +
  67 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  68 + s = (jstring)env->GetObjectField(model_config, fid);
  69 + p = env->GetStringUTFChars(s, nullptr);
  70 + ans.model_config.provider = p;
  71 + env->ReleaseStringUTFChars(s, p);
  72 +
  73 + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
  74 + s = (jstring)env->GetObjectField(model_config, fid);
  75 + p = env->GetStringUTFChars(s, nullptr);
  76 + ans.model_config.model_type = p;
  77 + env->ReleaseStringUTFChars(s, p);
  78 +
  79 + // transducer
  80 + fid = env->GetFieldID(model_config_cls, "transducer",
  81 + "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
  82 + jobject transducer_config = env->GetObjectField(model_config, fid);
  83 + jclass transducer_config_cls = env->GetObjectClass(transducer_config);
  84 +
  85 + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
  86 + s = (jstring)env->GetObjectField(transducer_config, fid);
  87 + p = env->GetStringUTFChars(s, nullptr);
  88 + ans.model_config.transducer.encoder_filename = p;
  89 + env->ReleaseStringUTFChars(s, p);
  90 +
  91 + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
  92 + s = (jstring)env->GetObjectField(transducer_config, fid);
  93 + p = env->GetStringUTFChars(s, nullptr);
  94 + ans.model_config.transducer.decoder_filename = p;
  95 + env->ReleaseStringUTFChars(s, p);
  96 +
  97 + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
  98 + s = (jstring)env->GetObjectField(transducer_config, fid);
  99 + p = env->GetStringUTFChars(s, nullptr);
  100 + ans.model_config.transducer.joiner_filename = p;
  101 + env->ReleaseStringUTFChars(s, p);
  102 +
  103 + // paraformer
  104 + fid = env->GetFieldID(model_config_cls, "paraformer",
  105 + "Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");
  106 + jobject paraformer_config = env->GetObjectField(model_config, fid);
  107 + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
  108 +
  109 + fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
  110 +
  111 + s = (jstring)env->GetObjectField(paraformer_config, fid);
  112 + p = env->GetStringUTFChars(s, nullptr);
  113 + ans.model_config.paraformer.model = p;
  114 + env->ReleaseStringUTFChars(s, p);
  115 +
  116 + // whisper
  117 + fid = env->GetFieldID(model_config_cls, "whisper",
  118 + "Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");
  119 + jobject whisper_config = env->GetObjectField(model_config, fid);
  120 + jclass whisper_config_cls = env->GetObjectClass(whisper_config);
  121 +
  122 + fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");
  123 + s = (jstring)env->GetObjectField(whisper_config, fid);
  124 + p = env->GetStringUTFChars(s, nullptr);
  125 + ans.model_config.whisper.encoder = p;
  126 + env->ReleaseStringUTFChars(s, p);
  127 +
  128 + fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");
  129 + s = (jstring)env->GetObjectField(whisper_config, fid);
  130 + p = env->GetStringUTFChars(s, nullptr);
  131 + ans.model_config.whisper.decoder = p;
  132 + env->ReleaseStringUTFChars(s, p);
  133 +
  134 + fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;");
  135 + s = (jstring)env->GetObjectField(whisper_config, fid);
  136 + p = env->GetStringUTFChars(s, nullptr);
  137 + ans.model_config.whisper.language = p;
  138 + env->ReleaseStringUTFChars(s, p);
  139 +
  140 + fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;");
  141 + s = (jstring)env->GetObjectField(whisper_config, fid);
  142 + p = env->GetStringUTFChars(s, nullptr);
  143 + ans.model_config.whisper.task = p;
  144 + env->ReleaseStringUTFChars(s, p);
  145 +
  146 + fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I");
  147 + ans.model_config.whisper.tail_paddings =
  148 + env->GetIntField(whisper_config, fid);
  149 +
  150 + return ans;
  151 +}
  152 +
  153 +} // namespace sherpa_onnx
  154 +
  155 +SHERPA_ONNX_EXTERN_C
  156 +JNIEXPORT jlong JNICALL
  157 +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
  158 + jobject /*obj*/,
  159 + jobject asset_manager,
  160 + jobject _config) {
  161 +#if __ANDROID_API__ >= 9
  162 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  163 + if (!mgr) {
  164 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  165 + }
  166 +#endif
  167 + auto config = sherpa_onnx::GetOfflineConfig(env, _config);
  168 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  169 + auto model = new sherpa_onnx::OfflineRecognizer(
  170 +#if __ANDROID_API__ >= 9
  171 + mgr,
  172 +#endif
  173 + config);
  174 +
  175 + return (jlong)model;
  176 +}
  177 +
  178 +SHERPA_ONNX_EXTERN_C
  179 +JNIEXPORT jlong JNICALL
  180 +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
  181 + jobject /*obj*/,
  182 + jobject _config) {
  183 + auto config = sherpa_onnx::GetOfflineConfig(env, _config);
  184 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  185 +
  186 + if (!config.Validate()) {
  187 + SHERPA_ONNX_LOGE("Errors found in config!");
  188 + return 0;
  189 + }
  190 +
  191 + auto model = new sherpa_onnx::OfflineRecognizer(config);
  192 +
  193 + return (jlong)model;
  194 +}
  195 +
  196 +SHERPA_ONNX_EXTERN_C
  197 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_delete(
  198 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  199 + delete reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
  200 +}
  201 +
  202 +SHERPA_ONNX_EXTERN_C
  203 +JNIEXPORT jlong JNICALL
  204 +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv *env,
  205 + jobject /*obj*/,
  206 + jlong ptr) {
  207 + auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
  208 + std::unique_ptr<sherpa_onnx::OfflineStream> s = recognizer->CreateStream();
  209 +
  210 + // The user is responsible to free the returned pointer.
  211 + //
  212 + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
  213 + // ./offline-stream.cc
  214 + sherpa_onnx::OfflineStream *p = s.release();
  215 + return (jlong)p;
  216 +}
  217 +
  218 +SHERPA_ONNX_EXTERN_C
  219 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode(
  220 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) {
  221 + auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
  222 + auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
  223 +
  224 + recognizer->DecodeStream(stream);
  225 +}
  226 +
  227 +SHERPA_ONNX_EXTERN_C
  228 +JNIEXPORT jobjectArray JNICALL
  229 +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
  230 + jobject /*obj*/,
  231 + jlong streamPtr) {
  232 + auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
  233 + sherpa_onnx::OfflineRecognitionResult result = stream->GetResult();
  234 +
  235 + // [0]: text, jstring
  236 + // [1]: tokens, array of jstring
  237 + // [2]: timestamps, array of float
  238 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  239 + 3, env->FindClass("java/lang/Object"), nullptr);
  240 +
  241 + jstring text = env->NewStringUTF(result.text.c_str());
  242 + env->SetObjectArrayElement(obj_arr, 0, text);
  243 +
  244 + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
  245 + result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
  246 +
  247 + int32_t i = 0;
  248 + for (const auto &t : result.tokens) {
  249 + jstring jtext = env->NewStringUTF(t.c_str());
  250 + env->SetObjectArrayElement(tokens_arr, i, jtext);
  251 + i += 1;
  252 + }
  253 +
  254 + env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
  255 +
  256 + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
  257 + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
  258 + result.timestamps.data());
  259 +
  260 + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
  261 +
  262 + return obj_arr;
  263 +}
  1 +// sherpa-onnx/jni/online-recognizer.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-recognizer.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
  13 + OnlineRecognizerConfig ans;
  14 +
  15 + jclass cls = env->GetObjectClass(config);
  16 + jfieldID fid;
  17 +
  18 + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
  19 + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
  20 +
  21 + //---------- decoding ----------
  22 + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
  23 + jstring s = (jstring)env->GetObjectField(config, fid);
  24 + const char *p = env->GetStringUTFChars(s, nullptr);
  25 + ans.decoding_method = p;
  26 + env->ReleaseStringUTFChars(s, p);
  27 +
  28 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  29 + ans.max_active_paths = env->GetIntField(config, fid);
  30 +
  31 + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
  32 + s = (jstring)env->GetObjectField(config, fid);
  33 + p = env->GetStringUTFChars(s, nullptr);
  34 + ans.hotwords_file = p;
  35 + env->ReleaseStringUTFChars(s, p);
  36 +
  37 + fid = env->GetFieldID(cls, "hotwordsScore", "F");
  38 + ans.hotwords_score = env->GetFloatField(config, fid);
  39 +
  40 + //---------- feat config ----------
  41 + fid = env->GetFieldID(cls, "featConfig",
  42 + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
  43 + jobject feat_config = env->GetObjectField(config, fid);
  44 + jclass feat_config_cls = env->GetObjectClass(feat_config);
  45 +
  46 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  47 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
  48 +
  49 + fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
  50 + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
  51 +
  52 + //---------- enable endpoint ----------
  53 + fid = env->GetFieldID(cls, "enableEndpoint", "Z");
  54 + ans.enable_endpoint = env->GetBooleanField(config, fid);
  55 +
  56 + //---------- endpoint_config ----------
  57 +
  58 + fid = env->GetFieldID(cls, "endpointConfig",
  59 + "Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
  60 + jobject endpoint_config = env->GetObjectField(config, fid);
  61 + jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
  62 +
  63 + fid = env->GetFieldID(endpoint_config_cls, "rule1",
  64 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  65 + jobject rule1 = env->GetObjectField(endpoint_config, fid);
  66 + jclass rule_class = env->GetObjectClass(rule1);
  67 +
  68 + fid = env->GetFieldID(endpoint_config_cls, "rule2",
  69 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  70 + jobject rule2 = env->GetObjectField(endpoint_config, fid);
  71 +
  72 + fid = env->GetFieldID(endpoint_config_cls, "rule3",
  73 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  74 + jobject rule3 = env->GetObjectField(endpoint_config, fid);
  75 +
  76 + fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
  77 + ans.endpoint_config.rule1.must_contain_nonsilence =
  78 + env->GetBooleanField(rule1, fid);
  79 + ans.endpoint_config.rule2.must_contain_nonsilence =
  80 + env->GetBooleanField(rule2, fid);
  81 + ans.endpoint_config.rule3.must_contain_nonsilence =
  82 + env->GetBooleanField(rule3, fid);
  83 +
  84 + fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
  85 + ans.endpoint_config.rule1.min_trailing_silence =
  86 + env->GetFloatField(rule1, fid);
  87 + ans.endpoint_config.rule2.min_trailing_silence =
  88 + env->GetFloatField(rule2, fid);
  89 + ans.endpoint_config.rule3.min_trailing_silence =
  90 + env->GetFloatField(rule3, fid);
  91 +
  92 + fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
  93 + ans.endpoint_config.rule1.min_utterance_length =
  94 + env->GetFloatField(rule1, fid);
  95 + ans.endpoint_config.rule2.min_utterance_length =
  96 + env->GetFloatField(rule2, fid);
  97 + ans.endpoint_config.rule3.min_utterance_length =
  98 + env->GetFloatField(rule3, fid);
  99 +
  100 + //---------- model config ----------
  101 + fid = env->GetFieldID(cls, "modelConfig",
  102 + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
  103 + jobject model_config = env->GetObjectField(config, fid);
  104 + jclass model_config_cls = env->GetObjectClass(model_config);
  105 +
  106 + // transducer
  107 + fid = env->GetFieldID(model_config_cls, "transducer",
  108 + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
  109 + jobject transducer_config = env->GetObjectField(model_config, fid);
  110 + jclass transducer_config_cls = env->GetObjectClass(transducer_config);
  111 +
  112 + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
  113 + s = (jstring)env->GetObjectField(transducer_config, fid);
  114 + p = env->GetStringUTFChars(s, nullptr);
  115 + ans.model_config.transducer.encoder = p;
  116 + env->ReleaseStringUTFChars(s, p);
  117 +
  118 + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
  119 + s = (jstring)env->GetObjectField(transducer_config, fid);
  120 + p = env->GetStringUTFChars(s, nullptr);
  121 + ans.model_config.transducer.decoder = p;
  122 + env->ReleaseStringUTFChars(s, p);
  123 +
  124 + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
  125 + s = (jstring)env->GetObjectField(transducer_config, fid);
  126 + p = env->GetStringUTFChars(s, nullptr);
  127 + ans.model_config.transducer.joiner = p;
  128 + env->ReleaseStringUTFChars(s, p);
  129 +
  130 + // paraformer
  131 + fid = env->GetFieldID(model_config_cls, "paraformer",
  132 + "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
  133 + jobject paraformer_config = env->GetObjectField(model_config, fid);
  134 + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
  135 +
  136 + fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
  137 + s = (jstring)env->GetObjectField(paraformer_config, fid);
  138 + p = env->GetStringUTFChars(s, nullptr);
  139 + ans.model_config.paraformer.encoder = p;
  140 + env->ReleaseStringUTFChars(s, p);
  141 +
  142 + fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
  143 + s = (jstring)env->GetObjectField(paraformer_config, fid);
  144 + p = env->GetStringUTFChars(s, nullptr);
  145 + ans.model_config.paraformer.decoder = p;
  146 + env->ReleaseStringUTFChars(s, p);
  147 +
  148 + // streaming zipformer2 CTC
  149 + fid =
  150 + env->GetFieldID(model_config_cls, "zipformer2Ctc",
  151 + "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
  152 + jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
  153 + jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
  154 +
  155 + fid =
  156 + env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
  157 + s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
  158 + p = env->GetStringUTFChars(s, nullptr);
  159 + ans.model_config.zipformer2_ctc.model = p;
  160 + env->ReleaseStringUTFChars(s, p);
  161 +
  162 + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
  163 + s = (jstring)env->GetObjectField(model_config, fid);
  164 + p = env->GetStringUTFChars(s, nullptr);
  165 + ans.model_config.tokens = p;
  166 + env->ReleaseStringUTFChars(s, p);
  167 +
  168 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  169 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
  170 +
  171 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  172 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
  173 +
  174 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  175 + s = (jstring)env->GetObjectField(model_config, fid);
  176 + p = env->GetStringUTFChars(s, nullptr);
  177 + ans.model_config.provider = p;
  178 + env->ReleaseStringUTFChars(s, p);
  179 +
  180 + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
  181 + s = (jstring)env->GetObjectField(model_config, fid);
  182 + p = env->GetStringUTFChars(s, nullptr);
  183 + ans.model_config.model_type = p;
  184 + env->ReleaseStringUTFChars(s, p);
  185 +
  186 + //---------- rnn lm model config ----------
  187 + fid = env->GetFieldID(cls, "lmConfig",
  188 + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
  189 + jobject lm_model_config = env->GetObjectField(config, fid);
  190 + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
  191 +
  192 + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
  193 + s = (jstring)env->GetObjectField(lm_model_config, fid);
  194 + p = env->GetStringUTFChars(s, nullptr);
  195 + ans.lm_config.model = p;
  196 + env->ReleaseStringUTFChars(s, p);
  197 +
  198 + fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
  199 + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
  200 +
  201 + return ans;
  202 +}
  203 +} // namespace sherpa_onnx
  204 +
  205 +SHERPA_ONNX_EXTERN_C
  206 +JNIEXPORT jlong JNICALL
  207 +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
  208 + jobject /*obj*/,
  209 + jobject asset_manager,
  210 + jobject _config) {
  211 +#if __ANDROID_API__ >= 9
  212 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  213 + if (!mgr) {
  214 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  215 + }
  216 +#endif
  217 + auto config = sherpa_onnx::GetConfig(env, _config);
  218 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  219 +
  220 + auto recognizer = new sherpa_onnx::OnlineRecognizer(
  221 +#if __ANDROID_API__ >= 9
  222 + mgr,
  223 +#endif
  224 + config);
  225 +
  226 + return (jlong)recognizer;
  227 +}
  228 +
  229 +SHERPA_ONNX_EXTERN_C
  230 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromFile(
  231 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  232 + auto config = sherpa_onnx::GetConfig(env, _config);
  233 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  234 +
  235 + if (!config.Validate()) {
  236 + SHERPA_ONNX_LOGE("Errors found in config!");
  237 + return 0;
  238 + }
  239 +
  240 + auto recognizer = new sherpa_onnx::OnlineRecognizer(config);
  241 +
  242 + return (jlong)recognizer;
  243 +}
  244 +
  245 +SHERPA_ONNX_EXTERN_C
  246 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_delete(
  247 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  248 + delete reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  249 +}
  250 +
  251 +SHERPA_ONNX_EXTERN_C
  252 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reset(
  253 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  254 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  255 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  256 + recognizer->Reset(stream);
  257 +}
  258 +
  259 +SHERPA_ONNX_EXTERN_C
  260 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady(
  261 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  262 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  263 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  264 +
  265 + return recognizer->IsReady(stream);
  266 +}
  267 +
  268 +SHERPA_ONNX_EXTERN_C
  269 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint(
  270 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  271 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  272 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  273 +
  274 + return recognizer->IsEndpoint(stream);
  275 +}
  276 +
  277 +SHERPA_ONNX_EXTERN_C
  278 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decode(
  279 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  280 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  281 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  282 +
  283 + recognizer->DecodeStream(stream);
  284 +}
  285 +
  286 +SHERPA_ONNX_EXTERN_C
  287 +JNIEXPORT jlong JNICALL
  288 +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env,
  289 + jobject /*obj*/,
  290 + jlong ptr,
  291 + jstring hotwords) {
  292 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  293 +
  294 + const char *p = env->GetStringUTFChars(hotwords, nullptr);
  295 + std::unique_ptr<sherpa_onnx::OnlineStream> stream;
  296 +
  297 + if (strlen(p) == 0) {
  298 + stream = recognizer->CreateStream();
  299 + } else {
  300 + stream = recognizer->CreateStream(p);
  301 + }
  302 +
  303 + env->ReleaseStringUTFChars(hotwords, p);
  304 +
  305 + // The user is responsible to free the returned pointer.
  306 + //
  307 + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
  308 + // ./offline-stream.cc
  309 + sherpa_onnx::OnlineStream *ans = stream.release();
  310 + return (jlong)ans;
  311 +}
  312 +
  313 +SHERPA_ONNX_EXTERN_C
  314 +JNIEXPORT jobjectArray JNICALL
  315 +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env,
  316 + jobject /*obj*/,
  317 + jlong ptr,
  318 + jlong stream_ptr) {
  319 + auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
  320 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  321 +
  322 + sherpa_onnx::OnlineRecognizerResult result = recognizer->GetResult(stream);
  323 +
  324 + // [0]: text, jstring
  325 + // [1]: tokens, array of jstring
  326 + // [2]: timestamps, array of float
  327 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  328 + 3, env->FindClass("java/lang/Object"), nullptr);
  329 +
  330 + jstring text = env->NewStringUTF(result.text.c_str());
  331 + env->SetObjectArrayElement(obj_arr, 0, text);
  332 +
  333 + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
  334 + result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
  335 +
  336 + int32_t i = 0;
  337 + for (const auto &t : result.tokens) {
  338 + jstring jtext = env->NewStringUTF(t.c_str());
  339 + env->SetObjectArrayElement(tokens_arr, i, jtext);
  340 + i += 1;
  341 + }
  342 +
  343 + env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
  344 +
  345 + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
  346 + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
  347 + result.timestamps.data());
  348 +
  349 + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
  350 +
  351 + return obj_arr;
  352 +}
  1 +// sherpa-onnx/jni/online-stream.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-stream.h"
  6 +
  7 +#include "sherpa-onnx/jni/common.h"
  8 +
  9 +SHERPA_ONNX_EXTERN_C
  10 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_delete(
  11 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  12 + delete reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
  13 +}
  14 +
  15 +SHERPA_ONNX_EXTERN_C
  16 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform(
  17 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  18 + jint sample_rate) {
  19 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
  20 +
  21 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  22 + jsize n = env->GetArrayLength(samples);
  23 + stream->AcceptWaveform(sample_rate, p, n);
  24 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  25 +}
  26 +
  27 +SHERPA_ONNX_EXTERN_C
  28 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished(
  29 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  30 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
  31 + stream->InputFinished();
  32 +}
  1 +// sherpa-onnx/jni/speaker-embedding-extractor.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  5 +
  6 +#include "sherpa-onnx/jni/common.h"
  7 +
  8 +namespace sherpa_onnx {
  9 +
  10 +static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
  11 + JNIEnv *env, jobject config) {
  12 + SpeakerEmbeddingExtractorConfig ans;
  13 +
  14 + jclass cls = env->GetObjectClass(config);
  15 +
  16 + jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
  17 + jstring s = (jstring)env->GetObjectField(config, fid);
  18 + const char *p = env->GetStringUTFChars(s, nullptr);
  19 +
  20 + ans.model = p;
  21 + env->ReleaseStringUTFChars(s, p);
  22 +
  23 + fid = env->GetFieldID(cls, "numThreads", "I");
  24 + ans.num_threads = env->GetIntField(config, fid);
  25 +
  26 + fid = env->GetFieldID(cls, "debug", "Z");
  27 + ans.debug = env->GetBooleanField(config, fid);
  28 +
  29 + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
  30 + s = (jstring)env->GetObjectField(config, fid);
  31 + p = env->GetStringUTFChars(s, nullptr);
  32 + ans.provider = p;
  33 + env->ReleaseStringUTFChars(s, p);
  34 +
  35 + return ans;
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +SHERPA_ONNX_EXTERN_C
  41 +JNIEXPORT jlong JNICALL
  42 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
  43 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  44 +#if __ANDROID_API__ >= 9
  45 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  46 + if (!mgr) {
  47 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  48 + }
  49 +#endif
  50 + auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
  51 + SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
  52 +
  53 + auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(
  54 +#if __ANDROID_API__ >= 9
  55 + mgr,
  56 +#endif
  57 + config);
  58 +
  59 + return (jlong)extractor;
  60 +}
  61 +
  62 +SHERPA_ONNX_EXTERN_C
  63 +JNIEXPORT jlong JNICALL
  64 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
  65 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  66 + auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
  67 + SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
  68 +
  69 + if (!config.Validate()) {
  70 + SHERPA_ONNX_LOGE("Errors found in config!");
  71 + }
  72 +
  73 + auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(config);
  74 +
  75 + return (jlong)extractor;
  76 +}
  77 +
  78 +SHERPA_ONNX_EXTERN_C
  79 +JNIEXPORT void JNICALL
  80 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
  81 + jobject /*obj*/,
  82 + jlong ptr) {
  83 + delete reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
  84 +}
  85 +
  86 +SHERPA_ONNX_EXTERN_C
  87 +JNIEXPORT jlong JNICALL
  88 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
  89 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  90 + std::unique_ptr<sherpa_onnx::OnlineStream> s =
  91 + reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr)
  92 + ->CreateStream();
  93 +
  94 + // The user is responsible to free the returned pointer.
  95 + //
  96 + // See Java_com_k2fsa_sherpa_onnx_OnlineStream_delete() from
  97 + // ./online-stream.cc
  98 + sherpa_onnx::OnlineStream *p = s.release();
  99 + return (jlong)p;
  100 +}
  101 +
  102 +SHERPA_ONNX_EXTERN_C
  103 +JNIEXPORT jboolean JNICALL
  104 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
  105 + jobject /*obj*/,
  106 + jlong ptr,
  107 + jlong stream_ptr) {
  108 + auto extractor =
  109 + reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
  110 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  111 + return extractor->IsReady(stream);
  112 +}
  113 +
  114 +SHERPA_ONNX_EXTERN_C
  115 +JNIEXPORT jfloatArray JNICALL
  116 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
  117 + jobject /*obj*/,
  118 + jlong ptr,
  119 + jlong stream_ptr) {
  120 + auto extractor =
  121 + reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
  122 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  123 +
  124 + std::vector<float> embedding = extractor->Compute(stream);
  125 + jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
  126 + env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
  127 + embedding.data());
  128 + return embedding_arr;
  129 +}
  130 +
  131 +SHERPA_ONNX_EXTERN_C
  132 +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
  133 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  134 + auto extractor =
  135 + reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
  136 + return extractor->Dim();
  137 +}
  1 +// sherpa-onnx/jni/speaker-embedding-manager.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
  5 +
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/jni/common.h"
  8 +
  9 +SHERPA_ONNX_EXTERN_C
  10 +JNIEXPORT jlong JNICALL
  11 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_create(JNIEnv *env,
  12 + jobject /*obj*/,
  13 + jint dim) {
  14 + auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
  15 + return (jlong)p;
  16 +}
  17 +
  18 +SHERPA_ONNX_EXTERN_C
  19 +JNIEXPORT void JNICALL
  20 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
  21 + jobject /*obj*/,
  22 + jlong ptr) {
  23 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  24 + delete manager;
  25 +}
  26 +
  27 +SHERPA_ONNX_EXTERN_C
  28 +JNIEXPORT jboolean JNICALL
  29 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
  30 + jobject /*obj*/,
  31 + jlong ptr, jstring name,
  32 + jfloatArray embedding) {
  33 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  34 +
  35 + jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
  36 + jsize n = env->GetArrayLength(embedding);
  37 +
  38 + if (n != manager->Dim()) {
  39 + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
  40 + static_cast<int32_t>(n));
  41 + exit(-1);
  42 + }
  43 +
  44 + const char *p_name = env->GetStringUTFChars(name, nullptr);
  45 +
  46 + jboolean ok = manager->Add(p_name, p);
  47 + env->ReleaseStringUTFChars(name, p_name);
  48 + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
  49 +
  50 + return ok;
  51 +}
  52 +
  53 +SHERPA_ONNX_EXTERN_C
  54 +JNIEXPORT jboolean JNICALL
  55 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
  56 + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
  57 + jobjectArray embedding_arr) {
  58 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  59 +
  60 + int num_embeddings = env->GetArrayLength(embedding_arr);
  61 + if (num_embeddings == 0) {
  62 + return false;
  63 + }
  64 +
  65 + std::vector<std::vector<float>> embedding_list;
  66 + embedding_list.reserve(num_embeddings);
  67 + for (int32_t i = 0; i != num_embeddings; ++i) {
  68 + jfloatArray embedding =
  69 + (jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
  70 +
  71 + jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
  72 + jsize n = env->GetArrayLength(embedding);
  73 +
  74 + if (n != manager->Dim()) {
  75 + SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
  76 + static_cast<int32_t>(n));
  77 + exit(-1);
  78 + }
  79 +
  80 + embedding_list.push_back({p, p + n});
  81 + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
  82 + }
  83 +
  84 + const char *p_name = env->GetStringUTFChars(name, nullptr);
  85 +
  86 + jboolean ok = manager->Add(p_name, embedding_list);
  87 +
  88 + env->ReleaseStringUTFChars(name, p_name);
  89 +
  90 + return ok;
  91 +}
  92 +
  93 +SHERPA_ONNX_EXTERN_C
  94 +JNIEXPORT jboolean JNICALL
  95 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
  96 + jobject /*obj*/,
  97 + jlong ptr,
  98 + jstring name) {
  99 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  100 +
  101 + const char *p_name = env->GetStringUTFChars(name, nullptr);
  102 +
  103 + jboolean ok = manager->Remove(p_name);
  104 +
  105 + env->ReleaseStringUTFChars(name, p_name);
  106 +
  107 + return ok;
  108 +}
  109 +
  110 +SHERPA_ONNX_EXTERN_C
  111 +JNIEXPORT jstring JNICALL
  112 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
  113 + jobject /*obj*/,
  114 + jlong ptr,
  115 + jfloatArray embedding,
  116 + jfloat threshold) {
  117 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  118 +
  119 + jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
  120 + jsize n = env->GetArrayLength(embedding);
  121 +
  122 + if (n != manager->Dim()) {
  123 + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
  124 + static_cast<int32_t>(n));
  125 + exit(-1);
  126 + }
  127 +
  128 + std::string name = manager->Search(p, threshold);
  129 +
  130 + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
  131 +
  132 + return env->NewStringUTF(name.c_str());
  133 +}
  134 +
  135 +SHERPA_ONNX_EXTERN_C
  136 +JNIEXPORT jboolean JNICALL
  137 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
  138 + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
  139 + jfloatArray embedding, jfloat threshold) {
  140 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  141 +
  142 + jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
  143 + jsize n = env->GetArrayLength(embedding);
  144 +
  145 + if (n != manager->Dim()) {
  146 + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
  147 + static_cast<int32_t>(n));
  148 + exit(-1);
  149 + }
  150 +
  151 + const char *p_name = env->GetStringUTFChars(name, nullptr);
  152 +
  153 + jboolean ok = manager->Verify(p_name, p, threshold);
  154 +
  155 + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
  156 +
  157 + env->ReleaseStringUTFChars(name, p_name);
  158 +
  159 + return ok;
  160 +}
  161 +
  162 +SHERPA_ONNX_EXTERN_C
  163 +JNIEXPORT jboolean JNICALL
  164 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
  165 + jobject /*obj*/,
  166 + jlong ptr,
  167 + jstring name) {
  168 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  169 +
  170 + const char *p_name = env->GetStringUTFChars(name, nullptr);
  171 +
  172 + jboolean ok = manager->Contains(p_name);
  173 +
  174 + env->ReleaseStringUTFChars(name, p_name);
  175 +
  176 + return ok;
  177 +}
  178 +
  179 +SHERPA_ONNX_EXTERN_C
  180 +JNIEXPORT jint JNICALL
  181 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
  182 + jobject /*obj*/,
  183 + jlong ptr) {
  184 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  185 + return manager->NumSpeakers();
  186 +}
  187 +
  188 +SHERPA_ONNX_EXTERN_C
  189 +JNIEXPORT jobjectArray JNICALL
  190 +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
  191 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  192 + auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
  193 + std::vector<std::string> all_speakers = manager->GetAllSpeakers();
  194 +
  195 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  196 + all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
  197 +
  198 + int32_t i = 0;
  199 + for (auto &s : all_speakers) {
  200 + jstring js = env->NewStringUTF(s.c_str());
  201 + env->SetObjectArrayElement(obj_arr, i, js);
  202 +
  203 + ++i;
  204 + }
  205 +
  206 + return obj_arr;
  207 +}
  1 +// sherpa-onnx/csrc/voice-activity-detector.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/voice-activity-detector.h"
  5 +
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/jni/common.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
  12 + VadModelConfig ans;
  13 +
  14 + jclass cls = env->GetObjectClass(config);
  15 + jfieldID fid;
  16 +
  17 + // silero_vad
  18 + fid = env->GetFieldID(cls, "sileroVadModelConfig",
  19 + "Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");
  20 + jobject silero_vad_config = env->GetObjectField(config, fid);
  21 + jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);
  22 +
  23 + fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");
  24 + auto s = (jstring)env->GetObjectField(silero_vad_config, fid);
  25 + auto p = env->GetStringUTFChars(s, nullptr);
  26 + ans.silero_vad.model = p;
  27 + env->ReleaseStringUTFChars(s, p);
  28 +
  29 + fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");
  30 + ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);
  31 +
  32 + fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");
  33 + ans.silero_vad.min_silence_duration =
  34 + env->GetFloatField(silero_vad_config, fid);
  35 +
  36 + fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");
  37 + ans.silero_vad.min_speech_duration =
  38 + env->GetFloatField(silero_vad_config, fid);
  39 +
  40 + fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");
  41 + ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);
  42 +
  43 + fid = env->GetFieldID(cls, "sampleRate", "I");
  44 + ans.sample_rate = env->GetIntField(config, fid);
  45 +
  46 + fid = env->GetFieldID(cls, "numThreads", "I");
  47 + ans.num_threads = env->GetIntField(config, fid);
  48 +
  49 + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
  50 + s = (jstring)env->GetObjectField(config, fid);
  51 + p = env->GetStringUTFChars(s, nullptr);
  52 + ans.provider = p;
  53 + env->ReleaseStringUTFChars(s, p);
  54 +
  55 + fid = env->GetFieldID(cls, "debug", "Z");
  56 + ans.debug = env->GetBooleanField(config, fid);
  57 +
  58 + return ans;
  59 +}
  60 +
  61 +} // namespace sherpa_onnx
  62 +
  63 +SHERPA_ONNX_EXTERN_C
  64 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
  65 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  66 +#if __ANDROID_API__ >= 9
  67 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  68 + if (!mgr) {
  69 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  70 + }
  71 +#endif
  72 + auto config = sherpa_onnx::GetVadModelConfig(env, _config);
  73 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  74 + auto model = new sherpa_onnx::VoiceActivityDetector(
  75 +#if __ANDROID_API__ >= 9
  76 + mgr,
  77 +#endif
  78 + config);
  79 +
  80 + return (jlong)model;
  81 +}
  82 +
  83 +SHERPA_ONNX_EXTERN_C
  84 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
  85 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  86 + auto config = sherpa_onnx::GetVadModelConfig(env, _config);
  87 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  88 +
  89 + if (!config.Validate()) {
  90 + SHERPA_ONNX_LOGE("Errors found in config!");
  91 + return 0;
  92 + }
  93 +
  94 + auto model = new sherpa_onnx::VoiceActivityDetector(config);
  95 +
  96 + return (jlong)model;
  97 +}
  98 +
  99 +SHERPA_ONNX_EXTERN_C
  100 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,
  101 + jobject /*obj*/,
  102 + jlong ptr) {
  103 + delete reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  104 +}
  105 +
  106 +SHERPA_ONNX_EXTERN_C
  107 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
  108 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
  109 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  110 +
  111 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  112 + jsize n = env->GetArrayLength(samples);
  113 +
  114 + model->AcceptWaveform(p, n);
  115 +
  116 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  117 +}
  118 +
  119 +SHERPA_ONNX_EXTERN_C
  120 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,
  121 + jobject /*obj*/,
  122 + jlong ptr) {
  123 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  124 + return model->Empty();
  125 +}
  126 +
  127 +SHERPA_ONNX_EXTERN_C
  128 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
  129 + jobject /*obj*/,
  130 + jlong ptr) {
  131 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  132 + model->Pop();
  133 +}
  134 +
  135 +SHERPA_ONNX_EXTERN_C
  136 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
  137 + jobject /*obj*/,
  138 + jlong ptr) {
  139 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  140 + model->Clear();
  141 +}
  142 +
  143 +SHERPA_ONNX_EXTERN_C
  144 +JNIEXPORT jobjectArray JNICALL
  145 +Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {
  146 + const auto &front =
  147 + reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr)->Front();
  148 +
  149 + jfloatArray samples_arr = env->NewFloatArray(front.samples.size());
  150 + env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),
  151 + front.samples.data());
  152 +
  153 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  154 + 2, env->FindClass("java/lang/Object"), nullptr);
  155 +
  156 + env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));
  157 + env->SetObjectArrayElement(obj_arr, 1, samples_arr);
  158 +
  159 + return obj_arr;
  160 +}
  161 +
  162 +SHERPA_ONNX_EXTERN_C
  163 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
  164 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  165 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  166 + return model->IsSpeechDetected();
  167 +}
  168 +
  169 +SHERPA_ONNX_EXTERN_C
  170 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,
  171 + jobject /*obj*/,
  172 + jlong ptr) {
  173 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  174 + model->Reset();
  175 +}
@@ -2,8 +2,6 @@ package com.k2fsa.sherpa.onnx @@ -2,8 +2,6 @@ package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 4
5 -const val TAG = "sherpa-onnx"  
6 -  
7 data class OfflineZipformerAudioTaggingModelConfig( 5 data class OfflineZipformerAudioTaggingModelConfig(
8 var model: String = "", 6 var model: String = "",
9 ) 7 )
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +data class FeatureConfig(
  4 + var sampleRate: Int = 16000,
  5 + var featureDim: Int = 80,
  6 +)
  7 +
  8 +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
  9 + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
  10 +}
@@ -3,26 +3,6 @@ package com.k2fsa.sherpa.onnx @@ -3,26 +3,6 @@ package com.k2fsa.sherpa.onnx
3 3
4 import android.content.res.AssetManager 4 import android.content.res.AssetManager
5 5
6 -data class OnlineTransducerModelConfig(  
7 - var encoder: String = "",  
8 - var decoder: String = "",  
9 - var joiner: String = "",  
10 -)  
11 -  
12 -data class OnlineModelConfig(  
13 - var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),  
14 - var tokens: String,  
15 - var numThreads: Int = 1,  
16 - var debug: Boolean = false,  
17 - var provider: String = "cpu",  
18 - var modelType: String = "",  
19 -)  
20 -  
21 -data class FeatureConfig(  
22 - var sampleRate: Int = 16000,  
23 - var featureDim: Int = 80,  
24 -)  
25 -  
26 data class KeywordSpotterConfig( 6 data class KeywordSpotterConfig(
27 var featConfig: FeatureConfig = FeatureConfig(), 7 var featConfig: FeatureConfig = FeatureConfig(),
28 var modelConfig: OnlineModelConfig, 8 var modelConfig: OnlineModelConfig,
@@ -33,17 +13,24 @@ data class KeywordSpotterConfig( @@ -33,17 +13,24 @@ data class KeywordSpotterConfig(
33 var numTrailingBlanks: Int = 2, 13 var numTrailingBlanks: Int = 2,
34 ) 14 )
35 15
36 -class SherpaOnnxKws( 16 +data class KeywordSpotterResult(
  17 + val keyword: String,
  18 + val tokens: Array<String>,
  19 + val timestamps: FloatArray,
  20 + // TODO(fangjun): Add more fields
  21 +)
  22 +
  23 +class KeywordSpotter(
37 assetManager: AssetManager? = null, 24 assetManager: AssetManager? = null,
38 - var config: KeywordSpotterConfig, 25 + val config: KeywordSpotterConfig,
39 ) { 26 ) {
40 private val ptr: Long 27 private val ptr: Long
41 28
42 init { 29 init {
43 - if (assetManager != null) {  
44 - ptr = new(assetManager, config) 30 + ptr = if (assetManager != null) {
  31 + newFromAsset(assetManager, config)
45 } else { 32 } else {
46 - ptr = newFromFile(config) 33 + newFromFile(config)
47 } 34 }
48 } 35 }
49 36
@@ -51,20 +38,28 @@ class SherpaOnnxKws( @@ -51,20 +38,28 @@ class SherpaOnnxKws(
51 delete(ptr) 38 delete(ptr)
52 } 39 }
53 40
54 - fun acceptWaveform(samples: FloatArray, sampleRate: Int) =  
55 - acceptWaveform(ptr, samples, sampleRate) 41 + fun release() = finalize()
  42 +
  43 + fun createStream(keywords: String = ""): OnlineStream {
  44 + val p = createStream(ptr, keywords)
  45 + return OnlineStream(p)
  46 + }
56 47
57 - fun inputFinished() = inputFinished(ptr)  
58 - fun decode() = decode(ptr)  
59 - fun isReady(): Boolean = isReady(ptr)  
60 - fun reset(keywords: String): Boolean = reset(ptr, keywords) 48 + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
  49 + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
  50 + fun getResult(stream: OnlineStream): KeywordSpotterResult {
  51 + val objArray = getResult(ptr, stream.ptr)
61 52
62 - val keyword: String  
63 - get() = getKeyword(ptr) 53 + val keyword = objArray[0] as String
  54 + val tokens = objArray[1] as Array<String>
  55 + val timestamps = objArray[2] as FloatArray
  56 +
  57 + return KeywordSpotterResult(keyword = keyword, tokens = tokens, timestamps = timestamps)
  58 + }
64 59
65 private external fun delete(ptr: Long) 60 private external fun delete(ptr: Long)
66 61
67 - private external fun new( 62 + private external fun newFromAsset(
68 assetManager: AssetManager, 63 assetManager: AssetManager,
69 config: KeywordSpotterConfig, 64 config: KeywordSpotterConfig,
70 ): Long 65 ): Long
@@ -73,12 +68,10 @@ class SherpaOnnxKws( @@ -73,12 +68,10 @@ class SherpaOnnxKws(
73 config: KeywordSpotterConfig, 68 config: KeywordSpotterConfig,
74 ): Long 69 ): Long
75 70
76 - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)  
77 - private external fun inputFinished(ptr: Long)  
78 - private external fun getKeyword(ptr: Long): String  
79 - private external fun reset(ptr: Long, keywords: String): Boolean  
80 - private external fun decode(ptr: Long)  
81 - private external fun isReady(ptr: Long): Boolean 71 + private external fun createStream(ptr: Long, keywords: String): Long
  72 + private external fun isReady(ptr: Long, streamPtr: Long): Boolean
  73 + private external fun decode(ptr: Long, streamPtr: Long)
  74 + private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
82 75
83 companion object { 76 companion object {
84 init { 77 init {
@@ -87,10 +80,6 @@ class SherpaOnnxKws( @@ -87,10 +80,6 @@ class SherpaOnnxKws(
87 } 80 }
88 } 81 }
89 82
90 -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {  
91 - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)  
92 -}  
93 -  
94 /* 83 /*
95 Please see 84 Please see
96 https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html 85 https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
@@ -108,7 +97,7 @@ by following the code) @@ -108,7 +97,7 @@ by following the code)
108 https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary 97 https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
109 98
110 */ 99 */
111 -fun getModelConfig(type: Int): OnlineModelConfig? { 100 +fun getKwsModelConfig(type: Int): OnlineModelConfig? {
112 when (type) { 101 when (type) {
113 0 -> { 102 0 -> {
114 val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" 103 val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
@@ -137,7 +126,7 @@ fun getModelConfig(type: Int): OnlineModelConfig? { @@ -137,7 +126,7 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
137 } 126 }
138 127
139 } 128 }
140 - return null; 129 + return null
141 } 130 }
142 131
143 /* 132 /*
@@ -145,7 +134,7 @@ fun getModelConfig(type: Int): OnlineModelConfig? { @@ -145,7 +134,7 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
145 * Caution: The types and modelDir should be the same as those in getModelConfig 134 * Caution: The types and modelDir should be the same as those in getModelConfig
146 * function above. 135 * function above.
147 */ 136 */
148 -fun getKeywordsFile(type: Int) : String { 137 +fun getKeywordsFile(type: Int): String {
149 when (type) { 138 when (type) {
150 0 -> { 139 0 -> {
151 val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" 140 val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
@@ -158,5 +147,5 @@ fun getKeywordsFile(type: Int) : String { @@ -158,5 +147,5 @@ fun getKeywordsFile(type: Int) : String {
158 } 147 }
159 148
160 } 149 }
161 - return ""; 150 + return ""
162 } 151 }
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +data class OfflineRecognizerResult(
  6 + val text: String,
  7 + val tokens: Array<String>,
  8 + val timestamps: FloatArray,
  9 +)
  10 +
  11 +data class OfflineTransducerModelConfig(
  12 + var encoder: String = "",
  13 + var decoder: String = "",
  14 + var joiner: String = "",
  15 +)
  16 +
  17 +data class OfflineParaformerModelConfig(
  18 + var model: String = "",
  19 +)
  20 +
  21 +data class OfflineWhisperModelConfig(
  22 + var encoder: String = "",
  23 + var decoder: String = "",
  24 + var language: String = "en", // Used with multilingual model
  25 + var task: String = "transcribe", // transcribe or translate
  26 + var tailPaddings: Int = 1000, // Padding added at the end of the samples
  27 +)
  28 +
  29 +data class OfflineModelConfig(
  30 + var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
  31 + var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
  32 + var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
  33 + var numThreads: Int = 1,
  34 + var debug: Boolean = false,
  35 + var provider: String = "cpu",
  36 + var modelType: String = "",
  37 + var tokens: String,
  38 +)
  39 +
  40 +data class OfflineRecognizerConfig(
  41 + var featConfig: FeatureConfig = FeatureConfig(),
  42 + var modelConfig: OfflineModelConfig,
  43 + // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
  44 + var decodingMethod: String = "greedy_search",
  45 + var maxActivePaths: Int = 4,
  46 + var hotwordsFile: String = "",
  47 + var hotwordsScore: Float = 1.5f,
  48 +)
  49 +
  50 +class OfflineRecognizer(
  51 + assetManager: AssetManager? = null,
  52 + config: OfflineRecognizerConfig,
  53 +) {
  54 + private val ptr: Long
  55 +
  56 + init {
  57 + ptr = if (assetManager != null) {
  58 + newFromAsset(assetManager, config)
  59 + } else {
  60 + newFromFile(config)
  61 + }
  62 + }
  63 +
  64 + protected fun finalize() {
  65 + delete(ptr)
  66 + }
  67 +
  68 + fun release() = finalize()
  69 +
  70 + fun createStream(): OfflineStream {
  71 + val p = createStream(ptr)
  72 + return OfflineStream(p)
  73 + }
  74 +
  75 + fun getResult(stream: OfflineStream): OfflineRecognizerResult {
  76 + val objArray = getResult(stream.ptr)
  77 +
  78 + val text = objArray[0] as String
  79 + val tokens = objArray[1] as Array<String>
  80 + val timestamps = objArray[2] as FloatArray
  81 + return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
  82 + }
  83 +
  84 + fun decode(stream: OfflineStream) = decode(ptr, stream.ptr)
  85 +
  86 + private external fun delete(ptr: Long)
  87 +
  88 + private external fun createStream(ptr: Long): Long
  89 +
  90 + private external fun newFromAsset(
  91 + assetManager: AssetManager,
  92 + config: OfflineRecognizerConfig,
  93 + ): Long
  94 +
  95 + private external fun newFromFile(
  96 + config: OfflineRecognizerConfig,
  97 + ): Long
  98 +
  99 + private external fun decode(ptr: Long, streamPtr: Long)
  100 +
  101 + private external fun getResult(streamPtr: Long): Array<Any>
  102 +
  103 + companion object {
  104 + init {
  105 + System.loadLibrary("sherpa-onnx-jni")
  106 + }
  107 + }
  108 +}
  109 +
  110 +/*
  111 +Please see
  112 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  113 +for a list of pre-trained models.
  114 +
  115 +We only add a few here. Please change the following code
  116 +to add your own. (It should be straightforward to add a new model
  117 +by following the code)
  118 +
  119 +@param type
  120 +
  121 +0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)
  122 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese
  123 + int8
  124 +
  125 +1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English)
  126 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english
  127 + encoder int8, decoder/joiner float32
  128 +
  129 +2 - sherpa-onnx-whisper-tiny.en
  130 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
  131 + encoder int8, decoder int8
  132 +
  133 +3 - sherpa-onnx-whisper-base.en
  134 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
  135 + encoder int8, decoder int8
  136 +
  137 +4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)
  138 + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese
  139 + encoder/joiner int8, decoder fp32
  140 +
  141 + */
  142 +fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
  143 + when (type) {
  144 + 0 -> {
  145 + val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"
  146 + return OfflineModelConfig(
  147 + paraformer = OfflineParaformerModelConfig(
  148 + model = "$modelDir/model.int8.onnx",
  149 + ),
  150 + tokens = "$modelDir/tokens.txt",
  151 + modelType = "paraformer",
  152 + )
  153 + }
  154 +
  155 + 1 -> {
  156 + val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04"
  157 + return OfflineModelConfig(
  158 + transducer = OfflineTransducerModelConfig(
  159 + encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx",
  160 + decoder = "$modelDir/decoder-epoch-30-avg-4.onnx",
  161 + joiner = "$modelDir/joiner-epoch-30-avg-4.onnx",
  162 + ),
  163 + tokens = "$modelDir/tokens.txt",
  164 + modelType = "zipformer",
  165 + )
  166 + }
  167 +
  168 + 2 -> {
  169 + val modelDir = "sherpa-onnx-whisper-tiny.en"
  170 + return OfflineModelConfig(
  171 + whisper = OfflineWhisperModelConfig(
  172 + encoder = "$modelDir/tiny.en-encoder.int8.onnx",
  173 + decoder = "$modelDir/tiny.en-decoder.int8.onnx",
  174 + ),
  175 + tokens = "$modelDir/tiny.en-tokens.txt",
  176 + modelType = "whisper",
  177 + )
  178 + }
  179 +
  180 + 3 -> {
  181 + val modelDir = "sherpa-onnx-whisper-base.en"
  182 + return OfflineModelConfig(
  183 + whisper = OfflineWhisperModelConfig(
  184 + encoder = "$modelDir/base.en-encoder.int8.onnx",
  185 + decoder = "$modelDir/base.en-decoder.int8.onnx",
  186 + ),
  187 + tokens = "$modelDir/base.en-tokens.txt",
  188 + modelType = "whisper",
  189 + )
  190 + }
  191 +
  192 +
  193 + 4 -> {
  194 + val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"
  195 + return OfflineModelConfig(
  196 + transducer = OfflineTransducerModelConfig(
  197 + encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",
  198 + decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",
  199 + joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",
  200 + ),
  201 + tokens = "$modelDir/tokens.txt",
  202 + modelType = "zipformer",
  203 + )
  204 + }
  205 +
  206 + 5 -> {
  207 + val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2"
  208 + return OfflineModelConfig(
  209 + transducer = OfflineTransducerModelConfig(
  210 + encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx",
  211 + decoder = "$modelDir/decoder-epoch-20-avg-1.onnx",
  212 + joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx",
  213 + ),
  214 + tokens = "$modelDir/tokens.txt",
  215 + modelType = "zipformer2",
  216 + )
  217 + }
  218 +
  219 + }
  220 + return null
  221 +}
1 -// Copyright (c) 2023 Xiaomi Corporation  
2 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
3 2
4 import android.content.res.AssetManager 3 import android.content.res.AssetManager
@@ -46,15 +45,11 @@ data class OnlineLMConfig( @@ -46,15 +45,11 @@ data class OnlineLMConfig(
46 var scale: Float = 0.5f, 45 var scale: Float = 0.5f,
47 ) 46 )
48 47
49 -data class FeatureConfig(  
50 - var sampleRate: Int = 16000,  
51 - var featureDim: Int = 80,  
52 -)  
53 48
54 data class OnlineRecognizerConfig( 49 data class OnlineRecognizerConfig(
55 var featConfig: FeatureConfig = FeatureConfig(), 50 var featConfig: FeatureConfig = FeatureConfig(),
56 var modelConfig: OnlineModelConfig, 51 var modelConfig: OnlineModelConfig,
57 - var lmConfig: OnlineLMConfig, 52 + var lmConfig: OnlineLMConfig = OnlineLMConfig(),
58 var endpointConfig: EndpointConfig = EndpointConfig(), 53 var endpointConfig: EndpointConfig = EndpointConfig(),
59 var enableEndpoint: Boolean = true, 54 var enableEndpoint: Boolean = true,
60 var decodingMethod: String = "greedy_search", 55 var decodingMethod: String = "greedy_search",
@@ -63,17 +58,24 @@ data class OnlineRecognizerConfig( @@ -63,17 +58,24 @@ data class OnlineRecognizerConfig(
63 var hotwordsScore: Float = 1.5f, 58 var hotwordsScore: Float = 1.5f,
64 ) 59 )
65 60
66 -class SherpaOnnx( 61 +data class OnlineRecognizerResult(
  62 + val text: String,
  63 + val tokens: Array<String>,
  64 + val timestamps: FloatArray,
  65 + // TODO(fangjun): Add more fields
  66 +)
  67 +
  68 +class OnlineRecognizer(
67 assetManager: AssetManager? = null, 69 assetManager: AssetManager? = null,
68 - var config: OnlineRecognizerConfig, 70 + val config: OnlineRecognizerConfig,
69 ) { 71 ) {
70 private val ptr: Long 72 private val ptr: Long
71 73
72 init { 74 init {
73 - if (assetManager != null) {  
74 - ptr = new(assetManager, config) 75 + ptr = if (assetManager != null) {
  76 + newFromAsset(assetManager, config)
75 } else { 77 } else {
76 - ptr = newFromFile(config) 78 + newFromFile(config)
77 } 79 }
78 } 80 }
79 81
@@ -81,24 +83,30 @@ class SherpaOnnx( @@ -81,24 +83,30 @@ class SherpaOnnx(
81 delete(ptr) 83 delete(ptr)
82 } 84 }
83 85
84 - fun acceptWaveform(samples: FloatArray, sampleRate: Int) =  
85 - acceptWaveform(ptr, samples, sampleRate) 86 + fun release() = finalize()
  87 +
  88 + fun createStream(hotwords: String = ""): OnlineStream {
  89 + val p = createStream(ptr, hotwords)
  90 + return OnlineStream(p)
  91 + }
86 92
87 - fun inputFinished() = inputFinished(ptr)  
88 - fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)  
89 - fun decode() = decode(ptr)  
90 - fun isEndpoint(): Boolean = isEndpoint(ptr)  
91 - fun isReady(): Boolean = isReady(ptr) 93 + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
  94 + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
  95 + fun isEndpoint(stream: OnlineStream) = isEndpoint(ptr, stream.ptr)
  96 + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
  97 + fun getResult(stream: OnlineStream): OnlineRecognizerResult {
  98 + val objArray = getResult(ptr, stream.ptr)
92 99
93 - val text: String  
94 - get() = getText(ptr) 100 + val text = objArray[0] as String
  101 + val tokens = objArray[1] as Array<String>
  102 + val timestamps = objArray[2] as FloatArray
95 103
96 - val tokens: Array<String>  
97 - get() = getTokens(ptr) 104 + return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
  105 + }
98 106
99 private external fun delete(ptr: Long) 107 private external fun delete(ptr: Long)
100 108
101 - private external fun new( 109 + private external fun newFromAsset(
102 assetManager: AssetManager, 110 assetManager: AssetManager,
103 config: OnlineRecognizerConfig, 111 config: OnlineRecognizerConfig,
104 ): Long 112 ): Long
@@ -107,14 +115,12 @@ class SherpaOnnx( @@ -107,14 +115,12 @@ class SherpaOnnx(
107 config: OnlineRecognizerConfig, 115 config: OnlineRecognizerConfig,
108 ): Long 116 ): Long
109 117
110 - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)  
111 - private external fun inputFinished(ptr: Long)  
112 - private external fun getText(ptr: Long): String  
113 - private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)  
114 - private external fun decode(ptr: Long)  
115 - private external fun isEndpoint(ptr: Long): Boolean  
116 - private external fun isReady(ptr: Long): Boolean  
117 - private external fun getTokens(ptr: Long): Array<String> 118 + private external fun createStream(ptr: Long, hotwords: String): Long
  119 + private external fun reset(ptr: Long, streamPtr: Long)
  120 + private external fun decode(ptr: Long, streamPtr: Long)
  121 + private external fun isEndpoint(ptr: Long, streamPtr: Long): Boolean
  122 + private external fun isReady(ptr: Long, streamPtr: Long): Boolean
  123 + private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
118 124
119 companion object { 125 companion object {
120 init { 126 init {
@@ -123,9 +129,6 @@ class SherpaOnnx( @@ -123,9 +129,6 @@ class SherpaOnnx(
123 } 129 }
124 } 130 }
125 131
126 -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {  
127 - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)  
128 -}  
129 132
130 /* 133 /*
131 Please see 134 Please see
@@ -277,14 +280,40 @@ fun getModelConfig(type: Int): OnlineModelConfig? { @@ -277,14 +280,40 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
277 transducer = OnlineTransducerModelConfig( 280 transducer = OnlineTransducerModelConfig(
278 encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", 281 encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
279 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", 282 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
280 - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", 283 + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
  284 + ),
  285 + tokens = "$modelDir/tokens.txt",
  286 + modelType = "zipformer",
  287 + )
  288 + }
  289 +
  290 + 9 -> {
  291 + val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23"
  292 + return OnlineModelConfig(
  293 + transducer = OnlineTransducerModelConfig(
  294 + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
  295 + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
  296 + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
  297 + ),
  298 + tokens = "$modelDir/tokens.txt",
  299 + modelType = "zipformer",
  300 + )
  301 + }
  302 +
  303 + 10 -> {
  304 + val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"
  305 + return OnlineModelConfig(
  306 + transducer = OnlineTransducerModelConfig(
  307 + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
  308 + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
  309 + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
281 ), 310 ),
282 tokens = "$modelDir/tokens.txt", 311 tokens = "$modelDir/tokens.txt",
283 modelType = "zipformer", 312 modelType = "zipformer",
284 ) 313 )
285 } 314 }
286 } 315 }
287 - return null; 316 + return null
288 } 317 }
289 318
290 /* 319 /*
@@ -310,7 +339,7 @@ fun getOnlineLMConfig(type: Int): OnlineLMConfig { @@ -310,7 +339,7 @@ fun getOnlineLMConfig(type: Int): OnlineLMConfig {
310 ) 339 )
311 } 340 }
312 } 341 }
313 - return OnlineLMConfig(); 342 + return OnlineLMConfig()
314 } 343 }
315 344
316 fun getEndpointConfig(): EndpointConfig { 345 fun getEndpointConfig(): EndpointConfig {
@@ -320,3 +349,4 @@ fun getEndpointConfig(): EndpointConfig { @@ -320,3 +349,4 @@ fun getEndpointConfig(): EndpointConfig {
320 rule3 = EndpointRule(false, 0.0f, 20.0f) 349 rule3 = EndpointRule(false, 0.0f, 20.0f)
321 ) 350 )
322 } 351 }
  352 +
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +class OnlineStream(var ptr: Long = 0) {
  4 + fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
  5 + acceptWaveform(ptr, samples, sampleRate)
  6 +
  7 + fun inputFinished() = inputFinished(ptr)
  8 +
  9 + protected fun finalize() {
  10 + if (ptr != 0L) {
  11 + delete(ptr)
  12 + ptr = 0
  13 + }
  14 + }
  15 +
  16 + fun release() = finalize()
  17 +
  18 + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
  19 + private external fun inputFinished(ptr: Long)
  20 + private external fun delete(ptr: Long)
  21 +
  22 + companion object {
  23 + init {
  24 + System.loadLibrary("sherpa-onnx-jni")
  25 + }
  26 + }
  27 +}
@@ -3,7 +3,6 @@ package com.k2fsa.sherpa.onnx @@ -3,7 +3,6 @@ package com.k2fsa.sherpa.onnx
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 import android.util.Log 4 import android.util.Log
5 5
6 -private val TAG = "sherpa-onnx"  
7 data class SpeakerEmbeddingExtractorConfig( 6 data class SpeakerEmbeddingExtractorConfig(
8 val model: String, 7 val model: String,
9 var numThreads: Int = 1, 8 var numThreads: Int = 1,
@@ -11,33 +10,6 @@ data class SpeakerEmbeddingExtractorConfig( @@ -11,33 +10,6 @@ data class SpeakerEmbeddingExtractorConfig(
11 var provider: String = "cpu", 10 var provider: String = "cpu",
12 ) 11 )
13 12
14 -class SpeakerEmbeddingExtractorStream(var ptr: Long) {  
15 - fun acceptWaveform(samples: FloatArray, sampleRate: Int) =  
16 - acceptWaveform(ptr, samples, sampleRate)  
17 -  
18 - fun inputFinished() = inputFinished(ptr)  
19 -  
20 - protected fun finalize() {  
21 - delete(ptr)  
22 - ptr = 0  
23 - }  
24 -  
25 - private external fun myTest(ptr: Long, v: Array<FloatArray>)  
26 -  
27 - fun release() = finalize()  
28 - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)  
29 -  
30 - private external fun inputFinished(ptr: Long)  
31 -  
32 - private external fun delete(ptr: Long)  
33 -  
34 - companion object {  
35 - init {  
36 - System.loadLibrary("sherpa-onnx-jni")  
37 - }  
38 - }  
39 -}  
40 -  
41 class SpeakerEmbeddingExtractor( 13 class SpeakerEmbeddingExtractor(
42 assetManager: AssetManager? = null, 14 assetManager: AssetManager? = null,
43 config: SpeakerEmbeddingExtractorConfig, 15 config: SpeakerEmbeddingExtractorConfig,
@@ -46,29 +18,31 @@ class SpeakerEmbeddingExtractor( @@ -46,29 +18,31 @@ class SpeakerEmbeddingExtractor(
46 18
47 init { 19 init {
48 ptr = if (assetManager != null) { 20 ptr = if (assetManager != null) {
49 - new(assetManager, config) 21 + newFromAsset(assetManager, config)
50 } else { 22 } else {
51 newFromFile(config) 23 newFromFile(config)
52 } 24 }
53 } 25 }
54 26
55 protected fun finalize() { 27 protected fun finalize() {
  28 + if (ptr != 0L) {
56 delete(ptr) 29 delete(ptr)
57 ptr = 0 30 ptr = 0
58 } 31 }
  32 + }
59 33
60 fun release() = finalize() 34 fun release() = finalize()
61 35
62 - fun createStream(): SpeakerEmbeddingExtractorStream { 36 + fun createStream(): OnlineStream {
63 val p = createStream(ptr) 37 val p = createStream(ptr)
64 - return SpeakerEmbeddingExtractorStream(p) 38 + return OnlineStream(p)
65 } 39 }
66 40
67 - fun isReady(stream: SpeakerEmbeddingExtractorStream) = isReady(ptr, stream.ptr)  
68 - fun compute(stream: SpeakerEmbeddingExtractorStream) = compute(ptr, stream.ptr) 41 + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
  42 + fun compute(stream: OnlineStream) = compute(ptr, stream.ptr)
69 fun dim() = dim(ptr) 43 fun dim() = dim(ptr)
70 44
71 - private external fun new( 45 + private external fun newFromAsset(
72 assetManager: AssetManager, 46 assetManager: AssetManager,
73 config: SpeakerEmbeddingExtractorConfig, 47 config: SpeakerEmbeddingExtractorConfig,
74 ): Long 48 ): Long
@@ -98,13 +72,15 @@ class SpeakerEmbeddingManager(val dim: Int) { @@ -98,13 +72,15 @@ class SpeakerEmbeddingManager(val dim: Int) {
98 private var ptr: Long 72 private var ptr: Long
99 73
100 init { 74 init {
101 - ptr = new(dim) 75 + ptr = create(dim)
102 } 76 }
103 77
104 protected fun finalize() { 78 protected fun finalize() {
  79 + if (ptr != 0L) {
105 delete(ptr) 80 delete(ptr)
106 ptr = 0 81 ptr = 0
107 } 82 }
  83 + }
108 84
109 fun release() = finalize() 85 fun release() = finalize()
110 fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding) 86 fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding)
@@ -119,7 +95,7 @@ class SpeakerEmbeddingManager(val dim: Int) { @@ -119,7 +95,7 @@ class SpeakerEmbeddingManager(val dim: Int) {
119 95
120 fun allSpeakerNames() = allSpeakerNames(ptr) 96 fun allSpeakerNames() = allSpeakerNames(ptr)
121 97
122 - private external fun new(dim: Int): Long 98 + private external fun create(dim: Int): Long
123 private external fun delete(ptr: Long): Unit 99 private external fun delete(ptr: Long): Unit
124 private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean 100 private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean
125 private external fun addList(ptr: Long, name: String, embedding: Array<FloatArray>): Boolean 101 private external fun addList(ptr: Long, name: String, embedding: Array<FloatArray>): Boolean
@@ -170,7 +146,7 @@ object SpeakerRecognition { @@ -170,7 +146,7 @@ object SpeakerRecognition {
170 if (_extractor != null) { 146 if (_extractor != null) {
171 return 147 return
172 } 148 }
173 - Log.i(TAG, "Initializing speaker embedding extractor") 149 + Log.i("sherpa-onnx", "Initializing speaker embedding extractor")
174 150
175 _extractor = SpeakerEmbeddingExtractor( 151 _extractor = SpeakerEmbeddingExtractor(
176 assetManager = assetManager, 152 assetManager = assetManager,
1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 -import android.util.Log  
5 4
6 -private val TAG = "sherpa-onnx"  
7 -  
8 -data class SpokenLanguageIdentificationWhisperConfig ( 5 +data class SpokenLanguageIdentificationWhisperConfig(
9 var encoder: String, 6 var encoder: String,
10 var decoder: String, 7 var decoder: String,
11 var tailPaddings: Int = -1, 8 var tailPaddings: Int = -1,
12 ) 9 )
13 10
14 -data class SpokenLanguageIdentificationConfig ( 11 +data class SpokenLanguageIdentificationConfig(
15 var whisper: SpokenLanguageIdentificationWhisperConfig, 12 var whisper: SpokenLanguageIdentificationWhisperConfig,
16 var numThreads: Int = 1, 13 var numThreads: Int = 1,
17 var debug: Boolean = false, 14 var debug: Boolean = false,
18 var provider: String = "cpu", 15 var provider: String = "cpu",
19 ) 16 )
20 17
21 -class SpokenLanguageIdentification ( 18 +class SpokenLanguageIdentification(
22 assetManager: AssetManager? = null, 19 assetManager: AssetManager? = null,
23 config: SpokenLanguageIdentificationConfig, 20 config: SpokenLanguageIdentificationConfig,
24 ) { 21 ) {
@@ -69,10 +66,14 @@ class SpokenLanguageIdentification ( @@ -69,10 +66,14 @@ class SpokenLanguageIdentification (
69 } 66 }
70 } 67 }
71 } 68 }
  69 +
72 // please refer to 70 // please refer to
73 // https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper 71 // https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper
74 // to download more models 72 // to download more models
75 -fun getSpokenLanguageIdentificationConfig(type: Int, numThreads: Int=1): SpokenLanguageIdentificationConfig? { 73 +fun getSpokenLanguageIdentificationConfig(
  74 + type: Int,
  75 + numThreads: Int = 1
  76 +): SpokenLanguageIdentificationConfig? {
76 when (type) { 77 when (type) {
77 0 -> { 78 0 -> {
78 val modelDir = "sherpa-onnx-whisper-tiny" 79 val modelDir = "sherpa-onnx-whisper-tiny"
@@ -27,7 +27,7 @@ class Vad( @@ -27,7 +27,7 @@ class Vad(
27 27
28 init { 28 init {
29 if (assetManager != null) { 29 if (assetManager != null) {
30 - ptr = new(assetManager, config) 30 + ptr = newFromAsset(assetManager, config)
31 } else { 31 } else {
32 ptr = newFromFile(config) 32 ptr = newFromFile(config)
33 } 33 }
@@ -54,7 +54,7 @@ class Vad( @@ -54,7 +54,7 @@ class Vad(
54 54
55 private external fun delete(ptr: Long) 55 private external fun delete(ptr: Long)
56 56
57 - private external fun new( 57 + private external fun newFromAsset(
58 assetManager: AssetManager, 58 assetManager: AssetManager,
59 config: VadModelConfig, 59 config: VadModelConfig,
60 ): Long 60 ): Long