Fangjun Kuang
Committed by GitHub

Remove the 30-second constraint from whisper. (#471)

@@ -16,8 +16,12 @@ which $EXE @@ -16,8 +16,12 @@ which $EXE
16 names=( 16 names=(
17 tiny.en 17 tiny.en
18 base.en 18 base.en
19 -# small.en  
20 -# medium.en 19 +small.en
  20 +medium.en
  21 +tiny
  22 +base
  23 +small
  24 +medium
21 ) 25 )
22 26
23 for name in ${names[@]}; do 27 for name in ${names[@]}; do
@@ -33,8 +37,8 @@ for name in ${names[@]}; do @@ -33,8 +37,8 @@ for name in ${names[@]}; do
33 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 37 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
34 pushd $repo 38 pushd $repo
35 git lfs pull --include "*.onnx" 39 git lfs pull --include "*.onnx"
36 - git lfs pull --include "*.ort"  
37 - ls -lh *.{onnx,ort} 40 + # git lfs pull --include "*.ort"
  41 + ls -lh *.onnx
38 popd 42 popd
39 43
40 log "test fp32 onnx" 44 log "test fp32 onnx"
@@ -43,6 +47,7 @@ for name in ${names[@]}; do @@ -43,6 +47,7 @@ for name in ${names[@]}; do
43 --tokens=$repo/${name}-tokens.txt \ 47 --tokens=$repo/${name}-tokens.txt \
44 --whisper-encoder=$repo/${name}-encoder.onnx \ 48 --whisper-encoder=$repo/${name}-encoder.onnx \
45 --whisper-decoder=$repo/${name}-decoder.onnx \ 49 --whisper-decoder=$repo/${name}-decoder.onnx \
  50 + --whisper-tail-paddings=500 \
46 --num-threads=2 \ 51 --num-threads=2 \
47 $repo/test_wavs/0.wav \ 52 $repo/test_wavs/0.wav \
48 $repo/test_wavs/1.wav \ 53 $repo/test_wavs/1.wav \
@@ -54,28 +59,7 @@ for name in ${names[@]}; do @@ -54,28 +59,7 @@ for name in ${names[@]}; do
54 --tokens=$repo/${name}-tokens.txt \ 59 --tokens=$repo/${name}-tokens.txt \
55 --whisper-encoder=$repo/${name}-encoder.int8.onnx \ 60 --whisper-encoder=$repo/${name}-encoder.int8.onnx \
56 --whisper-decoder=$repo/${name}-decoder.int8.onnx \ 61 --whisper-decoder=$repo/${name}-decoder.int8.onnx \
57 - --num-threads=2 \  
58 - $repo/test_wavs/0.wav \  
59 - $repo/test_wavs/1.wav \  
60 - $repo/test_wavs/8k.wav  
61 -  
62 - log "test fp32 ort"  
63 -  
64 - time $EXE \  
65 - --tokens=$repo/${name}-tokens.txt \  
66 - --whisper-encoder=$repo/${name}-encoder.ort \  
67 - --whisper-decoder=$repo/${name}-decoder.ort \  
68 - --num-threads=2 \  
69 - $repo/test_wavs/0.wav \  
70 - $repo/test_wavs/1.wav \  
71 - $repo/test_wavs/8k.wav  
72 -  
73 - log "test int8 ort"  
74 -  
75 - time $EXE \  
76 - --tokens=$repo/${name}-tokens.txt \  
77 - --whisper-encoder=$repo/${name}-encoder.int8.ort \  
78 - --whisper-decoder=$repo/${name}-decoder.int8.ort \ 62 + --whisper-tail-paddings=500 \
79 --num-threads=2 \ 63 --num-threads=2 \
80 $repo/test_wavs/0.wav \ 64 $repo/test_wavs/0.wav \
81 $repo/test_wavs/1.wav \ 65 $repo/test_wavs/1.wav \
@@ -15,7 +15,7 @@ jobs: @@ -15,7 +15,7 @@ jobs:
15 strategy: 15 strategy:
16 fail-fast: false 16 fail-fast: false
17 matrix: 17 matrix:
18 - os: [macos-latest] 18 + os: [ubuntu-latest]
19 model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] 19 model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
20 python-version: ["3.8"] 20 python-version: ["3.8"]
21 21
@@ -44,7 +44,7 @@ jobs: @@ -44,7 +44,7 @@ jobs:
44 ls -lh 44 ls -lh
45 fi 45 fi
46 python3 ./export-onnx.py --model ${{ matrix.model }} 46 python3 ./export-onnx.py --model ${{ matrix.model }}
47 - python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ 47 + # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
48 48
49 ls -lh 49 ls -lh
50 50
@@ -52,41 +52,61 @@ jobs: @@ -52,41 +52,61 @@ jobs:
52 ls -lh ~/.cache/whisper 52 ls -lh ~/.cache/whisper
53 fi 53 fi
54 54
  55 + src=sherpa-onnx-whisper-${{ matrix.model }}
  56 +
  57 + mkdir $src
  58 + cp *.onnx $src/
  59 + cp *tokens.txt $src
  60 +
  61 + cd $src
  62 + mkdir -p test_wavs
  63 + cd test_wavs
  64 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
  65 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
  66 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
  67 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
  68 + cd ../..
  69 + mv $src ../..
  70 +
  71 + cd ../..
  72 + echo "--------------------"
  73 + ls -lh
  74 + ls -lh $src
  75 + echo "--------------------"
  76 +
  77 + tar cjvf ./$src.tar.bz2 $src
  78 +
  79 + - name: Release
  80 + uses: svenstaro/upload-release-action@v2
  81 + with:
  82 + file_glob: true
  83 + file: ./*.tar.bz2
  84 + overwrite: true
  85 + repo_name: k2-fsa/sherpa-onnx
  86 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  87 + tag: asr-models
  88 +
55 - name: Publish ${{ matrix.model }} to huggingface 89 - name: Publish ${{ matrix.model }} to huggingface
56 shell: bash 90 shell: bash
57 env: 91 env:
58 HF_TOKEN: ${{ secrets.HF_TOKEN }} 92 HF_TOKEN: ${{ secrets.HF_TOKEN }}
59 run: | 93 run: |
60 - model=${{ matrix.model }}  
61 -  
62 - cd scripts/whisper 94 + src=sherpa-onnx-whisper-${{ matrix.model }}
63 95
64 git config --global user.email "csukuangfj@gmail.com" 96 git config --global user.email "csukuangfj@gmail.com"
65 git config --global user.name "Fangjun Kuang" 97 git config --global user.name "Fangjun Kuang"
66 98
67 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface 99 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
  100 + rm -rf huggingface/*
68 101
69 - cp *.onnx ./huggingface  
70 - cp *.ort ./huggingface  
71 - cp *tokens.txt ./huggingface 102 + cp -av $src/* ./huggingface/
72 103
73 cd huggingface 104 cd huggingface
74 105
75 - if [[ $model == distil-medium.en ]]; then  
76 - mkdir test_wavs  
77 - cd test_wavs  
78 - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav  
79 - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav  
80 - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav  
81 - wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt  
82 - git add .  
83 - cd ..  
84 - fi  
85 -  
86 git status 106 git status
87 ls -lh 107 ls -lh
88 git lfs track "*.onnx" 108 git lfs track "*.onnx"
89 - git lfs track "*.ort" 109 + # git lfs track "*.ort"
90 git add . 110 git add .
91 git commit -m "upload ${{ matrix.model }}" 111 git commit -m "upload ${{ matrix.model }}"
92 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main 112 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
@@ -107,6 +107,16 @@ jobs: @@ -107,6 +107,16 @@ jobs:
107 name: release-static 107 name: release-static
108 path: build/bin/* 108 path: build/bin/*
109 109
  110 + - name: Test offline Whisper
  111 + shell: bash
  112 + run: |
  113 + export PATH=$PWD/build/bin:$PATH
  114 + export EXE=sherpa-onnx-offline
  115 +
  116 + readelf -d build/bin/sherpa-onnx-offline
  117 +
  118 + .github/scripts/test-offline-whisper.sh
  119 +
110 - name: Test online CTC 120 - name: Test online CTC
111 shell: bash 121 shell: bash
112 run: | 122 run: |
@@ -139,16 +149,6 @@ jobs: @@ -139,16 +149,6 @@ jobs:
139 149
140 .github/scripts/test-online-paraformer.sh 150 .github/scripts/test-online-paraformer.sh
141 151
142 - - name: Test offline Whisper  
143 - shell: bash  
144 - run: |  
145 - export PATH=$PWD/build/bin:$PATH  
146 - export EXE=sherpa-onnx-offline  
147 -  
148 - readelf -d build/bin/sherpa-onnx-offline  
149 -  
150 - .github/scripts/test-offline-whisper.sh  
151 -  
152 - name: Test offline transducer 152 - name: Test offline transducer
153 shell: bash 153 shell: bash
154 run: | 154 run: |
@@ -93,13 +93,13 @@ jobs: @@ -93,13 +93,13 @@ jobs:
93 93
94 .github/scripts/test-online-paraformer.sh 94 .github/scripts/test-online-paraformer.sh
95 95
96 - - name: Test offline Whisper for windows x86  
97 - shell: bash  
98 - run: |  
99 - export PATH=$PWD/build/bin/Release:$PATH  
100 - export EXE=sherpa-onnx-offline.exe  
101 -  
102 - .github/scripts/test-offline-whisper.sh 96 + # - name: Test offline Whisper for windows x86
  97 + # shell: bash
  98 + # run: |
  99 + # export PATH=$PWD/build/bin/Release:$PATH
  100 + # export EXE=sherpa-onnx-offline.exe
  101 + #
  102 + # .github/scripts/test-offline-whisper.sh
103 103
104 - name: Test offline CTC for windows x86 104 - name: Test offline CTC for windows x86
105 shell: bash 105 shell: bash
@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py @@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
8 8
9 Thanks to https://github.com/TadaoYamaoka 9 Thanks to https://github.com/TadaoYamaoka
10 for making the onnx export script public. 10 for making the onnx export script public.
  11 +
  12 +Note that we have removed the 30 seconds constraint from whisper. You can
  13 +use any T <= 30.
11 """ 14 """
12 15
13 import argparse 16 import argparse
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional @@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
17 20
18 import onnx 21 import onnx
19 import torch 22 import torch
  23 +import torch.nn.functional as F
20 from onnxruntime.quantization import QuantType, quantize_dynamic 24 from onnxruntime.quantization import QuantType, quantize_dynamic
21 from torch import Tensor, nn 25 from torch import Tensor, nn
22 26
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): @@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
65 onnx.save(model, filename) 69 onnx.save(model, filename)
66 70
67 71
  72 +def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
  73 + """
  74 + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
  75 + the mel spectrogram of the audio
  76 + """
  77 + x = F.gelu(self.conv1(x))
  78 + x = F.gelu(self.conv2(x))
  79 + x = x.permute(0, 2, 1)
  80 +
  81 + if False:
  82 + # This branch contains the original code
  83 + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  84 + x = (x + self.positional_embedding).to(x.dtype)
  85 + else:
  86 + # This branch contains the actual changes
  87 + assert (
  88 + x.shape[2] == self.positional_embedding.shape[1]
  89 + ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
  90 + assert (
  91 + x.shape[1] == self.positional_embedding.shape[0]
  92 + ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
  93 + x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
  94 +
  95 + for block in self.blocks:
  96 + x = block(x)
  97 +
  98 + x = self.ln_post(x)
  99 + return x
  100 +
  101 +
  102 +AudioEncoder.forward = modified_audio_encoder_forward
  103 +
  104 +
68 class AudioEncoderTensorCache(nn.Module): 105 class AudioEncoderTensorCache(nn.Module):
69 def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): 106 def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
70 super().__init__() 107 super().__init__()
@@ -279,6 +316,7 @@ def main(): @@ -279,6 +316,7 @@ def main():
279 model = whisper.load_model(filename) 316 model = whisper.load_model(filename)
280 else: 317 else:
281 model = whisper.load_model(name) 318 model = whisper.load_model(name)
  319 + print(model.dims)
282 320
283 print( 321 print(
284 f"number of model parameters: {name}", 322 f"number of model parameters: {name}",
@@ -311,19 +349,20 @@ def main(): @@ -311,19 +349,20 @@ def main():
311 assert mel.shape == (batch_size, 80, 30 * 100) 349 assert mel.shape == (batch_size, 80, 30 * 100)
312 350
313 encoder = AudioEncoderTensorCache(model.encoder, model.decoder) 351 encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
  352 +
314 n_layer_cross_k, n_layer_cross_v = encoder(mel) 353 n_layer_cross_k, n_layer_cross_v = encoder(mel)
315 assert n_layer_cross_k.shape == ( 354 assert n_layer_cross_k.shape == (
316 model.dims.n_text_layer, 355 model.dims.n_text_layer,
317 batch_size, 356 batch_size,
318 model.dims.n_audio_ctx, 357 model.dims.n_audio_ctx,
319 model.dims.n_text_state, 358 model.dims.n_text_state,
320 - ), n_layer_cross_k.shape 359 + ), (n_layer_cross_k.shape, model.dims)
321 assert n_layer_cross_v.shape == ( 360 assert n_layer_cross_v.shape == (
322 model.dims.n_text_layer, 361 model.dims.n_text_layer,
323 batch_size, 362 batch_size,
324 model.dims.n_audio_ctx, 363 model.dims.n_audio_ctx,
325 model.dims.n_text_state, 364 model.dims.n_text_state,
326 - ), n_layer_cross_v.shape 365 + ), (n_layer_cross_v.shape, model.dims)
327 366
328 encoder_filename = f"{name}-encoder.onnx" 367 encoder_filename = f"{name}-encoder.onnx"
329 torch.onnx.export( 368 torch.onnx.export(
@@ -334,9 +373,9 @@ def main(): @@ -334,9 +373,9 @@ def main():
334 input_names=["mel"], 373 input_names=["mel"],
335 output_names=["n_layer_cross_k", "n_layer_cross_v"], 374 output_names=["n_layer_cross_k", "n_layer_cross_v"],
336 dynamic_axes={ 375 dynamic_axes={
337 - "mel": {0: "n_audio"}, # n_audio is also known as batch_size  
338 - "n_layer_cross_k": {1: "n_audio"},  
339 - "n_layer_cross_v": {1: "n_audio"}, 376 + "mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
  377 + "n_layer_cross_k": {1: "n_audio", 2: "T"},
  378 + "n_layer_cross_v": {1: "n_audio", 2: "T"},
340 }, 379 },
341 ) 380 )
342 381
@@ -461,8 +500,8 @@ def main(): @@ -461,8 +500,8 @@ def main():
461 "tokens": {0: "n_audio", 1: "n_tokens"}, 500 "tokens": {0: "n_audio", 1: "n_tokens"},
462 "in_n_layer_self_k_cache": {1: "n_audio"}, 501 "in_n_layer_self_k_cache": {1: "n_audio"},
463 "in_n_layer_self_v_cache": {1: "n_audio"}, 502 "in_n_layer_self_v_cache": {1: "n_audio"},
464 - "n_layer_cross_k": {1: "n_audio"},  
465 - "n_layer_cross_v": {1: "n_audio"}, 503 + "n_layer_cross_k": {1: "n_audio", 2: "T"},
  504 + "n_layer_cross_v": {1: "n_audio", 2: "T"},
466 }, 505 },
467 ) 506 )
468 507
@@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor: @@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
253 log_spec = torch.clamp(features, min=1e-10).log10() 253 log_spec = torch.clamp(features, min=1e-10).log10()
254 log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 254 log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
255 mel = (log_spec + 4.0) / 4.0 255 mel = (log_spec + 4.0) / 4.0
  256 + # mel (T, 80)
  257 +
  258 + # We pad 50 frames at the end so that it is able to detect eot
  259 + # You can use another value instead of 50.
  260 + mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
  261 + # Note that if it throws for a multilingual model,
  262 + # please use a larger value, say 300
  263 +
256 target = 3000 264 target = 3000
257 - mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) 265 + if mel.shape[0] > target:
  266 + mel = mel[:target]
  267 +
  268 + # We don't need to pad it to 30 seconds now!
  269 + # mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
  270 +
258 mel = mel.t().unsqueeze(0) 271 mel = mel.t().unsqueeze(0)
259 272
260 return mel 273 return mel
@@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
115 115
116 NormalizeFeatures(f.data(), num_frames, feat_dim); 116 NormalizeFeatures(f.data(), num_frames, feat_dim);
117 117
118 - std::array<int64_t, 3> shape{1, max_num_frames, feat_dim}; 118 + // note that 50 is an experience value.
  119 + // see also ../../scripts/whisper/test.py
  120 + //
  121 + // You can replace 50 by other values, say, 100.
  122 + //
  123 + // Since we have removed the 30 seconds constraint, we need
  124 + // tail_padding_frames so that whisper is able to detect the eot token.
  125 + int32_t tail_padding_frames = 50;
  126 + if (model_->IsMultiLingual()) {
  127 + // 300 is an experience value. If it throws, please use a larger value.
  128 + tail_padding_frames = 300;
  129 + }
  130 +
  131 + if (config_.model_config.whisper.tail_paddings > 0) {
  132 + tail_padding_frames = config_.model_config.whisper.tail_paddings;
  133 + }
  134 +
  135 + int32_t actual_frames =
  136 + std::min(num_frames + tail_padding_frames, max_num_frames);
  137 +
  138 + std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
119 139
120 Ort::Value mel = Ort::Value::CreateTensor<float>( 140 Ort::Value mel = Ort::Value::CreateTensor<float>(
121 model_->Allocator(), shape.data(), shape.size()); 141 model_->Allocator(), shape.data(), shape.size());
@@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
123 std::copy(f.begin(), f.end(), p_mel); 143 std::copy(f.begin(), f.end(), p_mel);
124 144
125 memset(p_mel + f.size(), 0, 145 memset(p_mel + f.size(), 0,
126 - (max_num_frames - num_frames) * feat_dim * sizeof(float)); 146 + (actual_frames - num_frames) * feat_dim * sizeof(float));
127 mel = Transpose12(model_->Allocator(), &mel); 147 mel = Transpose12(model_->Allocator(), &mel);
128 148
129 try { 149 try {
@@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { @@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
32 "Valid values: transcribe, translate. " 32 "Valid values: transcribe, translate. "
33 "Note that for non-multilingual models, it supports " 33 "Note that for non-multilingual models, it supports "
34 "only 'transcribe'"); 34 "only 'transcribe'");
  35 +
  36 + po->Register(
  37 + "whisper-tail-paddings", &tail_paddings,
  38 + "Suggest value: 50 for English models. 300 for multilingual models. "
  39 + "Since we have removed the 30-second constraint, we need to add some "
  40 + "tail padding frames "
  41 + "so that whisper can detect the eot token. Leave it to -1 to use 50 for "
  42 + "English models and 300 for multilingual models.");
35 } 43 }
36 44
37 bool OfflineWhisperModelConfig::Validate() const { 45 bool OfflineWhisperModelConfig::Validate() const {
@@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const { @@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const {
63 os << "encoder=\"" << encoder << "\", "; 71 os << "encoder=\"" << encoder << "\", ";
64 os << "decoder=\"" << decoder << "\", "; 72 os << "decoder=\"" << decoder << "\", ";
65 os << "language=\"" << language << "\", "; 73 os << "language=\"" << language << "\", ";
66 - os << "task=\"" << task << "\")"; 74 + os << "task=\"" << task << "\", ";
  75 + os << "tail_paddings=" << tail_paddings << ")";
67 76
68 return os.str(); 77 return os.str();
69 } 78 }
@@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig { @@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig {
28 // Note: For non-multilingual models, it supports only "transcribe" 28 // Note: For non-multilingual models, it supports only "transcribe"
29 std::string task = "transcribe"; 29 std::string task = "transcribe";
30 30
  31 + // Number of tail padding frames.
  32 + //
  33 + // Since we remove the 30-second constraint, we need to add some paddings
  34 + // at the end.
  35 + //
  36 + // Recommended values:
  37 + // - 50 for English models
  38 + // - 300 for multilingual models
  39 + int32_t tail_paddings = -1;
  40 +
31 OfflineWhisperModelConfig() = default; 41 OfflineWhisperModelConfig() = default;
32 OfflineWhisperModelConfig(const std::string &encoder, 42 OfflineWhisperModelConfig(const std::string &encoder,
33 const std::string &decoder, 43 const std::string &decoder,
34 const std::string &language, 44 const std::string &language,
35 - const std::string &task)  
36 - : encoder(encoder), decoder(decoder), language(language), task(task) {} 45 + const std::string &task, int32_t tail_paddings)
  46 + : encoder(encoder),
  47 + decoder(decoder),
  48 + language(language),
  49 + task(task),
  50 + tail_paddings(tail_paddings) {}
37 51
38 void Register(ParseOptions *po); 52 void Register(ParseOptions *po);
39 bool Validate() const; 53 bool Validate() const;
@@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) { @@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) {
15 using PyClass = OfflineWhisperModelConfig; 15 using PyClass = OfflineWhisperModelConfig;
16 py::class_<PyClass>(*m, "OfflineWhisperModelConfig") 16 py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
17 .def(py::init<const std::string &, const std::string &, 17 .def(py::init<const std::string &, const std::string &,
18 - const std::string &, const std::string &>(), 18 + const std::string &, const std::string &, int32_t>(),
19 py::arg("encoder"), py::arg("decoder"), py::arg("language"), 19 py::arg("encoder"), py::arg("decoder"), py::arg("language"),
20 - py::arg("task")) 20 + py::arg("task"), py::arg("tail_paddings") = -1)
21 .def_readwrite("encoder", &PyClass::encoder) 21 .def_readwrite("encoder", &PyClass::encoder)
22 .def_readwrite("decoder", &PyClass::decoder) 22 .def_readwrite("decoder", &PyClass::decoder)
23 .def_readwrite("language", &PyClass::language) 23 .def_readwrite("language", &PyClass::language)
24 .def_readwrite("task", &PyClass::task) 24 .def_readwrite("task", &PyClass::task)
  25 + .def_readwrite("tail_paddings", &PyClass::tail_paddings)
25 .def("__str__", &PyClass::ToString); 26 .def("__str__", &PyClass::ToString);
26 } 27 }
27 28