Fangjun Kuang
Committed by GitHub

Add tail_paddings to Whisper C API. (#886)

1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.9.24") 4 +set(SHERPA_ONNX_VERSION "1.9.25")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -29,6 +29,7 @@ function createOfflineRecognizer() { @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
29 decoder: '', 29 decoder: '',
30 language: '', 30 language: '',
31 task: '', 31 task: '',
  32 + tailPaddings: -1,
32 }, 33 },
33 tdnn: { 34 tdnn: {
34 model: '', 35 model: '',
@@ -29,6 +29,7 @@ function createOfflineRecognizer() { @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
29 decoder: '', 29 decoder: '',
30 language: '', 30 language: '',
31 task: '', 31 task: '',
  32 + tailPaddings: -1,
32 }, 33 },
33 tdnn: { 34 tdnn: {
34 model: '', 35 model: '',
@@ -32,6 +32,7 @@ function createOfflineRecognizer() { @@ -32,6 +32,7 @@ function createOfflineRecognizer() {
32 decoder: '', 32 decoder: '',
33 language: '', 33 language: '',
34 task: '', 34 task: '',
  35 + tailPaddings: -1,
35 }, 36 },
36 tdnn: { 37 tdnn: {
37 model: '', 38 model: '',
@@ -29,6 +29,7 @@ function createOfflineRecognizer() { @@ -29,6 +29,7 @@ function createOfflineRecognizer() {
29 decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', 29 decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
30 language: '', 30 language: '',
31 task: 'transcribe', 31 task: 'transcribe',
  32 + tailPaddings: -1,
32 }, 33 },
33 tdnn: { 34 tdnn: {
34 model: '', 35 model: '',
@@ -301,6 +301,7 @@ namespace SherpaOnnx @@ -301,6 +301,7 @@ namespace SherpaOnnx
301 Decoder = ""; 301 Decoder = "";
302 Language = ""; 302 Language = "";
303 Task = "transcribe"; 303 Task = "transcribe";
  304 + TailPaddings = -1;
304 } 305 }
305 [MarshalAs(UnmanagedType.LPStr)] 306 [MarshalAs(UnmanagedType.LPStr)]
306 public string Encoder; 307 public string Encoder;
@@ -313,6 +314,8 @@ namespace SherpaOnnx @@ -313,6 +314,8 @@ namespace SherpaOnnx
313 314
314 [MarshalAs(UnmanagedType.LPStr)] 315 [MarshalAs(UnmanagedType.LPStr)]
315 public string Task; 316 public string Task;
  317 +
  318 + public int TailPaddings;
316 } 319 }
317 320
318 [StructLayout(LayoutKind.Sequential)] 321 [StructLayout(LayoutKind.Sequential)]
@@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct { @@ -336,10 +336,11 @@ type OfflineNemoEncDecCtcModelConfig struct {
336 } 336 }
337 337
338 type OfflineWhisperModelConfig struct { 338 type OfflineWhisperModelConfig struct {
339 - Encoder string  
340 - Decoder string  
341 - Language string  
342 - Task string 339 + Encoder string
  340 + Decoder string
  341 + Language string
  342 + Task string
  343 + TailPaddings int
343 } 344 }
344 345
345 type OfflineTdnnModelConfig struct { 346 type OfflineTdnnModelConfig struct {
@@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { @@ -441,6 +442,8 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
441 c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) 442 c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
442 defer C.free(unsafe.Pointer(c.model_config.whisper.task)) 443 defer C.free(unsafe.Pointer(c.model_config.whisper.task))
443 444
  445 + c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings)
  446 +
444 c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) 447 c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
445 defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) 448 defer C.free(unsafe.Pointer(c.model_config.tdnn.model))
446 449
@@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig( @@ -74,7 +74,8 @@ static SherpaOnnxOfflineWhisperModelConfig GetOfflineWhisperModelConfig(
74 SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder); 74 SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
75 SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder); 75 SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
76 SHERPA_ONNX_ASSIGN_ATTR_STR(language, language); 76 SHERPA_ONNX_ASSIGN_ATTR_STR(language, language);
77 - SHERPA_ONNX_ASSIGN_ATTR_STR(task, languagek); 77 + SHERPA_ONNX_ASSIGN_ATTR_STR(task, task);
  78 + SHERPA_ONNX_ASSIGN_ATTR_INT32(tail_paddings, tailPaddings);
78 79
79 return c; 80 return c;
80 } 81 }
@@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -341,6 +341,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
341 recognizer_config.model_config.whisper.task = "transcribe"; 341 recognizer_config.model_config.whisper.task = "transcribe";
342 } 342 }
343 343
  344 + recognizer_config.model_config.whisper.tail_paddings =
  345 + SHERPA_ONNX_OR(config->model_config.whisper.tail_paddings, -1);
  346 +
344 recognizer_config.model_config.tdnn.model = 347 recognizer_config.model_config.tdnn.model =
345 SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); 348 SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
346 349
@@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { @@ -359,6 +359,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
359 const char *decoder; 359 const char *decoder;
360 const char *language; 360 const char *language;
361 const char *task; 361 const char *task;
  362 + int32_t tail_paddings;
362 } SherpaOnnxOfflineWhisperModelConfig; 363 } SherpaOnnxOfflineWhisperModelConfig;
363 364
364 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { 365 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
@@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig( @@ -314,13 +314,15 @@ func sherpaOnnxOfflineWhisperModelConfig(
314 encoder: String = "", 314 encoder: String = "",
315 decoder: String = "", 315 decoder: String = "",
316 language: String = "", 316 language: String = "",
317 - task: String = "transcribe" 317 + task: String = "transcribe",
  318 + tailPaddings: Int = -1
318 ) -> SherpaOnnxOfflineWhisperModelConfig { 319 ) -> SherpaOnnxOfflineWhisperModelConfig {
319 return SherpaOnnxOfflineWhisperModelConfig( 320 return SherpaOnnxOfflineWhisperModelConfig(
320 encoder: toCPointer(encoder), 321 encoder: toCPointer(encoder),
321 decoder: toCPointer(decoder), 322 decoder: toCPointer(decoder),
322 language: toCPointer(language), 323 language: toCPointer(language),
323 - task: toCPointer(task) 324 + task: toCPointer(task),
  325 + tail_paddings: Int32(tailPaddings)
324 ) 326 )
325 } 327 }
326 328
@@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { @@ -453,6 +453,8 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) {
453 Module.setValue(ptr + 12, buffer + offset, 'i8*'); 453 Module.setValue(ptr + 12, buffer + offset, 'i8*');
454 offset += taskLen; 454 offset += taskLen;
455 455
  456 + Module.setValue(ptr + 16, config.tailPaddings || -1, 'i32');
  457 +
456 return { 458 return {
457 buffer: buffer, ptr: ptr, len: len, 459 buffer: buffer, ptr: ptr, len: len,
458 } 460 }
@@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, "");
14 static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); 14 static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, "");
15 15
16 static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); 16 static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, "");
17 -static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, ""); 17 +static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 5 * 4, "");
18 static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); 18 static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, "");
19 static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); 19 static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, "");
20 20
@@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { @@ -80,6 +80,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) {
80 fprintf(stdout, "decoder: %s\n", whisper->decoder); 80 fprintf(stdout, "decoder: %s\n", whisper->decoder);
81 fprintf(stdout, "language: %s\n", whisper->language); 81 fprintf(stdout, "language: %s\n", whisper->language);
82 fprintf(stdout, "task: %s\n", whisper->task); 82 fprintf(stdout, "task: %s\n", whisper->task);
  83 + fprintf(stdout, "tail_paddings: %d\n", whisper->tail_paddings);
83 84
84 fprintf(stdout, "----------offline tdnn model config----------\n"); 85 fprintf(stdout, "----------offline tdnn model config----------\n");
85 fprintf(stdout, "model: %s\n", tdnn->model); 86 fprintf(stdout, "model: %s\n", tdnn->model);