Committed by
GitHub
Add Go API for offline punctuation models (#1434)
It is contributed by a community user from [our QQ group](https://k2-fsa.github.io/sherpa/social-groups.html#qq).
正在显示
1 个修改的文件
包含
49 行增加
和
1 行删除
| @@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization | @@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization | ||
| 1283 | c.clustering.num_clusters = C.int(config.Clustering.NumClusters) | 1283 | c.clustering.num_clusters = C.int(config.Clustering.NumClusters) |
| 1284 | c.clustering.threshold = C.float(config.Clustering.Threshold) | 1284 | c.clustering.threshold = C.float(config.Clustering.Threshold) |
| 1285 | 1285 | ||
| 1286 | - SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) | 1286 | + C.SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) |
| 1287 | } | 1287 | } |
| 1288 | 1288 | ||
| 1289 | type OfflineSpeakerDiarizationSegment struct { | 1289 | type OfflineSpeakerDiarizationSegment struct { |
| @@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker | @@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker | ||
| 1317 | 1317 | ||
| 1318 | return ans | 1318 | return ans |
| 1319 | } | 1319 | } |
| 1320 | + | ||
| 1321 | +// ============================================================ | ||
| 1322 | +// For punctuation | ||
| 1323 | +// ============================================================ | ||
| 1324 | +type OfflinePunctuationModelConfig struct { | ||
| 1325 | + Ct_transformer string | ||
| 1326 | + Num_threads C.int | ||
| 1327 | + Debug C.int // true to print debug information of the model | ||
| 1328 | + Provider string | ||
| 1329 | +} | ||
| 1330 | + | ||
| 1331 | +type OfflinePunctuationConfig struct { | ||
| 1332 | + Model OfflinePunctuationModelConfig | ||
| 1333 | +} | ||
| 1334 | + | ||
| 1335 | +type OfflinePunctuation struct { | ||
| 1336 | + impl *C.struct_SherpaOnnxOfflinePunctuation | ||
| 1337 | +} | ||
| 1338 | + | ||
| 1339 | +func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation { | ||
| 1340 | + cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{} | ||
| 1341 | + cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer) | ||
| 1342 | + defer C.free(unsafe.Pointer(cfg.model.ct_transformer)) | ||
| 1343 | + | ||
| 1344 | + cfg.model.num_threads = config.Model.Num_threads | ||
| 1345 | + cfg.model.debug = config.Model.Debug | ||
| 1346 | + cfg.model.provider = C.CString(config.Model.Provider) | ||
| 1347 | + defer C.free(unsafe.Pointer(cfg.model.provider)) | ||
| 1348 | + | ||
| 1349 | + punc := &OfflinePunctuation{} | ||
| 1350 | + punc.impl = C.SherpaOnnxCreateOfflinePunctuation(&cfg) | ||
| 1351 | + | ||
| 1352 | + return punc | ||
| 1353 | +} | ||
| 1354 | + | ||
| 1355 | +func DeleteOfflinePunc(punc *OfflinePunctuation) { | ||
| 1356 | + C.SherpaOnnxDestroyOfflinePunctuation(punc.impl) | ||
| 1357 | + punc.impl = nil | ||
| 1358 | +} | ||
| 1359 | + | ||
| 1360 | +func (punc *OfflinePunctuation) AddPunct(text string) string { | ||
| 1361 | + p := C.SherpaOfflinePunctuationAddPunct(punc.impl, C.CString(text)) | ||
| 1362 | + defer C.free(unsafe.Pointer(p)) | ||
| 1363 | + | ||
| 1364 | + text_with_punct := C.GoString(p) | ||
| 1365 | + | ||
| 1366 | + return text_with_punct | ||
| 1367 | +} |
-
请 注册 或 登录 后发表评论