Committed by
GitHub
Remove whisper dependency from the whisper Python example (#283)
正在显示
1 个修改的文件
包含
25 行增加
和
11 行删除
| @@ -4,15 +4,14 @@ | @@ -4,15 +4,14 @@ | ||
| 4 | Please first run ./export-onnx.py | 4 | Please first run ./export-onnx.py |
| 5 | before you run this script | 5 | before you run this script |
| 6 | """ | 6 | """ |
| 7 | +import argparse | ||
| 7 | import base64 | 8 | import base64 |
| 8 | from typing import Tuple | 9 | from typing import Tuple |
| 9 | 10 | ||
| 10 | import kaldi_native_fbank as knf | 11 | import kaldi_native_fbank as knf |
| 11 | import onnxruntime as ort | 12 | import onnxruntime as ort |
| 12 | import torch | 13 | import torch |
| 13 | - | ||
| 14 | -import whisper | ||
| 15 | -import argparse | 14 | +import torchaudio |
| 16 | 15 | ||
| 17 | 16 | ||
| 18 | def get_args(): | 17 | def get_args(): |
| @@ -225,16 +224,24 @@ def load_tokens(filename): | @@ -225,16 +224,24 @@ def load_tokens(filename): | ||
| 225 | return tokens | 224 | return tokens |
| 226 | 225 | ||
| 227 | 226 | ||
| 228 | -def main(): | ||
| 229 | - args = get_args() | ||
| 230 | - encoder = args.encoder | ||
| 231 | - decoder = args.decoder | ||
| 232 | - | ||
| 233 | - audio = whisper.load_audio(args.sound_file) | 227 | +def compute_features(filename: str) -> torch.Tensor: |
| 228 | + """ | ||
| 229 | + Args: | ||
| 230 | + filename: | ||
| 231 | + Path to an audio file. | ||
| 232 | + Returns: | ||
| 233 | + Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. | ||
| 234 | + """ | ||
| 235 | + wave, sample_rate = torchaudio.load(filename) | ||
| 236 | + audio = wave[0].contiguous() # only use the first channel | ||
| 237 | + if sample_rate != 16000: | ||
| 238 | + audio = torchaudio.functional.resample( | ||
| 239 | + audio, orig_freq=sample_rate, new_freq=16000 | ||
| 240 | + ) | ||
| 234 | 241 | ||
| 235 | features = [] | 242 | features = [] |
| 236 | online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) | 243 | online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) |
| 237 | - online_whisper_fbank.accept_waveform(16000, audio) | 244 | + online_whisper_fbank.accept_waveform(16000, audio.numpy()) |
| 238 | online_whisper_fbank.input_finished() | 245 | online_whisper_fbank.input_finished() |
| 239 | for i in range(online_whisper_fbank.num_frames_ready): | 246 | for i in range(online_whisper_fbank.num_frames_ready): |
| 240 | f = online_whisper_fbank.get_frame(i) | 247 | f = online_whisper_fbank.get_frame(i) |
| @@ -250,7 +257,14 @@ def main(): | @@ -250,7 +257,14 @@ def main(): | ||
| 250 | mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) | 257 | mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) |
| 251 | mel = mel.t().unsqueeze(0) | 258 | mel = mel.t().unsqueeze(0) |
| 252 | 259 | ||
| 253 | - model = OnnxModel(encoder, decoder) | 260 | + return mel |
| 261 | + | ||
| 262 | + | ||
| 263 | +def main(): | ||
| 264 | + args = get_args() | ||
| 265 | + | ||
| 266 | + mel = compute_features(args.sound_file) | ||
| 267 | + model = OnnxModel(args.encoder, args.decoder) | ||
| 254 | 268 | ||
| 255 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) | 269 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) |
| 256 | 270 |
-
请 注册 或 登录 后发表评论