Robin Zhong
Committed by GitHub

update kotlin api for better release native object and add user-friendly apis. (#1275)

@@ -24,7 +24,7 @@ class KeywordSpotter( @@ -24,7 +24,7 @@ class KeywordSpotter(
24 assetManager: AssetManager? = null, 24 assetManager: AssetManager? = null,
25 val config: KeywordSpotterConfig, 25 val config: KeywordSpotterConfig,
26 ) { 26 ) {
27 - private val ptr: Long 27 + private var ptr: Long
28 28
29 init { 29 init {
30 ptr = if (assetManager != null) { 30 ptr = if (assetManager != null) {
@@ -35,7 +35,10 @@ class KeywordSpotter( @@ -35,7 +35,10 @@ class KeywordSpotter(
35 } 35 }
36 36
37 protected fun finalize() { 37 protected fun finalize() {
38 - delete(ptr) 38 + if (ptr != 0L) {
  39 + delete(ptr)
  40 + ptr = 0
  41 + }
39 } 42 }
40 43
41 fun release() = finalize() 44 fun release() = finalize()
@@ -18,7 +18,7 @@ class OfflinePunctuation( @@ -18,7 +18,7 @@ class OfflinePunctuation(
18 assetManager: AssetManager? = null, 18 assetManager: AssetManager? = null,
19 config: OfflinePunctuationConfig, 19 config: OfflinePunctuationConfig,
20 ) { 20 ) {
21 - private val ptr: Long 21 + private var ptr: Long
22 22
23 init { 23 init {
24 ptr = if (assetManager != null) { 24 ptr = if (assetManager != null) {
@@ -29,7 +29,10 @@ class OfflinePunctuation( @@ -29,7 +29,10 @@ class OfflinePunctuation(
29 } 29 }
30 30
31 protected fun finalize() { 31 protected fun finalize() {
32 - delete(ptr) 32 + if (ptr != 0L) {
  33 + delete(ptr)
  34 + ptr = 0
  35 + }
33 } 36 }
34 37
35 fun release() = finalize() 38 fun release() = finalize()
@@ -72,7 +72,7 @@ class OfflineRecognizer( @@ -72,7 +72,7 @@ class OfflineRecognizer(
72 assetManager: AssetManager? = null, 72 assetManager: AssetManager? = null,
73 config: OfflineRecognizerConfig, 73 config: OfflineRecognizerConfig,
74 ) { 74 ) {
75 - private val ptr: Long 75 + private var ptr: Long
76 76
77 init { 77 init {
78 ptr = if (assetManager != null) { 78 ptr = if (assetManager != null) {
@@ -83,7 +83,10 @@ class OfflineRecognizer( @@ -83,7 +83,10 @@ class OfflineRecognizer(
83 } 83 }
84 84
85 protected fun finalize() { 85 protected fun finalize() {
86 - delete(ptr) 86 + if (ptr != 0L) {
  87 + delete(ptr)
  88 + ptr = 0
  89 + }
87 } 90 }
88 91
89 fun release() = finalize() 92 fun release() = finalize()
@@ -102,7 +105,14 @@ class OfflineRecognizer( @@ -102,7 +105,14 @@ class OfflineRecognizer(
102 val lang = objArray[3] as String 105 val lang = objArray[3] as String
103 val emotion = objArray[4] as String 106 val emotion = objArray[4] as String
104 val event = objArray[5] as String 107 val event = objArray[5] as String
105 - return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps, lang = lang, emotion = emotion, event = event) 108 + return OfflineRecognizerResult(
  109 + text = text,
  110 + tokens = tokens,
  111 + timestamps = timestamps,
  112 + lang = lang,
  113 + emotion = emotion,
  114 + event = event
  115 + )
106 } 116 }
107 117
108 fun decode(stream: OfflineStream) = decode(ptr, stream.ptr) 118 fun decode(stream: OfflineStream) = decode(ptr, stream.ptr)
@@ -13,6 +13,14 @@ class OfflineStream(var ptr: Long) { @@ -13,6 +13,14 @@ class OfflineStream(var ptr: Long) {
13 13
14 fun release() = finalize() 14 fun release() = finalize()
15 15
  16 + fun use(block: (OfflineStream) -> Unit) {
  17 + try {
  18 + block(this)
  19 + } finally {
  20 + release()
  21 + }
  22 + }
  23 +
16 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) 24 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
17 private external fun delete(ptr: Long) 25 private external fun delete(ptr: Long)
18 26
@@ -62,7 +62,7 @@ data class OnlineRecognizerConfig( @@ -62,7 +62,7 @@ data class OnlineRecognizerConfig(
62 var featConfig: FeatureConfig = FeatureConfig(), 62 var featConfig: FeatureConfig = FeatureConfig(),
63 var modelConfig: OnlineModelConfig, 63 var modelConfig: OnlineModelConfig,
64 var lmConfig: OnlineLMConfig = OnlineLMConfig(), 64 var lmConfig: OnlineLMConfig = OnlineLMConfig(),
65 - var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(), 65 + var ctcFstDecoderConfig: OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
66 var endpointConfig: EndpointConfig = EndpointConfig(), 66 var endpointConfig: EndpointConfig = EndpointConfig(),
67 var enableEndpoint: Boolean = true, 67 var enableEndpoint: Boolean = true,
68 var decodingMethod: String = "greedy_search", 68 var decodingMethod: String = "greedy_search",
@@ -85,7 +85,7 @@ class OnlineRecognizer( @@ -85,7 +85,7 @@ class OnlineRecognizer(
85 assetManager: AssetManager? = null, 85 assetManager: AssetManager? = null,
86 val config: OnlineRecognizerConfig, 86 val config: OnlineRecognizerConfig,
87 ) { 87 ) {
88 - private val ptr: Long 88 + private var ptr: Long
89 89
90 init { 90 init {
91 ptr = if (assetManager != null) { 91 ptr = if (assetManager != null) {
@@ -96,7 +96,10 @@ class OnlineRecognizer( @@ -96,7 +96,10 @@ class OnlineRecognizer(
96 } 96 }
97 97
98 protected fun finalize() { 98 protected fun finalize() {
99 - delete(ptr) 99 + if (ptr != 0L) {
  100 + delete(ptr)
  101 + ptr = 0
  102 + }
100 } 103 }
101 104
102 fun release() = finalize() 105 fun release() = finalize()
@@ -15,10 +15,19 @@ class OnlineStream(var ptr: Long = 0) { @@ -15,10 +15,19 @@ class OnlineStream(var ptr: Long = 0) {
15 15
16 fun release() = finalize() 16 fun release() = finalize()
17 17
  18 + fun use(block: (OnlineStream) -> Unit) {
  19 + try {
  20 + block(this)
  21 + } finally {
  22 + release()
  23 + }
  24 + }
  25 +
18 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) 26 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
19 private external fun inputFinished(ptr: Long) 27 private external fun inputFinished(ptr: Long)
20 private external fun delete(ptr: Long) 28 private external fun delete(ptr: Long)
21 29
  30 +
22 companion object { 31 companion object {
23 init { 32 init {
24 System.loadLibrary("sherpa-onnx-jni") 33 System.loadLibrary("sherpa-onnx-jni")
@@ -19,11 +19,13 @@ data class VadModelConfig( @@ -19,11 +19,13 @@ data class VadModelConfig(
19 var debug: Boolean = false, 19 var debug: Boolean = false,
20 ) 20 )
21 21
  22 +class SpeechSegment(val start: Int, val samples: FloatArray)
  23 +
22 class Vad( 24 class Vad(
23 assetManager: AssetManager? = null, 25 assetManager: AssetManager? = null,
24 var config: VadModelConfig, 26 var config: VadModelConfig,
25 ) { 27 ) {
26 - private val ptr: Long 28 + private var ptr: Long
27 29
28 init { 30 init {
29 if (assetManager != null) { 31 if (assetManager != null) {
@@ -34,17 +36,23 @@ class Vad( @@ -34,17 +36,23 @@ class Vad(
34 } 36 }
35 37
36 protected fun finalize() { 38 protected fun finalize() {
37 - delete(ptr) 39 + if (ptr != 0L) {
  40 + delete(ptr)
  41 + ptr = 0
  42 + }
38 } 43 }
39 44
  45 + fun release() = finalize()
  46 +
40 fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) 47 fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
41 48
42 fun empty(): Boolean = empty(ptr) 49 fun empty(): Boolean = empty(ptr)
43 fun pop() = pop(ptr) 50 fun pop() = pop(ptr)
44 51
45 - // return an array containing  
46 - // [start: Int, samples: FloatArray]  
47 - fun front() = front(ptr) 52 + fun front(): SpeechSegment {
  53 + val segment = front(ptr)
  54 + return SpeechSegment(segment[0] as Int, segment[1] as FloatArray)
  55 + }
48 56
49 fun clear() = clear(ptr) 57 fun clear() = clear(ptr)
50 58
@@ -3,8 +3,49 @@ package com.k2fsa.sherpa.onnx @@ -3,8 +3,49 @@ package com.k2fsa.sherpa.onnx
3 3
4 import android.content.res.AssetManager 4 import android.content.res.AssetManager
5 5
  6 +data class WaveData(
  7 + val samples: FloatArray,
  8 + val sampleRate: Int,
  9 +) {
  10 + override fun equals(other: Any?): Boolean {
  11 + if (this === other) return true
  12 + if (javaClass != other?.javaClass) return false
  13 +
  14 + other as WaveData
  15 +
  16 + if (!samples.contentEquals(other.samples)) return false
  17 + if (sampleRate != other.sampleRate) return false
  18 +
  19 + return true
  20 + }
  21 +
  22 + override fun hashCode(): Int {
  23 + var result = samples.contentHashCode()
  24 + result = 31 * result + sampleRate
  25 + return result
  26 + }
  27 +}
  28 +
6 class WaveReader { 29 class WaveReader {
7 companion object { 30 companion object {
  31 +
  32 + fun readWave(
  33 + assetManager: AssetManager,
  34 + filename: String,
  35 + ): WaveData {
  36 + return readWaveFromAsset(assetManager, filename).let {
  37 + WaveData(it[0] as FloatArray, it[1] as Int)
  38 + }
  39 + }
  40 +
  41 + fun readWave(
  42 + filename: String,
  43 + ): WaveData {
  44 + return readWaveFromFile(filename).let {
  45 + WaveData(it[0] as FloatArray, it[1] as Int)
  46 + }
  47 + }
  48 +
8 // Read a mono wave file asset 49 // Read a mono wave file asset
9 // The returned array has two entries: 50 // The returned array has two entries:
10 // - the first entry contains an 1-D float array 51 // - the first entry contains an 1-D float array