Fangjun Kuang
Committed by GitHub

Support multilingual whisper models (#274)

... ... @@ -36,6 +36,9 @@ jobs:
CIBW_ARCHS: "universal2"
CIBW_BUILD_VERBOSITY: 3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
- name: Display wheels
shell: bash
run: |
... ...
... ... @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en"]
model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
steps:
- uses: actions/checkout@v2
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.7.6")
set(SHERPA_ONNX_VERSION "1.7.7")
# Disable warning about
#
... ...
... ... @@ -3,7 +3,7 @@ module non-streaming-decode-files
go 1.12
require (
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2
)
... ...
... ... @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
... ...
... ... @@ -4,6 +4,6 @@ go 1.12
require (
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5
)
... ...
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
... ...
... ... @@ -3,7 +3,7 @@ module streaming-decode-files
go 1.12
require (
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2
)
... ...
... ... @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
... ...
... ... @@ -11,10 +11,12 @@ fun main() {
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
var modelConfig = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
var modelConfig = OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
... ... @@ -41,19 +43,19 @@ fun main() {
var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
)
var samples : FloatArray = objArray[0] as FloatArray
var sampleRate : Int = objArray[1] as Int
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
model.acceptWaveform(samples, sampleRate=sampleRate)
model.acceptWaveform(samples, sampleRate = sampleRate)
while (model.isReady()) {
model.decode()
model.decode()
}
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tailPaddings, sampleRate=sampleRate)
model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
model.inputFinished()
while (model.isReady()) {
model.decode()
model.decode()
}
println("results: ${model.text}")
... ...
... ... @@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
help="Path to whisper decoder model",
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
... ... @@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
language=args.whisper_language,
task=args.whisper_task,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)
... ...
... ... @@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
... ... @@ -201,6 +202,28 @@ def get_args():
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
... ... @@ -371,10 +394,10 @@ def main():
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
language=args.whisper_language,
task=args.whisper_task,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)
... ...
... ... @@ -11,6 +11,7 @@ for making the onnx export script public.
"""
import argparse
import os
from pathlib import Path
from typing import Any, Dict, Optional
... ... @@ -250,6 +251,7 @@ def main():
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model.eval()
print(model.dims)
audio = torch.rand(16000 * 2)
... ... @@ -306,8 +308,12 @@ def main():
"n_text_head": model.dims.n_text_head,
"n_text_layer": model.dims.n_text_layer,
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
"all_language_codes": ",".join(tokenizer.all_language_codes),
"all_language_tokens": ",".join(
list(map(str, tokenizer.all_language_tokens))
), # a list of ids
"all_language_codes": ",".join(
tokenizer.all_language_codes
), # e.g., en, de, zh, fr
"sot": tokenizer.sot,
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
"eot": tokenizer.eot,
... ... @@ -413,6 +419,9 @@ def main():
},
)
if 'large' in args.model:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
... ...
... ... @@ -39,6 +39,24 @@ def get_args():
)
parser.add_argument(
"--language",
type=str,
help="""The actual spoken language in the audio.
Example values, en, de, zh, jp, fr.
If None, we will detect the language using the first 30s of the
input audio
""",
)
parser.add_argument(
"--task",
choices=["transcribe", "translate"],
type=str,
default="transcribe",
help="Valid values are: transcribe, translate",
)
parser.add_argument(
"sound_file",
type=str,
help="Path to the test wave",
... ... @@ -74,12 +92,22 @@ class OnnxModel:
self.sot = int(meta["sot"])
self.eot = int(meta["eot"])
self.translate = int(meta["translate"])
self.transcribe = int(meta["transcribe"])
self.no_timestamps = int(meta["no_timestamps"])
self.no_speech = int(meta["no_speech"])
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
map(int, meta["all_language_tokens"].split(","))
)
self.all_language_codes = meta["all_language_codes"].split(",")
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
self.is_multilingual = int(meta["is_multilingual"]) == 1
def init_decoder(self, decoder: str):
... ... @@ -164,6 +192,29 @@ class OnnxModel:
# logits is changed in-place
logits[self.translate] = float("-inf")
def detect_language(
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
) -> int:
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
def load_tokens(filename):
tokens = dict()
... ... @@ -200,7 +251,35 @@ def main():
mel = mel.t().unsqueeze(0)
model = OnnxModel(encoder, decoder)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
if args.language is not None:
if model.is_multilingual is False and args.language != "en":
print(f"This model supports only English. Given: {args.language}")
return
if args.language not in model.lang2id:
print(f"Invalid language: {args.language}")
print(f"Valid values are: {list(model.lang2id.keys())}")
return
# [sot, lang, task, notimestamps]
model.sot_sequence[1] = model.lang2id[args.language]
elif model.is_multilingual is True:
print("detecting language")
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
model.sot_sequence[1] = lang
if args.task is not None:
if model.is_multilingual is False and args.task != "transcribe":
print("This model supports only English. Please use --task=transcribe")
return
assert args.task in ["transcribe", "translate"], args.task
if args.task == "translate":
model.sot_sequence[2] = model.translate
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
... ... @@ -213,6 +292,7 @@ def main():
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += len(model.sot_sequence)
# logits.shape (batch_size, tokens.shape[1], vocab_size)
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=True)
... ... @@ -225,7 +305,6 @@ def main():
break
results.append(max_token_id.item())
tokens = torch.tensor([[results[-1]]])
offset += 1
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
... ... @@ -235,6 +314,7 @@ def main():
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += 1
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1)
... ...
... ... @@ -37,7 +37,7 @@
} \
\
dst = atoi(value.get()); \
if (dst <= 0) { \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
exit(-1); \
} \
... ... @@ -77,6 +77,24 @@
} \
} while (0)
// read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
src_key); \
exit(-1); \
} \
} while (0)
// Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
... ...
... ... @@ -23,21 +23,227 @@
namespace sherpa_onnx {
static std::string FixInvalidUtf8(const std::string &s) {
int32_t s_size = s.size();
std::string ans;
ans.reserve(s_size);
for (int32_t i = 0; i < s_size;) {
uint8_t c = s[i];
if (c < 0x80) {
// valid
ans.append(1, c);
++i;
continue;
} else if ((c >= 0xc0) && (c < 0xe0)) {
// beginning of two bytes
if ((i + 1) > (s_size - 1)) {
// no subsequent byte. invalid!
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
// valid 2-byte utf-8
ans.append(1, c);
ans.append(1, next);
i += 2;
continue;
} else if ((c >= 0xe0) && (c < 0xf0)) {
// beginning of 3 bytes
if ((i + 2) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
i += 3;
continue;
} else if ((c >= 0xf0) && (c < 0xf8)) {
// 4 bytes
if ((i + 3) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
i += 4;
continue;
} else if ((c >= 0xf8) && (c < 0xfc)) {
// 5 bytes
if ((i + 4) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
i += 5;
continue;
} else if ((c >= 0xfc) && (c < 0xfe)) {
// 6 bytes
if ((i + 5) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next5 = s[i + 5];
if (!(next5 >= 0x80 && next5 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
ans.append(1, next5);
i += 6;
continue;
} else {
i += 1;
}
}
return ans;
}
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.contains(i)) {
continue;
}
const auto &s = sym_table[i];
r.text += s;
text += s;
r.tokens.push_back(s);
}
// TODO(fangjun): Fix the following error in offline-stream.cc
//
// j["text"] = text;
// libc++abi: terminating with uncaught exception of type
// nlohmann::json_abi_v3_11_2::detail::type_error:
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
#if 0
r.text = FixInvalidUtf8(text);
#else
r.text = text;
#endif
return r;
}
... ... @@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
config_.model_config.whisper, model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
... ... @@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
mel = Transpose12(model_->Allocator(), &mel);
auto cross_kv = model_->ForwardEncoder(std::move(mel));
auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
... ...
... ... @@ -7,17 +7,106 @@
#include <algorithm>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
int64_t token_val = model_->SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = model_->VocabSize();
const auto &all_language_ids = model_->GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
#if 1
SHERPA_ONNX_LOGE("Detected language: %s",
model_->GetID2Lang().at(lang_id).c_str());
#endif
return lang_id;
}
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
auto self_kv_cache = model_->GetInitialSelfKVCache();
// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
if (model_->IsMultiLingual()) {
if (!config_.language.empty()) {
const auto &lang2id = model_->GetLang2ID();
if (!lang2id.count(config_.language)) {
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
exit(-1);
}
int32_t lang_id = lang2id.at(config_.language);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
} else {
int32_t lang_id = DetectLanguage(cross_k, cross_v);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
}
if (config_.task == "translate") {
initial_tokens[2] = model_->Translate();
} else if (config_.task != "transcribe") {
// initial_tokens[2] is transcribe by default
SHERPA_ONNX_LOGE(
"Unsupported task: %s. Valid values are: transcribe, translate.",
config_.task.c_str());
}
}
initial_tokens.push_back(model_->NoTimeStampsToken());
int32_t batch_size = 1;
std::array<int64_t, 2> token_shape{
batch_size, static_cast<int64_t>(initial_tokens.size())};
... ... @@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto self_kv_cache = model_->GetInitialSelfKVCache();
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
initial_tokens.size();
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
... ... @@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std::array<int64_t, 2> token_shape{1, 1};
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), token_shape.data(), token_shape.size());
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
p_tokens[0] = max_token_id;
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
if (i == 0) {
*p_offset = initial_tokens.size();
} else {
*p_offset += 1;
}
decoder_out = model_->ForwardDecoder(std::move(tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
... ... @@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
*p_offset += 1;
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
... ... @@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
}
std::vector<OfflineWhisperDecoderResult> ans(1);
ans[0].tokens = std::move(predicted_tokens);
return ans;
... ...
... ... @@ -8,19 +8,25 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace sherpa_onnx {
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
public:
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
: model_(model) {}
OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
OfflineWhisperModel *model)
: config_(config), model_(model) {}
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) const; // NOLINT
private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned
};
... ...
... ... @@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register("whisper-decoder", &decoder,
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
"medium.en-decoder.onnx.");
po->Register(
"whisper-language", &language,
"The spoke language in the input audio file. Example values: "
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
" infer the language from the input audio file. "
"Please refer to "
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
" for valid values. Note that for non-multilingual models, it supports "
"only 'en'");
po->Register("whisper-task", &task,
"Valid values: transcribe, translate. "
"Note that for non-multilingual models, it supports "
"only 'transcribe'");
}
bool OfflineWhisperModelConfig::Validate() const {
... ... @@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
return false;
}
if (task != "translate" && task != "transcribe") {
SHERPA_ONNX_LOGE(
"--whisper-task supports only translate and transcribe. Given: %s",
task.c_str());
return false;
}
return true;
}
... ... @@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
os << "OfflineWhisperModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
os << "decoder=\"" << decoder << "\", ";
os << "language=\"" << language << "\", ";
os << "task=\"" << task << "\")";
return os.str();
}
... ...
... ... @@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
std::string encoder;
std::string decoder;
// Available languages can be found at
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
//
// Note: For non-multilingual models, it supports only "en"
//
// If empty, we will infer it from the input audio file when
// the model is multilingual.
std::string language;
// Valid values are transcribe and translate
//
// Note: For non-multilingual models, it supports only "transcribe"
std::string task = "transcribe";
OfflineWhisperModelConfig() = default;
OfflineWhisperModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
const std::string &decoder,
const std::string &language,
const std::string &task)
: encoder(encoder), decoder(decoder), language(language), task(task) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -7,6 +7,7 @@
#include <algorithm>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
const std::vector<int32_t> &GetAllLanguageIDs() const {
return all_language_tokens_;
}
const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
return lang2id_;
}
const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
return id2lang_;
}
int32_t NoTimeStampsToken() const { return no_timestamps_; }
int32_t EOT() const { return eot_; }
int32_t SOT() const { return sot_; }
int32_t TextCtx() const { return n_text_ctx_; }
int32_t VocabSize() const { return n_vocab_; }
int32_t Translate() const { return translate_; }
bool IsMultiLingual() const { return is_multilingual_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
... ... @@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
if (is_multilingual_) {
SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
"all_language_tokens");
SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
"all_language_codes");
if (all_language_tokens_.size() != all_language_codes_.size()) {
SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
static_cast<int32_t>(all_language_tokens_.size()),
static_cast<int32_t>(all_language_codes_.size()));
exit(-1);
}
for (int32_t i = 0;
i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
}
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
... ... @@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<int32_t> all_language_tokens_;
std::vector<std::string> all_language_codes_;
std::unordered_map<std::string, int32_t> lang2id_;
std::unordered_map<int32_t, std::string> id2lang_;
// model meta data
int32_t n_text_layer_;
int32_t n_text_ctx_;
int32_t n_text_state_;
int32_t n_vocab_;
int32_t sot_;
int32_t eot_;
int32_t blank_;
int32_t translate_;
int32_t transcribe_;
int32_t no_timestamps_;
int32_t no_speech_;
int32_t is_multilingual_;
std::vector<int64_t> sot_sequence_;
};
... ... @@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
Ort::Value features) {
Ort::Value features) const {
return impl_->ForwardEncoder(std::move(features));
}
... ... @@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
Ort::Value n_layer_self_v_cache,
Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v,
Ort::Value offset) {
Ort::Value offset) const {
return impl_->ForwardDecoder(
std::move(tokens), std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
std::move(n_layer_cross_v), std::move(offset));
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
const {
return impl_->GetInitialSelfKVCache();
}
... ... @@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
return impl_->GetInitialTokens();
}
const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
return impl_->GetAllLanguageIDs();
}
const std::unordered_map<std::string, int32_t>
&OfflineWhisperModel::GetLang2ID() const {
return impl_->GetLang2ID();
}
const std::unordered_map<int32_t, std::string>
&OfflineWhisperModel::GetID2Lang() const {
return impl_->GetID2Lang();
}
int32_t OfflineWhisperModel::NoTimeStampsToken() const {
return impl_->NoTimeStampsToken();
}
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
bool OfflineWhisperModel::IsMultiLingual() const {
return impl_->IsMultiLingual();
}
} // namespace sherpa_onnx
... ...
... ... @@ -5,7 +5,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
... ... @@ -30,7 +32,7 @@ class OfflineWhisperModel {
* - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
*/
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
/** Run the decoder model.
*
... ... @@ -58,7 +60,9 @@ class OfflineWhisperModel {
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset);
Ort::Value n_layer_cross_v, Ort::Value offset) const;
int32_t DetectLanguage() const;
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
... ... @@ -66,14 +70,23 @@ class OfflineWhisperModel {
* - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*/
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
const std::vector<int64_t> &GetInitialTokens() const;
const std::vector<int32_t> &GetAllLanguageIDs() const;
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
int32_t NoTimeStampsToken() const;
int32_t EOT() const;
int32_t SOT() const;
int32_t TextCtx() const;
int32_t VocabSize() const;
int32_t Translate() const;
bool IsMultiLingual() const;
private:
class Impl;
... ...
... ... @@ -14,10 +14,14 @@ namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m) {
using PyClass = OfflineWhisperModelConfig;
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
py::arg("task"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def_readwrite("language", &PyClass::language)
.def_readwrite("task", &PyClass::task)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -244,6 +244,8 @@ class OfflineRecognizer(object):
encoder: str,
decoder: str,
tokens: str,
language: str = "en",
task: str = "transcribe",
num_threads: int = 1,
decoding_method: str = "greedy_search",
debug: bool = False,
... ... @@ -268,6 +270,14 @@ class OfflineRecognizer(object):
symbol integer_id
language:
The spoken language in the audio file. Example values: en, de, zh,
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
for all possible values. Note that for non-multilingual models, the
only valid value is 'en'.
task:
Valid values are: transcribe, translate. Note that for
non-multilingual models, the only valid value is 'transcribe'.
num_threads:
Number of threads for neural network computation.
decoding_method:
... ... @@ -279,7 +289,12 @@ class OfflineRecognizer(object):
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
whisper=OfflineWhisperModelConfig(
encoder=encoder,
decoder=decoder,
language=language,
task=task,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
... ...