Fangjun Kuang
Committed by GitHub

Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
1 Makefile 1 Makefile
2 *.jar 2 *.jar
  3 +hs_err_pid*.log
@@ -4,7 +4,7 @@ import android.content.res.AssetManager @@ -4,7 +4,7 @@ import android.content.res.AssetManager
4 4
5 fun main() { 5 fun main() {
6 var featConfig = FeatureConfig( 6 var featConfig = FeatureConfig(
7 - sampleRate = 16000.0f, 7 + sampleRate = 16000,
8 featureDim = 80, 8 featureDim = 80,
9 ) 9 )
10 10
@@ -13,7 +13,7 @@ fun main() { @@ -13,7 +13,7 @@ fun main() {
13 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", 13 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
14 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", 14 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
15 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", 15 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
16 - numThreads = 4, 16 + numThreads = 1,
17 debug = false, 17 debug = false,
18 ) 18 )
19 19
@@ -24,22 +24,31 @@ fun main() { @@ -24,22 +24,31 @@ fun main() {
24 featConfig = featConfig, 24 featConfig = featConfig,
25 endpointConfig = endpointConfig, 25 endpointConfig = endpointConfig,
26 enableEndpoint = true, 26 enableEndpoint = true,
  27 + decodingMethod = "greedy_search",
  28 + maxActivePaths = 4,
27 ) 29 )
28 30
29 var model = SherpaOnnx( 31 var model = SherpaOnnx(
30 assetManager = AssetManager(), 32 assetManager = AssetManager(),
31 config = config, 33 config = config,
32 ) 34 )
  35 +
33 var samples = WaveReader.readWave( 36 var samples = WaveReader.readWave(
34 assetManager = AssetManager(), 37 assetManager = AssetManager(),
35 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", 38 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
36 ) 39 )
37 40
38 - model.decodeSamples(samples!!) 41 + model.acceptWaveform(samples!!, sampleRate=16000)
  42 + while (model.isReady()) {
  43 + model.decode()
  44 + }
39 45
40 var tail_paddings = FloatArray(8000) // 0.5 seconds 46 var tail_paddings = FloatArray(8000) // 0.5 seconds
41 - model.decodeSamples(tail_paddings)  
42 - 47 + model.acceptWaveform(tail_paddings, sampleRate=16000)
43 model.inputFinished() 48 model.inputFinished()
  49 + while (model.isReady()) {
  50 + model.decode()
  51 + }
  52 +
44 println("results: ${model.text}") 53 println("results: ${model.text}")
45 } 54 }
@@ -38,3 +38,4 @@ log.txt @@ -38,3 +38,4 @@ log.txt
38 tags 38 tags
39 run-decode-file-python.sh 39 run-decode-file-python.sh
40 android/SherpaOnnx/app/src/main/assets/ 40 android/SherpaOnnx/app/src/main/assets/
  41 +*.ncnn.*
@@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() { @@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
121 val ret = audioRecord?.read(buffer, 0, buffer.size) 121 val ret = audioRecord?.read(buffer, 0, buffer.size)
122 if (ret != null && ret > 0) { 122 if (ret != null && ret > 0) {
123 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 123 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
124 - model.decodeSamples(samples) 124 + model.acceptWaveform(samples, sampleRate=16000)
  125 + while (model.isReady()) {
  126 + model.decode()
  127 + }
125 runOnUiThread { 128 runOnUiThread {
126 val isEndpoint = model.isEndpoint() 129 val isEndpoint = model.isEndpoint()
127 val text = model.text 130 val text = model.text
@@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() { @@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
177 val type = 0 180 val type = 0
178 println("Select model type ${type}") 181 println("Select model type ${type}")
179 val config = OnlineRecognizerConfig( 182 val config = OnlineRecognizerConfig(
180 - featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), 183 + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
181 modelConfig = getModelConfig(type = type)!!, 184 modelConfig = getModelConfig(type = type)!!,
182 endpointConfig = getEndpointConfig(), 185 endpointConfig = getEndpointConfig(),
183 - enableEndpoint = true 186 + enableEndpoint = true,
  187 + decodingMethod = "greedy_search",
  188 + maxActivePaths = 4,
184 ) 189 )
185 190
186 model = SherpaOnnx( 191 model = SherpaOnnx(
187 assetManager = application.assets, 192 assetManager = application.assets,
188 config = config, 193 config = config,
189 ) 194 )
190 - /*  
191 - println("reading samples")  
192 - val samples = WaveReader.readWave(  
193 - assetManager = application.assets,  
194 - // filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",  
195 - filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",  
196 - // filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"  
197 - )  
198 - println("samples read done!")  
199 -  
200 - model.decodeSamples(samples!!)  
201 -  
202 - val tailPaddings = FloatArray(8000) // 0.5 seconds  
203 - model.decodeSamples(tailPaddings)  
204 -  
205 - println("result is: ${model.text}")  
206 - model.reset()  
207 - */  
208 } 195 }
209 } 196 }
@@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig( @@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
24 ) 24 )
25 25
26 data class FeatureConfig( 26 data class FeatureConfig(
27 - var sampleRate: Float = 16000.0f, 27 + var sampleRate: Int = 16000,
28 var featureDim: Int = 80, 28 var featureDim: Int = 80,
29 ) 29 )
30 30
@@ -32,7 +32,9 @@ data class OnlineRecognizerConfig( @@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
32 var featConfig: FeatureConfig = FeatureConfig(), 32 var featConfig: FeatureConfig = FeatureConfig(),
33 var modelConfig: OnlineTransducerModelConfig, 33 var modelConfig: OnlineTransducerModelConfig,
34 var endpointConfig: EndpointConfig = EndpointConfig(), 34 var endpointConfig: EndpointConfig = EndpointConfig(),
35 - var enableEndpoint: Boolean, 35 + var enableEndpoint: Boolean = true,
  36 + var decodingMethod: String = "greedy_search",
  37 + var maxActivePaths: Int = 4,
36 ) 38 )
37 39
38 class SherpaOnnx( 40 class SherpaOnnx(
@@ -49,12 +51,14 @@ class SherpaOnnx( @@ -49,12 +51,14 @@ class SherpaOnnx(
49 } 51 }
50 52
51 53
52 - fun decodeSamples(samples: FloatArray) =  
53 - decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate) 54 + fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
  55 + acceptWaveform(ptr, samples, sampleRate)
54 56
55 fun inputFinished() = inputFinished(ptr) 57 fun inputFinished() = inputFinished(ptr)
56 fun reset() = reset(ptr) 58 fun reset() = reset(ptr)
  59 + fun decode() = decode(ptr)
57 fun isEndpoint(): Boolean = isEndpoint(ptr) 60 fun isEndpoint(): Boolean = isEndpoint(ptr)
  61 + fun isReady(): Boolean = isReady(ptr)
58 62
59 val text: String 63 val text: String
60 get() = getText(ptr) 64 get() = getText(ptr)
@@ -66,11 +70,13 @@ class SherpaOnnx( @@ -66,11 +70,13 @@ class SherpaOnnx(
66 config: OnlineRecognizerConfig, 70 config: OnlineRecognizerConfig,
67 ): Long 71 ): Long
68 72
69 - private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float) 73 + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
70 private external fun inputFinished(ptr: Long) 74 private external fun inputFinished(ptr: Long)
71 private external fun getText(ptr: Long): String 75 private external fun getText(ptr: Long): String
72 private external fun reset(ptr: Long) 76 private external fun reset(ptr: Long)
  77 + private external fun decode(ptr: Long)
73 private external fun isEndpoint(ptr: Long): Boolean 78 private external fun isEndpoint(ptr: Long): Boolean
  79 + private external fun isReady(ptr: Long): Boolean
74 80
75 companion object { 81 companion object {
76 init { 82 init {
@@ -79,7 +85,7 @@ class SherpaOnnx( @@ -79,7 +85,7 @@ class SherpaOnnx(
79 } 85 }
80 } 86 }
81 87
82 -fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig { 88 +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
83 return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) 89 return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
84 } 90 }
85 91
@@ -23,10 +23,10 @@ extension AVAudioPCMBuffer { @@ -23,10 +23,10 @@ extension AVAudioPCMBuffer {
23 class ViewController: UIViewController { 23 class ViewController: UIViewController {
24 @IBOutlet weak var resultLabel: UILabel! 24 @IBOutlet weak var resultLabel: UILabel!
25 @IBOutlet weak var recordBtn: UIButton! 25 @IBOutlet weak var recordBtn: UIButton!
26 - 26 +
27 var audioEngine: AVAudioEngine? = nil 27 var audioEngine: AVAudioEngine? = nil
28 var recognizer: SherpaOnnxRecognizer! = nil 28 var recognizer: SherpaOnnxRecognizer! = nil
29 - 29 +
30 /// It saves the decoded results so far 30 /// It saves the decoded results so far
31 var sentences: [String] = [] { 31 var sentences: [String] = [] {
32 didSet { 32 didSet {
@@ -42,7 +42,7 @@ class ViewController: UIViewController { @@ -42,7 +42,7 @@ class ViewController: UIViewController {
42 if sentences.isEmpty { 42 if sentences.isEmpty {
43 return "0: \(lastSentence.lowercased())" 43 return "0: \(lastSentence.lowercased())"
44 } 44 }
45 - 45 +
46 let start = max(sentences.count - maxSentence, 0) 46 let start = max(sentences.count - maxSentence, 0)
47 if lastSentence.isEmpty { 47 if lastSentence.isEmpty {
48 return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...] 48 return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...]
@@ -52,23 +52,23 @@ class ViewController: UIViewController { @@ -52,23 +52,23 @@ class ViewController: UIViewController {
52 .joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())" 52 .joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())"
53 } 53 }
54 } 54 }
55 - 55 +
56 func updateLabel() { 56 func updateLabel() {
57 DispatchQueue.main.async { 57 DispatchQueue.main.async {
58 self.resultLabel.text = self.results 58 self.resultLabel.text = self.results
59 } 59 }
60 } 60 }
61 - 61 +
62 override func viewDidLoad() { 62 override func viewDidLoad() {
63 super.viewDidLoad() 63 super.viewDidLoad()
64 // Do any additional setup after loading the view. 64 // Do any additional setup after loading the view.
65 - 65 +
66 resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" 66 resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
67 recordBtn.setTitle("Start", for: .normal) 67 recordBtn.setTitle("Start", for: .normal)
68 initRecognizer() 68 initRecognizer()
69 initRecorder() 69 initRecorder()
70 } 70 }
71 - 71 +
72 @IBAction func onRecordBtnClick(_ sender: UIButton) { 72 @IBAction func onRecordBtnClick(_ sender: UIButton) {
73 if recordBtn.currentTitle == "Start" { 73 if recordBtn.currentTitle == "Start" {
74 startRecorder() 74 startRecorder()
@@ -78,30 +78,32 @@ class ViewController: UIViewController { @@ -78,30 +78,32 @@ class ViewController: UIViewController {
78 recordBtn.setTitle("Start", for: .normal) 78 recordBtn.setTitle("Start", for: .normal)
79 } 79 }
80 } 80 }
81 - 81 +
82 func initRecognizer() { 82 func initRecognizer() {
83 // Please select one model that is best suitable for you. 83 // Please select one model that is best suitable for you.
84 // 84 //
85 // You can also modify Model.swift to add new pre-trained models from 85 // You can also modify Model.swift to add new pre-trained models from
86 // https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html 86 // https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
87 - 87 +
88 let modelConfig = getBilingualStreamZhEnZipformer20230220() 88 let modelConfig = getBilingualStreamZhEnZipformer20230220()
89 - 89 +
90 let featConfig = sherpaOnnxFeatureConfig( 90 let featConfig = sherpaOnnxFeatureConfig(
91 sampleRate: 16000, 91 sampleRate: 16000,
92 featureDim: 80) 92 featureDim: 80)
93 - 93 +
94 var config = sherpaOnnxOnlineRecognizerConfig( 94 var config = sherpaOnnxOnlineRecognizerConfig(
95 featConfig: featConfig, 95 featConfig: featConfig,
96 modelConfig: modelConfig, 96 modelConfig: modelConfig,
97 enableEndpoint: true, 97 enableEndpoint: true,
98 rule1MinTrailingSilence: 2.4, 98 rule1MinTrailingSilence: 2.4,
99 rule2MinTrailingSilence: 0.8, 99 rule2MinTrailingSilence: 0.8,
100 - rule3MinUtteranceLength: 30 100 + rule3MinUtteranceLength: 30,
  101 + decodingMethod: "greedy_search",
  102 + maxActivePaths: 4
101 ) 103 )
102 recognizer = SherpaOnnxRecognizer(config: &config) 104 recognizer = SherpaOnnxRecognizer(config: &config)
103 } 105 }
104 - 106 +
105 func initRecorder() { 107 func initRecorder() {
106 print("init recorder") 108 print("init recorder")
107 audioEngine = AVAudioEngine() 109 audioEngine = AVAudioEngine()
@@ -112,9 +114,9 @@ class ViewController: UIViewController { @@ -112,9 +114,9 @@ class ViewController: UIViewController {
112 commonFormat: .pcmFormatFloat32, 114 commonFormat: .pcmFormatFloat32,
113 sampleRate: 16000, channels: 1, 115 sampleRate: 16000, channels: 1,
114 interleaved: false)! 116 interleaved: false)!
115 - 117 +
116 let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)! 118 let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)!
117 - 119 +
118 inputNode!.installTap( 120 inputNode!.installTap(
119 onBus: bus, 121 onBus: bus,
120 bufferSize: 1024, 122 bufferSize: 1024,
@@ -122,34 +124,34 @@ class ViewController: UIViewController { @@ -122,34 +124,34 @@ class ViewController: UIViewController {
122 ) { 124 ) {
123 (buffer: AVAudioPCMBuffer, when: AVAudioTime) in 125 (buffer: AVAudioPCMBuffer, when: AVAudioTime) in
124 var newBufferAvailable = true 126 var newBufferAvailable = true
125 - 127 +
126 let inputCallback: AVAudioConverterInputBlock = { 128 let inputCallback: AVAudioConverterInputBlock = {
127 inNumPackets, outStatus in 129 inNumPackets, outStatus in
128 if newBufferAvailable { 130 if newBufferAvailable {
129 outStatus.pointee = .haveData 131 outStatus.pointee = .haveData
130 newBufferAvailable = false 132 newBufferAvailable = false
131 - 133 +
132 return buffer 134 return buffer
133 } else { 135 } else {
134 outStatus.pointee = .noDataNow 136 outStatus.pointee = .noDataNow
135 return nil 137 return nil
136 } 138 }
137 } 139 }
138 - 140 +
139 let convertedBuffer = AVAudioPCMBuffer( 141 let convertedBuffer = AVAudioPCMBuffer(
140 pcmFormat: outputFormat, 142 pcmFormat: outputFormat,
141 frameCapacity: 143 frameCapacity:
142 AVAudioFrameCount(outputFormat.sampleRate) 144 AVAudioFrameCount(outputFormat.sampleRate)
143 * buffer.frameLength 145 * buffer.frameLength
144 / AVAudioFrameCount(buffer.format.sampleRate))! 146 / AVAudioFrameCount(buffer.format.sampleRate))!
145 - 147 +
146 var error: NSError? 148 var error: NSError?
147 let _ = converter.convert( 149 let _ = converter.convert(
148 to: convertedBuffer, 150 to: convertedBuffer,
149 error: &error, withInputFrom: inputCallback) 151 error: &error, withInputFrom: inputCallback)
150 - 152 +
151 // TODO(fangjun): Handle status != haveData 153 // TODO(fangjun): Handle status != haveData
152 - 154 +
153 let array = convertedBuffer.array() 155 let array = convertedBuffer.array()
154 if !array.isEmpty { 156 if !array.isEmpty {
155 self.recognizer.acceptWaveform(samples: array) 157 self.recognizer.acceptWaveform(samples: array)
@@ -158,13 +160,13 @@ class ViewController: UIViewController { @@ -158,13 +160,13 @@ class ViewController: UIViewController {
158 } 160 }
159 let isEndpoint = self.recognizer.isEndpoint() 161 let isEndpoint = self.recognizer.isEndpoint()
160 let text = self.recognizer.getResult().text 162 let text = self.recognizer.getResult().text
161 - 163 +
162 if !text.isEmpty && self.lastSentence != text { 164 if !text.isEmpty && self.lastSentence != text {
163 self.lastSentence = text 165 self.lastSentence = text
164 self.updateLabel() 166 self.updateLabel()
165 print(text) 167 print(text)
166 } 168 }
167 - 169 +
168 if isEndpoint { 170 if isEndpoint {
169 if !text.isEmpty { 171 if !text.isEmpty {
170 let tmp = self.lastSentence 172 let tmp = self.lastSentence
@@ -175,13 +177,13 @@ class ViewController: UIViewController { @@ -175,13 +177,13 @@ class ViewController: UIViewController {
175 } 177 }
176 } 178 }
177 } 179 }
178 - 180 +
179 } 181 }
180 - 182 +
181 func startRecorder() { 183 func startRecorder() {
182 lastSentence = "" 184 lastSentence = ""
183 sentences = [] 185 sentences = []
184 - 186 +
185 do { 187 do {
186 try self.audioEngine?.start() 188 try self.audioEngine?.start()
187 } catch let error as NSError { 189 } catch let error as NSError {
@@ -189,7 +191,7 @@ class ViewController: UIViewController { @@ -189,7 +191,7 @@ class ViewController: UIViewController {
189 } 191 }
190 print("started") 192 print("started")
191 } 193 }
192 - 194 +
193 func stopRecorder() { 195 func stopRecorder() {
194 audioEngine?.stop() 196 audioEngine?.stop()
195 print("stopped") 197 print("stopped")
@@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream( @@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
76 76
77 void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } 77 void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
78 78
79 -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, 79 +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
80 const float *samples, int32_t n) { 80 const float *samples, int32_t n) {
81 stream->impl->AcceptWaveform(sample_rate, samples, n); 81 stream->impl->AcceptWaveform(sample_rate, samples, n);
82 } 82 }
@@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); @@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
120 /// @param samples A pointer to a 1-D array containing audio samples. 120 /// @param samples A pointer to a 1-D array containing audio samples.
121 /// The range of samples has to be normalized to [-1, 1]. 121 /// The range of samples has to be normalized to [-1, 1].
122 /// @param n Number of elements in the samples array. 122 /// @param n Number of elements in the samples array.
123 -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, 123 +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
124 const float *samples, int32_t n); 124 const float *samples, int32_t n);
125 125
126 /// Return 1 if there are enough number of feature frames for decoding. 126 /// Return 1 if there are enough number of feature frames for decoding.
@@ -48,7 +48,7 @@ class FeatureExtractor::Impl { @@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
48 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 48 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
49 } 49 }
50 50
51 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { 51 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
52 std::lock_guard<std::mutex> lock(mutex_); 52 std::lock_guard<std::mutex> lock(mutex_);
53 fbank_->AcceptWaveform(sampling_rate, waveform, n); 53 fbank_->AcceptWaveform(sampling_rate, waveform, n);
54 } 54 }
@@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) @@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
107 107
108 FeatureExtractor::~FeatureExtractor() = default; 108 FeatureExtractor::~FeatureExtractor() = default;
109 109
110 -void FeatureExtractor::AcceptWaveform(float sampling_rate, 110 +void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
111 const float *waveform, int32_t n) { 111 const float *waveform, int32_t n) {
112 impl_->AcceptWaveform(sampling_rate, waveform, n); 112 impl_->AcceptWaveform(sampling_rate, waveform, n);
113 } 113 }
@@ -14,7 +14,7 @@ @@ -14,7 +14,7 @@
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
17 - float sampling_rate = 16000; 17 + int32_t sampling_rate = 16000;
18 int32_t feature_dim = 80; 18 int32_t feature_dim = 80;
19 int32_t max_feature_vectors = -1; 19 int32_t max_feature_vectors = -1;
20 20
@@ -34,7 +34,7 @@ class FeatureExtractor { @@ -34,7 +34,7 @@ class FeatureExtractor {
34 @param waveform Pointer to a 1-D array of size n 34 @param waveform Pointer to a 1-D array of size n
35 @param n Number of entries in waveform 35 @param n Number of entries in waveform
36 */ 36 */
37 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); 37 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
38 38
39 /** 39 /**
40 * InputFinished() tells the class you won't be providing any 40 * InputFinished() tells the class you won't be providing any
@@ -112,7 +112,7 @@ for a list of pre-trained models to download. @@ -112,7 +112,7 @@ for a list of pre-trained models to download.
112 112
113 param.suggestedLatency = info->defaultLowInputLatency; 113 param.suggestedLatency = info->defaultLowInputLatency;
114 param.hostApiSpecificStreamInfo = nullptr; 114 param.hostApiSpecificStreamInfo = nullptr;
115 - const float sample_rate = 16000; 115 + float sample_rate = 16000;
116 116
117 PaStream *stream; 117 PaStream *stream;
118 PaError err = 118 PaError err =
@@ -61,7 +61,7 @@ for a list of pre-trained models to download. @@ -61,7 +61,7 @@ for a list of pre-trained models to download.
61 61
62 sherpa_onnx::OnlineRecognizer recognizer(config); 62 sherpa_onnx::OnlineRecognizer recognizer(config);
63 63
64 - float expected_sampling_rate = config.feat_config.sampling_rate; 64 + int32_t expected_sampling_rate = config.feat_config.sampling_rate;
65 65
66 bool is_ok = false; 66 bool is_ok = false;
67 std::vector<float> samples = 67 std::vector<float> samples =
@@ -72,7 +72,7 @@ for a list of pre-trained models to download. @@ -72,7 +72,7 @@ for a list of pre-trained models to download.
72 return -1; 72 return -1;
73 } 73 }
74 74
75 - float duration = samples.size() / expected_sampling_rate; 75 + float duration = samples.size() / static_cast<float>(expected_sampling_rate);
76 76
77 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); 77 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
78 fprintf(stderr, "wav duration (s): %.3f\n", duration); 78 fprintf(stderr, "wav duration (s): %.3f\n", duration);
@@ -40,19 +40,18 @@ class SherpaOnnx { @@ -40,19 +40,18 @@ class SherpaOnnx {
40 mgr, 40 mgr,
41 #endif 41 #endif
42 config), 42 config),
43 - stream_(recognizer_.CreateStream()),  
44 - tail_padding_(16000 * 0.32, 0) { 43 + stream_(recognizer_.CreateStream()) {
45 } 44 }
46 45
47 - void DecodeSamples(float sample_rate, const float *samples, int32_t n) const { 46 + void AcceptWaveform(int32_t sample_rate, const float *samples,
  47 + int32_t n) const {
48 stream_->AcceptWaveform(sample_rate, samples, n); 48 stream_->AcceptWaveform(sample_rate, samples, n);
49 - Decode();  
50 } 49 }
51 50
52 void InputFinished() const { 51 void InputFinished() const {
53 - stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size()); 52 + std::vector<float> tail_padding(16000 * 0.32, 0);
  53 + stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
54 stream_->InputFinished(); 54 stream_->InputFinished();
55 - Decode();  
56 } 55 }
57 56
58 const std::string GetText() const { 57 const std::string GetText() const {
@@ -62,19 +61,15 @@ class SherpaOnnx { @@ -62,19 +61,15 @@ class SherpaOnnx {
62 61
63 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } 62 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
64 63
  64 + bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
  65 +
65 void Reset() const { return recognizer_.Reset(stream_.get()); } 66 void Reset() const { return recognizer_.Reset(stream_.get()); }
66 67
67 - private:  
68 - void Decode() const {  
69 - while (recognizer_.IsReady(stream_.get())) {  
70 - recognizer_.DecodeStream(stream_.get());  
71 - }  
72 - } 68 + void Decode() const { recognizer_.DecodeStream(stream_.get()); }
73 69
74 private: 70 private:
75 sherpa_onnx::OnlineRecognizer recognizer_; 71 sherpa_onnx::OnlineRecognizer recognizer_;
76 std::unique_ptr<sherpa_onnx::OnlineStream> stream_; 72 std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
77 - std::vector<float> tail_padding_;  
78 }; 73 };
79 74
80 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { 75 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
@@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
86 // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html 81 // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
87 // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html 82 // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
88 83
  84 + //---------- decoding ----------
  85 + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
  86 + jstring s = (jstring)env->GetObjectField(config, fid);
  87 + const char *p = env->GetStringUTFChars(s, nullptr);
  88 + ans.decoding_method = p;
  89 + env->ReleaseStringUTFChars(s, p);
  90 +
  91 + fid = env->GetFieldID(cls, "maxActivePaths", "I");
  92 + ans.max_active_paths = env->GetIntField(config, fid);
  93 +
89 //---------- feat config ---------- 94 //---------- feat config ----------
90 fid = env->GetFieldID(cls, "featConfig", 95 fid = env->GetFieldID(cls, "featConfig",
91 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); 96 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
92 jobject feat_config = env->GetObjectField(config, fid); 97 jobject feat_config = env->GetObjectField(config, fid);
93 jclass feat_config_cls = env->GetObjectClass(feat_config); 98 jclass feat_config_cls = env->GetObjectClass(feat_config);
94 99
95 - fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");  
96 - ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid); 100 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
  101 + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
97 102
98 fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); 103 fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
99 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); 104 ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
@@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
153 jclass model_config_cls = env->GetObjectClass(model_config); 158 jclass model_config_cls = env->GetObjectClass(model_config);
154 159
155 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); 160 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
156 - jstring s = (jstring)env->GetObjectField(model_config, fid);  
157 - const char *p = env->GetStringUTFChars(s, nullptr); 161 + s = (jstring)env->GetObjectField(model_config, fid);
  162 + p = env->GetStringUTFChars(s, nullptr);
158 ans.model_config.encoder_filename = p; 163 ans.model_config.encoder_filename = p;
159 env->ReleaseStringUTFChars(s, p); 164 env->ReleaseStringUTFChars(s, p);
160 165
@@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( @@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
198 #endif 203 #endif
199 204
200 auto config = sherpa_onnx::GetConfig(env, _config); 205 auto config = sherpa_onnx::GetConfig(env, _config);
  206 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
201 auto model = new sherpa_onnx::SherpaOnnx( 207 auto model = new sherpa_onnx::SherpaOnnx(
202 #if __ANDROID_API__ >= 9 208 #if __ANDROID_API__ >= 9
203 mgr, 209 mgr,
@@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( @@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
221 } 227 }
222 228
223 SHERPA_ONNX_EXTERN_C 229 SHERPA_ONNX_EXTERN_C
  230 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
  231 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  232 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  233 + return model->IsReady();
  234 +}
  235 +
  236 +SHERPA_ONNX_EXTERN_C
224 JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( 237 JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
225 JNIEnv *env, jobject /*obj*/, jlong ptr) { 238 JNIEnv *env, jobject /*obj*/, jlong ptr) {
226 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); 239 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
@@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( @@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
228 } 241 }
229 242
230 SHERPA_ONNX_EXTERN_C 243 SHERPA_ONNX_EXTERN_C
231 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples( 244 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
  245 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  246 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  247 + model->Decode();
  248 +}
  249 +
  250 +SHERPA_ONNX_EXTERN_C
  251 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
232 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, 252 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
233 - jfloat sample_rate) { 253 + jint sample_rate) {
234 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); 254 auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
235 255
236 jfloat *p = env->GetFloatArrayElements(samples, nullptr); 256 jfloat *p = env->GetFloatArrayElements(samples, nullptr);
237 jsize n = env->GetArrayLength(samples); 257 jsize n = env->GetArrayLength(samples);
238 258
239 - model->DecodeSamples(sample_rate, p, n); 259 + model->AcceptWaveform(sample_rate, p, n);
240 260
241 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); 261 env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
242 } 262 }
@@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig( @@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig(
62 enableEndpoint: Bool = false, 62 enableEndpoint: Bool = false,
63 rule1MinTrailingSilence: Float = 2.4, 63 rule1MinTrailingSilence: Float = 2.4,
64 rule2MinTrailingSilence: Float = 1.2, 64 rule2MinTrailingSilence: Float = 1.2,
65 - rule3MinUtteranceLength: Float = 30 65 + rule3MinUtteranceLength: Float = 30,
  66 + decodingMethod: String = "greedy_search",
  67 + maxActivePaths: Int = 4
66 ) -> SherpaOnnxOnlineRecognizerConfig{ 68 ) -> SherpaOnnxOnlineRecognizerConfig{
67 return SherpaOnnxOnlineRecognizerConfig( 69 return SherpaOnnxOnlineRecognizerConfig(
68 feat_config: featConfig, 70 feat_config: featConfig,
69 model_config: modelConfig, 71 model_config: modelConfig,
  72 + decoding_method: toCPointer(decodingMethod),
  73 + max_active_paths: Int32(maxActivePaths),
70 enable_endpoint: enableEndpoint ? 1 : 0, 74 enable_endpoint: enableEndpoint ? 1 : 0,
71 rule1_min_trailing_silence: rule1MinTrailingSilence, 75 rule1_min_trailing_silence: rule1MinTrailingSilence,
72 rule2_min_trailing_silence: rule2MinTrailingSilence, 76 rule2_min_trailing_silence: rule2MinTrailingSilence,
@@ -128,12 +132,12 @@ class SherpaOnnxRecognizer { @@ -128,12 +132,12 @@ class SherpaOnnxRecognizer {
128 /// Decode wave samples. 132 /// Decode wave samples.
129 /// 133 ///
130 /// - Parameters: 134 /// - Parameters:
131 - /// - samples: Audio samples normalzed to the range [-1, 1] 135 + /// - samples: Audio samples normalized to the range [-1, 1]
132 /// - sampleRate: Sample rate of the input audio samples. Must match 136 /// - sampleRate: Sample rate of the input audio samples. Must match
133 /// the one expected by the model. It must be 16000 for 137 /// the one expected by the model. It must be 16000 for
134 /// models from icefall. 138 /// models from icefall.
135 - func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {  
136 - AcceptWaveform(stream, sampleRate, samples, Int32(samples.count)) 139 + func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
  140 + AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
137 } 141 }
138 142
139 func isReady() -> Bool { 143 func isReady() -> Bool {
@@ -32,7 +32,9 @@ func run() { @@ -32,7 +32,9 @@ func run() {
32 var config = sherpaOnnxOnlineRecognizerConfig( 32 var config = sherpaOnnxOnlineRecognizerConfig(
33 featConfig: featConfig, 33 featConfig: featConfig,
34 modelConfig: modelConfig, 34 modelConfig: modelConfig,
35 - enableEndpoint: false 35 + enableEndpoint: false,
  36 + decodingMethod: "modified_beam_search",
  37 + maxActivePaths: 4
36 ) 38 )
37 39
38 40