Wei Kang
Committed by GitHub

decoder for open vocabulary keyword spotting (#505)

* various fixes to ContextGraph to support open vocabulary keywords decoder

* Add keyword spotter runtime

* Add binary

* First version works

* Minor fixes

* update text2token

* default values

* Add jni for kws

* add kws android project

* Minor fixes

* Remove unused interface

* Minor fixes

* Add workflow

* handle extra info in texts

* Minor fixes

* Add more comments

* Fix ci

* fix cpp style

* Add input box in android demo so that users can specify their keywords

* Fix cpp style

* Fix comments

* Minor fixes

* Minor fixes

* minor fixes

* Minor fixes

* Minor fixes

* Add CI

* Fix code style

* cpplint

* Fix comments

* Fix error
正在显示 77 个修改的文件 包含 3316 行增加68 行删除
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run Chinese keyword spotting (Wenetspeech)"
log "------------------------------------------------------------"
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
log "Start testing ${repo_url}"
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
--max-active-paths=4 \
--num-threads=4 \
$repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav
rm -rf $repo
log "------------------------------------------------------------"
log "Run English keyword spotting (Gigaspeech)"
log "------------------------------------------------------------"
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
log "Start testing ${repo_url}"
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
--max-active-paths=4 \
--num-threads=4 \
$repo/test_wavs/0.wav $repo/test_wavs/1.wav
rm -rf $repo
... ...
name: apk-kws
on:
push:
branches:
- apk-kws
tags:
- '*'
workflow_dispatch:
concurrency:
group: apk-kws-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
jobs:
apk:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-android
- name: Display NDK HOME
shell: bash
run: |
echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
ls -lh ${ANDROID_NDK_LATEST_HOME}
- name: build APK
shell: bash
run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-kws-apk.sh
- name: Display APK
shell: bash
run: |
ls -lh ./apks/
- uses: actions/upload-artifact@v3
with:
path: ./apks/*.apk
- name: Release APK
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: apks/*.apk
overwrite: true
... ...
... ... @@ -107,6 +107,14 @@ jobs:
name: release-static
path: build/bin/*
- name: Test transducer kws
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-keyword-spotter
.github/scripts/test-kws.sh
- name: Test online CTC
shell: bash
run: |
... ...
... ... @@ -98,6 +98,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test transducer kws
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-keyword-spotter
.github/scripts/test-kws.sh
- name: Test online CTC
shell: bash
run: |
... ... @@ -106,7 +114,6 @@ jobs:
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
... ...
... ... @@ -62,7 +62,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece==0.1.96 soundfile
- name: Install sherpa-onnx
shell: bash
... ...
... ... @@ -45,7 +45,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy sentencepiece
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
- name: Install sherpa-onnx
shell: bash
... ...
... ... @@ -45,7 +45,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy sentencepiece
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
- name: Install sherpa-onnx
shell: bash
... ...
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties
... ...
/build
\ No newline at end of file
... ...
plugins {
id 'com.android.application'
id 'org.jetbrains.kotlin.android'
}
android {
namespace 'com.k2fsa.sherpa.onnx'
compileSdk 32
defaultConfig {
applicationId "com.k2fsa.sherpa.onnx"
minSdk 21
targetSdk 32
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = '1.8'
}
}
dependencies {
implementation 'androidx.core:core-ktx:1.7.0'
implementation 'androidx.appcompat:appcompat:1.5.1'
implementation 'com.google.android.material:material:1.7.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
}
\ No newline at end of file
... ...
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName)
}
}
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.SherpaOnnx"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
<meta-data
android:name="android.app.lib_name"
android:value="" />
</activity>
</application>
</manifest>
... ...
package com.k2fsa.sherpa.onnx
import android.Manifest
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.os.Bundle
import android.text.method.ScrollingMovementMethod
import android.util.Log
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.*
import kotlin.concurrent.thread
private const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
class MainActivity : AppCompatActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
private lateinit var model: SherpaOnnxKws
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
private lateinit var inputText: EditText
private var recordingThread: Thread? = null
private val audioSource = MediaRecorder.AudioSource.MIC
private val sampleRateInHz = 16000
private val channelConfig = AudioFormat.CHANNEL_IN_MONO
// Note: We don't use AudioFormat.ENCODING_PCM_FLOAT
// since the AudioRecord.read(float[]) needs API level >= 23
// but we are targeting API level >= 21
private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
private var idx: Int = 0
private var lastText: String = ""
@Volatile
private var isRecording: Boolean = false
override fun onRequestPermissionsResult(
requestCode: Int, permissions: Array<String>, grantResults: IntArray
) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) {
grantResults[0] == PackageManager.PERMISSION_GRANTED
} else {
false
}
if (!permissionToRecordAccepted) {
Log.e(TAG, "Audio record is disallowed")
finish()
}
Log.i(TAG, "Audio record is permitted")
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
Log.i(TAG, "Start to initialize model")
initModel()
Log.i(TAG, "Finished initializing model")
recordButton = findViewById(R.id.record_button)
recordButton.setOnClickListener { onclick() }
textView = findViewById(R.id.my_text)
textView.movementMethod = ScrollingMovementMethod()
inputText = findViewById(R.id.input_text)
}
private fun onclick() {
if (!isRecording) {
var keywords = inputText.text.toString()
Log.i(TAG, keywords)
keywords = keywords.replace("\n", "/")
// If keywords is an empty string, it just resets the decoding stream
// always returns true in this case.
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
// Return false if errors occured when adding keywords, true otherwise.
val status = model.reset(keywords)
if (!status) {
Log.i(TAG, "Failed to reset with keywords.")
Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show();
return
}
val ret = initMicrophone()
if (!ret) {
Log.e(TAG, "Failed to initialize microphone")
return
}
Log.i(TAG, "state: ${audioRecord?.state}")
audioRecord!!.startRecording()
recordButton.setText(R.string.stop)
isRecording = true
textView.text = ""
lastText = ""
idx = 0
recordingThread = thread(true) {
processSamples()
}
Log.i(TAG, "Started recording")
} else {
isRecording = false
audioRecord!!.stop()
audioRecord!!.release()
audioRecord = null
recordButton.setText(R.string.start)
Log.i(TAG, "Stopped recording")
}
}
private fun processSamples() {
Log.i(TAG, "processing samples")
val interval = 0.1 // i.e., 100 ms
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
val buffer = ShortArray(bufferSize)
while (isRecording) {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
while (model.isReady()) {
model.decode()
}
val text = model.keyword
var textToDisplay = lastText;
if(text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = "${idx}: ${text}"
} else {
textToDisplay = "${idx}: ${text}\n${lastText}"
}
lastText = "${idx}: ${text}\n${lastText}"
idx += 1
}
runOnUiThread {
textView.text = textToDisplay
}
}
}
}
private fun initMicrophone(): Boolean {
if (ActivityCompat.checkSelfPermission(
this, Manifest.permission.RECORD_AUDIO
) != PackageManager.PERMISSION_GRANTED
) {
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
return false
}
val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
Log.i(
TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}"
)
audioRecord = AudioRecord(
audioSource,
sampleRateInHz,
channelConfig,
audioFormat,
numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
)
return true
}
private fun initModel() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
// for a list of available models
val type = 0
Log.i(TAG, "Select model type ${type}")
val config = KeywordSpotterConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
keywordsFile = getKeywordsFile(type = type)!!,
)
model = SherpaOnnxKws(
assetManager = application.assets,
config = config,
)
}
}
... ...
// Copyright (c) 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OnlineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
data class KeywordSpotterConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var maxActivePaths: Int = 4,
var keywordsFile: String = "keywords.txt",
var keywordsScore: Float = 1.5f,
var keywordsThreshold: Float = 0.25f,
var numTrailingBlanks: Int = 2,
)
class SherpaOnnxKws(
assetManager: AssetManager? = null,
var config: KeywordSpotterConfig,
) {
private val ptr: Long
init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
}
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
fun decode() = decode(ptr)
fun isReady(): Boolean = isReady(ptr)
fun reset(keywords: String): Boolean = reset(ptr, keywords)
val keyword: String
get() = getKeyword(ptr)
private external fun delete(ptr: Long)
private external fun new(
assetManager: AssetManager,
config: KeywordSpotterConfig,
): Long
private external fun newFromFile(
config: KeywordSpotterConfig,
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getKeyword(ptr: Long): String
private external fun reset(ptr: Long, keywords: String): Boolean
private external fun decode(ptr: Long)
private external fun isReady(ptr: Long): Boolean
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese)
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary
1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English)
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
*/
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
)
}
1 -> {
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
)
}
}
return null;
}
/*
* Get the default keywords for each model.
* Caution: The types and modelDir should be the same as those in getModelConfig
* function above.
*/
fun getKeywordsFile(type: Int) : String {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
return "$modelDir/keywords.txt"
}
1 -> {
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
return "$modelDir/keywords.txt"
}
}
return "";
}
... ...
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
class WaveReader {
companion object {
// Read a mono wave file asset
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromAsset(
assetManager: AssetManager,
filename: String,
): Array<Any>
// Read a mono wave file from disk
// The returned array has two entries:
// - the first entry contains an 1-D float array
// - the second entry is the sample rate
external fun readWaveFromFile(
filename: String,
): Array<Any>
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
... ...
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
<aapt:attr name="android:fillColor">
<gradient
android:endX="85.84757"
android:endY="92.4963"
android:startX="42.9492"
android:startY="49.59793"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#3DDC84"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>
... ...
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:gravity="center"
android:orientation="vertical">
<EditText
android:id="@+id/input_text"
android:layout_width="match_parent"
android:layout_height="320dp"
android:layout_weight="2.5"
android:hint="@string/keyword_hint"
android:scrollbars="vertical"
android:text=""
android:textSize="15dp" />
<TextView
android:id="@+id/my_text"
android:layout_width="match_parent"
android:layout_height="443dp"
android:layout_weight="2.5"
android:padding="24dp"
android:scrollbars="vertical"
android:singleLine="false"
android:text="@string/hint"
android:textSize="15dp" />
<Button
android:id="@+id/record_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:text="@string/start" />
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
... ...
<resources xmlns:tools="http://schemas.android.com/tools">
<!-- Base application theme. -->
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
<!-- Primary brand color. -->
<item name="colorPrimary">@color/purple_200</item>
<item name="colorPrimaryVariant">@color/purple_700</item>
<item name="colorOnPrimary">@color/black</item>
<!-- Secondary brand color. -->
<item name="colorSecondary">@color/teal_200</item>
<item name="colorSecondaryVariant">@color/teal_200</item>
<item name="colorOnSecondary">@color/black</item>
<!-- Status bar color. -->
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
<!-- Customize your theme here. -->
</style>
</resources>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="purple_200">#FFBB86FC</color>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="teal_700">#FF018786</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>
\ No newline at end of file
... ...
<resources>
<string name="app_name">KWS with Next-gen Kaldi</string>
<string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
\n
\n\n\n
The source code and pre-trained models are publicly available.
Please see https://github.com/k2-fsa/sherpa-onnx for details.
</string>
<string name="keyword_hint">Input your keywords here, one keyword perline.</string>
<string name="start">Start</string>
<string name="stop">Stop</string>
</resources>
... ...
<resources xmlns:tools="http://schemas.android.com/tools">
<!-- Base application theme. -->
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
<!-- Primary brand color. -->
<item name="colorPrimary">@color/purple_500</item>
<item name="colorPrimaryVariant">@color/purple_700</item>
<item name="colorOnPrimary">@color/white</item>
<!-- Secondary brand color. -->
<item name="colorSecondary">@color/teal_200</item>
<item name="colorSecondaryVariant">@color/teal_700</item>
<item name="colorOnSecondary">@color/black</item>
<!-- Status bar color. -->
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
<!-- Customize your theme here. -->
</style>
</resources>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?><!--
Sample backup rules file; uncomment and customize as necessary.
See https://developer.android.com/guide/topics/data/autobackup
for details.
Note: This file is ignored for devices older that API 31
See https://developer.android.com/about/versions/12/backup-restore
-->
<full-backup-content>
<!--
<include domain="sharedpref" path="."/>
<exclude domain="sharedpref" path="device.xml"/>
-->
</full-backup-content>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?><!--
Sample data extraction rules file; uncomment and customize as necessary.
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
for details.
-->
<data-extraction-rules>
<cloud-backup>
<!-- TODO: Use <include> and <exclude> to control what is backed up.
<include .../>
<exclude .../>
-->
</cloud-backup>
<!--
<device-transfer>
<include .../>
<exclude .../>
</device-transfer>
-->
</data-extraction-rules>
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
import org.junit.Test
import org.junit.Assert.*
/**
* Example local unit test, which will execute on the development machine (host).
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
class ExampleUnitTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
}
\ No newline at end of file
... ...
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id 'com.android.application' version '7.3.1' apply false
id 'com.android.library' version '7.3.1' apply false
id 'org.jetbrains.kotlin.android' version '1.7.20' apply false
}
\ No newline at end of file
... ...
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Kotlin code style for this project: "official" or "obsolete":
kotlin.code.style=official
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
\ No newline at end of file
... ...
#Thu Feb 23 11:09:06 CST 2023
distributionBase=GRADLE_USER_HOME
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
distributionPath=wrapper/dists
zipStorePath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
... ...
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"
... ...
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
... ...
pluginManagement {
repositories {
gradlePluginPortal()
google()
mavenCentral()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "SherpaOnnxKws"
include ':app'
... ...
#!/usr/bin/env bash
# Please set the environment variable ANDROID_NDK
# before running this script
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
log "====================armv7-eabi================"
./build-android-armv7-eabi.sh
log "====================x86-64===================="
./build-android-x86-64.sh
log "====================x86===================="
./build-android-x86.sh
mkdir -p apks
# Download the model
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
log "Start testing ${repo_url}"
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
# remove .git to save spaces
rm -rf .git
rm *.int8.onnx
rm README.md configuration.json .gitattributes
rm -rfv test_wavs
ls -lh
popd
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
fi
tree ./android/SherpaOnnxKws/app/src/main/assets/
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build apk for $arch"
log "------------------------------------------------------------"
src_arch=$arch
if [ $arch == "armeabi-v7a" ]; then
src_arch=armv7-eabi
elif [ $arch == "x86_64" ]; then
src_arch=x86-64
fi
ls -lh ./build-android-$src_arch/install/lib/*.so
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
pushd ./android/SherpaOnnxKws
./gradlew build
popd
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-wenetspeech-zh-${SHERPA_ONNX_VERSION}-$arch.apk
ls -lh apks
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
done
git checkout .
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
# English model
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
log "Start testing ${repo_url}"
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
# remove .git to save spaces
rm -rf .git
rm *.int8.onnx
rm README.md configuration.json .gitattributes
rm -rfv test_wavs
ls -lh
popd
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
fi
tree ./android/SherpaOnnxKws/app/src/main/assets/
pushd android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx
sed -i.bak s/"type = 0"/"type = 1"/ ./MainActivity.kt
git diff
popd
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build apk for $arch"
log "------------------------------------------------------------"
src_arch=$arch
if [ $arch == "armeabi-v7a" ]; then
src_arch=armv7-eabi
elif [ $arch == "x86_64" ]; then
src_arch=x86-64
fi
ls -lh ./build-android-$src_arch/install/lib/*.so
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
pushd ./android/SherpaOnnxKws
./gradlew build
popd
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-gigaspeech-en-${SHERPA_ONNX_VERSION}-$arch.apk
ls -lh apks
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
done
git checkout .
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
... ...
... ... @@ -151,6 +151,7 @@ class BuildExtension(build_ext):
# Remember to also change setup.py
binaries = ["sherpa-onnx"]
binaries += ["sherpa-onnx-keyword-spotter"]
binaries += ["sherpa-onnx-offline"]
binaries += ["sherpa-onnx-microphone"]
binaries += ["sherpa-onnx-microphone-offline"]
... ...
... ... @@ -36,13 +36,44 @@ import argparse
from sherpa_onnx import text2token
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
type=str,
required=True,
help="Path to the input texts",
help="""Path to the input texts.
Each line in the texts contains the original phrase, it might also contain some
extra items, for example, the boosting score (startting with :), the triggering
threshold (startting with #, only used in keyword spotting task) and the original
phrase (startting with @). Note: extra items will be kept in the output.
example input 1 (tokens_type = ppinyin):
小爱同学 :2.0 #0.6 @小爱同学
你好问问 :3.5 @你好问问
小艺小艺 #0.6 @小艺小艺
example output 1:
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
n ǐ h ǎo w èn w èn :3.5 @你好问问
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
example input 2 (tokens_type = bpe):
HELLO WORLD :1.5 #0.4
HI GOOGLE :2.0 #0.8
HEY SIRI #0.35
example output 2:
▁HE LL O ▁WORLD :1.5 #0.4
▁HI ▁GO O G LE :2.0 #0.8
▁HE Y ▁S I RI #0.35
""",
)
parser.add_argument(
... ... @@ -56,7 +87,11 @@ def get_args():
"--tokens-type",
type=str,
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
choices=["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"],
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
""",
)
parser.add_argument(
... ... @@ -79,9 +114,21 @@ def main():
args = get_args()
texts = []
# extra information like boosting score (start with :), triggering threshold (start with #)
# original keyword (start with @)
extra_info = []
with open(args.text, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
extra = []
text = []
toks = line.strip().split()
for tok in toks:
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
extra.append(tok)
else:
text.append(tok)
texts.append(" ".join(text))
extra_info.append(extra)
encoded_texts = text2token(
texts,
tokens=args.tokens,
... ... @@ -89,7 +136,8 @@ def main():
bpe_model=args.bpe_model,
)
with open(args.output, "w", encoding="utf8") as f:
for txt in encoded_texts:
for i, txt in enumerate(encoded_texts):
txt += extra_info[i]
f.write(" ".join(txt) + "\n")
... ...
... ... @@ -51,6 +51,7 @@ def get_binaries_to_install():
# Remember to also change cmake/cmake_extension.py
binaries = ["sherpa-onnx"]
binaries += ["sherpa-onnx-keyword-spotter"]
binaries += ["sherpa-onnx-offline"]
binaries += ["sherpa-onnx-microphone"]
binaries += ["sherpa-onnx-microphone-offline"]
... ...
... ... @@ -19,6 +19,8 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
keyword-spotter-impl.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
... ... @@ -87,6 +89,7 @@ set(sources
stack.cc
symbol-table.cc
text-utils.cc
transducer-keyword-decoder.cc
transpose.cc
unbind.cc
utils.cc
... ... @@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
set(main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
... ...
... ... @@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <cmath>
#include <map>
#include <random>
#include <string>
... ... @@ -15,27 +16,25 @@
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
static void TestHelper(const std::map<std::string, float> &queries, float score,
bool strict_mode) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
std::vector<float> scores;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
auto context_graph = ContextGraph(contexts, 1, scores);
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
total_scores += std::get<0>(res);
state = std::get<1>(res);
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
... ... @@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
}
}
TEST(ContextGraph, TestBasic) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
TestHelper(queries, 0, true);
}
TEST(ContextGraph, TestBasicNonStrict) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
TestHelper(queries, 0, false);
}
TEST(ContextGraph, TestCustomize) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
TestHelper(queries, 5, true);
}
TEST(ContextGraph, TestCustomizeNonStrict) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
TestHelper(queries, 5, false);
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());
... ...
... ... @@ -4,22 +4,59 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <algorithm>
#include <cassert>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const {
if (!scores.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
}
if (!phrases.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
}
if (!ac_thresholds.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
}
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
float score = scores.empty() ? 0.0f : scores[i];
score = score == 0.0f ? context_score_ : score;
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
std::string phrase = phrases.empty() ? std::string() : phrases[i];
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, score, node->node_score + score,
is_end ? node->node_score + score : 0, j + 1,
is_end ? ac_threshold : 0.0f, is_end,
is_end ? phrase : std::string());
} else {
float token_score = std::max(score, node->next[token]->token_score);
node->next[token]->token_score = token_score;
float node_score = node->node_score + token_score;
node->next[token]->node_score = node_score;
bool is_end =
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
node->next[token]->output_score = is_end ? node_score : 0.0f;
node->next[token]->is_end = is_end;
if (j == token_ids[i].size() - 1) {
node->next[token]->phrase = phrase;
node->next[token]->ac_threshold = ac_threshold;
}
}
node = node->next[token].get();
}
... ... @@ -27,8 +64,9 @@ void ContextGraph::Build(
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
... ... @@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
}
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
return std::make_pair(score + node->output_score, node);
const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
if (!strict_mode && node->output_score != 0) {
SHERPA_ONNX_CHECK(nullptr != matched_node);
float output_score =
node->is_end ? node->node_score
: (node->output != nullptr ? node->output->node_score
: node->node_score);
return std::make_tuple(score + output_score - node->node_score, root_.get(),
matched_node);
}
return std::make_tuple(score + node->output_score, node, matched_node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
... ... @@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
return std::make_pair(score, root_.get());
}
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
const ContextState *state) const {
bool status = false;
const ContextState *node = nullptr;
if (state->is_end) {
status = true;
node = state;
} else {
if (state->output != nullptr) {
status = true;
node = state->output;
}
}
return std::make_pair(status, node);
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
... ...
... ... @@ -6,6 +6,8 @@
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
... ... @@ -22,34 +24,55 @@ struct ContextState {
float token_score;
float node_score;
float output_score;
int32_t level;
float ac_threshold;
bool is_end;
std::string phrase;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float output_score, bool is_end)
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
bool is_end = false, const std::string &phrase = {})
: token(token),
token_score(token_score),
node_score(node_score),
output_score(output_score),
is_end(is_end) {}
level(level),
ac_threshold(ac_threshold),
is_end(is_end),
phrase(phrase) {}
};
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score)
: context_score_(context_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
float context_score, float ac_threshold,
const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {},
const std::vector<float> &ac_thresholds = {})
: context_score_(context_score), ac_threshold_(ac_threshold) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
root_->fail = root_.get();
Build(token_ids);
Build(token_ids, scores, phrases, ac_thresholds);
}
std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score, const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {})
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
std::vector<float>()) {}
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id,
bool strict_mode = true) const;
std::pair<bool, const ContextState *> IsMatched(
const ContextState *state) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;
... ... @@ -57,8 +80,12 @@ class ContextGraph {
private:
float context_score_;
float ac_threshold_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const;
void FillFailOutput() const;
};
... ...
... ... @@ -28,6 +28,10 @@ struct Hypothesis {
// on which ys[i] is decoded.
std::vector<int32_t> timestamps;
// The acoustic probability for each token in ys.
// Only used for keyword spotting task.
std::vector<float> ys_probs;
// The total score of ys in log space.
// It contains only acoustic scores
double log_prob = 0;
... ...
// sherpa-onnx/csrc/keyword-spotter-impl.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
namespace sherpa_onnx {
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
const KeywordSpotterConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
AAssetManager *mgr, const KeywordSpotterConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/keyword-spotter-impl.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/online-stream.h"
namespace sherpa_onnx {
class KeywordSpotterImpl {
public:
static std::unique_ptr<KeywordSpotterImpl> Create(
const KeywordSpotterConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<KeywordSpotterImpl> Create(
AAssetManager *mgr, const KeywordSpotterConfig &config);
#endif
virtual ~KeywordSpotterImpl() = default;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(
const std::string &keywords) const = 0;
virtual bool IsReady(OnlineStream *s) const = 0;
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
... ...
// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
#include <algorithm>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
static KeywordResult Convert(const TransducerKeywordResult &src,
const SymbolTable &sym_table, float frame_shift_ms,
int32_t subsampling_factor,
int32_t frames_since_start) {
KeywordResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
r.keyword = src.keyword;
bool from_tokens = src.keyword.empty();
for (auto i : src.tokens) {
auto sym = sym_table[i];
if (from_tokens) {
r.keyword.append(sym);
}
r.tokens.push_back(std::move(sym));
}
if (from_tokens && r.keyword.size()) {
r.keyword = r.keyword.substr(1);
}
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
r.start_time = frames_since_start * frame_shift_ms / 1000.;
return r;
}
class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
public:
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
InitKeywords();
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
#if __ANDROID_API__ >= 9
KeywordSpotterTransducerImpl(AAssetManager *mgr,
const KeywordSpotterConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
InitKeywords(mgr);
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::string &keywords) const override {
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
std::istringstream is(kws);
std::vector<std::vector<int32_t>> current_ids;
std::vector<std::string> current_kws;
std::vector<float> current_scores;
std::vector<float> current_thresholds;
if (!EncodeKeywords(is, sym_, &current_ids, &current_kws, &current_scores,
&current_thresholds)) {
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
return nullptr;
}
int32_t num_kws = current_ids.size();
int32_t num_default_kws = keywords_id_.size();
current_ids.insert(current_ids.end(), keywords_id_.begin(),
keywords_id_.end());
if (!current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else if (!current_kws.empty() && keywords_.empty()) {
current_kws.insert(current_kws.end(), num_default_kws, std::string());
} else if (current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), num_kws, std::string());
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else {
// Do nothing.
}
if (!current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else if (!current_scores.empty() && boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_default_kws,
config_.keywords_score);
} else if (current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_kws,
config_.keywords_score);
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else {
// Do nothing.
}
if (!current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
thresholds_.end());
} else if (!current_thresholds.empty() && thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_default_kws,
config_.keywords_threshold);
} else if (current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_kws,
config_.keywords_threshold);
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
thresholds_.end());
} else {
// Do nothing.
}
auto keywords_graph = std::make_shared<ContextGraph>(
current_ids, config_.keywords_score, config_.keywords_threshold,
current_scores, current_kws, current_thresholds);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
InitOnlineStream(stream.get());
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = ss[0]->FeatureDim();
std::vector<TransducerKeywordResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
for (int32_t i = 0; i != n; ++i) {
SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);
results[i] = std::move(ss[i]->GetKeywordResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
std::array<int64_t, 1> processed_frames_shape{
static_cast<int64_t>(all_processed_frames.size())};
Ort::Value processed_frames = Ort::Value::CreateTensor(
memory_info, all_processed_frames.data(), all_processed_frames.size(),
processed_frames_shape.data(), processed_frames_shape.size());
auto states = model_->StackStates(states_vec);
auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames));
decoder_->Decode(std::move(pair.first), ss, &results);
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second);
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetKeywordResult(results[i]);
ss[i]->SetStates(std::move(next_states[i]));
}
}
KeywordResult GetResult(OnlineStream *s) const override {
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetNumFramesSinceStart());
}
private:
void InitKeywords(std::istream &is) {
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
&thresholds_)) {
SHERPA_ONNX_LOGE("Encode keywords failed.");
exit(-1);
}
keywords_graph_ = std::make_shared<ContextGraph>(
keywords_id_, config_.keywords_score, config_.keywords_threshold,
boost_scores_, keywords_, thresholds_);
}
void InitKeywords() {
// each line in keywords_file contains space-separated words
std::ifstream is(config_.keywords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
config_.keywords_file.c_str());
exit(-1);
}
InitKeywords(is);
}
#if __ANDROID_API__ >= 9
void InitKeywords(AAssetManager *mgr) {
// each line in keywords_file contains space-separated words
auto buf = ReadFile(mgr, config_.keywords_file);
std::istrstream is(buf.data(), buf.size());
if (!is) {
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
config_.keywords_file.c_str());
exit(-1);
}
InitKeywords(is);
}
#endif
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
stream->SetKeywordResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
private:
KeywordSpotterConfig config_;
std::vector<std::vector<int32_t>> keywords_id_;
std::vector<float> boost_scores_;
std::vector<float> thresholds_;
std::vector<std::string> keywords_;
ContextGraphPtr keywords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<TransducerKeywordDecoder> decoder_;
SymbolTable sym_;
int32_t unk_id_ = -1;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
... ...
// sherpa-onnx/csrc/keyword-spotter.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include <assert.h>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
namespace sherpa_onnx {
std::string KeywordResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"keyword\""
<< ": ";
os << "\"" << keyword << "\""
<< ", ";
os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";
sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
sep = ", ";
}
os << "]";
os << "}";
return os.str();
}
void KeywordSpotterConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("num-trailing-blanks", &num_trailing_blanks,
"The number of trailing blanks should have after the keyword.");
po->Register("keywords-score", &keywords_score,
"The bonus score for each token in context word/phrase.");
po->Register("keywords-threshold", &keywords_threshold,
"The acoustic threshold (probability) to trigger the keywords.");
po->Register(
"keywords-file", &keywords_file,
"The file containing keywords, one word/phrase per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
}
bool KeywordSpotterConfig::Validate() const {
if (keywords_file.empty()) {
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
return false;
}
if (!std::ifstream(keywords_file.c_str()).good()) {
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
return false;
}
return model_config.Validate();
}
std::string KeywordSpotterConfig::ToString() const {
std::ostringstream os;
os << "KeywordSpotterConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
os << "keywords_score=" << keywords_score << ", ";
os << "keywords_threshold=" << keywords_threshold << ", ";
os << "keywords_file=\"" << keywords_file << "\")";
return os.str();
}
KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
: impl_(KeywordSpotterImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
const KeywordSpotterConfig &config)
: impl_(KeywordSpotterImpl::Create(mgr, config)) {}
#endif
KeywordSpotter::~KeywordSpotter() = default;
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
return impl_->CreateStream();
}
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
const std::string &keywords) const {
return impl_->CreateStream(keywords);
}
bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
return impl_->GetResult(s);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/keyword-spotter.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct KeywordResult {
/// The triggered keyword.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
std::string keyword;
/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// Starting time of this segment.
/// When an endpoint is detected, it will change
float start_time = 0;
/** Return a json string.
*
* The returned string contains:
* {
* "keyword": "The triggered keyword",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "start_time": x,
* }
*/
std::string AsJsonString() const;
};
struct KeywordSpotterConfig {
FeatureExtractorConfig feat_config;
OnlineModelConfig model_config;
int32_t max_active_paths = 4;
int32_t num_trailing_blanks = 1;
float keywords_score = 1.0;
float keywords_threshold = 0.25;
std::string keywords_file;
KeywordSpotterConfig() = default;
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config,
int32_t max_active_paths, int32_t num_trailing_blanks,
float keywords_score, float keywords_threshold,
const std::string &keywords_file)
: feat_config(feat_config),
model_config(model_config),
max_active_paths(max_active_paths),
num_trailing_blanks(num_trailing_blanks),
keywords_score(keywords_score),
keywords_threshold(keywords_threshold),
keywords_file(keywords_file) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class KeywordSpotterImpl;
class KeywordSpotter {
public:
explicit KeywordSpotter(const KeywordSpotterConfig &config);
#if __ANDROID_API__ >= 9
KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
#endif
~KeywordSpotter();
/** Create a stream for decoding.
*
*/
std::unique_ptr<OnlineStream> CreateStream() const;
/** Create a stream for decoding.
*
* @param The keywords for this string, it might contain several keywords,
* the keywords are separated by "/". In each of the keywords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
*/
bool IsReady(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};
DecodeStreams(ss, 1);
}
/** Decode multiple streams in parallel
*
* @param ss Pointer array containing streams to be decoded.
* @param n Number of streams in `ss`.
*/
void DecodeStreams(OnlineStream **ss, int32_t n) const;
KeywordResult GetResult(OnlineStream *s) const;
private:
std::unique_ptr<KeywordSpotterImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
... ...
... ... @@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), View(&decoder_out));
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
... ... @@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
}
... ...
... ... @@ -51,6 +51,25 @@ class OnlineStream::Impl {
OnlineTransducerDecoderResult &GetResult() { return result_; }
void SetKeywordResult(const TransducerKeywordResult &r) {
keyword_result_ = r;
}
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
if (remove_duplicates) {
if (!prev_keyword_result_.timestamps.empty() &&
!keyword_result_.timestamps.empty() &&
keyword_result_.timestamps[0] <=
prev_keyword_result_.timestamps.back()) {
return empty_keyword_result_;
} else {
prev_keyword_result_ = keyword_result_;
}
return keyword_result_;
} else {
return keyword_result_;
}
}
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
... ... @@ -93,6 +112,9 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
int32_t segment_ = 0;
OnlineTransducerDecoderResult result_;
TransducerKeywordResult prev_keyword_result_;
TransducerKeywordResult keyword_result_;
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<float> paraformer_feat_cache_;
... ... @@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
impl_->SetKeywordResult(r);
}
TransducerKeywordResult &OnlineStream::GetKeywordResult(
bool remove_duplicates /*=false*/) {
return impl_->GetKeywordResult(remove_duplicates);
}
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
return impl_->GetCtcResult();
}
... ...
... ... @@ -14,9 +14,11 @@
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
class TransducerKeywordResult;
class OnlineStream {
public:
explicit OnlineStream(const FeatureExtractorConfig &config = {},
... ... @@ -76,6 +78,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();
void SetKeywordResult(const TransducerKeywordResult &r);
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
void SetCtcResult(const OnlineCtcDecoderResult &r);
OnlineCtcDecoderResult &GetCtcResult();
... ... @@ -92,7 +97,7 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
// for streaming parformer
// for streaming paraformer
std::vector<float> &GetParaformerFeatCache();
std::vector<float> &GetParaformerEncoderOutCache();
std::vector<float> &GetParaformerAlphaCache();
... ...
... ... @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != result->size()) {
fprintf(stderr,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}
... ... @@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), View(&decoder_out));
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
... ... @@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
... ...
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include <stdio.h>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
std::string filename;
} Stream;
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Usage:
(1) Streaming transducer
./bin/sherpa-onnx-keyword-spotter \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=2 \
--keywords-file=keywords.txt \
/path/to/foo.wav [bar.wav foobar.wav ...]
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for provider: cpu (default), cuda, coreml.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::KeywordSpotterConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() < 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::KeywordSpotter keyword_spotter(config);
std::vector<Stream> ss;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
auto s = keyword_spotter.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({std::move(s), wav_filename});
}
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
for (;;) {
ready_streams.clear();
for (auto &s : ss) {
const auto p_ss = s.online_stream.get();
if (keyword_spotter.IsReady(p_ss)) {
ready_streams.push_back(p_ss);
}
std::ostringstream os;
const auto r = keyword_spotter.GetResult(p_ss);
if (!r.keyword.empty()) {
os << s.filename << "\n";
os << r.AsJsonString() << "\n\n";
fprintf(stderr, "%s", os.str().c_str());
}
}
if (ready_streams.empty()) {
break;
}
keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
}
return 0;
}
... ...
// sherpa-onnx/csrc/transducer-keywords-decoder.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include <algorithm>
#include <cmath>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
TransducerKeywordResult r;
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = blank_id;
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
}
void TransducerKeywordDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != result->size()) {
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = 0; // blank_id is hardcoded to 0
std::vector<Hypotheses> cur;
for (auto &r : *result) {
cur.push_back(std::move(r.hyps));
}
std::vector<Hypothesis> prev;
for (int32_t t = 0; t != num_frames; ++t) {
// Due to merging paths with identical token sequences,
// not all utterances have "num_active_paths" paths.
auto hyps_row_splits = GetHypsRowSplits(cur);
int32_t num_hyps =
hyps_row_splits.back(); // total num hyps for all utterance
prev.clear();
for (auto &hyps : cur) {
for (auto &h : hyps) {
prev.push_back(std::move(h.second));
}
}
cur.clear();
cur.reserve(batch_size);
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
// The acoustic logprobs for current frame
std::vector<float> logprobs(vocab_size * num_hyps);
std::memcpy(logprobs.data(), p_logit,
sizeof(float) * vocab_size * num_hyps);
// now p_logit contains log_softmax output, we rename it to p_logprob
// to match what it actually contains
float *p_logprob = p_logit;
// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob;
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
}
p_logprob = p_logit; // we changed p_logprob in the above for loop
for (int32_t b = 0; b != batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_row_splits[b];
int32_t end = hyps_row_splits[b + 1];
auto topk =
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
Hypotheses hyps;
for (auto k : topk) {
int32_t hyp_index = k / vocab_size + start;
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
// blank is hardcoded to 0
// also, it treats unk as blank
if (new_token != 0 && new_token != unk_id_) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.ys_probs.push_back(
exp(logprobs[hyp_index * vocab_size + new_token]));
new_hyp.num_trailing_blanks = 0;
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
// Start matching from the start state, forget the decoder history.
if (new_hyp.context_state->token == -1) {
new_hyp.ys = blanks;
new_hyp.timestamps.clear();
new_hyp.ys_probs.clear();
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
auto best_hyp = hyps.GetMostProbable(false);
auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
bool matched = std::get<0>(status);
const ContextState *matched_state = std::get<1>(status);
if (matched) {
float ys_prob = 0.0;
int32_t length = best_hyp.ys_probs.size();
for (int32_t i = 1; i <= matched_state->level; ++i) {
ys_prob += best_hyp.ys_probs[i];
}
ys_prob /= matched_state->level;
if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
ys_prob >= matched_state->ac_threshold) {
auto &r = (*result)[b];
r.tokens = {best_hyp.ys.end() - matched_state->level,
best_hyp.ys.end()};
r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
best_hyp.timestamps.end()};
r.keyword = matched_state->phrase;
hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
}
}
cur.push_back(std::move(hyps));
p_logprob += (end - start) * vocab_size;
} // for (int32_t b = 0; b != batch_size; ++b)
}
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(false);
auto &r = (*result)[b];
r.hyps = std::move(hyps);
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
r.frame_offset += num_frames;
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/transducer-keywords-decoder.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
struct TransducerKeywordResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs for keywords
std::vector<int64_t> tokens;
/// The triggered keyword
std::string keyword;
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;
// used only in modified beam_search
Hypotheses hyps;
};
class TransducerKeywordDecoder {
public:
TransducerKeywordDecoder(OnlineTransducerModel *model,
int32_t max_active_paths,
int32_t num_trailing_blanks, int32_t unk_id)
: model_(model),
max_active_paths_(max_active_paths),
num_trailing_blanks_(num_trailing_blanks),
unk_id_(unk_id) {}
TransducerKeywordResult GetEmptyResult() const;
void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordResult> *result);
private:
OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_;
int32_t num_trailing_blanks_;
int32_t unk_id_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
... ...
... ... @@ -15,16 +15,31 @@
namespace sherpa_onnx {
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
hotwords->clear();
std::vector<int32_t> tmp;
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *ids,
std::vector<std::string> *phrases,
std::vector<float> *scores,
std::vector<float> *thresholds) {
SHERPA_ONNX_CHECK(ids != nullptr);
ids->clear();
std::vector<int32_t> tmp_ids;
std::vector<float> tmp_scores;
std::vector<float> tmp_thresholds;
std::vector<std::string> tmp_phrases;
std::string line;
std::string word;
bool has_scores = false;
bool has_thresholds = false;
bool has_phrases = false;
while (std::getline(is, line)) {
float score = 0;
float threshold = 0;
std::string phrase = "";
std::istringstream iss(line);
std::vector<std::string> syms;
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
... ... @@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
}
}
if (symbol_table.contains(word)) {
int32_t number = symbol_table[word];
tmp.push_back(number);
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
} else {
SHERPA_ONNX_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
"the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
switch (word[0]) {
case ':': // boosting score for current keyword
score = std::stof(word.substr(1));
has_scores = true;
break;
case '#': // triggering threshold (probability) for current keyword
threshold = std::stof(word.substr(1));
has_thresholds = true;
break;
case '@': // the original keyword string
phrase = word.substr(1);
has_phrases = true;
break;
default:
SHERPA_ONNX_LOGE(
"Cannot find ID for token %s at line: %s. (Hint: words on "
"the same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
}
}
}
hotwords->push_back(std::move(tmp));
ids->push_back(std::move(tmp_ids));
tmp_scores.push_back(score);
tmp_phrases.push_back(phrase);
tmp_thresholds.push_back(threshold);
}
if (scores != nullptr) {
if (has_scores) {
scores->swap(tmp_scores);
} else {
scores->clear();
}
}
if (phrases != nullptr) {
if (has_phrases) {
*phrases = std::move(tmp_phrases);
} else {
phrases->clear();
}
}
if (thresholds != nullptr) {
if (has_thresholds) {
thresholds->swap(tmp_thresholds);
} else {
thresholds->clear();
}
}
return true;
}
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
}
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *keywords_id,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold) {
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
threshold);
}
} // namespace sherpa_onnx
... ...
... ... @@ -26,7 +26,32 @@ namespace sherpa_onnx {
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords);
std::vector<std::vector<int32_t>> *hotwords_id);
/* Encode the keywords in an input stream to be tokens ids.
*
* @param is The input stream, it contains several lines, one hotword for each
* line. For each hotword, the tokens (cjkchar or bpe) are separated
* by spaces, it might contain boosting score (starting with :),
* triggering threshold (starting with #) and keyword string (starting
* with @) too.
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
* in the stream should be in the symbol_table, if not this
* function returns fasle.
*
* @param keywords_id The encoded ids to be written to.
* @param keywords The original keyword string to be written to.
* @param boost_scores The boosting score for each keyword to be written to.
* @param threshold The triggering threshold for each keyword to be written to.
*
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *keywords_id,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold);
} // namespace sherpa_onnx
... ...
... ... @@ -21,6 +21,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-tts.h"
... ... @@ -140,6 +141,73 @@ class SherpaOnnxVad {
VoiceActivityDetector vad_;
};
class SherpaOnnxKws {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
: keyword_spotter_(mgr, config),
stream_(keyword_spotter_.CreateStream()) {}
#endif
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
// If keywords is an empty string, it just recreates the decoding stream
// always returns true in this case.
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
// Return false if errors occurred when adding keywords, true otherwise.
bool Reset(const std::string &keywords = {}) {
if (keywords.empty()) {
stream_ = keyword_spotter_.CreateStream();
return true;
} else {
auto stream = keyword_spotter_.CreateStream(keywords);
// Set new keywords failed, the stream_ will not be updated.
if (stream == nullptr) {
return false;
} else {
stream_ = std::move(stream);
return true;
}
}
}
std::string GetKeyword() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.keyword;
}
std::vector<std::string> GetTokens() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.tokens;
}
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
private:
KeywordSpotter keyword_spotter_;
std::unique_ptr<OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
... ... @@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
return ans;
}
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
KeywordSpotterConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.keywords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "keywordsScore", "F");
ans.keywords_score = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
ans.keywords_threshold = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
ans.num_trailing_blanks = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
VadModelConfig ans;
... ... @@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
// Convert the C string to a jstring
jstring jstr = env->NewStringUTF(cstr);
// Set the array element
env->SetObjectArrayElement(result, i, jstr);
}
return result;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
std::string keywords_str = p_keywords;
bool status =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
env->ReleaseStringUTFChars(keywords, p_keywords);
return status;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
auto tokens =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
int32_t size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
... ...
... ... @@ -28,9 +28,14 @@ def cli():
)
@click.option(
"--tokens-type",
type=str,
type=click.Choice(
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
),
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
""",
)
@click.option(
"--bpe-model",
... ... @@ -42,14 +47,56 @@ def encode_text(
):
"""
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
Each line in the texts contains the original phrase, it might also contain some
extra items, for example, the boosting score (startting with :), the triggering
threshold (startting with #, only used in keyword spotting task) and the original
phrase (startting with @). Note: the extra items will be kept same in the output.
example input 1 (tokens_type = ppinyin):
小爱同学 :2.0 #0.6 @小爱同学
你好问问 :3.5 @你好问问
小艺小艺 #0.6 @小艺小艺
example output 1:
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
n ǐ h ǎo w èn w èn :3.5 @你好问问
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
example input 2 (tokens_type = bpe):
HELLO WORLD :1.5 #0.4
HI GOOGLE :2.0 #0.8
HEY SIRI #0.35
example output 2:
▁HE LL O ▁WORLD :1.5 #0.4
▁HI ▁GO O G LE :2.0 #0.8
▁HE Y ▁S I RI #0.35
"""
texts = []
# extra information like boosting score (start with :), triggering threshold (start with #)
# original keyword (start with @)
extra_info = []
with open(input, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
extra = []
text = []
toks = line.strip().split()
for tok in toks:
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
extra.append(tok)
else:
text.append(tok)
texts.append(" ".join(text))
extra_info.append(extra)
encoded_texts = text2token(
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
)
with open(output, "w", encoding="utf8") as f:
for txt in encoded_texts:
for i, txt in enumerate(encoded_texts):
txt += extra_info[i]
f.write(" ".join(txt) + "\n")
... ...
... ... @@ -6,6 +6,9 @@ from typing import List, Optional, Union
import sentencepiece as spm
from pypinyin import pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
def text2token(
texts: List[str],
... ... @@ -23,7 +26,9 @@ def text2token(
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
... ... @@ -53,6 +58,24 @@ def text2token(
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
elif "pinyin" in tokens_type:
for txt in texts:
py = [x[0] for x in pinyin(txt)]
if "ppinyin" == tokens_type:
res = []
for x in py:
initial = to_initials(x, strict=False)
final = to_finals_tone(x, strict=False)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
texts_list.append(res)
else:
texts_list.append(py)
else:
assert (
tokens_type == "cjkchar+bpe"
... ...