Fangjun Kuang
Committed by GitHub

Add Java API for speaker identification (#822)

@@ -106,6 +106,15 @@ jobs: @@ -106,6 +106,15 @@ jobs:
106 make -j4 106 make -j4
107 ls -lh lib 107 ls -lh lib
108 108
  109 + - name: Run java test (speaker identification)
  110 + shell: bash
  111 + run: |
  112 + cd ./java-api-examples
  113 + ./run-speaker-identification.sh
  114 + # Delete model files to save space
  115 + rm -rf *.onnx
  116 + rm -rf sr-data
  117 +
109 - name: Run java test (audio tagging) 118 - name: Run java test (audio tagging)
110 shell: bash 119 shell: bash
111 run: | 120 run: |
@@ -50,3 +50,9 @@ The punctuation model supports both English and Chinese. @@ -50,3 +50,9 @@ The punctuation model supports both English and Chinese.
50 ./run-audio-tagging-zipformer-from-file.sh 50 ./run-audio-tagging-zipformer-from-file.sh
51 ./run-audio-tagging-ced-from-file.sh 51 ./run-audio-tagging-ced-from-file.sh
52 ``` 52 ```
  53 +
  54 +## Speaker identification
  55 +
  56 +```bash
  57 +./run-speaker-identification.sh
  58 +```
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +// This file shows how to use a speaker embedding extractor model for speaker
  4 +// identification.
  5 +import com.k2fsa.sherpa.onnx.*;
  6 +
  7 +public class SpeakerIdentification {
  8 + public static float[] computeEmbedding(SpeakerEmbeddingExtractor extractor, String filename) {
  9 + WaveReader reader = new WaveReader(filename);
  10 +
  11 + OnlineStream stream = extractor.createStream();
  12 + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
  13 + stream.inputFinished();
  14 +
  15 + float[] embedding = extractor.compute(stream);
  16 + stream.release();
  17 +
  18 + return embedding;
  19 + }
  20 +
  21 + public static void main(String[] args) {
  22 + // Please download the model from
  23 + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  24 + String model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";
  25 + SpeakerEmbeddingExtractorConfig config =
  26 + SpeakerEmbeddingExtractorConfig.builder()
  27 + .setModel(model)
  28 + .setNumThreads(1)
  29 + .setDebug(true)
  30 + .build();
  31 + SpeakerEmbeddingExtractor extractor = new SpeakerEmbeddingExtractor(config);
  32 + SpeakerEmbeddingManager manager = new SpeakerEmbeddingManager(extractor.getDim());
  33 +
  34 + String[] spk1Files =
  35 + new String[] {
  36 + "./sr-data/enroll/fangjun-sr-1.wav",
  37 + "./sr-data/enroll/fangjun-sr-2.wav",
  38 + "./sr-data/enroll/fangjun-sr-3.wav",
  39 + };
  40 +
  41 + float[][] spk1Vec = new float[spk1Files.length][];
  42 +
  43 + for (int i = 0; i < spk1Files.length; ++i) {
  44 + spk1Vec[i] = computeEmbedding(extractor, spk1Files[i]);
  45 + }
  46 +
  47 + String[] spk2Files =
  48 + new String[] {
  49 + "./sr-data/enroll/leijun-sr-1.wav", "./sr-data/enroll/leijun-sr-2.wav",
  50 + };
  51 +
  52 + float[][] spk2Vec = new float[spk2Files.length][];
  53 +
  54 + for (int i = 0; i < spk2Files.length; ++i) {
  55 + spk2Vec[i] = computeEmbedding(extractor, spk2Files[i]);
  56 + }
  57 +
  58 + if (!manager.add("fangjun", spk1Vec)) {
  59 + System.out.println("Failed to register fangjun");
  60 + return;
  61 + }
  62 +
  63 + if (!manager.add("leijun", spk2Vec)) {
  64 + System.out.println("Failed to register leijun");
  65 + return;
  66 + }
  67 +
  68 + if (manager.getNumSpeakers() != 2) {
  69 + System.out.println("There should be two speakers");
  70 + return;
  71 + }
  72 +
  73 + if (!manager.contains("fangjun")) {
  74 + System.out.println("It should contain the speaker fangjun");
  75 + return;
  76 + }
  77 +
  78 + if (!manager.contains("leijun")) {
  79 + System.out.println("It should contain the speaker leijun");
  80 + return;
  81 + }
  82 +
  83 + System.out.println("---All speakers---");
  84 + String[] allSpeakers = manager.getAllSpeakerNames();
  85 + for (String s : allSpeakers) {
  86 + System.out.println(s);
  87 + }
  88 + System.out.println("------------");
  89 +
  90 + String[] testFiles =
  91 + new String[] {
  92 + "./sr-data/test/fangjun-test-sr-1.wav",
  93 + "./sr-data/test/leijun-test-sr-1.wav",
  94 + "./sr-data/test/liudehua-test-sr-1.wav"
  95 + };
  96 +
  97 + float threshold = 0.6f;
  98 + for (String file : testFiles) {
  99 + float[] embedding = computeEmbedding(extractor, file);
  100 +
  101 + String name = manager.search(embedding, threshold);
  102 + if (name.isEmpty()) {
  103 + name = "<Unknown>";
  104 + }
  105 + System.out.printf("%s: %s\n", file, name);
  106 + }
  107 +
  108 + // test verify
  109 + if (!manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) {
  110 + System.out.printf("testFiles[0] should match fangjun!");
  111 + return;
  112 + }
  113 +
  114 + if (!manager.remove("fangjun")) {
  115 + System.out.println("Failed to remove fangjun");
  116 + return;
  117 + }
  118 +
  119 + if (manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) {
  120 + System.out.printf("%s should match no one!\n", testFiles[0]);
  121 + return;
  122 + }
  123 +
  124 + if (manager.getNumSpeakers() != 1) {
  125 + System.out.println("There should only 1 speaker left.");
  126 + return;
  127 + }
  128 +
  129 + extractor.release();
  130 + manager.release();
  131 + }
  132 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
  6 + mkdir -p ../build
  7 + pushd ../build
  8 + cmake \
  9 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  10 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  11 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  12 + -DBUILD_SHARED_LIBS=ON \
  13 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  14 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  15 + ..
  16 +
  17 + make -j4
  18 + ls -lh lib
  19 + popd
  20 +fi
  21 +
  22 +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
  23 + pushd ../sherpa-onnx/java-api
  24 + make
  25 + popd
  26 +fi
  27 +
  28 +if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
  29 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  30 +fi
  31 +
  32 +if [ ! -f ./sr-data/enroll/leijun-sr-1.wav ]; then
  33 + curl -SL -o sr-data.tar.gz https://github.com/csukuangfj/sr-data/archive/refs/tags/v1.0.0.tar.gz
  34 + tar xvf sr-data.tar.gz
  35 + mv sr-data-1.0.0 sr-data
  36 +fi
  37 +
  38 +java \
  39 + -Djava.library.path=$PWD/../build/lib \
  40 + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
  41 + ./SpeakerIdentification.java
@@ -51,6 +51,10 @@ java_files += AudioTaggingConfig.java @@ -51,6 +51,10 @@ java_files += AudioTaggingConfig.java
51 java_files += AudioEvent.java 51 java_files += AudioEvent.java
52 java_files += AudioTagging.java 52 java_files += AudioTagging.java
53 53
  54 +java_files += SpeakerEmbeddingExtractorConfig.java
  55 +java_files += SpeakerEmbeddingExtractor.java
  56 +java_files += SpeakerEmbeddingManager.java
  57 +
54 class_files := $(java_files:%.java=%.class) 58 class_files := $(java_files:%.java=%.class)
55 59
56 java_files := $(addprefix src/$(package_dir)/,$(java_files)) 60 java_files := $(addprefix src/$(package_dir)/,$(java_files))
@@ -14,7 +14,7 @@ public class AudioTaggingConfig { @@ -14,7 +14,7 @@ public class AudioTaggingConfig {
14 } 14 }
15 15
16 public static Builder builder() { 16 public static Builder builder() {
17 - return new AudioTaggingConfig.Builder(); 17 + return new Builder();
18 } 18 }
19 19
20 public static class Builder { 20 public static class Builder {
@@ -7,7 +7,7 @@ public class OfflineRecognizer { @@ -7,7 +7,7 @@ public class OfflineRecognizer {
7 System.loadLibrary("sherpa-onnx-jni"); 7 System.loadLibrary("sherpa-onnx-jni");
8 } 8 }
9 9
10 - private long ptr = 0; // this is the asr engine ptrss 10 + private long ptr = 0;
11 11
12 public OfflineRecognizer(OfflineRecognizerConfig config) { 12 public OfflineRecognizer(OfflineRecognizerConfig config) {
13 ptr = newFromFile(config); 13 ptr = newFromFile(config);
@@ -7,7 +7,7 @@ public class OfflineTts { @@ -7,7 +7,7 @@ public class OfflineTts {
7 System.loadLibrary("sherpa-onnx-jni"); 7 System.loadLibrary("sherpa-onnx-jni");
8 } 8 }
9 9
10 - private long ptr = 0; // this is the asr engine ptrss 10 + private long ptr = 0;
11 11
12 public OfflineTts(OfflineTtsConfig config) { 12 public OfflineTts(OfflineTtsConfig config) {
13 ptr = newFromFile(config); 13 ptr = newFromFile(config);
@@ -8,7 +8,7 @@ public class OnlineRecognizer { @@ -8,7 +8,7 @@ public class OnlineRecognizer {
8 System.loadLibrary("sherpa-onnx-jni"); 8 System.loadLibrary("sherpa-onnx-jni");
9 } 9 }
10 10
11 - private long ptr = 0; // this is the asr engine ptrss 11 + private long ptr = 0;
12 12
13 13
14 public OnlineRecognizer(OnlineRecognizerConfig config) { 14 public OnlineRecognizer(OnlineRecognizerConfig config) {
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class SpeakerEmbeddingExtractor {
  6 + static {
  7 + System.loadLibrary("sherpa-onnx-jni");
  8 + }
  9 +
  10 + private long ptr = 0;
  11 +
  12 + public SpeakerEmbeddingExtractor(SpeakerEmbeddingExtractorConfig config) {
  13 + ptr = newFromFile(config);
  14 + }
  15 +
  16 + @Override
  17 + protected void finalize() throws Throwable {
  18 + release();
  19 + }
  20 +
  21 + public void release() {
  22 + if (this.ptr == 0) {
  23 + return;
  24 + }
  25 + delete(this.ptr);
  26 + this.ptr = 0;
  27 + }
  28 +
  29 + public OnlineStream createStream() {
  30 + long p = createStream(ptr);
  31 + return new OnlineStream(p);
  32 + }
  33 +
  34 + public boolean isReady(OnlineStream s) {
  35 + return isReady(ptr, s.getPtr());
  36 + }
  37 +
  38 + public float[] compute(OnlineStream s) {
  39 + return compute(ptr, s.getPtr());
  40 + }
  41 +
  42 + public int getDim() {
  43 + return dim(ptr);
  44 + }
  45 +
  46 + private native void delete(long ptr);
  47 +
  48 + private native long newFromFile(SpeakerEmbeddingExtractorConfig config);
  49 +
  50 + private native long createStream(long ptr);
  51 +
  52 + private native boolean isReady(long ptr, long streamPtr);
  53 +
  54 + private native float[] compute(long ptr, long streamPtr);
  55 +
  56 + private native int dim(long ptr);
  57 +}
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class SpeakerEmbeddingExtractorConfig {
  6 + private final String model;
  7 + private final int numThreads;
  8 + private final boolean debug;
  9 + private final String provider;
  10 +
  11 + private SpeakerEmbeddingExtractorConfig(Builder builder) {
  12 + this.model = builder.model;
  13 + this.numThreads = builder.numThreads;
  14 + this.debug = builder.debug;
  15 + this.provider = builder.provider;
  16 + }
  17 +
  18 + public static Builder builder() {
  19 + return new Builder();
  20 + }
  21 +
  22 + public static class Builder {
  23 + private String model = "";
  24 + private int numThreads = 1;
  25 + private boolean debug = true;
  26 + private String provider = "cpu";
  27 +
  28 + public SpeakerEmbeddingExtractorConfig build() {
  29 + return new SpeakerEmbeddingExtractorConfig(this);
  30 + }
  31 +
  32 +
  33 + public Builder setModel(String model) {
  34 + this.model = model;
  35 + return this;
  36 + }
  37 +
  38 + public Builder setNumThreads(int numThreads) {
  39 + this.numThreads = numThreads;
  40 + return this;
  41 + }
  42 +
  43 + public Builder setDebug(boolean debug) {
  44 + this.debug = debug;
  45 + return this;
  46 + }
  47 +
  48 + public Builder setProvider(String provider) {
  49 + this.provider = provider;
  50 + return this;
  51 + }
  52 + }
  53 +
  54 +}
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class SpeakerEmbeddingManager {
  6 + static {
  7 + System.loadLibrary("sherpa-onnx-jni");
  8 + }
  9 +
  10 + private long ptr = 0;
  11 +
  12 + public SpeakerEmbeddingManager(int dim) {
  13 + ptr = create(dim);
  14 + }
  15 +
  16 + @Override
  17 + protected void finalize() throws Throwable {
  18 + release();
  19 + }
  20 +
  21 + public void release() {
  22 + if (this.ptr == 0) {
  23 + return;
  24 + }
  25 + delete(this.ptr);
  26 + this.ptr = 0;
  27 + }
  28 +
  29 + public boolean add(String name, float[] embedding) {
  30 + return add(ptr, name, embedding);
  31 + }
  32 +
  33 + public boolean add(String name, float[][] embedding) {
  34 + return addList(ptr, name, embedding);
  35 + }
  36 +
  37 + public boolean remove(String name) {
  38 + return remove(ptr, name);
  39 + }
  40 +
  41 + public String search(float[] embedding, float threshold) {
  42 + return search(ptr, embedding, threshold);
  43 + }
  44 +
  45 + public boolean verify(String name, float[] embedding, float threshold) {
  46 + return verify(ptr, name, embedding, threshold);
  47 + }
  48 +
  49 + public boolean contains(String name) {
  50 + return contains(ptr, name);
  51 + }
  52 +
  53 + public int getNumSpeakers() {
  54 + return numSpeakers(ptr);
  55 + }
  56 +
  57 + public String[] getAllSpeakerNames() {
  58 + return allSpeakerNames(ptr);
  59 + }
  60 +
  61 + private native long create(int dim);
  62 +
  63 + private native void delete(long ptr);
  64 +
  65 + private native boolean add(long ptr, String name, float[] embedding);
  66 +
  67 + private native boolean addList(long ptr, String name, float[][] embedding);
  68 +
  69 + private native boolean remove(long ptr, String name);
  70 +
  71 + private native String search(long ptr, float[] embedding, float threshold);
  72 +
  73 + private native boolean verify(long ptr, String name, float[] embedding, float threshold);
  74 +
  75 + private native boolean contains(long ptr, String name);
  76 +
  77 + private native int numSpeakers(long ptr);
  78 +
  79 + private native String[] allSpeakerNames(long ptr);
  80 +}
@@ -12,7 +12,7 @@ public class SpokenLanguageIdentification { @@ -12,7 +12,7 @@ public class SpokenLanguageIdentification {
12 } 12 }
13 13
14 private final Map<String, String> localeMap; 14 private final Map<String, String> localeMap;
15 - private long ptr = 0; // this is the asr engine ptrss 15 + private long ptr = 0;
16 16
17 public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) { 17 public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) {
18 ptr = newFromFile(config); 18 ptr = newFromFile(config);