Fangjun Kuang
Committed by GitHub

Add dart API for SenseVoice (#1159)

@@ -6,6 +6,10 @@ cd dart-api-examples @@ -6,6 +6,10 @@ cd dart-api-examples
6 6
7 pushd non-streaming-asr 7 pushd non-streaming-asr
8 8
  9 +echo '----------SenseVoice----------'
  10 +./run-sense-voice.sh
  11 +rm -rf sherpa-onnx-*
  12 +
9 echo '----------NeMo transducer----------' 13 echo '----------NeMo transducer----------'
10 ./run-nemo-transducer.sh 14 ./run-nemo-transducer.sh
11 rm -rf sherpa-onnx-* 15 rm -rf sherpa-onnx-*
@@ -11,4 +11,5 @@ This folder contains examples for non-streaming ASR with Dart API. @@ -11,4 +11,5 @@ This folder contains examples for non-streaming ASR with Dart API.
11 |[./bin/whisper.dart](./bin/whisper.dart)| Use whisper for speech recognition. See [./run-whisper.sh](./run-whisper.sh)| 11 |[./bin/whisper.dart](./bin/whisper.dart)| Use whisper for speech recognition. See [./run-whisper.sh](./run-whisper.sh)|
12 |[./bin/zipformer-transducer.dart](./bin/zipformer-transducer.dart)| Use a zipformer transducer for speech recognition. See [./run-zipformer-transducer.sh](./run-zipformer-transducer.sh)| 12 |[./bin/zipformer-transducer.dart](./bin/zipformer-transducer.dart)| Use a zipformer transducer for speech recognition. See [./run-zipformer-transducer.sh](./run-zipformer-transducer.sh)|
13 |[./bin/vad-with-paraformer.dart](./bin/vad-with-paraformer.dart)| Use a [silero-vad](https://github.com/snakers4/silero-vad) with paraformer for speech recognition. See [./run-vad-with-paraformer.sh](./run-vad-with-paraformer.sh)| 13 |[./bin/vad-with-paraformer.dart](./bin/vad-with-paraformer.dart)| Use a [silero-vad](https://github.com/snakers4/silero-vad) with paraformer for speech recognition. See [./run-vad-with-paraformer.sh](./run-vad-with-paraformer.sh)|
  14 +|[./bin/sense-voice.dart](./bin/sense-voice.dart)| Use a SenseVoice CTC model for speech recognition. See [./run-sense-voice.sh](./run-sense-voice.sh)|
14 15
  1 +// Copyright (c) 2024 Xiaomi Corporation
  2 +import 'dart:io';
  3 +import 'dart:typed_data';
  4 +
  5 +import 'package:args/args.dart';
  6 +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
  7 +
  8 +import './init.dart';
  9 +
  10 +void main(List<String> arguments) async {
  11 + await initSherpaOnnx();
  12 +
  13 + final parser = ArgParser()
  14 + ..addOption('model', help: 'Path to the paraformer model')
  15 + ..addOption('tokens', help: 'Path to tokens.txt')
  16 + ..addOption('language',
  17 + help: 'auto, zh, en, ja, ko, yue, or leave it empty to use auto',
  18 + defaultsTo: '')
  19 + ..addOption('use-itn',
  20 + help: 'true to use inverse text normalization', defaultsTo: 'false')
  21 + ..addOption('input-wav', help: 'Path to input.wav to transcribe');
  22 +
  23 + final res = parser.parse(arguments);
  24 + if (res['model'] == null ||
  25 + res['tokens'] == null ||
  26 + res['input-wav'] == null) {
  27 + print(parser.usage);
  28 + exit(1);
  29 + }
  30 +
  31 + final model = res['model'] as String;
  32 + final tokens = res['tokens'] as String;
  33 + final inputWav = res['input-wav'] as String;
  34 + final language = res['language'] as String;
  35 + final useItn = (res['use-itn'] as String).toLowerCase() == 'true';
  36 +
  37 + final senseVoice = sherpa_onnx.OfflineSenseVoiceModelConfig(
  38 + model: model, language: language, useInverseTextNormalization: useItn);
  39 +
  40 + final modelConfig = sherpa_onnx.OfflineModelConfig(
  41 + senseVoice: senseVoice,
  42 + tokens: tokens,
  43 + debug: true,
  44 + numThreads: 1,
  45 + );
  46 + final config = sherpa_onnx.OfflineRecognizerConfig(model: modelConfig);
  47 + final recognizer = sherpa_onnx.OfflineRecognizer(config);
  48 +
  49 + final waveData = sherpa_onnx.readWave(inputWav);
  50 + final stream = recognizer.createStream();
  51 +
  52 + stream.acceptWaveform(
  53 + samples: waveData.samples, sampleRate: waveData.sampleRate);
  54 + recognizer.decode(stream);
  55 +
  56 + final result = recognizer.getResult(stream);
  57 + print(result.text);
  58 +
  59 + stream.free();
  60 + recognizer.free();
  61 +}
@@ -10,7 +10,7 @@ environment: @@ -10,7 +10,7 @@ environment:
10 10
11 # Add regular dependencies here. 11 # Add regular dependencies here.
12 dependencies: 12 dependencies:
13 - sherpa_onnx: ^1.10.16 13 + sherpa_onnx: ^1.10.17
14 path: ^1.9.0 14 path: ^1.9.0
15 args: ^2.5.0 15 args: ^2.5.0
16 16
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +dart pub get
  6 +
  7 +if [ ! -f ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt ]; then
  8 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  9 + tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  10 + rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  11 +fi
  12 +
  13 +dart run \
  14 + ./bin/sense-voice.dart \
  15 + --model ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx \
  16 + --tokens ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \
  17 + --use-itn true \
  18 + --input-wav ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav
@@ -11,7 +11,7 @@ environment: @@ -11,7 +11,7 @@ environment:
11 11
12 # Add regular dependencies here. 12 # Add regular dependencies here.
13 dependencies: 13 dependencies:
14 - sherpa_onnx: ^1.10.16 14 + sherpa_onnx: ^1.10.17
15 path: ^1.9.0 15 path: ^1.9.0
16 args: ^2.5.0 16 args: ^2.5.0
17 17
@@ -8,7 +8,7 @@ environment: @@ -8,7 +8,7 @@ environment:
8 8
9 # Add regular dependencies here. 9 # Add regular dependencies here.
10 dependencies: 10 dependencies:
11 - sherpa_onnx: ^1.10.16 11 + sherpa_onnx: ^1.10.17
12 path: ^1.9.0 12 path: ^1.9.0
13 args: ^2.5.0 13 args: ^2.5.0
14 14
@@ -9,7 +9,7 @@ environment: @@ -9,7 +9,7 @@ environment:
9 sdk: ^3.4.0 9 sdk: ^3.4.0
10 10
11 dependencies: 11 dependencies:
12 - sherpa_onnx: ^1.10.16 12 + sherpa_onnx: ^1.10.17
13 path: ^1.9.0 13 path: ^1.9.0
14 args: ^2.5.0 14 args: ^2.5.0
15 15
@@ -5,7 +5,7 @@ description: > @@ -5,7 +5,7 @@ description: >
5 5
6 publish_to: 'none' 6 publish_to: 'none'
7 7
8 -version: 1.10.16 8 +version: 1.10.17
9 9
10 topics: 10 topics:
11 - speech-recognition 11 - speech-recognition
@@ -30,7 +30,7 @@ dependencies: @@ -30,7 +30,7 @@ dependencies:
30 record: ^5.1.0 30 record: ^5.1.0
31 url_launcher: ^6.2.6 31 url_launcher: ^6.2.6
32 32
33 - sherpa_onnx: ^1.10.16 33 + sherpa_onnx: ^1.10.17
34 # sherpa_onnx: 34 # sherpa_onnx:
35 # path: ../../flutter/sherpa_onnx 35 # path: ../../flutter/sherpa_onnx
36 36
@@ -5,7 +5,7 @@ description: > @@ -5,7 +5,7 @@ description: >
5 5
6 publish_to: 'none' # Remove this line if you wish to publish to pub.dev 6 publish_to: 'none' # Remove this line if you wish to publish to pub.dev
7 7
8 -version: 1.10.16 8 +version: 1.10.17
9 9
10 environment: 10 environment:
11 sdk: '>=3.4.0 <4.0.0' 11 sdk: '>=3.4.0 <4.0.0'
@@ -17,7 +17,7 @@ dependencies: @@ -17,7 +17,7 @@ dependencies:
17 cupertino_icons: ^1.0.6 17 cupertino_icons: ^1.0.6
18 path_provider: ^2.1.3 18 path_provider: ^2.1.3
19 path: ^1.9.0 19 path: ^1.9.0
20 - sherpa_onnx: ^1.10.16 20 + sherpa_onnx: ^1.10.17
21 url_launcher: ^6.2.6 21 url_launcher: ^6.2.6
22 audioplayers: ^5.0.0 22 audioplayers: ^5.0.0
23 23
@@ -79,6 +79,23 @@ class OfflineTdnnModelConfig { @@ -79,6 +79,23 @@ class OfflineTdnnModelConfig {
79 final String model; 79 final String model;
80 } 80 }
81 81
  82 +class OfflineSenseVoiceModelConfig {
  83 + const OfflineSenseVoiceModelConfig({
  84 + this.model = '',
  85 + this.language = '',
  86 + this.useInverseTextNormalization = false,
  87 + });
  88 +
  89 + @override
  90 + String toString() {
  91 + return 'OfflineSenseVoiceModelConfig(model: $model, language: $language, useInverseTextNormalization: $useInverseTextNormalization)';
  92 + }
  93 +
  94 + final String model;
  95 + final String language;
  96 + final bool useInverseTextNormalization;
  97 +}
  98 +
82 class OfflineLMConfig { 99 class OfflineLMConfig {
83 const OfflineLMConfig({this.model = '', this.scale = 1.0}); 100 const OfflineLMConfig({this.model = '', this.scale = 1.0});
84 101
@@ -98,6 +115,7 @@ class OfflineModelConfig { @@ -98,6 +115,7 @@ class OfflineModelConfig {
98 this.nemoCtc = const OfflineNemoEncDecCtcModelConfig(), 115 this.nemoCtc = const OfflineNemoEncDecCtcModelConfig(),
99 this.whisper = const OfflineWhisperModelConfig(), 116 this.whisper = const OfflineWhisperModelConfig(),
100 this.tdnn = const OfflineTdnnModelConfig(), 117 this.tdnn = const OfflineTdnnModelConfig(),
  118 + this.senseVoice = const OfflineSenseVoiceModelConfig(),
101 required this.tokens, 119 required this.tokens,
102 this.numThreads = 1, 120 this.numThreads = 1,
103 this.debug = true, 121 this.debug = true,
@@ -110,7 +128,7 @@ class OfflineModelConfig { @@ -110,7 +128,7 @@ class OfflineModelConfig {
110 128
111 @override 129 @override
112 String toString() { 130 String toString() {
113 - return 'OfflineModelConfig(transducer: $transducer, paraformer: $paraformer, nemoCtc: $nemoCtc, whisper: $whisper, tdnn: $tdnn, tokens: $tokens, numThreads: $numThreads, debug: $debug, provider: $provider, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab, telespeechCtc: $telespeechCtc)'; 131 + return 'OfflineModelConfig(transducer: $transducer, paraformer: $paraformer, nemoCtc: $nemoCtc, whisper: $whisper, tdnn: $tdnn, senseVoice: $senseVoice, tokens: $tokens, numThreads: $numThreads, debug: $debug, provider: $provider, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab, telespeechCtc: $telespeechCtc)';
114 } 132 }
115 133
116 final OfflineTransducerModelConfig transducer; 134 final OfflineTransducerModelConfig transducer;
@@ -118,6 +136,7 @@ class OfflineModelConfig { @@ -118,6 +136,7 @@ class OfflineModelConfig {
118 final OfflineNemoEncDecCtcModelConfig nemoCtc; 136 final OfflineNemoEncDecCtcModelConfig nemoCtc;
119 final OfflineWhisperModelConfig whisper; 137 final OfflineWhisperModelConfig whisper;
120 final OfflineTdnnModelConfig tdnn; 138 final OfflineTdnnModelConfig tdnn;
  139 + final OfflineSenseVoiceModelConfig senseVoice;
121 140
122 final String tokens; 141 final String tokens;
123 final int numThreads; 142 final int numThreads;
@@ -219,6 +238,14 @@ class OfflineRecognizer { @@ -219,6 +238,14 @@ class OfflineRecognizer {
219 238
220 c.ref.model.tdnn.model = config.model.tdnn.model.toNativeUtf8(); 239 c.ref.model.tdnn.model = config.model.tdnn.model.toNativeUtf8();
221 240
  241 + c.ref.model.senseVoice.model = config.model.senseVoice.model.toNativeUtf8();
  242 +
  243 + c.ref.model.senseVoice.language =
  244 + config.model.senseVoice.language.toNativeUtf8();
  245 +
  246 + c.ref.model.senseVoice.useInverseTextNormalization =
  247 + config.model.senseVoice.useInverseTextNormalization ? 1 : 0;
  248 +
222 c.ref.model.tokens = config.model.tokens.toNativeUtf8(); 249 c.ref.model.tokens = config.model.tokens.toNativeUtf8();
223 250
224 c.ref.model.numThreads = config.model.numThreads; 251 c.ref.model.numThreads = config.model.numThreads;
@@ -254,6 +281,8 @@ class OfflineRecognizer { @@ -254,6 +281,8 @@ class OfflineRecognizer {
254 calloc.free(c.ref.model.modelType); 281 calloc.free(c.ref.model.modelType);
255 calloc.free(c.ref.model.provider); 282 calloc.free(c.ref.model.provider);
256 calloc.free(c.ref.model.tokens); 283 calloc.free(c.ref.model.tokens);
  284 + calloc.free(c.ref.model.senseVoice.language);
  285 + calloc.free(c.ref.model.senseVoice.model);
257 calloc.free(c.ref.model.tdnn.model); 286 calloc.free(c.ref.model.tdnn.model);
258 calloc.free(c.ref.model.whisper.task); 287 calloc.free(c.ref.model.whisper.task);
259 calloc.free(c.ref.model.whisper.language); 288 calloc.free(c.ref.model.whisper.language);
@@ -87,6 +87,14 @@ final class SherpaOnnxOfflineTdnnModelConfig extends Struct { @@ -87,6 +87,14 @@ final class SherpaOnnxOfflineTdnnModelConfig extends Struct {
87 external Pointer<Utf8> model; 87 external Pointer<Utf8> model;
88 } 88 }
89 89
  90 +final class SherpaOnnxOfflineSenseVoiceModelConfig extends Struct {
  91 + external Pointer<Utf8> model;
  92 + external Pointer<Utf8> language;
  93 +
  94 + @Int32()
  95 + external int useInverseTextNormalization;
  96 +}
  97 +
90 final class SherpaOnnxOfflineLMConfig extends Struct { 98 final class SherpaOnnxOfflineLMConfig extends Struct {
91 external Pointer<Utf8> model; 99 external Pointer<Utf8> model;
92 100
@@ -115,6 +123,8 @@ final class SherpaOnnxOfflineModelConfig extends Struct { @@ -115,6 +123,8 @@ final class SherpaOnnxOfflineModelConfig extends Struct {
115 external Pointer<Utf8> modelingUnit; 123 external Pointer<Utf8> modelingUnit;
116 external Pointer<Utf8> bpeVocab; 124 external Pointer<Utf8> bpeVocab;
117 external Pointer<Utf8> telespeechCtc; 125 external Pointer<Utf8> telespeechCtc;
  126 +
  127 + external SherpaOnnxOfflineSenseVoiceModelConfig senseVoice;
118 } 128 }
119 129
120 final class SherpaOnnxOfflineRecognizerConfig extends Struct { 130 final class SherpaOnnxOfflineRecognizerConfig extends Struct {
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
20 -version: 1.10.16 20 +version: 1.10.17
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -30,19 +30,19 @@ dependencies: @@ -30,19 +30,19 @@ dependencies:
30 flutter: 30 flutter:
31 sdk: flutter 31 sdk: flutter
32 32
33 - sherpa_onnx_android: ^1.10.16 33 + sherpa_onnx_android: ^1.10.17
34 # path: ../sherpa_onnx_android 34 # path: ../sherpa_onnx_android
35 35
36 - sherpa_onnx_macos: ^1.10.16 36 + sherpa_onnx_macos: ^1.10.17
37 # path: ../sherpa_onnx_macos 37 # path: ../sherpa_onnx_macos
38 38
39 - sherpa_onnx_linux: ^1.10.16 39 + sherpa_onnx_linux: ^1.10.17
40 # path: ../sherpa_onnx_linux 40 # path: ../sherpa_onnx_linux
41 # 41 #
42 - sherpa_onnx_windows: ^1.10.16 42 + sherpa_onnx_windows: ^1.10.17
43 # path: ../sherpa_onnx_windows 43 # path: ../sherpa_onnx_windows
44 44
45 - sherpa_onnx_ios: ^1.10.16 45 + sherpa_onnx_ios: ^1.10.17
46 # sherpa_onnx_ios: 46 # sherpa_onnx_ios:
47 # path: ../sherpa_onnx_ios 47 # path: ../sherpa_onnx_ios
48 48
@@ -7,7 +7,7 @@ @@ -7,7 +7,7 @@
7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c 7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
8 Pod::Spec.new do |s| 8 Pod::Spec.new do |s|
9 s.name = 'sherpa_onnx_ios' 9 s.name = 'sherpa_onnx_ios'
10 - s.version = '1.10.16' 10 + s.version = '1.10.17'
11 s.summary = 'A new Flutter FFI plugin project.' 11 s.summary = 'A new Flutter FFI plugin project.'
12 s.description = <<-DESC 12 s.description = <<-DESC
13 A new Flutter FFI plugin project. 13 A new Flutter FFI plugin project.
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 # 4 #
5 Pod::Spec.new do |s| 5 Pod::Spec.new do |s|
6 s.name = 'sherpa_onnx_macos' 6 s.name = 'sherpa_onnx_macos'
7 - s.version = '1.10.16' 7 + s.version = '1.10.17'
8 s.summary = 'sherpa-onnx Flutter FFI plugin project.' 8 s.summary = 'sherpa-onnx Flutter FFI plugin project.'
9 s.description = <<-DESC 9 s.description = <<-DESC
10 sherpa-onnx Flutter FFI plugin project. 10 sherpa-onnx Flutter FFI plugin project.
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
20 -version: 1.10.16 20 +version: 1.10.17
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -13,14 +13,15 @@ namespace sherpa_onnx { @@ -13,14 +13,15 @@ namespace sherpa_onnx {
13 13
14 void CudaConfig::Register(ParseOptions *po) { 14 void CudaConfig::Register(ParseOptions *po) {
15 po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, 15 po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search,
16 - "CuDNN convolution algrorithm search"); 16 + "CuDNN convolution algrorithm search");
17 } 17 }
18 18
19 bool CudaConfig::Validate() const { 19 bool CudaConfig::Validate() const {
20 if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { 20 if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) {
21 - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option."  
22 - "Options : [1,3]. Check OnnxRT docs",  
23 - cudnn_conv_algo_search); 21 + SHERPA_ONNX_LOGE(
  22 + "cudnn_conv_algo_search: '%d' is not a valid option."
  23 + "Options : [1,3]. Check OnnxRT docs",
  24 + cudnn_conv_algo_search);
24 return false; 25 return false;
25 } 26 }
26 return true; 27 return true;
@@ -37,41 +38,41 @@ std::string CudaConfig::ToString() const { @@ -37,41 +38,41 @@ std::string CudaConfig::ToString() const {
37 38
38 void TensorrtConfig::Register(ParseOptions *po) { 39 void TensorrtConfig::Register(ParseOptions *po) {
39 po->Register("trt-max-workspace-size", &trt_max_workspace_size, 40 po->Register("trt-max-workspace-size", &trt_max_workspace_size,
40 - "Set TensorRT EP GPU memory usage limit."); 41 + "Set TensorRT EP GPU memory usage limit.");
41 po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, 42 po->Register("trt-max-partition-iterations", &trt_max_partition_iterations,
42 - "Limit partitioning iterations for model conversion."); 43 + "Limit partitioning iterations for model conversion.");
43 po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, 44 po->Register("trt-min-subgraph-size", &trt_min_subgraph_size,
44 - "Set minimum size for subgraphs in partitioning."); 45 + "Set minimum size for subgraphs in partitioning.");
45 po->Register("trt-fp16-enable", &trt_fp16_enable, 46 po->Register("trt-fp16-enable", &trt_fp16_enable,
46 - "Enable FP16 precision for faster performance."); 47 + "Enable FP16 precision for faster performance.");
47 po->Register("trt-detailed-build-log", &trt_detailed_build_log, 48 po->Register("trt-detailed-build-log", &trt_detailed_build_log,
48 - "Enable detailed logging of build steps."); 49 + "Enable detailed logging of build steps.");
49 po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, 50 po->Register("trt-engine-cache-enable", &trt_engine_cache_enable,
50 - "Enable caching of TensorRT engines."); 51 + "Enable caching of TensorRT engines.");
51 po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, 52 po->Register("trt-timing-cache-enable", &trt_timing_cache_enable,
52 - "Enable use of timing cache to speed up builds."); 53 + "Enable use of timing cache to speed up builds.");
53 po->Register("trt-engine-cache-path", &trt_engine_cache_path, 54 po->Register("trt-engine-cache-path", &trt_engine_cache_path,
54 - "Set path to store cached TensorRT engines."); 55 + "Set path to store cached TensorRT engines.");
55 po->Register("trt-timing-cache-path", &trt_timing_cache_path, 56 po->Register("trt-timing-cache-path", &trt_timing_cache_path,
56 - "Set path for storing timing cache."); 57 + "Set path for storing timing cache.");
57 po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, 58 po->Register("trt-dump-subgraphs", &trt_dump_subgraphs,
58 - "Dump optimized subgraphs for debugging."); 59 + "Dump optimized subgraphs for debugging.");
59 } 60 }
60 61
61 bool TensorrtConfig::Validate() const { 62 bool TensorrtConfig::Validate() const {
62 if (trt_max_workspace_size < 0) { 63 if (trt_max_workspace_size < 0) {
63 - SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.",  
64 - trt_max_workspace_size); 64 + SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.",
  65 + trt_max_workspace_size);
65 return false; 66 return false;
66 } 67 }
67 if (trt_max_partition_iterations < 0) { 68 if (trt_max_partition_iterations < 0) {
68 SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", 69 SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.",
69 - trt_max_partition_iterations); 70 + trt_max_partition_iterations);
70 return false; 71 return false;
71 } 72 }
72 if (trt_min_subgraph_size < 0) { 73 if (trt_min_subgraph_size < 0) {
73 SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", 74 SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.",
74 - trt_min_subgraph_size); 75 + trt_min_subgraph_size);
75 return false; 76 return false;
76 } 77 }
77 78
@@ -83,23 +84,19 @@ std::string TensorrtConfig::ToString() const { @@ -83,23 +84,19 @@ std::string TensorrtConfig::ToString() const {
83 84
84 os << "TensorrtConfig("; 85 os << "TensorrtConfig(";
85 os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; 86 os << "trt_max_workspace_size=" << trt_max_workspace_size << ", ";
86 - os << "trt_max_partition_iterations="  
87 - << trt_max_partition_iterations << ", "; 87 + os << "trt_max_partition_iterations=" << trt_max_partition_iterations << ", ";
88 os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; 88 os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", ";
89 - os << "trt_fp16_enable=\""  
90 - << (trt_fp16_enable? "True" : "False") << "\", "; 89 + os << "trt_fp16_enable=\"" << (trt_fp16_enable ? "True" : "False") << "\", ";
91 os << "trt_detailed_build_log=\"" 90 os << "trt_detailed_build_log=\""
92 - << (trt_detailed_build_log? "True" : "False") << "\", "; 91 + << (trt_detailed_build_log ? "True" : "False") << "\", ";
93 os << "trt_engine_cache_enable=\"" 92 os << "trt_engine_cache_enable=\""
94 - << (trt_engine_cache_enable? "True" : "False") << "\", ";  
95 - os << "trt_engine_cache_path=\""  
96 - << trt_engine_cache_path.c_str() << "\", "; 93 + << (trt_engine_cache_enable ? "True" : "False") << "\", ";
  94 + os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", ";
97 os << "trt_timing_cache_enable=\"" 95 os << "trt_timing_cache_enable=\""
98 - << (trt_timing_cache_enable? "True" : "False") << "\", ";  
99 - os << "trt_timing_cache_path=\""  
100 - << trt_timing_cache_path.c_str() << "\",";  
101 - os << "trt_dump_subgraphs=\""  
102 - << (trt_dump_subgraphs? "True" : "False") << "\" )"; 96 + << (trt_timing_cache_enable ? "True" : "False") << "\", ";
  97 + os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\",";
  98 + os << "trt_dump_subgraphs=\"" << (trt_dump_subgraphs ? "True" : "False")
  99 + << "\" )";
103 return os.str(); 100 return os.str();
104 } 101 }
105 102