Committed by
GitHub
Add tail_paddings to Whisper C API. (#886)
正在显示
13 个修改的文件
包含
29 行增加
和
9 行删除
| @@ -29,6 +29,7 @@ function createOfflineRecognizer() { | @@ -29,6 +29,7 @@ function createOfflineRecognizer() { | ||
| 29 | decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', | 29 | decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', |
| 30 | language: '', | 30 | language: '', |
| 31 | task: 'transcribe', | 31 | task: 'transcribe', |
| 32 | + tailPaddings: -1, | ||
| 32 | }, | 33 | }, |
| 33 | tdnn: { | 34 | tdnn: { |
| 34 | model: '', | 35 | model: '', |
| @@ -301,6 +301,7 @@ namespace SherpaOnnx | @@ -301,6 +301,7 @@ namespace SherpaOnnx | ||
| 301 | Decoder = ""; | 301 | Decoder = ""; |
| 302 | Language = ""; | 302 | Language = ""; |
| 303 | Task = "transcribe"; | 303 | Task = "transcribe"; |
| 304 | + TailPaddings = -1; | ||
| 304 | } | 305 | } |
| 305 | [MarshalAs(UnmanagedType.LPStr)] | 306 | [MarshalAs(UnmanagedType.LPStr)] |
| 306 | public string Encoder; | 307 | public string Encoder; |
| @@ -313,6 +314,8 @@ namespace SherpaOnnx | @@ -313,6 +314,8 @@ namespace SherpaOnnx | ||
| 313 | 314 | ||
| 314 | [MarshalAs(UnmanagedType.LPStr)] | 315 | [MarshalAs(UnmanagedType.LPStr)] |
| 315 | public string Task; | 316 | public string Task; |
| 317 | + | ||
| 318 | + public int TailPaddings; | ||
| 316 | } | 319 | } |
| 317 | 320 | ||
| 318 | [StructLayout(LayoutKind.Sequential)] | 321 | [StructLayout(LayoutKind.Sequential)] |
| @@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct { | @@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct { | ||
| 336 | } | 336 | } |
| 337 | 337 | ||
| 338 | type OfflineWhisperModelConfig struct { | 338 | type OfflineWhisperModelConfig struct { |
| 339 | - Encoder string | ||
| 340 | - Decoder string | ||
| 341 | - Language string | ||
| 342 | - Task string | 339 | + Encoder string |
| 340 | + Decoder string | ||
| 341 | + Language string | ||
| 342 | + Task string | ||
| 343 | + TailPaddings int | ||
| 343 | } | 344 | } |
| 344 | 345 | ||
| 345 | type OfflineTdnnModelConfig struct { | 346 | type OfflineTdnnModelConfig struct { |
| @@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { | @@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { | ||
| 441 | c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) | 442 | c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) |
| 442 | defer C.free(unsafe.Pointer(c.model_config.whisper.task)) | 443 | defer C.free(unsafe.Pointer(c.model_config.whisper.task)) |
| 443 | 444 | ||
| 445 | + c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings) | ||
| 446 | + | ||
| 444 | c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) | 447 | c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) |
| 445 | defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) | 448 | defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) |
| 446 | 449 |
| @@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig( | @@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig( | ||
| 74 | SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder); | 74 | SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder); |
| 75 | SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder); | 75 | SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder); |
| 76 | SHERPA_ONNX_ASSIGN_ATTR_STR(language, language); | 76 | SHERPA_ONNX_ASSIGN_ATTR_STR(language, language); |
| 77 | - SHERPA_ONNX_ASSIGN_ATTR_STR(task, languagek); | 77 | + SHERPA_ONNX_ASSIGN_ATTR_STR(task, task); |
| 78 | + SHERPA_ONNX_ASSIGN_ATTR_INT32(tail_paddings, tailPaddings); | ||
| 78 | 79 | ||
| 79 | return c; | 80 | return c; |
| 80 | } | 81 | } |
| @@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 341 | recognizer_config.model_config.whisper.task = "transcribe"; | 341 | recognizer_config.model_config.whisper.task = "transcribe"; |
| 342 | } | 342 | } |
| 343 | 343 | ||
| 344 | + recognizer_config.model_config.whisper.tail_paddings = | ||
| 345 | + SHERPA_ONNX_OR(config->model_config.whisper.tail_paddings, -1); | ||
| 346 | + | ||
| 344 | recognizer_config.model_config.tdnn.model = | 347 | recognizer_config.model_config.tdnn.model = |
| 345 | SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); | 348 | SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); |
| 346 | 349 |
| @@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { | @@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { | ||
| 359 | const char *decoder; | 359 | const char *decoder; |
| 360 | const char *language; | 360 | const char *language; |
| 361 | const char *task; | 361 | const char *task; |
| 362 | + int32_t tail_paddings; | ||
| 362 | } SherpaOnnxOfflineWhisperModelConfig; | 363 | } SherpaOnnxOfflineWhisperModelConfig; |
| 363 | 364 | ||
| 364 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { | 365 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { |
| @@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig( | @@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig( | ||
| 314 | encoder: String = "", | 314 | encoder: String = "", |
| 315 | decoder: String = "", | 315 | decoder: String = "", |
| 316 | language: String = "", | 316 | language: String = "", |
| 317 | - task: String = "transcribe" | 317 | + task: String = "transcribe", |
| 318 | + tailPaddings: Int = -1 | ||
| 318 | ) -> SherpaOnnxOfflineWhisperModelConfig { | 319 | ) -> SherpaOnnxOfflineWhisperModelConfig { |
| 319 | return SherpaOnnxOfflineWhisperModelConfig( | 320 | return SherpaOnnxOfflineWhisperModelConfig( |
| 320 | encoder: toCPointer(encoder), | 321 | encoder: toCPointer(encoder), |
| 321 | decoder: toCPointer(decoder), | 322 | decoder: toCPointer(decoder), |
| 322 | language: toCPointer(language), | 323 | language: toCPointer(language), |
| 323 | - task: toCPointer(task) | 324 | + task: toCPointer(task), |
| 325 | + tail_paddings: Int32(tailPaddings) | ||
| 324 | ) | 326 | ) |
| 325 | } | 327 | } |
| 326 | 328 |
| @@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { | @@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { | ||
| 453 | Module.setValue(ptr + 12, buffer + offset, 'i8*'); | 453 | Module.setValue(ptr + 12, buffer + offset, 'i8*'); |
| 454 | offset += taskLen; | 454 | offset += taskLen; |
| 455 | 455 | ||
| 456 | + Module.setValue(ptr + 16, config.tailPaddings || -1, 'i32'); | ||
| 457 | + | ||
| 456 | return { | 458 | return { |
| 457 | buffer: buffer, ptr: ptr, len: len, | 459 | buffer: buffer, ptr: ptr, len: len, |
| 458 | } | 460 | } |
| @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); | @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); | ||
| 14 | static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); | 14 | static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); |
| 15 | 15 | ||
| 16 | static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); | 16 | static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); |
| 17 | -static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, ""); | 17 | +static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 5 * 4, ""); |
| 18 | static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); | 18 | static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); |
| 19 | static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); | 19 | static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); |
| 20 | 20 | ||
| @@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { | @@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { | ||
| 80 | fprintf(stdout, "decoder: %s\n", whisper->decoder); | 80 | fprintf(stdout, "decoder: %s\n", whisper->decoder); |
| 81 | fprintf(stdout, "language: %s\n", whisper->language); | 81 | fprintf(stdout, "language: %s\n", whisper->language); |
| 82 | fprintf(stdout, "task: %s\n", whisper->task); | 82 | fprintf(stdout, "task: %s\n", whisper->task); |
| 83 | + fprintf(stdout, "tail_paddings: %d\n", whisper->tail_paddings); | ||
| 83 | 84 | ||
| 84 | fprintf(stdout, "----------offline tdnn model config----------\n"); | 85 | fprintf(stdout, "----------offline tdnn model config----------\n"); |
| 85 | fprintf(stdout, "model: %s\n", tdnn->model); | 86 | fprintf(stdout, "model: %s\n", tdnn->model); |
-
请 注册 或 登录 后发表评论