#!/usr/bin/env python3
# Copyright    2025  Xiaomi Corp.        (authors: Fangjun Kuang)


import re
import time
from typing import Dict, List

import jieba
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
from misaki import zh

try:
    from piper_phonemize import phonemize_espeak
except Exception as ex:
    raise RuntimeError(
        f"{ex}\nPlease run\n"
        "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
    )


def show(filename):
    session_opts = ort.SessionOptions()
    session_opts.log_severity_level = 3
    sess = ort.InferenceSession(filename, session_opts)
    for i in sess.get_inputs():
        print(i)

    print("-----")

    for i in sess.get_outputs():
        print(i)


"""
NodeArg(name='tokens', type='tensor(int64)', shape=[1, 'sequence_length'])
NodeArg(name='style', type='tensor(float)', shape=[1, 256])
NodeArg(name='speed', type='tensor(float)', shape=[1])
-----
NodeArg(name='audio', type='tensor(float)', shape=['audio_length'])
"""


def load_voices(speaker_names: List[str], dim: List[int], voices_bin: str):
    embedding = (
        np.fromfile(voices_bin, dtype="uint8")
        .view(np.float32)
        .reshape(len(speaker_names), *dim)
    )
    print("embedding.shape", embedding.shape)
    ans = dict()
    for i in range(len(speaker_names)):
        ans[speaker_names[i]] = embedding[i]

    return ans


def load_tokens(filename: str) -> Dict[str, int]:
    ans = dict()
    with open(filename, encoding="utf-8") as f:
        for line in f:
            fields = line.strip().split()
            if len(fields) == 2:
                token, idx = fields
                ans[token] = int(idx)
            else:
                assert len(fields) == 1, (len(fields), line)
                ans[" "] = int(fields[0])
    return ans


def load_lexicon(filename: str) -> Dict[str, List[str]]:
    ans = dict()
    for lexicon in filename.split(","):
        print(lexicon)
        with open(lexicon, encoding="utf-8") as f:
            for line in f:
                w, tokens = line.strip().split(" ", maxsplit=1)
                ans[w] = "".join(tokens.split())
    return ans


class OnnxModel:
    def __init__(self, model_filename: str, tokens: str, lexicon: str, voices_bin: str):
        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 1

        self.session_opts = session_opts
        self.model = ort.InferenceSession(
            model_filename,
            sess_options=self.session_opts,
            providers=["CPUExecutionProvider"],
        )
        self.token2id = load_tokens(tokens)
        self.word2tokens = load_lexicon(lexicon)

        meta = self.model.get_modelmeta().custom_metadata_map
        print(meta)
        dim = list(map(int, meta["style_dim"].split(",")))
        speaker_names = meta["speaker_names"].split(",")
        self.voices = load_voices(
            speaker_names=speaker_names, dim=dim, voices_bin=voices_bin
        )
        self.sample_rate = int(meta["sample_rate"])
        print(list(self.voices.keys()))

        self.sample_rate = 24000
        self.max_len = self.voices[next(iter(self.voices))].shape[0] - 1

    def __call__(self, text: str, voice: str):
        punctuations = ';:,.!?-…()"“”'
        text = text.lower()
        g2p = zh.ZHG2P()

        tokens = ""

        for t in re.findall("[\u4E00-\u9FFF]+|[\u0000-\u007f]+", text):
            if ord(t[0]) < 0x7F:
                for w in t.split():
                    while w:
                        if w[0] in punctuations:
                            tokens += w[0] + " "
                            w = w[1:]
                            continue

                        if w[-1] in punctuations:
                            if w[:-1] in self.word2tokens:
                                tokens += self.word2tokens[w[:-1]]
                                tokens += w[-1]
                        else:
                            if w in self.word2tokens:
                                tokens += self.word2tokens[w]
                            else:
                                print(f"Use espeak-ng for word {w}")
                                tokens += "".join(phonemize_espeak(w, "en-us")[0])

                        tokens += " "
                        break
            else:
                # Chinese
                for w in jieba.cut(t):
                    if w in self.word2tokens:
                        tokens += self.word2tokens[w]
                    else:
                        for i in w:
                            if i in self.word2tokens:
                                tokens += self.word2tokens[i]
                            else:
                                print(f"skip {i}")

        token_ids = [self.token2id[i] for i in tokens]
        token_ids = token_ids[: self.max_len]

        style = self.voices[voice][len(token_ids)]

        token_ids = [0, *token_ids, 0]
        token_ids = np.array([token_ids], dtype=np.int64)

        speed = np.array([1.0], dtype=np.float32)

        audio = self.model.run(
            [
                self.model.get_outputs()[0].name,
            ],
            {
                self.model.get_inputs()[0].name: token_ids,
                self.model.get_inputs()[1].name: style,
                self.model.get_inputs()[2].name: speed,
            },
        )[0]
        return audio


def main():
    m = OnnxModel(
        model_filename="./kokoro.onnx",
        tokens="./tokens.txt",
        lexicon="./lexicon-gb-en.txt,./lexicon-zh.txt",
        voices_bin="./voices.bin",
    )
    text = "来听一听, 这个是什么口音? How are you doing? Are you ok? Thank you! 你觉得中英文说得如何呢?"

    text = text.lower()

    voice = "bf_alice"
    start = time.time()
    audio = m(text, voice=voice)
    end = time.time()

    elapsed_seconds = end - start
    audio_duration = len(audio) / m.sample_rate
    real_time_factor = elapsed_seconds / audio_duration

    filename = f"kokoro_v1.0_{voice}_zh_en.wav"
    sf.write(
        filename,
        audio,
        samplerate=m.sample_rate,
        subtype="PCM_16",
    )
    print(f" Saved to {filename}")
    print(f" Elapsed seconds: {elapsed_seconds:.3f}")
    print(f" Audio duration in seconds: {audio_duration:.3f}")
    print(f" RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")


if __name__ == "__main__":
    main()