继续操作前请注册或者登录。
Fangjun Kuang
Committed by GitHub

Support distil-whisper (#411)

... ... @@ -16,32 +16,49 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash
run: |
python3 -m pip install openai-whisper torch onnxruntime onnx
python3 -m pip install torch==1.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install openai-whisper==20230314 onnxruntime onnx
- name: export ${{ matrix.model }}
shell: bash
run: |
cd scripts/whisper
model=${{ matrix.model }}
echo "model: $model"
if [[ $model == distil-medium.en ]]; then
wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
ls -lh
fi
python3 ./export-onnx.py --model ${{ matrix.model }}
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
ls -lh
ls -lh ~/.cache/whisper
if [[ $model != distil-medium.en ]]; then
ls -lh ~/.cache/whisper
fi
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
model=${{ matrix.model }}
cd scripts/whisper
git config --global user.email "csukuangfj@gmail.com"
... ... @@ -54,6 +71,18 @@ jobs:
cp *tokens.txt ./huggingface
cd huggingface
if [[ $model == distil-medium.en ]]; then
mkdir test_wavs
cd test_wavs
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
git add .
cd ..
fi
git status
ls -lh
git lfs track "*.onnx"
... ...
... ... @@ -39,7 +39,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
"large", "large-v1", "large-v2",
"distil-medium.en",
],
# fmt: on
)
return parser.parse_args()
... ... @@ -257,10 +259,27 @@ def convert_tokens(name, model):
def main():
args = get_args()
name = args.model
print(args)
print(name)
opset_version = 13
model = whisper.load_model(name)
if name == "distil-medium.en":
filename = "./distil-medium-en-original-model.bin"
if not Path(filename):
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-medium.en
to download original-model.bin
You can use the following command to do that:
wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(
f"number of model parameters: {name}",
sum(p.numel() for p in model.parameters()),
... ...