Fangjun Kuang
Committed by GitHub

Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
1 Makefile 1 Makefile
2 *.jar 2 *.jar
  3 +hs_err_pid*.log
@@ -4,7 +4,7 @@ import android.content.res.AssetManager @@ -4,7 +4,7 @@ import android.content.res.AssetManager
4 4
5 fun main() { 5 fun main() {
6 var featConfig = FeatureConfig( 6 var featConfig = FeatureConfig(
7 - sampleRate = 16000.0f, 7 + sampleRate = 16000,
8 featureDim = 80, 8 featureDim = 80,
9 ) 9 )
10 10
@@ -13,7 +13,7 @@ fun main() { @@ -13,7 +13,7 @@ fun main() {
13 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", 13 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
14 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", 14 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
15 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", 15 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
16 - numThreads = 4, 16 + numThreads = 1,
17 debug = false, 17 debug = false,
18 ) 18 )
19 19
@@ -24,22 +24,31 @@ fun main() { @@ -24,22 +24,31 @@ fun main() {
24 featConfig = featConfig, 24 featConfig = featConfig,
25 endpointConfig = endpointConfig, 25 endpointConfig = endpointConfig,
26 enableEndpoint = true, 26 enableEndpoint = true,
  27 + decodingMethod = "greedy_search",
  28 + maxActivePaths = 4,
27 ) 29 )
28 30
29 var model = SherpaOnnx( 31 var model = SherpaOnnx(
30 assetManager = AssetManager(), 32 assetManager = AssetManager(),
31 config = config, 33 config = config,
32 ) 34 )
  35 +
33 var samples = WaveReader.readWave( 36 var samples = WaveReader.readWave(
34 assetManager = AssetManager(), 37 assetManager = AssetManager(),
35 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", 38 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
36 ) 39 )
37 40
38 - model.decodeSamples(samples!!) 41 + model.acceptWaveform(samples!!, sampleRate=16000)
  42 + while (model.isReady()) {
  43 + model.decode()
  44 + }
39 45
40 var tail_paddings = FloatArray(8000) // 0.5 seconds 46 var tail_paddings = FloatArray(8000) // 0.5 seconds
41 - model.decodeSamples(tail_paddings)  
42 - 47 + model.acceptWaveform(tail_paddings, sampleRate=16000)
43 model.inputFinished() 48 model.inputFinished()
  49 + while (model.isReady()) {
  50 + model.decode()
  51 + }
  52 +
44 println("results: ${model.text}") 53 println("results: ${model.text}")
45 } 54 }
@@ -38,3 +38,4 @@ log.txt @@ -38,3 +38,4 @@ log.txt
38 tags 38 tags
39 run-decode-file-python.sh 39 run-decode-file-python.sh
40 android/SherpaOnnx/app/src/main/assets/ 40 android/SherpaOnnx/app/src/main/assets/
  41 +*.ncnn.*
@@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() { @@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
121 val ret = audioRecord?.read(buffer, 0, buffer.size) 121 val ret = audioRecord?.read(buffer, 0, buffer.size)
122 if (ret != null && ret > 0) { 122 if (ret != null && ret > 0) {
123 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 123 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
124 - model.decodeSamples(samples) 124 + model.acceptWaveform(samples, sampleRate=16000)
  125 + while (model.isReady()) {
  126 + model.decode()
  127 + }
125 runOnUiThread { 128 runOnUiThread {
126 val isEndpoint = model.isEndpoint() 129 val isEndpoint = model.isEndpoint()
127 val text = model.text 130 val text = model.text
@@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() { @@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
177 val type = 0 180 val type = 0
178 println("Select model type ${type}") 181 println("Select model type ${type}")
179 val config = OnlineRecognizerConfig( 182 val config = OnlineRecognizerConfig(
180 - featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), 183 + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
181 modelConfig = getModelConfig(type = type)!!, 184 modelConfig = getModelConfig(type = type)!!,
182 endpointConfig = getEndpointConfig(), 185 endpointConfig = getEndpointConfig(),
183 - enableEndpoint = true 186 + enableEndpoint = true,
  187 + decodingMethod = "greedy_search",
  188 + maxActivePaths = 4,
184 ) 189 )
185 190
186 model = SherpaOnnx( 191 model = SherpaOnnx(
187 assetManager = application.assets, 192 assetManager = application.assets,
188 config = config, 193 config = config,
189 ) 194 )
190 - /*  
191 - println("reading samples")  
192 - val samples = WaveReader.readWave(  
193 - assetManager = application.assets,  
194 - // filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",  
195 - filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",  
196 - // filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"  
197 - )  
198 - println("samples read done!")  
199 -  
200 - model.decodeSamples(samples!!)  
201 -  
202 - val tailPaddings = FloatArray(8000) // 0.5 seconds  
203 - model.decodeSamples(tailPaddings)  
204 -  
205 - println("result is: ${model.text}")  
206 - model.reset()  
207 - */  
208 } 195 }
209 } 196 }
@@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig( @@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
24 ) 24 )
25 25
26 data class FeatureConfig( 26 data class FeatureConfig(
27 - var sampleRate: Float = 16000.0f, 27 + var sampleRate: Int = 16000,
28 var featureDim: Int = 80, 28 var featureDim: Int = 80,
29 ) 29 )
30 30
@@ -32,7 +32,9 @@ data class OnlineRecognizerConfig( @@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
32 var featConfig: FeatureConfig = FeatureConfig(), 32 var featConfig: FeatureConfig = FeatureConfig(),
33 var modelConfig: OnlineTransducerModelConfig, 33 var modelConfig: OnlineTransducerModelConfig,
34 var endpointConfig: EndpointConfig = EndpointConfig(), 34 var endpointConfig: EndpointConfig = EndpointConfig(),
35 - var enableEndpoint: Boolean, 35 + var enableEndpoint: Boolean = true,
  36 + var decodingMethod: String = "greedy_search",
  37 + var maxActivePaths: Int = 4,
36 ) 38 )
37 39
38 class SherpaOnnx( 40 class SherpaOnnx(
@@ -49,12 +51,14 @@ class SherpaOnnx( @@ -49,12 +51,14 @@ class SherpaOnnx(
49 } 51 }
50 52
51 53
52 - fun decodeSamples(samples: FloatArray) =  
53 - decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate) 54 + fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
  55 + acceptWaveform(ptr, samples, sampleRate)
54 56
55 fun inputFinished() = inputFinished(ptr) 57 fun inputFinished() = inputFinished(ptr)
56 fun reset() = reset(ptr) 58 fun reset() = reset(ptr)
  59 + fun decode() = decode(ptr)
57 fun isEndpoint(): Boolean = isEndpoint(ptr) 60 fun isEndpoint(): Boolean = isEndpoint(ptr)
  61 + fun isReady(): Boolean = isReady(ptr)
58 62
59 val text: String 63 val text: String
60 get() = getText(ptr) 64 get() = getText(ptr)
@@ -66,11 +70,13 @@ class SherpaOnnx( @@ -66,11 +70,13 @@ class SherpaOnnx(
66 config: OnlineRecognizerConfig, 70 config: OnlineRecognizerConfig,
67 ): Long 71 ): Long
68 72
69 - private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float) 73 + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
70 private external fun inputFinished(ptr: Long) 74 private external fun inputFinished(ptr: Long)
71 private external fun getText(ptr: Long): String 75 private external fun getText(ptr: Long): String
72 private external fun reset(ptr: Long) 76 private external fun reset(ptr: Long)
  77 + private external fun decode(ptr: Long)
73 private external fun isEndpoint(ptr: Long): Boolean 78 private external fun isEndpoint(ptr: Long): Boolean
  79 + private external fun isReady(ptr: Long): Boolean
74 80
75 companion object { 81 companion object {
76 init { 82 init {
@@ -79,7 +85,7 @@ class SherpaOnnx( @@ -79,7 +85,7 @@ class SherpaOnnx(
79 } 85 }
80 } 86 }
81 87
82 -fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig { 88 +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
83 return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) 89 return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
84 } 90 }
85 91
@@ -97,7 +97,9 @@ class ViewController: UIViewController { @@ -97,7 +97,9 @@ class ViewController: UIViewController {
97 enableEndpoint: true, 97 enableEndpoint: true,
98 rule1MinTrailingSilence: 2.4, 98 rule1MinTrailingSilence: 2.4,
99 rule2MinTrailingSilence: 0.8, 99 rule2MinTrailingSilence: 0.8,
100 - rule3MinUtteranceLength: 30 100 + rule3MinUtteranceLength: 30,
  101 + decodingMethod: "greedy_search",
  102 + maxActivePaths: 4
101 ) 103 )
102 recognizer = SherpaOnnxRecognizer(config: &config) 104 recognizer = SherpaOnnxRecognizer(config: &config)
103 } 105 }
@@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream( @@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
76 76
77 void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } 77 void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
78 78
79 -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, 79 +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
80 const float *samples, int32_t n) { 80 const float *samples, int32_t n) {
81 stream->impl->AcceptWaveform(sample_rate, samples, n); 81 stream->impl->AcceptWaveform(sample_rate, samples, n);
82 } 82 }
@@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); @@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
120 /// @param samples A pointer to a 1-D array containing audio samples. 120 /// @param samples A pointer to a 1-D array containing audio samples.
121 /// The range of samples has to be normalized to [-1, 1]. 121 /// The range of samples has to be normalized to [-1, 1].
122 /// @param n Number of elements in the samples array. 122 /// @param n Number of elements in the samples array.
123 -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, 123 +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
124 const float *samples, int32_t n); 124 const float *samples, int32_t n);
125 125
126 /// Return 1 if there are enough number of feature frames for decoding. 126 /// Return 1 if there are enough number of feature frames for decoding.
@@ -48,7 +48,7 @@ class FeatureExtractor::Impl { @@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
48 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 48 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
49 } 49 }
50 50
51 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { 51 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
52 std::lock_guard<std::mutex> lock(mutex_); 52 std::lock_guard<std::mutex> lock(mutex_);
53 fbank_->AcceptWaveform(sampling_rate, waveform, n); 53 fbank_->AcceptWaveform(sampling_rate, waveform, n);
54 } 54 }
@@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) @@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
107 107
108 FeatureExtractor::~FeatureExtractor() = default; 108 FeatureExtractor::~FeatureExtractor() = default;
109 109
110 -void FeatureExtractor::AcceptWaveform(float sampling_rate, 110 +void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
111 const float *waveform, int32_t n) { 111 const float *waveform, int32_t n) {
112 impl_->AcceptWaveform(sampling_rate, waveform, n); 112 impl_->AcceptWaveform(sampling_rate, waveform, n);
113 } 113 }
@@ -14,7 +14,7 @@ @@ -14,7 +14,7 @@
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
17 - float sampling_rate = 16000; 17 + int32_t sampling_rate = 16000;
18 int32_t feature_dim = 80; 18 int32_t feature_dim = 80;
19 int32_t max_feature_vectors = -1; 19 int32_t max_feature_vectors = -1;
20 20
@@ -34,7 +34,7 @@ class FeatureExtractor { @@ -34,7 +34,7 @@ class FeatureExtractor {
34 @param waveform Pointer to a 1-D array of size n 34 @param waveform Pointer to a 1-D array of size n
35 @param n Number of entries in waveform 35 @param n Number of entries in waveform
36 */ 36 */
37 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); 37 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
38 38
39 /** 39 /**
40 * InputFinished() tells the class you won't be providing any 40 * InputFinished() tells the class you won't be providing any
@@ -112,7 +112,7 @@ for a list of pre-trained models to download. @@ -112,7 +112,7 @@ for a list of pre-trained models to download.
112 112
113 param.suggestedLatency = info->defaultLowInputLatency; 113 param.suggestedLatency = info->defaultLowInputLatency;
114 param.hostApiSpecificStreamInfo = nullptr; 114 param.hostApiSpecificStreamInfo = nullptr;
115 - const float sample_rate = 16000; 115 + float sample_rate = 16000;
116 116
117 PaStream *stream; 117 PaStream *stream;
118 PaError err = 118 PaError err =
@@ -61,7 +61,7 @@ for a list of pre-trained models to download. @@ -61,7 +61,7 @@ for a list of pre-trained models to download.
61 61
62 sherpa_onnx::OnlineRecognizer recognizer(config); 62 sherpa_onnx::OnlineRecognizer recognizer(config);
63 63
64 - float expected_sampling_rate = config.feat_config.sampling_rate; 64 + int32_t expected_sampling_rate = config.feat_config.sampling_rate;
65 65
66 bool is_ok = false; 66 bool is_ok = false;
67 std::vector<float> samples = 67 std::vector<float> samples =
@@ -72,7 +72,7 @@ for a list of pre-trained models to download. @@ -72,7 +72,7 @@ for a list of pre-trained models to download.
72 return -1; 72 return -1;
73 } 73 }
74 74
75 - float duration = samples.size() / expected_sampling_rate; 75 + float duration = samples.size() / static_cast<float>(expected_sampling_rate);
76 76
77 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); 77 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
78 fprintf(stderr, "wav duration (s): %.3f\n", duration); 78 fprintf(stderr, "wav duration (s): %.3f\n", duration);
@@ -40,19 +40,18 @@ class SherpaOnnx { @@ -40,19 +40,18 @@ class SherpaOnnx {
40 mgr, 40 mgr,
41 #endif 41 #endif
42 config), 42 config),
43 - stream_(recognizer_.CreateStream()),  
44 - tail_padding_(16000 * 0.32, 0) { 43 + stream_(recognizer_.CreateStream()) {
45 } 44 }
46 45
47 - void DecodeSamples(float sample_rate, const float *samples, int32_t n) const { 46 + void AcceptWaveform(int32_t sample_rate, const float *samples,
  47 + int32_t n) const {
48 stream_->AcceptWaveform(sample_rate, samples, n); 48 stream_->AcceptWaveform(sample_rate, samples, n);
49 - Decode();  
50 } 49 }
51 50
52 void InputFinished() const { 51 void InputFinished() const {
53 - stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size()); 52 + std::vector<float> tail_padding(16000 * 0.32, 0);
  53 + stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
54 stream_->InputFinished(); 54 stream_->InputFinished();
55 - Decode();  
56 } 55 }
57 56
58 const std::string GetText() const { 57 const std::string GetText() const {
@@ -62,19 +61,15 @@ class SherpaOnnx { @@ -62,19 +61,15 @@ class SherpaOnnx {
62 61
63 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } 62 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
64 63
  64 + bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
  65 +
65 void Reset() const { return recognizer_.Reset(stream_.get()); } 66 void Reset() const { return recognizer_.Reset(stream_.get()); }
66 67
67 - private:  
68 - void Decode() const {  
69 - while (recognizer_.IsReady(stream_.get())) {  
70 - recognizer_.DecodeStream(stream_.get());  
71 - }  
72 - } 68 + void Decode() const { recognizer_.DecodeStream(stream_.get()); }
73 69
74 private: 70 private:
75 sherpa_onnx::OnlineRecognizer recognizer_; 71 sherpa_onnx::OnlineRecognizer recognizer_;
76 std::unique_ptr<sherpa_onnx::OnlineStream> stream_; 72 std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
77 - std::vector<float> tail_padding_;  
78 }; 73 };
79 74
80 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { 75 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
@@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
86 // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html 81 // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
87 // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html 82 // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
88 83
  84 + //---------- decoding ----------
  85 + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
  86 + jstring s = (jstring)env->GetObjectField(config, fid);
  87 + const char *p = env->GetStringUTFChars(s, nullptr);
  88 + ans.decoding_method = p;
  89 + env->ReleaseStringUTFChars(s, p);
  90 +
  91 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  92 + ans.max_active_paths = env->GetIntField(config, fid);
  93 +
89 //---------- feat config ---------- 94 //---------- feat config ----------
90 fid = env->GetFieldID(cls, "featConfig", 95 fid = env->GetFieldID(cls, "featConfig",
91 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); 96 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
92 jobject feat_config = env->GetObjectField(config, fid); 97 jobject feat_config = env->GetObjectField(config, fid);
93 jclass feat_config_cls = env->GetObjectClass(feat_config); 98 jclass feat_config_cls = env->GetObjectClass(feat_config);
94 99
95 - fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");  
96 - ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid); 100 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  101 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
97 102
98 fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); 103 fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
99 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); 104 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
@@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
153 jclass model_config_cls = env->GetObjectClass(model_config); 158 jclass model_config_cls = env->GetObjectClass(model_config);
154 159
155 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); 160 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
156 - jstring s = (jstring)env->GetObjectField(model_config, fid);  
157 - const char *p = env->GetStringUTFChars(s, nullptr); 161 + s = (jstring)env->GetObjectField(model_config, fid);
  162 + p = env->GetStringUTFChars(s, nullptr);
158 ans.model_config.encoder_filename = p; 163 ans.model_config.encoder_filename = p;
159 env->ReleaseStringUTFChars(s, p); 164 env->ReleaseStringUTFChars(s, p);
160 165
@@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( @@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
198 #endif 203 #endif
199 204
200 auto config = sherpa_onnx::GetConfig(env, _config); 205 auto config = sherpa_onnx::GetConfig(env, _config);
  206 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
201 auto model = new sherpa_onnx::SherpaOnnx( 207 auto model = new sherpa_onnx::SherpaOnnx(
202 #if __ANDROID_API__ >= 9 208 #if __ANDROID_API__ >= 9
203 mgr, 209 mgr,
@@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( @@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
221 } 227 }
222 228
223 SHERPA_ONNX_EXTERN_C 229 SHERPA_ONNX_EXTERN_C
  230 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
  231 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  232 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  233 + return model->IsReady();
  234 +}
  235 +
  236 +SHERPA_ONNX_EXTERN_C
224 JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( 237 JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
225 JNIEnv *env, jobject /*obj*/, jlong ptr) { 238 JNIEnv *env, jobject /*obj*/, jlong ptr) {
226 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); 239 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
@@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( @@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
228 } 241 }
229 242
230 SHERPA_ONNX_EXTERN_C 243 SHERPA_ONNX_EXTERN_C
231 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples( 244 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
  245 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  246 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  247 + model->Decode();
  248 +}
  249 +
  250 +SHERPA_ONNX_EXTERN_C
  251 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
232 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, 252 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
233 - jfloat sample_rate) { 253 + jint sample_rate) {
234 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); 254 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
235 255
236 jfloat *p = env->GetFloatArrayElements(samples, nullptr); 256 jfloat *p = env->GetFloatArrayElements(samples, nullptr);
237 jsize n = env->GetArrayLength(samples); 257 jsize n = env->GetArrayLength(samples);
238 258
239 - model->DecodeSamples(sample_rate, p, n); 259 + model->AcceptWaveform(sample_rate, p, n);
240 260
241 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); 261 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
242 } 262 }
@@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig( @@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig(
62 enableEndpoint: Bool = false, 62 enableEndpoint: Bool = false,
63 rule1MinTrailingSilence: Float = 2.4, 63 rule1MinTrailingSilence: Float = 2.4,
64 rule2MinTrailingSilence: Float = 1.2, 64 rule2MinTrailingSilence: Float = 1.2,
65 - rule3MinUtteranceLength: Float = 30 65 + rule3MinUtteranceLength: Float = 30,
  66 + decodingMethod: String = "greedy_search",
  67 + maxActivePaths: Int = 4
66 ) -> SherpaOnnxOnlineRecognizerConfig{ 68 ) -> SherpaOnnxOnlineRecognizerConfig{
67 return SherpaOnnxOnlineRecognizerConfig( 69 return SherpaOnnxOnlineRecognizerConfig(
68 feat_config: featConfig, 70 feat_config: featConfig,
69 model_config: modelConfig, 71 model_config: modelConfig,
  72 + decoding_method: toCPointer(decodingMethod),
  73 + max_active_paths: Int32(maxActivePaths),
70 enable_endpoint: enableEndpoint ? 1 : 0, 74 enable_endpoint: enableEndpoint ? 1 : 0,
71 rule1_min_trailing_silence: rule1MinTrailingSilence, 75 rule1_min_trailing_silence: rule1MinTrailingSilence,
72 rule2_min_trailing_silence: rule2MinTrailingSilence, 76 rule2_min_trailing_silence: rule2MinTrailingSilence,
@@ -128,12 +132,12 @@ class SherpaOnnxRecognizer { @@ -128,12 +132,12 @@ class SherpaOnnxRecognizer {
128 /// Decode wave samples. 132 /// Decode wave samples.
129 /// 133 ///
130 /// - Parameters: 134 /// - Parameters:
131 - /// - samples: Audio samples normalzed to the range [-1, 1] 135 + /// - samples: Audio samples normalized to the range [-1, 1]
132 /// - sampleRate: Sample rate of the input audio samples. Must match 136 /// - sampleRate: Sample rate of the input audio samples. Must match
133 /// the one expected by the model. It must be 16000 for 137 /// the one expected by the model. It must be 16000 for
134 /// models from icefall. 138 /// models from icefall.
135 - func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {  
136 - AcceptWaveform(stream, sampleRate, samples, Int32(samples.count)) 139 + func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
  140 + AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
137 } 141 }
138 142
139 func isReady() -> Bool { 143 func isReady() -> Bool {
@@ -32,7 +32,9 @@ func run() { @@ -32,7 +32,9 @@ func run() {
32 var config = sherpaOnnxOnlineRecognizerConfig( 32 var config = sherpaOnnxOnlineRecognizerConfig(
33 featConfig: featConfig, 33 featConfig: featConfig,
34 modelConfig: modelConfig, 34 modelConfig: modelConfig,
35 - enableEndpoint: false 35 + enableEndpoint: false,
  36 + decodingMethod: "modified_beam_search",
  37 + maxActivePaths: 4
36 ) 38 )
37 39
38 40