Committed by
GitHub
Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)
正在显示
23 个修改的文件
包含
152 行增加
和
85 行删除
| @@ -15,9 +15,9 @@ jobs: | @@ -15,9 +15,9 @@ jobs: | ||
| 15 | strategy: | 15 | strategy: |
| 16 | fail-fast: false | 16 | fail-fast: false |
| 17 | matrix: | 17 | matrix: |
| 18 | - os: [ubuntu-latest] | ||
| 19 | - # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"] | ||
| 20 | - model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"] | 18 | + os: [macos-latest] |
| 19 | + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell", "large", "large-v1", "large-v2", "distil-large-v2"] | ||
| 20 | + # model: ["large", "large-v1", "large-v2", "large-v3", "distil-large-v2"] | ||
| 21 | python-version: ["3.8"] | 21 | python-version: ["3.8"] |
| 22 | 22 | ||
| 23 | steps: | 23 | steps: |
| @@ -32,7 +32,7 @@ jobs: | @@ -32,7 +32,7 @@ jobs: | ||
| 32 | shell: bash | 32 | shell: bash |
| 33 | run: | | 33 | run: | |
| 34 | python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html | 34 | python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html |
| 35 | - python3 -m pip install openai-whisper==20230314 onnxruntime onnx | 35 | + python3 -m pip install openai-whisper==20231117 onnxruntime onnx soundfile librosa |
| 36 | 36 | ||
| 37 | - name: export ${{ matrix.model }} | 37 | - name: export ${{ matrix.model }} |
| 38 | shell: bash | 38 | shell: bash |
| @@ -62,7 +62,6 @@ jobs: | @@ -62,7 +62,6 @@ jobs: | ||
| 62 | rm -fv medium-aishell-decoder.onnx | 62 | rm -fv medium-aishell-decoder.onnx |
| 63 | fi | 63 | fi |
| 64 | 64 | ||
| 65 | - | ||
| 66 | ls -lh | 65 | ls -lh |
| 67 | 66 | ||
| 68 | ls -lh ~/.cache/whisper || true | 67 | ls -lh ~/.cache/whisper || true |
| @@ -74,7 +73,8 @@ jobs: | @@ -74,7 +73,8 @@ jobs: | ||
| 74 | src=sherpa-onnx-whisper-${{ matrix.model }} | 73 | src=sherpa-onnx-whisper-${{ matrix.model }} |
| 75 | 74 | ||
| 76 | cd .. | 75 | cd .. |
| 77 | - mv whisper $src | 76 | + mkdir $src |
| 77 | + mv -v whisper/$model* $src/ | ||
| 78 | 78 | ||
| 79 | echo "------------------------------" | 79 | echo "------------------------------" |
| 80 | 80 | ||
| @@ -97,19 +97,16 @@ jobs: | @@ -97,19 +97,16 @@ jobs: | ||
| 97 | ls -lh $src | 97 | ls -lh $src |
| 98 | echo "--------------------" | 98 | echo "--------------------" |
| 99 | 99 | ||
| 100 | - if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then | ||
| 101 | - #tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2. | ||
| 102 | - tar cvjf $src.tar.bz2 $src | ||
| 103 | - split -b 1G $src.tar.bz2 $src.tar.bz2. | ||
| 104 | - rm $src.tar.bz2 | ||
| 105 | - # cat $src.tar.gz.* | tar xjf - | 100 | + if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then |
| 101 | + echo "Don't release model to github for large models. $model" | ||
| 106 | else | 102 | else |
| 107 | tar cvjf $src.tar.bz2 $src | 103 | tar cvjf $src.tar.bz2 $src |
| 108 | fi | 104 | fi |
| 109 | - ls -lh | ||
| 110 | 105 | ||
| 106 | + ls -lh | ||
| 111 | 107 | ||
| 112 | - name: Release | 108 | - name: Release |
| 109 | + if: matrix.model != 'large' && matrix.model != 'large-v1' && matrix.model != 'large-v2' && matrix.model != 'large-v3' && matrix.model != 'distil-large-v2' | ||
| 113 | uses: svenstaro/upload-release-action@v2 | 110 | uses: svenstaro/upload-release-action@v2 |
| 114 | with: | 111 | with: |
| 115 | file_glob: true | 112 | file_glob: true |
| @@ -119,19 +116,6 @@ jobs: | @@ -119,19 +116,6 @@ jobs: | ||
| 119 | repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | 116 | repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} |
| 120 | tag: asr-models | 117 | tag: asr-models |
| 121 | 118 | ||
| 122 | - - name: Test ${{ matrix.model }} | ||
| 123 | - shell: bash | ||
| 124 | - run: | | ||
| 125 | - python3 -m pip install kaldi-native-fbank | ||
| 126 | - git checkout . | ||
| 127 | - model=${{ matrix.model }} | ||
| 128 | - src=sherpa-onnx-whisper-$model | ||
| 129 | - python3 scripts/whisper/test.py \ | ||
| 130 | - --encoder $src/$model-encoder.int8.onnx \ | ||
| 131 | - --decoder $src/$model-decoder.int8.onnx \ | ||
| 132 | - --tokens $src/$model-tokens.txt \ | ||
| 133 | - $src/test_wavs/0.wav | ||
| 134 | - | ||
| 135 | - name: Publish ${{ matrix.model }} to huggingface | 119 | - name: Publish ${{ matrix.model }} to huggingface |
| 136 | shell: bash | 120 | shell: bash |
| 137 | env: | 121 | env: |
| @@ -144,27 +128,36 @@ jobs: | @@ -144,27 +128,36 @@ jobs: | ||
| 144 | 128 | ||
| 145 | export GIT_CLONE_PROTECTION_ACTIVE=false | 129 | export GIT_CLONE_PROTECTION_ACTIVE=false |
| 146 | 130 | ||
| 147 | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface | 131 | + export GIT_LFS_SKIP_SMUDGE=1 |
| 132 | + | ||
| 133 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface | ||
| 148 | 134 | ||
| 149 | if [[ $model != medium-aishell ]]; then | 135 | if [[ $model != medium-aishell ]]; then |
| 150 | rm -rf huggingface/* | 136 | rm -rf huggingface/* |
| 151 | fi | 137 | fi |
| 152 | 138 | ||
| 153 | - if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then | ||
| 154 | - mv $src.tar* ./huggingface | ||
| 155 | - else | ||
| 156 | - cp -v $src/*.onnx ./huggingface | ||
| 157 | - cp -v $src/*tokens* ./huggingface | ||
| 158 | - cp -av $src/test_wavs ./huggingface | ||
| 159 | - fi | 139 | + cp -av $src/* ./huggingface/ |
| 160 | 140 | ||
| 161 | cd huggingface | 141 | cd huggingface |
| 162 | 142 | ||
| 163 | git status | 143 | git status |
| 164 | ls -lh | 144 | ls -lh |
| 165 | - git lfs track "*gz*" | ||
| 166 | git lfs track "*onnx*" | 145 | git lfs track "*onnx*" |
| 146 | + git lfs track "*weights*" | ||
| 167 | 147 | ||
| 168 | git add . | 148 | git add . |
| 169 | git commit -m "upload ${{ matrix.model }}" | 149 | git commit -m "upload ${{ matrix.model }}" |
| 170 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main | 150 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main |
| 151 | + | ||
| 152 | + - name: Test ${{ matrix.model }} | ||
| 153 | + shell: bash | ||
| 154 | + run: | | ||
| 155 | + python3 -m pip install kaldi-native-fbank | ||
| 156 | + git checkout . | ||
| 157 | + model=${{ matrix.model }} | ||
| 158 | + src=sherpa-onnx-whisper-$model | ||
| 159 | + time python3 scripts/whisper/test.py \ | ||
| 160 | + --encoder $src/$model-encoder.onnx \ | ||
| 161 | + --decoder $src/$model-decoder.onnx \ | ||
| 162 | + --tokens $src/$model-tokens.txt \ | ||
| 163 | + $src/test_wavs/0.wav |
| @@ -11,7 +11,7 @@ project(sherpa-onnx) | @@ -11,7 +11,7 @@ project(sherpa-onnx) | ||
| 11 | # ./nodejs-addon-examples | 11 | # ./nodejs-addon-examples |
| 12 | # ./dart-api-examples/ | 12 | # ./dart-api-examples/ |
| 13 | # ./CHANGELOG.md | 13 | # ./CHANGELOG.md |
| 14 | -set(SHERPA_ONNX_VERSION "1.10.13") | 14 | +set(SHERPA_ONNX_VERSION "1.10.14") |
| 15 | 15 | ||
| 16 | # Disable warning about | 16 | # Disable warning about |
| 17 | # | 17 | # |
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz | ||
| 16 | - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz | ||
| 17 | - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.19.3.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz |
| 16 | + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz | ||
| 17 | + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.20.0.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| @@ -5,7 +5,7 @@ description: > | @@ -5,7 +5,7 @@ description: > | ||
| 5 | 5 | ||
| 6 | publish_to: 'none' | 6 | publish_to: 'none' |
| 7 | 7 | ||
| 8 | -version: 1.10.13 | 8 | +version: 1.10.14 |
| 9 | 9 | ||
| 10 | topics: | 10 | topics: |
| 11 | - speech-recognition | 11 | - speech-recognition |
| @@ -30,7 +30,7 @@ dependencies: | @@ -30,7 +30,7 @@ dependencies: | ||
| 30 | record: ^5.1.0 | 30 | record: ^5.1.0 |
| 31 | url_launcher: ^6.2.6 | 31 | url_launcher: ^6.2.6 |
| 32 | 32 | ||
| 33 | - sherpa_onnx: ^1.10.13 | 33 | + sherpa_onnx: ^1.10.14 |
| 34 | # sherpa_onnx: | 34 | # sherpa_onnx: |
| 35 | # path: ../../flutter/sherpa_onnx | 35 | # path: ../../flutter/sherpa_onnx |
| 36 | 36 |
| @@ -17,7 +17,7 @@ dependencies: | @@ -17,7 +17,7 @@ dependencies: | ||
| 17 | cupertino_icons: ^1.0.6 | 17 | cupertino_icons: ^1.0.6 |
| 18 | path_provider: ^2.1.3 | 18 | path_provider: ^2.1.3 |
| 19 | path: ^1.9.0 | 19 | path: ^1.9.0 |
| 20 | - sherpa_onnx: ^1.10.13 | 20 | + sherpa_onnx: ^1.10.14 |
| 21 | url_launcher: ^6.2.6 | 21 | url_launcher: ^6.2.6 |
| 22 | audioplayers: ^5.0.0 | 22 | audioplayers: ^5.0.0 |
| 23 | 23 |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec |
| 20 | -version: 1.10.13 | 20 | +version: 1.10.14 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 | ||
| @@ -30,19 +30,19 @@ dependencies: | @@ -30,19 +30,19 @@ dependencies: | ||
| 30 | flutter: | 30 | flutter: |
| 31 | sdk: flutter | 31 | sdk: flutter |
| 32 | 32 | ||
| 33 | - sherpa_onnx_android: ^1.10.13 | 33 | + sherpa_onnx_android: ^1.10.14 |
| 34 | # path: ../sherpa_onnx_android | 34 | # path: ../sherpa_onnx_android |
| 35 | 35 | ||
| 36 | - sherpa_onnx_macos: ^1.10.13 | 36 | + sherpa_onnx_macos: ^1.10.14 |
| 37 | # path: ../sherpa_onnx_macos | 37 | # path: ../sherpa_onnx_macos |
| 38 | 38 | ||
| 39 | - sherpa_onnx_linux: ^1.10.13 | 39 | + sherpa_onnx_linux: ^1.10.14 |
| 40 | # path: ../sherpa_onnx_linux | 40 | # path: ../sherpa_onnx_linux |
| 41 | # | 41 | # |
| 42 | - sherpa_onnx_windows: ^1.10.13 | 42 | + sherpa_onnx_windows: ^1.10.14 |
| 43 | # path: ../sherpa_onnx_windows | 43 | # path: ../sherpa_onnx_windows |
| 44 | 44 | ||
| 45 | - sherpa_onnx_ios: ^1.10.13 | 45 | + sherpa_onnx_ios: ^1.10.14 |
| 46 | # sherpa_onnx_ios: | 46 | # sherpa_onnx_ios: |
| 47 | # path: ../sherpa_onnx_ios | 47 | # path: ../sherpa_onnx_ios |
| 48 | 48 |
| @@ -7,7 +7,7 @@ | @@ -7,7 +7,7 @@ | ||
| 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c | 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c |
| 8 | Pod::Spec.new do |s| | 8 | Pod::Spec.new do |s| |
| 9 | s.name = 'sherpa_onnx_ios' | 9 | s.name = 'sherpa_onnx_ios' |
| 10 | - s.version = '1.10.13' | 10 | + s.version = '1.10.14' |
| 11 | s.summary = 'A new Flutter FFI plugin project.' | 11 | s.summary = 'A new Flutter FFI plugin project.' |
| 12 | s.description = <<-DESC | 12 | s.description = <<-DESC |
| 13 | A new Flutter FFI plugin project. | 13 | A new Flutter FFI plugin project. |
| @@ -4,7 +4,7 @@ | @@ -4,7 +4,7 @@ | ||
| 4 | # | 4 | # |
| 5 | Pod::Spec.new do |s| | 5 | Pod::Spec.new do |s| |
| 6 | s.name = 'sherpa_onnx_macos' | 6 | s.name = 'sherpa_onnx_macos' |
| 7 | - s.version = '1.10.13' | 7 | + s.version = '1.10.14' |
| 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' | 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' |
| 9 | s.description = <<-DESC | 9 | s.description = <<-DESC |
| 10 | sherpa-onnx Flutter FFI plugin project. | 10 | sherpa-onnx Flutter FFI plugin project. |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec |
| 20 | -version: 1.10.13 | 20 | +version: 1.10.14 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 |
| @@ -32,6 +32,9 @@ from whisper.model import ( | @@ -32,6 +32,9 @@ from whisper.model import ( | ||
| 32 | TextDecoder, | 32 | TextDecoder, |
| 33 | ) | 33 | ) |
| 34 | 34 | ||
| 35 | +torch.set_num_threads(1) | ||
| 36 | +torch.set_num_interop_threads(1) | ||
| 37 | + | ||
| 35 | 38 | ||
| 36 | def get_args(): | 39 | def get_args(): |
| 37 | parser = argparse.ArgumentParser() | 40 | parser = argparse.ArgumentParser() |
| @@ -43,8 +46,9 @@ def get_args(): | @@ -43,8 +46,9 @@ def get_args(): | ||
| 43 | choices=[ | 46 | choices=[ |
| 44 | "tiny", "tiny.en", "base", "base.en", | 47 | "tiny", "tiny.en", "base", "base.en", |
| 45 | "small", "small.en", "medium", "medium.en", | 48 | "small", "small.en", "medium", "medium.en", |
| 46 | - "large", "large-v1", "large-v2", | 49 | + "large", "large-v1", "large-v2", "large-v3", |
| 47 | "distil-medium.en", "distil-small.en", "distil-large-v2", | 50 | "distil-medium.en", "distil-small.en", "distil-large-v2", |
| 51 | + # "distil-large-v3", # distil-large-v3 is not supported! | ||
| 48 | # for fine-tuned models from icefall | 52 | # for fine-tuned models from icefall |
| 49 | "medium-aishell", | 53 | "medium-aishell", |
| 50 | ], | 54 | ], |
| @@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): | @@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
| 63 | Key-value pairs. | 67 | Key-value pairs. |
| 64 | """ | 68 | """ |
| 65 | model = onnx.load(filename) | 69 | model = onnx.load(filename) |
| 70 | + | ||
| 71 | + while len(model.metadata_props): | ||
| 72 | + model.metadata_props.pop() | ||
| 73 | + | ||
| 66 | for key, value in meta_data.items(): | 74 | for key, value in meta_data.items(): |
| 67 | meta = model.metadata_props.add() | 75 | meta = model.metadata_props.add() |
| 68 | meta.key = key | 76 | meta.key = key |
| 69 | meta.value = str(value) | 77 | meta.value = str(value) |
| 70 | 78 | ||
| 71 | - onnx.save(model, filename) | 79 | + if "large" in filename: |
| 80 | + external_filename = filename.split(".onnx")[0] | ||
| 81 | + onnx.save( | ||
| 82 | + model, | ||
| 83 | + filename, | ||
| 84 | + save_as_external_data=True, | ||
| 85 | + all_tensors_to_one_file=True, | ||
| 86 | + location=external_filename + ".weights", | ||
| 87 | + ) | ||
| 88 | + else: | ||
| 89 | + onnx.save(model, filename) | ||
| 72 | 90 | ||
| 73 | 91 | ||
| 74 | def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor): | 92 | def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor): |
| @@ -376,7 +394,9 @@ def main(): | @@ -376,7 +394,9 @@ def main(): | ||
| 376 | 394 | ||
| 377 | # write tokens | 395 | # write tokens |
| 378 | 396 | ||
| 379 | - tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) | 397 | + tokenizer = whisper.tokenizer.get_tokenizer( |
| 398 | + model.is_multilingual, num_languages=model.num_languages | ||
| 399 | + ) | ||
| 380 | 400 | ||
| 381 | model.eval() | 401 | model.eval() |
| 382 | print(model.dims) | 402 | print(model.dims) |
| @@ -384,10 +404,15 @@ def main(): | @@ -384,10 +404,15 @@ def main(): | ||
| 384 | audio = whisper.pad_or_trim(audio) | 404 | audio = whisper.pad_or_trim(audio) |
| 385 | assert audio.shape == (16000 * 30,), audio.shape | 405 | assert audio.shape == (16000 * 30,), audio.shape |
| 386 | 406 | ||
| 387 | - # make log-Mel spectrogram and move to the same device as the model | ||
| 388 | - mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) | 407 | + if args.model in ("large", "large-v3"): |
| 408 | + n_mels = 128 | ||
| 409 | + else: | ||
| 410 | + n_mels = 80 | ||
| 411 | + mel = ( | ||
| 412 | + whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0) | ||
| 413 | + ) | ||
| 389 | batch_size = 1 | 414 | batch_size = 1 |
| 390 | - assert mel.shape == (batch_size, 80, 30 * 100) | 415 | + assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape |
| 391 | 416 | ||
| 392 | encoder = AudioEncoderTensorCache(model.encoder, model.decoder) | 417 | encoder = AudioEncoderTensorCache(model.encoder, model.decoder) |
| 393 | 418 | ||
| @@ -547,6 +572,17 @@ def main(): | @@ -547,6 +572,17 @@ def main(): | ||
| 547 | ) | 572 | ) |
| 548 | 573 | ||
| 549 | if "large" in args.model: | 574 | if "large" in args.model: |
| 575 | + decoder_external_filename = decoder_filename.split(".onnx")[0] | ||
| 576 | + decoder_model = onnx.load(decoder_filename) | ||
| 577 | + onnx.save( | ||
| 578 | + decoder_model, | ||
| 579 | + decoder_filename, | ||
| 580 | + save_as_external_data=True, | ||
| 581 | + all_tensors_to_one_file=True, | ||
| 582 | + location=decoder_external_filename + ".weights", | ||
| 583 | + ) | ||
| 584 | + | ||
| 585 | + if "large" in args.model: | ||
| 550 | # it causes errors for large models, so skip it. | 586 | # it causes errors for large models, so skip it. |
| 551 | return | 587 | return |
| 552 | # Generate int8 quantization models | 588 | # Generate int8 quantization models |
| @@ -9,9 +9,10 @@ import base64 | @@ -9,9 +9,10 @@ import base64 | ||
| 9 | from typing import Tuple | 9 | from typing import Tuple |
| 10 | 10 | ||
| 11 | import kaldi_native_fbank as knf | 11 | import kaldi_native_fbank as knf |
| 12 | +import numpy as np | ||
| 12 | import onnxruntime as ort | 13 | import onnxruntime as ort |
| 14 | +import soundfile as sf | ||
| 13 | import torch | 15 | import torch |
| 14 | -import torchaudio | ||
| 15 | 16 | ||
| 16 | 17 | ||
| 17 | def get_args(): | 18 | def get_args(): |
| @@ -98,7 +99,6 @@ class OnnxModel: | @@ -98,7 +99,6 @@ class OnnxModel: | ||
| 98 | self.blank = int(meta["blank_id"]) | 99 | self.blank = int(meta["blank_id"]) |
| 99 | 100 | ||
| 100 | self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) | 101 | self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) |
| 101 | - | ||
| 102 | self.sot_sequence.append(self.no_timestamps) | 102 | self.sot_sequence.append(self.no_timestamps) |
| 103 | 103 | ||
| 104 | self.all_language_tokens = list( | 104 | self.all_language_tokens = list( |
| @@ -226,7 +226,18 @@ def load_tokens(filename): | @@ -226,7 +226,18 @@ def load_tokens(filename): | ||
| 226 | return tokens | 226 | return tokens |
| 227 | 227 | ||
| 228 | 228 | ||
| 229 | -def compute_features(filename: str) -> torch.Tensor: | 229 | +def load_audio(filename: str) -> Tuple[np.ndarray, int]: |
| 230 | + data, sample_rate = sf.read( | ||
| 231 | + filename, | ||
| 232 | + always_2d=True, | ||
| 233 | + dtype="float32", | ||
| 234 | + ) | ||
| 235 | + data = data[:, 0] # use only the first channel | ||
| 236 | + samples = np.ascontiguousarray(data) | ||
| 237 | + return samples, sample_rate | ||
| 238 | + | ||
| 239 | + | ||
| 240 | +def compute_features(filename: str, dim: int = 80) -> torch.Tensor: | ||
| 230 | """ | 241 | """ |
| 231 | Args: | 242 | Args: |
| 232 | filename: | 243 | filename: |
| @@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor: | @@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor: | ||
| 234 | Returns: | 245 | Returns: |
| 235 | Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. | 246 | Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. |
| 236 | """ | 247 | """ |
| 237 | - wave, sample_rate = torchaudio.load(filename) | ||
| 238 | - audio = wave[0].contiguous() # only use the first channel | 248 | + wave, sample_rate = load_audio(filename) |
| 239 | if sample_rate != 16000: | 249 | if sample_rate != 16000: |
| 240 | - audio = torchaudio.functional.resample( | ||
| 241 | - audio, orig_freq=sample_rate, new_freq=16000 | ||
| 242 | - ) | 250 | + import librosa |
| 251 | + | ||
| 252 | + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) | ||
| 253 | + sample_rate = 16000 | ||
| 243 | 254 | ||
| 244 | features = [] | 255 | features = [] |
| 245 | - online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) | ||
| 246 | - online_whisper_fbank.accept_waveform(16000, audio.numpy()) | 256 | + opts = knf.WhisperFeatureOptions() |
| 257 | + opts.dim = dim | ||
| 258 | + online_whisper_fbank = knf.OnlineWhisperFbank(opts) | ||
| 259 | + online_whisper_fbank.accept_waveform(16000, wave) | ||
| 247 | online_whisper_fbank.input_finished() | 260 | online_whisper_fbank.input_finished() |
| 248 | for i in range(online_whisper_fbank.num_frames_ready): | 261 | for i in range(online_whisper_fbank.num_frames_ready): |
| 249 | f = online_whisper_fbank.get_frame(i) | 262 | f = online_whisper_fbank.get_frame(i) |
| @@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor: | @@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor: | ||
| 280 | def main(): | 293 | def main(): |
| 281 | args = get_args() | 294 | args = get_args() |
| 282 | 295 | ||
| 283 | - mel = compute_features(args.sound_file) | ||
| 284 | model = OnnxModel(args.encoder, args.decoder) | 296 | model = OnnxModel(args.encoder, args.decoder) |
| 297 | + dim = 80 if "large-v3" not in args.encoder else 128 | ||
| 298 | + mel = compute_features(args.sound_file, dim=dim) | ||
| 285 | 299 | ||
| 286 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) | 300 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) |
| 287 | 301 | ||
| @@ -313,6 +327,7 @@ def main(): | @@ -313,6 +327,7 @@ def main(): | ||
| 313 | 327 | ||
| 314 | n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() | 328 | n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() |
| 315 | 329 | ||
| 330 | + print(model.sot_sequence) | ||
| 316 | tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) | 331 | tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) |
| 317 | offset = torch.zeros(1, dtype=torch.int64) | 332 | offset = torch.zeros(1, dtype=torch.int64) |
| 318 | logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( | 333 | logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( |
| @@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 88 | } | 88 | } |
| 89 | 89 | ||
| 90 | std::unique_ptr<OfflineStream> CreateStream() const override { | 90 | std::unique_ptr<OfflineStream> CreateStream() const override { |
| 91 | - return std::make_unique<OfflineStream>(WhisperTag{}); | 91 | + WhisperTag tag; |
| 92 | + tag.dim = model_->FeatureDim(); | ||
| 93 | + return std::make_unique<OfflineStream>(tag); | ||
| 92 | } | 94 | } |
| 93 | 95 | ||
| 94 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { | 96 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { |
| @@ -97,12 +97,16 @@ class OfflineStream::Impl { | @@ -97,12 +97,16 @@ class OfflineStream::Impl { | ||
| 97 | } | 97 | } |
| 98 | } | 98 | } |
| 99 | 99 | ||
| 100 | - explicit Impl(WhisperTag /*tag*/) { | 100 | + explicit Impl(WhisperTag tag) { |
| 101 | config_.normalize_samples = true; | 101 | config_.normalize_samples = true; |
| 102 | opts_.frame_opts.samp_freq = 16000; | 102 | opts_.frame_opts.samp_freq = 16000; |
| 103 | - opts_.mel_opts.num_bins = 80; // not used | ||
| 104 | - whisper_fbank_ = | ||
| 105 | - std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); | 103 | + opts_.mel_opts.num_bins = tag.dim; |
| 104 | + | ||
| 105 | + knf::WhisperFeatureOptions whisper_opts; | ||
| 106 | + whisper_opts.frame_opts = opts_.frame_opts; | ||
| 107 | + whisper_opts.dim = tag.dim; | ||
| 108 | + | ||
| 109 | + whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts); | ||
| 106 | config_.sampling_rate = opts_.frame_opts.samp_freq; | 110 | config_.sampling_rate = opts_.frame_opts.samp_freq; |
| 107 | } | 111 | } |
| 108 | 112 |
| @@ -35,7 +35,10 @@ struct OfflineRecognitionResult { | @@ -35,7 +35,10 @@ struct OfflineRecognitionResult { | ||
| 35 | std::string AsJsonString() const; | 35 | std::string AsJsonString() const; |
| 36 | }; | 36 | }; |
| 37 | 37 | ||
| 38 | -struct WhisperTag {}; | 38 | +struct WhisperTag { |
| 39 | + int32_t dim = 80; | ||
| 40 | +}; | ||
| 41 | + | ||
| 39 | struct CEDTag {}; | 42 | struct CEDTag {}; |
| 40 | 43 | ||
| 41 | class OfflineStream { | 44 | class OfflineStream { |
| @@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl { | @@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl { | ||
| 217 | 217 | ||
| 218 | int32_t VocabSize() const { return n_vocab_; } | 218 | int32_t VocabSize() const { return n_vocab_; } |
| 219 | 219 | ||
| 220 | + int32_t FeatureDim() const { return n_mels_; } | ||
| 221 | + | ||
| 220 | int32_t Translate() const { return translate_; } | 222 | int32_t Translate() const { return translate_; } |
| 221 | 223 | ||
| 222 | bool IsMultiLingual() const { return is_multilingual_; } | 224 | bool IsMultiLingual() const { return is_multilingual_; } |
| @@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl { | @@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl { | ||
| 242 | } | 244 | } |
| 243 | 245 | ||
| 244 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 246 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 247 | + SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels"); | ||
| 245 | SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); | 248 | SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); |
| 246 | SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); | 249 | SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); |
| 247 | SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); | 250 | SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); |
| @@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl { | @@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl { | ||
| 316 | std::unordered_map<int32_t, std::string> id2lang_; | 319 | std::unordered_map<int32_t, std::string> id2lang_; |
| 317 | 320 | ||
| 318 | // model meta data | 321 | // model meta data |
| 322 | + int32_t n_mels_ = 80; | ||
| 319 | int32_t n_text_layer_ = 0; | 323 | int32_t n_text_layer_ = 0; |
| 320 | int32_t n_text_ctx_ = 0; | 324 | int32_t n_text_ctx_ = 0; |
| 321 | int32_t n_text_state_ = 0; | 325 | int32_t n_text_state_ = 0; |
| @@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } | @@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } | ||
| 414 | 418 | ||
| 415 | int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } | 419 | int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } |
| 416 | 420 | ||
| 421 | +int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); } | ||
| 422 | + | ||
| 417 | int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } | 423 | int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } |
| 418 | 424 | ||
| 419 | bool OfflineWhisperModel::IsMultiLingual() const { | 425 | bool OfflineWhisperModel::IsMultiLingual() const { |
| @@ -102,6 +102,7 @@ class OfflineWhisperModel { | @@ -102,6 +102,7 @@ class OfflineWhisperModel { | ||
| 102 | int32_t SOT() const; | 102 | int32_t SOT() const; |
| 103 | int32_t TextCtx() const; | 103 | int32_t TextCtx() const; |
| 104 | int32_t VocabSize() const; | 104 | int32_t VocabSize() const; |
| 105 | + int32_t FeatureDim() const; | ||
| 105 | int32_t Translate() const; | 106 | int32_t Translate() const; |
| 106 | bool IsMultiLingual() const; | 107 | bool IsMultiLingual() const; |
| 107 | 108 |
-
请 注册 或 登录 后发表评论