Committed by
GitHub
decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
正在显示
77 个修改的文件
包含
3316 行增加
和
68 行删除
.github/scripts/test-kws.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run Chinese keyword spotting (Wenetspeech)" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git | ||
| 21 | +log "Start testing ${repo_url}" | ||
| 22 | +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 | ||
| 23 | +log "Download pretrained model and test-data from $repo_url" | ||
| 24 | + | ||
| 25 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 26 | +pushd $repo | ||
| 27 | +git lfs pull --include "*.onnx" | ||
| 28 | +ls -lh *.onnx | ||
| 29 | +popd | ||
| 30 | + | ||
| 31 | +time $EXE \ | ||
| 32 | + --tokens=$repo/tokens.txt \ | ||
| 33 | + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 34 | + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 35 | + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 36 | + --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 37 | + --max-active-paths=4 \ | ||
| 38 | + --num-threads=4 \ | ||
| 39 | + $repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav | ||
| 40 | + | ||
| 41 | +rm -rf $repo | ||
| 42 | + | ||
| 43 | +log "------------------------------------------------------------" | ||
| 44 | +log "Run English keyword spotting (Gigaspeech)" | ||
| 45 | +log "------------------------------------------------------------" | ||
| 46 | + | ||
| 47 | +repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git | ||
| 48 | +log "Start testing ${repo_url}" | ||
| 49 | +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 | ||
| 50 | +log "Download pretrained model and test-data from $repo_url" | ||
| 51 | + | ||
| 52 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 53 | +pushd $repo | ||
| 54 | +git lfs pull --include "*.onnx" | ||
| 55 | +ls -lh *.onnx | ||
| 56 | +popd | ||
| 57 | + | ||
| 58 | +time $EXE \ | ||
| 59 | + --tokens=$repo/tokens.txt \ | ||
| 60 | + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 61 | + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 62 | + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 63 | + --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 64 | + --max-active-paths=4 \ | ||
| 65 | + --num-threads=4 \ | ||
| 66 | + $repo/test_wavs/0.wav $repo/test_wavs/1.wav | ||
| 67 | + | ||
| 68 | +rm -rf $repo |
.github/workflows/apk-kws.yaml
0 → 100644
| 1 | +name: apk-kws | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - apk-kws | ||
| 7 | + tags: | ||
| 8 | + - '*' | ||
| 9 | + | ||
| 10 | + workflow_dispatch: | ||
| 11 | + | ||
| 12 | +concurrency: | ||
| 13 | + group: apk-kws-${{ github.ref }} | ||
| 14 | + cancel-in-progress: true | ||
| 15 | + | ||
| 16 | +permissions: | ||
| 17 | + contents: write | ||
| 18 | + | ||
| 19 | +jobs: | ||
| 20 | + apk: | ||
| 21 | + runs-on: ${{ matrix.os }} | ||
| 22 | + strategy: | ||
| 23 | + fail-fast: false | ||
| 24 | + matrix: | ||
| 25 | + os: [ubuntu-latest] | ||
| 26 | + | ||
| 27 | + steps: | ||
| 28 | + - uses: actions/checkout@v4 | ||
| 29 | + with: | ||
| 30 | + fetch-depth: 0 | ||
| 31 | + | ||
| 32 | + - name: ccache | ||
| 33 | + uses: hendrikmuhs/ccache-action@v1.2 | ||
| 34 | + with: | ||
| 35 | + key: ${{ matrix.os }}-android | ||
| 36 | + | ||
| 37 | + - name: Display NDK HOME | ||
| 38 | + shell: bash | ||
| 39 | + run: | | ||
| 40 | + echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}" | ||
| 41 | + ls -lh ${ANDROID_NDK_LATEST_HOME} | ||
| 42 | + | ||
| 43 | + - name: build APK | ||
| 44 | + shell: bash | ||
| 45 | + run: | | ||
| 46 | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache | ||
| 47 | + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" | ||
| 48 | + cmake --version | ||
| 49 | + | ||
| 50 | + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME | ||
| 51 | + ./build-kws-apk.sh | ||
| 52 | + | ||
| 53 | + - name: Display APK | ||
| 54 | + shell: bash | ||
| 55 | + run: | | ||
| 56 | + ls -lh ./apks/ | ||
| 57 | + | ||
| 58 | + - uses: actions/upload-artifact@v3 | ||
| 59 | + with: | ||
| 60 | + path: ./apks/*.apk | ||
| 61 | + | ||
| 62 | + - name: Release APK | ||
| 63 | + uses: svenstaro/upload-release-action@v2 | ||
| 64 | + with: | ||
| 65 | + file_glob: true | ||
| 66 | + file: apks/*.apk | ||
| 67 | + overwrite: true |
| @@ -107,6 +107,14 @@ jobs: | @@ -107,6 +107,14 @@ jobs: | ||
| 107 | name: release-static | 107 | name: release-static |
| 108 | path: build/bin/* | 108 | path: build/bin/* |
| 109 | 109 | ||
| 110 | + - name: Test transducer kws | ||
| 111 | + shell: bash | ||
| 112 | + run: | | ||
| 113 | + export PATH=$PWD/build/bin:$PATH | ||
| 114 | + export EXE=sherpa-onnx-keyword-spotter | ||
| 115 | + | ||
| 116 | + .github/scripts/test-kws.sh | ||
| 117 | + | ||
| 110 | - name: Test online CTC | 118 | - name: Test online CTC |
| 111 | shell: bash | 119 | shell: bash |
| 112 | run: | | 120 | run: | |
| @@ -98,6 +98,14 @@ jobs: | @@ -98,6 +98,14 @@ jobs: | ||
| 98 | otool -L build/bin/sherpa-onnx | 98 | otool -L build/bin/sherpa-onnx |
| 99 | otool -l build/bin/sherpa-onnx | 99 | otool -l build/bin/sherpa-onnx |
| 100 | 100 | ||
| 101 | + - name: Test transducer kws | ||
| 102 | + shell: bash | ||
| 103 | + run: | | ||
| 104 | + export PATH=$PWD/build/bin:$PATH | ||
| 105 | + export EXE=sherpa-onnx-keyword-spotter | ||
| 106 | + | ||
| 107 | + .github/scripts/test-kws.sh | ||
| 108 | + | ||
| 101 | - name: Test online CTC | 109 | - name: Test online CTC |
| 102 | shell: bash | 110 | shell: bash |
| 103 | run: | | 111 | run: | |
| @@ -106,7 +114,6 @@ jobs: | @@ -106,7 +114,6 @@ jobs: | ||
| 106 | 114 | ||
| 107 | .github/scripts/test-online-ctc.sh | 115 | .github/scripts/test-online-ctc.sh |
| 108 | 116 | ||
| 109 | - | ||
| 110 | - name: Test offline TTS | 117 | - name: Test offline TTS |
| 111 | shell: bash | 118 | shell: bash |
| 112 | run: | | 119 | run: | |
| @@ -62,7 +62,7 @@ jobs: | @@ -62,7 +62,7 @@ jobs: | ||
| 62 | - name: Install Python dependencies | 62 | - name: Install Python dependencies |
| 63 | shell: bash | 63 | shell: bash |
| 64 | run: | | 64 | run: | |
| 65 | - python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile | 65 | + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece==0.1.96 soundfile |
| 66 | 66 | ||
| 67 | - name: Install sherpa-onnx | 67 | - name: Install sherpa-onnx |
| 68 | shell: bash | 68 | shell: bash |
| @@ -45,7 +45,7 @@ jobs: | @@ -45,7 +45,7 @@ jobs: | ||
| 45 | - name: Install Python dependencies | 45 | - name: Install Python dependencies |
| 46 | shell: bash | 46 | shell: bash |
| 47 | run: | | 47 | run: | |
| 48 | - python3 -m pip install --upgrade pip numpy sentencepiece | 48 | + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece |
| 49 | 49 | ||
| 50 | - name: Install sherpa-onnx | 50 | - name: Install sherpa-onnx |
| 51 | shell: bash | 51 | shell: bash |
| @@ -45,7 +45,7 @@ jobs: | @@ -45,7 +45,7 @@ jobs: | ||
| 45 | - name: Install Python dependencies | 45 | - name: Install Python dependencies |
| 46 | shell: bash | 46 | shell: bash |
| 47 | run: | | 47 | run: | |
| 48 | - python3 -m pip install --upgrade pip numpy sentencepiece | 48 | + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece |
| 49 | 49 | ||
| 50 | - name: Install sherpa-onnx | 50 | - name: Install sherpa-onnx |
| 51 | shell: bash | 51 | shell: bash |
android/SherpaOnnxKws/.gitignore
0 → 100644
android/SherpaOnnxKws/app/.gitignore
0 → 100644
| 1 | +/build |
android/SherpaOnnxKws/app/build.gradle
0 → 100644
| 1 | +plugins { | ||
| 2 | + id 'com.android.application' | ||
| 3 | + id 'org.jetbrains.kotlin.android' | ||
| 4 | +} | ||
| 5 | + | ||
| 6 | +android { | ||
| 7 | + namespace 'com.k2fsa.sherpa.onnx' | ||
| 8 | + compileSdk 32 | ||
| 9 | + | ||
| 10 | + defaultConfig { | ||
| 11 | + applicationId "com.k2fsa.sherpa.onnx" | ||
| 12 | + minSdk 21 | ||
| 13 | + targetSdk 32 | ||
| 14 | + versionCode 1 | ||
| 15 | + versionName "1.0" | ||
| 16 | + | ||
| 17 | + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" | ||
| 18 | + } | ||
| 19 | + | ||
| 20 | + buildTypes { | ||
| 21 | + release { | ||
| 22 | + minifyEnabled false | ||
| 23 | + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' | ||
| 24 | + } | ||
| 25 | + } | ||
| 26 | + compileOptions { | ||
| 27 | + sourceCompatibility JavaVersion.VERSION_1_8 | ||
| 28 | + targetCompatibility JavaVersion.VERSION_1_8 | ||
| 29 | + } | ||
| 30 | + kotlinOptions { | ||
| 31 | + jvmTarget = '1.8' | ||
| 32 | + } | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +dependencies { | ||
| 36 | + | ||
| 37 | + implementation 'androidx.core:core-ktx:1.7.0' | ||
| 38 | + implementation 'androidx.appcompat:appcompat:1.5.1' | ||
| 39 | + implementation 'com.google.android.material:material:1.7.0' | ||
| 40 | + implementation 'androidx.constraintlayout:constraintlayout:2.1.4' | ||
| 41 | + testImplementation 'junit:junit:4.13.2' | ||
| 42 | + androidTestImplementation 'androidx.test.ext:junit:1.1.4' | ||
| 43 | + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0' | ||
| 44 | +} |
android/SherpaOnnxKws/app/proguard-rules.pro
0 → 100644
| 1 | +# Add project specific ProGuard rules here. | ||
| 2 | +# You can control the set of applied configuration files using the | ||
| 3 | +# proguardFiles setting in build.gradle. | ||
| 4 | +# | ||
| 5 | +# For more details, see | ||
| 6 | +# http://developer.android.com/guide/developing/tools/proguard.html | ||
| 7 | + | ||
| 8 | +# If your project uses WebView with JS, uncomment the following | ||
| 9 | +# and specify the fully qualified class name to the JavaScript interface | ||
| 10 | +# class: | ||
| 11 | +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { | ||
| 12 | +# public *; | ||
| 13 | +#} | ||
| 14 | + | ||
| 15 | +# Uncomment this to preserve the line number information for | ||
| 16 | +# debugging stack traces. | ||
| 17 | +#-keepattributes SourceFile,LineNumberTable | ||
| 18 | + | ||
| 19 | +# If you keep the line number information, uncomment this to | ||
| 20 | +# hide the original source file name. | ||
| 21 | +#-renamesourcefileattribute SourceFile |
android/SherpaOnnxKws/app/src/androidTest/java/com/k2fsa/sherpa/onnx/ExampleInstrumentedTest.kt
0 → 100644
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +import androidx.test.platform.app.InstrumentationRegistry | ||
| 4 | +import androidx.test.ext.junit.runners.AndroidJUnit4 | ||
| 5 | + | ||
| 6 | +import org.junit.Test | ||
| 7 | +import org.junit.runner.RunWith | ||
| 8 | + | ||
| 9 | +import org.junit.Assert.* | ||
| 10 | + | ||
| 11 | +/** | ||
| 12 | + * Instrumented test, which will execute on an Android device. | ||
| 13 | + * | ||
| 14 | + * See [testing documentation](http://d.android.com/tools/testing). | ||
| 15 | + */ | ||
| 16 | +@RunWith(AndroidJUnit4::class) | ||
| 17 | +class ExampleInstrumentedTest { | ||
| 18 | + @Test | ||
| 19 | + fun useAppContext() { | ||
| 20 | + // Context of the app under test. | ||
| 21 | + val appContext = InstrumentationRegistry.getInstrumentation().targetContext | ||
| 22 | + assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName) | ||
| 23 | + } | ||
| 24 | +} |
| 1 | +<?xml version="1.0" encoding="utf-8"?> | ||
| 2 | +<manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||
| 3 | + xmlns:tools="http://schemas.android.com/tools"> | ||
| 4 | + | ||
| 5 | + <uses-permission android:name="android.permission.RECORD_AUDIO" /> | ||
| 6 | + | ||
| 7 | + <application | ||
| 8 | + android:allowBackup="true" | ||
| 9 | + android:dataExtractionRules="@xml/data_extraction_rules" | ||
| 10 | + android:fullBackupContent="@xml/backup_rules" | ||
| 11 | + android:icon="@mipmap/ic_launcher" | ||
| 12 | + android:label="@string/app_name" | ||
| 13 | + android:roundIcon="@mipmap/ic_launcher_round" | ||
| 14 | + android:supportsRtl="true" | ||
| 15 | + android:theme="@style/Theme.SherpaOnnx" | ||
| 16 | + tools:targetApi="31"> | ||
| 17 | + <activity | ||
| 18 | + android:name=".MainActivity" | ||
| 19 | + android:exported="true"> | ||
| 20 | + <intent-filter> | ||
| 21 | + <action android:name="android.intent.action.MAIN" /> | ||
| 22 | + | ||
| 23 | + <category android:name="android.intent.category.LAUNCHER" /> | ||
| 24 | + </intent-filter> | ||
| 25 | + | ||
| 26 | + <meta-data | ||
| 27 | + android:name="android.app.lib_name" | ||
| 28 | + android:value="" /> | ||
| 29 | + </activity> | ||
| 30 | + </application> | ||
| 31 | + | ||
| 32 | +</manifest> |
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +import android.Manifest | ||
| 4 | +import android.content.pm.PackageManager | ||
| 5 | +import android.media.AudioFormat | ||
| 6 | +import android.media.AudioRecord | ||
| 7 | +import android.media.MediaRecorder | ||
| 8 | +import android.os.Bundle | ||
| 9 | +import android.text.method.ScrollingMovementMethod | ||
| 10 | +import android.util.Log | ||
| 11 | +import android.widget.Button | ||
| 12 | +import android.widget.EditText | ||
| 13 | +import android.widget.TextView | ||
| 14 | +import android.widget.Toast | ||
| 15 | +import androidx.appcompat.app.AppCompatActivity | ||
| 16 | +import androidx.core.app.ActivityCompat | ||
| 17 | +import com.k2fsa.sherpa.onnx.* | ||
| 18 | +import kotlin.concurrent.thread | ||
| 19 | + | ||
| 20 | +private const val TAG = "sherpa-onnx" | ||
| 21 | +private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 | ||
| 22 | + | ||
| 23 | +class MainActivity : AppCompatActivity() { | ||
| 24 | + private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO) | ||
| 25 | + | ||
| 26 | + private lateinit var model: SherpaOnnxKws | ||
| 27 | + private var audioRecord: AudioRecord? = null | ||
| 28 | + private lateinit var recordButton: Button | ||
| 29 | + private lateinit var textView: TextView | ||
| 30 | + private lateinit var inputText: EditText | ||
| 31 | + private var recordingThread: Thread? = null | ||
| 32 | + | ||
| 33 | + private val audioSource = MediaRecorder.AudioSource.MIC | ||
| 34 | + private val sampleRateInHz = 16000 | ||
| 35 | + private val channelConfig = AudioFormat.CHANNEL_IN_MONO | ||
| 36 | + | ||
| 37 | + // Note: We don't use AudioFormat.ENCODING_PCM_FLOAT | ||
| 38 | + // since the AudioRecord.read(float[]) needs API level >= 23 | ||
| 39 | + // but we are targeting API level >= 21 | ||
| 40 | + private val audioFormat = AudioFormat.ENCODING_PCM_16BIT | ||
| 41 | + private var idx: Int = 0 | ||
| 42 | + private var lastText: String = "" | ||
| 43 | + | ||
| 44 | + @Volatile | ||
| 45 | + private var isRecording: Boolean = false | ||
| 46 | + | ||
| 47 | + override fun onRequestPermissionsResult( | ||
| 48 | + requestCode: Int, permissions: Array<String>, grantResults: IntArray | ||
| 49 | + ) { | ||
| 50 | + super.onRequestPermissionsResult(requestCode, permissions, grantResults) | ||
| 51 | + val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) { | ||
| 52 | + grantResults[0] == PackageManager.PERMISSION_GRANTED | ||
| 53 | + } else { | ||
| 54 | + false | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + if (!permissionToRecordAccepted) { | ||
| 58 | + Log.e(TAG, "Audio record is disallowed") | ||
| 59 | + finish() | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + Log.i(TAG, "Audio record is permitted") | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + override fun onCreate(savedInstanceState: Bundle?) { | ||
| 66 | + super.onCreate(savedInstanceState) | ||
| 67 | + setContentView(R.layout.activity_main) | ||
| 68 | + | ||
| 69 | + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) | ||
| 70 | + | ||
| 71 | + Log.i(TAG, "Start to initialize model") | ||
| 72 | + initModel() | ||
| 73 | + Log.i(TAG, "Finished initializing model") | ||
| 74 | + | ||
| 75 | + recordButton = findViewById(R.id.record_button) | ||
| 76 | + recordButton.setOnClickListener { onclick() } | ||
| 77 | + | ||
| 78 | + textView = findViewById(R.id.my_text) | ||
| 79 | + textView.movementMethod = ScrollingMovementMethod() | ||
| 80 | + | ||
| 81 | + inputText = findViewById(R.id.input_text) | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + private fun onclick() { | ||
| 85 | + if (!isRecording) { | ||
| 86 | + var keywords = inputText.text.toString() | ||
| 87 | + | ||
| 88 | + Log.i(TAG, keywords) | ||
| 89 | + keywords = keywords.replace("\n", "/") | ||
| 90 | + // If keywords is an empty string, it just resets the decoding stream | ||
| 91 | + // always returns true in this case. | ||
| 92 | + // If keywords is not empty, it will create a new decoding stream with | ||
| 93 | + // the given keywords appended to the default keywords. | ||
| 94 | + // Return false if errors occured when adding keywords, true otherwise. | ||
| 95 | + val status = model.reset(keywords) | ||
| 96 | + if (!status) { | ||
| 97 | + Log.i(TAG, "Failed to reset with keywords.") | ||
| 98 | + Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show(); | ||
| 99 | + return | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + val ret = initMicrophone() | ||
| 103 | + if (!ret) { | ||
| 104 | + Log.e(TAG, "Failed to initialize microphone") | ||
| 105 | + return | ||
| 106 | + } | ||
| 107 | + Log.i(TAG, "state: ${audioRecord?.state}") | ||
| 108 | + audioRecord!!.startRecording() | ||
| 109 | + recordButton.setText(R.string.stop) | ||
| 110 | + isRecording = true | ||
| 111 | + textView.text = "" | ||
| 112 | + lastText = "" | ||
| 113 | + idx = 0 | ||
| 114 | + | ||
| 115 | + recordingThread = thread(true) { | ||
| 116 | + processSamples() | ||
| 117 | + } | ||
| 118 | + Log.i(TAG, "Started recording") | ||
| 119 | + } else { | ||
| 120 | + isRecording = false | ||
| 121 | + audioRecord!!.stop() | ||
| 122 | + audioRecord!!.release() | ||
| 123 | + audioRecord = null | ||
| 124 | + recordButton.setText(R.string.start) | ||
| 125 | + Log.i(TAG, "Stopped recording") | ||
| 126 | + } | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + private fun processSamples() { | ||
| 130 | + Log.i(TAG, "processing samples") | ||
| 131 | + | ||
| 132 | + val interval = 0.1 // i.e., 100 ms | ||
| 133 | + val bufferSize = (interval * sampleRateInHz).toInt() // in samples | ||
| 134 | + val buffer = ShortArray(bufferSize) | ||
| 135 | + | ||
| 136 | + while (isRecording) { | ||
| 137 | + val ret = audioRecord?.read(buffer, 0, buffer.size) | ||
| 138 | + if (ret != null && ret > 0) { | ||
| 139 | + val samples = FloatArray(ret) { buffer[it] / 32768.0f } | ||
| 140 | + model.acceptWaveform(samples, sampleRate=sampleRateInHz) | ||
| 141 | + while (model.isReady()) { | ||
| 142 | + model.decode() | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + val text = model.keyword | ||
| 146 | + | ||
| 147 | + var textToDisplay = lastText; | ||
| 148 | + | ||
| 149 | + if(text.isNotBlank()) { | ||
| 150 | + if (lastText.isBlank()) { | ||
| 151 | + textToDisplay = "${idx}: ${text}" | ||
| 152 | + } else { | ||
| 153 | + textToDisplay = "${idx}: ${text}\n${lastText}" | ||
| 154 | + } | ||
| 155 | + lastText = "${idx}: ${text}\n${lastText}" | ||
| 156 | + idx += 1 | ||
| 157 | + } | ||
| 158 | + | ||
| 159 | + runOnUiThread { | ||
| 160 | + textView.text = textToDisplay | ||
| 161 | + } | ||
| 162 | + } | ||
| 163 | + } | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + private fun initMicrophone(): Boolean { | ||
| 167 | + if (ActivityCompat.checkSelfPermission( | ||
| 168 | + this, Manifest.permission.RECORD_AUDIO | ||
| 169 | + ) != PackageManager.PERMISSION_GRANTED | ||
| 170 | + ) { | ||
| 171 | + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) | ||
| 172 | + return false | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat) | ||
| 176 | + Log.i( | ||
| 177 | + TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}" | ||
| 178 | + ) | ||
| 179 | + | ||
| 180 | + audioRecord = AudioRecord( | ||
| 181 | + audioSource, | ||
| 182 | + sampleRateInHz, | ||
| 183 | + channelConfig, | ||
| 184 | + audioFormat, | ||
| 185 | + numBytes * 2 // a sample has two bytes as we are using 16-bit PCM | ||
| 186 | + ) | ||
| 187 | + return true | ||
| 188 | + } | ||
| 189 | + | ||
| 190 | + private fun initModel() { | ||
| 191 | + // Please change getModelConfig() to add new models | ||
| 192 | + // See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 193 | + // for a list of available models | ||
| 194 | + val type = 0 | ||
| 195 | + Log.i(TAG, "Select model type ${type}") | ||
| 196 | + val config = KeywordSpotterConfig( | ||
| 197 | + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), | ||
| 198 | + modelConfig = getModelConfig(type = type)!!, | ||
| 199 | + keywordsFile = getKeywordsFile(type = type)!!, | ||
| 200 | + ) | ||
| 201 | + | ||
| 202 | + model = SherpaOnnxKws( | ||
| 203 | + assetManager = application.assets, | ||
| 204 | + config = config, | ||
| 205 | + ) | ||
| 206 | + } | ||
| 207 | +} |
| 1 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 2 | +package com.k2fsa.sherpa.onnx | ||
| 3 | + | ||
| 4 | +import android.content.res.AssetManager | ||
| 5 | + | ||
| 6 | +data class OnlineTransducerModelConfig( | ||
| 7 | + var encoder: String = "", | ||
| 8 | + var decoder: String = "", | ||
| 9 | + var joiner: String = "", | ||
| 10 | +) | ||
| 11 | + | ||
| 12 | +data class OnlineModelConfig( | ||
| 13 | + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), | ||
| 14 | + var tokens: String, | ||
| 15 | + var numThreads: Int = 1, | ||
| 16 | + var debug: Boolean = false, | ||
| 17 | + var provider: String = "cpu", | ||
| 18 | + var modelType: String = "", | ||
| 19 | +) | ||
| 20 | + | ||
| 21 | +data class FeatureConfig( | ||
| 22 | + var sampleRate: Int = 16000, | ||
| 23 | + var featureDim: Int = 80, | ||
| 24 | +) | ||
| 25 | + | ||
| 26 | +data class KeywordSpotterConfig( | ||
| 27 | + var featConfig: FeatureConfig = FeatureConfig(), | ||
| 28 | + var modelConfig: OnlineModelConfig, | ||
| 29 | + var maxActivePaths: Int = 4, | ||
| 30 | + var keywordsFile: String = "keywords.txt", | ||
| 31 | + var keywordsScore: Float = 1.5f, | ||
| 32 | + var keywordsThreshold: Float = 0.25f, | ||
| 33 | + var numTrailingBlanks: Int = 2, | ||
| 34 | +) | ||
| 35 | + | ||
| 36 | +class SherpaOnnxKws( | ||
| 37 | + assetManager: AssetManager? = null, | ||
| 38 | + var config: KeywordSpotterConfig, | ||
| 39 | +) { | ||
| 40 | + private val ptr: Long | ||
| 41 | + | ||
| 42 | + init { | ||
| 43 | + if (assetManager != null) { | ||
| 44 | + ptr = new(assetManager, config) | ||
| 45 | + } else { | ||
| 46 | + ptr = newFromFile(config) | ||
| 47 | + } | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + protected fun finalize() { | ||
| 51 | + delete(ptr) | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = | ||
| 55 | + acceptWaveform(ptr, samples, sampleRate) | ||
| 56 | + | ||
| 57 | + fun inputFinished() = inputFinished(ptr) | ||
| 58 | + fun decode() = decode(ptr) | ||
| 59 | + fun isReady(): Boolean = isReady(ptr) | ||
| 60 | + fun reset(keywords: String): Boolean = reset(ptr, keywords) | ||
| 61 | + | ||
| 62 | + val keyword: String | ||
| 63 | + get() = getKeyword(ptr) | ||
| 64 | + | ||
| 65 | + private external fun delete(ptr: Long) | ||
| 66 | + | ||
| 67 | + private external fun new( | ||
| 68 | + assetManager: AssetManager, | ||
| 69 | + config: KeywordSpotterConfig, | ||
| 70 | + ): Long | ||
| 71 | + | ||
| 72 | + private external fun newFromFile( | ||
| 73 | + config: KeywordSpotterConfig, | ||
| 74 | + ): Long | ||
| 75 | + | ||
| 76 | + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) | ||
| 77 | + private external fun inputFinished(ptr: Long) | ||
| 78 | + private external fun getKeyword(ptr: Long): String | ||
| 79 | + private external fun reset(ptr: Long, keywords: String): Boolean | ||
| 80 | + private external fun decode(ptr: Long) | ||
| 81 | + private external fun isReady(ptr: Long): Boolean | ||
| 82 | + | ||
| 83 | + companion object { | ||
| 84 | + init { | ||
| 85 | + System.loadLibrary("sherpa-onnx-jni") | ||
| 86 | + } | ||
| 87 | + } | ||
| 88 | +} | ||
| 89 | + | ||
| 90 | +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { | ||
| 91 | + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +/* | ||
| 95 | +Please see | ||
| 96 | +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 97 | +for a list of pre-trained models. | ||
| 98 | + | ||
| 99 | +We only add a few here. Please change the following code | ||
| 100 | +to add your own. (It should be straightforward to add a new model | ||
| 101 | +by following the code) | ||
| 102 | + | ||
| 103 | +@param type | ||
| 104 | +0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese) | ||
| 105 | + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary | ||
| 106 | + | ||
| 107 | +1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English) | ||
| 108 | + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary | ||
| 109 | + | ||
| 110 | + */ | ||
| 111 | +fun getModelConfig(type: Int): OnlineModelConfig? { | ||
| 112 | + when (type) { | ||
| 113 | + 0 -> { | ||
| 114 | + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" | ||
| 115 | + return OnlineModelConfig( | ||
| 116 | + transducer = OnlineTransducerModelConfig( | ||
| 117 | + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 118 | + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 119 | + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 120 | + ), | ||
| 121 | + tokens = "$modelDir/tokens.txt", | ||
| 122 | + modelType = "zipformer2", | ||
| 123 | + ) | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + 1 -> { | ||
| 127 | + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01" | ||
| 128 | + return OnlineModelConfig( | ||
| 129 | + transducer = OnlineTransducerModelConfig( | ||
| 130 | + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 131 | + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 132 | + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 133 | + ), | ||
| 134 | + tokens = "$modelDir/tokens.txt", | ||
| 135 | + modelType = "zipformer2", | ||
| 136 | + ) | ||
| 137 | + } | ||
| 138 | + | ||
| 139 | + } | ||
| 140 | + return null; | ||
| 141 | +} | ||
| 142 | + | ||
| 143 | +/* | ||
| 144 | + * Get the default keywords for each model. | ||
| 145 | + * Caution: The types and modelDir should be the same as those in getModelConfig | ||
| 146 | + * function above. | ||
| 147 | + */ | ||
| 148 | +fun getKeywordsFile(type: Int) : String { | ||
| 149 | + when (type) { | ||
| 150 | + 0 -> { | ||
| 151 | + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" | ||
| 152 | + return "$modelDir/keywords.txt" | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + 1 -> { | ||
| 156 | + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01" | ||
| 157 | + return "$modelDir/keywords.txt" | ||
| 158 | + } | ||
| 159 | + | ||
| 160 | + } | ||
| 161 | + return ""; | ||
| 162 | +} |
| 1 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 2 | +package com.k2fsa.sherpa.onnx | ||
| 3 | + | ||
| 4 | +import android.content.res.AssetManager | ||
| 5 | + | ||
| 6 | +class WaveReader { | ||
| 7 | + companion object { | ||
| 8 | + // Read a mono wave file asset | ||
| 9 | + // The returned array has two entries: | ||
| 10 | + // - the first entry contains an 1-D float array | ||
| 11 | + // - the second entry is the sample rate | ||
| 12 | + external fun readWaveFromAsset( | ||
| 13 | + assetManager: AssetManager, | ||
| 14 | + filename: String, | ||
| 15 | + ): Array<Any> | ||
| 16 | + | ||
| 17 | + // Read a mono wave file from disk | ||
| 18 | + // The returned array has two entries: | ||
| 19 | + // - the first entry contains an 1-D float array | ||
| 20 | + // - the second entry is the sample rate | ||
| 21 | + external fun readWaveFromFile( | ||
| 22 | + filename: String, | ||
| 23 | + ): Array<Any> | ||
| 24 | + | ||
| 25 | + init { | ||
| 26 | + System.loadLibrary("sherpa-onnx-jni") | ||
| 27 | + } | ||
| 28 | + } | ||
| 29 | +} |
| 1 | +<vector xmlns:android="http://schemas.android.com/apk/res/android" | ||
| 2 | + xmlns:aapt="http://schemas.android.com/aapt" | ||
| 3 | + android:width="108dp" | ||
| 4 | + android:height="108dp" | ||
| 5 | + android:viewportWidth="108" | ||
| 6 | + android:viewportHeight="108"> | ||
| 7 | + <path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z"> | ||
| 8 | + <aapt:attr name="android:fillColor"> | ||
| 9 | + <gradient | ||
| 10 | + android:endX="85.84757" | ||
| 11 | + android:endY="92.4963" | ||
| 12 | + android:startX="42.9492" | ||
| 13 | + android:startY="49.59793" | ||
| 14 | + android:type="linear"> | ||
| 15 | + <item | ||
| 16 | + android:color="#44000000" | ||
| 17 | + android:offset="0.0" /> | ||
| 18 | + <item | ||
| 19 | + android:color="#00000000" | ||
| 20 | + android:offset="1.0" /> | ||
| 21 | + </gradient> | ||
| 22 | + </aapt:attr> | ||
| 23 | + </path> | ||
| 24 | + <path | ||
| 25 | + android:fillColor="#FFFFFF" | ||
| 26 | + android:fillType="nonZero" | ||
| 27 | + android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z" | ||
| 28 | + android:strokeWidth="1" | ||
| 29 | + android:strokeColor="#00000000" /> | ||
| 30 | +</vector> |
| 1 | +<?xml version="1.0" encoding="utf-8"?> | ||
| 2 | +<vector xmlns:android="http://schemas.android.com/apk/res/android" | ||
| 3 | + android:width="108dp" | ||
| 4 | + android:height="108dp" | ||
| 5 | + android:viewportWidth="108" | ||
| 6 | + android:viewportHeight="108"> | ||
| 7 | + <path | ||
| 8 | + android:fillColor="#3DDC84" | ||
| 9 | + android:pathData="M0,0h108v108h-108z" /> | ||
| 10 | + <path | ||
| 11 | + android:fillColor="#00000000" | ||
| 12 | + android:pathData="M9,0L9,108" | ||
| 13 | + android:strokeWidth="0.8" | ||
| 14 | + android:strokeColor="#33FFFFFF" /> | ||
| 15 | + <path | ||
| 16 | + android:fillColor="#00000000" | ||
| 17 | + android:pathData="M19,0L19,108" | ||
| 18 | + android:strokeWidth="0.8" | ||
| 19 | + android:strokeColor="#33FFFFFF" /> | ||
| 20 | + <path | ||
| 21 | + android:fillColor="#00000000" | ||
| 22 | + android:pathData="M29,0L29,108" | ||
| 23 | + android:strokeWidth="0.8" | ||
| 24 | + android:strokeColor="#33FFFFFF" /> | ||
| 25 | + <path | ||
| 26 | + android:fillColor="#00000000" | ||
| 27 | + android:pathData="M39,0L39,108" | ||
| 28 | + android:strokeWidth="0.8" | ||
| 29 | + android:strokeColor="#33FFFFFF" /> | ||
| 30 | + <path | ||
| 31 | + android:fillColor="#00000000" | ||
| 32 | + android:pathData="M49,0L49,108" | ||
| 33 | + android:strokeWidth="0.8" | ||
| 34 | + android:strokeColor="#33FFFFFF" /> | ||
| 35 | + <path | ||
| 36 | + android:fillColor="#00000000" | ||
| 37 | + android:pathData="M59,0L59,108" | ||
| 38 | + android:strokeWidth="0.8" | ||
| 39 | + android:strokeColor="#33FFFFFF" /> | ||
| 40 | + <path | ||
| 41 | + android:fillColor="#00000000" | ||
| 42 | + android:pathData="M69,0L69,108" | ||
| 43 | + android:strokeWidth="0.8" | ||
| 44 | + android:strokeColor="#33FFFFFF" /> | ||
| 45 | + <path | ||
| 46 | + android:fillColor="#00000000" | ||
| 47 | + android:pathData="M79,0L79,108" | ||
| 48 | + android:strokeWidth="0.8" | ||
| 49 | + android:strokeColor="#33FFFFFF" /> | ||
| 50 | + <path | ||
| 51 | + android:fillColor="#00000000" | ||
| 52 | + android:pathData="M89,0L89,108" | ||
| 53 | + android:strokeWidth="0.8" | ||
| 54 | + android:strokeColor="#33FFFFFF" /> | ||
| 55 | + <path | ||
| 56 | + android:fillColor="#00000000" | ||
| 57 | + android:pathData="M99,0L99,108" | ||
| 58 | + android:strokeWidth="0.8" | ||
| 59 | + android:strokeColor="#33FFFFFF" /> | ||
| 60 | + <path | ||
| 61 | + android:fillColor="#00000000" | ||
| 62 | + android:pathData="M0,9L108,9" | ||
| 63 | + android:strokeWidth="0.8" | ||
| 64 | + android:strokeColor="#33FFFFFF" /> | ||
| 65 | + <path | ||
| 66 | + android:fillColor="#00000000" | ||
| 67 | + android:pathData="M0,19L108,19" | ||
| 68 | + android:strokeWidth="0.8" | ||
| 69 | + android:strokeColor="#33FFFFFF" /> | ||
| 70 | + <path | ||
| 71 | + android:fillColor="#00000000" | ||
| 72 | + android:pathData="M0,29L108,29" | ||
| 73 | + android:strokeWidth="0.8" | ||
| 74 | + android:strokeColor="#33FFFFFF" /> | ||
| 75 | + <path | ||
| 76 | + android:fillColor="#00000000" | ||
| 77 | + android:pathData="M0,39L108,39" | ||
| 78 | + android:strokeWidth="0.8" | ||
| 79 | + android:strokeColor="#33FFFFFF" /> | ||
| 80 | + <path | ||
| 81 | + android:fillColor="#00000000" | ||
| 82 | + android:pathData="M0,49L108,49" | ||
| 83 | + android:strokeWidth="0.8" | ||
| 84 | + android:strokeColor="#33FFFFFF" /> | ||
| 85 | + <path | ||
| 86 | + android:fillColor="#00000000" | ||
| 87 | + android:pathData="M0,59L108,59" | ||
| 88 | + android:strokeWidth="0.8" | ||
| 89 | + android:strokeColor="#33FFFFFF" /> | ||
| 90 | + <path | ||
| 91 | + android:fillColor="#00000000" | ||
| 92 | + android:pathData="M0,69L108,69" | ||
| 93 | + android:strokeWidth="0.8" | ||
| 94 | + android:strokeColor="#33FFFFFF" /> | ||
| 95 | + <path | ||
| 96 | + android:fillColor="#00000000" | ||
| 97 | + android:pathData="M0,79L108,79" | ||
| 98 | + android:strokeWidth="0.8" | ||
| 99 | + android:strokeColor="#33FFFFFF" /> | ||
| 100 | + <path | ||
| 101 | + android:fillColor="#00000000" | ||
| 102 | + android:pathData="M0,89L108,89" | ||
| 103 | + android:strokeWidth="0.8" | ||
| 104 | + android:strokeColor="#33FFFFFF" /> | ||
| 105 | + <path | ||
| 106 | + android:fillColor="#00000000" | ||
| 107 | + android:pathData="M0,99L108,99" | ||
| 108 | + android:strokeWidth="0.8" | ||
| 109 | + android:strokeColor="#33FFFFFF" /> | ||
| 110 | + <path | ||
| 111 | + android:fillColor="#00000000" | ||
| 112 | + android:pathData="M19,29L89,29" | ||
| 113 | + android:strokeWidth="0.8" | ||
| 114 | + android:strokeColor="#33FFFFFF" /> | ||
| 115 | + <path | ||
| 116 | + android:fillColor="#00000000" | ||
| 117 | + android:pathData="M19,39L89,39" | ||
| 118 | + android:strokeWidth="0.8" | ||
| 119 | + android:strokeColor="#33FFFFFF" /> | ||
| 120 | + <path | ||
| 121 | + android:fillColor="#00000000" | ||
| 122 | + android:pathData="M19,49L89,49" | ||
| 123 | + android:strokeWidth="0.8" | ||
| 124 | + android:strokeColor="#33FFFFFF" /> | ||
| 125 | + <path | ||
| 126 | + android:fillColor="#00000000" | ||
| 127 | + android:pathData="M19,59L89,59" | ||
| 128 | + android:strokeWidth="0.8" | ||
| 129 | + android:strokeColor="#33FFFFFF" /> | ||
| 130 | + <path | ||
| 131 | + android:fillColor="#00000000" | ||
| 132 | + android:pathData="M19,69L89,69" | ||
| 133 | + android:strokeWidth="0.8" | ||
| 134 | + android:strokeColor="#33FFFFFF" /> | ||
| 135 | + <path | ||
| 136 | + android:fillColor="#00000000" | ||
| 137 | + android:pathData="M19,79L89,79" | ||
| 138 | + android:strokeWidth="0.8" | ||
| 139 | + android:strokeColor="#33FFFFFF" /> | ||
| 140 | + <path | ||
| 141 | + android:fillColor="#00000000" | ||
| 142 | + android:pathData="M29,19L29,89" | ||
| 143 | + android:strokeWidth="0.8" | ||
| 144 | + android:strokeColor="#33FFFFFF" /> | ||
| 145 | + <path | ||
| 146 | + android:fillColor="#00000000" | ||
| 147 | + android:pathData="M39,19L39,89" | ||
| 148 | + android:strokeWidth="0.8" | ||
| 149 | + android:strokeColor="#33FFFFFF" /> | ||
| 150 | + <path | ||
| 151 | + android:fillColor="#00000000" | ||
| 152 | + android:pathData="M49,19L49,89" | ||
| 153 | + android:strokeWidth="0.8" | ||
| 154 | + android:strokeColor="#33FFFFFF" /> | ||
| 155 | + <path | ||
| 156 | + android:fillColor="#00000000" | ||
| 157 | + android:pathData="M59,19L59,89" | ||
| 158 | + android:strokeWidth="0.8" | ||
| 159 | + android:strokeColor="#33FFFFFF" /> | ||
| 160 | + <path | ||
| 161 | + android:fillColor="#00000000" | ||
| 162 | + android:pathData="M69,19L69,89" | ||
| 163 | + android:strokeWidth="0.8" | ||
| 164 | + android:strokeColor="#33FFFFFF" /> | ||
| 165 | + <path | ||
| 166 | + android:fillColor="#00000000" | ||
| 167 | + android:pathData="M79,19L79,89" | ||
| 168 | + android:strokeWidth="0.8" | ||
| 169 | + android:strokeColor="#33FFFFFF" /> | ||
| 170 | +</vector> |
| 1 | +<?xml version="1.0" encoding="utf-8"?> | ||
| 2 | +<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" | ||
| 3 | + xmlns:app="http://schemas.android.com/apk/res-auto" | ||
| 4 | + xmlns:tools="http://schemas.android.com/tools" | ||
| 5 | + android:layout_width="match_parent" | ||
| 6 | + android:layout_height="match_parent" | ||
| 7 | + tools:context=".MainActivity"> | ||
| 8 | + | ||
| 9 | + <LinearLayout | ||
| 10 | + android:layout_width="match_parent" | ||
| 11 | + android:layout_height="match_parent" | ||
| 12 | + android:gravity="center" | ||
| 13 | + android:orientation="vertical"> | ||
| 14 | + | ||
| 15 | + <EditText | ||
| 16 | + android:id="@+id/input_text" | ||
| 17 | + android:layout_width="match_parent" | ||
| 18 | + android:layout_height="320dp" | ||
| 19 | + android:layout_weight="2.5" | ||
| 20 | + android:hint="@string/keyword_hint" | ||
| 21 | + android:scrollbars="vertical" | ||
| 22 | + android:text="" | ||
| 23 | + android:textSize="15dp" /> | ||
| 24 | + | ||
| 25 | + <TextView | ||
| 26 | + android:id="@+id/my_text" | ||
| 27 | + android:layout_width="match_parent" | ||
| 28 | + android:layout_height="443dp" | ||
| 29 | + android:layout_weight="2.5" | ||
| 30 | + android:padding="24dp" | ||
| 31 | + android:scrollbars="vertical" | ||
| 32 | + android:singleLine="false" | ||
| 33 | + android:text="@string/hint" | ||
| 34 | + android:textSize="15dp" /> | ||
| 35 | + | ||
| 36 | + <Button | ||
| 37 | + android:id="@+id/record_button" | ||
| 38 | + android:layout_width="wrap_content" | ||
| 39 | + android:layout_height="wrap_content" | ||
| 40 | + android:layout_weight="0.5" | ||
| 41 | + android:text="@string/start" /> | ||
| 42 | + | ||
| 43 | + </LinearLayout> | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +</androidx.constraintlayout.widget.ConstraintLayout> |
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
不能预览此文件类型
| 1 | +<resources xmlns:tools="http://schemas.android.com/tools"> | ||
| 2 | + <!-- Base application theme. --> | ||
| 3 | + <style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar"> | ||
| 4 | + <!-- Primary brand color. --> | ||
| 5 | + <item name="colorPrimary">@color/purple_200</item> | ||
| 6 | + <item name="colorPrimaryVariant">@color/purple_700</item> | ||
| 7 | + <item name="colorOnPrimary">@color/black</item> | ||
| 8 | + <!-- Secondary brand color. --> | ||
| 9 | + <item name="colorSecondary">@color/teal_200</item> | ||
| 10 | + <item name="colorSecondaryVariant">@color/teal_200</item> | ||
| 11 | + <item name="colorOnSecondary">@color/black</item> | ||
| 12 | + <!-- Status bar color. --> | ||
| 13 | + <item name="android:statusBarColor">?attr/colorPrimaryVariant</item> | ||
| 14 | + <!-- Customize your theme here. --> | ||
| 15 | + </style> | ||
| 16 | +</resources> |
| 1 | +<?xml version="1.0" encoding="utf-8"?> | ||
| 2 | +<resources> | ||
| 3 | + <color name="purple_200">#FFBB86FC</color> | ||
| 4 | + <color name="purple_500">#FF6200EE</color> | ||
| 5 | + <color name="purple_700">#FF3700B3</color> | ||
| 6 | + <color name="teal_200">#FF03DAC5</color> | ||
| 7 | + <color name="teal_700">#FF018786</color> | ||
| 8 | + <color name="black">#FF000000</color> | ||
| 9 | + <color name="white">#FFFFFFFF</color> | ||
| 10 | +</resources> |
| 1 | +<resources> | ||
| 2 | + <string name="app_name">KWS with Next-gen Kaldi</string> | ||
| 3 | + <string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi. | ||
| 4 | + \n | ||
| 5 | + \n\n\n | ||
| 6 | + The source code and pre-trained models are publicly available. | ||
| 7 | + Please see https://github.com/k2-fsa/sherpa-onnx for details. | ||
| 8 | + </string> | ||
| 9 | + <string name="keyword_hint">Input your keywords here, one keyword perline.</string> | ||
| 10 | + <string name="start">Start</string> | ||
| 11 | + <string name="stop">Stop</string> | ||
| 12 | +</resources> |
| 1 | +<resources xmlns:tools="http://schemas.android.com/tools"> | ||
| 2 | + <!-- Base application theme. --> | ||
| 3 | + <style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar"> | ||
| 4 | + <!-- Primary brand color. --> | ||
| 5 | + <item name="colorPrimary">@color/purple_500</item> | ||
| 6 | + <item name="colorPrimaryVariant">@color/purple_700</item> | ||
| 7 | + <item name="colorOnPrimary">@color/white</item> | ||
| 8 | + <!-- Secondary brand color. --> | ||
| 9 | + <item name="colorSecondary">@color/teal_200</item> | ||
| 10 | + <item name="colorSecondaryVariant">@color/teal_700</item> | ||
| 11 | + <item name="colorOnSecondary">@color/black</item> | ||
| 12 | + <!-- Status bar color. --> | ||
| 13 | + <item name="android:statusBarColor">?attr/colorPrimaryVariant</item> | ||
| 14 | + <!-- Customize your theme here. --> | ||
| 15 | + </style> | ||
| 16 | +</resources> |
| 1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
| 2 | + Sample backup rules file; uncomment and customize as necessary. | ||
| 3 | + See https://developer.android.com/guide/topics/data/autobackup | ||
| 4 | + for details. | ||
| 5 | + Note: This file is ignored for devices older that API 31 | ||
| 6 | + See https://developer.android.com/about/versions/12/backup-restore | ||
| 7 | +--> | ||
| 8 | +<full-backup-content> | ||
| 9 | + <!-- | ||
| 10 | + <include domain="sharedpref" path="."/> | ||
| 11 | + <exclude domain="sharedpref" path="device.xml"/> | ||
| 12 | +--> | ||
| 13 | +</full-backup-content> |
| 1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
| 2 | + Sample data extraction rules file; uncomment and customize as necessary. | ||
| 3 | + See https://developer.android.com/about/versions/12/backup-restore#xml-changes | ||
| 4 | + for details. | ||
| 5 | +--> | ||
| 6 | +<data-extraction-rules> | ||
| 7 | + <cloud-backup> | ||
| 8 | + <!-- TODO: Use <include> and <exclude> to control what is backed up. | ||
| 9 | + <include .../> | ||
| 10 | + <exclude .../> | ||
| 11 | + --> | ||
| 12 | + </cloud-backup> | ||
| 13 | + <!-- | ||
| 14 | + <device-transfer> | ||
| 15 | + <include .../> | ||
| 16 | + <exclude .../> | ||
| 17 | + </device-transfer> | ||
| 18 | + --> | ||
| 19 | +</data-extraction-rules> |
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +import org.junit.Test | ||
| 4 | + | ||
| 5 | +import org.junit.Assert.* | ||
| 6 | + | ||
| 7 | +/** | ||
| 8 | + * Example local unit test, which will execute on the development machine (host). | ||
| 9 | + * | ||
| 10 | + * See [testing documentation](http://d.android.com/tools/testing). | ||
| 11 | + */ | ||
| 12 | +class ExampleUnitTest { | ||
| 13 | + @Test | ||
| 14 | + fun addition_isCorrect() { | ||
| 15 | + assertEquals(4, 2 + 2) | ||
| 16 | + } | ||
| 17 | +} |
android/SherpaOnnxKws/build.gradle
0 → 100644
| 1 | +// Top-level build file where you can add configuration options common to all sub-projects/modules. | ||
| 2 | +plugins { | ||
| 3 | + id 'com.android.application' version '7.3.1' apply false | ||
| 4 | + id 'com.android.library' version '7.3.1' apply false | ||
| 5 | + id 'org.jetbrains.kotlin.android' version '1.7.20' apply false | ||
| 6 | +} |
android/SherpaOnnxKws/gradle.properties
0 → 100644
| 1 | +# Project-wide Gradle settings. | ||
| 2 | +# IDE (e.g. Android Studio) users: | ||
| 3 | +# Gradle settings configured through the IDE *will override* | ||
| 4 | +# any settings specified in this file. | ||
| 5 | +# For more details on how to configure your build environment visit | ||
| 6 | +# http://www.gradle.org/docs/current/userguide/build_environment.html | ||
| 7 | +# Specifies the JVM arguments used for the daemon process. | ||
| 8 | +# The setting is particularly useful for tweaking memory settings. | ||
| 9 | +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 | ||
| 10 | +# When configured, Gradle will run in incubating parallel mode. | ||
| 11 | +# This option should only be used with decoupled projects. More details, visit | ||
| 12 | +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects | ||
| 13 | +# org.gradle.parallel=true | ||
| 14 | +# AndroidX package structure to make it clearer which packages are bundled with the | ||
| 15 | +# Android operating system, and which are packaged with your app's APK | ||
| 16 | +# https://developer.android.com/topic/libraries/support-library/androidx-rn | ||
| 17 | +android.useAndroidX=true | ||
| 18 | +# Kotlin code style for this project: "official" or "obsolete": | ||
| 19 | +kotlin.code.style=official | ||
| 20 | +# Enables namespacing of each library's R class so that its R class includes only the | ||
| 21 | +# resources declared in the library itself and none from the library's dependencies, | ||
| 22 | +# thereby reducing the size of the R class for that library | ||
| 23 | +android.nonTransitiveRClass=true |
不能预览此文件类型
android/SherpaOnnxKws/gradlew
0 → 100755
| 1 | +#!/usr/bin/env sh | ||
| 2 | + | ||
| 3 | +# | ||
| 4 | +# Copyright 2015 the original author or authors. | ||
| 5 | +# | ||
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | +# you may not use this file except in compliance with the License. | ||
| 8 | +# You may obtain a copy of the License at | ||
| 9 | +# | ||
| 10 | +# https://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | +# | ||
| 12 | +# Unless required by applicable law or agreed to in writing, software | ||
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | +# See the License for the specific language governing permissions and | ||
| 16 | +# limitations under the License. | ||
| 17 | +# | ||
| 18 | + | ||
| 19 | +############################################################################## | ||
| 20 | +## | ||
| 21 | +## Gradle start up script for UN*X | ||
| 22 | +## | ||
| 23 | +############################################################################## | ||
| 24 | + | ||
| 25 | +# Attempt to set APP_HOME | ||
| 26 | +# Resolve links: $0 may be a link | ||
| 27 | +PRG="$0" | ||
| 28 | +# Need this for relative symlinks. | ||
| 29 | +while [ -h "$PRG" ] ; do | ||
| 30 | + ls=`ls -ld "$PRG"` | ||
| 31 | + link=`expr "$ls" : '.*-> \(.*\)$'` | ||
| 32 | + if expr "$link" : '/.*' > /dev/null; then | ||
| 33 | + PRG="$link" | ||
| 34 | + else | ||
| 35 | + PRG=`dirname "$PRG"`"/$link" | ||
| 36 | + fi | ||
| 37 | +done | ||
| 38 | +SAVED="`pwd`" | ||
| 39 | +cd "`dirname \"$PRG\"`/" >/dev/null | ||
| 40 | +APP_HOME="`pwd -P`" | ||
| 41 | +cd "$SAVED" >/dev/null | ||
| 42 | + | ||
| 43 | +APP_NAME="Gradle" | ||
| 44 | +APP_BASE_NAME=`basename "$0"` | ||
| 45 | + | ||
| 46 | +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. | ||
| 47 | +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' | ||
| 48 | + | ||
| 49 | +# Use the maximum available, or set MAX_FD != -1 to use that value. | ||
| 50 | +MAX_FD="maximum" | ||
| 51 | + | ||
| 52 | +warn () { | ||
| 53 | + echo "$*" | ||
| 54 | +} | ||
| 55 | + | ||
| 56 | +die () { | ||
| 57 | + echo | ||
| 58 | + echo "$*" | ||
| 59 | + echo | ||
| 60 | + exit 1 | ||
| 61 | +} | ||
| 62 | + | ||
| 63 | +# OS specific support (must be 'true' or 'false'). | ||
| 64 | +cygwin=false | ||
| 65 | +msys=false | ||
| 66 | +darwin=false | ||
| 67 | +nonstop=false | ||
| 68 | +case "`uname`" in | ||
| 69 | + CYGWIN* ) | ||
| 70 | + cygwin=true | ||
| 71 | + ;; | ||
| 72 | + Darwin* ) | ||
| 73 | + darwin=true | ||
| 74 | + ;; | ||
| 75 | + MINGW* ) | ||
| 76 | + msys=true | ||
| 77 | + ;; | ||
| 78 | + NONSTOP* ) | ||
| 79 | + nonstop=true | ||
| 80 | + ;; | ||
| 81 | +esac | ||
| 82 | + | ||
| 83 | +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar | ||
| 84 | + | ||
| 85 | + | ||
| 86 | +# Determine the Java command to use to start the JVM. | ||
| 87 | +if [ -n "$JAVA_HOME" ] ; then | ||
| 88 | + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then | ||
| 89 | + # IBM's JDK on AIX uses strange locations for the executables | ||
| 90 | + JAVACMD="$JAVA_HOME/jre/sh/java" | ||
| 91 | + else | ||
| 92 | + JAVACMD="$JAVA_HOME/bin/java" | ||
| 93 | + fi | ||
| 94 | + if [ ! -x "$JAVACMD" ] ; then | ||
| 95 | + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME | ||
| 96 | + | ||
| 97 | +Please set the JAVA_HOME variable in your environment to match the | ||
| 98 | +location of your Java installation." | ||
| 99 | + fi | ||
| 100 | +else | ||
| 101 | + JAVACMD="java" | ||
| 102 | + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. | ||
| 103 | + | ||
| 104 | +Please set the JAVA_HOME variable in your environment to match the | ||
| 105 | +location of your Java installation." | ||
| 106 | +fi | ||
| 107 | + | ||
| 108 | +# Increase the maximum file descriptors if we can. | ||
| 109 | +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then | ||
| 110 | + MAX_FD_LIMIT=`ulimit -H -n` | ||
| 111 | + if [ $? -eq 0 ] ; then | ||
| 112 | + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then | ||
| 113 | + MAX_FD="$MAX_FD_LIMIT" | ||
| 114 | + fi | ||
| 115 | + ulimit -n $MAX_FD | ||
| 116 | + if [ $? -ne 0 ] ; then | ||
| 117 | + warn "Could not set maximum file descriptor limit: $MAX_FD" | ||
| 118 | + fi | ||
| 119 | + else | ||
| 120 | + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" | ||
| 121 | + fi | ||
| 122 | +fi | ||
| 123 | + | ||
| 124 | +# For Darwin, add options to specify how the application appears in the dock | ||
| 125 | +if $darwin; then | ||
| 126 | + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" | ||
| 127 | +fi | ||
| 128 | + | ||
| 129 | +# For Cygwin or MSYS, switch paths to Windows format before running java | ||
| 130 | +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then | ||
| 131 | + APP_HOME=`cygpath --path --mixed "$APP_HOME"` | ||
| 132 | + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` | ||
| 133 | + | ||
| 134 | + JAVACMD=`cygpath --unix "$JAVACMD"` | ||
| 135 | + | ||
| 136 | + # We build the pattern for arguments to be converted via cygpath | ||
| 137 | + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` | ||
| 138 | + SEP="" | ||
| 139 | + for dir in $ROOTDIRSRAW ; do | ||
| 140 | + ROOTDIRS="$ROOTDIRS$SEP$dir" | ||
| 141 | + SEP="|" | ||
| 142 | + done | ||
| 143 | + OURCYGPATTERN="(^($ROOTDIRS))" | ||
| 144 | + # Add a user-defined pattern to the cygpath arguments | ||
| 145 | + if [ "$GRADLE_CYGPATTERN" != "" ] ; then | ||
| 146 | + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" | ||
| 147 | + fi | ||
| 148 | + # Now convert the arguments - kludge to limit ourselves to /bin/sh | ||
| 149 | + i=0 | ||
| 150 | + for arg in "$@" ; do | ||
| 151 | + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` | ||
| 152 | + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option | ||
| 153 | + | ||
| 154 | + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition | ||
| 155 | + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` | ||
| 156 | + else | ||
| 157 | + eval `echo args$i`="\"$arg\"" | ||
| 158 | + fi | ||
| 159 | + i=`expr $i + 1` | ||
| 160 | + done | ||
| 161 | + case $i in | ||
| 162 | + 0) set -- ;; | ||
| 163 | + 1) set -- "$args0" ;; | ||
| 164 | + 2) set -- "$args0" "$args1" ;; | ||
| 165 | + 3) set -- "$args0" "$args1" "$args2" ;; | ||
| 166 | + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; | ||
| 167 | + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; | ||
| 168 | + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; | ||
| 169 | + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; | ||
| 170 | + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; | ||
| 171 | + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; | ||
| 172 | + esac | ||
| 173 | +fi | ||
| 174 | + | ||
| 175 | +# Escape application args | ||
| 176 | +save () { | ||
| 177 | + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done | ||
| 178 | + echo " " | ||
| 179 | +} | ||
| 180 | +APP_ARGS=`save "$@"` | ||
| 181 | + | ||
| 182 | +# Collect all arguments for the java command, following the shell quoting and substitution rules | ||
| 183 | +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" | ||
| 184 | + | ||
| 185 | +exec "$JAVACMD" "$@" |
android/SherpaOnnxKws/gradlew.bat
0 → 100644
| 1 | +@rem | ||
| 2 | +@rem Copyright 2015 the original author or authors. | ||
| 3 | +@rem | ||
| 4 | +@rem Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | +@rem you may not use this file except in compliance with the License. | ||
| 6 | +@rem You may obtain a copy of the License at | ||
| 7 | +@rem | ||
| 8 | +@rem https://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | +@rem | ||
| 10 | +@rem Unless required by applicable law or agreed to in writing, software | ||
| 11 | +@rem distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | +@rem See the License for the specific language governing permissions and | ||
| 14 | +@rem limitations under the License. | ||
| 15 | +@rem | ||
| 16 | + | ||
| 17 | +@if "%DEBUG%" == "" @echo off | ||
| 18 | +@rem ########################################################################## | ||
| 19 | +@rem | ||
| 20 | +@rem Gradle startup script for Windows | ||
| 21 | +@rem | ||
| 22 | +@rem ########################################################################## | ||
| 23 | + | ||
| 24 | +@rem Set local scope for the variables with windows NT shell | ||
| 25 | +if "%OS%"=="Windows_NT" setlocal | ||
| 26 | + | ||
| 27 | +set DIRNAME=%~dp0 | ||
| 28 | +if "%DIRNAME%" == "" set DIRNAME=. | ||
| 29 | +set APP_BASE_NAME=%~n0 | ||
| 30 | +set APP_HOME=%DIRNAME% | ||
| 31 | + | ||
| 32 | +@rem Resolve any "." and ".." in APP_HOME to make it shorter. | ||
| 33 | +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi | ||
| 34 | + | ||
| 35 | +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. | ||
| 36 | +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" | ||
| 37 | + | ||
| 38 | +@rem Find java.exe | ||
| 39 | +if defined JAVA_HOME goto findJavaFromJavaHome | ||
| 40 | + | ||
| 41 | +set JAVA_EXE=java.exe | ||
| 42 | +%JAVA_EXE% -version >NUL 2>&1 | ||
| 43 | +if "%ERRORLEVEL%" == "0" goto execute | ||
| 44 | + | ||
| 45 | +echo. | ||
| 46 | +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. | ||
| 47 | +echo. | ||
| 48 | +echo Please set the JAVA_HOME variable in your environment to match the | ||
| 49 | +echo location of your Java installation. | ||
| 50 | + | ||
| 51 | +goto fail | ||
| 52 | + | ||
| 53 | +:findJavaFromJavaHome | ||
| 54 | +set JAVA_HOME=%JAVA_HOME:"=% | ||
| 55 | +set JAVA_EXE=%JAVA_HOME%/bin/java.exe | ||
| 56 | + | ||
| 57 | +if exist "%JAVA_EXE%" goto execute | ||
| 58 | + | ||
| 59 | +echo. | ||
| 60 | +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% | ||
| 61 | +echo. | ||
| 62 | +echo Please set the JAVA_HOME variable in your environment to match the | ||
| 63 | +echo location of your Java installation. | ||
| 64 | + | ||
| 65 | +goto fail | ||
| 66 | + | ||
| 67 | +:execute | ||
| 68 | +@rem Setup the command line | ||
| 69 | + | ||
| 70 | +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar | ||
| 71 | + | ||
| 72 | + | ||
| 73 | +@rem Execute Gradle | ||
| 74 | +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* | ||
| 75 | + | ||
| 76 | +:end | ||
| 77 | +@rem End local scope for the variables with windows NT shell | ||
| 78 | +if "%ERRORLEVEL%"=="0" goto mainEnd | ||
| 79 | + | ||
| 80 | +:fail | ||
| 81 | +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of | ||
| 82 | +rem the _cmd.exe /c_ return code! | ||
| 83 | +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 | ||
| 84 | +exit /b 1 | ||
| 85 | + | ||
| 86 | +:mainEnd | ||
| 87 | +if "%OS%"=="Windows_NT" endlocal | ||
| 88 | + | ||
| 89 | +:omega |
android/SherpaOnnxKws/settings.gradle
0 → 100644
| 1 | +pluginManagement { | ||
| 2 | + repositories { | ||
| 3 | + gradlePluginPortal() | ||
| 4 | + google() | ||
| 5 | + mavenCentral() | ||
| 6 | + } | ||
| 7 | +} | ||
| 8 | +dependencyResolutionManagement { | ||
| 9 | + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) | ||
| 10 | + repositories { | ||
| 11 | + google() | ||
| 12 | + mavenCentral() | ||
| 13 | + } | ||
| 14 | +} | ||
| 15 | +rootProject.name = "SherpaOnnxKws" | ||
| 16 | +include ':app' |
build-kws-apk.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +# Please set the environment variable ANDROID_NDK | ||
| 4 | +# before running this script | ||
| 5 | + | ||
| 6 | +# Inside the $ANDROID_NDK directory, you can find a binary ndk-build | ||
| 7 | +# and some other files like the file "build/cmake/android.toolchain.cmake" | ||
| 8 | + | ||
| 9 | +set -e | ||
| 10 | + | ||
| 11 | +log() { | ||
| 12 | + # This function is from espnet | ||
| 13 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 14 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 15 | +} | ||
| 16 | + | ||
| 17 | +SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) | ||
| 18 | + | ||
| 19 | +log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}" | ||
| 20 | + | ||
| 21 | +log "====================arm64-v8a=================" | ||
| 22 | +./build-android-arm64-v8a.sh | ||
| 23 | +log "====================armv7-eabi================" | ||
| 24 | +./build-android-armv7-eabi.sh | ||
| 25 | +log "====================x86-64====================" | ||
| 26 | +./build-android-x86-64.sh | ||
| 27 | +log "====================x86====================" | ||
| 28 | +./build-android-x86.sh | ||
| 29 | + | ||
| 30 | +mkdir -p apks | ||
| 31 | + | ||
| 32 | +# Download the model | ||
| 33 | +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 | ||
| 34 | + | ||
| 35 | +if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then | ||
| 36 | + repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git | ||
| 37 | + log "Start testing ${repo_url}" | ||
| 38 | + log "Download pretrained model and test-data from $repo_url" | ||
| 39 | + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 40 | + pushd $repo | ||
| 41 | + git lfs pull --include "*.onnx" | ||
| 42 | + | ||
| 43 | + # remove .git to save spaces | ||
| 44 | + rm -rf .git | ||
| 45 | + rm *.int8.onnx | ||
| 46 | + rm README.md configuration.json .gitattributes | ||
| 47 | + rm -rfv test_wavs | ||
| 48 | + ls -lh | ||
| 49 | + popd | ||
| 50 | + | ||
| 51 | + mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/ | ||
| 52 | +fi | ||
| 53 | + | ||
| 54 | +tree ./android/SherpaOnnxKws/app/src/main/assets/ | ||
| 55 | + | ||
| 56 | +for arch in arm64-v8a armeabi-v7a x86_64 x86; do | ||
| 57 | + log "------------------------------------------------------------" | ||
| 58 | + log "build apk for $arch" | ||
| 59 | + log "------------------------------------------------------------" | ||
| 60 | + src_arch=$arch | ||
| 61 | + if [ $arch == "armeabi-v7a" ]; then | ||
| 62 | + src_arch=armv7-eabi | ||
| 63 | + elif [ $arch == "x86_64" ]; then | ||
| 64 | + src_arch=x86-64 | ||
| 65 | + fi | ||
| 66 | + | ||
| 67 | + ls -lh ./build-android-$src_arch/install/lib/*.so | ||
| 68 | + | ||
| 69 | + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/ | ||
| 70 | + | ||
| 71 | + pushd ./android/SherpaOnnxKws | ||
| 72 | + ./gradlew build | ||
| 73 | + popd | ||
| 74 | + | ||
| 75 | + mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-wenetspeech-zh-${SHERPA_ONNX_VERSION}-$arch.apk | ||
| 76 | + ls -lh apks | ||
| 77 | + rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so | ||
| 78 | +done | ||
| 79 | + | ||
| 80 | +git checkout . | ||
| 81 | + | ||
| 82 | +rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo | ||
| 83 | + | ||
| 84 | +# English model | ||
| 85 | +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 | ||
| 86 | + | ||
| 87 | +if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then | ||
| 88 | + repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git | ||
| 89 | + log "Start testing ${repo_url}" | ||
| 90 | + log "Download pretrained model and test-data from $repo_url" | ||
| 91 | + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 92 | + pushd $repo | ||
| 93 | + git lfs pull --include "*.onnx" | ||
| 94 | + | ||
| 95 | + # remove .git to save spaces | ||
| 96 | + rm -rf .git | ||
| 97 | + rm *.int8.onnx | ||
| 98 | + rm README.md configuration.json .gitattributes | ||
| 99 | + rm -rfv test_wavs | ||
| 100 | + ls -lh | ||
| 101 | + popd | ||
| 102 | + | ||
| 103 | + mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/ | ||
| 104 | +fi | ||
| 105 | + | ||
| 106 | +tree ./android/SherpaOnnxKws/app/src/main/assets/ | ||
| 107 | + | ||
| 108 | +pushd android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx | ||
| 109 | +sed -i.bak s/"type = 0"/"type = 1"/ ./MainActivity.kt | ||
| 110 | +git diff | ||
| 111 | +popd | ||
| 112 | + | ||
| 113 | +for arch in arm64-v8a armeabi-v7a x86_64 x86; do | ||
| 114 | + log "------------------------------------------------------------" | ||
| 115 | + log "build apk for $arch" | ||
| 116 | + log "------------------------------------------------------------" | ||
| 117 | + src_arch=$arch | ||
| 118 | + if [ $arch == "armeabi-v7a" ]; then | ||
| 119 | + src_arch=armv7-eabi | ||
| 120 | + elif [ $arch == "x86_64" ]; then | ||
| 121 | + src_arch=x86-64 | ||
| 122 | + fi | ||
| 123 | + | ||
| 124 | + ls -lh ./build-android-$src_arch/install/lib/*.so | ||
| 125 | + | ||
| 126 | + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/ | ||
| 127 | + | ||
| 128 | + pushd ./android/SherpaOnnxKws | ||
| 129 | + ./gradlew build | ||
| 130 | + popd | ||
| 131 | + | ||
| 132 | + mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-gigaspeech-en-${SHERPA_ONNX_VERSION}-$arch.apk | ||
| 133 | + ls -lh apks | ||
| 134 | + rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so | ||
| 135 | +done | ||
| 136 | + | ||
| 137 | +git checkout . | ||
| 138 | + | ||
| 139 | +rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo |
| @@ -151,6 +151,7 @@ class BuildExtension(build_ext): | @@ -151,6 +151,7 @@ class BuildExtension(build_ext): | ||
| 151 | # Remember to also change setup.py | 151 | # Remember to also change setup.py |
| 152 | 152 | ||
| 153 | binaries = ["sherpa-onnx"] | 153 | binaries = ["sherpa-onnx"] |
| 154 | + binaries += ["sherpa-onnx-keyword-spotter"] | ||
| 154 | binaries += ["sherpa-onnx-offline"] | 155 | binaries += ["sherpa-onnx-offline"] |
| 155 | binaries += ["sherpa-onnx-microphone"] | 156 | binaries += ["sherpa-onnx-microphone"] |
| 156 | binaries += ["sherpa-onnx-microphone-offline"] | 157 | binaries += ["sherpa-onnx-microphone-offline"] |
| @@ -36,13 +36,44 @@ import argparse | @@ -36,13 +36,44 @@ import argparse | ||
| 36 | 36 | ||
| 37 | from sherpa_onnx import text2token | 37 | from sherpa_onnx import text2token |
| 38 | 38 | ||
| 39 | + | ||
| 39 | def get_args(): | 40 | def get_args(): |
| 40 | parser = argparse.ArgumentParser() | 41 | parser = argparse.ArgumentParser() |
| 41 | parser.add_argument( | 42 | parser.add_argument( |
| 42 | "--text", | 43 | "--text", |
| 43 | type=str, | 44 | type=str, |
| 44 | required=True, | 45 | required=True, |
| 45 | - help="Path to the input texts", | 46 | + help="""Path to the input texts. |
| 47 | + | ||
| 48 | + Each line in the texts contains the original phrase, it might also contain some | ||
| 49 | + extra items, for example, the boosting score (startting with :), the triggering | ||
| 50 | + threshold (startting with #, only used in keyword spotting task) and the original | ||
| 51 | + phrase (startting with @). Note: extra items will be kept in the output. | ||
| 52 | + | ||
| 53 | + example input 1 (tokens_type = ppinyin): | ||
| 54 | + | ||
| 55 | + 小爱同学 :2.0 #0.6 @小爱同学 | ||
| 56 | + 你好问问 :3.5 @你好问问 | ||
| 57 | + 小艺小艺 #0.6 @小艺小艺 | ||
| 58 | + | ||
| 59 | + example output 1: | ||
| 60 | + | ||
| 61 | + x iǎo ài t óng x ué :2.0 #0.6 @小爱同学 | ||
| 62 | + n ǐ h ǎo w èn w èn :3.5 @你好问问 | ||
| 63 | + x iǎo y ì x iǎo y ì #0.6 @小艺小艺 | ||
| 64 | + | ||
| 65 | + example input 2 (tokens_type = bpe): | ||
| 66 | + | ||
| 67 | + HELLO WORLD :1.5 #0.4 | ||
| 68 | + HI GOOGLE :2.0 #0.8 | ||
| 69 | + HEY SIRI #0.35 | ||
| 70 | + | ||
| 71 | + example output 2: | ||
| 72 | + | ||
| 73 | + ▁HE LL O ▁WORLD :1.5 #0.4 | ||
| 74 | + ▁HI ▁GO O G LE :2.0 #0.8 | ||
| 75 | + ▁HE Y ▁S I RI #0.35 | ||
| 76 | + """, | ||
| 46 | ) | 77 | ) |
| 47 | 78 | ||
| 48 | parser.add_argument( | 79 | parser.add_argument( |
| @@ -56,7 +87,11 @@ def get_args(): | @@ -56,7 +87,11 @@ def get_args(): | ||
| 56 | "--tokens-type", | 87 | "--tokens-type", |
| 57 | type=str, | 88 | type=str, |
| 58 | required=True, | 89 | required=True, |
| 59 | - help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", | 90 | + choices=["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], |
| 91 | + help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin. | ||
| 92 | + fpinyin means full pinyin, each cjkchar has a pinyin(with tone). | ||
| 93 | + ppinyin means partial pinyin, it splits pinyin into initial and final, | ||
| 94 | + """, | ||
| 60 | ) | 95 | ) |
| 61 | 96 | ||
| 62 | parser.add_argument( | 97 | parser.add_argument( |
| @@ -79,9 +114,21 @@ def main(): | @@ -79,9 +114,21 @@ def main(): | ||
| 79 | args = get_args() | 114 | args = get_args() |
| 80 | 115 | ||
| 81 | texts = [] | 116 | texts = [] |
| 117 | + # extra information like boosting score (start with :), triggering threshold (start with #) | ||
| 118 | + # original keyword (start with @) | ||
| 119 | + extra_info = [] | ||
| 82 | with open(args.text, "r", encoding="utf8") as f: | 120 | with open(args.text, "r", encoding="utf8") as f: |
| 83 | for line in f: | 121 | for line in f: |
| 84 | - texts.append(line.strip()) | 122 | + extra = [] |
| 123 | + text = [] | ||
| 124 | + toks = line.strip().split() | ||
| 125 | + for tok in toks: | ||
| 126 | + if tok[0] == ":" or tok[0] == "#" or tok[0] == "@": | ||
| 127 | + extra.append(tok) | ||
| 128 | + else: | ||
| 129 | + text.append(tok) | ||
| 130 | + texts.append(" ".join(text)) | ||
| 131 | + extra_info.append(extra) | ||
| 85 | encoded_texts = text2token( | 132 | encoded_texts = text2token( |
| 86 | texts, | 133 | texts, |
| 87 | tokens=args.tokens, | 134 | tokens=args.tokens, |
| @@ -89,7 +136,8 @@ def main(): | @@ -89,7 +136,8 @@ def main(): | ||
| 89 | bpe_model=args.bpe_model, | 136 | bpe_model=args.bpe_model, |
| 90 | ) | 137 | ) |
| 91 | with open(args.output, "w", encoding="utf8") as f: | 138 | with open(args.output, "w", encoding="utf8") as f: |
| 92 | - for txt in encoded_texts: | 139 | + for i, txt in enumerate(encoded_texts): |
| 140 | + txt += extra_info[i] | ||
| 93 | f.write(" ".join(txt) + "\n") | 141 | f.write(" ".join(txt) + "\n") |
| 94 | 142 | ||
| 95 | 143 |
| @@ -51,6 +51,7 @@ def get_binaries_to_install(): | @@ -51,6 +51,7 @@ def get_binaries_to_install(): | ||
| 51 | 51 | ||
| 52 | # Remember to also change cmake/cmake_extension.py | 52 | # Remember to also change cmake/cmake_extension.py |
| 53 | binaries = ["sherpa-onnx"] | 53 | binaries = ["sherpa-onnx"] |
| 54 | + binaries += ["sherpa-onnx-keyword-spotter"] | ||
| 54 | binaries += ["sherpa-onnx-offline"] | 55 | binaries += ["sherpa-onnx-offline"] |
| 55 | binaries += ["sherpa-onnx-microphone"] | 56 | binaries += ["sherpa-onnx-microphone"] |
| 56 | binaries += ["sherpa-onnx-microphone-offline"] | 57 | binaries += ["sherpa-onnx-microphone-offline"] |
| @@ -19,6 +19,8 @@ set(sources | @@ -19,6 +19,8 @@ set(sources | ||
| 19 | features.cc | 19 | features.cc |
| 20 | file-utils.cc | 20 | file-utils.cc |
| 21 | hypothesis.cc | 21 | hypothesis.cc |
| 22 | + keyword-spotter-impl.cc | ||
| 23 | + keyword-spotter.cc | ||
| 22 | offline-ctc-fst-decoder-config.cc | 24 | offline-ctc-fst-decoder-config.cc |
| 23 | offline-ctc-fst-decoder.cc | 25 | offline-ctc-fst-decoder.cc |
| 24 | offline-ctc-greedy-search-decoder.cc | 26 | offline-ctc-greedy-search-decoder.cc |
| @@ -87,6 +89,7 @@ set(sources | @@ -87,6 +89,7 @@ set(sources | ||
| 87 | stack.cc | 89 | stack.cc |
| 88 | symbol-table.cc | 90 | symbol-table.cc |
| 89 | text-utils.cc | 91 | text-utils.cc |
| 92 | + transducer-keyword-decoder.cc | ||
| 90 | transpose.cc | 93 | transpose.cc |
| 91 | unbind.cc | 94 | unbind.cc |
| 92 | utils.cc | 95 | utils.cc |
| @@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) | @@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) | ||
| 173 | endif() | 176 | endif() |
| 174 | 177 | ||
| 175 | add_executable(sherpa-onnx sherpa-onnx.cc) | 178 | add_executable(sherpa-onnx sherpa-onnx.cc) |
| 179 | +add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc) | ||
| 176 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) | 180 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) |
| 177 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 181 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 178 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 182 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| 179 | 183 | ||
| 180 | set(main_exes | 184 | set(main_exes |
| 181 | sherpa-onnx | 185 | sherpa-onnx |
| 186 | + sherpa-onnx-keyword-spotter | ||
| 182 | sherpa-onnx-offline | 187 | sherpa-onnx-offline |
| 183 | sherpa-onnx-offline-parallel | 188 | sherpa-onnx-offline-parallel |
| 184 | sherpa-onnx-offline-tts | 189 | sherpa-onnx-offline-tts |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | #include "sherpa-onnx/csrc/context-graph.h" | 5 | #include "sherpa-onnx/csrc/context-graph.h" |
| 6 | 6 | ||
| 7 | #include <chrono> // NOLINT | 7 | #include <chrono> // NOLINT |
| 8 | +#include <cmath> | ||
| 8 | #include <map> | 9 | #include <map> |
| 9 | #include <random> | 10 | #include <random> |
| 10 | #include <string> | 11 | #include <string> |
| @@ -15,27 +16,25 @@ | @@ -15,27 +16,25 @@ | ||
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 17 | 18 | ||
| 18 | -TEST(ContextGraph, TestBasic) { | 19 | +static void TestHelper(const std::map<std::string, float> &queries, float score, |
| 20 | + bool strict_mode) { | ||
| 19 | std::vector<std::string> contexts_str( | 21 | std::vector<std::string> contexts_str( |
| 20 | {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); | 22 | {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); |
| 21 | std::vector<std::vector<int32_t>> contexts; | 23 | std::vector<std::vector<int32_t>> contexts; |
| 24 | + std::vector<float> scores; | ||
| 22 | for (int32_t i = 0; i < contexts_str.size(); ++i) { | 25 | for (int32_t i = 0; i < contexts_str.size(); ++i) { |
| 23 | contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); | 26 | contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); |
| 27 | + scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100); | ||
| 24 | } | 28 | } |
| 25 | - auto context_graph = ContextGraph(contexts, 1); | ||
| 26 | - | ||
| 27 | - auto queries = std::map<std::string, float>{ | ||
| 28 | - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, | ||
| 29 | - {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, | ||
| 30 | - {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; | 29 | + auto context_graph = ContextGraph(contexts, 1, scores); |
| 31 | 30 | ||
| 32 | for (const auto &iter : queries) { | 31 | for (const auto &iter : queries) { |
| 33 | float total_scores = 0; | 32 | float total_scores = 0; |
| 34 | auto state = context_graph.Root(); | 33 | auto state = context_graph.Root(); |
| 35 | for (auto q : iter.first) { | 34 | for (auto q : iter.first) { |
| 36 | - auto res = context_graph.ForwardOneStep(state, q); | ||
| 37 | - total_scores += res.first; | ||
| 38 | - state = res.second; | 35 | + auto res = context_graph.ForwardOneStep(state, q, strict_mode); |
| 36 | + total_scores += std::get<0>(res); | ||
| 37 | + state = std::get<1>(res); | ||
| 39 | } | 38 | } |
| 40 | auto res = context_graph.Finalize(state); | 39 | auto res = context_graph.Finalize(state); |
| 41 | EXPECT_EQ(res.second->token, -1); | 40 | EXPECT_EQ(res.second->token, -1); |
| @@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) { | @@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) { | ||
| 44 | } | 43 | } |
| 45 | } | 44 | } |
| 46 | 45 | ||
| 46 | +TEST(ContextGraph, TestBasic) { | ||
| 47 | + auto queries = std::map<std::string, float>{ | ||
| 48 | + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, | ||
| 49 | + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, | ||
| 50 | + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; | ||
| 51 | + TestHelper(queries, 0, true); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +TEST(ContextGraph, TestBasicNonStrict) { | ||
| 55 | + auto queries = std::map<std::string, float>{ | ||
| 56 | + {"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3}, | ||
| 57 | + {"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}}; | ||
| 58 | + TestHelper(queries, 0, false); | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +TEST(ContextGraph, TestCustomize) { | ||
| 62 | + auto queries = std::map<std::string, float>{ | ||
| 63 | + {"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18}, | ||
| 64 | + {"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5}, | ||
| 65 | + {"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}}; | ||
| 66 | + TestHelper(queries, 5, true); | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +TEST(ContextGraph, TestCustomizeNonStrict) { | ||
| 70 | + auto queries = std::map<std::string, float>{ | ||
| 71 | + {"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84}, | ||
| 72 | + {"SHED", 10}, {"SHELF", 10}, {"HELL", 5}, | ||
| 73 | + {"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}}; | ||
| 74 | + TestHelper(queries, 5, false); | ||
| 75 | +} | ||
| 76 | + | ||
| 47 | TEST(ContextGraph, Benchmark) { | 77 | TEST(ContextGraph, Benchmark) { |
| 48 | std::random_device rd; | 78 | std::random_device rd; |
| 49 | std::mt19937 mt(rd()); | 79 | std::mt19937 mt(rd()); |
| @@ -4,22 +4,59 @@ | @@ -4,22 +4,59 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/context-graph.h" | 5 | #include "sherpa-onnx/csrc/context-graph.h" |
| 6 | 6 | ||
| 7 | +#include <algorithm> | ||
| 7 | #include <cassert> | 8 | #include <cassert> |
| 8 | #include <queue> | 9 | #include <queue> |
| 10 | +#include <string> | ||
| 11 | +#include <tuple> | ||
| 9 | #include <utility> | 12 | #include <utility> |
| 10 | 13 | ||
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 15 | + | ||
| 11 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 12 | -void ContextGraph::Build( | ||
| 13 | - const std::vector<std::vector<int32_t>> &token_ids) const { | 17 | +void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids, |
| 18 | + const std::vector<float> &scores, | ||
| 19 | + const std::vector<std::string> &phrases, | ||
| 20 | + const std::vector<float> &ac_thresholds) const { | ||
| 21 | + if (!scores.empty()) { | ||
| 22 | + SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size()); | ||
| 23 | + } | ||
| 24 | + if (!phrases.empty()) { | ||
| 25 | + SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size()); | ||
| 26 | + } | ||
| 27 | + if (!ac_thresholds.empty()) { | ||
| 28 | + SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size()); | ||
| 29 | + } | ||
| 14 | for (int32_t i = 0; i < token_ids.size(); ++i) { | 30 | for (int32_t i = 0; i < token_ids.size(); ++i) { |
| 15 | auto node = root_.get(); | 31 | auto node = root_.get(); |
| 32 | + float score = scores.empty() ? 0.0f : scores[i]; | ||
| 33 | + score = score == 0.0f ? context_score_ : score; | ||
| 34 | + float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i]; | ||
| 35 | + ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold; | ||
| 36 | + std::string phrase = phrases.empty() ? std::string() : phrases[i]; | ||
| 37 | + | ||
| 16 | for (int32_t j = 0; j < token_ids[i].size(); ++j) { | 38 | for (int32_t j = 0; j < token_ids[i].size(); ++j) { |
| 17 | int32_t token = token_ids[i][j]; | 39 | int32_t token = token_ids[i][j]; |
| 18 | if (0 == node->next.count(token)) { | 40 | if (0 == node->next.count(token)) { |
| 19 | bool is_end = j == token_ids[i].size() - 1; | 41 | bool is_end = j == token_ids[i].size() - 1; |
| 20 | node->next[token] = std::make_unique<ContextState>( | 42 | node->next[token] = std::make_unique<ContextState>( |
| 21 | - token, context_score_, node->node_score + context_score_, | ||
| 22 | - is_end ? node->node_score + context_score_ : 0, is_end); | 43 | + token, score, node->node_score + score, |
| 44 | + is_end ? node->node_score + score : 0, j + 1, | ||
| 45 | + is_end ? ac_threshold : 0.0f, is_end, | ||
| 46 | + is_end ? phrase : std::string()); | ||
| 47 | + } else { | ||
| 48 | + float token_score = std::max(score, node->next[token]->token_score); | ||
| 49 | + node->next[token]->token_score = token_score; | ||
| 50 | + float node_score = node->node_score + token_score; | ||
| 51 | + node->next[token]->node_score = node_score; | ||
| 52 | + bool is_end = | ||
| 53 | + (j == token_ids[i].size() - 1) || node->next[token]->is_end; | ||
| 54 | + node->next[token]->output_score = is_end ? node_score : 0.0f; | ||
| 55 | + node->next[token]->is_end = is_end; | ||
| 56 | + if (j == token_ids[i].size() - 1) { | ||
| 57 | + node->next[token]->phrase = phrase; | ||
| 58 | + node->next[token]->ac_threshold = ac_threshold; | ||
| 59 | + } | ||
| 23 | } | 60 | } |
| 24 | node = node->next[token].get(); | 61 | node = node->next[token].get(); |
| 25 | } | 62 | } |
| @@ -27,8 +64,9 @@ void ContextGraph::Build( | @@ -27,8 +64,9 @@ void ContextGraph::Build( | ||
| 27 | FillFailOutput(); | 64 | FillFailOutput(); |
| 28 | } | 65 | } |
| 29 | 66 | ||
| 30 | -std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
| 31 | - const ContextState *state, int32_t token) const { | 67 | +std::tuple<float, const ContextState *, const ContextState *> |
| 68 | +ContextGraph::ForwardOneStep(const ContextState *state, int32_t token, | ||
| 69 | + bool strict_mode /*= true*/) const { | ||
| 32 | const ContextState *node; | 70 | const ContextState *node; |
| 33 | float score; | 71 | float score; |
| 34 | if (1 == state->next.count(token)) { | 72 | if (1 == state->next.count(token)) { |
| @@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | @@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
| 45 | } | 83 | } |
| 46 | score = node->node_score - state->node_score; | 84 | score = node->node_score - state->node_score; |
| 47 | } | 85 | } |
| 86 | + | ||
| 48 | SHERPA_ONNX_CHECK(nullptr != node); | 87 | SHERPA_ONNX_CHECK(nullptr != node); |
| 49 | - return std::make_pair(score + node->output_score, node); | 88 | + |
| 89 | + const ContextState *matched_node = | ||
| 90 | + node->is_end ? node : (node->output != nullptr ? node->output : nullptr); | ||
| 91 | + | ||
| 92 | + if (!strict_mode && node->output_score != 0) { | ||
| 93 | + SHERPA_ONNX_CHECK(nullptr != matched_node); | ||
| 94 | + float output_score = | ||
| 95 | + node->is_end ? node->node_score | ||
| 96 | + : (node->output != nullptr ? node->output->node_score | ||
| 97 | + : node->node_score); | ||
| 98 | + return std::make_tuple(score + output_score - node->node_score, root_.get(), | ||
| 99 | + matched_node); | ||
| 100 | + } | ||
| 101 | + return std::make_tuple(score + node->output_score, node, matched_node); | ||
| 50 | } | 102 | } |
| 51 | 103 | ||
| 52 | std::pair<float, const ContextState *> ContextGraph::Finalize( | 104 | std::pair<float, const ContextState *> ContextGraph::Finalize( |
| @@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize( | @@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize( | ||
| 55 | return std::make_pair(score, root_.get()); | 107 | return std::make_pair(score, root_.get()); |
| 56 | } | 108 | } |
| 57 | 109 | ||
| 110 | +std::pair<bool, const ContextState *> ContextGraph::IsMatched( | ||
| 111 | + const ContextState *state) const { | ||
| 112 | + bool status = false; | ||
| 113 | + const ContextState *node = nullptr; | ||
| 114 | + if (state->is_end) { | ||
| 115 | + status = true; | ||
| 116 | + node = state; | ||
| 117 | + } else { | ||
| 118 | + if (state->output != nullptr) { | ||
| 119 | + status = true; | ||
| 120 | + node = state->output; | ||
| 121 | + } | ||
| 122 | + } | ||
| 123 | + return std::make_pair(status, node); | ||
| 124 | +} | ||
| 125 | + | ||
| 58 | void ContextGraph::FillFailOutput() const { | 126 | void ContextGraph::FillFailOutput() const { |
| 59 | std::queue<const ContextState *> node_queue; | 127 | std::queue<const ContextState *> node_queue; |
| 60 | for (auto &kv : root_->next) { | 128 | for (auto &kv : root_->next) { |
| @@ -6,6 +6,8 @@ | @@ -6,6 +6,8 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ | 6 | #define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <string> | ||
| 10 | +#include <tuple> | ||
| 9 | #include <unordered_map> | 11 | #include <unordered_map> |
| 10 | #include <utility> | 12 | #include <utility> |
| 11 | #include <vector> | 13 | #include <vector> |
| @@ -22,34 +24,55 @@ struct ContextState { | @@ -22,34 +24,55 @@ struct ContextState { | ||
| 22 | float token_score; | 24 | float token_score; |
| 23 | float node_score; | 25 | float node_score; |
| 24 | float output_score; | 26 | float output_score; |
| 27 | + int32_t level; | ||
| 28 | + float ac_threshold; | ||
| 25 | bool is_end; | 29 | bool is_end; |
| 30 | + std::string phrase; | ||
| 26 | std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; | 31 | std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; |
| 27 | const ContextState *fail = nullptr; | 32 | const ContextState *fail = nullptr; |
| 28 | const ContextState *output = nullptr; | 33 | const ContextState *output = nullptr; |
| 29 | 34 | ||
| 30 | ContextState() = default; | 35 | ContextState() = default; |
| 31 | ContextState(int32_t token, float token_score, float node_score, | 36 | ContextState(int32_t token, float token_score, float node_score, |
| 32 | - float output_score, bool is_end) | 37 | + float output_score, int32_t level = 0, float ac_threshold = 0.0f, |
| 38 | + bool is_end = false, const std::string &phrase = {}) | ||
| 33 | : token(token), | 39 | : token(token), |
| 34 | token_score(token_score), | 40 | token_score(token_score), |
| 35 | node_score(node_score), | 41 | node_score(node_score), |
| 36 | output_score(output_score), | 42 | output_score(output_score), |
| 37 | - is_end(is_end) {} | 43 | + level(level), |
| 44 | + ac_threshold(ac_threshold), | ||
| 45 | + is_end(is_end), | ||
| 46 | + phrase(phrase) {} | ||
| 38 | }; | 47 | }; |
| 39 | 48 | ||
| 40 | class ContextGraph { | 49 | class ContextGraph { |
| 41 | public: | 50 | public: |
| 42 | ContextGraph() = default; | 51 | ContextGraph() = default; |
| 43 | ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, | 52 | ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, |
| 44 | - float context_score) | ||
| 45 | - : context_score_(context_score) { | ||
| 46 | - root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false); | 53 | + float context_score, float ac_threshold, |
| 54 | + const std::vector<float> &scores = {}, | ||
| 55 | + const std::vector<std::string> &phrases = {}, | ||
| 56 | + const std::vector<float> &ac_thresholds = {}) | ||
| 57 | + : context_score_(context_score), ac_threshold_(ac_threshold) { | ||
| 58 | + root_ = std::make_unique<ContextState>(-1, 0, 0, 0); | ||
| 47 | root_->fail = root_.get(); | 59 | root_->fail = root_.get(); |
| 48 | - Build(token_ids); | 60 | + Build(token_ids, scores, phrases, ac_thresholds); |
| 49 | } | 61 | } |
| 50 | 62 | ||
| 51 | - std::pair<float, const ContextState *> ForwardOneStep( | ||
| 52 | - const ContextState *state, int32_t token_id) const; | 63 | + ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, |
| 64 | + float context_score, const std::vector<float> &scores = {}, | ||
| 65 | + const std::vector<std::string> &phrases = {}) | ||
| 66 | + : ContextGraph(token_ids, context_score, 0.0f, scores, phrases, | ||
| 67 | + std::vector<float>()) {} | ||
| 68 | + | ||
| 69 | + std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep( | ||
| 70 | + const ContextState *state, int32_t token_id, | ||
| 71 | + bool strict_mode = true) const; | ||
| 72 | + | ||
| 73 | + std::pair<bool, const ContextState *> IsMatched( | ||
| 74 | + const ContextState *state) const; | ||
| 75 | + | ||
| 53 | std::pair<float, const ContextState *> Finalize( | 76 | std::pair<float, const ContextState *> Finalize( |
| 54 | const ContextState *state) const; | 77 | const ContextState *state) const; |
| 55 | 78 | ||
| @@ -57,8 +80,12 @@ class ContextGraph { | @@ -57,8 +80,12 @@ class ContextGraph { | ||
| 57 | 80 | ||
| 58 | private: | 81 | private: |
| 59 | float context_score_; | 82 | float context_score_; |
| 83 | + float ac_threshold_; | ||
| 60 | std::unique_ptr<ContextState> root_; | 84 | std::unique_ptr<ContextState> root_; |
| 61 | - void Build(const std::vector<std::vector<int32_t>> &token_ids) const; | 85 | + void Build(const std::vector<std::vector<int32_t>> &token_ids, |
| 86 | + const std::vector<float> &scores, | ||
| 87 | + const std::vector<std::string> &phrases, | ||
| 88 | + const std::vector<float> &ac_thresholds) const; | ||
| 62 | void FillFailOutput() const; | 89 | void FillFailOutput() const; |
| 63 | }; | 90 | }; |
| 64 | 91 |
| @@ -28,6 +28,10 @@ struct Hypothesis { | @@ -28,6 +28,10 @@ struct Hypothesis { | ||
| 28 | // on which ys[i] is decoded. | 28 | // on which ys[i] is decoded. |
| 29 | std::vector<int32_t> timestamps; | 29 | std::vector<int32_t> timestamps; |
| 30 | 30 | ||
| 31 | + // The acoustic probability for each token in ys. | ||
| 32 | + // Only used for keyword spotting task. | ||
| 33 | + std::vector<float> ys_probs; | ||
| 34 | + | ||
| 31 | // The total score of ys in log space. | 35 | // The total score of ys in log space. |
| 32 | // It contains only acoustic scores | 36 | // It contains only acoustic scores |
| 33 | double log_prob = 0; | 37 | double log_prob = 0; |
sherpa-onnx/csrc/keyword-spotter-impl.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/keyword-spotter-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/keyword-spotter-impl.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create( | ||
| 12 | + const KeywordSpotterConfig &config) { | ||
| 13 | + if (!config.model_config.transducer.encoder.empty()) { | ||
| 14 | + return std::make_unique<KeywordSpotterTransducerImpl>(config); | ||
| 15 | + } | ||
| 16 | + | ||
| 17 | + SHERPA_ONNX_LOGE("Please specify a model"); | ||
| 18 | + exit(-1); | ||
| 19 | +} | ||
| 20 | + | ||
| 21 | +#if __ANDROID_API__ >= 9 | ||
| 22 | +std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create( | ||
| 23 | + AAssetManager *mgr, const KeywordSpotterConfig &config) { | ||
| 24 | + if (!config.model_config.transducer.encoder.empty()) { | ||
| 25 | + return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + SHERPA_ONNX_LOGE("Please specify a model"); | ||
| 29 | + exit(-1); | ||
| 30 | +} | ||
| 31 | +#endif | ||
| 32 | + | ||
| 33 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/keyword-spotter-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/keyword-spotter-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +class KeywordSpotterImpl { | ||
| 23 | + public: | ||
| 24 | + static std::unique_ptr<KeywordSpotterImpl> Create( | ||
| 25 | + const KeywordSpotterConfig &config); | ||
| 26 | + | ||
| 27 | +#if __ANDROID_API__ >= 9 | ||
| 28 | + static std::unique_ptr<KeywordSpotterImpl> Create( | ||
| 29 | + AAssetManager *mgr, const KeywordSpotterConfig &config); | ||
| 30 | +#endif | ||
| 31 | + | ||
| 32 | + virtual ~KeywordSpotterImpl() = default; | ||
| 33 | + | ||
| 34 | + virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; | ||
| 35 | + | ||
| 36 | + virtual std::unique_ptr<OnlineStream> CreateStream( | ||
| 37 | + const std::string &keywords) const = 0; | ||
| 38 | + | ||
| 39 | + virtual bool IsReady(OnlineStream *s) const = 0; | ||
| 40 | + | ||
| 41 | + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; | ||
| 42 | + | ||
| 43 | + virtual KeywordResult GetResult(OnlineStream *s) const = 0; | ||
| 44 | +}; | ||
| 45 | + | ||
| 46 | +} // namespace sherpa_onnx | ||
| 47 | + | ||
| 48 | +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <regex> // NOLINT | ||
| 11 | +#include <string> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include <strstream> | ||
| 17 | + | ||
| 18 | +#include "android/asset_manager.h" | ||
| 19 | +#include "android/asset_manager_jni.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 23 | +#include "sherpa-onnx/csrc/keyword-spotter-impl.h" | ||
| 24 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 25 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 26 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 27 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 28 | +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" | ||
| 29 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 30 | + | ||
| 31 | +namespace sherpa_onnx { | ||
| 32 | + | ||
| 33 | +static KeywordResult Convert(const TransducerKeywordResult &src, | ||
| 34 | + const SymbolTable &sym_table, float frame_shift_ms, | ||
| 35 | + int32_t subsampling_factor, | ||
| 36 | + int32_t frames_since_start) { | ||
| 37 | + KeywordResult r; | ||
| 38 | + r.tokens.reserve(src.tokens.size()); | ||
| 39 | + r.timestamps.reserve(src.tokens.size()); | ||
| 40 | + r.keyword = src.keyword; | ||
| 41 | + bool from_tokens = src.keyword.empty(); | ||
| 42 | + | ||
| 43 | + for (auto i : src.tokens) { | ||
| 44 | + auto sym = sym_table[i]; | ||
| 45 | + if (from_tokens) { | ||
| 46 | + r.keyword.append(sym); | ||
| 47 | + } | ||
| 48 | + r.tokens.push_back(std::move(sym)); | ||
| 49 | + } | ||
| 50 | + if (from_tokens && r.keyword.size()) { | ||
| 51 | + r.keyword = r.keyword.substr(1); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 55 | + for (auto t : src.timestamps) { | ||
| 56 | + float time = frame_shift_s * t; | ||
| 57 | + r.timestamps.push_back(time); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + r.start_time = frames_since_start * frame_shift_ms / 1000.; | ||
| 61 | + | ||
| 62 | + return r; | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 66 | + public: | ||
| 67 | + explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) | ||
| 68 | + : config_(config), | ||
| 69 | + model_(OnlineTransducerModel::Create(config.model_config)), | ||
| 70 | + sym_(config.model_config.tokens) { | ||
| 71 | + if (sym_.contains("<unk>")) { | ||
| 72 | + unk_id_ = sym_["<unk>"]; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + InitKeywords(); | ||
| 76 | + | ||
| 77 | + decoder_ = std::make_unique<TransducerKeywordDecoder>( | ||
| 78 | + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, | ||
| 79 | + unk_id_); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | +#if __ANDROID_API__ >= 9 | ||
| 83 | + KeywordSpotterTransducerImpl(AAssetManager *mgr, | ||
| 84 | + const KeywordSpotterConfig &config) | ||
| 85 | + : config_(config), | ||
| 86 | + model_(OnlineTransducerModel::Create(mgr, config.model_config)), | ||
| 87 | + sym_(mgr, config.model_config.tokens) { | ||
| 88 | + if (sym_.contains("<unk>")) { | ||
| 89 | + unk_id_ = sym_["<unk>"]; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + InitKeywords(mgr); | ||
| 93 | + | ||
| 94 | + decoder_ = std::make_unique<TransducerKeywordDecoder>( | ||
| 95 | + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, | ||
| 96 | + unk_id_); | ||
| 97 | + } | ||
| 98 | +#endif | ||
| 99 | + | ||
| 100 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 101 | + auto stream = | ||
| 102 | + std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_); | ||
| 103 | + InitOnlineStream(stream.get()); | ||
| 104 | + return stream; | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + std::unique_ptr<OnlineStream> CreateStream( | ||
| 108 | + const std::string &keywords) const override { | ||
| 109 | + auto kws = std::regex_replace(keywords, std::regex("/"), "\n"); | ||
| 110 | + std::istringstream is(kws); | ||
| 111 | + | ||
| 112 | + std::vector<std::vector<int32_t>> current_ids; | ||
| 113 | + std::vector<std::string> current_kws; | ||
| 114 | + std::vector<float> current_scores; | ||
| 115 | + std::vector<float> current_thresholds; | ||
| 116 | + | ||
| 117 | + if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, | ||
| 118 | + ¤t_thresholds)) { | ||
| 119 | + SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); | ||
| 120 | + return nullptr; | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + int32_t num_kws = current_ids.size(); | ||
| 124 | + int32_t num_default_kws = keywords_id_.size(); | ||
| 125 | + | ||
| 126 | + current_ids.insert(current_ids.end(), keywords_id_.begin(), | ||
| 127 | + keywords_id_.end()); | ||
| 128 | + | ||
| 129 | + if (!current_kws.empty() && !keywords_.empty()) { | ||
| 130 | + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); | ||
| 131 | + } else if (!current_kws.empty() && keywords_.empty()) { | ||
| 132 | + current_kws.insert(current_kws.end(), num_default_kws, std::string()); | ||
| 133 | + } else if (current_kws.empty() && !keywords_.empty()) { | ||
| 134 | + current_kws.insert(current_kws.end(), num_kws, std::string()); | ||
| 135 | + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); | ||
| 136 | + } else { | ||
| 137 | + // Do nothing. | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + if (!current_scores.empty() && !boost_scores_.empty()) { | ||
| 141 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 142 | + boost_scores_.end()); | ||
| 143 | + } else if (!current_scores.empty() && boost_scores_.empty()) { | ||
| 144 | + current_scores.insert(current_scores.end(), num_default_kws, | ||
| 145 | + config_.keywords_score); | ||
| 146 | + } else if (current_scores.empty() && !boost_scores_.empty()) { | ||
| 147 | + current_scores.insert(current_scores.end(), num_kws, | ||
| 148 | + config_.keywords_score); | ||
| 149 | + current_scores.insert(current_scores.end(), boost_scores_.begin(), | ||
| 150 | + boost_scores_.end()); | ||
| 151 | + } else { | ||
| 152 | + // Do nothing. | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + if (!current_thresholds.empty() && !thresholds_.empty()) { | ||
| 156 | + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), | ||
| 157 | + thresholds_.end()); | ||
| 158 | + } else if (!current_thresholds.empty() && thresholds_.empty()) { | ||
| 159 | + current_thresholds.insert(current_thresholds.end(), num_default_kws, | ||
| 160 | + config_.keywords_threshold); | ||
| 161 | + } else if (current_thresholds.empty() && !thresholds_.empty()) { | ||
| 162 | + current_thresholds.insert(current_thresholds.end(), num_kws, | ||
| 163 | + config_.keywords_threshold); | ||
| 164 | + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), | ||
| 165 | + thresholds_.end()); | ||
| 166 | + } else { | ||
| 167 | + // Do nothing. | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + auto keywords_graph = std::make_shared<ContextGraph>( | ||
| 171 | + current_ids, config_.keywords_score, config_.keywords_threshold, | ||
| 172 | + current_scores, current_kws, current_thresholds); | ||
| 173 | + | ||
| 174 | + auto stream = | ||
| 175 | + std::make_unique<OnlineStream>(config_.feat_config, keywords_graph); | ||
| 176 | + InitOnlineStream(stream.get()); | ||
| 177 | + return stream; | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + bool IsReady(OnlineStream *s) const override { | ||
| 181 | + return s->GetNumProcessedFrames() + model_->ChunkSize() < | ||
| 182 | + s->NumFramesReady(); | ||
| 183 | + } | ||
| 184 | + | ||
| 185 | + void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
| 186 | + int32_t chunk_size = model_->ChunkSize(); | ||
| 187 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 188 | + | ||
| 189 | + int32_t feature_dim = ss[0]->FeatureDim(); | ||
| 190 | + | ||
| 191 | + std::vector<TransducerKeywordResult> results(n); | ||
| 192 | + std::vector<float> features_vec(n * chunk_size * feature_dim); | ||
| 193 | + std::vector<std::vector<Ort::Value>> states_vec(n); | ||
| 194 | + std::vector<int64_t> all_processed_frames(n); | ||
| 195 | + | ||
| 196 | + for (int32_t i = 0; i != n; ++i) { | ||
| 197 | + SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr); | ||
| 198 | + | ||
| 199 | + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | ||
| 200 | + std::vector<float> features = | ||
| 201 | + ss[i]->GetFrames(num_processed_frames, chunk_size); | ||
| 202 | + | ||
| 203 | + // Question: should num_processed_frames include chunk_shift? | ||
| 204 | + ss[i]->GetNumProcessedFrames() += chunk_shift; | ||
| 205 | + | ||
| 206 | + std::copy(features.begin(), features.end(), | ||
| 207 | + features_vec.data() + i * chunk_size * feature_dim); | ||
| 208 | + | ||
| 209 | + results[i] = std::move(ss[i]->GetKeywordResult()); | ||
| 210 | + states_vec[i] = std::move(ss[i]->GetStates()); | ||
| 211 | + all_processed_frames[i] = num_processed_frames; | ||
| 212 | + } | ||
| 213 | + | ||
| 214 | + auto memory_info = | ||
| 215 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 216 | + | ||
| 217 | + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim}; | ||
| 218 | + | ||
| 219 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | ||
| 220 | + features_vec.size(), x_shape.data(), | ||
| 221 | + x_shape.size()); | ||
| 222 | + | ||
| 223 | + std::array<int64_t, 1> processed_frames_shape{ | ||
| 224 | + static_cast<int64_t>(all_processed_frames.size())}; | ||
| 225 | + | ||
| 226 | + Ort::Value processed_frames = Ort::Value::CreateTensor( | ||
| 227 | + memory_info, all_processed_frames.data(), all_processed_frames.size(), | ||
| 228 | + processed_frames_shape.data(), processed_frames_shape.size()); | ||
| 229 | + | ||
| 230 | + auto states = model_->StackStates(states_vec); | ||
| 231 | + | ||
| 232 | + auto pair = model_->RunEncoder(std::move(x), std::move(states), | ||
| 233 | + std::move(processed_frames)); | ||
| 234 | + | ||
| 235 | + decoder_->Decode(std::move(pair.first), ss, &results); | ||
| 236 | + | ||
| 237 | + std::vector<std::vector<Ort::Value>> next_states = | ||
| 238 | + model_->UnStackStates(pair.second); | ||
| 239 | + | ||
| 240 | + for (int32_t i = 0; i != n; ++i) { | ||
| 241 | + ss[i]->SetKeywordResult(results[i]); | ||
| 242 | + ss[i]->SetStates(std::move(next_states[i])); | ||
| 243 | + } | ||
| 244 | + } | ||
| 245 | + | ||
| 246 | + KeywordResult GetResult(OnlineStream *s) const override { | ||
| 247 | + TransducerKeywordResult decoder_result = s->GetKeywordResult(true); | ||
| 248 | + | ||
| 249 | + // TODO(fangjun): Remember to change these constants if needed | ||
| 250 | + int32_t frame_shift_ms = 10; | ||
| 251 | + int32_t subsampling_factor = 4; | ||
| 252 | + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | ||
| 253 | + s->GetNumFramesSinceStart()); | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | + private: | ||
| 257 | + void InitKeywords(std::istream &is) { | ||
| 258 | + if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_, | ||
| 259 | + &thresholds_)) { | ||
| 260 | + SHERPA_ONNX_LOGE("Encode keywords failed."); | ||
| 261 | + exit(-1); | ||
| 262 | + } | ||
| 263 | + keywords_graph_ = std::make_shared<ContextGraph>( | ||
| 264 | + keywords_id_, config_.keywords_score, config_.keywords_threshold, | ||
| 265 | + boost_scores_, keywords_, thresholds_); | ||
| 266 | + } | ||
| 267 | + | ||
| 268 | + void InitKeywords() { | ||
| 269 | + // each line in keywords_file contains space-separated words | ||
| 270 | + | ||
| 271 | + std::ifstream is(config_.keywords_file); | ||
| 272 | + if (!is) { | ||
| 273 | + SHERPA_ONNX_LOGE("Open keywords file failed: %s", | ||
| 274 | + config_.keywords_file.c_str()); | ||
| 275 | + exit(-1); | ||
| 276 | + } | ||
| 277 | + InitKeywords(is); | ||
| 278 | + } | ||
| 279 | + | ||
| 280 | +#if __ANDROID_API__ >= 9 | ||
| 281 | + void InitKeywords(AAssetManager *mgr) { | ||
| 282 | + // each line in keywords_file contains space-separated words | ||
| 283 | + | ||
| 284 | + auto buf = ReadFile(mgr, config_.keywords_file); | ||
| 285 | + | ||
| 286 | + std::istrstream is(buf.data(), buf.size()); | ||
| 287 | + | ||
| 288 | + if (!is) { | ||
| 289 | + SHERPA_ONNX_LOGE("Open keywords file failed: %s", | ||
| 290 | + config_.keywords_file.c_str()); | ||
| 291 | + exit(-1); | ||
| 292 | + } | ||
| 293 | + InitKeywords(is); | ||
| 294 | + } | ||
| 295 | +#endif | ||
| 296 | + | ||
| 297 | + void InitOnlineStream(OnlineStream *stream) const { | ||
| 298 | + auto r = decoder_->GetEmptyResult(); | ||
| 299 | + SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1); | ||
| 300 | + | ||
| 301 | + SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr); | ||
| 302 | + r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root(); | ||
| 303 | + | ||
| 304 | + stream->SetKeywordResult(r); | ||
| 305 | + stream->SetStates(model_->GetEncoderInitStates()); | ||
| 306 | + } | ||
| 307 | + | ||
| 308 | + private: | ||
| 309 | + KeywordSpotterConfig config_; | ||
| 310 | + std::vector<std::vector<int32_t>> keywords_id_; | ||
| 311 | + std::vector<float> boost_scores_; | ||
| 312 | + std::vector<float> thresholds_; | ||
| 313 | + std::vector<std::string> keywords_; | ||
| 314 | + ContextGraphPtr keywords_graph_; | ||
| 315 | + std::unique_ptr<OnlineTransducerModel> model_; | ||
| 316 | + std::unique_ptr<TransducerKeywordDecoder> decoder_; | ||
| 317 | + SymbolTable sym_; | ||
| 318 | + int32_t unk_id_ = -1; | ||
| 319 | +}; | ||
| 320 | + | ||
| 321 | +} // namespace sherpa_onnx | ||
| 322 | + | ||
| 323 | +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ |
sherpa-onnx/csrc/keyword-spotter.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/keyword-spotter.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <fstream> | ||
| 11 | +#include <iomanip> | ||
| 12 | +#include <memory> | ||
| 13 | +#include <sstream> | ||
| 14 | +#include <utility> | ||
| 15 | +#include <vector> | ||
| 16 | + | ||
| 17 | +#include "sherpa-onnx/csrc/keyword-spotter-impl.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +std::string KeywordResult::AsJsonString() const { | ||
| 22 | + std::ostringstream os; | ||
| 23 | + os << "{"; | ||
| 24 | + os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time | ||
| 25 | + << ", "; | ||
| 26 | + | ||
| 27 | + os << "\"keyword\"" | ||
| 28 | + << ": "; | ||
| 29 | + os << "\"" << keyword << "\"" | ||
| 30 | + << ", "; | ||
| 31 | + | ||
| 32 | + os << "\"" | ||
| 33 | + << "timestamps" | ||
| 34 | + << "\"" | ||
| 35 | + << ": "; | ||
| 36 | + os << "["; | ||
| 37 | + | ||
| 38 | + std::string sep = ""; | ||
| 39 | + for (auto t : timestamps) { | ||
| 40 | + os << sep << std::fixed << std::setprecision(2) << t; | ||
| 41 | + sep = ", "; | ||
| 42 | + } | ||
| 43 | + os << "], "; | ||
| 44 | + | ||
| 45 | + os << "\"" | ||
| 46 | + << "tokens" | ||
| 47 | + << "\"" | ||
| 48 | + << ":"; | ||
| 49 | + os << "["; | ||
| 50 | + | ||
| 51 | + sep = ""; | ||
| 52 | + auto oldFlags = os.flags(); | ||
| 53 | + for (const auto &t : tokens) { | ||
| 54 | + if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) { | ||
| 55 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str()); | ||
| 56 | + os << sep << "\"" | ||
| 57 | + << "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0]) | ||
| 58 | + << ">" | ||
| 59 | + << "\""; | ||
| 60 | + os.flags(oldFlags); | ||
| 61 | + } else { | ||
| 62 | + os << sep << "\"" << t << "\""; | ||
| 63 | + } | ||
| 64 | + sep = ", "; | ||
| 65 | + } | ||
| 66 | + os << "]"; | ||
| 67 | + os << "}"; | ||
| 68 | + | ||
| 69 | + return os.str(); | ||
| 70 | +} | ||
| 71 | + | ||
| 72 | +void KeywordSpotterConfig::Register(ParseOptions *po) { | ||
| 73 | + feat_config.Register(po); | ||
| 74 | + model_config.Register(po); | ||
| 75 | + | ||
| 76 | + po->Register("max-active-paths", &max_active_paths, | ||
| 77 | + "beam size used in modified beam search."); | ||
| 78 | + po->Register("num-trailing-blanks", &num_trailing_blanks, | ||
| 79 | + "The number of trailing blanks should have after the keyword."); | ||
| 80 | + po->Register("keywords-score", &keywords_score, | ||
| 81 | + "The bonus score for each token in context word/phrase."); | ||
| 82 | + po->Register("keywords-threshold", &keywords_threshold, | ||
| 83 | + "The acoustic threshold (probability) to trigger the keywords."); | ||
| 84 | + po->Register( | ||
| 85 | + "keywords-file", &keywords_file, | ||
| 86 | + "The file containing keywords, one word/phrase per line, and for each" | ||
| 87 | + "phrase the bpe/cjkchar are separated by a space. For example: " | ||
| 88 | + "▁HE LL O ▁WORLD" | ||
| 89 | + "你 好 世 界"); | ||
| 90 | +} | ||
| 91 | + | ||
| 92 | +bool KeywordSpotterConfig::Validate() const { | ||
| 93 | + if (keywords_file.empty()) { | ||
| 94 | + SHERPA_ONNX_LOGE("Please provide --keywords-file."); | ||
| 95 | + return false; | ||
| 96 | + } | ||
| 97 | + if (!std::ifstream(keywords_file.c_str()).good()) { | ||
| 98 | + SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str()); | ||
| 99 | + return false; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + return model_config.Validate(); | ||
| 103 | +} | ||
| 104 | + | ||
| 105 | +std::string KeywordSpotterConfig::ToString() const { | ||
| 106 | + std::ostringstream os; | ||
| 107 | + | ||
| 108 | + os << "KeywordSpotterConfig("; | ||
| 109 | + os << "feat_config=" << feat_config.ToString() << ", "; | ||
| 110 | + os << "model_config=" << model_config.ToString() << ", "; | ||
| 111 | + os << "max_active_paths=" << max_active_paths << ", "; | ||
| 112 | + os << "num_trailing_blanks=" << num_trailing_blanks << ", "; | ||
| 113 | + os << "keywords_score=" << keywords_score << ", "; | ||
| 114 | + os << "keywords_threshold=" << keywords_threshold << ", "; | ||
| 115 | + os << "keywords_file=\"" << keywords_file << "\")"; | ||
| 116 | + | ||
| 117 | + return os.str(); | ||
| 118 | +} | ||
| 119 | + | ||
| 120 | +KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config) | ||
| 121 | + : impl_(KeywordSpotterImpl::Create(config)) {} | ||
| 122 | + | ||
| 123 | +#if __ANDROID_API__ >= 9 | ||
| 124 | +KeywordSpotter::KeywordSpotter(AAssetManager *mgr, | ||
| 125 | + const KeywordSpotterConfig &config) | ||
| 126 | + : impl_(KeywordSpotterImpl::Create(mgr, config)) {} | ||
| 127 | +#endif | ||
| 128 | + | ||
| 129 | +KeywordSpotter::~KeywordSpotter() = default; | ||
| 130 | + | ||
| 131 | +std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const { | ||
| 132 | + return impl_->CreateStream(); | ||
| 133 | +} | ||
| 134 | + | ||
| 135 | +std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream( | ||
| 136 | + const std::string &keywords) const { | ||
| 137 | + return impl_->CreateStream(keywords); | ||
| 138 | +} | ||
| 139 | + | ||
| 140 | +bool KeywordSpotter::IsReady(OnlineStream *s) const { | ||
| 141 | + return impl_->IsReady(s); | ||
| 142 | +} | ||
| 143 | + | ||
| 144 | +void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { | ||
| 145 | + impl_->DecodeStreams(ss, n); | ||
| 146 | +} | ||
| 147 | + | ||
| 148 | +KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const { | ||
| 149 | + return impl_->GetResult(s); | ||
| 150 | +} | ||
| 151 | + | ||
| 152 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/keyword-spotter.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/keyword-spotter.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "sherpa-onnx/csrc/features.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 20 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 21 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 22 | + | ||
| 23 | +namespace sherpa_onnx { | ||
| 24 | + | ||
| 25 | +struct KeywordResult { | ||
| 26 | + /// The triggered keyword. | ||
| 27 | + /// For English, it consists of space separated words. | ||
| 28 | + /// For Chinese, it consists of Chinese words without spaces. | ||
| 29 | + /// Example 1: "hello world" | ||
| 30 | + /// Example 2: "你好世界" | ||
| 31 | + std::string keyword; | ||
| 32 | + | ||
| 33 | + /// Decoded results at the token level. | ||
| 34 | + /// For instance, for BPE-based models it consists of a list of BPE tokens. | ||
| 35 | + std::vector<std::string> tokens; | ||
| 36 | + | ||
| 37 | + /// timestamps.size() == tokens.size() | ||
| 38 | + /// timestamps[i] records the time in seconds when tokens[i] is decoded. | ||
| 39 | + std::vector<float> timestamps; | ||
| 40 | + | ||
| 41 | + /// Starting time of this segment. | ||
| 42 | + /// When an endpoint is detected, it will change | ||
| 43 | + float start_time = 0; | ||
| 44 | + | ||
| 45 | + /** Return a json string. | ||
| 46 | + * | ||
| 47 | + * The returned string contains: | ||
| 48 | + * { | ||
| 49 | + * "keyword": "The triggered keyword", | ||
| 50 | + * "tokens": [x, x, x], | ||
| 51 | + * "timestamps": [x, x, x], | ||
| 52 | + * "start_time": x, | ||
| 53 | + * } | ||
| 54 | + */ | ||
| 55 | + std::string AsJsonString() const; | ||
| 56 | +}; | ||
| 57 | + | ||
| 58 | +struct KeywordSpotterConfig { | ||
| 59 | + FeatureExtractorConfig feat_config; | ||
| 60 | + OnlineModelConfig model_config; | ||
| 61 | + | ||
| 62 | + int32_t max_active_paths = 4; | ||
| 63 | + | ||
| 64 | + int32_t num_trailing_blanks = 1; | ||
| 65 | + | ||
| 66 | + float keywords_score = 1.0; | ||
| 67 | + | ||
| 68 | + float keywords_threshold = 0.25; | ||
| 69 | + | ||
| 70 | + std::string keywords_file; | ||
| 71 | + | ||
| 72 | + KeywordSpotterConfig() = default; | ||
| 73 | + | ||
| 74 | + KeywordSpotterConfig(const FeatureExtractorConfig &feat_config, | ||
| 75 | + const OnlineModelConfig &model_config, | ||
| 76 | + int32_t max_active_paths, int32_t num_trailing_blanks, | ||
| 77 | + float keywords_score, float keywords_threshold, | ||
| 78 | + const std::string &keywords_file) | ||
| 79 | + : feat_config(feat_config), | ||
| 80 | + model_config(model_config), | ||
| 81 | + max_active_paths(max_active_paths), | ||
| 82 | + num_trailing_blanks(num_trailing_blanks), | ||
| 83 | + keywords_score(keywords_score), | ||
| 84 | + keywords_threshold(keywords_threshold), | ||
| 85 | + keywords_file(keywords_file) {} | ||
| 86 | + | ||
| 87 | + void Register(ParseOptions *po); | ||
| 88 | + bool Validate() const; | ||
| 89 | + | ||
| 90 | + std::string ToString() const; | ||
| 91 | +}; | ||
| 92 | + | ||
| 93 | +class KeywordSpotterImpl; | ||
| 94 | + | ||
| 95 | +class KeywordSpotter { | ||
| 96 | + public: | ||
| 97 | + explicit KeywordSpotter(const KeywordSpotterConfig &config); | ||
| 98 | + | ||
| 99 | +#if __ANDROID_API__ >= 9 | ||
| 100 | + KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config); | ||
| 101 | +#endif | ||
| 102 | + | ||
| 103 | + ~KeywordSpotter(); | ||
| 104 | + | ||
| 105 | + /** Create a stream for decoding. | ||
| 106 | + * | ||
| 107 | + */ | ||
| 108 | + std::unique_ptr<OnlineStream> CreateStream() const; | ||
| 109 | + | ||
| 110 | + /** Create a stream for decoding. | ||
| 111 | + * | ||
| 112 | + * @param The keywords for this string, it might contain several keywords, | ||
| 113 | + * the keywords are separated by "/". In each of the keywords, there | ||
| 114 | + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). | ||
| 115 | + * For example, keywords I LOVE YOU and HELLO WORLD, looks like: | ||
| 116 | + * | ||
| 117 | + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" | ||
| 118 | + */ | ||
| 119 | + std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const; | ||
| 120 | + | ||
| 121 | + /** | ||
| 122 | + * Return true if the given stream has enough frames for decoding. | ||
| 123 | + * Return false otherwise | ||
| 124 | + */ | ||
| 125 | + bool IsReady(OnlineStream *s) const; | ||
| 126 | + | ||
| 127 | + /** Decode a single stream. */ | ||
| 128 | + void DecodeStream(OnlineStream *s) const { | ||
| 129 | + OnlineStream *ss[1] = {s}; | ||
| 130 | + DecodeStreams(ss, 1); | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + /** Decode multiple streams in parallel | ||
| 134 | + * | ||
| 135 | + * @param ss Pointer array containing streams to be decoded. | ||
| 136 | + * @param n Number of streams in `ss`. | ||
| 137 | + */ | ||
| 138 | + void DecodeStreams(OnlineStream **ss, int32_t n) const; | ||
| 139 | + | ||
| 140 | + KeywordResult GetResult(OnlineStream *s) const; | ||
| 141 | + | ||
| 142 | + private: | ||
| 143 | + std::unique_ptr<KeywordSpotterImpl> impl_; | ||
| 144 | +}; | ||
| 145 | + | ||
| 146 | +} // namespace sherpa_onnx | ||
| 147 | + | ||
| 148 | +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ |
| @@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 93 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); | 93 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); |
| 94 | // now cur_encoder_out is of shape (num_hyps, joiner_dim) | 94 | // now cur_encoder_out is of shape (num_hyps, joiner_dim) |
| 95 | 95 | ||
| 96 | - Ort::Value logit = model_->RunJoiner( | ||
| 97 | - std::move(cur_encoder_out), View(&decoder_out)); | 96 | + Ort::Value logit = |
| 97 | + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | ||
| 98 | 98 | ||
| 99 | float *p_logit = logit.GetTensorMutableData<float>(); | 99 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 100 | LogSoftmax(p_logit, vocab_size, num_hyps); | 100 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| @@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 134 | if (context_graphs[i] != nullptr) { | 134 | if (context_graphs[i] != nullptr) { |
| 135 | auto context_res = | 135 | auto context_res = |
| 136 | context_graphs[i]->ForwardOneStep(context_state, new_token); | 136 | context_graphs[i]->ForwardOneStep(context_state, new_token); |
| 137 | - context_score = context_res.first; | ||
| 138 | - new_hyp.context_state = context_res.second; | 137 | + context_score = std::get<0>(context_res); |
| 138 | + new_hyp.context_state = std::get<1>(context_res); | ||
| 139 | } | 139 | } |
| 140 | } | 140 | } |
| 141 | 141 |
| @@ -51,6 +51,25 @@ class OnlineStream::Impl { | @@ -51,6 +51,25 @@ class OnlineStream::Impl { | ||
| 51 | 51 | ||
| 52 | OnlineTransducerDecoderResult &GetResult() { return result_; } | 52 | OnlineTransducerDecoderResult &GetResult() { return result_; } |
| 53 | 53 | ||
| 54 | + void SetKeywordResult(const TransducerKeywordResult &r) { | ||
| 55 | + keyword_result_ = r; | ||
| 56 | + } | ||
| 57 | + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) { | ||
| 58 | + if (remove_duplicates) { | ||
| 59 | + if (!prev_keyword_result_.timestamps.empty() && | ||
| 60 | + !keyword_result_.timestamps.empty() && | ||
| 61 | + keyword_result_.timestamps[0] <= | ||
| 62 | + prev_keyword_result_.timestamps.back()) { | ||
| 63 | + return empty_keyword_result_; | ||
| 64 | + } else { | ||
| 65 | + prev_keyword_result_ = keyword_result_; | ||
| 66 | + } | ||
| 67 | + return keyword_result_; | ||
| 68 | + } else { | ||
| 69 | + return keyword_result_; | ||
| 70 | + } | ||
| 71 | + } | ||
| 72 | + | ||
| 54 | OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } | 73 | OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } |
| 55 | 74 | ||
| 56 | void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } | 75 | void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } |
| @@ -93,6 +112,9 @@ class OnlineStream::Impl { | @@ -93,6 +112,9 @@ class OnlineStream::Impl { | ||
| 93 | int32_t start_frame_index_ = 0; // never reset | 112 | int32_t start_frame_index_ = 0; // never reset |
| 94 | int32_t segment_ = 0; | 113 | int32_t segment_ = 0; |
| 95 | OnlineTransducerDecoderResult result_; | 114 | OnlineTransducerDecoderResult result_; |
| 115 | + TransducerKeywordResult prev_keyword_result_; | ||
| 116 | + TransducerKeywordResult keyword_result_; | ||
| 117 | + TransducerKeywordResult empty_keyword_result_; | ||
| 96 | OnlineCtcDecoderResult ctc_result_; | 118 | OnlineCtcDecoderResult ctc_result_; |
| 97 | std::vector<Ort::Value> states_; // states for transducer or ctc models | 119 | std::vector<Ort::Value> states_; // states for transducer or ctc models |
| 98 | std::vector<float> paraformer_feat_cache_; | 120 | std::vector<float> paraformer_feat_cache_; |
| @@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | @@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | ||
| 149 | return impl_->GetResult(); | 171 | return impl_->GetResult(); |
| 150 | } | 172 | } |
| 151 | 173 | ||
| 174 | +void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) { | ||
| 175 | + impl_->SetKeywordResult(r); | ||
| 176 | +} | ||
| 177 | + | ||
| 178 | +TransducerKeywordResult &OnlineStream::GetKeywordResult( | ||
| 179 | + bool remove_duplicates /*=false*/) { | ||
| 180 | + return impl_->GetKeywordResult(remove_duplicates); | ||
| 181 | +} | ||
| 182 | + | ||
| 152 | OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { | 183 | OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { |
| 153 | return impl_->GetCtcResult(); | 184 | return impl_->GetCtcResult(); |
| 154 | } | 185 | } |
| @@ -14,9 +14,11 @@ | @@ -14,9 +14,11 @@ | ||
| 14 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" | 14 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" |
| 15 | #include "sherpa-onnx/csrc/online-paraformer-decoder.h" | 15 | #include "sherpa-onnx/csrc/online-paraformer-decoder.h" |
| 16 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 16 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 17 | +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" | ||
| 17 | 18 | ||
| 18 | namespace sherpa_onnx { | 19 | namespace sherpa_onnx { |
| 19 | 20 | ||
| 21 | +class TransducerKeywordResult; | ||
| 20 | class OnlineStream { | 22 | class OnlineStream { |
| 21 | public: | 23 | public: |
| 22 | explicit OnlineStream(const FeatureExtractorConfig &config = {}, | 24 | explicit OnlineStream(const FeatureExtractorConfig &config = {}, |
| @@ -76,6 +78,9 @@ class OnlineStream { | @@ -76,6 +78,9 @@ class OnlineStream { | ||
| 76 | void SetResult(const OnlineTransducerDecoderResult &r); | 78 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 77 | OnlineTransducerDecoderResult &GetResult(); | 79 | OnlineTransducerDecoderResult &GetResult(); |
| 78 | 80 | ||
| 81 | + void SetKeywordResult(const TransducerKeywordResult &r); | ||
| 82 | + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false); | ||
| 83 | + | ||
| 79 | void SetCtcResult(const OnlineCtcDecoderResult &r); | 84 | void SetCtcResult(const OnlineCtcDecoderResult &r); |
| 80 | OnlineCtcDecoderResult &GetCtcResult(); | 85 | OnlineCtcDecoderResult &GetCtcResult(); |
| 81 | 86 | ||
| @@ -92,7 +97,7 @@ class OnlineStream { | @@ -92,7 +97,7 @@ class OnlineStream { | ||
| 92 | */ | 97 | */ |
| 93 | const ContextGraphPtr &GetContextGraph() const; | 98 | const ContextGraphPtr &GetContextGraph() const; |
| 94 | 99 | ||
| 95 | - // for streaming parformer | 100 | + // for streaming paraformer |
| 96 | std::vector<float> &GetParaformerFeatCache(); | 101 | std::vector<float> &GetParaformerFeatCache(); |
| 97 | std::vector<float> &GetParaformerEncoderOutCache(); | 102 | std::vector<float> &GetParaformerEncoderOutCache(); |
| 98 | std::vector<float> &GetParaformerAlphaCache(); | 103 | std::vector<float> &GetParaformerAlphaCache(); |
| @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 75 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 75 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 76 | 76 | ||
| 77 | if (encoder_out_shape[0] != result->size()) { | 77 | if (encoder_out_shape[0] != result->size()) { |
| 78 | - fprintf(stderr, | ||
| 79 | - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 80 | - static_cast<int32_t>(encoder_out_shape[0]), | ||
| 81 | - static_cast<int32_t>(result->size())); | 78 | + SHERPA_ONNX_LOGE( |
| 79 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 80 | + static_cast<int32_t>(encoder_out_shape[0]), | ||
| 81 | + static_cast<int32_t>(result->size())); | ||
| 82 | exit(-1); | 82 | exit(-1); |
| 83 | } | 83 | } |
| 84 | 84 | ||
| @@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 119 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 119 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| 120 | cur_encoder_out = | 120 | cur_encoder_out = |
| 121 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); | 121 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); |
| 122 | - Ort::Value logit = model_->RunJoiner( | ||
| 123 | - std::move(cur_encoder_out), View(&decoder_out)); | 122 | + Ort::Value logit = |
| 123 | + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | ||
| 124 | 124 | ||
| 125 | float *p_logit = logit.GetTensorMutableData<float>(); | 125 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 126 | LogSoftmax(p_logit, vocab_size, num_hyps); | 126 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| @@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 164 | if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { | 164 | if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { |
| 165 | auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( | 165 | auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( |
| 166 | context_state, new_token); | 166 | context_state, new_token); |
| 167 | - context_score = context_res.first; | ||
| 168 | - new_hyp.context_state = context_res.second; | 167 | + context_score = std::get<0>(context_res); |
| 168 | + new_hyp.context_state = std::get<1>(context_res); | ||
| 169 | } | 169 | } |
| 170 | if (lm_) { | 170 | if (lm_) { |
| 171 | lm_->ComputeLMScore(lm_scale_, &new_hyp); | 171 | lm_->ComputeLMScore(lm_scale_, &new_hyp); |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <stdio.h> | ||
| 6 | + | ||
| 7 | +#include <iomanip> | ||
| 8 | +#include <iostream> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 13 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 14 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 15 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 16 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 17 | + | ||
| 18 | +typedef struct { | ||
| 19 | + std::unique_ptr<sherpa_onnx::OnlineStream> online_stream; | ||
| 20 | + std::string filename; | ||
| 21 | +} Stream; | ||
| 22 | + | ||
| 23 | +int main(int32_t argc, char *argv[]) { | ||
| 24 | + const char *kUsageMessage = R"usage( | ||
| 25 | +Usage: | ||
| 26 | + | ||
| 27 | +(1) Streaming transducer | ||
| 28 | + | ||
| 29 | + ./bin/sherpa-onnx-keyword-spotter \ | ||
| 30 | + --tokens=/path/to/tokens.txt \ | ||
| 31 | + --encoder=/path/to/encoder.onnx \ | ||
| 32 | + --decoder=/path/to/decoder.onnx \ | ||
| 33 | + --joiner=/path/to/joiner.onnx \ | ||
| 34 | + --provider=cpu \ | ||
| 35 | + --num-threads=2 \ | ||
| 36 | + --keywords-file=keywords.txt \ | ||
| 37 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 38 | + | ||
| 39 | +Note: It supports decoding multiple files in batches | ||
| 40 | + | ||
| 41 | +Default value for num_threads is 2. | ||
| 42 | +Valid values for provider: cpu (default), cuda, coreml. | ||
| 43 | +foo.wav should be of single channel, 16-bit PCM encoded wave file; its | ||
| 44 | +sampling rate can be arbitrary and does not need to be 16kHz. | ||
| 45 | + | ||
| 46 | +Please refer to | ||
| 47 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 48 | +for a list of pre-trained models to download. | ||
| 49 | +)usage"; | ||
| 50 | + | ||
| 51 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 52 | + sherpa_onnx::KeywordSpotterConfig config; | ||
| 53 | + | ||
| 54 | + config.Register(&po); | ||
| 55 | + | ||
| 56 | + po.Read(argc, argv); | ||
| 57 | + if (po.NumArgs() < 1) { | ||
| 58 | + po.PrintUsage(); | ||
| 59 | + exit(EXIT_FAILURE); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 63 | + | ||
| 64 | + if (!config.Validate()) { | ||
| 65 | + fprintf(stderr, "Errors in config!\n"); | ||
| 66 | + return -1; | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + sherpa_onnx::KeywordSpotter keyword_spotter(config); | ||
| 70 | + | ||
| 71 | + std::vector<Stream> ss; | ||
| 72 | + | ||
| 73 | + for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||
| 74 | + const std::string wav_filename = po.GetArg(i); | ||
| 75 | + int32_t sampling_rate = -1; | ||
| 76 | + | ||
| 77 | + bool is_ok = false; | ||
| 78 | + const std::vector<float> samples = | ||
| 79 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 80 | + | ||
| 81 | + if (!is_ok) { | ||
| 82 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 83 | + return -1; | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + auto s = keyword_spotter.CreateStream(); | ||
| 87 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 88 | + | ||
| 89 | + std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate)); | ||
| 90 | + // Note: We can call AcceptWaveform() multiple times. | ||
| 91 | + s->AcceptWaveform(sampling_rate, tail_paddings.data(), | ||
| 92 | + tail_paddings.size()); | ||
| 93 | + | ||
| 94 | + // Call InputFinished() to indicate that no audio samples are available | ||
| 95 | + s->InputFinished(); | ||
| 96 | + ss.push_back({std::move(s), wav_filename}); | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + std::vector<sherpa_onnx::OnlineStream *> ready_streams; | ||
| 100 | + for (;;) { | ||
| 101 | + ready_streams.clear(); | ||
| 102 | + for (auto &s : ss) { | ||
| 103 | + const auto p_ss = s.online_stream.get(); | ||
| 104 | + if (keyword_spotter.IsReady(p_ss)) { | ||
| 105 | + ready_streams.push_back(p_ss); | ||
| 106 | + } | ||
| 107 | + std::ostringstream os; | ||
| 108 | + const auto r = keyword_spotter.GetResult(p_ss); | ||
| 109 | + if (!r.keyword.empty()) { | ||
| 110 | + os << s.filename << "\n"; | ||
| 111 | + os << r.AsJsonString() << "\n\n"; | ||
| 112 | + fprintf(stderr, "%s", os.str().c_str()); | ||
| 113 | + } | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + if (ready_streams.empty()) { | ||
| 117 | + break; | ||
| 118 | + } | ||
| 119 | + keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size()); | ||
| 120 | + } | ||
| 121 | + return 0; | ||
| 122 | +} |
| 1 | +// sherpa-onnx/csrc/transducer-keywords-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <algorithm> | ||
| 6 | +#include <cmath> | ||
| 7 | +#include <utility> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/log.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const { | ||
| 17 | + int32_t context_size = model_->ContextSize(); | ||
| 18 | + int32_t blank_id = 0; // always 0 | ||
| 19 | + TransducerKeywordResult r; | ||
| 20 | + std::vector<int64_t> blanks(context_size, -1); | ||
| 21 | + blanks.back() = blank_id; | ||
| 22 | + | ||
| 23 | + Hypotheses blank_hyp({{blanks, 0}}); | ||
| 24 | + r.hyps = std::move(blank_hyp); | ||
| 25 | + return r; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +void TransducerKeywordDecoder::Decode( | ||
| 29 | + Ort::Value encoder_out, OnlineStream **ss, | ||
| 30 | + std::vector<TransducerKeywordResult> *result) { | ||
| 31 | + std::vector<int64_t> encoder_out_shape = | ||
| 32 | + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 33 | + | ||
| 34 | + if (encoder_out_shape[0] != result->size()) { | ||
| 35 | + SHERPA_ONNX_LOGE( | ||
| 36 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 37 | + static_cast<int32_t>(encoder_out_shape[0]), | ||
| 38 | + static_cast<int32_t>(result->size())); | ||
| 39 | + exit(-1); | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); | ||
| 43 | + | ||
| 44 | + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); | ||
| 45 | + int32_t vocab_size = model_->VocabSize(); | ||
| 46 | + int32_t context_size = model_->ContextSize(); | ||
| 47 | + std::vector<int64_t> blanks(context_size, -1); | ||
| 48 | + blanks.back() = 0; // blank_id is hardcoded to 0 | ||
| 49 | + | ||
| 50 | + std::vector<Hypotheses> cur; | ||
| 51 | + for (auto &r : *result) { | ||
| 52 | + cur.push_back(std::move(r.hyps)); | ||
| 53 | + } | ||
| 54 | + std::vector<Hypothesis> prev; | ||
| 55 | + | ||
| 56 | + for (int32_t t = 0; t != num_frames; ++t) { | ||
| 57 | + // Due to merging paths with identical token sequences, | ||
| 58 | + // not all utterances have "num_active_paths" paths. | ||
| 59 | + auto hyps_row_splits = GetHypsRowSplits(cur); | ||
| 60 | + int32_t num_hyps = | ||
| 61 | + hyps_row_splits.back(); // total num hyps for all utterance | ||
| 62 | + prev.clear(); | ||
| 63 | + for (auto &hyps : cur) { | ||
| 64 | + for (auto &h : hyps) { | ||
| 65 | + prev.push_back(std::move(h.second)); | ||
| 66 | + } | ||
| 67 | + } | ||
| 68 | + cur.clear(); | ||
| 69 | + cur.reserve(batch_size); | ||
| 70 | + | ||
| 71 | + Ort::Value decoder_input = model_->BuildDecoderInput(prev); | ||
| 72 | + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 73 | + | ||
| 74 | + Ort::Value cur_encoder_out = | ||
| 75 | + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | ||
| 76 | + cur_encoder_out = | ||
| 77 | + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); | ||
| 78 | + Ort::Value logit = | ||
| 79 | + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | ||
| 80 | + | ||
| 81 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 82 | + LogSoftmax(p_logit, vocab_size, num_hyps); | ||
| 83 | + | ||
| 84 | + // The acoustic logprobs for current frame | ||
| 85 | + std::vector<float> logprobs(vocab_size * num_hyps); | ||
| 86 | + std::memcpy(logprobs.data(), p_logit, | ||
| 87 | + sizeof(float) * vocab_size * num_hyps); | ||
| 88 | + | ||
| 89 | + // now p_logit contains log_softmax output, we rename it to p_logprob | ||
| 90 | + // to match what it actually contains | ||
| 91 | + float *p_logprob = p_logit; | ||
| 92 | + | ||
| 93 | + // add log_prob of each hypothesis to p_logprob before taking top_k | ||
| 94 | + for (int32_t i = 0; i != num_hyps; ++i) { | ||
| 95 | + float log_prob = prev[i].log_prob; | ||
| 96 | + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { | ||
| 97 | + *p_logprob += log_prob; | ||
| 98 | + } | ||
| 99 | + } | ||
| 100 | + p_logprob = p_logit; // we changed p_logprob in the above for loop | ||
| 101 | + | ||
| 102 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 103 | + int32_t frame_offset = (*result)[b].frame_offset; | ||
| 104 | + int32_t start = hyps_row_splits[b]; | ||
| 105 | + int32_t end = hyps_row_splits[b + 1]; | ||
| 106 | + auto topk = | ||
| 107 | + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); | ||
| 108 | + | ||
| 109 | + Hypotheses hyps; | ||
| 110 | + for (auto k : topk) { | ||
| 111 | + int32_t hyp_index = k / vocab_size + start; | ||
| 112 | + int32_t new_token = k % vocab_size; | ||
| 113 | + | ||
| 114 | + Hypothesis new_hyp = prev[hyp_index]; | ||
| 115 | + float context_score = 0; | ||
| 116 | + auto context_state = new_hyp.context_state; | ||
| 117 | + | ||
| 118 | + // blank is hardcoded to 0 | ||
| 119 | + // also, it treats unk as blank | ||
| 120 | + if (new_token != 0 && new_token != unk_id_) { | ||
| 121 | + new_hyp.ys.push_back(new_token); | ||
| 122 | + new_hyp.timestamps.push_back(t + frame_offset); | ||
| 123 | + new_hyp.ys_probs.push_back( | ||
| 124 | + exp(logprobs[hyp_index * vocab_size + new_token])); | ||
| 125 | + | ||
| 126 | + new_hyp.num_trailing_blanks = 0; | ||
| 127 | + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( | ||
| 128 | + context_state, new_token); | ||
| 129 | + context_score = std::get<0>(context_res); | ||
| 130 | + new_hyp.context_state = std::get<1>(context_res); | ||
| 131 | + // Start matching from the start state, forget the decoder history. | ||
| 132 | + if (new_hyp.context_state->token == -1) { | ||
| 133 | + new_hyp.ys = blanks; | ||
| 134 | + new_hyp.timestamps.clear(); | ||
| 135 | + new_hyp.ys_probs.clear(); | ||
| 136 | + } | ||
| 137 | + } else { | ||
| 138 | + ++new_hyp.num_trailing_blanks; | ||
| 139 | + } | ||
| 140 | + new_hyp.log_prob = p_logprob[k] + context_score; | ||
| 141 | + hyps.Add(std::move(new_hyp)); | ||
| 142 | + } // for (auto k : topk) | ||
| 143 | + | ||
| 144 | + auto best_hyp = hyps.GetMostProbable(false); | ||
| 145 | + | ||
| 146 | + auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state); | ||
| 147 | + bool matched = std::get<0>(status); | ||
| 148 | + const ContextState *matched_state = std::get<1>(status); | ||
| 149 | + | ||
| 150 | + if (matched) { | ||
| 151 | + float ys_prob = 0.0; | ||
| 152 | + int32_t length = best_hyp.ys_probs.size(); | ||
| 153 | + for (int32_t i = 1; i <= matched_state->level; ++i) { | ||
| 154 | + ys_prob += best_hyp.ys_probs[i]; | ||
| 155 | + } | ||
| 156 | + ys_prob /= matched_state->level; | ||
| 157 | + if (best_hyp.num_trailing_blanks > num_trailing_blanks_ && | ||
| 158 | + ys_prob >= matched_state->ac_threshold) { | ||
| 159 | + auto &r = (*result)[b]; | ||
| 160 | + r.tokens = {best_hyp.ys.end() - matched_state->level, | ||
| 161 | + best_hyp.ys.end()}; | ||
| 162 | + r.timestamps = {best_hyp.timestamps.end() - matched_state->level, | ||
| 163 | + best_hyp.timestamps.end()}; | ||
| 164 | + r.keyword = matched_state->phrase; | ||
| 165 | + | ||
| 166 | + hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}}); | ||
| 167 | + } | ||
| 168 | + } | ||
| 169 | + cur.push_back(std::move(hyps)); | ||
| 170 | + p_logprob += (end - start) * vocab_size; | ||
| 171 | + } // for (int32_t b = 0; b != batch_size; ++b) | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 175 | + auto &hyps = cur[b]; | ||
| 176 | + auto best_hyp = hyps.GetMostProbable(false); | ||
| 177 | + auto &r = (*result)[b]; | ||
| 178 | + r.hyps = std::move(hyps); | ||
| 179 | + r.num_trailing_blanks = best_hyp.num_trailing_blanks; | ||
| 180 | + r.frame_offset += num_frames; | ||
| 181 | + } | ||
| 182 | +} | ||
| 183 | + | ||
| 184 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/transducer-keywords-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct TransducerKeywordResult { | ||
| 18 | + /// Number of frames after subsampling we have decoded so far | ||
| 19 | + int32_t frame_offset = 0; | ||
| 20 | + | ||
| 21 | + /// The decoded token IDs for keywords | ||
| 22 | + std::vector<int64_t> tokens; | ||
| 23 | + | ||
| 24 | + /// The triggered keyword | ||
| 25 | + std::string keyword; | ||
| 26 | + | ||
| 27 | + /// number of trailing blank frames decoded so far | ||
| 28 | + int32_t num_trailing_blanks = 0; | ||
| 29 | + | ||
| 30 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 31 | + std::vector<int32_t> timestamps; | ||
| 32 | + | ||
| 33 | + // used only in modified beam_search | ||
| 34 | + Hypotheses hyps; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +class TransducerKeywordDecoder { | ||
| 38 | + public: | ||
| 39 | + TransducerKeywordDecoder(OnlineTransducerModel *model, | ||
| 40 | + int32_t max_active_paths, | ||
| 41 | + int32_t num_trailing_blanks, int32_t unk_id) | ||
| 42 | + : model_(model), | ||
| 43 | + max_active_paths_(max_active_paths), | ||
| 44 | + num_trailing_blanks_(num_trailing_blanks), | ||
| 45 | + unk_id_(unk_id) {} | ||
| 46 | + | ||
| 47 | + TransducerKeywordResult GetEmptyResult() const; | ||
| 48 | + | ||
| 49 | + void Decode(Ort::Value encoder_out, OnlineStream **ss, | ||
| 50 | + std::vector<TransducerKeywordResult> *result); | ||
| 51 | + | ||
| 52 | + private: | ||
| 53 | + OnlineTransducerModel *model_; // Not owned | ||
| 54 | + | ||
| 55 | + int32_t max_active_paths_; | ||
| 56 | + int32_t num_trailing_blanks_; | ||
| 57 | + int32_t unk_id_; | ||
| 58 | +}; | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx | ||
| 61 | + | ||
| 62 | +#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ |
| @@ -15,16 +15,31 @@ | @@ -15,16 +15,31 @@ | ||
| 15 | 15 | ||
| 16 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 17 | 17 | ||
| 18 | -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | ||
| 19 | - std::vector<std::vector<int32_t>> *hotwords) { | ||
| 20 | - hotwords->clear(); | ||
| 21 | - std::vector<int32_t> tmp; | 18 | +static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, |
| 19 | + std::vector<std::vector<int32_t>> *ids, | ||
| 20 | + std::vector<std::string> *phrases, | ||
| 21 | + std::vector<float> *scores, | ||
| 22 | + std::vector<float> *thresholds) { | ||
| 23 | + SHERPA_ONNX_CHECK(ids != nullptr); | ||
| 24 | + ids->clear(); | ||
| 25 | + | ||
| 26 | + std::vector<int32_t> tmp_ids; | ||
| 27 | + std::vector<float> tmp_scores; | ||
| 28 | + std::vector<float> tmp_thresholds; | ||
| 29 | + std::vector<std::string> tmp_phrases; | ||
| 30 | + | ||
| 22 | std::string line; | 31 | std::string line; |
| 23 | std::string word; | 32 | std::string word; |
| 33 | + bool has_scores = false; | ||
| 34 | + bool has_thresholds = false; | ||
| 35 | + bool has_phrases = false; | ||
| 24 | 36 | ||
| 25 | while (std::getline(is, line)) { | 37 | while (std::getline(is, line)) { |
| 38 | + float score = 0; | ||
| 39 | + float threshold = 0; | ||
| 40 | + std::string phrase = ""; | ||
| 41 | + | ||
| 26 | std::istringstream iss(line); | 42 | std::istringstream iss(line); |
| 27 | - std::vector<std::string> syms; | ||
| 28 | while (iss >> word) { | 43 | while (iss >> word) { |
| 29 | if (word.size() >= 3) { | 44 | if (word.size() >= 3) { |
| 30 | // For BPE-based models, we replace ▁ with a space | 45 | // For BPE-based models, we replace ▁ with a space |
| @@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | @@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | ||
| 35 | } | 50 | } |
| 36 | } | 51 | } |
| 37 | if (symbol_table.contains(word)) { | 52 | if (symbol_table.contains(word)) { |
| 38 | - int32_t number = symbol_table[word]; | ||
| 39 | - tmp.push_back(number); | 53 | + int32_t id = symbol_table[word]; |
| 54 | + tmp_ids.push_back(id); | ||
| 40 | } else { | 55 | } else { |
| 41 | - SHERPA_ONNX_LOGE( | ||
| 42 | - "Cannot find ID for hotword %s at line: %s. (Hint: words on " | ||
| 43 | - "the " | ||
| 44 | - "same line are separated by spaces)", | ||
| 45 | - word.c_str(), line.c_str()); | ||
| 46 | - return false; | 56 | + switch (word[0]) { |
| 57 | + case ':': // boosting score for current keyword | ||
| 58 | + score = std::stof(word.substr(1)); | ||
| 59 | + has_scores = true; | ||
| 60 | + break; | ||
| 61 | + case '#': // triggering threshold (probability) for current keyword | ||
| 62 | + threshold = std::stof(word.substr(1)); | ||
| 63 | + has_thresholds = true; | ||
| 64 | + break; | ||
| 65 | + case '@': // the original keyword string | ||
| 66 | + phrase = word.substr(1); | ||
| 67 | + has_phrases = true; | ||
| 68 | + break; | ||
| 69 | + default: | ||
| 70 | + SHERPA_ONNX_LOGE( | ||
| 71 | + "Cannot find ID for token %s at line: %s. (Hint: words on " | ||
| 72 | + "the same line are separated by spaces)", | ||
| 73 | + word.c_str(), line.c_str()); | ||
| 74 | + return false; | ||
| 75 | + } | ||
| 47 | } | 76 | } |
| 48 | } | 77 | } |
| 49 | - hotwords->push_back(std::move(tmp)); | 78 | + ids->push_back(std::move(tmp_ids)); |
| 79 | + tmp_scores.push_back(score); | ||
| 80 | + tmp_phrases.push_back(phrase); | ||
| 81 | + tmp_thresholds.push_back(threshold); | ||
| 82 | + } | ||
| 83 | + if (scores != nullptr) { | ||
| 84 | + if (has_scores) { | ||
| 85 | + scores->swap(tmp_scores); | ||
| 86 | + } else { | ||
| 87 | + scores->clear(); | ||
| 88 | + } | ||
| 89 | + } | ||
| 90 | + if (phrases != nullptr) { | ||
| 91 | + if (has_phrases) { | ||
| 92 | + *phrases = std::move(tmp_phrases); | ||
| 93 | + } else { | ||
| 94 | + phrases->clear(); | ||
| 95 | + } | ||
| 96 | + } | ||
| 97 | + if (thresholds != nullptr) { | ||
| 98 | + if (has_thresholds) { | ||
| 99 | + thresholds->swap(tmp_thresholds); | ||
| 100 | + } else { | ||
| 101 | + thresholds->clear(); | ||
| 102 | + } | ||
| 50 | } | 103 | } |
| 51 | return true; | 104 | return true; |
| 52 | } | 105 | } |
| 53 | 106 | ||
| 107 | +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | ||
| 108 | + std::vector<std::vector<int32_t>> *hotwords) { | ||
| 109 | + return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr); | ||
| 110 | +} | ||
| 111 | + | ||
| 112 | +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | ||
| 113 | + std::vector<std::vector<int32_t>> *keywords_id, | ||
| 114 | + std::vector<std::string> *keywords, | ||
| 115 | + std::vector<float> *boost_scores, | ||
| 116 | + std::vector<float> *threshold) { | ||
| 117 | + return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores, | ||
| 118 | + threshold); | ||
| 119 | +} | ||
| 120 | + | ||
| 54 | } // namespace sherpa_onnx | 121 | } // namespace sherpa_onnx |
| @@ -26,7 +26,32 @@ namespace sherpa_onnx { | @@ -26,7 +26,32 @@ namespace sherpa_onnx { | ||
| 26 | * otherwise returns false. | 26 | * otherwise returns false. |
| 27 | */ | 27 | */ |
| 28 | bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | 28 | bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, |
| 29 | - std::vector<std::vector<int32_t>> *hotwords); | 29 | + std::vector<std::vector<int32_t>> *hotwords_id); |
| 30 | + | ||
| 31 | +/* Encode the keywords in an input stream to be tokens ids. | ||
| 32 | + * | ||
| 33 | + * @param is The input stream, it contains several lines, one hotword for each | ||
| 34 | + * line. For each hotword, the tokens (cjkchar or bpe) are separated | ||
| 35 | + * by spaces, it might contain boosting score (starting with :), | ||
| 36 | + * triggering threshold (starting with #) and keyword string (starting | ||
| 37 | + * with @) too. | ||
| 38 | + * @param symbol_table The tokens table mapping symbols to ids. All the symbols | ||
| 39 | + * in the stream should be in the symbol_table, if not this | ||
| 40 | + * function returns fasle. | ||
| 41 | + * | ||
| 42 | + * @param keywords_id The encoded ids to be written to. | ||
| 43 | + * @param keywords The original keyword string to be written to. | ||
| 44 | + * @param boost_scores The boosting score for each keyword to be written to. | ||
| 45 | + * @param threshold The triggering threshold for each keyword to be written to. | ||
| 46 | + * | ||
| 47 | + * @return If all the symbols from ``is`` are in the symbol_table, returns true | ||
| 48 | + * otherwise returns false. | ||
| 49 | + */ | ||
| 50 | +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | ||
| 51 | + std::vector<std::vector<int32_t>> *keywords_id, | ||
| 52 | + std::vector<std::string> *keywords, | ||
| 53 | + std::vector<float> *boost_scores, | ||
| 54 | + std::vector<float> *threshold); | ||
| 30 | 55 | ||
| 31 | } // namespace sherpa_onnx | 56 | } // namespace sherpa_onnx |
| 32 | 57 |
| @@ -21,6 +21,7 @@ | @@ -21,6 +21,7 @@ | ||
| 21 | #include "android/asset_manager_jni.h" | 21 | #include "android/asset_manager_jni.h" |
| 22 | #endif | 22 | #endif |
| 23 | 23 | ||
| 24 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 24 | #include "sherpa-onnx/csrc/macros.h" | 25 | #include "sherpa-onnx/csrc/macros.h" |
| 25 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 26 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 26 | #include "sherpa-onnx/csrc/offline-tts.h" | 27 | #include "sherpa-onnx/csrc/offline-tts.h" |
| @@ -140,6 +141,73 @@ class SherpaOnnxVad { | @@ -140,6 +141,73 @@ class SherpaOnnxVad { | ||
| 140 | VoiceActivityDetector vad_; | 141 | VoiceActivityDetector vad_; |
| 141 | }; | 142 | }; |
| 142 | 143 | ||
| 144 | +class SherpaOnnxKws { | ||
| 145 | + public: | ||
| 146 | +#if __ANDROID_API__ >= 9 | ||
| 147 | + SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config) | ||
| 148 | + : keyword_spotter_(mgr, config), | ||
| 149 | + stream_(keyword_spotter_.CreateStream()) {} | ||
| 150 | +#endif | ||
| 151 | + | ||
| 152 | + explicit SherpaOnnxKws(const KeywordSpotterConfig &config) | ||
| 153 | + : keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {} | ||
| 154 | + | ||
| 155 | + void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { | ||
| 156 | + if (input_sample_rate_ == -1) { | ||
| 157 | + input_sample_rate_ = sample_rate; | ||
| 158 | + } | ||
| 159 | + | ||
| 160 | + stream_->AcceptWaveform(sample_rate, samples, n); | ||
| 161 | + } | ||
| 162 | + | ||
| 163 | + void InputFinished() const { | ||
| 164 | + std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0); | ||
| 165 | + stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), | ||
| 166 | + tail_padding.size()); | ||
| 167 | + stream_->InputFinished(); | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + // If keywords is an empty string, it just recreates the decoding stream | ||
| 171 | + // always returns true in this case. | ||
| 172 | + // If keywords is not empty, it will create a new decoding stream with | ||
| 173 | + // the given keywords appended to the default keywords. | ||
| 174 | + // Return false if errors occurred when adding keywords, true otherwise. | ||
| 175 | + bool Reset(const std::string &keywords = {}) { | ||
| 176 | + if (keywords.empty()) { | ||
| 177 | + stream_ = keyword_spotter_.CreateStream(); | ||
| 178 | + return true; | ||
| 179 | + } else { | ||
| 180 | + auto stream = keyword_spotter_.CreateStream(keywords); | ||
| 181 | + // Set new keywords failed, the stream_ will not be updated. | ||
| 182 | + if (stream == nullptr) { | ||
| 183 | + return false; | ||
| 184 | + } else { | ||
| 185 | + stream_ = std::move(stream); | ||
| 186 | + return true; | ||
| 187 | + } | ||
| 188 | + } | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + std::string GetKeyword() const { | ||
| 192 | + auto result = keyword_spotter_.GetResult(stream_.get()); | ||
| 193 | + return result.keyword; | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + std::vector<std::string> GetTokens() const { | ||
| 197 | + auto result = keyword_spotter_.GetResult(stream_.get()); | ||
| 198 | + return result.tokens; | ||
| 199 | + } | ||
| 200 | + | ||
| 201 | + bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); } | ||
| 202 | + | ||
| 203 | + void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); } | ||
| 204 | + | ||
| 205 | + private: | ||
| 206 | + KeywordSpotter keyword_spotter_; | ||
| 207 | + std::unique_ptr<OnlineStream> stream_; | ||
| 208 | + int32_t input_sample_rate_ = -1; | ||
| 209 | +}; | ||
| 210 | + | ||
| 143 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | 211 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { |
| 144 | OnlineRecognizerConfig ans; | 212 | OnlineRecognizerConfig ans; |
| 145 | 213 | ||
| @@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | @@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | ||
| 457 | return ans; | 525 | return ans; |
| 458 | } | 526 | } |
| 459 | 527 | ||
| 528 | +static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { | ||
| 529 | + KeywordSpotterConfig ans; | ||
| 530 | + | ||
| 531 | + jclass cls = env->GetObjectClass(config); | ||
| 532 | + jfieldID fid; | ||
| 533 | + | ||
| 534 | + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html | ||
| 535 | + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html | ||
| 536 | + | ||
| 537 | + //---------- decoding ---------- | ||
| 538 | + fid = env->GetFieldID(cls, "maxActivePaths", "I"); | ||
| 539 | + ans.max_active_paths = env->GetIntField(config, fid); | ||
| 540 | + | ||
| 541 | + fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;"); | ||
| 542 | + jstring s = (jstring)env->GetObjectField(config, fid); | ||
| 543 | + const char *p = env->GetStringUTFChars(s, nullptr); | ||
| 544 | + ans.keywords_file = p; | ||
| 545 | + env->ReleaseStringUTFChars(s, p); | ||
| 546 | + | ||
| 547 | + fid = env->GetFieldID(cls, "keywordsScore", "F"); | ||
| 548 | + ans.keywords_score = env->GetFloatField(config, fid); | ||
| 549 | + | ||
| 550 | + fid = env->GetFieldID(cls, "keywordsThreshold", "F"); | ||
| 551 | + ans.keywords_threshold = env->GetFloatField(config, fid); | ||
| 552 | + | ||
| 553 | + fid = env->GetFieldID(cls, "numTrailingBlanks", "I"); | ||
| 554 | + ans.num_trailing_blanks = env->GetIntField(config, fid); | ||
| 555 | + | ||
| 556 | + //---------- feat config ---------- | ||
| 557 | + fid = env->GetFieldID(cls, "featConfig", | ||
| 558 | + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); | ||
| 559 | + jobject feat_config = env->GetObjectField(config, fid); | ||
| 560 | + jclass feat_config_cls = env->GetObjectClass(feat_config); | ||
| 561 | + | ||
| 562 | + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); | ||
| 563 | + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); | ||
| 564 | + | ||
| 565 | + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); | ||
| 566 | + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); | ||
| 567 | + | ||
| 568 | + //---------- model config ---------- | ||
| 569 | + fid = env->GetFieldID(cls, "modelConfig", | ||
| 570 | + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); | ||
| 571 | + jobject model_config = env->GetObjectField(config, fid); | ||
| 572 | + jclass model_config_cls = env->GetObjectClass(model_config); | ||
| 573 | + | ||
| 574 | + // transducer | ||
| 575 | + fid = env->GetFieldID(model_config_cls, "transducer", | ||
| 576 | + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); | ||
| 577 | + jobject transducer_config = env->GetObjectField(model_config, fid); | ||
| 578 | + jclass transducer_config_cls = env->GetObjectClass(transducer_config); | ||
| 579 | + | ||
| 580 | + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); | ||
| 581 | + s = (jstring)env->GetObjectField(transducer_config, fid); | ||
| 582 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 583 | + ans.model_config.transducer.encoder = p; | ||
| 584 | + env->ReleaseStringUTFChars(s, p); | ||
| 585 | + | ||
| 586 | + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); | ||
| 587 | + s = (jstring)env->GetObjectField(transducer_config, fid); | ||
| 588 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 589 | + ans.model_config.transducer.decoder = p; | ||
| 590 | + env->ReleaseStringUTFChars(s, p); | ||
| 591 | + | ||
| 592 | + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); | ||
| 593 | + s = (jstring)env->GetObjectField(transducer_config, fid); | ||
| 594 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 595 | + ans.model_config.transducer.joiner = p; | ||
| 596 | + env->ReleaseStringUTFChars(s, p); | ||
| 597 | + | ||
| 598 | + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); | ||
| 599 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 600 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 601 | + ans.model_config.tokens = p; | ||
| 602 | + env->ReleaseStringUTFChars(s, p); | ||
| 603 | + | ||
| 604 | + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | ||
| 605 | + ans.model_config.num_threads = env->GetIntField(model_config, fid); | ||
| 606 | + | ||
| 607 | + fid = env->GetFieldID(model_config_cls, "debug", "Z"); | ||
| 608 | + ans.model_config.debug = env->GetBooleanField(model_config, fid); | ||
| 609 | + | ||
| 610 | + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | ||
| 611 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 612 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 613 | + ans.model_config.provider = p; | ||
| 614 | + env->ReleaseStringUTFChars(s, p); | ||
| 615 | + | ||
| 616 | + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); | ||
| 617 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 618 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 619 | + ans.model_config.model_type = p; | ||
| 620 | + env->ReleaseStringUTFChars(s, p); | ||
| 621 | + | ||
| 622 | + return ans; | ||
| 623 | +} | ||
| 624 | + | ||
| 460 | static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { | 625 | static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { |
| 461 | VadModelConfig ans; | 626 | VadModelConfig ans; |
| 462 | 627 | ||
| @@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( | @@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( | ||
| 1013 | jclass stringClass = env->FindClass("java/lang/String"); | 1178 | jclass stringClass = env->FindClass("java/lang/String"); |
| 1014 | 1179 | ||
| 1015 | // convert C++ list into jni string array | 1180 | // convert C++ list into jni string array |
| 1016 | - jobjectArray result = env->NewObjectArray(size, stringClass, NULL); | 1181 | + jobjectArray result = env->NewObjectArray(size, stringClass, nullptr); |
| 1182 | + for (int32_t i = 0; i < size; i++) { | ||
| 1183 | + // Convert the C++ string to a C string | ||
| 1184 | + const char *cstr = tokens[i].c_str(); | ||
| 1185 | + | ||
| 1186 | + // Convert the C string to a jstring | ||
| 1187 | + jstring jstr = env->NewStringUTF(cstr); | ||
| 1188 | + | ||
| 1189 | + // Set the array element | ||
| 1190 | + env->SetObjectArrayElement(result, i, jstr); | ||
| 1191 | + } | ||
| 1192 | + | ||
| 1193 | + return result; | ||
| 1194 | +} | ||
| 1195 | + | ||
| 1196 | +SHERPA_ONNX_EXTERN_C | ||
| 1197 | +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new( | ||
| 1198 | + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | ||
| 1199 | +#if __ANDROID_API__ >= 9 | ||
| 1200 | + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); | ||
| 1201 | + if (!mgr) { | ||
| 1202 | + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); | ||
| 1203 | + } | ||
| 1204 | +#endif | ||
| 1205 | + auto config = sherpa_onnx::GetKwsConfig(env, _config); | ||
| 1206 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 1207 | + auto model = new sherpa_onnx::SherpaOnnxKws( | ||
| 1208 | +#if __ANDROID_API__ >= 9 | ||
| 1209 | + mgr, | ||
| 1210 | +#endif | ||
| 1211 | + config); | ||
| 1212 | + | ||
| 1213 | + return (jlong)model; | ||
| 1214 | +} | ||
| 1215 | + | ||
| 1216 | +SHERPA_ONNX_EXTERN_C | ||
| 1217 | +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile( | ||
| 1218 | + JNIEnv *env, jobject /*obj*/, jobject _config) { | ||
| 1219 | + auto config = sherpa_onnx::GetKwsConfig(env, _config); | ||
| 1220 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 1221 | + auto model = new sherpa_onnx::SherpaOnnxKws(config); | ||
| 1222 | + | ||
| 1223 | + return (jlong)model; | ||
| 1224 | +} | ||
| 1225 | + | ||
| 1226 | +SHERPA_ONNX_EXTERN_C | ||
| 1227 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete( | ||
| 1228 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 1229 | + delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr); | ||
| 1230 | +} | ||
| 1231 | + | ||
| 1232 | +SHERPA_ONNX_EXTERN_C | ||
| 1233 | +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady( | ||
| 1234 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 1235 | + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr); | ||
| 1236 | + return model->IsReady(); | ||
| 1237 | +} | ||
| 1238 | + | ||
| 1239 | +SHERPA_ONNX_EXTERN_C | ||
| 1240 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode( | ||
| 1241 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 1242 | + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr); | ||
| 1243 | + model->Decode(); | ||
| 1244 | +} | ||
| 1245 | + | ||
| 1246 | +SHERPA_ONNX_EXTERN_C | ||
| 1247 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform( | ||
| 1248 | + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, | ||
| 1249 | + jint sample_rate) { | ||
| 1250 | + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr); | ||
| 1251 | + | ||
| 1252 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 1253 | + jsize n = env->GetArrayLength(samples); | ||
| 1254 | + | ||
| 1255 | + model->AcceptWaveform(sample_rate, p, n); | ||
| 1256 | + | ||
| 1257 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 1258 | +} | ||
| 1259 | + | ||
| 1260 | +SHERPA_ONNX_EXTERN_C | ||
| 1261 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished( | ||
| 1262 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 1263 | + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished(); | ||
| 1264 | +} | ||
| 1265 | + | ||
| 1266 | +SHERPA_ONNX_EXTERN_C | ||
| 1267 | +JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword( | ||
| 1268 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 1269 | + // see | ||
| 1270 | + // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni | ||
| 1271 | + auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword(); | ||
| 1272 | + return env->NewStringUTF(text.c_str()); | ||
| 1273 | +} | ||
| 1274 | + | ||
| 1275 | +SHERPA_ONNX_EXTERN_C | ||
| 1276 | +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset( | ||
| 1277 | + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { | ||
| 1278 | + const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); | ||
| 1279 | + | ||
| 1280 | + std::string keywords_str = p_keywords; | ||
| 1281 | + | ||
| 1282 | + bool status = | ||
| 1283 | + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str); | ||
| 1284 | + env->ReleaseStringUTFChars(keywords, p_keywords); | ||
| 1285 | + return status; | ||
| 1286 | +} | ||
| 1287 | + | ||
| 1288 | +SHERPA_ONNX_EXTERN_C | ||
| 1289 | +JNIEXPORT jobjectArray JNICALL | ||
| 1290 | +Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/, | ||
| 1291 | + jlong ptr) { | ||
| 1292 | + auto tokens = | ||
| 1293 | + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens(); | ||
| 1294 | + int32_t size = tokens.size(); | ||
| 1295 | + jclass stringClass = env->FindClass("java/lang/String"); | ||
| 1296 | + | ||
| 1297 | + // convert C++ list into jni string array | ||
| 1298 | + jobjectArray result = env->NewObjectArray(size, stringClass, nullptr); | ||
| 1017 | for (int32_t i = 0; i < size; i++) { | 1299 | for (int32_t i = 0; i < size; i++) { |
| 1018 | // Convert the C++ string to a C string | 1300 | // Convert the C++ string to a C string |
| 1019 | const char *cstr = tokens[i].c_str(); | 1301 | const char *cstr = tokens[i].c_str(); |
| @@ -28,9 +28,14 @@ def cli(): | @@ -28,9 +28,14 @@ def cli(): | ||
| 28 | ) | 28 | ) |
| 29 | @click.option( | 29 | @click.option( |
| 30 | "--tokens-type", | 30 | "--tokens-type", |
| 31 | - type=str, | 31 | + type=click.Choice( |
| 32 | + ["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True | ||
| 33 | + ), | ||
| 32 | required=True, | 34 | required=True, |
| 33 | - help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", | 35 | + help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin. |
| 36 | + fpinyin means full pinyin, each cjkchar has a pinyin(with tone). | ||
| 37 | + ppinyin means partial pinyin, it splits pinyin into initial and final, | ||
| 38 | + """, | ||
| 34 | ) | 39 | ) |
| 35 | @click.option( | 40 | @click.option( |
| 36 | "--bpe-model", | 41 | "--bpe-model", |
| @@ -42,14 +47,56 @@ def encode_text( | @@ -42,14 +47,56 @@ def encode_text( | ||
| 42 | ): | 47 | ): |
| 43 | """ | 48 | """ |
| 44 | Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. | 49 | Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. |
| 50 | + Each line in the texts contains the original phrase, it might also contain some | ||
| 51 | + extra items, for example, the boosting score (startting with :), the triggering | ||
| 52 | + threshold (startting with #, only used in keyword spotting task) and the original | ||
| 53 | + phrase (startting with @). Note: the extra items will be kept same in the output. | ||
| 54 | + | ||
| 55 | + example input 1 (tokens_type = ppinyin): | ||
| 56 | + | ||
| 57 | + 小爱同学 :2.0 #0.6 @小爱同学 | ||
| 58 | + 你好问问 :3.5 @你好问问 | ||
| 59 | + 小艺小艺 #0.6 @小艺小艺 | ||
| 60 | + | ||
| 61 | + example output 1: | ||
| 62 | + | ||
| 63 | + x iǎo ài t óng x ué :2.0 #0.6 @小爱同学 | ||
| 64 | + n ǐ h ǎo w èn w èn :3.5 @你好问问 | ||
| 65 | + x iǎo y ì x iǎo y ì #0.6 @小艺小艺 | ||
| 66 | + | ||
| 67 | + example input 2 (tokens_type = bpe): | ||
| 68 | + | ||
| 69 | + HELLO WORLD :1.5 #0.4 | ||
| 70 | + HI GOOGLE :2.0 #0.8 | ||
| 71 | + HEY SIRI #0.35 | ||
| 72 | + | ||
| 73 | + example output 2: | ||
| 74 | + | ||
| 75 | + ▁HE LL O ▁WORLD :1.5 #0.4 | ||
| 76 | + ▁HI ▁GO O G LE :2.0 #0.8 | ||
| 77 | + ▁HE Y ▁S I RI #0.35 | ||
| 45 | """ | 78 | """ |
| 46 | texts = [] | 79 | texts = [] |
| 80 | + # extra information like boosting score (start with :), triggering threshold (start with #) | ||
| 81 | + # original keyword (start with @) | ||
| 82 | + extra_info = [] | ||
| 47 | with open(input, "r", encoding="utf8") as f: | 83 | with open(input, "r", encoding="utf8") as f: |
| 48 | for line in f: | 84 | for line in f: |
| 49 | - texts.append(line.strip()) | 85 | + extra = [] |
| 86 | + text = [] | ||
| 87 | + toks = line.strip().split() | ||
| 88 | + for tok in toks: | ||
| 89 | + if tok[0] == ":" or tok[0] == "#" or tok[0] == "@": | ||
| 90 | + extra.append(tok) | ||
| 91 | + else: | ||
| 92 | + text.append(tok) | ||
| 93 | + texts.append(" ".join(text)) | ||
| 94 | + extra_info.append(extra) | ||
| 95 | + | ||
| 50 | encoded_texts = text2token( | 96 | encoded_texts = text2token( |
| 51 | texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model | 97 | texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model |
| 52 | ) | 98 | ) |
| 53 | with open(output, "w", encoding="utf8") as f: | 99 | with open(output, "w", encoding="utf8") as f: |
| 54 | - for txt in encoded_texts: | 100 | + for i, txt in enumerate(encoded_texts): |
| 101 | + txt += extra_info[i] | ||
| 55 | f.write(" ".join(txt) + "\n") | 102 | f.write(" ".join(txt) + "\n") |
| @@ -6,6 +6,9 @@ from typing import List, Optional, Union | @@ -6,6 +6,9 @@ from typing import List, Optional, Union | ||
| 6 | 6 | ||
| 7 | import sentencepiece as spm | 7 | import sentencepiece as spm |
| 8 | 8 | ||
| 9 | +from pypinyin import pinyin | ||
| 10 | +from pypinyin.contrib.tone_convert import to_initials, to_finals_tone | ||
| 11 | + | ||
| 9 | 12 | ||
| 10 | def text2token( | 13 | def text2token( |
| 11 | texts: List[str], | 14 | texts: List[str], |
| @@ -23,7 +26,9 @@ def text2token( | @@ -23,7 +26,9 @@ def text2token( | ||
| 23 | tokens: | 26 | tokens: |
| 24 | The path of the tokens.txt. | 27 | The path of the tokens.txt. |
| 25 | tokens_type: | 28 | tokens_type: |
| 26 | - The valid values are cjkchar, bpe, cjkchar+bpe. | 29 | + The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin. |
| 30 | + fpinyin means full pinyin, each cjkchar has a pinyin(with tone). | ||
| 31 | + ppinyin means partial pinyin, it splits pinyin into initial and final, | ||
| 27 | bpe_model: | 32 | bpe_model: |
| 28 | The path of the bpe model. Only required when tokens_type is bpe or | 33 | The path of the bpe model. Only required when tokens_type is bpe or |
| 29 | cjkchar+bpe. | 34 | cjkchar+bpe. |
| @@ -53,6 +58,24 @@ def text2token( | @@ -53,6 +58,24 @@ def text2token( | ||
| 53 | texts_list = [list("".join(text.split())) for text in texts] | 58 | texts_list = [list("".join(text.split())) for text in texts] |
| 54 | elif tokens_type == "bpe": | 59 | elif tokens_type == "bpe": |
| 55 | texts_list = sp.encode(texts, out_type=str) | 60 | texts_list = sp.encode(texts, out_type=str) |
| 61 | + elif "pinyin" in tokens_type: | ||
| 62 | + for txt in texts: | ||
| 63 | + py = [x[0] for x in pinyin(txt)] | ||
| 64 | + if "ppinyin" == tokens_type: | ||
| 65 | + res = [] | ||
| 66 | + for x in py: | ||
| 67 | + initial = to_initials(x, strict=False) | ||
| 68 | + final = to_finals_tone(x, strict=False) | ||
| 69 | + if initial == "" and final == "": | ||
| 70 | + res.append(x) | ||
| 71 | + else: | ||
| 72 | + if initial != "": | ||
| 73 | + res.append(initial) | ||
| 74 | + if final != "": | ||
| 75 | + res.append(final) | ||
| 76 | + texts_list.append(res) | ||
| 77 | + else: | ||
| 78 | + texts_list.append(py) | ||
| 56 | else: | 79 | else: |
| 57 | assert ( | 80 | assert ( |
| 58 | tokens_type == "cjkchar+bpe" | 81 | tokens_type == "cjkchar+bpe" |
-
请 注册 或 登录 后发表评论