Committed by
GitHub
Export non-streaming NeMo faster conformer hybrid transducer and ctc to sherpa-onnx (#847)
正在显示
16 个修改的文件
包含
1055 行增加
和
20 行删除
| 1 | +name: export-nemo-fast-conformer-ctc-non-streaming | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-fast-conformer-hybrid-transducer-ctc-non-streaming-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-fast-conformer-hybrid-transducer-ctc-non-streaming: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: Hybrid ctc non-streaming | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install NeMo | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + BRANCH='main' | ||
| 33 | + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 34 | + pip install onnxruntime | ||
| 35 | + pip install kaldi-native-fbank | ||
| 36 | + pip install soundfile librosa | ||
| 37 | + | ||
| 38 | + - name: Run | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
| 42 | + ./run-ctc-non-streaming.sh | ||
| 43 | + | ||
| 44 | + mv -v sherpa-onnx-nemo* ../../.. | ||
| 45 | + | ||
| 46 | + - name: Compress files | ||
| 47 | + shell: bash | ||
| 48 | + run: | | ||
| 49 | + dirs=( | ||
| 50 | + sherpa-onnx-nemo-fast-conformer-ctc-en-24500 | ||
| 51 | + sherpa-onnx-nemo-fast-conformer-ctc-es-1424 | ||
| 52 | + sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288 | ||
| 53 | + sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 54 | + ) | ||
| 55 | + for d in ${dirs[@]}; do | ||
| 56 | + tar cjvf ${d}.tar.bz2 ./$d | ||
| 57 | + done | ||
| 58 | + | ||
| 59 | + - name: Release | ||
| 60 | + uses: svenstaro/upload-release-action@v2 | ||
| 61 | + with: | ||
| 62 | + file_glob: true | ||
| 63 | + file: ./*.tar.bz2 | ||
| 64 | + overwrite: true | ||
| 65 | + repo_name: k2-fsa/sherpa-onnx | ||
| 66 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 67 | + tag: asr-models |
| @@ -10,7 +10,7 @@ concurrency: | @@ -10,7 +10,7 @@ concurrency: | ||
| 10 | jobs: | 10 | jobs: |
| 11 | export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx: | 11 | export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx: |
| 12 | if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | 12 | if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' |
| 13 | - name: export NeMo fast conformer | 13 | + name: Hybrid ctc streaming |
| 14 | runs-on: ${{ matrix.os }} | 14 | runs-on: ${{ matrix.os }} |
| 15 | strategy: | 15 | strategy: |
| 16 | fail-fast: false | 16 | fail-fast: false |
| @@ -54,13 +54,13 @@ jobs: | @@ -54,13 +54,13 @@ jobs: | ||
| 54 | curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | 54 | curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt |
| 55 | popd | 55 | popd |
| 56 | 56 | ||
| 57 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
| 58 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
| 59 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | 57 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms |
| 58 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-480ms | ||
| 59 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-1040ms | ||
| 60 | 60 | ||
| 61 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
| 62 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
| 63 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | 61 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms |
| 62 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-480ms | ||
| 63 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-1040ms | ||
| 64 | 64 | ||
| 65 | - name: Release | 65 | - name: Release |
| 66 | uses: svenstaro/upload-release-action@v2 | 66 | uses: svenstaro/upload-release-action@v2 |
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer-non-streaming.yaml
0 → 100644
| 1 | +name: export-nemo-fast-conformer-transducer-non-streaming | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-fast-conformer-hybrid-transducer-transducer-non-streaming-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-fast-conformer-hybrid-transducer-transducer-non-streaming: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: Hybrid transducer non-streaming | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install NeMo | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + BRANCH='main' | ||
| 33 | + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 34 | + pip install onnxruntime | ||
| 35 | + pip install kaldi-native-fbank | ||
| 36 | + pip install soundfile librosa | ||
| 37 | + | ||
| 38 | + - name: Run | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
| 42 | + ./run-transducer-non-streaming.sh | ||
| 43 | + | ||
| 44 | + mv -v sherpa-onnx-nemo* ../../.. | ||
| 45 | + | ||
| 46 | + - name: Compress files | ||
| 47 | + shell: bash | ||
| 48 | + run: | | ||
| 49 | + dirs=( | ||
| 50 | + sherpa-onnx-nemo-fast-conformer-transducer-en-24500 | ||
| 51 | + sherpa-onnx-nemo-fast-conformer-transducer-es-1424 | ||
| 52 | + sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288 | ||
| 53 | + sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 54 | + ) | ||
| 55 | + for d in ${dirs[@]}; do | ||
| 56 | + tar cjvf ${d}.tar.bz2 ./$d | ||
| 57 | + done | ||
| 58 | + | ||
| 59 | + - name: Release | ||
| 60 | + uses: svenstaro/upload-release-action@v2 | ||
| 61 | + with: | ||
| 62 | + file_glob: true | ||
| 63 | + file: ./*.tar.bz2 | ||
| 64 | + overwrite: true | ||
| 65 | + repo_name: k2-fsa/sherpa-onnx | ||
| 66 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 67 | + tag: asr-models |
| @@ -10,7 +10,7 @@ concurrency: | @@ -10,7 +10,7 @@ concurrency: | ||
| 10 | jobs: | 10 | jobs: |
| 11 | export-nemo-fast-conformer-hybrid-transducer-to-onnx: | 11 | export-nemo-fast-conformer-hybrid-transducer-to-onnx: |
| 12 | if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | 12 | if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' |
| 13 | - name: NeMo transducer | 13 | + name: Hybrid transducer streaming |
| 14 | runs-on: ${{ matrix.os }} | 14 | runs-on: ${{ matrix.os }} |
| 15 | strategy: | 15 | strategy: |
| 16 | fail-fast: false | 16 | fail-fast: false |
| @@ -54,13 +54,13 @@ jobs: | @@ -54,13 +54,13 @@ jobs: | ||
| 54 | curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | 54 | curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt |
| 55 | popd | 55 | popd |
| 56 | 56 | ||
| 57 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
| 58 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
| 59 | - cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | 57 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms |
| 58 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-480ms | ||
| 59 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-1040ms | ||
| 60 | 60 | ||
| 61 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
| 62 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
| 63 | - tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | 61 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms |
| 62 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-480ms | ||
| 63 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-1040ms | ||
| 64 | 64 | ||
| 65 | - name: Release | 65 | - name: Release |
| 66 | uses: svenstaro/upload-release-action@v2 | 66 | uses: svenstaro/upload-release-action@v2 |
scripts/nemo/.gitignore
0 → 100644
| 1 | +!run-*.sh |
| @@ -6,4 +6,20 @@ This folder contains scripts for exporting models from | @@ -6,4 +6,20 @@ This folder contains scripts for exporting models from | ||
| 6 | - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms | 6 | - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms |
| 7 | - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms | 7 | - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms |
| 8 | 8 | ||
| 9 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_ctc_large | ||
| 10 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_enes_conformer_transducer_large_codesw | ||
| 11 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_transducer_large | ||
| 12 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_enzh_fastconformer_transducer_large_codesw | ||
| 13 | + | ||
| 14 | + | ||
| 15 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fa_fastconformer_hybrid_large | ||
| 16 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_it_fastconformer_hybrid_large_pc | ||
| 17 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_pl_fastconformer_hybrid_large_pc | ||
| 18 | + - # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_ua_fastconformer_hybrid_large_pc | ||
| 19 | + | ||
| 20 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc | ||
| 21 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc | ||
| 22 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu | ||
| 23 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc | ||
| 24 | + | ||
| 9 | to `sherpa-onnx`. | 25 | to `sherpa-onnx`. |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import argparse | ||
| 4 | +from typing import Dict | ||
| 5 | + | ||
| 6 | +import nemo.collections.asr as nemo_asr | ||
| 7 | +import onnx | ||
| 8 | +import torch | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +def get_args(): | ||
| 12 | + parser = argparse.ArgumentParser() | ||
| 13 | + parser.add_argument( | ||
| 14 | + "--model", | ||
| 15 | + type=str, | ||
| 16 | + required=True, | ||
| 17 | + ) | ||
| 18 | + parser.add_argument( | ||
| 19 | + "--doc", | ||
| 20 | + type=str, | ||
| 21 | + default="", | ||
| 22 | + ) | ||
| 23 | + return parser.parse_args() | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 27 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 28 | + | ||
| 29 | + Args: | ||
| 30 | + filename: | ||
| 31 | + Filename of the ONNX model to be changed. | ||
| 32 | + meta_data: | ||
| 33 | + Key-value pairs. | ||
| 34 | + """ | ||
| 35 | + model = onnx.load(filename) | ||
| 36 | + while len(model.metadata_props): | ||
| 37 | + model.metadata_props.pop() | ||
| 38 | + | ||
| 39 | + for key, value in meta_data.items(): | ||
| 40 | + meta = model.metadata_props.add() | ||
| 41 | + meta.key = key | ||
| 42 | + meta.value = str(value) | ||
| 43 | + | ||
| 44 | + onnx.save(model, filename) | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +@torch.no_grad() | ||
| 48 | +def main(): | ||
| 49 | + args = get_args() | ||
| 50 | + model_name = args.model | ||
| 51 | + | ||
| 52 | + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
| 53 | + print(asr_model.cfg) | ||
| 54 | + print(asr_model) | ||
| 55 | + | ||
| 56 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 57 | + for i, s in enumerate(asr_model.joint.vocabulary): | ||
| 58 | + f.write(f"{s} {i}\n") | ||
| 59 | + f.write(f"<blk> {i+1}\n") | ||
| 60 | + print("Saved to tokens.txt") | ||
| 61 | + | ||
| 62 | + decoder_type = "ctc" | ||
| 63 | + asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
| 64 | + asr_model.eval() | ||
| 65 | + | ||
| 66 | + asr_model.set_export_config({"decoder_type": "ctc"}) | ||
| 67 | + | ||
| 68 | + filename = "model.onnx" | ||
| 69 | + | ||
| 70 | + asr_model.export(filename) | ||
| 71 | + | ||
| 72 | + normalize_type = asr_model.cfg.preprocessor.normalize | ||
| 73 | + if normalize_type == "NA": | ||
| 74 | + normalize_type = "" | ||
| 75 | + | ||
| 76 | + meta_data = { | ||
| 77 | + "vocab_size": asr_model.decoder.vocab_size, | ||
| 78 | + "normalize_type": normalize_type, | ||
| 79 | + "subsampling_factor": 8, | ||
| 80 | + "model_type": "EncDecHybridRNNTCTCBPEModel", | ||
| 81 | + "version": "1", | ||
| 82 | + "model_author": "NeMo", | ||
| 83 | + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
| 84 | + "comment": "Only the CTC branch is exported", | ||
| 85 | + "doc": args.doc, | ||
| 86 | + } | ||
| 87 | + add_meta_data(filename, meta_data) | ||
| 88 | + | ||
| 89 | + print("preprocessor", asr_model.cfg.preprocessor) | ||
| 90 | + print(meta_data) | ||
| 91 | + | ||
| 92 | + | ||
| 93 | +if __name__ == "__main__": | ||
| 94 | + main() |
| @@ -91,11 +91,15 @@ def main(): | @@ -91,11 +91,15 @@ def main(): | ||
| 91 | 91 | ||
| 92 | asr_model.export(filename) | 92 | asr_model.export(filename) |
| 93 | 93 | ||
| 94 | + normalize_type = asr_model.cfg.preprocessor.normalize | ||
| 95 | + if normalize_type == "NA": | ||
| 96 | + normalize_type = "" | ||
| 97 | + | ||
| 94 | meta_data = { | 98 | meta_data = { |
| 95 | "vocab_size": asr_model.decoder.vocab_size, | 99 | "vocab_size": asr_model.decoder.vocab_size, |
| 96 | "window_size": window_size, | 100 | "window_size": window_size, |
| 97 | "chunk_shift": chunk_shift, | 101 | "chunk_shift": chunk_shift, |
| 98 | - "normalize_type": "None", | 102 | + "normalize_type": normalize_type, |
| 99 | "cache_last_channel_dim1": cache_last_channel_dim1, | 103 | "cache_last_channel_dim1": cache_last_channel_dim1, |
| 100 | "cache_last_channel_dim2": cache_last_channel_dim2, | 104 | "cache_last_channel_dim2": cache_last_channel_dim2, |
| 101 | "cache_last_channel_dim3": cache_last_channel_dim3, | 105 | "cache_last_channel_dim3": cache_last_channel_dim3, |
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer-non-streaming.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import argparse | ||
| 4 | +from typing import Dict | ||
| 5 | + | ||
| 6 | +import nemo.collections.asr as nemo_asr | ||
| 7 | +import onnx | ||
| 8 | +import torch | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +def get_args(): | ||
| 12 | + parser = argparse.ArgumentParser() | ||
| 13 | + parser.add_argument( | ||
| 14 | + "--model", | ||
| 15 | + type=str, | ||
| 16 | + required=True, | ||
| 17 | + ) | ||
| 18 | + parser.add_argument( | ||
| 19 | + "--doc", | ||
| 20 | + type=str, | ||
| 21 | + default="", | ||
| 22 | + ) | ||
| 23 | + return parser.parse_args() | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 27 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 28 | + | ||
| 29 | + Args: | ||
| 30 | + filename: | ||
| 31 | + Filename of the ONNX model to be changed. | ||
| 32 | + meta_data: | ||
| 33 | + Key-value pairs. | ||
| 34 | + """ | ||
| 35 | + model = onnx.load(filename) | ||
| 36 | + while len(model.metadata_props): | ||
| 37 | + model.metadata_props.pop() | ||
| 38 | + | ||
| 39 | + for key, value in meta_data.items(): | ||
| 40 | + meta = model.metadata_props.add() | ||
| 41 | + meta.key = key | ||
| 42 | + meta.value = str(value) | ||
| 43 | + | ||
| 44 | + onnx.save(model, filename) | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +@torch.no_grad() | ||
| 48 | +def main(): | ||
| 49 | + args = get_args() | ||
| 50 | + model_name = args.model | ||
| 51 | + | ||
| 52 | + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
| 53 | + | ||
| 54 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 55 | + for i, s in enumerate(asr_model.joint.vocabulary): | ||
| 56 | + f.write(f"{s} {i}\n") | ||
| 57 | + f.write(f"<blk> {i+1}\n") | ||
| 58 | + print("Saved to tokens.txt") | ||
| 59 | + | ||
| 60 | + decoder_type = "rnnt" | ||
| 61 | + asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
| 62 | + asr_model.eval() | ||
| 63 | + | ||
| 64 | + asr_model.set_export_config({"decoder_type": "rnnt"}) | ||
| 65 | + | ||
| 66 | + # asr_model.export("model.onnx") | ||
| 67 | + asr_model.encoder.export("encoder.onnx") | ||
| 68 | + asr_model.decoder.export("decoder.onnx") | ||
| 69 | + asr_model.joint.export("joiner.onnx") | ||
| 70 | + # model.onnx is a suffix. | ||
| 71 | + # It will generate two files: | ||
| 72 | + # encoder-model.onnx | ||
| 73 | + # decoder_joint-model.onnx | ||
| 74 | + | ||
| 75 | + normalize_type = asr_model.cfg.preprocessor.normalize | ||
| 76 | + if normalize_type == "NA": | ||
| 77 | + normalize_type = "" | ||
| 78 | + meta_data = { | ||
| 79 | + "vocab_size": asr_model.decoder.vocab_size, | ||
| 80 | + "normalize_type": normalize_type, | ||
| 81 | + "pred_rnn_layers": asr_model.decoder.pred_rnn_layers, | ||
| 82 | + "pred_hidden": asr_model.decoder.pred_hidden, | ||
| 83 | + "subsampling_factor": 8, | ||
| 84 | + "model_type": "EncDecHybridRNNTCTCBPEModel", | ||
| 85 | + "version": "1", | ||
| 86 | + "model_author": "NeMo", | ||
| 87 | + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
| 88 | + "comment": "Only the transducer branch is exported", | ||
| 89 | + "doc": args.doc, | ||
| 90 | + } | ||
| 91 | + add_meta_data("encoder.onnx", meta_data) | ||
| 92 | + | ||
| 93 | + print(meta_data) | ||
| 94 | + | ||
| 95 | + | ||
| 96 | +if __name__ == "__main__": | ||
| 97 | + main() |
| @@ -96,11 +96,15 @@ def main(): | @@ -96,11 +96,15 @@ def main(): | ||
| 96 | # encoder-model.onnx | 96 | # encoder-model.onnx |
| 97 | # decoder_joint-model.onnx | 97 | # decoder_joint-model.onnx |
| 98 | 98 | ||
| 99 | + normalize_type = asr_model.cfg.preprocessor.normalize | ||
| 100 | + if normalize_type == "NA": | ||
| 101 | + normalize_type = "" | ||
| 102 | + | ||
| 99 | meta_data = { | 103 | meta_data = { |
| 100 | "vocab_size": asr_model.decoder.vocab_size, | 104 | "vocab_size": asr_model.decoder.vocab_size, |
| 101 | "window_size": window_size, | 105 | "window_size": window_size, |
| 102 | "chunk_shift": chunk_shift, | 106 | "chunk_shift": chunk_shift, |
| 103 | - "normalize_type": "None", | 107 | + "normalize_type": normalize_type, |
| 104 | "cache_last_channel_dim1": cache_last_channel_dim1, | 108 | "cache_last_channel_dim1": cache_last_channel_dim1, |
| 105 | "cache_last_channel_dim2": cache_last_channel_dim2, | 109 | "cache_last_channel_dim2": cache_last_channel_dim2, |
| 106 | "cache_last_channel_dim3": cache_last_channel_dim3, | 110 | "cache_last_channel_dim3": cache_last_channel_dim3, |
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +log() { | ||
| 7 | + # This function is from espnet | ||
| 8 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 9 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 10 | +} | ||
| 11 | + | ||
| 12 | +# 8500 hours of English speech | ||
| 13 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc | ||
| 14 | +name=$(basename $url) | ||
| 15 | +doc="This collection contains the English FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization on NeMo ASRSet En PC with around 8500 hours of English speech (SPGI 1k, VoxPopuli, MCV11, Europarl-ASR, Fisher, LibriSpeech, NSC1, MLS). It utilizes a Google SentencePiece [1] tokenizer with a vocabulary size of 1024. It transcribes text in upper and lower case English alphabet along with spaces, periods, commas, question marks, and a few other characters." | ||
| 16 | + | ||
| 17 | +log "Process $name at $url" | ||
| 18 | +./export-onnx-ctc-non-streaming.py --model $name --doc "$doc" | ||
| 19 | + | ||
| 20 | +d=sherpa-onnx-nemo-fast-conformer-ctc-en-24500 | ||
| 21 | +mkdir -p $d | ||
| 22 | +mv -v model.onnx $d/ | ||
| 23 | +mv -v tokens.txt $d/ | ||
| 24 | +ls -lh $d | ||
| 25 | + | ||
| 26 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc | ||
| 27 | +name=$(basename $url) | ||
| 28 | +doc="This collection contains the Spanish FastConformer Hybrid (CTC and Transducer) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC ES ASRSET (Fisher, MCV12, MLS, Voxpopuli) containing 1424 hours of Spanish speech. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 1024, and transcribes text in upper and lower case Spanish alphabet along with spaces, period, comma, question mark and inverted question mark." | ||
| 29 | + | ||
| 30 | +./export-onnx-ctc-non-streaming.py --model $name --doc "$doc" | ||
| 31 | + | ||
| 32 | +d=sherpa-onnx-nemo-fast-conformer-ctc-es-1424 | ||
| 33 | +mkdir -p $d | ||
| 34 | +mv -v model.onnx $d/ | ||
| 35 | +mv -v tokens.txt $d/ | ||
| 36 | +ls -lh $d | ||
| 37 | + | ||
| 38 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu | ||
| 39 | +name=$(basename $url) | ||
| 40 | +doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC German, English, Spanish, and French ASR sets that contain 14,288 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters. The total tokenizer size is 2560, of which 1024 tokens are allocated to English, German, French, and Spanish. The remaining tokens are reserved for future languages." | ||
| 41 | + | ||
| 42 | +./export-onnx-ctc-non-streaming.py --model $name --doc "$doc" | ||
| 43 | + | ||
| 44 | +d=sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288 | ||
| 45 | +mkdir -p $d | ||
| 46 | +mv -v model.onnx $d/ | ||
| 47 | +mv -v tokens.txt $d/ | ||
| 48 | +ls -lh $d | ||
| 49 | + | ||
| 50 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc | ||
| 51 | +name=$(basename $url) | ||
| 52 | +doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC Belarusian, German, English, Spanish, French, Croatian, Italian, Polish, Russian, and Ukrainian ASR sets that contain ~20,000 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language (2560 total), and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters." | ||
| 53 | + | ||
| 54 | +./export-onnx-ctc-non-streaming.py --model $name --doc "$doc" | ||
| 55 | + | ||
| 56 | +d=sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 57 | +mkdir -p $d | ||
| 58 | +mv -v model.onnx $d/ | ||
| 59 | +mv -v tokens.txt $d/ | ||
| 60 | +ls -lh $d | ||
| 61 | + | ||
| 62 | +# Now test the exported model | ||
| 63 | +log "Download test data" | ||
| 64 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2 | ||
| 65 | +tar xvf spoken-language-identification-test-wavs.tar.bz2 | ||
| 66 | +rm spoken-language-identification-test-wavs.tar.bz2 | ||
| 67 | +data=spoken-language-identification-test-wavs | ||
| 68 | + | ||
| 69 | +d=sherpa-onnx-nemo-fast-conformer-ctc-en-24500 | ||
| 70 | +python3 ./test-onnx-ctc-non-streaming.py \ | ||
| 71 | + --model $d/model.onnx \ | ||
| 72 | + --tokens $d/tokens.txt \ | ||
| 73 | + --wav $data/en-english.wav | ||
| 74 | +mkdir -p $d/test_wavs | ||
| 75 | +cp -v $data/en-english.wav $d/test_wavs | ||
| 76 | + | ||
| 77 | +d=sherpa-onnx-nemo-fast-conformer-ctc-es-1424 | ||
| 78 | +python3 ./test-onnx-ctc-non-streaming.py \ | ||
| 79 | + --model $d/model.onnx \ | ||
| 80 | + --tokens $d/tokens.txt \ | ||
| 81 | + --wav $data/es-spanish.wav | ||
| 82 | +mkdir -p $d/test_wavs | ||
| 83 | +cp -v $data/es-spanish.wav $d/test_wavs | ||
| 84 | + | ||
| 85 | +d=sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288 | ||
| 86 | +mkdir -p $d/test_wavs | ||
| 87 | +for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav; do | ||
| 88 | + python3 ./test-onnx-ctc-non-streaming.py \ | ||
| 89 | + --model $d/model.onnx \ | ||
| 90 | + --tokens $d/tokens.txt \ | ||
| 91 | + --wav $data/$w | ||
| 92 | + cp -v $data/$w $d/test_wavs | ||
| 93 | +done | ||
| 94 | + | ||
| 95 | +d=sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 96 | +mkdir -p $d/test_wavs | ||
| 97 | +for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav hr-croatian.wav it-italian.wav po-polish.wav ru-russian.wav uk-ukrainian.wav; do | ||
| 98 | + python3 ./test-onnx-ctc-non-streaming.py \ | ||
| 99 | + --model $d/model.onnx \ | ||
| 100 | + --tokens $d/tokens.txt \ | ||
| 101 | + --wav $data/$w | ||
| 102 | + cp -v $data/$w $d/test_wavs | ||
| 103 | +done |
| @@ -16,7 +16,7 @@ ms=( | @@ -16,7 +16,7 @@ ms=( | ||
| 16 | 16 | ||
| 17 | for m in ${ms[@]}; do | 17 | for m in ${ms[@]}; do |
| 18 | ./export-onnx-ctc.py --model $m | 18 | ./export-onnx-ctc.py --model $m |
| 19 | - d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | 19 | + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-${m}ms |
| 20 | if [ ! -f $d/model.onnx ]; then | 20 | if [ ! -f $d/model.onnx ]; then |
| 21 | mkdir -p $d | 21 | mkdir -p $d |
| 22 | mv -v model.onnx $d/ | 22 | mv -v model.onnx $d/ |
| @@ -28,7 +28,7 @@ done | @@ -28,7 +28,7 @@ done | ||
| 28 | # Now test the exported models | 28 | # Now test the exported models |
| 29 | 29 | ||
| 30 | for m in ${ms[@]}; do | 30 | for m in ${ms[@]}; do |
| 31 | - d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | 31 | + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-${m}ms |
| 32 | python3 ./test-onnx-ctc.py \ | 32 | python3 ./test-onnx-ctc.py \ |
| 33 | --model $d/model.onnx \ | 33 | --model $d/model.onnx \ |
| 34 | --tokens $d/tokens.txt \ | 34 | --tokens $d/tokens.txt \ |
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +log() { | ||
| 7 | + # This function is from espnet | ||
| 8 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 9 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 10 | +} | ||
| 11 | + | ||
| 12 | +# 8500 hours of English speech | ||
| 13 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc | ||
| 14 | +name=$(basename $url) | ||
| 15 | +doc="This collection contains the English FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization on NeMo ASRSet En PC with around 8500 hours of English speech (SPGI 1k, VoxPopuli, MCV11, Europarl-ASR, Fisher, LibriSpeech, NSC1, MLS). It utilizes a Google SentencePiece [1] tokenizer with a vocabulary size of 1024. It transcribes text in upper and lower case English alphabet along with spaces, periods, commas, question marks, and a few other characters." | ||
| 16 | + | ||
| 17 | +log "Process $name at $url" | ||
| 18 | +./export-onnx-transducer-non-streaming.py --model $name --doc "$doc" | ||
| 19 | + | ||
| 20 | +d=sherpa-onnx-nemo-fast-conformer-transducer-en-24500 | ||
| 21 | +mkdir -p $d | ||
| 22 | +mv -v *.onnx $d/ | ||
| 23 | +mv -v tokens.txt $d/ | ||
| 24 | +ls -lh $d | ||
| 25 | + | ||
| 26 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc | ||
| 27 | +name=$(basename $url) | ||
| 28 | +doc="This collection contains the Spanish FastConformer Hybrid (CTC and Transducer) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC ES ASRSET (Fisher, MCV12, MLS, Voxpopuli) containing 1424 hours of Spanish speech. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 1024, and transcribes text in upper and lower case Spanish alphabet along with spaces, period, comma, question mark and inverted question mark." | ||
| 29 | + | ||
| 30 | +./export-onnx-transducer-non-streaming.py --model $name --doc "$doc" | ||
| 31 | + | ||
| 32 | +d=sherpa-onnx-nemo-fast-conformer-transducer-es-1424 | ||
| 33 | +mkdir -p $d | ||
| 34 | +mv -v *.onnx $d/ | ||
| 35 | +mv -v tokens.txt $d/ | ||
| 36 | +ls -lh $d | ||
| 37 | + | ||
| 38 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu | ||
| 39 | +name=$(basename $url) | ||
| 40 | +doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC German, English, Spanish, and French ASR sets that contain 14,288 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters. The total tokenizer size is 2560, of which 1024 tokens are allocated to English, German, French, and Spanish. The remaining tokens are reserved for future languages." | ||
| 41 | + | ||
| 42 | +./export-onnx-transducer-non-streaming.py --model $name --doc "$doc" | ||
| 43 | + | ||
| 44 | +d=sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288 | ||
| 45 | +mkdir -p $d | ||
| 46 | +mv -v *.onnx $d/ | ||
| 47 | +mv -v tokens.txt $d/ | ||
| 48 | +ls -lh $d | ||
| 49 | + | ||
| 50 | +url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc | ||
| 51 | +name=$(basename $url) | ||
| 52 | +doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC Belarusian, German, English, Spanish, French, Croatian, Italian, Polish, Russian, and Ukrainian ASR sets that contain ~20,000 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language (2560 total), and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters." | ||
| 53 | + | ||
| 54 | +./export-onnx-transducer-non-streaming.py --model $name --doc "$doc" | ||
| 55 | + | ||
| 56 | +d=sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 57 | +mkdir -p $d | ||
| 58 | +mv -v *.onnx $d/ | ||
| 59 | +mv -v tokens.txt $d/ | ||
| 60 | +ls -lh $d | ||
| 61 | + | ||
| 62 | +# Now test the exported model | ||
| 63 | +log "Download test data" | ||
| 64 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2 | ||
| 65 | +tar xvf spoken-language-identification-test-wavs.tar.bz2 | ||
| 66 | +rm spoken-language-identification-test-wavs.tar.bz2 | ||
| 67 | +data=spoken-language-identification-test-wavs | ||
| 68 | + | ||
| 69 | +d=sherpa-onnx-nemo-fast-conformer-transducer-en-24500 | ||
| 70 | +python3 ./test-onnx-transducer-non-streaming.py \ | ||
| 71 | + --encoder $d/encoder.onnx \ | ||
| 72 | + --decoder $d/decoder.onnx \ | ||
| 73 | + --joiner $d/joiner.onnx \ | ||
| 74 | + --tokens $d/tokens.txt \ | ||
| 75 | + --wav $data/en-english.wav | ||
| 76 | +mkdir -p $d/test_wavs | ||
| 77 | +cp -v $data/en-english.wav $d/test_wavs | ||
| 78 | + | ||
| 79 | +d=sherpa-onnx-nemo-fast-conformer-transducer-es-1424 | ||
| 80 | +python3 ./test-onnx-transducer-non-streaming.py \ | ||
| 81 | + --encoder $d/encoder.onnx \ | ||
| 82 | + --decoder $d/decoder.onnx \ | ||
| 83 | + --joiner $d/joiner.onnx \ | ||
| 84 | + --tokens $d/tokens.txt \ | ||
| 85 | + --wav $data/es-spanish.wav | ||
| 86 | +mkdir -p $d/test_wavs | ||
| 87 | +cp -v $data/es-spanish.wav $d/test_wavs | ||
| 88 | + | ||
| 89 | +d=sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288 | ||
| 90 | +mkdir -p $d/test_wavs | ||
| 91 | +for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav; do | ||
| 92 | + python3 ./test-onnx-transducer-non-streaming.py \ | ||
| 93 | + --encoder $d/encoder.onnx \ | ||
| 94 | + --decoder $d/decoder.onnx \ | ||
| 95 | + --joiner $d/joiner.onnx \ | ||
| 96 | + --tokens $d/tokens.txt \ | ||
| 97 | + --wav $data/$w | ||
| 98 | + cp -v $data/$w $d/test_wavs | ||
| 99 | +done | ||
| 100 | + | ||
| 101 | +d=sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k | ||
| 102 | +mkdir -p $d/test_wavs | ||
| 103 | +for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav hr-croatian.wav it-italian.wav po-polish.wav ru-russian.wav uk-ukrainian.wav; do | ||
| 104 | + python3 ./test-onnx-transducer-non-streaming.py \ | ||
| 105 | + --encoder $d/encoder.onnx \ | ||
| 106 | + --decoder $d/decoder.onnx \ | ||
| 107 | + --joiner $d/joiner.onnx \ | ||
| 108 | + --tokens $d/tokens.txt \ | ||
| 109 | + --wav $data/$w | ||
| 110 | + cp -v $data/$w $d/test_wavs | ||
| 111 | +done |
| @@ -16,7 +16,7 @@ ms=( | @@ -16,7 +16,7 @@ ms=( | ||
| 16 | 16 | ||
| 17 | for m in ${ms[@]}; do | 17 | for m in ${ms[@]}; do |
| 18 | ./export-onnx-transducer.py --model $m | 18 | ./export-onnx-transducer.py --model $m |
| 19 | - d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | 19 | + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-${m}ms |
| 20 | if [ ! -f $d/encoder.onnx ]; then | 20 | if [ ! -f $d/encoder.onnx ]; then |
| 21 | mkdir -p $d | 21 | mkdir -p $d |
| 22 | mv -v encoder.onnx $d/ | 22 | mv -v encoder.onnx $d/ |
| @@ -30,7 +30,7 @@ done | @@ -30,7 +30,7 @@ done | ||
| 30 | # Now test the exported models | 30 | # Now test the exported models |
| 31 | 31 | ||
| 32 | for m in ${ms[@]}; do | 32 | for m in ${ms[@]}; do |
| 33 | - d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | 33 | + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-${m}ms |
| 34 | python3 ./test-onnx-transducer.py \ | 34 | python3 ./test-onnx-transducer.py \ |
| 35 | --encoder $d/encoder.onnx \ | 35 | --encoder $d/encoder.onnx \ |
| 36 | --decoder $d/decoder.onnx \ | 36 | --decoder $d/decoder.onnx \ |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from pathlib import Path | ||
| 6 | + | ||
| 7 | +import kaldi_native_fbank as knf | ||
| 8 | +import numpy as np | ||
| 9 | +import onnxruntime as ort | ||
| 10 | +import torch | ||
| 11 | +import soundfile as sf | ||
| 12 | +import librosa | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def get_args(): | ||
| 16 | + parser = argparse.ArgumentParser() | ||
| 17 | + parser.add_argument("--model", type=str, required=True, help="Path to model.onnx") | ||
| 18 | + | ||
| 19 | + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") | ||
| 20 | + | ||
| 21 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 22 | + | ||
| 23 | + return parser.parse_args() | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +def create_fbank(): | ||
| 27 | + opts = knf.FbankOptions() | ||
| 28 | + opts.frame_opts.dither = 0 | ||
| 29 | + opts.frame_opts.remove_dc_offset = False | ||
| 30 | + opts.frame_opts.window_type = "hann" | ||
| 31 | + | ||
| 32 | + opts.mel_opts.low_freq = 0 | ||
| 33 | + opts.mel_opts.num_bins = 80 | ||
| 34 | + | ||
| 35 | + opts.mel_opts.is_librosa = True | ||
| 36 | + | ||
| 37 | + fbank = knf.OnlineFbank(opts) | ||
| 38 | + return fbank | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def compute_features(audio, fbank): | ||
| 42 | + assert len(audio.shape) == 1, audio.shape | ||
| 43 | + fbank.accept_waveform(16000, audio) | ||
| 44 | + ans = [] | ||
| 45 | + processed = 0 | ||
| 46 | + while processed < fbank.num_frames_ready: | ||
| 47 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 48 | + processed += 1 | ||
| 49 | + ans = np.stack(ans) | ||
| 50 | + return ans | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +class OnnxModel: | ||
| 54 | + def __init__( | ||
| 55 | + self, | ||
| 56 | + filename: str, | ||
| 57 | + ): | ||
| 58 | + session_opts = ort.SessionOptions() | ||
| 59 | + session_opts.inter_op_num_threads = 1 | ||
| 60 | + session_opts.intra_op_num_threads = 1 | ||
| 61 | + | ||
| 62 | + self.session_opts = session_opts | ||
| 63 | + | ||
| 64 | + self.model = ort.InferenceSession( | ||
| 65 | + filename, | ||
| 66 | + sess_options=self.session_opts, | ||
| 67 | + providers=["CPUExecutionProvider"], | ||
| 68 | + ) | ||
| 69 | + print("==========Input==========") | ||
| 70 | + for i in self.model.get_inputs(): | ||
| 71 | + print(i) | ||
| 72 | + print("==========Output==========") | ||
| 73 | + for i in self.model.get_outputs(): | ||
| 74 | + print(i) | ||
| 75 | + """ | ||
| 76 | + ==========Input========== | ||
| 77 | + NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2']) | ||
| 78 | + NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) | ||
| 79 | + ==========Output========== | ||
| 80 | + NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 1025]) | ||
| 81 | + """ | ||
| 82 | + | ||
| 83 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 84 | + self.normalize_type = meta["normalize_type"] | ||
| 85 | + print(meta) | ||
| 86 | + | ||
| 87 | + def __call__(self, x: np.ndarray): | ||
| 88 | + # x: (T, C) | ||
| 89 | + x = torch.from_numpy(x) | ||
| 90 | + x = x.t().unsqueeze(0) | ||
| 91 | + # x: [1, C, T] | ||
| 92 | + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) | ||
| 93 | + | ||
| 94 | + log_probs = self.model.run( | ||
| 95 | + [ | ||
| 96 | + self.model.get_outputs()[0].name, | ||
| 97 | + ], | ||
| 98 | + { | ||
| 99 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 100 | + self.model.get_inputs()[1].name: x_lens.numpy(), | ||
| 101 | + }, | ||
| 102 | + )[0] | ||
| 103 | + # [batch_size, T, vocab_size] | ||
| 104 | + return torch.from_numpy(log_probs) | ||
| 105 | + | ||
| 106 | + | ||
| 107 | +def main(): | ||
| 108 | + args = get_args() | ||
| 109 | + assert Path(args.model).is_file(), args.model | ||
| 110 | + assert Path(args.tokens).is_file(), args.tokens | ||
| 111 | + assert Path(args.wav).is_file(), args.wav | ||
| 112 | + | ||
| 113 | + print(vars(args)) | ||
| 114 | + | ||
| 115 | + model = OnnxModel(args.model) | ||
| 116 | + | ||
| 117 | + id2token = dict() | ||
| 118 | + with open(args.tokens, encoding="utf-8") as f: | ||
| 119 | + for line in f: | ||
| 120 | + t, idx = line.split() | ||
| 121 | + id2token[int(idx)] = t | ||
| 122 | + | ||
| 123 | + fbank = create_fbank() | ||
| 124 | + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) | ||
| 125 | + audio = audio[:, 0] # only use the first channel | ||
| 126 | + if sample_rate != 16000: | ||
| 127 | + audio = librosa.resample( | ||
| 128 | + audio, | ||
| 129 | + orig_sr=sample_rate, | ||
| 130 | + target_sr=16000, | ||
| 131 | + ) | ||
| 132 | + sample_rate = 16000 | ||
| 133 | + | ||
| 134 | + blank = len(id2token) - 1 | ||
| 135 | + ans = [] | ||
| 136 | + prev = -1 | ||
| 137 | + | ||
| 138 | + print(audio.shape) | ||
| 139 | + features = compute_features(audio, fbank) | ||
| 140 | + if model.normalize_type != "": | ||
| 141 | + assert model.normalize_type == "per_feature", model.normalize_type | ||
| 142 | + features = torch.from_numpy(features) | ||
| 143 | + mean = features.mean(dim=1, keepdims=True) | ||
| 144 | + stddev = features.std(dim=1, keepdims=True) | ||
| 145 | + features = (features - mean) / stddev | ||
| 146 | + features = features.numpy() | ||
| 147 | + | ||
| 148 | + print("features.shape", features.shape) | ||
| 149 | + log_probs = model(features) | ||
| 150 | + | ||
| 151 | + print("log_probs.shape", log_probs.shape) | ||
| 152 | + | ||
| 153 | + log_probs = log_probs[0, :, :] # remove batch dim | ||
| 154 | + ids = torch.argmax(log_probs, dim=1).tolist() | ||
| 155 | + for k in ids: | ||
| 156 | + if k != blank and k != prev: | ||
| 157 | + ans.append(k) | ||
| 158 | + prev = k | ||
| 159 | + | ||
| 160 | + tokens = [id2token[i] for i in ans] | ||
| 161 | + underline = "▁" | ||
| 162 | + # underline = b"\xe2\x96\x81".decode() | ||
| 163 | + text = "".join(tokens).replace(underline, " ").strip() | ||
| 164 | + print(args.wav) | ||
| 165 | + print(text) | ||
| 166 | + | ||
| 167 | + | ||
| 168 | +main() |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from pathlib import Path | ||
| 6 | + | ||
| 7 | +import kaldi_native_fbank as knf | ||
| 8 | +import librosa | ||
| 9 | +import numpy as np | ||
| 10 | +import onnxruntime as ort | ||
| 11 | +import soundfile as sf | ||
| 12 | +import torch | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def get_args(): | ||
| 16 | + parser = argparse.ArgumentParser() | ||
| 17 | + parser.add_argument( | ||
| 18 | + "--encoder", type=str, required=True, help="Path to encoder.onnx" | ||
| 19 | + ) | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--decoder", type=str, required=True, help="Path to decoder.onnx" | ||
| 22 | + ) | ||
| 23 | + parser.add_argument("--joiner", type=str, required=True, help="Path to joiner.onnx") | ||
| 24 | + | ||
| 25 | + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") | ||
| 26 | + | ||
| 27 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 28 | + | ||
| 29 | + return parser.parse_args() | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +def create_fbank(): | ||
| 33 | + opts = knf.FbankOptions() | ||
| 34 | + opts.frame_opts.dither = 0 | ||
| 35 | + opts.frame_opts.remove_dc_offset = False | ||
| 36 | + opts.frame_opts.window_type = "hann" | ||
| 37 | + | ||
| 38 | + opts.mel_opts.low_freq = 0 | ||
| 39 | + opts.mel_opts.num_bins = 80 | ||
| 40 | + | ||
| 41 | + opts.mel_opts.is_librosa = True | ||
| 42 | + | ||
| 43 | + fbank = knf.OnlineFbank(opts) | ||
| 44 | + return fbank | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def compute_features(audio, fbank): | ||
| 48 | + assert len(audio.shape) == 1, audio.shape | ||
| 49 | + fbank.accept_waveform(16000, audio) | ||
| 50 | + ans = [] | ||
| 51 | + processed = 0 | ||
| 52 | + while processed < fbank.num_frames_ready: | ||
| 53 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 54 | + processed += 1 | ||
| 55 | + ans = np.stack(ans) | ||
| 56 | + return ans | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +def display(sess): | ||
| 60 | + print("==========Input==========") | ||
| 61 | + for i in sess.get_inputs(): | ||
| 62 | + print(i) | ||
| 63 | + print("==========Output==========") | ||
| 64 | + for i in sess.get_outputs(): | ||
| 65 | + print(i) | ||
| 66 | + | ||
| 67 | + | ||
| 68 | +""" | ||
| 69 | +encoder | ||
| 70 | +==========Input========== | ||
| 71 | +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2']) | ||
| 72 | +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) | ||
| 73 | +==========Output========== | ||
| 74 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2']) | ||
| 75 | +NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) | ||
| 76 | + | ||
| 77 | +decoder | ||
| 78 | +==========Input========== | ||
| 79 | +NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) | ||
| 80 | +NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) | ||
| 81 | +NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640]) | ||
| 82 | +NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640]) | ||
| 83 | +==========Output========== | ||
| 84 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2']) | ||
| 85 | +NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) | ||
| 86 | +NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640]) | ||
| 87 | +NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640]) | ||
| 88 | + | ||
| 89 | +joiner | ||
| 90 | +==========Input========== | ||
| 91 | +NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2']) | ||
| 92 | +NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2']) | ||
| 93 | +==========Output========== | ||
| 94 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025]) | ||
| 95 | +""" | ||
| 96 | + | ||
| 97 | + | ||
| 98 | +class OnnxModel: | ||
| 99 | + def __init__( | ||
| 100 | + self, | ||
| 101 | + encoder: str, | ||
| 102 | + decoder: str, | ||
| 103 | + joiner: str, | ||
| 104 | + ): | ||
| 105 | + self.init_encoder(encoder) | ||
| 106 | + display(self.encoder) | ||
| 107 | + self.init_decoder(decoder) | ||
| 108 | + display(self.decoder) | ||
| 109 | + self.init_joiner(joiner) | ||
| 110 | + display(self.joiner) | ||
| 111 | + | ||
| 112 | + def init_encoder(self, encoder): | ||
| 113 | + session_opts = ort.SessionOptions() | ||
| 114 | + session_opts.inter_op_num_threads = 1 | ||
| 115 | + session_opts.intra_op_num_threads = 1 | ||
| 116 | + | ||
| 117 | + self.encoder = ort.InferenceSession( | ||
| 118 | + encoder, | ||
| 119 | + sess_options=session_opts, | ||
| 120 | + providers=["CPUExecutionProvider"], | ||
| 121 | + ) | ||
| 122 | + | ||
| 123 | + meta = self.encoder.get_modelmeta().custom_metadata_map | ||
| 124 | + self.normalize_type = meta["normalize_type"] | ||
| 125 | + print(meta) | ||
| 126 | + | ||
| 127 | + self.pred_rnn_layers = int(meta["pred_rnn_layers"]) | ||
| 128 | + self.pred_hidden = int(meta["pred_hidden"]) | ||
| 129 | + | ||
| 130 | + def init_decoder(self, decoder): | ||
| 131 | + session_opts = ort.SessionOptions() | ||
| 132 | + session_opts.inter_op_num_threads = 1 | ||
| 133 | + session_opts.intra_op_num_threads = 1 | ||
| 134 | + | ||
| 135 | + self.decoder = ort.InferenceSession( | ||
| 136 | + decoder, | ||
| 137 | + sess_options=session_opts, | ||
| 138 | + providers=["CPUExecutionProvider"], | ||
| 139 | + ) | ||
| 140 | + | ||
| 141 | + def init_joiner(self, joiner): | ||
| 142 | + session_opts = ort.SessionOptions() | ||
| 143 | + session_opts.inter_op_num_threads = 1 | ||
| 144 | + session_opts.intra_op_num_threads = 1 | ||
| 145 | + | ||
| 146 | + self.joiner = ort.InferenceSession( | ||
| 147 | + joiner, | ||
| 148 | + sess_options=session_opts, | ||
| 149 | + providers=["CPUExecutionProvider"], | ||
| 150 | + ) | ||
| 151 | + | ||
| 152 | + def get_decoder_state(self): | ||
| 153 | + batch_size = 1 | ||
| 154 | + state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() | ||
| 155 | + state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() | ||
| 156 | + return state0, state1 | ||
| 157 | + | ||
| 158 | + def run_encoder(self, x: np.ndarray): | ||
| 159 | + # x: (T, C) | ||
| 160 | + x = torch.from_numpy(x) | ||
| 161 | + x = x.t().unsqueeze(0) | ||
| 162 | + # x: [1, C, T] | ||
| 163 | + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) | ||
| 164 | + | ||
| 165 | + (encoder_out, out_len) = self.encoder.run( | ||
| 166 | + [ | ||
| 167 | + self.encoder.get_outputs()[0].name, | ||
| 168 | + self.encoder.get_outputs()[1].name, | ||
| 169 | + ], | ||
| 170 | + { | ||
| 171 | + self.encoder.get_inputs()[0].name: x.numpy(), | ||
| 172 | + self.encoder.get_inputs()[1].name: x_lens.numpy(), | ||
| 173 | + }, | ||
| 174 | + ) | ||
| 175 | + # [batch_size, dim, T] | ||
| 176 | + return encoder_out | ||
| 177 | + | ||
| 178 | + def run_decoder( | ||
| 179 | + self, | ||
| 180 | + token: int, | ||
| 181 | + state0: np.ndarray, | ||
| 182 | + state1: np.ndarray, | ||
| 183 | + ): | ||
| 184 | + target = torch.tensor([[token]], dtype=torch.int32).numpy() | ||
| 185 | + target_len = torch.tensor([1], dtype=torch.int32).numpy() | ||
| 186 | + | ||
| 187 | + ( | ||
| 188 | + decoder_out, | ||
| 189 | + decoder_out_length, | ||
| 190 | + state0_next, | ||
| 191 | + state1_next, | ||
| 192 | + ) = self.decoder.run( | ||
| 193 | + [ | ||
| 194 | + self.decoder.get_outputs()[0].name, | ||
| 195 | + self.decoder.get_outputs()[1].name, | ||
| 196 | + self.decoder.get_outputs()[2].name, | ||
| 197 | + self.decoder.get_outputs()[3].name, | ||
| 198 | + ], | ||
| 199 | + { | ||
| 200 | + self.decoder.get_inputs()[0].name: target, | ||
| 201 | + self.decoder.get_inputs()[1].name: target_len, | ||
| 202 | + self.decoder.get_inputs()[2].name: state0, | ||
| 203 | + self.decoder.get_inputs()[3].name: state1, | ||
| 204 | + }, | ||
| 205 | + ) | ||
| 206 | + return decoder_out, state0_next, state1_next | ||
| 207 | + | ||
| 208 | + def run_joiner( | ||
| 209 | + self, | ||
| 210 | + encoder_out: np.ndarray, | ||
| 211 | + decoder_out: np.ndarray, | ||
| 212 | + ): | ||
| 213 | + # encoder_out: [batch_size, dim, 1] | ||
| 214 | + # decoder_out: [batch_size, dim, 1] | ||
| 215 | + logit = self.joiner.run( | ||
| 216 | + [ | ||
| 217 | + self.joiner.get_outputs()[0].name, | ||
| 218 | + ], | ||
| 219 | + { | ||
| 220 | + self.joiner.get_inputs()[0].name: encoder_out, | ||
| 221 | + self.joiner.get_inputs()[1].name: decoder_out, | ||
| 222 | + }, | ||
| 223 | + )[0] | ||
| 224 | + # logit: [batch_size, 1, 1, vocab_size] | ||
| 225 | + return logit | ||
| 226 | + | ||
| 227 | + | ||
| 228 | +def main(): | ||
| 229 | + args = get_args() | ||
| 230 | + assert Path(args.encoder).is_file(), args.encoder | ||
| 231 | + assert Path(args.decoder).is_file(), args.decoder | ||
| 232 | + assert Path(args.joiner).is_file(), args.joiner | ||
| 233 | + assert Path(args.tokens).is_file(), args.tokens | ||
| 234 | + assert Path(args.wav).is_file(), args.wav | ||
| 235 | + | ||
| 236 | + print(vars(args)) | ||
| 237 | + | ||
| 238 | + model = OnnxModel(args.encoder, args.decoder, args.joiner) | ||
| 239 | + | ||
| 240 | + id2token = dict() | ||
| 241 | + with open(args.tokens, encoding="utf-8") as f: | ||
| 242 | + for line in f: | ||
| 243 | + t, idx = line.split() | ||
| 244 | + id2token[int(idx)] = t | ||
| 245 | + | ||
| 246 | + fbank = create_fbank() | ||
| 247 | + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) | ||
| 248 | + audio = audio[:, 0] # only use the first channel | ||
| 249 | + if sample_rate != 16000: | ||
| 250 | + audio = librosa.resample( | ||
| 251 | + audio, | ||
| 252 | + orig_sr=sample_rate, | ||
| 253 | + target_sr=16000, | ||
| 254 | + ) | ||
| 255 | + sample_rate = 16000 | ||
| 256 | + | ||
| 257 | + tail_padding = np.zeros(sample_rate * 2) | ||
| 258 | + | ||
| 259 | + audio = np.concatenate([audio, tail_padding]) | ||
| 260 | + | ||
| 261 | + blank = len(id2token) - 1 | ||
| 262 | + ans = [blank] | ||
| 263 | + state0, state1 = model.get_decoder_state() | ||
| 264 | + decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1) | ||
| 265 | + | ||
| 266 | + features = compute_features(audio, fbank) | ||
| 267 | + if model.normalize_type != "": | ||
| 268 | + assert model.normalize_type == "per_feature", model.normalize_type | ||
| 269 | + features = torch.from_numpy(features) | ||
| 270 | + mean = features.mean(dim=1, keepdims=True) | ||
| 271 | + stddev = features.std(dim=1, keepdims=True) | ||
| 272 | + features = (features - mean) / stddev | ||
| 273 | + features = features.numpy() | ||
| 274 | + print(audio.shape) | ||
| 275 | + print("features.shape", features.shape) | ||
| 276 | + | ||
| 277 | + encoder_out = model.run_encoder(features) | ||
| 278 | + # encoder_out:[batch_size, dim, T) | ||
| 279 | + for t in range(encoder_out.shape[2]): | ||
| 280 | + encoder_out_t = encoder_out[:, :, t : t + 1] | ||
| 281 | + logits = model.run_joiner(encoder_out_t, decoder_out) | ||
| 282 | + logits = torch.from_numpy(logits) | ||
| 283 | + logits = logits.squeeze() | ||
| 284 | + idx = torch.argmax(logits, dim=-1).item() | ||
| 285 | + if idx != blank: | ||
| 286 | + ans.append(idx) | ||
| 287 | + state0 = state0_next | ||
| 288 | + state1 = state1_next | ||
| 289 | + decoder_out, state0_next, state1_next = model.run_decoder( | ||
| 290 | + ans[-1], state0, state1 | ||
| 291 | + ) | ||
| 292 | + | ||
| 293 | + ans = ans[1:] # remove the first blank | ||
| 294 | + print(ans) | ||
| 295 | + tokens = [id2token[i] for i in ans] | ||
| 296 | + underline = "▁" | ||
| 297 | + # underline = b"\xe2\x96\x81".decode() | ||
| 298 | + text = "".join(tokens).replace(underline, " ").strip() | ||
| 299 | + print(args.wav) | ||
| 300 | + print(text) | ||
| 301 | + | ||
| 302 | + | ||
| 303 | +main() |
-
请 注册 或 登录 后发表评论