Fangjun Kuang
Committed by GitHub

Add CI test for Whisper models (#239)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +names=(
  17 +tiny.en
  18 +base.en
  19 +# small.en
  20 +# medium.en
  21 +)
  22 +
  23 +for name in ${names[@]}; do
  24 + log "------------------------------------------------------------"
  25 + log "Run $name"
  26 + log "------------------------------------------------------------"
  27 +
  28 + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
  29 + log "Start testing ${repo_url}"
  30 + repo=$(basename $repo_url)
  31 + log "Download pretrained model and test-data from $repo_url"
  32 +
  33 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  34 + pushd $repo
  35 + git lfs pull --include "*.onnx"
  36 + git lfs pull --include "*.ort"
  37 + ls -lh *.{onnx,ort}
  38 + popd
  39 +
  40 + log "test fp32 onnx"
  41 +
  42 + time $EXE \
  43 + --tokens=$repo/${name}-tokens.txt \
  44 + --whisper-encoder=$repo/${name}-encoder.onnx \
  45 + --whisper-decoder=$repo/${name}-decoder.onnx \
  46 + --num-threads=2 \
  47 + $repo/test_wavs/0.wav \
  48 + $repo/test_wavs/1.wav \
  49 + $repo/test_wavs/8k.wav
  50 +
  51 + log "test int8 onnx"
  52 +
  53 + time $EXE \
  54 + --tokens=$repo/${name}-tokens.txt \
  55 + --whisper-encoder=$repo/${name}-encoder.int8.onnx \
  56 + --whisper-decoder=$repo/${name}-decoder.int8.onnx \
  57 + --num-threads=2 \
  58 + $repo/test_wavs/0.wav \
  59 + $repo/test_wavs/1.wav \
  60 + $repo/test_wavs/8k.wav
  61 +
  62 + log "test fp32 ort"
  63 +
  64 + time $EXE \
  65 + --tokens=$repo/${name}-tokens.txt \
  66 + --whisper-encoder=$repo/${name}-encoder.ort \
  67 + --whisper-decoder=$repo/${name}-decoder.ort \
  68 + --num-threads=2 \
  69 + $repo/test_wavs/0.wav \
  70 + $repo/test_wavs/1.wav \
  71 + $repo/test_wavs/8k.wav
  72 +
  73 + log "test int8 ort"
  74 +
  75 + time $EXE \
  76 + --tokens=$repo/${name}-tokens.txt \
  77 + --whisper-encoder=$repo/${name}-encoder.int8.ort \
  78 + --whisper-decoder=$repo/${name}-decoder.int8.ort \
  79 + --num-threads=2 \
  80 + $repo/test_wavs/0.wav \
  81 + $repo/test_wavs/1.wav \
  82 + $repo/test_wavs/8k.wav
  83 +
  84 + rm -rf $repo
  85 +done
@@ -84,6 +84,14 @@ jobs: @@ -84,6 +84,14 @@ jobs:
84 file build/bin/sherpa-onnx 84 file build/bin/sherpa-onnx
85 readelf -d build/bin/sherpa-onnx 85 readelf -d build/bin/sherpa-onnx
86 86
  87 + - name: Test offline Whisper
  88 + shell: bash
  89 + run: |
  90 + export PATH=$PWD/build/bin:$PATH
  91 + export EXE=sherpa-onnx-offline
  92 +
  93 + .github/scripts/test-offline-whisper.sh
  94 +
87 - name: Test offline CTC 95 - name: Test offline CTC
88 shell: bash 96 shell: bash
89 run: | 97 run: |
@@ -82,6 +82,14 @@ jobs: @@ -82,6 +82,14 @@ jobs:
82 otool -L build/bin/sherpa-onnx 82 otool -L build/bin/sherpa-onnx
83 otool -l build/bin/sherpa-onnx 83 otool -l build/bin/sherpa-onnx
84 84
  85 + - name: Test offline Whisper
  86 + shell: bash
  87 + run: |
  88 + export PATH=$PWD/build/bin:$PATH
  89 + export EXE=sherpa-onnx-offline
  90 +
  91 + .github/scripts/test-offline-whisper.sh
  92 +
85 - name: Test offline CTC 93 - name: Test offline CTC
86 shell: bash 94 shell: bash
87 run: | 95 run: |
@@ -74,6 +74,14 @@ jobs: @@ -74,6 +74,14 @@ jobs:
74 74
75 ls -lh ./bin/Release/sherpa-onnx.exe 75 ls -lh ./bin/Release/sherpa-onnx.exe
76 76
  77 + - name: Test offline Whisper for windows x64
  78 + shell: bash
  79 + run: |
  80 + export PATH=$PWD/build/bin/Release:$PATH
  81 + export EXE=sherpa-onnx-offline.exe
  82 +
  83 + .github/scripts/test-offline-whisper.sh
  84 +
77 - name: Test offline CTC for windows x64 85 - name: Test offline CTC for windows x64
78 shell: bash 86 shell: bash
79 run: | 87 run: |
@@ -75,6 +75,14 @@ jobs: @@ -75,6 +75,14 @@ jobs:
75 75
76 ls -lh ./bin/Release/sherpa-onnx.exe 76 ls -lh ./bin/Release/sherpa-onnx.exe
77 77
  78 + - name: Test offline Whisper for windows x64
  79 + shell: bash
  80 + run: |
  81 + export PATH=$PWD/build/bin/Release:$PATH
  82 + export EXE=sherpa-onnx-offline.exe
  83 +
  84 + .github/scripts/test-offline-whisper.sh
  85 +
78 - name: Test offline CTC for windows x64 86 - name: Test offline CTC for windows x64
79 shell: bash 87 shell: bash
80 run: | 88 run: |
@@ -73,6 +73,14 @@ jobs: @@ -73,6 +73,14 @@ jobs:
73 73
74 ls -lh ./bin/Release/sherpa-onnx.exe 74 ls -lh ./bin/Release/sherpa-onnx.exe
75 75
  76 + - name: Test offline Whisper for windows x86
  77 + shell: bash
  78 + run: |
  79 + export PATH=$PWD/build/bin/Release:$PATH
  80 + export EXE=sherpa-onnx-offline.exe
  81 +
  82 + .github/scripts/test-offline-whisper.sh
  83 +
76 - name: Test offline CTC for windows x86 84 - name: Test offline CTC for windows x86
77 shell: bash 85 shell: bash
78 run: | 86 run: |
@@ -5,5 +5,9 @@ and use onnxruntime to replace PyTorch for speech recognition. @@ -5,5 +5,9 @@ and use onnxruntime to replace PyTorch for speech recognition.
5 5
6 You can use [sherpa-onnx][sherpa-onnx] to run the converted model. 6 You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
7 7
  8 +Please see
  9 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html
  10 +for details.
  11 +
8 [whisper]: https://github.com/openai/whisper 12 [whisper]: https://github.com/openai/whisper
9 [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx 13 [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
@@ -18,15 +18,30 @@ import argparse @@ -18,15 +18,30 @@ import argparse
18 def get_args(): 18 def get_args():
19 parser = argparse.ArgumentParser() 19 parser = argparse.ArgumentParser()
20 parser.add_argument( 20 parser.add_argument(
21 - "--model", 21 + "--encoder",
22 type=str, 22 type=str,
23 required=True, 23 required=True,
24 - # fmt: off  
25 - choices=[  
26 - "tiny", "tiny.en", "base", "base.en",  
27 - "small", "small.en", "medium", "medium.en",  
28 - "large", "large-v1", "large-v2"],  
29 - # fmt: on 24 + help="Path to the encoder",
  25 + )
  26 +
  27 + parser.add_argument(
  28 + "--decoder",
  29 + type=str,
  30 + required=True,
  31 + help="Path to the decoder",
  32 + )
  33 +
  34 + parser.add_argument(
  35 + "--tokens",
  36 + type=str,
  37 + required=True,
  38 + help="Path to the tokens",
  39 + )
  40 +
  41 + parser.add_argument(
  42 + "sound_file",
  43 + type=str,
  44 + help="Path to the test wave",
30 ) 45 )
31 return parser.parse_args() 46 return parser.parse_args()
32 47
@@ -161,11 +176,10 @@ def load_tokens(filename): @@ -161,11 +176,10 @@ def load_tokens(filename):
161 176
162 def main(): 177 def main():
163 args = get_args() 178 args = get_args()
164 - name = args.model 179 + encoder = args.encoder
  180 + decoder = args.decoder
165 181
166 - encoder = f"./{name}-encoder.onnx"  
167 - decoder = f"./{name}-decoder.onnx"  
168 - audio = whisper.load_audio("0.wav") 182 + audio = whisper.load_audio(args.sound_file)
169 183
170 features = [] 184 features = []
171 online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) 185 online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
@@ -224,17 +238,13 @@ def main(): @@ -224,17 +238,13 @@ def main():
224 logits = logits[0, -1] 238 logits = logits[0, -1]
225 model.suppress_tokens(logits, is_initial=False) 239 model.suppress_tokens(logits, is_initial=False)
226 max_token_id = logits.argmax(dim=-1) 240 max_token_id = logits.argmax(dim=-1)
227 - token_table = load_tokens(f"./{name}-tokens.txt") 241 + token_table = load_tokens(args.tokens)
228 s = b"" 242 s = b""
229 for i in results: 243 for i in results:
230 if i in token_table: 244 if i in token_table:
231 s += base64.b64decode(token_table[i]) 245 s += base64.b64decode(token_table[i])
232 - else:  
233 - print("oov", i)  
234 246
235 print(s.decode().strip()) 247 print(s.decode().strip())
236 - print(results)  
237 - print(model.sot_sequence)  
238 248
239 249
240 if __name__ == "__main__": 250 if __name__ == "__main__":