Aero
Committed by GitHub

Add isolate_tts demo (#1529)

import 'dart:io';
import 'dart:isolate';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:media_kit/media_kit.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import 'utils.dart';
class _IsolateTask<T> {
final SendPort sendPort;
RootIsolateToken? rootIsolateToken;
_IsolateTask(this.sendPort, this.rootIsolateToken);
}
class _PortModel {
final String method;
final SendPort? sendPort;
dynamic data;
_PortModel({
required this.method,
this.sendPort,
this.data,
});
}
class _TtsManager {
/// 主进程通信端口
final ReceivePort receivePort;
final Isolate isolate;
final SendPort isolatePort;
_TtsManager({
required this.receivePort,
required this.isolate,
required this.isolatePort,
});
}
class IsolateTts {
static late final _TtsManager _ttsManager;
/// 获取线程里的通信端口
static SendPort get _sendPort => _ttsManager.isolatePort;
static late sherpa_onnx.OfflineTts _tts;
static late Player _player;
static Future<void> init() async {
ReceivePort port = ReceivePort();
RootIsolateToken? rootIsolateToken = RootIsolateToken.instance;
Isolate isolate = await Isolate.spawn(
_isolateEntry,
_IsolateTask(port.sendPort, rootIsolateToken),
errorsAreFatal: false,
);
port.listen((msg) async {
if (msg is SendPort) {
print(11);
_ttsManager = _TtsManager(receivePort: port, isolate: isolate, isolatePort: msg);
return;
}
});
}
static Future<void> _isolateEntry(_IsolateTask task) async {
if (task.rootIsolateToken != null) {
BackgroundIsolateBinaryMessenger.ensureInitialized(task.rootIsolateToken!);
}
MediaKit.ensureInitialized();
_player = Player();
sherpa_onnx.initBindings();
final receivePort = ReceivePort();
task.sendPort.send(receivePort.sendPort);
String modelDir = '';
String modelName = '';
String ruleFsts = '';
String ruleFars = '';
String lexicon = '';
String dataDir = '';
String dictDir = '';
// Example 7
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2
modelDir = 'vits-melo-tts-zh_en';
modelName = 'model.onnx';
lexicon = 'lexicon.txt';
dictDir = 'vits-melo-tts-zh_en/dict';
if (modelName == '') {
throw Exception('You are supposed to select a model by changing the code before you run the app');
}
final Directory directory = await getApplicationDocumentsDirectory();
modelName = p.join(directory.path, modelDir, modelName);
if (ruleFsts != '') {
final all = ruleFsts.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, f));
}
ruleFsts = tmp.join(',');
}
if (ruleFars != '') {
final all = ruleFars.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, f));
}
ruleFars = tmp.join(',');
}
if (lexicon != '') {
lexicon = p.join(directory.path, modelDir, lexicon);
}
if (dataDir != '') {
dataDir = p.join(directory.path, dataDir);
}
if (dictDir != '') {
dictDir = p.join(directory.path, dictDir);
}
final tokens = p.join(directory.path, modelDir, 'tokens.txt');
final vits = sherpa_onnx.OfflineTtsVitsModelConfig(
model: modelName,
lexicon: lexicon,
tokens: tokens,
dataDir: dataDir,
dictDir: dictDir,
);
final modelConfig = sherpa_onnx.OfflineTtsModelConfig(
vits: vits,
numThreads: 2,
debug: true,
provider: 'cpu',
);
final config = sherpa_onnx.OfflineTtsConfig(
model: modelConfig,
ruleFsts: ruleFsts,
ruleFars: ruleFars,
maxNumSenetences: 1,
);
// print(config);
receivePort.listen((msg) async {
print(msg);
if (msg is _PortModel) {
switch (msg.method) {
case 'generate':
{
_PortModel _v = msg;
final stopwatch = Stopwatch();
stopwatch.start();
final audio = _tts.generate(text: _v.data['text'], sid: _v.data['sid'], speed: _v.data['speed']);
final suffix = '-sid-${_v.data['sid']}-speed-${_v.data['sid'].toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave(
filename: filename,
samples: audio.samples,
sampleRate: audio.sampleRate,
);
if (ok) {
stopwatch.stop();
double elapsed = stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration = audio.samples.length.toDouble() / audio.sampleRate.toDouble();
print('Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ');
await _player.open(Media('file:///$filename'));
await _player.play();
}
}
break;
}
}
});
_tts = sherpa_onnx.OfflineTts(config);
}
static Future<void> generate({required String text, int sid = 0, double speed = 1.0}) async {
ReceivePort receivePort = ReceivePort();
_sendPort.send(_PortModel(
method: 'generate',
data: {'text': text, 'sid': sid, 'speed': speed},
sendPort: receivePort.sendPort,
));
await receivePort.first;
receivePort.close();
}
}
/// 这里是页面
class IsolateTtsView extends StatefulWidget {
const IsolateTtsView({super.key});
@override
State<IsolateTtsView> createState() => _IsolateTtsViewState();
}
class _IsolateTtsViewState extends State<IsolateTtsView> {
@override
void initState() {
super.initState();
IsolateTts.init();
}
@override
Widget build(BuildContext context) {
return Scaffold(
body: Center(
child: ElevatedButton(
onPressed: () {
IsolateTts.generate(text: '这是已退出的 isolate TTS');
},
child: Text('Isolate TTS'),
),
),
);
}
}
... ...
// Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart';
import './tts.dart';
import './info.dart';
import './tts.dart';
import 'isolate_tts.dart';
void main() {
runApp(const MyApp());
... ... @@ -38,6 +39,7 @@ class _MyHomePageState extends State<MyHomePage> {
final List<Widget> _tabs = [
TtsScreen(),
InfoScreen(),
IsolateTtsView(),
];
@override
Widget build(BuildContext context) {
... ... @@ -62,6 +64,10 @@ class _MyHomePageState extends State<MyHomePage> {
icon: Icon(Icons.info),
label: 'Info',
),
BottomNavigationBarItem(
icon: Icon(Icons.multiline_chart),
label: 'isolate',
),
],
),
);
... ...
... ... @@ -79,17 +79,16 @@ Future<sherpa_onnx.OfflineTts> createOfflineTts() async {
// Example 7
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2
// modelDir = 'vits-melo-tts-zh_en';
// modelName = 'model.onnx';
// lexicon = 'lexicon.txt';
// dictDir = 'vits-melo-tts-zh_en/dict';
modelDir = 'vits-melo-tts-zh_en';
modelName = 'model.onnx';
lexicon = 'lexicon.txt';
dictDir = 'vits-melo-tts-zh_en/dict';
// ============================================================
// Please don't change the remaining part of this function
// ============================================================
if (modelName == '') {
throw Exception(
'You are supposed to select a model by changing the code before you run the app');
throw Exception('You are supposed to select a model by changing the code before you run the app');
}
final Directory directory = await getApplicationDocumentsDirectory();
... ...
... ... @@ -77,9 +77,7 @@ class _TtsScreenState extends State<TtsScreen> {
onTapOutside: (PointerDownEvent event) {
FocusManager.instance.primaryFocus?.unfocus();
},
inputFormatters: <TextInputFormatter>[
FilteringTextInputFormatter.digitsOnly
]),
inputFormatters: <TextInputFormatter>[FilteringTextInputFormatter.digitsOnly]),
Slider(
// decoration: InputDecoration(
// labelText: "speech speed",
... ... @@ -108,125 +106,117 @@ class _TtsScreenState extends State<TtsScreen> {
},
),
const SizedBox(height: 5),
Row(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
OutlinedButton(
child: Text("Generate"),
onPressed: () async {
await _init();
await _player?.stop();
setState(() {
_maxSpeakerID = _tts?.numSpeakers ?? 0;
if (_maxSpeakerID > 0) {
_maxSpeakerID -= 1;
}
});
if (_tts == null) {
_controller_hint.value = TextEditingValue(
text: 'Failed to initialize tts',
);
return;
}
_controller_hint.value = TextEditingValue(
text: '',
);
final text = _controller_text_input.text.trim();
if (text == '') {
_controller_hint.value = TextEditingValue(
text: 'Please first input your text to generate',
);
return;
}
final sid =
int.tryParse(_controller_sid.text.trim()) ?? 0;
final stopwatch = Stopwatch();
stopwatch.start();
final audio =
_tts!.generate(text: text, sid: sid, speed: _speed);
final suffix =
'-sid-$sid-speed-${_speed.toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave(
filename: filename,
samples: audio.samples,
sampleRate: audio.sampleRate,
);
if (ok) {
stopwatch.stop();
double elapsed =
stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration =
audio.samples.length.toDouble() /
audio.sampleRate.toDouble();
_controller_hint.value = TextEditingValue(
text: 'Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ',
);
_lastFilename = filename;
await _player?.play(DeviceFileSource(_lastFilename));
} else {
_controller_hint.value = TextEditingValue(
text: 'Failed to save generated audio',
);
}
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Clear"),
onPressed: () {
_controller_text_input.value = TextEditingValue(
text: '',
);
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Play"),
onPressed: () async {
if (_lastFilename == '') {
_controller_hint.value = TextEditingValue(
text: 'No generated wave file found',
);
return;
}
await _player?.stop();
await _player?.play(DeviceFileSource(_lastFilename));
_controller_hint.value = TextEditingValue(
text: 'Playing\n$_lastFilename',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Stop"),
onPressed: () async {
await _player?.stop();
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
]),
Row(mainAxisAlignment: MainAxisAlignment.center, children: <Widget>[
OutlinedButton(
child: Text("Generate"),
onPressed: () async {
await _init();
await _player?.stop();
setState(() {
_maxSpeakerID = _tts?.numSpeakers ?? 0;
if (_maxSpeakerID > 0) {
_maxSpeakerID -= 1;
}
});
if (_tts == null) {
_controller_hint.value = TextEditingValue(
text: 'Failed to initialize tts',
);
return;
}
_controller_hint.value = TextEditingValue(
text: '',
);
final text = _controller_text_input.text.trim();
if (text == '') {
_controller_hint.value = TextEditingValue(
text: 'Please first input your text to generate',
);
return;
}
final sid = int.tryParse(_controller_sid.text.trim()) ?? 0;
final stopwatch = Stopwatch();
stopwatch.start();
final audio = _tts!.generate(text: text, sid: sid, speed: _speed);
final suffix = '-sid-$sid-speed-${_speed.toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave(
filename: filename,
samples: audio.samples,
sampleRate: audio.sampleRate,
);
if (ok) {
stopwatch.stop();
double elapsed = stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration = audio.samples.length.toDouble() / audio.sampleRate.toDouble();
_controller_hint.value = TextEditingValue(
text: 'Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ',
);
_lastFilename = filename;
await _player?.play(DeviceFileSource(_lastFilename));
} else {
_controller_hint.value = TextEditingValue(
text: 'Failed to save generated audio',
);
}
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Clear"),
onPressed: () {
_controller_text_input.value = TextEditingValue(
text: '',
);
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Play"),
onPressed: () async {
if (_lastFilename == '') {
_controller_hint.value = TextEditingValue(
text: 'No generated wave file found',
);
return;
}
await _player?.stop();
await _player?.play(DeviceFileSource(_lastFilename));
_controller_hint.value = TextEditingValue(
text: 'Playing\n$_lastFilename',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Stop"),
onPressed: () async {
await _player?.stop();
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
]),
const SizedBox(height: 5),
TextField(
decoration: InputDecoration(
... ...
... ... @@ -24,6 +24,12 @@ dependencies:
url_launcher: 6.2.6
url_launcher_linux: 3.1.0
audioplayers: ^5.0.0
media_kit:
media_kit_libs_video:
flutter:
uses-material-design: true
assets:
- assets/vits-melo-tts-zh_en/
- assets/vits-melo-tts-zh_en/dict/
\ No newline at end of file
... ...