Committed by
GitHub
Add fine-tuned whisper model on aishell (#565)
See also https://github.com/k2-fsa/icefall/pull/1466
正在显示
3 个修改的文件
包含
36 行增加
和
7 行删除
| @@ -15,9 +15,9 @@ jobs: | @@ -15,9 +15,9 @@ jobs: | ||
| 15 | strategy: | 15 | strategy: |
| 16 | fail-fast: false | 16 | fail-fast: false |
| 17 | matrix: | 17 | matrix: |
| 18 | - os: [macos-latest] | 18 | + os: [ubuntu-latest] |
| 19 | # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"] | 19 | # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"] |
| 20 | - model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"] | 20 | + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"] |
| 21 | python-version: ["3.8"] | 21 | python-version: ["3.8"] |
| 22 | 22 | ||
| 23 | steps: | 23 | steps: |
| @@ -49,9 +49,19 @@ jobs: | @@ -49,9 +49,19 @@ jobs: | ||
| 49 | elif [[ $model == distil-small.en ]]; then | 49 | elif [[ $model == distil-small.en ]]; then |
| 50 | wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin | 50 | wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin |
| 51 | ls -lh | 51 | ls -lh |
| 52 | + elif [[ $model == medium-aishell ]]; then | ||
| 53 | + wget -q -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt | ||
| 54 | + ls -lh | ||
| 52 | fi | 55 | fi |
| 53 | python3 ./export-onnx.py --model ${{ matrix.model }} | 56 | python3 ./export-onnx.py --model ${{ matrix.model }} |
| 54 | # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ | 57 | # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ |
| 58 | + # | ||
| 59 | + if [[ $model == medium-aishell ]]; then | ||
| 60 | + ls -lh *.onnx | ||
| 61 | + rm -fv medium-aishell-encoder.onnx | ||
| 62 | + rm -fv medium-aishell-decoder.onnx | ||
| 63 | + fi | ||
| 64 | + | ||
| 55 | 65 | ||
| 56 | ls -lh | 66 | ls -lh |
| 57 | 67 | ||
| @@ -59,6 +69,7 @@ jobs: | @@ -59,6 +69,7 @@ jobs: | ||
| 59 | ls -lh distil*original-model.bin || true | 69 | ls -lh distil*original-model.bin || true |
| 60 | rm -rf ~/.cache/whisper | 70 | rm -rf ~/.cache/whisper |
| 61 | rm -f distil*original-model.bin | 71 | rm -f distil*original-model.bin |
| 72 | + rm -f medium-aishell.pt | ||
| 62 | 73 | ||
| 63 | src=sherpa-onnx-whisper-${{ matrix.model }} | 74 | src=sherpa-onnx-whisper-${{ matrix.model }} |
| 64 | 75 | ||
| @@ -132,7 +143,10 @@ jobs: | @@ -132,7 +143,10 @@ jobs: | ||
| 132 | git config --global user.name "Fangjun Kuang" | 143 | git config --global user.name "Fangjun Kuang" |
| 133 | 144 | ||
| 134 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface | 145 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface |
| 135 | - rm -rf huggingface/* | 146 | + |
| 147 | + if [[ $model != medium-aishell ]]; then | ||
| 148 | + rm -rf huggingface/* | ||
| 149 | + fi | ||
| 136 | 150 | ||
| 137 | if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then | 151 | if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then |
| 138 | mv $src.tar* ./huggingface | 152 | mv $src.tar* ./huggingface |
| @@ -44,7 +44,9 @@ def get_args(): | @@ -44,7 +44,9 @@ def get_args(): | ||
| 44 | "tiny", "tiny.en", "base", "base.en", | 44 | "tiny", "tiny.en", "base", "base.en", |
| 45 | "small", "small.en", "medium", "medium.en", | 45 | "small", "small.en", "medium", "medium.en", |
| 46 | "large", "large-v1", "large-v2", | 46 | "large", "large-v1", "large-v2", |
| 47 | - "distil-medium.en", "distil-small.en", "distil-large-v2" | 47 | + "distil-medium.en", "distil-small.en", "distil-large-v2", |
| 48 | + # for fine-tuned models from icefall | ||
| 49 | + "medium-aishell", | ||
| 48 | ], | 50 | ], |
| 49 | # fmt: on | 51 | # fmt: on |
| 50 | ) | 52 | ) |
| @@ -340,6 +342,19 @@ def main(): | @@ -340,6 +342,19 @@ def main(): | ||
| 340 | """ | 342 | """ |
| 341 | ) | 343 | ) |
| 342 | model = whisper.load_model(filename) | 344 | model = whisper.load_model(filename) |
| 345 | + elif name == "medium-aishell": | ||
| 346 | + filename = "./medium-aishell.pt" | ||
| 347 | + if not Path(filename).is_file(): | ||
| 348 | + raise ValueError( | ||
| 349 | + """ | ||
| 350 | + Please go to https://huggingface.co/yuekai/icefall_asr_aishell_whisper/tree/main/exp_medium | ||
| 351 | + to download whisper-medium-aishell1-epoch-10-avg-4.pt | ||
| 352 | + You can use the following command to do that: | ||
| 353 | + | ||
| 354 | + wget -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt | ||
| 355 | + """ | ||
| 356 | + ) | ||
| 357 | + model = whisper.load_model(filename) | ||
| 343 | else: | 358 | else: |
| 344 | model = whisper.load_model(name) | 359 | model = whisper.load_model(name) |
| 345 | print(model.dims) | 360 | print(model.dims) |
| @@ -257,9 +257,9 @@ def compute_features(filename: str) -> torch.Tensor: | @@ -257,9 +257,9 @@ def compute_features(filename: str) -> torch.Tensor: | ||
| 257 | mel = (log_spec + 4.0) / 4.0 | 257 | mel = (log_spec + 4.0) / 4.0 |
| 258 | # mel (T, 80) | 258 | # mel (T, 80) |
| 259 | 259 | ||
| 260 | - # We pad 50 frames at the end so that it is able to detect eot | ||
| 261 | - # You can use another value instead of 50. | ||
| 262 | - mel = torch.nn.functional.pad(mel, (0, 0, 0, 1000), "constant", 0) | 260 | + # We pad 1500 frames at the end so that it is able to detect eot |
| 261 | + # You can use another value instead of 1500. | ||
| 262 | + mel = torch.nn.functional.pad(mel, (0, 0, 0, 1500), "constant", 0) | ||
| 263 | # Note that if it throws for a multilingual model, | 263 | # Note that if it throws for a multilingual model, |
| 264 | # please use a larger value, say 300 | 264 | # please use a larger value, say 300 |
| 265 | 265 |
-
请 注册 或 登录 后发表评论