Fangjun Kuang
Committed by GitHub

Support whisper models (#238)

正在显示 39 个修改的文件 包含 1835 行增加51 行删除
  1 +name: export-whisper-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: release-whisper-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + release-whisper-models:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: ${{ matrix.model }}
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + model: ["tiny.en", "base.en", "small.en", "medium.en"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v2
  23 +
  24 + - name: Install dependencies
  25 + shell: bash
  26 + run: |
  27 + python3 -m pip install openai-whisper torch onnxruntime onnx
  28 +
  29 + - name: export ${{ matrix.model }}
  30 + shell: bash
  31 + run: |
  32 + cd scripts/whisper
  33 + python3 ./export-onnx.py --model ${{ matrix.model }}
  34 + python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
  35 +
  36 + ls -lh
  37 +
  38 + ls -lh ~/.cache/whisper
  39 +
  40 + - name: Publish ${{ matrix.model }} to huggingface
  41 + shell: bash
  42 + env:
  43 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  44 + run: |
  45 + cd scripts/whisper
  46 +
  47 + git config --global user.email "csukuangfj@gmail.com"
  48 + git config --global user.name "Fangjun Kuang"
  49 +
  50 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
  51 +
  52 + cp *.onnx ./huggingface
  53 + cp *.ort ./huggingface
  54 + cp *tokens.txt ./huggingface
  55 +
  56 + cd huggingface
  57 + git status
  58 + ls -lh
  59 + git lfs track "*.onnx"
  60 + git lfs track "*.ort"
  61 + git add .
  62 + git commit -m "upload ${{ matrix.model }}"
  63 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
@@ -23,14 +23,14 @@ on: @@ -23,14 +23,14 @@ on:
23 - 'sherpa-onnx/jni/*' 23 - 'sherpa-onnx/jni/*'
24 24
25 concurrency: 25 concurrency:
26 - group: jni-${{ github.ref }} 26 + group: run-java-test-${{ github.ref }}
27 cancel-in-progress: true 27 cancel-in-progress: true
28 28
29 permissions: 29 permissions:
30 contents: read 30 contents: read
31 31
32 jobs: 32 jobs:
33 - jni: 33 + run_java_test:
34 runs-on: ${{ matrix.os }} 34 runs-on: ${{ matrix.os }}
35 strategy: 35 strategy:
36 fail-fast: false 36 fail-fast: false
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.5.5") 4 +set(SHERPA_ONNX_VERSION "1.6.0")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz  
16 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz  
17 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz  
18 - /tmp/kaldi-native-fbank-1.17.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
  16 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
  17 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
  18 + /tmp/kaldi-native-fbank-1.18.1.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 # 2 #
3 # Copyright (c) 2023 by manyeyes 3 # Copyright (c) 2023 by manyeyes
  4 +# Copyright (c) 2023 Xiaomi Corporation
4 5
5 """ 6 """
6 This file demonstrates how to use sherpa-onnx Python API to transcribe 7 This file demonstrates how to use sherpa-onnx Python API to transcribe
@@ -34,6 +35,27 @@ file(s) with a non-streaming model. @@ -34,6 +35,27 @@ file(s) with a non-streaming model.
34 35
35 (3) For CTC models from NeMo 36 (3) For CTC models from NeMo
36 37
  38 +python3 ./python-api-examples/offline-decode-files.py \
  39 + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
  40 + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
  41 + --num-threads=2 \
  42 + --decoding-method=greedy_search \
  43 + --debug=false \
  44 + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
  45 + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
  46 + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
  47 +
  48 +(4) For Whisper models
  49 +
  50 +python3 ./python-api-examples/offline-decode-files.py \
  51 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  52 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  53 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  54 + --num-threads=1 \
  55 + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
  56 + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
  57 + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
  58 +
37 Please refer to 59 Please refer to
38 https://k2-fsa.github.io/sherpa/onnx/index.html 60 https://k2-fsa.github.io/sherpa/onnx/index.html
39 to install sherpa-onnx and to download the pre-trained models 61 to install sherpa-onnx and to download the pre-trained models
@@ -145,6 +167,20 @@ def get_args(): @@ -145,6 +167,20 @@ def get_args():
145 ) 167 )
146 168
147 parser.add_argument( 169 parser.add_argument(
  170 + "--whisper-encoder",
  171 + default="",
  172 + type=str,
  173 + help="Path to whisper encoder model",
  174 + )
  175 +
  176 + parser.add_argument(
  177 + "--whisper-decoder",
  178 + default="",
  179 + type=str,
  180 + help="Path to whisper decoder model",
  181 + )
  182 +
  183 + parser.add_argument(
148 "--decoding-method", 184 "--decoding-method",
149 type=str, 185 type=str,
150 default="greedy_search", 186 default="greedy_search",
@@ -247,6 +283,8 @@ def main(): @@ -247,6 +283,8 @@ def main():
247 if args.encoder: 283 if args.encoder:
248 assert len(args.paraformer) == 0, args.paraformer 284 assert len(args.paraformer) == 0, args.paraformer
249 assert len(args.nemo_ctc) == 0, args.nemo_ctc 285 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  286 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  287 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
250 288
251 contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] 289 contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
252 if contexts: 290 if contexts:
@@ -271,6 +309,9 @@ def main(): @@ -271,6 +309,9 @@ def main():
271 ) 309 )
272 elif args.paraformer: 310 elif args.paraformer:
273 assert len(args.nemo_ctc) == 0, args.nemo_ctc 311 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  312 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  313 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  314 +
274 assert_file_exists(args.paraformer) 315 assert_file_exists(args.paraformer)
275 316
276 recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 317 recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
@@ -283,6 +324,11 @@ def main(): @@ -283,6 +324,11 @@ def main():
283 debug=args.debug, 324 debug=args.debug,
284 ) 325 )
285 elif args.nemo_ctc: 326 elif args.nemo_ctc:
  327 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  328 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  329 +
  330 + assert_file_exists(args.nemo_ctc)
  331 +
286 recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( 332 recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
287 model=args.nemo_ctc, 333 model=args.nemo_ctc,
288 tokens=args.tokens, 334 tokens=args.tokens,
@@ -292,6 +338,18 @@ def main(): @@ -292,6 +338,18 @@ def main():
292 decoding_method=args.decoding_method, 338 decoding_method=args.decoding_method,
293 debug=args.debug, 339 debug=args.debug,
294 ) 340 )
  341 + elif args.whisper_encoder:
  342 + assert_file_exists(args.whisper_encoder)
  343 + assert_file_exists(args.whisper_decoder)
  344 +
  345 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  346 + encoder=args.whisper_encoder,
  347 + decoder=args.whisper_decoder,
  348 + tokens=args.tokens,
  349 + num_threads=args.num_threads,
  350 + decoding_method=args.decoding_method,
  351 + debug=args.debug,
  352 + )
295 else: 353 else:
296 print("Please specify at least one model") 354 print("Please specify at least one model")
297 return 355 return
  1 +*.onnx
  2 +*.config
  3 +*.ort
  4 +*-tokens.txt
  1 +# Introduction
  2 +
  3 +This folder contains code showing how to convert [Whisper][whisper] to onnx
  4 +and use onnxruntime to replace PyTorch for speech recognition.
  5 +
  6 +You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
  7 +
  8 +[whisper]: https://github.com/openai/whisper
  9 +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +# flake8: noqa
  4 +
  5 +"""
  6 +Note: Code in this file is modified from
  7 +https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
  8 +
  9 +Thanks to https://github.com/TadaoYamaoka
  10 +for making the onnx export script public.
  11 +"""
  12 +
  13 +import argparse
  14 +from pathlib import Path
  15 +from typing import Any, Dict, Optional
  16 +
  17 +import onnx
  18 +import torch
  19 +from onnxruntime.quantization import QuantType, quantize_dynamic
  20 +from torch import Tensor, nn
  21 +
  22 +import whisper
  23 +from whisper.model import (
  24 + AudioEncoder,
  25 + MultiHeadAttention,
  26 + ResidualAttentionBlock,
  27 + TextDecoder,
  28 +)
  29 +
  30 +
  31 +def get_args():
  32 + parser = argparse.ArgumentParser()
  33 + parser.add_argument(
  34 + "--model",
  35 + type=str,
  36 + required=True,
  37 + # fmt: off
  38 + choices=[
  39 + "tiny", "tiny.en", "base", "base.en",
  40 + "small", "small.en", "medium", "medium.en",
  41 + "large", "large-v1", "large-v2"],
  42 + # fmt: on
  43 + )
  44 + return parser.parse_args()
  45 +
  46 +
  47 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  48 + """Add meta data to an ONNX model. It is changed in-place.
  49 +
  50 + Args:
  51 + filename:
  52 + Filename of the ONNX model to be changed.
  53 + meta_data:
  54 + Key-value pairs.
  55 + """
  56 + model = onnx.load(filename)
  57 + for key, value in meta_data.items():
  58 + meta = model.metadata_props.add()
  59 + meta.key = key
  60 + meta.value = str(value)
  61 +
  62 + onnx.save(model, filename)
  63 +
  64 +
  65 +class AudioEncoderTensorCache(nn.Module):
  66 + def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
  67 + super().__init__()
  68 + self.audioEncoder = inAudioEncoder
  69 + self.textDecoder = inTextDecoder
  70 +
  71 + def forward(self, x: Tensor):
  72 + audio_features = self.audioEncoder(x)
  73 +
  74 + n_layer_cross_k_list = []
  75 + n_layer_cross_v_list = []
  76 + for block in self.textDecoder.blocks:
  77 + n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
  78 + n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
  79 +
  80 + return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
  81 +
  82 +
  83 +class MultiHeadAttentionCross(nn.Module):
  84 + def __init__(self, inMultiHeadAttention: MultiHeadAttention):
  85 + super().__init__()
  86 + self.multiHeadAttention = inMultiHeadAttention
  87 +
  88 + def forward(
  89 + self,
  90 + x: Tensor,
  91 + k: Tensor,
  92 + v: Tensor,
  93 + mask: Optional[Tensor] = None,
  94 + ):
  95 + q = self.multiHeadAttention.query(x)
  96 + wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
  97 + return self.multiHeadAttention.out(wv)
  98 +
  99 +
  100 +class MultiHeadAttentionSelf(nn.Module):
  101 + def __init__(self, inMultiHeadAttention: MultiHeadAttention):
  102 + super().__init__()
  103 + self.multiHeadAttention = inMultiHeadAttention
  104 +
  105 + def forward(
  106 + self,
  107 + x: Tensor, # (b, n_ctx , n_state)
  108 + k_cache: Tensor, # (b, n_ctx_cache, n_state)
  109 + v_cache: Tensor, # (b, n_ctx_cache, n_state)
  110 + mask: Tensor,
  111 + ):
  112 + q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state)
  113 + k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state)
  114 + v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state)
  115 +
  116 + k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state)
  117 + v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state)
  118 +
  119 + wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
  120 + return self.multiHeadAttention.out(wv), k_cache, v_cache
  121 +
  122 +
  123 +class ResidualAttentionBlockTensorCache(nn.Module):
  124 + def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
  125 + super().__init__()
  126 + self.originalBlock = inResidualAttentionBlock
  127 + self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
  128 + self.cross_attn = (
  129 + MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
  130 + if inResidualAttentionBlock.cross_attn
  131 + else None
  132 + )
  133 +
  134 + def forward(
  135 + self,
  136 + x: Tensor,
  137 + self_k_cache: Tensor,
  138 + self_v_cache: Tensor,
  139 + cross_k: Tensor,
  140 + cross_v: Tensor,
  141 + mask: Tensor,
  142 + ):
  143 + self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
  144 + self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
  145 + )
  146 + x = x + self_attn_x
  147 +
  148 + if self.cross_attn:
  149 + x = x + self.cross_attn(
  150 + self.originalBlock.cross_attn_ln(x), cross_k, cross_v
  151 + )
  152 +
  153 + x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
  154 + return x, self_k_cache_updated, self_v_cache_updated
  155 +
  156 +
  157 +class TextDecoderTensorCache(nn.Module):
  158 + def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
  159 + super().__init__()
  160 + self.textDecoder = inTextDecoder
  161 + self.n_ctx = in_n_ctx
  162 +
  163 + self.blocks = []
  164 + for orginal_block in self.textDecoder.blocks:
  165 + self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
  166 +
  167 + def forward(
  168 + self,
  169 + tokens: Tensor,
  170 + n_layer_self_k_cache: Tensor,
  171 + n_layer_self_v_cache: Tensor,
  172 + n_layer_cross_k: Tensor,
  173 + n_layer_cross_v: Tensor,
  174 + offset: Tensor,
  175 + ):
  176 + x = (
  177 + self.textDecoder.token_embedding(tokens)
  178 + + self.textDecoder.positional_embedding[
  179 + offset[0] : offset[0] + tokens.shape[-1]
  180 + ]
  181 + )
  182 + x = x.to(n_layer_cross_k[0].dtype)
  183 +
  184 + i = 0
  185 + for block in self.blocks:
  186 + self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
  187 + self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
  188 + x, self_k_cache, self_v_cache = block(
  189 + x,
  190 + self_k_cache=self_k_cache,
  191 + self_v_cache=self_v_cache,
  192 + cross_k=n_layer_cross_k[i],
  193 + cross_v=n_layer_cross_v[i],
  194 + mask=self.textDecoder.mask,
  195 + )
  196 + n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
  197 + n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
  198 + i += 1
  199 +
  200 + x = self.textDecoder.ln(x)
  201 +
  202 + logits = (
  203 + x
  204 + @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
  205 + ).float()
  206 +
  207 + return logits, n_layer_self_k_cache, n_layer_self_v_cache
  208 +
  209 +
  210 +# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232
  211 +def convert_tokens(name, model):
  212 + whisper_dir = Path(whisper.__file__).parent
  213 + multilingual = model.is_multilingual
  214 + tokenizer = (
  215 + whisper_dir
  216 + / "assets"
  217 + / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
  218 + )
  219 + if not tokenizer.is_file():
  220 + raise ValueError(f"Cannot find {tokenizer}")
  221 +
  222 + # import base64
  223 +
  224 + with open(tokenizer, "r") as f:
  225 + contents = f.read()
  226 + # tokens = {
  227 + # base64.b64decode(token): int(rank)
  228 + # for token, rank in (line.split() for line in contents.splitlines() if line)
  229 + # }
  230 + tokens = {
  231 + token: int(rank)
  232 + for token, rank in (line.split() for line in contents.splitlines() if line)
  233 + }
  234 +
  235 + with open(f"{name}-tokens.txt", "w") as f:
  236 + for t, i in tokens.items():
  237 + f.write(f"{t} {i}\n")
  238 +
  239 +
  240 +@torch.no_grad()
  241 +def main():
  242 + args = get_args()
  243 + name = args.model
  244 +
  245 + opset_version = 13
  246 +
  247 + model = whisper.load_model(name)
  248 + convert_tokens(name=name, model=model)
  249 +
  250 + # write tokens
  251 +
  252 + tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
  253 + model.eval()
  254 + print(model.dims)
  255 + audio = torch.rand(16000 * 2)
  256 + audio = whisper.pad_or_trim(audio)
  257 + assert audio.shape == (16000 * 30,), audio.shape
  258 +
  259 + # make log-Mel spectrogram and move to the same device as the model
  260 + mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
  261 + batch_size = 1
  262 + assert mel.shape == (batch_size, 80, 30 * 100)
  263 +
  264 + encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
  265 + n_layer_cross_k, n_layer_cross_v = encoder(mel)
  266 + assert n_layer_cross_k.shape == (
  267 + model.dims.n_text_layer,
  268 + batch_size,
  269 + model.dims.n_audio_ctx,
  270 + model.dims.n_text_state,
  271 + ), n_layer_cross_k.shape
  272 + assert n_layer_cross_v.shape == (
  273 + model.dims.n_text_layer,
  274 + batch_size,
  275 + model.dims.n_audio_ctx,
  276 + model.dims.n_text_state,
  277 + ), n_layer_cross_v.shape
  278 +
  279 + encoder_filename = f"{name}-encoder.onnx"
  280 + torch.onnx.export(
  281 + encoder,
  282 + mel,
  283 + encoder_filename,
  284 + opset_version=opset_version,
  285 + input_names=["mel"],
  286 + output_names=["n_layer_cross_k", "n_layer_cross_v"],
  287 + dynamic_axes={
  288 + "mel": {0: "n_audio"}, # n_audio is also known as batch_size
  289 + "n_layer_cross_k": {1: "n_audio"},
  290 + "n_layer_cross_v": {1: "n_audio"},
  291 + },
  292 + )
  293 +
  294 + encoder_meta_data = {
  295 + "model_type": f"whisper-{name}",
  296 + "version": "1",
  297 + "maintainer": "k2-fsa",
  298 + "n_mels": model.dims.n_mels,
  299 + "n_audio_ctx": model.dims.n_audio_ctx,
  300 + "n_audio_state": model.dims.n_audio_state,
  301 + "n_audio_head": model.dims.n_audio_head,
  302 + "n_audio_layer": model.dims.n_audio_layer,
  303 + "n_vocab": model.dims.n_vocab,
  304 + "n_text_ctx": model.dims.n_text_ctx,
  305 + "n_text_state": model.dims.n_text_state,
  306 + "n_text_head": model.dims.n_text_head,
  307 + "n_text_layer": model.dims.n_text_layer,
  308 + "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
  309 + "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
  310 + "all_language_codes": ",".join(tokenizer.all_language_codes),
  311 + "sot": tokenizer.sot,
  312 + "sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
  313 + "eot": tokenizer.eot,
  314 + "blank_id": tokenizer.encode(" ")[0],
  315 + "is_multilingual": int(model.is_multilingual),
  316 + "no_speech": tokenizer.no_speech,
  317 + "non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))),
  318 + "transcribe": tokenizer.transcribe,
  319 + "translate": tokenizer.translate,
  320 + "sot_prev": tokenizer.sot_prev,
  321 + "sot_lm": tokenizer.sot_lm,
  322 + "no_timestamps": tokenizer.no_timestamps,
  323 + }
  324 + print(f"encoder_meta_data: {encoder_meta_data}")
  325 + add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
  326 +
  327 + n_audio = mel.shape[0]
  328 + tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
  329 + mel.device
  330 + ) # [n_audio, 3]
  331 + decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
  332 + n_layer_self_k_cache = torch.zeros(
  333 + (
  334 + len(model.decoder.blocks),
  335 + n_audio,
  336 + model.dims.n_text_ctx,
  337 + model.dims.n_text_state,
  338 + ),
  339 + device=mel.device,
  340 + )
  341 + n_layer_self_v_cache = torch.zeros(
  342 + (
  343 + len(model.decoder.blocks),
  344 + n_audio,
  345 + model.dims.n_text_ctx,
  346 + model.dims.n_text_state,
  347 + ),
  348 + device=mel.device,
  349 + )
  350 + offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
  351 + logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
  352 + tokens,
  353 + n_layer_self_k_cache,
  354 + n_layer_self_v_cache,
  355 + n_layer_cross_k,
  356 + n_layer_cross_v,
  357 + offset,
  358 + )
  359 + assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab)
  360 + assert n_layer_self_k_cache.shape == (
  361 + model.dims.n_text_layer,
  362 + n_audio,
  363 + model.dims.n_text_ctx,
  364 + model.dims.n_text_state,
  365 + )
  366 + assert n_layer_self_v_cache.shape == (
  367 + model.dims.n_text_layer,
  368 + n_audio,
  369 + model.dims.n_text_ctx,
  370 + model.dims.n_text_state,
  371 + )
  372 +
  373 + offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
  374 + tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
  375 +
  376 + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
  377 + tokens,
  378 + n_layer_self_k_cache,
  379 + n_layer_self_v_cache,
  380 + n_layer_cross_k,
  381 + n_layer_cross_v,
  382 + offset,
  383 + )
  384 +
  385 + decoder_filename = f"{name}-decoder.onnx"
  386 + torch.onnx.export(
  387 + decoder,
  388 + (
  389 + tokens,
  390 + n_layer_self_k_cache,
  391 + n_layer_self_v_cache,
  392 + n_layer_cross_k,
  393 + n_layer_cross_v,
  394 + offset,
  395 + ),
  396 + decoder_filename,
  397 + opset_version=opset_version,
  398 + input_names=[
  399 + "tokens",
  400 + "in_n_layer_self_k_cache",
  401 + "in_n_layer_self_v_cache",
  402 + "n_layer_cross_k",
  403 + "n_layer_cross_v",
  404 + "offset",
  405 + ],
  406 + output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
  407 + dynamic_axes={
  408 + "tokens": {0: "n_audio", 1: "n_tokens"},
  409 + "in_n_layer_self_k_cache": {1: "n_audio"},
  410 + "in_n_layer_self_v_cache": {1: "n_audio"},
  411 + "n_layer_cross_k": {1: "n_audio"},
  412 + "n_layer_cross_v": {1: "n_audio"},
  413 + },
  414 + )
  415 +
  416 + # Generate int8 quantization models
  417 + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
  418 +
  419 + print("Generate int8 quantization models")
  420 +
  421 + encoder_filename_int8 = f"{name}-encoder.int8.onnx"
  422 + quantize_dynamic(
  423 + model_input=encoder_filename,
  424 + model_output=encoder_filename_int8,
  425 + op_types_to_quantize=["MatMul"],
  426 + weight_type=QuantType.QInt8,
  427 + )
  428 +
  429 + decoder_filename_int8 = f"{name}-decoder.int8.onnx"
  430 + quantize_dynamic(
  431 + model_input=decoder_filename,
  432 + model_output=decoder_filename_int8,
  433 + op_types_to_quantize=["MatMul"],
  434 + weight_type=QuantType.QInt8,
  435 + )
  436 +
  437 +
  438 +if __name__ == "__main__":
  439 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +"""
  4 +Please first run ./export-onnx.py
  5 +before you run this script
  6 +"""
  7 +import base64
  8 +from typing import Tuple
  9 +
  10 +import kaldi_native_fbank as knf
  11 +import onnxruntime as ort
  12 +import torch
  13 +
  14 +import whisper
  15 +import argparse
  16 +
  17 +
  18 +def get_args():
  19 + parser = argparse.ArgumentParser()
  20 + parser.add_argument(
  21 + "--model",
  22 + type=str,
  23 + required=True,
  24 + # fmt: off
  25 + choices=[
  26 + "tiny", "tiny.en", "base", "base.en",
  27 + "small", "small.en", "medium", "medium.en",
  28 + "large", "large-v1", "large-v2"],
  29 + # fmt: on
  30 + )
  31 + return parser.parse_args()
  32 +
  33 +
  34 +class OnnxModel:
  35 + def __init__(
  36 + self,
  37 + encoder: str,
  38 + decoder: str,
  39 + ):
  40 + session_opts = ort.SessionOptions()
  41 + session_opts.inter_op_num_threads = 1
  42 + session_opts.intra_op_num_threads = 4
  43 +
  44 + self.session_opts = session_opts
  45 +
  46 + self.init_encoder(encoder)
  47 + self.init_decoder(decoder)
  48 +
  49 + def init_encoder(self, encoder: str):
  50 + self.encoder = ort.InferenceSession(
  51 + encoder,
  52 + sess_options=self.session_opts,
  53 + )
  54 +
  55 + meta = self.encoder.get_modelmeta().custom_metadata_map
  56 + self.n_text_layer = int(meta["n_text_layer"])
  57 + self.n_text_ctx = int(meta["n_text_ctx"])
  58 + self.n_text_state = int(meta["n_text_state"])
  59 + self.sot = int(meta["sot"])
  60 + self.eot = int(meta["eot"])
  61 + self.translate = int(meta["translate"])
  62 + self.no_timestamps = int(meta["no_timestamps"])
  63 + self.no_speech = int(meta["no_speech"])
  64 + self.blank = int(meta["blank_id"])
  65 +
  66 + self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
  67 +
  68 + self.is_multilingual = int(meta["is_multilingual"]) == 1
  69 +
  70 + def init_decoder(self, decoder: str):
  71 + self.decoder = ort.InferenceSession(
  72 + decoder,
  73 + sess_options=self.session_opts,
  74 + )
  75 +
  76 + def run_encoder(
  77 + self,
  78 + mel: torch.Tensor,
  79 + ) -> Tuple[torch.Tensor, torch.Tensor]:
  80 + n_layer_cross_k, n_layer_cross_v = self.encoder.run(
  81 + [
  82 + self.encoder.get_outputs()[0].name,
  83 + self.encoder.get_outputs()[1].name,
  84 + ],
  85 + {
  86 + self.encoder.get_inputs()[0].name: mel.numpy(),
  87 + },
  88 + )
  89 + return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
  90 +
  91 + def run_decoder(
  92 + self,
  93 + tokens: torch.Tensor,
  94 + n_layer_self_k_cache: torch.Tensor,
  95 + n_layer_self_v_cache: torch.Tensor,
  96 + n_layer_cross_k: torch.Tensor,
  97 + n_layer_cross_v: torch.Tensor,
  98 + offset: torch.Tensor,
  99 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  100 + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
  101 + [
  102 + self.decoder.get_outputs()[0].name,
  103 + self.decoder.get_outputs()[1].name,
  104 + self.decoder.get_outputs()[2].name,
  105 + ],
  106 + {
  107 + self.decoder.get_inputs()[0].name: tokens.numpy(),
  108 + self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
  109 + self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
  110 + self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
  111 + self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
  112 + self.decoder.get_inputs()[5].name: offset.numpy(),
  113 + },
  114 + )
  115 + return (
  116 + torch.from_numpy(logits),
  117 + torch.from_numpy(out_n_layer_self_k_cache),
  118 + torch.from_numpy(out_n_layer_self_v_cache),
  119 + )
  120 +
  121 + def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
  122 + batch_size = 1
  123 + n_layer_self_k_cache = torch.zeros(
  124 + self.n_text_layer,
  125 + batch_size,
  126 + self.n_text_ctx,
  127 + self.n_text_state,
  128 + )
  129 + n_layer_self_v_cache = torch.zeros(
  130 + self.n_text_layer,
  131 + batch_size,
  132 + self.n_text_ctx,
  133 + self.n_text_state,
  134 + )
  135 + return n_layer_self_k_cache, n_layer_self_v_cache
  136 +
  137 + def suppress_tokens(self, logits, is_initial: bool) -> None:
  138 + # suppress blank
  139 + if is_initial:
  140 + logits[self.eot] = float("-inf")
  141 + logits[self.blank] = float("-inf")
  142 +
  143 + # suppress <|notimestamps|>
  144 + logits[self.no_timestamps] = float("-inf")
  145 +
  146 + logits[self.sot] = float("-inf")
  147 + logits[self.no_speech] = float("-inf")
  148 +
  149 + # logits is changed in-place
  150 + logits[self.translate] = float("-inf")
  151 +
  152 +
  153 +def load_tokens(filename):
  154 + tokens = dict()
  155 + with open(filename, "r") as f:
  156 + for line in f:
  157 + t, i = line.split()
  158 + tokens[int(i)] = t
  159 + return tokens
  160 +
  161 +
  162 +def main():
  163 + args = get_args()
  164 + name = args.model
  165 +
  166 + encoder = f"./{name}-encoder.onnx"
  167 + decoder = f"./{name}-decoder.onnx"
  168 + audio = whisper.load_audio("0.wav")
  169 +
  170 + features = []
  171 + online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
  172 + online_whisper_fbank.accept_waveform(16000, audio)
  173 + online_whisper_fbank.input_finished()
  174 + for i in range(online_whisper_fbank.num_frames_ready):
  175 + f = online_whisper_fbank.get_frame(i)
  176 + f = torch.from_numpy(f)
  177 + features.append(f)
  178 +
  179 + features = torch.stack(features)
  180 +
  181 + log_spec = torch.clamp(features, min=1e-10).log10()
  182 + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
  183 + mel = (log_spec + 4.0) / 4.0
  184 + target = 3000
  185 + mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
  186 + mel = mel.t().unsqueeze(0)
  187 +
  188 + model = OnnxModel(encoder, decoder)
  189 + n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
  190 + n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
  191 +
  192 + tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
  193 + offset = torch.zeros(1, dtype=torch.int64)
  194 + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
  195 + tokens=tokens,
  196 + n_layer_self_k_cache=n_layer_self_k_cache,
  197 + n_layer_self_v_cache=n_layer_self_v_cache,
  198 + n_layer_cross_k=n_layer_cross_k,
  199 + n_layer_cross_v=n_layer_cross_v,
  200 + offset=offset,
  201 + )
  202 + # logits.shape (batch_size, tokens.shape[1], vocab_size)
  203 + logits = logits[0, -1]
  204 + model.suppress_tokens(logits, is_initial=True)
  205 + # logits = logits.softmax(dim=-1)
  206 + # for greedy search, we don't need to compute softmax or log_softmax
  207 + max_token_id = logits.argmax(dim=-1)
  208 + results = []
  209 + for i in range(model.n_text_ctx):
  210 + if max_token_id == model.eot:
  211 + break
  212 + results.append(max_token_id.item())
  213 + tokens = torch.tensor([[results[-1]]])
  214 + offset += 1
  215 +
  216 + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
  217 + tokens=tokens,
  218 + n_layer_self_k_cache=n_layer_self_k_cache,
  219 + n_layer_self_v_cache=n_layer_self_v_cache,
  220 + n_layer_cross_k=n_layer_cross_k,
  221 + n_layer_cross_v=n_layer_cross_v,
  222 + offset=offset,
  223 + )
  224 + logits = logits[0, -1]
  225 + model.suppress_tokens(logits, is_initial=False)
  226 + max_token_id = logits.argmax(dim=-1)
  227 + token_table = load_tokens(f"./{name}-tokens.txt")
  228 + s = b""
  229 + for i in results:
  230 + if i in token_table:
  231 + s += base64.b64decode(token_table[i])
  232 + else:
  233 + print("oov", i)
  234 +
  235 + print(s.decode().strip())
  236 + print(results)
  237 + print(model.sot_sequence)
  238 +
  239 +
  240 +if __name__ == "__main__":
  241 + main()
@@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON) @@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON)
11 endif() 11 endif()
12 12
13 set(sources 13 set(sources
  14 + base64-decode.cc
14 cat.cc 15 cat.cc
15 context-graph.cc 16 context-graph.cc
16 endpoint.cc 17 endpoint.cc
@@ -35,6 +36,9 @@ set(sources @@ -35,6 +36,9 @@ set(sources
35 offline-transducer-model-config.cc 36 offline-transducer-model-config.cc
36 offline-transducer-model.cc 37 offline-transducer-model.cc
37 offline-transducer-modified-beam-search-decoder.cc 38 offline-transducer-modified-beam-search-decoder.cc
  39 + offline-whisper-greedy-search-decoder.cc
  40 + offline-whisper-model-config.cc
  41 + offline-whisper-model.cc
38 online-conformer-transducer-model.cc 42 online-conformer-transducer-model.cc
39 online-lm-config.cc 43 online-lm-config.cc
40 online-lm.cc 44 online-lm.cc
@@ -50,12 +54,12 @@ set(sources @@ -50,12 +54,12 @@ set(sources
50 online-zipformer-transducer-model.cc 54 online-zipformer-transducer-model.cc
51 online-zipformer2-transducer-model.cc 55 online-zipformer2-transducer-model.cc
52 onnx-utils.cc 56 onnx-utils.cc
53 - session.cc  
54 packed-sequence.cc 57 packed-sequence.cc
55 pad-sequence.cc 58 pad-sequence.cc
56 parse-options.cc 59 parse-options.cc
57 provider.cc 60 provider.cc
58 resample.cc 61 resample.cc
  62 + session.cc
59 slice.cc 63 slice.cc
60 stack.cc 64 stack.cc
61 symbol-table.cc 65 symbol-table.cc
  1 +// sherpa-onnx/csrc/base64-decode.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/base64-decode.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static int32_t Ord(char c) {
  12 + if (c >= 'A' && c <= 'Z') {
  13 + return c - 'A';
  14 + } else if (c >= 'a' && c <= 'z') {
  15 + return c - 'a' + ('Z' - 'A') + 1;
  16 + } else if (c >= '0' && c <= '9') {
  17 + return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
  18 + } else if (c == '+') {
  19 + return 62;
  20 + } else if (c == '/') {
  21 + return 63;
  22 + }
  23 +
  24 + SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c);
  25 +
  26 + exit(-1);
  27 +}
  28 +
  29 +// see
  30 +// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243
  31 +std::string Base64Decode(const std::string &s) {
  32 + if (s.empty()) {
  33 + SHERPA_ONNX_LOGE("Empty string!");
  34 + exit(-1);
  35 + }
  36 +
  37 + int32_t n = s.size() / 4 * 3;
  38 +
  39 + std::string ans;
  40 + ans.reserve(n);
  41 +
  42 + int32_t i = 0;
  43 + while (i < static_cast<int32_t>(s.size())) {
  44 + if (s[i] == '=') {
  45 + return " ";
  46 + }
  47 +
  48 + int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
  49 + ans.push_back(first);
  50 +
  51 + if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
  52 + int32_t second =
  53 + ((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
  54 + ans.push_back(second);
  55 +
  56 + if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
  57 + int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
  58 + ans.push_back(third);
  59 + }
  60 + }
  61 + i += 4;
  62 + }
  63 +
  64 + return ans;
  65 +}
  66 +
  67 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/base64-decode.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_
  6 +#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_
  7 +
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +/** @param s A base64 encoded string.
  13 + * @return Return the decoded string.
  14 + */
  15 +std::string Base64Decode(const std::string &s);
  16 +
  17 +} // namespace sherpa_onnx
  18 +
  19 +#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_
1 -  
2 // sherpa-onnx/csrc/macros.h 1 // sherpa-onnx/csrc/macros.h
3 // 2 //
4 // Copyright 2023 Xiaomi Corporation 3 // Copyright 2023 Xiaomi Corporation
@@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
14 transducer.Register(po); 14 transducer.Register(po);
15 paraformer.Register(po); 15 paraformer.Register(po);
16 nemo_ctc.Register(po); 16 nemo_ctc.Register(po);
  17 + whisper.Register(po);
17 18
18 po->Register("tokens", &tokens, "Path to tokens.txt"); 19 po->Register("tokens", &tokens, "Path to tokens.txt");
19 20
@@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
28 29
29 po->Register("model-type", &model_type, 30 po->Register("model-type", &model_type,
30 "Specify it to reduce model initialization time. " 31 "Specify it to reduce model initialization time. "
31 - "Valid values are: transducer, paraformer, nemo_ctc. " 32 + "Valid values are: transducer, paraformer, nemo_ctc, whisper."
32 "All other values lead to loading the model twice."); 33 "All other values lead to loading the model twice.");
33 } 34 }
34 35
@@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const { @@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const {
51 return nemo_ctc.Validate(); 52 return nemo_ctc.Validate();
52 } 53 }
53 54
  55 + if (!whisper.encoder.empty()) {
  56 + return whisper.Validate();
  57 + }
  58 +
54 return transducer.Validate(); 59 return transducer.Validate();
55 } 60 }
56 61
@@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const { @@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const {
61 os << "transducer=" << transducer.ToString() << ", "; 66 os << "transducer=" << transducer.ToString() << ", ";
62 os << "paraformer=" << paraformer.ToString() << ", "; 67 os << "paraformer=" << paraformer.ToString() << ", ";
63 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 68 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
  69 + os << "whisper=" << whisper.ToString() << ", ";
64 os << "tokens=\"" << tokens << "\", "; 70 os << "tokens=\"" << tokens << "\", ";
65 os << "num_threads=" << num_threads << ", "; 71 os << "num_threads=" << num_threads << ", ";
66 os << "debug=" << (debug ? "True" : "False") << ", "; 72 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" 9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
11 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  12 +#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
14 15
@@ -16,6 +17,7 @@ struct OfflineModelConfig { @@ -16,6 +17,7 @@ struct OfflineModelConfig {
16 OfflineTransducerModelConfig transducer; 17 OfflineTransducerModelConfig transducer;
17 OfflineParaformerModelConfig paraformer; 18 OfflineParaformerModelConfig paraformer;
18 OfflineNemoEncDecCtcModelConfig nemo_ctc; 19 OfflineNemoEncDecCtcModelConfig nemo_ctc;
  20 + OfflineWhisperModelConfig whisper;
19 21
20 std::string tokens; 22 std::string tokens;
21 int32_t num_threads = 2; 23 int32_t num_threads = 2;
@@ -37,11 +39,13 @@ struct OfflineModelConfig { @@ -37,11 +39,13 @@ struct OfflineModelConfig {
37 OfflineModelConfig(const OfflineTransducerModelConfig &transducer, 39 OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
38 const OfflineParaformerModelConfig &paraformer, 40 const OfflineParaformerModelConfig &paraformer,
39 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 41 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
  42 + const OfflineWhisperModelConfig &whisper,
40 const std::string &tokens, int32_t num_threads, bool debug, 43 const std::string &tokens, int32_t num_threads, bool debug,
41 const std::string &provider, const std::string &model_type) 44 const std::string &provider, const std::string &model_type)
42 : transducer(transducer), 45 : transducer(transducer),
43 paraformer(paraformer), 46 paraformer(paraformer),
44 nemo_ctc(nemo_ctc), 47 nemo_ctc(nemo_ctc),
  48 + whisper(whisper),
45 tokens(tokens), 49 tokens(tokens),
46 num_threads(num_threads), 50 num_threads(num_threads),
47 debug(debug), 51 debug(debug),
@@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
16 16
17 bool OfflineNemoEncDecCtcModelConfig::Validate() const { 17 bool OfflineNemoEncDecCtcModelConfig::Validate() const {
18 if (!FileExists(model)) { 18 if (!FileExists(model)) {
19 - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); 19 + SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
20 return false; 20 return false;
21 } 21 }
22 22
@@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) { @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
15 15
16 bool OfflineParaformerModelConfig::Validate() const { 16 bool OfflineParaformerModelConfig::Validate() const {
17 if (!FileExists(model)) { 17 if (!FileExists(model)) {
18 - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); 18 + SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
19 return false; 19 return false;
20 } 20 }
21 21
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 11 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
12 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
13 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 13 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
  14 +#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
14 #include "sherpa-onnx/csrc/onnx-utils.h" 15 #include "sherpa-onnx/csrc/onnx-utils.h"
15 #include "sherpa-onnx/csrc/text-utils.h" 16 #include "sherpa-onnx/csrc/text-utils.h"
16 17
@@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
26 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 27 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
27 } else if (model_type == "nemo_ctc") { 28 } else if (model_type == "nemo_ctc") {
28 return std::make_unique<OfflineRecognizerCtcImpl>(config); 29 return std::make_unique<OfflineRecognizerCtcImpl>(config);
  30 + } else if (model_type == "whisper") {
  31 + return std::make_unique<OfflineRecognizerWhisperImpl>(config);
29 } else { 32 } else {
30 SHERPA_ONNX_LOGE( 33 SHERPA_ONNX_LOGE(
31 "Invalid model_type: %s. Trying to load the model to get its type", 34 "Invalid model_type: %s. Trying to load the model to get its type",
@@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
43 model_filename = config.model_config.paraformer.model; 46 model_filename = config.model_config.paraformer.model;
44 } else if (!config.model_config.nemo_ctc.model.empty()) { 47 } else if (!config.model_config.nemo_ctc.model.empty()) {
45 model_filename = config.model_config.nemo_ctc.model; 48 model_filename = config.model_config.nemo_ctc.model;
  49 + } else if (!config.model_config.whisper.encoder.empty()) {
  50 + model_filename = config.model_config.whisper.encoder;
46 } else { 51 } else {
47 SHERPA_ONNX_LOGE("Please provide a model"); 52 SHERPA_ONNX_LOGE("Please provide a model");
48 exit(-1); 53 exit(-1);
@@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
77 "\n " 82 "\n "
78 "https://huggingface.co/csukuangfj/" 83 "https://huggingface.co/csukuangfj/"
79 "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" 84 "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
  85 + "\n "
  86 + "(3) Whisper"
80 "\n"); 87 "\n");
81 exit(-1); 88 exit(-1);
82 } 89 }
@@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
95 return std::make_unique<OfflineRecognizerCtcImpl>(config); 102 return std::make_unique<OfflineRecognizerCtcImpl>(config);
96 } 103 }
97 104
  105 + if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
  106 + return std::make_unique<OfflineRecognizerWhisperImpl>(config);
  107 + }
  108 +
98 SHERPA_ONNX_LOGE( 109 SHERPA_ONNX_LOGE(
99 "\nUnsupported model_type: %s\n" 110 "\nUnsupported model_type: %s\n"
100 "We support only the following model types at present: \n" 111 "We support only the following model types at present: \n"
101 " - Non-streaming transducer models from icefall\n" 112 " - Non-streaming transducer models from icefall\n"
102 " - Non-streaming Paraformer models from FunASR\n" 113 " - Non-streaming Paraformer models from FunASR\n"
103 - " - EncDecCTCModelBPE models from NeMo\n", 114 + " - EncDecCTCModelBPE models from NeMo\n"
  115 + " - Whisper models\n",
104 model_type.c_str()); 116 model_type.c_str());
105 117
106 exit(-1); 118 exit(-1);
  1 +// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <cmath>
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#include "sherpa-onnx/csrc/offline-model-config.h"
  16 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  17 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  18 +#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
  19 +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
  20 +#include "sherpa-onnx/csrc/offline-whisper-model.h"
  21 +#include "sherpa-onnx/csrc/symbol-table.h"
  22 +#include "sherpa-onnx/csrc/transpose.h"
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
  27 + const SymbolTable &sym_table) {
  28 + OfflineRecognitionResult r;
  29 + r.tokens.reserve(src.tokens.size());
  30 +
  31 + for (auto i : src.tokens) {
  32 + if (!sym_table.contains(i)) {
  33 + continue;
  34 + }
  35 +
  36 + const auto &s = sym_table[i];
  37 + r.text += s;
  38 + r.tokens.push_back(s);
  39 + }
  40 +
  41 + return r;
  42 +}
  43 +
  44 +class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
  45 + public:
  46 + explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
  47 + : config_(config),
  48 + symbol_table_(config_.model_config.tokens),
  49 + model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
  50 + // tokens.txt from whisper is base64 encoded, so we need to decode it
  51 + symbol_table_.ApplyBase64Decode();
  52 +
  53 + if (config.decoding_method == "greedy_search") {
  54 + decoder_ =
  55 + std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
  56 + } else {
  57 + SHERPA_ONNX_LOGE(
  58 + "Only greedy_search is supported at present for whisper. Given %s",
  59 + config.decoding_method.c_str());
  60 + exit(-1);
  61 + }
  62 + }
  63 +
  64 + std::unique_ptr<OfflineStream> CreateStream() const override {
  65 + return std::make_unique<OfflineStream>(WhisperTag{});
  66 + }
  67 +
  68 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  69 + // batch decoding is not implemented yet
  70 + for (int32_t i = 0; i != n; ++i) {
  71 + DecodeStream(ss[i]);
  72 + }
  73 + }
  74 +
  75 + private:
  76 + void DecodeStream(OfflineStream *s) const {
  77 + int32_t max_num_frames = 3000;
  78 + auto memory_info =
  79 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  80 +
  81 + int32_t feat_dim = s->FeatureDim();
  82 + std::vector<float> f = s->GetFrames();
  83 + int32_t num_frames = f.size() / feat_dim;
  84 +
  85 + if (num_frames > max_num_frames) {
  86 + SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported.");
  87 + exit(-1);
  88 + }
  89 +
  90 + NormalizeFeatures(f.data(), num_frames, feat_dim);
  91 +
  92 + std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
  93 +
  94 + Ort::Value mel = Ort::Value::CreateTensor<float>(
  95 + model_->Allocator(), shape.data(), shape.size());
  96 + float *p_mel = mel.GetTensorMutableData<float>();
  97 + std::copy(f.begin(), f.end(), p_mel);
  98 +
  99 + memset(p_mel + f.size(), 0,
  100 + (max_num_frames - num_frames) * feat_dim * sizeof(float));
  101 + mel = Transpose12(model_->Allocator(), &mel);
  102 +
  103 + auto cross_kv = model_->ForwardEncoder(std::move(mel));
  104 + auto results =
  105 + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
  106 +
  107 + auto r = Convert(results[0], symbol_table_);
  108 + s->SetResult(r);
  109 + }
  110 +
  111 + private:
  112 + static void NormalizeFeatures(float *features, int32_t num_frames,
  113 + int32_t feat_dim) {
  114 + // log_spec = torch.clamp(features, min=1e-10).log10()
  115 + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
  116 + // mel = (log_spec + 4.0) / 4.0
  117 +
  118 + int32_t n = num_frames * feat_dim;
  119 + float max_v = -1e20;
  120 + for (int32_t i = 0; i != n; ++i) {
  121 + float f = features[i];
  122 +
  123 + f = std::max<float>(f, 1e-10);
  124 + f = std::log10(f);
  125 +
  126 + max_v = std::max(f, max_v);
  127 +
  128 + features[i] = f;
  129 + }
  130 +
  131 + max_v -= 8;
  132 +
  133 + for (int32_t i = 0; i != n; ++i) {
  134 + float f = features[i];
  135 + f = std::max(f, max_v);
  136 +
  137 + f = (f + 4) / 4;
  138 +
  139 + features[i] = f;
  140 + }
  141 + }
  142 +
  143 + private:
  144 + OfflineRecognizerConfig config_;
  145 + SymbolTable symbol_table_;
  146 + std::unique_ptr<OfflineWhisperModel> model_;
  147 + std::unique_ptr<OfflineWhisperDecoder> decoder_;
  148 +};
  149 +
  150 +} // namespace sherpa_onnx
  151 +
  152 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
@@ -86,6 +86,15 @@ class OfflineStream::Impl { @@ -86,6 +86,15 @@ class OfflineStream::Impl {
86 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 86 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
87 } 87 }
88 88
  89 + Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
  90 + : context_graph_(context_graph) {
  91 + config_.normalize_samples = true;
  92 + opts_.frame_opts.samp_freq = 16000;
  93 + opts_.mel_opts.num_bins = 80;
  94 + whisper_fbank_ =
  95 + std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
  96 + }
  97 +
89 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 98 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
90 if (config_.normalize_samples) { 99 if (config_.normalize_samples) {
91 AcceptWaveformImpl(sampling_rate, waveform, n); 100 AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -117,20 +126,35 @@ class OfflineStream::Impl { @@ -117,20 +126,35 @@ class OfflineStream::Impl {
117 lowpass_filter_width); 126 lowpass_filter_width);
118 std::vector<float> samples; 127 std::vector<float> samples;
119 resampler->Resample(waveform, n, true, &samples); 128 resampler->Resample(waveform, n, true, &samples);
120 - fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),  
121 - samples.size());  
122 - fbank_->InputFinished(); 129 +
  130 + if (fbank_) {
  131 + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
  132 + samples.size());
  133 + fbank_->InputFinished();
  134 + } else {
  135 + whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
  136 + samples.data(), samples.size());
  137 + whisper_fbank_->InputFinished();
  138 + }
  139 +
123 return; 140 return;
124 - } 141 + } // if (sampling_rate != opts_.frame_opts.samp_freq)
125 142
126 - fbank_->AcceptWaveform(sampling_rate, waveform, n);  
127 - fbank_->InputFinished(); 143 + if (fbank_) {
  144 + fbank_->AcceptWaveform(sampling_rate, waveform, n);
  145 + fbank_->InputFinished();
  146 + } else {
  147 + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
  148 + whisper_fbank_->InputFinished();
  149 + }
128 } 150 }
129 151
130 int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } 152 int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
131 153
132 std::vector<float> GetFrames() const { 154 std::vector<float> GetFrames() const {
133 - int32_t n = fbank_->NumFramesReady(); 155 + int32_t n =
  156 + fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
  157 +
134 assert(n > 0 && "Please first call AcceptWaveform()"); 158 assert(n > 0 && "Please first call AcceptWaveform()");
135 159
136 int32_t feature_dim = FeatureDim(); 160 int32_t feature_dim = FeatureDim();
@@ -140,7 +164,8 @@ class OfflineStream::Impl { @@ -140,7 +164,8 @@ class OfflineStream::Impl {
140 float *p = features.data(); 164 float *p = features.data();
141 165
142 for (int32_t i = 0; i != n; ++i) { 166 for (int32_t i = 0; i != n; ++i) {
143 - const float *f = fbank_->GetFrame(i); 167 + const float *f =
  168 + fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
144 std::copy(f, f + feature_dim, p); 169 std::copy(f, f + feature_dim, p);
145 p += feature_dim; 170 p += feature_dim;
146 } 171 }
@@ -191,6 +216,7 @@ class OfflineStream::Impl { @@ -191,6 +216,7 @@ class OfflineStream::Impl {
191 private: 216 private:
192 OfflineFeatureExtractorConfig config_; 217 OfflineFeatureExtractorConfig config_;
193 std::unique_ptr<knf::OnlineFbank> fbank_; 218 std::unique_ptr<knf::OnlineFbank> fbank_;
  219 + std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
194 knf::FbankOptions opts_; 220 knf::FbankOptions opts_;
195 OfflineRecognitionResult r_; 221 OfflineRecognitionResult r_;
196 ContextGraphPtr context_graph_; 222 ContextGraphPtr context_graph_;
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream( @@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
201 ContextGraphPtr context_graph /*= nullptr*/) 227 ContextGraphPtr context_graph /*= nullptr*/)
202 : impl_(std::make_unique<Impl>(config, context_graph)) {} 228 : impl_(std::make_unique<Impl>(config, context_graph)) {}
203 229
  230 +OfflineStream::OfflineStream(WhisperTag tag,
  231 + ContextGraphPtr context_graph /*= nullptr*/)
  232 + : impl_(std::make_unique<Impl>(tag, context_graph)) {}
  233 +
204 OfflineStream::~OfflineStream() = default; 234 OfflineStream::~OfflineStream() = default;
205 235
206 void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, 236 void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
@@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig { @@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig {
65 void Register(ParseOptions *po); 65 void Register(ParseOptions *po);
66 }; 66 };
67 67
  68 +struct WhisperTag {};
  69 +
68 class OfflineStream { 70 class OfflineStream {
69 public: 71 public:
70 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, 72 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
71 ContextGraphPtr context_graph = nullptr); 73 ContextGraphPtr context_graph = nullptr);
  74 +
  75 + explicit OfflineStream(WhisperTag tag,
  76 + ContextGraphPtr context_graph = nullptr);
72 ~OfflineStream(); 77 ~OfflineStream();
73 78
74 /** 79 /**
@@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { @@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
18 18
19 bool OfflineTransducerModelConfig::Validate() const { 19 bool OfflineTransducerModelConfig::Validate() const {
20 if (!FileExists(encoder_filename)) { 20 if (!FileExists(encoder_filename)) {
21 - SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); 21 + SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
  22 + encoder_filename.c_str());
22 return false; 23 return false;
23 } 24 }
24 25
25 if (!FileExists(decoder_filename)) { 26 if (!FileExists(decoder_filename)) {
26 - SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); 27 + SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
  28 + decoder_filename.c_str());
27 return false; 29 return false;
28 } 30 }
29 31
30 if (!FileExists(joiner_filename)) { 32 if (!FileExists(joiner_filename)) {
31 - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); 33 + SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
  34 + joiner_filename.c_str());
32 return false; 35 return false;
33 } 36 }
34 37
  1 +// sherpa-onnx/csrc/offline-whisper-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineWhisperDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int32_t> tokens;
  17 +};
  18 +
  19 +class OfflineWhisperDecoder {
  20 + public:
  21 + virtual ~OfflineWhisperDecoder() = default;
  22 +
  23 + /** Run beam search given the output from the whisper encoder model.
  24 + *
  25 + * @param n_layer_cross_k A 4-D tensor of shape
  26 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  27 + * @param n_layer_cross_v A 4-D tensor of shape
  28 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  29 + *
  30 + * @return Return a vector of size `N` containing the decoded results.
  31 + */
  32 + virtual std::vector<OfflineWhisperDecoderResult> Decode(
  33 + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
  34 +};
  35 +
  36 +} // namespace sherpa_onnx
  37 +
  38 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +std::vector<OfflineWhisperDecoderResult>
  13 +OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
  14 + Ort::Value cross_v) {
  15 + auto memory_info =
  16 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  17 +
  18 + auto self_kv_cache = model_->GetInitialSelfKVCache();
  19 +
  20 + std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
  21 + int32_t batch_size = 1;
  22 + std::array<int64_t, 2> token_shape{
  23 + batch_size, static_cast<int64_t>(initial_tokens.size())};
  24 +
  25 + Ort::Value tokens = Ort::Value::CreateTensor(
  26 + memory_info, initial_tokens.data(), initial_tokens.size(),
  27 + token_shape.data(), token_shape.size());
  28 +
  29 + std::array<int64_t, 1> offset_shape{1};
  30 + Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
  31 + model_->Allocator(), offset_shape.data(), offset_shape.size());
  32 + *(offset.GetTensorMutableData<int64_t>()) = 0;
  33 +
  34 + auto decoder_out = model_->ForwardDecoder(
  35 + std::move(tokens), std::move(self_kv_cache.first),
  36 + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
  37 + std::move(offset));
  38 +
  39 + const auto &logits = std::get<0>(decoder_out);
  40 + const float *p_logits = logits.GetTensorData<float>();
  41 +
  42 + auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
  43 + int32_t vocab_size = logits_shape[2];
  44 +
  45 + int32_t max_token_id = static_cast<int32_t>(std::distance(
  46 + p_logits, std::max_element(p_logits, p_logits + vocab_size)));
  47 +
  48 + int32_t n_text_ctx = model_->TextCtx();
  49 +
  50 + std::vector<int32_t> predicted_tokens;
  51 + for (int32_t i = 0; i < n_text_ctx; ++i) {
  52 + if (max_token_id == model_->EOT()) {
  53 + break;
  54 + }
  55 +
  56 + predicted_tokens.push_back(max_token_id);
  57 +
  58 + std::array<int64_t, 2> token_shape{1, 1};
  59 + Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
  60 + model_->Allocator(), token_shape.data(), token_shape.size());
  61 + int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
  62 + p_tokens[0] = max_token_id;
  63 +
  64 + int64_t *p_offset =
  65 + std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
  66 +
  67 + if (i == 0) {
  68 + *p_offset = initial_tokens.size();
  69 + } else {
  70 + *p_offset += 1;
  71 + }
  72 +
  73 + decoder_out = model_->ForwardDecoder(std::move(tokens),
  74 + std::move(std::get<1>(decoder_out)),
  75 + std::move(std::get<2>(decoder_out)),
  76 + std::move(std::get<3>(decoder_out)),
  77 + std::move(std::get<4>(decoder_out)),
  78 + std::move(std::get<5>(decoder_out)));
  79 +
  80 + const auto &logits = std::get<0>(decoder_out);
  81 + const float *p_logits = logits.GetTensorData<float>();
  82 +
  83 + max_token_id = static_cast<int64_t>(std::distance(
  84 + p_logits, std::max_element(p_logits, p_logits + vocab_size)));
  85 + }
  86 +
  87 + std::vector<OfflineWhisperDecoderResult> ans(1);
  88 + ans[0].tokens = std::move(predicted_tokens);
  89 +
  90 + return ans;
  91 +}
  92 +
  93 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-whisper-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
  16 + public:
  17 + explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
  18 + : model_(model) {}
  19 +
  20 + std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
  21 + Ort::Value cross_v) override;
  22 +
  23 + private:
  24 + OfflineWhisperModel *model_; // not owned
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-whisper-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineWhisperModelConfig::Register(ParseOptions *po) {
  13 + po->Register("whisper-encoder", &encoder,
  14 + "Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, "
  15 + "medium.en-encoder.onnx.");
  16 +
  17 + po->Register("whisper-decoder", &decoder,
  18 + "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
  19 + "medium.en-decoder.onnx.");
  20 +}
  21 +
  22 +bool OfflineWhisperModelConfig::Validate() const {
  23 + if (!FileExists(encoder)) {
  24 + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
  25 + return false;
  26 + }
  27 +
  28 + if (!FileExists(decoder)) {
  29 + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
  30 + return false;
  31 + }
  32 +
  33 + return true;
  34 +}
  35 +
  36 +std::string OfflineWhisperModelConfig::ToString() const {
  37 + std::ostringstream os;
  38 +
  39 + os << "OfflineWhisperModelConfig(";
  40 + os << "encoder=\"" << encoder << "\", ";
  41 + os << "decoder=\"" << decoder << "\")";
  42 +
  43 + return os.str();
  44 +}
  45 +
  46 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-whisper-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineWhisperModelConfig {
  14 + std::string encoder;
  15 + std::string decoder;
  16 +
  17 + OfflineWhisperModelConfig() = default;
  18 + OfflineWhisperModelConfig(const std::string &encoder,
  19 + const std::string &decoder)
  20 + : encoder(encoder), decoder(decoder) {}
  21 +
  22 + void Register(ParseOptions *po);
  23 + bool Validate() const;
  24 +
  25 + std::string ToString() const;
  26 +};
  27 +
  28 +} // namespace sherpa_onnx
  29 +
  30 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-whisper-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-whisper-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +#include <tuple>
  10 +#include <utility>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +#include "sherpa-onnx/csrc/session.h"
  15 +#include "sherpa-onnx/csrc/text-utils.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OfflineWhisperModel::Impl {
  20 + public:
  21 + explicit Impl(const OfflineModelConfig &config)
  22 + : config_(config),
  23 + env_(ORT_LOGGING_LEVEL_ERROR),
  24 + sess_opts_(GetSessionOptions(config)),
  25 + allocator_{} {
  26 + {
  27 + auto buf = ReadFile(config.whisper.encoder);
  28 + InitEncoder(buf.data(), buf.size());
  29 + }
  30 +
  31 + {
  32 + auto buf = ReadFile(config.whisper.decoder);
  33 + InitDecoder(buf.data(), buf.size());
  34 + }
  35 + }
  36 +
  37 + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
  38 + auto encoder_out = encoder_sess_->Run(
  39 + {}, encoder_input_names_ptr_.data(), &features, 1,
  40 + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
  41 +
  42 + return {std::move(encoder_out[0]), std::move(encoder_out[1])};
  43 + }
  44 +
  45 + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  46 + Ort::Value>
  47 + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
  48 + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
  49 + Ort::Value n_layer_cross_v, Ort::Value offset) {
  50 + std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
  51 + std::move(n_layer_self_k_cache),
  52 + std::move(n_layer_self_v_cache),
  53 + std::move(n_layer_cross_k),
  54 + std::move(n_layer_cross_v),
  55 + std::move(offset)};
  56 +
  57 + auto decoder_out = decoder_sess_->Run(
  58 + {}, decoder_input_names_ptr_.data(), decoder_input.data(),
  59 + decoder_input.size(), decoder_output_names_ptr_.data(),
  60 + decoder_output_names_ptr_.size());
  61 +
  62 + return {std::move(decoder_out[0]), std::move(decoder_out[1]),
  63 + std::move(decoder_out[2]), std::move(decoder_input[3]),
  64 + std::move(decoder_input[4]), std::move(decoder_input[5])};
  65 + }
  66 +
  67 + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
  68 + std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
  69 +
  70 + Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
  71 + Allocator(), shape.data(), shape.size());
  72 +
  73 + Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
  74 + Allocator(), shape.data(), shape.size());
  75 +
  76 + auto n = shape[0] * shape[1] * shape[2] * shape[3];
  77 +
  78 + float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
  79 + float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
  80 +
  81 + memset(p_k, 0, sizeof(float) * n);
  82 + memset(p_v, 0, sizeof(float) * n);
  83 +
  84 + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
  85 + }
  86 +
  87 + OrtAllocator *Allocator() const { return allocator_; }
  88 +
  89 + const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
  90 +
  91 + int32_t EOT() const { return eot_; }
  92 +
  93 + int32_t TextCtx() const { return n_text_ctx_; }
  94 +
  95 + private:
  96 + void InitEncoder(void *model_data, size_t model_data_length) {
  97 + encoder_sess_ = std::make_unique<Ort::Session>(
  98 + env_, model_data, model_data_length, sess_opts_);
  99 +
  100 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  101 + &encoder_input_names_ptr_);
  102 +
  103 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  104 + &encoder_output_names_ptr_);
  105 +
  106 + // get meta data
  107 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  108 + if (config_.debug) {
  109 + std::ostringstream os;
  110 + os << "---encoder---\n";
  111 + PrintModelMetadata(os, meta_data);
  112 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  113 + }
  114 +
  115 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  116 + SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
  117 + SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
  118 + SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
  119 + SHERPA_ONNX_READ_META_DATA(sot_, "sot");
  120 + SHERPA_ONNX_READ_META_DATA(eot_, "eot");
  121 + SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
  122 + SHERPA_ONNX_READ_META_DATA(translate_, "translate");
  123 + SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
  124 + SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
  125 + SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
  126 + }
  127 +
  128 + void InitDecoder(void *model_data, size_t model_data_length) {
  129 + decoder_sess_ = std::make_unique<Ort::Session>(
  130 + env_, model_data, model_data_length, sess_opts_);
  131 +
  132 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  133 + &decoder_input_names_ptr_);
  134 +
  135 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  136 + &decoder_output_names_ptr_);
  137 + }
  138 +
  139 + private:
  140 + OfflineModelConfig config_;
  141 + Ort::Env env_;
  142 + Ort::SessionOptions sess_opts_;
  143 + Ort::AllocatorWithDefaultOptions allocator_;
  144 +
  145 + std::unique_ptr<Ort::Session> encoder_sess_;
  146 + std::unique_ptr<Ort::Session> decoder_sess_;
  147 +
  148 + std::vector<std::string> encoder_input_names_;
  149 + std::vector<const char *> encoder_input_names_ptr_;
  150 +
  151 + std::vector<std::string> encoder_output_names_;
  152 + std::vector<const char *> encoder_output_names_ptr_;
  153 +
  154 + std::vector<std::string> decoder_input_names_;
  155 + std::vector<const char *> decoder_input_names_ptr_;
  156 +
  157 + std::vector<std::string> decoder_output_names_;
  158 + std::vector<const char *> decoder_output_names_ptr_;
  159 +
  160 + // model meta data
  161 + int32_t n_text_layer_;
  162 + int32_t n_text_ctx_;
  163 + int32_t n_text_state_;
  164 + int32_t sot_;
  165 + int32_t eot_;
  166 + int32_t blank_;
  167 + int32_t translate_;
  168 + int32_t no_timestamps_;
  169 + int32_t no_speech_;
  170 + std::vector<int64_t> sot_sequence_;
  171 +};
  172 +
  173 +OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
  174 + : impl_(std::make_unique<Impl>(config)) {}
  175 +
  176 +OfflineWhisperModel::~OfflineWhisperModel() = default;
  177 +
  178 +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
  179 + Ort::Value features) {
  180 + return impl_->ForwardEncoder(std::move(features));
  181 +}
  182 +
  183 +std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  184 + Ort::Value>
  185 +OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
  186 + Ort::Value n_layer_self_k_cache,
  187 + Ort::Value n_layer_self_v_cache,
  188 + Ort::Value n_layer_cross_k,
  189 + Ort::Value n_layer_cross_v,
  190 + Ort::Value offset) {
  191 + return impl_->ForwardDecoder(
  192 + std::move(tokens), std::move(n_layer_self_k_cache),
  193 + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
  194 + std::move(n_layer_cross_v), std::move(offset));
  195 +}
  196 +
  197 +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
  198 + return impl_->GetInitialSelfKVCache();
  199 +}
  200 +
  201 +OrtAllocator *OfflineWhisperModel::Allocator() const {
  202 + return impl_->Allocator();
  203 +}
  204 +
  205 +const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
  206 + return impl_->GetInitialTokens();
  207 +}
  208 +
  209 +int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
  210 +
  211 +int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
  212 +
  213 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-whisper-model.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <tuple>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/offline-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineWhisperModel {
  18 + public:
  19 + explicit OfflineWhisperModel(const OfflineModelConfig &config);
  20 + ~OfflineWhisperModel();
  21 +
  22 + /** Run the encoder model.
  23 + *
  24 + * @param features A tensor of shape (N, C, T). It is changed in-place.
  25 + * C is 80 and T is 3000.
  26 + *
  27 + * @return Return a pair containing:
  28 + * - n_layer_cross_k: A 4-D tensor of shape
  29 + * (n_text_layer, N, n_audio_ctx, n_text_state)
  30 + * - n_layer_cross_v: A 4-D tensor of shape
  31 + * (n_text_layer, N, n_audio_ctx, n_text_state)
  32 + */
  33 + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
  34 +
  35 + /** Run the decoder model.
  36 + *
  37 + * @param tokens A int64 tensor of shape (N, num_words)
  38 + * @param n_layer_self_k_cache A 4-D tensor of shape
  39 + * (n_text_layer, N, n_text_ctx, n_text_state).
  40 + * @param n_layer_self_v_cache A 4-D tensor of shape
  41 + * (n_text_layer, N, n_text_ctx, n_text_state).
  42 + * @param n_layer_cross_k A 4-D tensor of shape
  43 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  44 + * @param n_layer_cross_v A 4-D tensor of shape
  45 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  46 + * @param offset A int64 tensor of shape (N,)
  47 + *
  48 + * @return Return a tuple containing 6 tensors:
  49 + *
  50 + * - logits A 3-D tensor of shape (N, num_words, vocab_size)
  51 + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
  52 + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
  53 + * - out_n_layer_cross_k Same as n_layer_cross_k
  54 + * - out_n_layer_cross_v Same as n_layer_cross_v
  55 + * - out_offset Same as offset
  56 + */
  57 + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
  58 + Ort::Value>
  59 + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
  60 + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
  61 + Ort::Value n_layer_cross_v, Ort::Value offset);
  62 +
  63 + /** Return the initial self kv cache in a pair
  64 + * - n_layer_self_k_cache A 4-D tensor of shape
  65 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  66 + * - n_layer_self_v_cache A 4-D tensor of shape
  67 + * (n_text_layer, N, n_audio_ctx, n_text_state).
  68 + */
  69 + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
  70 + const std::vector<int64_t> &GetInitialTokens() const;
  71 +
  72 + /** Return an allocator for allocating memory
  73 + */
  74 + OrtAllocator *Allocator() const;
  75 + int32_t EOT() const;
  76 + int32_t TextCtx() const;
  77 +
  78 + private:
  79 + class Impl;
  80 + std::unique_ptr<Impl> impl_;
  81 +};
  82 +
  83 +} // namespace sherpa_onnx
  84 +
  85 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
@@ -98,11 +98,15 @@ Usage: @@ -98,11 +98,15 @@ Usage:
98 ./bin/sherpa-onnx-microphone-offline \ 98 ./bin/sherpa-onnx-microphone-offline \
99 --tokens=/path/to/tokens.txt \ 99 --tokens=/path/to/tokens.txt \
100 --paraformer=/path/to/model.onnx \ 100 --paraformer=/path/to/model.onnx \
101 - --num-threads=2 \  
102 - --decoding-method=greedy_search 101 + --num-threads=1
103 102
104 -Default value for num_threads is 2.  
105 -Valid values for decoding_method: greedy_search. 103 +(3) Whisper models
  104 +
  105 + ./bin/sherpa-onnx-microphone-offline \
  106 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  107 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  108 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  109 + --num-threads=1
106 110
107 Please refer to 111 Please refer to
108 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 112 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -23,7 +23,7 @@ Usage: @@ -23,7 +23,7 @@ Usage:
23 --encoder=/path/to/encoder.onnx \ 23 --encoder=/path/to/encoder.onnx \
24 --decoder=/path/to/decoder.onnx \ 24 --decoder=/path/to/decoder.onnx \
25 --joiner=/path/to/joiner.onnx \ 25 --joiner=/path/to/joiner.onnx \
26 - --num-threads=2 \ 26 + --num-threads=1 \
27 --decoding-method=greedy_search \ 27 --decoding-method=greedy_search \
28 /path/to/foo.wav [bar.wav foobar.wav ...] 28 /path/to/foo.wav [bar.wav foobar.wav ...]
29 29
@@ -33,14 +33,22 @@ Usage: @@ -33,14 +33,22 @@ Usage:
33 ./bin/sherpa-onnx-offline \ 33 ./bin/sherpa-onnx-offline \
34 --tokens=/path/to/tokens.txt \ 34 --tokens=/path/to/tokens.txt \
35 --paraformer=/path/to/model.onnx \ 35 --paraformer=/path/to/model.onnx \
36 - --num-threads=2 \ 36 + --num-threads=1 \
37 --decoding-method=greedy_search \ 37 --decoding-method=greedy_search \
38 /path/to/foo.wav [bar.wav foobar.wav ...] 38 /path/to/foo.wav [bar.wav foobar.wav ...]
39 39
  40 +(3) Whisper models
  41 +
  42 + ./bin/sherpa-onnx-offline \
  43 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  44 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  45 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  46 + --num-threads=1 \
  47 + /path/to/foo.wav [bar.wav foobar.wav ...]
  48 +
  49 +
40 Note: It supports decoding multiple files in batches 50 Note: It supports decoding multiple files in batches
41 51
42 -Default value for num_threads is 2.  
43 -Valid values for decoding_method: greedy_search.  
44 foo.wav should be of single channel, 16-bit PCM encoded wave file; its 52 foo.wav should be of single channel, 16-bit PCM encoded wave file; its
45 sampling rate can be arbitrary and does not need to be 16kHz. 53 sampling rate can be arbitrary and does not need to be 16kHz.
46 54
@@ -55,6 +63,7 @@ for a list of pre-trained models to download. @@ -55,6 +63,7 @@ for a list of pre-trained models to download.
55 63
56 po.Read(argc, argv); 64 po.Read(argc, argv);
57 if (po.NumArgs() < 1) { 65 if (po.NumArgs() < 1) {
  66 + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
58 po.PrintUsage(); 67 po.PrintUsage();
59 exit(EXIT_FAILURE); 68 exit(EXIT_FAILURE);
60 } 69 }
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <sstream> 9 #include <sstream>
10 #include <strstream> 10 #include <strstream>
11 11
  12 +#include "sherpa-onnx/csrc/base64-decode.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 13 #include "sherpa-onnx/csrc/onnx-utils.h"
13 14
14 #if __ANDROID_API__ >= 9 15 #if __ANDROID_API__ >= 9
@@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { @@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
82 return os << symbol_table.ToString(); 83 return os << symbol_table.ToString();
83 } 84 }
84 85
  86 +void SymbolTable::ApplyBase64Decode() {
  87 + sym2id_.clear();
  88 + for (auto &p : id2sym_) {
  89 + p.second = Base64Decode(p.second);
  90 + sym2id_[p.second] = p.first;
  91 + }
  92 +}
  93 +
85 } // namespace sherpa_onnx 94 } // namespace sherpa_onnx
@@ -45,6 +45,9 @@ class SymbolTable { @@ -45,6 +45,9 @@ class SymbolTable {
45 /// Return true if there is a given symbol in the symbol table. 45 /// Return true if there is a given symbol in the symbol table.
46 bool contains(const std::string &sym) const; 46 bool contains(const std::string &sym) const;
47 47
  48 + // for tokens.txt from Whisper
  49 + void ApplyBase64Decode();
  50 +
48 private: 51 private:
49 void Init(std::istream &is); 52 void Init(std::istream &is);
50 53
@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
11 offline-recognizer.cc 11 offline-recognizer.cc
12 offline-stream.cc 12 offline-stream.cc
13 offline-transducer-model-config.cc 13 offline-transducer-model-config.cc
  14 + offline-whisper-model-config.cc
14 online-lm-config.cc 15 online-lm-config.cc
15 online-recognizer.cc 16 online-recognizer.cc
16 online-stream.cc 17 online-stream.cc
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
  14 +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
@@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) { @@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) {
18 PybindOfflineTransducerModelConfig(m); 19 PybindOfflineTransducerModelConfig(m);
19 PybindOfflineParaformerModelConfig(m); 20 PybindOfflineParaformerModelConfig(m);
20 PybindOfflineNemoEncDecCtcModelConfig(m); 21 PybindOfflineNemoEncDecCtcModelConfig(m);
  22 + PybindOfflineWhisperModelConfig(m);
21 23
22 using PyClass = OfflineModelConfig; 24 using PyClass = OfflineModelConfig;
23 py::class_<PyClass>(*m, "OfflineModelConfig") 25 py::class_<PyClass>(*m, "OfflineModelConfig")
24 - .def(  
25 - py::init<const OfflineTransducerModelConfig &,  
26 - const OfflineParaformerModelConfig &,  
27 - const OfflineNemoEncDecCtcModelConfig &, const std::string &,  
28 - int32_t, bool, const std::string &, const std::string &>(),  
29 - py::arg("transducer") = OfflineTransducerModelConfig(),  
30 - py::arg("paraformer") = OfflineParaformerModelConfig(),  
31 - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),  
32 - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,  
33 - py::arg("provider") = "cpu", py::arg("model_type") = "") 26 + .def(py::init<const OfflineTransducerModelConfig &,
  27 + const OfflineParaformerModelConfig &,
  28 + const OfflineNemoEncDecCtcModelConfig &,
  29 + const OfflineWhisperModelConfig &, const std::string &,
  30 + int32_t, bool, const std::string &, const std::string &>(),
  31 + py::arg("transducer") = OfflineTransducerModelConfig(),
  32 + py::arg("paraformer") = OfflineParaformerModelConfig(),
  33 + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
  34 + py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
  35 + py::arg("num_threads"), py::arg("debug") = false,
  36 + py::arg("provider") = "cpu", py::arg("model_type") = "")
34 .def_readwrite("transducer", &PyClass::transducer) 37 .def_readwrite("transducer", &PyClass::transducer)
35 .def_readwrite("paraformer", &PyClass::paraformer) 38 .def_readwrite("paraformer", &PyClass::paraformer)
36 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 39 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
  40 + .def_readwrite("whisper", &PyClass::whisper)
37 .def_readwrite("tokens", &PyClass::tokens) 41 .def_readwrite("tokens", &PyClass::tokens)
38 .def_readwrite("num_threads", &PyClass::num_threads) 42 .def_readwrite("num_threads", &PyClass::num_threads)
39 .def_readwrite("debug", &PyClass::debug) 43 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/offline-whisper-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineWhisperModelConfig(py::module *m) {
  15 + using PyClass = OfflineWhisperModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
  17 + .def(py::init<const std::string &, const std::string &>(),
  18 + py::arg("encoder"), py::arg("decoder"))
  19 + .def_readwrite("encoder", &PyClass::encoder)
  20 + .def_readwrite("decoder", &PyClass::decoder)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-whisper-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineWhisperModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
1 # Copyright (c) 2023 by manyeyes 1 # Copyright (c) 2023 by manyeyes
  2 +# Copyright (c) 2023 Xiaomi Corporation
2 from pathlib import Path 3 from pathlib import Path
3 from typing import List, Optional 4 from typing import List, Optional
4 5
@@ -7,6 +8,7 @@ from _sherpa_onnx import ( @@ -7,6 +8,7 @@ from _sherpa_onnx import (
7 OfflineModelConfig, 8 OfflineModelConfig,
8 OfflineNemoEncDecCtcModelConfig, 9 OfflineNemoEncDecCtcModelConfig,
9 OfflineParaformerModelConfig, 10 OfflineParaformerModelConfig,
  11 + OfflineWhisperModelConfig,
10 ) 12 )
11 from _sherpa_onnx import OfflineRecognizer as _Recognizer 13 from _sherpa_onnx import OfflineRecognizer as _Recognizer
12 from _sherpa_onnx import ( 14 from _sherpa_onnx import (
@@ -69,7 +71,7 @@ class OfflineRecognizer(object): @@ -69,7 +71,7 @@ class OfflineRecognizer(object):
69 feature_dim: 71 feature_dim:
70 Dimension of the feature used to train the model. 72 Dimension of the feature used to train the model.
71 decoding_method: 73 decoding_method:
72 - Support only greedy_search for now. 74 + Valid values: greedy_search, modified_beam_search.
73 debug: 75 debug:
74 True to show debug messages. 76 True to show debug messages.
75 provider: 77 provider:
@@ -137,7 +139,7 @@ class OfflineRecognizer(object): @@ -137,7 +139,7 @@ class OfflineRecognizer(object):
137 feature_dim: 139 feature_dim:
138 Dimension of the feature used to train the model. 140 Dimension of the feature used to train the model.
139 decoding_method: 141 decoding_method:
140 - Valid values are greedy_search, modified_beam_search. 142 + Valid values are greedy_search.
141 debug: 143 debug:
142 True to show debug messages. 144 True to show debug messages.
143 provider: 145 provider:
@@ -185,14 +187,14 @@ class OfflineRecognizer(object): @@ -185,14 +187,14 @@ class OfflineRecognizer(object):
185 English, etc. 187 English, etc.
186 188
187 Args: 189 Args:
  190 + model:
  191 + Path to ``model.onnx``.
188 tokens: 192 tokens:
189 Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two 193 Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
190 columns:: 194 columns::
191 195
192 symbol integer_id 196 symbol integer_id
193 197
194 - model:  
195 - Path to ``model.onnx``.  
196 num_threads: 198 num_threads:
197 Number of threads for neural network computation. 199 Number of threads for neural network computation.
198 sample_rate: 200 sample_rate:
@@ -200,7 +202,7 @@ class OfflineRecognizer(object): @@ -200,7 +202,7 @@ class OfflineRecognizer(object):
200 feature_dim: 202 feature_dim:
201 Dimension of the feature used to train the model. 203 Dimension of the feature used to train the model.
202 decoding_method: 204 decoding_method:
203 - Valid values are greedy_search, modified_beam_search. 205 + Valid values are greedy_search.
204 debug: 206 debug:
205 True to show debug messages. 207 True to show debug messages.
206 provider: 208 provider:
@@ -229,6 +231,68 @@ class OfflineRecognizer(object): @@ -229,6 +231,68 @@ class OfflineRecognizer(object):
229 self.recognizer = _Recognizer(recognizer_config) 231 self.recognizer = _Recognizer(recognizer_config)
230 return self 232 return self
231 233
  234 + @classmethod
  235 + def from_whisper(
  236 + cls,
  237 + encoder: str,
  238 + decoder: str,
  239 + tokens: str,
  240 + num_threads: int,
  241 + decoding_method: str = "greedy_search",
  242 + debug: bool = False,
  243 + provider: str = "cpu",
  244 + ):
  245 + """
  246 + Please refer to
  247 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  248 + to download pre-trained models for different kinds of whisper models,
  249 + e.g., tiny, tiny.en, base, base.en, etc.
  250 +
  251 + Args:
  252 + encoder_model:
  253 + Path to the encoder model, e.g., tiny-encoder.onnx,
  254 + tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
  255 + decoder_model:
  256 + Path to the encoder model, e.g., tiny-encoder.onnx,
  257 + tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
  258 + tokens:
  259 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  260 + columns::
  261 +
  262 + symbol integer_id
  263 +
  264 + num_threads:
  265 + Number of threads for neural network computation.
  266 + decoding_method:
  267 + Valid values: greedy_search.
  268 + debug:
  269 + True to show debug messages.
  270 + provider:
  271 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  272 + """
  273 + self = cls.__new__(cls)
  274 + model_config = OfflineModelConfig(
  275 + whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
  276 + tokens=tokens,
  277 + num_threads=num_threads,
  278 + debug=debug,
  279 + provider=provider,
  280 + model_type="whisper",
  281 + )
  282 +
  283 + feat_config = OfflineFeatureExtractorConfig(
  284 + sampling_rate=16000,
  285 + feature_dim=80,
  286 + )
  287 +
  288 + recognizer_config = OfflineRecognizerConfig(
  289 + feat_config=feat_config,
  290 + model_config=model_config,
  291 + decoding_method=decoding_method,
  292 + )
  293 + self.recognizer = _Recognizer(recognizer_config)
  294 + return self
  295 +
232 def create_stream(self, contexts_list: Optional[List[List[int]]] = None): 296 def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
233 if contexts_list is None: 297 if contexts_list is None:
234 return self.recognizer.create_stream() 298 return self.recognizer.create_stream()