Fangjun Kuang
Committed by GitHub

Add Go API for offline punctuation models (#1434)

It is contributed by a community user 
from [our QQ group](https://k2-fsa.github.io/sherpa/social-groups.html#qq).
@@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization @@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization
1283 c.clustering.num_clusters = C.int(config.Clustering.NumClusters) 1283 c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
1284 c.clustering.threshold = C.float(config.Clustering.Threshold) 1284 c.clustering.threshold = C.float(config.Clustering.Threshold)
1285 1285
1286 - SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) 1286 + C.SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
1287 } 1287 }
1288 1288
1289 type OfflineSpeakerDiarizationSegment struct { 1289 type OfflineSpeakerDiarizationSegment struct {
@@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker @@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker
1317 1317
1318 return ans 1318 return ans
1319 } 1319 }
  1320 +
  1321 +// ============================================================
  1322 +// For punctuation
  1323 +// ============================================================
  1324 +type OfflinePunctuationModelConfig struct {
  1325 + Ct_transformer string
  1326 + Num_threads C.int
  1327 + Debug C.int // true to print debug information of the model
  1328 + Provider string
  1329 +}
  1330 +
  1331 +type OfflinePunctuationConfig struct {
  1332 + Model OfflinePunctuationModelConfig
  1333 +}
  1334 +
  1335 +type OfflinePunctuation struct {
  1336 + impl *C.struct_SherpaOnnxOfflinePunctuation
  1337 +}
  1338 +
  1339 +func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation {
  1340 + cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{}
  1341 + cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer)
  1342 + defer C.free(unsafe.Pointer(cfg.model.ct_transformer))
  1343 +
  1344 + cfg.model.num_threads = config.Model.Num_threads
  1345 + cfg.model.debug = config.Model.Debug
  1346 + cfg.model.provider = C.CString(config.Model.Provider)
  1347 + defer C.free(unsafe.Pointer(cfg.model.provider))
  1348 +
  1349 + punc := &OfflinePunctuation{}
  1350 + punc.impl = C.SherpaOnnxCreateOfflinePunctuation(&cfg)
  1351 +
  1352 + return punc
  1353 +}
  1354 +
  1355 +func DeleteOfflinePunc(punc *OfflinePunctuation) {
  1356 + C.SherpaOnnxDestroyOfflinePunctuation(punc.impl)
  1357 + punc.impl = nil
  1358 +}
  1359 +
  1360 +func (punc *OfflinePunctuation) AddPunct(text string) string {
  1361 + p := C.SherpaOfflinePunctuationAddPunct(punc.impl, C.CString(text))
  1362 + defer C.free(unsafe.Pointer(p))
  1363 +
  1364 + text_with_punct := C.GoString(p)
  1365 +
  1366 + return text_with_punct
  1367 +}