Committed by
GitHub
convert wespeaker models to sherpa-onnx (#475)
正在显示
5 个修改的文件
包含
506 行增加
和
0 行删除
| 1 | +name: export-wespeaker-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-wespeaker-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-wespeaker-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export wespeaker | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [ubuntu-latest] | ||
| 19 | + python-version: ["3.8"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v2 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install Python dependencies | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + pip install kaldi-native-fbank numpy onnx onnxruntime | ||
| 33 | + | ||
| 34 | + - name: Run | ||
| 35 | + shell: bash | ||
| 36 | + run: | | ||
| 37 | + cd scripts/wespeaker | ||
| 38 | + ./run.sh | ||
| 39 | + | ||
| 40 | + mv -v *.onnx ../.. | ||
| 41 | + | ||
| 42 | + - name: Release | ||
| 43 | + uses: svenstaro/upload-release-action@v2 | ||
| 44 | + with: | ||
| 45 | + file_glob: true | ||
| 46 | + file: ./*.onnx | ||
| 47 | + overwrite: true | ||
| 48 | + repo_name: k2-fsa/sherpa-onnx | ||
| 49 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 50 | + tag: speaker-recongition-models |
scripts/wespeaker/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains script for adding meta data to onnx models from | ||
| 4 | +https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md | ||
| 5 | + | ||
| 6 | +You can use the models with metadata in sherpa-onnx. | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +**Caution**: You have to add model meta data to `*.onnx` since we plan | ||
| 10 | +to support models from different frameworks. |
scripts/wespeaker/add_meta_data.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +This script adds meta data to a model so that it can be used in sherpa-onnx. | ||
| 6 | + | ||
| 7 | +Usage: | ||
| 8 | +./add_meta_data.py --model ./voxceleb_resnet34.onnx --language English | ||
| 9 | +""" | ||
| 10 | + | ||
| 11 | +import argparse | ||
| 12 | +from pathlib import Path | ||
| 13 | +from typing import Dict | ||
| 14 | + | ||
| 15 | +import onnx | ||
| 16 | +import onnxruntime | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +def get_args(): | ||
| 20 | + parser = argparse.ArgumentParser() | ||
| 21 | + parser.add_argument( | ||
| 22 | + "--model", | ||
| 23 | + type=str, | ||
| 24 | + required=True, | ||
| 25 | + help="Path to the input onnx model. Example value: model.onnx", | ||
| 26 | + ) | ||
| 27 | + | ||
| 28 | + parser.add_argument( | ||
| 29 | + "--language", | ||
| 30 | + type=str, | ||
| 31 | + required=True, | ||
| 32 | + help="""Supported language of the input model. | ||
| 33 | + Example value: Chinese, English. | ||
| 34 | + """, | ||
| 35 | + ) | ||
| 36 | + | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--url", | ||
| 39 | + type=str, | ||
| 40 | + default="https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md", | ||
| 41 | + help="Where the model is downloaded", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument( | ||
| 45 | + "--comment", | ||
| 46 | + type=str, | ||
| 47 | + default="no comment", | ||
| 48 | + help="Comment about the model", | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + parser.add_argument( | ||
| 52 | + "--sample-rate", | ||
| 53 | + type=int, | ||
| 54 | + default=16000, | ||
| 55 | + help="Sample rate expected by the model", | ||
| 56 | + ) | ||
| 57 | + | ||
| 58 | + return parser.parse_args() | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 62 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 63 | + | ||
| 64 | + Args: | ||
| 65 | + filename: | ||
| 66 | + Filename of the ONNX model to be changed. | ||
| 67 | + meta_data: | ||
| 68 | + Key-value pairs. | ||
| 69 | + """ | ||
| 70 | + model = onnx.load(filename) | ||
| 71 | + for key, value in meta_data.items(): | ||
| 72 | + meta = model.metadata_props.add() | ||
| 73 | + meta.key = key | ||
| 74 | + meta.value = str(value) | ||
| 75 | + | ||
| 76 | + onnx.save(model, filename) | ||
| 77 | + | ||
| 78 | + | ||
| 79 | +def get_output_dim(filename) -> int: | ||
| 80 | + filename = str(filename) | ||
| 81 | + session_opts = onnxruntime.SessionOptions() | ||
| 82 | + session_opts.log_severity_level = 3 # error level | ||
| 83 | + sess = onnxruntime.InferenceSession(filename, session_opts) | ||
| 84 | + | ||
| 85 | + for i in sess.get_inputs(): | ||
| 86 | + print(i) | ||
| 87 | + | ||
| 88 | + print("----------") | ||
| 89 | + | ||
| 90 | + for o in sess.get_outputs(): | ||
| 91 | + print(o) | ||
| 92 | + | ||
| 93 | + print("----------") | ||
| 94 | + | ||
| 95 | + assert len(sess.get_inputs()) == 1 | ||
| 96 | + assert len(sess.get_outputs()) == 1 | ||
| 97 | + | ||
| 98 | + i = sess.get_inputs()[0] | ||
| 99 | + o = sess.get_outputs()[0] | ||
| 100 | + | ||
| 101 | + assert i.shape[:2] == ["B", "T"], i.shape | ||
| 102 | + assert o.shape[0] == "B" | ||
| 103 | + | ||
| 104 | + assert i.shape[2] == 80, i.shape | ||
| 105 | + | ||
| 106 | + return o.shape[1] | ||
| 107 | + | ||
| 108 | + | ||
| 109 | +def main(): | ||
| 110 | + args = get_args() | ||
| 111 | + model = Path(args.model) | ||
| 112 | + language = args.language | ||
| 113 | + url = args.url | ||
| 114 | + comment = args.comment | ||
| 115 | + sample_rate = args.sample_rate | ||
| 116 | + | ||
| 117 | + if not model.is_file(): | ||
| 118 | + raise ValueError(f"{model} does not exist") | ||
| 119 | + | ||
| 120 | + assert len(language) > 0, len(language) | ||
| 121 | + assert len(url) > 0, len(url) | ||
| 122 | + | ||
| 123 | + output_dim = get_output_dim(model) | ||
| 124 | + | ||
| 125 | + # all models from wespeaker expect input samples in the range | ||
| 126 | + # [-32768, 32767] | ||
| 127 | + normalize_features = 0 | ||
| 128 | + | ||
| 129 | + meta_data = { | ||
| 130 | + "framework": "wespeaker", | ||
| 131 | + "language": language, | ||
| 132 | + "url": url, | ||
| 133 | + "comment": comment, | ||
| 134 | + "sample_rate": sample_rate, | ||
| 135 | + "output_dim": output_dim, | ||
| 136 | + "normalize_features": normalize_features, | ||
| 137 | + } | ||
| 138 | + print(meta_data) | ||
| 139 | + add_meta_data(filename=str(model), meta_data=meta_data) | ||
| 140 | + | ||
| 141 | + | ||
| 142 | +if __name__ == "__main__": | ||
| 143 | + main() |
scripts/wespeaker/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +echo "Downloading models" | ||
| 6 | +export GIT_LFS_SKIP_SMUDGE=1 | ||
| 7 | +git clone https://huggingface.co/openspeech/wespeaker-models | ||
| 8 | +cd wespeaker-models | ||
| 9 | +git lfs pull --include "*.onnx" | ||
| 10 | +ls -lh | ||
| 11 | +cd .. | ||
| 12 | +mv wespeaker-models/*.onnx . | ||
| 13 | +ls -lh | ||
| 14 | + | ||
| 15 | +./add_meta_data.py \ | ||
| 16 | + --model ./voxceleb_resnet34.onnx \ | ||
| 17 | + --language English \ | ||
| 18 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.onnx | ||
| 19 | +./test.py --model ./voxceleb_resnet34.onnx \ | ||
| 20 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 21 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 22 | + | ||
| 23 | +./test.py --model ./voxceleb_resnet34.onnx \ | ||
| 24 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 25 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 26 | + | ||
| 27 | +mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx | ||
| 28 | + | ||
| 29 | +./add_meta_data.py \ | ||
| 30 | + --model ./voxceleb_resnet34_LM.onnx \ | ||
| 31 | + --language English \ | ||
| 32 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx | ||
| 33 | +./test.py --model ./voxceleb_resnet34_LM.onnx \ | ||
| 34 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 35 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 36 | + | ||
| 37 | +./test.py --model ./voxceleb_resnet34_LM.onnx \ | ||
| 38 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 39 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 40 | + | ||
| 41 | +mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx | ||
| 42 | + | ||
| 43 | +./add_meta_data.py \ | ||
| 44 | + --model ./voxceleb_resnet152_LM.onnx \ | ||
| 45 | + --language English \ | ||
| 46 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.onnx | ||
| 47 | + | ||
| 48 | +./test.py --model ./voxceleb_resnet152_LM.onnx \ | ||
| 49 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 50 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 51 | + | ||
| 52 | +./test.py --model ./voxceleb_resnet152_LM.onnx \ | ||
| 53 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 54 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 55 | + | ||
| 56 | +mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx | ||
| 57 | + | ||
| 58 | +./add_meta_data.py \ | ||
| 59 | + --model ./voxceleb_resnet221_LM.onnx \ | ||
| 60 | + --language English \ | ||
| 61 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.onnx | ||
| 62 | + | ||
| 63 | +./test.py --model ./voxceleb_resnet221_LM.onnx \ | ||
| 64 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 65 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 66 | + | ||
| 67 | +./test.py --model ./voxceleb_resnet221_LM.onnx \ | ||
| 68 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 69 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 70 | + | ||
| 71 | +mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx | ||
| 72 | + | ||
| 73 | +./add_meta_data.py \ | ||
| 74 | + --model ./voxceleb_resnet293_LM.onnx \ | ||
| 75 | + --language English \ | ||
| 76 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.onnx | ||
| 77 | + | ||
| 78 | +./test.py --model ./voxceleb_resnet293_LM.onnx \ | ||
| 79 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 80 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 81 | + | ||
| 82 | +./test.py --model ./voxceleb_resnet293_LM.onnx \ | ||
| 83 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 84 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 85 | + | ||
| 86 | +mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx | ||
| 87 | + | ||
| 88 | +./add_meta_data.py \ | ||
| 89 | + --model ./voxceleb_CAM++.onnx \ | ||
| 90 | + --language English \ | ||
| 91 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.onnx | ||
| 92 | + | ||
| 93 | +./test.py --model ./voxceleb_CAM++.onnx \ | ||
| 94 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 95 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 96 | + | ||
| 97 | +./test.py --model ./voxceleb_CAM++.onnx \ | ||
| 98 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 99 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 100 | + | ||
| 101 | +mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx | ||
| 102 | + | ||
| 103 | +./add_meta_data.py \ | ||
| 104 | + --model ./voxceleb_CAM++_LM.onnx \ | ||
| 105 | + --language English \ | ||
| 106 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.onnx | ||
| 107 | + | ||
| 108 | +./test.py --model ./voxceleb_CAM++_LM.onnx \ | ||
| 109 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 110 | + --file2 ./wespeaker-models/test_wavs/00024_spk1.wav \ | ||
| 111 | + | ||
| 112 | +./test.py --model ./voxceleb_CAM++_LM.onnx \ | ||
| 113 | + --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | ||
| 114 | + --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | ||
| 115 | + | ||
| 116 | +mv voxceleb_CAM++_LM.onnx en_voxceleb_CAM++_LM.onnx | ||
| 117 | + | ||
| 118 | +./add_meta_data.py \ | ||
| 119 | + --model ./cnceleb_resnet34.onnx \ | ||
| 120 | + --language Chinese \ | ||
| 121 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx | ||
| 122 | + | ||
| 123 | +mv cnceleb_resnet34.onnx zh_cnceleb_resnet34.onnx | ||
| 124 | + | ||
| 125 | +./add_meta_data.py \ | ||
| 126 | + --model ./cnceleb_resnet34_LM.onnx \ | ||
| 127 | + --language Chinese \ | ||
| 128 | + --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx | ||
| 129 | + | ||
| 130 | +mv cnceleb_resnet34_LM.onnx zh_cnceleb_resnet34_LM.onnx | ||
| 131 | + | ||
| 132 | +ls -lh |
scripts/wespeaker/test.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +This script computes speaker similarity score in the range [0-1] | ||
| 6 | +of two wave files using a speaker recognition model. | ||
| 7 | +""" | ||
| 8 | +import argparse | ||
| 9 | +import wave | ||
| 10 | +from pathlib import Path | ||
| 11 | + | ||
| 12 | +import kaldi_native_fbank as knf | ||
| 13 | +import numpy as np | ||
| 14 | +import onnxruntime as ort | ||
| 15 | +from numpy.linalg import norm | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def get_args(): | ||
| 19 | + parser = argparse.ArgumentParser() | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--model", | ||
| 22 | + type=str, | ||
| 23 | + required=True, | ||
| 24 | + help="Path to the input onnx model. Example value: model.onnx", | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--file1", | ||
| 29 | + type=str, | ||
| 30 | + required=True, | ||
| 31 | + help="Input wave 1", | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + parser.add_argument( | ||
| 35 | + "--file2", | ||
| 36 | + type=str, | ||
| 37 | + required=True, | ||
| 38 | + help="Input wave 2", | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + return parser.parse_args() | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: | ||
| 45 | + """ | ||
| 46 | + Args: | ||
| 47 | + filename: | ||
| 48 | + Path to a wave file, which must be of 16-bit and 16kHz. | ||
| 49 | + expected_sample_rate: | ||
| 50 | + Expected sample rate of the wave file. | ||
| 51 | + Returns: | ||
| 52 | + Return a 1-D float32 array containing audio samples. Each sample is in | ||
| 53 | + the range [-1, 1]. | ||
| 54 | + """ | ||
| 55 | + filename = str(filename) | ||
| 56 | + with wave.open(filename) as f: | ||
| 57 | + # Note: If wave_file_sample_rate is different from | ||
| 58 | + # recognizer.sample_rate, we will do resampling inside sherpa-ncnn | ||
| 59 | + wave_file_sample_rate = f.getframerate() | ||
| 60 | + assert wave_file_sample_rate == expected_sample_rate, ( | ||
| 61 | + wave_file_sample_rate, | ||
| 62 | + expected_sample_rate, | ||
| 63 | + ) | ||
| 64 | + | ||
| 65 | + num_channels = f.getnchannels() | ||
| 66 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 67 | + num_samples = f.getnframes() | ||
| 68 | + samples = f.readframes(num_samples) | ||
| 69 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 70 | + samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0] | ||
| 71 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 72 | + | ||
| 73 | + samples_float32 = samples_float32 / 32768 | ||
| 74 | + | ||
| 75 | + return samples_float32 | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray: | ||
| 79 | + opts = knf.FbankOptions() | ||
| 80 | + opts.frame_opts.dither = 0 | ||
| 81 | + opts.frame_opts.samp_freq = sample_rate | ||
| 82 | + opts.frame_opts.snip_edges = False | ||
| 83 | + | ||
| 84 | + opts.mel_opts.num_bins = 80 | ||
| 85 | + opts.mel_opts.debug_mel = False | ||
| 86 | + | ||
| 87 | + fbank = knf.OnlineFbank(opts) | ||
| 88 | + fbank.accept_waveform(sample_rate, samples) | ||
| 89 | + fbank.input_finished() | ||
| 90 | + | ||
| 91 | + features = [] | ||
| 92 | + for i in range(fbank.num_frames_ready): | ||
| 93 | + f = fbank.get_frame(i) | ||
| 94 | + features.append(f) | ||
| 95 | + features = np.stack(features, axis=0) | ||
| 96 | + | ||
| 97 | + return features | ||
| 98 | + | ||
| 99 | + | ||
| 100 | +class OnnxModel: | ||
| 101 | + def __init__( | ||
| 102 | + self, | ||
| 103 | + filename: str, | ||
| 104 | + ): | ||
| 105 | + session_opts = ort.SessionOptions() | ||
| 106 | + session_opts.inter_op_num_threads = 1 | ||
| 107 | + session_opts.intra_op_num_threads = 4 | ||
| 108 | + | ||
| 109 | + self.session_opts = session_opts | ||
| 110 | + | ||
| 111 | + self.model = ort.InferenceSession( | ||
| 112 | + filename, | ||
| 113 | + sess_options=self.session_opts, | ||
| 114 | + ) | ||
| 115 | + | ||
| 116 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 117 | + self.normalize_features = int(meta["normalize_features"]) | ||
| 118 | + self.sample_rate = int(meta["sample_rate"]) | ||
| 119 | + self.output_dim = int(meta["output_dim"]) | ||
| 120 | + | ||
| 121 | + def __call__(self, x: np.ndarray) -> np.ndarray: | ||
| 122 | + """ | ||
| 123 | + Args: | ||
| 124 | + x: | ||
| 125 | + A 2-D float32 tensor of shape (T, C). | ||
| 126 | + y: | ||
| 127 | + A 1-D float32 tensor containing model output. | ||
| 128 | + """ | ||
| 129 | + x = np.expand_dims(x, axis=0) | ||
| 130 | + | ||
| 131 | + return self.model.run( | ||
| 132 | + [ | ||
| 133 | + self.model.get_outputs()[0].name, | ||
| 134 | + ], | ||
| 135 | + { | ||
| 136 | + self.model.get_inputs()[0].name: x, | ||
| 137 | + }, | ||
| 138 | + )[0][0] | ||
| 139 | + | ||
| 140 | + | ||
| 141 | +def main(): | ||
| 142 | + args = get_args() | ||
| 143 | + filename = Path(args.model) | ||
| 144 | + file1 = Path(args.file1) | ||
| 145 | + file2 = Path(args.file2) | ||
| 146 | + assert filename.is_file(), filename | ||
| 147 | + assert file1.is_file(), file1 | ||
| 148 | + assert file2.is_file(), file2 | ||
| 149 | + | ||
| 150 | + model = OnnxModel(filename) | ||
| 151 | + wave1 = read_wavefile(file1, model.sample_rate) | ||
| 152 | + wave2 = read_wavefile(file2, model.sample_rate) | ||
| 153 | + | ||
| 154 | + if not model.normalize_features: | ||
| 155 | + wave1 = wave1 * 32768 | ||
| 156 | + wave2 = wave2 * 32768 | ||
| 157 | + | ||
| 158 | + features1 = compute_features(wave1, model.sample_rate) | ||
| 159 | + features2 = compute_features(wave2, model.sample_rate) | ||
| 160 | + | ||
| 161 | + output1 = model(features1) | ||
| 162 | + output2 = model(features2) | ||
| 163 | + | ||
| 164 | + print(output1.shape) | ||
| 165 | + print(output2.shape) | ||
| 166 | + similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) | ||
| 167 | + print(f"similarity in the range [0-1]: {similarity}") | ||
| 168 | + | ||
| 169 | + | ||
| 170 | +if __name__ == "__main__": | ||
| 171 | + main() |
-
请 注册 或 登录 后发表评论