Fangjun Kuang
Committed by GitHub

Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

@@ -15,9 +15,9 @@ jobs: @@ -15,9 +15,9 @@ jobs:
15 strategy: 15 strategy:
16 fail-fast: false 16 fail-fast: false
17 matrix: 17 matrix:
18 - os: [ubuntu-latest]  
19 - # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]  
20 - model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"] 18 + os: [macos-latest]
  19 + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell", "large", "large-v1", "large-v2", "distil-large-v2"]
  20 + # model: ["large", "large-v1", "large-v2", "large-v3", "distil-large-v2"]
21 python-version: ["3.8"] 21 python-version: ["3.8"]
22 22
23 steps: 23 steps:
@@ -32,7 +32,7 @@ jobs: @@ -32,7 +32,7 @@ jobs:
32 shell: bash 32 shell: bash
33 run: | 33 run: |
34 python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html 34 python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
35 - python3 -m pip install openai-whisper==20230314 onnxruntime onnx 35 + python3 -m pip install openai-whisper==20231117 onnxruntime onnx soundfile librosa
36 36
37 - name: export ${{ matrix.model }} 37 - name: export ${{ matrix.model }}
38 shell: bash 38 shell: bash
@@ -62,7 +62,6 @@ jobs: @@ -62,7 +62,6 @@ jobs:
62 rm -fv medium-aishell-decoder.onnx 62 rm -fv medium-aishell-decoder.onnx
63 fi 63 fi
64 64
65 -  
66 ls -lh 65 ls -lh
67 66
68 ls -lh ~/.cache/whisper || true 67 ls -lh ~/.cache/whisper || true
@@ -74,7 +73,8 @@ jobs: @@ -74,7 +73,8 @@ jobs:
74 src=sherpa-onnx-whisper-${{ matrix.model }} 73 src=sherpa-onnx-whisper-${{ matrix.model }}
75 74
76 cd .. 75 cd ..
77 - mv whisper $src 76 + mkdir $src
  77 + mv -v whisper/$model* $src/
78 78
79 echo "------------------------------" 79 echo "------------------------------"
80 80
@@ -98,18 +98,15 @@ jobs: @@ -98,18 +98,15 @@ jobs:
98 echo "--------------------" 98 echo "--------------------"
99 99
100 if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then 100 if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
101 - #tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.  
102 - tar cvjf $src.tar.bz2 $src  
103 - split -b 1G $src.tar.bz2 $src.tar.bz2.  
104 - rm $src.tar.bz2  
105 - # cat $src.tar.gz.* | tar xjf - 101 + echo "Don't release model to github for large models. $model"
106 else 102 else
107 tar cvjf $src.tar.bz2 $src 103 tar cvjf $src.tar.bz2 $src
108 fi 104 fi
109 - ls -lh  
110 105
  106 + ls -lh
111 107
112 - name: Release 108 - name: Release
  109 + if: matrix.model != 'large' && matrix.model != 'large-v1' && matrix.model != 'large-v2' && matrix.model != 'large-v3' && matrix.model != 'distil-large-v2'
113 uses: svenstaro/upload-release-action@v2 110 uses: svenstaro/upload-release-action@v2
114 with: 111 with:
115 file_glob: true 112 file_glob: true
@@ -119,19 +116,6 @@ jobs: @@ -119,19 +116,6 @@ jobs:
119 repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} 116 repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
120 tag: asr-models 117 tag: asr-models
121 118
122 - - name: Test ${{ matrix.model }}  
123 - shell: bash  
124 - run: |  
125 - python3 -m pip install kaldi-native-fbank  
126 - git checkout .  
127 - model=${{ matrix.model }}  
128 - src=sherpa-onnx-whisper-$model  
129 - python3 scripts/whisper/test.py \  
130 - --encoder $src/$model-encoder.int8.onnx \  
131 - --decoder $src/$model-decoder.int8.onnx \  
132 - --tokens $src/$model-tokens.txt \  
133 - $src/test_wavs/0.wav  
134 -  
135 - name: Publish ${{ matrix.model }} to huggingface 119 - name: Publish ${{ matrix.model }} to huggingface
136 shell: bash 120 shell: bash
137 env: 121 env:
@@ -144,27 +128,36 @@ jobs: @@ -144,27 +128,36 @@ jobs:
144 128
145 export GIT_CLONE_PROTECTION_ACTIVE=false 129 export GIT_CLONE_PROTECTION_ACTIVE=false
146 130
147 - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface 131 + export GIT_LFS_SKIP_SMUDGE=1
  132 +
  133 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
148 134
149 if [[ $model != medium-aishell ]]; then 135 if [[ $model != medium-aishell ]]; then
150 rm -rf huggingface/* 136 rm -rf huggingface/*
151 fi 137 fi
152 138
153 - if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then  
154 - mv $src.tar* ./huggingface  
155 - else  
156 - cp -v $src/*.onnx ./huggingface  
157 - cp -v $src/*tokens* ./huggingface  
158 - cp -av $src/test_wavs ./huggingface  
159 - fi 139 + cp -av $src/* ./huggingface/
160 140
161 cd huggingface 141 cd huggingface
162 142
163 git status 143 git status
164 ls -lh 144 ls -lh
165 - git lfs track "*gz*"  
166 git lfs track "*onnx*" 145 git lfs track "*onnx*"
  146 + git lfs track "*weights*"
167 147
168 git add . 148 git add .
169 git commit -m "upload ${{ matrix.model }}" 149 git commit -m "upload ${{ matrix.model }}"
170 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main 150 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
  151 +
  152 + - name: Test ${{ matrix.model }}
  153 + shell: bash
  154 + run: |
  155 + python3 -m pip install kaldi-native-fbank
  156 + git checkout .
  157 + model=${{ matrix.model }}
  158 + src=sherpa-onnx-whisper-$model
  159 + time python3 scripts/whisper/test.py \
  160 + --encoder $src/$model-encoder.onnx \
  161 + --decoder $src/$model-decoder.onnx \
  162 + --tokens $src/$model-tokens.txt \
  163 + $src/test_wavs/0.wav
1 -## 1.10.14 (to-be-released) 1 +## 1.10.14
2 2
  3 +* Support whisper large v3
3 * Update onnxruntime from v1.18.0 to v1.18.1 4 * Update onnxruntime from v1.18.0 to v1.18.1
4 * Fix invalid utf8 sequence from Whisper for Dart API. 5 * Fix invalid utf8 sequence from Whisper for Dart API.
5 6
@@ -11,7 +11,7 @@ project(sherpa-onnx) @@ -11,7 +11,7 @@ project(sherpa-onnx)
11 # ./nodejs-addon-examples 11 # ./nodejs-addon-examples
12 # ./dart-api-examples/ 12 # ./dart-api-examples/
13 # ./CHANGELOG.md 13 # ./CHANGELOG.md
14 -set(SHERPA_ONNX_VERSION "1.10.13") 14 +set(SHERPA_ONNX_VERSION "1.10.14")
15 15
16 # Disable warning about 16 # Disable warning about
17 # 17 #
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz  
16 - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz  
17 - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz  
18 - /tmp/kaldi-native-fbank-1.19.3.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz
  16 + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz
  17 + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz
  18 + /tmp/kaldi-native-fbank-1.20.0.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
@@ -10,7 +10,7 @@ environment: @@ -10,7 +10,7 @@ environment:
10 10
11 # Add regular dependencies here. 11 # Add regular dependencies here.
12 dependencies: 12 dependencies:
13 - sherpa_onnx: ^1.10.13 13 + sherpa_onnx: ^1.10.14
14 path: ^1.9.0 14 path: ^1.9.0
15 args: ^2.5.0 15 args: ^2.5.0
16 16
@@ -11,7 +11,7 @@ environment: @@ -11,7 +11,7 @@ environment:
11 11
12 # Add regular dependencies here. 12 # Add regular dependencies here.
13 dependencies: 13 dependencies:
14 - sherpa_onnx: ^1.10.13 14 + sherpa_onnx: ^1.10.14
15 path: ^1.9.0 15 path: ^1.9.0
16 args: ^2.5.0 16 args: ^2.5.0
17 17
@@ -8,7 +8,7 @@ environment: @@ -8,7 +8,7 @@ environment:
8 8
9 # Add regular dependencies here. 9 # Add regular dependencies here.
10 dependencies: 10 dependencies:
11 - sherpa_onnx: ^1.10.13 11 + sherpa_onnx: ^1.10.14
12 path: ^1.9.0 12 path: ^1.9.0
13 args: ^2.5.0 13 args: ^2.5.0
14 14
@@ -9,7 +9,7 @@ environment: @@ -9,7 +9,7 @@ environment:
9 sdk: ^3.4.0 9 sdk: ^3.4.0
10 10
11 dependencies: 11 dependencies:
12 - sherpa_onnx: ^1.10.13 12 + sherpa_onnx: ^1.10.14
13 path: ^1.9.0 13 path: ^1.9.0
14 args: ^2.5.0 14 args: ^2.5.0
15 15
@@ -5,7 +5,7 @@ description: > @@ -5,7 +5,7 @@ description: >
5 5
6 publish_to: 'none' 6 publish_to: 'none'
7 7
8 -version: 1.10.13 8 +version: 1.10.14
9 9
10 topics: 10 topics:
11 - speech-recognition 11 - speech-recognition
@@ -30,7 +30,7 @@ dependencies: @@ -30,7 +30,7 @@ dependencies:
30 record: ^5.1.0 30 record: ^5.1.0
31 url_launcher: ^6.2.6 31 url_launcher: ^6.2.6
32 32
33 - sherpa_onnx: ^1.10.13 33 + sherpa_onnx: ^1.10.14
34 # sherpa_onnx: 34 # sherpa_onnx:
35 # path: ../../flutter/sherpa_onnx 35 # path: ../../flutter/sherpa_onnx
36 36
@@ -17,7 +17,7 @@ dependencies: @@ -17,7 +17,7 @@ dependencies:
17 cupertino_icons: ^1.0.6 17 cupertino_icons: ^1.0.6
18 path_provider: ^2.1.3 18 path_provider: ^2.1.3
19 path: ^1.9.0 19 path: ^1.9.0
20 - sherpa_onnx: ^1.10.13 20 + sherpa_onnx: ^1.10.14
21 url_launcher: ^6.2.6 21 url_launcher: ^6.2.6
22 audioplayers: ^5.0.0 22 audioplayers: ^5.0.0
23 23
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
20 -version: 1.10.13 20 +version: 1.10.14
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -30,19 +30,19 @@ dependencies: @@ -30,19 +30,19 @@ dependencies:
30 flutter: 30 flutter:
31 sdk: flutter 31 sdk: flutter
32 32
33 - sherpa_onnx_android: ^1.10.13 33 + sherpa_onnx_android: ^1.10.14
34 # path: ../sherpa_onnx_android 34 # path: ../sherpa_onnx_android
35 35
36 - sherpa_onnx_macos: ^1.10.13 36 + sherpa_onnx_macos: ^1.10.14
37 # path: ../sherpa_onnx_macos 37 # path: ../sherpa_onnx_macos
38 38
39 - sherpa_onnx_linux: ^1.10.13 39 + sherpa_onnx_linux: ^1.10.14
40 # path: ../sherpa_onnx_linux 40 # path: ../sherpa_onnx_linux
41 # 41 #
42 - sherpa_onnx_windows: ^1.10.13 42 + sherpa_onnx_windows: ^1.10.14
43 # path: ../sherpa_onnx_windows 43 # path: ../sherpa_onnx_windows
44 44
45 - sherpa_onnx_ios: ^1.10.13 45 + sherpa_onnx_ios: ^1.10.14
46 # sherpa_onnx_ios: 46 # sherpa_onnx_ios:
47 # path: ../sherpa_onnx_ios 47 # path: ../sherpa_onnx_ios
48 48
@@ -7,7 +7,7 @@ @@ -7,7 +7,7 @@
7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c 7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
8 Pod::Spec.new do |s| 8 Pod::Spec.new do |s|
9 s.name = 'sherpa_onnx_ios' 9 s.name = 'sherpa_onnx_ios'
10 - s.version = '1.10.13' 10 + s.version = '1.10.14'
11 s.summary = 'A new Flutter FFI plugin project.' 11 s.summary = 'A new Flutter FFI plugin project.'
12 s.description = <<-DESC 12 s.description = <<-DESC
13 A new Flutter FFI plugin project. 13 A new Flutter FFI plugin project.
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 # 4 #
5 Pod::Spec.new do |s| 5 Pod::Spec.new do |s|
6 s.name = 'sherpa_onnx_macos' 6 s.name = 'sherpa_onnx_macos'
7 - s.version = '1.10.13' 7 + s.version = '1.10.14'
8 s.summary = 'sherpa-onnx Flutter FFI plugin project.' 8 s.summary = 'sherpa-onnx Flutter FFI plugin project.'
9 s.description = <<-DESC 9 s.description = <<-DESC
10 sherpa-onnx Flutter FFI plugin project. 10 sherpa-onnx Flutter FFI plugin project.
1 { 1 {
2 "dependencies": { 2 "dependencies": {
3 - "sherpa-onnx-node": "^1.10.13" 3 + "sherpa-onnx-node": "^1.10.14"
4 } 4 }
5 } 5 }
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
20 -version: 1.10.13 20 +version: 1.10.14
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -2,3 +2,9 @@ @@ -2,3 +2,9 @@
2 *.config 2 *.config
3 *.ort 3 *.ort
4 *-tokens.txt 4 *-tokens.txt
  5 +*.bias
  6 +*.weights
  7 +*.weight
  8 +*.*embedding
  9 +_Const*
  10 +onnx__*
@@ -32,6 +32,9 @@ from whisper.model import ( @@ -32,6 +32,9 @@ from whisper.model import (
32 TextDecoder, 32 TextDecoder,
33 ) 33 )
34 34
  35 +torch.set_num_threads(1)
  36 +torch.set_num_interop_threads(1)
  37 +
35 38
36 def get_args(): 39 def get_args():
37 parser = argparse.ArgumentParser() 40 parser = argparse.ArgumentParser()
@@ -43,8 +46,9 @@ def get_args(): @@ -43,8 +46,9 @@ def get_args():
43 choices=[ 46 choices=[
44 "tiny", "tiny.en", "base", "base.en", 47 "tiny", "tiny.en", "base", "base.en",
45 "small", "small.en", "medium", "medium.en", 48 "small", "small.en", "medium", "medium.en",
46 - "large", "large-v1", "large-v2", 49 + "large", "large-v1", "large-v2", "large-v3",
47 "distil-medium.en", "distil-small.en", "distil-large-v2", 50 "distil-medium.en", "distil-small.en", "distil-large-v2",
  51 + # "distil-large-v3", # distil-large-v3 is not supported!
48 # for fine-tuned models from icefall 52 # for fine-tuned models from icefall
49 "medium-aishell", 53 "medium-aishell",
50 ], 54 ],
@@ -63,11 +67,25 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): @@ -63,11 +67,25 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
63 Key-value pairs. 67 Key-value pairs.
64 """ 68 """
65 model = onnx.load(filename) 69 model = onnx.load(filename)
  70 +
  71 + while len(model.metadata_props):
  72 + model.metadata_props.pop()
  73 +
66 for key, value in meta_data.items(): 74 for key, value in meta_data.items():
67 meta = model.metadata_props.add() 75 meta = model.metadata_props.add()
68 meta.key = key 76 meta.key = key
69 meta.value = str(value) 77 meta.value = str(value)
70 78
  79 + if "large" in filename:
  80 + external_filename = filename.split(".onnx")[0]
  81 + onnx.save(
  82 + model,
  83 + filename,
  84 + save_as_external_data=True,
  85 + all_tensors_to_one_file=True,
  86 + location=external_filename + ".weights",
  87 + )
  88 + else:
71 onnx.save(model, filename) 89 onnx.save(model, filename)
72 90
73 91
@@ -376,7 +394,9 @@ def main(): @@ -376,7 +394,9 @@ def main():
376 394
377 # write tokens 395 # write tokens
378 396
379 - tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) 397 + tokenizer = whisper.tokenizer.get_tokenizer(
  398 + model.is_multilingual, num_languages=model.num_languages
  399 + )
380 400
381 model.eval() 401 model.eval()
382 print(model.dims) 402 print(model.dims)
@@ -384,10 +404,15 @@ def main(): @@ -384,10 +404,15 @@ def main():
384 audio = whisper.pad_or_trim(audio) 404 audio = whisper.pad_or_trim(audio)
385 assert audio.shape == (16000 * 30,), audio.shape 405 assert audio.shape == (16000 * 30,), audio.shape
386 406
387 - # make log-Mel spectrogram and move to the same device as the model  
388 - mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) 407 + if args.model in ("large", "large-v3"):
  408 + n_mels = 128
  409 + else:
  410 + n_mels = 80
  411 + mel = (
  412 + whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
  413 + )
389 batch_size = 1 414 batch_size = 1
390 - assert mel.shape == (batch_size, 80, 30 * 100) 415 + assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
391 416
392 encoder = AudioEncoderTensorCache(model.encoder, model.decoder) 417 encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
393 418
@@ -547,6 +572,17 @@ def main(): @@ -547,6 +572,17 @@ def main():
547 ) 572 )
548 573
549 if "large" in args.model: 574 if "large" in args.model:
  575 + decoder_external_filename = decoder_filename.split(".onnx")[0]
  576 + decoder_model = onnx.load(decoder_filename)
  577 + onnx.save(
  578 + decoder_model,
  579 + decoder_filename,
  580 + save_as_external_data=True,
  581 + all_tensors_to_one_file=True,
  582 + location=decoder_external_filename + ".weights",
  583 + )
  584 +
  585 + if "large" in args.model:
550 # it causes errors for large models, so skip it. 586 # it causes errors for large models, so skip it.
551 return 587 return
552 # Generate int8 quantization models 588 # Generate int8 quantization models
@@ -9,9 +9,10 @@ import base64 @@ -9,9 +9,10 @@ import base64
9 from typing import Tuple 9 from typing import Tuple
10 10
11 import kaldi_native_fbank as knf 11 import kaldi_native_fbank as knf
  12 +import numpy as np
12 import onnxruntime as ort 13 import onnxruntime as ort
  14 +import soundfile as sf
13 import torch 15 import torch
14 -import torchaudio  
15 16
16 17
17 def get_args(): 18 def get_args():
@@ -98,7 +99,6 @@ class OnnxModel: @@ -98,7 +99,6 @@ class OnnxModel:
98 self.blank = int(meta["blank_id"]) 99 self.blank = int(meta["blank_id"])
99 100
100 self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) 101 self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
101 -  
102 self.sot_sequence.append(self.no_timestamps) 102 self.sot_sequence.append(self.no_timestamps)
103 103
104 self.all_language_tokens = list( 104 self.all_language_tokens = list(
@@ -226,7 +226,18 @@ def load_tokens(filename): @@ -226,7 +226,18 @@ def load_tokens(filename):
226 return tokens 226 return tokens
227 227
228 228
229 -def compute_features(filename: str) -> torch.Tensor: 229 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  230 + data, sample_rate = sf.read(
  231 + filename,
  232 + always_2d=True,
  233 + dtype="float32",
  234 + )
  235 + data = data[:, 0] # use only the first channel
  236 + samples = np.ascontiguousarray(data)
  237 + return samples, sample_rate
  238 +
  239 +
  240 +def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
230 """ 241 """
231 Args: 242 Args:
232 filename: 243 filename:
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor: @@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
234 Returns: 245 Returns:
235 Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. 246 Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
236 """ 247 """
237 - wave, sample_rate = torchaudio.load(filename)  
238 - audio = wave[0].contiguous() # only use the first channel 248 + wave, sample_rate = load_audio(filename)
239 if sample_rate != 16000: 249 if sample_rate != 16000:
240 - audio = torchaudio.functional.resample(  
241 - audio, orig_freq=sample_rate, new_freq=16000  
242 - ) 250 + import librosa
  251 +
  252 + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
  253 + sample_rate = 16000
243 254
244 features = [] 255 features = []
245 - online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())  
246 - online_whisper_fbank.accept_waveform(16000, audio.numpy()) 256 + opts = knf.WhisperFeatureOptions()
  257 + opts.dim = dim
  258 + online_whisper_fbank = knf.OnlineWhisperFbank(opts)
  259 + online_whisper_fbank.accept_waveform(16000, wave)
247 online_whisper_fbank.input_finished() 260 online_whisper_fbank.input_finished()
248 for i in range(online_whisper_fbank.num_frames_ready): 261 for i in range(online_whisper_fbank.num_frames_ready):
249 f = online_whisper_fbank.get_frame(i) 262 f = online_whisper_fbank.get_frame(i)
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor: @@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
280 def main(): 293 def main():
281 args = get_args() 294 args = get_args()
282 295
283 - mel = compute_features(args.sound_file)  
284 model = OnnxModel(args.encoder, args.decoder) 296 model = OnnxModel(args.encoder, args.decoder)
  297 + dim = 80 if "large-v3" not in args.encoder else 128
  298 + mel = compute_features(args.sound_file, dim=dim)
285 299
286 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) 300 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
287 301
@@ -313,6 +327,7 @@ def main(): @@ -313,6 +327,7 @@ def main():
313 327
314 n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() 328 n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
315 329
  330 + print(model.sot_sequence)
316 tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) 331 tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
317 offset = torch.zeros(1, dtype=torch.int64) 332 offset = torch.zeros(1, dtype=torch.int64)
318 logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( 333 logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
@@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
88 } 88 }
89 89
90 std::unique_ptr<OfflineStream> CreateStream() const override { 90 std::unique_ptr<OfflineStream> CreateStream() const override {
91 - return std::make_unique<OfflineStream>(WhisperTag{}); 91 + WhisperTag tag;
  92 + tag.dim = model_->FeatureDim();
  93 + return std::make_unique<OfflineStream>(tag);
92 } 94 }
93 95
94 void DecodeStreams(OfflineStream **ss, int32_t n) const override { 96 void DecodeStreams(OfflineStream **ss, int32_t n) const override {
@@ -97,12 +97,16 @@ class OfflineStream::Impl { @@ -97,12 +97,16 @@ class OfflineStream::Impl {
97 } 97 }
98 } 98 }
99 99
100 - explicit Impl(WhisperTag /*tag*/) { 100 + explicit Impl(WhisperTag tag) {
101 config_.normalize_samples = true; 101 config_.normalize_samples = true;
102 opts_.frame_opts.samp_freq = 16000; 102 opts_.frame_opts.samp_freq = 16000;
103 - opts_.mel_opts.num_bins = 80; // not used  
104 - whisper_fbank_ =  
105 - std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); 103 + opts_.mel_opts.num_bins = tag.dim;
  104 +
  105 + knf::WhisperFeatureOptions whisper_opts;
  106 + whisper_opts.frame_opts = opts_.frame_opts;
  107 + whisper_opts.dim = tag.dim;
  108 +
  109 + whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
106 config_.sampling_rate = opts_.frame_opts.samp_freq; 110 config_.sampling_rate = opts_.frame_opts.samp_freq;
107 } 111 }
108 112
@@ -35,7 +35,10 @@ struct OfflineRecognitionResult { @@ -35,7 +35,10 @@ struct OfflineRecognitionResult {
35 std::string AsJsonString() const; 35 std::string AsJsonString() const;
36 }; 36 };
37 37
38 -struct WhisperTag {}; 38 +struct WhisperTag {
  39 + int32_t dim = 80;
  40 +};
  41 +
39 struct CEDTag {}; 42 struct CEDTag {};
40 43
41 class OfflineStream { 44 class OfflineStream {
@@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl { @@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl {
217 217
218 int32_t VocabSize() const { return n_vocab_; } 218 int32_t VocabSize() const { return n_vocab_; }
219 219
  220 + int32_t FeatureDim() const { return n_mels_; }
  221 +
220 int32_t Translate() const { return translate_; } 222 int32_t Translate() const { return translate_; }
221 223
222 bool IsMultiLingual() const { return is_multilingual_; } 224 bool IsMultiLingual() const { return is_multilingual_; }
@@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl { @@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl {
242 } 244 }
243 245
244 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 246 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  247 + SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels");
245 SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); 248 SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
246 SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); 249 SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
247 SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); 250 SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
@@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl { @@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl {
316 std::unordered_map<int32_t, std::string> id2lang_; 319 std::unordered_map<int32_t, std::string> id2lang_;
317 320
318 // model meta data 321 // model meta data
  322 + int32_t n_mels_ = 80;
319 int32_t n_text_layer_ = 0; 323 int32_t n_text_layer_ = 0;
320 int32_t n_text_ctx_ = 0; 324 int32_t n_text_ctx_ = 0;
321 int32_t n_text_state_ = 0; 325 int32_t n_text_state_ = 0;
@@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } @@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
414 418
415 int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } 419 int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
416 420
  421 +int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); }
  422 +
417 int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } 423 int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
418 424
419 bool OfflineWhisperModel::IsMultiLingual() const { 425 bool OfflineWhisperModel::IsMultiLingual() const {
@@ -102,6 +102,7 @@ class OfflineWhisperModel { @@ -102,6 +102,7 @@ class OfflineWhisperModel {
102 int32_t SOT() const; 102 int32_t SOT() const;
103 int32_t TextCtx() const; 103 int32_t TextCtx() const;
104 int32_t VocabSize() const; 104 int32_t VocabSize() const;
  105 + int32_t FeatureDim() const;
105 int32_t Translate() const; 106 int32_t Translate() const;
106 bool IsMultiLingual() const; 107 bool IsMultiLingual() const;
107 108