esavin
Committed by GitHub

Expose dither for JNI (#2215)

@@ -6,10 +6,12 @@ package com.k2fsa.sherpa.onnx; @@ -6,10 +6,12 @@ package com.k2fsa.sherpa.onnx;
6 public class FeatureConfig { 6 public class FeatureConfig {
7 private final int sampleRate; 7 private final int sampleRate;
8 private final int featureDim; 8 private final int featureDim;
  9 + private final float dither;
9 10
10 private FeatureConfig(Builder builder) { 11 private FeatureConfig(Builder builder) {
11 this.sampleRate = builder.sampleRate; 12 this.sampleRate = builder.sampleRate;
12 this.featureDim = builder.featureDim; 13 this.featureDim = builder.featureDim;
  14 + this.dither = builder.dither;
13 } 15 }
14 16
15 public static Builder builder() { 17 public static Builder builder() {
@@ -24,9 +26,14 @@ public class FeatureConfig { @@ -24,9 +26,14 @@ public class FeatureConfig {
24 return featureDim; 26 return featureDim;
25 } 27 }
26 28
  29 + public float getDither() {
  30 + return dither;
  31 + }
  32 +
27 public static class Builder { 33 public static class Builder {
28 private int sampleRate = 16000; 34 private int sampleRate = 16000;
29 private int featureDim = 80; 35 private int featureDim = 80;
  36 + private float dither = 0.0f;
30 37
31 public FeatureConfig build() { 38 public FeatureConfig build() {
32 return new FeatureConfig(this); 39 return new FeatureConfig(this);
@@ -41,5 +48,9 @@ public class FeatureConfig { @@ -41,5 +48,9 @@ public class FeatureConfig {
41 this.featureDim = featureDim; 48 this.featureDim = featureDim;
42 return this; 49 return this;
43 } 50 }
  51 + public Builder setDither(float dither) {
  52 + this.dither = dither;
  53 + return this;
  54 + }
44 } 55 }
45 } 56 }
@@ -49,6 +49,9 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { @@ -49,6 +49,9 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
49 fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); 49 fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
50 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); 50 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
51 51
  52 + fid = env->GetFieldID(feat_config_cls, "dither", "F");
  53 + ans.feat_config.dither = env->GetFloatField(feat_config, fid);
  54 +
52 //---------- model config ---------- 55 //---------- model config ----------
53 fid = env->GetFieldID(cls, "modelConfig", 56 fid = env->GetFieldID(cls, "modelConfig",
54 "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); 57 "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
@@ -62,6 +62,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { @@ -62,6 +62,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
62 fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); 62 fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
63 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); 63 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
64 64
  65 + fid = env->GetFieldID(feat_config_cls, "dither", "F");
  66 + ans.feat_config.dither = env->GetFloatField(feat_config, fid);
  67 +
65 //---------- model config ---------- 68 //---------- model config ----------
66 fid = env->GetFieldID(cls, "modelConfig", 69 fid = env->GetFieldID(cls, "modelConfig",
67 "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;"); 70 "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
@@ -65,6 +65,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -65,6 +65,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
65 fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); 65 fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
66 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); 66 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
67 67
  68 + fid = env->GetFieldID(feat_config_cls, "dither", "F");
  69 + ans.feat_config.dither = env->GetFloatField(feat_config, fid);
  70 +
68 //---------- enable endpoint ---------- 71 //---------- enable endpoint ----------
69 fid = env->GetFieldID(cls, "enableEndpoint", "Z"); 72 fid = env->GetFieldID(cls, "enableEndpoint", "Z");
70 ans.enable_endpoint = env->GetBooleanField(config, fid); 73 ans.enable_endpoint = env->GetBooleanField(config, fid);
@@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx @@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx
3 data class FeatureConfig( 3 data class FeatureConfig(
4 var sampleRate: Int = 16000, 4 var sampleRate: Int = 16000,
5 var featureDim: Int = 80, 5 var featureDim: Int = 80,
  6 + var dither: Float = 0.0f
6 ) 7 )
7 8
8 fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { 9 fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {