speaker-identification-c-api.c 6.4 KB
// c-api-examples/speaker-identification-c-api.c
//
// Copyright (c)  2024  Xiaomi Corporation

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "sherpa-onnx/c-api/c-api.h"

static const float *ComputeEmbedding(
    const SherpaOnnxSpeakerEmbeddingExtractor *ex, const char *wav_filename) {
  const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
  if (wave == NULL) {
    fprintf(stderr, "Failed to read %s\n", wav_filename);
    exit(-1);
  }

  const SherpaOnnxOnlineStream *stream =
      SherpaOnnxSpeakerEmbeddingExtractorCreateStream(ex);

  SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
                                       wave->num_samples);
  SherpaOnnxOnlineStreamInputFinished(stream);

  if (!SherpaOnnxSpeakerEmbeddingExtractorIsReady(ex, stream)) {
    fprintf(stderr, "The input wave file %s is too short!\n", wav_filename);
    exit(-1);
  }

  const float *v =
      SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(ex, stream);

  SherpaOnnxDestroyOnlineStream(stream);
  SherpaOnnxFreeWave(wave);

  return v;
}

void PrintUsage(const char *program_name) {
  fprintf(stderr, "Usage: %s <model_path> <threshold> <speaker1_name> <speaker1_wav1> [speaker1_wav2] [speaker1_wav3] <speaker2_name> <speaker2_wav1> [speaker2_wav2] [speaker2_wav3] <output_file> <test_wav1> <test_wav2> ...\n", program_name);
  fprintf(stderr, "Example: %s 3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx 0.6 fangjun ./sr-data/enroll/fangjun-sr-1.wav ./sr-data/enroll/fangjun-sr-2.wav ./sr-data/enroll/fangjun-sr-3.wav leijun ./sr-data/enroll/leijun-sr-1.wav ./sr-data/enroll/leijun-sr-2.wav result.txt ./sr-data/test/fangjun-test-sr-1.wav ./sr-data/test/leijun-test-sr-1.wav ./sr-data/test/liudehua-test-sr-1.wav\n", program_name);
}

int32_t main(int32_t argc, char *argv[]) {
  if (argc < 7) {
    PrintUsage(argv[0]);
    return -1;
  }

  // Parse command line arguments
  const char *model_path = argv[1];
  float threshold = atof(argv[2]);
  
  // Find the position of output file and test files
  int32_t output_file_index = -1;
  for (int32_t i = 3; i < argc; i++) {
    if (strstr(argv[i], ".txt") != NULL) {
      output_file_index = i;
      break;
    }
  }
  
  if (output_file_index == -1 || output_file_index >= argc - 1) {
    fprintf(stderr, "Output file not found or no test files provided\n");
    PrintUsage(argv[0]);
    return -1;
  }
  
  const char *output_file = argv[output_file_index];
  int32_t num_test_files = argc - output_file_index - 1;
  const char **test_files = (const char **)&argv[output_file_index + 1];
  
  // Parse speaker information
  int32_t num_speakers = 0;
  const char *speaker_names[10] = {NULL};
  const char *speaker_files[10][4] = {NULL};
  int32_t speaker_file_counts[10] = {0};
  
  int32_t current_index = 3;
  while (current_index < output_file_index && num_speakers < 10) {
    // Speaker name
    speaker_names[num_speakers] = argv[current_index++];
    
    // Speaker wave files
    int32_t file_count = 0;
    while (current_index < output_file_index && 
           strstr(argv[current_index], ".wav") != NULL && 
           file_count < 4) {
      speaker_files[num_speakers][file_count++] = argv[current_index++];
    }
    
    speaker_file_counts[num_speakers] = file_count;
    num_speakers++;
  }
  
  // Open output file
  FILE *fp = fopen(output_file, "w");
  if (!fp) {
    fprintf(stderr, "Failed to open output file: %s\n", output_file);
    return -1;
  }
  
  fprintf(fp, "Speaker Identification Results\n");
  fprintf(fp, "Model: %s\n", model_path);
  fprintf(fp, "Threshold: %.2f\n", threshold);
  fprintf(fp, "========================================\n");

  // Initialize speaker embedding extractor
  SherpaOnnxSpeakerEmbeddingExtractorConfig config;
  memset(&config, 0, sizeof(config));
  config.model = model_path;
  config.num_threads = 1;
  config.debug = 0;
  config.provider = "cpu";

  const SherpaOnnxSpeakerEmbeddingExtractor *ex =
      SherpaOnnxCreateSpeakerEmbeddingExtractor(&config);
  if (!ex) {
    fprintf(stderr, "Failed to create speaker embedding extractor");
    fclose(fp);
    return -1;
  }

  int32_t dim = SherpaOnnxSpeakerEmbeddingExtractorDim(ex);
  const SherpaOnnxSpeakerEmbeddingManager *manager =
      SherpaOnnxCreateSpeakerEmbeddingManager(dim);

  // Register speakers
  for (int32_t i = 0; i < num_speakers; i++) {
    const float *embeddings[4] = {NULL};
    int32_t count = speaker_file_counts[i];
    
    for (int32_t j = 0; j < count; j++) {
      embeddings[j] = ComputeEmbedding(ex, speaker_files[i][j]);
    }
    
    if (!SherpaOnnxSpeakerEmbeddingManagerAddList(manager, speaker_names[i], embeddings)) {
      fprintf(stderr, "Failed to register %s\n", speaker_names[i]);
      fprintf(fp, "Failed to register %s\n", speaker_names[i]);
      fclose(fp);
      exit(-1);
    }
    
    for (int32_t j = 0; j < count; j++) {
      SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(embeddings[j]);
    }
    
    fprintf(stderr, "Registered speaker: %s with %d wave files\n", speaker_names[i], count);
    fprintf(fp, "Registered speaker: %s with %d wave files\n", speaker_names[i], count);
  }

  fprintf(fp, "\nTest Results:\n");
  fprintf(fp, "========================================\n");

  // Process test files
  for (int32_t i = 0; i < num_test_files; i++) {
    const char *test_file = test_files[i];
    const float *v = ComputeEmbedding(ex, test_file);
    
    const char *name = SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v, threshold);
    
    fprintf(stderr, "Testing %s: ", test_file);
    fprintf(fp, "Test file: %s\n", test_file);
    
    if (name) {
      fprintf(stderr, "Found %s\n", name);
      fprintf(fp, "  Result: Found speaker - %s\n", name);
      SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name);
    } else {
      fprintf(stderr, "Not found\n");
      fprintf(fp, "  Result: Speaker not found\n");
    }
    
    // Verify against all registered speakers
    for (int32_t j = 0; j < num_speakers; j++) {
      int32_t ok = SherpaOnnxSpeakerEmbeddingManagerVerify(manager, speaker_names[j], v, threshold);
      fprintf(fp, "  Verify with %s: %s\n", speaker_names[j], ok ? "MATCH" : "NO MATCH");
    }
    
    fprintf(fp, "\n");
    SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(v);
  }

  // Cleanup
  SherpaOnnxDestroySpeakerEmbeddingManager(manager);
  SherpaOnnxDestroySpeakerEmbeddingExtractor(ex);
  fclose(fp);
  
  fprintf(stderr, "Results saved to: %s\n", output_file);

  return 0;
}