Fangjun Kuang
Committed by GitHub

Validate input sid (#369)

... ... @@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is ON on Linux" ON)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
... ...
... ... @@ -124,6 +124,11 @@ def main():
start = time.time()
audio = tts.generate(args.text, sid=args.sid)
end = time.time()
if len(audio.samples) == 0:
print("Error in generating audios. Please read previous error messages.")
return
elapsed_seconds = end - start
audio_duration = len(audio.samples) / audio.sample_rate
real_time_factor = elapsed_seconds / audio_duration
... ...
... ... @@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
std::vector<int64_t> ans;
ans.push_back(token2id_.at("sil"));
auto sil = token2id_.at("sil");
auto eos = token2id_.at("eos");
ans.push_back(sil);
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(sil);
continue;
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
... ... @@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
ans.push_back(sil);
ans.push_back(eos);
return ans;
}
... ... @@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
ToLowerCase(&text);
std::vector<std::string> words = SplitUtf8(text);
int32_t blank = token2id_.at(" ");
std::vector<int64_t> ans;
for (const auto &w : words) {
... ... @@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
ans.push_back(blank);
}
if (blank_ != -1 && !ans.empty()) {
if (!ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}
... ... @@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}
void Lexicon::InitLanguage(const std::string &_lang) {
... ...
... ... @@ -44,7 +44,6 @@ class Lexicon {
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
int32_t blank_ = -1; // ID for the blank token
Language language_;
//
};
... ...
... ... @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
int32_t num_speakers = model_->NumSpeakers();
if (num_speakers == 0 && sid != 0) {
SHERPA_ONNX_LOGE(
"This is a single-speaker model and supports only sid 0. Given sid: "
"%d",
sid);
return {};
}
if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
SHERPA_ONNX_LOGE(
"This model contains only %d speakers. sid should be in the range "
"[%d, %d]. Given: %d",
num_speakers, 0, num_speakers - 1, sid);
return {};
}
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
... ...
... ... @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl {
std::string Punctuations() const { return punctuations_; }
std::string Language() const { return language_; }
int32_t NumSpeakers() const { return num_speakers_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ... @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
}
... ... @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl {
int32_t sample_rate_;
int32_t add_blank_;
int32_t n_speakers_;
int32_t num_speakers_;
std::string punctuations_;
std::string language_;
};
... ... @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const {
std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
int32_t OfflineTtsVitsModel::NumSpeakers() const {
return impl_->NumSpeakers();
}
} // namespace sherpa_onnx
... ...
... ... @@ -39,6 +39,7 @@ class OfflineTtsVitsModel {
std::string Punctuations() const;
std::string Language() const;
int32_t NumSpeakers() const;
private:
class Impl;
... ...
... ... @@ -81,6 +81,12 @@ or detailes.
sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1), sid);
if (audio.samples.empty()) {
fprintf(
stderr,
"Error in generating audios. Please read previous error messages.\n");
exit(EXIT_FAILURE);
}
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
... ...