Committed by
GitHub
Support multilingual whisper models (#274)
正在显示
24 个修改的文件
包含
692 行增加
和
73 行删除
| @@ -36,6 +36,9 @@ jobs: | @@ -36,6 +36,9 @@ jobs: | ||
| 36 | CIBW_ARCHS: "universal2" | 36 | CIBW_ARCHS: "universal2" |
| 37 | CIBW_BUILD_VERBOSITY: 3 | 37 | CIBW_BUILD_VERBOSITY: 3 |
| 38 | 38 | ||
| 39 | + # Don't repair macOS wheels | ||
| 40 | + CIBW_REPAIR_WHEEL_COMMAND_MACOS: "" | ||
| 41 | + | ||
| 39 | - name: Display wheels | 42 | - name: Display wheels |
| 40 | shell: bash | 43 | shell: bash |
| 41 | run: | | 44 | run: | |
| @@ -16,7 +16,7 @@ jobs: | @@ -16,7 +16,7 @@ jobs: | ||
| 16 | fail-fast: false | 16 | fail-fast: false |
| 17 | matrix: | 17 | matrix: |
| 18 | os: [macos-latest] | 18 | os: [macos-latest] |
| 19 | - model: ["tiny.en", "base.en", "small.en", "medium.en"] | 19 | + model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] |
| 20 | 20 | ||
| 21 | steps: | 21 | steps: |
| 22 | - uses: actions/checkout@v2 | 22 | - uses: actions/checkout@v2 |
| @@ -3,7 +3,7 @@ module non-streaming-decode-files | @@ -3,7 +3,7 @@ module non-streaming-decode-files | ||
| 3 | go 1.12 | 3 | go 1.12 |
| 4 | 4 | ||
| 5 | require ( | 5 | require ( |
| 6 | - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 | 6 | + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 |
| 7 | github.com/spf13/pflag v1.0.5 | 7 | github.com/spf13/pflag v1.0.5 |
| 8 | github.com/youpy/go-wav v0.3.2 | 8 | github.com/youpy/go-wav v0.3.2 |
| 9 | ) | 9 | ) |
| @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 | @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 | ||
| 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= |
| 3 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | 3 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= |
| 4 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | 4 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= |
| 5 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= | ||
| 6 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= | ||
| 7 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= | ||
| 8 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 9 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= | ||
| 10 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 11 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= | ||
| 12 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | 5 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= |
| 6 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= | ||
| 7 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= | ||
| 8 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 9 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= | ||
| 10 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 11 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= | ||
| 12 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | ||
| 13 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | 13 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= |
| 14 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | 14 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= |
| 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= |
| @@ -4,6 +4,6 @@ go 1.12 | @@ -4,6 +4,6 @@ go 1.12 | ||
| 4 | 4 | ||
| 5 | require ( | 5 | require ( |
| 6 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 | 6 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 |
| 7 | - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 | 7 | + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 |
| 8 | github.com/spf13/pflag v1.0.5 | 8 | github.com/spf13/pflag v1.0.5 |
| 9 | ) | 9 | ) |
| 1 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= | 1 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= |
| 2 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= | 2 | github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= |
| 3 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= | ||
| 4 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= | ||
| 5 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= | ||
| 6 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 7 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= | ||
| 8 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 9 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= | ||
| 10 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | 3 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= |
| 4 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= | ||
| 5 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= | ||
| 6 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 7 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= | ||
| 8 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 9 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= | ||
| 10 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | ||
| 11 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= | 11 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= |
| 12 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= | 12 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= |
| @@ -3,7 +3,7 @@ module streaming-decode-files | @@ -3,7 +3,7 @@ module streaming-decode-files | ||
| 3 | go 1.12 | 3 | go 1.12 |
| 4 | 4 | ||
| 5 | require ( | 5 | require ( |
| 6 | - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 | 6 | + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 |
| 7 | github.com/spf13/pflag v1.0.5 | 7 | github.com/spf13/pflag v1.0.5 |
| 8 | github.com/youpy/go-wav v0.3.2 | 8 | github.com/youpy/go-wav v0.3.2 |
| 9 | ) | 9 | ) |
| @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 | @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 | ||
| 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= |
| 3 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | 3 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= |
| 4 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | 4 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= |
| 5 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= | ||
| 6 | -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= | ||
| 7 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= | ||
| 8 | -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 9 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= | ||
| 10 | -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 11 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= | ||
| 12 | -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | 5 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= |
| 6 | +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= | ||
| 7 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= | ||
| 8 | +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= | ||
| 9 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= | ||
| 10 | +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= | ||
| 11 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= | ||
| 12 | +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= | ||
| 13 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | 13 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= |
| 14 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | 14 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= |
| 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= |
| @@ -11,10 +11,12 @@ fun main() { | @@ -11,10 +11,12 @@ fun main() { | ||
| 11 | // please refer to | 11 | // please refer to |
| 12 | // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 12 | // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 13 | // to dowload pre-trained models | 13 | // to dowload pre-trained models |
| 14 | - var modelConfig = OnlineTransducerModelConfig( | ||
| 15 | - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", | ||
| 16 | - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", | ||
| 17 | - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", | 14 | + var modelConfig = OnlineModelConfig( |
| 15 | + transducer = OnlineTransducerModelConfig( | ||
| 16 | + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", | ||
| 17 | + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", | ||
| 18 | + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", | ||
| 19 | + ), | ||
| 18 | tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", | 20 | tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", |
| 19 | numThreads = 1, | 21 | numThreads = 1, |
| 20 | debug = false, | 22 | debug = false, |
| @@ -41,19 +43,19 @@ fun main() { | @@ -41,19 +43,19 @@ fun main() { | ||
| 41 | var objArray = WaveReader.readWaveFromFile( | 43 | var objArray = WaveReader.readWaveFromFile( |
| 42 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", | 44 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", |
| 43 | ) | 45 | ) |
| 44 | - var samples : FloatArray = objArray[0] as FloatArray | ||
| 45 | - var sampleRate : Int = objArray[1] as Int | 46 | + var samples: FloatArray = objArray[0] as FloatArray |
| 47 | + var sampleRate: Int = objArray[1] as Int | ||
| 46 | 48 | ||
| 47 | - model.acceptWaveform(samples, sampleRate=sampleRate) | 49 | + model.acceptWaveform(samples, sampleRate = sampleRate) |
| 48 | while (model.isReady()) { | 50 | while (model.isReady()) { |
| 49 | - model.decode() | 51 | + model.decode() |
| 50 | } | 52 | } |
| 51 | 53 | ||
| 52 | var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds | 54 | var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds |
| 53 | - model.acceptWaveform(tailPaddings, sampleRate=sampleRate) | 55 | + model.acceptWaveform(tailPaddings, sampleRate = sampleRate) |
| 54 | model.inputFinished() | 56 | model.inputFinished() |
| 55 | while (model.isReady()) { | 57 | while (model.isReady()) { |
| 56 | - model.decode() | 58 | + model.decode() |
| 57 | } | 59 | } |
| 58 | 60 | ||
| 59 | println("results: ${model.text}") | 61 | println("results: ${model.text}") |
| @@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): | @@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): | ||
| 234 | help="Path to whisper decoder model", | 234 | help="Path to whisper decoder model", |
| 235 | ) | 235 | ) |
| 236 | 236 | ||
| 237 | + parser.add_argument( | ||
| 238 | + "--whisper-language", | ||
| 239 | + default="", | ||
| 240 | + type=str, | ||
| 241 | + help="""It specifies the spoken language in the input audio file. | ||
| 242 | + Example values: en, fr, de, zh, jp. | ||
| 243 | + Available languages for multilingual models can be found at | ||
| 244 | + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 | ||
| 245 | + If not specified, we infer the language from the input audio file. | ||
| 246 | + """, | ||
| 247 | + ) | ||
| 248 | + | ||
| 249 | + parser.add_argument( | ||
| 250 | + "--whisper-task", | ||
| 251 | + default="transcribe", | ||
| 252 | + choices=["transcribe", "translate"], | ||
| 253 | + type=str, | ||
| 254 | + help="""For multilingual models, if you specify translate, the output | ||
| 255 | + will be in English. | ||
| 256 | + """, | ||
| 257 | + ) | ||
| 258 | + | ||
| 237 | 259 | ||
| 238 | def add_model_args(parser: argparse.ArgumentParser): | 260 | def add_model_args(parser: argparse.ArgumentParser): |
| 239 | add_transducer_model_args(parser) | 261 | add_transducer_model_args(parser) |
| @@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 813 | tokens=args.tokens, | 835 | tokens=args.tokens, |
| 814 | num_threads=args.num_threads, | 836 | num_threads=args.num_threads, |
| 815 | decoding_method=args.decoding_method, | 837 | decoding_method=args.decoding_method, |
| 838 | + language=args.whisper_language, | ||
| 839 | + task=args.whisper_task, | ||
| 816 | ) | 840 | ) |
| 817 | elif args.tdnn_model: | 841 | elif args.tdnn_model: |
| 818 | assert_file_exists(args.tdnn_model) | 842 | assert_file_exists(args.tdnn_model) |
| @@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 53 | --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ | 53 | --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ |
| 54 | --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ | 54 | --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ |
| 55 | --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ | 55 | --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ |
| 56 | + --whisper-task=transcribe \ | ||
| 56 | --num-threads=1 \ | 57 | --num-threads=1 \ |
| 57 | ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ | 58 | ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ |
| 58 | ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ | 59 | ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ |
| @@ -201,6 +202,28 @@ def get_args(): | @@ -201,6 +202,28 @@ def get_args(): | ||
| 201 | ) | 202 | ) |
| 202 | 203 | ||
| 203 | parser.add_argument( | 204 | parser.add_argument( |
| 205 | + "--whisper-language", | ||
| 206 | + default="", | ||
| 207 | + type=str, | ||
| 208 | + help="""It specifies the spoken language in the input audio file. | ||
| 209 | + Example values: en, fr, de, zh, jp. | ||
| 210 | + Available languages for multilingual models can be found at | ||
| 211 | + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 | ||
| 212 | + If not specified, we infer the language from the input audio file. | ||
| 213 | + """, | ||
| 214 | + ) | ||
| 215 | + | ||
| 216 | + parser.add_argument( | ||
| 217 | + "--whisper-task", | ||
| 218 | + default="transcribe", | ||
| 219 | + choices=["transcribe", "translate"], | ||
| 220 | + type=str, | ||
| 221 | + help="""For multilingual models, if you specify translate, the output | ||
| 222 | + will be in English. | ||
| 223 | + """, | ||
| 224 | + ) | ||
| 225 | + | ||
| 226 | + parser.add_argument( | ||
| 204 | "--decoding-method", | 227 | "--decoding-method", |
| 205 | type=str, | 228 | type=str, |
| 206 | default="greedy_search", | 229 | default="greedy_search", |
| @@ -371,10 +394,10 @@ def main(): | @@ -371,10 +394,10 @@ def main(): | ||
| 371 | decoder=args.whisper_decoder, | 394 | decoder=args.whisper_decoder, |
| 372 | tokens=args.tokens, | 395 | tokens=args.tokens, |
| 373 | num_threads=args.num_threads, | 396 | num_threads=args.num_threads, |
| 374 | - sample_rate=args.sample_rate, | ||
| 375 | - feature_dim=args.feature_dim, | ||
| 376 | decoding_method=args.decoding_method, | 397 | decoding_method=args.decoding_method, |
| 377 | debug=args.debug, | 398 | debug=args.debug, |
| 399 | + language=args.whisper_language, | ||
| 400 | + task=args.whisper_task, | ||
| 378 | ) | 401 | ) |
| 379 | elif args.tdnn_model: | 402 | elif args.tdnn_model: |
| 380 | assert_file_exists(args.tdnn_model) | 403 | assert_file_exists(args.tdnn_model) |
| @@ -11,6 +11,7 @@ for making the onnx export script public. | @@ -11,6 +11,7 @@ for making the onnx export script public. | ||
| 11 | """ | 11 | """ |
| 12 | 12 | ||
| 13 | import argparse | 13 | import argparse |
| 14 | +import os | ||
| 14 | from pathlib import Path | 15 | from pathlib import Path |
| 15 | from typing import Any, Dict, Optional | 16 | from typing import Any, Dict, Optional |
| 16 | 17 | ||
| @@ -250,6 +251,7 @@ def main(): | @@ -250,6 +251,7 @@ def main(): | ||
| 250 | # write tokens | 251 | # write tokens |
| 251 | 252 | ||
| 252 | tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) | 253 | tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) |
| 254 | + | ||
| 253 | model.eval() | 255 | model.eval() |
| 254 | print(model.dims) | 256 | print(model.dims) |
| 255 | audio = torch.rand(16000 * 2) | 257 | audio = torch.rand(16000 * 2) |
| @@ -306,8 +308,12 @@ def main(): | @@ -306,8 +308,12 @@ def main(): | ||
| 306 | "n_text_head": model.dims.n_text_head, | 308 | "n_text_head": model.dims.n_text_head, |
| 307 | "n_text_layer": model.dims.n_text_layer, | 309 | "n_text_layer": model.dims.n_text_layer, |
| 308 | "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), | 310 | "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), |
| 309 | - "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))), | ||
| 310 | - "all_language_codes": ",".join(tokenizer.all_language_codes), | 311 | + "all_language_tokens": ",".join( |
| 312 | + list(map(str, tokenizer.all_language_tokens)) | ||
| 313 | + ), # a list of ids | ||
| 314 | + "all_language_codes": ",".join( | ||
| 315 | + tokenizer.all_language_codes | ||
| 316 | + ), # e.g., en, de, zh, fr | ||
| 311 | "sot": tokenizer.sot, | 317 | "sot": tokenizer.sot, |
| 312 | "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), | 318 | "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), |
| 313 | "eot": tokenizer.eot, | 319 | "eot": tokenizer.eot, |
| @@ -413,6 +419,9 @@ def main(): | @@ -413,6 +419,9 @@ def main(): | ||
| 413 | }, | 419 | }, |
| 414 | ) | 420 | ) |
| 415 | 421 | ||
| 422 | + if 'large' in args.model: | ||
| 423 | + # it causes errors for large models, so skip it. | ||
| 424 | + return | ||
| 416 | # Generate int8 quantization models | 425 | # Generate int8 quantization models |
| 417 | # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection | 426 | # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection |
| 418 | 427 |
| @@ -39,6 +39,24 @@ def get_args(): | @@ -39,6 +39,24 @@ def get_args(): | ||
| 39 | ) | 39 | ) |
| 40 | 40 | ||
| 41 | parser.add_argument( | 41 | parser.add_argument( |
| 42 | + "--language", | ||
| 43 | + type=str, | ||
| 44 | + help="""The actual spoken language in the audio. | ||
| 45 | + Example values, en, de, zh, jp, fr. | ||
| 46 | + If None, we will detect the language using the first 30s of the | ||
| 47 | + input audio | ||
| 48 | + """, | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + parser.add_argument( | ||
| 52 | + "--task", | ||
| 53 | + choices=["transcribe", "translate"], | ||
| 54 | + type=str, | ||
| 55 | + default="transcribe", | ||
| 56 | + help="Valid values are: transcribe, translate", | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + parser.add_argument( | ||
| 42 | "sound_file", | 60 | "sound_file", |
| 43 | type=str, | 61 | type=str, |
| 44 | help="Path to the test wave", | 62 | help="Path to the test wave", |
| @@ -74,12 +92,22 @@ class OnnxModel: | @@ -74,12 +92,22 @@ class OnnxModel: | ||
| 74 | self.sot = int(meta["sot"]) | 92 | self.sot = int(meta["sot"]) |
| 75 | self.eot = int(meta["eot"]) | 93 | self.eot = int(meta["eot"]) |
| 76 | self.translate = int(meta["translate"]) | 94 | self.translate = int(meta["translate"]) |
| 95 | + self.transcribe = int(meta["transcribe"]) | ||
| 77 | self.no_timestamps = int(meta["no_timestamps"]) | 96 | self.no_timestamps = int(meta["no_timestamps"]) |
| 78 | self.no_speech = int(meta["no_speech"]) | 97 | self.no_speech = int(meta["no_speech"]) |
| 79 | self.blank = int(meta["blank_id"]) | 98 | self.blank = int(meta["blank_id"]) |
| 80 | 99 | ||
| 81 | self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) | 100 | self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) |
| 82 | 101 | ||
| 102 | + self.sot_sequence.append(self.no_timestamps) | ||
| 103 | + | ||
| 104 | + self.all_language_tokens = list( | ||
| 105 | + map(int, meta["all_language_tokens"].split(",")) | ||
| 106 | + ) | ||
| 107 | + self.all_language_codes = meta["all_language_codes"].split(",") | ||
| 108 | + self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens)) | ||
| 109 | + self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes)) | ||
| 110 | + | ||
| 83 | self.is_multilingual = int(meta["is_multilingual"]) == 1 | 111 | self.is_multilingual = int(meta["is_multilingual"]) == 1 |
| 84 | 112 | ||
| 85 | def init_decoder(self, decoder: str): | 113 | def init_decoder(self, decoder: str): |
| @@ -164,6 +192,29 @@ class OnnxModel: | @@ -164,6 +192,29 @@ class OnnxModel: | ||
| 164 | # logits is changed in-place | 192 | # logits is changed in-place |
| 165 | logits[self.translate] = float("-inf") | 193 | logits[self.translate] = float("-inf") |
| 166 | 194 | ||
| 195 | + def detect_language( | ||
| 196 | + self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor | ||
| 197 | + ) -> int: | ||
| 198 | + tokens = torch.tensor([[self.sot]], dtype=torch.int64) | ||
| 199 | + offset = torch.zeros(1, dtype=torch.int64) | ||
| 200 | + n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache() | ||
| 201 | + | ||
| 202 | + logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder( | ||
| 203 | + tokens=tokens, | ||
| 204 | + n_layer_self_k_cache=n_layer_self_k_cache, | ||
| 205 | + n_layer_self_v_cache=n_layer_self_v_cache, | ||
| 206 | + n_layer_cross_k=n_layer_cross_k, | ||
| 207 | + n_layer_cross_v=n_layer_cross_v, | ||
| 208 | + offset=offset, | ||
| 209 | + ) | ||
| 210 | + logits = logits.reshape(-1) | ||
| 211 | + mask = torch.ones(logits.shape[0], dtype=torch.int64) | ||
| 212 | + mask[self.all_language_tokens] = 0 | ||
| 213 | + logits[mask] = float("-inf") | ||
| 214 | + lang_id = logits.argmax().item() | ||
| 215 | + print("detected language: ", self.id2lang[lang_id]) | ||
| 216 | + return lang_id | ||
| 217 | + | ||
| 167 | 218 | ||
| 168 | def load_tokens(filename): | 219 | def load_tokens(filename): |
| 169 | tokens = dict() | 220 | tokens = dict() |
| @@ -200,7 +251,35 @@ def main(): | @@ -200,7 +251,35 @@ def main(): | ||
| 200 | mel = mel.t().unsqueeze(0) | 251 | mel = mel.t().unsqueeze(0) |
| 201 | 252 | ||
| 202 | model = OnnxModel(encoder, decoder) | 253 | model = OnnxModel(encoder, decoder) |
| 254 | + | ||
| 203 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) | 255 | n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) |
| 256 | + | ||
| 257 | + if args.language is not None: | ||
| 258 | + if model.is_multilingual is False and args.language != "en": | ||
| 259 | + print(f"This model supports only English. Given: {args.language}") | ||
| 260 | + return | ||
| 261 | + | ||
| 262 | + if args.language not in model.lang2id: | ||
| 263 | + print(f"Invalid language: {args.language}") | ||
| 264 | + print(f"Valid values are: {list(model.lang2id.keys())}") | ||
| 265 | + return | ||
| 266 | + | ||
| 267 | + # [sot, lang, task, notimestamps] | ||
| 268 | + model.sot_sequence[1] = model.lang2id[args.language] | ||
| 269 | + elif model.is_multilingual is True: | ||
| 270 | + print("detecting language") | ||
| 271 | + lang = model.detect_language(n_layer_cross_k, n_layer_cross_v) | ||
| 272 | + model.sot_sequence[1] = lang | ||
| 273 | + | ||
| 274 | + if args.task is not None: | ||
| 275 | + if model.is_multilingual is False and args.task != "transcribe": | ||
| 276 | + print("This model supports only English. Please use --task=transcribe") | ||
| 277 | + return | ||
| 278 | + assert args.task in ["transcribe", "translate"], args.task | ||
| 279 | + | ||
| 280 | + if args.task == "translate": | ||
| 281 | + model.sot_sequence[2] = model.translate | ||
| 282 | + | ||
| 204 | n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() | 283 | n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() |
| 205 | 284 | ||
| 206 | tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) | 285 | tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) |
| @@ -213,6 +292,7 @@ def main(): | @@ -213,6 +292,7 @@ def main(): | ||
| 213 | n_layer_cross_v=n_layer_cross_v, | 292 | n_layer_cross_v=n_layer_cross_v, |
| 214 | offset=offset, | 293 | offset=offset, |
| 215 | ) | 294 | ) |
| 295 | + offset += len(model.sot_sequence) | ||
| 216 | # logits.shape (batch_size, tokens.shape[1], vocab_size) | 296 | # logits.shape (batch_size, tokens.shape[1], vocab_size) |
| 217 | logits = logits[0, -1] | 297 | logits = logits[0, -1] |
| 218 | model.suppress_tokens(logits, is_initial=True) | 298 | model.suppress_tokens(logits, is_initial=True) |
| @@ -225,7 +305,6 @@ def main(): | @@ -225,7 +305,6 @@ def main(): | ||
| 225 | break | 305 | break |
| 226 | results.append(max_token_id.item()) | 306 | results.append(max_token_id.item()) |
| 227 | tokens = torch.tensor([[results[-1]]]) | 307 | tokens = torch.tensor([[results[-1]]]) |
| 228 | - offset += 1 | ||
| 229 | 308 | ||
| 230 | logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( | 309 | logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( |
| 231 | tokens=tokens, | 310 | tokens=tokens, |
| @@ -235,6 +314,7 @@ def main(): | @@ -235,6 +314,7 @@ def main(): | ||
| 235 | n_layer_cross_v=n_layer_cross_v, | 314 | n_layer_cross_v=n_layer_cross_v, |
| 236 | offset=offset, | 315 | offset=offset, |
| 237 | ) | 316 | ) |
| 317 | + offset += 1 | ||
| 238 | logits = logits[0, -1] | 318 | logits = logits[0, -1] |
| 239 | model.suppress_tokens(logits, is_initial=False) | 319 | model.suppress_tokens(logits, is_initial=False) |
| 240 | max_token_id = logits.argmax(dim=-1) | 320 | max_token_id = logits.argmax(dim=-1) |
| @@ -37,7 +37,7 @@ | @@ -37,7 +37,7 @@ | ||
| 37 | } \ | 37 | } \ |
| 38 | \ | 38 | \ |
| 39 | dst = atoi(value.get()); \ | 39 | dst = atoi(value.get()); \ |
| 40 | - if (dst <= 0) { \ | 40 | + if (dst < 0) { \ |
| 41 | SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ | 41 | SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ |
| 42 | exit(-1); \ | 42 | exit(-1); \ |
| 43 | } \ | 43 | } \ |
| @@ -77,6 +77,24 @@ | @@ -77,6 +77,24 @@ | ||
| 77 | } \ | 77 | } \ |
| 78 | } while (0) | 78 | } while (0) |
| 79 | 79 | ||
| 80 | +// read a vector of strings | ||
| 81 | +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ | ||
| 82 | + do { \ | ||
| 83 | + auto value = \ | ||
| 84 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 85 | + if (!value) { \ | ||
| 86 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 87 | + exit(-1); \ | ||
| 88 | + } \ | ||
| 89 | + SplitStringToVector(value.get(), ",", false, &dst); \ | ||
| 90 | + \ | ||
| 91 | + if (dst.empty()) { \ | ||
| 92 | + SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ | ||
| 93 | + src_key); \ | ||
| 94 | + exit(-1); \ | ||
| 95 | + } \ | ||
| 96 | + } while (0) | ||
| 97 | + | ||
| 80 | // Read a string | 98 | // Read a string |
| 81 | #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ | 99 | #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ |
| 82 | do { \ | 100 | do { \ |
| @@ -23,21 +23,227 @@ | @@ -23,21 +23,227 @@ | ||
| 23 | 23 | ||
| 24 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 25 | 25 | ||
| 26 | +static std::string FixInvalidUtf8(const std::string &s) { | ||
| 27 | + int32_t s_size = s.size(); | ||
| 28 | + | ||
| 29 | + std::string ans; | ||
| 30 | + ans.reserve(s_size); | ||
| 31 | + | ||
| 32 | + for (int32_t i = 0; i < s_size;) { | ||
| 33 | + uint8_t c = s[i]; | ||
| 34 | + if (c < 0x80) { | ||
| 35 | + // valid | ||
| 36 | + ans.append(1, c); | ||
| 37 | + ++i; | ||
| 38 | + continue; | ||
| 39 | + } else if ((c >= 0xc0) && (c < 0xe0)) { | ||
| 40 | + // beginning of two bytes | ||
| 41 | + if ((i + 1) > (s_size - 1)) { | ||
| 42 | + // no subsequent byte. invalid! | ||
| 43 | + i += 1; | ||
| 44 | + continue; | ||
| 45 | + } | ||
| 46 | + uint8_t next = s[i + 1]; | ||
| 47 | + if (!(next >= 0x80 && next < 0xc0)) { | ||
| 48 | + // invalid | ||
| 49 | + i += 1; | ||
| 50 | + continue; | ||
| 51 | + } | ||
| 52 | + // valid 2-byte utf-8 | ||
| 53 | + ans.append(1, c); | ||
| 54 | + ans.append(1, next); | ||
| 55 | + i += 2; | ||
| 56 | + continue; | ||
| 57 | + } else if ((c >= 0xe0) && (c < 0xf0)) { | ||
| 58 | + // beginning of 3 bytes | ||
| 59 | + if ((i + 2) > (s_size - 1)) { | ||
| 60 | + // invalid | ||
| 61 | + i += 1; | ||
| 62 | + continue; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + uint8_t next = s[i + 1]; | ||
| 66 | + if (!(next >= 0x80 && next < 0xc0)) { | ||
| 67 | + // invalid | ||
| 68 | + i += 1; | ||
| 69 | + continue; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + uint8_t next2 = s[i + 2]; | ||
| 73 | + if (!(next2 >= 0x80 && next2 < 0xc0)) { | ||
| 74 | + // invalid | ||
| 75 | + i += 1; | ||
| 76 | + continue; | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + ans.append(1, c); | ||
| 80 | + ans.append(1, next); | ||
| 81 | + ans.append(1, next2); | ||
| 82 | + i += 3; | ||
| 83 | + continue; | ||
| 84 | + } else if ((c >= 0xf0) && (c < 0xf8)) { | ||
| 85 | + // 4 bytes | ||
| 86 | + if ((i + 3) > (s_size - 1)) { | ||
| 87 | + // invalid | ||
| 88 | + i += 1; | ||
| 89 | + continue; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + uint8_t next = s[i + 1]; | ||
| 93 | + if (!(next >= 0x80 && next < 0xc0)) { | ||
| 94 | + // invalid | ||
| 95 | + i += 1; | ||
| 96 | + continue; | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + uint8_t next2 = s[i + 2]; | ||
| 100 | + if (!(next2 >= 0x80 && next2 < 0xc0)) { | ||
| 101 | + // invalid | ||
| 102 | + i += 1; | ||
| 103 | + continue; | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + uint8_t next3 = s[i + 3]; | ||
| 107 | + if (!(next3 >= 0x80 && next3 < 0xc0)) { | ||
| 108 | + // invalid | ||
| 109 | + i += 1; | ||
| 110 | + continue; | ||
| 111 | + } | ||
| 112 | + ans.append(1, c); | ||
| 113 | + ans.append(1, next); | ||
| 114 | + ans.append(1, next2); | ||
| 115 | + ans.append(1, next3); | ||
| 116 | + i += 4; | ||
| 117 | + continue; | ||
| 118 | + } else if ((c >= 0xf8) && (c < 0xfc)) { | ||
| 119 | + // 5 bytes | ||
| 120 | + if ((i + 4) > (s_size - 1)) { | ||
| 121 | + // invalid | ||
| 122 | + i += 1; | ||
| 123 | + continue; | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + uint8_t next = s[i + 1]; | ||
| 127 | + if (!(next >= 0x80 && next < 0xc0)) { | ||
| 128 | + // invalid | ||
| 129 | + i += 1; | ||
| 130 | + continue; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + uint8_t next2 = s[i + 2]; | ||
| 134 | + if (!(next2 >= 0x80 && next2 < 0xc0)) { | ||
| 135 | + // invalid | ||
| 136 | + i += 1; | ||
| 137 | + continue; | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + uint8_t next3 = s[i + 3]; | ||
| 141 | + if (!(next3 >= 0x80 && next3 < 0xc0)) { | ||
| 142 | + // invalid | ||
| 143 | + i += 1; | ||
| 144 | + continue; | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + uint8_t next4 = s[i + 4]; | ||
| 148 | + if (!(next4 >= 0x80 && next4 < 0xc0)) { | ||
| 149 | + // invalid | ||
| 150 | + i += 1; | ||
| 151 | + continue; | ||
| 152 | + } | ||
| 153 | + ans.append(1, c); | ||
| 154 | + ans.append(1, next); | ||
| 155 | + ans.append(1, next2); | ||
| 156 | + ans.append(1, next3); | ||
| 157 | + ans.append(1, next4); | ||
| 158 | + i += 5; | ||
| 159 | + continue; | ||
| 160 | + } else if ((c >= 0xfc) && (c < 0xfe)) { | ||
| 161 | + // 6 bytes | ||
| 162 | + if ((i + 5) > (s_size - 1)) { | ||
| 163 | + // invalid | ||
| 164 | + i += 1; | ||
| 165 | + continue; | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + uint8_t next = s[i + 1]; | ||
| 169 | + if (!(next >= 0x80 && next < 0xc0)) { | ||
| 170 | + // invalid | ||
| 171 | + i += 1; | ||
| 172 | + continue; | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + uint8_t next2 = s[i + 2]; | ||
| 176 | + if (!(next2 >= 0x80 && next2 < 0xc0)) { | ||
| 177 | + // invalid | ||
| 178 | + i += 1; | ||
| 179 | + continue; | ||
| 180 | + } | ||
| 181 | + | ||
| 182 | + uint8_t next3 = s[i + 3]; | ||
| 183 | + if (!(next3 >= 0x80 && next3 < 0xc0)) { | ||
| 184 | + // invalid | ||
| 185 | + i += 1; | ||
| 186 | + continue; | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + uint8_t next4 = s[i + 4]; | ||
| 190 | + if (!(next4 >= 0x80 && next4 < 0xc0)) { | ||
| 191 | + // invalid | ||
| 192 | + i += 1; | ||
| 193 | + continue; | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + uint8_t next5 = s[i + 5]; | ||
| 197 | + if (!(next5 >= 0x80 && next5 < 0xc0)) { | ||
| 198 | + // invalid | ||
| 199 | + i += 1; | ||
| 200 | + continue; | ||
| 201 | + } | ||
| 202 | + ans.append(1, c); | ||
| 203 | + ans.append(1, next); | ||
| 204 | + ans.append(1, next2); | ||
| 205 | + ans.append(1, next3); | ||
| 206 | + ans.append(1, next4); | ||
| 207 | + ans.append(1, next5); | ||
| 208 | + i += 6; | ||
| 209 | + continue; | ||
| 210 | + } else { | ||
| 211 | + i += 1; | ||
| 212 | + } | ||
| 213 | + } | ||
| 214 | + return ans; | ||
| 215 | +} | ||
| 216 | + | ||
| 26 | static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | 217 | static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, |
| 27 | const SymbolTable &sym_table) { | 218 | const SymbolTable &sym_table) { |
| 28 | OfflineRecognitionResult r; | 219 | OfflineRecognitionResult r; |
| 29 | r.tokens.reserve(src.tokens.size()); | 220 | r.tokens.reserve(src.tokens.size()); |
| 30 | 221 | ||
| 222 | + std::string text; | ||
| 31 | for (auto i : src.tokens) { | 223 | for (auto i : src.tokens) { |
| 32 | if (!sym_table.contains(i)) { | 224 | if (!sym_table.contains(i)) { |
| 33 | continue; | 225 | continue; |
| 34 | } | 226 | } |
| 35 | 227 | ||
| 36 | const auto &s = sym_table[i]; | 228 | const auto &s = sym_table[i]; |
| 37 | - r.text += s; | 229 | + text += s; |
| 38 | r.tokens.push_back(s); | 230 | r.tokens.push_back(s); |
| 39 | } | 231 | } |
| 40 | 232 | ||
| 233 | + // TODO(fangjun): Fix the following error in offline-stream.cc | ||
| 234 | + // | ||
| 235 | + // j["text"] = text; | ||
| 236 | + | ||
| 237 | + // libc++abi: terminating with uncaught exception of type | ||
| 238 | + // nlohmann::json_abi_v3_11_2::detail::type_error: | ||
| 239 | + // [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86 | ||
| 240 | + | ||
| 241 | +#if 0 | ||
| 242 | + r.text = FixInvalidUtf8(text); | ||
| 243 | +#else | ||
| 244 | + r.text = text; | ||
| 245 | +#endif | ||
| 246 | + | ||
| 41 | return r; | 247 | return r; |
| 42 | } | 248 | } |
| 43 | 249 | ||
| @@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 51 | symbol_table_.ApplyBase64Decode(); | 257 | symbol_table_.ApplyBase64Decode(); |
| 52 | 258 | ||
| 53 | if (config.decoding_method == "greedy_search") { | 259 | if (config.decoding_method == "greedy_search") { |
| 54 | - decoder_ = | ||
| 55 | - std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get()); | 260 | + decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>( |
| 261 | + config_.model_config.whisper, model_.get()); | ||
| 56 | } else { | 262 | } else { |
| 57 | SHERPA_ONNX_LOGE( | 263 | SHERPA_ONNX_LOGE( |
| 58 | "Only greedy_search is supported at present for whisper. Given %s", | 264 | "Only greedy_search is supported at present for whisper. Given %s", |
| @@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 101 | mel = Transpose12(model_->Allocator(), &mel); | 307 | mel = Transpose12(model_->Allocator(), &mel); |
| 102 | 308 | ||
| 103 | auto cross_kv = model_->ForwardEncoder(std::move(mel)); | 309 | auto cross_kv = model_->ForwardEncoder(std::move(mel)); |
| 310 | + | ||
| 104 | auto results = | 311 | auto results = |
| 105 | decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); | 312 | decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); |
| 106 | 313 |
| @@ -7,17 +7,106 @@ | @@ -7,17 +7,106 @@ | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <utility> | 8 | #include <utility> |
| 9 | 9 | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | + | ||
| 10 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 11 | 14 | ||
| 15 | +int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage( | ||
| 16 | + Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT | ||
| 17 | + int64_t token_val = model_->SOT(); | ||
| 18 | + std::array<int64_t, 2> token_shape{1, 1}; | ||
| 19 | + | ||
| 20 | + auto memory_info = | ||
| 21 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 22 | + | ||
| 23 | + Ort::Value tokens = Ort::Value::CreateTensor( | ||
| 24 | + memory_info, &token_val, 1, token_shape.data(), token_shape.size()); | ||
| 25 | + | ||
| 26 | + auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
| 27 | + | ||
| 28 | + std::array<int64_t, 1> offset_shape{1}; | ||
| 29 | + Ort::Value offset = Ort::Value::CreateTensor<int64_t>( | ||
| 30 | + model_->Allocator(), offset_shape.data(), offset_shape.size()); | ||
| 31 | + *(offset.GetTensorMutableData<int64_t>()) = 0; | ||
| 32 | + | ||
| 33 | + auto decoder_out = model_->ForwardDecoder( | ||
| 34 | + std::move(tokens), std::move(self_kv_cache.first), | ||
| 35 | + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), | ||
| 36 | + std::move(offset)); | ||
| 37 | + | ||
| 38 | + cross_k = std::move(std::get<3>(decoder_out)); | ||
| 39 | + cross_v = std::move(std::get<4>(decoder_out)); | ||
| 40 | + | ||
| 41 | + const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>(); | ||
| 42 | + int32_t vocab_size = model_->VocabSize(); | ||
| 43 | + const auto &all_language_ids = model_->GetAllLanguageIDs(); | ||
| 44 | + | ||
| 45 | + int32_t lang_id = all_language_ids[0]; | ||
| 46 | + float this_logit = p_logits[lang_id]; | ||
| 47 | + | ||
| 48 | + for (int32_t i = 1; i != all_language_ids.size(); ++i) { | ||
| 49 | + int32_t id = all_language_ids[i]; | ||
| 50 | + float p = p_logits[id]; | ||
| 51 | + | ||
| 52 | + if (p > this_logit) { | ||
| 53 | + this_logit = p; | ||
| 54 | + lang_id = id; | ||
| 55 | + } | ||
| 56 | + } | ||
| 57 | +#if 1 | ||
| 58 | + SHERPA_ONNX_LOGE("Detected language: %s", | ||
| 59 | + model_->GetID2Lang().at(lang_id).c_str()); | ||
| 60 | +#endif | ||
| 61 | + | ||
| 62 | + return lang_id; | ||
| 63 | +} | ||
| 64 | + | ||
| 12 | std::vector<OfflineWhisperDecoderResult> | 65 | std::vector<OfflineWhisperDecoderResult> |
| 13 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | 66 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, |
| 14 | Ort::Value cross_v) { | 67 | Ort::Value cross_v) { |
| 15 | auto memory_info = | 68 | auto memory_info = |
| 16 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 69 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 17 | 70 | ||
| 18 | - auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
| 19 | - | 71 | + // For multilingual models, initial_tokens contains [sot, language, task] |
| 72 | + // - language is English by default | ||
| 73 | + // - task is transcribe by default | ||
| 74 | + // | ||
| 75 | + // For non-multilingual models, initial_tokens contains [sot] | ||
| 20 | std::vector<int64_t> initial_tokens = model_->GetInitialTokens(); | 76 | std::vector<int64_t> initial_tokens = model_->GetInitialTokens(); |
| 77 | + | ||
| 78 | + if (model_->IsMultiLingual()) { | ||
| 79 | + if (!config_.language.empty()) { | ||
| 80 | + const auto &lang2id = model_->GetLang2ID(); | ||
| 81 | + | ||
| 82 | + if (!lang2id.count(config_.language)) { | ||
| 83 | + SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str()); | ||
| 84 | + exit(-1); | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + int32_t lang_id = lang2id.at(config_.language); | ||
| 88 | + | ||
| 89 | + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps | ||
| 90 | + initial_tokens[1] = lang_id; | ||
| 91 | + } else { | ||
| 92 | + int32_t lang_id = DetectLanguage(cross_k, cross_v); | ||
| 93 | + | ||
| 94 | + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps | ||
| 95 | + initial_tokens[1] = lang_id; | ||
| 96 | + } | ||
| 97 | + | ||
| 98 | + if (config_.task == "translate") { | ||
| 99 | + initial_tokens[2] = model_->Translate(); | ||
| 100 | + } else if (config_.task != "transcribe") { | ||
| 101 | + // initial_tokens[2] is transcribe by default | ||
| 102 | + SHERPA_ONNX_LOGE( | ||
| 103 | + "Unsupported task: %s. Valid values are: transcribe, translate.", | ||
| 104 | + config_.task.c_str()); | ||
| 105 | + } | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + initial_tokens.push_back(model_->NoTimeStampsToken()); | ||
| 109 | + | ||
| 21 | int32_t batch_size = 1; | 110 | int32_t batch_size = 1; |
| 22 | std::array<int64_t, 2> token_shape{ | 111 | std::array<int64_t, 2> token_shape{ |
| 23 | batch_size, static_cast<int64_t>(initial_tokens.size())}; | 112 | batch_size, static_cast<int64_t>(initial_tokens.size())}; |
| @@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 31 | model_->Allocator(), offset_shape.data(), offset_shape.size()); | 120 | model_->Allocator(), offset_shape.data(), offset_shape.size()); |
| 32 | *(offset.GetTensorMutableData<int64_t>()) = 0; | 121 | *(offset.GetTensorMutableData<int64_t>()) = 0; |
| 33 | 122 | ||
| 123 | + auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
| 124 | + | ||
| 34 | auto decoder_out = model_->ForwardDecoder( | 125 | auto decoder_out = model_->ForwardDecoder( |
| 35 | std::move(tokens), std::move(self_kv_cache.first), | 126 | std::move(tokens), std::move(self_kv_cache.first), |
| 36 | std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), | 127 | std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), |
| 37 | std::move(offset)); | 128 | std::move(offset)); |
| 38 | 129 | ||
| 130 | + *(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) = | ||
| 131 | + initial_tokens.size(); | ||
| 132 | + | ||
| 39 | const auto &logits = std::get<0>(decoder_out); | 133 | const auto &logits = std::get<0>(decoder_out); |
| 40 | const float *p_logits = logits.GetTensorData<float>(); | 134 | const float *p_logits = logits.GetTensorData<float>(); |
| 41 | 135 | ||
| @@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 58 | std::array<int64_t, 2> token_shape{1, 1}; | 152 | std::array<int64_t, 2> token_shape{1, 1}; |
| 59 | Ort::Value tokens = Ort::Value::CreateTensor<int64_t>( | 153 | Ort::Value tokens = Ort::Value::CreateTensor<int64_t>( |
| 60 | model_->Allocator(), token_shape.data(), token_shape.size()); | 154 | model_->Allocator(), token_shape.data(), token_shape.size()); |
| 155 | + | ||
| 61 | int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>(); | 156 | int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>(); |
| 62 | p_tokens[0] = max_token_id; | 157 | p_tokens[0] = max_token_id; |
| 63 | 158 | ||
| 64 | - int64_t *p_offset = | ||
| 65 | - std::get<5>(decoder_out).GetTensorMutableData<int64_t>(); | ||
| 66 | - | ||
| 67 | - if (i == 0) { | ||
| 68 | - *p_offset = initial_tokens.size(); | ||
| 69 | - } else { | ||
| 70 | - *p_offset += 1; | ||
| 71 | - } | ||
| 72 | - | ||
| 73 | decoder_out = model_->ForwardDecoder(std::move(tokens), | 159 | decoder_out = model_->ForwardDecoder(std::move(tokens), |
| 74 | std::move(std::get<1>(decoder_out)), | 160 | std::move(std::get<1>(decoder_out)), |
| 75 | std::move(std::get<2>(decoder_out)), | 161 | std::move(std::get<2>(decoder_out)), |
| @@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 77 | std::move(std::get<4>(decoder_out)), | 163 | std::move(std::get<4>(decoder_out)), |
| 78 | std::move(std::get<5>(decoder_out))); | 164 | std::move(std::get<5>(decoder_out))); |
| 79 | 165 | ||
| 166 | + int64_t *p_offset = | ||
| 167 | + std::get<5>(decoder_out).GetTensorMutableData<int64_t>(); | ||
| 168 | + | ||
| 169 | + *p_offset += 1; | ||
| 170 | + | ||
| 80 | const auto &logits = std::get<0>(decoder_out); | 171 | const auto &logits = std::get<0>(decoder_out); |
| 81 | const float *p_logits = logits.GetTensorData<float>(); | 172 | const float *p_logits = logits.GetTensorData<float>(); |
| 82 | 173 | ||
| @@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 85 | } | 176 | } |
| 86 | 177 | ||
| 87 | std::vector<OfflineWhisperDecoderResult> ans(1); | 178 | std::vector<OfflineWhisperDecoderResult> ans(1); |
| 179 | + | ||
| 88 | ans[0].tokens = std::move(predicted_tokens); | 180 | ans[0].tokens = std::move(predicted_tokens); |
| 89 | 181 | ||
| 90 | return ans; | 182 | return ans; |
| @@ -8,19 +8,25 @@ | @@ -8,19 +8,25 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/offline-whisper-decoder.h" | 10 | #include "sherpa-onnx/csrc/offline-whisper-decoder.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-whisper-model.h" | 12 | #include "sherpa-onnx/csrc/offline-whisper-model.h" |
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | 15 | ||
| 15 | class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | 16 | class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { |
| 16 | public: | 17 | public: |
| 17 | - explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model) | ||
| 18 | - : model_(model) {} | 18 | + OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config, |
| 19 | + OfflineWhisperModel *model) | ||
| 20 | + : config_(config), model_(model) {} | ||
| 19 | 21 | ||
| 20 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, | 22 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, |
| 21 | Ort::Value cross_v) override; | 23 | Ort::Value cross_v) override; |
| 22 | 24 | ||
| 25 | + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT | ||
| 26 | + Ort::Value &cross_v) const; // NOLINT | ||
| 27 | + | ||
| 23 | private: | 28 | private: |
| 29 | + OfflineWhisperModelConfig config_; | ||
| 24 | OfflineWhisperModel *model_; // not owned | 30 | OfflineWhisperModel *model_; // not owned |
| 25 | }; | 31 | }; |
| 26 | 32 |
| @@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | @@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | ||
| 17 | po->Register("whisper-decoder", &decoder, | 17 | po->Register("whisper-decoder", &decoder, |
| 18 | "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " | 18 | "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " |
| 19 | "medium.en-decoder.onnx."); | 19 | "medium.en-decoder.onnx."); |
| 20 | + | ||
| 21 | + po->Register( | ||
| 22 | + "whisper-language", &language, | ||
| 23 | + "The spoke language in the input audio file. Example values: " | ||
| 24 | + "en, de, fr, zh, jp. If it is not given for a multilingual model, we will" | ||
| 25 | + " infer the language from the input audio file. " | ||
| 26 | + "Please refer to " | ||
| 27 | + "https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10" | ||
| 28 | + " for valid values. Note that for non-multilingual models, it supports " | ||
| 29 | + "only 'en'"); | ||
| 30 | + | ||
| 31 | + po->Register("whisper-task", &task, | ||
| 32 | + "Valid values: transcribe, translate. " | ||
| 33 | + "Note that for non-multilingual models, it supports " | ||
| 34 | + "only 'transcribe'"); | ||
| 20 | } | 35 | } |
| 21 | 36 | ||
| 22 | bool OfflineWhisperModelConfig::Validate() const { | 37 | bool OfflineWhisperModelConfig::Validate() const { |
| @@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const { | @@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const { | ||
| 30 | return false; | 45 | return false; |
| 31 | } | 46 | } |
| 32 | 47 | ||
| 48 | + if (task != "translate" && task != "transcribe") { | ||
| 49 | + SHERPA_ONNX_LOGE( | ||
| 50 | + "--whisper-task supports only translate and transcribe. Given: %s", | ||
| 51 | + task.c_str()); | ||
| 52 | + | ||
| 53 | + return false; | ||
| 54 | + } | ||
| 55 | + | ||
| 33 | return true; | 56 | return true; |
| 34 | } | 57 | } |
| 35 | 58 | ||
| @@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const { | @@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const { | ||
| 38 | 61 | ||
| 39 | os << "OfflineWhisperModelConfig("; | 62 | os << "OfflineWhisperModelConfig("; |
| 40 | os << "encoder=\"" << encoder << "\", "; | 63 | os << "encoder=\"" << encoder << "\", "; |
| 41 | - os << "decoder=\"" << decoder << "\")"; | 64 | + os << "decoder=\"" << decoder << "\", "; |
| 65 | + os << "language=\"" << language << "\", "; | ||
| 66 | + os << "task=\"" << task << "\")"; | ||
| 42 | 67 | ||
| 43 | return os.str(); | 68 | return os.str(); |
| 44 | } | 69 | } |
| @@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig { | @@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig { | ||
| 14 | std::string encoder; | 14 | std::string encoder; |
| 15 | std::string decoder; | 15 | std::string decoder; |
| 16 | 16 | ||
| 17 | + // Available languages can be found at | ||
| 18 | + // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 | ||
| 19 | + // | ||
| 20 | + // Note: For non-multilingual models, it supports only "en" | ||
| 21 | + // | ||
| 22 | + // If empty, we will infer it from the input audio file when | ||
| 23 | + // the model is multilingual. | ||
| 24 | + std::string language; | ||
| 25 | + | ||
| 26 | + // Valid values are transcribe and translate | ||
| 27 | + // | ||
| 28 | + // Note: For non-multilingual models, it supports only "transcribe" | ||
| 29 | + std::string task = "transcribe"; | ||
| 30 | + | ||
| 17 | OfflineWhisperModelConfig() = default; | 31 | OfflineWhisperModelConfig() = default; |
| 18 | OfflineWhisperModelConfig(const std::string &encoder, | 32 | OfflineWhisperModelConfig(const std::string &encoder, |
| 19 | - const std::string &decoder) | ||
| 20 | - : encoder(encoder), decoder(decoder) {} | 33 | + const std::string &decoder, |
| 34 | + const std::string &language, | ||
| 35 | + const std::string &task) | ||
| 36 | + : encoder(encoder), decoder(decoder), language(language), task(task) {} | ||
| 21 | 37 | ||
| 22 | void Register(ParseOptions *po); | 38 | void Register(ParseOptions *po); |
| 23 | bool Validate() const; | 39 | bool Validate() const; |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <tuple> | 9 | #include <tuple> |
| 10 | +#include <unordered_map> | ||
| 10 | #include <utility> | 11 | #include <utility> |
| 11 | 12 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl { | @@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl { | ||
| 88 | 89 | ||
| 89 | const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } | 90 | const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } |
| 90 | 91 | ||
| 92 | + const std::vector<int32_t> &GetAllLanguageIDs() const { | ||
| 93 | + return all_language_tokens_; | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + const std::unordered_map<std::string, int32_t> &GetLang2ID() const { | ||
| 97 | + return lang2id_; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + const std::unordered_map<int32_t, std::string> &GetID2Lang() const { | ||
| 101 | + return id2lang_; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + int32_t NoTimeStampsToken() const { return no_timestamps_; } | ||
| 105 | + | ||
| 91 | int32_t EOT() const { return eot_; } | 106 | int32_t EOT() const { return eot_; } |
| 92 | 107 | ||
| 108 | + int32_t SOT() const { return sot_; } | ||
| 109 | + | ||
| 93 | int32_t TextCtx() const { return n_text_ctx_; } | 110 | int32_t TextCtx() const { return n_text_ctx_; } |
| 94 | 111 | ||
| 112 | + int32_t VocabSize() const { return n_vocab_; } | ||
| 113 | + | ||
| 114 | + int32_t Translate() const { return translate_; } | ||
| 115 | + | ||
| 116 | + bool IsMultiLingual() const { return is_multilingual_; } | ||
| 117 | + | ||
| 95 | private: | 118 | private: |
| 96 | void InitEncoder(void *model_data, size_t model_data_length) { | 119 | void InitEncoder(void *model_data, size_t model_data_length) { |
| 97 | encoder_sess_ = std::make_unique<Ort::Session>( | 120 | encoder_sess_ = std::make_unique<Ort::Session>( |
| @@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl { | @@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl { | ||
| 116 | SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); | 139 | SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); |
| 117 | SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); | 140 | SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); |
| 118 | SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); | 141 | SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); |
| 142 | + SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab"); | ||
| 119 | SHERPA_ONNX_READ_META_DATA(sot_, "sot"); | 143 | SHERPA_ONNX_READ_META_DATA(sot_, "sot"); |
| 120 | SHERPA_ONNX_READ_META_DATA(eot_, "eot"); | 144 | SHERPA_ONNX_READ_META_DATA(eot_, "eot"); |
| 121 | SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); | 145 | SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); |
| 122 | SHERPA_ONNX_READ_META_DATA(translate_, "translate"); | 146 | SHERPA_ONNX_READ_META_DATA(translate_, "translate"); |
| 147 | + SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe"); | ||
| 148 | + SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual"); | ||
| 123 | SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); | 149 | SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); |
| 124 | SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); | 150 | SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); |
| 125 | SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); | 151 | SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); |
| 152 | + | ||
| 153 | + if (is_multilingual_) { | ||
| 154 | + SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_, | ||
| 155 | + "all_language_tokens"); | ||
| 156 | + SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_, | ||
| 157 | + "all_language_codes"); | ||
| 158 | + if (all_language_tokens_.size() != all_language_codes_.size()) { | ||
| 159 | + SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d", | ||
| 160 | + static_cast<int32_t>(all_language_tokens_.size()), | ||
| 161 | + static_cast<int32_t>(all_language_codes_.size())); | ||
| 162 | + exit(-1); | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + for (int32_t i = 0; | ||
| 166 | + i != static_cast<int32_t>(all_language_tokens_.size()); ++i) { | ||
| 167 | + lang2id_[all_language_codes_[i]] = all_language_tokens_[i]; | ||
| 168 | + id2lang_[all_language_tokens_[i]] = all_language_codes_[i]; | ||
| 169 | + } | ||
| 170 | + } | ||
| 126 | } | 171 | } |
| 127 | 172 | ||
| 128 | void InitDecoder(void *model_data, size_t model_data_length) { | 173 | void InitDecoder(void *model_data, size_t model_data_length) { |
| @@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl { | @@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl { | ||
| 157 | std::vector<std::string> decoder_output_names_; | 202 | std::vector<std::string> decoder_output_names_; |
| 158 | std::vector<const char *> decoder_output_names_ptr_; | 203 | std::vector<const char *> decoder_output_names_ptr_; |
| 159 | 204 | ||
| 205 | + std::vector<int32_t> all_language_tokens_; | ||
| 206 | + std::vector<std::string> all_language_codes_; | ||
| 207 | + std::unordered_map<std::string, int32_t> lang2id_; | ||
| 208 | + std::unordered_map<int32_t, std::string> id2lang_; | ||
| 209 | + | ||
| 160 | // model meta data | 210 | // model meta data |
| 161 | int32_t n_text_layer_; | 211 | int32_t n_text_layer_; |
| 162 | int32_t n_text_ctx_; | 212 | int32_t n_text_ctx_; |
| 163 | int32_t n_text_state_; | 213 | int32_t n_text_state_; |
| 214 | + int32_t n_vocab_; | ||
| 164 | int32_t sot_; | 215 | int32_t sot_; |
| 165 | int32_t eot_; | 216 | int32_t eot_; |
| 166 | int32_t blank_; | 217 | int32_t blank_; |
| 167 | int32_t translate_; | 218 | int32_t translate_; |
| 219 | + int32_t transcribe_; | ||
| 168 | int32_t no_timestamps_; | 220 | int32_t no_timestamps_; |
| 169 | int32_t no_speech_; | 221 | int32_t no_speech_; |
| 222 | + int32_t is_multilingual_; | ||
| 170 | std::vector<int64_t> sot_sequence_; | 223 | std::vector<int64_t> sot_sequence_; |
| 171 | }; | 224 | }; |
| 172 | 225 | ||
| @@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) | @@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) | ||
| 176 | OfflineWhisperModel::~OfflineWhisperModel() = default; | 229 | OfflineWhisperModel::~OfflineWhisperModel() = default; |
| 177 | 230 | ||
| 178 | std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder( | 231 | std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder( |
| 179 | - Ort::Value features) { | 232 | + Ort::Value features) const { |
| 180 | return impl_->ForwardEncoder(std::move(features)); | 233 | return impl_->ForwardEncoder(std::move(features)); |
| 181 | } | 234 | } |
| 182 | 235 | ||
| @@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, | @@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, | ||
| 187 | Ort::Value n_layer_self_v_cache, | 240 | Ort::Value n_layer_self_v_cache, |
| 188 | Ort::Value n_layer_cross_k, | 241 | Ort::Value n_layer_cross_k, |
| 189 | Ort::Value n_layer_cross_v, | 242 | Ort::Value n_layer_cross_v, |
| 190 | - Ort::Value offset) { | 243 | + Ort::Value offset) const { |
| 191 | return impl_->ForwardDecoder( | 244 | return impl_->ForwardDecoder( |
| 192 | std::move(tokens), std::move(n_layer_self_k_cache), | 245 | std::move(tokens), std::move(n_layer_self_k_cache), |
| 193 | std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), | 246 | std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), |
| 194 | std::move(n_layer_cross_v), std::move(offset)); | 247 | std::move(n_layer_cross_v), std::move(offset)); |
| 195 | } | 248 | } |
| 196 | 249 | ||
| 197 | -std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() { | 250 | +std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() |
| 251 | + const { | ||
| 198 | return impl_->GetInitialSelfKVCache(); | 252 | return impl_->GetInitialSelfKVCache(); |
| 199 | } | 253 | } |
| 200 | 254 | ||
| @@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const { | @@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const { | ||
| 206 | return impl_->GetInitialTokens(); | 260 | return impl_->GetInitialTokens(); |
| 207 | } | 261 | } |
| 208 | 262 | ||
| 263 | +const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const { | ||
| 264 | + return impl_->GetAllLanguageIDs(); | ||
| 265 | +} | ||
| 266 | + | ||
| 267 | +const std::unordered_map<std::string, int32_t> | ||
| 268 | + &OfflineWhisperModel::GetLang2ID() const { | ||
| 269 | + return impl_->GetLang2ID(); | ||
| 270 | +} | ||
| 271 | + | ||
| 272 | +const std::unordered_map<int32_t, std::string> | ||
| 273 | + &OfflineWhisperModel::GetID2Lang() const { | ||
| 274 | + return impl_->GetID2Lang(); | ||
| 275 | +} | ||
| 276 | + | ||
| 277 | +int32_t OfflineWhisperModel::NoTimeStampsToken() const { | ||
| 278 | + return impl_->NoTimeStampsToken(); | ||
| 279 | +} | ||
| 280 | + | ||
| 209 | int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } | 281 | int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } |
| 210 | 282 | ||
| 283 | +int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); } | ||
| 284 | + | ||
| 211 | int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } | 285 | int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } |
| 212 | 286 | ||
| 287 | +int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 288 | + | ||
| 289 | +int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } | ||
| 290 | + | ||
| 291 | +bool OfflineWhisperModel::IsMultiLingual() const { | ||
| 292 | + return impl_->IsMultiLingual(); | ||
| 293 | +} | ||
| 294 | + | ||
| 213 | } // namespace sherpa_onnx | 295 | } // namespace sherpa_onnx |
| @@ -5,7 +5,9 @@ | @@ -5,7 +5,9 @@ | ||
| 5 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ | 5 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ |
| 6 | 6 | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | +#include <string> | ||
| 8 | #include <tuple> | 9 | #include <tuple> |
| 10 | +#include <unordered_map> | ||
| 9 | #include <utility> | 11 | #include <utility> |
| 10 | #include <vector> | 12 | #include <vector> |
| 11 | 13 | ||
| @@ -30,7 +32,7 @@ class OfflineWhisperModel { | @@ -30,7 +32,7 @@ class OfflineWhisperModel { | ||
| 30 | * - n_layer_cross_v: A 4-D tensor of shape | 32 | * - n_layer_cross_v: A 4-D tensor of shape |
| 31 | * (n_text_layer, N, n_audio_ctx, n_text_state) | 33 | * (n_text_layer, N, n_audio_ctx, n_text_state) |
| 32 | */ | 34 | */ |
| 33 | - std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features); | 35 | + std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const; |
| 34 | 36 | ||
| 35 | /** Run the decoder model. | 37 | /** Run the decoder model. |
| 36 | * | 38 | * |
| @@ -58,7 +60,9 @@ class OfflineWhisperModel { | @@ -58,7 +60,9 @@ class OfflineWhisperModel { | ||
| 58 | Ort::Value> | 60 | Ort::Value> |
| 59 | ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, | 61 | ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, |
| 60 | Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, | 62 | Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, |
| 61 | - Ort::Value n_layer_cross_v, Ort::Value offset); | 63 | + Ort::Value n_layer_cross_v, Ort::Value offset) const; |
| 64 | + | ||
| 65 | + int32_t DetectLanguage() const; | ||
| 62 | 66 | ||
| 63 | /** Return the initial self kv cache in a pair | 67 | /** Return the initial self kv cache in a pair |
| 64 | * - n_layer_self_k_cache A 4-D tensor of shape | 68 | * - n_layer_self_k_cache A 4-D tensor of shape |
| @@ -66,14 +70,23 @@ class OfflineWhisperModel { | @@ -66,14 +70,23 @@ class OfflineWhisperModel { | ||
| 66 | * - n_layer_self_v_cache A 4-D tensor of shape | 70 | * - n_layer_self_v_cache A 4-D tensor of shape |
| 67 | * (n_text_layer, N, n_audio_ctx, n_text_state). | 71 | * (n_text_layer, N, n_audio_ctx, n_text_state). |
| 68 | */ | 72 | */ |
| 69 | - std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache(); | 73 | + std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const; |
| 70 | const std::vector<int64_t> &GetInitialTokens() const; | 74 | const std::vector<int64_t> &GetInitialTokens() const; |
| 75 | + const std::vector<int32_t> &GetAllLanguageIDs() const; | ||
| 76 | + const std::unordered_map<std::string, int32_t> &GetLang2ID() const; | ||
| 77 | + const std::unordered_map<int32_t, std::string> &GetID2Lang() const; | ||
| 71 | 78 | ||
| 72 | /** Return an allocator for allocating memory | 79 | /** Return an allocator for allocating memory |
| 73 | */ | 80 | */ |
| 74 | OrtAllocator *Allocator() const; | 81 | OrtAllocator *Allocator() const; |
| 82 | + | ||
| 83 | + int32_t NoTimeStampsToken() const; | ||
| 75 | int32_t EOT() const; | 84 | int32_t EOT() const; |
| 85 | + int32_t SOT() const; | ||
| 76 | int32_t TextCtx() const; | 86 | int32_t TextCtx() const; |
| 87 | + int32_t VocabSize() const; | ||
| 88 | + int32_t Translate() const; | ||
| 89 | + bool IsMultiLingual() const; | ||
| 77 | 90 | ||
| 78 | private: | 91 | private: |
| 79 | class Impl; | 92 | class Impl; |
| @@ -14,10 +14,14 @@ namespace sherpa_onnx { | @@ -14,10 +14,14 @@ namespace sherpa_onnx { | ||
| 14 | void PybindOfflineWhisperModelConfig(py::module *m) { | 14 | void PybindOfflineWhisperModelConfig(py::module *m) { |
| 15 | using PyClass = OfflineWhisperModelConfig; | 15 | using PyClass = OfflineWhisperModelConfig; |
| 16 | py::class_<PyClass>(*m, "OfflineWhisperModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineWhisperModelConfig") |
| 17 | - .def(py::init<const std::string &, const std::string &>(), | ||
| 18 | - py::arg("encoder"), py::arg("decoder")) | 17 | + .def(py::init<const std::string &, const std::string &, |
| 18 | + const std::string &, const std::string &>(), | ||
| 19 | + py::arg("encoder"), py::arg("decoder"), py::arg("language"), | ||
| 20 | + py::arg("task")) | ||
| 19 | .def_readwrite("encoder", &PyClass::encoder) | 21 | .def_readwrite("encoder", &PyClass::encoder) |
| 20 | .def_readwrite("decoder", &PyClass::decoder) | 22 | .def_readwrite("decoder", &PyClass::decoder) |
| 23 | + .def_readwrite("language", &PyClass::language) | ||
| 24 | + .def_readwrite("task", &PyClass::task) | ||
| 21 | .def("__str__", &PyClass::ToString); | 25 | .def("__str__", &PyClass::ToString); |
| 22 | } | 26 | } |
| 23 | 27 |
| @@ -244,6 +244,8 @@ class OfflineRecognizer(object): | @@ -244,6 +244,8 @@ class OfflineRecognizer(object): | ||
| 244 | encoder: str, | 244 | encoder: str, |
| 245 | decoder: str, | 245 | decoder: str, |
| 246 | tokens: str, | 246 | tokens: str, |
| 247 | + language: str = "en", | ||
| 248 | + task: str = "transcribe", | ||
| 247 | num_threads: int = 1, | 249 | num_threads: int = 1, |
| 248 | decoding_method: str = "greedy_search", | 250 | decoding_method: str = "greedy_search", |
| 249 | debug: bool = False, | 251 | debug: bool = False, |
| @@ -268,6 +270,14 @@ class OfflineRecognizer(object): | @@ -268,6 +270,14 @@ class OfflineRecognizer(object): | ||
| 268 | 270 | ||
| 269 | symbol integer_id | 271 | symbol integer_id |
| 270 | 272 | ||
| 273 | + language: | ||
| 274 | + The spoken language in the audio file. Example values: en, de, zh, | ||
| 275 | + jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 | ||
| 276 | + for all possible values. Note that for non-multilingual models, the | ||
| 277 | + only valid value is 'en'. | ||
| 278 | + task: | ||
| 279 | + Valid values are: transcribe, translate. Note that for | ||
| 280 | + non-multilingual models, the only valid value is 'transcribe'. | ||
| 271 | num_threads: | 281 | num_threads: |
| 272 | Number of threads for neural network computation. | 282 | Number of threads for neural network computation. |
| 273 | decoding_method: | 283 | decoding_method: |
| @@ -279,7 +289,12 @@ class OfflineRecognizer(object): | @@ -279,7 +289,12 @@ class OfflineRecognizer(object): | ||
| 279 | """ | 289 | """ |
| 280 | self = cls.__new__(cls) | 290 | self = cls.__new__(cls) |
| 281 | model_config = OfflineModelConfig( | 291 | model_config = OfflineModelConfig( |
| 282 | - whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), | 292 | + whisper=OfflineWhisperModelConfig( |
| 293 | + encoder=encoder, | ||
| 294 | + decoder=decoder, | ||
| 295 | + language=language, | ||
| 296 | + task=task, | ||
| 297 | + ), | ||
| 283 | tokens=tokens, | 298 | tokens=tokens, |
| 284 | num_threads=num_threads, | 299 | num_threads=num_threads, |
| 285 | debug=debug, | 300 | debug=debug, |
-
请 注册 或 登录 后发表评论