online_punctuation.dart
3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import 'dart:ffi';
import 'package:ffi/ffi.dart';
import './sherpa_onnx_bindings.dart';
class OnlinePunctuationModelConfig {
OnlinePunctuationModelConfig(
{required this.cnnBiLstm,
required this.bpeVocab,
this.numThreads = 1,
this.provider = 'cpu',
this.debug = true});
factory OnlinePunctuationModelConfig.fromJson(Map<String, dynamic> json) {
return OnlinePunctuationModelConfig(
cnnBiLstm: json['cnnBiLstm'],
bpeVocab: json['bpeVocab'],
numThreads: json['numThreads'],
provider: json['provider'],
debug: json['debug'],
);
}
@override
String toString() {
return 'OnlinePunctuationModelConfig(cnnBiLstm: $cnnBiLstm, '
'bpeVocab: $bpeVocab, numThreads: $numThreads, '
'provider: $provider, debug: $debug)';
}
Map<String, dynamic> toJson() {
return {
'cnnBiLstm': cnnBiLstm,
'bpeVocab': bpeVocab,
'numThreads': numThreads,
'provider': provider,
'debug': debug,
};
}
final String cnnBiLstm;
final String bpeVocab;
final int numThreads;
final String provider;
final bool debug;
}
class OnlinePunctuationConfig {
OnlinePunctuationConfig({
required this.model,
});
factory OnlinePunctuationConfig.fromJson(Map<String, dynamic> json) {
return OnlinePunctuationConfig(
model: OnlinePunctuationModelConfig.fromJson(json['model']),
);
}
@override
String toString() {
return 'OnlinePunctuationConfig(model: $model)';
}
Map<String, dynamic> toJson() {
return {
'model': model.toJson(),
};
}
final OnlinePunctuationModelConfig model;
}
class OnlinePunctuation {
OnlinePunctuation.fromPtr({required this.ptr, required this.config});
OnlinePunctuation._({required this.ptr, required this.config});
// The user has to invoke OnlinePunctuation.free() to avoid memory leak.
factory OnlinePunctuation({required OnlinePunctuationConfig config}) {
final c = calloc<SherpaOnnxOnlinePunctuationConfig>();
final cnnBiLstmPtr = config.model.cnnBiLstm.toNativeUtf8();
final bpeVocabPtr = config.model.bpeVocab.toNativeUtf8();
c.ref.model.cnnBiLstm = cnnBiLstmPtr;
c.ref.model.bpeVocab = bpeVocabPtr;
c.ref.model.numThreads = config.model.numThreads;
c.ref.model.debug = config.model.debug ? 1 : 0;
final providerPtr = config.model.provider.toNativeUtf8();
c.ref.model.provider = providerPtr;
if (SherpaOnnxBindings.sherpaOnnxCreateOnlinePunctuation == null) {
throw Exception("Please initialize sherpa-onnx first");
}
final ptr = SherpaOnnxBindings.sherpaOnnxCreateOnlinePunctuation?.call(c) ??
nullptr;
if (ptr == nullptr) {
throw Exception(
"Failed to create online punctuation. Please check your config");
}
// Free the allocated strings and struct memory
calloc.free(providerPtr);
calloc.free(cnnBiLstmPtr);
calloc.free(bpeVocabPtr);
calloc.free(c);
return OnlinePunctuation._(ptr: ptr, config: config);
}
void free() {
SherpaOnnxBindings.sherpaOnnxDestroyOnlinePunctuation?.call(ptr);
ptr = nullptr;
}
String addPunct(String text) {
final textPtr = text.toNativeUtf8();
final p = SherpaOnnxBindings.sherpaOnnxOnlinePunctuationAddPunct
?.call(ptr, textPtr) ??
nullptr;
calloc.free(textPtr);
if (p == nullptr) {
return '';
}
final ans = p.toDartString();
SherpaOnnxBindings.sherpaOnnxOnlinePunctuationFreeText?.call(p);
return ans;
}
Pointer<SherpaOnnxOnlinePunctuation> ptr;
final OnlinePunctuationConfig config;
}