Fangjun Kuang
Committed by GitHub

Support exporting models to onnx from 3D-Speaker (#522)

  1 +name: export-3dspeaker-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-3dspeaker-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-3dspeaker-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export 3d-speaker to ONNX
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + python-version: ["3.8"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v4
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Run
  30 + shell: bash
  31 + run: |
  32 + cd scripts/3dspeaker
  33 + ./run.sh
  34 +
  35 + mv -v *.onnx ../..
  36 +
  37 + - name: Release
  38 + uses: svenstaro/upload-release-action@v2
  39 + with:
  40 + file_glob: true
  41 + file: ./*.onnx
  42 + overwrite: true
  43 + repo_name: k2-fsa/sherpa-onnx
  44 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  45 + tag: speaker-recongition-models
  1 +# Introduction
  2 +
  3 +This directory contains scripts
  4 +about exporting models from https://github.com/alibaba-damo-academy/3D-Speaker
  5 +to `onnx` so that they can be used in `sherpa-onnx`.
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +import json
  6 +import os
  7 +import pathlib
  8 +import re
  9 +from typing import Dict
  10 +
  11 +import onnx
  12 +import torch
  13 +from infer_sv import supports
  14 +from modelscope.hub.snapshot_download import snapshot_download
  15 +from speakerlab.utils.builder import dynamic_import
  16 +
  17 +
  18 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  19 + """Add meta data to an ONNX model. It is changed in-place.
  20 +
  21 + Args:
  22 + filename:
  23 + Filename of the ONNX model to be changed.
  24 + meta_data:
  25 + Key-value pairs.
  26 + """
  27 + model = onnx.load(filename)
  28 + for key, value in meta_data.items():
  29 + meta = model.metadata_props.add()
  30 + meta.key = key
  31 + meta.value = str(value)
  32 +
  33 + onnx.save(model, filename)
  34 +
  35 +
  36 +def get_args():
  37 + parser = argparse.ArgumentParser()
  38 + parser.add_argument(
  39 + "--model",
  40 + type=str,
  41 + required=True,
  42 + choices=[
  43 + "speech_campplus_sv_en_voxceleb_16k",
  44 + "speech_campplus_sv_zh-cn_16k-common",
  45 + "speech_eres2net_sv_en_voxceleb_16k",
  46 + "speech_eres2net_sv_zh-cn_16k-common",
  47 + "speech_eres2net_base_200k_sv_zh-cn_16k-common",
  48 + "speech_eres2net_base_sv_zh-cn_3dspeaker_16k",
  49 + "speech_eres2net_large_sv_zh-cn_3dspeaker_16k",
  50 + ],
  51 + )
  52 + return parser.parse_args()
  53 +
  54 +
  55 +@torch.no_grad()
  56 +def main():
  57 + args = get_args()
  58 + local_model_dir = "pretrained"
  59 + model_id = f"damo/{args.model}"
  60 + conf = supports[model_id]
  61 + cache_dir = snapshot_download(
  62 + model_id,
  63 + revision=conf["revision"],
  64 + )
  65 + cache_dir = pathlib.Path(cache_dir)
  66 +
  67 + save_dir = os.path.join(local_model_dir, model_id.split("/")[1])
  68 + save_dir = pathlib.Path(save_dir)
  69 + save_dir.mkdir(exist_ok=True, parents=True)
  70 +
  71 + download_files = ["examples", conf["model_pt"]]
  72 + for src in cache_dir.glob("*"):
  73 + if re.search("|".join(download_files), src.name):
  74 + dst = save_dir / src.name
  75 + try:
  76 + dst.unlink()
  77 + except FileNotFoundError:
  78 + pass
  79 + dst.symlink_to(src)
  80 + pretrained_model = save_dir / conf["model_pt"]
  81 + pretrained_state = torch.load(pretrained_model, map_location="cpu")
  82 +
  83 + model = conf["model"]
  84 + embedding_model = dynamic_import(model["obj"])(**model["args"])
  85 + embedding_model.load_state_dict(pretrained_state)
  86 + embedding_model.eval()
  87 +
  88 + with open(f"{cache_dir}/configuration.json") as f:
  89 + json_config = json.loads(f.read())
  90 + print(json_config)
  91 +
  92 + T = 100
  93 + C = 80
  94 + x = torch.rand(1, T, C)
  95 + filename = f"{args.model}.onnx"
  96 + torch.onnx.export(
  97 + embedding_model,
  98 + x,
  99 + filename,
  100 + opset_version=13,
  101 + input_names=["x"],
  102 + output_names=["embedding"],
  103 + dynamic_axes={
  104 + "x": {0: "N", 1: "T"},
  105 + "embeddings": {0: "N"},
  106 + },
  107 + )
  108 +
  109 + # all models from 3d-speaker expect input samples in the range
  110 + # [-1, 1]
  111 + normalize_samples = 1
  112 +
  113 + # all models from 3d-speaker normalize the features by the global mean
  114 + feature_normalize_type = "global-mean"
  115 + sample_rate = json_config["model"]["model_config"]["sample_rate"]
  116 +
  117 + feat_dim = conf["model"]["args"]["feat_dim"]
  118 + assert feat_dim == 80, feat_dim
  119 +
  120 + output_dim = conf["model"]["args"]["embedding_size"]
  121 +
  122 + if "zh-cn" in args.model:
  123 + language = "Chinese"
  124 + elif "en" in args.model:
  125 + language = "English"
  126 + else:
  127 + raise ValueError(f"Unsupported language for model {args.model}")
  128 +
  129 + comment = f"This model is from damo/{args.model}"
  130 + url = f"https://www.modelscope.cn/models/damo/{args.model}/summary"
  131 +
  132 + meta_data = {
  133 + "framework": "3d-speaker",
  134 + "language": language,
  135 + "url": url,
  136 + "comment": comment,
  137 + "sample_rate": sample_rate,
  138 + "output_dim": output_dim,
  139 + "normalize_samples": normalize_samples,
  140 + "feature_normalize_type": feature_normalize_type,
  141 + }
  142 + print(meta_data)
  143 + add_meta_data(filename=filename, meta_data=meta_data)
  144 +
  145 +
  146 +main()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +function install_3d_speaker() {
  6 + echo "Install 3D-Speaker"
  7 + git clone https://github.com/alibaba-damo-academy/3D-Speaker.git
  8 + pushd 3D-Speaker
  9 + pip install -q -r ./requirements.txt
  10 + pip install -q modelscope onnx onnxruntime kaldi-native-fbank
  11 + popd
  12 +}
  13 +
  14 +function download_test_data() {
  15 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
  16 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
  17 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
  18 +
  19 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_en_16k.wav
  20 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_en_16k.wav
  21 + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_en_16k.wav
  22 +}
  23 +
  24 +install_3d_speaker
  25 +
  26 +download_test_data
  27 +
  28 +export PYTHONPATH=$PWD/3D-Speaker:$PYTHONPATH
  29 +export PYTHONPATH=$PWD/3D-Speaker/speakerlab/bin:$PYTHONPATH
  30 +
  31 +models=(
  32 +speech_campplus_sv_en_voxceleb_16k
  33 +speech_campplus_sv_zh-cn_16k-common
  34 +speech_eres2net_sv_en_voxceleb_16k
  35 +speech_eres2net_sv_zh-cn_16k-common
  36 +speech_eres2net_base_200k_sv_zh-cn_16k-common
  37 +speech_eres2net_base_sv_zh-cn_3dspeaker_16k
  38 +speech_eres2net_large_sv_zh-cn_3dspeaker_16k
  39 +)
  40 +for model in ${models[@]}; do
  41 + echo "--------------------$model--------------------"
  42 + python3 ./export-onnx.py --model $model
  43 +
  44 + python3 ./test-onnx.py \
  45 + --model ${model}.onnx \
  46 + --file1 ./speaker1_a_cn_16k.wav \
  47 + --file2 ./speaker1_b_cn_16k.wav
  48 +
  49 + python3 ./test-onnx.py \
  50 + --model ${model}.onnx \
  51 + --file1 ./speaker1_a_cn_16k.wav \
  52 + --file2 ./speaker2_a_cn_16k.wav
  53 +
  54 + python3 ./test-onnx.py \
  55 + --model ${model}.onnx \
  56 + --file1 ./speaker1_a_en_16k.wav \
  57 + --file2 ./speaker1_b_en_16k.wav
  58 +
  59 + python3 ./test-onnx.py \
  60 + --model ${model}.onnx \
  61 + --file1 ./speaker1_a_en_16k.wav \
  62 + --file2 ./speaker2_a_en_16k.wav
  63 +done
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +This script computes speaker similarity score in the range [0-1]
  6 +of two wave files using a speaker embedding model.
  7 +"""
  8 +import argparse
  9 +import wave
  10 +from pathlib import Path
  11 +
  12 +import kaldi_native_fbank as knf
  13 +import numpy as np
  14 +import onnxruntime as ort
  15 +from numpy.linalg import norm
  16 +
  17 +
  18 +def get_args():
  19 + parser = argparse.ArgumentParser()
  20 + parser.add_argument(
  21 + "--model",
  22 + type=str,
  23 + required=True,
  24 + help="Path to the input onnx model. Example value: model.onnx",
  25 + )
  26 +
  27 + parser.add_argument(
  28 + "--file1",
  29 + type=str,
  30 + required=True,
  31 + help="Input wave 1",
  32 + )
  33 +
  34 + parser.add_argument(
  35 + "--file2",
  36 + type=str,
  37 + required=True,
  38 + help="Input wave 2",
  39 + )
  40 +
  41 + return parser.parse_args()
  42 +
  43 +
  44 +def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray:
  45 + """
  46 + Args:
  47 + filename:
  48 + Path to a wave file, which must be of 16-bit and 16kHz.
  49 + expected_sample_rate:
  50 + Expected sample rate of the wave file.
  51 + Returns:
  52 + Return a 1-D float32 array containing audio samples. Each sample is in
  53 + the range [-1, 1].
  54 + """
  55 + filename = str(filename)
  56 + with wave.open(filename) as f:
  57 + wave_file_sample_rate = f.getframerate()
  58 + assert wave_file_sample_rate == expected_sample_rate, (
  59 + wave_file_sample_rate,
  60 + expected_sample_rate,
  61 + )
  62 +
  63 + num_channels = f.getnchannels()
  64 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  65 + num_samples = f.getnframes()
  66 + samples = f.readframes(num_samples)
  67 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  68 + samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
  69 + samples_float32 = samples_int16.astype(np.float32)
  70 +
  71 + samples_float32 = samples_float32 / 32768
  72 +
  73 + return samples_float32
  74 +
  75 +
  76 +def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray:
  77 + opts = knf.FbankOptions()
  78 + opts.frame_opts.dither = 0
  79 + opts.frame_opts.samp_freq = sample_rate
  80 + opts.frame_opts.snip_edges = True
  81 +
  82 + opts.mel_opts.num_bins = 80
  83 + opts.mel_opts.debug_mel = False
  84 +
  85 + fbank = knf.OnlineFbank(opts)
  86 + fbank.accept_waveform(sample_rate, samples)
  87 + fbank.input_finished()
  88 +
  89 + features = []
  90 + for i in range(fbank.num_frames_ready):
  91 + f = fbank.get_frame(i)
  92 + features.append(f)
  93 + features = np.stack(features, axis=0)
  94 +
  95 + return features
  96 +
  97 +
  98 +class OnnxModel:
  99 + def __init__(
  100 + self,
  101 + filename: str,
  102 + ):
  103 + session_opts = ort.SessionOptions()
  104 + session_opts.inter_op_num_threads = 1
  105 + session_opts.intra_op_num_threads = 1
  106 +
  107 + self.session_opts = session_opts
  108 +
  109 + self.model = ort.InferenceSession(
  110 + filename,
  111 + sess_options=self.session_opts,
  112 + )
  113 +
  114 + meta = self.model.get_modelmeta().custom_metadata_map
  115 + self.normalize_samples = int(meta["normalize_samples"])
  116 + self.sample_rate = int(meta["sample_rate"])
  117 + self.output_dim = int(meta["output_dim"])
  118 + self.feature_normalize_type = meta["feature_normalize_type"]
  119 +
  120 + def __call__(self, x: np.ndarray) -> np.ndarray:
  121 + """
  122 + Args:
  123 + x:
  124 + A 2-D float32 tensor of shape (T, C).
  125 + y:
  126 + A 1-D float32 tensor containing model output.
  127 + """
  128 + x = np.expand_dims(x, axis=0)
  129 +
  130 + return self.model.run(
  131 + [
  132 + self.model.get_outputs()[0].name,
  133 + ],
  134 + {
  135 + self.model.get_inputs()[0].name: x,
  136 + },
  137 + )[0][0]
  138 +
  139 +
  140 +def main():
  141 + args = get_args()
  142 + print(args)
  143 + filename = Path(args.model)
  144 + file1 = Path(args.file1)
  145 + file2 = Path(args.file2)
  146 + assert filename.is_file(), filename
  147 + assert file1.is_file(), file1
  148 + assert file2.is_file(), file2
  149 +
  150 + model = OnnxModel(filename)
  151 + wave1 = read_wavefile(file1, model.sample_rate)
  152 + wave2 = read_wavefile(file2, model.sample_rate)
  153 +
  154 + if not model.normalize_samples:
  155 + wave1 = wave1 * 32768
  156 + wave2 = wave2 * 32768
  157 +
  158 + features1 = compute_features(wave1, model.sample_rate)
  159 + features2 = compute_features(wave2, model.sample_rate)
  160 +
  161 + if model.feature_normalize_type == "global-mean":
  162 + features1 -= features1.mean(axis=0, keepdims=True)
  163 + features2 -= features2.mean(axis=0, keepdims=True)
  164 +
  165 + output1 = model(features1)
  166 + output2 = model(features2)
  167 +
  168 + similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
  169 + print(f"similarity in the range [0-1]: {similarity}")
  170 +
  171 +
  172 +if __name__ == "__main__":
  173 + main()
@@ -124,7 +124,7 @@ def main(): @@ -124,7 +124,7 @@ def main():
124 124
125 # all models from wespeaker expect input samples in the range 125 # all models from wespeaker expect input samples in the range
126 # [-32768, 32767] 126 # [-32768, 32767]
127 - normalize_features = 0 127 + normalize_samples = 0
128 128
129 meta_data = { 129 meta_data = {
130 "framework": "wespeaker", 130 "framework": "wespeaker",
@@ -133,7 +133,7 @@ def main(): @@ -133,7 +133,7 @@ def main():
133 "comment": comment, 133 "comment": comment,
134 "sample_rate": sample_rate, 134 "sample_rate": sample_rate,
135 "output_dim": output_dim, 135 "output_dim": output_dim,
136 - "normalize_features": normalize_features, 136 + "normalize_samples": normalize_samples,
137 } 137 }
138 print(meta_data) 138 print(meta_data)
139 add_meta_data(filename=str(model), meta_data=meta_data) 139 add_meta_data(filename=str(model), meta_data=meta_data)
@@ -3,7 +3,7 @@ @@ -3,7 +3,7 @@
3 3
4 """ 4 """
5 This script computes speaker similarity score in the range [0-1] 5 This script computes speaker similarity score in the range [0-1]
6 -of two wave files using a speaker recognition model. 6 +of two wave files using a speaker embedding model.
7 """ 7 """
8 import argparse 8 import argparse
9 import wave 9 import wave
@@ -54,8 +54,6 @@ def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: @@ -54,8 +54,6 @@ def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray:
54 """ 54 """
55 filename = str(filename) 55 filename = str(filename)
56 with wave.open(filename) as f: 56 with wave.open(filename) as f:
57 - # Note: If wave_file_sample_rate is different from  
58 - # recognizer.sample_rate, we will do resampling inside sherpa-ncnn  
59 wave_file_sample_rate = f.getframerate() 57 wave_file_sample_rate = f.getframerate()
60 assert wave_file_sample_rate == expected_sample_rate, ( 58 assert wave_file_sample_rate == expected_sample_rate, (
61 wave_file_sample_rate, 59 wave_file_sample_rate,
@@ -104,7 +102,7 @@ class OnnxModel: @@ -104,7 +102,7 @@ class OnnxModel:
104 ): 102 ):
105 session_opts = ort.SessionOptions() 103 session_opts = ort.SessionOptions()
106 session_opts.inter_op_num_threads = 1 104 session_opts.inter_op_num_threads = 1
107 - session_opts.intra_op_num_threads = 4 105 + session_opts.intra_op_num_threads = 1
108 106
109 self.session_opts = session_opts 107 self.session_opts = session_opts
110 108
@@ -114,7 +112,7 @@ class OnnxModel: @@ -114,7 +112,7 @@ class OnnxModel:
114 ) 112 )
115 113
116 meta = self.model.get_modelmeta().custom_metadata_map 114 meta = self.model.get_modelmeta().custom_metadata_map
117 - self.normalize_features = int(meta["normalize_features"]) 115 + self.normalize_samples = int(meta["normalize_samples"])
118 self.sample_rate = int(meta["sample_rate"]) 116 self.sample_rate = int(meta["sample_rate"])
119 self.output_dim = int(meta["output_dim"]) 117 self.output_dim = int(meta["output_dim"])
120 118
@@ -151,7 +149,7 @@ def main(): @@ -151,7 +149,7 @@ def main():
151 wave1 = read_wavefile(file1, model.sample_rate) 149 wave1 = read_wavefile(file1, model.sample_rate)
152 wave2 = read_wavefile(file2, model.sample_rate) 150 wave2 = read_wavefile(file2, model.sample_rate)
153 151
154 - if not model.normalize_features: 152 + if not model.normalize_samples:
155 wave1 = wave1 * 32768 153 wave1 = wave1 * 32768
156 wave2 = wave2 * 32768 154 wave2 = wave2 * 32768
157 155
@@ -161,8 +159,6 @@ def main(): @@ -161,8 +159,6 @@ def main():
161 output1 = model(features1) 159 output1 = model(features1)
162 output2 = model(features2) 160 output2 = model(features2)
163 161
164 - print(output1.shape)  
165 - print(output2.shape)  
166 similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) 162 similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
167 print(f"similarity in the range [0-1]: {similarity}") 163 print(f"similarity in the range [0-1]: {similarity}")
168 164
@@ -27,7 +27,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl @@ -27,7 +27,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl
27 FeatureExtractorConfig feat_config; 27 FeatureExtractorConfig feat_config;
28 auto meta_data = model_.GetMetaData(); 28 auto meta_data = model_.GetMetaData();
29 feat_config.sampling_rate = meta_data.sample_rate; 29 feat_config.sampling_rate = meta_data.sample_rate;
30 - feat_config.normalize_samples = meta_data.normalize_features; 30 + feat_config.normalize_samples = meta_data.normalize_samples;
31 31
32 return std::make_unique<OnlineStream>(feat_config); 32 return std::make_unique<OnlineStream>(feat_config);
33 } 33 }
@@ -12,7 +12,7 @@ namespace sherpa_onnx { @@ -12,7 +12,7 @@ namespace sherpa_onnx {
12 struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { 12 struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
13 int32_t output_dim = 0; 13 int32_t output_dim = 0;
14 int32_t sample_rate = 0; 14 int32_t sample_rate = 0;
15 - int32_t normalize_features = 0; 15 + int32_t normalize_samples = 0;
16 std::string language; 16 std::string language;
17 }; 17 };
18 18
@@ -61,8 +61,8 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { @@ -61,8 +61,8 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
61 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 61 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
62 SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); 62 SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
63 SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); 63 SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
64 - SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,  
65 - "normalize_features"); 64 + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples,
  65 + "normalize_samples");
66 SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); 66 SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
67 67
68 std::string framework; 68 std::string framework;