Fangjun Kuang
Committed by GitHub

Remove whisper dependency from the whisper Python example (#283)

@@ -4,15 +4,14 @@ @@ -4,15 +4,14 @@
4 Please first run ./export-onnx.py 4 Please first run ./export-onnx.py
5 before you run this script 5 before you run this script
6 """ 6 """
  7 +import argparse
7 import base64 8 import base64
8 from typing import Tuple 9 from typing import Tuple
9 10
10 import kaldi_native_fbank as knf 11 import kaldi_native_fbank as knf
11 import onnxruntime as ort 12 import onnxruntime as ort
12 import torch 13 import torch
13 -  
14 -import whisper  
15 -import argparse 14 +import torchaudio
16 15
17 16
18 def get_args(): 17 def get_args():
@@ -225,16 +224,24 @@ def load_tokens(filename): @@ -225,16 +224,24 @@ def load_tokens(filename):
225 return tokens 224 return tokens
226 225
227 226
228 -def main():  
229 - args = get_args()  
230 - encoder = args.encoder  
231 - decoder = args.decoder  
232 -  
233 - audio = whisper.load_audio(args.sound_file) 227 +def compute_features(filename: str) -> torch.Tensor:
  228 + """
  229 + Args:
  230 + filename:
  231 + Path to an audio file.
  232 + Returns:
  233 + Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
  234 + """
  235 + wave, sample_rate = torchaudio.load(filename)
  236 + audio = wave[0].contiguous() # only use the first channel
  237 + if sample_rate != 16000:
  238 + audio = torchaudio.functional.resample(
  239 + audio, orig_freq=sample_rate, new_freq=16000
  240 + )
234 241
235 features = [] 242 features = []
236 online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) 243 online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
237 - online_whisper_fbank.accept_waveform(16000, audio) 244 + online_whisper_fbank.accept_waveform(16000, audio.numpy())
238 online_whisper_fbank.input_finished() 245 online_whisper_fbank.input_finished()
239 for i in range(online_whisper_fbank.num_frames_ready): 246 for i in range(online_whisper_fbank.num_frames_ready):
240 f = online_whisper_fbank.get_frame(i) 247 f = online_whisper_fbank.get_frame(i)
@@ -250,7 +257,14 @@ def main(): @@ -250,7 +257,14 @@ def main():
250 mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) 257 mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
251 mel = mel.t().unsqueeze(0) 258 mel = mel.t().unsqueeze(0)
252 259
253 - model = OnnxModel(encoder, decoder) 260 + return mel
  261 +
  262 +
  263 +def main():
  264 + args = get_args()
  265 +
  266 + mel = compute_features(args.sound_file)
  267 + model = OnnxModel(args.encoder, args.decoder)
254 268
255 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) 269 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
256 270