Fangjun Kuang
Committed by GitHub

Add Kotlin API for audio tagging (#770)

  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +import android.util.Log
  5 +
  6 +private val TAG = "sherpa-onnx"
  7 +
  8 +data class OfflineZipformerAudioTaggingModelConfig (
  9 + val model: String,
  10 +)
  11 +
  12 +data class AudioTaggingModelConfig (
  13 + var zipformer: OfflineZipformerAudioTaggingModelConfig,
  14 + var numThreads: Int = 1,
  15 + var debug: Boolean = false,
  16 + var provider: String = "cpu",
  17 +)
  18 +
  19 +data class AudioTaggingConfig (
  20 + var model: AudioTaggingModelConfig,
  21 + var labels: String,
  22 + var topK: Int = 5,
  23 +)
  24 +
  25 +data class AudioEvent (
  26 + val name: String,
  27 + val index: Int,
  28 + val prob: Float,
  29 +)
  30 +
  31 +class AudioTagging(
  32 + assetManager: AssetManager? = null,
  33 + config: AudioTaggingConfig,
  34 +) {
  35 + private var ptr: Long
  36 +
  37 + init {
  38 + ptr = if (assetManager != null) {
  39 + newFromAsset(assetManager, config)
  40 + } else {
  41 + newFromFile(config)
  42 + }
  43 + }
  44 +
  45 + protected fun finalize() {
  46 + if(ptr != 0L) {
  47 + delete(ptr)
  48 + ptr = 0
  49 + }
  50 + }
  51 +
  52 + fun release() = finalize()
  53 +
  54 + fun createStream(): OfflineStream {
  55 + val p = createStream(ptr)
  56 + return OfflineStream(p)
  57 + }
  58 +
  59 + fun compute(stream: OfflineStream, topK: Int=-1): ArrayList<AudioEvent> {
  60 + var events :Array<Any> = compute(ptr, stream.ptr, topK)
  61 + val ans = ArrayList<AudioEvent>()
  62 +
  63 + for (e in events) {
  64 + val p :Array<Any> = e as Array<Any>
  65 + ans.add(AudioEvent(
  66 + name=p[0] as String,
  67 + index=p[1] as Int,
  68 + prob=p[2] as Float,
  69 + ))
  70 + }
  71 +
  72 + return ans
  73 + }
  74 +
  75 + private external fun newFromAsset(
  76 + assetManager: AssetManager,
  77 + config: AudioTaggingConfig,
  78 + ): Long
  79 +
  80 + private external fun newFromFile(
  81 + config: AudioTaggingConfig,
  82 + ): Long
  83 +
  84 + private external fun delete(ptr: Long)
  85 +
  86 + private external fun createStream(ptr: Long): Long
  87 +
  88 + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
  89 +
  90 + companion object {
  91 + init {
  92 + System.loadLibrary("sherpa-onnx-jni")
  93 + }
  94 + }
  95 +}
@@ -7,12 +7,56 @@ fun callback(samples: FloatArray): Unit { @@ -7,12 +7,56 @@ fun callback(samples: FloatArray): Unit {
7 } 7 }
8 8
9 fun main() { 9 fun main() {
  10 + testAudioTagging()
10 testSpeakerRecognition() 11 testSpeakerRecognition()
11 testTts() 12 testTts()
12 testAsr("transducer") 13 testAsr("transducer")
13 testAsr("zipformer2-ctc") 14 testAsr("zipformer2-ctc")
14 } 15 }
15 16
  17 +fun testAudioTagging() {
  18 + val config = AudioTaggingConfig(
  19 + model=AudioTaggingModelConfig(
  20 + zipformer=OfflineZipformerAudioTaggingModelConfig(
  21 + model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
  22 + ),
  23 + numThreads=1,
  24 + debug=true,
  25 + provider="cpu",
  26 + ),
  27 + labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
  28 + topK=5,
  29 + )
  30 + val tagger = AudioTagging(assetManager=null, config=config)
  31 +
  32 + val testFiles = arrayOf(
  33 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
  34 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
  35 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
  36 + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
  37 + )
  38 + println("----------")
  39 + for (waveFilename in testFiles) {
  40 + val stream = tagger.createStream()
  41 +
  42 + val objArray = WaveReader.readWaveFromFile(
  43 + filename = waveFilename,
  44 + )
  45 + val samples: FloatArray = objArray[0] as FloatArray
  46 + val sampleRate: Int = objArray[1] as Int
  47 +
  48 + stream.acceptWaveform(samples, sampleRate = sampleRate)
  49 + val events = tagger.compute(stream)
  50 + stream.release()
  51 +
  52 + println(waveFilename)
  53 + println(events)
  54 + println("----------")
  55 + }
  56 +
  57 + tagger.release()
  58 +}
  59 +
16 fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { 60 fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
17 var objArray = WaveReader.readWaveFromFile( 61 var objArray = WaveReader.readWaveFromFile(
18 filename = filename, 62 filename = filename,
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +class OfflineStream(var ptr: Long) {
  4 + fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
  5 + acceptWaveform(ptr, samples, sampleRate)
  6 +
  7 + protected fun finalize() {
  8 + if(ptr != 0L) {
  9 + delete(ptr)
  10 + ptr = 0
  11 + }
  12 + }
  13 +
  14 + fun release() = finalize()
  15 +
  16 + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
  17 + private external fun delete(ptr: Long)
  18 +
  19 + companion object {
  20 + init {
  21 + System.loadLibrary("sherpa-onnx-jni")
  22 + }
  23 + }
  24 +}
@@ -4,8 +4,7 @@ @@ -4,8 +4,7 @@
4 # Note: This scripts runs only on Linux and macOS, though sherpa-onnx 4 # Note: This scripts runs only on Linux and macOS, though sherpa-onnx
5 # supports building JNI libs for Windows. 5 # supports building JNI libs for Windows.
6 6
7 -set -e  
8 - 7 +set -ex
9 8
10 cd .. 9 cd ..
11 mkdir -p build 10 mkdir -p build
@@ -29,59 +28,93 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH @@ -29,59 +28,93 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
29 28
30 cd ../kotlin-api-examples 29 cd ../kotlin-api-examples
31 30
32 -if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then  
33 - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx  
34 -fi  
35 -  
36 -if [ ! -f ./speaker1_a_cn_16k.wav ]; then  
37 - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav  
38 -fi  
39 -  
40 -if [ ! -f ./speaker1_b_cn_16k.wav ]; then  
41 - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav  
42 -fi  
43 -  
44 -if [ ! -f ./speaker2_a_cn_16k.wav ]; then  
45 - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav  
46 -fi  
47 -  
48 -if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then  
49 - git lfs install  
50 - git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21  
51 -fi  
52 -  
53 -if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then  
54 - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2  
55 - tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2  
56 - rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2  
57 -fi  
58 -  
59 -if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then  
60 - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2  
61 - tar xf vits-piper-en_US-amy-low.tar.bz2  
62 - rm vits-piper-en_US-amy-low.tar.bz2  
63 -fi  
64 -  
65 -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt Speaker.kt faked-log.kt 31 +function testSpeakerEmbeddingExtractor() {
  32 + if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
  33 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
  34 + fi
  35 +
  36 + if [ ! -f ./speaker1_a_cn_16k.wav ]; then
  37 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
  38 + fi
  39 +
  40 + if [ ! -f ./speaker1_b_cn_16k.wav ]; then
  41 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
  42 + fi
  43 +
  44 + if [ ! -f ./speaker2_a_cn_16k.wav ]; then
  45 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
  46 + fi
  47 +}
  48 +
  49 +function testAsr() {
  50 + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
  51 + git lfs install
  52 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
  53 + fi
  54 +
  55 + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
  56 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
  57 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
  58 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
  59 + fi
  60 +}
  61 +
  62 +function testTts() {
  63 + if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
  64 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
  65 + tar xf vits-piper-en_US-amy-low.tar.bz2
  66 + rm vits-piper-en_US-amy-low.tar.bz2
  67 + fi
  68 +}
  69 +
  70 +function testAudioTagging() {
  71 + if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then
  72 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  73 + tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  74 + rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
  75 + fi
  76 +}
  77 +
  78 +function test() {
  79 + testAudioTagging
  80 + testSpeakerEmbeddingExtractor
  81 + testAsr
  82 + testTts
  83 +}
  84 +
  85 +test
  86 +
  87 +kotlinc-jvm -include-runtime -d main.jar \
  88 + AudioTagging.kt \
  89 + Main.kt \
  90 + OfflineStream.kt \
  91 + SherpaOnnx.kt \
  92 + Speaker.kt \
  93 + Tts.kt \
  94 + WaveReader.kt \
  95 + faked-asset-manager.kt \
  96 + faked-log.kt
66 97
67 ls -lh main.jar 98 ls -lh main.jar
68 99
69 java -Djava.library.path=../build/lib -jar main.jar 100 java -Djava.library.path=../build/lib -jar main.jar
70 101
71 -# For two-pass  
72 -  
73 -if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then  
74 - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
75 - tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
76 - rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2  
77 -fi  
78 -  
79 -if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then  
80 - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2  
81 - tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2  
82 - rm sherpa-onnx-whisper-tiny.en.tar.bz2  
83 -fi  
84 -  
85 -kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt  
86 -ls -lh 2pass.jar  
87 -java -Djava.library.path=../build/lib -jar 2pass.jar 102 +function testTwoPass() {
  103 + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
  104 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  105 + tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  106 + rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  107 + fi
  108 +
  109 + if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
  110 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
  111 + tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
  112 + rm sherpa-onnx-whisper-tiny.en.tar.bz2
  113 + fi
  114 +
  115 + kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt
  116 + ls -lh 2pass.jar
  117 + java -Djava.library.path=../build/lib -jar 2pass.jar
  118 +}
  119 +
  120 +testTwoPass
@@ -4,6 +4,11 @@ @@ -4,6 +4,11 @@
4 4
5 #include "sherpa-onnx/csrc/audio-tagging-impl.h" 5 #include "sherpa-onnx/csrc/audio-tagging-impl.h"
6 6
  7 +#if __ANDROID_API__ >= 9
  8 +#include "android/asset_manager.h"
  9 +#include "android/asset_manager_jni.h"
  10 +#endif
  11 +
7 #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" 12 #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
8 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
9 14
@@ -20,4 +25,17 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( @@ -20,4 +25,17 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
20 return nullptr; 25 return nullptr;
21 } 26 }
22 27
  28 +#if __ANDROID_API__ >= 9
  29 +std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
  30 + AAssetManager *mgr, const AudioTaggingConfig &config) {
  31 + if (!config.model.zipformer.model.empty()) {
  32 + return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
  33 + }
  34 +
  35 + SHERPA_ONNX_LOG(
  36 + "Please specify an audio tagging model! Return a null pointer");
  37 + return nullptr;
  38 +}
  39 +#endif
  40 +
23 } // namespace sherpa_onnx 41 } // namespace sherpa_onnx
@@ -7,6 +7,11 @@ @@ -7,6 +7,11 @@
7 #include <memory> 7 #include <memory>
8 #include <vector> 8 #include <vector>
9 9
  10 +#if __ANDROID_API__ >= 9
  11 +#include "android/asset_manager.h"
  12 +#include "android/asset_manager_jni.h"
  13 +#endif
  14 +
10 #include "sherpa-onnx/csrc/audio-tagging.h" 15 #include "sherpa-onnx/csrc/audio-tagging.h"
11 16
12 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -18,6 +23,11 @@ class AudioTaggingImpl { @@ -18,6 +23,11 @@ class AudioTaggingImpl {
18 static std::unique_ptr<AudioTaggingImpl> Create( 23 static std::unique_ptr<AudioTaggingImpl> Create(
19 const AudioTaggingConfig &config); 24 const AudioTaggingConfig &config);
20 25
  26 +#if __ANDROID_API__ >= 9
  27 + static std::unique_ptr<AudioTaggingImpl> Create(
  28 + AAssetManager *mgr, const AudioTaggingConfig &config);
  29 +#endif
  30 +
21 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; 31 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
22 32
23 virtual std::vector<AudioEvent> Compute(OfflineStream *s, 33 virtual std::vector<AudioEvent> Compute(OfflineStream *s,
@@ -8,7 +8,15 @@ @@ -8,7 +8,15 @@
8 #include <sstream> 8 #include <sstream>
9 #include <string> 9 #include <string>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include <strstream>
  13 +
  14 +#include "android/asset_manager.h"
  15 +#include "android/asset_manager_jni.h"
  16 +#endif
  17 +
11 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
  19 +#include "sherpa-onnx/csrc/onnx-utils.h"
12 #include "sherpa-onnx/csrc/text-utils.h" 20 #include "sherpa-onnx/csrc/text-utils.h"
13 21
14 namespace sherpa_onnx { 22 namespace sherpa_onnx {
@@ -18,6 +26,15 @@ AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { @@ -18,6 +26,15 @@ AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
18 Init(is); 26 Init(is);
19 } 27 }
20 28
  29 +#if __ANDROID_API__ >= 9
  30 +AudioTaggingLabels::AudioTaggingLabels(AAssetManager *mgr,
  31 + const std::string &filename) {
  32 + auto buf = ReadFile(mgr, filename);
  33 + std::istrstream is(buf.data(), buf.size());
  34 + Init(is);
  35 +}
  36 +#endif
  37 +
21 // Format of a label file 38 // Format of a label file
22 /* 39 /*
23 index,mid,display_name 40 index,mid,display_name
@@ -8,11 +8,19 @@ @@ -8,11 +8,19 @@
8 #include <string> 8 #include <string>
9 #include <vector> 9 #include <vector>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 namespace sherpa_onnx { 16 namespace sherpa_onnx {
12 17
13 class AudioTaggingLabels { 18 class AudioTaggingLabels {
14 public: 19 public:
15 explicit AudioTaggingLabels(const std::string &filename); 20 explicit AudioTaggingLabels(const std::string &filename);
  21 +#if __ANDROID_API__ >= 9
  22 + AudioTaggingLabels(AAssetManager *mgr, const std::string &filename);
  23 +#endif
16 24
17 // Return the event name for the given index. 25 // Return the event name for the given index.
18 // The returned reference is valid as long as this object is alive 26 // The returned reference is valid as long as this object is alive
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <utility> 8 #include <utility>
9 #include <vector> 9 #include <vector>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "sherpa-onnx/csrc/audio-tagging-impl.h" 16 #include "sherpa-onnx/csrc/audio-tagging-impl.h"
12 #include "sherpa-onnx/csrc/audio-tagging-label-file.h" 17 #include "sherpa-onnx/csrc/audio-tagging-label-file.h"
13 #include "sherpa-onnx/csrc/audio-tagging.h" 18 #include "sherpa-onnx/csrc/audio-tagging.h"
@@ -28,6 +33,20 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { @@ -28,6 +33,20 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
28 } 33 }
29 } 34 }
30 35
  36 +#if __ANDROID_API__ >= 9
  37 + explicit AudioTaggingZipformerImpl(AAssetManager *mgr,
  38 + const AudioTaggingConfig &config)
  39 + : config_(config),
  40 + model_(mgr, config.model),
  41 + labels_(mgr, config.labels) {
  42 + if (model_.NumEventClasses() != labels_.NumEventClasses()) {
  43 + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
  44 + model_.NumEventClasses(), labels_.NumEventClasses());
  45 + exit(-1);
  46 + }
  47 + }
  48 +#endif
  49 +
31 std::unique_ptr<OfflineStream> CreateStream() const override { 50 std::unique_ptr<OfflineStream> CreateStream() const override {
32 return std::make_unique<OfflineStream>(); 51 return std::make_unique<OfflineStream>();
33 } 52 }
@@ -4,6 +4,11 @@ @@ -4,6 +4,11 @@
4 4
5 #include "sherpa-onnx/csrc/audio-tagging.h" 5 #include "sherpa-onnx/csrc/audio-tagging.h"
6 6
  7 +#if __ANDROID_API__ >= 9
  8 +#include "android/asset_manager.h"
  9 +#include "android/asset_manager_jni.h"
  10 +#endif
  11 +
7 #include "sherpa-onnx/csrc/audio-tagging-impl.h" 12 #include "sherpa-onnx/csrc/audio-tagging-impl.h"
8 #include "sherpa-onnx/csrc/file-utils.h" 13 #include "sherpa-onnx/csrc/file-utils.h"
9 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
@@ -61,6 +66,11 @@ std::string AudioTaggingConfig::ToString() const { @@ -61,6 +66,11 @@ std::string AudioTaggingConfig::ToString() const {
61 AudioTagging::AudioTagging(const AudioTaggingConfig &config) 66 AudioTagging::AudioTagging(const AudioTaggingConfig &config)
62 : impl_(AudioTaggingImpl::Create(config)) {} 67 : impl_(AudioTaggingImpl::Create(config)) {}
63 68
  69 +#if __ANDROID_API__ >= 9
  70 +AudioTagging::AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config)
  71 + : impl_(AudioTaggingImpl::Create(mgr, config)) {}
  72 +#endif
  73 +
64 AudioTagging::~AudioTagging() = default; 74 AudioTagging::~AudioTagging() = default;
65 75
66 std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const { 76 std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <string> 8 #include <string>
9 #include <vector> 9 #include <vector>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "sherpa-onnx/csrc/audio-tagging-model-config.h" 16 #include "sherpa-onnx/csrc/audio-tagging-model-config.h"
12 #include "sherpa-onnx/csrc/offline-stream.h" 17 #include "sherpa-onnx/csrc/offline-stream.h"
13 #include "sherpa-onnx/csrc/parse-options.h" 18 #include "sherpa-onnx/csrc/parse-options.h"
@@ -46,6 +51,10 @@ class AudioTagging { @@ -46,6 +51,10 @@ class AudioTagging {
46 public: 51 public:
47 explicit AudioTagging(const AudioTaggingConfig &config); 52 explicit AudioTagging(const AudioTaggingConfig &config);
48 53
  54 +#if __ANDROID_API__ >= 9
  55 + AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config);
  56 +#endif
  57 +
49 ~AudioTagging(); 58 ~AudioTagging();
50 59
51 std::unique_ptr<OfflineStream> CreateStream() const; 60 std::unique_ptr<OfflineStream> CreateStream() const;
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +import android.util.Log
  5 +
  6 +private val TAG = "sherpa-onnx"
  7 +
  8 +data class OfflineZipformerAudioTaggingModelConfig (
  9 + val model: String,
  10 +)
  11 +
  12 +data class AudioTaggingModelConfig (
  13 + var zipformer: OfflineZipformerAudioTaggingModelConfig,
  14 + var numThreads: Int = 1,
  15 + var debug: Boolean = false,
  16 + var provider: String = "cpu",
  17 +)
  18 +
  19 +data class AudioTaggingConfig (
  20 + var model: AudioTaggingModelConfig,
  21 + var labels: String,
  22 + var topK: Int = 5,
  23 +)
  24 +
  25 +data class AudioEvent (
  26 + val name: String,
  27 + val index: Int,
  28 + val prob: Float,
  29 +)
  30 +
  31 +class AudioTagging(
  32 + assetManager: AssetManager? = null,
  33 + config: AudioTaggingConfig,
  34 +) {
  35 + private var ptr: Long
  36 +
  37 + init {
  38 + ptr = if (assetManager != null) {
  39 + newFromAsset(assetManager, config)
  40 + } else {
  41 + newFromFile(config)
  42 + }
  43 + }
  44 +
  45 + protected fun finalize() {
  46 + if(ptr != 0) {
  47 + delete(ptr)
  48 + ptr = 0
  49 + }
  50 + }
  51 +
  52 + fun release() = finalize()
  53 +
  54 + fun createStream(): OfflineStream {
  55 + val p = createStream(ptr)
  56 + return OfflineStream(p)
  57 + }
  58 +
  59 + // fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
  60 + fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
  61 + var events :Array<Any> = compute(ptr, stream.ptr, topK)
  62 + }
  63 +
  64 + private external fun newFromAsset(
  65 + assetManager: AssetManager,
  66 + config: AudioTaggingConfig,
  67 + ): Long
  68 +
  69 + private external fun newFromFile(
  70 + config: AudioTaggingConfig,
  71 + ): Long
  72 +
  73 + private external fun delete(ptr: Long)
  74 +
  75 + private external fun createStream(ptr: Long): Long
  76 +
  77 + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
  78 +
  79 + companion object {
  80 + init {
  81 + System.loadLibrary("sherpa-onnx-jni")
  82 + }
  83 + }
  84 +}
@@ -9,6 +9,10 @@ if(NOT DEFINED ANDROID_ABI) @@ -9,6 +9,10 @@ if(NOT DEFINED ANDROID_ABI)
9 include_directories($ENV{JAVA_HOME}/include/darwin) 9 include_directories($ENV{JAVA_HOME}/include/darwin)
10 endif() 10 endif()
11 11
12 -add_library(sherpa-onnx-jni jni.cc) 12 +add_library(sherpa-onnx-jni
  13 + audio-tagging.cc
  14 + jni.cc
  15 + offline-stream.cc
  16 +)
13 target_link_libraries(sherpa-onnx-jni sherpa-onnx-core) 17 target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
14 install(TARGETS sherpa-onnx-jni DESTINATION lib) 18 install(TARGETS sherpa-onnx-jni DESTINATION lib)
  1 +// sherpa-onnx/jni/audio-tagging.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/audio-tagging.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
  13 + AudioTaggingConfig ans;
  14 +
  15 + jclass cls = env->GetObjectClass(config);
  16 +
  17 + jfieldID fid = env->GetFieldID(
  18 + cls, "model", "Lcom/k2fsa/sherpa/onnx/AudioTaggingModelConfig;");
  19 + jobject model = env->GetObjectField(config, fid);
  20 + jclass model_cls = env->GetObjectClass(model);
  21 +
  22 + fid = env->GetFieldID(
  23 + model_cls, "zipformer",
  24 + "Lcom/k2fsa/sherpa/onnx/OfflineZipformerAudioTaggingModelConfig;");
  25 + jobject zipformer = env->GetObjectField(model, fid);
  26 + jclass zipformer_cls = env->GetObjectClass(zipformer);
  27 +
  28 + fid = env->GetFieldID(zipformer_cls, "model", "Ljava/lang/String;");
  29 + jstring s = (jstring)env->GetObjectField(zipformer, fid);
  30 + const char *p = env->GetStringUTFChars(s, nullptr);
  31 + ans.model.zipformer.model = p;
  32 + env->ReleaseStringUTFChars(s, p);
  33 +
  34 + fid = env->GetFieldID(model_cls, "numThreads", "I");
  35 + ans.model.num_threads = env->GetIntField(model, fid);
  36 +
  37 + fid = env->GetFieldID(model_cls, "debug", "Z");
  38 + ans.model.debug = env->GetBooleanField(model, fid);
  39 +
  40 + fid = env->GetFieldID(model_cls, "provider", "Ljava/lang/String;");
  41 + s = (jstring)env->GetObjectField(model, fid);
  42 + p = env->GetStringUTFChars(s, nullptr);
  43 + ans.model.provider = p;
  44 + env->ReleaseStringUTFChars(s, p);
  45 +
  46 + fid = env->GetFieldID(cls, "labels", "Ljava/lang/String;");
  47 + s = (jstring)env->GetObjectField(config, fid);
  48 + p = env->GetStringUTFChars(s, nullptr);
  49 + ans.labels = p;
  50 + env->ReleaseStringUTFChars(s, p);
  51 +
  52 + fid = env->GetFieldID(cls, "topK", "I");
  53 + ans.top_k = env->GetIntField(config, fid);
  54 +
  55 + return ans;
  56 +}
  57 +
  58 +} // namespace sherpa_onnx
  59 +
  60 +SHERPA_ONNX_EXTERN_C
  61 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromFile(
  62 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  63 + auto config = sherpa_onnx::GetAudioTaggingConfig(env, _config);
  64 + SHERPA_ONNX_LOGE("audio tagging newFromFile config:\n%s",
  65 + config.ToString().c_str());
  66 +
  67 + if (!config.Validate()) {
  68 + SHERPA_ONNX_LOGE("Errors found in config!");
  69 + return 0;
  70 + }
  71 +
  72 + auto tagger = new sherpa_onnx::AudioTagging(config);
  73 +
  74 + return (jlong)tagger;
  75 +}
  76 +
  77 +SHERPA_ONNX_EXTERN_C
  78 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_delete(
  79 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  80 + delete reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
  81 +}
  82 +
  83 +SHERPA_ONNX_EXTERN_C
  84 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_createStream(
  85 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  86 + auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
  87 + std::unique_ptr<sherpa_onnx::OfflineStream> s = tagger->CreateStream();
  88 +
  89 + // The user is responsible to free the returned pointer.
  90 + //
  91 + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
  92 + // ./offline-stream.cc
  93 + sherpa_onnx::OfflineStream *p = s.release();
  94 + return (jlong)p;
  95 +}
  96 +
  97 +SHERPA_ONNX_EXTERN_C
  98 +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_compute(
  99 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr, jint top_k) {
  100 + auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
  101 + auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
  102 + std::vector<sherpa_onnx::AudioEvent> events = tagger->Compute(stream, top_k);
  103 +
  104 + // TODO(fangjun): Return an array of AudioEvent directly
  105 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  106 + events.size(), env->FindClass("java/lang/Object"), nullptr);
  107 +
  108 + int32_t i = 0;
  109 + for (const auto &e : events) {
  110 + jobjectArray a = (jobjectArray)env->NewObjectArray(
  111 + 3, env->FindClass("java/lang/Object"), nullptr);
  112 +
  113 + // 0 name
  114 + // 1 index
  115 + // 2 prob
  116 + jstring js = env->NewStringUTF(e.name.c_str());
  117 + env->SetObjectArrayElement(a, 0, js);
  118 + env->SetObjectArrayElement(a, 1, NewInteger(env, e.index));
  119 + env->SetObjectArrayElement(a, 2, NewFloat(env, e.prob));
  120 +
  121 + env->SetObjectArrayElement(obj_arr, i, a);
  122 + i += 1;
  123 + }
  124 +
  125 + return obj_arr;
  126 +}
  1 +// sherpa-onnx/jni/common.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_JNI_COMMON_H_
  6 +#define SHERPA_ONNX_JNI_COMMON_H_
  7 +
  8 +#if __ANDROID_API__ >= 9
  9 +#include "android/asset_manager.h"
  10 +#include "android/asset_manager_jni.h"
  11 +#endif
  12 +
  13 +// If you use ndk, you can find "jni.h" inside
  14 +// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
  15 +#include "jni.h" // NOLINT
  16 +
  17 +#define SHERPA_ONNX_EXTERN_C extern "C"
  18 +
  19 +// defined in jni.cc
  20 +jobject NewInteger(JNIEnv *env, int32_t value);
  21 +jobject NewFloat(JNIEnv *env, float value);
  22 +
  23 +#endif // SHERPA_ONNX_JNI_COMMON_H_
@@ -7,20 +7,11 @@ @@ -7,20 +7,11 @@
7 // TODO(fangjun): Add documentation to functions/methods in this file 7 // TODO(fangjun): Add documentation to functions/methods in this file
8 // and also show how to use them with kotlin, possibly with java. 8 // and also show how to use them with kotlin, possibly with java.
9 9
10 -// If you use ndk, you can find "jni.h" inside  
11 -// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include  
12 -#include "jni.h" // NOLINT  
13 -  
14 #include <fstream> 10 #include <fstream>
15 #include <functional> 11 #include <functional>
16 #include <strstream> 12 #include <strstream>
17 #include <utility> 13 #include <utility>
18 14
19 -#if __ANDROID_API__ >= 9  
20 -#include "android/asset_manager.h"  
21 -#include "android/asset_manager_jni.h"  
22 -#endif  
23 -  
24 #include "sherpa-onnx/csrc/keyword-spotter.h" 15 #include "sherpa-onnx/csrc/keyword-spotter.h"
25 #include "sherpa-onnx/csrc/macros.h" 16 #include "sherpa-onnx/csrc/macros.h"
26 #include "sherpa-onnx/csrc/offline-recognizer.h" 17 #include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -31,13 +22,12 @@ @@ -31,13 +22,12 @@
31 #include "sherpa-onnx/csrc/voice-activity-detector.h" 22 #include "sherpa-onnx/csrc/voice-activity-detector.h"
32 #include "sherpa-onnx/csrc/wave-reader.h" 23 #include "sherpa-onnx/csrc/wave-reader.h"
33 #include "sherpa-onnx/csrc/wave-writer.h" 24 #include "sherpa-onnx/csrc/wave-writer.h"
  25 +#include "sherpa-onnx/jni/common.h"
34 26
35 #if SHERPA_ONNX_ENABLE_TTS == 1 27 #if SHERPA_ONNX_ENABLE_TTS == 1
36 #include "sherpa-onnx/csrc/offline-tts.h" 28 #include "sherpa-onnx/csrc/offline-tts.h"
37 #endif 29 #endif
38 30
39 -#define SHERPA_ONNX_EXTERN_C extern "C"  
40 -  
41 namespace sherpa_onnx { 31 namespace sherpa_onnx {
42 32
43 class SherpaOnnx { 33 class SherpaOnnx {
@@ -1224,12 +1214,18 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames( @@ -1224,12 +1214,18 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
1224 1214
1225 // see 1215 // see
1226 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables 1216 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
1227 -static jobject NewInteger(JNIEnv *env, int32_t value) { 1217 +jobject NewInteger(JNIEnv *env, int32_t value) {
1228 jclass cls = env->FindClass("java/lang/Integer"); 1218 jclass cls = env->FindClass("java/lang/Integer");
1229 jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V"); 1219 jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
1230 return env->NewObject(cls, constructor, value); 1220 return env->NewObject(cls, constructor, value);
1231 } 1221 }
1232 1222
  1223 +jobject NewFloat(JNIEnv *env, float value) {
  1224 + jclass cls = env->FindClass("java/lang/Float");
  1225 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
  1226 + return env->NewObject(cls, constructor, value);
  1227 +}
  1228 +
1233 #if SHERPA_ONNX_ENABLE_TTS == 1 1229 #if SHERPA_ONNX_ENABLE_TTS == 1
1234 SHERPA_ONNX_EXTERN_C 1230 SHERPA_ONNX_EXTERN_C
1235 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new( 1231 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
  1 +// sherpa-onnx/jni/offline-stream.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-stream.h"
  6 +
  7 +#include "sherpa-onnx/jni/common.h"
  8 +
  9 +SHERPA_ONNX_EXTERN_C
  10 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_delete(
  11 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  12 + delete reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
  13 +}
  14 +
  15 +SHERPA_ONNX_EXTERN_C
  16 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_acceptWaveform(
  17 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  18 + jint sample_rate) {
  19 + auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
  20 +
  21 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  22 + jsize n = env->GetArrayLength(samples);
  23 + stream->AcceptWaveform(sample_rate, p, n);
  24 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  25 +}