Committed by
GitHub
Remove the 30-second constraint from whisper. (#471)
正在显示
10 个修改的文件
包含
178 行增加
和
78 行删除
| @@ -16,8 +16,12 @@ which $EXE | @@ -16,8 +16,12 @@ which $EXE | ||
| 16 | names=( | 16 | names=( |
| 17 | tiny.en | 17 | tiny.en |
| 18 | base.en | 18 | base.en |
| 19 | -# small.en | ||
| 20 | -# medium.en | 19 | +small.en |
| 20 | +medium.en | ||
| 21 | +tiny | ||
| 22 | +base | ||
| 23 | +small | ||
| 24 | +medium | ||
| 21 | ) | 25 | ) |
| 22 | 26 | ||
| 23 | for name in ${names[@]}; do | 27 | for name in ${names[@]}; do |
| @@ -33,8 +37,8 @@ for name in ${names[@]}; do | @@ -33,8 +37,8 @@ for name in ${names[@]}; do | ||
| 33 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 37 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 34 | pushd $repo | 38 | pushd $repo |
| 35 | git lfs pull --include "*.onnx" | 39 | git lfs pull --include "*.onnx" |
| 36 | - git lfs pull --include "*.ort" | ||
| 37 | - ls -lh *.{onnx,ort} | 40 | + # git lfs pull --include "*.ort" |
| 41 | + ls -lh *.onnx | ||
| 38 | popd | 42 | popd |
| 39 | 43 | ||
| 40 | log "test fp32 onnx" | 44 | log "test fp32 onnx" |
| @@ -43,6 +47,7 @@ for name in ${names[@]}; do | @@ -43,6 +47,7 @@ for name in ${names[@]}; do | ||
| 43 | --tokens=$repo/${name}-tokens.txt \ | 47 | --tokens=$repo/${name}-tokens.txt \ |
| 44 | --whisper-encoder=$repo/${name}-encoder.onnx \ | 48 | --whisper-encoder=$repo/${name}-encoder.onnx \ |
| 45 | --whisper-decoder=$repo/${name}-decoder.onnx \ | 49 | --whisper-decoder=$repo/${name}-decoder.onnx \ |
| 50 | + --whisper-tail-paddings=500 \ | ||
| 46 | --num-threads=2 \ | 51 | --num-threads=2 \ |
| 47 | $repo/test_wavs/0.wav \ | 52 | $repo/test_wavs/0.wav \ |
| 48 | $repo/test_wavs/1.wav \ | 53 | $repo/test_wavs/1.wav \ |
| @@ -54,28 +59,7 @@ for name in ${names[@]}; do | @@ -54,28 +59,7 @@ for name in ${names[@]}; do | ||
| 54 | --tokens=$repo/${name}-tokens.txt \ | 59 | --tokens=$repo/${name}-tokens.txt \ |
| 55 | --whisper-encoder=$repo/${name}-encoder.int8.onnx \ | 60 | --whisper-encoder=$repo/${name}-encoder.int8.onnx \ |
| 56 | --whisper-decoder=$repo/${name}-decoder.int8.onnx \ | 61 | --whisper-decoder=$repo/${name}-decoder.int8.onnx \ |
| 57 | - --num-threads=2 \ | ||
| 58 | - $repo/test_wavs/0.wav \ | ||
| 59 | - $repo/test_wavs/1.wav \ | ||
| 60 | - $repo/test_wavs/8k.wav | ||
| 61 | - | ||
| 62 | - log "test fp32 ort" | ||
| 63 | - | ||
| 64 | - time $EXE \ | ||
| 65 | - --tokens=$repo/${name}-tokens.txt \ | ||
| 66 | - --whisper-encoder=$repo/${name}-encoder.ort \ | ||
| 67 | - --whisper-decoder=$repo/${name}-decoder.ort \ | ||
| 68 | - --num-threads=2 \ | ||
| 69 | - $repo/test_wavs/0.wav \ | ||
| 70 | - $repo/test_wavs/1.wav \ | ||
| 71 | - $repo/test_wavs/8k.wav | ||
| 72 | - | ||
| 73 | - log "test int8 ort" | ||
| 74 | - | ||
| 75 | - time $EXE \ | ||
| 76 | - --tokens=$repo/${name}-tokens.txt \ | ||
| 77 | - --whisper-encoder=$repo/${name}-encoder.int8.ort \ | ||
| 78 | - --whisper-decoder=$repo/${name}-decoder.int8.ort \ | 62 | + --whisper-tail-paddings=500 \ |
| 79 | --num-threads=2 \ | 63 | --num-threads=2 \ |
| 80 | $repo/test_wavs/0.wav \ | 64 | $repo/test_wavs/0.wav \ |
| 81 | $repo/test_wavs/1.wav \ | 65 | $repo/test_wavs/1.wav \ |
| @@ -15,7 +15,7 @@ jobs: | @@ -15,7 +15,7 @@ jobs: | ||
| 15 | strategy: | 15 | strategy: |
| 16 | fail-fast: false | 16 | fail-fast: false |
| 17 | matrix: | 17 | matrix: |
| 18 | - os: [macos-latest] | 18 | + os: [ubuntu-latest] |
| 19 | model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] | 19 | model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] |
| 20 | python-version: ["3.8"] | 20 | python-version: ["3.8"] |
| 21 | 21 | ||
| @@ -44,7 +44,7 @@ jobs: | @@ -44,7 +44,7 @@ jobs: | ||
| 44 | ls -lh | 44 | ls -lh |
| 45 | fi | 45 | fi |
| 46 | python3 ./export-onnx.py --model ${{ matrix.model }} | 46 | python3 ./export-onnx.py --model ${{ matrix.model }} |
| 47 | - python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ | 47 | + # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ |
| 48 | 48 | ||
| 49 | ls -lh | 49 | ls -lh |
| 50 | 50 | ||
| @@ -52,41 +52,61 @@ jobs: | @@ -52,41 +52,61 @@ jobs: | ||
| 52 | ls -lh ~/.cache/whisper | 52 | ls -lh ~/.cache/whisper |
| 53 | fi | 53 | fi |
| 54 | 54 | ||
| 55 | + src=sherpa-onnx-whisper-${{ matrix.model }} | ||
| 56 | + | ||
| 57 | + mkdir $src | ||
| 58 | + cp *.onnx $src/ | ||
| 59 | + cp *tokens.txt $src | ||
| 60 | + | ||
| 61 | + cd $src | ||
| 62 | + mkdir -p test_wavs | ||
| 63 | + cd test_wavs | ||
| 64 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav | ||
| 65 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav | ||
| 66 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav | ||
| 67 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt | ||
| 68 | + cd ../.. | ||
| 69 | + mv $src ../.. | ||
| 70 | + | ||
| 71 | + cd ../.. | ||
| 72 | + echo "--------------------" | ||
| 73 | + ls -lh | ||
| 74 | + ls -lh $src | ||
| 75 | + echo "--------------------" | ||
| 76 | + | ||
| 77 | + tar cjvf ./$src.tar.bz2 $src | ||
| 78 | + | ||
| 79 | + - name: Release | ||
| 80 | + uses: svenstaro/upload-release-action@v2 | ||
| 81 | + with: | ||
| 82 | + file_glob: true | ||
| 83 | + file: ./*.tar.bz2 | ||
| 84 | + overwrite: true | ||
| 85 | + repo_name: k2-fsa/sherpa-onnx | ||
| 86 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 87 | + tag: asr-models | ||
| 88 | + | ||
| 55 | - name: Publish ${{ matrix.model }} to huggingface | 89 | - name: Publish ${{ matrix.model }} to huggingface |
| 56 | shell: bash | 90 | shell: bash |
| 57 | env: | 91 | env: |
| 58 | HF_TOKEN: ${{ secrets.HF_TOKEN }} | 92 | HF_TOKEN: ${{ secrets.HF_TOKEN }} |
| 59 | run: | | 93 | run: | |
| 60 | - model=${{ matrix.model }} | ||
| 61 | - | ||
| 62 | - cd scripts/whisper | 94 | + src=sherpa-onnx-whisper-${{ matrix.model }} |
| 63 | 95 | ||
| 64 | git config --global user.email "csukuangfj@gmail.com" | 96 | git config --global user.email "csukuangfj@gmail.com" |
| 65 | git config --global user.name "Fangjun Kuang" | 97 | git config --global user.name "Fangjun Kuang" |
| 66 | 98 | ||
| 67 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface | 99 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface |
| 100 | + rm -rf huggingface/* | ||
| 68 | 101 | ||
| 69 | - cp *.onnx ./huggingface | ||
| 70 | - cp *.ort ./huggingface | ||
| 71 | - cp *tokens.txt ./huggingface | 102 | + cp -av $src/* ./huggingface/ |
| 72 | 103 | ||
| 73 | cd huggingface | 104 | cd huggingface |
| 74 | 105 | ||
| 75 | - if [[ $model == distil-medium.en ]]; then | ||
| 76 | - mkdir test_wavs | ||
| 77 | - cd test_wavs | ||
| 78 | - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav | ||
| 79 | - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav | ||
| 80 | - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav | ||
| 81 | - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt | ||
| 82 | - git add . | ||
| 83 | - cd .. | ||
| 84 | - fi | ||
| 85 | - | ||
| 86 | git status | 106 | git status |
| 87 | ls -lh | 107 | ls -lh |
| 88 | git lfs track "*.onnx" | 108 | git lfs track "*.onnx" |
| 89 | - git lfs track "*.ort" | 109 | + # git lfs track "*.ort" |
| 90 | git add . | 110 | git add . |
| 91 | git commit -m "upload ${{ matrix.model }}" | 111 | git commit -m "upload ${{ matrix.model }}" |
| 92 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main | 112 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main |
| @@ -107,6 +107,16 @@ jobs: | @@ -107,6 +107,16 @@ jobs: | ||
| 107 | name: release-static | 107 | name: release-static |
| 108 | path: build/bin/* | 108 | path: build/bin/* |
| 109 | 109 | ||
| 110 | + - name: Test offline Whisper | ||
| 111 | + shell: bash | ||
| 112 | + run: | | ||
| 113 | + export PATH=$PWD/build/bin:$PATH | ||
| 114 | + export EXE=sherpa-onnx-offline | ||
| 115 | + | ||
| 116 | + readelf -d build/bin/sherpa-onnx-offline | ||
| 117 | + | ||
| 118 | + .github/scripts/test-offline-whisper.sh | ||
| 119 | + | ||
| 110 | - name: Test online CTC | 120 | - name: Test online CTC |
| 111 | shell: bash | 121 | shell: bash |
| 112 | run: | | 122 | run: | |
| @@ -139,16 +149,6 @@ jobs: | @@ -139,16 +149,6 @@ jobs: | ||
| 139 | 149 | ||
| 140 | .github/scripts/test-online-paraformer.sh | 150 | .github/scripts/test-online-paraformer.sh |
| 141 | 151 | ||
| 142 | - - name: Test offline Whisper | ||
| 143 | - shell: bash | ||
| 144 | - run: | | ||
| 145 | - export PATH=$PWD/build/bin:$PATH | ||
| 146 | - export EXE=sherpa-onnx-offline | ||
| 147 | - | ||
| 148 | - readelf -d build/bin/sherpa-onnx-offline | ||
| 149 | - | ||
| 150 | - .github/scripts/test-offline-whisper.sh | ||
| 151 | - | ||
| 152 | - name: Test offline transducer | 152 | - name: Test offline transducer |
| 153 | shell: bash | 153 | shell: bash |
| 154 | run: | | 154 | run: | |
| @@ -93,13 +93,13 @@ jobs: | @@ -93,13 +93,13 @@ jobs: | ||
| 93 | 93 | ||
| 94 | .github/scripts/test-online-paraformer.sh | 94 | .github/scripts/test-online-paraformer.sh |
| 95 | 95 | ||
| 96 | - - name: Test offline Whisper for windows x86 | ||
| 97 | - shell: bash | ||
| 98 | - run: | | ||
| 99 | - export PATH=$PWD/build/bin/Release:$PATH | ||
| 100 | - export EXE=sherpa-onnx-offline.exe | ||
| 101 | - | ||
| 102 | - .github/scripts/test-offline-whisper.sh | 96 | + # - name: Test offline Whisper for windows x86 |
| 97 | + # shell: bash | ||
| 98 | + # run: | | ||
| 99 | + # export PATH=$PWD/build/bin/Release:$PATH | ||
| 100 | + # export EXE=sherpa-onnx-offline.exe | ||
| 101 | + # | ||
| 102 | + # .github/scripts/test-offline-whisper.sh | ||
| 103 | 103 | ||
| 104 | - name: Test offline CTC for windows x86 | 104 | - name: Test offline CTC for windows x86 |
| 105 | shell: bash | 105 | shell: bash |
| @@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py | @@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py | ||
| 8 | 8 | ||
| 9 | Thanks to https://github.com/TadaoYamaoka | 9 | Thanks to https://github.com/TadaoYamaoka |
| 10 | for making the onnx export script public. | 10 | for making the onnx export script public. |
| 11 | + | ||
| 12 | +Note that we have removed the 30 seconds constraint from whisper. You can | ||
| 13 | +use any T <= 30. | ||
| 11 | """ | 14 | """ |
| 12 | 15 | ||
| 13 | import argparse | 16 | import argparse |
| @@ -17,6 +20,7 @@ from typing import Any, Dict, Optional | @@ -17,6 +20,7 @@ from typing import Any, Dict, Optional | ||
| 17 | 20 | ||
| 18 | import onnx | 21 | import onnx |
| 19 | import torch | 22 | import torch |
| 23 | +import torch.nn.functional as F | ||
| 20 | from onnxruntime.quantization import QuantType, quantize_dynamic | 24 | from onnxruntime.quantization import QuantType, quantize_dynamic |
| 21 | from torch import Tensor, nn | 25 | from torch import Tensor, nn |
| 22 | 26 | ||
| @@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): | @@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
| 65 | onnx.save(model, filename) | 69 | onnx.save(model, filename) |
| 66 | 70 | ||
| 67 | 71 | ||
| 72 | +def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor): | ||
| 73 | + """ | ||
| 74 | + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) | ||
| 75 | + the mel spectrogram of the audio | ||
| 76 | + """ | ||
| 77 | + x = F.gelu(self.conv1(x)) | ||
| 78 | + x = F.gelu(self.conv2(x)) | ||
| 79 | + x = x.permute(0, 2, 1) | ||
| 80 | + | ||
| 81 | + if False: | ||
| 82 | + # This branch contains the original code | ||
| 83 | + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" | ||
| 84 | + x = (x + self.positional_embedding).to(x.dtype) | ||
| 85 | + else: | ||
| 86 | + # This branch contains the actual changes | ||
| 87 | + assert ( | ||
| 88 | + x.shape[2] == self.positional_embedding.shape[1] | ||
| 89 | + ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}" | ||
| 90 | + assert ( | ||
| 91 | + x.shape[1] == self.positional_embedding.shape[0] | ||
| 92 | + ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}" | ||
| 93 | + x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) | ||
| 94 | + | ||
| 95 | + for block in self.blocks: | ||
| 96 | + x = block(x) | ||
| 97 | + | ||
| 98 | + x = self.ln_post(x) | ||
| 99 | + return x | ||
| 100 | + | ||
| 101 | + | ||
| 102 | +AudioEncoder.forward = modified_audio_encoder_forward | ||
| 103 | + | ||
| 104 | + | ||
| 68 | class AudioEncoderTensorCache(nn.Module): | 105 | class AudioEncoderTensorCache(nn.Module): |
| 69 | def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): | 106 | def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): |
| 70 | super().__init__() | 107 | super().__init__() |
| @@ -279,6 +316,7 @@ def main(): | @@ -279,6 +316,7 @@ def main(): | ||
| 279 | model = whisper.load_model(filename) | 316 | model = whisper.load_model(filename) |
| 280 | else: | 317 | else: |
| 281 | model = whisper.load_model(name) | 318 | model = whisper.load_model(name) |
| 319 | + print(model.dims) | ||
| 282 | 320 | ||
| 283 | print( | 321 | print( |
| 284 | f"number of model parameters: {name}", | 322 | f"number of model parameters: {name}", |
| @@ -311,19 +349,20 @@ def main(): | @@ -311,19 +349,20 @@ def main(): | ||
| 311 | assert mel.shape == (batch_size, 80, 30 * 100) | 349 | assert mel.shape == (batch_size, 80, 30 * 100) |
| 312 | 350 | ||
| 313 | encoder = AudioEncoderTensorCache(model.encoder, model.decoder) | 351 | encoder = AudioEncoderTensorCache(model.encoder, model.decoder) |
| 352 | + | ||
| 314 | n_layer_cross_k, n_layer_cross_v = encoder(mel) | 353 | n_layer_cross_k, n_layer_cross_v = encoder(mel) |
| 315 | assert n_layer_cross_k.shape == ( | 354 | assert n_layer_cross_k.shape == ( |
| 316 | model.dims.n_text_layer, | 355 | model.dims.n_text_layer, |
| 317 | batch_size, | 356 | batch_size, |
| 318 | model.dims.n_audio_ctx, | 357 | model.dims.n_audio_ctx, |
| 319 | model.dims.n_text_state, | 358 | model.dims.n_text_state, |
| 320 | - ), n_layer_cross_k.shape | 359 | + ), (n_layer_cross_k.shape, model.dims) |
| 321 | assert n_layer_cross_v.shape == ( | 360 | assert n_layer_cross_v.shape == ( |
| 322 | model.dims.n_text_layer, | 361 | model.dims.n_text_layer, |
| 323 | batch_size, | 362 | batch_size, |
| 324 | model.dims.n_audio_ctx, | 363 | model.dims.n_audio_ctx, |
| 325 | model.dims.n_text_state, | 364 | model.dims.n_text_state, |
| 326 | - ), n_layer_cross_v.shape | 365 | + ), (n_layer_cross_v.shape, model.dims) |
| 327 | 366 | ||
| 328 | encoder_filename = f"{name}-encoder.onnx" | 367 | encoder_filename = f"{name}-encoder.onnx" |
| 329 | torch.onnx.export( | 368 | torch.onnx.export( |
| @@ -334,9 +373,9 @@ def main(): | @@ -334,9 +373,9 @@ def main(): | ||
| 334 | input_names=["mel"], | 373 | input_names=["mel"], |
| 335 | output_names=["n_layer_cross_k", "n_layer_cross_v"], | 374 | output_names=["n_layer_cross_k", "n_layer_cross_v"], |
| 336 | dynamic_axes={ | 375 | dynamic_axes={ |
| 337 | - "mel": {0: "n_audio"}, # n_audio is also known as batch_size | ||
| 338 | - "n_layer_cross_k": {1: "n_audio"}, | ||
| 339 | - "n_layer_cross_v": {1: "n_audio"}, | 376 | + "mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size |
| 377 | + "n_layer_cross_k": {1: "n_audio", 2: "T"}, | ||
| 378 | + "n_layer_cross_v": {1: "n_audio", 2: "T"}, | ||
| 340 | }, | 379 | }, |
| 341 | ) | 380 | ) |
| 342 | 381 | ||
| @@ -461,8 +500,8 @@ def main(): | @@ -461,8 +500,8 @@ def main(): | ||
| 461 | "tokens": {0: "n_audio", 1: "n_tokens"}, | 500 | "tokens": {0: "n_audio", 1: "n_tokens"}, |
| 462 | "in_n_layer_self_k_cache": {1: "n_audio"}, | 501 | "in_n_layer_self_k_cache": {1: "n_audio"}, |
| 463 | "in_n_layer_self_v_cache": {1: "n_audio"}, | 502 | "in_n_layer_self_v_cache": {1: "n_audio"}, |
| 464 | - "n_layer_cross_k": {1: "n_audio"}, | ||
| 465 | - "n_layer_cross_v": {1: "n_audio"}, | 503 | + "n_layer_cross_k": {1: "n_audio", 2: "T"}, |
| 504 | + "n_layer_cross_v": {1: "n_audio", 2: "T"}, | ||
| 466 | }, | 505 | }, |
| 467 | ) | 506 | ) |
| 468 | 507 |
| @@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor: | @@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor: | ||
| 253 | log_spec = torch.clamp(features, min=1e-10).log10() | 253 | log_spec = torch.clamp(features, min=1e-10).log10() |
| 254 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | 254 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| 255 | mel = (log_spec + 4.0) / 4.0 | 255 | mel = (log_spec + 4.0) / 4.0 |
| 256 | + # mel (T, 80) | ||
| 257 | + | ||
| 258 | + # We pad 50 frames at the end so that it is able to detect eot | ||
| 259 | + # You can use another value instead of 50. | ||
| 260 | + mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0) | ||
| 261 | + # Note that if it throws for a multilingual model, | ||
| 262 | + # please use a larger value, say 300 | ||
| 263 | + | ||
| 256 | target = 3000 | 264 | target = 3000 |
| 257 | - mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) | 265 | + if mel.shape[0] > target: |
| 266 | + mel = mel[:target] | ||
| 267 | + | ||
| 268 | + # We don't need to pad it to 30 seconds now! | ||
| 269 | + # mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) | ||
| 270 | + | ||
| 258 | mel = mel.t().unsqueeze(0) | 271 | mel = mel.t().unsqueeze(0) |
| 259 | 272 | ||
| 260 | return mel | 273 | return mel |
| @@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 115 | 115 | ||
| 116 | NormalizeFeatures(f.data(), num_frames, feat_dim); | 116 | NormalizeFeatures(f.data(), num_frames, feat_dim); |
| 117 | 117 | ||
| 118 | - std::array<int64_t, 3> shape{1, max_num_frames, feat_dim}; | 118 | + // note that 50 is an experience value. |
| 119 | + // see also ../../scripts/whisper/test.py | ||
| 120 | + // | ||
| 121 | + // You can replace 50 by other values, say, 100. | ||
| 122 | + // | ||
| 123 | + // Since we have removed the 30 seconds constraint, we need | ||
| 124 | + // tail_padding_frames so that whisper is able to detect the eot token. | ||
| 125 | + int32_t tail_padding_frames = 50; | ||
| 126 | + if (model_->IsMultiLingual()) { | ||
| 127 | + // 300 is an experience value. If it throws, please use a larger value. | ||
| 128 | + tail_padding_frames = 300; | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + if (config_.model_config.whisper.tail_paddings > 0) { | ||
| 132 | + tail_padding_frames = config_.model_config.whisper.tail_paddings; | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + int32_t actual_frames = | ||
| 136 | + std::min(num_frames + tail_padding_frames, max_num_frames); | ||
| 137 | + | ||
| 138 | + std::array<int64_t, 3> shape{1, actual_frames, feat_dim}; | ||
| 119 | 139 | ||
| 120 | Ort::Value mel = Ort::Value::CreateTensor<float>( | 140 | Ort::Value mel = Ort::Value::CreateTensor<float>( |
| 121 | model_->Allocator(), shape.data(), shape.size()); | 141 | model_->Allocator(), shape.data(), shape.size()); |
| @@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 123 | std::copy(f.begin(), f.end(), p_mel); | 143 | std::copy(f.begin(), f.end(), p_mel); |
| 124 | 144 | ||
| 125 | memset(p_mel + f.size(), 0, | 145 | memset(p_mel + f.size(), 0, |
| 126 | - (max_num_frames - num_frames) * feat_dim * sizeof(float)); | 146 | + (actual_frames - num_frames) * feat_dim * sizeof(float)); |
| 127 | mel = Transpose12(model_->Allocator(), &mel); | 147 | mel = Transpose12(model_->Allocator(), &mel); |
| 128 | 148 | ||
| 129 | try { | 149 | try { |
| @@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | @@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | ||
| 32 | "Valid values: transcribe, translate. " | 32 | "Valid values: transcribe, translate. " |
| 33 | "Note that for non-multilingual models, it supports " | 33 | "Note that for non-multilingual models, it supports " |
| 34 | "only 'transcribe'"); | 34 | "only 'transcribe'"); |
| 35 | + | ||
| 36 | + po->Register( | ||
| 37 | + "whisper-tail-paddings", &tail_paddings, | ||
| 38 | + "Suggest value: 50 for English models. 300 for multilingual models. " | ||
| 39 | + "Since we have removed the 30-second constraint, we need to add some " | ||
| 40 | + "tail padding frames " | ||
| 41 | + "so that whisper can detect the eot token. Leave it to -1 to use 50 for " | ||
| 42 | + "English models and 300 for multilingual models."); | ||
| 35 | } | 43 | } |
| 36 | 44 | ||
| 37 | bool OfflineWhisperModelConfig::Validate() const { | 45 | bool OfflineWhisperModelConfig::Validate() const { |
| @@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const { | @@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const { | ||
| 63 | os << "encoder=\"" << encoder << "\", "; | 71 | os << "encoder=\"" << encoder << "\", "; |
| 64 | os << "decoder=\"" << decoder << "\", "; | 72 | os << "decoder=\"" << decoder << "\", "; |
| 65 | os << "language=\"" << language << "\", "; | 73 | os << "language=\"" << language << "\", "; |
| 66 | - os << "task=\"" << task << "\")"; | 74 | + os << "task=\"" << task << "\", "; |
| 75 | + os << "tail_paddings=" << tail_paddings << ")"; | ||
| 67 | 76 | ||
| 68 | return os.str(); | 77 | return os.str(); |
| 69 | } | 78 | } |
| @@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig { | @@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig { | ||
| 28 | // Note: For non-multilingual models, it supports only "transcribe" | 28 | // Note: For non-multilingual models, it supports only "transcribe" |
| 29 | std::string task = "transcribe"; | 29 | std::string task = "transcribe"; |
| 30 | 30 | ||
| 31 | + // Number of tail padding frames. | ||
| 32 | + // | ||
| 33 | + // Since we remove the 30-second constraint, we need to add some paddings | ||
| 34 | + // at the end. | ||
| 35 | + // | ||
| 36 | + // Recommended values: | ||
| 37 | + // - 50 for English models | ||
| 38 | + // - 300 for multilingual models | ||
| 39 | + int32_t tail_paddings = -1; | ||
| 40 | + | ||
| 31 | OfflineWhisperModelConfig() = default; | 41 | OfflineWhisperModelConfig() = default; |
| 32 | OfflineWhisperModelConfig(const std::string &encoder, | 42 | OfflineWhisperModelConfig(const std::string &encoder, |
| 33 | const std::string &decoder, | 43 | const std::string &decoder, |
| 34 | const std::string &language, | 44 | const std::string &language, |
| 35 | - const std::string &task) | ||
| 36 | - : encoder(encoder), decoder(decoder), language(language), task(task) {} | 45 | + const std::string &task, int32_t tail_paddings) |
| 46 | + : encoder(encoder), | ||
| 47 | + decoder(decoder), | ||
| 48 | + language(language), | ||
| 49 | + task(task), | ||
| 50 | + tail_paddings(tail_paddings) {} | ||
| 37 | 51 | ||
| 38 | void Register(ParseOptions *po); | 52 | void Register(ParseOptions *po); |
| 39 | bool Validate() const; | 53 | bool Validate() const; |
| @@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) { | @@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) { | ||
| 15 | using PyClass = OfflineWhisperModelConfig; | 15 | using PyClass = OfflineWhisperModelConfig; |
| 16 | py::class_<PyClass>(*m, "OfflineWhisperModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineWhisperModelConfig") |
| 17 | .def(py::init<const std::string &, const std::string &, | 17 | .def(py::init<const std::string &, const std::string &, |
| 18 | - const std::string &, const std::string &>(), | 18 | + const std::string &, const std::string &, int32_t>(), |
| 19 | py::arg("encoder"), py::arg("decoder"), py::arg("language"), | 19 | py::arg("encoder"), py::arg("decoder"), py::arg("language"), |
| 20 | - py::arg("task")) | 20 | + py::arg("task"), py::arg("tail_paddings") = -1) |
| 21 | .def_readwrite("encoder", &PyClass::encoder) | 21 | .def_readwrite("encoder", &PyClass::encoder) |
| 22 | .def_readwrite("decoder", &PyClass::decoder) | 22 | .def_readwrite("decoder", &PyClass::decoder) |
| 23 | .def_readwrite("language", &PyClass::language) | 23 | .def_readwrite("language", &PyClass::language) |
| 24 | .def_readwrite("task", &PyClass::task) | 24 | .def_readwrite("task", &PyClass::task) |
| 25 | + .def_readwrite("tail_paddings", &PyClass::tail_paddings) | ||
| 25 | .def("__str__", &PyClass::ToString); | 26 | .def("__str__", &PyClass::ToString); |
| 26 | } | 27 | } |
| 27 | 28 |
-
请 注册 或 登录 后发表评论