正在显示
39 个修改的文件
包含
1835 行增加
和
51 行删除
| 1 | +name: export-whisper-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: release-whisper-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + release-whisper-models: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: ${{ matrix.model }} | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + model: ["tiny.en", "base.en", "small.en", "medium.en"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v2 | ||
| 23 | + | ||
| 24 | + - name: Install dependencies | ||
| 25 | + shell: bash | ||
| 26 | + run: | | ||
| 27 | + python3 -m pip install openai-whisper torch onnxruntime onnx | ||
| 28 | + | ||
| 29 | + - name: export ${{ matrix.model }} | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + cd scripts/whisper | ||
| 33 | + python3 ./export-onnx.py --model ${{ matrix.model }} | ||
| 34 | + python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ | ||
| 35 | + | ||
| 36 | + ls -lh | ||
| 37 | + | ||
| 38 | + ls -lh ~/.cache/whisper | ||
| 39 | + | ||
| 40 | + - name: Publish ${{ matrix.model }} to huggingface | ||
| 41 | + shell: bash | ||
| 42 | + env: | ||
| 43 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 44 | + run: | | ||
| 45 | + cd scripts/whisper | ||
| 46 | + | ||
| 47 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 48 | + git config --global user.name "Fangjun Kuang" | ||
| 49 | + | ||
| 50 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface | ||
| 51 | + | ||
| 52 | + cp *.onnx ./huggingface | ||
| 53 | + cp *.ort ./huggingface | ||
| 54 | + cp *tokens.txt ./huggingface | ||
| 55 | + | ||
| 56 | + cd huggingface | ||
| 57 | + git status | ||
| 58 | + ls -lh | ||
| 59 | + git lfs track "*.onnx" | ||
| 60 | + git lfs track "*.ort" | ||
| 61 | + git add . | ||
| 62 | + git commit -m "upload ${{ matrix.model }}" | ||
| 63 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main |
| @@ -23,14 +23,14 @@ on: | @@ -23,14 +23,14 @@ on: | ||
| 23 | - 'sherpa-onnx/jni/*' | 23 | - 'sherpa-onnx/jni/*' |
| 24 | 24 | ||
| 25 | concurrency: | 25 | concurrency: |
| 26 | - group: jni-${{ github.ref }} | 26 | + group: run-java-test-${{ github.ref }} |
| 27 | cancel-in-progress: true | 27 | cancel-in-progress: true |
| 28 | 28 | ||
| 29 | permissions: | 29 | permissions: |
| 30 | contents: read | 30 | contents: read |
| 31 | 31 | ||
| 32 | jobs: | 32 | jobs: |
| 33 | - jni: | 33 | + run_java_test: |
| 34 | runs-on: ${{ matrix.os }} | 34 | runs-on: ${{ matrix.os }} |
| 35 | strategy: | 35 | strategy: |
| 36 | fail-fast: false | 36 | fail-fast: false |
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz | ||
| 16 | - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz | ||
| 17 | - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.17.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz |
| 16 | + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz | ||
| 17 | + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.18.1.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
| 2 | # | 2 | # |
| 3 | # Copyright (c) 2023 by manyeyes | 3 | # Copyright (c) 2023 by manyeyes |
| 4 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | 5 | ||
| 5 | """ | 6 | """ |
| 6 | This file demonstrates how to use sherpa-onnx Python API to transcribe | 7 | This file demonstrates how to use sherpa-onnx Python API to transcribe |
| @@ -34,6 +35,27 @@ file(s) with a non-streaming model. | @@ -34,6 +35,27 @@ file(s) with a non-streaming model. | ||
| 34 | 35 | ||
| 35 | (3) For CTC models from NeMo | 36 | (3) For CTC models from NeMo |
| 36 | 37 | ||
| 38 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 39 | + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ | ||
| 40 | + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ | ||
| 41 | + --num-threads=2 \ | ||
| 42 | + --decoding-method=greedy_search \ | ||
| 43 | + --debug=false \ | ||
| 44 | + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ | ||
| 45 | + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ | ||
| 46 | + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav | ||
| 47 | + | ||
| 48 | +(4) For Whisper models | ||
| 49 | + | ||
| 50 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 51 | + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ | ||
| 52 | + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ | ||
| 53 | + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ | ||
| 54 | + --num-threads=1 \ | ||
| 55 | + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ | ||
| 56 | + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ | ||
| 57 | + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav | ||
| 58 | + | ||
| 37 | Please refer to | 59 | Please refer to |
| 38 | https://k2-fsa.github.io/sherpa/onnx/index.html | 60 | https://k2-fsa.github.io/sherpa/onnx/index.html |
| 39 | to install sherpa-onnx and to download the pre-trained models | 61 | to install sherpa-onnx and to download the pre-trained models |
| @@ -145,6 +167,20 @@ def get_args(): | @@ -145,6 +167,20 @@ def get_args(): | ||
| 145 | ) | 167 | ) |
| 146 | 168 | ||
| 147 | parser.add_argument( | 169 | parser.add_argument( |
| 170 | + "--whisper-encoder", | ||
| 171 | + default="", | ||
| 172 | + type=str, | ||
| 173 | + help="Path to whisper encoder model", | ||
| 174 | + ) | ||
| 175 | + | ||
| 176 | + parser.add_argument( | ||
| 177 | + "--whisper-decoder", | ||
| 178 | + default="", | ||
| 179 | + type=str, | ||
| 180 | + help="Path to whisper decoder model", | ||
| 181 | + ) | ||
| 182 | + | ||
| 183 | + parser.add_argument( | ||
| 148 | "--decoding-method", | 184 | "--decoding-method", |
| 149 | type=str, | 185 | type=str, |
| 150 | default="greedy_search", | 186 | default="greedy_search", |
| @@ -247,6 +283,8 @@ def main(): | @@ -247,6 +283,8 @@ def main(): | ||
| 247 | if args.encoder: | 283 | if args.encoder: |
| 248 | assert len(args.paraformer) == 0, args.paraformer | 284 | assert len(args.paraformer) == 0, args.paraformer |
| 249 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 285 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 286 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 287 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 250 | 288 | ||
| 251 | contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | 289 | contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] |
| 252 | if contexts: | 290 | if contexts: |
| @@ -271,6 +309,9 @@ def main(): | @@ -271,6 +309,9 @@ def main(): | ||
| 271 | ) | 309 | ) |
| 272 | elif args.paraformer: | 310 | elif args.paraformer: |
| 273 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 311 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 312 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 313 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 314 | + | ||
| 274 | assert_file_exists(args.paraformer) | 315 | assert_file_exists(args.paraformer) |
| 275 | 316 | ||
| 276 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | 317 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( |
| @@ -283,6 +324,11 @@ def main(): | @@ -283,6 +324,11 @@ def main(): | ||
| 283 | debug=args.debug, | 324 | debug=args.debug, |
| 284 | ) | 325 | ) |
| 285 | elif args.nemo_ctc: | 326 | elif args.nemo_ctc: |
| 327 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 328 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 329 | + | ||
| 330 | + assert_file_exists(args.nemo_ctc) | ||
| 331 | + | ||
| 286 | recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | 332 | recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( |
| 287 | model=args.nemo_ctc, | 333 | model=args.nemo_ctc, |
| 288 | tokens=args.tokens, | 334 | tokens=args.tokens, |
| @@ -292,6 +338,18 @@ def main(): | @@ -292,6 +338,18 @@ def main(): | ||
| 292 | decoding_method=args.decoding_method, | 338 | decoding_method=args.decoding_method, |
| 293 | debug=args.debug, | 339 | debug=args.debug, |
| 294 | ) | 340 | ) |
| 341 | + elif args.whisper_encoder: | ||
| 342 | + assert_file_exists(args.whisper_encoder) | ||
| 343 | + assert_file_exists(args.whisper_decoder) | ||
| 344 | + | ||
| 345 | + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | ||
| 346 | + encoder=args.whisper_encoder, | ||
| 347 | + decoder=args.whisper_decoder, | ||
| 348 | + tokens=args.tokens, | ||
| 349 | + num_threads=args.num_threads, | ||
| 350 | + decoding_method=args.decoding_method, | ||
| 351 | + debug=args.debug, | ||
| 352 | + ) | ||
| 295 | else: | 353 | else: |
| 296 | print("Please specify at least one model") | 354 | print("Please specify at least one model") |
| 297 | return | 355 | return |
scripts/whisper/.gitignore
0 → 100644
scripts/whisper/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains code showing how to convert [Whisper][whisper] to onnx | ||
| 4 | +and use onnxruntime to replace PyTorch for speech recognition. | ||
| 5 | + | ||
| 6 | +You can use [sherpa-onnx][sherpa-onnx] to run the converted model. | ||
| 7 | + | ||
| 8 | +[whisper]: https://github.com/openai/whisper | ||
| 9 | +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx |
scripts/whisper/export-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +# flake8: noqa | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +Note: Code in this file is modified from | ||
| 7 | +https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py | ||
| 8 | + | ||
| 9 | +Thanks to https://github.com/TadaoYamaoka | ||
| 10 | +for making the onnx export script public. | ||
| 11 | +""" | ||
| 12 | + | ||
| 13 | +import argparse | ||
| 14 | +from pathlib import Path | ||
| 15 | +from typing import Any, Dict, Optional | ||
| 16 | + | ||
| 17 | +import onnx | ||
| 18 | +import torch | ||
| 19 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 20 | +from torch import Tensor, nn | ||
| 21 | + | ||
| 22 | +import whisper | ||
| 23 | +from whisper.model import ( | ||
| 24 | + AudioEncoder, | ||
| 25 | + MultiHeadAttention, | ||
| 26 | + ResidualAttentionBlock, | ||
| 27 | + TextDecoder, | ||
| 28 | +) | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +def get_args(): | ||
| 32 | + parser = argparse.ArgumentParser() | ||
| 33 | + parser.add_argument( | ||
| 34 | + "--model", | ||
| 35 | + type=str, | ||
| 36 | + required=True, | ||
| 37 | + # fmt: off | ||
| 38 | + choices=[ | ||
| 39 | + "tiny", "tiny.en", "base", "base.en", | ||
| 40 | + "small", "small.en", "medium", "medium.en", | ||
| 41 | + "large", "large-v1", "large-v2"], | ||
| 42 | + # fmt: on | ||
| 43 | + ) | ||
| 44 | + return parser.parse_args() | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
| 48 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 49 | + | ||
| 50 | + Args: | ||
| 51 | + filename: | ||
| 52 | + Filename of the ONNX model to be changed. | ||
| 53 | + meta_data: | ||
| 54 | + Key-value pairs. | ||
| 55 | + """ | ||
| 56 | + model = onnx.load(filename) | ||
| 57 | + for key, value in meta_data.items(): | ||
| 58 | + meta = model.metadata_props.add() | ||
| 59 | + meta.key = key | ||
| 60 | + meta.value = str(value) | ||
| 61 | + | ||
| 62 | + onnx.save(model, filename) | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +class AudioEncoderTensorCache(nn.Module): | ||
| 66 | + def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): | ||
| 67 | + super().__init__() | ||
| 68 | + self.audioEncoder = inAudioEncoder | ||
| 69 | + self.textDecoder = inTextDecoder | ||
| 70 | + | ||
| 71 | + def forward(self, x: Tensor): | ||
| 72 | + audio_features = self.audioEncoder(x) | ||
| 73 | + | ||
| 74 | + n_layer_cross_k_list = [] | ||
| 75 | + n_layer_cross_v_list = [] | ||
| 76 | + for block in self.textDecoder.blocks: | ||
| 77 | + n_layer_cross_k_list.append(block.cross_attn.key(audio_features)) | ||
| 78 | + n_layer_cross_v_list.append(block.cross_attn.value(audio_features)) | ||
| 79 | + | ||
| 80 | + return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list) | ||
| 81 | + | ||
| 82 | + | ||
| 83 | +class MultiHeadAttentionCross(nn.Module): | ||
| 84 | + def __init__(self, inMultiHeadAttention: MultiHeadAttention): | ||
| 85 | + super().__init__() | ||
| 86 | + self.multiHeadAttention = inMultiHeadAttention | ||
| 87 | + | ||
| 88 | + def forward( | ||
| 89 | + self, | ||
| 90 | + x: Tensor, | ||
| 91 | + k: Tensor, | ||
| 92 | + v: Tensor, | ||
| 93 | + mask: Optional[Tensor] = None, | ||
| 94 | + ): | ||
| 95 | + q = self.multiHeadAttention.query(x) | ||
| 96 | + wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask) | ||
| 97 | + return self.multiHeadAttention.out(wv) | ||
| 98 | + | ||
| 99 | + | ||
| 100 | +class MultiHeadAttentionSelf(nn.Module): | ||
| 101 | + def __init__(self, inMultiHeadAttention: MultiHeadAttention): | ||
| 102 | + super().__init__() | ||
| 103 | + self.multiHeadAttention = inMultiHeadAttention | ||
| 104 | + | ||
| 105 | + def forward( | ||
| 106 | + self, | ||
| 107 | + x: Tensor, # (b, n_ctx , n_state) | ||
| 108 | + k_cache: Tensor, # (b, n_ctx_cache, n_state) | ||
| 109 | + v_cache: Tensor, # (b, n_ctx_cache, n_state) | ||
| 110 | + mask: Tensor, | ||
| 111 | + ): | ||
| 112 | + q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state) | ||
| 113 | + k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state) | ||
| 114 | + v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state) | ||
| 115 | + | ||
| 116 | + k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state) | ||
| 117 | + v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state) | ||
| 118 | + | ||
| 119 | + wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask) | ||
| 120 | + return self.multiHeadAttention.out(wv), k_cache, v_cache | ||
| 121 | + | ||
| 122 | + | ||
| 123 | +class ResidualAttentionBlockTensorCache(nn.Module): | ||
| 124 | + def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock): | ||
| 125 | + super().__init__() | ||
| 126 | + self.originalBlock = inResidualAttentionBlock | ||
| 127 | + self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn) | ||
| 128 | + self.cross_attn = ( | ||
| 129 | + MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn) | ||
| 130 | + if inResidualAttentionBlock.cross_attn | ||
| 131 | + else None | ||
| 132 | + ) | ||
| 133 | + | ||
| 134 | + def forward( | ||
| 135 | + self, | ||
| 136 | + x: Tensor, | ||
| 137 | + self_k_cache: Tensor, | ||
| 138 | + self_v_cache: Tensor, | ||
| 139 | + cross_k: Tensor, | ||
| 140 | + cross_v: Tensor, | ||
| 141 | + mask: Tensor, | ||
| 142 | + ): | ||
| 143 | + self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn( | ||
| 144 | + self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask | ||
| 145 | + ) | ||
| 146 | + x = x + self_attn_x | ||
| 147 | + | ||
| 148 | + if self.cross_attn: | ||
| 149 | + x = x + self.cross_attn( | ||
| 150 | + self.originalBlock.cross_attn_ln(x), cross_k, cross_v | ||
| 151 | + ) | ||
| 152 | + | ||
| 153 | + x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x)) | ||
| 154 | + return x, self_k_cache_updated, self_v_cache_updated | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +class TextDecoderTensorCache(nn.Module): | ||
| 158 | + def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int): | ||
| 159 | + super().__init__() | ||
| 160 | + self.textDecoder = inTextDecoder | ||
| 161 | + self.n_ctx = in_n_ctx | ||
| 162 | + | ||
| 163 | + self.blocks = [] | ||
| 164 | + for orginal_block in self.textDecoder.blocks: | ||
| 165 | + self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block)) | ||
| 166 | + | ||
| 167 | + def forward( | ||
| 168 | + self, | ||
| 169 | + tokens: Tensor, | ||
| 170 | + n_layer_self_k_cache: Tensor, | ||
| 171 | + n_layer_self_v_cache: Tensor, | ||
| 172 | + n_layer_cross_k: Tensor, | ||
| 173 | + n_layer_cross_v: Tensor, | ||
| 174 | + offset: Tensor, | ||
| 175 | + ): | ||
| 176 | + x = ( | ||
| 177 | + self.textDecoder.token_embedding(tokens) | ||
| 178 | + + self.textDecoder.positional_embedding[ | ||
| 179 | + offset[0] : offset[0] + tokens.shape[-1] | ||
| 180 | + ] | ||
| 181 | + ) | ||
| 182 | + x = x.to(n_layer_cross_k[0].dtype) | ||
| 183 | + | ||
| 184 | + i = 0 | ||
| 185 | + for block in self.blocks: | ||
| 186 | + self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] | ||
| 187 | + self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] | ||
| 188 | + x, self_k_cache, self_v_cache = block( | ||
| 189 | + x, | ||
| 190 | + self_k_cache=self_k_cache, | ||
| 191 | + self_v_cache=self_v_cache, | ||
| 192 | + cross_k=n_layer_cross_k[i], | ||
| 193 | + cross_v=n_layer_cross_v[i], | ||
| 194 | + mask=self.textDecoder.mask, | ||
| 195 | + ) | ||
| 196 | + n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache | ||
| 197 | + n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache | ||
| 198 | + i += 1 | ||
| 199 | + | ||
| 200 | + x = self.textDecoder.ln(x) | ||
| 201 | + | ||
| 202 | + logits = ( | ||
| 203 | + x | ||
| 204 | + @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1) | ||
| 205 | + ).float() | ||
| 206 | + | ||
| 207 | + return logits, n_layer_self_k_cache, n_layer_self_v_cache | ||
| 208 | + | ||
| 209 | + | ||
| 210 | +# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232 | ||
| 211 | +def convert_tokens(name, model): | ||
| 212 | + whisper_dir = Path(whisper.__file__).parent | ||
| 213 | + multilingual = model.is_multilingual | ||
| 214 | + tokenizer = ( | ||
| 215 | + whisper_dir | ||
| 216 | + / "assets" | ||
| 217 | + / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") | ||
| 218 | + ) | ||
| 219 | + if not tokenizer.is_file(): | ||
| 220 | + raise ValueError(f"Cannot find {tokenizer}") | ||
| 221 | + | ||
| 222 | + # import base64 | ||
| 223 | + | ||
| 224 | + with open(tokenizer, "r") as f: | ||
| 225 | + contents = f.read() | ||
| 226 | + # tokens = { | ||
| 227 | + # base64.b64decode(token): int(rank) | ||
| 228 | + # for token, rank in (line.split() for line in contents.splitlines() if line) | ||
| 229 | + # } | ||
| 230 | + tokens = { | ||
| 231 | + token: int(rank) | ||
| 232 | + for token, rank in (line.split() for line in contents.splitlines() if line) | ||
| 233 | + } | ||
| 234 | + | ||
| 235 | + with open(f"{name}-tokens.txt", "w") as f: | ||
| 236 | + for t, i in tokens.items(): | ||
| 237 | + f.write(f"{t} {i}\n") | ||
| 238 | + | ||
| 239 | + | ||
| 240 | +@torch.no_grad() | ||
| 241 | +def main(): | ||
| 242 | + args = get_args() | ||
| 243 | + name = args.model | ||
| 244 | + | ||
| 245 | + opset_version = 13 | ||
| 246 | + | ||
| 247 | + model = whisper.load_model(name) | ||
| 248 | + convert_tokens(name=name, model=model) | ||
| 249 | + | ||
| 250 | + # write tokens | ||
| 251 | + | ||
| 252 | + tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) | ||
| 253 | + model.eval() | ||
| 254 | + print(model.dims) | ||
| 255 | + audio = torch.rand(16000 * 2) | ||
| 256 | + audio = whisper.pad_or_trim(audio) | ||
| 257 | + assert audio.shape == (16000 * 30,), audio.shape | ||
| 258 | + | ||
| 259 | + # make log-Mel spectrogram and move to the same device as the model | ||
| 260 | + mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) | ||
| 261 | + batch_size = 1 | ||
| 262 | + assert mel.shape == (batch_size, 80, 30 * 100) | ||
| 263 | + | ||
| 264 | + encoder = AudioEncoderTensorCache(model.encoder, model.decoder) | ||
| 265 | + n_layer_cross_k, n_layer_cross_v = encoder(mel) | ||
| 266 | + assert n_layer_cross_k.shape == ( | ||
| 267 | + model.dims.n_text_layer, | ||
| 268 | + batch_size, | ||
| 269 | + model.dims.n_audio_ctx, | ||
| 270 | + model.dims.n_text_state, | ||
| 271 | + ), n_layer_cross_k.shape | ||
| 272 | + assert n_layer_cross_v.shape == ( | ||
| 273 | + model.dims.n_text_layer, | ||
| 274 | + batch_size, | ||
| 275 | + model.dims.n_audio_ctx, | ||
| 276 | + model.dims.n_text_state, | ||
| 277 | + ), n_layer_cross_v.shape | ||
| 278 | + | ||
| 279 | + encoder_filename = f"{name}-encoder.onnx" | ||
| 280 | + torch.onnx.export( | ||
| 281 | + encoder, | ||
| 282 | + mel, | ||
| 283 | + encoder_filename, | ||
| 284 | + opset_version=opset_version, | ||
| 285 | + input_names=["mel"], | ||
| 286 | + output_names=["n_layer_cross_k", "n_layer_cross_v"], | ||
| 287 | + dynamic_axes={ | ||
| 288 | + "mel": {0: "n_audio"}, # n_audio is also known as batch_size | ||
| 289 | + "n_layer_cross_k": {1: "n_audio"}, | ||
| 290 | + "n_layer_cross_v": {1: "n_audio"}, | ||
| 291 | + }, | ||
| 292 | + ) | ||
| 293 | + | ||
| 294 | + encoder_meta_data = { | ||
| 295 | + "model_type": f"whisper-{name}", | ||
| 296 | + "version": "1", | ||
| 297 | + "maintainer": "k2-fsa", | ||
| 298 | + "n_mels": model.dims.n_mels, | ||
| 299 | + "n_audio_ctx": model.dims.n_audio_ctx, | ||
| 300 | + "n_audio_state": model.dims.n_audio_state, | ||
| 301 | + "n_audio_head": model.dims.n_audio_head, | ||
| 302 | + "n_audio_layer": model.dims.n_audio_layer, | ||
| 303 | + "n_vocab": model.dims.n_vocab, | ||
| 304 | + "n_text_ctx": model.dims.n_text_ctx, | ||
| 305 | + "n_text_state": model.dims.n_text_state, | ||
| 306 | + "n_text_head": model.dims.n_text_head, | ||
| 307 | + "n_text_layer": model.dims.n_text_layer, | ||
| 308 | + "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), | ||
| 309 | + "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))), | ||
| 310 | + "all_language_codes": ",".join(tokenizer.all_language_codes), | ||
| 311 | + "sot": tokenizer.sot, | ||
| 312 | + "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), | ||
| 313 | + "eot": tokenizer.eot, | ||
| 314 | + "blank_id": tokenizer.encode(" ")[0], | ||
| 315 | + "is_multilingual": int(model.is_multilingual), | ||
| 316 | + "no_speech": tokenizer.no_speech, | ||
| 317 | + "non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))), | ||
| 318 | + "transcribe": tokenizer.transcribe, | ||
| 319 | + "translate": tokenizer.translate, | ||
| 320 | + "sot_prev": tokenizer.sot_prev, | ||
| 321 | + "sot_lm": tokenizer.sot_lm, | ||
| 322 | + "no_timestamps": tokenizer.no_timestamps, | ||
| 323 | + } | ||
| 324 | + print(f"encoder_meta_data: {encoder_meta_data}") | ||
| 325 | + add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data) | ||
| 326 | + | ||
| 327 | + n_audio = mel.shape[0] | ||
| 328 | + tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to( | ||
| 329 | + mel.device | ||
| 330 | + ) # [n_audio, 3] | ||
| 331 | + decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx) | ||
| 332 | + n_layer_self_k_cache = torch.zeros( | ||
| 333 | + ( | ||
| 334 | + len(model.decoder.blocks), | ||
| 335 | + n_audio, | ||
| 336 | + model.dims.n_text_ctx, | ||
| 337 | + model.dims.n_text_state, | ||
| 338 | + ), | ||
| 339 | + device=mel.device, | ||
| 340 | + ) | ||
| 341 | + n_layer_self_v_cache = torch.zeros( | ||
| 342 | + ( | ||
| 343 | + len(model.decoder.blocks), | ||
| 344 | + n_audio, | ||
| 345 | + model.dims.n_text_ctx, | ||
| 346 | + model.dims.n_text_state, | ||
| 347 | + ), | ||
| 348 | + device=mel.device, | ||
| 349 | + ) | ||
| 350 | + offset = torch.zeros(1, dtype=torch.int64).to(mel.device) | ||
| 351 | + logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder( | ||
| 352 | + tokens, | ||
| 353 | + n_layer_self_k_cache, | ||
| 354 | + n_layer_self_v_cache, | ||
| 355 | + n_layer_cross_k, | ||
| 356 | + n_layer_cross_v, | ||
| 357 | + offset, | ||
| 358 | + ) | ||
| 359 | + assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab) | ||
| 360 | + assert n_layer_self_k_cache.shape == ( | ||
| 361 | + model.dims.n_text_layer, | ||
| 362 | + n_audio, | ||
| 363 | + model.dims.n_text_ctx, | ||
| 364 | + model.dims.n_text_state, | ||
| 365 | + ) | ||
| 366 | + assert n_layer_self_v_cache.shape == ( | ||
| 367 | + model.dims.n_text_layer, | ||
| 368 | + n_audio, | ||
| 369 | + model.dims.n_text_ctx, | ||
| 370 | + model.dims.n_text_state, | ||
| 371 | + ) | ||
| 372 | + | ||
| 373 | + offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device) | ||
| 374 | + tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] | ||
| 375 | + | ||
| 376 | + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder( | ||
| 377 | + tokens, | ||
| 378 | + n_layer_self_k_cache, | ||
| 379 | + n_layer_self_v_cache, | ||
| 380 | + n_layer_cross_k, | ||
| 381 | + n_layer_cross_v, | ||
| 382 | + offset, | ||
| 383 | + ) | ||
| 384 | + | ||
| 385 | + decoder_filename = f"{name}-decoder.onnx" | ||
| 386 | + torch.onnx.export( | ||
| 387 | + decoder, | ||
| 388 | + ( | ||
| 389 | + tokens, | ||
| 390 | + n_layer_self_k_cache, | ||
| 391 | + n_layer_self_v_cache, | ||
| 392 | + n_layer_cross_k, | ||
| 393 | + n_layer_cross_v, | ||
| 394 | + offset, | ||
| 395 | + ), | ||
| 396 | + decoder_filename, | ||
| 397 | + opset_version=opset_version, | ||
| 398 | + input_names=[ | ||
| 399 | + "tokens", | ||
| 400 | + "in_n_layer_self_k_cache", | ||
| 401 | + "in_n_layer_self_v_cache", | ||
| 402 | + "n_layer_cross_k", | ||
| 403 | + "n_layer_cross_v", | ||
| 404 | + "offset", | ||
| 405 | + ], | ||
| 406 | + output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"], | ||
| 407 | + dynamic_axes={ | ||
| 408 | + "tokens": {0: "n_audio", 1: "n_tokens"}, | ||
| 409 | + "in_n_layer_self_k_cache": {1: "n_audio"}, | ||
| 410 | + "in_n_layer_self_v_cache": {1: "n_audio"}, | ||
| 411 | + "n_layer_cross_k": {1: "n_audio"}, | ||
| 412 | + "n_layer_cross_v": {1: "n_audio"}, | ||
| 413 | + }, | ||
| 414 | + ) | ||
| 415 | + | ||
| 416 | + # Generate int8 quantization models | ||
| 417 | + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection | ||
| 418 | + | ||
| 419 | + print("Generate int8 quantization models") | ||
| 420 | + | ||
| 421 | + encoder_filename_int8 = f"{name}-encoder.int8.onnx" | ||
| 422 | + quantize_dynamic( | ||
| 423 | + model_input=encoder_filename, | ||
| 424 | + model_output=encoder_filename_int8, | ||
| 425 | + op_types_to_quantize=["MatMul"], | ||
| 426 | + weight_type=QuantType.QInt8, | ||
| 427 | + ) | ||
| 428 | + | ||
| 429 | + decoder_filename_int8 = f"{name}-decoder.int8.onnx" | ||
| 430 | + quantize_dynamic( | ||
| 431 | + model_input=decoder_filename, | ||
| 432 | + model_output=decoder_filename_int8, | ||
| 433 | + op_types_to_quantize=["MatMul"], | ||
| 434 | + weight_type=QuantType.QInt8, | ||
| 435 | + ) | ||
| 436 | + | ||
| 437 | + | ||
| 438 | +if __name__ == "__main__": | ||
| 439 | + main() |
scripts/whisper/requirements.txt
0 → 100644
| 1 | +openai-whisper |
scripts/whisper/test.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +""" | ||
| 4 | +Please first run ./export-onnx.py | ||
| 5 | +before you run this script | ||
| 6 | +""" | ||
| 7 | +import base64 | ||
| 8 | +from typing import Tuple | ||
| 9 | + | ||
| 10 | +import kaldi_native_fbank as knf | ||
| 11 | +import onnxruntime as ort | ||
| 12 | +import torch | ||
| 13 | + | ||
| 14 | +import whisper | ||
| 15 | +import argparse | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def get_args(): | ||
| 19 | + parser = argparse.ArgumentParser() | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--model", | ||
| 22 | + type=str, | ||
| 23 | + required=True, | ||
| 24 | + # fmt: off | ||
| 25 | + choices=[ | ||
| 26 | + "tiny", "tiny.en", "base", "base.en", | ||
| 27 | + "small", "small.en", "medium", "medium.en", | ||
| 28 | + "large", "large-v1", "large-v2"], | ||
| 29 | + # fmt: on | ||
| 30 | + ) | ||
| 31 | + return parser.parse_args() | ||
| 32 | + | ||
| 33 | + | ||
| 34 | +class OnnxModel: | ||
| 35 | + def __init__( | ||
| 36 | + self, | ||
| 37 | + encoder: str, | ||
| 38 | + decoder: str, | ||
| 39 | + ): | ||
| 40 | + session_opts = ort.SessionOptions() | ||
| 41 | + session_opts.inter_op_num_threads = 1 | ||
| 42 | + session_opts.intra_op_num_threads = 4 | ||
| 43 | + | ||
| 44 | + self.session_opts = session_opts | ||
| 45 | + | ||
| 46 | + self.init_encoder(encoder) | ||
| 47 | + self.init_decoder(decoder) | ||
| 48 | + | ||
| 49 | + def init_encoder(self, encoder: str): | ||
| 50 | + self.encoder = ort.InferenceSession( | ||
| 51 | + encoder, | ||
| 52 | + sess_options=self.session_opts, | ||
| 53 | + ) | ||
| 54 | + | ||
| 55 | + meta = self.encoder.get_modelmeta().custom_metadata_map | ||
| 56 | + self.n_text_layer = int(meta["n_text_layer"]) | ||
| 57 | + self.n_text_ctx = int(meta["n_text_ctx"]) | ||
| 58 | + self.n_text_state = int(meta["n_text_state"]) | ||
| 59 | + self.sot = int(meta["sot"]) | ||
| 60 | + self.eot = int(meta["eot"]) | ||
| 61 | + self.translate = int(meta["translate"]) | ||
| 62 | + self.no_timestamps = int(meta["no_timestamps"]) | ||
| 63 | + self.no_speech = int(meta["no_speech"]) | ||
| 64 | + self.blank = int(meta["blank_id"]) | ||
| 65 | + | ||
| 66 | + self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) | ||
| 67 | + | ||
| 68 | + self.is_multilingual = int(meta["is_multilingual"]) == 1 | ||
| 69 | + | ||
| 70 | + def init_decoder(self, decoder: str): | ||
| 71 | + self.decoder = ort.InferenceSession( | ||
| 72 | + decoder, | ||
| 73 | + sess_options=self.session_opts, | ||
| 74 | + ) | ||
| 75 | + | ||
| 76 | + def run_encoder( | ||
| 77 | + self, | ||
| 78 | + mel: torch.Tensor, | ||
| 79 | + ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| 80 | + n_layer_cross_k, n_layer_cross_v = self.encoder.run( | ||
| 81 | + [ | ||
| 82 | + self.encoder.get_outputs()[0].name, | ||
| 83 | + self.encoder.get_outputs()[1].name, | ||
| 84 | + ], | ||
| 85 | + { | ||
| 86 | + self.encoder.get_inputs()[0].name: mel.numpy(), | ||
| 87 | + }, | ||
| 88 | + ) | ||
| 89 | + return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v) | ||
| 90 | + | ||
| 91 | + def run_decoder( | ||
| 92 | + self, | ||
| 93 | + tokens: torch.Tensor, | ||
| 94 | + n_layer_self_k_cache: torch.Tensor, | ||
| 95 | + n_layer_self_v_cache: torch.Tensor, | ||
| 96 | + n_layer_cross_k: torch.Tensor, | ||
| 97 | + n_layer_cross_v: torch.Tensor, | ||
| 98 | + offset: torch.Tensor, | ||
| 99 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| 100 | + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run( | ||
| 101 | + [ | ||
| 102 | + self.decoder.get_outputs()[0].name, | ||
| 103 | + self.decoder.get_outputs()[1].name, | ||
| 104 | + self.decoder.get_outputs()[2].name, | ||
| 105 | + ], | ||
| 106 | + { | ||
| 107 | + self.decoder.get_inputs()[0].name: tokens.numpy(), | ||
| 108 | + self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(), | ||
| 109 | + self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(), | ||
| 110 | + self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(), | ||
| 111 | + self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(), | ||
| 112 | + self.decoder.get_inputs()[5].name: offset.numpy(), | ||
| 113 | + }, | ||
| 114 | + ) | ||
| 115 | + return ( | ||
| 116 | + torch.from_numpy(logits), | ||
| 117 | + torch.from_numpy(out_n_layer_self_k_cache), | ||
| 118 | + torch.from_numpy(out_n_layer_self_v_cache), | ||
| 119 | + ) | ||
| 120 | + | ||
| 121 | + def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| 122 | + batch_size = 1 | ||
| 123 | + n_layer_self_k_cache = torch.zeros( | ||
| 124 | + self.n_text_layer, | ||
| 125 | + batch_size, | ||
| 126 | + self.n_text_ctx, | ||
| 127 | + self.n_text_state, | ||
| 128 | + ) | ||
| 129 | + n_layer_self_v_cache = torch.zeros( | ||
| 130 | + self.n_text_layer, | ||
| 131 | + batch_size, | ||
| 132 | + self.n_text_ctx, | ||
| 133 | + self.n_text_state, | ||
| 134 | + ) | ||
| 135 | + return n_layer_self_k_cache, n_layer_self_v_cache | ||
| 136 | + | ||
| 137 | + def suppress_tokens(self, logits, is_initial: bool) -> None: | ||
| 138 | + # suppress blank | ||
| 139 | + if is_initial: | ||
| 140 | + logits[self.eot] = float("-inf") | ||
| 141 | + logits[self.blank] = float("-inf") | ||
| 142 | + | ||
| 143 | + # suppress <|notimestamps|> | ||
| 144 | + logits[self.no_timestamps] = float("-inf") | ||
| 145 | + | ||
| 146 | + logits[self.sot] = float("-inf") | ||
| 147 | + logits[self.no_speech] = float("-inf") | ||
| 148 | + | ||
| 149 | + # logits is changed in-place | ||
| 150 | + logits[self.translate] = float("-inf") | ||
| 151 | + | ||
| 152 | + | ||
| 153 | +def load_tokens(filename): | ||
| 154 | + tokens = dict() | ||
| 155 | + with open(filename, "r") as f: | ||
| 156 | + for line in f: | ||
| 157 | + t, i = line.split() | ||
| 158 | + tokens[int(i)] = t | ||
| 159 | + return tokens | ||
| 160 | + | ||
| 161 | + | ||
| 162 | +def main(): | ||
| 163 | + args = get_args() | ||
| 164 | + name = args.model | ||
| 165 | + | ||
| 166 | + encoder = f"./{name}-encoder.onnx" | ||
| 167 | + decoder = f"./{name}-decoder.onnx" | ||
| 168 | + audio = whisper.load_audio("0.wav") | ||
| 169 | + | ||
| 170 | + features = [] | ||
| 171 | + online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) | ||
| 172 | + online_whisper_fbank.accept_waveform(16000, audio) | ||
| 173 | + online_whisper_fbank.input_finished() | ||
| 174 | + for i in range(online_whisper_fbank.num_frames_ready): | ||
| 175 | + f = online_whisper_fbank.get_frame(i) | ||
| 176 | + f = torch.from_numpy(f) | ||
| 177 | + features.append(f) | ||
| 178 | + | ||
| 179 | + features = torch.stack(features) | ||
| 180 | + | ||
| 181 | + log_spec = torch.clamp(features, min=1e-10).log10() | ||
| 182 | + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | ||
| 183 | + mel = (log_spec + 4.0) / 4.0 | ||
| 184 | + target = 3000 | ||
| 185 | + mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) | ||
| 186 | + mel = mel.t().unsqueeze(0) | ||
| 187 | + | ||
| 188 | + model = OnnxModel(encoder, decoder) | ||
| 189 | + n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) | ||
| 190 | + n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() | ||
| 191 | + | ||
| 192 | + tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) | ||
| 193 | + offset = torch.zeros(1, dtype=torch.int64) | ||
| 194 | + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( | ||
| 195 | + tokens=tokens, | ||
| 196 | + n_layer_self_k_cache=n_layer_self_k_cache, | ||
| 197 | + n_layer_self_v_cache=n_layer_self_v_cache, | ||
| 198 | + n_layer_cross_k=n_layer_cross_k, | ||
| 199 | + n_layer_cross_v=n_layer_cross_v, | ||
| 200 | + offset=offset, | ||
| 201 | + ) | ||
| 202 | + # logits.shape (batch_size, tokens.shape[1], vocab_size) | ||
| 203 | + logits = logits[0, -1] | ||
| 204 | + model.suppress_tokens(logits, is_initial=True) | ||
| 205 | + # logits = logits.softmax(dim=-1) | ||
| 206 | + # for greedy search, we don't need to compute softmax or log_softmax | ||
| 207 | + max_token_id = logits.argmax(dim=-1) | ||
| 208 | + results = [] | ||
| 209 | + for i in range(model.n_text_ctx): | ||
| 210 | + if max_token_id == model.eot: | ||
| 211 | + break | ||
| 212 | + results.append(max_token_id.item()) | ||
| 213 | + tokens = torch.tensor([[results[-1]]]) | ||
| 214 | + offset += 1 | ||
| 215 | + | ||
| 216 | + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( | ||
| 217 | + tokens=tokens, | ||
| 218 | + n_layer_self_k_cache=n_layer_self_k_cache, | ||
| 219 | + n_layer_self_v_cache=n_layer_self_v_cache, | ||
| 220 | + n_layer_cross_k=n_layer_cross_k, | ||
| 221 | + n_layer_cross_v=n_layer_cross_v, | ||
| 222 | + offset=offset, | ||
| 223 | + ) | ||
| 224 | + logits = logits[0, -1] | ||
| 225 | + model.suppress_tokens(logits, is_initial=False) | ||
| 226 | + max_token_id = logits.argmax(dim=-1) | ||
| 227 | + token_table = load_tokens(f"./{name}-tokens.txt") | ||
| 228 | + s = b"" | ||
| 229 | + for i in results: | ||
| 230 | + if i in token_table: | ||
| 231 | + s += base64.b64decode(token_table[i]) | ||
| 232 | + else: | ||
| 233 | + print("oov", i) | ||
| 234 | + | ||
| 235 | + print(s.decode().strip()) | ||
| 236 | + print(results) | ||
| 237 | + print(model.sot_sequence) | ||
| 238 | + | ||
| 239 | + | ||
| 240 | +if __name__ == "__main__": | ||
| 241 | + main() |
| @@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON) | @@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON) | ||
| 11 | endif() | 11 | endif() |
| 12 | 12 | ||
| 13 | set(sources | 13 | set(sources |
| 14 | + base64-decode.cc | ||
| 14 | cat.cc | 15 | cat.cc |
| 15 | context-graph.cc | 16 | context-graph.cc |
| 16 | endpoint.cc | 17 | endpoint.cc |
| @@ -35,6 +36,9 @@ set(sources | @@ -35,6 +36,9 @@ set(sources | ||
| 35 | offline-transducer-model-config.cc | 36 | offline-transducer-model-config.cc |
| 36 | offline-transducer-model.cc | 37 | offline-transducer-model.cc |
| 37 | offline-transducer-modified-beam-search-decoder.cc | 38 | offline-transducer-modified-beam-search-decoder.cc |
| 39 | + offline-whisper-greedy-search-decoder.cc | ||
| 40 | + offline-whisper-model-config.cc | ||
| 41 | + offline-whisper-model.cc | ||
| 38 | online-conformer-transducer-model.cc | 42 | online-conformer-transducer-model.cc |
| 39 | online-lm-config.cc | 43 | online-lm-config.cc |
| 40 | online-lm.cc | 44 | online-lm.cc |
| @@ -50,12 +54,12 @@ set(sources | @@ -50,12 +54,12 @@ set(sources | ||
| 50 | online-zipformer-transducer-model.cc | 54 | online-zipformer-transducer-model.cc |
| 51 | online-zipformer2-transducer-model.cc | 55 | online-zipformer2-transducer-model.cc |
| 52 | onnx-utils.cc | 56 | onnx-utils.cc |
| 53 | - session.cc | ||
| 54 | packed-sequence.cc | 57 | packed-sequence.cc |
| 55 | pad-sequence.cc | 58 | pad-sequence.cc |
| 56 | parse-options.cc | 59 | parse-options.cc |
| 57 | provider.cc | 60 | provider.cc |
| 58 | resample.cc | 61 | resample.cc |
| 62 | + session.cc | ||
| 59 | slice.cc | 63 | slice.cc |
| 60 | stack.cc | 64 | stack.cc |
| 61 | symbol-table.cc | 65 | symbol-table.cc |
sherpa-onnx/csrc/base64-decode.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/base64-decode.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/base64-decode.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +static int32_t Ord(char c) { | ||
| 12 | + if (c >= 'A' && c <= 'Z') { | ||
| 13 | + return c - 'A'; | ||
| 14 | + } else if (c >= 'a' && c <= 'z') { | ||
| 15 | + return c - 'a' + ('Z' - 'A') + 1; | ||
| 16 | + } else if (c >= '0' && c <= '9') { | ||
| 17 | + return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2; | ||
| 18 | + } else if (c == '+') { | ||
| 19 | + return 62; | ||
| 20 | + } else if (c == '/') { | ||
| 21 | + return 63; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c); | ||
| 25 | + | ||
| 26 | + exit(-1); | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +// see | ||
| 30 | +// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243 | ||
| 31 | +std::string Base64Decode(const std::string &s) { | ||
| 32 | + if (s.empty()) { | ||
| 33 | + SHERPA_ONNX_LOGE("Empty string!"); | ||
| 34 | + exit(-1); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + int32_t n = s.size() / 4 * 3; | ||
| 38 | + | ||
| 39 | + std::string ans; | ||
| 40 | + ans.reserve(n); | ||
| 41 | + | ||
| 42 | + int32_t i = 0; | ||
| 43 | + while (i < static_cast<int32_t>(s.size())) { | ||
| 44 | + if (s[i] == '=') { | ||
| 45 | + return " "; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4); | ||
| 49 | + ans.push_back(first); | ||
| 50 | + | ||
| 51 | + if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') { | ||
| 52 | + int32_t second = | ||
| 53 | + ((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2); | ||
| 54 | + ans.push_back(second); | ||
| 55 | + | ||
| 56 | + if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') { | ||
| 57 | + int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]); | ||
| 58 | + ans.push_back(third); | ||
| 59 | + } | ||
| 60 | + } | ||
| 61 | + i += 4; | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + return ans; | ||
| 65 | +} | ||
| 66 | + | ||
| 67 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/base64-decode.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/base64-decode.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +/** @param s A base64 encoded string. | ||
| 13 | + * @return Return the decoded string. | ||
| 14 | + */ | ||
| 15 | +std::string Base64Decode(const std::string &s); | ||
| 16 | + | ||
| 17 | +} // namespace sherpa_onnx | ||
| 18 | + | ||
| 19 | +#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_ |
| @@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 14 | transducer.Register(po); | 14 | transducer.Register(po); |
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | nemo_ctc.Register(po); | 16 | nemo_ctc.Register(po); |
| 17 | + whisper.Register(po); | ||
| 17 | 18 | ||
| 18 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 19 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 19 | 20 | ||
| @@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 28 | 29 | ||
| 29 | po->Register("model-type", &model_type, | 30 | po->Register("model-type", &model_type, |
| 30 | "Specify it to reduce model initialization time. " | 31 | "Specify it to reduce model initialization time. " |
| 31 | - "Valid values are: transducer, paraformer, nemo_ctc. " | 32 | + "Valid values are: transducer, paraformer, nemo_ctc, whisper." |
| 32 | "All other values lead to loading the model twice."); | 33 | "All other values lead to loading the model twice."); |
| 33 | } | 34 | } |
| 34 | 35 | ||
| @@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const { | @@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 51 | return nemo_ctc.Validate(); | 52 | return nemo_ctc.Validate(); |
| 52 | } | 53 | } |
| 53 | 54 | ||
| 55 | + if (!whisper.encoder.empty()) { | ||
| 56 | + return whisper.Validate(); | ||
| 57 | + } | ||
| 58 | + | ||
| 54 | return transducer.Validate(); | 59 | return transducer.Validate(); |
| 55 | } | 60 | } |
| 56 | 61 | ||
| @@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 61 | os << "transducer=" << transducer.ToString() << ", "; | 66 | os << "transducer=" << transducer.ToString() << ", "; |
| 62 | os << "paraformer=" << paraformer.ToString() << ", "; | 67 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 63 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | 68 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; |
| 69 | + os << "whisper=" << whisper.ToString() << ", "; | ||
| 64 | os << "tokens=\"" << tokens << "\", "; | 70 | os << "tokens=\"" << tokens << "\", "; |
| 65 | os << "num_threads=" << num_threads << ", "; | 71 | os << "num_threads=" << num_threads << ", "; |
| 66 | os << "debug=" << (debug ? "True" : "False") << ", "; | 72 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 12 | +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | 15 | ||
| @@ -16,6 +17,7 @@ struct OfflineModelConfig { | @@ -16,6 +17,7 @@ struct OfflineModelConfig { | ||
| 16 | OfflineTransducerModelConfig transducer; | 17 | OfflineTransducerModelConfig transducer; |
| 17 | OfflineParaformerModelConfig paraformer; | 18 | OfflineParaformerModelConfig paraformer; |
| 18 | OfflineNemoEncDecCtcModelConfig nemo_ctc; | 19 | OfflineNemoEncDecCtcModelConfig nemo_ctc; |
| 20 | + OfflineWhisperModelConfig whisper; | ||
| 19 | 21 | ||
| 20 | std::string tokens; | 22 | std::string tokens; |
| 21 | int32_t num_threads = 2; | 23 | int32_t num_threads = 2; |
| @@ -37,11 +39,13 @@ struct OfflineModelConfig { | @@ -37,11 +39,13 @@ struct OfflineModelConfig { | ||
| 37 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | 39 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, |
| 38 | const OfflineParaformerModelConfig ¶former, | 40 | const OfflineParaformerModelConfig ¶former, |
| 39 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | 41 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, |
| 42 | + const OfflineWhisperModelConfig &whisper, | ||
| 40 | const std::string &tokens, int32_t num_threads, bool debug, | 43 | const std::string &tokens, int32_t num_threads, bool debug, |
| 41 | const std::string &provider, const std::string &model_type) | 44 | const std::string &provider, const std::string &model_type) |
| 42 | : transducer(transducer), | 45 | : transducer(transducer), |
| 43 | paraformer(paraformer), | 46 | paraformer(paraformer), |
| 44 | nemo_ctc(nemo_ctc), | 47 | nemo_ctc(nemo_ctc), |
| 48 | + whisper(whisper), | ||
| 45 | tokens(tokens), | 49 | tokens(tokens), |
| 46 | num_threads(num_threads), | 50 | num_threads(num_threads), |
| 47 | debug(debug), | 51 | debug(debug), |
| @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { | @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { | ||
| 16 | 16 | ||
| 17 | bool OfflineNemoEncDecCtcModelConfig::Validate() const { | 17 | bool OfflineNemoEncDecCtcModelConfig::Validate() const { |
| 18 | if (!FileExists(model)) { | 18 | if (!FileExists(model)) { |
| 19 | - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | 19 | + SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str()); |
| 20 | return false; | 20 | return false; |
| 21 | } | 21 | } |
| 22 | 22 |
| @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) { | @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) { | ||
| 15 | 15 | ||
| 16 | bool OfflineParaformerModelConfig::Validate() const { | 16 | bool OfflineParaformerModelConfig::Validate() const { |
| 17 | if (!FileExists(model)) { | 17 | if (!FileExists(model)) { |
| 18 | - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | 18 | + SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str()); |
| 19 | return false; | 19 | return false; |
| 20 | } | 20 | } |
| 21 | 21 |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | 11 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" |
| 12 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" |
| 13 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 13 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| 14 | +#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" | ||
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 15 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | #include "sherpa-onnx/csrc/text-utils.h" | 16 | #include "sherpa-onnx/csrc/text-utils.h" |
| 16 | 17 | ||
| @@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 26 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 27 | } else if (model_type == "nemo_ctc") { | 28 | } else if (model_type == "nemo_ctc") { |
| 28 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 29 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 30 | + } else if (model_type == "whisper") { | ||
| 31 | + return std::make_unique<OfflineRecognizerWhisperImpl>(config); | ||
| 29 | } else { | 32 | } else { |
| 30 | SHERPA_ONNX_LOGE( | 33 | SHERPA_ONNX_LOGE( |
| 31 | "Invalid model_type: %s. Trying to load the model to get its type", | 34 | "Invalid model_type: %s. Trying to load the model to get its type", |
| @@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 43 | model_filename = config.model_config.paraformer.model; | 46 | model_filename = config.model_config.paraformer.model; |
| 44 | } else if (!config.model_config.nemo_ctc.model.empty()) { | 47 | } else if (!config.model_config.nemo_ctc.model.empty()) { |
| 45 | model_filename = config.model_config.nemo_ctc.model; | 48 | model_filename = config.model_config.nemo_ctc.model; |
| 49 | + } else if (!config.model_config.whisper.encoder.empty()) { | ||
| 50 | + model_filename = config.model_config.whisper.encoder; | ||
| 46 | } else { | 51 | } else { |
| 47 | SHERPA_ONNX_LOGE("Please provide a model"); | 52 | SHERPA_ONNX_LOGE("Please provide a model"); |
| 48 | exit(-1); | 53 | exit(-1); |
| @@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 77 | "\n " | 82 | "\n " |
| 78 | "https://huggingface.co/csukuangfj/" | 83 | "https://huggingface.co/csukuangfj/" |
| 79 | "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" | 84 | "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" |
| 85 | + "\n " | ||
| 86 | + "(3) Whisper" | ||
| 80 | "\n"); | 87 | "\n"); |
| 81 | exit(-1); | 88 | exit(-1); |
| 82 | } | 89 | } |
| @@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 95 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 102 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 96 | } | 103 | } |
| 97 | 104 | ||
| 105 | + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { | ||
| 106 | + return std::make_unique<OfflineRecognizerWhisperImpl>(config); | ||
| 107 | + } | ||
| 108 | + | ||
| 98 | SHERPA_ONNX_LOGE( | 109 | SHERPA_ONNX_LOGE( |
| 99 | "\nUnsupported model_type: %s\n" | 110 | "\nUnsupported model_type: %s\n" |
| 100 | "We support only the following model types at present: \n" | 111 | "We support only the following model types at present: \n" |
| 101 | " - Non-streaming transducer models from icefall\n" | 112 | " - Non-streaming transducer models from icefall\n" |
| 102 | " - Non-streaming Paraformer models from FunASR\n" | 113 | " - Non-streaming Paraformer models from FunASR\n" |
| 103 | - " - EncDecCTCModelBPE models from NeMo\n", | 114 | + " - EncDecCTCModelBPE models from NeMo\n" |
| 115 | + " - Whisper models\n", | ||
| 104 | model_type.c_str()); | 116 | model_type.c_str()); |
| 105 | 117 | ||
| 106 | exit(-1); | 118 | exit(-1); |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <cmath> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <string> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-whisper-decoder.h" | ||
| 19 | +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" | ||
| 20 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 21 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 22 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 23 | + | ||
| 24 | +namespace sherpa_onnx { | ||
| 25 | + | ||
| 26 | +static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 27 | + const SymbolTable &sym_table) { | ||
| 28 | + OfflineRecognitionResult r; | ||
| 29 | + r.tokens.reserve(src.tokens.size()); | ||
| 30 | + | ||
| 31 | + for (auto i : src.tokens) { | ||
| 32 | + if (!sym_table.contains(i)) { | ||
| 33 | + continue; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + const auto &s = sym_table[i]; | ||
| 37 | + r.text += s; | ||
| 38 | + r.tokens.push_back(s); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + return r; | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 45 | + public: | ||
| 46 | + explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) | ||
| 47 | + : config_(config), | ||
| 48 | + symbol_table_(config_.model_config.tokens), | ||
| 49 | + model_(std::make_unique<OfflineWhisperModel>(config.model_config)) { | ||
| 50 | + // tokens.txt from whisper is base64 encoded, so we need to decode it | ||
| 51 | + symbol_table_.ApplyBase64Decode(); | ||
| 52 | + | ||
| 53 | + if (config.decoding_method == "greedy_search") { | ||
| 54 | + decoder_ = | ||
| 55 | + std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get()); | ||
| 56 | + } else { | ||
| 57 | + SHERPA_ONNX_LOGE( | ||
| 58 | + "Only greedy_search is supported at present for whisper. Given %s", | ||
| 59 | + config.decoding_method.c_str()); | ||
| 60 | + exit(-1); | ||
| 61 | + } | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 65 | + return std::make_unique<OfflineStream>(WhisperTag{}); | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 69 | + // batch decoding is not implemented yet | ||
| 70 | + for (int32_t i = 0; i != n; ++i) { | ||
| 71 | + DecodeStream(ss[i]); | ||
| 72 | + } | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + private: | ||
| 76 | + void DecodeStream(OfflineStream *s) const { | ||
| 77 | + int32_t max_num_frames = 3000; | ||
| 78 | + auto memory_info = | ||
| 79 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 80 | + | ||
| 81 | + int32_t feat_dim = s->FeatureDim(); | ||
| 82 | + std::vector<float> f = s->GetFrames(); | ||
| 83 | + int32_t num_frames = f.size() / feat_dim; | ||
| 84 | + | ||
| 85 | + if (num_frames > max_num_frames) { | ||
| 86 | + SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported."); | ||
| 87 | + exit(-1); | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + NormalizeFeatures(f.data(), num_frames, feat_dim); | ||
| 91 | + | ||
| 92 | + std::array<int64_t, 3> shape{1, max_num_frames, feat_dim}; | ||
| 93 | + | ||
| 94 | + Ort::Value mel = Ort::Value::CreateTensor<float>( | ||
| 95 | + model_->Allocator(), shape.data(), shape.size()); | ||
| 96 | + float *p_mel = mel.GetTensorMutableData<float>(); | ||
| 97 | + std::copy(f.begin(), f.end(), p_mel); | ||
| 98 | + | ||
| 99 | + memset(p_mel + f.size(), 0, | ||
| 100 | + (max_num_frames - num_frames) * feat_dim * sizeof(float)); | ||
| 101 | + mel = Transpose12(model_->Allocator(), &mel); | ||
| 102 | + | ||
| 103 | + auto cross_kv = model_->ForwardEncoder(std::move(mel)); | ||
| 104 | + auto results = | ||
| 105 | + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); | ||
| 106 | + | ||
| 107 | + auto r = Convert(results[0], symbol_table_); | ||
| 108 | + s->SetResult(r); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + private: | ||
| 112 | + static void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 113 | + int32_t feat_dim) { | ||
| 114 | + // log_spec = torch.clamp(features, min=1e-10).log10() | ||
| 115 | + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | ||
| 116 | + // mel = (log_spec + 4.0) / 4.0 | ||
| 117 | + | ||
| 118 | + int32_t n = num_frames * feat_dim; | ||
| 119 | + float max_v = -1e20; | ||
| 120 | + for (int32_t i = 0; i != n; ++i) { | ||
| 121 | + float f = features[i]; | ||
| 122 | + | ||
| 123 | + f = std::max<float>(f, 1e-10); | ||
| 124 | + f = std::log10(f); | ||
| 125 | + | ||
| 126 | + max_v = std::max(f, max_v); | ||
| 127 | + | ||
| 128 | + features[i] = f; | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + max_v -= 8; | ||
| 132 | + | ||
| 133 | + for (int32_t i = 0; i != n; ++i) { | ||
| 134 | + float f = features[i]; | ||
| 135 | + f = std::max(f, max_v); | ||
| 136 | + | ||
| 137 | + f = (f + 4) / 4; | ||
| 138 | + | ||
| 139 | + features[i] = f; | ||
| 140 | + } | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + private: | ||
| 144 | + OfflineRecognizerConfig config_; | ||
| 145 | + SymbolTable symbol_table_; | ||
| 146 | + std::unique_ptr<OfflineWhisperModel> model_; | ||
| 147 | + std::unique_ptr<OfflineWhisperDecoder> decoder_; | ||
| 148 | +}; | ||
| 149 | + | ||
| 150 | +} // namespace sherpa_onnx | ||
| 151 | + | ||
| 152 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ |
| @@ -86,6 +86,15 @@ class OfflineStream::Impl { | @@ -86,6 +86,15 @@ class OfflineStream::Impl { | ||
| 86 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 86 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 87 | } | 87 | } |
| 88 | 88 | ||
| 89 | + Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph) | ||
| 90 | + : context_graph_(context_graph) { | ||
| 91 | + config_.normalize_samples = true; | ||
| 92 | + opts_.frame_opts.samp_freq = 16000; | ||
| 93 | + opts_.mel_opts.num_bins = 80; | ||
| 94 | + whisper_fbank_ = | ||
| 95 | + std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); | ||
| 96 | + } | ||
| 97 | + | ||
| 89 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 98 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 90 | if (config_.normalize_samples) { | 99 | if (config_.normalize_samples) { |
| 91 | AcceptWaveformImpl(sampling_rate, waveform, n); | 100 | AcceptWaveformImpl(sampling_rate, waveform, n); |
| @@ -117,20 +126,35 @@ class OfflineStream::Impl { | @@ -117,20 +126,35 @@ class OfflineStream::Impl { | ||
| 117 | lowpass_filter_width); | 126 | lowpass_filter_width); |
| 118 | std::vector<float> samples; | 127 | std::vector<float> samples; |
| 119 | resampler->Resample(waveform, n, true, &samples); | 128 | resampler->Resample(waveform, n, true, &samples); |
| 120 | - fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), | ||
| 121 | - samples.size()); | ||
| 122 | - fbank_->InputFinished(); | 129 | + |
| 130 | + if (fbank_) { | ||
| 131 | + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), | ||
| 132 | + samples.size()); | ||
| 133 | + fbank_->InputFinished(); | ||
| 134 | + } else { | ||
| 135 | + whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, | ||
| 136 | + samples.data(), samples.size()); | ||
| 137 | + whisper_fbank_->InputFinished(); | ||
| 138 | + } | ||
| 139 | + | ||
| 123 | return; | 140 | return; |
| 124 | - } | 141 | + } // if (sampling_rate != opts_.frame_opts.samp_freq) |
| 125 | 142 | ||
| 126 | - fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 127 | - fbank_->InputFinished(); | 143 | + if (fbank_) { |
| 144 | + fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 145 | + fbank_->InputFinished(); | ||
| 146 | + } else { | ||
| 147 | + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 148 | + whisper_fbank_->InputFinished(); | ||
| 149 | + } | ||
| 128 | } | 150 | } |
| 129 | 151 | ||
| 130 | int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | 152 | int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } |
| 131 | 153 | ||
| 132 | std::vector<float> GetFrames() const { | 154 | std::vector<float> GetFrames() const { |
| 133 | - int32_t n = fbank_->NumFramesReady(); | 155 | + int32_t n = |
| 156 | + fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady(); | ||
| 157 | + | ||
| 134 | assert(n > 0 && "Please first call AcceptWaveform()"); | 158 | assert(n > 0 && "Please first call AcceptWaveform()"); |
| 135 | 159 | ||
| 136 | int32_t feature_dim = FeatureDim(); | 160 | int32_t feature_dim = FeatureDim(); |
| @@ -140,7 +164,8 @@ class OfflineStream::Impl { | @@ -140,7 +164,8 @@ class OfflineStream::Impl { | ||
| 140 | float *p = features.data(); | 164 | float *p = features.data(); |
| 141 | 165 | ||
| 142 | for (int32_t i = 0; i != n; ++i) { | 166 | for (int32_t i = 0; i != n; ++i) { |
| 143 | - const float *f = fbank_->GetFrame(i); | 167 | + const float *f = |
| 168 | + fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i); | ||
| 144 | std::copy(f, f + feature_dim, p); | 169 | std::copy(f, f + feature_dim, p); |
| 145 | p += feature_dim; | 170 | p += feature_dim; |
| 146 | } | 171 | } |
| @@ -191,6 +216,7 @@ class OfflineStream::Impl { | @@ -191,6 +216,7 @@ class OfflineStream::Impl { | ||
| 191 | private: | 216 | private: |
| 192 | OfflineFeatureExtractorConfig config_; | 217 | OfflineFeatureExtractorConfig config_; |
| 193 | std::unique_ptr<knf::OnlineFbank> fbank_; | 218 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 219 | + std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; | ||
| 194 | knf::FbankOptions opts_; | 220 | knf::FbankOptions opts_; |
| 195 | OfflineRecognitionResult r_; | 221 | OfflineRecognitionResult r_; |
| 196 | ContextGraphPtr context_graph_; | 222 | ContextGraphPtr context_graph_; |
| @@ -201,6 +227,10 @@ OfflineStream::OfflineStream( | @@ -201,6 +227,10 @@ OfflineStream::OfflineStream( | ||
| 201 | ContextGraphPtr context_graph /*= nullptr*/) | 227 | ContextGraphPtr context_graph /*= nullptr*/) |
| 202 | : impl_(std::make_unique<Impl>(config, context_graph)) {} | 228 | : impl_(std::make_unique<Impl>(config, context_graph)) {} |
| 203 | 229 | ||
| 230 | +OfflineStream::OfflineStream(WhisperTag tag, | ||
| 231 | + ContextGraphPtr context_graph /*= nullptr*/) | ||
| 232 | + : impl_(std::make_unique<Impl>(tag, context_graph)) {} | ||
| 233 | + | ||
| 204 | OfflineStream::~OfflineStream() = default; | 234 | OfflineStream::~OfflineStream() = default; |
| 205 | 235 | ||
| 206 | void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, | 236 | void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| @@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig { | @@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig { | ||
| 65 | void Register(ParseOptions *po); | 65 | void Register(ParseOptions *po); |
| 66 | }; | 66 | }; |
| 67 | 67 | ||
| 68 | +struct WhisperTag {}; | ||
| 69 | + | ||
| 68 | class OfflineStream { | 70 | class OfflineStream { |
| 69 | public: | 71 | public: |
| 70 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, | 72 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, |
| 71 | ContextGraphPtr context_graph = nullptr); | 73 | ContextGraphPtr context_graph = nullptr); |
| 74 | + | ||
| 75 | + explicit OfflineStream(WhisperTag tag, | ||
| 76 | + ContextGraphPtr context_graph = nullptr); | ||
| 72 | ~OfflineStream(); | 77 | ~OfflineStream(); |
| 73 | 78 | ||
| 74 | /** | 79 | /** |
| @@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | @@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 18 | 18 | ||
| 19 | bool OfflineTransducerModelConfig::Validate() const { | 19 | bool OfflineTransducerModelConfig::Validate() const { |
| 20 | if (!FileExists(encoder_filename)) { | 20 | if (!FileExists(encoder_filename)) { |
| 21 | - SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); | 21 | + SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", |
| 22 | + encoder_filename.c_str()); | ||
| 22 | return false; | 23 | return false; |
| 23 | } | 24 | } |
| 24 | 25 | ||
| 25 | if (!FileExists(decoder_filename)) { | 26 | if (!FileExists(decoder_filename)) { |
| 26 | - SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); | 27 | + SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", |
| 28 | + decoder_filename.c_str()); | ||
| 27 | return false; | 29 | return false; |
| 28 | } | 30 | } |
| 29 | 31 | ||
| 30 | if (!FileExists(joiner_filename)) { | 32 | if (!FileExists(joiner_filename)) { |
| 31 | - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); | 33 | + SHERPA_ONNX_LOGE("transducer joiner: %s does not exist", |
| 34 | + joiner_filename.c_str()); | ||
| 32 | return false; | 35 | return false; |
| 33 | } | 36 | } |
| 34 | 37 |
sherpa-onnx/csrc/offline-whisper-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-whisper-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineWhisperDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int32_t> tokens; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +class OfflineWhisperDecoder { | ||
| 20 | + public: | ||
| 21 | + virtual ~OfflineWhisperDecoder() = default; | ||
| 22 | + | ||
| 23 | + /** Run beam search given the output from the whisper encoder model. | ||
| 24 | + * | ||
| 25 | + * @param n_layer_cross_k A 4-D tensor of shape | ||
| 26 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 27 | + * @param n_layer_cross_v A 4-D tensor of shape | ||
| 28 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 29 | + * | ||
| 30 | + * @return Return a vector of size `N` containing the decoded results. | ||
| 31 | + */ | ||
| 32 | + virtual std::vector<OfflineWhisperDecoderResult> Decode( | ||
| 33 | + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | ||
| 34 | +}; | ||
| 35 | + | ||
| 36 | +} // namespace sherpa_onnx | ||
| 37 | + | ||
| 38 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +std::vector<OfflineWhisperDecoderResult> | ||
| 13 | +OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 14 | + Ort::Value cross_v) { | ||
| 15 | + auto memory_info = | ||
| 16 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 17 | + | ||
| 18 | + auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
| 19 | + | ||
| 20 | + std::vector<int64_t> initial_tokens = model_->GetInitialTokens(); | ||
| 21 | + int32_t batch_size = 1; | ||
| 22 | + std::array<int64_t, 2> token_shape{ | ||
| 23 | + batch_size, static_cast<int64_t>(initial_tokens.size())}; | ||
| 24 | + | ||
| 25 | + Ort::Value tokens = Ort::Value::CreateTensor( | ||
| 26 | + memory_info, initial_tokens.data(), initial_tokens.size(), | ||
| 27 | + token_shape.data(), token_shape.size()); | ||
| 28 | + | ||
| 29 | + std::array<int64_t, 1> offset_shape{1}; | ||
| 30 | + Ort::Value offset = Ort::Value::CreateTensor<int64_t>( | ||
| 31 | + model_->Allocator(), offset_shape.data(), offset_shape.size()); | ||
| 32 | + *(offset.GetTensorMutableData<int64_t>()) = 0; | ||
| 33 | + | ||
| 34 | + auto decoder_out = model_->ForwardDecoder( | ||
| 35 | + std::move(tokens), std::move(self_kv_cache.first), | ||
| 36 | + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), | ||
| 37 | + std::move(offset)); | ||
| 38 | + | ||
| 39 | + const auto &logits = std::get<0>(decoder_out); | ||
| 40 | + const float *p_logits = logits.GetTensorData<float>(); | ||
| 41 | + | ||
| 42 | + auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 43 | + int32_t vocab_size = logits_shape[2]; | ||
| 44 | + | ||
| 45 | + int32_t max_token_id = static_cast<int32_t>(std::distance( | ||
| 46 | + p_logits, std::max_element(p_logits, p_logits + vocab_size))); | ||
| 47 | + | ||
| 48 | + int32_t n_text_ctx = model_->TextCtx(); | ||
| 49 | + | ||
| 50 | + std::vector<int32_t> predicted_tokens; | ||
| 51 | + for (int32_t i = 0; i < n_text_ctx; ++i) { | ||
| 52 | + if (max_token_id == model_->EOT()) { | ||
| 53 | + break; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + predicted_tokens.push_back(max_token_id); | ||
| 57 | + | ||
| 58 | + std::array<int64_t, 2> token_shape{1, 1}; | ||
| 59 | + Ort::Value tokens = Ort::Value::CreateTensor<int64_t>( | ||
| 60 | + model_->Allocator(), token_shape.data(), token_shape.size()); | ||
| 61 | + int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>(); | ||
| 62 | + p_tokens[0] = max_token_id; | ||
| 63 | + | ||
| 64 | + int64_t *p_offset = | ||
| 65 | + std::get<5>(decoder_out).GetTensorMutableData<int64_t>(); | ||
| 66 | + | ||
| 67 | + if (i == 0) { | ||
| 68 | + *p_offset = initial_tokens.size(); | ||
| 69 | + } else { | ||
| 70 | + *p_offset += 1; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + decoder_out = model_->ForwardDecoder(std::move(tokens), | ||
| 74 | + std::move(std::get<1>(decoder_out)), | ||
| 75 | + std::move(std::get<2>(decoder_out)), | ||
| 76 | + std::move(std::get<3>(decoder_out)), | ||
| 77 | + std::move(std::get<4>(decoder_out)), | ||
| 78 | + std::move(std::get<5>(decoder_out))); | ||
| 79 | + | ||
| 80 | + const auto &logits = std::get<0>(decoder_out); | ||
| 81 | + const float *p_logits = logits.GetTensorData<float>(); | ||
| 82 | + | ||
| 83 | + max_token_id = static_cast<int64_t>(std::distance( | ||
| 84 | + p_logits, std::max_element(p_logits, p_logits + vocab_size))); | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + std::vector<OfflineWhisperDecoderResult> ans(1); | ||
| 88 | + ans[0].tokens = std::move(predicted_tokens); | ||
| 89 | + | ||
| 90 | + return ans; | ||
| 91 | +} | ||
| 92 | + | ||
| 93 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-whisper-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | ||
| 16 | + public: | ||
| 17 | + explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model) | ||
| 18 | + : model_(model) {} | ||
| 19 | + | ||
| 20 | + std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, | ||
| 21 | + Ort::Value cross_v) override; | ||
| 22 | + | ||
| 23 | + private: | ||
| 24 | + OfflineWhisperModel *model_; // not owned | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-whisper-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineWhisperModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("whisper-encoder", &encoder, | ||
| 14 | + "Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, " | ||
| 15 | + "medium.en-encoder.onnx."); | ||
| 16 | + | ||
| 17 | + po->Register("whisper-decoder", &decoder, | ||
| 18 | + "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " | ||
| 19 | + "medium.en-decoder.onnx."); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +bool OfflineWhisperModelConfig::Validate() const { | ||
| 23 | + if (!FileExists(encoder)) { | ||
| 24 | + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); | ||
| 25 | + return false; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + if (!FileExists(decoder)) { | ||
| 29 | + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); | ||
| 30 | + return false; | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + return true; | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +std::string OfflineWhisperModelConfig::ToString() const { | ||
| 37 | + std::ostringstream os; | ||
| 38 | + | ||
| 39 | + os << "OfflineWhisperModelConfig("; | ||
| 40 | + os << "encoder=\"" << encoder << "\", "; | ||
| 41 | + os << "decoder=\"" << decoder << "\")"; | ||
| 42 | + | ||
| 43 | + return os.str(); | ||
| 44 | +} | ||
| 45 | + | ||
| 46 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-whisper-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineWhisperModelConfig { | ||
| 14 | + std::string encoder; | ||
| 15 | + std::string decoder; | ||
| 16 | + | ||
| 17 | + OfflineWhisperModelConfig() = default; | ||
| 18 | + OfflineWhisperModelConfig(const std::string &encoder, | ||
| 19 | + const std::string &decoder) | ||
| 20 | + : encoder(encoder), decoder(decoder) {} | ||
| 21 | + | ||
| 22 | + void Register(ParseOptions *po); | ||
| 23 | + bool Validate() const; | ||
| 24 | + | ||
| 25 | + std::string ToString() const; | ||
| 26 | +}; | ||
| 27 | + | ||
| 28 | +} // namespace sherpa_onnx | ||
| 29 | + | ||
| 30 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-whisper-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-whisper-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <tuple> | ||
| 10 | +#include <utility> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | +#include "sherpa-onnx/csrc/session.h" | ||
| 15 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OfflineWhisperModel::Impl { | ||
| 20 | + public: | ||
| 21 | + explicit Impl(const OfflineModelConfig &config) | ||
| 22 | + : config_(config), | ||
| 23 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 24 | + sess_opts_(GetSessionOptions(config)), | ||
| 25 | + allocator_{} { | ||
| 26 | + { | ||
| 27 | + auto buf = ReadFile(config.whisper.encoder); | ||
| 28 | + InitEncoder(buf.data(), buf.size()); | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + { | ||
| 32 | + auto buf = ReadFile(config.whisper.decoder); | ||
| 33 | + InitDecoder(buf.data(), buf.size()); | ||
| 34 | + } | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) { | ||
| 38 | + auto encoder_out = encoder_sess_->Run( | ||
| 39 | + {}, encoder_input_names_ptr_.data(), &features, 1, | ||
| 40 | + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); | ||
| 41 | + | ||
| 42 | + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value, | ||
| 46 | + Ort::Value> | ||
| 47 | + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, | ||
| 48 | + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, | ||
| 49 | + Ort::Value n_layer_cross_v, Ort::Value offset) { | ||
| 50 | + std::array<Ort::Value, 6> decoder_input = {std::move(tokens), | ||
| 51 | + std::move(n_layer_self_k_cache), | ||
| 52 | + std::move(n_layer_self_v_cache), | ||
| 53 | + std::move(n_layer_cross_k), | ||
| 54 | + std::move(n_layer_cross_v), | ||
| 55 | + std::move(offset)}; | ||
| 56 | + | ||
| 57 | + auto decoder_out = decoder_sess_->Run( | ||
| 58 | + {}, decoder_input_names_ptr_.data(), decoder_input.data(), | ||
| 59 | + decoder_input.size(), decoder_output_names_ptr_.data(), | ||
| 60 | + decoder_output_names_ptr_.size()); | ||
| 61 | + | ||
| 62 | + return {std::move(decoder_out[0]), std::move(decoder_out[1]), | ||
| 63 | + std::move(decoder_out[2]), std::move(decoder_input[3]), | ||
| 64 | + std::move(decoder_input[4]), std::move(decoder_input[5])}; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() { | ||
| 68 | + std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; | ||
| 69 | + | ||
| 70 | + Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>( | ||
| 71 | + Allocator(), shape.data(), shape.size()); | ||
| 72 | + | ||
| 73 | + Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>( | ||
| 74 | + Allocator(), shape.data(), shape.size()); | ||
| 75 | + | ||
| 76 | + auto n = shape[0] * shape[1] * shape[2] * shape[3]; | ||
| 77 | + | ||
| 78 | + float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>(); | ||
| 79 | + float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>(); | ||
| 80 | + | ||
| 81 | + memset(p_k, 0, sizeof(float) * n); | ||
| 82 | + memset(p_v, 0, sizeof(float) * n); | ||
| 83 | + | ||
| 84 | + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 88 | + | ||
| 89 | + const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } | ||
| 90 | + | ||
| 91 | + int32_t EOT() const { return eot_; } | ||
| 92 | + | ||
| 93 | + int32_t TextCtx() const { return n_text_ctx_; } | ||
| 94 | + | ||
| 95 | + private: | ||
| 96 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 97 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 98 | + env_, model_data, model_data_length, sess_opts_); | ||
| 99 | + | ||
| 100 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 101 | + &encoder_input_names_ptr_); | ||
| 102 | + | ||
| 103 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 104 | + &encoder_output_names_ptr_); | ||
| 105 | + | ||
| 106 | + // get meta data | ||
| 107 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 108 | + if (config_.debug) { | ||
| 109 | + std::ostringstream os; | ||
| 110 | + os << "---encoder---\n"; | ||
| 111 | + PrintModelMetadata(os, meta_data); | ||
| 112 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 116 | + SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); | ||
| 117 | + SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); | ||
| 118 | + SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); | ||
| 119 | + SHERPA_ONNX_READ_META_DATA(sot_, "sot"); | ||
| 120 | + SHERPA_ONNX_READ_META_DATA(eot_, "eot"); | ||
| 121 | + SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); | ||
| 122 | + SHERPA_ONNX_READ_META_DATA(translate_, "translate"); | ||
| 123 | + SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); | ||
| 124 | + SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); | ||
| 125 | + SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + void InitDecoder(void *model_data, size_t model_data_length) { | ||
| 129 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 130 | + env_, model_data, model_data_length, sess_opts_); | ||
| 131 | + | ||
| 132 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 133 | + &decoder_input_names_ptr_); | ||
| 134 | + | ||
| 135 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 136 | + &decoder_output_names_ptr_); | ||
| 137 | + } | ||
| 138 | + | ||
| 139 | + private: | ||
| 140 | + OfflineModelConfig config_; | ||
| 141 | + Ort::Env env_; | ||
| 142 | + Ort::SessionOptions sess_opts_; | ||
| 143 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 144 | + | ||
| 145 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 146 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 147 | + | ||
| 148 | + std::vector<std::string> encoder_input_names_; | ||
| 149 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 150 | + | ||
| 151 | + std::vector<std::string> encoder_output_names_; | ||
| 152 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 153 | + | ||
| 154 | + std::vector<std::string> decoder_input_names_; | ||
| 155 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 156 | + | ||
| 157 | + std::vector<std::string> decoder_output_names_; | ||
| 158 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 159 | + | ||
| 160 | + // model meta data | ||
| 161 | + int32_t n_text_layer_; | ||
| 162 | + int32_t n_text_ctx_; | ||
| 163 | + int32_t n_text_state_; | ||
| 164 | + int32_t sot_; | ||
| 165 | + int32_t eot_; | ||
| 166 | + int32_t blank_; | ||
| 167 | + int32_t translate_; | ||
| 168 | + int32_t no_timestamps_; | ||
| 169 | + int32_t no_speech_; | ||
| 170 | + std::vector<int64_t> sot_sequence_; | ||
| 171 | +}; | ||
| 172 | + | ||
| 173 | +OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) | ||
| 174 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 175 | + | ||
| 176 | +OfflineWhisperModel::~OfflineWhisperModel() = default; | ||
| 177 | + | ||
| 178 | +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder( | ||
| 179 | + Ort::Value features) { | ||
| 180 | + return impl_->ForwardEncoder(std::move(features)); | ||
| 181 | +} | ||
| 182 | + | ||
| 183 | +std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value, | ||
| 184 | + Ort::Value> | ||
| 185 | +OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, | ||
| 186 | + Ort::Value n_layer_self_k_cache, | ||
| 187 | + Ort::Value n_layer_self_v_cache, | ||
| 188 | + Ort::Value n_layer_cross_k, | ||
| 189 | + Ort::Value n_layer_cross_v, | ||
| 190 | + Ort::Value offset) { | ||
| 191 | + return impl_->ForwardDecoder( | ||
| 192 | + std::move(tokens), std::move(n_layer_self_k_cache), | ||
| 193 | + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), | ||
| 194 | + std::move(n_layer_cross_v), std::move(offset)); | ||
| 195 | +} | ||
| 196 | + | ||
| 197 | +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() { | ||
| 198 | + return impl_->GetInitialSelfKVCache(); | ||
| 199 | +} | ||
| 200 | + | ||
| 201 | +OrtAllocator *OfflineWhisperModel::Allocator() const { | ||
| 202 | + return impl_->Allocator(); | ||
| 203 | +} | ||
| 204 | + | ||
| 205 | +const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const { | ||
| 206 | + return impl_->GetInitialTokens(); | ||
| 207 | +} | ||
| 208 | + | ||
| 209 | +int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } | ||
| 210 | + | ||
| 211 | +int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } | ||
| 212 | + | ||
| 213 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-whisper-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-whisper-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <tuple> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineWhisperModel { | ||
| 18 | + public: | ||
| 19 | + explicit OfflineWhisperModel(const OfflineModelConfig &config); | ||
| 20 | + ~OfflineWhisperModel(); | ||
| 21 | + | ||
| 22 | + /** Run the encoder model. | ||
| 23 | + * | ||
| 24 | + * @param features A tensor of shape (N, C, T). It is changed in-place. | ||
| 25 | + * C is 80 and T is 3000. | ||
| 26 | + * | ||
| 27 | + * @return Return a pair containing: | ||
| 28 | + * - n_layer_cross_k: A 4-D tensor of shape | ||
| 29 | + * (n_text_layer, N, n_audio_ctx, n_text_state) | ||
| 30 | + * - n_layer_cross_v: A 4-D tensor of shape | ||
| 31 | + * (n_text_layer, N, n_audio_ctx, n_text_state) | ||
| 32 | + */ | ||
| 33 | + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features); | ||
| 34 | + | ||
| 35 | + /** Run the decoder model. | ||
| 36 | + * | ||
| 37 | + * @param tokens A int64 tensor of shape (N, num_words) | ||
| 38 | + * @param n_layer_self_k_cache A 4-D tensor of shape | ||
| 39 | + * (n_text_layer, N, n_text_ctx, n_text_state). | ||
| 40 | + * @param n_layer_self_v_cache A 4-D tensor of shape | ||
| 41 | + * (n_text_layer, N, n_text_ctx, n_text_state). | ||
| 42 | + * @param n_layer_cross_k A 4-D tensor of shape | ||
| 43 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 44 | + * @param n_layer_cross_v A 4-D tensor of shape | ||
| 45 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 46 | + * @param offset A int64 tensor of shape (N,) | ||
| 47 | + * | ||
| 48 | + * @return Return a tuple containing 6 tensors: | ||
| 49 | + * | ||
| 50 | + * - logits A 3-D tensor of shape (N, num_words, vocab_size) | ||
| 51 | + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache | ||
| 52 | + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache | ||
| 53 | + * - out_n_layer_cross_k Same as n_layer_cross_k | ||
| 54 | + * - out_n_layer_cross_v Same as n_layer_cross_v | ||
| 55 | + * - out_offset Same as offset | ||
| 56 | + */ | ||
| 57 | + std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value, | ||
| 58 | + Ort::Value> | ||
| 59 | + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, | ||
| 60 | + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, | ||
| 61 | + Ort::Value n_layer_cross_v, Ort::Value offset); | ||
| 62 | + | ||
| 63 | + /** Return the initial self kv cache in a pair | ||
| 64 | + * - n_layer_self_k_cache A 4-D tensor of shape | ||
| 65 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 66 | + * - n_layer_self_v_cache A 4-D tensor of shape | ||
| 67 | + * (n_text_layer, N, n_audio_ctx, n_text_state). | ||
| 68 | + */ | ||
| 69 | + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache(); | ||
| 70 | + const std::vector<int64_t> &GetInitialTokens() const; | ||
| 71 | + | ||
| 72 | + /** Return an allocator for allocating memory | ||
| 73 | + */ | ||
| 74 | + OrtAllocator *Allocator() const; | ||
| 75 | + int32_t EOT() const; | ||
| 76 | + int32_t TextCtx() const; | ||
| 77 | + | ||
| 78 | + private: | ||
| 79 | + class Impl; | ||
| 80 | + std::unique_ptr<Impl> impl_; | ||
| 81 | +}; | ||
| 82 | + | ||
| 83 | +} // namespace sherpa_onnx | ||
| 84 | + | ||
| 85 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ |
| @@ -98,11 +98,15 @@ Usage: | @@ -98,11 +98,15 @@ Usage: | ||
| 98 | ./bin/sherpa-onnx-microphone-offline \ | 98 | ./bin/sherpa-onnx-microphone-offline \ |
| 99 | --tokens=/path/to/tokens.txt \ | 99 | --tokens=/path/to/tokens.txt \ |
| 100 | --paraformer=/path/to/model.onnx \ | 100 | --paraformer=/path/to/model.onnx \ |
| 101 | - --num-threads=2 \ | ||
| 102 | - --decoding-method=greedy_search | 101 | + --num-threads=1 |
| 103 | 102 | ||
| 104 | -Default value for num_threads is 2. | ||
| 105 | -Valid values for decoding_method: greedy_search. | 103 | +(3) Whisper models |
| 104 | + | ||
| 105 | + ./bin/sherpa-onnx-microphone-offline \ | ||
| 106 | + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ | ||
| 107 | + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ | ||
| 108 | + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ | ||
| 109 | + --num-threads=1 | ||
| 106 | 110 | ||
| 107 | Please refer to | 111 | Please refer to |
| 108 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 112 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| @@ -23,7 +23,7 @@ Usage: | @@ -23,7 +23,7 @@ Usage: | ||
| 23 | --encoder=/path/to/encoder.onnx \ | 23 | --encoder=/path/to/encoder.onnx \ |
| 24 | --decoder=/path/to/decoder.onnx \ | 24 | --decoder=/path/to/decoder.onnx \ |
| 25 | --joiner=/path/to/joiner.onnx \ | 25 | --joiner=/path/to/joiner.onnx \ |
| 26 | - --num-threads=2 \ | 26 | + --num-threads=1 \ |
| 27 | --decoding-method=greedy_search \ | 27 | --decoding-method=greedy_search \ |
| 28 | /path/to/foo.wav [bar.wav foobar.wav ...] | 28 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 29 | 29 | ||
| @@ -33,14 +33,22 @@ Usage: | @@ -33,14 +33,22 @@ Usage: | ||
| 33 | ./bin/sherpa-onnx-offline \ | 33 | ./bin/sherpa-onnx-offline \ |
| 34 | --tokens=/path/to/tokens.txt \ | 34 | --tokens=/path/to/tokens.txt \ |
| 35 | --paraformer=/path/to/model.onnx \ | 35 | --paraformer=/path/to/model.onnx \ |
| 36 | - --num-threads=2 \ | 36 | + --num-threads=1 \ |
| 37 | --decoding-method=greedy_search \ | 37 | --decoding-method=greedy_search \ |
| 38 | /path/to/foo.wav [bar.wav foobar.wav ...] | 38 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 39 | 39 | ||
| 40 | +(3) Whisper models | ||
| 41 | + | ||
| 42 | + ./bin/sherpa-onnx-offline \ | ||
| 43 | + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ | ||
| 44 | + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ | ||
| 45 | + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ | ||
| 46 | + --num-threads=1 \ | ||
| 47 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 48 | + | ||
| 49 | + | ||
| 40 | Note: It supports decoding multiple files in batches | 50 | Note: It supports decoding multiple files in batches |
| 41 | 51 | ||
| 42 | -Default value for num_threads is 2. | ||
| 43 | -Valid values for decoding_method: greedy_search. | ||
| 44 | foo.wav should be of single channel, 16-bit PCM encoded wave file; its | 52 | foo.wav should be of single channel, 16-bit PCM encoded wave file; its |
| 45 | sampling rate can be arbitrary and does not need to be 16kHz. | 53 | sampling rate can be arbitrary and does not need to be 16kHz. |
| 46 | 54 | ||
| @@ -55,6 +63,7 @@ for a list of pre-trained models to download. | @@ -55,6 +63,7 @@ for a list of pre-trained models to download. | ||
| 55 | 63 | ||
| 56 | po.Read(argc, argv); | 64 | po.Read(argc, argv); |
| 57 | if (po.NumArgs() < 1) { | 65 | if (po.NumArgs() < 1) { |
| 66 | + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n"); | ||
| 58 | po.PrintUsage(); | 67 | po.PrintUsage(); |
| 59 | exit(EXIT_FAILURE); | 68 | exit(EXIT_FAILURE); |
| 60 | } | 69 | } |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <sstream> | 9 | #include <sstream> |
| 10 | #include <strstream> | 10 | #include <strstream> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/base64-decode.h" | ||
| 12 | #include "sherpa-onnx/csrc/onnx-utils.h" | 13 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 13 | 14 | ||
| 14 | #if __ANDROID_API__ >= 9 | 15 | #if __ANDROID_API__ >= 9 |
| @@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { | @@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { | ||
| 82 | return os << symbol_table.ToString(); | 83 | return os << symbol_table.ToString(); |
| 83 | } | 84 | } |
| 84 | 85 | ||
| 86 | +void SymbolTable::ApplyBase64Decode() { | ||
| 87 | + sym2id_.clear(); | ||
| 88 | + for (auto &p : id2sym_) { | ||
| 89 | + p.second = Base64Decode(p.second); | ||
| 90 | + sym2id_[p.second] = p.first; | ||
| 91 | + } | ||
| 92 | +} | ||
| 93 | + | ||
| 85 | } // namespace sherpa_onnx | 94 | } // namespace sherpa_onnx |
| @@ -45,6 +45,9 @@ class SymbolTable { | @@ -45,6 +45,9 @@ class SymbolTable { | ||
| 45 | /// Return true if there is a given symbol in the symbol table. | 45 | /// Return true if there is a given symbol in the symbol table. |
| 46 | bool contains(const std::string &sym) const; | 46 | bool contains(const std::string &sym) const; |
| 47 | 47 | ||
| 48 | + // for tokens.txt from Whisper | ||
| 49 | + void ApplyBase64Decode(); | ||
| 50 | + | ||
| 48 | private: | 51 | private: |
| 49 | void Init(std::istream &is); | 52 | void Init(std::istream &is); |
| 50 | 53 |
| @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx | @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 11 | offline-recognizer.cc | 11 | offline-recognizer.cc |
| 12 | offline-stream.cc | 12 | offline-stream.cc |
| 13 | offline-transducer-model-config.cc | 13 | offline-transducer-model-config.cc |
| 14 | + offline-whisper-model-config.cc | ||
| 14 | online-lm-config.cc | 15 | online-lm-config.cc |
| 15 | online-recognizer.cc | 16 | online-recognizer.cc |
| 16 | online-stream.cc | 17 | online-stream.cc |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 14 | +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| @@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 18 | PybindOfflineTransducerModelConfig(m); | 19 | PybindOfflineTransducerModelConfig(m); |
| 19 | PybindOfflineParaformerModelConfig(m); | 20 | PybindOfflineParaformerModelConfig(m); |
| 20 | PybindOfflineNemoEncDecCtcModelConfig(m); | 21 | PybindOfflineNemoEncDecCtcModelConfig(m); |
| 22 | + PybindOfflineWhisperModelConfig(m); | ||
| 21 | 23 | ||
| 22 | using PyClass = OfflineModelConfig; | 24 | using PyClass = OfflineModelConfig; |
| 23 | py::class_<PyClass>(*m, "OfflineModelConfig") | 25 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| 24 | - .def( | ||
| 25 | - py::init<const OfflineTransducerModelConfig &, | ||
| 26 | - const OfflineParaformerModelConfig &, | ||
| 27 | - const OfflineNemoEncDecCtcModelConfig &, const std::string &, | ||
| 28 | - int32_t, bool, const std::string &, const std::string &>(), | ||
| 29 | - py::arg("transducer") = OfflineTransducerModelConfig(), | ||
| 30 | - py::arg("paraformer") = OfflineParaformerModelConfig(), | ||
| 31 | - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | ||
| 32 | - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | ||
| 33 | - py::arg("provider") = "cpu", py::arg("model_type") = "") | 26 | + .def(py::init<const OfflineTransducerModelConfig &, |
| 27 | + const OfflineParaformerModelConfig &, | ||
| 28 | + const OfflineNemoEncDecCtcModelConfig &, | ||
| 29 | + const OfflineWhisperModelConfig &, const std::string &, | ||
| 30 | + int32_t, bool, const std::string &, const std::string &>(), | ||
| 31 | + py::arg("transducer") = OfflineTransducerModelConfig(), | ||
| 32 | + py::arg("paraformer") = OfflineParaformerModelConfig(), | ||
| 33 | + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | ||
| 34 | + py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), | ||
| 35 | + py::arg("num_threads"), py::arg("debug") = false, | ||
| 36 | + py::arg("provider") = "cpu", py::arg("model_type") = "") | ||
| 34 | .def_readwrite("transducer", &PyClass::transducer) | 37 | .def_readwrite("transducer", &PyClass::transducer) |
| 35 | .def_readwrite("paraformer", &PyClass::paraformer) | 38 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 36 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 39 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| 40 | + .def_readwrite("whisper", &PyClass::whisper) | ||
| 37 | .def_readwrite("tokens", &PyClass::tokens) | 41 | .def_readwrite("tokens", &PyClass::tokens) |
| 38 | .def_readwrite("num_threads", &PyClass::num_threads) | 42 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 39 | .def_readwrite("debug", &PyClass::debug) | 43 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/offline-whisper-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineWhisperModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineWhisperModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineWhisperModelConfig") | ||
| 17 | + .def(py::init<const std::string &, const std::string &>(), | ||
| 18 | + py::arg("encoder"), py::arg("decoder")) | ||
| 19 | + .def_readwrite("encoder", &PyClass::encoder) | ||
| 20 | + .def_readwrite("decoder", &PyClass::decoder) | ||
| 21 | + .def("__str__", &PyClass::ToString); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-whisper-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineWhisperModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ |
| 1 | # Copyright (c) 2023 by manyeyes | 1 | # Copyright (c) 2023 by manyeyes |
| 2 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 2 | from pathlib import Path | 3 | from pathlib import Path |
| 3 | from typing import List, Optional | 4 | from typing import List, Optional |
| 4 | 5 | ||
| @@ -7,6 +8,7 @@ from _sherpa_onnx import ( | @@ -7,6 +8,7 @@ from _sherpa_onnx import ( | ||
| 7 | OfflineModelConfig, | 8 | OfflineModelConfig, |
| 8 | OfflineNemoEncDecCtcModelConfig, | 9 | OfflineNemoEncDecCtcModelConfig, |
| 9 | OfflineParaformerModelConfig, | 10 | OfflineParaformerModelConfig, |
| 11 | + OfflineWhisperModelConfig, | ||
| 10 | ) | 12 | ) |
| 11 | from _sherpa_onnx import OfflineRecognizer as _Recognizer | 13 | from _sherpa_onnx import OfflineRecognizer as _Recognizer |
| 12 | from _sherpa_onnx import ( | 14 | from _sherpa_onnx import ( |
| @@ -69,7 +71,7 @@ class OfflineRecognizer(object): | @@ -69,7 +71,7 @@ class OfflineRecognizer(object): | ||
| 69 | feature_dim: | 71 | feature_dim: |
| 70 | Dimension of the feature used to train the model. | 72 | Dimension of the feature used to train the model. |
| 71 | decoding_method: | 73 | decoding_method: |
| 72 | - Support only greedy_search for now. | 74 | + Valid values: greedy_search, modified_beam_search. |
| 73 | debug: | 75 | debug: |
| 74 | True to show debug messages. | 76 | True to show debug messages. |
| 75 | provider: | 77 | provider: |
| @@ -137,7 +139,7 @@ class OfflineRecognizer(object): | @@ -137,7 +139,7 @@ class OfflineRecognizer(object): | ||
| 137 | feature_dim: | 139 | feature_dim: |
| 138 | Dimension of the feature used to train the model. | 140 | Dimension of the feature used to train the model. |
| 139 | decoding_method: | 141 | decoding_method: |
| 140 | - Valid values are greedy_search, modified_beam_search. | 142 | + Valid values are greedy_search. |
| 141 | debug: | 143 | debug: |
| 142 | True to show debug messages. | 144 | True to show debug messages. |
| 143 | provider: | 145 | provider: |
| @@ -185,14 +187,14 @@ class OfflineRecognizer(object): | @@ -185,14 +187,14 @@ class OfflineRecognizer(object): | ||
| 185 | English, etc. | 187 | English, etc. |
| 186 | 188 | ||
| 187 | Args: | 189 | Args: |
| 190 | + model: | ||
| 191 | + Path to ``model.onnx``. | ||
| 188 | tokens: | 192 | tokens: |
| 189 | Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | 193 | Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two |
| 190 | columns:: | 194 | columns:: |
| 191 | 195 | ||
| 192 | symbol integer_id | 196 | symbol integer_id |
| 193 | 197 | ||
| 194 | - model: | ||
| 195 | - Path to ``model.onnx``. | ||
| 196 | num_threads: | 198 | num_threads: |
| 197 | Number of threads for neural network computation. | 199 | Number of threads for neural network computation. |
| 198 | sample_rate: | 200 | sample_rate: |
| @@ -200,7 +202,7 @@ class OfflineRecognizer(object): | @@ -200,7 +202,7 @@ class OfflineRecognizer(object): | ||
| 200 | feature_dim: | 202 | feature_dim: |
| 201 | Dimension of the feature used to train the model. | 203 | Dimension of the feature used to train the model. |
| 202 | decoding_method: | 204 | decoding_method: |
| 203 | - Valid values are greedy_search, modified_beam_search. | 205 | + Valid values are greedy_search. |
| 204 | debug: | 206 | debug: |
| 205 | True to show debug messages. | 207 | True to show debug messages. |
| 206 | provider: | 208 | provider: |
| @@ -229,6 +231,68 @@ class OfflineRecognizer(object): | @@ -229,6 +231,68 @@ class OfflineRecognizer(object): | ||
| 229 | self.recognizer = _Recognizer(recognizer_config) | 231 | self.recognizer = _Recognizer(recognizer_config) |
| 230 | return self | 232 | return self |
| 231 | 233 | ||
| 234 | + @classmethod | ||
| 235 | + def from_whisper( | ||
| 236 | + cls, | ||
| 237 | + encoder: str, | ||
| 238 | + decoder: str, | ||
| 239 | + tokens: str, | ||
| 240 | + num_threads: int, | ||
| 241 | + decoding_method: str = "greedy_search", | ||
| 242 | + debug: bool = False, | ||
| 243 | + provider: str = "cpu", | ||
| 244 | + ): | ||
| 245 | + """ | ||
| 246 | + Please refer to | ||
| 247 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | ||
| 248 | + to download pre-trained models for different kinds of whisper models, | ||
| 249 | + e.g., tiny, tiny.en, base, base.en, etc. | ||
| 250 | + | ||
| 251 | + Args: | ||
| 252 | + encoder_model: | ||
| 253 | + Path to the encoder model, e.g., tiny-encoder.onnx, | ||
| 254 | + tiny-encoder.int8.onnx, tiny-encoder.ort, etc. | ||
| 255 | + decoder_model: | ||
| 256 | + Path to the encoder model, e.g., tiny-encoder.onnx, | ||
| 257 | + tiny-encoder.int8.onnx, tiny-encoder.ort, etc. | ||
| 258 | + tokens: | ||
| 259 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 260 | + columns:: | ||
| 261 | + | ||
| 262 | + symbol integer_id | ||
| 263 | + | ||
| 264 | + num_threads: | ||
| 265 | + Number of threads for neural network computation. | ||
| 266 | + decoding_method: | ||
| 267 | + Valid values: greedy_search. | ||
| 268 | + debug: | ||
| 269 | + True to show debug messages. | ||
| 270 | + provider: | ||
| 271 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 272 | + """ | ||
| 273 | + self = cls.__new__(cls) | ||
| 274 | + model_config = OfflineModelConfig( | ||
| 275 | + whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), | ||
| 276 | + tokens=tokens, | ||
| 277 | + num_threads=num_threads, | ||
| 278 | + debug=debug, | ||
| 279 | + provider=provider, | ||
| 280 | + model_type="whisper", | ||
| 281 | + ) | ||
| 282 | + | ||
| 283 | + feat_config = OfflineFeatureExtractorConfig( | ||
| 284 | + sampling_rate=16000, | ||
| 285 | + feature_dim=80, | ||
| 286 | + ) | ||
| 287 | + | ||
| 288 | + recognizer_config = OfflineRecognizerConfig( | ||
| 289 | + feat_config=feat_config, | ||
| 290 | + model_config=model_config, | ||
| 291 | + decoding_method=decoding_method, | ||
| 292 | + ) | ||
| 293 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 294 | + return self | ||
| 295 | + | ||
| 232 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | 296 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
| 233 | if contexts_list is None: | 297 | if contexts_list is None: |
| 234 | return self.recognizer.create_stream() | 298 | return self.recognizer.create_stream() |
-
请 注册 或 登录 后发表评论