Fangjun Kuang
Committed by GitHub

Support writing generated audio samples to wave files (#363)

@@ -93,6 +93,7 @@ list(APPEND sources @@ -93,6 +93,7 @@ list(APPEND sources
93 offline-tts-vits-model-config.cc 93 offline-tts-vits-model-config.cc
94 offline-tts-vits-model.cc 94 offline-tts-vits-model.cc
95 offline-tts.cc 95 offline-tts.cc
  96 + wave-writer.cc
96 ) 97 )
97 98
98 if(SHERPA_ONNX_ENABLE_CHECK) 99 if(SHERPA_ONNX_ENABLE_CHECK)
@@ -53,7 +53,7 @@ static std::unordered_map<std::string, int32_t> ReadTokens( @@ -53,7 +53,7 @@ static std::unordered_map<std::string, int32_t> ReadTokens(
53 exit(-1); 53 exit(-1);
54 } 54 }
55 #endif 55 #endif
56 - token2id.insert({sym, id}); 56 + token2id.insert({std::move(sym), id});
57 } 57 }
58 58
59 return token2id; 59 return token2id;
@@ -78,6 +78,7 @@ static std::vector<int32_t> ConvertTokensToIds( @@ -78,6 +78,7 @@ static std::vector<int32_t> ConvertTokensToIds(
78 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, 78 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
79 const std::string &punctuations) { 79 const std::string &punctuations) {
80 token2id_ = ReadTokens(tokens); 80 token2id_ = ReadTokens(tokens);
  81 + blank_ = token2id_.at(" ");
81 std::ifstream is(lexicon); 82 std::ifstream is(lexicon);
82 83
83 std::string word; 84 std::string word;
@@ -149,6 +150,11 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIds( @@ -149,6 +150,11 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
149 ans.insert(ans.end(), prefix.begin(), prefix.end()); 150 ans.insert(ans.end(), prefix.begin(), prefix.end());
150 ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 151 ans.insert(ans.end(), token_ids.begin(), token_ids.end());
151 ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); 152 ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
  153 + ans.push_back(blank_);
  154 + }
  155 +
  156 + if (!ans.empty()) {
  157 + ans.resize(ans.size() - 1);
152 } 158 }
153 159
154 return ans; 160 return ans;
@@ -24,6 +24,7 @@ class Lexicon { @@ -24,6 +24,7 @@ class Lexicon {
24 std::unordered_map<std::string, std::vector<int32_t>> word2ids_; 24 std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
25 std::unordered_set<std::string> punctuations_; 25 std::unordered_set<std::string> punctuations_;
26 std::unordered_map<std::string, int32_t> token2id_; 26 std::unordered_map<std::string, int32_t> token2id_;
  27 + int32_t blank_; // ID for the blank token
27 }; 28 };
28 29
29 } // namespace sherpa_onnx 30 } // namespace sherpa_onnx
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include "sherpa-onnx/csrc/offline-tts.h" 7 #include "sherpa-onnx/csrc/offline-tts.h"
8 #include "sherpa-onnx/csrc/parse-options.h" 8 #include "sherpa-onnx/csrc/parse-options.h"
  9 +#include "sherpa-onnx/csrc/wave-writer.h"
9 10
10 int main(int32_t argc, char *argv[]) { 11 int main(int32_t argc, char *argv[]) {
11 const char *kUsageMessage = R"usage( 12 const char *kUsageMessage = R"usage(
@@ -15,13 +16,34 @@ Offline text-to-speech with sherpa-onnx @@ -15,13 +16,34 @@ Offline text-to-speech with sherpa-onnx
15 --vits-model /path/to/model.onnx \ 16 --vits-model /path/to/model.onnx \
16 --vits-lexicon /path/to/lexicon.txt \ 17 --vits-lexicon /path/to/lexicon.txt \
17 --vits-tokens /path/to/tokens.txt 18 --vits-tokens /path/to/tokens.txt
  19 + --output-filename ./generated.wav \
18 'some text within single quotes' 20 'some text within single quotes'
19 21
20 -It will generate a file test.wav. 22 +It will generate a file ./generated.wav as specified by --output-filename.
  23 +
  24 +You can download a test model from
  25 +https://huggingface.co/csukuangfj/vits-ljs
  26 +
  27 +For instance, you can use:
  28 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
  29 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
  30 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
  31 +
  32 +./bin/sherpa-onnx-offline-tts \
  33 + --vits-model=./vits-ljs.onnx \
  34 + --vits-lexicon=./lexicon.txt \
  35 + --vits-tokens=./tokens.txt \
  36 + --output-filename=./generated.wav \
  37 + 'liliana, the most beautiful and lovely assistant of our team!'
21 )usage"; 38 )usage";
22 39
23 sherpa_onnx::ParseOptions po(kUsageMessage); 40 sherpa_onnx::ParseOptions po(kUsageMessage);
  41 + std::string output_filename = "./generated.wav";
  42 + po.Register("output-filename", &output_filename,
  43 + "Path to save the generated audio");
  44 +
24 sherpa_onnx::OfflineTtsConfig config; 45 sherpa_onnx::OfflineTtsConfig config;
  46 +
25 config.Register(&po); 47 config.Register(&po);
26 po.Read(argc, argv); 48 po.Read(argc, argv);
27 49
@@ -47,11 +69,15 @@ It will generate a file test.wav. @@ -47,11 +69,15 @@ It will generate a file test.wav.
47 sherpa_onnx::OfflineTts tts(config); 69 sherpa_onnx::OfflineTts tts(config);
48 auto audio = tts.Generate(po.GetArg(1)); 70 auto audio = tts.Generate(po.GetArg(1));
49 71
50 - std::ofstream os("t.pcm", std::ios::binary);  
51 - os.write(reinterpret_cast<const char *>(audio.samples.data()),  
52 - sizeof(float) * audio.samples.size()); 72 + bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
  73 + audio.samples.data(), audio.samples.size());
  74 + if (!ok) {
  75 + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str());
  76 + exit(EXIT_FAILURE);
  77 + }
53 78
54 - // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav 79 + fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
  80 + fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
55 81
56 return 0; 82 return 0;
57 } 83 }
  1 +// sherpa-onnx/csrc/wave-writer.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/wave-writer.h"
  6 +
  7 +#include <fstream>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +namespace {
  15 +
  16 +// see http://soundfile.sapp.org/doc/WaveFormat/
  17 +//
  18 +// Note: We assume little endian here
  19 +// TODO(fangjun): Support big endian
  20 +struct WaveHeader {
  21 + int32_t chunk_id;
  22 + int32_t chunk_size;
  23 + int32_t format;
  24 + int32_t subchunk1_id;
  25 + int32_t subchunk1_size;
  26 + int16_t audio_format;
  27 + int16_t num_channels;
  28 + int32_t sample_rate;
  29 + int32_t byte_rate;
  30 + int16_t block_align;
  31 + int16_t bits_per_sample;
  32 + int32_t subchunk2_id; // a tag of this chunk
  33 + int32_t subchunk2_size; // size of subchunk2
  34 +};
  35 +
  36 +} // namespace
  37 +
  38 +bool WriteWave(const std::string &filename, int32_t sampling_rate,
  39 + const float *samples, int32_t n) {
  40 + WaveHeader header;
  41 + header.chunk_id = 0x46464952; // FFIR
  42 + header.format = 0x45564157; // EVAW
  43 + header.subchunk1_id = 0x20746d66; // "fmt "
  44 + header.subchunk1_size = 16; // 16 for PCM
  45 + header.audio_format = 1; // PCM =1
  46 +
  47 + int32_t num_channels = 1;
  48 + int32_t bits_per_sample = 16; // int16_t
  49 + header.num_channels = num_channels;
  50 + header.sample_rate = sampling_rate;
  51 + header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8;
  52 + header.block_align = num_channels * bits_per_sample / 8;
  53 + header.bits_per_sample = bits_per_sample;
  54 + header.subchunk2_id = 0x61746164; // atad
  55 + header.subchunk2_size = n * num_channels * bits_per_sample / 8;
  56 +
  57 + header.chunk_size = 36 + header.subchunk2_size;
  58 +
  59 + std::vector<int16_t> samples_int16(n);
  60 + for (int32_t i = 0; i != n; ++i) {
  61 + samples_int16[i] = samples[i] * 32676;
  62 + }
  63 +
  64 + std::ofstream os(filename, std::ios::binary);
  65 + if (!os) {
  66 + SHERPA_ONNX_LOGE("Failed to create %s", filename.c_str());
  67 + return false;
  68 + }
  69 +
  70 + os.write(reinterpret_cast<const char *>(&header), sizeof(header));
  71 + os.write(reinterpret_cast<const char *>(samples_int16.data()),
  72 + samples_int16.size() * sizeof(int16_t));
  73 +
  74 + if (!os) {
  75 + SHERPA_ONNX_LOGE("Write %s failed", filename.c_str());
  76 + return false;
  77 + }
  78 +
  79 + return true;
  80 +}
  81 +
  82 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/wave-writer.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_WAVE_WRITER_H_
  6 +#define SHERPA_ONNX_CSRC_WAVE_WRITER_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// Write a single channel wave file.
  14 +// Note that the input samples are in the range [-1, 1]. It will be multiplied
  15 +// by 32767 and saved in int16_t format in the wave file.
  16 +//
  17 +// @param filename Path to save the samples.
  18 +// @param sampling_rate Sample rate of the samples.
  19 +// @param samples Pointer to the samples
  20 +// @param n Number of samples
  21 +// @return Return true if the write succeeds; return false otherwise.
  22 +bool WriteWave(const std::string &filename, int32_t sampling_rate,
  23 + const float *samples, int32_t n);
  24 +
  25 +} // namespace sherpa_onnx
  26 +
  27 +#endif // SHERPA_ONNX_CSRC_WAVE_WRITER_H_