Committed by
GitHub
Export speaker verification models from NeMo to ONNX (#526)
正在显示
10 个修改的文件
包含
448 行增加
和
28 行删除
| 1 | #!/usr/bin/env bash | 1 | #!/usr/bin/env bash |
| 2 | 2 | ||
| 3 | -set -e | 3 | +set -ex |
| 4 | 4 | ||
| 5 | log() { | 5 | log() { |
| 6 | # This function is from espnet | 6 | # This function is from espnet |
| @@ -21,18 +21,19 @@ model_dir=$d/wespeaker | @@ -21,18 +21,19 @@ model_dir=$d/wespeaker | ||
| 21 | mkdir -p $model_dir | 21 | mkdir -p $model_dir |
| 22 | pushd $model_dir | 22 | pushd $model_dir |
| 23 | models=( | 23 | models=( |
| 24 | -en_voxceleb_CAM++.onnx | ||
| 25 | -en_voxceleb_CAM++_LM.onnx | ||
| 26 | -en_voxceleb_resnet152_LM.onnx | ||
| 27 | -en_voxceleb_resnet221_LM.onnx | ||
| 28 | -en_voxceleb_resnet293_LM.onnx | ||
| 29 | -en_voxceleb_resnet34.onnx | ||
| 30 | -en_voxceleb_resnet34_LM.onnx | ||
| 31 | -zh_cnceleb_resnet34.onnx | ||
| 32 | -zh_cnceleb_resnet34_LM.onnx | 24 | +wespeaker_en_voxceleb_CAM++.onnx |
| 25 | +wespeaker_en_voxceleb_CAM++_LM.onnx | ||
| 26 | +wespeaker_en_voxceleb_resnet152_LM.onnx | ||
| 27 | +wespeaker_en_voxceleb_resnet221_LM.onnx | ||
| 28 | +wespeaker_en_voxceleb_resnet293_LM.onnx | ||
| 29 | +wespeaker_en_voxceleb_resnet34.onnx | ||
| 30 | +wespeaker_en_voxceleb_resnet34_LM.onnx | ||
| 31 | +wespeaker_zh_cnceleb_resnet34.onnx | ||
| 32 | +wespeaker_zh_cnceleb_resnet34_LM.onnx | ||
| 33 | ) | 33 | ) |
| 34 | for m in ${models[@]}; do | 34 | for m in ${models[@]}; do |
| 35 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | 35 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m |
| 36 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_en_voxceleb_CAM++_LM.onnx | ||
| 36 | done | 37 | done |
| 37 | ls -lh | 38 | ls -lh |
| 38 | popd | 39 | popd |
| @@ -42,13 +43,13 @@ model_dir=$d/3dspeaker | @@ -42,13 +43,13 @@ model_dir=$d/3dspeaker | ||
| 42 | mkdir -p $model_dir | 43 | mkdir -p $model_dir |
| 43 | pushd $model_dir | 44 | pushd $model_dir |
| 44 | models=( | 45 | models=( |
| 45 | -speech_campplus_sv_en_voxceleb_16k.onnx | ||
| 46 | -speech_campplus_sv_zh-cn_16k-common.onnx | ||
| 47 | -speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx | ||
| 48 | -speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 49 | -speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx | ||
| 50 | -speech_eres2net_sv_en_voxceleb_16k.onnx | ||
| 51 | -speech_eres2net_sv_zh-cn_16k-common.onnx | 46 | +3dspeaker_speech_campplus_sv_en_voxceleb_16k.onnx |
| 47 | +3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx | ||
| 48 | +3dspeaker_speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx | ||
| 49 | +3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 50 | +3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx | ||
| 51 | +3dspeaker_speech_eres2net_sv_en_voxceleb_16k.onnx | ||
| 52 | +3dspeaker_speech_eres2net_sv_zh-cn_16k-common.onnx | ||
| 52 | ) | 53 | ) |
| 53 | for m in ${models[@]}; do | 54 | for m in ${models[@]}; do |
| 54 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | 55 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m |
| 1 | +name: export-nemo-speaker-verification-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-speaker-verification-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-speaker-verification-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export nemo speaker verification models to ONNX | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [ubuntu-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v4 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Run | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + cd scripts/nemo/speaker-verification | ||
| 33 | + ./run.sh | ||
| 34 | + | ||
| 35 | + mv -v *.onnx ../../.. | ||
| 36 | + | ||
| 37 | + - name: Release | ||
| 38 | + uses: svenstaro/upload-release-action@v2 | ||
| 39 | + with: | ||
| 40 | + file_glob: true | ||
| 41 | + file: ./*.onnx | ||
| 42 | + overwrite: true | ||
| 43 | + repo_name: k2-fsa/sherpa-onnx | ||
| 44 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 45 | + tag: speaker-recongition-models |
| @@ -29,7 +29,7 @@ Please visit | @@ -29,7 +29,7 @@ Please visit | ||
| 29 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models | 29 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models |
| 30 | to download a model. An example is given below: | 30 | to download a model. An example is given below: |
| 31 | 31 | ||
| 32 | - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx | 32 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx |
| 33 | 33 | ||
| 34 | Note that `zh` means Chinese, while `en` means English. | 34 | Note that `zh` means Chinese, while `en` means English. |
| 35 | 35 | ||
| @@ -39,7 +39,7 @@ Assume the filename of the text file is speaker.txt. | @@ -39,7 +39,7 @@ Assume the filename of the text file is speaker.txt. | ||
| 39 | 39 | ||
| 40 | python3 ./python-api-examples/speaker-identification.py \ | 40 | python3 ./python-api-examples/speaker-identification.py \ |
| 41 | --speaker-file ./speaker.txt \ | 41 | --speaker-file ./speaker.txt \ |
| 42 | - --model ./zh_cnceleb_resnet34.onnx | 42 | + --model ./wespeaker_zh_cnceleb_resnet34.onnx |
| 43 | """ | 43 | """ |
| 44 | import argparse | 44 | import argparse |
| 45 | import queue | 45 | import queue |
| @@ -60,4 +60,6 @@ for model in ${models[@]}; do | @@ -60,4 +60,6 @@ for model in ${models[@]}; do | ||
| 60 | --model ${model}.onnx \ | 60 | --model ${model}.onnx \ |
| 61 | --file1 ./speaker1_a_en_16k.wav \ | 61 | --file1 ./speaker1_a_en_16k.wav \ |
| 62 | --file2 ./speaker2_a_en_16k.wav | 62 | --file2 ./speaker2_a_en_16k.wav |
| 63 | + | ||
| 64 | + mv ${model}.onnx 3dspeaker_${model}.onnx | ||
| 63 | done | 65 | done |
scripts/nemo/README.md
0 → 100644
scripts/nemo/speaker-verification/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This directory contains script for exporting speaker verification models | ||
| 4 | +from [NeMo](https://github.com/NVIDIA/NeMo/) to onnx | ||
| 5 | +so that you can use them in `sherpa-onnx`. | ||
| 6 | + | ||
| 7 | +Specifically, the following 4 models are exported to `sherpa-onnx` | ||
| 8 | +from | ||
| 9 | +[this page](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speaker_recognition/results.html#speaker-recognition-models): | ||
| 10 | + | ||
| 11 | + - [titanet_large](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_large), | ||
| 12 | + - [titanet_small](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_small) | ||
| 13 | + - [speakerverification_speakernet](https://ngc.nvidia.com/catalog/models/nvidia:nemo:speakerverification_speakernet) | ||
| 14 | + - [ecapa_tdnn](https://ngc.nvidia.com/catalog/models/nvidia:nemo:ecapa_tdnn) |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from typing import Dict | ||
| 6 | + | ||
| 7 | +import nemo.collections.asr as nemo_asr | ||
| 8 | +import onnx | ||
| 9 | +import torch | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def get_args(): | ||
| 13 | + parser = argparse.ArgumentParser() | ||
| 14 | + parser.add_argument( | ||
| 15 | + "--model", | ||
| 16 | + type=str, | ||
| 17 | + required=True, | ||
| 18 | + choices=[ | ||
| 19 | + "speakerverification_speakernet", | ||
| 20 | + "titanet_large", | ||
| 21 | + "titanet_small", | ||
| 22 | + "ecapa_tdnn", | ||
| 23 | + ], | ||
| 24 | + ) | ||
| 25 | + return parser.parse_args() | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 29 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 30 | + | ||
| 31 | + Args: | ||
| 32 | + filename: | ||
| 33 | + Filename of the ONNX model to be changed. | ||
| 34 | + meta_data: | ||
| 35 | + Key-value pairs. | ||
| 36 | + """ | ||
| 37 | + model = onnx.load(filename) | ||
| 38 | + for key, value in meta_data.items(): | ||
| 39 | + meta = model.metadata_props.add() | ||
| 40 | + meta.key = key | ||
| 41 | + meta.value = str(value) | ||
| 42 | + | ||
| 43 | + onnx.save(model, filename) | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +@torch.no_grad() | ||
| 47 | +def main(): | ||
| 48 | + args = get_args() | ||
| 49 | + speaker_model_config = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | ||
| 50 | + model_name=args.model, return_config=True | ||
| 51 | + ) | ||
| 52 | + preprocessor_config = speaker_model_config["preprocessor"] | ||
| 53 | + | ||
| 54 | + print(args.model) | ||
| 55 | + print(speaker_model_config) | ||
| 56 | + print(preprocessor_config) | ||
| 57 | + | ||
| 58 | + assert preprocessor_config["n_fft"] == 512, preprocessor_config | ||
| 59 | + | ||
| 60 | + assert ( | ||
| 61 | + preprocessor_config["_target_"] | ||
| 62 | + == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor" | ||
| 63 | + ), preprocessor_config | ||
| 64 | + | ||
| 65 | + assert preprocessor_config["frame_splicing"] == 1, preprocessor_config | ||
| 66 | + | ||
| 67 | + speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | ||
| 68 | + model_name=args.model | ||
| 69 | + ) | ||
| 70 | + speaker_model.eval() | ||
| 71 | + filename = f"nemo_en_{args.model}.onnx" | ||
| 72 | + speaker_model.export(filename) | ||
| 73 | + | ||
| 74 | + print(f"Adding metadata to {filename}") | ||
| 75 | + | ||
| 76 | + comment = "This model is from NeMo." | ||
| 77 | + url = { | ||
| 78 | + "titanet_large": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_large", | ||
| 79 | + "titanet_small": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_small", | ||
| 80 | + "speakerverification_speakernet": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:speakerverification_speakernet", | ||
| 81 | + "ecapa_tdnn": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:ecapa_tdnn", | ||
| 82 | + }[args.model] | ||
| 83 | + | ||
| 84 | + language = "English" | ||
| 85 | + | ||
| 86 | + meta_data = { | ||
| 87 | + "framework": "nemo", | ||
| 88 | + "language": language, | ||
| 89 | + "url": url, | ||
| 90 | + "comment": comment, | ||
| 91 | + "sample_rate": preprocessor_config["sample_rate"], | ||
| 92 | + "output_dim": speaker_model_config["decoder"]["emb_sizes"], | ||
| 93 | + "feature_normalize_type": preprocessor_config["normalize"], | ||
| 94 | + "window_size_ms": int(float(preprocessor_config["window_size"]) * 1000), | ||
| 95 | + "window_stride_ms": int(float(preprocessor_config["window_stride"]) * 1000), | ||
| 96 | + "window_type": preprocessor_config["window"], # e.g., hann | ||
| 97 | + "feat_dim": preprocessor_config["features"], | ||
| 98 | + } | ||
| 99 | + print(meta_data) | ||
| 100 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +if __name__ == "__main__": | ||
| 104 | + main() |
scripts/nemo/speaker-verification/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +function install_nemo() { | ||
| 7 | + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py | ||
| 8 | + python3 get-pip.py | ||
| 9 | + | ||
| 10 | + pip install torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html | ||
| 11 | + | ||
| 12 | + pip install wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile | ||
| 13 | + | ||
| 14 | + sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip | ||
| 15 | + | ||
| 16 | + BRANCH='main' | ||
| 17 | + python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 18 | +} | ||
| 19 | + | ||
| 20 | +install_nemo | ||
| 21 | + | ||
| 22 | +model_list=( | ||
| 23 | +speakerverification_speakernet | ||
| 24 | +titanet_large | ||
| 25 | +titanet_small | ||
| 26 | +# ecapa_tdnn # causes errors, see https://github.com/NVIDIA/NeMo/issues/8168 | ||
| 27 | +) | ||
| 28 | + | ||
| 29 | +for model in ${model_list[@]}; do | ||
| 30 | + python3 ./export-onnx.py --model $model | ||
| 31 | +done | ||
| 32 | + | ||
| 33 | +ls -lh | ||
| 34 | + | ||
| 35 | +function download_test_data() { | ||
| 36 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_en_16k.wav | ||
| 37 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_en_16k.wav | ||
| 38 | + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_en_16k.wav | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +download_test_data | ||
| 42 | + | ||
| 43 | +for model in ${model_list[@]}; do | ||
| 44 | + python3 ./test-onnx.py \ | ||
| 45 | + --model nemo_en_${model}.onnx \ | ||
| 46 | + --file1 ./speaker1_a_en_16k.wav \ | ||
| 47 | + --file2 ./speaker1_b_en_16k.wav | ||
| 48 | + | ||
| 49 | + python3 ./test-onnx.py \ | ||
| 50 | + --model nemo_en_${model}.onnx \ | ||
| 51 | + --file1 ./speaker1_a_en_16k.wav \ | ||
| 52 | + --file2 ./speaker2_a_en_16k.wav | ||
| 53 | +done |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +This script computes speaker similarity score in the range [0-1] | ||
| 6 | +of two wave files using a speaker embedding model. | ||
| 7 | +""" | ||
| 8 | +import argparse | ||
| 9 | +import wave | ||
| 10 | +from pathlib import Path | ||
| 11 | + | ||
| 12 | +import kaldi_native_fbank as knf | ||
| 13 | +import numpy as np | ||
| 14 | +import onnxruntime as ort | ||
| 15 | +from numpy.linalg import norm | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def get_args(): | ||
| 19 | + parser = argparse.ArgumentParser() | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--model", | ||
| 22 | + type=str, | ||
| 23 | + required=True, | ||
| 24 | + help="Path to the input onnx model. Example value: model.onnx", | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--file1", | ||
| 29 | + type=str, | ||
| 30 | + required=True, | ||
| 31 | + help="Input wave 1", | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + parser.add_argument( | ||
| 35 | + "--file2", | ||
| 36 | + type=str, | ||
| 37 | + required=True, | ||
| 38 | + help="Input wave 2", | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + return parser.parse_args() | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray: | ||
| 45 | + """ | ||
| 46 | + Args: | ||
| 47 | + filename: | ||
| 48 | + Path to a wave file, which must be of 16-bit and 16kHz. | ||
| 49 | + expected_sample_rate: | ||
| 50 | + Expected sample rate of the wave file. | ||
| 51 | + Returns: | ||
| 52 | + Return a 1-D float32 array containing audio samples. Each sample is in | ||
| 53 | + the range [-1, 1]. | ||
| 54 | + """ | ||
| 55 | + filename = str(filename) | ||
| 56 | + with wave.open(filename) as f: | ||
| 57 | + wave_file_sample_rate = f.getframerate() | ||
| 58 | + assert wave_file_sample_rate == expected_sample_rate, ( | ||
| 59 | + wave_file_sample_rate, | ||
| 60 | + expected_sample_rate, | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + num_channels = f.getnchannels() | ||
| 64 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 65 | + num_samples = f.getnframes() | ||
| 66 | + samples = f.readframes(num_samples) | ||
| 67 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 68 | + samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0] | ||
| 69 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 70 | + | ||
| 71 | + samples_float32 = samples_float32 / 32768 | ||
| 72 | + | ||
| 73 | + return samples_float32 | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +def compute_features(samples: np.ndarray, model: "OnnxModel") -> np.ndarray: | ||
| 77 | + fbank_opts = knf.FbankOptions() | ||
| 78 | + fbank_opts.frame_opts.samp_freq = model.sample_rate | ||
| 79 | + fbank_opts.frame_opts.frame_length_ms = model.window_size_ms | ||
| 80 | + fbank_opts.frame_opts.frame_shift_ms = model.window_stride_ms | ||
| 81 | + fbank_opts.frame_opts.dither = 0 | ||
| 82 | + fbank_opts.frame_opts.remove_dc_offset = False | ||
| 83 | + fbank_opts.frame_opts.window_type = model.window_type | ||
| 84 | + | ||
| 85 | + fbank_opts.mel_opts.num_bins = model.feat_dim | ||
| 86 | + fbank_opts.mel_opts.low_freq = 0 | ||
| 87 | + fbank_opts.mel_opts.is_librosa = True | ||
| 88 | + | ||
| 89 | + fbank = knf.OnlineFbank(fbank_opts) | ||
| 90 | + fbank.accept_waveform(model.sample_rate, samples) | ||
| 91 | + fbank.input_finished() | ||
| 92 | + | ||
| 93 | + features = [] | ||
| 94 | + for i in range(fbank.num_frames_ready): | ||
| 95 | + f = fbank.get_frame(i) | ||
| 96 | + features.append(f) | ||
| 97 | + features = np.stack(features, axis=0) | ||
| 98 | + # at this point, the shape of features is (T, C) | ||
| 99 | + | ||
| 100 | + if model.feature_normalize_type != "": | ||
| 101 | + assert model.feature_normalize_type == "per_feature" | ||
| 102 | + mean = np.mean(features, axis=0, keepdims=True) | ||
| 103 | + std = np.std(features, axis=0, keepdims=True) | ||
| 104 | + features = (features - mean) / std | ||
| 105 | + | ||
| 106 | + feature_len = features.shape[0] | ||
| 107 | + pad = 16 - feature_len % 16 | ||
| 108 | + | ||
| 109 | + if pad > 0: | ||
| 110 | + padding = np.zeros((pad, features.shape[1]), dtype=np.float32) | ||
| 111 | + features = np.concatenate([features, padding]) | ||
| 112 | + | ||
| 113 | + features = np.expand_dims(features, axis=0) | ||
| 114 | + | ||
| 115 | + return features, feature_len | ||
| 116 | + | ||
| 117 | + | ||
| 118 | +class OnnxModel: | ||
| 119 | + def __init__( | ||
| 120 | + self, | ||
| 121 | + filename: str, | ||
| 122 | + ): | ||
| 123 | + session_opts = ort.SessionOptions() | ||
| 124 | + session_opts.inter_op_num_threads = 1 | ||
| 125 | + session_opts.intra_op_num_threads = 1 | ||
| 126 | + | ||
| 127 | + self.session_opts = session_opts | ||
| 128 | + | ||
| 129 | + self.model = ort.InferenceSession( | ||
| 130 | + filename, | ||
| 131 | + sess_options=self.session_opts, | ||
| 132 | + ) | ||
| 133 | + | ||
| 134 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 135 | + self.framework = meta["framework"] | ||
| 136 | + self.sample_rate = int(meta["sample_rate"]) | ||
| 137 | + self.output_dim = int(meta["output_dim"]) | ||
| 138 | + self.feature_normalize_type = meta["feature_normalize_type"] | ||
| 139 | + self.window_size_ms = int(meta["window_size_ms"]) | ||
| 140 | + self.window_stride_ms = int(meta["window_stride_ms"]) | ||
| 141 | + self.window_type = meta["window_type"] | ||
| 142 | + self.feat_dim = int(meta["feat_dim"]) | ||
| 143 | + print(meta) | ||
| 144 | + | ||
| 145 | + assert self.framework == "nemo", self.framework | ||
| 146 | + | ||
| 147 | + def __call__(self, x: np.ndarray, x_lens: int) -> np.ndarray: | ||
| 148 | + """ | ||
| 149 | + Args: | ||
| 150 | + x: | ||
| 151 | + A 2-D float32 tensor of shape (T, C). | ||
| 152 | + y: | ||
| 153 | + A 1-D float32 tensor containing model output. | ||
| 154 | + """ | ||
| 155 | + x = x.transpose(0, 2, 1) # (B, T, C) -> (B, C, T) | ||
| 156 | + x_lens = np.asarray([x_lens], dtype=np.int64) | ||
| 157 | + | ||
| 158 | + return self.model.run( | ||
| 159 | + [ | ||
| 160 | + self.model.get_outputs()[1].name, | ||
| 161 | + ], | ||
| 162 | + { | ||
| 163 | + self.model.get_inputs()[0].name: x, | ||
| 164 | + self.model.get_inputs()[1].name: x_lens, | ||
| 165 | + }, | ||
| 166 | + )[0][0] | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +def main(): | ||
| 170 | + args = get_args() | ||
| 171 | + print(args) | ||
| 172 | + filename = Path(args.model) | ||
| 173 | + file1 = Path(args.file1) | ||
| 174 | + file2 = Path(args.file2) | ||
| 175 | + assert filename.is_file(), filename | ||
| 176 | + assert file1.is_file(), file1 | ||
| 177 | + assert file2.is_file(), file2 | ||
| 178 | + | ||
| 179 | + model = OnnxModel(filename) | ||
| 180 | + wave1 = read_wavefile(file1, model.sample_rate) | ||
| 181 | + wave2 = read_wavefile(file2, model.sample_rate) | ||
| 182 | + | ||
| 183 | + features1, features1_len = compute_features(wave1, model) | ||
| 184 | + features2, features2_len = compute_features(wave2, model) | ||
| 185 | + | ||
| 186 | + output1 = model(features1, features1_len) | ||
| 187 | + output2 = model(features2, features2_len) | ||
| 188 | + | ||
| 189 | + similarity = np.dot(output1, output2) / (norm(output1) * norm(output2)) | ||
| 190 | + print(f"similarity in the range [0-1]: {similarity}") | ||
| 191 | + | ||
| 192 | + | ||
| 193 | +if __name__ == "__main__": | ||
| 194 | + main() |
| @@ -24,7 +24,7 @@ ls -lh | @@ -24,7 +24,7 @@ ls -lh | ||
| 24 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 24 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 25 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 25 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 26 | 26 | ||
| 27 | -mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx | 27 | +mv voxceleb_resnet34.onnx wespeaker_en_voxceleb_resnet34.onnx |
| 28 | 28 | ||
| 29 | ./add_meta_data.py \ | 29 | ./add_meta_data.py \ |
| 30 | --model ./voxceleb_resnet34_LM.onnx \ | 30 | --model ./voxceleb_resnet34_LM.onnx \ |
| @@ -38,7 +38,7 @@ mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx | @@ -38,7 +38,7 @@ mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx | ||
| 38 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 38 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 39 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 39 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 40 | 40 | ||
| 41 | -mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx | 41 | +mv voxceleb_resnet34_LM.onnx wespeaker_en_voxceleb_resnet34_LM.onnx |
| 42 | 42 | ||
| 43 | ./add_meta_data.py \ | 43 | ./add_meta_data.py \ |
| 44 | --model ./voxceleb_resnet152_LM.onnx \ | 44 | --model ./voxceleb_resnet152_LM.onnx \ |
| @@ -53,7 +53,7 @@ mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx | @@ -53,7 +53,7 @@ mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx | ||
| 53 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 53 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 54 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 54 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 55 | 55 | ||
| 56 | -mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx | 56 | +mv voxceleb_resnet152_LM.onnx wespeaker_en_voxceleb_resnet152_LM.onnx |
| 57 | 57 | ||
| 58 | ./add_meta_data.py \ | 58 | ./add_meta_data.py \ |
| 59 | --model ./voxceleb_resnet221_LM.onnx \ | 59 | --model ./voxceleb_resnet221_LM.onnx \ |
| @@ -68,7 +68,7 @@ mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx | @@ -68,7 +68,7 @@ mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx | ||
| 68 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 68 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 69 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 69 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 70 | 70 | ||
| 71 | -mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx | 71 | +mv voxceleb_resnet221_LM.onnx wespeaker_en_voxceleb_resnet221_LM.onnx |
| 72 | 72 | ||
| 73 | ./add_meta_data.py \ | 73 | ./add_meta_data.py \ |
| 74 | --model ./voxceleb_resnet293_LM.onnx \ | 74 | --model ./voxceleb_resnet293_LM.onnx \ |
| @@ -83,7 +83,7 @@ mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx | @@ -83,7 +83,7 @@ mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx | ||
| 83 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 83 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 84 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 84 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 85 | 85 | ||
| 86 | -mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx | 86 | +mv voxceleb_resnet293_LM.onnx wespeaker_en_voxceleb_resnet293_LM.onnx |
| 87 | 87 | ||
| 88 | ./add_meta_data.py \ | 88 | ./add_meta_data.py \ |
| 89 | --model ./voxceleb_CAM++.onnx \ | 89 | --model ./voxceleb_CAM++.onnx \ |
| @@ -98,7 +98,7 @@ mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx | @@ -98,7 +98,7 @@ mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx | ||
| 98 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 98 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 99 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 99 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 100 | 100 | ||
| 101 | -mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx | 101 | +mv voxceleb_CAM++.onnx wespeaker_en_voxceleb_CAM++.onnx |
| 102 | 102 | ||
| 103 | ./add_meta_data.py \ | 103 | ./add_meta_data.py \ |
| 104 | --model ./voxceleb_CAM++_LM.onnx \ | 104 | --model ./voxceleb_CAM++_LM.onnx \ |
| @@ -113,20 +113,20 @@ mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx | @@ -113,20 +113,20 @@ mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx | ||
| 113 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ | 113 | --file1 ./wespeaker-models/test_wavs/00001_spk1.wav \ |
| 114 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav | 114 | --file2 ./wespeaker-models/test_wavs/00010_spk2.wav |
| 115 | 115 | ||
| 116 | -mv voxceleb_CAM++_LM.onnx en_voxceleb_CAM++_LM.onnx | 116 | +mv voxceleb_CAM++_LM.onnx wespeaker_en_voxceleb_CAM++_LM.onnx |
| 117 | 117 | ||
| 118 | ./add_meta_data.py \ | 118 | ./add_meta_data.py \ |
| 119 | --model ./cnceleb_resnet34.onnx \ | 119 | --model ./cnceleb_resnet34.onnx \ |
| 120 | --language Chinese \ | 120 | --language Chinese \ |
| 121 | --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx | 121 | --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx |
| 122 | 122 | ||
| 123 | -mv cnceleb_resnet34.onnx zh_cnceleb_resnet34.onnx | 123 | +mv cnceleb_resnet34.onnx wespeaker_zh_cnceleb_resnet34.onnx |
| 124 | 124 | ||
| 125 | ./add_meta_data.py \ | 125 | ./add_meta_data.py \ |
| 126 | --model ./cnceleb_resnet34_LM.onnx \ | 126 | --model ./cnceleb_resnet34_LM.onnx \ |
| 127 | --language Chinese \ | 127 | --language Chinese \ |
| 128 | --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx | 128 | --url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx |
| 129 | 129 | ||
| 130 | -mv cnceleb_resnet34_LM.onnx zh_cnceleb_resnet34_LM.onnx | 130 | +mv cnceleb_resnet34_LM.onnx wespeaker_zh_cnceleb_resnet34_LM.onnx |
| 131 | 131 | ||
| 132 | ls -lh | 132 | ls -lh |
-
请 注册 或 登录 后发表评论