Committed by
GitHub
Fix modified beam search for iOS and android (#76)
* Use Int type for sampling rate * Fix swift * Fix iOS
正在显示
15 个修改的文件
包含
125 行增加
和
93 行删除
| @@ -4,7 +4,7 @@ import android.content.res.AssetManager | @@ -4,7 +4,7 @@ import android.content.res.AssetManager | ||
| 4 | 4 | ||
| 5 | fun main() { | 5 | fun main() { |
| 6 | var featConfig = FeatureConfig( | 6 | var featConfig = FeatureConfig( |
| 7 | - sampleRate = 16000.0f, | 7 | + sampleRate = 16000, |
| 8 | featureDim = 80, | 8 | featureDim = 80, |
| 9 | ) | 9 | ) |
| 10 | 10 | ||
| @@ -13,7 +13,7 @@ fun main() { | @@ -13,7 +13,7 @@ fun main() { | ||
| 13 | decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", | 13 | decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", |
| 14 | joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", | 14 | joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", |
| 15 | tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", | 15 | tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", |
| 16 | - numThreads = 4, | 16 | + numThreads = 1, |
| 17 | debug = false, | 17 | debug = false, |
| 18 | ) | 18 | ) |
| 19 | 19 | ||
| @@ -24,22 +24,31 @@ fun main() { | @@ -24,22 +24,31 @@ fun main() { | ||
| 24 | featConfig = featConfig, | 24 | featConfig = featConfig, |
| 25 | endpointConfig = endpointConfig, | 25 | endpointConfig = endpointConfig, |
| 26 | enableEndpoint = true, | 26 | enableEndpoint = true, |
| 27 | + decodingMethod = "greedy_search", | ||
| 28 | + maxActivePaths = 4, | ||
| 27 | ) | 29 | ) |
| 28 | 30 | ||
| 29 | var model = SherpaOnnx( | 31 | var model = SherpaOnnx( |
| 30 | assetManager = AssetManager(), | 32 | assetManager = AssetManager(), |
| 31 | config = config, | 33 | config = config, |
| 32 | ) | 34 | ) |
| 35 | + | ||
| 33 | var samples = WaveReader.readWave( | 36 | var samples = WaveReader.readWave( |
| 34 | assetManager = AssetManager(), | 37 | assetManager = AssetManager(), |
| 35 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", | 38 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", |
| 36 | ) | 39 | ) |
| 37 | 40 | ||
| 38 | - model.decodeSamples(samples!!) | 41 | + model.acceptWaveform(samples!!, sampleRate=16000) |
| 42 | + while (model.isReady()) { | ||
| 43 | + model.decode() | ||
| 44 | + } | ||
| 39 | 45 | ||
| 40 | var tail_paddings = FloatArray(8000) // 0.5 seconds | 46 | var tail_paddings = FloatArray(8000) // 0.5 seconds |
| 41 | - model.decodeSamples(tail_paddings) | ||
| 42 | - | 47 | + model.acceptWaveform(tail_paddings, sampleRate=16000) |
| 43 | model.inputFinished() | 48 | model.inputFinished() |
| 49 | + while (model.isReady()) { | ||
| 50 | + model.decode() | ||
| 51 | + } | ||
| 52 | + | ||
| 44 | println("results: ${model.text}") | 53 | println("results: ${model.text}") |
| 45 | } | 54 | } |
| @@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() { | @@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() { | ||
| 121 | val ret = audioRecord?.read(buffer, 0, buffer.size) | 121 | val ret = audioRecord?.read(buffer, 0, buffer.size) |
| 122 | if (ret != null && ret > 0) { | 122 | if (ret != null && ret > 0) { |
| 123 | val samples = FloatArray(ret) { buffer[it] / 32768.0f } | 123 | val samples = FloatArray(ret) { buffer[it] / 32768.0f } |
| 124 | - model.decodeSamples(samples) | 124 | + model.acceptWaveform(samples, sampleRate=16000) |
| 125 | + while (model.isReady()) { | ||
| 126 | + model.decode() | ||
| 127 | + } | ||
| 125 | runOnUiThread { | 128 | runOnUiThread { |
| 126 | val isEndpoint = model.isEndpoint() | 129 | val isEndpoint = model.isEndpoint() |
| 127 | val text = model.text | 130 | val text = model.text |
| @@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() { | @@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() { | ||
| 177 | val type = 0 | 180 | val type = 0 |
| 178 | println("Select model type ${type}") | 181 | println("Select model type ${type}") |
| 179 | val config = OnlineRecognizerConfig( | 182 | val config = OnlineRecognizerConfig( |
| 180 | - featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), | 183 | + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), |
| 181 | modelConfig = getModelConfig(type = type)!!, | 184 | modelConfig = getModelConfig(type = type)!!, |
| 182 | endpointConfig = getEndpointConfig(), | 185 | endpointConfig = getEndpointConfig(), |
| 183 | - enableEndpoint = true | 186 | + enableEndpoint = true, |
| 187 | + decodingMethod = "greedy_search", | ||
| 188 | + maxActivePaths = 4, | ||
| 184 | ) | 189 | ) |
| 185 | 190 | ||
| 186 | model = SherpaOnnx( | 191 | model = SherpaOnnx( |
| 187 | assetManager = application.assets, | 192 | assetManager = application.assets, |
| 188 | config = config, | 193 | config = config, |
| 189 | ) | 194 | ) |
| 190 | - /* | ||
| 191 | - println("reading samples") | ||
| 192 | - val samples = WaveReader.readWave( | ||
| 193 | - assetManager = application.assets, | ||
| 194 | - // filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav", | ||
| 195 | - filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav", | ||
| 196 | - // filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav" | ||
| 197 | - ) | ||
| 198 | - println("samples read done!") | ||
| 199 | - | ||
| 200 | - model.decodeSamples(samples!!) | ||
| 201 | - | ||
| 202 | - val tailPaddings = FloatArray(8000) // 0.5 seconds | ||
| 203 | - model.decodeSamples(tailPaddings) | ||
| 204 | - | ||
| 205 | - println("result is: ${model.text}") | ||
| 206 | - model.reset() | ||
| 207 | - */ | ||
| 208 | } | 195 | } |
| 209 | } | 196 | } |
| @@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig( | @@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig( | ||
| 24 | ) | 24 | ) |
| 25 | 25 | ||
| 26 | data class FeatureConfig( | 26 | data class FeatureConfig( |
| 27 | - var sampleRate: Float = 16000.0f, | 27 | + var sampleRate: Int = 16000, |
| 28 | var featureDim: Int = 80, | 28 | var featureDim: Int = 80, |
| 29 | ) | 29 | ) |
| 30 | 30 | ||
| @@ -32,7 +32,9 @@ data class OnlineRecognizerConfig( | @@ -32,7 +32,9 @@ data class OnlineRecognizerConfig( | ||
| 32 | var featConfig: FeatureConfig = FeatureConfig(), | 32 | var featConfig: FeatureConfig = FeatureConfig(), |
| 33 | var modelConfig: OnlineTransducerModelConfig, | 33 | var modelConfig: OnlineTransducerModelConfig, |
| 34 | var endpointConfig: EndpointConfig = EndpointConfig(), | 34 | var endpointConfig: EndpointConfig = EndpointConfig(), |
| 35 | - var enableEndpoint: Boolean, | 35 | + var enableEndpoint: Boolean = true, |
| 36 | + var decodingMethod: String = "greedy_search", | ||
| 37 | + var maxActivePaths: Int = 4, | ||
| 36 | ) | 38 | ) |
| 37 | 39 | ||
| 38 | class SherpaOnnx( | 40 | class SherpaOnnx( |
| @@ -49,12 +51,14 @@ class SherpaOnnx( | @@ -49,12 +51,14 @@ class SherpaOnnx( | ||
| 49 | } | 51 | } |
| 50 | 52 | ||
| 51 | 53 | ||
| 52 | - fun decodeSamples(samples: FloatArray) = | ||
| 53 | - decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate) | 54 | + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = |
| 55 | + acceptWaveform(ptr, samples, sampleRate) | ||
| 54 | 56 | ||
| 55 | fun inputFinished() = inputFinished(ptr) | 57 | fun inputFinished() = inputFinished(ptr) |
| 56 | fun reset() = reset(ptr) | 58 | fun reset() = reset(ptr) |
| 59 | + fun decode() = decode(ptr) | ||
| 57 | fun isEndpoint(): Boolean = isEndpoint(ptr) | 60 | fun isEndpoint(): Boolean = isEndpoint(ptr) |
| 61 | + fun isReady(): Boolean = isReady(ptr) | ||
| 58 | 62 | ||
| 59 | val text: String | 63 | val text: String |
| 60 | get() = getText(ptr) | 64 | get() = getText(ptr) |
| @@ -66,11 +70,13 @@ class SherpaOnnx( | @@ -66,11 +70,13 @@ class SherpaOnnx( | ||
| 66 | config: OnlineRecognizerConfig, | 70 | config: OnlineRecognizerConfig, |
| 67 | ): Long | 71 | ): Long |
| 68 | 72 | ||
| 69 | - private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float) | 73 | + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) |
| 70 | private external fun inputFinished(ptr: Long) | 74 | private external fun inputFinished(ptr: Long) |
| 71 | private external fun getText(ptr: Long): String | 75 | private external fun getText(ptr: Long): String |
| 72 | private external fun reset(ptr: Long) | 76 | private external fun reset(ptr: Long) |
| 77 | + private external fun decode(ptr: Long) | ||
| 73 | private external fun isEndpoint(ptr: Long): Boolean | 78 | private external fun isEndpoint(ptr: Long): Boolean |
| 79 | + private external fun isReady(ptr: Long): Boolean | ||
| 74 | 80 | ||
| 75 | companion object { | 81 | companion object { |
| 76 | init { | 82 | init { |
| @@ -79,7 +85,7 @@ class SherpaOnnx( | @@ -79,7 +85,7 @@ class SherpaOnnx( | ||
| 79 | } | 85 | } |
| 80 | } | 86 | } |
| 81 | 87 | ||
| 82 | -fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig { | 88 | +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { |
| 83 | return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) | 89 | return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) |
| 84 | } | 90 | } |
| 85 | 91 |
| @@ -23,10 +23,10 @@ extension AVAudioPCMBuffer { | @@ -23,10 +23,10 @@ extension AVAudioPCMBuffer { | ||
| 23 | class ViewController: UIViewController { | 23 | class ViewController: UIViewController { |
| 24 | @IBOutlet weak var resultLabel: UILabel! | 24 | @IBOutlet weak var resultLabel: UILabel! |
| 25 | @IBOutlet weak var recordBtn: UIButton! | 25 | @IBOutlet weak var recordBtn: UIButton! |
| 26 | - | 26 | + |
| 27 | var audioEngine: AVAudioEngine? = nil | 27 | var audioEngine: AVAudioEngine? = nil |
| 28 | var recognizer: SherpaOnnxRecognizer! = nil | 28 | var recognizer: SherpaOnnxRecognizer! = nil |
| 29 | - | 29 | + |
| 30 | /// It saves the decoded results so far | 30 | /// It saves the decoded results so far |
| 31 | var sentences: [String] = [] { | 31 | var sentences: [String] = [] { |
| 32 | didSet { | 32 | didSet { |
| @@ -42,7 +42,7 @@ class ViewController: UIViewController { | @@ -42,7 +42,7 @@ class ViewController: UIViewController { | ||
| 42 | if sentences.isEmpty { | 42 | if sentences.isEmpty { |
| 43 | return "0: \(lastSentence.lowercased())" | 43 | return "0: \(lastSentence.lowercased())" |
| 44 | } | 44 | } |
| 45 | - | 45 | + |
| 46 | let start = max(sentences.count - maxSentence, 0) | 46 | let start = max(sentences.count - maxSentence, 0) |
| 47 | if lastSentence.isEmpty { | 47 | if lastSentence.isEmpty { |
| 48 | return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...] | 48 | return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...] |
| @@ -52,23 +52,23 @@ class ViewController: UIViewController { | @@ -52,23 +52,23 @@ class ViewController: UIViewController { | ||
| 52 | .joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())" | 52 | .joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())" |
| 53 | } | 53 | } |
| 54 | } | 54 | } |
| 55 | - | 55 | + |
| 56 | func updateLabel() { | 56 | func updateLabel() { |
| 57 | DispatchQueue.main.async { | 57 | DispatchQueue.main.async { |
| 58 | self.resultLabel.text = self.results | 58 | self.resultLabel.text = self.results |
| 59 | } | 59 | } |
| 60 | } | 60 | } |
| 61 | - | 61 | + |
| 62 | override func viewDidLoad() { | 62 | override func viewDidLoad() { |
| 63 | super.viewDidLoad() | 63 | super.viewDidLoad() |
| 64 | // Do any additional setup after loading the view. | 64 | // Do any additional setup after loading the view. |
| 65 | - | 65 | + |
| 66 | resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" | 66 | resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" |
| 67 | recordBtn.setTitle("Start", for: .normal) | 67 | recordBtn.setTitle("Start", for: .normal) |
| 68 | initRecognizer() | 68 | initRecognizer() |
| 69 | initRecorder() | 69 | initRecorder() |
| 70 | } | 70 | } |
| 71 | - | 71 | + |
| 72 | @IBAction func onRecordBtnClick(_ sender: UIButton) { | 72 | @IBAction func onRecordBtnClick(_ sender: UIButton) { |
| 73 | if recordBtn.currentTitle == "Start" { | 73 | if recordBtn.currentTitle == "Start" { |
| 74 | startRecorder() | 74 | startRecorder() |
| @@ -78,30 +78,32 @@ class ViewController: UIViewController { | @@ -78,30 +78,32 @@ class ViewController: UIViewController { | ||
| 78 | recordBtn.setTitle("Start", for: .normal) | 78 | recordBtn.setTitle("Start", for: .normal) |
| 79 | } | 79 | } |
| 80 | } | 80 | } |
| 81 | - | 81 | + |
| 82 | func initRecognizer() { | 82 | func initRecognizer() { |
| 83 | // Please select one model that is best suitable for you. | 83 | // Please select one model that is best suitable for you. |
| 84 | // | 84 | // |
| 85 | // You can also modify Model.swift to add new pre-trained models from | 85 | // You can also modify Model.swift to add new pre-trained models from |
| 86 | // https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html | 86 | // https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html |
| 87 | - | 87 | + |
| 88 | let modelConfig = getBilingualStreamZhEnZipformer20230220() | 88 | let modelConfig = getBilingualStreamZhEnZipformer20230220() |
| 89 | - | 89 | + |
| 90 | let featConfig = sherpaOnnxFeatureConfig( | 90 | let featConfig = sherpaOnnxFeatureConfig( |
| 91 | sampleRate: 16000, | 91 | sampleRate: 16000, |
| 92 | featureDim: 80) | 92 | featureDim: 80) |
| 93 | - | 93 | + |
| 94 | var config = sherpaOnnxOnlineRecognizerConfig( | 94 | var config = sherpaOnnxOnlineRecognizerConfig( |
| 95 | featConfig: featConfig, | 95 | featConfig: featConfig, |
| 96 | modelConfig: modelConfig, | 96 | modelConfig: modelConfig, |
| 97 | enableEndpoint: true, | 97 | enableEndpoint: true, |
| 98 | rule1MinTrailingSilence: 2.4, | 98 | rule1MinTrailingSilence: 2.4, |
| 99 | rule2MinTrailingSilence: 0.8, | 99 | rule2MinTrailingSilence: 0.8, |
| 100 | - rule3MinUtteranceLength: 30 | 100 | + rule3MinUtteranceLength: 30, |
| 101 | + decodingMethod: "greedy_search", | ||
| 102 | + maxActivePaths: 4 | ||
| 101 | ) | 103 | ) |
| 102 | recognizer = SherpaOnnxRecognizer(config: &config) | 104 | recognizer = SherpaOnnxRecognizer(config: &config) |
| 103 | } | 105 | } |
| 104 | - | 106 | + |
| 105 | func initRecorder() { | 107 | func initRecorder() { |
| 106 | print("init recorder") | 108 | print("init recorder") |
| 107 | audioEngine = AVAudioEngine() | 109 | audioEngine = AVAudioEngine() |
| @@ -112,9 +114,9 @@ class ViewController: UIViewController { | @@ -112,9 +114,9 @@ class ViewController: UIViewController { | ||
| 112 | commonFormat: .pcmFormatFloat32, | 114 | commonFormat: .pcmFormatFloat32, |
| 113 | sampleRate: 16000, channels: 1, | 115 | sampleRate: 16000, channels: 1, |
| 114 | interleaved: false)! | 116 | interleaved: false)! |
| 115 | - | 117 | + |
| 116 | let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)! | 118 | let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)! |
| 117 | - | 119 | + |
| 118 | inputNode!.installTap( | 120 | inputNode!.installTap( |
| 119 | onBus: bus, | 121 | onBus: bus, |
| 120 | bufferSize: 1024, | 122 | bufferSize: 1024, |
| @@ -122,34 +124,34 @@ class ViewController: UIViewController { | @@ -122,34 +124,34 @@ class ViewController: UIViewController { | ||
| 122 | ) { | 124 | ) { |
| 123 | (buffer: AVAudioPCMBuffer, when: AVAudioTime) in | 125 | (buffer: AVAudioPCMBuffer, when: AVAudioTime) in |
| 124 | var newBufferAvailable = true | 126 | var newBufferAvailable = true |
| 125 | - | 127 | + |
| 126 | let inputCallback: AVAudioConverterInputBlock = { | 128 | let inputCallback: AVAudioConverterInputBlock = { |
| 127 | inNumPackets, outStatus in | 129 | inNumPackets, outStatus in |
| 128 | if newBufferAvailable { | 130 | if newBufferAvailable { |
| 129 | outStatus.pointee = .haveData | 131 | outStatus.pointee = .haveData |
| 130 | newBufferAvailable = false | 132 | newBufferAvailable = false |
| 131 | - | 133 | + |
| 132 | return buffer | 134 | return buffer |
| 133 | } else { | 135 | } else { |
| 134 | outStatus.pointee = .noDataNow | 136 | outStatus.pointee = .noDataNow |
| 135 | return nil | 137 | return nil |
| 136 | } | 138 | } |
| 137 | } | 139 | } |
| 138 | - | 140 | + |
| 139 | let convertedBuffer = AVAudioPCMBuffer( | 141 | let convertedBuffer = AVAudioPCMBuffer( |
| 140 | pcmFormat: outputFormat, | 142 | pcmFormat: outputFormat, |
| 141 | frameCapacity: | 143 | frameCapacity: |
| 142 | AVAudioFrameCount(outputFormat.sampleRate) | 144 | AVAudioFrameCount(outputFormat.sampleRate) |
| 143 | * buffer.frameLength | 145 | * buffer.frameLength |
| 144 | / AVAudioFrameCount(buffer.format.sampleRate))! | 146 | / AVAudioFrameCount(buffer.format.sampleRate))! |
| 145 | - | 147 | + |
| 146 | var error: NSError? | 148 | var error: NSError? |
| 147 | let _ = converter.convert( | 149 | let _ = converter.convert( |
| 148 | to: convertedBuffer, | 150 | to: convertedBuffer, |
| 149 | error: &error, withInputFrom: inputCallback) | 151 | error: &error, withInputFrom: inputCallback) |
| 150 | - | 152 | + |
| 151 | // TODO(fangjun): Handle status != haveData | 153 | // TODO(fangjun): Handle status != haveData |
| 152 | - | 154 | + |
| 153 | let array = convertedBuffer.array() | 155 | let array = convertedBuffer.array() |
| 154 | if !array.isEmpty { | 156 | if !array.isEmpty { |
| 155 | self.recognizer.acceptWaveform(samples: array) | 157 | self.recognizer.acceptWaveform(samples: array) |
| @@ -158,13 +160,13 @@ class ViewController: UIViewController { | @@ -158,13 +160,13 @@ class ViewController: UIViewController { | ||
| 158 | } | 160 | } |
| 159 | let isEndpoint = self.recognizer.isEndpoint() | 161 | let isEndpoint = self.recognizer.isEndpoint() |
| 160 | let text = self.recognizer.getResult().text | 162 | let text = self.recognizer.getResult().text |
| 161 | - | 163 | + |
| 162 | if !text.isEmpty && self.lastSentence != text { | 164 | if !text.isEmpty && self.lastSentence != text { |
| 163 | self.lastSentence = text | 165 | self.lastSentence = text |
| 164 | self.updateLabel() | 166 | self.updateLabel() |
| 165 | print(text) | 167 | print(text) |
| 166 | } | 168 | } |
| 167 | - | 169 | + |
| 168 | if isEndpoint { | 170 | if isEndpoint { |
| 169 | if !text.isEmpty { | 171 | if !text.isEmpty { |
| 170 | let tmp = self.lastSentence | 172 | let tmp = self.lastSentence |
| @@ -175,13 +177,13 @@ class ViewController: UIViewController { | @@ -175,13 +177,13 @@ class ViewController: UIViewController { | ||
| 175 | } | 177 | } |
| 176 | } | 178 | } |
| 177 | } | 179 | } |
| 178 | - | 180 | + |
| 179 | } | 181 | } |
| 180 | - | 182 | + |
| 181 | func startRecorder() { | 183 | func startRecorder() { |
| 182 | lastSentence = "" | 184 | lastSentence = "" |
| 183 | sentences = [] | 185 | sentences = [] |
| 184 | - | 186 | + |
| 185 | do { | 187 | do { |
| 186 | try self.audioEngine?.start() | 188 | try self.audioEngine?.start() |
| 187 | } catch let error as NSError { | 189 | } catch let error as NSError { |
| @@ -189,7 +191,7 @@ class ViewController: UIViewController { | @@ -189,7 +191,7 @@ class ViewController: UIViewController { | ||
| 189 | } | 191 | } |
| 190 | print("started") | 192 | print("started") |
| 191 | } | 193 | } |
| 192 | - | 194 | + |
| 193 | func stopRecorder() { | 195 | func stopRecorder() { |
| 194 | audioEngine?.stop() | 196 | audioEngine?.stop() |
| 195 | print("stopped") | 197 | print("stopped") |
| @@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream( | @@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream( | ||
| 76 | 76 | ||
| 77 | void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } | 77 | void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } |
| 78 | 78 | ||
| 79 | -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, | 79 | +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate, |
| 80 | const float *samples, int32_t n) { | 80 | const float *samples, int32_t n) { |
| 81 | stream->impl->AcceptWaveform(sample_rate, samples, n); | 81 | stream->impl->AcceptWaveform(sample_rate, samples, n); |
| 82 | } | 82 | } |
| @@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); | @@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); | ||
| 120 | /// @param samples A pointer to a 1-D array containing audio samples. | 120 | /// @param samples A pointer to a 1-D array containing audio samples. |
| 121 | /// The range of samples has to be normalized to [-1, 1]. | 121 | /// The range of samples has to be normalized to [-1, 1]. |
| 122 | /// @param n Number of elements in the samples array. | 122 | /// @param n Number of elements in the samples array. |
| 123 | -void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, | 123 | +void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate, |
| 124 | const float *samples, int32_t n); | 124 | const float *samples, int32_t n); |
| 125 | 125 | ||
| 126 | /// Return 1 if there are enough number of feature frames for decoding. | 126 | /// Return 1 if there are enough number of feature frames for decoding. |
| @@ -48,7 +48,7 @@ class FeatureExtractor::Impl { | @@ -48,7 +48,7 @@ class FeatureExtractor::Impl { | ||
| 48 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 48 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 49 | } | 49 | } |
| 50 | 50 | ||
| 51 | - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { | 51 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 52 | std::lock_guard<std::mutex> lock(mutex_); | 52 | std::lock_guard<std::mutex> lock(mutex_); |
| 53 | fbank_->AcceptWaveform(sampling_rate, waveform, n); | 53 | fbank_->AcceptWaveform(sampling_rate, waveform, n); |
| 54 | } | 54 | } |
| @@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | @@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | ||
| 107 | 107 | ||
| 108 | FeatureExtractor::~FeatureExtractor() = default; | 108 | FeatureExtractor::~FeatureExtractor() = default; |
| 109 | 109 | ||
| 110 | -void FeatureExtractor::AcceptWaveform(float sampling_rate, | 110 | +void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, |
| 111 | const float *waveform, int32_t n) { | 111 | const float *waveform, int32_t n) { |
| 112 | impl_->AcceptWaveform(sampling_rate, waveform, n); | 112 | impl_->AcceptWaveform(sampling_rate, waveform, n); |
| 113 | } | 113 | } |
| @@ -14,7 +14,7 @@ | @@ -14,7 +14,7 @@ | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | struct FeatureExtractorConfig { | 16 | struct FeatureExtractorConfig { |
| 17 | - float sampling_rate = 16000; | 17 | + int32_t sampling_rate = 16000; |
| 18 | int32_t feature_dim = 80; | 18 | int32_t feature_dim = 80; |
| 19 | int32_t max_feature_vectors = -1; | 19 | int32_t max_feature_vectors = -1; |
| 20 | 20 | ||
| @@ -34,7 +34,7 @@ class FeatureExtractor { | @@ -34,7 +34,7 @@ class FeatureExtractor { | ||
| 34 | @param waveform Pointer to a 1-D array of size n | 34 | @param waveform Pointer to a 1-D array of size n |
| 35 | @param n Number of entries in waveform | 35 | @param n Number of entries in waveform |
| 36 | */ | 36 | */ |
| 37 | - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); | 37 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); |
| 38 | 38 | ||
| 39 | /** | 39 | /** |
| 40 | * InputFinished() tells the class you won't be providing any | 40 | * InputFinished() tells the class you won't be providing any |
| @@ -112,7 +112,7 @@ for a list of pre-trained models to download. | @@ -112,7 +112,7 @@ for a list of pre-trained models to download. | ||
| 112 | 112 | ||
| 113 | param.suggestedLatency = info->defaultLowInputLatency; | 113 | param.suggestedLatency = info->defaultLowInputLatency; |
| 114 | param.hostApiSpecificStreamInfo = nullptr; | 114 | param.hostApiSpecificStreamInfo = nullptr; |
| 115 | - const float sample_rate = 16000; | 115 | + float sample_rate = 16000; |
| 116 | 116 | ||
| 117 | PaStream *stream; | 117 | PaStream *stream; |
| 118 | PaError err = | 118 | PaError err = |
| @@ -61,7 +61,7 @@ for a list of pre-trained models to download. | @@ -61,7 +61,7 @@ for a list of pre-trained models to download. | ||
| 61 | 61 | ||
| 62 | sherpa_onnx::OnlineRecognizer recognizer(config); | 62 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 63 | 63 | ||
| 64 | - float expected_sampling_rate = config.feat_config.sampling_rate; | 64 | + int32_t expected_sampling_rate = config.feat_config.sampling_rate; |
| 65 | 65 | ||
| 66 | bool is_ok = false; | 66 | bool is_ok = false; |
| 67 | std::vector<float> samples = | 67 | std::vector<float> samples = |
| @@ -72,7 +72,7 @@ for a list of pre-trained models to download. | @@ -72,7 +72,7 @@ for a list of pre-trained models to download. | ||
| 72 | return -1; | 72 | return -1; |
| 73 | } | 73 | } |
| 74 | 74 | ||
| 75 | - float duration = samples.size() / expected_sampling_rate; | 75 | + float duration = samples.size() / static_cast<float>(expected_sampling_rate); |
| 76 | 76 | ||
| 77 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | 77 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); |
| 78 | fprintf(stderr, "wav duration (s): %.3f\n", duration); | 78 | fprintf(stderr, "wav duration (s): %.3f\n", duration); |
| @@ -40,19 +40,18 @@ class SherpaOnnx { | @@ -40,19 +40,18 @@ class SherpaOnnx { | ||
| 40 | mgr, | 40 | mgr, |
| 41 | #endif | 41 | #endif |
| 42 | config), | 42 | config), |
| 43 | - stream_(recognizer_.CreateStream()), | ||
| 44 | - tail_padding_(16000 * 0.32, 0) { | 43 | + stream_(recognizer_.CreateStream()) { |
| 45 | } | 44 | } |
| 46 | 45 | ||
| 47 | - void DecodeSamples(float sample_rate, const float *samples, int32_t n) const { | 46 | + void AcceptWaveform(int32_t sample_rate, const float *samples, |
| 47 | + int32_t n) const { | ||
| 48 | stream_->AcceptWaveform(sample_rate, samples, n); | 48 | stream_->AcceptWaveform(sample_rate, samples, n); |
| 49 | - Decode(); | ||
| 50 | } | 49 | } |
| 51 | 50 | ||
| 52 | void InputFinished() const { | 51 | void InputFinished() const { |
| 53 | - stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size()); | 52 | + std::vector<float> tail_padding(16000 * 0.32, 0); |
| 53 | + stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size()); | ||
| 54 | stream_->InputFinished(); | 54 | stream_->InputFinished(); |
| 55 | - Decode(); | ||
| 56 | } | 55 | } |
| 57 | 56 | ||
| 58 | const std::string GetText() const { | 57 | const std::string GetText() const { |
| @@ -62,19 +61,15 @@ class SherpaOnnx { | @@ -62,19 +61,15 @@ class SherpaOnnx { | ||
| 62 | 61 | ||
| 63 | bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } | 62 | bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } |
| 64 | 63 | ||
| 64 | + bool IsReady() const { return recognizer_.IsReady(stream_.get()); } | ||
| 65 | + | ||
| 65 | void Reset() const { return recognizer_.Reset(stream_.get()); } | 66 | void Reset() const { return recognizer_.Reset(stream_.get()); } |
| 66 | 67 | ||
| 67 | - private: | ||
| 68 | - void Decode() const { | ||
| 69 | - while (recognizer_.IsReady(stream_.get())) { | ||
| 70 | - recognizer_.DecodeStream(stream_.get()); | ||
| 71 | - } | ||
| 72 | - } | 68 | + void Decode() const { recognizer_.DecodeStream(stream_.get()); } |
| 73 | 69 | ||
| 74 | private: | 70 | private: |
| 75 | sherpa_onnx::OnlineRecognizer recognizer_; | 71 | sherpa_onnx::OnlineRecognizer recognizer_; |
| 76 | std::unique_ptr<sherpa_onnx::OnlineStream> stream_; | 72 | std::unique_ptr<sherpa_onnx::OnlineStream> stream_; |
| 77 | - std::vector<float> tail_padding_; | ||
| 78 | }; | 73 | }; |
| 79 | 74 | ||
| 80 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | 75 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { |
| @@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 86 | // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html | 81 | // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html |
| 87 | // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html | 82 | // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html |
| 88 | 83 | ||
| 84 | + //---------- decoding ---------- | ||
| 85 | + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); | ||
| 86 | + jstring s = (jstring)env->GetObjectField(config, fid); | ||
| 87 | + const char *p = env->GetStringUTFChars(s, nullptr); | ||
| 88 | + ans.decoding_method = p; | ||
| 89 | + env->ReleaseStringUTFChars(s, p); | ||
| 90 | + | ||
| 91 | + fid = env->GetFieldID(cls, "maxActivePaths", "I"); | ||
| 92 | + ans.max_active_paths = env->GetIntField(config, fid); | ||
| 93 | + | ||
| 89 | //---------- feat config ---------- | 94 | //---------- feat config ---------- |
| 90 | fid = env->GetFieldID(cls, "featConfig", | 95 | fid = env->GetFieldID(cls, "featConfig", |
| 91 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); | 96 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); |
| 92 | jobject feat_config = env->GetObjectField(config, fid); | 97 | jobject feat_config = env->GetObjectField(config, fid); |
| 93 | jclass feat_config_cls = env->GetObjectClass(feat_config); | 98 | jclass feat_config_cls = env->GetObjectClass(feat_config); |
| 94 | 99 | ||
| 95 | - fid = env->GetFieldID(feat_config_cls, "sampleRate", "F"); | ||
| 96 | - ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid); | 100 | + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); |
| 101 | + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); | ||
| 97 | 102 | ||
| 98 | fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); | 103 | fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); |
| 99 | ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); | 104 | ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); |
| @@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 153 | jclass model_config_cls = env->GetObjectClass(model_config); | 158 | jclass model_config_cls = env->GetObjectClass(model_config); |
| 154 | 159 | ||
| 155 | fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); | 160 | fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); |
| 156 | - jstring s = (jstring)env->GetObjectField(model_config, fid); | ||
| 157 | - const char *p = env->GetStringUTFChars(s, nullptr); | 161 | + s = (jstring)env->GetObjectField(model_config, fid); |
| 162 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 158 | ans.model_config.encoder_filename = p; | 163 | ans.model_config.encoder_filename = p; |
| 159 | env->ReleaseStringUTFChars(s, p); | 164 | env->ReleaseStringUTFChars(s, p); |
| 160 | 165 | ||
| @@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( | @@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( | ||
| 198 | #endif | 203 | #endif |
| 199 | 204 | ||
| 200 | auto config = sherpa_onnx::GetConfig(env, _config); | 205 | auto config = sherpa_onnx::GetConfig(env, _config); |
| 206 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 201 | auto model = new sherpa_onnx::SherpaOnnx( | 207 | auto model = new sherpa_onnx::SherpaOnnx( |
| 202 | #if __ANDROID_API__ >= 9 | 208 | #if __ANDROID_API__ >= 9 |
| 203 | mgr, | 209 | mgr, |
| @@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( | @@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( | ||
| 221 | } | 227 | } |
| 222 | 228 | ||
| 223 | SHERPA_ONNX_EXTERN_C | 229 | SHERPA_ONNX_EXTERN_C |
| 230 | +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady( | ||
| 231 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 232 | + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); | ||
| 233 | + return model->IsReady(); | ||
| 234 | +} | ||
| 235 | + | ||
| 236 | +SHERPA_ONNX_EXTERN_C | ||
| 224 | JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( | 237 | JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( |
| 225 | JNIEnv *env, jobject /*obj*/, jlong ptr) { | 238 | JNIEnv *env, jobject /*obj*/, jlong ptr) { |
| 226 | auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); | 239 | auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); |
| @@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( | @@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( | ||
| 228 | } | 241 | } |
| 229 | 242 | ||
| 230 | SHERPA_ONNX_EXTERN_C | 243 | SHERPA_ONNX_EXTERN_C |
| 231 | -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples( | 244 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode( |
| 245 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 246 | + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); | ||
| 247 | + model->Decode(); | ||
| 248 | +} | ||
| 249 | + | ||
| 250 | +SHERPA_ONNX_EXTERN_C | ||
| 251 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform( | ||
| 232 | JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, | 252 | JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, |
| 233 | - jfloat sample_rate) { | 253 | + jint sample_rate) { |
| 234 | auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); | 254 | auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); |
| 235 | 255 | ||
| 236 | jfloat *p = env->GetFloatArrayElements(samples, nullptr); | 256 | jfloat *p = env->GetFloatArrayElements(samples, nullptr); |
| 237 | jsize n = env->GetArrayLength(samples); | 257 | jsize n = env->GetArrayLength(samples); |
| 238 | 258 | ||
| 239 | - model->DecodeSamples(sample_rate, p, n); | 259 | + model->AcceptWaveform(sample_rate, p, n); |
| 240 | 260 | ||
| 241 | env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | 261 | env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); |
| 242 | } | 262 | } |
| @@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig( | @@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig( | ||
| 62 | enableEndpoint: Bool = false, | 62 | enableEndpoint: Bool = false, |
| 63 | rule1MinTrailingSilence: Float = 2.4, | 63 | rule1MinTrailingSilence: Float = 2.4, |
| 64 | rule2MinTrailingSilence: Float = 1.2, | 64 | rule2MinTrailingSilence: Float = 1.2, |
| 65 | - rule3MinUtteranceLength: Float = 30 | 65 | + rule3MinUtteranceLength: Float = 30, |
| 66 | + decodingMethod: String = "greedy_search", | ||
| 67 | + maxActivePaths: Int = 4 | ||
| 66 | ) -> SherpaOnnxOnlineRecognizerConfig{ | 68 | ) -> SherpaOnnxOnlineRecognizerConfig{ |
| 67 | return SherpaOnnxOnlineRecognizerConfig( | 69 | return SherpaOnnxOnlineRecognizerConfig( |
| 68 | feat_config: featConfig, | 70 | feat_config: featConfig, |
| 69 | model_config: modelConfig, | 71 | model_config: modelConfig, |
| 72 | + decoding_method: toCPointer(decodingMethod), | ||
| 73 | + max_active_paths: Int32(maxActivePaths), | ||
| 70 | enable_endpoint: enableEndpoint ? 1 : 0, | 74 | enable_endpoint: enableEndpoint ? 1 : 0, |
| 71 | rule1_min_trailing_silence: rule1MinTrailingSilence, | 75 | rule1_min_trailing_silence: rule1MinTrailingSilence, |
| 72 | rule2_min_trailing_silence: rule2MinTrailingSilence, | 76 | rule2_min_trailing_silence: rule2MinTrailingSilence, |
| @@ -128,12 +132,12 @@ class SherpaOnnxRecognizer { | @@ -128,12 +132,12 @@ class SherpaOnnxRecognizer { | ||
| 128 | /// Decode wave samples. | 132 | /// Decode wave samples. |
| 129 | /// | 133 | /// |
| 130 | /// - Parameters: | 134 | /// - Parameters: |
| 131 | - /// - samples: Audio samples normalzed to the range [-1, 1] | 135 | + /// - samples: Audio samples normalized to the range [-1, 1] |
| 132 | /// - sampleRate: Sample rate of the input audio samples. Must match | 136 | /// - sampleRate: Sample rate of the input audio samples. Must match |
| 133 | /// the one expected by the model. It must be 16000 for | 137 | /// the one expected by the model. It must be 16000 for |
| 134 | /// models from icefall. | 138 | /// models from icefall. |
| 135 | - func acceptWaveform(samples: [Float], sampleRate: Float = 16000) { | ||
| 136 | - AcceptWaveform(stream, sampleRate, samples, Int32(samples.count)) | 139 | + func acceptWaveform(samples: [Float], sampleRate: Int = 16000) { |
| 140 | + AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count)) | ||
| 137 | } | 141 | } |
| 138 | 142 | ||
| 139 | func isReady() -> Bool { | 143 | func isReady() -> Bool { |
| @@ -32,7 +32,9 @@ func run() { | @@ -32,7 +32,9 @@ func run() { | ||
| 32 | var config = sherpaOnnxOnlineRecognizerConfig( | 32 | var config = sherpaOnnxOnlineRecognizerConfig( |
| 33 | featConfig: featConfig, | 33 | featConfig: featConfig, |
| 34 | modelConfig: modelConfig, | 34 | modelConfig: modelConfig, |
| 35 | - enableEndpoint: false | 35 | + enableEndpoint: false, |
| 36 | + decodingMethod: "modified_beam_search", | ||
| 37 | + maxActivePaths: 4 | ||
| 36 | ) | 38 | ) |
| 37 | 39 | ||
| 38 | 40 |
-
请 注册 或 登录 后发表评论