正在显示
17 个修改的文件
包含
169 行增加
和
49 行删除
| @@ -6,6 +6,10 @@ cd dart-api-examples | @@ -6,6 +6,10 @@ cd dart-api-examples | ||
| 6 | 6 | ||
| 7 | pushd non-streaming-asr | 7 | pushd non-streaming-asr |
| 8 | 8 | ||
| 9 | +echo '----------SenseVoice----------' | ||
| 10 | +./run-sense-voice.sh | ||
| 11 | +rm -rf sherpa-onnx-* | ||
| 12 | + | ||
| 9 | echo '----------NeMo transducer----------' | 13 | echo '----------NeMo transducer----------' |
| 10 | ./run-nemo-transducer.sh | 14 | ./run-nemo-transducer.sh |
| 11 | rm -rf sherpa-onnx-* | 15 | rm -rf sherpa-onnx-* |
| @@ -11,4 +11,5 @@ This folder contains examples for non-streaming ASR with Dart API. | @@ -11,4 +11,5 @@ This folder contains examples for non-streaming ASR with Dart API. | ||
| 11 | |[./bin/whisper.dart](./bin/whisper.dart)| Use whisper for speech recognition. See [./run-whisper.sh](./run-whisper.sh)| | 11 | |[./bin/whisper.dart](./bin/whisper.dart)| Use whisper for speech recognition. See [./run-whisper.sh](./run-whisper.sh)| |
| 12 | |[./bin/zipformer-transducer.dart](./bin/zipformer-transducer.dart)| Use a zipformer transducer for speech recognition. See [./run-zipformer-transducer.sh](./run-zipformer-transducer.sh)| | 12 | |[./bin/zipformer-transducer.dart](./bin/zipformer-transducer.dart)| Use a zipformer transducer for speech recognition. See [./run-zipformer-transducer.sh](./run-zipformer-transducer.sh)| |
| 13 | |[./bin/vad-with-paraformer.dart](./bin/vad-with-paraformer.dart)| Use a [silero-vad](https://github.com/snakers4/silero-vad) with paraformer for speech recognition. See [./run-vad-with-paraformer.sh](./run-vad-with-paraformer.sh)| | 13 | |[./bin/vad-with-paraformer.dart](./bin/vad-with-paraformer.dart)| Use a [silero-vad](https://github.com/snakers4/silero-vad) with paraformer for speech recognition. See [./run-vad-with-paraformer.sh](./run-vad-with-paraformer.sh)| |
| 14 | +|[./bin/sense-voice.dart](./bin/sense-voice.dart)| Use a SenseVoice CTC model for speech recognition. See [./run-sense-voice.sh](./run-sense-voice.sh)| | ||
| 14 | 15 |
| 1 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 2 | +import 'dart:io'; | ||
| 3 | +import 'dart:typed_data'; | ||
| 4 | + | ||
| 5 | +import 'package:args/args.dart'; | ||
| 6 | +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; | ||
| 7 | + | ||
| 8 | +import './init.dart'; | ||
| 9 | + | ||
| 10 | +void main(List<String> arguments) async { | ||
| 11 | + await initSherpaOnnx(); | ||
| 12 | + | ||
| 13 | + final parser = ArgParser() | ||
| 14 | + ..addOption('model', help: 'Path to the paraformer model') | ||
| 15 | + ..addOption('tokens', help: 'Path to tokens.txt') | ||
| 16 | + ..addOption('language', | ||
| 17 | + help: 'auto, zh, en, ja, ko, yue, or leave it empty to use auto', | ||
| 18 | + defaultsTo: '') | ||
| 19 | + ..addOption('use-itn', | ||
| 20 | + help: 'true to use inverse text normalization', defaultsTo: 'false') | ||
| 21 | + ..addOption('input-wav', help: 'Path to input.wav to transcribe'); | ||
| 22 | + | ||
| 23 | + final res = parser.parse(arguments); | ||
| 24 | + if (res['model'] == null || | ||
| 25 | + res['tokens'] == null || | ||
| 26 | + res['input-wav'] == null) { | ||
| 27 | + print(parser.usage); | ||
| 28 | + exit(1); | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + final model = res['model'] as String; | ||
| 32 | + final tokens = res['tokens'] as String; | ||
| 33 | + final inputWav = res['input-wav'] as String; | ||
| 34 | + final language = res['language'] as String; | ||
| 35 | + final useItn = (res['use-itn'] as String).toLowerCase() == 'true'; | ||
| 36 | + | ||
| 37 | + final senseVoice = sherpa_onnx.OfflineSenseVoiceModelConfig( | ||
| 38 | + model: model, language: language, useInverseTextNormalization: useItn); | ||
| 39 | + | ||
| 40 | + final modelConfig = sherpa_onnx.OfflineModelConfig( | ||
| 41 | + senseVoice: senseVoice, | ||
| 42 | + tokens: tokens, | ||
| 43 | + debug: true, | ||
| 44 | + numThreads: 1, | ||
| 45 | + ); | ||
| 46 | + final config = sherpa_onnx.OfflineRecognizerConfig(model: modelConfig); | ||
| 47 | + final recognizer = sherpa_onnx.OfflineRecognizer(config); | ||
| 48 | + | ||
| 49 | + final waveData = sherpa_onnx.readWave(inputWav); | ||
| 50 | + final stream = recognizer.createStream(); | ||
| 51 | + | ||
| 52 | + stream.acceptWaveform( | ||
| 53 | + samples: waveData.samples, sampleRate: waveData.sampleRate); | ||
| 54 | + recognizer.decode(stream); | ||
| 55 | + | ||
| 56 | + final result = recognizer.getResult(stream); | ||
| 57 | + print(result.text); | ||
| 58 | + | ||
| 59 | + stream.free(); | ||
| 60 | + recognizer.free(); | ||
| 61 | +} |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +dart pub get | ||
| 6 | + | ||
| 7 | +if [ ! -f ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt ]; then | ||
| 8 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 9 | + tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 10 | + rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 11 | +fi | ||
| 12 | + | ||
| 13 | +dart run \ | ||
| 14 | + ./bin/sense-voice.dart \ | ||
| 15 | + --model ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx \ | ||
| 16 | + --tokens ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \ | ||
| 17 | + --use-itn true \ | ||
| 18 | + --input-wav ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav |
| @@ -5,7 +5,7 @@ description: > | @@ -5,7 +5,7 @@ description: > | ||
| 5 | 5 | ||
| 6 | publish_to: 'none' | 6 | publish_to: 'none' |
| 7 | 7 | ||
| 8 | -version: 1.10.16 | 8 | +version: 1.10.17 |
| 9 | 9 | ||
| 10 | topics: | 10 | topics: |
| 11 | - speech-recognition | 11 | - speech-recognition |
| @@ -30,7 +30,7 @@ dependencies: | @@ -30,7 +30,7 @@ dependencies: | ||
| 30 | record: ^5.1.0 | 30 | record: ^5.1.0 |
| 31 | url_launcher: ^6.2.6 | 31 | url_launcher: ^6.2.6 |
| 32 | 32 | ||
| 33 | - sherpa_onnx: ^1.10.16 | 33 | + sherpa_onnx: ^1.10.17 |
| 34 | # sherpa_onnx: | 34 | # sherpa_onnx: |
| 35 | # path: ../../flutter/sherpa_onnx | 35 | # path: ../../flutter/sherpa_onnx |
| 36 | 36 |
| @@ -5,7 +5,7 @@ description: > | @@ -5,7 +5,7 @@ description: > | ||
| 5 | 5 | ||
| 6 | publish_to: 'none' # Remove this line if you wish to publish to pub.dev | 6 | publish_to: 'none' # Remove this line if you wish to publish to pub.dev |
| 7 | 7 | ||
| 8 | -version: 1.10.16 | 8 | +version: 1.10.17 |
| 9 | 9 | ||
| 10 | environment: | 10 | environment: |
| 11 | sdk: '>=3.4.0 <4.0.0' | 11 | sdk: '>=3.4.0 <4.0.0' |
| @@ -17,7 +17,7 @@ dependencies: | @@ -17,7 +17,7 @@ dependencies: | ||
| 17 | cupertino_icons: ^1.0.6 | 17 | cupertino_icons: ^1.0.6 |
| 18 | path_provider: ^2.1.3 | 18 | path_provider: ^2.1.3 |
| 19 | path: ^1.9.0 | 19 | path: ^1.9.0 |
| 20 | - sherpa_onnx: ^1.10.16 | 20 | + sherpa_onnx: ^1.10.17 |
| 21 | url_launcher: ^6.2.6 | 21 | url_launcher: ^6.2.6 |
| 22 | audioplayers: ^5.0.0 | 22 | audioplayers: ^5.0.0 |
| 23 | 23 |
| @@ -79,6 +79,23 @@ class OfflineTdnnModelConfig { | @@ -79,6 +79,23 @@ class OfflineTdnnModelConfig { | ||
| 79 | final String model; | 79 | final String model; |
| 80 | } | 80 | } |
| 81 | 81 | ||
| 82 | +class OfflineSenseVoiceModelConfig { | ||
| 83 | + const OfflineSenseVoiceModelConfig({ | ||
| 84 | + this.model = '', | ||
| 85 | + this.language = '', | ||
| 86 | + this.useInverseTextNormalization = false, | ||
| 87 | + }); | ||
| 88 | + | ||
| 89 | + @override | ||
| 90 | + String toString() { | ||
| 91 | + return 'OfflineSenseVoiceModelConfig(model: $model, language: $language, useInverseTextNormalization: $useInverseTextNormalization)'; | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + final String model; | ||
| 95 | + final String language; | ||
| 96 | + final bool useInverseTextNormalization; | ||
| 97 | +} | ||
| 98 | + | ||
| 82 | class OfflineLMConfig { | 99 | class OfflineLMConfig { |
| 83 | const OfflineLMConfig({this.model = '', this.scale = 1.0}); | 100 | const OfflineLMConfig({this.model = '', this.scale = 1.0}); |
| 84 | 101 | ||
| @@ -98,6 +115,7 @@ class OfflineModelConfig { | @@ -98,6 +115,7 @@ class OfflineModelConfig { | ||
| 98 | this.nemoCtc = const OfflineNemoEncDecCtcModelConfig(), | 115 | this.nemoCtc = const OfflineNemoEncDecCtcModelConfig(), |
| 99 | this.whisper = const OfflineWhisperModelConfig(), | 116 | this.whisper = const OfflineWhisperModelConfig(), |
| 100 | this.tdnn = const OfflineTdnnModelConfig(), | 117 | this.tdnn = const OfflineTdnnModelConfig(), |
| 118 | + this.senseVoice = const OfflineSenseVoiceModelConfig(), | ||
| 101 | required this.tokens, | 119 | required this.tokens, |
| 102 | this.numThreads = 1, | 120 | this.numThreads = 1, |
| 103 | this.debug = true, | 121 | this.debug = true, |
| @@ -110,7 +128,7 @@ class OfflineModelConfig { | @@ -110,7 +128,7 @@ class OfflineModelConfig { | ||
| 110 | 128 | ||
| 111 | @override | 129 | @override |
| 112 | String toString() { | 130 | String toString() { |
| 113 | - return 'OfflineModelConfig(transducer: $transducer, paraformer: $paraformer, nemoCtc: $nemoCtc, whisper: $whisper, tdnn: $tdnn, tokens: $tokens, numThreads: $numThreads, debug: $debug, provider: $provider, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab, telespeechCtc: $telespeechCtc)'; | 131 | + return 'OfflineModelConfig(transducer: $transducer, paraformer: $paraformer, nemoCtc: $nemoCtc, whisper: $whisper, tdnn: $tdnn, senseVoice: $senseVoice, tokens: $tokens, numThreads: $numThreads, debug: $debug, provider: $provider, modelType: $modelType, modelingUnit: $modelingUnit, bpeVocab: $bpeVocab, telespeechCtc: $telespeechCtc)'; |
| 114 | } | 132 | } |
| 115 | 133 | ||
| 116 | final OfflineTransducerModelConfig transducer; | 134 | final OfflineTransducerModelConfig transducer; |
| @@ -118,6 +136,7 @@ class OfflineModelConfig { | @@ -118,6 +136,7 @@ class OfflineModelConfig { | ||
| 118 | final OfflineNemoEncDecCtcModelConfig nemoCtc; | 136 | final OfflineNemoEncDecCtcModelConfig nemoCtc; |
| 119 | final OfflineWhisperModelConfig whisper; | 137 | final OfflineWhisperModelConfig whisper; |
| 120 | final OfflineTdnnModelConfig tdnn; | 138 | final OfflineTdnnModelConfig tdnn; |
| 139 | + final OfflineSenseVoiceModelConfig senseVoice; | ||
| 121 | 140 | ||
| 122 | final String tokens; | 141 | final String tokens; |
| 123 | final int numThreads; | 142 | final int numThreads; |
| @@ -219,6 +238,14 @@ class OfflineRecognizer { | @@ -219,6 +238,14 @@ class OfflineRecognizer { | ||
| 219 | 238 | ||
| 220 | c.ref.model.tdnn.model = config.model.tdnn.model.toNativeUtf8(); | 239 | c.ref.model.tdnn.model = config.model.tdnn.model.toNativeUtf8(); |
| 221 | 240 | ||
| 241 | + c.ref.model.senseVoice.model = config.model.senseVoice.model.toNativeUtf8(); | ||
| 242 | + | ||
| 243 | + c.ref.model.senseVoice.language = | ||
| 244 | + config.model.senseVoice.language.toNativeUtf8(); | ||
| 245 | + | ||
| 246 | + c.ref.model.senseVoice.useInverseTextNormalization = | ||
| 247 | + config.model.senseVoice.useInverseTextNormalization ? 1 : 0; | ||
| 248 | + | ||
| 222 | c.ref.model.tokens = config.model.tokens.toNativeUtf8(); | 249 | c.ref.model.tokens = config.model.tokens.toNativeUtf8(); |
| 223 | 250 | ||
| 224 | c.ref.model.numThreads = config.model.numThreads; | 251 | c.ref.model.numThreads = config.model.numThreads; |
| @@ -254,6 +281,8 @@ class OfflineRecognizer { | @@ -254,6 +281,8 @@ class OfflineRecognizer { | ||
| 254 | calloc.free(c.ref.model.modelType); | 281 | calloc.free(c.ref.model.modelType); |
| 255 | calloc.free(c.ref.model.provider); | 282 | calloc.free(c.ref.model.provider); |
| 256 | calloc.free(c.ref.model.tokens); | 283 | calloc.free(c.ref.model.tokens); |
| 284 | + calloc.free(c.ref.model.senseVoice.language); | ||
| 285 | + calloc.free(c.ref.model.senseVoice.model); | ||
| 257 | calloc.free(c.ref.model.tdnn.model); | 286 | calloc.free(c.ref.model.tdnn.model); |
| 258 | calloc.free(c.ref.model.whisper.task); | 287 | calloc.free(c.ref.model.whisper.task); |
| 259 | calloc.free(c.ref.model.whisper.language); | 288 | calloc.free(c.ref.model.whisper.language); |
| @@ -87,6 +87,14 @@ final class SherpaOnnxOfflineTdnnModelConfig extends Struct { | @@ -87,6 +87,14 @@ final class SherpaOnnxOfflineTdnnModelConfig extends Struct { | ||
| 87 | external Pointer<Utf8> model; | 87 | external Pointer<Utf8> model; |
| 88 | } | 88 | } |
| 89 | 89 | ||
| 90 | +final class SherpaOnnxOfflineSenseVoiceModelConfig extends Struct { | ||
| 91 | + external Pointer<Utf8> model; | ||
| 92 | + external Pointer<Utf8> language; | ||
| 93 | + | ||
| 94 | + @Int32() | ||
| 95 | + external int useInverseTextNormalization; | ||
| 96 | +} | ||
| 97 | + | ||
| 90 | final class SherpaOnnxOfflineLMConfig extends Struct { | 98 | final class SherpaOnnxOfflineLMConfig extends Struct { |
| 91 | external Pointer<Utf8> model; | 99 | external Pointer<Utf8> model; |
| 92 | 100 | ||
| @@ -115,6 +123,8 @@ final class SherpaOnnxOfflineModelConfig extends Struct { | @@ -115,6 +123,8 @@ final class SherpaOnnxOfflineModelConfig extends Struct { | ||
| 115 | external Pointer<Utf8> modelingUnit; | 123 | external Pointer<Utf8> modelingUnit; |
| 116 | external Pointer<Utf8> bpeVocab; | 124 | external Pointer<Utf8> bpeVocab; |
| 117 | external Pointer<Utf8> telespeechCtc; | 125 | external Pointer<Utf8> telespeechCtc; |
| 126 | + | ||
| 127 | + external SherpaOnnxOfflineSenseVoiceModelConfig senseVoice; | ||
| 118 | } | 128 | } |
| 119 | 129 | ||
| 120 | final class SherpaOnnxOfflineRecognizerConfig extends Struct { | 130 | final class SherpaOnnxOfflineRecognizerConfig extends Struct { |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec |
| 20 | -version: 1.10.16 | 20 | +version: 1.10.17 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 | ||
| @@ -30,19 +30,19 @@ dependencies: | @@ -30,19 +30,19 @@ dependencies: | ||
| 30 | flutter: | 30 | flutter: |
| 31 | sdk: flutter | 31 | sdk: flutter |
| 32 | 32 | ||
| 33 | - sherpa_onnx_android: ^1.10.16 | 33 | + sherpa_onnx_android: ^1.10.17 |
| 34 | # path: ../sherpa_onnx_android | 34 | # path: ../sherpa_onnx_android |
| 35 | 35 | ||
| 36 | - sherpa_onnx_macos: ^1.10.16 | 36 | + sherpa_onnx_macos: ^1.10.17 |
| 37 | # path: ../sherpa_onnx_macos | 37 | # path: ../sherpa_onnx_macos |
| 38 | 38 | ||
| 39 | - sherpa_onnx_linux: ^1.10.16 | 39 | + sherpa_onnx_linux: ^1.10.17 |
| 40 | # path: ../sherpa_onnx_linux | 40 | # path: ../sherpa_onnx_linux |
| 41 | # | 41 | # |
| 42 | - sherpa_onnx_windows: ^1.10.16 | 42 | + sherpa_onnx_windows: ^1.10.17 |
| 43 | # path: ../sherpa_onnx_windows | 43 | # path: ../sherpa_onnx_windows |
| 44 | 44 | ||
| 45 | - sherpa_onnx_ios: ^1.10.16 | 45 | + sherpa_onnx_ios: ^1.10.17 |
| 46 | # sherpa_onnx_ios: | 46 | # sherpa_onnx_ios: |
| 47 | # path: ../sherpa_onnx_ios | 47 | # path: ../sherpa_onnx_ios |
| 48 | 48 |
| @@ -7,7 +7,7 @@ | @@ -7,7 +7,7 @@ | ||
| 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c | 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c |
| 8 | Pod::Spec.new do |s| | 8 | Pod::Spec.new do |s| |
| 9 | s.name = 'sherpa_onnx_ios' | 9 | s.name = 'sherpa_onnx_ios' |
| 10 | - s.version = '1.10.16' | 10 | + s.version = '1.10.17' |
| 11 | s.summary = 'A new Flutter FFI plugin project.' | 11 | s.summary = 'A new Flutter FFI plugin project.' |
| 12 | s.description = <<-DESC | 12 | s.description = <<-DESC |
| 13 | A new Flutter FFI plugin project. | 13 | A new Flutter FFI plugin project. |
| @@ -4,7 +4,7 @@ | @@ -4,7 +4,7 @@ | ||
| 4 | # | 4 | # |
| 5 | Pod::Spec.new do |s| | 5 | Pod::Spec.new do |s| |
| 6 | s.name = 'sherpa_onnx_macos' | 6 | s.name = 'sherpa_onnx_macos' |
| 7 | - s.version = '1.10.16' | 7 | + s.version = '1.10.17' |
| 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' | 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' |
| 9 | s.description = <<-DESC | 9 | s.description = <<-DESC |
| 10 | sherpa-onnx Flutter FFI plugin project. | 10 | sherpa-onnx Flutter FFI plugin project. |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec |
| 20 | -version: 1.10.16 | 20 | +version: 1.10.17 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 |
| @@ -13,14 +13,15 @@ namespace sherpa_onnx { | @@ -13,14 +13,15 @@ namespace sherpa_onnx { | ||
| 13 | 13 | ||
| 14 | void CudaConfig::Register(ParseOptions *po) { | 14 | void CudaConfig::Register(ParseOptions *po) { |
| 15 | po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, | 15 | po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, |
| 16 | - "CuDNN convolution algrorithm search"); | 16 | + "CuDNN convolution algrorithm search"); |
| 17 | } | 17 | } |
| 18 | 18 | ||
| 19 | bool CudaConfig::Validate() const { | 19 | bool CudaConfig::Validate() const { |
| 20 | if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { | 20 | if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { |
| 21 | - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." | ||
| 22 | - "Options : [1,3]. Check OnnxRT docs", | ||
| 23 | - cudnn_conv_algo_search); | 21 | + SHERPA_ONNX_LOGE( |
| 22 | + "cudnn_conv_algo_search: '%d' is not a valid option." | ||
| 23 | + "Options : [1,3]. Check OnnxRT docs", | ||
| 24 | + cudnn_conv_algo_search); | ||
| 24 | return false; | 25 | return false; |
| 25 | } | 26 | } |
| 26 | return true; | 27 | return true; |
| @@ -37,41 +38,41 @@ std::string CudaConfig::ToString() const { | @@ -37,41 +38,41 @@ std::string CudaConfig::ToString() const { | ||
| 37 | 38 | ||
| 38 | void TensorrtConfig::Register(ParseOptions *po) { | 39 | void TensorrtConfig::Register(ParseOptions *po) { |
| 39 | po->Register("trt-max-workspace-size", &trt_max_workspace_size, | 40 | po->Register("trt-max-workspace-size", &trt_max_workspace_size, |
| 40 | - "Set TensorRT EP GPU memory usage limit."); | 41 | + "Set TensorRT EP GPU memory usage limit."); |
| 41 | po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, | 42 | po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, |
| 42 | - "Limit partitioning iterations for model conversion."); | 43 | + "Limit partitioning iterations for model conversion."); |
| 43 | po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, | 44 | po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, |
| 44 | - "Set minimum size for subgraphs in partitioning."); | 45 | + "Set minimum size for subgraphs in partitioning."); |
| 45 | po->Register("trt-fp16-enable", &trt_fp16_enable, | 46 | po->Register("trt-fp16-enable", &trt_fp16_enable, |
| 46 | - "Enable FP16 precision for faster performance."); | 47 | + "Enable FP16 precision for faster performance."); |
| 47 | po->Register("trt-detailed-build-log", &trt_detailed_build_log, | 48 | po->Register("trt-detailed-build-log", &trt_detailed_build_log, |
| 48 | - "Enable detailed logging of build steps."); | 49 | + "Enable detailed logging of build steps."); |
| 49 | po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, | 50 | po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, |
| 50 | - "Enable caching of TensorRT engines."); | 51 | + "Enable caching of TensorRT engines."); |
| 51 | po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, | 52 | po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, |
| 52 | - "Enable use of timing cache to speed up builds."); | 53 | + "Enable use of timing cache to speed up builds."); |
| 53 | po->Register("trt-engine-cache-path", &trt_engine_cache_path, | 54 | po->Register("trt-engine-cache-path", &trt_engine_cache_path, |
| 54 | - "Set path to store cached TensorRT engines."); | 55 | + "Set path to store cached TensorRT engines."); |
| 55 | po->Register("trt-timing-cache-path", &trt_timing_cache_path, | 56 | po->Register("trt-timing-cache-path", &trt_timing_cache_path, |
| 56 | - "Set path for storing timing cache."); | 57 | + "Set path for storing timing cache."); |
| 57 | po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, | 58 | po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, |
| 58 | - "Dump optimized subgraphs for debugging."); | 59 | + "Dump optimized subgraphs for debugging."); |
| 59 | } | 60 | } |
| 60 | 61 | ||
| 61 | bool TensorrtConfig::Validate() const { | 62 | bool TensorrtConfig::Validate() const { |
| 62 | if (trt_max_workspace_size < 0) { | 63 | if (trt_max_workspace_size < 0) { |
| 63 | - SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.", | ||
| 64 | - trt_max_workspace_size); | 64 | + SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.", |
| 65 | + trt_max_workspace_size); | ||
| 65 | return false; | 66 | return false; |
| 66 | } | 67 | } |
| 67 | if (trt_max_partition_iterations < 0) { | 68 | if (trt_max_partition_iterations < 0) { |
| 68 | SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", | 69 | SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", |
| 69 | - trt_max_partition_iterations); | 70 | + trt_max_partition_iterations); |
| 70 | return false; | 71 | return false; |
| 71 | } | 72 | } |
| 72 | if (trt_min_subgraph_size < 0) { | 73 | if (trt_min_subgraph_size < 0) { |
| 73 | SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", | 74 | SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", |
| 74 | - trt_min_subgraph_size); | 75 | + trt_min_subgraph_size); |
| 75 | return false; | 76 | return false; |
| 76 | } | 77 | } |
| 77 | 78 | ||
| @@ -83,23 +84,19 @@ std::string TensorrtConfig::ToString() const { | @@ -83,23 +84,19 @@ std::string TensorrtConfig::ToString() const { | ||
| 83 | 84 | ||
| 84 | os << "TensorrtConfig("; | 85 | os << "TensorrtConfig("; |
| 85 | os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; | 86 | os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; |
| 86 | - os << "trt_max_partition_iterations=" | ||
| 87 | - << trt_max_partition_iterations << ", "; | 87 | + os << "trt_max_partition_iterations=" << trt_max_partition_iterations << ", "; |
| 88 | os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; | 88 | os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; |
| 89 | - os << "trt_fp16_enable=\"" | ||
| 90 | - << (trt_fp16_enable? "True" : "False") << "\", "; | 89 | + os << "trt_fp16_enable=\"" << (trt_fp16_enable ? "True" : "False") << "\", "; |
| 91 | os << "trt_detailed_build_log=\"" | 90 | os << "trt_detailed_build_log=\"" |
| 92 | - << (trt_detailed_build_log? "True" : "False") << "\", "; | 91 | + << (trt_detailed_build_log ? "True" : "False") << "\", "; |
| 93 | os << "trt_engine_cache_enable=\"" | 92 | os << "trt_engine_cache_enable=\"" |
| 94 | - << (trt_engine_cache_enable? "True" : "False") << "\", "; | ||
| 95 | - os << "trt_engine_cache_path=\"" | ||
| 96 | - << trt_engine_cache_path.c_str() << "\", "; | 93 | + << (trt_engine_cache_enable ? "True" : "False") << "\", "; |
| 94 | + os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", "; | ||
| 97 | os << "trt_timing_cache_enable=\"" | 95 | os << "trt_timing_cache_enable=\"" |
| 98 | - << (trt_timing_cache_enable? "True" : "False") << "\", "; | ||
| 99 | - os << "trt_timing_cache_path=\"" | ||
| 100 | - << trt_timing_cache_path.c_str() << "\","; | ||
| 101 | - os << "trt_dump_subgraphs=\"" | ||
| 102 | - << (trt_dump_subgraphs? "True" : "False") << "\" )"; | 96 | + << (trt_timing_cache_enable ? "True" : "False") << "\", "; |
| 97 | + os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\","; | ||
| 98 | + os << "trt_dump_subgraphs=\"" << (trt_dump_subgraphs ? "True" : "False") | ||
| 99 | + << "\" )"; | ||
| 103 | return os.str(); | 100 | return os.str(); |
| 104 | } | 101 | } |
| 105 | 102 |
-
请 注册 或 登录 后发表评论