Fangjun Kuang
Committed by GitHub

Add Android demo for speaker recognition (#536)

See pre-built Android APKs at 
https://k2-fsa.github.io/sherpa/onnx/speaker-identification/apk.html
正在显示 73 个修改的文件 包含 3022 行增加6 行删除
name: apk-speaker-identification
on:
push:
branches:
- apk
tags:
- '*'
workflow_dispatch:
concurrency:
group: apk-speaker-identification-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
jobs:
apk_tts:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
runs-on: ${{ matrix.os }}
name: apk for tts ${{ matrix.index }}/${{ matrix.total }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
total: ["10"]
index: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# https://github.com/actions/setup-java
- uses: actions/setup-java@v4
with:
distribution: 'temurin' # See 'Supported distributions' for available options
java-version: '21'
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-android
- name: Display NDK HOME
shell: bash
run: |
echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
ls -lh ${ANDROID_NDK_LATEST_HOME}
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip jinja2
- name: Generate build script
shell: bash
run: |
cd scripts/apk
total=${{ matrix.total }}
index=${{ matrix.index }}
./generate-speaker-identification-apk-script.py --total $total --index $index
chmod +x build-apk-speaker-identification.sh
mv -v ./build-apk-speaker-identification.sh ../..
- name: build APK
shell: bash
run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-speaker-identification.sh
- name: Display APK
shell: bash
run: |
ls -lh ./apks/
du -h -d1 .
# - name: Release
# uses: svenstaro/upload-release-action@v2
# with:
# file_glob: true
# file: ./apks/*.apk
# overwrite: true
# repo_name: k2-fsa/sherpa-onnx
# repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
# tag: speaker-recongition-models
- name: Publish to huggingface
if: true
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
mkdir -p speaker-identification
cp -v ../apks/*.apk ./speaker-identification/
git status
git lfs track "*.apk"
git add .
git commit -m "add more apks"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-apk main
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.7")
set(SHERPA_ONNX_VERSION "1.9.8")
# Disable warning about
#
... ...
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties
... ...
/build
\ No newline at end of file
... ...
plugins {
id("com.android.application")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "com.k2fsa.sherpa.onnx.speaker.identification"
compileSdk = 34
defaultConfig {
applicationId = "com.k2fsa.sherpa.onnx.speaker.identification"
minSdk = 21
targetSdk = 34
versionCode = 1
versionName = "1.0"
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
vectorDrawables {
useSupportLibrary = true
}
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.1"
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
}
dependencies {
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.7.0")
implementation("androidx.activity:activity-compose:1.8.2")
implementation(platform("androidx.compose:compose-bom:2023.08.00"))
implementation("androidx.compose.ui:ui")
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
implementation("androidx.navigation:navigation-compose:2.7.6")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
androidTestImplementation("androidx.compose.ui:ui-test-junit4")
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")
}
\ No newline at end of file
... ...
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("com.k2fsa.sherpa.onnx.speaker.identification", appContext.packageName)
}
}
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.SherpaOnnxSpeakerIdentification"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true"
android:label="@string/app_name"
android:theme="@style/Theme.SherpaOnnxSpeakerIdentification">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
import androidx.compose.ui.graphics.vector.ImageVector
data class BarItem (
val title: String,
// see https://www.composables.com/icons
// and
// https://developer.android.com/reference/kotlin/androidx/compose/material/icons/filled/package-summary
val image: ImageVector,
val route: String,
)
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
import android.Manifest
import android.content.pm.PackageManager
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.core.app.ActivityCompat
import androidx.navigation.NavGraph.Companion.findStartDestination
import androidx.navigation.NavHostController
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController
import com.k2fsa.sherpa.onnx.SpeakerRecognition
import com.k2fsa.sherpa.onnx.speaker.identification.screens.HelpScreen
import com.k2fsa.sherpa.onnx.speaker.identification.screens.HomeScreen
import com.k2fsa.sherpa.onnx.speaker.identification.screens.RegisterScreen
import com.k2fsa.sherpa.onnx.speaker.identification.screens.ViewScreen
import com.k2fsa.sherpa.onnx.speaker.identification.ui.theme.SherpaOnnxSpeakerIdentificationTheme
const val TAG = "sherpa-onnx-speaker"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
class MainActivity : ComponentActivity() {
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContent {
SherpaOnnxSpeakerIdentificationTheme {
// A surface container using the 'background' color from the theme
Surface(
modifier = Modifier.fillMaxSize(),
color = MaterialTheme.colorScheme.background
) {
MainScreen()
}
}
}
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
SpeakerRecognition.initExtractor(this.assets)
}
@Deprecated("Deprecated in Java")
override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<out String>,
grantResults: IntArray
) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) {
grantResults[0] == PackageManager.PERMISSION_GRANTED
} else {
false
}
if (!permissionToRecordAccepted) {
Log.e(TAG, "Audio record is disallowed")
Toast.makeText(
this,
"This App needs access to the microphone",
Toast.LENGTH_SHORT
)
.show()
finish()
}
Log.i(TAG, "Audio record is permitted")
}
}
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MainScreen(modifier: Modifier = Modifier) {
val navController = rememberNavController()
Scaffold(
topBar = {
CenterAlignedTopAppBar(
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer,
titleContentColor = MaterialTheme.colorScheme.primary,
),
title = {
Text(
"Next-gen Kaldi: Speaker Identification",
fontWeight = FontWeight.Bold,
)
},
)
},
content = { padding ->
Column(Modifier.padding(padding)) {
NavigationHost(navController = navController)
}
},
bottomBar = {
BottomNavigationBar(navController = navController)
}
)
}
@Composable
fun NavigationHost(navController: NavHostController) {
NavHost(navController = navController, startDestination = NavRoutes.Home.route) {
composable(NavRoutes.Home.route) {
HomeScreen()
}
composable(NavRoutes.Register.route) {
RegisterScreen()
}
composable(NavRoutes.View.route) {
ViewScreen()
}
composable(NavRoutes.Help.route) {
HelpScreen()
}
}
}
@Composable
fun BottomNavigationBar(navController: NavHostController) {
NavigationBar {
val backStackEntry by navController.currentBackStackEntryAsState()
val currentRoute = backStackEntry?.destination?.route
NavBarItems.BarItems.forEach { navItem ->
NavigationBarItem(selected = currentRoute == navItem.route,
onClick = {
navController.navigate(navItem.route) {
popUpTo(navController.graph.findStartDestination().id) {
saveState = true
}
launchSingleTop = true
restoreState = true
}
},
icon = {
Icon(imageVector = navItem.image, contentDescription = navItem.title)
}, label = {
Text(text = navItem.title)
})
}
}
}
@Preview(showBackground = true)
@Composable
fun MainScreenPreview() {
SherpaOnnxSpeakerIdentificationTheme {
MainScreen()
}
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.AccountCircle
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.Home
import androidx.compose.material.icons.filled.Info
object NavBarItems {
val BarItems = listOf(
BarItem(
title = "Home",
image = Icons.Filled.Home,
route = "home",
),
BarItem(
title = "Register",
image = Icons.Filled.Add,
route = "register",
),
BarItem(
title = "View",
image = Icons.Filled.AccountCircle,
route = "view",
),
BarItem(
title = "Help",
image = Icons.Filled.Info,
route = "help",
),
)
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
sealed class NavRoutes(val route: String) {
object Home: NavRoutes("home")
object Register: NavRoutes("register")
object View: NavRoutes("view")
object Help: NavRoutes("help")
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
import com.k2fsa.sherpa.onnx.speaker.identification.TAG
data class SpeakerEmbeddingExtractorConfig(
val model: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
class SpeakerEmbeddingExtractorStream(var ptr: Long) {
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
protected fun finalize() {
delete(ptr)
ptr = 0
}
private external fun myTest(ptr: Long, v: Array<FloatArray>)
fun release() = finalize()
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun delete(ptr: Long)
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
class SpeakerEmbeddingExtractor(
assetManager: AssetManager? = null,
config: SpeakerEmbeddingExtractorConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
new(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
delete(ptr)
ptr = 0
}
fun release() = finalize()
fun createStream(): SpeakerEmbeddingExtractorStream {
val p = createStream(ptr)
return SpeakerEmbeddingExtractorStream(p)
}
fun isReady(stream: SpeakerEmbeddingExtractorStream) = isReady(ptr, stream.ptr)
fun compute(stream: SpeakerEmbeddingExtractorStream) = compute(ptr, stream.ptr)
fun dim() = dim(ptr)
private external fun new(
assetManager: AssetManager,
config: SpeakerEmbeddingExtractorConfig,
): Long
private external fun newFromFile(
config: SpeakerEmbeddingExtractorConfig,
): Long
private external fun delete(ptr: Long)
private external fun createStream(ptr: Long): Long
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
private external fun compute(ptr: Long, streamPtr: Long): FloatArray
private external fun dim(ptr: Long): Int
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
class SpeakerEmbeddingManager(val dim: Int) {
private var ptr: Long
init {
ptr = new(dim)
}
protected fun finalize() {
delete(ptr)
ptr = 0
}
fun release() = finalize()
fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding)
fun add(name: String, embedding: Array<FloatArray>) = addList(ptr, name, embedding)
fun remove(name: String) = remove(ptr, name)
fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold)
fun verify(name: String, embedding: FloatArray, threshold: Float) =
verify(ptr, name, embedding, threshold)
fun contains(name: String) = contains(ptr, name)
fun numSpeakers() = numSpeakers(ptr)
fun allSpeakerNames() = allSpeakerNames(ptr)
private external fun new(dim: Int): Long
private external fun delete(ptr: Long): Unit
private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean
private external fun addList(ptr: Long, name: String, embedding: Array<FloatArray>): Boolean
private external fun remove(ptr: Long, name: String): Boolean
private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String
private external fun verify(
ptr: Long,
name: String,
embedding: FloatArray,
threshold: Float
): Boolean
private external fun contains(ptr: Long, name: String): Boolean
private external fun numSpeakers(ptr: Long): Int
private external fun allSpeakerNames(ptr: Long): Array<String>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// Please download the model file from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
// and put it inside the assets directory.
//
// Please don't put it in a subdirectory of assets
private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
object SpeakerRecognition {
var _extractor: SpeakerEmbeddingExtractor? = null
var _manager: SpeakerEmbeddingManager? = null
val extractor: SpeakerEmbeddingExtractor
get() {
return _extractor!!
}
val manager: SpeakerEmbeddingManager
get() {
return _manager!!
}
fun initExtractor(assetManager: AssetManager? = null) {
synchronized(this) {
if (_extractor != null) {
return
}
Log.i(TAG, "Initializing speaker embedding extractor")
_extractor = SpeakerEmbeddingExtractor(
assetManager = assetManager,
config = SpeakerEmbeddingExtractorConfig(
model = modelName,
numThreads = 2,
debug = false,
provider = "cpu",
)
)
_manager = SpeakerEmbeddingManager(dim = _extractor!!.dim())
}
}
}
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.screens
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
@Composable
fun HelpScreen() {
Box(modifier= Modifier.fillMaxSize()) {
Column(
modifier = Modifier.padding(16.dp)
) {
Text("Please see http://github.com/k2-fsa/sherpa-onnx ")
Spacer(modifier = Modifier.height(16.dp))
Text("https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models")
Spacer(modifier = Modifier.height(16.dp))
Text("https://k2-fsa.github.io/sherpa/social-groups.html")
Spacer(modifier = Modifier.height(16.dp))
Text("Everything is open-sourced!")
}
}
}
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.screens
import android.Manifest
import android.annotation.SuppressLint
import android.app.Activity
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.material3.Button
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Slider
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.SpeakerRecognition
import com.k2fsa.sherpa.onnx.speaker.identification.R
import com.k2fsa.sherpa.onnx.speaker.identification.TAG
import kotlin.concurrent.thread
private var audioRecord: AudioRecord? = null
private var sampleList: MutableList<FloatArray>? = null
private val clearedResult = "-cleared-"
@Composable
fun HomeScreen() {
val activity = LocalContext.current as Activity
var threshold by remember {
mutableStateOf(0.5F)
}
var detectedName by remember {
mutableStateOf(clearedResult)
}
var isStarted by remember { mutableStateOf(false) }
val onRecordingButtonClick: () -> Unit = {
isStarted = !isStarted
if (isStarted) {
if (ActivityCompat.checkSelfPermission(
activity,
Manifest.permission.RECORD_AUDIO
) != PackageManager.PERMISSION_GRANTED
) {
Log.i(TAG, "Recording is not allowed")
} else {
// recording is allowed
val audioSource = MediaRecorder.AudioSource.MIC
val channelConfig = AudioFormat.CHANNEL_IN_MONO
val audioFormat = AudioFormat.ENCODING_PCM_16BIT
val numBytes =
AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
audioRecord = AudioRecord(
audioSource,
sampleRateInHz,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_16BIT,
numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
)
sampleList = null
detectedName = clearedResult
// recording is started here
thread(true) {
Log.i(TAG, "processing samples")
val interval = 0.1 // i.e., 100 ms
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
val buffer = ShortArray(bufferSize)
audioRecord?.let {
it.startRecording()
while (isStarted) {
val ret = audioRecord?.read(buffer, 0, buffer.size)
ret?.let { n ->
val samples = FloatArray(n) { buffer[it] / 32768.0f }
if (sampleList == null) {
sampleList = mutableListOf(samples)
} else {
sampleList?.add(samples)
}
}
}
}
Log.i(TAG, "Home: Recording is stopped. ${sampleList?.count()}")
}
}
} else {
// recording is stopped here
audioRecord?.stop()
audioRecord?.release()
audioRecord = null
sampleList?.let {
val stream = SpeakerRecognition.extractor.createStream()
for (samples in it) {
stream.acceptWaveform(samples = samples, sampleRate = sampleRateInHz)
}
stream.inputFinished()
if (SpeakerRecognition.extractor.isReady(stream)) {
val embedding = SpeakerRecognition.extractor.compute(stream)
detectedName = SpeakerRecognition.manager.search(
embedding = embedding,
threshold = threshold,
)
}
}
}
}
val onThresholdChange = { newValue: Float ->
threshold = newValue
}
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.TopCenter,
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
) {
HomeThresholdRow(
threshold = threshold,
onValueChange = onThresholdChange,
)
HomeButtonRow(
isStarted = isStarted,
onRecordingButtonClick = onRecordingButtonClick,
onClearButtonClick = {
detectedName = clearedResult
},
)
Spacer(modifier = Modifier.height(48.dp))
if(detectedName == clearedResult) {
// do nothing
} else if (detectedName.length > 0) {
Text(
text = "Speaker: ${detectedName}",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold,
)
} else {
Text(
text = "Unknown speaker",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold,
)
}
}
}
}
@SuppressLint("UnrememberedMutableState")
@Composable
private fun HomeButtonRow(
modifier: Modifier = Modifier,
isStarted: Boolean,
onRecordingButtonClick: () -> Unit,
onClearButtonClick: () -> Unit,
) {
val numSpeakers: Int by mutableStateOf(SpeakerRecognition.manager.numSpeakers())
Row(
modifier = modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.Center,
) {
Button(
enabled = numSpeakers > 0,
onClick = onRecordingButtonClick
) {
Text(text = stringResource(if (isStarted) R.string.stop else R.string.start))
}
Spacer(modifier = Modifier.width(24.dp))
Button(onClick = onClearButtonClick) {
Text(text = stringResource(id = R.string.clear))
}
}
}
@Composable
fun HomeThresholdRow(
modifier: Modifier = Modifier,
threshold: Float,
onValueChange: (Float) -> Unit,
) {
Column(modifier = Modifier) {
Text(
text = "Threshold: " + String.format("%.2f", threshold),
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold,
modifier = modifier.padding(bottom = 8.dp, top = 8.dp),
)
Slider(
value = threshold,
onValueChange = onValueChange,
valueRange = 0.1F..1.0F,
modifier = modifier.fillMaxWidth(),
)
}
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.screens
import android.Manifest
import android.annotation.SuppressLint
import android.app.Activity
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import android.widget.Toast
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.material3.Button
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.SpeakerRecognition
import com.k2fsa.sherpa.onnx.speaker.identification.R
import com.k2fsa.sherpa.onnx.speaker.identification.TAG
import kotlin.concurrent.thread
private var audioRecord: AudioRecord? = null
private var sampleList: MutableList<FloatArray>? = null
private var embeddingList: MutableList<FloatArray>? = null
val sampleRateInHz = 16000
@SuppressLint("UnrememberedMutableState")
@Preview
@Composable
fun RegisterScreen(modifier: Modifier = Modifier) {
val activity = LocalContext.current as Activity
var firstTime by remember { mutableStateOf(true) }
if (firstTime) {
firstTime = false
// clear states
embeddingList = null
}
val numberAudio: Int by mutableStateOf(embeddingList?.count() ?: 0)
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.TopCenter
) {
var speakerName by remember { mutableStateOf("") }
val onSpeakerNameChange = { newName: String -> speakerName = newName }
var isStarted by remember { mutableStateOf(false) }
val onRecordingButtonClick: () -> Unit = {
isStarted = !isStarted
if (isStarted) {
if (ActivityCompat.checkSelfPermission(
activity,
Manifest.permission.RECORD_AUDIO
) != PackageManager.PERMISSION_GRANTED
) {
Log.i(TAG, "Recording is not allowed")
} else {
// recording is allowed
val audioSource = MediaRecorder.AudioSource.MIC
val channelConfig = AudioFormat.CHANNEL_IN_MONO
val audioFormat = AudioFormat.ENCODING_PCM_16BIT
val numBytes =
AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
audioRecord = AudioRecord(
audioSource,
sampleRateInHz,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_16BIT,
numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
)
sampleList = null
// recording is started here
thread(true) {
Log.i(TAG, "processing samples")
val interval = 0.1 // i.e., 100 ms
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
val buffer = ShortArray(bufferSize)
audioRecord?.let {
it.startRecording()
while (isStarted) {
val ret = audioRecord?.read(buffer, 0, buffer.size)
ret?.let { n ->
val samples = FloatArray(n) { buffer[it] / 32768.0f }
if (sampleList == null) {
sampleList = mutableListOf(samples)
} else {
sampleList?.add(samples)
}
}
}
}
Log.i(TAG, "Recording is stopped. ${sampleList?.count()}")
}
}
} else {
// recording is stopped here
audioRecord?.stop()
audioRecord?.release()
audioRecord = null
sampleList?.let {
val stream = SpeakerRecognition.extractor.createStream()
for (samples in it) {
stream.acceptWaveform(samples=samples, sampleRate=sampleRateInHz)
}
stream.inputFinished()
if(SpeakerRecognition.extractor.isReady(stream)) {
val embedding = SpeakerRecognition.extractor.compute(stream)
if(embeddingList == null) {
embeddingList = mutableListOf(embedding)
} else {
embeddingList?.add(embedding)
}
}
}
}
}
val onAddButtonClick: () -> Unit = {
if(speakerName.isEmpty() || speakerName.isBlank()) {
Toast.makeText(
activity,
"please input a speaker name",
Toast.LENGTH_SHORT
).show()
} else if(SpeakerRecognition.manager.contains(speakerName.trim())) {
Toast.makeText(
activity,
"A speaker with $speakerName already exists. Please choose a new name",
Toast.LENGTH_SHORT
).show()
} else {
val ok = SpeakerRecognition.manager.add(speakerName.trim(), embedding = embeddingList!!.toTypedArray())
if(ok) {
Log.i(TAG, "Added ${speakerName.trim()} successfully")
Toast.makeText(
activity,
"Added ${speakerName.trim()}",
Toast.LENGTH_SHORT
).show()
embeddingList = null
sampleList = null
speakerName = ""
firstTime = true
} else {
Log.i(TAG, "Failed to add ${speakerName.trim()}")
Toast.makeText(
activity,
"Failed to add ${speakerName.trim()}",
Toast.LENGTH_SHORT
).show()
}
}
}
Column(horizontalAlignment = Alignment.CenterHorizontally) {
SpeakerNameRow(speakerName = speakerName, onValueChange = onSpeakerNameChange)
Text(
"Number of recordings: ${numberAudio}",
modifier = modifier.padding(24.dp),
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold,
)
RegisterSpeakerButtonRow(
modifier,
isStarted = isStarted,
onRecordingButtonClick = onRecordingButtonClick,
onAddButtonClick = onAddButtonClick,
)
}
}
}
@Composable
fun SpeakerNameRow(
modifier: Modifier = Modifier,
speakerName: String,
onValueChange: (String) -> Unit
) {
OutlinedTextField(
value = speakerName,
onValueChange = onValueChange,
label = {
Text("Please input the speaker name")
},
singleLine = true,
modifier = modifier
.fillMaxWidth()
.padding(8.dp)
)
}
@SuppressLint("UnrememberedMutableState")
@Composable
fun RegisterSpeakerButtonRow(
modifier: Modifier = Modifier,
isStarted: Boolean,
onRecordingButtonClick: () -> Unit,
onAddButtonClick: () -> Unit,
) {
val numberAudio: Int by mutableStateOf(embeddingList?.count() ?: 0)
Row(
modifier = modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.Center,
) {
Button(onClick = onRecordingButtonClick) {
Text(text = stringResource(if (isStarted) R.string.stop else R.string.start))
}
Spacer(modifier = Modifier.width(24.dp))
Button(
enabled = numberAudio > 0,
onClick = onAddButtonClick,
) {
Text(text = stringResource(id = R.string.add))
}
}
}
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.screens
import android.annotation.SuppressLint
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material3.Button
import androidx.compose.material3.Checkbox
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.toMutableStateList
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import com.k2fsa.sherpa.onnx.SpeakerRecognition
class SpeakerName(val name: String) {
val nameState = mutableStateOf(name)
val checked = mutableStateOf(false)
fun onCheckedChange(newValue: Boolean) {
checked.value = newValue
}
}
@SuppressLint("UnrememberedMutableState")
@OptIn(ExperimentalFoundationApi::class)
@Composable
fun ViewScreen() {
val allSpeakerNames = SpeakerRecognition.manager.allSpeakerNames()
val allSpeakerNameList = remember {
MutableList(
allSpeakerNames.size
) {
SpeakerName(allSpeakerNames[it])
}.toMutableStateList()
}
var enabled by remember {
mutableStateOf(SpeakerRecognition.manager.numSpeakers() > 0)
}
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.TopCenter
) {
Column(
modifier = Modifier.padding(16.dp),
horizontalAlignment = Alignment.CenterHorizontally,
) {
Button(
enabled = enabled,
onClick = {
val toRemove: MutableList<SpeakerName> = mutableListOf()
for (s in allSpeakerNameList) {
if (s.checked.value) {
SpeakerRecognition.manager.remove(s.name)
toRemove.add(s)
}
}
allSpeakerNameList.removeAll(toRemove)
enabled = SpeakerRecognition.manager.numSpeakers() > 0
}) {
Text("Delete selected")
}
LazyColumn(modifier = Modifier.fillMaxSize()) {
items(allSpeakerNameList) { s: SpeakerName ->
ViewRow(speakerName = s)
}
}
}
}
}
@Composable
fun ViewRow(
modifier: Modifier = Modifier,
speakerName: SpeakerName
) {
Surface(
modifier = modifier
.fillMaxWidth()
.padding(8.dp),
color = MaterialTheme.colorScheme.inversePrimary,
) {
Row(
modifier = modifier,
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically,
) {
Text(
text = speakerName.name,
modifier = modifier.weight(1.0F),
)
Checkbox(checked = speakerName.checked.value,
onCheckedChange = { speakerName.onCheckedChange(it) }
)
}
}
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.ui.theme
import androidx.compose.ui.graphics.Color
val Purple80 = Color(0xFFD0BCFF)
val PurpleGrey80 = Color(0xFFCCC2DC)
val Pink80 = Color(0xFFEFB8C8)
val Purple40 = Color(0xFF6650a4)
val PurpleGrey40 = Color(0xFF625b71)
val Pink40 = Color(0xFF7D5260)
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.ui.theme
import android.app.Activity
import android.os.Build
import androidx.compose.foundation.isSystemInDarkTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.darkColorScheme
import androidx.compose.material3.dynamicDarkColorScheme
import androidx.compose.material3.dynamicLightColorScheme
import androidx.compose.material3.lightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.SideEffect
import androidx.compose.ui.graphics.toArgb
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalView
import androidx.core.view.WindowCompat
private val DarkColorScheme = darkColorScheme(
primary = Purple80,
secondary = PurpleGrey80,
tertiary = Pink80
)
private val LightColorScheme = lightColorScheme(
primary = Purple40,
secondary = PurpleGrey40,
tertiary = Pink40
/* Other default colors to override
background = Color(0xFFFFFBFE),
surface = Color(0xFFFFFBFE),
onPrimary = Color.White,
onSecondary = Color.White,
onTertiary = Color.White,
onBackground = Color(0xFF1C1B1F),
onSurface = Color(0xFF1C1B1F),
*/
)
@Composable
fun SherpaOnnxSpeakerIdentificationTheme(
darkTheme: Boolean = isSystemInDarkTheme(),
// Dynamic color is available on Android 12+
dynamicColor: Boolean = true,
content: @Composable () -> Unit
) {
val colorScheme = when {
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
val context = LocalContext.current
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
}
darkTheme -> DarkColorScheme
else -> LightColorScheme
}
val view = LocalView.current
if (!view.isInEditMode) {
SideEffect {
val window = (view.context as Activity).window
window.statusBarColor = colorScheme.primary.toArgb()
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
}
}
MaterialTheme(
colorScheme = colorScheme,
typography = Typography,
content = content
)
}
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification.ui.theme
import androidx.compose.material3.Typography
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.sp
// Set of Material typography styles to start with
val Typography = Typography(
bodyLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 16.sp,
lineHeight = 24.sp,
letterSpacing = 0.5.sp
)
/* Other default text styles to override
titleLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 22.sp,
lineHeight = 28.sp,
letterSpacing = 0.sp
),
labelSmall = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Medium,
fontSize = 11.sp,
lineHeight = 16.sp,
letterSpacing = 0.5.sp
)
*/
)
\ No newline at end of file
... ...
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
<aapt:attr name="android:fillColor">
<gradient
android:endX="85.84757"
android:endY="92.4963"
android:startX="42.9492"
android:startY="49.59793"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#3DDC84"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>
... ...
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="purple_200">#FFBB86FC</color>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="teal_700">#FF018786</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>
\ No newline at end of file
... ...
<resources>
<string name="app_name">Speaker Identification</string>
<string name="start">Start recording</string>
<string name="stop">Stop recording</string>
<string name="add">Add speaker</string>
<string name="clear">Clear result</string>
</resources>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?>
<resources>
<style name="Theme.SherpaOnnxSpeakerIdentification" parent="android:Theme.Material.Light.NoActionBar" />
</resources>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?><!--
Sample backup rules file; uncomment and customize as necessary.
See https://developer.android.com/guide/topics/data/autobackup
for details.
Note: This file is ignored for devices older that API 31
See https://developer.android.com/about/versions/12/backup-restore
-->
<full-backup-content>
<!--
<include domain="sharedpref" path="."/>
<exclude domain="sharedpref" path="device.xml"/>
-->
</full-backup-content>
\ No newline at end of file
... ...
<?xml version="1.0" encoding="utf-8"?><!--
Sample data extraction rules file; uncomment and customize as necessary.
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
for details.
-->
<data-extraction-rules>
<cloud-backup>
<!-- TODO: Use <include> and <exclude> to control what is backed up.
<include .../>
<exclude .../>
-->
</cloud-backup>
<!--
<device-transfer>
<include .../>
<exclude .../>
</device-transfer>
-->
</data-extraction-rules>
\ No newline at end of file
... ...
package com.k2fsa.sherpa.onnx.speaker.identification
import org.junit.Test
import org.junit.Assert.*
/**
* Example local unit test, which will execute on the development machine (host).
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
class ExampleUnitTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
}
\ No newline at end of file
... ...
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
}
\ No newline at end of file
... ...
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Kotlin code style for this project: "official" or "obsolete":
kotlin.code.style=official
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
\ No newline at end of file
... ...
#Sun Jan 21 18:37:37 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
... ...
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"
... ...
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
... ...
pluginManagement {
repositories {
google()
mavenCentral()
gradlePluginPortal()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "SherpaOnnxSpeakerIdentification"
include(":app")
... ...
... ... @@ -7,11 +7,67 @@ fun callback(samples: FloatArray): Unit {
}
fun main() {
testSpeakerRecognition()
testTts()
testAsr("transducer")
testAsr("zipformer2-ctc")
}
fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
var objArray = WaveReader.readWaveFromFile(
filename = filename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
val stream = extractor.createStream()
stream.acceptWaveform(sampleRate = sampleRate, samples=samples)
stream.inputFinished()
check(extractor.isReady(stream))
val embedding = extractor.compute(stream)
stream.release()
return embedding
}
fun testSpeakerRecognition() {
val config = SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
)
val extractor = SpeakerEmbeddingExtractor(config = config)
val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")
val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")
val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")
var manager = SpeakerEmbeddingManager(extractor.dim())
var ok = manager.add(name = "speaker1", embedding=embedding1a)
check(ok)
manager.add(name = "speaker2", embedding=embedding2a)
check(ok)
var name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "speaker1")
manager.release()
manager = SpeakerEmbeddingManager(extractor.dim())
val embeddingList = mutableListOf(embedding1a, embedding1b)
ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())
check(ok)
name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "s1")
name = manager.search(embedding=embedding2a, threshold=0.5f)
check(name.length == 0)
manager.release()
}
fun testTts() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
... ...
../android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt
\ No newline at end of file
... ...
... ... @@ -29,6 +29,22 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
cd ../kotlin-api-examples
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
fi
if [ ! -f ./speaker1_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
fi
if [ ! -f ./speaker1_b_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
fi
if [ ! -f ./speaker2_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
fi
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git lfs install
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
... ... @@ -46,7 +62,7 @@ if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
rm vits-piper-en_US-amy-low.tar.bz2
fi
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt Speaker.kt
ls -lh main.jar
... ...
build-apk-tts.sh
!*.sh.in
... ...
#!/usr/bin/env bash
#
# Auto generated! Please DO NOT EDIT!
# Please set the environment variable ANDROID_NDK
# before running this script
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
log "Building Speaker identification APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
log "====================armv7-eabi================"
./build-android-armv7-eabi.sh
log "====================x86-64===================="
./build-android-x86-64.sh
log "====================x86===================="
./build-android-x86.sh
mkdir -p apks
{% for model in model_list %}
pushd ./android/SherpaOnnxSpeakerIdentification/app/src/main/assets/
model_name={{ model.model_name }}
short_name={{ model.short_name }}
lang={{ model.lang }}
framework={{ model.framework }}
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$model_name
popd
# Now we are at the project root directory
git checkout .
pushd android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/
sed -i.bak s/"private val modelName.*/private val modelName = \"$model_name\"/" ./Speaker.kt
git diff
popd
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
log "------------------------------------------------------------"
log "build tts apk for $arch"
log "------------------------------------------------------------"
src_arch=$arch
if [ $arch == "armeabi-v7a" ]; then
src_arch=armv7-eabi
elif [ $arch == "x86_64" ]; then
src_arch=x86-64
fi
ls -lh ./build-android-$src_arch/install/lib/*.so
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxSpeakerIdentification/app/src/main/jniLibs/$arch/
pushd ./android/SherpaOnnxSpeakerIdentification
./gradlew build
popd
mv android/SherpaOnnxSpeakerIdentification/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-$lang-speaker-identification-$framework-$short_name.apk
ls -lh apks
rm -v ./android/SherpaOnnxSpeakerIdentification/app/src/main/jniLibs/$arch/*.so
done
rm -rf ./android/SherpaOnnxSpeakerIdentification/app/src/main/assets/$model_name
{% endfor %}
git checkout .
ls -lh apks/
... ...
#!/usr/bin/env python3
import argparse
from dataclasses import dataclass
from typing import List, Optional
import jinja2
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--total",
type=int,
default=1,
help="Number of runners",
)
parser.add_argument(
"--index",
type=int,
default=0,
help="Index of the current runner",
)
return parser.parse_args()
@dataclass
class SpeakerIdentificationModel:
model_name: str
short_name: str = ""
lang: str = ""
framework: str = ""
def get_3dspeaker_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="3dspeaker_speech_campplus_sv_en_voxceleb_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_sv_en_voxceleb_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_sv_zh-cn_16k-common.onnx"),
]
prefix = '3dspeaker_speech_'
num = len(prefix)
for m in models:
m.framework = '3dspeaker'
m.short_name = m.model_name[num:-5]
if '_zh-cn_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
else:
raise ValueError(m)
return models
def get_wespeaker_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_CAM++.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_CAM++_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet152_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet221_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet293_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet34.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet34_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_zh_cnceleb_resnet34.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_zh_cnceleb_resnet34_LM.onnx"),
]
prefix = 'wespeaker_xx_'
num = len(prefix)
for m in models:
m.framework = 'wespeaker'
m.short_name = m.model_name[num:-5]
if '_zh_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
else:
raise ValueError(m)
return models
def get_nemo_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="nemo_en_speakerverification_speakernet.onnx"),
SpeakerIdentificationModel(model_name="nemo_en_titanet_large.onnx"),
SpeakerIdentificationModel(model_name="nemo_en_titanet_small.onnx"),
]
prefix = 'nemo_en_'
num = len(prefix)
for m in models:
m.framework = 'nemo'
m.short_name = m.model_name[num:-5]
if '_zh_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
else:
raise ValueError(m)
return models
def main():
args = get_args()
index = args.index
total = args.total
assert 0 <= index < total, (index, total)
all_model_list = get_3dspeaker_models()
all_model_list += get_wespeaker_models()
all_model_list += get_nemo_models()
num_models = len(all_model_list)
num_per_runner = num_models // total
if num_per_runner <= 0:
raise ValueError(f"num_models: {num_models}, num_runners: {total}")
start = index * num_per_runner
end = start + num_per_runner
remaining = num_models - args.total * num_per_runner
print(f"{index}/{total}: {start}-{end}/{num_models}")
d = dict()
d["model_list"] = all_model_list[start:end]
if index < remaining:
s = args.total * num_per_runner + index
d["model_list"].append(all_model_list[s])
print(f"{s}/{num_models}")
filename_list = ["./build-apk-speaker-identification.sh"]
for filename in filename_list:
environment = jinja2.Environment()
with open(f"{filename}.in") as f:
s = f.read()
template = environment.from_string(s)
s = template.render(**d)
with open(filename, "w") as f:
print(s, file=f)
if __name__ == "__main__":
main()
... ...
*.cc-bak
*.h-bak
... ...
... ... @@ -22,6 +22,12 @@ class SpeakerEmbeddingExtractorGeneralImpl
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorGeneralImpl(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: model_(mgr, config) {}
#endif
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
std::unique_ptr<OnlineStream> CreateStream() const override {
... ...
... ... @@ -90,4 +90,35 @@ SpeakerEmbeddingExtractorImpl::Create(
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
{
auto buffer = ReadFile(mgr, config.model);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kWeSpeaker:
// fall through
case ModelType::k3dSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(mgr,
config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
return nullptr;
}
// unreachable code
return nullptr;
}
#endif
} // namespace sherpa_onnx
... ...
... ... @@ -9,6 +9,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
... ... @@ -20,6 +25,11 @@ class SpeakerEmbeddingExtractorImpl {
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
const SpeakerEmbeddingExtractorConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config);
#endif
virtual int32_t Dim() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
... ...
... ... @@ -28,6 +28,19 @@ class SpeakerEmbeddingExtractorModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.model);
Init(buf.data(), buf.size());
}
}
#endif
Ort::Value Compute(Ort::Value x) const {
std::array<Ort::Value, 1> inputs = {std::move(x)};
... ... @@ -98,6 +111,12 @@ SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default;
const SpeakerEmbeddingExtractorModelMetaData &
... ...
... ... @@ -6,6 +6,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
... ... @@ -17,6 +22,11 @@ class SpeakerEmbeddingExtractorModel {
explicit SpeakerEmbeddingExtractorModel(
const SpeakerEmbeddingExtractorConfig &config);
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorModel(AAssetManager *mgr,
const SpeakerEmbeddingExtractorConfig &config);
#endif
~SpeakerEmbeddingExtractorModel();
const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const;
... ...
... ... @@ -22,6 +22,12 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorNeMoImpl(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: model_(mgr, config) {}
#endif
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
std::unique_ptr<OnlineStream> CreateStream() const override {
... ...
... ... @@ -28,6 +28,19 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.model);
Init(buf.data(), buf.size());
}
}
#endif
Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const {
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
... ... @@ -106,6 +119,12 @@ SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() =
default;
... ...
... ... @@ -6,6 +6,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
... ... @@ -17,6 +22,11 @@ class SpeakerEmbeddingExtractorNeMoModel {
explicit SpeakerEmbeddingExtractorNeMoModel(
const SpeakerEmbeddingExtractorConfig &config);
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractorNeMoModel(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config);
#endif
~SpeakerEmbeddingExtractorNeMoModel();
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const;
... ...
... ... @@ -55,6 +55,12 @@ SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: impl_(SpeakerEmbeddingExtractorImpl::Create(mgr, config)) {}
#endif
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
... ...
... ... @@ -9,6 +9,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
... ... @@ -40,6 +45,11 @@ class SpeakerEmbeddingExtractor {
explicit SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config);
#if __ANDROID_API__ >= 9
SpeakerEmbeddingExtractor(AAssetManager *mgr,
const SpeakerEmbeddingExtractorConfig &config);
#endif
~SpeakerEmbeddingExtractor();
// Return the dimension of the embedding
... ...
... ... @@ -8,6 +8,7 @@
#include <unordered_map>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
... ... @@ -36,6 +37,52 @@ class SpeakerEmbeddingManager::Impl {
return true;
}
bool Add(const std::string &name,
const std::vector<std::vector<float>> &embedding_list) {
if (name2row_.count(name)) {
// a speaker with the same name already exists
return false;
}
if (embedding_list.empty()) {
SHERPA_ONNX_LOGE("Empty list of embeddings");
return false;
}
for (const auto &x : embedding_list) {
if (x.size() != dim_) {
SHERPA_ONNX_LOGE("Given dim: %d, expected dim: %d",
static_cast<int32_t>(x.size()), dim_);
return false;
}
}
// compute the average
Eigen::RowVectorXf v = Eigen::Map<Eigen::RowVectorXf>(
const_cast<float *>(embedding_list[0].data()), dim_);
int32_t i = -1;
for (const auto &x : embedding_list) {
++i;
if (i == 0) {
continue;
}
v += Eigen::Map<Eigen::RowVectorXf>(const_cast<float *>(x.data()), dim_);
}
// no need to compute the mean since we are going to normalize it anyway
// v /= embedding_list.size();
v.normalize();
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
embedding_matrix_.bottomRows(1) = v;
name2row_[name] = embedding_matrix_.rows() - 1;
row2name_[embedding_matrix_.rows() - 1] = name;
return true;
}
bool Remove(const std::string &name) {
if (!name2row_.count(name)) {
return false;
... ... @@ -104,8 +151,24 @@ class SpeakerEmbeddingManager::Impl {
return true;
}
bool Contains(const std::string &name) const {
return name2row_.count(name) > 0;
}
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
int32_t Dim() const { return dim_; }
std::vector<std::string> GetAllSpeakers() const {
std::vector<std::string> all_speakers;
for (const auto &p : name2row_) {
all_speakers.push_back(p.first);
}
std::stable_sort(all_speakers.begin(), all_speakers.end());
return all_speakers;
}
private:
int32_t dim_;
FloatMatrix embedding_matrix_;
... ... @@ -123,6 +186,12 @@ bool SpeakerEmbeddingManager::Add(const std::string &name,
return impl_->Add(name, p);
}
bool SpeakerEmbeddingManager::Add(
const std::string &name,
const std::vector<std::vector<float>> &embedding_list) const {
return impl_->Add(name, embedding_list);
}
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
return impl_->Remove(name);
}
... ... @@ -141,4 +210,14 @@ int32_t SpeakerEmbeddingManager::NumSpeakers() const {
return impl_->NumSpeakers();
}
int32_t SpeakerEmbeddingManager::Dim() const { return impl_->Dim(); }
bool SpeakerEmbeddingManager::Contains(const std::string &name) const {
return impl_->Contains(name);
}
std::vector<std::string> SpeakerEmbeddingManager::GetAllSpeakers() const {
return impl_->GetAllSpeakers();
}
} // namespace sherpa_onnx
... ...
... ... @@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include <vector>
namespace sherpa_onnx {
... ... @@ -26,6 +27,19 @@ class SpeakerEmbeddingManager {
*/
bool Add(const std::string &name, const float *p) const;
/** Add a list of embeddings of a speaker.
*
* @param name Name of the speaker
* @param embedding_list A list of embeddings. Each entry should be of size
* `dim`. The average of the list is the final
* embedding.
* @return Return true if added successfully. Return false if it failed.
* At present, the only reason for a failure is that there is already
* a speaker with the same `name`.
*/
bool Add(const std::string &name,
const std::vector<std::vector<float>> &embedding_list) const;
/* Remove a speaker by its name.
*
* @param name Name of the speaker to remove.
... ... @@ -60,8 +74,16 @@ class SpeakerEmbeddingManager {
*/
bool Verify(const std::string &name, const float *p, float threshold) const;
// Return true if the given speaker already exists; return false otherwise.
bool Contains(const std::string &name) const;
int32_t NumSpeakers() const;
int32_t Dim() const;
// Return a list of speaker names
std::vector<std::string> GetAllSpeakers() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -27,6 +27,8 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
... ... @@ -208,6 +210,85 @@ class SherpaOnnxKws {
int32_t input_sample_rate_ = -1;
};
class SherpaOnnxSpeakerEmbeddingExtractorStream {
public:
explicit SherpaOnnxSpeakerEmbeddingExtractorStream(
std::unique_ptr<OnlineStream> stream)
: stream_(std::move(stream)) {}
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const { stream_->InputFinished(); }
OnlineStream *Get() const { return stream_.get(); }
private:
std::unique_ptr<OnlineStream> stream_;
};
class SherpaOnnxSpeakerEmbeddingExtractor {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxSpeakerEmbeddingExtractor(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: extractor_(mgr, config) {}
#endif
explicit SherpaOnnxSpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: extractor_(config) {}
int32_t Dim() const { return extractor_.Dim(); }
bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.IsReady(stream->Get());
}
SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const {
return new SherpaOnnxSpeakerEmbeddingExtractorStream(
extractor_.CreateStream());
}
std::vector<float> Compute(
const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.Compute(stream->Get());
}
private:
SpeakerEmbeddingExtractor extractor_;
};
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
JNIEnv *env, jobject config) {
SpeakerEmbeddingExtractorConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
... ... @@ -772,6 +853,334 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(
ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr)
->CreateStream();
return (jlong)stream;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
return extractor->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
std::vector<float> embedding = extractor->Compute(stream);
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
embedding.data());
return embedding_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
return extractor->Dim();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
stream->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
stream->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_new(
JNIEnv *env, jobject /*obj*/, jint dim) {
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
delete manager;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jstring name,
jfloatArray embedding) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, p);
env->ReleaseStringUTFChars(name, p_name);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jobjectArray embedding_arr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
int num_embeddings = env->GetArrayLength(embedding_arr);
if (num_embeddings == 0) {
return false;
}
std::vector<std::vector<float>> embedding_list;
embedding_list.reserve(num_embeddings);
for (int32_t i = 0; i != num_embeddings; ++i) {
jfloatArray embedding =
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
embedding_list.push_back({p, p + n});
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, embedding_list);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Remove(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jfloatArray embedding,
jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
std::string name = manager->Search(p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return env->NewStringUTF(name.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jfloatArray embedding, jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Verify(p_name, p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Contains(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
return manager->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (auto &s : all_speakers) {
jstring js = env->NewStringUTF(s.c_str());
env->SetObjectArrayElement(obj_arr, i, js);
++i;
}
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
... ... @@ -783,10 +1192,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Erros found in config!");
}
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(
#if __ANDROID_API__ >= 9
mgr,
... ... @@ -801,6 +1206,11 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config);
return (jlong)tts;
... ...
... ... @@ -17,6 +17,14 @@ void PybindSpeakerEmbeddingManager(py::module *m) {
.def(py::init<int32_t>(), py::arg("dim"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def_property_readonly("dim", &PyClass::Dim)
.def_property_readonly("all_speakers", &PyClass::GetAllSpeakers)
.def(
"__contains__",
[](const PyClass &self, const std::string &name) -> bool {
return self.Contains(name);
},
py::arg("name"), py::call_guard<py::gil_scoped_release>())
.def(
"add",
[](const PyClass &self, const std::string &name,
... ... @@ -26,6 +34,14 @@ void PybindSpeakerEmbeddingManager(py::module *m) {
py::arg("name"), py::arg("v"),
py::call_guard<py::gil_scoped_release>())
.def(
"add",
[](const PyClass &self, const std::string &name,
const std::vector<std::vector<float>> &embedding_list) -> bool {
return self.Add(name, embedding_list);
},
py::arg("name"), py::arg("embedding_list"),
py::call_guard<py::gil_scoped_release>())
.def(
"remove",
[](const PyClass &self, const std::string &name) -> bool {
return self.Remove(name);
... ...