test_speaker_recognition.py 7.0 KB
# sherpa-onnx/python/tests/test_speaker_recognition.py
#
# Copyright (c)  2024  Xiaomi Corporation
#
# To run this single test, use
#
#  ctest --verbose -R  test_speaker_recognition_py

import unittest
import wave
from collections import defaultdict
from pathlib import Path
from typing import Tuple

import numpy as np
import sherpa_onnx

d = "/tmp/sr-models"


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
    """
    Args:
      wave_filename:
        Path to a wave file. It should be single channel and each sample should
        be 16-bit. Its sample rate does not need to be 16kHz.
    Returns:
      Return a tuple containing:
       - A 1-D array of dtype np.float32 containing the samples, which are
       normalized to the range [-1, 1].
       - sample rate of the wave file
    """

    with wave.open(wave_filename) as f:
        assert f.getnchannels() == 1, f.getnchannels()
        assert f.getsampwidth() == 2, f.getsampwidth()  # it is in bytes
        num_samples = f.getnframes()
        samples = f.readframes(num_samples)
        samples_int16 = np.frombuffer(samples, dtype=np.int16)
        samples_float32 = samples_int16.astype(np.float32)

        samples_float32 = samples_float32 / 32768
        return samples_float32, f.getframerate()


def load_speaker_embedding_model(model_filename):
    config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
        model=model_filename,
        num_threads=1,
        debug=True,
        provider="cpu",
    )
    if not config.validate():
        raise ValueError(f"Invalid config. {config}")
    extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
    return extractor


def test_zh_models(model_filename: str):
    model_filename = str(model_filename)
    if "en" in model_filename:
        print(f"skip {model_filename}")
        return
    extractor = load_speaker_embedding_model(model_filename)
    filenames = [
        "leijun-sr-1",
        "leijun-sr-2",
        "fangjun-sr-1",
        "fangjun-sr-2",
        "fangjun-sr-3",
    ]
    tmp = defaultdict(list)
    for filename in filenames:
        print(filename)
        name = filename.split("-", maxsplit=1)[0]
        data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
        stream = extractor.create_stream()
        stream.accept_waveform(sample_rate=sample_rate, waveform=data)
        stream.input_finished()
        assert extractor.is_ready(stream)
        embedding = extractor.compute(stream)
        embedding = np.array(embedding)
        tmp[name].append(embedding)

    manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
    for name, embedding_list in tmp.items():
        print(name, len(embedding_list))
        embedding = sum(embedding_list) / len(embedding_list)
        status = manager.add(name, embedding)
        if not status:
            raise RuntimeError(f"Failed to register speaker {name}")

    filenames = [
        "leijun-test-sr-1",
        "leijun-test-sr-2",
        "leijun-test-sr-3",
        "fangjun-test-sr-1",
        "fangjun-test-sr-2",
    ]
    for filename in filenames:
        name = filename.split("-", maxsplit=1)[0]
        data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
        stream = extractor.create_stream()
        stream.accept_waveform(sample_rate=sample_rate, waveform=data)
        stream.input_finished()
        assert extractor.is_ready(stream)
        embedding = extractor.compute(stream)
        embedding = np.array(embedding)
        status = manager.verify(name, embedding, threshold=0.5)
        if not status:
            raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")

        ans = manager.search(embedding, threshold=0.5)
        assert ans == name, (name, ans)


def test_en_and_zh_models(model_filename: str):
    model_filename = str(model_filename)
    extractor = load_speaker_embedding_model(model_filename)
    manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)

    filenames = [
        "speaker1_a_cn_16k",
        "speaker2_a_cn_16k",
        "speaker1_a_en_16k",
        "speaker2_a_en_16k",
    ]
    is_en = "en" in model_filename
    for filename in filenames:
        if is_en and "cn" in filename:
            continue

        if not is_en and "en" in filename:
            continue

        name = filename.rsplit("_", maxsplit=1)[0]
        data, sample_rate = read_wave(
            f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
        )
        stream = extractor.create_stream()
        stream.accept_waveform(sample_rate=sample_rate, waveform=data)
        stream.input_finished()
        assert extractor.is_ready(stream)
        embedding = extractor.compute(stream)
        embedding = np.array(embedding)

        status = manager.add(name, embedding)
        if not status:
            raise RuntimeError(f"Failed to register speaker {name}")

    filenames = [
        "speaker1_b_cn_16k",
        "speaker1_b_en_16k",
    ]
    for filename in filenames:
        if is_en and "cn" in filename:
            continue

        if not is_en and "en" in filename:
            continue
        print(filename)
        name = filename.rsplit("_", maxsplit=1)[0]
        name = name.replace("b_cn", "a_cn")
        name = name.replace("b_en", "a_en")
        print(name)

        data, sample_rate = read_wave(
            f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
        )
        stream = extractor.create_stream()
        stream.accept_waveform(sample_rate=sample_rate, waveform=data)
        stream.input_finished()
        assert extractor.is_ready(stream)
        embedding = extractor.compute(stream)
        embedding = np.array(embedding)
        status = manager.verify(name, embedding, threshold=0.5)
        if not status:
            raise RuntimeError(
                f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
            )

        ans = manager.search(embedding, threshold=0.5)
        assert ans == name, (name, ans)


class TestSpeakerRecognition(unittest.TestCase):
    def test_wespeaker_models(self):
        model_dir = Path(d) / "wespeaker"
        if not model_dir.is_dir():
            print(f"{model_dir} does not exist - skip it")
            return
        for filename in model_dir.glob("*.onnx"):
            print(filename)
            test_zh_models(filename)
            test_en_and_zh_models(filename)

    def test_3dpeaker_models(self):
        model_dir = Path(d) / "3dspeaker"
        if not model_dir.is_dir():
            print(f"{model_dir} does not exist - skip it")
            return
        for filename in model_dir.glob("*.onnx"):
            print(filename)
            test_en_and_zh_models(filename)

    def test_nemo_models(self):
        model_dir = Path(d) / "nemo"
        if not model_dir.is_dir():
            print(f"{model_dir} does not exist - skip it")
            return
        for filename in model_dir.glob("*.onnx"):
            print(filename)
            test_en_and_zh_models(filename)


if __name__ == "__main__":
    unittest.main()