Fangjun Kuang
Committed by GitHub

convert wespeaker models to sherpa-onnx (#475)

  1 +name: export-wespeaker-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-wespeaker-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-wespeaker-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export wespeaker
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [ubuntu-latest]
  19 + python-version: ["3.8"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v2
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install Python dependencies
  30 + shell: bash
  31 + run: |
  32 + pip install kaldi-native-fbank numpy onnx onnxruntime
  33 +
  34 + - name: Run
  35 + shell: bash
  36 + run: |
  37 + cd scripts/wespeaker
  38 + ./run.sh
  39 +
  40 + mv -v *.onnx ../..
  41 +
  42 + - name: Release
  43 + uses: svenstaro/upload-release-action@v2
  44 + with:
  45 + file_glob: true
  46 + file: ./*.onnx
  47 + overwrite: true
  48 + repo_name: k2-fsa/sherpa-onnx
  49 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  50 + tag: speaker-recongition-models
  1 +# Introduction
  2 +
  3 +This folder contains script for adding meta data to onnx models from
  4 +https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
  5 +
  6 +You can use the models with metadata in sherpa-onnx.
  7 +
  8 +
  9 +**Caution**: You have to add model meta data to `*.onnx` since we plan
  10 +to support models from different frameworks.
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +This script adds meta data to a model so that it can be used in sherpa-onnx.
  6 +
  7 +Usage:
  8 +./add_meta_data.py --model ./voxceleb_resnet34.onnx --language English
  9 +"""
  10 +
  11 +import argparse
  12 +from pathlib import Path
  13 +from typing import Dict
  14 +
  15 +import onnx
  16 +import onnxruntime
  17 +
  18 +
  19 +def get_args():
  20 + parser = argparse.ArgumentParser()
  21 + parser.add_argument(
  22 + "--model",
  23 + type=str,
  24 + required=True,
  25 + help="Path to the input onnx model. Example value: model.onnx",
  26 + )
  27 +
  28 + parser.add_argument(
  29 + "--language",
  30 + type=str,
  31 + required=True,
  32 + help="""Supported language of the input model.
  33 + Example value: Chinese, English.
  34 + """,
  35 + )
  36 +
  37 + parser.add_argument(
  38 + "--url",
  39 + type=str,
  40 + default="https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md",
  41 + help="Where the model is downloaded",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--comment",
  46 + type=str,
  47 + default="no comment",
  48 + help="Comment about the model",
  49 + )
  50 +
  51 + parser.add_argument(
  52 + "--sample-rate",
  53 + type=int,
  54 + default=16000,
  55 + help="Sample rate expected by the model",
  56 + )
  57 +
  58 + return parser.parse_args()
  59 +
  60 +
  61 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  62 + """Add meta data to an ONNX model. It is changed in-place.
  63 +
  64 + Args:
  65 + filename:
  66 + Filename of the ONNX model to be changed.
  67 + meta_data:
  68 + Key-value pairs.
  69 + """
  70 + model = onnx.load(filename)
  71 + for key, value in meta_data.items():
  72 + meta = model.metadata_props.add()
  73 + meta.key = key
  74 + meta.value = str(value)
  75 +
  76 + onnx.save(model, filename)
  77 +
  78 +
  79 +def get_output_dim(filename) -> int:
  80 + filename = str(filename)
  81 + session_opts = onnxruntime.SessionOptions()
  82 + session_opts.log_severity_level = 3 # error level
  83 + sess = onnxruntime.InferenceSession(filename, session_opts)
  84 +
  85 + for i in sess.get_inputs():
  86 + print(i)
  87 +
  88 + print("----------")
  89 +
  90 + for o in sess.get_outputs():
  91 + print(o)
  92 +
  93 + print("----------")
  94 +
  95 + assert len(sess.get_inputs()) == 1
  96 + assert len(sess.get_outputs()) == 1
  97 +
  98 + i = sess.get_inputs()[0]
  99 + o = sess.get_outputs()[0]
  100 +
  101 + assert i.shape[:2] == ["B", "T"], i.shape
  102 + assert o.shape[0] == "B"
  103 +
  104 + assert i.shape[2] == 80, i.shape
  105 +
  106 + return o.shape[1]
  107 +
  108 +
  109 +def main():
  110 + args = get_args()
  111 + model = Path(args.model)
  112 + language = args.language
  113 + url = args.url
  114 + comment = args.comment
  115 + sample_rate = args.sample_rate
  116 +
  117 + if not model.is_file():
  118 + raise ValueError(f"{model} does not exist")
  119 +
  120 + assert len(language) > 0, len(language)
  121 + assert len(url) > 0, len(url)
  122 +
  123 + output_dim = get_output_dim(model)
  124 +
  125 + # all models from wespeaker expect input samples in the range
  126 + # [-32768, 32767]
  127 + normalize_features = 0
  128 +
  129 + meta_data = {
  130 + "framework": "wespeaker",
  131 + "language": language,
  132 + "url": url,
  133 + "comment": comment,
  134 + "sample_rate": sample_rate,
  135 + "output_dim": output_dim,
  136 + "normalize_features": normalize_features,
  137 + }
  138 + print(meta_data)
  139 + add_meta_data(filename=str(model), meta_data=meta_data)
  140 +
  141 +
  142 +if __name__ == "__main__":
  143 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +echo "Downloading models"
  6 +export GIT_LFS_SKIP_SMUDGE=1
  7 +git clone https://huggingface.co/openspeech/wespeaker-models
  8 +cd wespeaker-models
  9 +git lfs pull --include "*.onnx"
  10 +ls -lh
  11 +cd ..
  12 +mv wespeaker-models/*.onnx .
  13 +ls -lh
  14 +
  15 +./add_meta_data.py \
  16 + --model ./voxceleb_resnet34.onnx \
  17 + --language English \
  18 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.onnx
  19 +./test.py --model ./voxceleb_resnet34.onnx \
  20 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  21 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  22 +
  23 +./test.py --model ./voxceleb_resnet34.onnx \
  24 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  25 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  26 +
  27 +mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx
  28 +
  29 +./add_meta_data.py \
  30 + --model ./voxceleb_resnet34_LM.onnx \
  31 + --language English \
  32 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx
  33 +./test.py --model ./voxceleb_resnet34_LM.onnx \
  34 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  35 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  36 +
  37 +./test.py --model ./voxceleb_resnet34_LM.onnx \
  38 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  39 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  40 +
  41 +mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx
  42 +
  43 +./add_meta_data.py \
  44 + --model ./voxceleb_resnet152_LM.onnx \
  45 + --language English \
  46 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.onnx
  47 +
  48 +./test.py --model ./voxceleb_resnet152_LM.onnx \
  49 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  50 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  51 +
  52 +./test.py --model ./voxceleb_resnet152_LM.onnx \
  53 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  54 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  55 +
  56 +mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx
  57 +
  58 +./add_meta_data.py \
  59 + --model ./voxceleb_resnet221_LM.onnx \
  60 + --language English \
  61 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.onnx
  62 +
  63 +./test.py --model ./voxceleb_resnet221_LM.onnx \
  64 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  65 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  66 +
  67 +./test.py --model ./voxceleb_resnet221_LM.onnx \
  68 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  69 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  70 +
  71 +mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx
  72 +
  73 +./add_meta_data.py \
  74 + --model ./voxceleb_resnet293_LM.onnx \
  75 + --language English \
  76 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.onnx
  77 +
  78 +./test.py --model ./voxceleb_resnet293_LM.onnx \
  79 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  80 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  81 +
  82 +./test.py --model ./voxceleb_resnet293_LM.onnx \
  83 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  84 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  85 +
  86 +mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx
  87 +
  88 +./add_meta_data.py \
  89 + --model ./voxceleb_CAM++.onnx \
  90 + --language English \
  91 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.onnx
  92 +
  93 +./test.py --model ./voxceleb_CAM++.onnx \
  94 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  95 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  96 +
  97 +./test.py --model ./voxceleb_CAM++.onnx \
  98 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  99 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  100 +
  101 +mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx
  102 +
  103 +./add_meta_data.py \
  104 + --model ./voxceleb_CAM++_LM.onnx \
  105 + --language English \
  106 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.onnx
  107 +
  108 +./test.py --model ./voxceleb_CAM++_LM.onnx \
  109 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  110 + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
  111 +
  112 +./test.py --model ./voxceleb_CAM++_LM.onnx \
  113 + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
  114 + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav
  115 +
  116 +mv voxceleb_CAM++_LM.onnx en_voxceleb_CAM++_LM.onnx
  117 +
  118 +./add_meta_data.py \
  119 + --model ./cnceleb_resnet34.onnx \
  120 + --language Chinese \
  121 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx
  122 +
  123 +mv cnceleb_resnet34.onnx zh_cnceleb_resnet34.onnx
  124 +
  125 +./add_meta_data.py \
  126 + --model ./cnceleb_resnet34_LM.onnx \
  127 + --language Chinese \
  128 + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx
  129 +
  130 +mv cnceleb_resnet34_LM.onnx zh_cnceleb_resnet34_LM.onnx
  131 +
  132 +ls -lh
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +This script computes speaker similarity score in the range [0-1]
  6 +of two wave files using a speaker recognition model.
  7 +"""
  8 +import argparse
  9 +import wave
  10 +from pathlib import Path
  11 +
  12 +import kaldi_native_fbank as knf
  13 +import numpy as np
  14 +import onnxruntime as ort
  15 +from numpy.linalg import norm
  16 +
  17 +
  18 +def get_args():
  19 + parser = argparse.ArgumentParser()
  20 + parser.add_argument(
  21 + "--model",
  22 + type=str,
  23 + required=True,
  24 + help="Path to the input onnx model. Example value: model.onnx",
  25 + )
  26 +
  27 + parser.add_argument(
  28 + "--file1",
  29 + type=str,
  30 + required=True,
  31 + help="Input wave 1",
  32 + )
  33 +
  34 + parser.add_argument(
  35 + "--file2",
  36 + type=str,
  37 + required=True,
  38 + help="Input wave 2",
  39 + )
  40 +
  41 + return parser.parse_args()
  42 +
  43 +
  44 +def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray:
  45 + """
  46 + Args:
  47 + filename:
  48 + Path to a wave file, which must be of 16-bit and 16kHz.
  49 + expected_sample_rate:
  50 + Expected sample rate of the wave file.
  51 + Returns:
  52 + Return a 1-D float32 array containing audio samples. Each sample is in
  53 + the range [-1, 1].
  54 + """
  55 + filename = str(filename)
  56 + with wave.open(filename) as f:
  57 + # Note: If wave_file_sample_rate is different from
  58 + # recognizer.sample_rate, we will do resampling inside sherpa-ncnn
  59 + wave_file_sample_rate = f.getframerate()
  60 + assert wave_file_sample_rate == expected_sample_rate, (
  61 + wave_file_sample_rate,
  62 + expected_sample_rate,
  63 + )
  64 +
  65 + num_channels = f.getnchannels()
  66 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  67 + num_samples = f.getnframes()
  68 + samples = f.readframes(num_samples)
  69 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  70 + samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
  71 + samples_float32 = samples_int16.astype(np.float32)
  72 +
  73 + samples_float32 = samples_float32 / 32768
  74 +
  75 + return samples_float32
  76 +
  77 +
  78 +def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray:
  79 + opts = knf.FbankOptions()
  80 + opts.frame_opts.dither = 0
  81 + opts.frame_opts.samp_freq = sample_rate
  82 + opts.frame_opts.snip_edges = False
  83 +
  84 + opts.mel_opts.num_bins = 80
  85 + opts.mel_opts.debug_mel = False
  86 +
  87 + fbank = knf.OnlineFbank(opts)
  88 + fbank.accept_waveform(sample_rate, samples)
  89 + fbank.input_finished()
  90 +
  91 + features = []
  92 + for i in range(fbank.num_frames_ready):
  93 + f = fbank.get_frame(i)
  94 + features.append(f)
  95 + features = np.stack(features, axis=0)
  96 +
  97 + return features
  98 +
  99 +
  100 +class OnnxModel:
  101 + def __init__(
  102 + self,
  103 + filename: str,
  104 + ):
  105 + session_opts = ort.SessionOptions()
  106 + session_opts.inter_op_num_threads = 1
  107 + session_opts.intra_op_num_threads = 4
  108 +
  109 + self.session_opts = session_opts
  110 +
  111 + self.model = ort.InferenceSession(
  112 + filename,
  113 + sess_options=self.session_opts,
  114 + )
  115 +
  116 + meta = self.model.get_modelmeta().custom_metadata_map
  117 + self.normalize_features = int(meta["normalize_features"])
  118 + self.sample_rate = int(meta["sample_rate"])
  119 + self.output_dim = int(meta["output_dim"])
  120 +
  121 + def __call__(self, x: np.ndarray) -> np.ndarray:
  122 + """
  123 + Args:
  124 + x:
  125 + A 2-D float32 tensor of shape (T, C).
  126 + y:
  127 + A 1-D float32 tensor containing model output.
  128 + """
  129 + x = np.expand_dims(x, axis=0)
  130 +
  131 + return self.model.run(
  132 + [
  133 + self.model.get_outputs()[0].name,
  134 + ],
  135 + {
  136 + self.model.get_inputs()[0].name: x,
  137 + },
  138 + )[0][0]
  139 +
  140 +
  141 +def main():
  142 + args = get_args()
  143 + filename = Path(args.model)
  144 + file1 = Path(args.file1)
  145 + file2 = Path(args.file2)
  146 + assert filename.is_file(), filename
  147 + assert file1.is_file(), file1
  148 + assert file2.is_file(), file2
  149 +
  150 + model = OnnxModel(filename)
  151 + wave1 = read_wavefile(file1, model.sample_rate)
  152 + wave2 = read_wavefile(file2, model.sample_rate)
  153 +
  154 + if not model.normalize_features:
  155 + wave1 = wave1 * 32768
  156 + wave2 = wave2 * 32768
  157 +
  158 + features1 = compute_features(wave1, model.sample_rate)
  159 + features2 = compute_features(wave2, model.sample_rate)
  160 +
  161 + output1 = model(features1)
  162 + output2 = model(features2)
  163 +
  164 + print(output1.shape)
  165 + print(output2.shape)
  166 + similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
  167 + print(f"similarity in the range [0-1]: {similarity}")
  168 +
  169 +
  170 +if __name__ == "__main__":
  171 + main()