Fangjun Kuang
Committed by GitHub

Support distil-small.en whisper (#472)

@@ -22,6 +22,8 @@ tiny @@ -22,6 +22,8 @@ tiny
22 base 22 base
23 small 23 small
24 medium 24 medium
  25 +distil-medium.en
  26 +distil-small.en
25 ) 27 )
26 28
27 for name in ${names[@]}; do 29 for name in ${names[@]}; do
@@ -15,8 +15,9 @@ jobs: @@ -15,8 +15,9 @@ jobs:
15 strategy: 15 strategy:
16 fail-fast: false 16 fail-fast: false
17 matrix: 17 matrix:
18 - os: [ubuntu-latest]  
19 - model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] 18 + os: [macos-latest]
  19 + # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
  20 + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"]
20 python-version: ["3.8"] 21 python-version: ["3.8"]
21 22
22 steps: 23 steps:
@@ -42,23 +43,33 @@ jobs: @@ -42,23 +43,33 @@ jobs:
42 if [[ $model == distil-medium.en ]]; then 43 if [[ $model == distil-medium.en ]]; then
43 wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin 44 wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
44 ls -lh 45 ls -lh
  46 + elif [[ $model == distil-large-v2 ]]; then
  47 + wget -q -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
  48 + ls -lh
  49 + elif [[ $model == distil-small.en ]]; then
  50 + wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
  51 + ls -lh
45 fi 52 fi
46 python3 ./export-onnx.py --model ${{ matrix.model }} 53 python3 ./export-onnx.py --model ${{ matrix.model }}
47 # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ 54 # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
48 55
49 ls -lh 56 ls -lh
50 57
51 - if [[ $model != distil-medium.en ]]; then  
52 - ls -lh ~/.cache/whisper  
53 - fi 58 + ls -lh ~/.cache/whisper || true
  59 + ls -lh distil*original-model.bin || true
  60 + rm -rf ~/.cache/whisper
  61 + rm -f distil*original-model.bin
54 62
55 src=sherpa-onnx-whisper-${{ matrix.model }} 63 src=sherpa-onnx-whisper-${{ matrix.model }}
56 64
57 - mkdir $src  
58 - cp *.onnx $src/  
59 - cp *tokens.txt $src 65 + cd ..
  66 + mv whisper $src
  67 +
  68 + echo "------------------------------"
60 69
61 cd $src 70 cd $src
  71 + du -h -d1 .
  72 + ls -lh
62 mkdir -p test_wavs 73 mkdir -p test_wavs
63 cd test_wavs 74 cd test_wavs
64 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav 75 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
@@ -66,21 +77,32 @@ jobs: @@ -66,21 +77,32 @@ jobs:
66 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav 77 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
67 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt 78 wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
68 cd ../.. 79 cd ../..
69 - mv $src ../.. 80 + mv $src ../
  81 + echo "pwd: $PWD"
70 82
71 - cd ../.. 83 + cd ../
72 echo "--------------------" 84 echo "--------------------"
73 ls -lh 85 ls -lh
74 ls -lh $src 86 ls -lh $src
75 echo "--------------------" 87 echo "--------------------"
76 88
77 - tar cjvf ./$src.tar.bz2 $src 89 + if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
  90 + #tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.
  91 + tar cvjf $src.tar.bz2 $src
  92 + split -b 1G $src.tar.bz2 $src.tar.bz2.
  93 + rm $src.tar.bz2
  94 + # cat $src.tar.gz.* | tar xjf -
  95 + else
  96 + tar cvjf $src.tar.bz2 $src
  97 + fi
  98 + ls -lh
  99 +
78 100
79 - name: Release 101 - name: Release
80 uses: svenstaro/upload-release-action@v2 102 uses: svenstaro/upload-release-action@v2
81 with: 103 with:
82 file_glob: true 104 file_glob: true
83 - file: ./*.tar.bz2 105 + file: ./*.tar*
84 overwrite: true 106 overwrite: true
85 repo_name: k2-fsa/sherpa-onnx 107 repo_name: k2-fsa/sherpa-onnx
86 repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} 108 repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
@@ -99,14 +121,21 @@ jobs: @@ -99,14 +121,21 @@ jobs:
99 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface 121 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
100 rm -rf huggingface/* 122 rm -rf huggingface/*
101 123
102 - cp -av $src/* ./huggingface/ 124 + if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
  125 + mv $src.tar* ./huggingface
  126 + else
  127 + cp -v $src/*.onnx ./huggingface
  128 + cp -v $src/*tokens* ./huggingface
  129 + cp -av $src/test_wavs ./huggingface
  130 + fi
103 131
104 cd huggingface 132 cd huggingface
105 133
106 git status 134 git status
107 ls -lh 135 ls -lh
108 - git lfs track "*.onnx"  
109 - # git lfs track "*.ort" 136 + git lfs track "*gz*"
  137 + git lfs track "*onnx*"
  138 +
110 git add . 139 git add .
111 git commit -m "upload ${{ matrix.model }}" 140 git commit -m "upload ${{ matrix.model }}"
112 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main 141 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
@@ -90,7 +90,7 @@ jobs: @@ -90,7 +90,7 @@ jobs:
90 ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav 90 ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
91 91
92 - name: Start server for paraformer models 92 - name: Start server for paraformer models
93 - if: matrix.model_type == 'paraformer' 93 + if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
94 shell: bash 94 shell: bash
95 run: | 95 run: |
96 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en 96 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
@@ -106,7 +106,7 @@ jobs: @@ -106,7 +106,7 @@ jobs:
106 sleep 10 106 sleep 10
107 107
108 - name: Start client for paraformer models 108 - name: Start client for paraformer models
109 - if: matrix.model_type == 'paraformer' 109 + if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
110 shell: bash 110 shell: bash
111 run: | 111 run: |
112 python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ 112 python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.9.0") 4 +set(SHERPA_ONNX_VERSION "1.9.1")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -44,7 +44,7 @@ def get_args(): @@ -44,7 +44,7 @@ def get_args():
44 "tiny", "tiny.en", "base", "base.en", 44 "tiny", "tiny.en", "base", "base.en",
45 "small", "small.en", "medium", "medium.en", 45 "small", "small.en", "medium", "medium.en",
46 "large", "large-v1", "large-v2", 46 "large", "large-v1", "large-v2",
47 - "distil-medium.en", 47 + "distil-medium.en", "distil-small.en", "distil-large-v2"
48 ], 48 ],
49 # fmt: on 49 # fmt: on
50 ) 50 )
@@ -314,6 +314,32 @@ def main(): @@ -314,6 +314,32 @@ def main():
314 """ 314 """
315 ) 315 )
316 model = whisper.load_model(filename) 316 model = whisper.load_model(filename)
  317 + elif name == "distil-large-v2":
  318 + filename = "./distil-large-v2-original-model.bin"
  319 + if not Path(filename).is_file():
  320 + raise ValueError(
  321 + """
  322 + Please go to https://huggingface.co/distil-whisper/distil-large-v2
  323 + to download original-model.bin
  324 + You can use the following command to do that:
  325 +
  326 + wget -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
  327 + """
  328 + )
  329 + model = whisper.load_model(filename)
  330 + elif name == "distil-small.en":
  331 + filename = "./distil-small-en-original-model.bin"
  332 + if not Path(filename).is_file():
  333 + raise ValueError(
  334 + """
  335 + Please go to https://huggingface.co/distil-whisper/distil-small.en
  336 + to download original-model.bin
  337 + You can use the following command to do that:
  338 +
  339 + wget -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
  340 + """
  341 + )
  342 + model = whisper.load_model(filename)
317 else: 343 else:
318 model = whisper.load_model(name) 344 model = whisper.load_model(name)
319 print(model.dims) 345 print(model.dims)
@@ -209,7 +209,7 @@ class OnnxModel: @@ -209,7 +209,7 @@ class OnnxModel:
209 logits = logits.reshape(-1) 209 logits = logits.reshape(-1)
210 mask = torch.ones(logits.shape[0], dtype=torch.int64) 210 mask = torch.ones(logits.shape[0], dtype=torch.int64)
211 mask[self.all_language_tokens] = 0 211 mask[self.all_language_tokens] = 0
212 - logits[mask] = float("-inf") 212 + logits[mask != 0] = float("-inf")
213 lang_id = logits.argmax().item() 213 lang_id = logits.argmax().item()
214 print("detected language: ", self.id2lang[lang_id]) 214 print("detected language: ", self.id2lang[lang_id])
215 return lang_id 215 return lang_id
@@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor: @@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor:
263 263
264 target = 3000 264 target = 3000
265 if mel.shape[0] > target: 265 if mel.shape[0] > target:
266 - mel = mel[:target] 266 + # -50 so that there are some zero tail paddings.
  267 + mel = mel[: target - 50]
  268 + mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
267 269
268 # We don't need to pad it to 30 seconds now! 270 # We don't need to pad it to 30 seconds now!
269 # mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) 271 # mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
@@ -106,11 +106,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -106,11 +106,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
106 std::vector<float> f = s->GetFrames(); 106 std::vector<float> f = s->GetFrames();
107 int32_t num_frames = f.size() / feat_dim; 107 int32_t num_frames = f.size() / feat_dim;
108 108
109 - if (num_frames > max_num_frames) { 109 + // we use 50 here so that there will be some zero tail paddings
  110 + if (num_frames >= max_num_frames - 50) {
110 SHERPA_ONNX_LOGE( 111 SHERPA_ONNX_LOGE(
111 "Only waves less than 30 seconds are supported. We process only the " 112 "Only waves less than 30 seconds are supported. We process only the "
112 "first 30 seconds and discard the remaining data"); 113 "first 30 seconds and discard the remaining data");
113 - num_frames = max_num_frames; 114 + num_frames = max_num_frames - 50;
114 } 115 }
115 116
116 NormalizeFeatures(f.data(), num_frames, feat_dim); 117 NormalizeFeatures(f.data(), num_frames, feat_dim);
@@ -140,7 +141,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -140,7 +141,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
140 Ort::Value mel = Ort::Value::CreateTensor<float>( 141 Ort::Value mel = Ort::Value::CreateTensor<float>(
141 model_->Allocator(), shape.data(), shape.size()); 142 model_->Allocator(), shape.data(), shape.size());
142 float *p_mel = mel.GetTensorMutableData<float>(); 143 float *p_mel = mel.GetTensorMutableData<float>();
143 - std::copy(f.begin(), f.end(), p_mel); 144 + std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel);
144 145
145 memset(p_mel + f.size(), 0, 146 memset(p_mel + f.size(), 0,
146 (actual_frames - num_frames) * feat_dim * sizeof(float)); 147 (actual_frames - num_frames) * feat_dim * sizeof(float));