Committed by
GitHub
Support exporting models to onnx from 3D-Speaker (#522)
正在显示
10 个修改的文件
包含
442 行增加
和
14 行删除
| 1 | +name: export-3dspeaker-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-3dspeaker-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-3dspeaker-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export 3d-speaker to ONNX | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.8"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v4 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Run | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + cd scripts/3dspeaker | ||
| 33 | + ./run.sh | ||
| 34 | + | ||
| 35 | + mv -v *.onnx ../.. | ||
| 36 | + | ||
| 37 | + - name: Release | ||
| 38 | + uses: svenstaro/upload-release-action@v2 | ||
| 39 | + with: | ||
| 40 | + file_glob: true | ||
| 41 | + file: ./*.onnx | ||
| 42 | + overwrite: true | ||
| 43 | + repo_name: k2-fsa/sherpa-onnx | ||
| 44 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 45 | + tag: speaker-recongition-models |
scripts/3dspeaker/README.md
0 → 100644
scripts/3dspeaker/export-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +import json | ||
| 6 | +import os | ||
| 7 | +import pathlib | ||
| 8 | +import re | ||
| 9 | +from typing import Dict | ||
| 10 | + | ||
| 11 | +import onnx | ||
| 12 | +import torch | ||
| 13 | +from infer_sv import supports | ||
| 14 | +from modelscope.hub.snapshot_download import snapshot_download | ||
| 15 | +from speakerlab.utils.builder import dynamic_import | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 19 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 20 | + | ||
| 21 | + Args: | ||
| 22 | + filename: | ||
| 23 | + Filename of the ONNX model to be changed. | ||
| 24 | + meta_data: | ||
| 25 | + Key-value pairs. | ||
| 26 | + """ | ||
| 27 | + model = onnx.load(filename) | ||
| 28 | + for key, value in meta_data.items(): | ||
| 29 | + meta = model.metadata_props.add() | ||
| 30 | + meta.key = key | ||
| 31 | + meta.value = str(value) | ||
| 32 | + | ||
| 33 | + onnx.save(model, filename) | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def get_args(): | ||
| 37 | + parser = argparse.ArgumentParser() | ||
| 38 | + parser.add_argument( | ||
| 39 | + "--model", | ||
| 40 | + type=str, | ||
| 41 | + required=True, | ||
| 42 | + choices=[ | ||
| 43 | + "speech_campplus_sv_en_voxceleb_16k", | ||
| 44 | + "speech_campplus_sv_zh-cn_16k-common", | ||
| 45 | + "speech_eres2net_sv_en_voxceleb_16k", | ||
| 46 | + "speech_eres2net_sv_zh-cn_16k-common", | ||
| 47 | + "speech_eres2net_base_200k_sv_zh-cn_16k-common", | ||
| 48 | + "speech_eres2net_base_sv_zh-cn_3dspeaker_16k", | ||
| 49 | + "speech_eres2net_large_sv_zh-cn_3dspeaker_16k", | ||
| 50 | + ], | ||
| 51 | + ) | ||
| 52 | + return parser.parse_args() | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +@torch.no_grad() | ||
| 56 | +def main(): | ||
| 57 | + args = get_args() | ||
| 58 | + local_model_dir = "pretrained" | ||
| 59 | + model_id = f"damo/{args.model}" | ||
| 60 | + conf = supports[model_id] | ||
| 61 | + cache_dir = snapshot_download( | ||
| 62 | + model_id, | ||
| 63 | + revision=conf["revision"], | ||
| 64 | + ) | ||
| 65 | + cache_dir = pathlib.Path(cache_dir) | ||
| 66 | + | ||
| 67 | + save_dir = os.path.join(local_model_dir, model_id.split("/")[1]) | ||
| 68 | + save_dir = pathlib.Path(save_dir) | ||
| 69 | + save_dir.mkdir(exist_ok=True, parents=True) | ||
| 70 | + | ||
| 71 | + download_files = ["examples", conf["model_pt"]] | ||
| 72 | + for src in cache_dir.glob("*"): | ||
| 73 | + if re.search("|".join(download_files), src.name): | ||
| 74 | + dst = save_dir / src.name | ||
| 75 | + try: | ||
| 76 | + dst.unlink() | ||
| 77 | + except FileNotFoundError: | ||
| 78 | + pass | ||
| 79 | + dst.symlink_to(src) | ||
| 80 | + pretrained_model = save_dir / conf["model_pt"] | ||
| 81 | + pretrained_state = torch.load(pretrained_model, map_location="cpu") | ||
| 82 | + | ||
| 83 | + model = conf["model"] | ||
| 84 | + embedding_model = dynamic_import(model["obj"])(**model["args"]) | ||
| 85 | + embedding_model.load_state_dict(pretrained_state) | ||
| 86 | + embedding_model.eval() | ||
| 87 | + | ||
| 88 | + with open(f"{cache_dir}/configuration.json") as f: | ||
| 89 | + json_config = json.loads(f.read()) | ||
| 90 | + print(json_config) | ||
| 91 | + | ||
| 92 | + T = 100 | ||
| 93 | + C = 80 | ||
| 94 | + x = torch.rand(1, T, C) | ||
| 95 | + filename = f"{args.model}.onnx" | ||
| 96 | + torch.onnx.export( | ||
| 97 | + embedding_model, | ||
| 98 | + x, | ||
| 99 | + filename, | ||
| 100 | + opset_version=13, | ||
| 101 | + input_names=["x"], | ||
| 102 | + output_names=["embedding"], | ||
| 103 | + dynamic_axes={ | ||
| 104 | + "x": {0: "N", 1: "T"}, | ||
| 105 | + "embeddings": {0: "N"}, | ||
| 106 | + }, | ||
| 107 | + ) | ||
| 108 | + | ||
| 109 | + # all models from 3d-speaker expect input samples in the range | ||
| 110 | + # [-1, 1] | ||
| 111 | + normalize_samples = 1 | ||
| 112 | + | ||
| 113 | + # all models from 3d-speaker normalize the features by the global mean | ||
| 114 | + feature_normalize_type = "global-mean" | ||
| 115 | + sample_rate = json_config["model"]["model_config"]["sample_rate"] | ||
| 116 | + | ||
| 117 | + feat_dim = conf["model"]["args"]["feat_dim"] | ||
| 118 | + assert feat_dim == 80, feat_dim | ||
| 119 | + | ||
| 120 | + output_dim = conf["model"]["args"]["embedding_size"] | ||
| 121 | + | ||
| 122 | + if "zh-cn" in args.model: | ||
| 123 | + language = "Chinese" | ||
| 124 | + elif "en" in args.model: | ||
| 125 | + language = "English" | ||
| 126 | + else: | ||
| 127 | + raise ValueError(f"Unsupported language for model {args.model}") | ||
| 128 | + | ||
| 129 | + comment = f"This model is from damo/{args.model}" | ||
| 130 | + url = f"https://www.modelscope.cn/models/damo/{args.model}/summary" | ||
| 131 | + | ||
| 132 | + meta_data = { | ||
| 133 | + "framework": "3d-speaker", | ||
| 134 | + "language": language, | ||
| 135 | + "url": url, | ||
| 136 | + "comment": comment, | ||
| 137 | + "sample_rate": sample_rate, | ||
| 138 | + "output_dim": output_dim, | ||
| 139 | + "normalize_samples": normalize_samples, | ||
| 140 | + "feature_normalize_type": feature_normalize_type, | ||
| 141 | + } | ||
| 142 | + print(meta_data) | ||
| 143 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 144 | + | ||
| 145 | + | ||
| 146 | +main() |
scripts/3dspeaker/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +function install_3d_speaker() { | ||
| 6 | + echo "Install 3D-Speaker" | ||
| 7 | + git clone https://github.com/alibaba-damo-academy/3D-Speaker.git | ||
| 8 | + pushd 3D-Speaker | ||
| 9 | + pip install -q -r ./requirements.txt | ||
| 10 | + pip install -q modelscope onnx onnxruntime kaldi-native-fbank | ||
| 11 | + popd | ||
| 12 | +} | ||
| 13 | + | ||
| 14 | +function download_test_data() { | ||
| 15 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav | ||
| 16 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav | ||
| 17 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav | ||
| 18 | + | ||
| 19 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_en_16k.wav | ||
| 20 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_en_16k.wav | ||
| 21 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_en_16k.wav | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +install_3d_speaker | ||
| 25 | + | ||
| 26 | +download_test_data | ||
| 27 | + | ||
| 28 | +export PYTHONPATH=$PWD/3D-Speaker:$PYTHONPATH | ||
| 29 | +export PYTHONPATH=$PWD/3D-Speaker/speakerlab/bin:$PYTHONPATH | ||
| 30 | + | ||
| 31 | +models=( | ||
| 32 | +speech_campplus_sv_en_voxceleb_16k | ||
| 33 | +speech_campplus_sv_zh-cn_16k-common | ||
| 34 | +speech_eres2net_sv_en_voxceleb_16k | ||
| 35 | +speech_eres2net_sv_zh-cn_16k-common | ||
| 36 | +speech_eres2net_base_200k_sv_zh-cn_16k-common | ||
| 37 | +speech_eres2net_base_sv_zh-cn_3dspeaker_16k | ||
| 38 | +speech_eres2net_large_sv_zh-cn_3dspeaker_16k | ||
| 39 | +) | ||
| 40 | +for model in ${models[@]}; do | ||
| 41 | + echo "--------------------$model--------------------" | ||
| 42 | + python3 ./export-onnx.py --model $model | ||
| 43 | + | ||
| 44 | + python3 ./test-onnx.py \ | ||
| 45 | + --model ${model}.onnx \ | ||
| 46 | + --file1 ./speaker1_a_cn_16k.wav \ | ||
| 47 | + --file2 ./speaker1_b_cn_16k.wav | ||
| 48 | + | ||
| 49 | + python3 ./test-onnx.py \ | ||
| 50 | + --model ${model}.onnx \ | ||
| 51 | + --file1 ./speaker1_a_cn_16k.wav \ | ||
| 52 | + --file2 ./speaker2_a_cn_16k.wav | ||
| 53 | + | ||
| 54 | + python3 ./test-onnx.py \ | ||
| 55 | + --model ${model}.onnx \ | ||
| 56 | + --file1 ./speaker1_a_en_16k.wav \ | ||
| 57 | + --file2 ./speaker1_b_en_16k.wav | ||
| 58 | + | ||
| 59 | + python3 ./test-onnx.py \ | ||
| 60 | + --model ${model}.onnx \ | ||
| 61 | + --file1 ./speaker1_a_en_16k.wav \ | ||
| 62 | + --file2 ./speaker2_a_en_16k.wav | ||
| 63 | +done |
scripts/3dspeaker/test-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +This script computes speaker similarity score in the range [0-1] | ||
| 6 | +of two wave files using a speaker embedding model. | ||
| 7 | +""" | ||
| 8 | +import argparse | ||
| 9 | +import wave | ||
| 10 | +from pathlib import Path | ||
| 11 | + | ||
| 12 | +import kaldi_native_fbank as knf | ||
| 13 | +import numpy as np | ||
| 14 | +import onnxruntime as ort | ||
| 15 | +from numpy.linalg import norm | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def get_args(): | ||
| 19 | + parser = argparse.ArgumentParser() | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--model", | ||
| 22 | + type=str, | ||
| 23 | + required=True, | ||
| 24 | + help="Path to the input onnx model. Example value: model.onnx", | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--file1", | ||
| 29 | + type=str, | ||
| 30 | + required=True, | ||
| 31 | + help="Input wave 1", | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + parser.add_argument( | ||
| 35 | + "--file2", | ||
| 36 | + type=str, | ||
| 37 | + required=True, | ||
| 38 | + help="Input wave 2", | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + return parser.parse_args() | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: | ||
| 45 | + """ | ||
| 46 | + Args: | ||
| 47 | + filename: | ||
| 48 | + Path to a wave file, which must be of 16-bit and 16kHz. | ||
| 49 | + expected_sample_rate: | ||
| 50 | + Expected sample rate of the wave file. | ||
| 51 | + Returns: | ||
| 52 | + Return a 1-D float32 array containing audio samples. Each sample is in | ||
| 53 | + the range [-1, 1]. | ||
| 54 | + """ | ||
| 55 | + filename = str(filename) | ||
| 56 | + with wave.open(filename) as f: | ||
| 57 | + wave_file_sample_rate = f.getframerate() | ||
| 58 | + assert wave_file_sample_rate == expected_sample_rate, ( | ||
| 59 | + wave_file_sample_rate, | ||
| 60 | + expected_sample_rate, | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + num_channels = f.getnchannels() | ||
| 64 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 65 | + num_samples = f.getnframes() | ||
| 66 | + samples = f.readframes(num_samples) | ||
| 67 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 68 | + samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0] | ||
| 69 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 70 | + | ||
| 71 | + samples_float32 = samples_float32 / 32768 | ||
| 72 | + | ||
| 73 | + return samples_float32 | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray: | ||
| 77 | + opts = knf.FbankOptions() | ||
| 78 | + opts.frame_opts.dither = 0 | ||
| 79 | + opts.frame_opts.samp_freq = sample_rate | ||
| 80 | + opts.frame_opts.snip_edges = True | ||
| 81 | + | ||
| 82 | + opts.mel_opts.num_bins = 80 | ||
| 83 | + opts.mel_opts.debug_mel = False | ||
| 84 | + | ||
| 85 | + fbank = knf.OnlineFbank(opts) | ||
| 86 | + fbank.accept_waveform(sample_rate, samples) | ||
| 87 | + fbank.input_finished() | ||
| 88 | + | ||
| 89 | + features = [] | ||
| 90 | + for i in range(fbank.num_frames_ready): | ||
| 91 | + f = fbank.get_frame(i) | ||
| 92 | + features.append(f) | ||
| 93 | + features = np.stack(features, axis=0) | ||
| 94 | + | ||
| 95 | + return features | ||
| 96 | + | ||
| 97 | + | ||
| 98 | +class OnnxModel: | ||
| 99 | + def __init__( | ||
| 100 | + self, | ||
| 101 | + filename: str, | ||
| 102 | + ): | ||
| 103 | + session_opts = ort.SessionOptions() | ||
| 104 | + session_opts.inter_op_num_threads = 1 | ||
| 105 | + session_opts.intra_op_num_threads = 1 | ||
| 106 | + | ||
| 107 | + self.session_opts = session_opts | ||
| 108 | + | ||
| 109 | + self.model = ort.InferenceSession( | ||
| 110 | + filename, | ||
| 111 | + sess_options=self.session_opts, | ||
| 112 | + ) | ||
| 113 | + | ||
| 114 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 115 | + self.normalize_samples = int(meta["normalize_samples"]) | ||
| 116 | + self.sample_rate = int(meta["sample_rate"]) | ||
| 117 | + self.output_dim = int(meta["output_dim"]) | ||
| 118 | + self.feature_normalize_type = meta["feature_normalize_type"] | ||
| 119 | + | ||
| 120 | + def __call__(self, x: np.ndarray) -> np.ndarray: | ||
| 121 | + """ | ||
| 122 | + Args: | ||
| 123 | + x: | ||
| 124 | + A 2-D float32 tensor of shape (T, C). | ||
| 125 | + y: | ||
| 126 | + A 1-D float32 tensor containing model output. | ||
| 127 | + """ | ||
| 128 | + x = np.expand_dims(x, axis=0) | ||
| 129 | + | ||
| 130 | + return self.model.run( | ||
| 131 | + [ | ||
| 132 | + self.model.get_outputs()[0].name, | ||
| 133 | + ], | ||
| 134 | + { | ||
| 135 | + self.model.get_inputs()[0].name: x, | ||
| 136 | + }, | ||
| 137 | + )[0][0] | ||
| 138 | + | ||
| 139 | + | ||
| 140 | +def main(): | ||
| 141 | + args = get_args() | ||
| 142 | + print(args) | ||
| 143 | + filename = Path(args.model) | ||
| 144 | + file1 = Path(args.file1) | ||
| 145 | + file2 = Path(args.file2) | ||
| 146 | + assert filename.is_file(), filename | ||
| 147 | + assert file1.is_file(), file1 | ||
| 148 | + assert file2.is_file(), file2 | ||
| 149 | + | ||
| 150 | + model = OnnxModel(filename) | ||
| 151 | + wave1 = read_wavefile(file1, model.sample_rate) | ||
| 152 | + wave2 = read_wavefile(file2, model.sample_rate) | ||
| 153 | + | ||
| 154 | + if not model.normalize_samples: | ||
| 155 | + wave1 = wave1 * 32768 | ||
| 156 | + wave2 = wave2 * 32768 | ||
| 157 | + | ||
| 158 | + features1 = compute_features(wave1, model.sample_rate) | ||
| 159 | + features2 = compute_features(wave2, model.sample_rate) | ||
| 160 | + | ||
| 161 | + if model.feature_normalize_type == "global-mean": | ||
| 162 | + features1 -= features1.mean(axis=0, keepdims=True) | ||
| 163 | + features2 -= features2.mean(axis=0, keepdims=True) | ||
| 164 | + | ||
| 165 | + output1 = model(features1) | ||
| 166 | + output2 = model(features2) | ||
| 167 | + | ||
| 168 | + similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) | ||
| 169 | + print(f"similarity in the range [0-1]: {similarity}") | ||
| 170 | + | ||
| 171 | + | ||
| 172 | +if __name__ == "__main__": | ||
| 173 | + main() |
| @@ -124,7 +124,7 @@ def main(): | @@ -124,7 +124,7 @@ def main(): | ||
| 124 | 124 | ||
| 125 | # all models from wespeaker expect input samples in the range | 125 | # all models from wespeaker expect input samples in the range |
| 126 | # [-32768, 32767] | 126 | # [-32768, 32767] |
| 127 | - normalize_features = 0 | 127 | + normalize_samples = 0 |
| 128 | 128 | ||
| 129 | meta_data = { | 129 | meta_data = { |
| 130 | "framework": "wespeaker", | 130 | "framework": "wespeaker", |
| @@ -133,7 +133,7 @@ def main(): | @@ -133,7 +133,7 @@ def main(): | ||
| 133 | "comment": comment, | 133 | "comment": comment, |
| 134 | "sample_rate": sample_rate, | 134 | "sample_rate": sample_rate, |
| 135 | "output_dim": output_dim, | 135 | "output_dim": output_dim, |
| 136 | - "normalize_features": normalize_features, | 136 | + "normalize_samples": normalize_samples, |
| 137 | } | 137 | } |
| 138 | print(meta_data) | 138 | print(meta_data) |
| 139 | add_meta_data(filename=str(model), meta_data=meta_data) | 139 | add_meta_data(filename=str(model), meta_data=meta_data) |
| @@ -3,7 +3,7 @@ | @@ -3,7 +3,7 @@ | ||
| 3 | 3 | ||
| 4 | """ | 4 | """ |
| 5 | This script computes speaker similarity score in the range [0-1] | 5 | This script computes speaker similarity score in the range [0-1] |
| 6 | -of two wave files using a speaker recognition model. | 6 | +of two wave files using a speaker embedding model. |
| 7 | """ | 7 | """ |
| 8 | import argparse | 8 | import argparse |
| 9 | import wave | 9 | import wave |
| @@ -54,8 +54,6 @@ def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: | @@ -54,8 +54,6 @@ def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: | ||
| 54 | """ | 54 | """ |
| 55 | filename = str(filename) | 55 | filename = str(filename) |
| 56 | with wave.open(filename) as f: | 56 | with wave.open(filename) as f: |
| 57 | - # Note: If wave_file_sample_rate is different from | ||
| 58 | - # recognizer.sample_rate, we will do resampling inside sherpa-ncnn | ||
| 59 | wave_file_sample_rate = f.getframerate() | 57 | wave_file_sample_rate = f.getframerate() |
| 60 | assert wave_file_sample_rate == expected_sample_rate, ( | 58 | assert wave_file_sample_rate == expected_sample_rate, ( |
| 61 | wave_file_sample_rate, | 59 | wave_file_sample_rate, |
| @@ -104,7 +102,7 @@ class OnnxModel: | @@ -104,7 +102,7 @@ class OnnxModel: | ||
| 104 | ): | 102 | ): |
| 105 | session_opts = ort.SessionOptions() | 103 | session_opts = ort.SessionOptions() |
| 106 | session_opts.inter_op_num_threads = 1 | 104 | session_opts.inter_op_num_threads = 1 |
| 107 | - session_opts.intra_op_num_threads = 4 | 105 | + session_opts.intra_op_num_threads = 1 |
| 108 | 106 | ||
| 109 | self.session_opts = session_opts | 107 | self.session_opts = session_opts |
| 110 | 108 | ||
| @@ -114,7 +112,7 @@ class OnnxModel: | @@ -114,7 +112,7 @@ class OnnxModel: | ||
| 114 | ) | 112 | ) |
| 115 | 113 | ||
| 116 | meta = self.model.get_modelmeta().custom_metadata_map | 114 | meta = self.model.get_modelmeta().custom_metadata_map |
| 117 | - self.normalize_features = int(meta["normalize_features"]) | 115 | + self.normalize_samples = int(meta["normalize_samples"]) |
| 118 | self.sample_rate = int(meta["sample_rate"]) | 116 | self.sample_rate = int(meta["sample_rate"]) |
| 119 | self.output_dim = int(meta["output_dim"]) | 117 | self.output_dim = int(meta["output_dim"]) |
| 120 | 118 | ||
| @@ -151,7 +149,7 @@ def main(): | @@ -151,7 +149,7 @@ def main(): | ||
| 151 | wave1 = read_wavefile(file1, model.sample_rate) | 149 | wave1 = read_wavefile(file1, model.sample_rate) |
| 152 | wave2 = read_wavefile(file2, model.sample_rate) | 150 | wave2 = read_wavefile(file2, model.sample_rate) |
| 153 | 151 | ||
| 154 | - if not model.normalize_features: | 152 | + if not model.normalize_samples: |
| 155 | wave1 = wave1 * 32768 | 153 | wave1 = wave1 * 32768 |
| 156 | wave2 = wave2 * 32768 | 154 | wave2 = wave2 * 32768 |
| 157 | 155 | ||
| @@ -161,8 +159,6 @@ def main(): | @@ -161,8 +159,6 @@ def main(): | ||
| 161 | output1 = model(features1) | 159 | output1 = model(features1) |
| 162 | output2 = model(features2) | 160 | output2 = model(features2) |
| 163 | 161 | ||
| 164 | - print(output1.shape) | ||
| 165 | - print(output2.shape) | ||
| 166 | similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) | 162 | similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) |
| 167 | print(f"similarity in the range [0-1]: {similarity}") | 163 | print(f"similarity in the range [0-1]: {similarity}") |
| 168 | 164 |
| @@ -27,7 +27,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | @@ -27,7 +27,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | ||
| 27 | FeatureExtractorConfig feat_config; | 27 | FeatureExtractorConfig feat_config; |
| 28 | auto meta_data = model_.GetMetaData(); | 28 | auto meta_data = model_.GetMetaData(); |
| 29 | feat_config.sampling_rate = meta_data.sample_rate; | 29 | feat_config.sampling_rate = meta_data.sample_rate; |
| 30 | - feat_config.normalize_samples = meta_data.normalize_features; | 30 | + feat_config.normalize_samples = meta_data.normalize_samples; |
| 31 | 31 | ||
| 32 | return std::make_unique<OnlineStream>(feat_config); | 32 | return std::make_unique<OnlineStream>(feat_config); |
| 33 | } | 33 | } |
| @@ -12,7 +12,7 @@ namespace sherpa_onnx { | @@ -12,7 +12,7 @@ namespace sherpa_onnx { | ||
| 12 | struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { | 12 | struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { |
| 13 | int32_t output_dim = 0; | 13 | int32_t output_dim = 0; |
| 14 | int32_t sample_rate = 0; | 14 | int32_t sample_rate = 0; |
| 15 | - int32_t normalize_features = 0; | 15 | + int32_t normalize_samples = 0; |
| 16 | std::string language; | 16 | std::string language; |
| 17 | }; | 17 | }; |
| 18 | 18 |
| @@ -61,8 +61,8 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | @@ -61,8 +61,8 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | ||
| 61 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 61 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 62 | SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); | 62 | SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); |
| 63 | SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | 63 | SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); |
| 64 | - SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features, | ||
| 65 | - "normalize_features"); | 64 | + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples, |
| 65 | + "normalize_samples"); | ||
| 66 | SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); | 66 | SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); |
| 67 | 67 | ||
| 68 | std::string framework; | 68 | std::string framework; |
-
请 注册 或 登录 后发表评论