offline-speaker-diarization-c-api.c 4.2 KB
// c-api-examples/offline-sepaker-diarization-c-api.c
//
// Copyright (c)  2024  Xiaomi Corporation

//
// This file demonstrates how to implement speaker diarization with
// sherpa-onnx's C API.

// clang-format off
/*
Usage:

Step 1: Download a speaker segmentation model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example

  wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

Step 2: Download a speaker embedding extractor model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example

  wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

Step 3. Download test wave files

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example

  wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

Step 4. Run it

 */
// clang-format on

#include <stdio.h>
#include <string.h>

#include "sherpa-onnx/c-api/c-api.h"

static int32_t ProgressCallback(int32_t num_processed_chunks,
                                int32_t num_total_chunks, void *arg) {
  float progress = 100.0 * num_processed_chunks / num_total_chunks;
  fprintf(stderr, "progress %.2f%%\n", progress);

  // the return value is currently ignored
  return 0;
}

int main() {
  // Please see the comments at the start of this file for how to download
  // the .onnx file and .wav files below
  const char *segmentation_model =
      "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx";

  const char *embedding_extractor_model =
      "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";

  const char *wav_filename = "./0-four-speakers-zh.wav";

  const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
  if (wave == NULL) {
    fprintf(stderr, "Failed to read %s\n", wav_filename);
    return -1;
  }

  SherpaOnnxOfflineSpeakerDiarizationConfig config;
  memset(&config, 0, sizeof(config));

  config.segmentation.pyannote.model = segmentation_model;
  config.embedding.model = embedding_extractor_model;

  // the test wave ./0-four-speakers-zh.wav has 4 speakers, so
  // we set num_clusters to 4
  //
  config.clustering.num_clusters = 4;
  // If you don't know the number of speakers in the test wave file, please
  // use
  // config.clustering.threshold = 0.5; // You need to tune this threshold

  const SherpaOnnxOfflineSpeakerDiarization *sd =
      SherpaOnnxCreateOfflineSpeakerDiarization(&config);

  if (!sd) {
    fprintf(stderr, "Failed to initialize offline speaker diarization\n");
    return -1;
  }

  if (SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd) !=
      wave->sample_rate) {
    fprintf(
        stderr,
        "Expected sample rate: %d. Actual sample rate from the wave file: %d\n",
        SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd),
        wave->sample_rate);
    goto failed;
  }

  const SherpaOnnxOfflineSpeakerDiarizationResult *result =
      SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
          sd, wave->samples, wave->num_samples, ProgressCallback, NULL);
  if (!result) {
    fprintf(stderr, "Failed to do speaker diarization");
    goto failed;
  }

  int32_t num_segments =
      SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result);

  const SherpaOnnxOfflineSpeakerDiarizationSegment *segments =
      SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result);

  for (int32_t i = 0; i != num_segments; ++i) {
    fprintf(stderr, "%.3f -- %.3f speaker_%02d\n", segments[i].start,
            segments[i].end, segments[i].speaker);
  }

failed:

  SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
  SherpaOnnxOfflineSpeakerDiarizationDestroyResult(result);
  SherpaOnnxDestroyOfflineSpeakerDiarization(sd);
  SherpaOnnxFreeWave(wave);

  return 0;
}