yujinqiu
Committed by GitHub

Add vad clear api for better performance (#366)

* Add vad clear api for better performance

* rename to make naming consistent and remove macro

* Fix linker error

* Fix Vad.kt
@@ -161,9 +161,9 @@ class MainActivity : AppCompatActivity() { @@ -161,9 +161,9 @@ class MainActivity : AppCompatActivity() {
161 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 161 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
162 162
163 vad.acceptWaveform(samples) 163 vad.acceptWaveform(samples)
164 - while(!vad.empty()) {vad.pop();}  
165 164
166 val isSpeechDetected = vad.isSpeechDetected() 165 val isSpeechDetected = vad.isSpeechDetected()
  166 + vad.clear()
167 167
168 runOnUiThread { 168 runOnUiThread {
169 onVad(isSpeechDetected) 169 onVad(isSpeechDetected)
@@ -46,6 +46,8 @@ class Vad( @@ -46,6 +46,8 @@ class Vad(
46 // [start: Int, samples: FloatArray] 46 // [start: Int, samples: FloatArray]
47 fun front() = front(ptr) 47 fun front() = front(ptr)
48 48
  49 + fun clear() = clear(ptr)
  50 +
49 fun isSpeechDetected(): Boolean = isSpeechDetected(ptr) 51 fun isSpeechDetected(): Boolean = isSpeechDetected(ptr)
50 52
51 fun reset() = reset(ptr) 53 fun reset() = reset(ptr)
@@ -64,6 +66,7 @@ class Vad( @@ -64,6 +66,7 @@ class Vad(
64 private external fun acceptWaveform(ptr: Long, samples: FloatArray) 66 private external fun acceptWaveform(ptr: Long, samples: FloatArray)
65 private external fun empty(ptr: Long): Boolean 67 private external fun empty(ptr: Long): Boolean
66 private external fun pop(ptr: Long) 68 private external fun pop(ptr: Long)
  69 + private external fun clear(ptr: Long)
67 private external fun front(ptr: Long): Array<Any> 70 private external fun front(ptr: Long): Array<Any>
68 private external fun isSpeechDetected(ptr: Long): Boolean 71 private external fun isSpeechDetected(ptr: Long): Boolean
69 private external fun reset(ptr: Long) 72 private external fun reset(ptr: Long)
@@ -493,12 +493,17 @@ int32_t SherpaOnnxVoiceActivityDetectorDetected( @@ -493,12 +493,17 @@ int32_t SherpaOnnxVoiceActivityDetectorDetected(
493 return p->impl->IsSpeechDetected(); 493 return p->impl->IsSpeechDetected();
494 } 494 }
495 495
496 -SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop( 496 +void SherpaOnnxVoiceActivityDetectorPop(
497 SherpaOnnxVoiceActivityDetector *p) { 497 SherpaOnnxVoiceActivityDetector *p) {
498 p->impl->Pop(); 498 p->impl->Pop();
499 } 499 }
500 500
501 -SHERPA_ONNX_API const SherpaOnnxSpeechSegment * 501 +void SherpaOnnxVoiceActivityDetectorClear(
  502 + SherpaOnnxVoiceActivityDetector *p) {
  503 + p->impl->Clear();
  504 +}
  505 +
  506 +const SherpaOnnxSpeechSegment *
502 SherpaOnnxVoiceActivityDetectorFront(SherpaOnnxVoiceActivityDetector *p) { 507 SherpaOnnxVoiceActivityDetectorFront(SherpaOnnxVoiceActivityDetector *p) {
503 const sherpa_onnx::SpeechSegment &segment = p->impl->Front(); 508 const sherpa_onnx::SpeechSegment &segment = p->impl->Front();
504 509
@@ -580,6 +580,10 @@ SherpaOnnxVoiceActivityDetectorDetected(SherpaOnnxVoiceActivityDetector *p); @@ -580,6 +580,10 @@ SherpaOnnxVoiceActivityDetectorDetected(SherpaOnnxVoiceActivityDetector *p);
580 SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop( 580 SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop(
581 SherpaOnnxVoiceActivityDetector *p); 581 SherpaOnnxVoiceActivityDetector *p);
582 582
  583 +// Clear current speech segments.
  584 +SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorClear(
  585 + SherpaOnnxVoiceActivityDetector *p);
  586 +
583 // Return the first speech segment. 587 // Return the first speech segment.
584 // The user has to use SherpaOnnxDestroySpeechSegment() to free the returned 588 // The user has to use SherpaOnnxDestroySpeechSegment() to free the returned
585 // pointer to avoid memory leak. 589 // pointer to avoid memory leak.
@@ -76,6 +76,8 @@ class VoiceActivityDetector::Impl { @@ -76,6 +76,8 @@ class VoiceActivityDetector::Impl {
76 76
77 void Pop() { segments_.pop(); } 77 void Pop() { segments_.pop(); }
78 78
  79 + void Clear() { std::queue<SpeechSegment>().swap(segments_); }
  80 +
79 const SpeechSegment &Front() const { return segments_.front(); } 81 const SpeechSegment &Front() const { return segments_.front(); }
80 82
81 void Reset() { 83 void Reset() {
@@ -121,6 +123,8 @@ bool VoiceActivityDetector::Empty() const { return impl_->Empty(); } @@ -121,6 +123,8 @@ bool VoiceActivityDetector::Empty() const { return impl_->Empty(); }
121 123
122 void VoiceActivityDetector::Pop() { impl_->Pop(); } 124 void VoiceActivityDetector::Pop() { impl_->Pop(); }
123 125
  126 +void VoiceActivityDetector::Clear() { impl_->Clear(); }
  127 +
124 const SpeechSegment &VoiceActivityDetector::Front() const { 128 const SpeechSegment &VoiceActivityDetector::Front() const {
125 return impl_->Front(); 129 return impl_->Front();
126 } 130 }
@@ -36,6 +36,7 @@ class VoiceActivityDetector { @@ -36,6 +36,7 @@ class VoiceActivityDetector {
36 void AcceptWaveform(const float *samples, int32_t n); 36 void AcceptWaveform(const float *samples, int32_t n);
37 bool Empty() const; 37 bool Empty() const;
38 void Pop(); 38 void Pop();
  39 + void Clear();
39 const SpeechSegment &Front() const; 40 const SpeechSegment &Front() const;
40 41
41 bool IsSpeechDetected() const; 42 bool IsSpeechDetected() const;
@@ -124,6 +124,8 @@ class SherpaOnnxVad { @@ -124,6 +124,8 @@ class SherpaOnnxVad {
124 124
125 void Pop() { vad_.Pop(); } 125 void Pop() { vad_.Pop(); }
126 126
  127 + void Clear() { vad_.Clear();}
  128 +
127 const SpeechSegment &Front() const { return vad_.Front(); } 129 const SpeechSegment &Front() const { return vad_.Front(); }
128 130
129 bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); } 131 bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }
@@ -556,6 +558,14 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env, @@ -556,6 +558,14 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
556 model->Pop(); 558 model->Pop();
557 } 559 }
558 560
  561 +SHERPA_ONNX_EXTERN_C
  562 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
  563 + jobject /*obj*/,
  564 + jlong ptr) {
  565 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
  566 + model->Clear();
  567 +}
  568 +
559 // see 569 // see
560 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables 570 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
561 static jobject NewInteger(JNIEnv *env, int32_t value) { 571 static jobject NewInteger(JNIEnv *env, int32_t value) {
@@ -551,7 +551,7 @@ class SherpaOnnxVoiceActivityDetectorWrapper { @@ -551,7 +551,7 @@ class SherpaOnnxVoiceActivityDetectorWrapper {
551 return SherpaOnnxVoiceActivityDetectorEmpty(vad) == 1 551 return SherpaOnnxVoiceActivityDetectorEmpty(vad) == 1
552 } 552 }
553 553
554 - func isDetected() -> Bool { 554 + func isSpeechDetected() -> Bool {
555 return SherpaOnnxVoiceActivityDetectorDetected(vad) == 1 555 return SherpaOnnxVoiceActivityDetectorDetected(vad) == 1
556 } 556 }
557 557
@@ -559,6 +559,10 @@ class SherpaOnnxVoiceActivityDetectorWrapper { @@ -559,6 +559,10 @@ class SherpaOnnxVoiceActivityDetectorWrapper {
559 SherpaOnnxVoiceActivityDetectorPop(vad) 559 SherpaOnnxVoiceActivityDetectorPop(vad)
560 } 560 }
561 561
  562 + func clear() {
  563 + SherpaOnnxVoiceActivityDetectorClear(vad)
  564 + }
  565 +
562 func front() -> SherpaOnnxSpeechSegmentWrapper { 566 func front() -> SherpaOnnxSpeechSegmentWrapper {
563 let p: UnsafePointer<SherpaOnnxSpeechSegment>? = SherpaOnnxVoiceActivityDetectorFront(vad) 567 let p: UnsafePointer<SherpaOnnxSpeechSegment>? = SherpaOnnxVoiceActivityDetectorFront(vad)
564 return SherpaOnnxSpeechSegmentWrapper(p: p) 568 return SherpaOnnxSpeechSegmentWrapper(p: p)
@@ -174,32 +174,31 @@ func run() { @@ -174,32 +174,31 @@ func run() {
174 174
175 var segments: [SpeechSegment] = [] 175 var segments: [SpeechSegment] = []
176 176
177 - while array.count > windowSize {  
178 - // todo(fangjun): avoid extra copies here  
179 - vad.acceptWaveform(samples: [Float](array[0..<windowSize]))  
180 - array = [Float](array[windowSize..<array.count])  
181 -  
182 - while !vad.isEmpty() {  
183 - let s = vad.front()  
184 - vad.pop()  
185 - let result = recognizer.decode(samples: s.samples) 177 + for offset in stride(from: 0, to: array.count, by: windowSize) {
  178 + let end = min(offset + windowSize, array.count)
  179 + vad.acceptWaveform(samples: [Float](array[offset ..< end]))
  180 + }
186 181
187 - segments.append(  
188 - SpeechSegment(  
189 - start: Float(s.start) / Float(sampleRate),  
190 - duration: Float(s.samples.count) / Float(sampleRate),  
191 - text: result.text)) 182 + var index: Int = 0
  183 + while !vad.isEmpty() {
  184 + let s = vad.front()
  185 + vad.pop()
  186 + let result = recognizer.decode(samples: s.samples)
192 187
193 - print(segments.last!) 188 + segments.append(
  189 + SpeechSegment(
  190 + start: Float(s.start) / Float(sampleRate),
  191 + duration: Float(s.samples.count) / Float(sampleRate),
  192 + text: result.text))
194 193
195 - } 194 + print(segments.last!)
196 } 195 }
197 196
198 - let srt = zip(segments.indices, segments).map { (index, element) in 197 + let srt: String = zip(segments.indices, segments).map { (index, element) in
199 return "\(index+1)\n\(element)" 198 return "\(index+1)\n\(element)"
200 }.joined(separator: "\n\n") 199 }.joined(separator: "\n\n")
201 200
202 - let srtFilename = filePath.stringByDeletingPathExtension + ".srt" 201 + let srtFilename: String = filePath.stringByDeletingPathExtension + ".srt"
203 do { 202 do {
204 try srt.write(to: srtFilename.fileURL, atomically: true, encoding: .utf8) 203 try srt.write(to: srtFilename.fileURL, atomically: true, encoding: .utf8)
205 } catch { 204 } catch {