Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-08-16 00:28:52 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-08-16 00:28:52 +0800
Commit
f709c95c5f34985ca3c7d48282744fbd22a6e260
f709c95c
1 parent
496c5dd7
Support multilingual whisper models (#274)
隐藏空白字符变更
内嵌
并排对比
正在显示
24 个修改的文件
包含
692 行增加
和
73 行删除
.github/workflows/build-wheels-macos.yaml
.github/workflows/export-whisper-to-onnx.yaml
CMakeLists.txt
go-api-examples/non-streaming-decode-files/go.mod
go-api-examples/non-streaming-decode-files/go.sum
go-api-examples/real-time-speech-recognition-from-microphone/go.mod
go-api-examples/real-time-speech-recognition-from-microphone/go.sum
go-api-examples/streaming-decode-files/go.mod
go-api-examples/streaming-decode-files/go.sum
kotlin-api-examples/Main.kt
python-api-examples/non_streaming_server.py
python-api-examples/offline-decode-files.py
scripts/whisper/export-onnx.py
scripts/whisper/test.py
sherpa-onnx/csrc/macros.h
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
sherpa-onnx/csrc/offline-whisper-model-config.cc
sherpa-onnx/csrc/offline-whisper-model-config.h
sherpa-onnx/csrc/offline-whisper-model.cc
sherpa-onnx/csrc/offline-whisper-model.h
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
.github/workflows/build-wheels-macos.yaml
查看文件 @
f709c95
...
...
@@ -36,6 +36,9 @@ jobs:
CIBW_ARCHS
:
"
universal2"
CIBW_BUILD_VERBOSITY
:
3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS
:
"
"
-
name
:
Display wheels
shell
:
bash
run
:
|
...
...
.github/workflows/export-whisper-to-onnx.yaml
查看文件 @
f709c95
...
...
@@ -16,7 +16,7 @@ jobs:
fail-fast
:
false
matrix
:
os
:
[
macos-latest
]
model
:
[
"
tiny.en"
,
"
base.en"
,
"
small.en"
,
"
medium.en"
]
model
:
[
"
tiny.en"
,
"
base.en"
,
"
small.en"
,
"
medium.en"
,
"
tiny"
,
"
base"
,
"
small"
,
"
medium"
,
"
large"
,
"
large-v1"
,
"
large-v2"
]
steps
:
-
uses
:
actions/checkout@v2
...
...
CMakeLists.txt
查看文件 @
f709c95
cmake_minimum_required
(
VERSION 3.13 FATAL_ERROR
)
project
(
sherpa-onnx
)
set
(
SHERPA_ONNX_VERSION
"1.7.
6
"
)
set
(
SHERPA_ONNX_VERSION
"1.7.
7
"
)
# Disable warning about
#
...
...
go-api-examples/non-streaming-decode-files/go.mod
查看文件 @
f709c95
...
...
@@ -3,7 +3,7 @@ module non-streaming-decode-files
go 1.12
require (
github.com/k2-fsa/sherpa-onnx-go v1.
5.5
-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.
7.6
-alpha.1
github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2
)
...
...
go-api-examples/non-streaming-decode-files/go.sum
查看文件 @
f709c95
...
...
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
...
...
go-api-examples/real-time-speech-recognition-from-microphone/go.mod
查看文件 @
f709c95
...
...
@@ -4,6 +4,6 @@ go 1.12
require (
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/k2-fsa/sherpa-onnx-go v1.
5.5
-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.
7.6
-alpha.1
github.com/spf13/pflag v1.0.5
)
...
...
go-api-examples/real-time-speech-recognition-from-microphone/go.sum
查看文件 @
f709c95
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
...
...
go-api-examples/streaming-decode-files/go.mod
查看文件 @
f709c95
...
...
@@ -3,7 +3,7 @@ module streaming-decode-files
go 1.12
require (
github.com/k2-fsa/sherpa-onnx-go v1.
5.5
-alpha.1
github.com/k2-fsa/sherpa-onnx-go v1.
7.6
-alpha.1
github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2
)
...
...
go-api-examples/streaming-decode-files/go.sum
查看文件 @
f709c95
...
...
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
...
...
kotlin-api-examples/Main.kt
查看文件 @
f709c95
...
...
@@ -11,10 +11,12 @@ fun main() {
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
var modelConfig = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
var modelConfig = OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
...
...
@@ -41,19 +43,19 @@ fun main() {
var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
)
var samples : FloatArray = objArray[0] as FloatArray
var sampleRate : Int = objArray[1] as Int
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
model.acceptWaveform(samples, sampleRate
=
sampleRate)
model.acceptWaveform(samples, sampleRate
=
sampleRate)
while (model.isReady()) {
model.decode()
model.decode()
}
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tailPaddings, sampleRate
=
sampleRate)
model.acceptWaveform(tailPaddings, sampleRate
=
sampleRate)
model.inputFinished()
while (model.isReady()) {
model.decode()
model.decode()
}
println("results: ${model.text}")
...
...
python-api-examples/non_streaming_server.py
查看文件 @
f709c95
...
...
@@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
help
=
"Path to whisper decoder model"
,
)
parser
.
add_argument
(
"--whisper-language"
,
default
=
""
,
type
=
str
,
help
=
"""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
"""
,
)
parser
.
add_argument
(
"--whisper-task"
,
default
=
"transcribe"
,
choices
=
[
"transcribe"
,
"translate"
],
type
=
str
,
help
=
"""For multilingual models, if you specify translate, the output
will be in English.
"""
,
)
def
add_model_args
(
parser
:
argparse
.
ArgumentParser
):
add_transducer_model_args
(
parser
)
...
...
@@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
tokens
=
args
.
tokens
,
num_threads
=
args
.
num_threads
,
decoding_method
=
args
.
decoding_method
,
language
=
args
.
whisper_language
,
task
=
args
.
whisper_task
,
)
elif
args
.
tdnn_model
:
assert_file_exists
(
args
.
tdnn_model
)
...
...
python-api-examples/offline-decode-files.py
查看文件 @
f709c95
...
...
@@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx
\
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx
\
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt
\
--whisper-task=transcribe
\
--num-threads=1
\
./sherpa-onnx-whisper-base.en/test_wavs/0.wav
\
./sherpa-onnx-whisper-base.en/test_wavs/1.wav
\
...
...
@@ -201,6 +202,28 @@ def get_args():
)
parser
.
add_argument
(
"--whisper-language"
,
default
=
""
,
type
=
str
,
help
=
"""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
"""
,
)
parser
.
add_argument
(
"--whisper-task"
,
default
=
"transcribe"
,
choices
=
[
"transcribe"
,
"translate"
],
type
=
str
,
help
=
"""For multilingual models, if you specify translate, the output
will be in English.
"""
,
)
parser
.
add_argument
(
"--decoding-method"
,
type
=
str
,
default
=
"greedy_search"
,
...
...
@@ -371,10 +394,10 @@ def main():
decoder
=
args
.
whisper_decoder
,
tokens
=
args
.
tokens
,
num_threads
=
args
.
num_threads
,
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
,
language
=
args
.
whisper_language
,
task
=
args
.
whisper_task
,
)
elif
args
.
tdnn_model
:
assert_file_exists
(
args
.
tdnn_model
)
...
...
scripts/whisper/export-onnx.py
查看文件 @
f709c95
...
...
@@ -11,6 +11,7 @@ for making the onnx export script public.
"""
import
argparse
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
...
...
@@ -250,6 +251,7 @@ def main():
# write tokens
tokenizer
=
whisper
.
tokenizer
.
get_tokenizer
(
model
.
is_multilingual
)
model
.
eval
()
print
(
model
.
dims
)
audio
=
torch
.
rand
(
16000
*
2
)
...
...
@@ -306,8 +308,12 @@ def main():
"n_text_head"
:
model
.
dims
.
n_text_head
,
"n_text_layer"
:
model
.
dims
.
n_text_layer
,
"sot_sequence"
:
","
.
join
(
list
(
map
(
str
,
tokenizer
.
sot_sequence
))),
"all_language_tokens"
:
","
.
join
(
list
(
map
(
str
,
tokenizer
.
all_language_tokens
))),
"all_language_codes"
:
","
.
join
(
tokenizer
.
all_language_codes
),
"all_language_tokens"
:
","
.
join
(
list
(
map
(
str
,
tokenizer
.
all_language_tokens
))
),
# a list of ids
"all_language_codes"
:
","
.
join
(
tokenizer
.
all_language_codes
),
# e.g., en, de, zh, fr
"sot"
:
tokenizer
.
sot
,
"sot_index"
:
tokenizer
.
sot_sequence
.
index
(
tokenizer
.
sot
),
"eot"
:
tokenizer
.
eot
,
...
...
@@ -413,6 +419,9 @@ def main():
},
)
if
'large'
in
args
.
model
:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
...
...
scripts/whisper/test.py
查看文件 @
f709c95
...
...
@@ -39,6 +39,24 @@ def get_args():
)
parser
.
add_argument
(
"--language"
,
type
=
str
,
help
=
"""The actual spoken language in the audio.
Example values, en, de, zh, jp, fr.
If None, we will detect the language using the first 30s of the
input audio
"""
,
)
parser
.
add_argument
(
"--task"
,
choices
=
[
"transcribe"
,
"translate"
],
type
=
str
,
default
=
"transcribe"
,
help
=
"Valid values are: transcribe, translate"
,
)
parser
.
add_argument
(
"sound_file"
,
type
=
str
,
help
=
"Path to the test wave"
,
...
...
@@ -74,12 +92,22 @@ class OnnxModel:
self
.
sot
=
int
(
meta
[
"sot"
])
self
.
eot
=
int
(
meta
[
"eot"
])
self
.
translate
=
int
(
meta
[
"translate"
])
self
.
transcribe
=
int
(
meta
[
"transcribe"
])
self
.
no_timestamps
=
int
(
meta
[
"no_timestamps"
])
self
.
no_speech
=
int
(
meta
[
"no_speech"
])
self
.
blank
=
int
(
meta
[
"blank_id"
])
self
.
sot_sequence
=
list
(
map
(
int
,
meta
[
"sot_sequence"
]
.
split
(
","
)))
self
.
sot_sequence
.
append
(
self
.
no_timestamps
)
self
.
all_language_tokens
=
list
(
map
(
int
,
meta
[
"all_language_tokens"
]
.
split
(
","
))
)
self
.
all_language_codes
=
meta
[
"all_language_codes"
]
.
split
(
","
)
self
.
lang2id
=
dict
(
zip
(
self
.
all_language_codes
,
self
.
all_language_tokens
))
self
.
id2lang
=
dict
(
zip
(
self
.
all_language_tokens
,
self
.
all_language_codes
))
self
.
is_multilingual
=
int
(
meta
[
"is_multilingual"
])
==
1
def
init_decoder
(
self
,
decoder
:
str
):
...
...
@@ -164,6 +192,29 @@ class OnnxModel:
# logits is changed in-place
logits
[
self
.
translate
]
=
float
(
"-inf"
)
def
detect_language
(
self
,
n_layer_cross_k
:
torch
.
Tensor
,
n_layer_cross_v
:
torch
.
Tensor
)
->
int
:
tokens
=
torch
.
tensor
([[
self
.
sot
]],
dtype
=
torch
.
int64
)
offset
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
)
n_layer_self_k_cache
,
n_layer_self_v_cache
=
self
.
get_self_cache
()
logits
,
n_layer_self_k_cache
,
n_layer_self_v_cache
=
self
.
run_decoder
(
tokens
=
tokens
,
n_layer_self_k_cache
=
n_layer_self_k_cache
,
n_layer_self_v_cache
=
n_layer_self_v_cache
,
n_layer_cross_k
=
n_layer_cross_k
,
n_layer_cross_v
=
n_layer_cross_v
,
offset
=
offset
,
)
logits
=
logits
.
reshape
(
-
1
)
mask
=
torch
.
ones
(
logits
.
shape
[
0
],
dtype
=
torch
.
int64
)
mask
[
self
.
all_language_tokens
]
=
0
logits
[
mask
]
=
float
(
"-inf"
)
lang_id
=
logits
.
argmax
()
.
item
()
print
(
"detected language: "
,
self
.
id2lang
[
lang_id
])
return
lang_id
def
load_tokens
(
filename
):
tokens
=
dict
()
...
...
@@ -200,7 +251,35 @@ def main():
mel
=
mel
.
t
()
.
unsqueeze
(
0
)
model
=
OnnxModel
(
encoder
,
decoder
)
n_layer_cross_k
,
n_layer_cross_v
=
model
.
run_encoder
(
mel
)
if
args
.
language
is
not
None
:
if
model
.
is_multilingual
is
False
and
args
.
language
!=
"en"
:
print
(
f
"This model supports only English. Given: {args.language}"
)
return
if
args
.
language
not
in
model
.
lang2id
:
print
(
f
"Invalid language: {args.language}"
)
print
(
f
"Valid values are: {list(model.lang2id.keys())}"
)
return
# [sot, lang, task, notimestamps]
model
.
sot_sequence
[
1
]
=
model
.
lang2id
[
args
.
language
]
elif
model
.
is_multilingual
is
True
:
print
(
"detecting language"
)
lang
=
model
.
detect_language
(
n_layer_cross_k
,
n_layer_cross_v
)
model
.
sot_sequence
[
1
]
=
lang
if
args
.
task
is
not
None
:
if
model
.
is_multilingual
is
False
and
args
.
task
!=
"transcribe"
:
print
(
"This model supports only English. Please use --task=transcribe"
)
return
assert
args
.
task
in
[
"transcribe"
,
"translate"
],
args
.
task
if
args
.
task
==
"translate"
:
model
.
sot_sequence
[
2
]
=
model
.
translate
n_layer_self_k_cache
,
n_layer_self_v_cache
=
model
.
get_self_cache
()
tokens
=
torch
.
tensor
([
model
.
sot_sequence
],
dtype
=
torch
.
int64
)
...
...
@@ -213,6 +292,7 @@ def main():
n_layer_cross_v
=
n_layer_cross_v
,
offset
=
offset
,
)
offset
+=
len
(
model
.
sot_sequence
)
# logits.shape (batch_size, tokens.shape[1], vocab_size)
logits
=
logits
[
0
,
-
1
]
model
.
suppress_tokens
(
logits
,
is_initial
=
True
)
...
...
@@ -225,7 +305,6 @@ def main():
break
results
.
append
(
max_token_id
.
item
())
tokens
=
torch
.
tensor
([[
results
[
-
1
]]])
offset
+=
1
logits
,
n_layer_self_k_cache
,
n_layer_self_v_cache
=
model
.
run_decoder
(
tokens
=
tokens
,
...
...
@@ -235,6 +314,7 @@ def main():
n_layer_cross_v
=
n_layer_cross_v
,
offset
=
offset
,
)
offset
+=
1
logits
=
logits
[
0
,
-
1
]
model
.
suppress_tokens
(
logits
,
is_initial
=
False
)
max_token_id
=
logits
.
argmax
(
dim
=-
1
)
...
...
sherpa-onnx/csrc/macros.h
查看文件 @
f709c95
...
...
@@ -37,7 +37,7 @@
} \
\
dst = atoi(value.get()); \
if (dst <
= 0) {
\
if (dst <
0) {
\
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
exit(-1); \
} \
...
...
@@ -77,6 +77,24 @@
} \
} while (0)
// read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
src_key); \
exit(-1); \
} \
} while (0)
// Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
...
...
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
查看文件 @
f709c95
...
...
@@ -23,21 +23,227 @@
namespace
sherpa_onnx
{
static
std
::
string
FixInvalidUtf8
(
const
std
::
string
&
s
)
{
int32_t
s_size
=
s
.
size
();
std
::
string
ans
;
ans
.
reserve
(
s_size
);
for
(
int32_t
i
=
0
;
i
<
s_size
;)
{
uint8_t
c
=
s
[
i
];
if
(
c
<
0x80
)
{
// valid
ans
.
append
(
1
,
c
);
++
i
;
continue
;
}
else
if
((
c
>=
0xc0
)
&&
(
c
<
0xe0
))
{
// beginning of two bytes
if
((
i
+
1
)
>
(
s_size
-
1
))
{
// no subsequent byte. invalid!
i
+=
1
;
continue
;
}
uint8_t
next
=
s
[
i
+
1
];
if
(
!
(
next
>=
0x80
&&
next
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
// valid 2-byte utf-8
ans
.
append
(
1
,
c
);
ans
.
append
(
1
,
next
);
i
+=
2
;
continue
;
}
else
if
((
c
>=
0xe0
)
&&
(
c
<
0xf0
))
{
// beginning of 3 bytes
if
((
i
+
2
)
>
(
s_size
-
1
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next
=
s
[
i
+
1
];
if
(
!
(
next
>=
0x80
&&
next
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next2
=
s
[
i
+
2
];
if
(
!
(
next2
>=
0x80
&&
next2
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
ans
.
append
(
1
,
c
);
ans
.
append
(
1
,
next
);
ans
.
append
(
1
,
next2
);
i
+=
3
;
continue
;
}
else
if
((
c
>=
0xf0
)
&&
(
c
<
0xf8
))
{
// 4 bytes
if
((
i
+
3
)
>
(
s_size
-
1
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next
=
s
[
i
+
1
];
if
(
!
(
next
>=
0x80
&&
next
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next2
=
s
[
i
+
2
];
if
(
!
(
next2
>=
0x80
&&
next2
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next3
=
s
[
i
+
3
];
if
(
!
(
next3
>=
0x80
&&
next3
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
ans
.
append
(
1
,
c
);
ans
.
append
(
1
,
next
);
ans
.
append
(
1
,
next2
);
ans
.
append
(
1
,
next3
);
i
+=
4
;
continue
;
}
else
if
((
c
>=
0xf8
)
&&
(
c
<
0xfc
))
{
// 5 bytes
if
((
i
+
4
)
>
(
s_size
-
1
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next
=
s
[
i
+
1
];
if
(
!
(
next
>=
0x80
&&
next
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next2
=
s
[
i
+
2
];
if
(
!
(
next2
>=
0x80
&&
next2
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next3
=
s
[
i
+
3
];
if
(
!
(
next3
>=
0x80
&&
next3
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next4
=
s
[
i
+
4
];
if
(
!
(
next4
>=
0x80
&&
next4
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
ans
.
append
(
1
,
c
);
ans
.
append
(
1
,
next
);
ans
.
append
(
1
,
next2
);
ans
.
append
(
1
,
next3
);
ans
.
append
(
1
,
next4
);
i
+=
5
;
continue
;
}
else
if
((
c
>=
0xfc
)
&&
(
c
<
0xfe
))
{
// 6 bytes
if
((
i
+
5
)
>
(
s_size
-
1
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next
=
s
[
i
+
1
];
if
(
!
(
next
>=
0x80
&&
next
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next2
=
s
[
i
+
2
];
if
(
!
(
next2
>=
0x80
&&
next2
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next3
=
s
[
i
+
3
];
if
(
!
(
next3
>=
0x80
&&
next3
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next4
=
s
[
i
+
4
];
if
(
!
(
next4
>=
0x80
&&
next4
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
uint8_t
next5
=
s
[
i
+
5
];
if
(
!
(
next5
>=
0x80
&&
next5
<
0xc0
))
{
// invalid
i
+=
1
;
continue
;
}
ans
.
append
(
1
,
c
);
ans
.
append
(
1
,
next
);
ans
.
append
(
1
,
next2
);
ans
.
append
(
1
,
next3
);
ans
.
append
(
1
,
next4
);
ans
.
append
(
1
,
next5
);
i
+=
6
;
continue
;
}
else
{
i
+=
1
;
}
}
return
ans
;
}
static
OfflineRecognitionResult
Convert
(
const
OfflineWhisperDecoderResult
&
src
,
const
SymbolTable
&
sym_table
)
{
OfflineRecognitionResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
std
::
string
text
;
for
(
auto
i
:
src
.
tokens
)
{
if
(
!
sym_table
.
contains
(
i
))
{
continue
;
}
const
auto
&
s
=
sym_table
[
i
];
r
.
text
+=
s
;
text
+=
s
;
r
.
tokens
.
push_back
(
s
);
}
// TODO(fangjun): Fix the following error in offline-stream.cc
//
// j["text"] = text;
// libc++abi: terminating with uncaught exception of type
// nlohmann::json_abi_v3_11_2::detail::type_error:
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
#if 0
r.text = FixInvalidUtf8(text);
#else
r
.
text
=
text
;
#endif
return
r
;
}
...
...
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
symbol_table_
.
ApplyBase64Decode
();
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineWhisperGreedySearchDecoder
>
(
model_
.
get
());
decoder_
=
std
::
make_unique
<
OfflineWhisperGreedySearchDecoder
>
(
config_
.
model_config
.
whisper
,
model_
.
get
());
}
else
{
SHERPA_ONNX_LOGE
(
"Only greedy_search is supported at present for whisper. Given %s"
,
...
...
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
mel
=
Transpose12
(
model_
->
Allocator
(),
&
mel
);
auto
cross_kv
=
model_
->
ForwardEncoder
(
std
::
move
(
mel
));
auto
results
=
decoder_
->
Decode
(
std
::
move
(
cross_kv
.
first
),
std
::
move
(
cross_kv
.
second
));
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
查看文件 @
f709c95
...
...
@@ -7,17 +7,106 @@
#include <algorithm>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
int32_t
OfflineWhisperGreedySearchDecoder
::
DetectLanguage
(
Ort
::
Value
&
cross_k
,
Ort
::
Value
&
cross_v
)
const
{
// NOLINT
int64_t
token_val
=
model_
->
SOT
();
std
::
array
<
int64_t
,
2
>
token_shape
{
1
,
1
};
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
Ort
::
Value
tokens
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
token_val
,
1
,
token_shape
.
data
(),
token_shape
.
size
());
auto
self_kv_cache
=
model_
->
GetInitialSelfKVCache
();
std
::
array
<
int64_t
,
1
>
offset_shape
{
1
};
Ort
::
Value
offset
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
model_
->
Allocator
(),
offset_shape
.
data
(),
offset_shape
.
size
());
*
(
offset
.
GetTensorMutableData
<
int64_t
>
())
=
0
;
auto
decoder_out
=
model_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
self_kv_cache
.
first
),
std
::
move
(
self_kv_cache
.
second
),
std
::
move
(
cross_k
),
std
::
move
(
cross_v
),
std
::
move
(
offset
));
cross_k
=
std
::
move
(
std
::
get
<
3
>
(
decoder_out
));
cross_v
=
std
::
move
(
std
::
get
<
4
>
(
decoder_out
));
const
float
*
p_logits
=
std
::
get
<
0
>
(
decoder_out
).
GetTensorData
<
float
>
();
int32_t
vocab_size
=
model_
->
VocabSize
();
const
auto
&
all_language_ids
=
model_
->
GetAllLanguageIDs
();
int32_t
lang_id
=
all_language_ids
[
0
];
float
this_logit
=
p_logits
[
lang_id
];
for
(
int32_t
i
=
1
;
i
!=
all_language_ids
.
size
();
++
i
)
{
int32_t
id
=
all_language_ids
[
i
];
float
p
=
p_logits
[
id
];
if
(
p
>
this_logit
)
{
this_logit
=
p
;
lang_id
=
id
;
}
}
#if 1
SHERPA_ONNX_LOGE
(
"Detected language: %s"
,
model_
->
GetID2Lang
().
at
(
lang_id
).
c_str
());
#endif
return
lang_id
;
}
std
::
vector
<
OfflineWhisperDecoderResult
>
OfflineWhisperGreedySearchDecoder
::
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
auto
self_kv_cache
=
model_
->
GetInitialSelfKVCache
();
// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std
::
vector
<
int64_t
>
initial_tokens
=
model_
->
GetInitialTokens
();
if
(
model_
->
IsMultiLingual
())
{
if
(
!
config_
.
language
.
empty
())
{
const
auto
&
lang2id
=
model_
->
GetLang2ID
();
if
(
!
lang2id
.
count
(
config_
.
language
))
{
SHERPA_ONNX_LOGE
(
"Invalid language: %s"
,
config_
.
language
.
c_str
());
exit
(
-
1
);
}
int32_t
lang_id
=
lang2id
.
at
(
config_
.
language
);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens
[
1
]
=
lang_id
;
}
else
{
int32_t
lang_id
=
DetectLanguage
(
cross_k
,
cross_v
);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens
[
1
]
=
lang_id
;
}
if
(
config_
.
task
==
"translate"
)
{
initial_tokens
[
2
]
=
model_
->
Translate
();
}
else
if
(
config_
.
task
!=
"transcribe"
)
{
// initial_tokens[2] is transcribe by default
SHERPA_ONNX_LOGE
(
"Unsupported task: %s. Valid values are: transcribe, translate."
,
config_
.
task
.
c_str
());
}
}
initial_tokens
.
push_back
(
model_
->
NoTimeStampsToken
());
int32_t
batch_size
=
1
;
std
::
array
<
int64_t
,
2
>
token_shape
{
batch_size
,
static_cast
<
int64_t
>
(
initial_tokens
.
size
())};
...
...
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
model_
->
Allocator
(),
offset_shape
.
data
(),
offset_shape
.
size
());
*
(
offset
.
GetTensorMutableData
<
int64_t
>
())
=
0
;
auto
self_kv_cache
=
model_
->
GetInitialSelfKVCache
();
auto
decoder_out
=
model_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
self_kv_cache
.
first
),
std
::
move
(
self_kv_cache
.
second
),
std
::
move
(
cross_k
),
std
::
move
(
cross_v
),
std
::
move
(
offset
));
*
(
std
::
get
<
5
>
(
decoder_out
).
GetTensorMutableData
<
int64_t
>
())
=
initial_tokens
.
size
();
const
auto
&
logits
=
std
::
get
<
0
>
(
decoder_out
);
const
float
*
p_logits
=
logits
.
GetTensorData
<
float
>
();
...
...
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std
::
array
<
int64_t
,
2
>
token_shape
{
1
,
1
};
Ort
::
Value
tokens
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
model_
->
Allocator
(),
token_shape
.
data
(),
token_shape
.
size
());
int64_t
*
p_tokens
=
tokens
.
GetTensorMutableData
<
int64_t
>
();
p_tokens
[
0
]
=
max_token_id
;
int64_t
*
p_offset
=
std
::
get
<
5
>
(
decoder_out
).
GetTensorMutableData
<
int64_t
>
();
if
(
i
==
0
)
{
*
p_offset
=
initial_tokens
.
size
();
}
else
{
*
p_offset
+=
1
;
}
decoder_out
=
model_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
std
::
get
<
1
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
2
>
(
decoder_out
)),
...
...
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std
::
move
(
std
::
get
<
4
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
5
>
(
decoder_out
)));
int64_t
*
p_offset
=
std
::
get
<
5
>
(
decoder_out
).
GetTensorMutableData
<
int64_t
>
();
*
p_offset
+=
1
;
const
auto
&
logits
=
std
::
get
<
0
>
(
decoder_out
);
const
float
*
p_logits
=
logits
.
GetTensorData
<
float
>
();
...
...
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
}
std
::
vector
<
OfflineWhisperDecoderResult
>
ans
(
1
);
ans
[
0
].
tokens
=
std
::
move
(
predicted_tokens
);
return
ans
;
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
查看文件 @
f709c95
...
...
@@ -8,19 +8,25 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace
sherpa_onnx
{
class
OfflineWhisperGreedySearchDecoder
:
public
OfflineWhisperDecoder
{
public
:
explicit
OfflineWhisperGreedySearchDecoder
(
OfflineWhisperModel
*
model
)
:
model_
(
model
)
{}
OfflineWhisperGreedySearchDecoder
(
const
OfflineWhisperModelConfig
&
config
,
OfflineWhisperModel
*
model
)
:
config_
(
config
),
model_
(
model
)
{}
std
::
vector
<
OfflineWhisperDecoderResult
>
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
override
;
int32_t
DetectLanguage
(
Ort
::
Value
&
cross_k
,
// NOLINT
Ort
::
Value
&
cross_v
)
const
;
// NOLINT
private
:
OfflineWhisperModelConfig
config_
;
OfflineWhisperModel
*
model_
;
// not owned
};
...
...
sherpa-onnx/csrc/offline-whisper-model-config.cc
查看文件 @
f709c95
...
...
@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po
->
Register
(
"whisper-decoder"
,
&
decoder
,
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
"medium.en-decoder.onnx."
);
po
->
Register
(
"whisper-language"
,
&
language
,
"The spoke language in the input audio file. Example values: "
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
" infer the language from the input audio file. "
"Please refer to "
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
" for valid values. Note that for non-multilingual models, it supports "
"only 'en'"
);
po
->
Register
(
"whisper-task"
,
&
task
,
"Valid values: transcribe, translate. "
"Note that for non-multilingual models, it supports "
"only 'transcribe'"
);
}
bool
OfflineWhisperModelConfig
::
Validate
()
const
{
...
...
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
return
false
;
}
if
(
task
!=
"translate"
&&
task
!=
"transcribe"
)
{
SHERPA_ONNX_LOGE
(
"--whisper-task supports only translate and transcribe. Given: %s"
,
task
.
c_str
());
return
false
;
}
return
true
;
}
...
...
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
os
<<
"OfflineWhisperModelConfig("
;
os
<<
"encoder=
\"
"
<<
encoder
<<
"
\"
, "
;
os
<<
"decoder=
\"
"
<<
decoder
<<
"
\"
)"
;
os
<<
"decoder=
\"
"
<<
decoder
<<
"
\"
, "
;
os
<<
"language=
\"
"
<<
language
<<
"
\"
, "
;
os
<<
"task=
\"
"
<<
task
<<
"
\"
)"
;
return
os
.
str
();
}
...
...
sherpa-onnx/csrc/offline-whisper-model-config.h
查看文件 @
f709c95
...
...
@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
std
::
string
encoder
;
std
::
string
decoder
;
// Available languages can be found at
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
//
// Note: For non-multilingual models, it supports only "en"
//
// If empty, we will infer it from the input audio file when
// the model is multilingual.
std
::
string
language
;
// Valid values are transcribe and translate
//
// Note: For non-multilingual models, it supports only "transcribe"
std
::
string
task
=
"transcribe"
;
OfflineWhisperModelConfig
()
=
default
;
OfflineWhisperModelConfig
(
const
std
::
string
&
encoder
,
const
std
::
string
&
decoder
)
:
encoder
(
encoder
),
decoder
(
decoder
)
{}
const
std
::
string
&
decoder
,
const
std
::
string
&
language
,
const
std
::
string
&
task
)
:
encoder
(
encoder
),
decoder
(
decoder
),
language
(
language
),
task
(
task
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
...
...
sherpa-onnx/csrc/offline-whisper-model.cc
查看文件 @
f709c95
...
...
@@ -7,6 +7,7 @@
#include <algorithm>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
...
...
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
const
std
::
vector
<
int64_t
>
&
GetInitialTokens
()
const
{
return
sot_sequence_
;
}
const
std
::
vector
<
int32_t
>
&
GetAllLanguageIDs
()
const
{
return
all_language_tokens_
;
}
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
GetLang2ID
()
const
{
return
lang2id_
;
}
const
std
::
unordered_map
<
int32_t
,
std
::
string
>
&
GetID2Lang
()
const
{
return
id2lang_
;
}
int32_t
NoTimeStampsToken
()
const
{
return
no_timestamps_
;
}
int32_t
EOT
()
const
{
return
eot_
;
}
int32_t
SOT
()
const
{
return
sot_
;
}
int32_t
TextCtx
()
const
{
return
n_text_ctx_
;
}
int32_t
VocabSize
()
const
{
return
n_vocab_
;
}
int32_t
Translate
()
const
{
return
translate_
;
}
bool
IsMultiLingual
()
const
{
return
is_multilingual_
;
}
private
:
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
...
...
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
SHERPA_ONNX_READ_META_DATA
(
n_text_layer_
,
"n_text_layer"
);
SHERPA_ONNX_READ_META_DATA
(
n_text_ctx_
,
"n_text_ctx"
);
SHERPA_ONNX_READ_META_DATA
(
n_text_state_
,
"n_text_state"
);
SHERPA_ONNX_READ_META_DATA
(
n_vocab_
,
"n_vocab"
);
SHERPA_ONNX_READ_META_DATA
(
sot_
,
"sot"
);
SHERPA_ONNX_READ_META_DATA
(
eot_
,
"eot"
);
SHERPA_ONNX_READ_META_DATA
(
blank_
,
"blank_id"
);
SHERPA_ONNX_READ_META_DATA
(
translate_
,
"translate"
);
SHERPA_ONNX_READ_META_DATA
(
transcribe_
,
"transcribe"
);
SHERPA_ONNX_READ_META_DATA
(
is_multilingual_
,
"is_multilingual"
);
SHERPA_ONNX_READ_META_DATA
(
no_timestamps_
,
"no_timestamps"
);
SHERPA_ONNX_READ_META_DATA
(
no_speech_
,
"no_speech"
);
SHERPA_ONNX_READ_META_DATA_VEC
(
sot_sequence_
,
"sot_sequence"
);
if
(
is_multilingual_
)
{
SHERPA_ONNX_READ_META_DATA_VEC
(
all_language_tokens_
,
"all_language_tokens"
);
SHERPA_ONNX_READ_META_DATA_VEC_STRING
(
all_language_codes_
,
"all_language_codes"
);
if
(
all_language_tokens_
.
size
()
!=
all_language_codes_
.
size
())
{
SHERPA_ONNX_LOGE
(
"# lang_id: %d != # lang_code: %d"
,
static_cast
<
int32_t
>
(
all_language_tokens_
.
size
()),
static_cast
<
int32_t
>
(
all_language_codes_
.
size
()));
exit
(
-
1
);
}
for
(
int32_t
i
=
0
;
i
!=
static_cast
<
int32_t
>
(
all_language_tokens_
.
size
());
++
i
)
{
lang2id_
[
all_language_codes_
[
i
]]
=
all_language_tokens_
[
i
];
id2lang_
[
all_language_tokens_
[
i
]]
=
all_language_codes_
[
i
];
}
}
}
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
...
...
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
std
::
vector
<
std
::
string
>
decoder_output_names_
;
std
::
vector
<
const
char
*>
decoder_output_names_ptr_
;
std
::
vector
<
int32_t
>
all_language_tokens_
;
std
::
vector
<
std
::
string
>
all_language_codes_
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
lang2id_
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id2lang_
;
// model meta data
int32_t
n_text_layer_
;
int32_t
n_text_ctx_
;
int32_t
n_text_state_
;
int32_t
n_vocab_
;
int32_t
sot_
;
int32_t
eot_
;
int32_t
blank_
;
int32_t
translate_
;
int32_t
transcribe_
;
int32_t
no_timestamps_
;
int32_t
no_speech_
;
int32_t
is_multilingual_
;
std
::
vector
<
int64_t
>
sot_sequence_
;
};
...
...
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
OfflineWhisperModel
::~
OfflineWhisperModel
()
=
default
;
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineWhisperModel
::
ForwardEncoder
(
Ort
::
Value
features
)
{
Ort
::
Value
features
)
const
{
return
impl_
->
ForwardEncoder
(
std
::
move
(
features
));
}
...
...
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
{
Ort
::
Value
offset
)
const
{
return
impl_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
n_layer_self_k_cache
),
std
::
move
(
n_layer_self_v_cache
),
std
::
move
(
n_layer_cross_k
),
std
::
move
(
n_layer_cross_v
),
std
::
move
(
offset
));
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineWhisperModel
::
GetInitialSelfKVCache
()
{
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineWhisperModel
::
GetInitialSelfKVCache
()
const
{
return
impl_
->
GetInitialSelfKVCache
();
}
...
...
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
return
impl_
->
GetInitialTokens
();
}
const
std
::
vector
<
int32_t
>
&
OfflineWhisperModel
::
GetAllLanguageIDs
()
const
{
return
impl_
->
GetAllLanguageIDs
();
}
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
OfflineWhisperModel
::
GetLang2ID
()
const
{
return
impl_
->
GetLang2ID
();
}
const
std
::
unordered_map
<
int32_t
,
std
::
string
>
&
OfflineWhisperModel
::
GetID2Lang
()
const
{
return
impl_
->
GetID2Lang
();
}
int32_t
OfflineWhisperModel
::
NoTimeStampsToken
()
const
{
return
impl_
->
NoTimeStampsToken
();
}
int32_t
OfflineWhisperModel
::
EOT
()
const
{
return
impl_
->
EOT
();
}
int32_t
OfflineWhisperModel
::
SOT
()
const
{
return
impl_
->
SOT
();
}
int32_t
OfflineWhisperModel
::
TextCtx
()
const
{
return
impl_
->
TextCtx
();
}
int32_t
OfflineWhisperModel
::
VocabSize
()
const
{
return
impl_
->
VocabSize
();
}
int32_t
OfflineWhisperModel
::
Translate
()
const
{
return
impl_
->
Translate
();
}
bool
OfflineWhisperModel
::
IsMultiLingual
()
const
{
return
impl_
->
IsMultiLingual
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-whisper-model.h
查看文件 @
f709c95
...
...
@@ -5,7 +5,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
...
...
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
* - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
*/
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
ForwardEncoder
(
Ort
::
Value
features
);
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
ForwardEncoder
(
Ort
::
Value
features
)
const
;
/** Run the decoder model.
*
...
...
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
Ort
::
Value
>
ForwardDecoder
(
Ort
::
Value
tokens
,
Ort
::
Value
n_layer_self_k_cache
,
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
);
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
const
;
int32_t
DetectLanguage
()
const
;
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
...
...
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
* - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*/
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
GetInitialSelfKVCache
();
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
GetInitialSelfKVCache
()
const
;
const
std
::
vector
<
int64_t
>
&
GetInitialTokens
()
const
;
const
std
::
vector
<
int32_t
>
&
GetAllLanguageIDs
()
const
;
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
GetLang2ID
()
const
;
const
std
::
unordered_map
<
int32_t
,
std
::
string
>
&
GetID2Lang
()
const
;
/** Return an allocator for allocating memory
*/
OrtAllocator
*
Allocator
()
const
;
int32_t
NoTimeStampsToken
()
const
;
int32_t
EOT
()
const
;
int32_t
SOT
()
const
;
int32_t
TextCtx
()
const
;
int32_t
VocabSize
()
const
;
int32_t
Translate
()
const
;
bool
IsMultiLingual
()
const
;
private
:
class
Impl
;
...
...
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
查看文件 @
f709c95
...
...
@@ -14,10 +14,14 @@ namespace sherpa_onnx {
void
PybindOfflineWhisperModelConfig
(
py
::
module
*
m
)
{
using
PyClass
=
OfflineWhisperModelConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineWhisperModelConfig"
)
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"encoder"
),
py
::
arg
(
"decoder"
))
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"encoder"
),
py
::
arg
(
"decoder"
),
py
::
arg
(
"language"
),
py
::
arg
(
"task"
))
.
def_readwrite
(
"encoder"
,
&
PyClass
::
encoder
)
.
def_readwrite
(
"decoder"
,
&
PyClass
::
decoder
)
.
def_readwrite
(
"language"
,
&
PyClass
::
language
)
.
def_readwrite
(
"task"
,
&
PyClass
::
task
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
f709c95
...
...
@@ -244,6 +244,8 @@ class OfflineRecognizer(object):
encoder
:
str
,
decoder
:
str
,
tokens
:
str
,
language
:
str
=
"en"
,
task
:
str
=
"transcribe"
,
num_threads
:
int
=
1
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
...
...
@@ -268,6 +270,14 @@ class OfflineRecognizer(object):
symbol integer_id
language:
The spoken language in the audio file. Example values: en, de, zh,
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
for all possible values. Note that for non-multilingual models, the
only valid value is 'en'.
task:
Valid values are: transcribe, translate. Note that for
non-multilingual models, the only valid value is 'transcribe'.
num_threads:
Number of threads for neural network computation.
decoding_method:
...
...
@@ -279,7 +289,12 @@ class OfflineRecognizer(object):
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
whisper
=
OfflineWhisperModelConfig
(
encoder
=
encoder
,
decoder
=
decoder
),
whisper
=
OfflineWhisperModelConfig
(
encoder
=
encoder
,
decoder
=
decoder
,
language
=
language
,
task
=
task
,
),
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
,
...
...
请
注册
或
登录
后发表评论