Fangjun Kuang
Committed by GitHub

Support distil-whisper (#411)

@@ -16,32 +16,49 @@ jobs: @@ -16,32 +16,49 @@ jobs:
16 fail-fast: false 16 fail-fast: false
17 matrix: 17 matrix:
18 os: [macos-latest] 18 os: [macos-latest]
19 - model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] 19 + model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
  20 + python-version: ["3.8"]
20 21
21 steps: 22 steps:
22 - uses: actions/checkout@v4 23 - uses: actions/checkout@v4
23 24
  25 + - name: Setup Python ${{ matrix.python-version }}
  26 + uses: actions/setup-python@v2
  27 + with:
  28 + python-version: ${{ matrix.python-version }}
  29 +
24 - name: Install dependencies 30 - name: Install dependencies
25 shell: bash 31 shell: bash
26 run: | 32 run: |
27 - python3 -m pip install openai-whisper torch onnxruntime onnx 33 + python3 -m pip install torch==1.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
  34 + python3 -m pip install openai-whisper==20230314 onnxruntime onnx
28 35
29 - name: export ${{ matrix.model }} 36 - name: export ${{ matrix.model }}
30 shell: bash 37 shell: bash
31 run: | 38 run: |
32 cd scripts/whisper 39 cd scripts/whisper
  40 + model=${{ matrix.model }}
  41 + echo "model: $model"
  42 + if [[ $model == distil-medium.en ]]; then
  43 + wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
  44 + ls -lh
  45 + fi
33 python3 ./export-onnx.py --model ${{ matrix.model }} 46 python3 ./export-onnx.py --model ${{ matrix.model }}
34 python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ 47 python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
35 48
36 ls -lh 49 ls -lh
37 50
38 - ls -lh ~/.cache/whisper 51 + if [[ $model != distil-medium.en ]]; then
  52 + ls -lh ~/.cache/whisper
  53 + fi
39 54
40 - name: Publish ${{ matrix.model }} to huggingface 55 - name: Publish ${{ matrix.model }} to huggingface
41 shell: bash 56 shell: bash
42 env: 57 env:
43 HF_TOKEN: ${{ secrets.HF_TOKEN }} 58 HF_TOKEN: ${{ secrets.HF_TOKEN }}
44 run: | 59 run: |
  60 + model=${{ matrix.model }}
  61 +
45 cd scripts/whisper 62 cd scripts/whisper
46 63
47 git config --global user.email "csukuangfj@gmail.com" 64 git config --global user.email "csukuangfj@gmail.com"
@@ -54,6 +71,18 @@ jobs: @@ -54,6 +71,18 @@ jobs:
54 cp *tokens.txt ./huggingface 71 cp *tokens.txt ./huggingface
55 72
56 cd huggingface 73 cd huggingface
  74 +
  75 + if [[ $model == distil-medium.en ]]; then
  76 + mkdir test_wavs
  77 + cd test_wavs
  78 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
  79 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
  80 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
  81 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
  82 + git add .
  83 + cd ..
  84 + fi
  85 +
57 git status 86 git status
58 ls -lh 87 ls -lh
59 git lfs track "*.onnx" 88 git lfs track "*.onnx"
@@ -39,7 +39,9 @@ def get_args(): @@ -39,7 +39,9 @@ def get_args():
39 choices=[ 39 choices=[
40 "tiny", "tiny.en", "base", "base.en", 40 "tiny", "tiny.en", "base", "base.en",
41 "small", "small.en", "medium", "medium.en", 41 "small", "small.en", "medium", "medium.en",
42 - "large", "large-v1", "large-v2"], 42 + "large", "large-v1", "large-v2",
  43 + "distil-medium.en",
  44 + ],
43 # fmt: on 45 # fmt: on
44 ) 46 )
45 return parser.parse_args() 47 return parser.parse_args()
@@ -257,10 +259,27 @@ def convert_tokens(name, model): @@ -257,10 +259,27 @@ def convert_tokens(name, model):
257 def main(): 259 def main():
258 args = get_args() 260 args = get_args()
259 name = args.model 261 name = args.model
  262 + print(args)
  263 + print(name)
260 264
261 opset_version = 13 265 opset_version = 13
262 266
263 - model = whisper.load_model(name) 267 + if name == "distil-medium.en":
  268 + filename = "./distil-medium-en-original-model.bin"
  269 + if not Path(filename):
  270 + raise ValueError(
  271 + """
  272 + Please go to https://huggingface.co/distil-whisper/distil-medium.en
  273 + to download original-model.bin
  274 + You can use the following command to do that:
  275 +
  276 + wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
  277 + """
  278 + )
  279 + model = whisper.load_model(filename)
  280 + else:
  281 + model = whisper.load_model(name)
  282 +
264 print( 283 print(
265 f"number of model parameters: {name}", 284 f"number of model parameters: {name}",
266 sum(p.numel() for p in model.parameters()), 285 sum(p.numel() for p in model.parameters()),