Wei Kang
Committed by GitHub

decoder for open vocabulary keyword spotting (#505)

* various fixes to ContextGraph to support open vocabulary keywords decoder

* Add keyword spotter runtime

* Add binary

* First version works

* Minor fixes

* update text2token

* default values

* Add jni for kws

* add kws android project

* Minor fixes

* Remove unused interface

* Minor fixes

* Add workflow

* handle extra info in texts

* Minor fixes

* Add more comments

* Fix ci

* fix cpp style

* Add input box in android demo so that users can specify their keywords

* Fix cpp style

* Fix comments

* Minor fixes

* Minor fixes

* minor fixes

* Minor fixes

* Minor fixes

* Add CI

* Fix code style

* cpplint

* Fix comments

* Fix error
正在显示 77 个修改的文件 包含 3316 行增加68 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run Chinese keyword spotting (Wenetspeech)"
  18 +log "------------------------------------------------------------"
  19 +
  20 +repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
  21 +log "Start testing ${repo_url}"
  22 +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
  23 +log "Download pretrained model and test-data from $repo_url"
  24 +
  25 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  26 +pushd $repo
  27 +git lfs pull --include "*.onnx"
  28 +ls -lh *.onnx
  29 +popd
  30 +
  31 +time $EXE \
  32 + --tokens=$repo/tokens.txt \
  33 + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  34 + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  35 + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
  36 + --keywords-file=$repo/test_wavs/test_keywords.txt \
  37 + --max-active-paths=4 \
  38 + --num-threads=4 \
  39 + $repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav
  40 +
  41 +rm -rf $repo
  42 +
  43 +log "------------------------------------------------------------"
  44 +log "Run English keyword spotting (Gigaspeech)"
  45 +log "------------------------------------------------------------"
  46 +
  47 +repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
  48 +log "Start testing ${repo_url}"
  49 +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
  50 +log "Download pretrained model and test-data from $repo_url"
  51 +
  52 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  53 +pushd $repo
  54 +git lfs pull --include "*.onnx"
  55 +ls -lh *.onnx
  56 +popd
  57 +
  58 +time $EXE \
  59 + --tokens=$repo/tokens.txt \
  60 + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  61 + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  62 + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
  63 + --keywords-file=$repo/test_wavs/test_keywords.txt \
  64 + --max-active-paths=4 \
  65 + --num-threads=4 \
  66 + $repo/test_wavs/0.wav $repo/test_wavs/1.wav
  67 +
  68 +rm -rf $repo
  1 +name: apk-kws
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - apk-kws
  7 + tags:
  8 + - '*'
  9 +
  10 + workflow_dispatch:
  11 +
  12 +concurrency:
  13 + group: apk-kws-${{ github.ref }}
  14 + cancel-in-progress: true
  15 +
  16 +permissions:
  17 + contents: write
  18 +
  19 +jobs:
  20 + apk:
  21 + runs-on: ${{ matrix.os }}
  22 + strategy:
  23 + fail-fast: false
  24 + matrix:
  25 + os: [ubuntu-latest]
  26 +
  27 + steps:
  28 + - uses: actions/checkout@v4
  29 + with:
  30 + fetch-depth: 0
  31 +
  32 + - name: ccache
  33 + uses: hendrikmuhs/ccache-action@v1.2
  34 + with:
  35 + key: ${{ matrix.os }}-android
  36 +
  37 + - name: Display NDK HOME
  38 + shell: bash
  39 + run: |
  40 + echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
  41 + ls -lh ${ANDROID_NDK_LATEST_HOME}
  42 +
  43 + - name: build APK
  44 + shell: bash
  45 + run: |
  46 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  47 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  48 + cmake --version
  49 +
  50 + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
  51 + ./build-kws-apk.sh
  52 +
  53 + - name: Display APK
  54 + shell: bash
  55 + run: |
  56 + ls -lh ./apks/
  57 +
  58 + - uses: actions/upload-artifact@v3
  59 + with:
  60 + path: ./apks/*.apk
  61 +
  62 + - name: Release APK
  63 + uses: svenstaro/upload-release-action@v2
  64 + with:
  65 + file_glob: true
  66 + file: apks/*.apk
  67 + overwrite: true
@@ -107,6 +107,14 @@ jobs: @@ -107,6 +107,14 @@ jobs:
107 name: release-static 107 name: release-static
108 path: build/bin/* 108 path: build/bin/*
109 109
  110 + - name: Test transducer kws
  111 + shell: bash
  112 + run: |
  113 + export PATH=$PWD/build/bin:$PATH
  114 + export EXE=sherpa-onnx-keyword-spotter
  115 +
  116 + .github/scripts/test-kws.sh
  117 +
110 - name: Test online CTC 118 - name: Test online CTC
111 shell: bash 119 shell: bash
112 run: | 120 run: |
@@ -98,6 +98,14 @@ jobs: @@ -98,6 +98,14 @@ jobs:
98 otool -L build/bin/sherpa-onnx 98 otool -L build/bin/sherpa-onnx
99 otool -l build/bin/sherpa-onnx 99 otool -l build/bin/sherpa-onnx
100 100
  101 + - name: Test transducer kws
  102 + shell: bash
  103 + run: |
  104 + export PATH=$PWD/build/bin:$PATH
  105 + export EXE=sherpa-onnx-keyword-spotter
  106 +
  107 + .github/scripts/test-kws.sh
  108 +
101 - name: Test online CTC 109 - name: Test online CTC
102 shell: bash 110 shell: bash
103 run: | 111 run: |
@@ -106,7 +114,6 @@ jobs: @@ -106,7 +114,6 @@ jobs:
106 114
107 .github/scripts/test-online-ctc.sh 115 .github/scripts/test-online-ctc.sh
108 116
109 -  
110 - name: Test offline TTS 117 - name: Test offline TTS
111 shell: bash 118 shell: bash
112 run: | 119 run: |
@@ -62,7 +62,7 @@ jobs: @@ -62,7 +62,7 @@ jobs:
62 - name: Install Python dependencies 62 - name: Install Python dependencies
63 shell: bash 63 shell: bash
64 run: | 64 run: |
65 - python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile 65 + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece==0.1.96 soundfile
66 66
67 - name: Install sherpa-onnx 67 - name: Install sherpa-onnx
68 shell: bash 68 shell: bash
@@ -45,7 +45,7 @@ jobs: @@ -45,7 +45,7 @@ jobs:
45 - name: Install Python dependencies 45 - name: Install Python dependencies
46 shell: bash 46 shell: bash
47 run: | 47 run: |
48 - python3 -m pip install --upgrade pip numpy sentencepiece 48 + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
49 49
50 - name: Install sherpa-onnx 50 - name: Install sherpa-onnx
51 shell: bash 51 shell: bash
@@ -45,7 +45,7 @@ jobs: @@ -45,7 +45,7 @@ jobs:
45 - name: Install Python dependencies 45 - name: Install Python dependencies
46 shell: bash 46 shell: bash
47 run: | 47 run: |
48 - python3 -m pip install --upgrade pip numpy sentencepiece 48 + python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
49 49
50 - name: Install sherpa-onnx 50 - name: Install sherpa-onnx
51 shell: bash 51 shell: bash
  1 +*.iml
  2 +.gradle
  3 +/local.properties
  4 +/.idea/caches
  5 +/.idea/libraries
  6 +/.idea/modules.xml
  7 +/.idea/workspace.xml
  8 +/.idea/navEditor.xml
  9 +/.idea/assetWizardSettings.xml
  10 +.DS_Store
  11 +/build
  12 +/captures
  13 +.externalNativeBuild
  14 +.cxx
  15 +local.properties
  1 +plugins {
  2 + id 'com.android.application'
  3 + id 'org.jetbrains.kotlin.android'
  4 +}
  5 +
  6 +android {
  7 + namespace 'com.k2fsa.sherpa.onnx'
  8 + compileSdk 32
  9 +
  10 + defaultConfig {
  11 + applicationId "com.k2fsa.sherpa.onnx"
  12 + minSdk 21
  13 + targetSdk 32
  14 + versionCode 1
  15 + versionName "1.0"
  16 +
  17 + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
  18 + }
  19 +
  20 + buildTypes {
  21 + release {
  22 + minifyEnabled false
  23 + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
  24 + }
  25 + }
  26 + compileOptions {
  27 + sourceCompatibility JavaVersion.VERSION_1_8
  28 + targetCompatibility JavaVersion.VERSION_1_8
  29 + }
  30 + kotlinOptions {
  31 + jvmTarget = '1.8'
  32 + }
  33 +}
  34 +
  35 +dependencies {
  36 +
  37 + implementation 'androidx.core:core-ktx:1.7.0'
  38 + implementation 'androidx.appcompat:appcompat:1.5.1'
  39 + implementation 'com.google.android.material:material:1.7.0'
  40 + implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
  41 + testImplementation 'junit:junit:4.13.2'
  42 + androidTestImplementation 'androidx.test.ext:junit:1.1.4'
  43 + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
  44 +}
  1 +# Add project specific ProGuard rules here.
  2 +# You can control the set of applied configuration files using the
  3 +# proguardFiles setting in build.gradle.
  4 +#
  5 +# For more details, see
  6 +# http://developer.android.com/guide/developing/tools/proguard.html
  7 +
  8 +# If your project uses WebView with JS, uncomment the following
  9 +# and specify the fully qualified class name to the JavaScript interface
  10 +# class:
  11 +#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
  12 +# public *;
  13 +#}
  14 +
  15 +# Uncomment this to preserve the line number information for
  16 +# debugging stack traces.
  17 +#-keepattributes SourceFile,LineNumberTable
  18 +
  19 +# If you keep the line number information, uncomment this to
  20 +# hide the original source file name.
  21 +#-renamesourcefileattribute SourceFile
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import androidx.test.platform.app.InstrumentationRegistry
  4 +import androidx.test.ext.junit.runners.AndroidJUnit4
  5 +
  6 +import org.junit.Test
  7 +import org.junit.runner.RunWith
  8 +
  9 +import org.junit.Assert.*
  10 +
  11 +/**
  12 + * Instrumented test, which will execute on an Android device.
  13 + *
  14 + * See [testing documentation](http://d.android.com/tools/testing).
  15 + */
  16 +@RunWith(AndroidJUnit4::class)
  17 +class ExampleInstrumentedTest {
  18 + @Test
  19 + fun useAppContext() {
  20 + // Context of the app under test.
  21 + val appContext = InstrumentationRegistry.getInstrumentation().targetContext
  22 + assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName)
  23 + }
  24 +}
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<manifest xmlns:android="http://schemas.android.com/apk/res/android"
  3 + xmlns:tools="http://schemas.android.com/tools">
  4 +
  5 + <uses-permission android:name="android.permission.RECORD_AUDIO" />
  6 +
  7 + <application
  8 + android:allowBackup="true"
  9 + android:dataExtractionRules="@xml/data_extraction_rules"
  10 + android:fullBackupContent="@xml/backup_rules"
  11 + android:icon="@mipmap/ic_launcher"
  12 + android:label="@string/app_name"
  13 + android:roundIcon="@mipmap/ic_launcher_round"
  14 + android:supportsRtl="true"
  15 + android:theme="@style/Theme.SherpaOnnx"
  16 + tools:targetApi="31">
  17 + <activity
  18 + android:name=".MainActivity"
  19 + android:exported="true">
  20 + <intent-filter>
  21 + <action android:name="android.intent.action.MAIN" />
  22 +
  23 + <category android:name="android.intent.category.LAUNCHER" />
  24 + </intent-filter>
  25 +
  26 + <meta-data
  27 + android:name="android.app.lib_name"
  28 + android:value="" />
  29 + </activity>
  30 + </application>
  31 +
  32 +</manifest>
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.Manifest
  4 +import android.content.pm.PackageManager
  5 +import android.media.AudioFormat
  6 +import android.media.AudioRecord
  7 +import android.media.MediaRecorder
  8 +import android.os.Bundle
  9 +import android.text.method.ScrollingMovementMethod
  10 +import android.util.Log
  11 +import android.widget.Button
  12 +import android.widget.EditText
  13 +import android.widget.TextView
  14 +import android.widget.Toast
  15 +import androidx.appcompat.app.AppCompatActivity
  16 +import androidx.core.app.ActivityCompat
  17 +import com.k2fsa.sherpa.onnx.*
  18 +import kotlin.concurrent.thread
  19 +
  20 +private const val TAG = "sherpa-onnx"
  21 +private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
  22 +
  23 +class MainActivity : AppCompatActivity() {
  24 + private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
  25 +
  26 + private lateinit var model: SherpaOnnxKws
  27 + private var audioRecord: AudioRecord? = null
  28 + private lateinit var recordButton: Button
  29 + private lateinit var textView: TextView
  30 + private lateinit var inputText: EditText
  31 + private var recordingThread: Thread? = null
  32 +
  33 + private val audioSource = MediaRecorder.AudioSource.MIC
  34 + private val sampleRateInHz = 16000
  35 + private val channelConfig = AudioFormat.CHANNEL_IN_MONO
  36 +
  37 + // Note: We don't use AudioFormat.ENCODING_PCM_FLOAT
  38 + // since the AudioRecord.read(float[]) needs API level >= 23
  39 + // but we are targeting API level >= 21
  40 + private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
  41 + private var idx: Int = 0
  42 + private var lastText: String = ""
  43 +
  44 + @Volatile
  45 + private var isRecording: Boolean = false
  46 +
  47 + override fun onRequestPermissionsResult(
  48 + requestCode: Int, permissions: Array<String>, grantResults: IntArray
  49 + ) {
  50 + super.onRequestPermissionsResult(requestCode, permissions, grantResults)
  51 + val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) {
  52 + grantResults[0] == PackageManager.PERMISSION_GRANTED
  53 + } else {
  54 + false
  55 + }
  56 +
  57 + if (!permissionToRecordAccepted) {
  58 + Log.e(TAG, "Audio record is disallowed")
  59 + finish()
  60 + }
  61 +
  62 + Log.i(TAG, "Audio record is permitted")
  63 + }
  64 +
  65 + override fun onCreate(savedInstanceState: Bundle?) {
  66 + super.onCreate(savedInstanceState)
  67 + setContentView(R.layout.activity_main)
  68 +
  69 + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
  70 +
  71 + Log.i(TAG, "Start to initialize model")
  72 + initModel()
  73 + Log.i(TAG, "Finished initializing model")
  74 +
  75 + recordButton = findViewById(R.id.record_button)
  76 + recordButton.setOnClickListener { onclick() }
  77 +
  78 + textView = findViewById(R.id.my_text)
  79 + textView.movementMethod = ScrollingMovementMethod()
  80 +
  81 + inputText = findViewById(R.id.input_text)
  82 + }
  83 +
  84 + private fun onclick() {
  85 + if (!isRecording) {
  86 + var keywords = inputText.text.toString()
  87 +
  88 + Log.i(TAG, keywords)
  89 + keywords = keywords.replace("\n", "/")
  90 + // If keywords is an empty string, it just resets the decoding stream
  91 + // always returns true in this case.
  92 + // If keywords is not empty, it will create a new decoding stream with
  93 + // the given keywords appended to the default keywords.
  94 + // Return false if errors occured when adding keywords, true otherwise.
  95 + val status = model.reset(keywords)
  96 + if (!status) {
  97 + Log.i(TAG, "Failed to reset with keywords.")
  98 + Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show();
  99 + return
  100 + }
  101 +
  102 + val ret = initMicrophone()
  103 + if (!ret) {
  104 + Log.e(TAG, "Failed to initialize microphone")
  105 + return
  106 + }
  107 + Log.i(TAG, "state: ${audioRecord?.state}")
  108 + audioRecord!!.startRecording()
  109 + recordButton.setText(R.string.stop)
  110 + isRecording = true
  111 + textView.text = ""
  112 + lastText = ""
  113 + idx = 0
  114 +
  115 + recordingThread = thread(true) {
  116 + processSamples()
  117 + }
  118 + Log.i(TAG, "Started recording")
  119 + } else {
  120 + isRecording = false
  121 + audioRecord!!.stop()
  122 + audioRecord!!.release()
  123 + audioRecord = null
  124 + recordButton.setText(R.string.start)
  125 + Log.i(TAG, "Stopped recording")
  126 + }
  127 + }
  128 +
  129 + private fun processSamples() {
  130 + Log.i(TAG, "processing samples")
  131 +
  132 + val interval = 0.1 // i.e., 100 ms
  133 + val bufferSize = (interval * sampleRateInHz).toInt() // in samples
  134 + val buffer = ShortArray(bufferSize)
  135 +
  136 + while (isRecording) {
  137 + val ret = audioRecord?.read(buffer, 0, buffer.size)
  138 + if (ret != null && ret > 0) {
  139 + val samples = FloatArray(ret) { buffer[it] / 32768.0f }
  140 + model.acceptWaveform(samples, sampleRate=sampleRateInHz)
  141 + while (model.isReady()) {
  142 + model.decode()
  143 + }
  144 +
  145 + val text = model.keyword
  146 +
  147 + var textToDisplay = lastText;
  148 +
  149 + if(text.isNotBlank()) {
  150 + if (lastText.isBlank()) {
  151 + textToDisplay = "${idx}: ${text}"
  152 + } else {
  153 + textToDisplay = "${idx}: ${text}\n${lastText}"
  154 + }
  155 + lastText = "${idx}: ${text}\n${lastText}"
  156 + idx += 1
  157 + }
  158 +
  159 + runOnUiThread {
  160 + textView.text = textToDisplay
  161 + }
  162 + }
  163 + }
  164 + }
  165 +
  166 + private fun initMicrophone(): Boolean {
  167 + if (ActivityCompat.checkSelfPermission(
  168 + this, Manifest.permission.RECORD_AUDIO
  169 + ) != PackageManager.PERMISSION_GRANTED
  170 + ) {
  171 + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
  172 + return false
  173 + }
  174 +
  175 + val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
  176 + Log.i(
  177 + TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}"
  178 + )
  179 +
  180 + audioRecord = AudioRecord(
  181 + audioSource,
  182 + sampleRateInHz,
  183 + channelConfig,
  184 + audioFormat,
  185 + numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
  186 + )
  187 + return true
  188 + }
  189 +
  190 + private fun initModel() {
  191 + // Please change getModelConfig() to add new models
  192 + // See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  193 + // for a list of available models
  194 + val type = 0
  195 + Log.i(TAG, "Select model type ${type}")
  196 + val config = KeywordSpotterConfig(
  197 + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
  198 + modelConfig = getModelConfig(type = type)!!,
  199 + keywordsFile = getKeywordsFile(type = type)!!,
  200 + )
  201 +
  202 + model = SherpaOnnxKws(
  203 + assetManager = application.assets,
  204 + config = config,
  205 + )
  206 + }
  207 +}
  1 +// Copyright (c) 2024 Xiaomi Corporation
  2 +package com.k2fsa.sherpa.onnx
  3 +
  4 +import android.content.res.AssetManager
  5 +
  6 +data class OnlineTransducerModelConfig(
  7 + var encoder: String = "",
  8 + var decoder: String = "",
  9 + var joiner: String = "",
  10 +)
  11 +
  12 +data class OnlineModelConfig(
  13 + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
  14 + var tokens: String,
  15 + var numThreads: Int = 1,
  16 + var debug: Boolean = false,
  17 + var provider: String = "cpu",
  18 + var modelType: String = "",
  19 +)
  20 +
  21 +data class FeatureConfig(
  22 + var sampleRate: Int = 16000,
  23 + var featureDim: Int = 80,
  24 +)
  25 +
  26 +data class KeywordSpotterConfig(
  27 + var featConfig: FeatureConfig = FeatureConfig(),
  28 + var modelConfig: OnlineModelConfig,
  29 + var maxActivePaths: Int = 4,
  30 + var keywordsFile: String = "keywords.txt",
  31 + var keywordsScore: Float = 1.5f,
  32 + var keywordsThreshold: Float = 0.25f,
  33 + var numTrailingBlanks: Int = 2,
  34 +)
  35 +
  36 +class SherpaOnnxKws(
  37 + assetManager: AssetManager? = null,
  38 + var config: KeywordSpotterConfig,
  39 +) {
  40 + private val ptr: Long
  41 +
  42 + init {
  43 + if (assetManager != null) {
  44 + ptr = new(assetManager, config)
  45 + } else {
  46 + ptr = newFromFile(config)
  47 + }
  48 + }
  49 +
  50 + protected fun finalize() {
  51 + delete(ptr)
  52 + }
  53 +
  54 + fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
  55 + acceptWaveform(ptr, samples, sampleRate)
  56 +
  57 + fun inputFinished() = inputFinished(ptr)
  58 + fun decode() = decode(ptr)
  59 + fun isReady(): Boolean = isReady(ptr)
  60 + fun reset(keywords: String): Boolean = reset(ptr, keywords)
  61 +
  62 + val keyword: String
  63 + get() = getKeyword(ptr)
  64 +
  65 + private external fun delete(ptr: Long)
  66 +
  67 + private external fun new(
  68 + assetManager: AssetManager,
  69 + config: KeywordSpotterConfig,
  70 + ): Long
  71 +
  72 + private external fun newFromFile(
  73 + config: KeywordSpotterConfig,
  74 + ): Long
  75 +
  76 + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
  77 + private external fun inputFinished(ptr: Long)
  78 + private external fun getKeyword(ptr: Long): String
  79 + private external fun reset(ptr: Long, keywords: String): Boolean
  80 + private external fun decode(ptr: Long)
  81 + private external fun isReady(ptr: Long): Boolean
  82 +
  83 + companion object {
  84 + init {
  85 + System.loadLibrary("sherpa-onnx-jni")
  86 + }
  87 + }
  88 +}
  89 +
  90 +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
  91 + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
  92 +}
  93 +
  94 +/*
  95 +Please see
  96 +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  97 +for a list of pre-trained models.
  98 +
  99 +We only add a few here. Please change the following code
  100 +to add your own. (It should be straightforward to add a new model
  101 +by following the code)
  102 +
  103 +@param type
  104 +0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese)
  105 + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary
  106 +
  107 +1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English)
  108 + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
  109 +
  110 + */
  111 +fun getModelConfig(type: Int): OnlineModelConfig? {
  112 + when (type) {
  113 + 0 -> {
  114 + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
  115 + return OnlineModelConfig(
  116 + transducer = OnlineTransducerModelConfig(
  117 + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  118 + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  119 + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
  120 + ),
  121 + tokens = "$modelDir/tokens.txt",
  122 + modelType = "zipformer2",
  123 + )
  124 + }
  125 +
  126 + 1 -> {
  127 + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
  128 + return OnlineModelConfig(
  129 + transducer = OnlineTransducerModelConfig(
  130 + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  131 + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  132 + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
  133 + ),
  134 + tokens = "$modelDir/tokens.txt",
  135 + modelType = "zipformer2",
  136 + )
  137 + }
  138 +
  139 + }
  140 + return null;
  141 +}
  142 +
  143 +/*
  144 + * Get the default keywords for each model.
  145 + * Caution: The types and modelDir should be the same as those in getModelConfig
  146 + * function above.
  147 + */
  148 +fun getKeywordsFile(type: Int) : String {
  149 + when (type) {
  150 + 0 -> {
  151 + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
  152 + return "$modelDir/keywords.txt"
  153 + }
  154 +
  155 + 1 -> {
  156 + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
  157 + return "$modelDir/keywords.txt"
  158 + }
  159 +
  160 + }
  161 + return "";
  162 +}
  1 +// Copyright (c) 2023 Xiaomi Corporation
  2 +package com.k2fsa.sherpa.onnx
  3 +
  4 +import android.content.res.AssetManager
  5 +
  6 +class WaveReader {
  7 + companion object {
  8 + // Read a mono wave file asset
  9 + // The returned array has two entries:
  10 + // - the first entry contains an 1-D float array
  11 + // - the second entry is the sample rate
  12 + external fun readWaveFromAsset(
  13 + assetManager: AssetManager,
  14 + filename: String,
  15 + ): Array<Any>
  16 +
  17 + // Read a mono wave file from disk
  18 + // The returned array has two entries:
  19 + // - the first entry contains an 1-D float array
  20 + // - the second entry is the sample rate
  21 + external fun readWaveFromFile(
  22 + filename: String,
  23 + ): Array<Any>
  24 +
  25 + init {
  26 + System.loadLibrary("sherpa-onnx-jni")
  27 + }
  28 + }
  29 +}
  1 +*.so
  2 +*.txt
  3 +*.onnx
  4 +*.wav
  1 +<vector xmlns:android="http://schemas.android.com/apk/res/android"
  2 + xmlns:aapt="http://schemas.android.com/aapt"
  3 + android:width="108dp"
  4 + android:height="108dp"
  5 + android:viewportWidth="108"
  6 + android:viewportHeight="108">
  7 + <path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
  8 + <aapt:attr name="android:fillColor">
  9 + <gradient
  10 + android:endX="85.84757"
  11 + android:endY="92.4963"
  12 + android:startX="42.9492"
  13 + android:startY="49.59793"
  14 + android:type="linear">
  15 + <item
  16 + android:color="#44000000"
  17 + android:offset="0.0" />
  18 + <item
  19 + android:color="#00000000"
  20 + android:offset="1.0" />
  21 + </gradient>
  22 + </aapt:attr>
  23 + </path>
  24 + <path
  25 + android:fillColor="#FFFFFF"
  26 + android:fillType="nonZero"
  27 + android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
  28 + android:strokeWidth="1"
  29 + android:strokeColor="#00000000" />
  30 +</vector>
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<vector xmlns:android="http://schemas.android.com/apk/res/android"
  3 + android:width="108dp"
  4 + android:height="108dp"
  5 + android:viewportWidth="108"
  6 + android:viewportHeight="108">
  7 + <path
  8 + android:fillColor="#3DDC84"
  9 + android:pathData="M0,0h108v108h-108z" />
  10 + <path
  11 + android:fillColor="#00000000"
  12 + android:pathData="M9,0L9,108"
  13 + android:strokeWidth="0.8"
  14 + android:strokeColor="#33FFFFFF" />
  15 + <path
  16 + android:fillColor="#00000000"
  17 + android:pathData="M19,0L19,108"
  18 + android:strokeWidth="0.8"
  19 + android:strokeColor="#33FFFFFF" />
  20 + <path
  21 + android:fillColor="#00000000"
  22 + android:pathData="M29,0L29,108"
  23 + android:strokeWidth="0.8"
  24 + android:strokeColor="#33FFFFFF" />
  25 + <path
  26 + android:fillColor="#00000000"
  27 + android:pathData="M39,0L39,108"
  28 + android:strokeWidth="0.8"
  29 + android:strokeColor="#33FFFFFF" />
  30 + <path
  31 + android:fillColor="#00000000"
  32 + android:pathData="M49,0L49,108"
  33 + android:strokeWidth="0.8"
  34 + android:strokeColor="#33FFFFFF" />
  35 + <path
  36 + android:fillColor="#00000000"
  37 + android:pathData="M59,0L59,108"
  38 + android:strokeWidth="0.8"
  39 + android:strokeColor="#33FFFFFF" />
  40 + <path
  41 + android:fillColor="#00000000"
  42 + android:pathData="M69,0L69,108"
  43 + android:strokeWidth="0.8"
  44 + android:strokeColor="#33FFFFFF" />
  45 + <path
  46 + android:fillColor="#00000000"
  47 + android:pathData="M79,0L79,108"
  48 + android:strokeWidth="0.8"
  49 + android:strokeColor="#33FFFFFF" />
  50 + <path
  51 + android:fillColor="#00000000"
  52 + android:pathData="M89,0L89,108"
  53 + android:strokeWidth="0.8"
  54 + android:strokeColor="#33FFFFFF" />
  55 + <path
  56 + android:fillColor="#00000000"
  57 + android:pathData="M99,0L99,108"
  58 + android:strokeWidth="0.8"
  59 + android:strokeColor="#33FFFFFF" />
  60 + <path
  61 + android:fillColor="#00000000"
  62 + android:pathData="M0,9L108,9"
  63 + android:strokeWidth="0.8"
  64 + android:strokeColor="#33FFFFFF" />
  65 + <path
  66 + android:fillColor="#00000000"
  67 + android:pathData="M0,19L108,19"
  68 + android:strokeWidth="0.8"
  69 + android:strokeColor="#33FFFFFF" />
  70 + <path
  71 + android:fillColor="#00000000"
  72 + android:pathData="M0,29L108,29"
  73 + android:strokeWidth="0.8"
  74 + android:strokeColor="#33FFFFFF" />
  75 + <path
  76 + android:fillColor="#00000000"
  77 + android:pathData="M0,39L108,39"
  78 + android:strokeWidth="0.8"
  79 + android:strokeColor="#33FFFFFF" />
  80 + <path
  81 + android:fillColor="#00000000"
  82 + android:pathData="M0,49L108,49"
  83 + android:strokeWidth="0.8"
  84 + android:strokeColor="#33FFFFFF" />
  85 + <path
  86 + android:fillColor="#00000000"
  87 + android:pathData="M0,59L108,59"
  88 + android:strokeWidth="0.8"
  89 + android:strokeColor="#33FFFFFF" />
  90 + <path
  91 + android:fillColor="#00000000"
  92 + android:pathData="M0,69L108,69"
  93 + android:strokeWidth="0.8"
  94 + android:strokeColor="#33FFFFFF" />
  95 + <path
  96 + android:fillColor="#00000000"
  97 + android:pathData="M0,79L108,79"
  98 + android:strokeWidth="0.8"
  99 + android:strokeColor="#33FFFFFF" />
  100 + <path
  101 + android:fillColor="#00000000"
  102 + android:pathData="M0,89L108,89"
  103 + android:strokeWidth="0.8"
  104 + android:strokeColor="#33FFFFFF" />
  105 + <path
  106 + android:fillColor="#00000000"
  107 + android:pathData="M0,99L108,99"
  108 + android:strokeWidth="0.8"
  109 + android:strokeColor="#33FFFFFF" />
  110 + <path
  111 + android:fillColor="#00000000"
  112 + android:pathData="M19,29L89,29"
  113 + android:strokeWidth="0.8"
  114 + android:strokeColor="#33FFFFFF" />
  115 + <path
  116 + android:fillColor="#00000000"
  117 + android:pathData="M19,39L89,39"
  118 + android:strokeWidth="0.8"
  119 + android:strokeColor="#33FFFFFF" />
  120 + <path
  121 + android:fillColor="#00000000"
  122 + android:pathData="M19,49L89,49"
  123 + android:strokeWidth="0.8"
  124 + android:strokeColor="#33FFFFFF" />
  125 + <path
  126 + android:fillColor="#00000000"
  127 + android:pathData="M19,59L89,59"
  128 + android:strokeWidth="0.8"
  129 + android:strokeColor="#33FFFFFF" />
  130 + <path
  131 + android:fillColor="#00000000"
  132 + android:pathData="M19,69L89,69"
  133 + android:strokeWidth="0.8"
  134 + android:strokeColor="#33FFFFFF" />
  135 + <path
  136 + android:fillColor="#00000000"
  137 + android:pathData="M19,79L89,79"
  138 + android:strokeWidth="0.8"
  139 + android:strokeColor="#33FFFFFF" />
  140 + <path
  141 + android:fillColor="#00000000"
  142 + android:pathData="M29,19L29,89"
  143 + android:strokeWidth="0.8"
  144 + android:strokeColor="#33FFFFFF" />
  145 + <path
  146 + android:fillColor="#00000000"
  147 + android:pathData="M39,19L39,89"
  148 + android:strokeWidth="0.8"
  149 + android:strokeColor="#33FFFFFF" />
  150 + <path
  151 + android:fillColor="#00000000"
  152 + android:pathData="M49,19L49,89"
  153 + android:strokeWidth="0.8"
  154 + android:strokeColor="#33FFFFFF" />
  155 + <path
  156 + android:fillColor="#00000000"
  157 + android:pathData="M59,19L59,89"
  158 + android:strokeWidth="0.8"
  159 + android:strokeColor="#33FFFFFF" />
  160 + <path
  161 + android:fillColor="#00000000"
  162 + android:pathData="M69,19L69,89"
  163 + android:strokeWidth="0.8"
  164 + android:strokeColor="#33FFFFFF" />
  165 + <path
  166 + android:fillColor="#00000000"
  167 + android:pathData="M79,19L79,89"
  168 + android:strokeWidth="0.8"
  169 + android:strokeColor="#33FFFFFF" />
  170 +</vector>
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
  3 + xmlns:app="http://schemas.android.com/apk/res-auto"
  4 + xmlns:tools="http://schemas.android.com/tools"
  5 + android:layout_width="match_parent"
  6 + android:layout_height="match_parent"
  7 + tools:context=".MainActivity">
  8 +
  9 + <LinearLayout
  10 + android:layout_width="match_parent"
  11 + android:layout_height="match_parent"
  12 + android:gravity="center"
  13 + android:orientation="vertical">
  14 +
  15 + <EditText
  16 + android:id="@+id/input_text"
  17 + android:layout_width="match_parent"
  18 + android:layout_height="320dp"
  19 + android:layout_weight="2.5"
  20 + android:hint="@string/keyword_hint"
  21 + android:scrollbars="vertical"
  22 + android:text=""
  23 + android:textSize="15dp" />
  24 +
  25 + <TextView
  26 + android:id="@+id/my_text"
  27 + android:layout_width="match_parent"
  28 + android:layout_height="443dp"
  29 + android:layout_weight="2.5"
  30 + android:padding="24dp"
  31 + android:scrollbars="vertical"
  32 + android:singleLine="false"
  33 + android:text="@string/hint"
  34 + android:textSize="15dp" />
  35 +
  36 + <Button
  37 + android:id="@+id/record_button"
  38 + android:layout_width="wrap_content"
  39 + android:layout_height="wrap_content"
  40 + android:layout_weight="0.5"
  41 + android:text="@string/start" />
  42 +
  43 + </LinearLayout>
  44 +
  45 +
  46 +</androidx.constraintlayout.widget.ConstraintLayout>
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
  3 + <background android:drawable="@drawable/ic_launcher_background" />
  4 + <foreground android:drawable="@drawable/ic_launcher_foreground" />
  5 +</adaptive-icon>
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
  3 + <background android:drawable="@drawable/ic_launcher_background" />
  4 + <foreground android:drawable="@drawable/ic_launcher_foreground" />
  5 +</adaptive-icon>
  1 +<resources xmlns:tools="http://schemas.android.com/tools">
  2 + <!-- Base application theme. -->
  3 + <style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
  4 + <!-- Primary brand color. -->
  5 + <item name="colorPrimary">@color/purple_200</item>
  6 + <item name="colorPrimaryVariant">@color/purple_700</item>
  7 + <item name="colorOnPrimary">@color/black</item>
  8 + <!-- Secondary brand color. -->
  9 + <item name="colorSecondary">@color/teal_200</item>
  10 + <item name="colorSecondaryVariant">@color/teal_200</item>
  11 + <item name="colorOnSecondary">@color/black</item>
  12 + <!-- Status bar color. -->
  13 + <item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
  14 + <!-- Customize your theme here. -->
  15 + </style>
  16 +</resources>
  1 +<?xml version="1.0" encoding="utf-8"?>
  2 +<resources>
  3 + <color name="purple_200">#FFBB86FC</color>
  4 + <color name="purple_500">#FF6200EE</color>
  5 + <color name="purple_700">#FF3700B3</color>
  6 + <color name="teal_200">#FF03DAC5</color>
  7 + <color name="teal_700">#FF018786</color>
  8 + <color name="black">#FF000000</color>
  9 + <color name="white">#FFFFFFFF</color>
  10 +</resources>
  1 +<resources>
  2 + <string name="app_name">KWS with Next-gen Kaldi</string>
  3 + <string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
  4 + \n
  5 + \n\n\n
  6 + The source code and pre-trained models are publicly available.
  7 + Please see https://github.com/k2-fsa/sherpa-onnx for details.
  8 + </string>
  9 + <string name="keyword_hint">Input your keywords here, one keyword perline.</string>
  10 + <string name="start">Start</string>
  11 + <string name="stop">Stop</string>
  12 +</resources>
  1 +<resources xmlns:tools="http://schemas.android.com/tools">
  2 + <!-- Base application theme. -->
  3 + <style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
  4 + <!-- Primary brand color. -->
  5 + <item name="colorPrimary">@color/purple_500</item>
  6 + <item name="colorPrimaryVariant">@color/purple_700</item>
  7 + <item name="colorOnPrimary">@color/white</item>
  8 + <!-- Secondary brand color. -->
  9 + <item name="colorSecondary">@color/teal_200</item>
  10 + <item name="colorSecondaryVariant">@color/teal_700</item>
  11 + <item name="colorOnSecondary">@color/black</item>
  12 + <!-- Status bar color. -->
  13 + <item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
  14 + <!-- Customize your theme here. -->
  15 + </style>
  16 +</resources>
  1 +<?xml version="1.0" encoding="utf-8"?><!--
  2 + Sample backup rules file; uncomment and customize as necessary.
  3 + See https://developer.android.com/guide/topics/data/autobackup
  4 + for details.
  5 + Note: This file is ignored for devices older that API 31
  6 + See https://developer.android.com/about/versions/12/backup-restore
  7 +-->
  8 +<full-backup-content>
  9 + <!--
  10 + <include domain="sharedpref" path="."/>
  11 + <exclude domain="sharedpref" path="device.xml"/>
  12 +-->
  13 +</full-backup-content>
  1 +<?xml version="1.0" encoding="utf-8"?><!--
  2 + Sample data extraction rules file; uncomment and customize as necessary.
  3 + See https://developer.android.com/about/versions/12/backup-restore#xml-changes
  4 + for details.
  5 +-->
  6 +<data-extraction-rules>
  7 + <cloud-backup>
  8 + <!-- TODO: Use <include> and <exclude> to control what is backed up.
  9 + <include .../>
  10 + <exclude .../>
  11 + -->
  12 + </cloud-backup>
  13 + <!--
  14 + <device-transfer>
  15 + <include .../>
  16 + <exclude .../>
  17 + </device-transfer>
  18 + -->
  19 +</data-extraction-rules>
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import org.junit.Test
  4 +
  5 +import org.junit.Assert.*
  6 +
  7 +/**
  8 + * Example local unit test, which will execute on the development machine (host).
  9 + *
  10 + * See [testing documentation](http://d.android.com/tools/testing).
  11 + */
  12 +class ExampleUnitTest {
  13 + @Test
  14 + fun addition_isCorrect() {
  15 + assertEquals(4, 2 + 2)
  16 + }
  17 +}
  1 +// Top-level build file where you can add configuration options common to all sub-projects/modules.
  2 +plugins {
  3 + id 'com.android.application' version '7.3.1' apply false
  4 + id 'com.android.library' version '7.3.1' apply false
  5 + id 'org.jetbrains.kotlin.android' version '1.7.20' apply false
  6 +}
  1 +# Project-wide Gradle settings.
  2 +# IDE (e.g. Android Studio) users:
  3 +# Gradle settings configured through the IDE *will override*
  4 +# any settings specified in this file.
  5 +# For more details on how to configure your build environment visit
  6 +# http://www.gradle.org/docs/current/userguide/build_environment.html
  7 +# Specifies the JVM arguments used for the daemon process.
  8 +# The setting is particularly useful for tweaking memory settings.
  9 +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
  10 +# When configured, Gradle will run in incubating parallel mode.
  11 +# This option should only be used with decoupled projects. More details, visit
  12 +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
  13 +# org.gradle.parallel=true
  14 +# AndroidX package structure to make it clearer which packages are bundled with the
  15 +# Android operating system, and which are packaged with your app's APK
  16 +# https://developer.android.com/topic/libraries/support-library/androidx-rn
  17 +android.useAndroidX=true
  18 +# Kotlin code style for this project: "official" or "obsolete":
  19 +kotlin.code.style=official
  20 +# Enables namespacing of each library's R class so that its R class includes only the
  21 +# resources declared in the library itself and none from the library's dependencies,
  22 +# thereby reducing the size of the R class for that library
  23 +android.nonTransitiveRClass=true
  1 +#Thu Feb 23 11:09:06 CST 2023
  2 +distributionBase=GRADLE_USER_HOME
  3 +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
  4 +distributionPath=wrapper/dists
  5 +zipStorePath=wrapper/dists
  6 +zipStoreBase=GRADLE_USER_HOME
  1 +#!/usr/bin/env sh
  2 +
  3 +#
  4 +# Copyright 2015 the original author or authors.
  5 +#
  6 +# Licensed under the Apache License, Version 2.0 (the "License");
  7 +# you may not use this file except in compliance with the License.
  8 +# You may obtain a copy of the License at
  9 +#
  10 +# https://www.apache.org/licenses/LICENSE-2.0
  11 +#
  12 +# Unless required by applicable law or agreed to in writing, software
  13 +# distributed under the License is distributed on an "AS IS" BASIS,
  14 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 +# See the License for the specific language governing permissions and
  16 +# limitations under the License.
  17 +#
  18 +
  19 +##############################################################################
  20 +##
  21 +## Gradle start up script for UN*X
  22 +##
  23 +##############################################################################
  24 +
  25 +# Attempt to set APP_HOME
  26 +# Resolve links: $0 may be a link
  27 +PRG="$0"
  28 +# Need this for relative symlinks.
  29 +while [ -h "$PRG" ] ; do
  30 + ls=`ls -ld "$PRG"`
  31 + link=`expr "$ls" : '.*-> \(.*\)$'`
  32 + if expr "$link" : '/.*' > /dev/null; then
  33 + PRG="$link"
  34 + else
  35 + PRG=`dirname "$PRG"`"/$link"
  36 + fi
  37 +done
  38 +SAVED="`pwd`"
  39 +cd "`dirname \"$PRG\"`/" >/dev/null
  40 +APP_HOME="`pwd -P`"
  41 +cd "$SAVED" >/dev/null
  42 +
  43 +APP_NAME="Gradle"
  44 +APP_BASE_NAME=`basename "$0"`
  45 +
  46 +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
  47 +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
  48 +
  49 +# Use the maximum available, or set MAX_FD != -1 to use that value.
  50 +MAX_FD="maximum"
  51 +
  52 +warn () {
  53 + echo "$*"
  54 +}
  55 +
  56 +die () {
  57 + echo
  58 + echo "$*"
  59 + echo
  60 + exit 1
  61 +}
  62 +
  63 +# OS specific support (must be 'true' or 'false').
  64 +cygwin=false
  65 +msys=false
  66 +darwin=false
  67 +nonstop=false
  68 +case "`uname`" in
  69 + CYGWIN* )
  70 + cygwin=true
  71 + ;;
  72 + Darwin* )
  73 + darwin=true
  74 + ;;
  75 + MINGW* )
  76 + msys=true
  77 + ;;
  78 + NONSTOP* )
  79 + nonstop=true
  80 + ;;
  81 +esac
  82 +
  83 +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
  84 +
  85 +
  86 +# Determine the Java command to use to start the JVM.
  87 +if [ -n "$JAVA_HOME" ] ; then
  88 + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
  89 + # IBM's JDK on AIX uses strange locations for the executables
  90 + JAVACMD="$JAVA_HOME/jre/sh/java"
  91 + else
  92 + JAVACMD="$JAVA_HOME/bin/java"
  93 + fi
  94 + if [ ! -x "$JAVACMD" ] ; then
  95 + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
  96 +
  97 +Please set the JAVA_HOME variable in your environment to match the
  98 +location of your Java installation."
  99 + fi
  100 +else
  101 + JAVACMD="java"
  102 + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
  103 +
  104 +Please set the JAVA_HOME variable in your environment to match the
  105 +location of your Java installation."
  106 +fi
  107 +
  108 +# Increase the maximum file descriptors if we can.
  109 +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
  110 + MAX_FD_LIMIT=`ulimit -H -n`
  111 + if [ $? -eq 0 ] ; then
  112 + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
  113 + MAX_FD="$MAX_FD_LIMIT"
  114 + fi
  115 + ulimit -n $MAX_FD
  116 + if [ $? -ne 0 ] ; then
  117 + warn "Could not set maximum file descriptor limit: $MAX_FD"
  118 + fi
  119 + else
  120 + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
  121 + fi
  122 +fi
  123 +
  124 +# For Darwin, add options to specify how the application appears in the dock
  125 +if $darwin; then
  126 + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
  127 +fi
  128 +
  129 +# For Cygwin or MSYS, switch paths to Windows format before running java
  130 +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
  131 + APP_HOME=`cygpath --path --mixed "$APP_HOME"`
  132 + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
  133 +
  134 + JAVACMD=`cygpath --unix "$JAVACMD"`
  135 +
  136 + # We build the pattern for arguments to be converted via cygpath
  137 + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
  138 + SEP=""
  139 + for dir in $ROOTDIRSRAW ; do
  140 + ROOTDIRS="$ROOTDIRS$SEP$dir"
  141 + SEP="|"
  142 + done
  143 + OURCYGPATTERN="(^($ROOTDIRS))"
  144 + # Add a user-defined pattern to the cygpath arguments
  145 + if [ "$GRADLE_CYGPATTERN" != "" ] ; then
  146 + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
  147 + fi
  148 + # Now convert the arguments - kludge to limit ourselves to /bin/sh
  149 + i=0
  150 + for arg in "$@" ; do
  151 + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
  152 + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
  153 +
  154 + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
  155 + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
  156 + else
  157 + eval `echo args$i`="\"$arg\""
  158 + fi
  159 + i=`expr $i + 1`
  160 + done
  161 + case $i in
  162 + 0) set -- ;;
  163 + 1) set -- "$args0" ;;
  164 + 2) set -- "$args0" "$args1" ;;
  165 + 3) set -- "$args0" "$args1" "$args2" ;;
  166 + 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
  167 + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
  168 + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
  169 + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
  170 + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
  171 + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
  172 + esac
  173 +fi
  174 +
  175 +# Escape application args
  176 +save () {
  177 + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
  178 + echo " "
  179 +}
  180 +APP_ARGS=`save "$@"`
  181 +
  182 +# Collect all arguments for the java command, following the shell quoting and substitution rules
  183 +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
  184 +
  185 +exec "$JAVACMD" "$@"
  1 +@rem
  2 +@rem Copyright 2015 the original author or authors.
  3 +@rem
  4 +@rem Licensed under the Apache License, Version 2.0 (the "License");
  5 +@rem you may not use this file except in compliance with the License.
  6 +@rem You may obtain a copy of the License at
  7 +@rem
  8 +@rem https://www.apache.org/licenses/LICENSE-2.0
  9 +@rem
  10 +@rem Unless required by applicable law or agreed to in writing, software
  11 +@rem distributed under the License is distributed on an "AS IS" BASIS,
  12 +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 +@rem See the License for the specific language governing permissions and
  14 +@rem limitations under the License.
  15 +@rem
  16 +
  17 +@if "%DEBUG%" == "" @echo off
  18 +@rem ##########################################################################
  19 +@rem
  20 +@rem Gradle startup script for Windows
  21 +@rem
  22 +@rem ##########################################################################
  23 +
  24 +@rem Set local scope for the variables with windows NT shell
  25 +if "%OS%"=="Windows_NT" setlocal
  26 +
  27 +set DIRNAME=%~dp0
  28 +if "%DIRNAME%" == "" set DIRNAME=.
  29 +set APP_BASE_NAME=%~n0
  30 +set APP_HOME=%DIRNAME%
  31 +
  32 +@rem Resolve any "." and ".." in APP_HOME to make it shorter.
  33 +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
  34 +
  35 +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
  36 +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
  37 +
  38 +@rem Find java.exe
  39 +if defined JAVA_HOME goto findJavaFromJavaHome
  40 +
  41 +set JAVA_EXE=java.exe
  42 +%JAVA_EXE% -version >NUL 2>&1
  43 +if "%ERRORLEVEL%" == "0" goto execute
  44 +
  45 +echo.
  46 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
  47 +echo.
  48 +echo Please set the JAVA_HOME variable in your environment to match the
  49 +echo location of your Java installation.
  50 +
  51 +goto fail
  52 +
  53 +:findJavaFromJavaHome
  54 +set JAVA_HOME=%JAVA_HOME:"=%
  55 +set JAVA_EXE=%JAVA_HOME%/bin/java.exe
  56 +
  57 +if exist "%JAVA_EXE%" goto execute
  58 +
  59 +echo.
  60 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
  61 +echo.
  62 +echo Please set the JAVA_HOME variable in your environment to match the
  63 +echo location of your Java installation.
  64 +
  65 +goto fail
  66 +
  67 +:execute
  68 +@rem Setup the command line
  69 +
  70 +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
  71 +
  72 +
  73 +@rem Execute Gradle
  74 +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
  75 +
  76 +:end
  77 +@rem End local scope for the variables with windows NT shell
  78 +if "%ERRORLEVEL%"=="0" goto mainEnd
  79 +
  80 +:fail
  81 +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
  82 +rem the _cmd.exe /c_ return code!
  83 +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
  84 +exit /b 1
  85 +
  86 +:mainEnd
  87 +if "%OS%"=="Windows_NT" endlocal
  88 +
  89 +:omega
  1 +pluginManagement {
  2 + repositories {
  3 + gradlePluginPortal()
  4 + google()
  5 + mavenCentral()
  6 + }
  7 +}
  8 +dependencyResolutionManagement {
  9 + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
  10 + repositories {
  11 + google()
  12 + mavenCentral()
  13 + }
  14 +}
  15 +rootProject.name = "SherpaOnnxKws"
  16 +include ':app'
  1 +#!/usr/bin/env bash
  2 +
  3 +# Please set the environment variable ANDROID_NDK
  4 +# before running this script
  5 +
  6 +# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
  7 +# and some other files like the file "build/cmake/android.toolchain.cmake"
  8 +
  9 +set -e
  10 +
  11 +log() {
  12 + # This function is from espnet
  13 + local fname=${BASH_SOURCE[1]##*/}
  14 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  15 +}
  16 +
  17 +SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
  18 +
  19 +log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
  20 +
  21 +log "====================arm64-v8a================="
  22 +./build-android-arm64-v8a.sh
  23 +log "====================armv7-eabi================"
  24 +./build-android-armv7-eabi.sh
  25 +log "====================x86-64===================="
  26 +./build-android-x86-64.sh
  27 +log "====================x86===================="
  28 +./build-android-x86.sh
  29 +
  30 +mkdir -p apks
  31 +
  32 +# Download the model
  33 +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
  34 +
  35 +if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
  36 + repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
  37 + log "Start testing ${repo_url}"
  38 + log "Download pretrained model and test-data from $repo_url"
  39 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  40 + pushd $repo
  41 + git lfs pull --include "*.onnx"
  42 +
  43 + # remove .git to save spaces
  44 + rm -rf .git
  45 + rm *.int8.onnx
  46 + rm README.md configuration.json .gitattributes
  47 + rm -rfv test_wavs
  48 + ls -lh
  49 + popd
  50 +
  51 + mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
  52 +fi
  53 +
  54 +tree ./android/SherpaOnnxKws/app/src/main/assets/
  55 +
  56 +for arch in arm64-v8a armeabi-v7a x86_64 x86; do
  57 + log "------------------------------------------------------------"
  58 + log "build apk for $arch"
  59 + log "------------------------------------------------------------"
  60 + src_arch=$arch
  61 + if [ $arch == "armeabi-v7a" ]; then
  62 + src_arch=armv7-eabi
  63 + elif [ $arch == "x86_64" ]; then
  64 + src_arch=x86-64
  65 + fi
  66 +
  67 + ls -lh ./build-android-$src_arch/install/lib/*.so
  68 +
  69 + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
  70 +
  71 + pushd ./android/SherpaOnnxKws
  72 + ./gradlew build
  73 + popd
  74 +
  75 + mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-wenetspeech-zh-${SHERPA_ONNX_VERSION}-$arch.apk
  76 + ls -lh apks
  77 + rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
  78 +done
  79 +
  80 +git checkout .
  81 +
  82 +rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
  83 +
  84 +# English model
  85 +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
  86 +
  87 +if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
  88 + repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
  89 + log "Start testing ${repo_url}"
  90 + log "Download pretrained model and test-data from $repo_url"
  91 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  92 + pushd $repo
  93 + git lfs pull --include "*.onnx"
  94 +
  95 + # remove .git to save spaces
  96 + rm -rf .git
  97 + rm *.int8.onnx
  98 + rm README.md configuration.json .gitattributes
  99 + rm -rfv test_wavs
  100 + ls -lh
  101 + popd
  102 +
  103 + mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
  104 +fi
  105 +
  106 +tree ./android/SherpaOnnxKws/app/src/main/assets/
  107 +
  108 +pushd android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx
  109 +sed -i.bak s/"type = 0"/"type = 1"/ ./MainActivity.kt
  110 +git diff
  111 +popd
  112 +
  113 +for arch in arm64-v8a armeabi-v7a x86_64 x86; do
  114 + log "------------------------------------------------------------"
  115 + log "build apk for $arch"
  116 + log "------------------------------------------------------------"
  117 + src_arch=$arch
  118 + if [ $arch == "armeabi-v7a" ]; then
  119 + src_arch=armv7-eabi
  120 + elif [ $arch == "x86_64" ]; then
  121 + src_arch=x86-64
  122 + fi
  123 +
  124 + ls -lh ./build-android-$src_arch/install/lib/*.so
  125 +
  126 + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
  127 +
  128 + pushd ./android/SherpaOnnxKws
  129 + ./gradlew build
  130 + popd
  131 +
  132 + mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-gigaspeech-en-${SHERPA_ONNX_VERSION}-$arch.apk
  133 + ls -lh apks
  134 + rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
  135 +done
  136 +
  137 +git checkout .
  138 +
  139 +rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
@@ -151,6 +151,7 @@ class BuildExtension(build_ext): @@ -151,6 +151,7 @@ class BuildExtension(build_ext):
151 # Remember to also change setup.py 151 # Remember to also change setup.py
152 152
153 binaries = ["sherpa-onnx"] 153 binaries = ["sherpa-onnx"]
  154 + binaries += ["sherpa-onnx-keyword-spotter"]
154 binaries += ["sherpa-onnx-offline"] 155 binaries += ["sherpa-onnx-offline"]
155 binaries += ["sherpa-onnx-microphone"] 156 binaries += ["sherpa-onnx-microphone"]
156 binaries += ["sherpa-onnx-microphone-offline"] 157 binaries += ["sherpa-onnx-microphone-offline"]
@@ -36,13 +36,44 @@ import argparse @@ -36,13 +36,44 @@ import argparse
36 36
37 from sherpa_onnx import text2token 37 from sherpa_onnx import text2token
38 38
  39 +
39 def get_args(): 40 def get_args():
40 parser = argparse.ArgumentParser() 41 parser = argparse.ArgumentParser()
41 parser.add_argument( 42 parser.add_argument(
42 "--text", 43 "--text",
43 type=str, 44 type=str,
44 required=True, 45 required=True,
45 - help="Path to the input texts", 46 + help="""Path to the input texts.
  47 +
  48 + Each line in the texts contains the original phrase, it might also contain some
  49 + extra items, for example, the boosting score (startting with :), the triggering
  50 + threshold (startting with #, only used in keyword spotting task) and the original
  51 + phrase (startting with @). Note: extra items will be kept in the output.
  52 +
  53 + example input 1 (tokens_type = ppinyin):
  54 +
  55 + 小爱同学 :2.0 #0.6 @小爱同学
  56 + 你好问问 :3.5 @你好问问
  57 + 小艺小艺 #0.6 @小艺小艺
  58 +
  59 + example output 1:
  60 +
  61 + x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
  62 + n ǐ h ǎo w èn w èn :3.5 @你好问问
  63 + x iǎo y ì x iǎo y ì #0.6 @小艺小艺
  64 +
  65 + example input 2 (tokens_type = bpe):
  66 +
  67 + HELLO WORLD :1.5 #0.4
  68 + HI GOOGLE :2.0 #0.8
  69 + HEY SIRI #0.35
  70 +
  71 + example output 2:
  72 +
  73 + ▁HE LL O ▁WORLD :1.5 #0.4
  74 + ▁HI ▁GO O G LE :2.0 #0.8
  75 + ▁HE Y ▁S I RI #0.35
  76 + """,
46 ) 77 )
47 78
48 parser.add_argument( 79 parser.add_argument(
@@ -56,7 +87,11 @@ def get_args(): @@ -56,7 +87,11 @@ def get_args():
56 "--tokens-type", 87 "--tokens-type",
57 type=str, 88 type=str,
58 required=True, 89 required=True,
59 - help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", 90 + choices=["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"],
  91 + help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
  92 + fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
  93 + ppinyin means partial pinyin, it splits pinyin into initial and final,
  94 + """,
60 ) 95 )
61 96
62 parser.add_argument( 97 parser.add_argument(
@@ -79,9 +114,21 @@ def main(): @@ -79,9 +114,21 @@ def main():
79 args = get_args() 114 args = get_args()
80 115
81 texts = [] 116 texts = []
  117 + # extra information like boosting score (start with :), triggering threshold (start with #)
  118 + # original keyword (start with @)
  119 + extra_info = []
82 with open(args.text, "r", encoding="utf8") as f: 120 with open(args.text, "r", encoding="utf8") as f:
83 for line in f: 121 for line in f:
84 - texts.append(line.strip()) 122 + extra = []
  123 + text = []
  124 + toks = line.strip().split()
  125 + for tok in toks:
  126 + if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
  127 + extra.append(tok)
  128 + else:
  129 + text.append(tok)
  130 + texts.append(" ".join(text))
  131 + extra_info.append(extra)
85 encoded_texts = text2token( 132 encoded_texts = text2token(
86 texts, 133 texts,
87 tokens=args.tokens, 134 tokens=args.tokens,
@@ -89,7 +136,8 @@ def main(): @@ -89,7 +136,8 @@ def main():
89 bpe_model=args.bpe_model, 136 bpe_model=args.bpe_model,
90 ) 137 )
91 with open(args.output, "w", encoding="utf8") as f: 138 with open(args.output, "w", encoding="utf8") as f:
92 - for txt in encoded_texts: 139 + for i, txt in enumerate(encoded_texts):
  140 + txt += extra_info[i]
93 f.write(" ".join(txt) + "\n") 141 f.write(" ".join(txt) + "\n")
94 142
95 143
@@ -51,6 +51,7 @@ def get_binaries_to_install(): @@ -51,6 +51,7 @@ def get_binaries_to_install():
51 51
52 # Remember to also change cmake/cmake_extension.py 52 # Remember to also change cmake/cmake_extension.py
53 binaries = ["sherpa-onnx"] 53 binaries = ["sherpa-onnx"]
  54 + binaries += ["sherpa-onnx-keyword-spotter"]
54 binaries += ["sherpa-onnx-offline"] 55 binaries += ["sherpa-onnx-offline"]
55 binaries += ["sherpa-onnx-microphone"] 56 binaries += ["sherpa-onnx-microphone"]
56 binaries += ["sherpa-onnx-microphone-offline"] 57 binaries += ["sherpa-onnx-microphone-offline"]
@@ -19,6 +19,8 @@ set(sources @@ -19,6 +19,8 @@ set(sources
19 features.cc 19 features.cc
20 file-utils.cc 20 file-utils.cc
21 hypothesis.cc 21 hypothesis.cc
  22 + keyword-spotter-impl.cc
  23 + keyword-spotter.cc
22 offline-ctc-fst-decoder-config.cc 24 offline-ctc-fst-decoder-config.cc
23 offline-ctc-fst-decoder.cc 25 offline-ctc-fst-decoder.cc
24 offline-ctc-greedy-search-decoder.cc 26 offline-ctc-greedy-search-decoder.cc
@@ -87,6 +89,7 @@ set(sources @@ -87,6 +89,7 @@ set(sources
87 stack.cc 89 stack.cc
88 symbol-table.cc 90 symbol-table.cc
89 text-utils.cc 91 text-utils.cc
  92 + transducer-keyword-decoder.cc
90 transpose.cc 93 transpose.cc
91 unbind.cc 94 unbind.cc
92 utils.cc 95 utils.cc
@@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) @@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
173 endif() 176 endif()
174 177
175 add_executable(sherpa-onnx sherpa-onnx.cc) 178 add_executable(sherpa-onnx sherpa-onnx.cc)
  179 +add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
176 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 180 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
177 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 181 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
178 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 182 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
179 183
180 set(main_exes 184 set(main_exes
181 sherpa-onnx 185 sherpa-onnx
  186 + sherpa-onnx-keyword-spotter
182 sherpa-onnx-offline 187 sherpa-onnx-offline
183 sherpa-onnx-offline-parallel 188 sherpa-onnx-offline-parallel
184 sherpa-onnx-offline-tts 189 sherpa-onnx-offline-tts
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #include "sherpa-onnx/csrc/context-graph.h" 5 #include "sherpa-onnx/csrc/context-graph.h"
6 6
7 #include <chrono> // NOLINT 7 #include <chrono> // NOLINT
  8 +#include <cmath>
8 #include <map> 9 #include <map>
9 #include <random> 10 #include <random>
10 #include <string> 11 #include <string>
@@ -15,27 +16,25 @@ @@ -15,27 +16,25 @@
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
17 18
18 -TEST(ContextGraph, TestBasic) { 19 +static void TestHelper(const std::map<std::string, float> &queries, float score,
  20 + bool strict_mode) {
19 std::vector<std::string> contexts_str( 21 std::vector<std::string> contexts_str(
20 {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); 22 {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
21 std::vector<std::vector<int32_t>> contexts; 23 std::vector<std::vector<int32_t>> contexts;
  24 + std::vector<float> scores;
22 for (int32_t i = 0; i < contexts_str.size(); ++i) { 25 for (int32_t i = 0; i < contexts_str.size(); ++i) {
23 contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); 26 contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
  27 + scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
24 } 28 }
25 - auto context_graph = ContextGraph(contexts, 1);  
26 -  
27 - auto queries = std::map<std::string, float>{  
28 - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},  
29 - {"SHED", 6}, {"SHELF", 6}, {"HELL", 2},  
30 - {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; 29 + auto context_graph = ContextGraph(contexts, 1, scores);
31 30
32 for (const auto &iter : queries) { 31 for (const auto &iter : queries) {
33 float total_scores = 0; 32 float total_scores = 0;
34 auto state = context_graph.Root(); 33 auto state = context_graph.Root();
35 for (auto q : iter.first) { 34 for (auto q : iter.first) {
36 - auto res = context_graph.ForwardOneStep(state, q);  
37 - total_scores += res.first;  
38 - state = res.second; 35 + auto res = context_graph.ForwardOneStep(state, q, strict_mode);
  36 + total_scores += std::get<0>(res);
  37 + state = std::get<1>(res);
39 } 38 }
40 auto res = context_graph.Finalize(state); 39 auto res = context_graph.Finalize(state);
41 EXPECT_EQ(res.second->token, -1); 40 EXPECT_EQ(res.second->token, -1);
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) { @@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
44 } 43 }
45 } 44 }
46 45
  46 +TEST(ContextGraph, TestBasic) {
  47 + auto queries = std::map<std::string, float>{
  48 + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
  49 + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
  50 + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
  51 + TestHelper(queries, 0, true);
  52 +}
  53 +
  54 +TEST(ContextGraph, TestBasicNonStrict) {
  55 + auto queries = std::map<std::string, float>{
  56 + {"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
  57 + {"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
  58 + TestHelper(queries, 0, false);
  59 +}
  60 +
  61 +TEST(ContextGraph, TestCustomize) {
  62 + auto queries = std::map<std::string, float>{
  63 + {"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
  64 + {"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
  65 + {"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
  66 + TestHelper(queries, 5, true);
  67 +}
  68 +
  69 +TEST(ContextGraph, TestCustomizeNonStrict) {
  70 + auto queries = std::map<std::string, float>{
  71 + {"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
  72 + {"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
  73 + {"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
  74 + TestHelper(queries, 5, false);
  75 +}
  76 +
47 TEST(ContextGraph, Benchmark) { 77 TEST(ContextGraph, Benchmark) {
48 std::random_device rd; 78 std::random_device rd;
49 std::mt19937 mt(rd()); 79 std::mt19937 mt(rd());
@@ -4,22 +4,59 @@ @@ -4,22 +4,59 @@
4 4
5 #include "sherpa-onnx/csrc/context-graph.h" 5 #include "sherpa-onnx/csrc/context-graph.h"
6 6
  7 +#include <algorithm>
7 #include <cassert> 8 #include <cassert>
8 #include <queue> 9 #include <queue>
  10 +#include <string>
  11 +#include <tuple>
9 #include <utility> 12 #include <utility>
10 13
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +
11 namespace sherpa_onnx { 16 namespace sherpa_onnx {
12 -void ContextGraph::Build(  
13 - const std::vector<std::vector<int32_t>> &token_ids) const { 17 +void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
  18 + const std::vector<float> &scores,
  19 + const std::vector<std::string> &phrases,
  20 + const std::vector<float> &ac_thresholds) const {
  21 + if (!scores.empty()) {
  22 + SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
  23 + }
  24 + if (!phrases.empty()) {
  25 + SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
  26 + }
  27 + if (!ac_thresholds.empty()) {
  28 + SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
  29 + }
14 for (int32_t i = 0; i < token_ids.size(); ++i) { 30 for (int32_t i = 0; i < token_ids.size(); ++i) {
15 auto node = root_.get(); 31 auto node = root_.get();
  32 + float score = scores.empty() ? 0.0f : scores[i];
  33 + score = score == 0.0f ? context_score_ : score;
  34 + float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
  35 + ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
  36 + std::string phrase = phrases.empty() ? std::string() : phrases[i];
  37 +
16 for (int32_t j = 0; j < token_ids[i].size(); ++j) { 38 for (int32_t j = 0; j < token_ids[i].size(); ++j) {
17 int32_t token = token_ids[i][j]; 39 int32_t token = token_ids[i][j];
18 if (0 == node->next.count(token)) { 40 if (0 == node->next.count(token)) {
19 bool is_end = j == token_ids[i].size() - 1; 41 bool is_end = j == token_ids[i].size() - 1;
20 node->next[token] = std::make_unique<ContextState>( 42 node->next[token] = std::make_unique<ContextState>(
21 - token, context_score_, node->node_score + context_score_,  
22 - is_end ? node->node_score + context_score_ : 0, is_end); 43 + token, score, node->node_score + score,
  44 + is_end ? node->node_score + score : 0, j + 1,
  45 + is_end ? ac_threshold : 0.0f, is_end,
  46 + is_end ? phrase : std::string());
  47 + } else {
  48 + float token_score = std::max(score, node->next[token]->token_score);
  49 + node->next[token]->token_score = token_score;
  50 + float node_score = node->node_score + token_score;
  51 + node->next[token]->node_score = node_score;
  52 + bool is_end =
  53 + (j == token_ids[i].size() - 1) || node->next[token]->is_end;
  54 + node->next[token]->output_score = is_end ? node_score : 0.0f;
  55 + node->next[token]->is_end = is_end;
  56 + if (j == token_ids[i].size() - 1) {
  57 + node->next[token]->phrase = phrase;
  58 + node->next[token]->ac_threshold = ac_threshold;
  59 + }
23 } 60 }
24 node = node->next[token].get(); 61 node = node->next[token].get();
25 } 62 }
@@ -27,8 +64,9 @@ void ContextGraph::Build( @@ -27,8 +64,9 @@ void ContextGraph::Build(
27 FillFailOutput(); 64 FillFailOutput();
28 } 65 }
29 66
30 -std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(  
31 - const ContextState *state, int32_t token) const { 67 +std::tuple<float, const ContextState *, const ContextState *>
  68 +ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
  69 + bool strict_mode /*= true*/) const {
32 const ContextState *node; 70 const ContextState *node;
33 float score; 71 float score;
34 if (1 == state->next.count(token)) { 72 if (1 == state->next.count(token)) {
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( @@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
45 } 83 }
46 score = node->node_score - state->node_score; 84 score = node->node_score - state->node_score;
47 } 85 }
  86 +
48 SHERPA_ONNX_CHECK(nullptr != node); 87 SHERPA_ONNX_CHECK(nullptr != node);
49 - return std::make_pair(score + node->output_score, node); 88 +
  89 + const ContextState *matched_node =
  90 + node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
  91 +
  92 + if (!strict_mode && node->output_score != 0) {
  93 + SHERPA_ONNX_CHECK(nullptr != matched_node);
  94 + float output_score =
  95 + node->is_end ? node->node_score
  96 + : (node->output != nullptr ? node->output->node_score
  97 + : node->node_score);
  98 + return std::make_tuple(score + output_score - node->node_score, root_.get(),
  99 + matched_node);
  100 + }
  101 + return std::make_tuple(score + node->output_score, node, matched_node);
50 } 102 }
51 103
52 std::pair<float, const ContextState *> ContextGraph::Finalize( 104 std::pair<float, const ContextState *> ContextGraph::Finalize(
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize( @@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
55 return std::make_pair(score, root_.get()); 107 return std::make_pair(score, root_.get());
56 } 108 }
57 109
  110 +std::pair<bool, const ContextState *> ContextGraph::IsMatched(
  111 + const ContextState *state) const {
  112 + bool status = false;
  113 + const ContextState *node = nullptr;
  114 + if (state->is_end) {
  115 + status = true;
  116 + node = state;
  117 + } else {
  118 + if (state->output != nullptr) {
  119 + status = true;
  120 + node = state->output;
  121 + }
  122 + }
  123 + return std::make_pair(status, node);
  124 +}
  125 +
58 void ContextGraph::FillFailOutput() const { 126 void ContextGraph::FillFailOutput() const {
59 std::queue<const ContextState *> node_queue; 127 std::queue<const ContextState *> node_queue;
60 for (auto &kv : root_->next) { 128 for (auto &kv : root_->next) {
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 #define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ 6 #define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <string>
  10 +#include <tuple>
9 #include <unordered_map> 11 #include <unordered_map>
10 #include <utility> 12 #include <utility>
11 #include <vector> 13 #include <vector>
@@ -22,34 +24,55 @@ struct ContextState { @@ -22,34 +24,55 @@ struct ContextState {
22 float token_score; 24 float token_score;
23 float node_score; 25 float node_score;
24 float output_score; 26 float output_score;
  27 + int32_t level;
  28 + float ac_threshold;
25 bool is_end; 29 bool is_end;
  30 + std::string phrase;
26 std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; 31 std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
27 const ContextState *fail = nullptr; 32 const ContextState *fail = nullptr;
28 const ContextState *output = nullptr; 33 const ContextState *output = nullptr;
29 34
30 ContextState() = default; 35 ContextState() = default;
31 ContextState(int32_t token, float token_score, float node_score, 36 ContextState(int32_t token, float token_score, float node_score,
32 - float output_score, bool is_end) 37 + float output_score, int32_t level = 0, float ac_threshold = 0.0f,
  38 + bool is_end = false, const std::string &phrase = {})
33 : token(token), 39 : token(token),
34 token_score(token_score), 40 token_score(token_score),
35 node_score(node_score), 41 node_score(node_score),
36 output_score(output_score), 42 output_score(output_score),
37 - is_end(is_end) {} 43 + level(level),
  44 + ac_threshold(ac_threshold),
  45 + is_end(is_end),
  46 + phrase(phrase) {}
38 }; 47 };
39 48
40 class ContextGraph { 49 class ContextGraph {
41 public: 50 public:
42 ContextGraph() = default; 51 ContextGraph() = default;
43 ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, 52 ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
44 - float context_score)  
45 - : context_score_(context_score) {  
46 - root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false); 53 + float context_score, float ac_threshold,
  54 + const std::vector<float> &scores = {},
  55 + const std::vector<std::string> &phrases = {},
  56 + const std::vector<float> &ac_thresholds = {})
  57 + : context_score_(context_score), ac_threshold_(ac_threshold) {
  58 + root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
47 root_->fail = root_.get(); 59 root_->fail = root_.get();
48 - Build(token_ids); 60 + Build(token_ids, scores, phrases, ac_thresholds);
49 } 61 }
50 62
51 - std::pair<float, const ContextState *> ForwardOneStep(  
52 - const ContextState *state, int32_t token_id) const; 63 + ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
  64 + float context_score, const std::vector<float> &scores = {},
  65 + const std::vector<std::string> &phrases = {})
  66 + : ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
  67 + std::vector<float>()) {}
  68 +
  69 + std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
  70 + const ContextState *state, int32_t token_id,
  71 + bool strict_mode = true) const;
  72 +
  73 + std::pair<bool, const ContextState *> IsMatched(
  74 + const ContextState *state) const;
  75 +
53 std::pair<float, const ContextState *> Finalize( 76 std::pair<float, const ContextState *> Finalize(
54 const ContextState *state) const; 77 const ContextState *state) const;
55 78
@@ -57,8 +80,12 @@ class ContextGraph { @@ -57,8 +80,12 @@ class ContextGraph {
57 80
58 private: 81 private:
59 float context_score_; 82 float context_score_;
  83 + float ac_threshold_;
60 std::unique_ptr<ContextState> root_; 84 std::unique_ptr<ContextState> root_;
61 - void Build(const std::vector<std::vector<int32_t>> &token_ids) const; 85 + void Build(const std::vector<std::vector<int32_t>> &token_ids,
  86 + const std::vector<float> &scores,
  87 + const std::vector<std::string> &phrases,
  88 + const std::vector<float> &ac_thresholds) const;
62 void FillFailOutput() const; 89 void FillFailOutput() const;
63 }; 90 };
64 91
@@ -28,6 +28,10 @@ struct Hypothesis { @@ -28,6 +28,10 @@ struct Hypothesis {
28 // on which ys[i] is decoded. 28 // on which ys[i] is decoded.
29 std::vector<int32_t> timestamps; 29 std::vector<int32_t> timestamps;
30 30
  31 + // The acoustic probability for each token in ys.
  32 + // Only used for keyword spotting task.
  33 + std::vector<float> ys_probs;
  34 +
31 // The total score of ys in log space. 35 // The total score of ys in log space.
32 // It contains only acoustic scores 36 // It contains only acoustic scores
33 double log_prob = 0; 37 double log_prob = 0;
  1 +// sherpa-onnx/csrc/keyword-spotter-impl.cc
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
  6 +
  7 +#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
  12 + const KeywordSpotterConfig &config) {
  13 + if (!config.model_config.transducer.encoder.empty()) {
  14 + return std::make_unique<KeywordSpotterTransducerImpl>(config);
  15 + }
  16 +
  17 + SHERPA_ONNX_LOGE("Please specify a model");
  18 + exit(-1);
  19 +}
  20 +
  21 +#if __ANDROID_API__ >= 9
  22 +std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
  23 + AAssetManager *mgr, const KeywordSpotterConfig &config) {
  24 + if (!config.model_config.transducer.encoder.empty()) {
  25 + return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
  26 + }
  27 +
  28 + SHERPA_ONNX_LOGE("Please specify a model");
  29 + exit(-1);
  30 +}
  31 +#endif
  32 +
  33 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/keyword-spotter-impl.h
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  18 +#include "sherpa-onnx/csrc/online-stream.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +class KeywordSpotterImpl {
  23 + public:
  24 + static std::unique_ptr<KeywordSpotterImpl> Create(
  25 + const KeywordSpotterConfig &config);
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 + static std::unique_ptr<KeywordSpotterImpl> Create(
  29 + AAssetManager *mgr, const KeywordSpotterConfig &config);
  30 +#endif
  31 +
  32 + virtual ~KeywordSpotterImpl() = default;
  33 +
  34 + virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
  35 +
  36 + virtual std::unique_ptr<OnlineStream> CreateStream(
  37 + const std::string &keywords) const = 0;
  38 +
  39 + virtual bool IsReady(OnlineStream *s) const = 0;
  40 +
  41 + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
  42 +
  43 + virtual KeywordResult GetResult(OnlineStream *s) const = 0;
  44 +};
  45 +
  46 +} // namespace sherpa_onnx
  47 +
  48 +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
  1 +// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <regex> // NOLINT
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include <strstream>
  17 +
  18 +#include "android/asset_manager.h"
  19 +#include "android/asset_manager_jni.h"
  20 +#endif
  21 +
  22 +#include "sherpa-onnx/csrc/file-utils.h"
  23 +#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
  24 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  25 +#include "sherpa-onnx/csrc/macros.h"
  26 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  27 +#include "sherpa-onnx/csrc/symbol-table.h"
  28 +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
  29 +#include "sherpa-onnx/csrc/utils.h"
  30 +
  31 +namespace sherpa_onnx {
  32 +
  33 +static KeywordResult Convert(const TransducerKeywordResult &src,
  34 + const SymbolTable &sym_table, float frame_shift_ms,
  35 + int32_t subsampling_factor,
  36 + int32_t frames_since_start) {
  37 + KeywordResult r;
  38 + r.tokens.reserve(src.tokens.size());
  39 + r.timestamps.reserve(src.tokens.size());
  40 + r.keyword = src.keyword;
  41 + bool from_tokens = src.keyword.empty();
  42 +
  43 + for (auto i : src.tokens) {
  44 + auto sym = sym_table[i];
  45 + if (from_tokens) {
  46 + r.keyword.append(sym);
  47 + }
  48 + r.tokens.push_back(std::move(sym));
  49 + }
  50 + if (from_tokens && r.keyword.size()) {
  51 + r.keyword = r.keyword.substr(1);
  52 + }
  53 +
  54 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  55 + for (auto t : src.timestamps) {
  56 + float time = frame_shift_s * t;
  57 + r.timestamps.push_back(time);
  58 + }
  59 +
  60 + r.start_time = frames_since_start * frame_shift_ms / 1000.;
  61 +
  62 + return r;
  63 +}
  64 +
  65 +class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
  66 + public:
  67 + explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
  68 + : config_(config),
  69 + model_(OnlineTransducerModel::Create(config.model_config)),
  70 + sym_(config.model_config.tokens) {
  71 + if (sym_.contains("<unk>")) {
  72 + unk_id_ = sym_["<unk>"];
  73 + }
  74 +
  75 + InitKeywords();
  76 +
  77 + decoder_ = std::make_unique<TransducerKeywordDecoder>(
  78 + model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
  79 + unk_id_);
  80 + }
  81 +
  82 +#if __ANDROID_API__ >= 9
  83 + KeywordSpotterTransducerImpl(AAssetManager *mgr,
  84 + const KeywordSpotterConfig &config)
  85 + : config_(config),
  86 + model_(OnlineTransducerModel::Create(mgr, config.model_config)),
  87 + sym_(mgr, config.model_config.tokens) {
  88 + if (sym_.contains("<unk>")) {
  89 + unk_id_ = sym_["<unk>"];
  90 + }
  91 +
  92 + InitKeywords(mgr);
  93 +
  94 + decoder_ = std::make_unique<TransducerKeywordDecoder>(
  95 + model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
  96 + unk_id_);
  97 + }
  98 +#endif
  99 +
  100 + std::unique_ptr<OnlineStream> CreateStream() const override {
  101 + auto stream =
  102 + std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
  103 + InitOnlineStream(stream.get());
  104 + return stream;
  105 + }
  106 +
  107 + std::unique_ptr<OnlineStream> CreateStream(
  108 + const std::string &keywords) const override {
  109 + auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
  110 + std::istringstream is(kws);
  111 +
  112 + std::vector<std::vector<int32_t>> current_ids;
  113 + std::vector<std::string> current_kws;
  114 + std::vector<float> current_scores;
  115 + std::vector<float> current_thresholds;
  116 +
  117 + if (!EncodeKeywords(is, sym_, &current_ids, &current_kws, &current_scores,
  118 + &current_thresholds)) {
  119 + SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
  120 + return nullptr;
  121 + }
  122 +
  123 + int32_t num_kws = current_ids.size();
  124 + int32_t num_default_kws = keywords_id_.size();
  125 +
  126 + current_ids.insert(current_ids.end(), keywords_id_.begin(),
  127 + keywords_id_.end());
  128 +
  129 + if (!current_kws.empty() && !keywords_.empty()) {
  130 + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
  131 + } else if (!current_kws.empty() && keywords_.empty()) {
  132 + current_kws.insert(current_kws.end(), num_default_kws, std::string());
  133 + } else if (current_kws.empty() && !keywords_.empty()) {
  134 + current_kws.insert(current_kws.end(), num_kws, std::string());
  135 + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
  136 + } else {
  137 + // Do nothing.
  138 + }
  139 +
  140 + if (!current_scores.empty() && !boost_scores_.empty()) {
  141 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  142 + boost_scores_.end());
  143 + } else if (!current_scores.empty() && boost_scores_.empty()) {
  144 + current_scores.insert(current_scores.end(), num_default_kws,
  145 + config_.keywords_score);
  146 + } else if (current_scores.empty() && !boost_scores_.empty()) {
  147 + current_scores.insert(current_scores.end(), num_kws,
  148 + config_.keywords_score);
  149 + current_scores.insert(current_scores.end(), boost_scores_.begin(),
  150 + boost_scores_.end());
  151 + } else {
  152 + // Do nothing.
  153 + }
  154 +
  155 + if (!current_thresholds.empty() && !thresholds_.empty()) {
  156 + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
  157 + thresholds_.end());
  158 + } else if (!current_thresholds.empty() && thresholds_.empty()) {
  159 + current_thresholds.insert(current_thresholds.end(), num_default_kws,
  160 + config_.keywords_threshold);
  161 + } else if (current_thresholds.empty() && !thresholds_.empty()) {
  162 + current_thresholds.insert(current_thresholds.end(), num_kws,
  163 + config_.keywords_threshold);
  164 + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
  165 + thresholds_.end());
  166 + } else {
  167 + // Do nothing.
  168 + }
  169 +
  170 + auto keywords_graph = std::make_shared<ContextGraph>(
  171 + current_ids, config_.keywords_score, config_.keywords_threshold,
  172 + current_scores, current_kws, current_thresholds);
  173 +
  174 + auto stream =
  175 + std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
  176 + InitOnlineStream(stream.get());
  177 + return stream;
  178 + }
  179 +
  180 + bool IsReady(OnlineStream *s) const override {
  181 + return s->GetNumProcessedFrames() + model_->ChunkSize() <
  182 + s->NumFramesReady();
  183 + }
  184 +
  185 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  186 + int32_t chunk_size = model_->ChunkSize();
  187 + int32_t chunk_shift = model_->ChunkShift();
  188 +
  189 + int32_t feature_dim = ss[0]->FeatureDim();
  190 +
  191 + std::vector<TransducerKeywordResult> results(n);
  192 + std::vector<float> features_vec(n * chunk_size * feature_dim);
  193 + std::vector<std::vector<Ort::Value>> states_vec(n);
  194 + std::vector<int64_t> all_processed_frames(n);
  195 +
  196 + for (int32_t i = 0; i != n; ++i) {
  197 + SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
  198 +
  199 + const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
  200 + std::vector<float> features =
  201 + ss[i]->GetFrames(num_processed_frames, chunk_size);
  202 +
  203 + // Question: should num_processed_frames include chunk_shift?
  204 + ss[i]->GetNumProcessedFrames() += chunk_shift;
  205 +
  206 + std::copy(features.begin(), features.end(),
  207 + features_vec.data() + i * chunk_size * feature_dim);
  208 +
  209 + results[i] = std::move(ss[i]->GetKeywordResult());
  210 + states_vec[i] = std::move(ss[i]->GetStates());
  211 + all_processed_frames[i] = num_processed_frames;
  212 + }
  213 +
  214 + auto memory_info =
  215 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  216 +
  217 + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
  218 +
  219 + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
  220 + features_vec.size(), x_shape.data(),
  221 + x_shape.size());
  222 +
  223 + std::array<int64_t, 1> processed_frames_shape{
  224 + static_cast<int64_t>(all_processed_frames.size())};
  225 +
  226 + Ort::Value processed_frames = Ort::Value::CreateTensor(
  227 + memory_info, all_processed_frames.data(), all_processed_frames.size(),
  228 + processed_frames_shape.data(), processed_frames_shape.size());
  229 +
  230 + auto states = model_->StackStates(states_vec);
  231 +
  232 + auto pair = model_->RunEncoder(std::move(x), std::move(states),
  233 + std::move(processed_frames));
  234 +
  235 + decoder_->Decode(std::move(pair.first), ss, &results);
  236 +
  237 + std::vector<std::vector<Ort::Value>> next_states =
  238 + model_->UnStackStates(pair.second);
  239 +
  240 + for (int32_t i = 0; i != n; ++i) {
  241 + ss[i]->SetKeywordResult(results[i]);
  242 + ss[i]->SetStates(std::move(next_states[i]));
  243 + }
  244 + }
  245 +
  246 + KeywordResult GetResult(OnlineStream *s) const override {
  247 + TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
  248 +
  249 + // TODO(fangjun): Remember to change these constants if needed
  250 + int32_t frame_shift_ms = 10;
  251 + int32_t subsampling_factor = 4;
  252 + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
  253 + s->GetNumFramesSinceStart());
  254 + }
  255 +
  256 + private:
  257 + void InitKeywords(std::istream &is) {
  258 + if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
  259 + &thresholds_)) {
  260 + SHERPA_ONNX_LOGE("Encode keywords failed.");
  261 + exit(-1);
  262 + }
  263 + keywords_graph_ = std::make_shared<ContextGraph>(
  264 + keywords_id_, config_.keywords_score, config_.keywords_threshold,
  265 + boost_scores_, keywords_, thresholds_);
  266 + }
  267 +
  268 + void InitKeywords() {
  269 + // each line in keywords_file contains space-separated words
  270 +
  271 + std::ifstream is(config_.keywords_file);
  272 + if (!is) {
  273 + SHERPA_ONNX_LOGE("Open keywords file failed: %s",
  274 + config_.keywords_file.c_str());
  275 + exit(-1);
  276 + }
  277 + InitKeywords(is);
  278 + }
  279 +
  280 +#if __ANDROID_API__ >= 9
  281 + void InitKeywords(AAssetManager *mgr) {
  282 + // each line in keywords_file contains space-separated words
  283 +
  284 + auto buf = ReadFile(mgr, config_.keywords_file);
  285 +
  286 + std::istrstream is(buf.data(), buf.size());
  287 +
  288 + if (!is) {
  289 + SHERPA_ONNX_LOGE("Open keywords file failed: %s",
  290 + config_.keywords_file.c_str());
  291 + exit(-1);
  292 + }
  293 + InitKeywords(is);
  294 + }
  295 +#endif
  296 +
  297 + void InitOnlineStream(OnlineStream *stream) const {
  298 + auto r = decoder_->GetEmptyResult();
  299 + SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
  300 +
  301 + SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
  302 + r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
  303 +
  304 + stream->SetKeywordResult(r);
  305 + stream->SetStates(model_->GetEncoderInitStates());
  306 + }
  307 +
  308 + private:
  309 + KeywordSpotterConfig config_;
  310 + std::vector<std::vector<int32_t>> keywords_id_;
  311 + std::vector<float> boost_scores_;
  312 + std::vector<float> thresholds_;
  313 + std::vector<std::string> keywords_;
  314 + ContextGraphPtr keywords_graph_;
  315 + std::unique_ptr<OnlineTransducerModel> model_;
  316 + std::unique_ptr<TransducerKeywordDecoder> decoder_;
  317 + SymbolTable sym_;
  318 + int32_t unk_id_ = -1;
  319 +};
  320 +
  321 +} // namespace sherpa_onnx
  322 +
  323 +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
  1 +// sherpa-onnx/csrc/keyword-spotter.cc
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <fstream>
  11 +#include <iomanip>
  12 +#include <memory>
  13 +#include <sstream>
  14 +#include <utility>
  15 +#include <vector>
  16 +
  17 +#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +std::string KeywordResult::AsJsonString() const {
  22 + std::ostringstream os;
  23 + os << "{";
  24 + os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
  25 + << ", ";
  26 +
  27 + os << "\"keyword\""
  28 + << ": ";
  29 + os << "\"" << keyword << "\""
  30 + << ", ";
  31 +
  32 + os << "\""
  33 + << "timestamps"
  34 + << "\""
  35 + << ": ";
  36 + os << "[";
  37 +
  38 + std::string sep = "";
  39 + for (auto t : timestamps) {
  40 + os << sep << std::fixed << std::setprecision(2) << t;
  41 + sep = ", ";
  42 + }
  43 + os << "], ";
  44 +
  45 + os << "\""
  46 + << "tokens"
  47 + << "\""
  48 + << ":";
  49 + os << "[";
  50 +
  51 + sep = "";
  52 + auto oldFlags = os.flags();
  53 + for (const auto &t : tokens) {
  54 + if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
  55 + const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
  56 + os << sep << "\""
  57 + << "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
  58 + << ">"
  59 + << "\"";
  60 + os.flags(oldFlags);
  61 + } else {
  62 + os << sep << "\"" << t << "\"";
  63 + }
  64 + sep = ", ";
  65 + }
  66 + os << "]";
  67 + os << "}";
  68 +
  69 + return os.str();
  70 +}
  71 +
  72 +void KeywordSpotterConfig::Register(ParseOptions *po) {
  73 + feat_config.Register(po);
  74 + model_config.Register(po);
  75 +
  76 + po->Register("max-active-paths", &max_active_paths,
  77 + "beam size used in modified beam search.");
  78 + po->Register("num-trailing-blanks", &num_trailing_blanks,
  79 + "The number of trailing blanks should have after the keyword.");
  80 + po->Register("keywords-score", &keywords_score,
  81 + "The bonus score for each token in context word/phrase.");
  82 + po->Register("keywords-threshold", &keywords_threshold,
  83 + "The acoustic threshold (probability) to trigger the keywords.");
  84 + po->Register(
  85 + "keywords-file", &keywords_file,
  86 + "The file containing keywords, one word/phrase per line, and for each"
  87 + "phrase the bpe/cjkchar are separated by a space. For example: "
  88 + "▁HE LL O ▁WORLD"
  89 + "你 好 世 界");
  90 +}
  91 +
  92 +bool KeywordSpotterConfig::Validate() const {
  93 + if (keywords_file.empty()) {
  94 + SHERPA_ONNX_LOGE("Please provide --keywords-file.");
  95 + return false;
  96 + }
  97 + if (!std::ifstream(keywords_file.c_str()).good()) {
  98 + SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
  99 + return false;
  100 + }
  101 +
  102 + return model_config.Validate();
  103 +}
  104 +
  105 +std::string KeywordSpotterConfig::ToString() const {
  106 + std::ostringstream os;
  107 +
  108 + os << "KeywordSpotterConfig(";
  109 + os << "feat_config=" << feat_config.ToString() << ", ";
  110 + os << "model_config=" << model_config.ToString() << ", ";
  111 + os << "max_active_paths=" << max_active_paths << ", ";
  112 + os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
  113 + os << "keywords_score=" << keywords_score << ", ";
  114 + os << "keywords_threshold=" << keywords_threshold << ", ";
  115 + os << "keywords_file=\"" << keywords_file << "\")";
  116 +
  117 + return os.str();
  118 +}
  119 +
  120 +KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
  121 + : impl_(KeywordSpotterImpl::Create(config)) {}
  122 +
  123 +#if __ANDROID_API__ >= 9
  124 +KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
  125 + const KeywordSpotterConfig &config)
  126 + : impl_(KeywordSpotterImpl::Create(mgr, config)) {}
  127 +#endif
  128 +
  129 +KeywordSpotter::~KeywordSpotter() = default;
  130 +
  131 +std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
  132 + return impl_->CreateStream();
  133 +}
  134 +
  135 +std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
  136 + const std::string &keywords) const {
  137 + return impl_->CreateStream(keywords);
  138 +}
  139 +
  140 +bool KeywordSpotter::IsReady(OnlineStream *s) const {
  141 + return impl_->IsReady(s);
  142 +}
  143 +
  144 +void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
  145 + impl_->DecodeStreams(ss, n);
  146 +}
  147 +
  148 +KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
  149 + return impl_->GetResult(s);
  150 +}
  151 +
  152 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/keyword-spotter.h
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
  6 +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "sherpa-onnx/csrc/features.h"
  18 +#include "sherpa-onnx/csrc/online-model-config.h"
  19 +#include "sherpa-onnx/csrc/online-stream.h"
  20 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  21 +#include "sherpa-onnx/csrc/parse-options.h"
  22 +
  23 +namespace sherpa_onnx {
  24 +
  25 +struct KeywordResult {
  26 + /// The triggered keyword.
  27 + /// For English, it consists of space separated words.
  28 + /// For Chinese, it consists of Chinese words without spaces.
  29 + /// Example 1: "hello world"
  30 + /// Example 2: "你好世界"
  31 + std::string keyword;
  32 +
  33 + /// Decoded results at the token level.
  34 + /// For instance, for BPE-based models it consists of a list of BPE tokens.
  35 + std::vector<std::string> tokens;
  36 +
  37 + /// timestamps.size() == tokens.size()
  38 + /// timestamps[i] records the time in seconds when tokens[i] is decoded.
  39 + std::vector<float> timestamps;
  40 +
  41 + /// Starting time of this segment.
  42 + /// When an endpoint is detected, it will change
  43 + float start_time = 0;
  44 +
  45 + /** Return a json string.
  46 + *
  47 + * The returned string contains:
  48 + * {
  49 + * "keyword": "The triggered keyword",
  50 + * "tokens": [x, x, x],
  51 + * "timestamps": [x, x, x],
  52 + * "start_time": x,
  53 + * }
  54 + */
  55 + std::string AsJsonString() const;
  56 +};
  57 +
  58 +struct KeywordSpotterConfig {
  59 + FeatureExtractorConfig feat_config;
  60 + OnlineModelConfig model_config;
  61 +
  62 + int32_t max_active_paths = 4;
  63 +
  64 + int32_t num_trailing_blanks = 1;
  65 +
  66 + float keywords_score = 1.0;
  67 +
  68 + float keywords_threshold = 0.25;
  69 +
  70 + std::string keywords_file;
  71 +
  72 + KeywordSpotterConfig() = default;
  73 +
  74 + KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
  75 + const OnlineModelConfig &model_config,
  76 + int32_t max_active_paths, int32_t num_trailing_blanks,
  77 + float keywords_score, float keywords_threshold,
  78 + const std::string &keywords_file)
  79 + : feat_config(feat_config),
  80 + model_config(model_config),
  81 + max_active_paths(max_active_paths),
  82 + num_trailing_blanks(num_trailing_blanks),
  83 + keywords_score(keywords_score),
  84 + keywords_threshold(keywords_threshold),
  85 + keywords_file(keywords_file) {}
  86 +
  87 + void Register(ParseOptions *po);
  88 + bool Validate() const;
  89 +
  90 + std::string ToString() const;
  91 +};
  92 +
  93 +class KeywordSpotterImpl;
  94 +
  95 +class KeywordSpotter {
  96 + public:
  97 + explicit KeywordSpotter(const KeywordSpotterConfig &config);
  98 +
  99 +#if __ANDROID_API__ >= 9
  100 + KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
  101 +#endif
  102 +
  103 + ~KeywordSpotter();
  104 +
  105 + /** Create a stream for decoding.
  106 + *
  107 + */
  108 + std::unique_ptr<OnlineStream> CreateStream() const;
  109 +
  110 + /** Create a stream for decoding.
  111 + *
  112 + * @param The keywords for this string, it might contain several keywords,
  113 + * the keywords are separated by "/". In each of the keywords, there
  114 + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
  115 + * For example, keywords I LOVE YOU and HELLO WORLD, looks like:
  116 + *
  117 + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
  118 + */
  119 + std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
  120 +
  121 + /**
  122 + * Return true if the given stream has enough frames for decoding.
  123 + * Return false otherwise
  124 + */
  125 + bool IsReady(OnlineStream *s) const;
  126 +
  127 + /** Decode a single stream. */
  128 + void DecodeStream(OnlineStream *s) const {
  129 + OnlineStream *ss[1] = {s};
  130 + DecodeStreams(ss, 1);
  131 + }
  132 +
  133 + /** Decode multiple streams in parallel
  134 + *
  135 + * @param ss Pointer array containing streams to be decoded.
  136 + * @param n Number of streams in `ss`.
  137 + */
  138 + void DecodeStreams(OnlineStream **ss, int32_t n) const;
  139 +
  140 + KeywordResult GetResult(OnlineStream *s) const;
  141 +
  142 + private:
  143 + std::unique_ptr<KeywordSpotterImpl> impl_;
  144 +};
  145 +
  146 +} // namespace sherpa_onnx
  147 +
  148 +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
@@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
93 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); 93 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
94 // now cur_encoder_out is of shape (num_hyps, joiner_dim) 94 // now cur_encoder_out is of shape (num_hyps, joiner_dim)
95 95
96 - Ort::Value logit = model_->RunJoiner(  
97 - std::move(cur_encoder_out), View(&decoder_out)); 96 + Ort::Value logit =
  97 + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
98 98
99 float *p_logit = logit.GetTensorMutableData<float>(); 99 float *p_logit = logit.GetTensorMutableData<float>();
100 LogSoftmax(p_logit, vocab_size, num_hyps); 100 LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
134 if (context_graphs[i] != nullptr) { 134 if (context_graphs[i] != nullptr) {
135 auto context_res = 135 auto context_res =
136 context_graphs[i]->ForwardOneStep(context_state, new_token); 136 context_graphs[i]->ForwardOneStep(context_state, new_token);
137 - context_score = context_res.first;  
138 - new_hyp.context_state = context_res.second; 137 + context_score = std::get<0>(context_res);
  138 + new_hyp.context_state = std::get<1>(context_res);
139 } 139 }
140 } 140 }
141 141
@@ -51,6 +51,25 @@ class OnlineStream::Impl { @@ -51,6 +51,25 @@ class OnlineStream::Impl {
51 51
52 OnlineTransducerDecoderResult &GetResult() { return result_; } 52 OnlineTransducerDecoderResult &GetResult() { return result_; }
53 53
  54 + void SetKeywordResult(const TransducerKeywordResult &r) {
  55 + keyword_result_ = r;
  56 + }
  57 + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
  58 + if (remove_duplicates) {
  59 + if (!prev_keyword_result_.timestamps.empty() &&
  60 + !keyword_result_.timestamps.empty() &&
  61 + keyword_result_.timestamps[0] <=
  62 + prev_keyword_result_.timestamps.back()) {
  63 + return empty_keyword_result_;
  64 + } else {
  65 + prev_keyword_result_ = keyword_result_;
  66 + }
  67 + return keyword_result_;
  68 + } else {
  69 + return keyword_result_;
  70 + }
  71 + }
  72 +
54 OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } 73 OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
55 74
56 void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } 75 void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
@@ -93,6 +112,9 @@ class OnlineStream::Impl { @@ -93,6 +112,9 @@ class OnlineStream::Impl {
93 int32_t start_frame_index_ = 0; // never reset 112 int32_t start_frame_index_ = 0; // never reset
94 int32_t segment_ = 0; 113 int32_t segment_ = 0;
95 OnlineTransducerDecoderResult result_; 114 OnlineTransducerDecoderResult result_;
  115 + TransducerKeywordResult prev_keyword_result_;
  116 + TransducerKeywordResult keyword_result_;
  117 + TransducerKeywordResult empty_keyword_result_;
96 OnlineCtcDecoderResult ctc_result_; 118 OnlineCtcDecoderResult ctc_result_;
97 std::vector<Ort::Value> states_; // states for transducer or ctc models 119 std::vector<Ort::Value> states_; // states for transducer or ctc models
98 std::vector<float> paraformer_feat_cache_; 120 std::vector<float> paraformer_feat_cache_;
@@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { @@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
149 return impl_->GetResult(); 171 return impl_->GetResult();
150 } 172 }
151 173
  174 +void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
  175 + impl_->SetKeywordResult(r);
  176 +}
  177 +
  178 +TransducerKeywordResult &OnlineStream::GetKeywordResult(
  179 + bool remove_duplicates /*=false*/) {
  180 + return impl_->GetKeywordResult(remove_duplicates);
  181 +}
  182 +
152 OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { 183 OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
153 return impl_->GetCtcResult(); 184 return impl_->GetCtcResult();
154 } 185 }
@@ -14,9 +14,11 @@ @@ -14,9 +14,11 @@
14 #include "sherpa-onnx/csrc/online-ctc-decoder.h" 14 #include "sherpa-onnx/csrc/online-ctc-decoder.h"
15 #include "sherpa-onnx/csrc/online-paraformer-decoder.h" 15 #include "sherpa-onnx/csrc/online-paraformer-decoder.h"
16 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 16 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
  17 +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
17 18
18 namespace sherpa_onnx { 19 namespace sherpa_onnx {
19 20
  21 +class TransducerKeywordResult;
20 class OnlineStream { 22 class OnlineStream {
21 public: 23 public:
22 explicit OnlineStream(const FeatureExtractorConfig &config = {}, 24 explicit OnlineStream(const FeatureExtractorConfig &config = {},
@@ -76,6 +78,9 @@ class OnlineStream { @@ -76,6 +78,9 @@ class OnlineStream {
76 void SetResult(const OnlineTransducerDecoderResult &r); 78 void SetResult(const OnlineTransducerDecoderResult &r);
77 OnlineTransducerDecoderResult &GetResult(); 79 OnlineTransducerDecoderResult &GetResult();
78 80
  81 + void SetKeywordResult(const TransducerKeywordResult &r);
  82 + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
  83 +
79 void SetCtcResult(const OnlineCtcDecoderResult &r); 84 void SetCtcResult(const OnlineCtcDecoderResult &r);
80 OnlineCtcDecoderResult &GetCtcResult(); 85 OnlineCtcDecoderResult &GetCtcResult();
81 86
@@ -92,7 +97,7 @@ class OnlineStream { @@ -92,7 +97,7 @@ class OnlineStream {
92 */ 97 */
93 const ContextGraphPtr &GetContextGraph() const; 98 const ContextGraphPtr &GetContextGraph() const;
94 99
95 - // for streaming parformer 100 + // for streaming paraformer
96 std::vector<float> &GetParaformerFeatCache(); 101 std::vector<float> &GetParaformerFeatCache();
97 std::vector<float> &GetParaformerEncoderOutCache(); 102 std::vector<float> &GetParaformerEncoderOutCache();
98 std::vector<float> &GetParaformerAlphaCache(); 103 std::vector<float> &GetParaformerAlphaCache();
@@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
75 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 75 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
76 76
77 if (encoder_out_shape[0] != result->size()) { 77 if (encoder_out_shape[0] != result->size()) {
78 - fprintf(stderr,  
79 - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",  
80 - static_cast<int32_t>(encoder_out_shape[0]),  
81 - static_cast<int32_t>(result->size())); 78 + SHERPA_ONNX_LOGE(
  79 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
  80 + static_cast<int32_t>(encoder_out_shape[0]),
  81 + static_cast<int32_t>(result->size()));
82 exit(-1); 82 exit(-1);
83 } 83 }
84 84
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
119 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); 119 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
120 cur_encoder_out = 120 cur_encoder_out =
121 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); 121 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
122 - Ort::Value logit = model_->RunJoiner(  
123 - std::move(cur_encoder_out), View(&decoder_out)); 122 + Ort::Value logit =
  123 + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
124 124
125 float *p_logit = logit.GetTensorMutableData<float>(); 125 float *p_logit = logit.GetTensorMutableData<float>();
126 LogSoftmax(p_logit, vocab_size, num_hyps); 126 LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
164 if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { 164 if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
165 auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( 165 auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
166 context_state, new_token); 166 context_state, new_token);
167 - context_score = context_res.first;  
168 - new_hyp.context_state = context_res.second; 167 + context_score = std::get<0>(context_res);
  168 + new_hyp.context_state = std::get<1>(context_res);
169 } 169 }
170 if (lm_) { 170 if (lm_) {
171 lm_->ComputeLMScore(lm_scale_, &new_hyp); 171 lm_->ComputeLMScore(lm_scale_, &new_hyp);
  1 +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#include <stdio.h>
  6 +
  7 +#include <iomanip>
  8 +#include <iostream>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  13 +#include "sherpa-onnx/csrc/online-stream.h"
  14 +#include "sherpa-onnx/csrc/parse-options.h"
  15 +#include "sherpa-onnx/csrc/symbol-table.h"
  16 +#include "sherpa-onnx/csrc/wave-reader.h"
  17 +
  18 +typedef struct {
  19 + std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
  20 + std::string filename;
  21 +} Stream;
  22 +
  23 +int main(int32_t argc, char *argv[]) {
  24 + const char *kUsageMessage = R"usage(
  25 +Usage:
  26 +
  27 +(1) Streaming transducer
  28 +
  29 + ./bin/sherpa-onnx-keyword-spotter \
  30 + --tokens=/path/to/tokens.txt \
  31 + --encoder=/path/to/encoder.onnx \
  32 + --decoder=/path/to/decoder.onnx \
  33 + --joiner=/path/to/joiner.onnx \
  34 + --provider=cpu \
  35 + --num-threads=2 \
  36 + --keywords-file=keywords.txt \
  37 + /path/to/foo.wav [bar.wav foobar.wav ...]
  38 +
  39 +Note: It supports decoding multiple files in batches
  40 +
  41 +Default value for num_threads is 2.
  42 +Valid values for provider: cpu (default), cuda, coreml.
  43 +foo.wav should be of single channel, 16-bit PCM encoded wave file; its
  44 +sampling rate can be arbitrary and does not need to be 16kHz.
  45 +
  46 +Please refer to
  47 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  48 +for a list of pre-trained models to download.
  49 +)usage";
  50 +
  51 + sherpa_onnx::ParseOptions po(kUsageMessage);
  52 + sherpa_onnx::KeywordSpotterConfig config;
  53 +
  54 + config.Register(&po);
  55 +
  56 + po.Read(argc, argv);
  57 + if (po.NumArgs() < 1) {
  58 + po.PrintUsage();
  59 + exit(EXIT_FAILURE);
  60 + }
  61 +
  62 + fprintf(stderr, "%s\n", config.ToString().c_str());
  63 +
  64 + if (!config.Validate()) {
  65 + fprintf(stderr, "Errors in config!\n");
  66 + return -1;
  67 + }
  68 +
  69 + sherpa_onnx::KeywordSpotter keyword_spotter(config);
  70 +
  71 + std::vector<Stream> ss;
  72 +
  73 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  74 + const std::string wav_filename = po.GetArg(i);
  75 + int32_t sampling_rate = -1;
  76 +
  77 + bool is_ok = false;
  78 + const std::vector<float> samples =
  79 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  80 +
  81 + if (!is_ok) {
  82 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  83 + return -1;
  84 + }
  85 +
  86 + auto s = keyword_spotter.CreateStream();
  87 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  88 +
  89 + std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
  90 + // Note: We can call AcceptWaveform() multiple times.
  91 + s->AcceptWaveform(sampling_rate, tail_paddings.data(),
  92 + tail_paddings.size());
  93 +
  94 + // Call InputFinished() to indicate that no audio samples are available
  95 + s->InputFinished();
  96 + ss.push_back({std::move(s), wav_filename});
  97 + }
  98 +
  99 + std::vector<sherpa_onnx::OnlineStream *> ready_streams;
  100 + for (;;) {
  101 + ready_streams.clear();
  102 + for (auto &s : ss) {
  103 + const auto p_ss = s.online_stream.get();
  104 + if (keyword_spotter.IsReady(p_ss)) {
  105 + ready_streams.push_back(p_ss);
  106 + }
  107 + std::ostringstream os;
  108 + const auto r = keyword_spotter.GetResult(p_ss);
  109 + if (!r.keyword.empty()) {
  110 + os << s.filename << "\n";
  111 + os << r.AsJsonString() << "\n\n";
  112 + fprintf(stderr, "%s", os.str().c_str());
  113 + }
  114 + }
  115 +
  116 + if (ready_streams.empty()) {
  117 + break;
  118 + }
  119 + keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
  120 + }
  121 + return 0;
  122 +}
  1 +// sherpa-onnx/csrc/transducer-keywords-decoder.cc
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#include <algorithm>
  6 +#include <cmath>
  7 +#include <utility>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/log.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
  17 + int32_t context_size = model_->ContextSize();
  18 + int32_t blank_id = 0; // always 0
  19 + TransducerKeywordResult r;
  20 + std::vector<int64_t> blanks(context_size, -1);
  21 + blanks.back() = blank_id;
  22 +
  23 + Hypotheses blank_hyp({{blanks, 0}});
  24 + r.hyps = std::move(blank_hyp);
  25 + return r;
  26 +}
  27 +
  28 +void TransducerKeywordDecoder::Decode(
  29 + Ort::Value encoder_out, OnlineStream **ss,
  30 + std::vector<TransducerKeywordResult> *result) {
  31 + std::vector<int64_t> encoder_out_shape =
  32 + encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  33 +
  34 + if (encoder_out_shape[0] != result->size()) {
  35 + SHERPA_ONNX_LOGE(
  36 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
  37 + static_cast<int32_t>(encoder_out_shape[0]),
  38 + static_cast<int32_t>(result->size()));
  39 + exit(-1);
  40 + }
  41 +
  42 + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
  43 +
  44 + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
  45 + int32_t vocab_size = model_->VocabSize();
  46 + int32_t context_size = model_->ContextSize();
  47 + std::vector<int64_t> blanks(context_size, -1);
  48 + blanks.back() = 0; // blank_id is hardcoded to 0
  49 +
  50 + std::vector<Hypotheses> cur;
  51 + for (auto &r : *result) {
  52 + cur.push_back(std::move(r.hyps));
  53 + }
  54 + std::vector<Hypothesis> prev;
  55 +
  56 + for (int32_t t = 0; t != num_frames; ++t) {
  57 + // Due to merging paths with identical token sequences,
  58 + // not all utterances have "num_active_paths" paths.
  59 + auto hyps_row_splits = GetHypsRowSplits(cur);
  60 + int32_t num_hyps =
  61 + hyps_row_splits.back(); // total num hyps for all utterance
  62 + prev.clear();
  63 + for (auto &hyps : cur) {
  64 + for (auto &h : hyps) {
  65 + prev.push_back(std::move(h.second));
  66 + }
  67 + }
  68 + cur.clear();
  69 + cur.reserve(batch_size);
  70 +
  71 + Ort::Value decoder_input = model_->BuildDecoderInput(prev);
  72 + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  73 +
  74 + Ort::Value cur_encoder_out =
  75 + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
  76 + cur_encoder_out =
  77 + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
  78 + Ort::Value logit =
  79 + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
  80 +
  81 + float *p_logit = logit.GetTensorMutableData<float>();
  82 + LogSoftmax(p_logit, vocab_size, num_hyps);
  83 +
  84 + // The acoustic logprobs for current frame
  85 + std::vector<float> logprobs(vocab_size * num_hyps);
  86 + std::memcpy(logprobs.data(), p_logit,
  87 + sizeof(float) * vocab_size * num_hyps);
  88 +
  89 + // now p_logit contains log_softmax output, we rename it to p_logprob
  90 + // to match what it actually contains
  91 + float *p_logprob = p_logit;
  92 +
  93 + // add log_prob of each hypothesis to p_logprob before taking top_k
  94 + for (int32_t i = 0; i != num_hyps; ++i) {
  95 + float log_prob = prev[i].log_prob;
  96 + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
  97 + *p_logprob += log_prob;
  98 + }
  99 + }
  100 + p_logprob = p_logit; // we changed p_logprob in the above for loop
  101 +
  102 + for (int32_t b = 0; b != batch_size; ++b) {
  103 + int32_t frame_offset = (*result)[b].frame_offset;
  104 + int32_t start = hyps_row_splits[b];
  105 + int32_t end = hyps_row_splits[b + 1];
  106 + auto topk =
  107 + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
  108 +
  109 + Hypotheses hyps;
  110 + for (auto k : topk) {
  111 + int32_t hyp_index = k / vocab_size + start;
  112 + int32_t new_token = k % vocab_size;
  113 +
  114 + Hypothesis new_hyp = prev[hyp_index];
  115 + float context_score = 0;
  116 + auto context_state = new_hyp.context_state;
  117 +
  118 + // blank is hardcoded to 0
  119 + // also, it treats unk as blank
  120 + if (new_token != 0 && new_token != unk_id_) {
  121 + new_hyp.ys.push_back(new_token);
  122 + new_hyp.timestamps.push_back(t + frame_offset);
  123 + new_hyp.ys_probs.push_back(
  124 + exp(logprobs[hyp_index * vocab_size + new_token]));
  125 +
  126 + new_hyp.num_trailing_blanks = 0;
  127 + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
  128 + context_state, new_token);
  129 + context_score = std::get<0>(context_res);
  130 + new_hyp.context_state = std::get<1>(context_res);
  131 + // Start matching from the start state, forget the decoder history.
  132 + if (new_hyp.context_state->token == -1) {
  133 + new_hyp.ys = blanks;
  134 + new_hyp.timestamps.clear();
  135 + new_hyp.ys_probs.clear();
  136 + }
  137 + } else {
  138 + ++new_hyp.num_trailing_blanks;
  139 + }
  140 + new_hyp.log_prob = p_logprob[k] + context_score;
  141 + hyps.Add(std::move(new_hyp));
  142 + } // for (auto k : topk)
  143 +
  144 + auto best_hyp = hyps.GetMostProbable(false);
  145 +
  146 + auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
  147 + bool matched = std::get<0>(status);
  148 + const ContextState *matched_state = std::get<1>(status);
  149 +
  150 + if (matched) {
  151 + float ys_prob = 0.0;
  152 + int32_t length = best_hyp.ys_probs.size();
  153 + for (int32_t i = 1; i <= matched_state->level; ++i) {
  154 + ys_prob += best_hyp.ys_probs[i];
  155 + }
  156 + ys_prob /= matched_state->level;
  157 + if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
  158 + ys_prob >= matched_state->ac_threshold) {
  159 + auto &r = (*result)[b];
  160 + r.tokens = {best_hyp.ys.end() - matched_state->level,
  161 + best_hyp.ys.end()};
  162 + r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
  163 + best_hyp.timestamps.end()};
  164 + r.keyword = matched_state->phrase;
  165 +
  166 + hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
  167 + }
  168 + }
  169 + cur.push_back(std::move(hyps));
  170 + p_logprob += (end - start) * vocab_size;
  171 + } // for (int32_t b = 0; b != batch_size; ++b)
  172 + }
  173 +
  174 + for (int32_t b = 0; b != batch_size; ++b) {
  175 + auto &hyps = cur[b];
  176 + auto best_hyp = hyps.GetMostProbable(false);
  177 + auto &r = (*result)[b];
  178 + r.hyps = std::move(hyps);
  179 + r.num_trailing_blanks = best_hyp.num_trailing_blanks;
  180 + r.frame_offset += num_frames;
  181 + }
  182 +}
  183 +
  184 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/transducer-keywords-decoder.h
  2 +//
  3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
  7 +
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/online-stream.h"
  13 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct TransducerKeywordResult {
  18 + /// Number of frames after subsampling we have decoded so far
  19 + int32_t frame_offset = 0;
  20 +
  21 + /// The decoded token IDs for keywords
  22 + std::vector<int64_t> tokens;
  23 +
  24 + /// The triggered keyword
  25 + std::string keyword;
  26 +
  27 + /// number of trailing blank frames decoded so far
  28 + int32_t num_trailing_blanks = 0;
  29 +
  30 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  31 + std::vector<int32_t> timestamps;
  32 +
  33 + // used only in modified beam_search
  34 + Hypotheses hyps;
  35 +};
  36 +
  37 +class TransducerKeywordDecoder {
  38 + public:
  39 + TransducerKeywordDecoder(OnlineTransducerModel *model,
  40 + int32_t max_active_paths,
  41 + int32_t num_trailing_blanks, int32_t unk_id)
  42 + : model_(model),
  43 + max_active_paths_(max_active_paths),
  44 + num_trailing_blanks_(num_trailing_blanks),
  45 + unk_id_(unk_id) {}
  46 +
  47 + TransducerKeywordResult GetEmptyResult() const;
  48 +
  49 + void Decode(Ort::Value encoder_out, OnlineStream **ss,
  50 + std::vector<TransducerKeywordResult> *result);
  51 +
  52 + private:
  53 + OnlineTransducerModel *model_; // Not owned
  54 +
  55 + int32_t max_active_paths_;
  56 + int32_t num_trailing_blanks_;
  57 + int32_t unk_id_;
  58 +};
  59 +
  60 +} // namespace sherpa_onnx
  61 +
  62 +#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
@@ -15,16 +15,31 @@ @@ -15,16 +15,31 @@
15 15
16 namespace sherpa_onnx { 16 namespace sherpa_onnx {
17 17
18 -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,  
19 - std::vector<std::vector<int32_t>> *hotwords) {  
20 - hotwords->clear();  
21 - std::vector<int32_t> tmp; 18 +static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
  19 + std::vector<std::vector<int32_t>> *ids,
  20 + std::vector<std::string> *phrases,
  21 + std::vector<float> *scores,
  22 + std::vector<float> *thresholds) {
  23 + SHERPA_ONNX_CHECK(ids != nullptr);
  24 + ids->clear();
  25 +
  26 + std::vector<int32_t> tmp_ids;
  27 + std::vector<float> tmp_scores;
  28 + std::vector<float> tmp_thresholds;
  29 + std::vector<std::string> tmp_phrases;
  30 +
22 std::string line; 31 std::string line;
23 std::string word; 32 std::string word;
  33 + bool has_scores = false;
  34 + bool has_thresholds = false;
  35 + bool has_phrases = false;
24 36
25 while (std::getline(is, line)) { 37 while (std::getline(is, line)) {
  38 + float score = 0;
  39 + float threshold = 0;
  40 + std::string phrase = "";
  41 +
26 std::istringstream iss(line); 42 std::istringstream iss(line);
27 - std::vector<std::string> syms;  
28 while (iss >> word) { 43 while (iss >> word) {
29 if (word.size() >= 3) { 44 if (word.size() >= 3) {
30 // For BPE-based models, we replace ▁ with a space 45 // For BPE-based models, we replace ▁ with a space
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, @@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
35 } 50 }
36 } 51 }
37 if (symbol_table.contains(word)) { 52 if (symbol_table.contains(word)) {
38 - int32_t number = symbol_table[word];  
39 - tmp.push_back(number); 53 + int32_t id = symbol_table[word];
  54 + tmp_ids.push_back(id);
40 } else { 55 } else {
41 - SHERPA_ONNX_LOGE(  
42 - "Cannot find ID for hotword %s at line: %s. (Hint: words on "  
43 - "the "  
44 - "same line are separated by spaces)",  
45 - word.c_str(), line.c_str());  
46 - return false; 56 + switch (word[0]) {
  57 + case ':': // boosting score for current keyword
  58 + score = std::stof(word.substr(1));
  59 + has_scores = true;
  60 + break;
  61 + case '#': // triggering threshold (probability) for current keyword
  62 + threshold = std::stof(word.substr(1));
  63 + has_thresholds = true;
  64 + break;
  65 + case '@': // the original keyword string
  66 + phrase = word.substr(1);
  67 + has_phrases = true;
  68 + break;
  69 + default:
  70 + SHERPA_ONNX_LOGE(
  71 + "Cannot find ID for token %s at line: %s. (Hint: words on "
  72 + "the same line are separated by spaces)",
  73 + word.c_str(), line.c_str());
  74 + return false;
  75 + }
47 } 76 }
48 } 77 }
49 - hotwords->push_back(std::move(tmp)); 78 + ids->push_back(std::move(tmp_ids));
  79 + tmp_scores.push_back(score);
  80 + tmp_phrases.push_back(phrase);
  81 + tmp_thresholds.push_back(threshold);
  82 + }
  83 + if (scores != nullptr) {
  84 + if (has_scores) {
  85 + scores->swap(tmp_scores);
  86 + } else {
  87 + scores->clear();
  88 + }
  89 + }
  90 + if (phrases != nullptr) {
  91 + if (has_phrases) {
  92 + *phrases = std::move(tmp_phrases);
  93 + } else {
  94 + phrases->clear();
  95 + }
  96 + }
  97 + if (thresholds != nullptr) {
  98 + if (has_thresholds) {
  99 + thresholds->swap(tmp_thresholds);
  100 + } else {
  101 + thresholds->clear();
  102 + }
50 } 103 }
51 return true; 104 return true;
52 } 105 }
53 106
  107 +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
  108 + std::vector<std::vector<int32_t>> *hotwords) {
  109 + return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
  110 +}
  111 +
  112 +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
  113 + std::vector<std::vector<int32_t>> *keywords_id,
  114 + std::vector<std::string> *keywords,
  115 + std::vector<float> *boost_scores,
  116 + std::vector<float> *threshold) {
  117 + return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
  118 + threshold);
  119 +}
  120 +
54 } // namespace sherpa_onnx 121 } // namespace sherpa_onnx
@@ -26,7 +26,32 @@ namespace sherpa_onnx { @@ -26,7 +26,32 @@ namespace sherpa_onnx {
26 * otherwise returns false. 26 * otherwise returns false.
27 */ 27 */
28 bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, 28 bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
29 - std::vector<std::vector<int32_t>> *hotwords); 29 + std::vector<std::vector<int32_t>> *hotwords_id);
  30 +
  31 +/* Encode the keywords in an input stream to be tokens ids.
  32 + *
  33 + * @param is The input stream, it contains several lines, one hotword for each
  34 + * line. For each hotword, the tokens (cjkchar or bpe) are separated
  35 + * by spaces, it might contain boosting score (starting with :),
  36 + * triggering threshold (starting with #) and keyword string (starting
  37 + * with @) too.
  38 + * @param symbol_table The tokens table mapping symbols to ids. All the symbols
  39 + * in the stream should be in the symbol_table, if not this
  40 + * function returns fasle.
  41 + *
  42 + * @param keywords_id The encoded ids to be written to.
  43 + * @param keywords The original keyword string to be written to.
  44 + * @param boost_scores The boosting score for each keyword to be written to.
  45 + * @param threshold The triggering threshold for each keyword to be written to.
  46 + *
  47 + * @return If all the symbols from ``is`` are in the symbol_table, returns true
  48 + * otherwise returns false.
  49 + */
  50 +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
  51 + std::vector<std::vector<int32_t>> *keywords_id,
  52 + std::vector<std::string> *keywords,
  53 + std::vector<float> *boost_scores,
  54 + std::vector<float> *threshold);
30 55
31 } // namespace sherpa_onnx 56 } // namespace sherpa_onnx
32 57
@@ -21,6 +21,7 @@ @@ -21,6 +21,7 @@
21 #include "android/asset_manager_jni.h" 21 #include "android/asset_manager_jni.h"
22 #endif 22 #endif
23 23
  24 +#include "sherpa-onnx/csrc/keyword-spotter.h"
24 #include "sherpa-onnx/csrc/macros.h" 25 #include "sherpa-onnx/csrc/macros.h"
25 #include "sherpa-onnx/csrc/offline-recognizer.h" 26 #include "sherpa-onnx/csrc/offline-recognizer.h"
26 #include "sherpa-onnx/csrc/offline-tts.h" 27 #include "sherpa-onnx/csrc/offline-tts.h"
@@ -140,6 +141,73 @@ class SherpaOnnxVad { @@ -140,6 +141,73 @@ class SherpaOnnxVad {
140 VoiceActivityDetector vad_; 141 VoiceActivityDetector vad_;
141 }; 142 };
142 143
  144 +class SherpaOnnxKws {
  145 + public:
  146 +#if __ANDROID_API__ >= 9
  147 + SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
  148 + : keyword_spotter_(mgr, config),
  149 + stream_(keyword_spotter_.CreateStream()) {}
  150 +#endif
  151 +
  152 + explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
  153 + : keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
  154 +
  155 + void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
  156 + if (input_sample_rate_ == -1) {
  157 + input_sample_rate_ = sample_rate;
  158 + }
  159 +
  160 + stream_->AcceptWaveform(sample_rate, samples, n);
  161 + }
  162 +
  163 + void InputFinished() const {
  164 + std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
  165 + stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
  166 + tail_padding.size());
  167 + stream_->InputFinished();
  168 + }
  169 +
  170 + // If keywords is an empty string, it just recreates the decoding stream
  171 + // always returns true in this case.
  172 + // If keywords is not empty, it will create a new decoding stream with
  173 + // the given keywords appended to the default keywords.
  174 + // Return false if errors occurred when adding keywords, true otherwise.
  175 + bool Reset(const std::string &keywords = {}) {
  176 + if (keywords.empty()) {
  177 + stream_ = keyword_spotter_.CreateStream();
  178 + return true;
  179 + } else {
  180 + auto stream = keyword_spotter_.CreateStream(keywords);
  181 + // Set new keywords failed, the stream_ will not be updated.
  182 + if (stream == nullptr) {
  183 + return false;
  184 + } else {
  185 + stream_ = std::move(stream);
  186 + return true;
  187 + }
  188 + }
  189 + }
  190 +
  191 + std::string GetKeyword() const {
  192 + auto result = keyword_spotter_.GetResult(stream_.get());
  193 + return result.keyword;
  194 + }
  195 +
  196 + std::vector<std::string> GetTokens() const {
  197 + auto result = keyword_spotter_.GetResult(stream_.get());
  198 + return result.tokens;
  199 + }
  200 +
  201 + bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
  202 +
  203 + void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
  204 +
  205 + private:
  206 + KeywordSpotter keyword_spotter_;
  207 + std::unique_ptr<OnlineStream> stream_;
  208 + int32_t input_sample_rate_ = -1;
  209 +};
  210 +
143 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { 211 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
144 OnlineRecognizerConfig ans; 212 OnlineRecognizerConfig ans;
145 213
@@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { @@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
457 return ans; 525 return ans;
458 } 526 }
459 527
  528 +static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
  529 + KeywordSpotterConfig ans;
  530 +
  531 + jclass cls = env->GetObjectClass(config);
  532 + jfieldID fid;
  533 +
  534 + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
  535 + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
  536 +
  537 + //---------- decoding ----------
  538 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  539 + ans.max_active_paths = env->GetIntField(config, fid);
  540 +
  541 + fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
  542 + jstring s = (jstring)env->GetObjectField(config, fid);
  543 + const char *p = env->GetStringUTFChars(s, nullptr);
  544 + ans.keywords_file = p;
  545 + env->ReleaseStringUTFChars(s, p);
  546 +
  547 + fid = env->GetFieldID(cls, "keywordsScore", "F");
  548 + ans.keywords_score = env->GetFloatField(config, fid);
  549 +
  550 + fid = env->GetFieldID(cls, "keywordsThreshold", "F");
  551 + ans.keywords_threshold = env->GetFloatField(config, fid);
  552 +
  553 + fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
  554 + ans.num_trailing_blanks = env->GetIntField(config, fid);
  555 +
  556 + //---------- feat config ----------
  557 + fid = env->GetFieldID(cls, "featConfig",
  558 + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
  559 + jobject feat_config = env->GetObjectField(config, fid);
  560 + jclass feat_config_cls = env->GetObjectClass(feat_config);
  561 +
  562 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  563 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
  564 +
  565 + fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
  566 + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
  567 +
  568 + //---------- model config ----------
  569 + fid = env->GetFieldID(cls, "modelConfig",
  570 + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
  571 + jobject model_config = env->GetObjectField(config, fid);
  572 + jclass model_config_cls = env->GetObjectClass(model_config);
  573 +
  574 + // transducer
  575 + fid = env->GetFieldID(model_config_cls, "transducer",
  576 + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
  577 + jobject transducer_config = env->GetObjectField(model_config, fid);
  578 + jclass transducer_config_cls = env->GetObjectClass(transducer_config);
  579 +
  580 + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
  581 + s = (jstring)env->GetObjectField(transducer_config, fid);
  582 + p = env->GetStringUTFChars(s, nullptr);
  583 + ans.model_config.transducer.encoder = p;
  584 + env->ReleaseStringUTFChars(s, p);
  585 +
  586 + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
  587 + s = (jstring)env->GetObjectField(transducer_config, fid);
  588 + p = env->GetStringUTFChars(s, nullptr);
  589 + ans.model_config.transducer.decoder = p;
  590 + env->ReleaseStringUTFChars(s, p);
  591 +
  592 + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
  593 + s = (jstring)env->GetObjectField(transducer_config, fid);
  594 + p = env->GetStringUTFChars(s, nullptr);
  595 + ans.model_config.transducer.joiner = p;
  596 + env->ReleaseStringUTFChars(s, p);
  597 +
  598 + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
  599 + s = (jstring)env->GetObjectField(model_config, fid);
  600 + p = env->GetStringUTFChars(s, nullptr);
  601 + ans.model_config.tokens = p;
  602 + env->ReleaseStringUTFChars(s, p);
  603 +
  604 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  605 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
  606 +
  607 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  608 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
  609 +
  610 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  611 + s = (jstring)env->GetObjectField(model_config, fid);
  612 + p = env->GetStringUTFChars(s, nullptr);
  613 + ans.model_config.provider = p;
  614 + env->ReleaseStringUTFChars(s, p);
  615 +
  616 + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
  617 + s = (jstring)env->GetObjectField(model_config, fid);
  618 + p = env->GetStringUTFChars(s, nullptr);
  619 + ans.model_config.model_type = p;
  620 + env->ReleaseStringUTFChars(s, p);
  621 +
  622 + return ans;
  623 +}
  624 +
460 static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { 625 static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
461 VadModelConfig ans; 626 VadModelConfig ans;
462 627
@@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( @@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
1013 jclass stringClass = env->FindClass("java/lang/String"); 1178 jclass stringClass = env->FindClass("java/lang/String");
1014 1179
1015 // convert C++ list into jni string array 1180 // convert C++ list into jni string array
1016 - jobjectArray result = env->NewObjectArray(size, stringClass, NULL); 1181 + jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
  1182 + for (int32_t i = 0; i < size; i++) {
  1183 + // Convert the C++ string to a C string
  1184 + const char *cstr = tokens[i].c_str();
  1185 +
  1186 + // Convert the C string to a jstring
  1187 + jstring jstr = env->NewStringUTF(cstr);
  1188 +
  1189 + // Set the array element
  1190 + env->SetObjectArrayElement(result, i, jstr);
  1191 + }
  1192 +
  1193 + return result;
  1194 +}
  1195 +
  1196 +SHERPA_ONNX_EXTERN_C
  1197 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
  1198 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  1199 +#if __ANDROID_API__ >= 9
  1200 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  1201 + if (!mgr) {
  1202 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  1203 + }
  1204 +#endif
  1205 + auto config = sherpa_onnx::GetKwsConfig(env, _config);
  1206 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  1207 + auto model = new sherpa_onnx::SherpaOnnxKws(
  1208 +#if __ANDROID_API__ >= 9
  1209 + mgr,
  1210 +#endif
  1211 + config);
  1212 +
  1213 + return (jlong)model;
  1214 +}
  1215 +
  1216 +SHERPA_ONNX_EXTERN_C
  1217 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
  1218 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  1219 + auto config = sherpa_onnx::GetKwsConfig(env, _config);
  1220 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  1221 + auto model = new sherpa_onnx::SherpaOnnxKws(config);
  1222 +
  1223 + return (jlong)model;
  1224 +}
  1225 +
  1226 +SHERPA_ONNX_EXTERN_C
  1227 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
  1228 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  1229 + delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
  1230 +}
  1231 +
  1232 +SHERPA_ONNX_EXTERN_C
  1233 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
  1234 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  1235 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
  1236 + return model->IsReady();
  1237 +}
  1238 +
  1239 +SHERPA_ONNX_EXTERN_C
  1240 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
  1241 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  1242 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
  1243 + model->Decode();
  1244 +}
  1245 +
  1246 +SHERPA_ONNX_EXTERN_C
  1247 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
  1248 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  1249 + jint sample_rate) {
  1250 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
  1251 +
  1252 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  1253 + jsize n = env->GetArrayLength(samples);
  1254 +
  1255 + model->AcceptWaveform(sample_rate, p, n);
  1256 +
  1257 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  1258 +}
  1259 +
  1260 +SHERPA_ONNX_EXTERN_C
  1261 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
  1262 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  1263 + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
  1264 +}
  1265 +
  1266 +SHERPA_ONNX_EXTERN_C
  1267 +JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
  1268 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  1269 + // see
  1270 + // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
  1271 + auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
  1272 + return env->NewStringUTF(text.c_str());
  1273 +}
  1274 +
  1275 +SHERPA_ONNX_EXTERN_C
  1276 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
  1277 + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
  1278 + const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
  1279 +
  1280 + std::string keywords_str = p_keywords;
  1281 +
  1282 + bool status =
  1283 + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
  1284 + env->ReleaseStringUTFChars(keywords, p_keywords);
  1285 + return status;
  1286 +}
  1287 +
  1288 +SHERPA_ONNX_EXTERN_C
  1289 +JNIEXPORT jobjectArray JNICALL
  1290 +Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
  1291 + jlong ptr) {
  1292 + auto tokens =
  1293 + reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
  1294 + int32_t size = tokens.size();
  1295 + jclass stringClass = env->FindClass("java/lang/String");
  1296 +
  1297 + // convert C++ list into jni string array
  1298 + jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
1017 for (int32_t i = 0; i < size; i++) { 1299 for (int32_t i = 0; i < size; i++) {
1018 // Convert the C++ string to a C string 1300 // Convert the C++ string to a C string
1019 const char *cstr = tokens[i].c_str(); 1301 const char *cstr = tokens[i].c_str();
@@ -28,9 +28,14 @@ def cli(): @@ -28,9 +28,14 @@ def cli():
28 ) 28 )
29 @click.option( 29 @click.option(
30 "--tokens-type", 30 "--tokens-type",
31 - type=str, 31 + type=click.Choice(
  32 + ["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
  33 + ),
32 required=True, 34 required=True,
33 - help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", 35 + help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
  36 + fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
  37 + ppinyin means partial pinyin, it splits pinyin into initial and final,
  38 + """,
34 ) 39 )
35 @click.option( 40 @click.option(
36 "--bpe-model", 41 "--bpe-model",
@@ -42,14 +47,56 @@ def encode_text( @@ -42,14 +47,56 @@ def encode_text(
42 ): 47 ):
43 """ 48 """
44 Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. 49 Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
  50 + Each line in the texts contains the original phrase, it might also contain some
  51 + extra items, for example, the boosting score (startting with :), the triggering
  52 + threshold (startting with #, only used in keyword spotting task) and the original
  53 + phrase (startting with @). Note: the extra items will be kept same in the output.
  54 +
  55 + example input 1 (tokens_type = ppinyin):
  56 +
  57 + 小爱同学 :2.0 #0.6 @小爱同学
  58 + 你好问问 :3.5 @你好问问
  59 + 小艺小艺 #0.6 @小艺小艺
  60 +
  61 + example output 1:
  62 +
  63 + x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
  64 + n ǐ h ǎo w èn w èn :3.5 @你好问问
  65 + x iǎo y ì x iǎo y ì #0.6 @小艺小艺
  66 +
  67 + example input 2 (tokens_type = bpe):
  68 +
  69 + HELLO WORLD :1.5 #0.4
  70 + HI GOOGLE :2.0 #0.8
  71 + HEY SIRI #0.35
  72 +
  73 + example output 2:
  74 +
  75 + ▁HE LL O ▁WORLD :1.5 #0.4
  76 + ▁HI ▁GO O G LE :2.0 #0.8
  77 + ▁HE Y ▁S I RI #0.35
45 """ 78 """
46 texts = [] 79 texts = []
  80 + # extra information like boosting score (start with :), triggering threshold (start with #)
  81 + # original keyword (start with @)
  82 + extra_info = []
47 with open(input, "r", encoding="utf8") as f: 83 with open(input, "r", encoding="utf8") as f:
48 for line in f: 84 for line in f:
49 - texts.append(line.strip()) 85 + extra = []
  86 + text = []
  87 + toks = line.strip().split()
  88 + for tok in toks:
  89 + if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
  90 + extra.append(tok)
  91 + else:
  92 + text.append(tok)
  93 + texts.append(" ".join(text))
  94 + extra_info.append(extra)
  95 +
50 encoded_texts = text2token( 96 encoded_texts = text2token(
51 texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model 97 texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
52 ) 98 )
53 with open(output, "w", encoding="utf8") as f: 99 with open(output, "w", encoding="utf8") as f:
54 - for txt in encoded_texts: 100 + for i, txt in enumerate(encoded_texts):
  101 + txt += extra_info[i]
55 f.write(" ".join(txt) + "\n") 102 f.write(" ".join(txt) + "\n")
@@ -6,6 +6,9 @@ from typing import List, Optional, Union @@ -6,6 +6,9 @@ from typing import List, Optional, Union
6 6
7 import sentencepiece as spm 7 import sentencepiece as spm
8 8
  9 +from pypinyin import pinyin
  10 +from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
  11 +
9 12
10 def text2token( 13 def text2token(
11 texts: List[str], 14 texts: List[str],
@@ -23,7 +26,9 @@ def text2token( @@ -23,7 +26,9 @@ def text2token(
23 tokens: 26 tokens:
24 The path of the tokens.txt. 27 The path of the tokens.txt.
25 tokens_type: 28 tokens_type:
26 - The valid values are cjkchar, bpe, cjkchar+bpe. 29 + The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
  30 + fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
  31 + ppinyin means partial pinyin, it splits pinyin into initial and final,
27 bpe_model: 32 bpe_model:
28 The path of the bpe model. Only required when tokens_type is bpe or 33 The path of the bpe model. Only required when tokens_type is bpe or
29 cjkchar+bpe. 34 cjkchar+bpe.
@@ -53,6 +58,24 @@ def text2token( @@ -53,6 +58,24 @@ def text2token(
53 texts_list = [list("".join(text.split())) for text in texts] 58 texts_list = [list("".join(text.split())) for text in texts]
54 elif tokens_type == "bpe": 59 elif tokens_type == "bpe":
55 texts_list = sp.encode(texts, out_type=str) 60 texts_list = sp.encode(texts, out_type=str)
  61 + elif "pinyin" in tokens_type:
  62 + for txt in texts:
  63 + py = [x[0] for x in pinyin(txt)]
  64 + if "ppinyin" == tokens_type:
  65 + res = []
  66 + for x in py:
  67 + initial = to_initials(x, strict=False)
  68 + final = to_finals_tone(x, strict=False)
  69 + if initial == "" and final == "":
  70 + res.append(x)
  71 + else:
  72 + if initial != "":
  73 + res.append(initial)
  74 + if final != "":
  75 + res.append(final)
  76 + texts_list.append(res)
  77 + else:
  78 + texts_list.append(py)
56 else: 79 else:
57 assert ( 80 assert (
58 tokens_type == "cjkchar+bpe" 81 tokens_type == "cjkchar+bpe"