正在显示
14 个修改的文件
包含
355 行增加
和
25 行删除
| @@ -47,9 +47,23 @@ for type in base small; do | @@ -47,9 +47,23 @@ for type in base small; do | ||
| 47 | rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 | 47 | rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 |
| 48 | done | 48 | done |
| 49 | 49 | ||
| 50 | +log "------------------------------------------------------------" | ||
| 51 | +log "Run NeMo GigaAM Russian models v2" | ||
| 52 | +log "------------------------------------------------------------" | ||
| 53 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 54 | +tar xvf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 55 | +rm sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 56 | + | ||
| 57 | +$EXE \ | ||
| 58 | + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/model.int8.onnx \ | ||
| 59 | + --tokens=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/tokens.txt \ | ||
| 60 | + --debug=1 \ | ||
| 61 | + ./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/test_wavs/example.wav | ||
| 62 | + | ||
| 63 | +rm -rf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19 | ||
| 50 | 64 | ||
| 51 | log "------------------------------------------------------------" | 65 | log "------------------------------------------------------------" |
| 52 | -log "Run NeMo GigaAM Russian models" | 66 | +log "Run NeMo GigaAM Russian models v1" |
| 53 | log "------------------------------------------------------------" | 67 | log "------------------------------------------------------------" |
| 54 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 | 68 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 |
| 55 | tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 | 69 | tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 |
| @@ -15,6 +15,24 @@ echo "PATH: $PATH" | @@ -15,6 +15,24 @@ echo "PATH: $PATH" | ||
| 15 | 15 | ||
| 16 | which $EXE | 16 | which $EXE |
| 17 | 17 | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | +log "Run NeMo GigaAM Russian models v2" | ||
| 20 | +log "------------------------------------------------------------" | ||
| 21 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 22 | +tar xvf sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 23 | +rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 | ||
| 24 | + | ||
| 25 | +$EXE \ | ||
| 26 | + --encoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/encoder.int8.onnx \ | ||
| 27 | + --decoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/decoder.onnx \ | ||
| 28 | + --joiner=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/joiner.onnx \ | ||
| 29 | + --tokens=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/tokens.txt \ | ||
| 30 | + --model-type=nemo_transducer \ | ||
| 31 | + ./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/test_wavs/example.wav | ||
| 32 | + | ||
| 33 | +rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19 | ||
| 34 | + | ||
| 35 | + | ||
| 18 | log "------------------------------------------------------------------------" | 36 | log "------------------------------------------------------------------------" |
| 19 | log "Run zipformer transducer models (Russian) " | 37 | log "Run zipformer transducer models (Russian) " |
| 20 | log "------------------------------------------------------------------------" | 38 | log "------------------------------------------------------------------------" |
| @@ -43,7 +43,8 @@ jobs: | @@ -43,7 +43,8 @@ jobs: | ||
| 43 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ | 43 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ |
| 44 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ | 44 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ |
| 45 | mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ | 45 | mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ |
| 46 | - mv -v scripts/nemo/GigaAM/*-ctc.py $d/ | 46 | + mv -v scripts/nemo/GigaAM/export-onnx-ctc.py $d/ |
| 47 | + cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/ | ||
| 47 | 48 | ||
| 48 | ls -lh scripts/nemo/GigaAM/ | 49 | ls -lh scripts/nemo/GigaAM/ |
| 49 | 50 | ||
| @@ -71,7 +72,8 @@ jobs: | @@ -71,7 +72,8 @@ jobs: | ||
| 71 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ | 72 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ |
| 72 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ | 73 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ |
| 73 | mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/ | 74 | mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/ |
| 74 | - mv -v scripts/nemo/GigaAM/*-rnnt.py $d/ | 75 | + mv -v scripts/nemo/GigaAM/export-onnx-rnnt.py $d/ |
| 76 | + cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/ | ||
| 75 | 77 | ||
| 76 | ls -lh scripts/nemo/GigaAM/ | 78 | ls -lh scripts/nemo/GigaAM/ |
| 77 | 79 | ||
| @@ -91,11 +93,12 @@ jobs: | @@ -91,11 +93,12 @@ jobs: | ||
| 91 | mkdir $d/test_wavs | 93 | mkdir $d/test_wavs |
| 92 | rm scripts/nemo/GigaAM/v2_ctc.onnx | 94 | rm scripts/nemo/GigaAM/v2_ctc.onnx |
| 93 | mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ | 95 | mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ |
| 94 | - cp -v scripts/nemo/GigaAM/LICENCE $d/ | 96 | + cp -v scripts/nemo/GigaAM/LICENSE $d/ |
| 95 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ | 97 | mv -v scripts/nemo/GigaAM/tokens.txt $d/ |
| 96 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ | 98 | mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ |
| 97 | - mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ | 99 | + mv -v scripts/nemo/GigaAM/run-ctc-v2.sh $d/ |
| 98 | mv -v scripts/nemo/GigaAM/*-ctc-v2.py $d/ | 100 | mv -v scripts/nemo/GigaAM/*-ctc-v2.py $d/ |
| 101 | + cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/ | ||
| 99 | 102 | ||
| 100 | ls -lh scripts/nemo/GigaAM/ | 103 | ls -lh scripts/nemo/GigaAM/ |
| 101 | 104 | ||
| @@ -103,8 +106,36 @@ jobs: | @@ -103,8 +106,36 @@ jobs: | ||
| 103 | 106 | ||
| 104 | tar cjvf ${d}.tar.bz2 $d | 107 | tar cjvf ${d}.tar.bz2 $d |
| 105 | 108 | ||
| 109 | + - name: Run Transducer v2 | ||
| 110 | + shell: bash | ||
| 111 | + run: | | ||
| 112 | + pushd scripts/nemo/GigaAM | ||
| 113 | + ./run-rnnt-v2.sh | ||
| 114 | + popd | ||
| 115 | + | ||
| 116 | + d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19 | ||
| 117 | + mkdir $d | ||
| 118 | + mkdir $d/test_wavs | ||
| 119 | + | ||
| 120 | + mv -v scripts/nemo/GigaAM/encoder.int8.onnx $d/ | ||
| 121 | + mv -v scripts/nemo/GigaAM/decoder.onnx $d/ | ||
| 122 | + mv -v scripts/nemo/GigaAM/joiner.onnx $d/ | ||
| 123 | + | ||
| 124 | + cp -v scripts/nemo/GigaAM/*.md $d/ | ||
| 125 | + cp -v scripts/nemo/GigaAM/LICENSE $d/ | ||
| 126 | + mv -v scripts/nemo/GigaAM/tokens.txt $d/ | ||
| 127 | + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ | ||
| 128 | + mv -v scripts/nemo/GigaAM/run-rnnt-v2.sh $d/ | ||
| 129 | + cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/ | ||
| 130 | + | ||
| 131 | + ls -lh scripts/nemo/GigaAM/ | ||
| 132 | + | ||
| 133 | + ls -lh $d | ||
| 134 | + | ||
| 135 | + tar cjvf ${d}.tar.bz2 $d | ||
| 106 | 136 | ||
| 107 | - name: Release | 137 | - name: Release |
| 138 | + if: github.repository_owner == 'csukuangfj' | ||
| 108 | uses: svenstaro/upload-release-action@v2 | 139 | uses: svenstaro/upload-release-action@v2 |
| 109 | with: | 140 | with: |
| 110 | file_glob: true | 141 | file_glob: true |
| @@ -114,7 +145,16 @@ jobs: | @@ -114,7 +145,16 @@ jobs: | ||
| 114 | repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | 145 | repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} |
| 115 | tag: asr-models | 146 | tag: asr-models |
| 116 | 147 | ||
| 117 | - - name: Publish to huggingface (Transducer) | 148 | + - name: Release |
| 149 | + if: github.repository_owner == 'k2-fsa' | ||
| 150 | + uses: svenstaro/upload-release-action@v2 | ||
| 151 | + with: | ||
| 152 | + file_glob: true | ||
| 153 | + file: ./*.tar.bz2 | ||
| 154 | + overwrite: true | ||
| 155 | + tag: asr-models | ||
| 156 | + | ||
| 157 | + - name: Publish to huggingface (CTC) | ||
| 118 | env: | 158 | env: |
| 119 | HF_TOKEN: ${{ secrets.HF_TOKEN }} | 159 | HF_TOKEN: ${{ secrets.HF_TOKEN }} |
| 120 | uses: nick-fields/retry@v3 | 160 | uses: nick-fields/retry@v3 |
| @@ -126,11 +166,66 @@ jobs: | @@ -126,11 +166,66 @@ jobs: | ||
| 126 | git config --global user.email "csukuangfj@gmail.com" | 166 | git config --global user.email "csukuangfj@gmail.com" |
| 127 | git config --global user.name "Fangjun Kuang" | 167 | git config --global user.name "Fangjun Kuang" |
| 128 | 168 | ||
| 169 | + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/ | ||
| 170 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 171 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 172 | + rm -rf huggingface | ||
| 173 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | ||
| 174 | + cp -av $d/* ./huggingface | ||
| 175 | + cd huggingface | ||
| 176 | + git lfs track "*.onnx" | ||
| 177 | + git lfs track "*.wav" | ||
| 178 | + git status | ||
| 179 | + git add . | ||
| 180 | + git status | ||
| 181 | + git commit -m "add models" | ||
| 182 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main | ||
| 183 | + | ||
| 184 | + - name: Publish to huggingface (Transducer) | ||
| 185 | + env: | ||
| 186 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 187 | + uses: nick-fields/retry@v3 | ||
| 188 | + with: | ||
| 189 | + max_attempts: 5 | ||
| 190 | + timeout_seconds: 200 | ||
| 191 | + shell: bash | ||
| 192 | + command: | | ||
| 193 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 194 | + git config --global user.name "Fangjun Kuang" | ||
| 195 | + | ||
| 129 | d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/ | 196 | d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/ |
| 130 | export GIT_LFS_SKIP_SMUDGE=1 | 197 | export GIT_LFS_SKIP_SMUDGE=1 |
| 131 | export GIT_CLONE_PROTECTION_ACTIVE=false | 198 | export GIT_CLONE_PROTECTION_ACTIVE=false |
| 199 | + rm -rf huggingface | ||
| 200 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | ||
| 201 | + cp -av $d/* ./huggingface | ||
| 202 | + cd huggingface | ||
| 203 | + git lfs track "*.onnx" | ||
| 204 | + git lfs track "*.wav" | ||
| 205 | + git status | ||
| 206 | + git add . | ||
| 207 | + git status | ||
| 208 | + git commit -m "add models" | ||
| 209 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main | ||
| 210 | + | ||
| 211 | + - name: Publish v2 to huggingface (CTC) | ||
| 212 | + env: | ||
| 213 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 214 | + uses: nick-fields/retry@v3 | ||
| 215 | + with: | ||
| 216 | + max_attempts: 5 | ||
| 217 | + timeout_seconds: 200 | ||
| 218 | + shell: bash | ||
| 219 | + command: | | ||
| 220 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 221 | + git config --global user.name "Fangjun Kuang" | ||
| 222 | + | ||
| 223 | + d=sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/ | ||
| 224 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 225 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 226 | + rm -rf huggingface | ||
| 132 | git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | 227 | git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface |
| 133 | - mv -v $d/* ./huggingface | 228 | + cp -av $d/* ./huggingface |
| 134 | cd huggingface | 229 | cd huggingface |
| 135 | git lfs track "*.onnx" | 230 | git lfs track "*.onnx" |
| 136 | git lfs track "*.wav" | 231 | git lfs track "*.wav" |
| @@ -145,7 +240,7 @@ jobs: | @@ -145,7 +240,7 @@ jobs: | ||
| 145 | HF_TOKEN: ${{ secrets.HF_TOKEN }} | 240 | HF_TOKEN: ${{ secrets.HF_TOKEN }} |
| 146 | uses: nick-fields/retry@v3 | 241 | uses: nick-fields/retry@v3 |
| 147 | with: | 242 | with: |
| 148 | - max_attempts: 20 | 243 | + max_attempts: 5 |
| 149 | timeout_seconds: 200 | 244 | timeout_seconds: 200 |
| 150 | shell: bash | 245 | shell: bash |
| 151 | command: | | 246 | command: | |
| @@ -155,8 +250,9 @@ jobs: | @@ -155,8 +250,9 @@ jobs: | ||
| 155 | d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/ | 250 | d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/ |
| 156 | export GIT_LFS_SKIP_SMUDGE=1 | 251 | export GIT_LFS_SKIP_SMUDGE=1 |
| 157 | export GIT_CLONE_PROTECTION_ACTIVE=false | 252 | export GIT_CLONE_PROTECTION_ACTIVE=false |
| 253 | + rm -rf huggingface | ||
| 158 | git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | 254 | git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface |
| 159 | - mv -v $d/* ./huggingface | 255 | + cp -av $d/* ./huggingface |
| 160 | cd huggingface | 256 | cd huggingface |
| 161 | git lfs track "*.onnx" | 257 | git lfs track "*.onnx" |
| 162 | git lfs track "*.wav" | 258 | git lfs track "*.wav" |
| @@ -7,4 +7,4 @@ to sherpa-onnx. | @@ -7,4 +7,4 @@ to sherpa-onnx. | ||
| 7 | The ASR models are for Russian speech recognition in this folder. | 7 | The ASR models are for Russian speech recognition in this folder. |
| 8 | 8 | ||
| 9 | Please see the license of the models at | 9 | Please see the license of the models at |
| 10 | -https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf | 10 | +https://github.com/salute-developers/GigaAM/blob/main/LICENSE |
scripts/nemo/GigaAM/export-onnx-ctc-v2.py
100644 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 1 | import gigaam | 2 | import gigaam |
| 2 | import onnx | 3 | import onnx |
| 3 | import torch | 4 | import torch |
| @@ -27,7 +28,13 @@ def add_meta_data(filename: str, meta_data: dict[str, str]): | @@ -27,7 +28,13 @@ def add_meta_data(filename: str, meta_data: dict[str, str]): | ||
| 27 | 28 | ||
| 28 | def main() -> None: | 29 | def main() -> None: |
| 29 | model_name = "v2_ctc" | 30 | model_name = "v2_ctc" |
| 30 | - model = gigaam.load_model(model_name, fp16_encoder=False, use_flash=False, download_root=".") | 31 | + model = gigaam.load_model( |
| 32 | + model_name, fp16_encoder=False, use_flash=False, download_root="." | ||
| 33 | + ) | ||
| 34 | + | ||
| 35 | + # use characters | ||
| 36 | + # space is 0 | ||
| 37 | + # <blk> is the last token | ||
| 31 | with open("./tokens.txt", "w", encoding="utf-8") as f: | 38 | with open("./tokens.txt", "w", encoding="utf-8") as f: |
| 32 | for i, s in enumerate(model.cfg["labels"]): | 39 | for i, s in enumerate(model.cfg["labels"]): |
| 33 | f.write(f"{s} {i}\n") | 40 | f.write(f"{s} {i}\n") |
| @@ -53,5 +60,5 @@ def main() -> None: | @@ -53,5 +60,5 @@ def main() -> None: | ||
| 53 | ) | 60 | ) |
| 54 | 61 | ||
| 55 | 62 | ||
| 56 | -if __name__ == '__main__': | 63 | +if __name__ == "__main__": |
| 57 | main() | 64 | main() |
| @@ -82,6 +82,9 @@ def main(): | @@ -82,6 +82,9 @@ def main(): | ||
| 82 | model.load_state_dict(ckpt, strict=False) | 82 | model.load_state_dict(ckpt, strict=False) |
| 83 | model.eval() | 83 | model.eval() |
| 84 | 84 | ||
| 85 | + # use characters | ||
| 86 | + # space is 0 | ||
| 87 | + # <blk> is the last token | ||
| 85 | with open("tokens.txt", "w", encoding="utf-8") as f: | 88 | with open("tokens.txt", "w", encoding="utf-8") as f: |
| 86 | for i, t in enumerate(model.cfg.labels): | 89 | for i, t in enumerate(model.cfg.labels): |
| 87 | f.write(f"{t} {i}\n") | 90 | f.write(f"{t} {i}\n") |
scripts/nemo/GigaAM/export-onnx-rnnt-v2.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import os | ||
| 4 | + | ||
| 5 | +import gigaam | ||
| 6 | +import onnx | ||
| 7 | +import torch | ||
| 8 | +from gigaam.utils import onnx_converter | ||
| 9 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 10 | +from torch import Tensor | ||
| 11 | + | ||
| 12 | +""" | ||
| 13 | +==========Input========== | ||
| 14 | +NodeArg(name='audio_signal', type='tensor(float)', shape=['batch_size', 64, 'seq_len']) | ||
| 15 | +NodeArg(name='length', type='tensor(int64)', shape=['batch_size']) | ||
| 16 | +==========Output========== | ||
| 17 | +NodeArg(name='encoded', type='tensor(float)', shape=['batch_size', 768, 'Transposeencoded_dim_2']) | ||
| 18 | +NodeArg(name='encoded_len', type='tensor(int32)', shape=['batch_size']) | ||
| 19 | + | ||
| 20 | +==========Input========== | ||
| 21 | +NodeArg(name='x', type='tensor(int32)', shape=[1, 1]) | ||
| 22 | +NodeArg(name='unused_x_len.1', type='tensor(int32)', shape=[1]) | ||
| 23 | +NodeArg(name='h.1', type='tensor(float)', shape=[1, 1, 320]) | ||
| 24 | +NodeArg(name='c.1', type='tensor(float)', shape=[1, 1, 320]) | ||
| 25 | +==========Output========== | ||
| 26 | +NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1]) | ||
| 27 | +NodeArg(name='unused_x_len', type='tensor(int32)', shape=[1]) | ||
| 28 | +NodeArg(name='h', type='tensor(float)', shape=[1, 1, 320]) | ||
| 29 | +NodeArg(name='c', type='tensor(float)', shape=[1, 1, 320]) | ||
| 30 | + | ||
| 31 | +==========Input========== | ||
| 32 | +NodeArg(name='enc', type='tensor(float)', shape=[1, 768, 1]) | ||
| 33 | +NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1]) | ||
| 34 | +==========Output========== | ||
| 35 | +NodeArg(name='joint', type='tensor(float)', shape=[1, 1, 1, 34]) | ||
| 36 | +""" | ||
| 37 | + | ||
| 38 | + | ||
| 39 | +def add_meta_data(filename: str, meta_data: dict[str, str]): | ||
| 40 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 41 | + | ||
| 42 | + Args: | ||
| 43 | + filename: | ||
| 44 | + Filename of the ONNX model to be changed. | ||
| 45 | + meta_data: | ||
| 46 | + Key-value pairs. | ||
| 47 | + """ | ||
| 48 | + model = onnx.load(filename) | ||
| 49 | + while len(model.metadata_props): | ||
| 50 | + model.metadata_props.pop() | ||
| 51 | + | ||
| 52 | + for key, value in meta_data.items(): | ||
| 53 | + meta = model.metadata_props.add() | ||
| 54 | + meta.key = key | ||
| 55 | + meta.value = str(value) | ||
| 56 | + | ||
| 57 | + onnx.save(model, filename) | ||
| 58 | + | ||
| 59 | + | ||
| 60 | +class EncoderWrapper(torch.nn.Module): | ||
| 61 | + def __init__(self, m): | ||
| 62 | + super().__init__() | ||
| 63 | + self.m = m | ||
| 64 | + | ||
| 65 | + def forward(self, audio_signal: Tensor, length: Tensor): | ||
| 66 | + # https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py#L499 | ||
| 67 | + out, out_len = self.m.encoder(audio_signal, length) | ||
| 68 | + | ||
| 69 | + return out, out_len.to(torch.int64) | ||
| 70 | + | ||
| 71 | + def to_onnx(self, dir_path: str = "."): | ||
| 72 | + onnx_converter( | ||
| 73 | + model_name=f"{self.m.cfg.model_name}_encoder", | ||
| 74 | + out_dir=dir_path, | ||
| 75 | + module=self.m.encoder, | ||
| 76 | + dynamic_axes=self.m.encoder.dynamic_axes(), | ||
| 77 | + ) | ||
| 78 | + | ||
| 79 | + | ||
| 80 | +class DecoderWrapper(torch.nn.Module): | ||
| 81 | + def __init__(self, m): | ||
| 82 | + super().__init__() | ||
| 83 | + self.m = m | ||
| 84 | + | ||
| 85 | + def forward(self, x: Tensor, unused_x_len: Tensor, h: Tensor, c: Tensor): | ||
| 86 | + # https://github.com/salute-developers/GigaAM/blob/main/gigaam/decoder.py#L110C17-L110C54 | ||
| 87 | + emb = self.m.head.decoder.embed(x) | ||
| 88 | + g, (h, c) = self.m.head.decoder.lstm(emb.transpose(0, 1), (h, c)) | ||
| 89 | + return g.permute(1, 2, 0), unused_x_len + 1, h, c | ||
| 90 | + | ||
| 91 | + def to_onnx(self, dir_path: str = "."): | ||
| 92 | + label, hidden_h, hidden_c = self.m.head.decoder.input_example() | ||
| 93 | + label = label.to(torch.int32) | ||
| 94 | + label_len = torch.zeros(1, dtype=torch.int32) | ||
| 95 | + | ||
| 96 | + onnx_converter( | ||
| 97 | + model_name=f"{self.m.cfg.model_name}_decoder", | ||
| 98 | + out_dir=dir_path, | ||
| 99 | + module=self, | ||
| 100 | + dynamic_axes=self.m.encoder.dynamic_axes(), | ||
| 101 | + inputs=(label, label_len, hidden_h, hidden_c), | ||
| 102 | + input_names=["x", "unused_x_len.1", "h.1", "c.1"], | ||
| 103 | + output_names=["dec", "unused_x_len", "h", "c"], | ||
| 104 | + ) | ||
| 105 | + | ||
| 106 | + | ||
| 107 | +def main() -> None: | ||
| 108 | + model_name = "v2_rnnt" | ||
| 109 | + model = gigaam.load_model( | ||
| 110 | + model_name, fp16_encoder=False, use_flash=False, download_root="." | ||
| 111 | + ) | ||
| 112 | + | ||
| 113 | + # use characters | ||
| 114 | + # space is 0 | ||
| 115 | + # <blk> is the last token | ||
| 116 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 117 | + for i, s in enumerate(model.cfg["labels"]): | ||
| 118 | + f.write(f"{s} {i}\n") | ||
| 119 | + f.write(f"<blk> {i+1}\n") | ||
| 120 | + print("Saved to tokens.txt") | ||
| 121 | + | ||
| 122 | + EncoderWrapper(model).to_onnx(".") | ||
| 123 | + DecoderWrapper(model).to_onnx(".") | ||
| 124 | + | ||
| 125 | + onnx_converter( | ||
| 126 | + model_name=f"{model.cfg.model_name}_joint", | ||
| 127 | + out_dir=".", | ||
| 128 | + module=model.head.joint, | ||
| 129 | + ) | ||
| 130 | + meta_data = { | ||
| 131 | + # vocab_size does not include the blank | ||
| 132 | + # we will increase vocab_size by 1 in the c++ code | ||
| 133 | + "vocab_size": model.cfg["head"]["decoder"]["num_classes"] - 1, | ||
| 134 | + "pred_rnn_layers": model.cfg["head"]["decoder"]["pred_rnn_layers"], | ||
| 135 | + "pred_hidden": model.cfg["head"]["decoder"]["pred_hidden"], | ||
| 136 | + "normalize_type": "", | ||
| 137 | + "subsampling_factor": 4, | ||
| 138 | + "model_type": "EncDecRNNTBPEModel", | ||
| 139 | + "version": "2", | ||
| 140 | + "model_author": "https://github.com/salute-developers/GigaAM", | ||
| 141 | + "license": "https://github.com/salute-developers/GigaAM/blob/main/LICENSE", | ||
| 142 | + "language": "Russian", | ||
| 143 | + "is_giga_am": 1, | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + add_meta_data(f"./{model_name}_encoder.onnx", meta_data) | ||
| 147 | + quantize_dynamic( | ||
| 148 | + model_input=f"./{model_name}_encoder.onnx", | ||
| 149 | + model_output="./encoder.int8.onnx", | ||
| 150 | + weight_type=QuantType.QUInt8, | ||
| 151 | + ) | ||
| 152 | + os.rename(f"./{model_name}_decoder.onnx", "decoder.onnx") | ||
| 153 | + os.rename(f"./{model_name}_joint.onnx", "joiner.onnx") | ||
| 154 | + os.remove(f"./{model_name}_encoder.onnx") | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +if __name__ == "__main__": | ||
| 158 | + main() |
scripts/nemo/GigaAM/export-onnx-rnnt.py
100644 → 100755
| @@ -83,6 +83,7 @@ def main(): | @@ -83,6 +83,7 @@ def main(): | ||
| 83 | model.load_state_dict(ckpt, strict=False) | 83 | model.load_state_dict(ckpt, strict=False) |
| 84 | model.eval() | 84 | model.eval() |
| 85 | 85 | ||
| 86 | + # use bpe | ||
| 86 | with open("./tokens.txt", "w", encoding="utf-8") as f: | 87 | with open("./tokens.txt", "w", encoding="utf-8") as f: |
| 87 | for i, s in enumerate(model.joint.vocabulary): | 88 | for i, s in enumerate(model.joint.vocabulary): |
| 88 | f.write(f"{s} {i}\n") | 89 | f.write(f"{s} {i}\n") |
| @@ -94,7 +95,9 @@ def main(): | @@ -94,7 +95,9 @@ def main(): | ||
| 94 | model.joint.export("joiner.onnx") | 95 | model.joint.export("joiner.onnx") |
| 95 | 96 | ||
| 96 | meta_data = { | 97 | meta_data = { |
| 97 | - "vocab_size": model.decoder.vocab_size, # not including the blank | 98 | + # not including the blank |
| 99 | + # we increase vocab_size in the C++ code | ||
| 100 | + "vocab_size": model.decoder.vocab_size, | ||
| 98 | "pred_rnn_layers": model.decoder.pred_rnn_layers, | 101 | "pred_rnn_layers": model.decoder.pred_rnn_layers, |
| 99 | "pred_hidden": model.decoder.pred_hidden, | 102 | "pred_hidden": model.decoder.pred_hidden, |
| 100 | "normalize_type": "", | 103 | "normalize_type": "", |
| @@ -5,11 +5,14 @@ set -ex | @@ -5,11 +5,14 @@ set -ex | ||
| 5 | function install_gigaam() { | 5 | function install_gigaam() { |
| 6 | curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py | 6 | curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py |
| 7 | python3 get-pip.py | 7 | python3 get-pip.py |
| 8 | + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html | ||
| 9 | + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa | ||
| 8 | 10 | ||
| 9 | BRANCH='main' | 11 | BRANCH='main' |
| 10 | python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam | 12 | python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam |
| 11 | 13 | ||
| 12 | python3 -m pip install -qq kaldi-native-fbank | 14 | python3 -m pip install -qq kaldi-native-fbank |
| 15 | + pip install numpy==1.26.4 | ||
| 13 | } | 16 | } |
| 14 | 17 | ||
| 15 | function download_files() { | 18 | function download_files() { |
| @@ -9,7 +9,7 @@ function install_nemo() { | @@ -9,7 +9,7 @@ function install_nemo() { | ||
| 9 | 9 | ||
| 10 | pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html | 10 | pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html |
| 11 | 11 | ||
| 12 | - pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa | 12 | + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa |
| 13 | pip install -qq ipython | 13 | pip install -qq ipython |
| 14 | 14 | ||
| 15 | # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython | 15 | # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython |
scripts/nemo/GigaAM/run-rnnt-v2.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +function install_gigaam() { | ||
| 7 | + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py | ||
| 8 | + python3 get-pip.py | ||
| 9 | + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html | ||
| 10 | + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa | ||
| 11 | + | ||
| 12 | + BRANCH='main' | ||
| 13 | + python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam | ||
| 14 | + | ||
| 15 | + python3 -m pip install -qq kaldi-native-fbank | ||
| 16 | + pip install numpy==1.26.4 | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +function download_files() { | ||
| 20 | + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/example.wav | ||
| 21 | + curl -SL -O https://github.com/salute-developers/GigaAM/blob/main/LICENSE | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +install_gigaam | ||
| 25 | +download_files | ||
| 26 | + | ||
| 27 | +python3 ./export-onnx-rnnt-v2.py | ||
| 28 | +ls -lh | ||
| 29 | +python3 ./test-onnx-rnnt.py |
| @@ -9,7 +9,7 @@ function install_nemo() { | @@ -9,7 +9,7 @@ function install_nemo() { | ||
| 9 | 9 | ||
| 10 | pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html | 10 | pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html |
| 11 | 11 | ||
| 12 | - pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa | 12 | + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa |
| 13 | pip install -qq ipython | 13 | pip install -qq ipython |
| 14 | 14 | ||
| 15 | # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython | 15 | # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython |
| @@ -19,7 +19,7 @@ def create_fbank(): | @@ -19,7 +19,7 @@ def create_fbank(): | ||
| 19 | opts.frame_opts.window_type = "hann" | 19 | opts.frame_opts.window_type = "hann" |
| 20 | 20 | ||
| 21 | # Even though GigaAM uses 400 for fft, here we use 512 | 21 | # Even though GigaAM uses 400 for fft, here we use 512 |
| 22 | - # since kaldi-native-fbank only support fft for power of 2. | 22 | + # since kaldi-native-fbank only supports fft for power of 2. |
| 23 | opts.frame_opts.round_to_power_of_two = True | 23 | opts.frame_opts.round_to_power_of_two = True |
| 24 | 24 | ||
| 25 | opts.mel_opts.low_freq = 0 | 25 | opts.mel_opts.low_freq = 0 |
| @@ -20,7 +20,7 @@ def create_fbank(): | @@ -20,7 +20,7 @@ def create_fbank(): | ||
| 20 | opts.frame_opts.window_type = "hann" | 20 | opts.frame_opts.window_type = "hann" |
| 21 | 21 | ||
| 22 | # Even though GigaAM uses 400 for fft, here we use 512 | 22 | # Even though GigaAM uses 400 for fft, here we use 512 |
| 23 | - # since kaldi-native-fbank only support fft for power of 2. | 23 | + # since kaldi-native-fbank only supports fft for power of 2. |
| 24 | opts.frame_opts.round_to_power_of_two = True | 24 | opts.frame_opts.round_to_power_of_two = True |
| 25 | 25 | ||
| 26 | opts.mel_opts.low_freq = 0 | 26 | opts.mel_opts.low_freq = 0 |
| @@ -166,12 +166,7 @@ class OnnxModel: | @@ -166,12 +166,7 @@ class OnnxModel: | ||
| 166 | target = torch.tensor([[token]], dtype=torch.int32).numpy() | 166 | target = torch.tensor([[token]], dtype=torch.int32).numpy() |
| 167 | target_len = torch.tensor([1], dtype=torch.int32).numpy() | 167 | target_len = torch.tensor([1], dtype=torch.int32).numpy() |
| 168 | 168 | ||
| 169 | - ( | ||
| 170 | - decoder_out, | ||
| 171 | - decoder_out_length, | ||
| 172 | - state0_next, | ||
| 173 | - state1_next, | ||
| 174 | - ) = self.decoder.run( | 169 | + (decoder_out, decoder_out_length, state0_next, state1_next,) = self.decoder.run( |
| 175 | [ | 170 | [ |
| 176 | self.decoder.get_outputs()[0].name, | 171 | self.decoder.get_outputs()[0].name, |
| 177 | self.decoder.get_outputs()[1].name, | 172 | self.decoder.get_outputs()[1].name, |
| @@ -213,8 +208,12 @@ def main(): | @@ -213,8 +208,12 @@ def main(): | ||
| 213 | id2token = dict() | 208 | id2token = dict() |
| 214 | with open("./tokens.txt", encoding="utf-8") as f: | 209 | with open("./tokens.txt", encoding="utf-8") as f: |
| 215 | for line in f: | 210 | for line in f: |
| 216 | - t, idx = line.split() | ||
| 217 | - id2token[int(idx)] = t | 211 | + fields = line.split() |
| 212 | + if len(fields) == 1: | ||
| 213 | + id2token[int(fields[0])] = " " | ||
| 214 | + else: | ||
| 215 | + t, idx = fields | ||
| 216 | + id2token[int(idx)] = t | ||
| 218 | 217 | ||
| 219 | fbank = create_fbank() | 218 | fbank = create_fbank() |
| 220 | audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True) | 219 | audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True) |
-
请 注册 或 登录 后发表评论