online_punctuation.dart
2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import 'dart:ffi';
import 'package:ffi/ffi.dart';
import './sherpa_onnx_bindings.dart';
class OnlinePunctuationModelConfig {
OnlinePunctuationModelConfig(
{required this.cnnBiLstm,
required this.bpeVocab,
this.numThreads = 1,
this.provider = 'cpu',
this.debug = true});
@override
String toString() {
return 'OnlinePunctuationModelConfig(cnnBiLstm: $cnnBiLstm, '
'bpeVocab: $bpeVocab, numThreads: $numThreads, '
'provider: $provider, debug: $debug)';
}
final String cnnBiLstm;
final String bpeVocab;
final int numThreads;
final String provider;
final bool debug;
}
class OnlinePunctuationConfig {
OnlinePunctuationConfig({
required this.model,
});
@override
String toString() {
return 'OnlinePunctuationConfig(model: $model)';
}
final OnlinePunctuationModelConfig model;
}
class OnlinePunctuation {
OnlinePunctuation.fromPtr({required this.ptr, required this.config});
OnlinePunctuation._({required this.ptr, required this.config});
// The user has to invoke OnlinePunctuation.free() to avoid memory leak.
factory OnlinePunctuation({required OnlinePunctuationConfig config}) {
final c = calloc<SherpaOnnxOnlinePunctuationConfig>();
final cnnBiLstmPtr = config.model.cnnBiLstm.toNativeUtf8();
final bpeVocabPtr = config.model.bpeVocab.toNativeUtf8();
c.ref.model.cnnBiLstm = cnnBiLstmPtr;
c.ref.model.bpeVocab = bpeVocabPtr;
c.ref.model.numThreads = config.model.numThreads;
c.ref.model.debug = config.model.debug ? 1 : 0;
final providerPtr = config.model.provider.toNativeUtf8();
c.ref.model.provider = providerPtr;
final ptr = SherpaOnnxBindings.sherpaOnnxCreateOnlinePunctuation?.call(c) ??
nullptr;
// Free the allocated strings and struct memory
calloc.free(providerPtr);
calloc.free(cnnBiLstmPtr);
calloc.free(bpeVocabPtr);
calloc.free(c);
return OnlinePunctuation._(ptr: ptr, config: config);
}
void free() {
SherpaOnnxBindings.sherpaOnnxDestroyOnlinePunctuation?.call(ptr);
ptr = nullptr;
}
String addPunct(String text) {
final textPtr = text.toNativeUtf8();
final p = SherpaOnnxBindings.sherpaOnnxOnlinePunctuationAddPunct
?.call(ptr, textPtr) ??
nullptr;
calloc.free(textPtr);
if (p == nullptr) {
return '';
}
final ans = p.toDartString();
SherpaOnnxBindings.sherpaOnnxOnlinePunctuationFreeText?.call(p);
return ans;
}
Pointer<SherpaOnnxOnlinePunctuation> ptr;
final OnlinePunctuationConfig config;
}