ced.dart
1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:io';
import 'package:args/args.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './init.dart';
void main(List<String> arguments) async {
await initSherpaOnnx();
final parser = ArgParser()
..addOption('model', help: 'Path to the zipformer model')
..addOption('labels', help: 'Path to class_labels_indices.csv')
..addOption('top-k', help: 'topK events to be returned', defaultsTo: '5')
..addOption('wav', help: 'Path to test.wav to be tagged');
final res = parser.parse(arguments);
if (res['model'] == null || res['labels'] == null || res['wav'] == null) {
print(parser.usage);
exit(1);
}
final model = res['model'] as String;
final labels = res['labels'] as String;
final topK = int.tryParse(res['top-k'] as String) ?? 5;
final wav = res['wav'] as String;
final modelConfig = sherpa_onnx.AudioTaggingModelConfig(
ced: model,
numThreads: 1,
debug: true,
provider: 'cpu',
);
final config = sherpa_onnx.AudioTaggingConfig(
model: modelConfig,
labels: labels,
);
final at = sherpa_onnx.AudioTagging(config: config);
final waveData = sherpa_onnx.readWave(wav);
final stream = at.createStream();
stream.acceptWaveform(
samples: waveData.samples, sampleRate: waveData.sampleRate);
final events = at.compute(stream: stream, topK: topK);
print(events);
stream.free();
at.free();
}