Fangjun Kuang
Committed by GitHub

Add KWS examples for Java API (#930)

@@ -107,6 +107,13 @@ jobs: @@ -107,6 +107,13 @@ jobs:
107 make -j4 107 make -j4
108 ls -lh lib 108 ls -lh lib
109 109
  110 + - name: Run java test (kws)
  111 + shell: bash
  112 + run: |
  113 + cd ./java-api-examples
  114 + ./run-kws-from-file.sh
  115 + rm -rf sherpa-onnx-*
  116 +
110 - name: Run java test (VAD + Non-streaming Paraformer) 117 - name: Run java test (VAD + Non-streaming Paraformer)
111 shell: bash 118 shell: bash
112 run: | 119 run: |
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +// This file shows how to use a keyword spotter model to spot keywords from
  4 +// a file.
  5 +
  6 +import com.k2fsa.sherpa.onnx.*;
  7 +
  8 +public class KyewordSpotterFromFile {
  9 + public static void main(String[] args) {
  10 + // please download test files from https://github.com/k2-fsa/sherpa-onnx/releases/tag/kws-models
  11 + String encoder =
  12 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  13 + String decoder =
  14 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  15 + String joiner =
  16 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
  17 + String tokens = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
  18 +
  19 + String keywordsFile =
  20 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt";
  21 +
  22 + String waveFilename = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
  23 +
  24 + OnlineTransducerModelConfig transducer =
  25 + OnlineTransducerModelConfig.builder()
  26 + .setEncoder(encoder)
  27 + .setDecoder(decoder)
  28 + .setJoiner(joiner)
  29 + .build();
  30 +
  31 + OnlineModelConfig modelConfig =
  32 + OnlineModelConfig.builder()
  33 + .setTransducer(transducer)
  34 + .setTokens(tokens)
  35 + .setNumThreads(1)
  36 + .setDebug(true)
  37 + .build();
  38 +
  39 + KeywordSpotterConfig config =
  40 + KeywordSpotterConfig.builder()
  41 + .setOnlineModelConfig(modelConfig)
  42 + .setKeywordsFile(keywordsFile)
  43 + .build();
  44 +
  45 + KeywordSpotter kws = new KeywordSpotter(config);
  46 + OnlineStream stream = kws.createStream();
  47 +
  48 + WaveReader reader = new WaveReader(waveFilename);
  49 +
  50 + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
  51 +
  52 + float[] tailPaddings = new float[(int) (0.8 * reader.getSampleRate())];
  53 + stream.acceptWaveform(tailPaddings, reader.getSampleRate());
  54 + while (kws.isReady(stream)) {
  55 + kws.decode(stream);
  56 +
  57 + String keyword = kws.getResult(stream).getKeyword();
  58 + if (!keyword.isEmpty()) {
  59 + System.out.printf("Detected keyword: %s\n", keyword);
  60 + }
  61 + }
  62 +
  63 + kws.release();
  64 + }
  65 +}
@@ -68,3 +68,9 @@ The punctuation model supports both English and Chinese. @@ -68,3 +68,9 @@ The punctuation model supports both English and Chinese.
68 ```bash 68 ```bash
69 ./run-vad-non-streaming-paraformer.sh 69 ./run-vad-non-streaming-paraformer.sh
70 ``` 70 ```
  71 +
  72 +## Keyword spotter
  73 +
  74 +```bash
  75 +./run-kws-from-file.sh
  76 +```
@@ -91,6 +91,7 @@ public class VadNonStreamingParaformer { @@ -91,6 +91,7 @@ public class VadNonStreamingParaformer {
91 stream.acceptWaveform(segment.getSamples(), 16000); 91 stream.acceptWaveform(segment.getSamples(), 16000);
92 recognizer.decode(stream); 92 recognizer.decode(stream);
93 String text = recognizer.getResult(stream).getText(); 93 String text = recognizer.getResult(stream).getText();
  94 + stream.release();
94 95
95 if (!text.isEmpty()) { 96 if (!text.isEmpty()) {
96 System.out.printf("%.3f--%.3f: %s\n", startTime, startTime + duration, text); 97 System.out.printf("%.3f--%.3f: %s\n", startTime, startTime + duration, text);
@@ -100,5 +101,8 @@ public class VadNonStreamingParaformer { @@ -100,5 +101,8 @@ public class VadNonStreamingParaformer {
100 } 101 }
101 } 102 }
102 } 103 }
  104 +
  105 + vad.release();
  106 + recognizer.release();
103 } 107 }
104 } 108 }
@@ -75,5 +75,7 @@ public class VadRemoveSilence { @@ -75,5 +75,7 @@ public class VadRemoveSilence {
75 String outFilename = "lei-jun-test-no-silence.wav"; 75 String outFilename = "lei-jun-test-no-silence.wav";
76 WaveWriter.write(outFilename, allSamples, 16000); 76 WaveWriter.write(outFilename, allSamples, 16000);
77 System.out.printf("Saved to %s\n", outFilename); 77 System.out.printf("Saved to %s\n", outFilename);
  78 +
  79 + vad.release();
78 } 80 }
79 } 81 }
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
  6 + mkdir -p ../build
  7 + pushd ../build
  8 + cmake \
  9 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  10 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  11 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  12 + -DBUILD_SHARED_LIBS=ON \
  13 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  14 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  15 + ..
  16 +
  17 + make -j4
  18 + ls -lh lib
  19 + popd
  20 +fi
  21 +
  22 +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
  23 + pushd ../sherpa-onnx/java-api
  24 + make
  25 + popd
  26 +fi
  27 +
  28 +if [ ! -f ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt ]; then
  29 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
  30 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
  31 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
  32 +fi
  33 +
  34 +java \
  35 + -Djava.library.path=$PWD/../build/lib \
  36 + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
  37 + ./KeywordSpotterFromFile.java
@@ -62,6 +62,10 @@ java_files += VadModelConfig.java @@ -62,6 +62,10 @@ java_files += VadModelConfig.java
62 java_files += SpeechSegment.java 62 java_files += SpeechSegment.java
63 java_files += Vad.java 63 java_files += Vad.java
64 64
  65 +java_files += KeywordSpotterConfig.java
  66 +java_files += KeywordSpotterResult.java
  67 +java_files += KeywordSpotter.java
  68 +
65 class_files := $(java_files:%.java=%.class) 69 class_files := $(java_files:%.java=%.class)
66 70
67 java_files := $(addprefix src/$(package_dir)/,$(java_files)) 71 java_files := $(addprefix src/$(package_dir)/,$(java_files))
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class KeywordSpotter {
  6 + static {
  7 + System.loadLibrary("sherpa-onnx-jni");
  8 + }
  9 +
  10 + private long ptr = 0;
  11 +
  12 + public KeywordSpotter(KeywordSpotterConfig config) {
  13 + ptr = newFromFile(config);
  14 + }
  15 +
  16 + public OnlineStream createStream(String keywords) {
  17 + long p = createStream(ptr, keywords);
  18 + return new OnlineStream(p);
  19 + }
  20 +
  21 + public OnlineStream createStream() {
  22 + long p = createStream(ptr, "");
  23 + return new OnlineStream(p);
  24 + }
  25 +
  26 + public void decode(OnlineStream s) {
  27 + decode(ptr, s.getPtr());
  28 + }
  29 +
  30 + public boolean isReady(OnlineStream s) {
  31 + return isReady(ptr, s.getPtr());
  32 + }
  33 +
  34 + public KeywordSpotterResult getResult(OnlineStream s) {
  35 + Object[] arr = getResult(ptr, s.getPtr());
  36 + String keyword = (String) arr[0];
  37 + String[] tokens = (String[]) arr[1];
  38 + float[] timestamps = (float[]) arr[2];
  39 + return new KeywordSpotterResult(keyword, tokens, timestamps);
  40 + }
  41 +
  42 + protected void finalize() throws Throwable {
  43 + release();
  44 + }
  45 +
  46 + // You'd better call it manually if it is not used anymore
  47 + public void release() {
  48 + if (this.ptr == 0) {
  49 + return;
  50 + }
  51 + delete(this.ptr);
  52 + this.ptr = 0;
  53 + }
  54 +
  55 + private native long newFromFile(KeywordSpotterConfig config);
  56 +
  57 + private native void delete(long ptr);
  58 +
  59 + private native long createStream(long ptr, String keywords);
  60 +
  61 + private native void decode(long ptr, long streamPtr);
  62 +
  63 + private native boolean isReady(long ptr, long streamPtr);
  64 +
  65 + private native Object[] getResult(long ptr, long streamPtr);
  66 +}
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class KeywordSpotterConfig {
  6 + private final FeatureConfig featConfig;
  7 + private final OnlineModelConfig modelConfig;
  8 +
  9 + private final int maxActivePaths;
  10 + private final String keywordsFile;
  11 + private final float keywordsScore;
  12 + private final float keywordsThreshold;
  13 + private final int numTrailingBlanks;
  14 +
  15 + private KeywordSpotterConfig(Builder builder) {
  16 + this.featConfig = builder.featConfig;
  17 + this.modelConfig = builder.modelConfig;
  18 + this.maxActivePaths = builder.maxActivePaths;
  19 + this.keywordsFile = builder.keywordsFile;
  20 + this.keywordsScore = builder.keywordsScore;
  21 + this.keywordsThreshold = builder.keywordsThreshold;
  22 + this.numTrailingBlanks = builder.numTrailingBlanks;
  23 + }
  24 +
  25 + public static Builder builder() {
  26 + return new Builder();
  27 + }
  28 +
  29 + public static class Builder {
  30 + private FeatureConfig featConfig = FeatureConfig.builder().build();
  31 + private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
  32 + private int maxActivePaths = 4;
  33 + private String keywordsFile = "keywords.txt";
  34 + private float keywordsScore = 1.5f;
  35 + private float keywordsThreshold = 0.25f;
  36 + private int numTrailingBlanks = 2;
  37 +
  38 + public KeywordSpotterConfig build() {
  39 + return new KeywordSpotterConfig(this);
  40 + }
  41 +
  42 + public Builder setFeatureConfig(FeatureConfig featConfig) {
  43 + this.featConfig = featConfig;
  44 + return this;
  45 + }
  46 +
  47 + public Builder setOnlineModelConfig(OnlineModelConfig modelConfig) {
  48 + this.modelConfig = modelConfig;
  49 + return this;
  50 + }
  51 +
  52 + public Builder setMaxActivePaths(int maxActivePaths) {
  53 + this.maxActivePaths = maxActivePaths;
  54 + return this;
  55 + }
  56 +
  57 + public Builder setKeywordsFile(String keywordsFile) {
  58 + this.keywordsFile = keywordsFile;
  59 + return this;
  60 + }
  61 +
  62 + public Builder setKeywordsScore(float keywordsScore) {
  63 + this.keywordsScore = keywordsScore;
  64 + return this;
  65 + }
  66 +
  67 + public Builder setKeywordsThreshold(float keywordsThreshold) {
  68 + this.keywordsThreshold = keywordsThreshold;
  69 + return this;
  70 + }
  71 +
  72 + public Builder setNumTrailingBlanks(int numTrailingBlanks) {
  73 + this.numTrailingBlanks = numTrailingBlanks;
  74 + return this;
  75 + }
  76 + }
  77 +}
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +package com.k2fsa.sherpa.onnx;
  4 +
  5 +public class KeywordSpotterResult {
  6 + private final String keyword;
  7 + private final String[] tokens;
  8 + private final float[] timestamps;
  9 +
  10 + public KeywordSpotterResult(String keyword, String[] tokens, float[] timestamps) {
  11 + this.keyword = keyword;
  12 + this.tokens = tokens;
  13 + this.timestamps = timestamps;
  14 + }
  15 +
  16 + public String getKeyword() {
  17 + return keyword;
  18 + }
  19 +
  20 + public String[] getTokens() {
  21 + return tokens;
  22 + }
  23 +
  24 + public float[] getTimestamps() {
  25 + return timestamps;
  26 + }
  27 +}
@@ -10,7 +10,6 @@ public class OnlineRecognizer { @@ -10,7 +10,6 @@ public class OnlineRecognizer {
10 10
11 private long ptr = 0; 11 private long ptr = 0;
12 12
13 -  
14 public OnlineRecognizer(OnlineRecognizerConfig config) { 13 public OnlineRecognizer(OnlineRecognizerConfig config) {
15 ptr = newFromFile(config); 14 ptr = newFromFile(config);
16 } 15 }
@@ -19,7 +18,6 @@ public class OnlineRecognizer { @@ -19,7 +18,6 @@ public class OnlineRecognizer {
19 decode(ptr, s.getPtr()); 18 decode(ptr, s.getPtr());
20 } 19 }
21 20
22 -  
23 public boolean isReady(OnlineStream s) { 21 public boolean isReady(OnlineStream s) {
24 return isReady(ptr, s.getPtr()); 22 return isReady(ptr, s.getPtr());
25 } 23 }