Fangjun Kuang
Committed by GitHub

Add Dart API for ten-vad (#2386)

@@ -4,6 +4,12 @@ set -ex @@ -4,6 +4,12 @@ set -ex
4 4
5 cd dart-api-examples 5 cd dart-api-examples
6 6
  7 +pushd vad
  8 +./run-ten-vad.sh
  9 +./run.sh
  10 +rm *.onnx
  11 +popd
  12 +
7 pushd non-streaming-asr 13 pushd non-streaming-asr
8 14
9 echo '----------Zipformer CTC----------' 15 echo '----------Zipformer CTC----------'
@@ -186,9 +192,3 @@ echo '----------streaming paraformer----------' @@ -186,9 +192,3 @@ echo '----------streaming paraformer----------'
186 rm -rf sherpa-onnx-* 192 rm -rf sherpa-onnx-*
187 193
188 popd # streaming-asr 194 popd # streaming-asr
189 -  
190 -pushd vad  
191 -./run.sh  
192 -rm *.onnx  
193 -popd  
194 -  
  1 +// Copyright (c) 2024 Xiaomi Corporation
  2 +import 'dart:io';
  3 +import 'dart:typed_data';
  4 +
  5 +import 'package:args/args.dart';
  6 +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
  7 +import './init.dart';
  8 +
  9 +void main(List<String> arguments) async {
  10 + await initSherpaOnnx();
  11 +
  12 + final parser = ArgParser()
  13 + ..addOption('ten-vad', help: 'Path to ten-vad.onnx')
  14 + ..addOption('input-wav', help: 'Path to input.wav')
  15 + ..addOption('output-wav', help: 'Path to output.wav');
  16 +
  17 + final res = parser.parse(arguments);
  18 + if (res['ten-vad'] == null ||
  19 + res['input-wav'] == null ||
  20 + res['output-wav'] == null) {
  21 + print(parser.usage);
  22 + exit(1);
  23 + }
  24 +
  25 + final tenVad = res['ten-vad'] as String;
  26 + final inputWav = res['input-wav'] as String;
  27 + final outputWav = res['output-wav'] as String;
  28 +
  29 + final tenVadConfig = sherpa_onnx.TenVadModelConfig(
  30 + model: tenVad,
  31 + threshold: 0.25,
  32 + minSilenceDuration: 0.25,
  33 + minSpeechDuration: 0.5,
  34 + windowSize: 256,
  35 + );
  36 +
  37 + final config = sherpa_onnx.VadModelConfig(
  38 + tenVad: tenVadConfig,
  39 + numThreads: 1,
  40 + debug: true,
  41 + );
  42 +
  43 + final vad = sherpa_onnx.VoiceActivityDetector(
  44 + config: config, bufferSizeInSeconds: 10);
  45 +
  46 + final waveData = sherpa_onnx.readWave(inputWav);
  47 + if (waveData.sampleRate != 16000) {
  48 + print('Only 16000 Hz is supported. Given: ${waveData.sampleRate}');
  49 + exit(1);
  50 + }
  51 +
  52 + int numSamples = waveData.samples.length;
  53 + int numIter = numSamples ~/ config.tenVad.windowSize;
  54 +
  55 + List<List<double>> allSamples = [];
  56 +
  57 + for (int i = 0; i != numIter; ++i) {
  58 + int start = i * config.tenVad.windowSize;
  59 + vad.acceptWaveform(Float32List.sublistView(
  60 + waveData.samples, start, start + config.tenVad.windowSize));
  61 +
  62 + if (vad.isDetected()) {
  63 + while (!vad.isEmpty()) {
  64 + allSamples.add(vad.front().samples);
  65 + vad.pop();
  66 + }
  67 + }
  68 + }
  69 +
  70 + vad.flush();
  71 + while (!vad.isEmpty()) {
  72 + allSamples.add(vad.front().samples);
  73 + vad.pop();
  74 + }
  75 +
  76 + vad.free();
  77 +
  78 + final s = Float32List.fromList(allSamples.expand((x) => x).toList());
  79 + sherpa_onnx.writeWave(
  80 + filename: outputWav, samples: s, sampleRate: waveData.sampleRate);
  81 +
  82 + print('Saved to $outputWav');
  83 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +dart pub get
  6 +
  7 +
  8 +if [[ ! -f ./ten-vad.onnx ]]; then
  9 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
  10 +fi
  11 +
  12 +if [[ ! -f ./lei-jun-test.wav ]]; then
  13 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  14 +fi
  15 +
  16 +dart run \
  17 + ./bin/ten-vad.dart \
  18 + --ten-vad ./ten-vad.onnx \
  19 + --input-wav ./lei-jun-test.wav \
  20 + --output-wav ./lei-jun-test-no-silence.wav
  21 +
  22 +ls -lh *.wav
@@ -487,6 +487,25 @@ final class SherpaOnnxSileroVadModelConfig extends Struct { @@ -487,6 +487,25 @@ final class SherpaOnnxSileroVadModelConfig extends Struct {
487 external double maxSpeechDuration; 487 external double maxSpeechDuration;
488 } 488 }
489 489
  490 +final class SherpaOnnxTenVadModelConfig extends Struct {
  491 + external Pointer<Utf8> model;
  492 +
  493 + @Float()
  494 + external double threshold;
  495 +
  496 + @Float()
  497 + external double minSilenceDuration;
  498 +
  499 + @Float()
  500 + external double minSpeechDuration;
  501 +
  502 + @Int32()
  503 + external int windowSize;
  504 +
  505 + @Float()
  506 + external double maxSpeechDuration;
  507 +}
  508 +
490 final class SherpaOnnxVadModelConfig extends Struct { 509 final class SherpaOnnxVadModelConfig extends Struct {
491 external SherpaOnnxSileroVadModelConfig sileroVad; 510 external SherpaOnnxSileroVadModelConfig sileroVad;
492 511
@@ -500,6 +519,8 @@ final class SherpaOnnxVadModelConfig extends Struct { @@ -500,6 +519,8 @@ final class SherpaOnnxVadModelConfig extends Struct {
500 519
501 @Int32() 520 @Int32()
502 external int debug; 521 external int debug;
  522 +
  523 + external SherpaOnnxTenVadModelConfig tenVad;
503 } 524 }
504 525
505 final class SherpaOnnxSpeechSegment extends Struct { 526 final class SherpaOnnxSpeechSegment extends Struct {
@@ -49,6 +49,50 @@ class SileroVadModelConfig { @@ -49,6 +49,50 @@ class SileroVadModelConfig {
49 final double maxSpeechDuration; 49 final double maxSpeechDuration;
50 } 50 }
51 51
  52 +class TenVadModelConfig {
  53 + const TenVadModelConfig(
  54 + {this.model = '',
  55 + this.threshold = 0.5,
  56 + this.minSilenceDuration = 0.5,
  57 + this.minSpeechDuration = 0.25,
  58 + this.windowSize = 256,
  59 + this.maxSpeechDuration = 5.0});
  60 +
  61 + factory TenVadModelConfig.fromJson(Map<String, dynamic> json) {
  62 + return TenVadModelConfig(
  63 + model: json['model'] as String? ?? '',
  64 + threshold: (json['threshold'] as num?)?.toDouble() ?? 0.5,
  65 + minSilenceDuration:
  66 + (json['minSilenceDuration'] as num?)?.toDouble() ?? 0.5,
  67 + minSpeechDuration:
  68 + (json['minSpeechDuration'] as num?)?.toDouble() ?? 0.25,
  69 + windowSize: json['windowSize'] as int? ?? 256,
  70 + maxSpeechDuration: (json['maxSpeechDuration'] as num?)?.toDouble() ?? 5.0,
  71 + );
  72 + }
  73 +
  74 + @override
  75 + String toString() {
  76 + return 'TenVadModelConfig(model: $model, threshold: $threshold, minSilenceDuration: $minSilenceDuration, minSpeechDuration: $minSpeechDuration, windowSize: $windowSize, maxSpeechDuration: $maxSpeechDuration)';
  77 + }
  78 +
  79 + Map<String, dynamic> toJson() => {
  80 + 'model': model,
  81 + 'threshold': threshold,
  82 + 'minSilenceDuration': minSilenceDuration,
  83 + 'minSpeechDuration': minSpeechDuration,
  84 + 'windowSize': windowSize,
  85 + 'maxSpeechDuration': maxSpeechDuration,
  86 + };
  87 +
  88 + final String model;
  89 + final double threshold;
  90 + final double minSilenceDuration;
  91 + final double minSpeechDuration;
  92 + final int windowSize;
  93 + final double maxSpeechDuration;
  94 +}
  95 +
52 class VadModelConfig { 96 class VadModelConfig {
53 VadModelConfig({ 97 VadModelConfig({
54 this.sileroVad = const SileroVadModelConfig(), 98 this.sileroVad = const SileroVadModelConfig(),
@@ -56,9 +100,11 @@ class VadModelConfig { @@ -56,9 +100,11 @@ class VadModelConfig {
56 this.numThreads = 1, 100 this.numThreads = 1,
57 this.provider = 'cpu', 101 this.provider = 'cpu',
58 this.debug = true, 102 this.debug = true,
  103 + this.tenVad = const TenVadModelConfig(),
59 }); 104 });
60 105
61 final SileroVadModelConfig sileroVad; 106 final SileroVadModelConfig sileroVad;
  107 + final TenVadModelConfig tenVad;
62 final int sampleRate; 108 final int sampleRate;
63 final int numThreads; 109 final int numThreads;
64 final String provider; 110 final String provider;
@@ -68,6 +114,8 @@ class VadModelConfig { @@ -68,6 +114,8 @@ class VadModelConfig {
68 return VadModelConfig( 114 return VadModelConfig(
69 sileroVad: SileroVadModelConfig.fromJson( 115 sileroVad: SileroVadModelConfig.fromJson(
70 json['sileroVad'] as Map<String, dynamic>? ?? const {}), 116 json['sileroVad'] as Map<String, dynamic>? ?? const {}),
  117 + tenVad: TenVadModelConfig.fromJson(
  118 + json['tenVad'] as Map<String, dynamic>? ?? const {}),
71 sampleRate: json['sampleRate'] as int? ?? 16000, 119 sampleRate: json['sampleRate'] as int? ?? 16000,
72 numThreads: json['numThreads'] as int? ?? 1, 120 numThreads: json['numThreads'] as int? ?? 1,
73 provider: json['provider'] as String? ?? 'cpu', 121 provider: json['provider'] as String? ?? 'cpu',
@@ -77,6 +125,7 @@ class VadModelConfig { @@ -77,6 +125,7 @@ class VadModelConfig {
77 125
78 Map<String, dynamic> toJson() => { 126 Map<String, dynamic> toJson() => {
79 'sileroVad': sileroVad.toJson(), 127 'sileroVad': sileroVad.toJson(),
  128 + 'tenVad': tenVad.toJson(),
80 'sampleRate': sampleRate, 129 'sampleRate': sampleRate,
81 'numThreads': numThreads, 130 'numThreads': numThreads,
82 'provider': provider, 131 'provider': provider,
@@ -85,7 +134,7 @@ class VadModelConfig { @@ -85,7 +134,7 @@ class VadModelConfig {
85 134
86 @override 135 @override
87 String toString() { 136 String toString() {
88 - return 'VadModelConfig(sileroVad: $sileroVad, sampleRate: $sampleRate, numThreads: $numThreads, provider: $provider, debug: $debug)'; 137 + return 'VadModelConfig(sileroVad: $sileroVad, tenVad: $tenVad, sampleRate: $sampleRate, numThreads: $numThreads, provider: $provider, debug: $debug)';
89 } 138 }
90 } 139 }
91 140
@@ -168,8 +217,8 @@ class VoiceActivityDetector { @@ -168,8 +217,8 @@ class VoiceActivityDetector {
168 {required VadModelConfig config, required double bufferSizeInSeconds}) { 217 {required VadModelConfig config, required double bufferSizeInSeconds}) {
169 final c = calloc<SherpaOnnxVadModelConfig>(); 218 final c = calloc<SherpaOnnxVadModelConfig>();
170 219
171 - final modelPtr = config.sileroVad.model.toNativeUtf8();  
172 - c.ref.sileroVad.model = modelPtr; 220 + final sileroVadModelPtr = config.sileroVad.model.toNativeUtf8();
  221 + c.ref.sileroVad.model = sileroVadModelPtr;
173 222
174 c.ref.sileroVad.threshold = config.sileroVad.threshold; 223 c.ref.sileroVad.threshold = config.sileroVad.threshold;
175 c.ref.sileroVad.minSilenceDuration = config.sileroVad.minSilenceDuration; 224 c.ref.sileroVad.minSilenceDuration = config.sileroVad.minSilenceDuration;
@@ -177,6 +226,15 @@ class VoiceActivityDetector { @@ -177,6 +226,15 @@ class VoiceActivityDetector {
177 c.ref.sileroVad.windowSize = config.sileroVad.windowSize; 226 c.ref.sileroVad.windowSize = config.sileroVad.windowSize;
178 c.ref.sileroVad.maxSpeechDuration = config.sileroVad.maxSpeechDuration; 227 c.ref.sileroVad.maxSpeechDuration = config.sileroVad.maxSpeechDuration;
179 228
  229 + final tenVadModelPtr = config.tenVad.model.toNativeUtf8();
  230 + c.ref.tenVad.model = tenVadModelPtr;
  231 +
  232 + c.ref.tenVad.threshold = config.tenVad.threshold;
  233 + c.ref.tenVad.minSilenceDuration = config.tenVad.minSilenceDuration;
  234 + c.ref.tenVad.minSpeechDuration = config.tenVad.minSpeechDuration;
  235 + c.ref.tenVad.windowSize = config.tenVad.windowSize;
  236 + c.ref.tenVad.maxSpeechDuration = config.tenVad.maxSpeechDuration;
  237 +
180 c.ref.sampleRate = config.sampleRate; 238 c.ref.sampleRate = config.sampleRate;
181 c.ref.numThreads = config.numThreads; 239 c.ref.numThreads = config.numThreads;
182 240
@@ -190,7 +248,8 @@ class VoiceActivityDetector { @@ -190,7 +248,8 @@ class VoiceActivityDetector {
190 nullptr; 248 nullptr;
191 249
192 calloc.free(providerPtr); 250 calloc.free(providerPtr);
193 - calloc.free(modelPtr); 251 + calloc.free(tenVadModelPtr);
  252 + calloc.free(sileroVadModelPtr);
194 calloc.free(c); 253 calloc.free(c);
195 254
196 return VoiceActivityDetector._(ptr: ptr, config: config); 255 return VoiceActivityDetector._(ptr: ptr, config: config);