Fangjun Kuang
Committed by GitHub

Support multilingual whisper models (#274)

@@ -36,6 +36,9 @@ jobs: @@ -36,6 +36,9 @@ jobs:
36 CIBW_ARCHS: "universal2" 36 CIBW_ARCHS: "universal2"
37 CIBW_BUILD_VERBOSITY: 3 37 CIBW_BUILD_VERBOSITY: 3
38 38
  39 + # Don't repair macOS wheels
  40 + CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
  41 +
39 - name: Display wheels 42 - name: Display wheels
40 shell: bash 43 shell: bash
41 run: | 44 run: |
@@ -16,7 +16,7 @@ jobs: @@ -16,7 +16,7 @@ jobs:
16 fail-fast: false 16 fail-fast: false
17 matrix: 17 matrix:
18 os: [macos-latest] 18 os: [macos-latest]
19 - model: ["tiny.en", "base.en", "small.en", "medium.en"] 19 + model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
20 20
21 steps: 21 steps:
22 - uses: actions/checkout@v2 22 - uses: actions/checkout@v2
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.6") 4 +set(SHERPA_ONNX_VERSION "1.7.7")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -3,7 +3,7 @@ module non-streaming-decode-files @@ -3,7 +3,7 @@ module non-streaming-decode-files
3 go 1.12 3 go 1.12
4 4
5 require ( 5 require (
6 - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 6 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
7 github.com/spf13/pflag v1.0.5 7 github.com/spf13/pflag v1.0.5
8 github.com/youpy/go-wav v0.3.2 8 github.com/youpy/go-wav v0.3.2
9 ) 9 )
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= 3 github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
4 github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 4 github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
5 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=  
6 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=  
7 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=  
8 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=  
9 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=  
10 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=  
11 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=  
12 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= 5 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
  6 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
  7 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
  8 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
  9 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
  10 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
  11 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
  12 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
13 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
14 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
15 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -4,6 +4,6 @@ go 1.12 @@ -4,6 +4,6 @@ go 1.12
4 4
5 require ( 5 require (
6 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 6 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
7 - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 7 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
8 github.com/spf13/pflag v1.0.5 8 github.com/spf13/pflag v1.0.5
9 ) 9 )
1 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= 1 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
2 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= 2 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
3 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=  
4 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=  
5 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=  
6 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=  
7 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=  
8 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=  
9 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=  
10 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= 3 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
  4 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
  5 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
  6 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
  7 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
  8 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
  9 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
  10 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
11 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 11 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
12 github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 12 github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
@@ -3,7 +3,7 @@ module streaming-decode-files @@ -3,7 +3,7 @@ module streaming-decode-files
3 go 1.12 3 go 1.12
4 4
5 require ( 5 require (
6 - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 6 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
7 github.com/spf13/pflag v1.0.5 7 github.com/spf13/pflag v1.0.5
8 github.com/youpy/go-wav v0.3.2 8 github.com/youpy/go-wav v0.3.2
9 ) 9 )
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= 3 github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
4 github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 4 github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
5 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=  
6 -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=  
7 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=  
8 -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=  
9 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=  
10 -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=  
11 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=  
12 -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= 5 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
  6 +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
  7 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
  8 +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
  9 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
  10 +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
  11 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
  12 +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
13 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
14 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
15 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -11,10 +11,12 @@ fun main() { @@ -11,10 +11,12 @@ fun main() {
11 // please refer to 11 // please refer to
12 // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 12 // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
13 // to dowload pre-trained models 13 // to dowload pre-trained models
14 - var modelConfig = OnlineTransducerModelConfig( 14 + var modelConfig = OnlineModelConfig(
  15 + transducer = OnlineTransducerModelConfig(
15 encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", 16 encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
16 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", 17 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
17 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", 18 joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
  19 + ),
18 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", 20 tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
19 numThreads = 1, 21 numThreads = 1,
20 debug = false, 22 debug = false,
@@ -41,16 +43,16 @@ fun main() { @@ -41,16 +43,16 @@ fun main() {
41 var objArray = WaveReader.readWaveFromFile( 43 var objArray = WaveReader.readWaveFromFile(
42 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", 44 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
43 ) 45 )
44 - var samples : FloatArray = objArray[0] as FloatArray  
45 - var sampleRate : Int = objArray[1] as Int 46 + var samples: FloatArray = objArray[0] as FloatArray
  47 + var sampleRate: Int = objArray[1] as Int
46 48
47 - model.acceptWaveform(samples, sampleRate=sampleRate) 49 + model.acceptWaveform(samples, sampleRate = sampleRate)
48 while (model.isReady()) { 50 while (model.isReady()) {
49 model.decode() 51 model.decode()
50 } 52 }
51 53
52 var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds 54 var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
53 - model.acceptWaveform(tailPaddings, sampleRate=sampleRate) 55 + model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
54 model.inputFinished() 56 model.inputFinished()
55 while (model.isReady()) { 57 while (model.isReady()) {
56 model.decode() 58 model.decode()
@@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): @@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
234 help="Path to whisper decoder model", 234 help="Path to whisper decoder model",
235 ) 235 )
236 236
  237 + parser.add_argument(
  238 + "--whisper-language",
  239 + default="",
  240 + type=str,
  241 + help="""It specifies the spoken language in the input audio file.
  242 + Example values: en, fr, de, zh, jp.
  243 + Available languages for multilingual models can be found at
  244 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  245 + If not specified, we infer the language from the input audio file.
  246 + """,
  247 + )
  248 +
  249 + parser.add_argument(
  250 + "--whisper-task",
  251 + default="transcribe",
  252 + choices=["transcribe", "translate"],
  253 + type=str,
  254 + help="""For multilingual models, if you specify translate, the output
  255 + will be in English.
  256 + """,
  257 + )
  258 +
237 259
238 def add_model_args(parser: argparse.ArgumentParser): 260 def add_model_args(parser: argparse.ArgumentParser):
239 add_transducer_model_args(parser) 261 add_transducer_model_args(parser)
@@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
813 tokens=args.tokens, 835 tokens=args.tokens,
814 num_threads=args.num_threads, 836 num_threads=args.num_threads,
815 decoding_method=args.decoding_method, 837 decoding_method=args.decoding_method,
  838 + language=args.whisper_language,
  839 + task=args.whisper_task,
816 ) 840 )
817 elif args.tdnn_model: 841 elif args.tdnn_model:
818 assert_file_exists(args.tdnn_model) 842 assert_file_exists(args.tdnn_model)
@@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
53 --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ 53 --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
54 --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ 54 --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
55 --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ 55 --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  56 + --whisper-task=transcribe \
56 --num-threads=1 \ 57 --num-threads=1 \
57 ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ 58 ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
58 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ 59 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
@@ -201,6 +202,28 @@ def get_args(): @@ -201,6 +202,28 @@ def get_args():
201 ) 202 )
202 203
203 parser.add_argument( 204 parser.add_argument(
  205 + "--whisper-language",
  206 + default="",
  207 + type=str,
  208 + help="""It specifies the spoken language in the input audio file.
  209 + Example values: en, fr, de, zh, jp.
  210 + Available languages for multilingual models can be found at
  211 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  212 + If not specified, we infer the language from the input audio file.
  213 + """,
  214 + )
  215 +
  216 + parser.add_argument(
  217 + "--whisper-task",
  218 + default="transcribe",
  219 + choices=["transcribe", "translate"],
  220 + type=str,
  221 + help="""For multilingual models, if you specify translate, the output
  222 + will be in English.
  223 + """,
  224 + )
  225 +
  226 + parser.add_argument(
204 "--decoding-method", 227 "--decoding-method",
205 type=str, 228 type=str,
206 default="greedy_search", 229 default="greedy_search",
@@ -371,10 +394,10 @@ def main(): @@ -371,10 +394,10 @@ def main():
371 decoder=args.whisper_decoder, 394 decoder=args.whisper_decoder,
372 tokens=args.tokens, 395 tokens=args.tokens,
373 num_threads=args.num_threads, 396 num_threads=args.num_threads,
374 - sample_rate=args.sample_rate,  
375 - feature_dim=args.feature_dim,  
376 decoding_method=args.decoding_method, 397 decoding_method=args.decoding_method,
377 debug=args.debug, 398 debug=args.debug,
  399 + language=args.whisper_language,
  400 + task=args.whisper_task,
378 ) 401 )
379 elif args.tdnn_model: 402 elif args.tdnn_model:
380 assert_file_exists(args.tdnn_model) 403 assert_file_exists(args.tdnn_model)
@@ -11,6 +11,7 @@ for making the onnx export script public. @@ -11,6 +11,7 @@ for making the onnx export script public.
11 """ 11 """
12 12
13 import argparse 13 import argparse
  14 +import os
14 from pathlib import Path 15 from pathlib import Path
15 from typing import Any, Dict, Optional 16 from typing import Any, Dict, Optional
16 17
@@ -250,6 +251,7 @@ def main(): @@ -250,6 +251,7 @@ def main():
250 # write tokens 251 # write tokens
251 252
252 tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) 253 tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
  254 +
253 model.eval() 255 model.eval()
254 print(model.dims) 256 print(model.dims)
255 audio = torch.rand(16000 * 2) 257 audio = torch.rand(16000 * 2)
@@ -306,8 +308,12 @@ def main(): @@ -306,8 +308,12 @@ def main():
306 "n_text_head": model.dims.n_text_head, 308 "n_text_head": model.dims.n_text_head,
307 "n_text_layer": model.dims.n_text_layer, 309 "n_text_layer": model.dims.n_text_layer,
308 "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), 310 "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
309 - "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),  
310 - "all_language_codes": ",".join(tokenizer.all_language_codes), 311 + "all_language_tokens": ",".join(
  312 + list(map(str, tokenizer.all_language_tokens))
  313 + ), # a list of ids
  314 + "all_language_codes": ",".join(
  315 + tokenizer.all_language_codes
  316 + ), # e.g., en, de, zh, fr
311 "sot": tokenizer.sot, 317 "sot": tokenizer.sot,
312 "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), 318 "sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
313 "eot": tokenizer.eot, 319 "eot": tokenizer.eot,
@@ -413,6 +419,9 @@ def main(): @@ -413,6 +419,9 @@ def main():
413 }, 419 },
414 ) 420 )
415 421
  422 + if 'large' in args.model:
  423 + # it causes errors for large models, so skip it.
  424 + return
416 # Generate int8 quantization models 425 # Generate int8 quantization models
417 # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection 426 # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
418 427
@@ -39,6 +39,24 @@ def get_args(): @@ -39,6 +39,24 @@ def get_args():
39 ) 39 )
40 40
41 parser.add_argument( 41 parser.add_argument(
  42 + "--language",
  43 + type=str,
  44 + help="""The actual spoken language in the audio.
  45 + Example values, en, de, zh, jp, fr.
  46 + If None, we will detect the language using the first 30s of the
  47 + input audio
  48 + """,
  49 + )
  50 +
  51 + parser.add_argument(
  52 + "--task",
  53 + choices=["transcribe", "translate"],
  54 + type=str,
  55 + default="transcribe",
  56 + help="Valid values are: transcribe, translate",
  57 + )
  58 +
  59 + parser.add_argument(
42 "sound_file", 60 "sound_file",
43 type=str, 61 type=str,
44 help="Path to the test wave", 62 help="Path to the test wave",
@@ -74,12 +92,22 @@ class OnnxModel: @@ -74,12 +92,22 @@ class OnnxModel:
74 self.sot = int(meta["sot"]) 92 self.sot = int(meta["sot"])
75 self.eot = int(meta["eot"]) 93 self.eot = int(meta["eot"])
76 self.translate = int(meta["translate"]) 94 self.translate = int(meta["translate"])
  95 + self.transcribe = int(meta["transcribe"])
77 self.no_timestamps = int(meta["no_timestamps"]) 96 self.no_timestamps = int(meta["no_timestamps"])
78 self.no_speech = int(meta["no_speech"]) 97 self.no_speech = int(meta["no_speech"])
79 self.blank = int(meta["blank_id"]) 98 self.blank = int(meta["blank_id"])
80 99
81 self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) 100 self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
82 101
  102 + self.sot_sequence.append(self.no_timestamps)
  103 +
  104 + self.all_language_tokens = list(
  105 + map(int, meta["all_language_tokens"].split(","))
  106 + )
  107 + self.all_language_codes = meta["all_language_codes"].split(",")
  108 + self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
  109 + self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
  110 +
83 self.is_multilingual = int(meta["is_multilingual"]) == 1 111 self.is_multilingual = int(meta["is_multilingual"]) == 1
84 112
85 def init_decoder(self, decoder: str): 113 def init_decoder(self, decoder: str):
@@ -164,6 +192,29 @@ class OnnxModel: @@ -164,6 +192,29 @@ class OnnxModel:
164 # logits is changed in-place 192 # logits is changed in-place
165 logits[self.translate] = float("-inf") 193 logits[self.translate] = float("-inf")
166 194
  195 + def detect_language(
  196 + self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
  197 + ) -> int:
  198 + tokens = torch.tensor([[self.sot]], dtype=torch.int64)
  199 + offset = torch.zeros(1, dtype=torch.int64)
  200 + n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
  201 +
  202 + logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
  203 + tokens=tokens,
  204 + n_layer_self_k_cache=n_layer_self_k_cache,
  205 + n_layer_self_v_cache=n_layer_self_v_cache,
  206 + n_layer_cross_k=n_layer_cross_k,
  207 + n_layer_cross_v=n_layer_cross_v,
  208 + offset=offset,
  209 + )
  210 + logits = logits.reshape(-1)
  211 + mask = torch.ones(logits.shape[0], dtype=torch.int64)
  212 + mask[self.all_language_tokens] = 0
  213 + logits[mask] = float("-inf")
  214 + lang_id = logits.argmax().item()
  215 + print("detected language: ", self.id2lang[lang_id])
  216 + return lang_id
  217 +
167 218
168 def load_tokens(filename): 219 def load_tokens(filename):
169 tokens = dict() 220 tokens = dict()
@@ -200,7 +251,35 @@ def main(): @@ -200,7 +251,35 @@ def main():
200 mel = mel.t().unsqueeze(0) 251 mel = mel.t().unsqueeze(0)
201 252
202 model = OnnxModel(encoder, decoder) 253 model = OnnxModel(encoder, decoder)
  254 +
203 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) 255 n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
  256 +
  257 + if args.language is not None:
  258 + if model.is_multilingual is False and args.language != "en":
  259 + print(f"This model supports only English. Given: {args.language}")
  260 + return
  261 +
  262 + if args.language not in model.lang2id:
  263 + print(f"Invalid language: {args.language}")
  264 + print(f"Valid values are: {list(model.lang2id.keys())}")
  265 + return
  266 +
  267 + # [sot, lang, task, notimestamps]
  268 + model.sot_sequence[1] = model.lang2id[args.language]
  269 + elif model.is_multilingual is True:
  270 + print("detecting language")
  271 + lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
  272 + model.sot_sequence[1] = lang
  273 +
  274 + if args.task is not None:
  275 + if model.is_multilingual is False and args.task != "transcribe":
  276 + print("This model supports only English. Please use --task=transcribe")
  277 + return
  278 + assert args.task in ["transcribe", "translate"], args.task
  279 +
  280 + if args.task == "translate":
  281 + model.sot_sequence[2] = model.translate
  282 +
204 n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() 283 n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
205 284
206 tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) 285 tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
@@ -213,6 +292,7 @@ def main(): @@ -213,6 +292,7 @@ def main():
213 n_layer_cross_v=n_layer_cross_v, 292 n_layer_cross_v=n_layer_cross_v,
214 offset=offset, 293 offset=offset,
215 ) 294 )
  295 + offset += len(model.sot_sequence)
216 # logits.shape (batch_size, tokens.shape[1], vocab_size) 296 # logits.shape (batch_size, tokens.shape[1], vocab_size)
217 logits = logits[0, -1] 297 logits = logits[0, -1]
218 model.suppress_tokens(logits, is_initial=True) 298 model.suppress_tokens(logits, is_initial=True)
@@ -225,7 +305,6 @@ def main(): @@ -225,7 +305,6 @@ def main():
225 break 305 break
226 results.append(max_token_id.item()) 306 results.append(max_token_id.item())
227 tokens = torch.tensor([[results[-1]]]) 307 tokens = torch.tensor([[results[-1]]])
228 - offset += 1  
229 308
230 logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( 309 logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
231 tokens=tokens, 310 tokens=tokens,
@@ -235,6 +314,7 @@ def main(): @@ -235,6 +314,7 @@ def main():
235 n_layer_cross_v=n_layer_cross_v, 314 n_layer_cross_v=n_layer_cross_v,
236 offset=offset, 315 offset=offset,
237 ) 316 )
  317 + offset += 1
238 logits = logits[0, -1] 318 logits = logits[0, -1]
239 model.suppress_tokens(logits, is_initial=False) 319 model.suppress_tokens(logits, is_initial=False)
240 max_token_id = logits.argmax(dim=-1) 320 max_token_id = logits.argmax(dim=-1)
@@ -37,7 +37,7 @@ @@ -37,7 +37,7 @@
37 } \ 37 } \
38 \ 38 \
39 dst = atoi(value.get()); \ 39 dst = atoi(value.get()); \
40 - if (dst <= 0) { \ 40 + if (dst < 0) { \
41 SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ 41 SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
42 exit(-1); \ 42 exit(-1); \
43 } \ 43 } \
@@ -77,6 +77,24 @@ @@ -77,6 +77,24 @@
77 } \ 77 } \
78 } while (0) 78 } while (0)
79 79
  80 +// read a vector of strings
  81 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
  82 + do { \
  83 + auto value = \
  84 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  85 + if (!value) { \
  86 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  87 + exit(-1); \
  88 + } \
  89 + SplitStringToVector(value.get(), ",", false, &dst); \
  90 + \
  91 + if (dst.empty()) { \
  92 + SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
  93 + src_key); \
  94 + exit(-1); \
  95 + } \
  96 + } while (0)
  97 +
80 // Read a string 98 // Read a string
81 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ 99 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
82 do { \ 100 do { \
@@ -23,21 +23,227 @@ @@ -23,21 +23,227 @@
23 23
24 namespace sherpa_onnx { 24 namespace sherpa_onnx {
25 25
  26 +static std::string FixInvalidUtf8(const std::string &s) {
  27 + int32_t s_size = s.size();
  28 +
  29 + std::string ans;
  30 + ans.reserve(s_size);
  31 +
  32 + for (int32_t i = 0; i < s_size;) {
  33 + uint8_t c = s[i];
  34 + if (c < 0x80) {
  35 + // valid
  36 + ans.append(1, c);
  37 + ++i;
  38 + continue;
  39 + } else if ((c >= 0xc0) && (c < 0xe0)) {
  40 + // beginning of two bytes
  41 + if ((i + 1) > (s_size - 1)) {
  42 + // no subsequent byte. invalid!
  43 + i += 1;
  44 + continue;
  45 + }
  46 + uint8_t next = s[i + 1];
  47 + if (!(next >= 0x80 && next < 0xc0)) {
  48 + // invalid
  49 + i += 1;
  50 + continue;
  51 + }
  52 + // valid 2-byte utf-8
  53 + ans.append(1, c);
  54 + ans.append(1, next);
  55 + i += 2;
  56 + continue;
  57 + } else if ((c >= 0xe0) && (c < 0xf0)) {
  58 + // beginning of 3 bytes
  59 + if ((i + 2) > (s_size - 1)) {
  60 + // invalid
  61 + i += 1;
  62 + continue;
  63 + }
  64 +
  65 + uint8_t next = s[i + 1];
  66 + if (!(next >= 0x80 && next < 0xc0)) {
  67 + // invalid
  68 + i += 1;
  69 + continue;
  70 + }
  71 +
  72 + uint8_t next2 = s[i + 2];
  73 + if (!(next2 >= 0x80 && next2 < 0xc0)) {
  74 + // invalid
  75 + i += 1;
  76 + continue;
  77 + }
  78 +
  79 + ans.append(1, c);
  80 + ans.append(1, next);
  81 + ans.append(1, next2);
  82 + i += 3;
  83 + continue;
  84 + } else if ((c >= 0xf0) && (c < 0xf8)) {
  85 + // 4 bytes
  86 + if ((i + 3) > (s_size - 1)) {
  87 + // invalid
  88 + i += 1;
  89 + continue;
  90 + }
  91 +
  92 + uint8_t next = s[i + 1];
  93 + if (!(next >= 0x80 && next < 0xc0)) {
  94 + // invalid
  95 + i += 1;
  96 + continue;
  97 + }
  98 +
  99 + uint8_t next2 = s[i + 2];
  100 + if (!(next2 >= 0x80 && next2 < 0xc0)) {
  101 + // invalid
  102 + i += 1;
  103 + continue;
  104 + }
  105 +
  106 + uint8_t next3 = s[i + 3];
  107 + if (!(next3 >= 0x80 && next3 < 0xc0)) {
  108 + // invalid
  109 + i += 1;
  110 + continue;
  111 + }
  112 + ans.append(1, c);
  113 + ans.append(1, next);
  114 + ans.append(1, next2);
  115 + ans.append(1, next3);
  116 + i += 4;
  117 + continue;
  118 + } else if ((c >= 0xf8) && (c < 0xfc)) {
  119 + // 5 bytes
  120 + if ((i + 4) > (s_size - 1)) {
  121 + // invalid
  122 + i += 1;
  123 + continue;
  124 + }
  125 +
  126 + uint8_t next = s[i + 1];
  127 + if (!(next >= 0x80 && next < 0xc0)) {
  128 + // invalid
  129 + i += 1;
  130 + continue;
  131 + }
  132 +
  133 + uint8_t next2 = s[i + 2];
  134 + if (!(next2 >= 0x80 && next2 < 0xc0)) {
  135 + // invalid
  136 + i += 1;
  137 + continue;
  138 + }
  139 +
  140 + uint8_t next3 = s[i + 3];
  141 + if (!(next3 >= 0x80 && next3 < 0xc0)) {
  142 + // invalid
  143 + i += 1;
  144 + continue;
  145 + }
  146 +
  147 + uint8_t next4 = s[i + 4];
  148 + if (!(next4 >= 0x80 && next4 < 0xc0)) {
  149 + // invalid
  150 + i += 1;
  151 + continue;
  152 + }
  153 + ans.append(1, c);
  154 + ans.append(1, next);
  155 + ans.append(1, next2);
  156 + ans.append(1, next3);
  157 + ans.append(1, next4);
  158 + i += 5;
  159 + continue;
  160 + } else if ((c >= 0xfc) && (c < 0xfe)) {
  161 + // 6 bytes
  162 + if ((i + 5) > (s_size - 1)) {
  163 + // invalid
  164 + i += 1;
  165 + continue;
  166 + }
  167 +
  168 + uint8_t next = s[i + 1];
  169 + if (!(next >= 0x80 && next < 0xc0)) {
  170 + // invalid
  171 + i += 1;
  172 + continue;
  173 + }
  174 +
  175 + uint8_t next2 = s[i + 2];
  176 + if (!(next2 >= 0x80 && next2 < 0xc0)) {
  177 + // invalid
  178 + i += 1;
  179 + continue;
  180 + }
  181 +
  182 + uint8_t next3 = s[i + 3];
  183 + if (!(next3 >= 0x80 && next3 < 0xc0)) {
  184 + // invalid
  185 + i += 1;
  186 + continue;
  187 + }
  188 +
  189 + uint8_t next4 = s[i + 4];
  190 + if (!(next4 >= 0x80 && next4 < 0xc0)) {
  191 + // invalid
  192 + i += 1;
  193 + continue;
  194 + }
  195 +
  196 + uint8_t next5 = s[i + 5];
  197 + if (!(next5 >= 0x80 && next5 < 0xc0)) {
  198 + // invalid
  199 + i += 1;
  200 + continue;
  201 + }
  202 + ans.append(1, c);
  203 + ans.append(1, next);
  204 + ans.append(1, next2);
  205 + ans.append(1, next3);
  206 + ans.append(1, next4);
  207 + ans.append(1, next5);
  208 + i += 6;
  209 + continue;
  210 + } else {
  211 + i += 1;
  212 + }
  213 + }
  214 + return ans;
  215 +}
  216 +
26 static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, 217 static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
27 const SymbolTable &sym_table) { 218 const SymbolTable &sym_table) {
28 OfflineRecognitionResult r; 219 OfflineRecognitionResult r;
29 r.tokens.reserve(src.tokens.size()); 220 r.tokens.reserve(src.tokens.size());
30 221
  222 + std::string text;
31 for (auto i : src.tokens) { 223 for (auto i : src.tokens) {
32 if (!sym_table.contains(i)) { 224 if (!sym_table.contains(i)) {
33 continue; 225 continue;
34 } 226 }
35 227
36 const auto &s = sym_table[i]; 228 const auto &s = sym_table[i];
37 - r.text += s; 229 + text += s;
38 r.tokens.push_back(s); 230 r.tokens.push_back(s);
39 } 231 }
40 232
  233 + // TODO(fangjun): Fix the following error in offline-stream.cc
  234 + //
  235 + // j["text"] = text;
  236 +
  237 + // libc++abi: terminating with uncaught exception of type
  238 + // nlohmann::json_abi_v3_11_2::detail::type_error:
  239 + // [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
  240 +
  241 +#if 0
  242 + r.text = FixInvalidUtf8(text);
  243 +#else
  244 + r.text = text;
  245 +#endif
  246 +
41 return r; 247 return r;
42 } 248 }
43 249
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
51 symbol_table_.ApplyBase64Decode(); 257 symbol_table_.ApplyBase64Decode();
52 258
53 if (config.decoding_method == "greedy_search") { 259 if (config.decoding_method == "greedy_search") {
54 - decoder_ =  
55 - std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get()); 260 + decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
  261 + config_.model_config.whisper, model_.get());
56 } else { 262 } else {
57 SHERPA_ONNX_LOGE( 263 SHERPA_ONNX_LOGE(
58 "Only greedy_search is supported at present for whisper. Given %s", 264 "Only greedy_search is supported at present for whisper. Given %s",
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
101 mel = Transpose12(model_->Allocator(), &mel); 307 mel = Transpose12(model_->Allocator(), &mel);
102 308
103 auto cross_kv = model_->ForwardEncoder(std::move(mel)); 309 auto cross_kv = model_->ForwardEncoder(std::move(mel));
  310 +
104 auto results = 311 auto results =
105 decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); 312 decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
106 313
@@ -7,17 +7,106 @@ @@ -7,17 +7,106 @@
7 #include <algorithm> 7 #include <algorithm>
8 #include <utility> 8 #include <utility>
9 9
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +
10 namespace sherpa_onnx { 13 namespace sherpa_onnx {
11 14
  15 +int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
  16 + Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
  17 + int64_t token_val = model_->SOT();
  18 + std::array<int64_t, 2> token_shape{1, 1};
  19 +
  20 + auto memory_info =
  21 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  22 +
  23 + Ort::Value tokens = Ort::Value::CreateTensor(
  24 + memory_info, &token_val, 1, token_shape.data(), token_shape.size());
  25 +
  26 + auto self_kv_cache = model_->GetInitialSelfKVCache();
  27 +
  28 + std::array<int64_t, 1> offset_shape{1};
  29 + Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
  30 + model_->Allocator(), offset_shape.data(), offset_shape.size());
  31 + *(offset.GetTensorMutableData<int64_t>()) = 0;
  32 +
  33 + auto decoder_out = model_->ForwardDecoder(
  34 + std::move(tokens), std::move(self_kv_cache.first),
  35 + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
  36 + std::move(offset));
  37 +
  38 + cross_k = std::move(std::get<3>(decoder_out));
  39 + cross_v = std::move(std::get<4>(decoder_out));
  40 +
  41 + const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
  42 + int32_t vocab_size = model_->VocabSize();
  43 + const auto &all_language_ids = model_->GetAllLanguageIDs();
  44 +
  45 + int32_t lang_id = all_language_ids[0];
  46 + float this_logit = p_logits[lang_id];
  47 +
  48 + for (int32_t i = 1; i != all_language_ids.size(); ++i) {
  49 + int32_t id = all_language_ids[i];
  50 + float p = p_logits[id];
  51 +
  52 + if (p > this_logit) {
  53 + this_logit = p;
  54 + lang_id = id;
  55 + }
  56 + }
  57 +#if 1
  58 + SHERPA_ONNX_LOGE("Detected language: %s",
  59 + model_->GetID2Lang().at(lang_id).c_str());
  60 +#endif
  61 +
  62 + return lang_id;
  63 +}
  64 +
12 std::vector<OfflineWhisperDecoderResult> 65 std::vector<OfflineWhisperDecoderResult>
13 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, 66 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
14 Ort::Value cross_v) { 67 Ort::Value cross_v) {
15 auto memory_info = 68 auto memory_info =
16 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 69 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
17 70
18 - auto self_kv_cache = model_->GetInitialSelfKVCache();  
19 - 71 + // For multilingual models, initial_tokens contains [sot, language, task]
  72 + // - language is English by default
  73 + // - task is transcribe by default
  74 + //
  75 + // For non-multilingual models, initial_tokens contains [sot]
20 std::vector<int64_t> initial_tokens = model_->GetInitialTokens(); 76 std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
  77 +
  78 + if (model_->IsMultiLingual()) {
  79 + if (!config_.language.empty()) {
  80 + const auto &lang2id = model_->GetLang2ID();
  81 +
  82 + if (!lang2id.count(config_.language)) {
  83 + SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
  84 + exit(-1);
  85 + }
  86 +
  87 + int32_t lang_id = lang2id.at(config_.language);
  88 +
  89 + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps
  90 + initial_tokens[1] = lang_id;
  91 + } else {
  92 + int32_t lang_id = DetectLanguage(cross_k, cross_v);
  93 +
  94 + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps
  95 + initial_tokens[1] = lang_id;
  96 + }
  97 +
  98 + if (config_.task == "translate") {
  99 + initial_tokens[2] = model_->Translate();
  100 + } else if (config_.task != "transcribe") {
  101 + // initial_tokens[2] is transcribe by default
  102 + SHERPA_ONNX_LOGE(
  103 + "Unsupported task: %s. Valid values are: transcribe, translate.",
  104 + config_.task.c_str());
  105 + }
  106 + }
  107 +
  108 + initial_tokens.push_back(model_->NoTimeStampsToken());
  109 +
21 int32_t batch_size = 1; 110 int32_t batch_size = 1;
22 std::array<int64_t, 2> token_shape{ 111 std::array<int64_t, 2> token_shape{
23 batch_size, static_cast<int64_t>(initial_tokens.size())}; 112 batch_size, static_cast<int64_t>(initial_tokens.size())};
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
31 model_->Allocator(), offset_shape.data(), offset_shape.size()); 120 model_->Allocator(), offset_shape.data(), offset_shape.size());
32 *(offset.GetTensorMutableData<int64_t>()) = 0; 121 *(offset.GetTensorMutableData<int64_t>()) = 0;
33 122
  123 + auto self_kv_cache = model_->GetInitialSelfKVCache();
  124 +
34 auto decoder_out = model_->ForwardDecoder( 125 auto decoder_out = model_->ForwardDecoder(
35 std::move(tokens), std::move(self_kv_cache.first), 126 std::move(tokens), std::move(self_kv_cache.first),
36 std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), 127 std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
37 std::move(offset)); 128 std::move(offset));
38 129
  130 + *(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
  131 + initial_tokens.size();
  132 +
39 const auto &logits = std::get<0>(decoder_out); 133 const auto &logits = std::get<0>(decoder_out);
40 const float *p_logits = logits.GetTensorData<float>(); 134 const float *p_logits = logits.GetTensorData<float>();
41 135
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
58 std::array<int64_t, 2> token_shape{1, 1}; 152 std::array<int64_t, 2> token_shape{1, 1};
59 Ort::Value tokens = Ort::Value::CreateTensor<int64_t>( 153 Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
60 model_->Allocator(), token_shape.data(), token_shape.size()); 154 model_->Allocator(), token_shape.data(), token_shape.size());
  155 +
61 int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>(); 156 int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
62 p_tokens[0] = max_token_id; 157 p_tokens[0] = max_token_id;
63 158
64 - int64_t *p_offset =  
65 - std::get<5>(decoder_out).GetTensorMutableData<int64_t>();  
66 -  
67 - if (i == 0) {  
68 - *p_offset = initial_tokens.size();  
69 - } else {  
70 - *p_offset += 1;  
71 - }  
72 -  
73 decoder_out = model_->ForwardDecoder(std::move(tokens), 159 decoder_out = model_->ForwardDecoder(std::move(tokens),
74 std::move(std::get<1>(decoder_out)), 160 std::move(std::get<1>(decoder_out)),
75 std::move(std::get<2>(decoder_out)), 161 std::move(std::get<2>(decoder_out)),
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
77 std::move(std::get<4>(decoder_out)), 163 std::move(std::get<4>(decoder_out)),
78 std::move(std::get<5>(decoder_out))); 164 std::move(std::get<5>(decoder_out)));
79 165
  166 + int64_t *p_offset =
  167 + std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
  168 +
  169 + *p_offset += 1;
  170 +
80 const auto &logits = std::get<0>(decoder_out); 171 const auto &logits = std::get<0>(decoder_out);
81 const float *p_logits = logits.GetTensorData<float>(); 172 const float *p_logits = logits.GetTensorData<float>();
82 173
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
85 } 176 }
86 177
87 std::vector<OfflineWhisperDecoderResult> ans(1); 178 std::vector<OfflineWhisperDecoderResult> ans(1);
  179 +
88 ans[0].tokens = std::move(predicted_tokens); 180 ans[0].tokens = std::move(predicted_tokens);
89 181
90 return ans; 182 return ans;
@@ -8,19 +8,25 @@ @@ -8,19 +8,25 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/offline-whisper-decoder.h" 10 #include "sherpa-onnx/csrc/offline-whisper-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
11 #include "sherpa-onnx/csrc/offline-whisper-model.h" 12 #include "sherpa-onnx/csrc/offline-whisper-model.h"
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
14 15
15 class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { 16 class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
16 public: 17 public:
17 - explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)  
18 - : model_(model) {} 18 + OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
  19 + OfflineWhisperModel *model)
  20 + : config_(config), model_(model) {}
19 21
20 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, 22 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
21 Ort::Value cross_v) override; 23 Ort::Value cross_v) override;
22 24
  25 + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
  26 + Ort::Value &cross_v) const; // NOLINT
  27 +
23 private: 28 private:
  29 + OfflineWhisperModelConfig config_;
24 OfflineWhisperModel *model_; // not owned 30 OfflineWhisperModel *model_; // not owned
25 }; 31 };
26 32
@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { @@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
17 po->Register("whisper-decoder", &decoder, 17 po->Register("whisper-decoder", &decoder,
18 "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " 18 "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
19 "medium.en-decoder.onnx."); 19 "medium.en-decoder.onnx.");
  20 +
  21 + po->Register(
  22 + "whisper-language", &language,
  23 + "The spoke language in the input audio file. Example values: "
  24 + "en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
  25 + " infer the language from the input audio file. "
  26 + "Please refer to "
  27 + "https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
  28 + " for valid values. Note that for non-multilingual models, it supports "
  29 + "only 'en'");
  30 +
  31 + po->Register("whisper-task", &task,
  32 + "Valid values: transcribe, translate. "
  33 + "Note that for non-multilingual models, it supports "
  34 + "only 'transcribe'");
20 } 35 }
21 36
22 bool OfflineWhisperModelConfig::Validate() const { 37 bool OfflineWhisperModelConfig::Validate() const {
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const { @@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
30 return false; 45 return false;
31 } 46 }
32 47
  48 + if (task != "translate" && task != "transcribe") {
  49 + SHERPA_ONNX_LOGE(
  50 + "--whisper-task supports only translate and transcribe. Given: %s",
  51 + task.c_str());
  52 +
  53 + return false;
  54 + }
  55 +
33 return true; 56 return true;
34 } 57 }
35 58
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const { @@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
38 61
39 os << "OfflineWhisperModelConfig("; 62 os << "OfflineWhisperModelConfig(";
40 os << "encoder=\"" << encoder << "\", "; 63 os << "encoder=\"" << encoder << "\", ";
41 - os << "decoder=\"" << decoder << "\")"; 64 + os << "decoder=\"" << decoder << "\", ";
  65 + os << "language=\"" << language << "\", ";
  66 + os << "task=\"" << task << "\")";
42 67
43 return os.str(); 68 return os.str();
44 } 69 }
@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig { @@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
14 std::string encoder; 14 std::string encoder;
15 std::string decoder; 15 std::string decoder;
16 16
  17 + // Available languages can be found at
  18 + // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  19 + //
  20 + // Note: For non-multilingual models, it supports only "en"
  21 + //
  22 + // If empty, we will infer it from the input audio file when
  23 + // the model is multilingual.
  24 + std::string language;
  25 +
  26 + // Valid values are transcribe and translate
  27 + //
  28 + // Note: For non-multilingual models, it supports only "transcribe"
  29 + std::string task = "transcribe";
  30 +
17 OfflineWhisperModelConfig() = default; 31 OfflineWhisperModelConfig() = default;
18 OfflineWhisperModelConfig(const std::string &encoder, 32 OfflineWhisperModelConfig(const std::string &encoder,
19 - const std::string &decoder)  
20 - : encoder(encoder), decoder(decoder) {} 33 + const std::string &decoder,
  34 + const std::string &language,
  35 + const std::string &task)
  36 + : encoder(encoder), decoder(decoder), language(language), task(task) {}
21 37
22 void Register(ParseOptions *po); 38 void Register(ParseOptions *po);
23 bool Validate() const; 39 bool Validate() const;
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 #include <algorithm> 7 #include <algorithm>
8 #include <string> 8 #include <string>
9 #include <tuple> 9 #include <tuple>
  10 +#include <unordered_map>
10 #include <utility> 11 #include <utility>
11 12
12 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl { @@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
88 89
89 const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } 90 const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
90 91
  92 + const std::vector<int32_t> &GetAllLanguageIDs() const {
  93 + return all_language_tokens_;
  94 + }
  95 +
  96 + const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
  97 + return lang2id_;
  98 + }
  99 +
  100 + const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
  101 + return id2lang_;
  102 + }
  103 +
  104 + int32_t NoTimeStampsToken() const { return no_timestamps_; }
  105 +
91 int32_t EOT() const { return eot_; } 106 int32_t EOT() const { return eot_; }
92 107
  108 + int32_t SOT() const { return sot_; }
  109 +
93 int32_t TextCtx() const { return n_text_ctx_; } 110 int32_t TextCtx() const { return n_text_ctx_; }
94 111
  112 + int32_t VocabSize() const { return n_vocab_; }
  113 +
  114 + int32_t Translate() const { return translate_; }
  115 +
  116 + bool IsMultiLingual() const { return is_multilingual_; }
  117 +
95 private: 118 private:
96 void InitEncoder(void *model_data, size_t model_data_length) { 119 void InitEncoder(void *model_data, size_t model_data_length) {
97 encoder_sess_ = std::make_unique<Ort::Session>( 120 encoder_sess_ = std::make_unique<Ort::Session>(
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl { @@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
116 SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); 139 SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
117 SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); 140 SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
118 SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); 141 SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
  142 + SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
119 SHERPA_ONNX_READ_META_DATA(sot_, "sot"); 143 SHERPA_ONNX_READ_META_DATA(sot_, "sot");
120 SHERPA_ONNX_READ_META_DATA(eot_, "eot"); 144 SHERPA_ONNX_READ_META_DATA(eot_, "eot");
121 SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); 145 SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
122 SHERPA_ONNX_READ_META_DATA(translate_, "translate"); 146 SHERPA_ONNX_READ_META_DATA(translate_, "translate");
  147 + SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
  148 + SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
123 SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); 149 SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
124 SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); 150 SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
125 SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); 151 SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
  152 +
  153 + if (is_multilingual_) {
  154 + SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
  155 + "all_language_tokens");
  156 + SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
  157 + "all_language_codes");
  158 + if (all_language_tokens_.size() != all_language_codes_.size()) {
  159 + SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
  160 + static_cast<int32_t>(all_language_tokens_.size()),
  161 + static_cast<int32_t>(all_language_codes_.size()));
  162 + exit(-1);
  163 + }
  164 +
  165 + for (int32_t i = 0;
  166 + i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
  167 + lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
  168 + id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
  169 + }
  170 + }
126 } 171 }
127 172
128 void InitDecoder(void *model_data, size_t model_data_length) { 173 void InitDecoder(void *model_data, size_t model_data_length) {
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl { @@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
157 std::vector<std::string> decoder_output_names_; 202 std::vector<std::string> decoder_output_names_;
158 std::vector<const char *> decoder_output_names_ptr_; 203 std::vector<const char *> decoder_output_names_ptr_;
159 204
  205 + std::vector<int32_t> all_language_tokens_;
  206 + std::vector<std::string> all_language_codes_;
  207 + std::unordered_map<std::string, int32_t> lang2id_;
  208 + std::unordered_map<int32_t, std::string> id2lang_;
  209 +
160 // model meta data 210 // model meta data
161 int32_t n_text_layer_; 211 int32_t n_text_layer_;
162 int32_t n_text_ctx_; 212 int32_t n_text_ctx_;
163 int32_t n_text_state_; 213 int32_t n_text_state_;
  214 + int32_t n_vocab_;
164 int32_t sot_; 215 int32_t sot_;
165 int32_t eot_; 216 int32_t eot_;
166 int32_t blank_; 217 int32_t blank_;
167 int32_t translate_; 218 int32_t translate_;
  219 + int32_t transcribe_;
168 int32_t no_timestamps_; 220 int32_t no_timestamps_;
169 int32_t no_speech_; 221 int32_t no_speech_;
  222 + int32_t is_multilingual_;
170 std::vector<int64_t> sot_sequence_; 223 std::vector<int64_t> sot_sequence_;
171 }; 224 };
172 225
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) @@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
176 OfflineWhisperModel::~OfflineWhisperModel() = default; 229 OfflineWhisperModel::~OfflineWhisperModel() = default;
177 230
178 std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder( 231 std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
179 - Ort::Value features) { 232 + Ort::Value features) const {
180 return impl_->ForwardEncoder(std::move(features)); 233 return impl_->ForwardEncoder(std::move(features));
181 } 234 }
182 235
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, @@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
187 Ort::Value n_layer_self_v_cache, 240 Ort::Value n_layer_self_v_cache,
188 Ort::Value n_layer_cross_k, 241 Ort::Value n_layer_cross_k,
189 Ort::Value n_layer_cross_v, 242 Ort::Value n_layer_cross_v,
190 - Ort::Value offset) { 243 + Ort::Value offset) const {
191 return impl_->ForwardDecoder( 244 return impl_->ForwardDecoder(
192 std::move(tokens), std::move(n_layer_self_k_cache), 245 std::move(tokens), std::move(n_layer_self_k_cache),
193 std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), 246 std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
194 std::move(n_layer_cross_v), std::move(offset)); 247 std::move(n_layer_cross_v), std::move(offset));
195 } 248 }
196 249
197 -std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() { 250 +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
  251 + const {
198 return impl_->GetInitialSelfKVCache(); 252 return impl_->GetInitialSelfKVCache();
199 } 253 }
200 254
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const { @@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
206 return impl_->GetInitialTokens(); 260 return impl_->GetInitialTokens();
207 } 261 }
208 262
  263 +const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
  264 + return impl_->GetAllLanguageIDs();
  265 +}
  266 +
  267 +const std::unordered_map<std::string, int32_t>
  268 + &OfflineWhisperModel::GetLang2ID() const {
  269 + return impl_->GetLang2ID();
  270 +}
  271 +
  272 +const std::unordered_map<int32_t, std::string>
  273 + &OfflineWhisperModel::GetID2Lang() const {
  274 + return impl_->GetID2Lang();
  275 +}
  276 +
  277 +int32_t OfflineWhisperModel::NoTimeStampsToken() const {
  278 + return impl_->NoTimeStampsToken();
  279 +}
  280 +
209 int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } 281 int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
210 282
  283 +int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
  284 +
211 int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } 285 int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
212 286
  287 +int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
  288 +
  289 +int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
  290 +
  291 +bool OfflineWhisperModel::IsMultiLingual() const {
  292 + return impl_->IsMultiLingual();
  293 +}
  294 +
213 } // namespace sherpa_onnx 295 } // namespace sherpa_onnx
@@ -5,7 +5,9 @@ @@ -5,7 +5,9 @@
5 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
6 6
7 #include <memory> 7 #include <memory>
  8 +#include <string>
8 #include <tuple> 9 #include <tuple>
  10 +#include <unordered_map>
9 #include <utility> 11 #include <utility>
10 #include <vector> 12 #include <vector>
11 13
@@ -30,7 +32,7 @@ class OfflineWhisperModel { @@ -30,7 +32,7 @@ class OfflineWhisperModel {
30 * - n_layer_cross_v: A 4-D tensor of shape 32 * - n_layer_cross_v: A 4-D tensor of shape
31 * (n_text_layer, N, n_audio_ctx, n_text_state) 33 * (n_text_layer, N, n_audio_ctx, n_text_state)
32 */ 34 */
33 - std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features); 35 + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
34 36
35 /** Run the decoder model. 37 /** Run the decoder model.
36 * 38 *
@@ -58,7 +60,9 @@ class OfflineWhisperModel { @@ -58,7 +60,9 @@ class OfflineWhisperModel {
58 Ort::Value> 60 Ort::Value>
59 ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, 61 ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
60 Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, 62 Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
61 - Ort::Value n_layer_cross_v, Ort::Value offset); 63 + Ort::Value n_layer_cross_v, Ort::Value offset) const;
  64 +
  65 + int32_t DetectLanguage() const;
62 66
63 /** Return the initial self kv cache in a pair 67 /** Return the initial self kv cache in a pair
64 * - n_layer_self_k_cache A 4-D tensor of shape 68 * - n_layer_self_k_cache A 4-D tensor of shape
@@ -66,14 +70,23 @@ class OfflineWhisperModel { @@ -66,14 +70,23 @@ class OfflineWhisperModel {
66 * - n_layer_self_v_cache A 4-D tensor of shape 70 * - n_layer_self_v_cache A 4-D tensor of shape
67 * (n_text_layer, N, n_audio_ctx, n_text_state). 71 * (n_text_layer, N, n_audio_ctx, n_text_state).
68 */ 72 */
69 - std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache(); 73 + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
70 const std::vector<int64_t> &GetInitialTokens() const; 74 const std::vector<int64_t> &GetInitialTokens() const;
  75 + const std::vector<int32_t> &GetAllLanguageIDs() const;
  76 + const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
  77 + const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
71 78
72 /** Return an allocator for allocating memory 79 /** Return an allocator for allocating memory
73 */ 80 */
74 OrtAllocator *Allocator() const; 81 OrtAllocator *Allocator() const;
  82 +
  83 + int32_t NoTimeStampsToken() const;
75 int32_t EOT() const; 84 int32_t EOT() const;
  85 + int32_t SOT() const;
76 int32_t TextCtx() const; 86 int32_t TextCtx() const;
  87 + int32_t VocabSize() const;
  88 + int32_t Translate() const;
  89 + bool IsMultiLingual() const;
77 90
78 private: 91 private:
79 class Impl; 92 class Impl;
@@ -14,10 +14,14 @@ namespace sherpa_onnx { @@ -14,10 +14,14 @@ namespace sherpa_onnx {
14 void PybindOfflineWhisperModelConfig(py::module *m) { 14 void PybindOfflineWhisperModelConfig(py::module *m) {
15 using PyClass = OfflineWhisperModelConfig; 15 using PyClass = OfflineWhisperModelConfig;
16 py::class_<PyClass>(*m, "OfflineWhisperModelConfig") 16 py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
17 - .def(py::init<const std::string &, const std::string &>(),  
18 - py::arg("encoder"), py::arg("decoder")) 17 + .def(py::init<const std::string &, const std::string &,
  18 + const std::string &, const std::string &>(),
  19 + py::arg("encoder"), py::arg("decoder"), py::arg("language"),
  20 + py::arg("task"))
19 .def_readwrite("encoder", &PyClass::encoder) 21 .def_readwrite("encoder", &PyClass::encoder)
20 .def_readwrite("decoder", &PyClass::decoder) 22 .def_readwrite("decoder", &PyClass::decoder)
  23 + .def_readwrite("language", &PyClass::language)
  24 + .def_readwrite("task", &PyClass::task)
21 .def("__str__", &PyClass::ToString); 25 .def("__str__", &PyClass::ToString);
22 } 26 }
23 27
@@ -244,6 +244,8 @@ class OfflineRecognizer(object): @@ -244,6 +244,8 @@ class OfflineRecognizer(object):
244 encoder: str, 244 encoder: str,
245 decoder: str, 245 decoder: str,
246 tokens: str, 246 tokens: str,
  247 + language: str = "en",
  248 + task: str = "transcribe",
247 num_threads: int = 1, 249 num_threads: int = 1,
248 decoding_method: str = "greedy_search", 250 decoding_method: str = "greedy_search",
249 debug: bool = False, 251 debug: bool = False,
@@ -268,6 +270,14 @@ class OfflineRecognizer(object): @@ -268,6 +270,14 @@ class OfflineRecognizer(object):
268 270
269 symbol integer_id 271 symbol integer_id
270 272
  273 + language:
  274 + The spoken language in the audio file. Example values: en, de, zh,
  275 + jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  276 + for all possible values. Note that for non-multilingual models, the
  277 + only valid value is 'en'.
  278 + task:
  279 + Valid values are: transcribe, translate. Note that for
  280 + non-multilingual models, the only valid value is 'transcribe'.
271 num_threads: 281 num_threads:
272 Number of threads for neural network computation. 282 Number of threads for neural network computation.
273 decoding_method: 283 decoding_method:
@@ -279,7 +289,12 @@ class OfflineRecognizer(object): @@ -279,7 +289,12 @@ class OfflineRecognizer(object):
279 """ 289 """
280 self = cls.__new__(cls) 290 self = cls.__new__(cls)
281 model_config = OfflineModelConfig( 291 model_config = OfflineModelConfig(
282 - whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), 292 + whisper=OfflineWhisperModelConfig(
  293 + encoder=encoder,
  294 + decoder=decoder,
  295 + language=language,
  296 + task=task,
  297 + ),
283 tokens=tokens, 298 tokens=tokens,
284 num_threads=num_threads, 299 num_threads=num_threads,
285 debug=debug, 300 debug=debug,