Committed by
GitHub
Support RK NPU for SenseVoice non-streaming ASR models (#2589)
This PR adds RK NPU support for SenseVoice non-streaming ASR models by implementing a new RKNN backend with greedy CTC decoding. - Adds offline RKNN implementation for SenseVoice models including model loading, feature processing, and CTC decoding - Introduces export tools to convert SenseVoice models from PyTorch to ONNX and then to RKNN format - Implements provider-aware validation to prevent mismatched model and provider usage
正在显示
18 个修改的文件
包含
1737 行增加
和
8 行删除
| @@ -152,3 +152,7 @@ vocab.json | @@ -152,3 +152,7 @@ vocab.json | ||
| 152 | *.so | 152 | *.so |
| 153 | sherpa-onnx-streaming-t-one-russian-2025-09-08 | 153 | sherpa-onnx-streaming-t-one-russian-2025-09-08 |
| 154 | sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10 | 154 | sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10 |
| 155 | +am.mvn | ||
| 156 | +*bpe.model | ||
| 157 | +config.yaml | ||
| 158 | +configuration.json |
| @@ -118,8 +118,11 @@ def display_params(params): | @@ -118,8 +118,11 @@ def display_params(params): | ||
| 118 | os.system(f"cat {params['config']}") | 118 | os.system(f"cat {params['config']}") |
| 119 | 119 | ||
| 120 | 120 | ||
| 121 | +@torch.no_grad() | ||
| 121 | def main(): | 122 | def main(): |
| 122 | model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu") | 123 | model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu") |
| 124 | + model.eval() | ||
| 125 | + | ||
| 123 | display_params(params) | 126 | display_params(params) |
| 124 | 127 | ||
| 125 | generate_tokens(params) | 128 | generate_tokens(params) |
scripts/sense-voice/rknn/export-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +import os | ||
| 6 | +from typing import Any, Dict, List, Tuple | ||
| 7 | + | ||
| 8 | +import onnx | ||
| 9 | +import sentencepiece as spm | ||
| 10 | +import torch | ||
| 11 | + | ||
| 12 | +from torch_model import SenseVoiceSmall | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def get_args(): | ||
| 16 | + parser = argparse.ArgumentParser( | ||
| 17 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 18 | + ) | ||
| 19 | + | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--input-len-in-seconds", | ||
| 22 | + type=int, | ||
| 23 | + required=True, | ||
| 24 | + help="""RKNN does not support dynamic shape, so we need to hard-code | ||
| 25 | + how long the model can process. | ||
| 26 | + """, | ||
| 27 | + ) | ||
| 28 | + return parser.parse_args() | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
| 32 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 33 | + | ||
| 34 | + Args: | ||
| 35 | + filename: | ||
| 36 | + Filename of the ONNX model to be changed. | ||
| 37 | + meta_data: | ||
| 38 | + Key-value pairs. | ||
| 39 | + """ | ||
| 40 | + model = onnx.load(filename) | ||
| 41 | + while len(model.metadata_props): | ||
| 42 | + model.metadata_props.pop() | ||
| 43 | + | ||
| 44 | + for key, value in meta_data.items(): | ||
| 45 | + meta = model.metadata_props.add() | ||
| 46 | + meta.key = key | ||
| 47 | + meta.value = str(value) | ||
| 48 | + | ||
| 49 | + onnx.save(model, filename) | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +def load_cmvn(filename) -> Tuple[List[float], List[float]]: | ||
| 53 | + neg_mean = None | ||
| 54 | + inv_stddev = None | ||
| 55 | + | ||
| 56 | + with open(filename) as f: | ||
| 57 | + for line in f: | ||
| 58 | + if not line.startswith("<LearnRateCoef>"): | ||
| 59 | + continue | ||
| 60 | + t = line.split()[3:-1] | ||
| 61 | + | ||
| 62 | + if neg_mean is None: | ||
| 63 | + neg_mean = list(map(lambda x: float(x), t)) | ||
| 64 | + else: | ||
| 65 | + inv_stddev = list(map(lambda x: float(x), t)) | ||
| 66 | + | ||
| 67 | + return neg_mean, inv_stddev | ||
| 68 | + | ||
| 69 | + | ||
| 70 | +def generate_tokens(sp): | ||
| 71 | + with open("tokens.txt", "w", encoding="utf-8") as f: | ||
| 72 | + for i in range(sp.vocab_size()): | ||
| 73 | + f.write(f"{sp.id_to_piece(i)} {i}\n") | ||
| 74 | + print("saved to tokens.txt") | ||
| 75 | + | ||
| 76 | + | ||
| 77 | +@torch.no_grad() | ||
| 78 | +def main(): | ||
| 79 | + args = get_args() | ||
| 80 | + print(vars(args)) | ||
| 81 | + | ||
| 82 | + sp = spm.SentencePieceProcessor() | ||
| 83 | + sp.load("./chn_jpn_yue_eng_ko_spectok.bpe.model") | ||
| 84 | + vocab_size = sp.vocab_size() | ||
| 85 | + generate_tokens(sp) | ||
| 86 | + | ||
| 87 | + print("loading model") | ||
| 88 | + | ||
| 89 | + state_dict = torch.load("./model.pt") | ||
| 90 | + if "state_dict" in state_dict: | ||
| 91 | + state_dict = state_dict["state_dict"] | ||
| 92 | + | ||
| 93 | + neg_mean, inv_stddev = load_cmvn("./am.mvn") | ||
| 94 | + | ||
| 95 | + neg_mean = torch.tensor(neg_mean, dtype=torch.float32) | ||
| 96 | + inv_stddev = torch.tensor(inv_stddev, dtype=torch.float32) | ||
| 97 | + | ||
| 98 | + model = SenseVoiceSmall(neg_mean=neg_mean, inv_stddev=inv_stddev) | ||
| 99 | + model.load_state_dict(state_dict) | ||
| 100 | + model.eval() | ||
| 101 | + del state_dict | ||
| 102 | + | ||
| 103 | + lfr_window_size = 7 | ||
| 104 | + lfr_window_shift = 6 | ||
| 105 | + | ||
| 106 | + # frame shift is 10ms, 1 second has about 100 feature frames | ||
| 107 | + input_len_in_seconds = int(args.input_len_in_seconds) | ||
| 108 | + num_frames = input_len_in_seconds * 100 | ||
| 109 | + print("num_frames", num_frames) | ||
| 110 | + | ||
| 111 | + # num_input_frames is an approximate number | ||
| 112 | + num_input_frames = int(num_frames / lfr_window_shift + 0.5) | ||
| 113 | + print("num_input_frames", num_input_frames) | ||
| 114 | + | ||
| 115 | + x = torch.randn(1, num_input_frames, 560, dtype=torch.float32) | ||
| 116 | + | ||
| 117 | + language = 3 | ||
| 118 | + text_norm = 15 | ||
| 119 | + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32) | ||
| 120 | + | ||
| 121 | + opset_version = 13 | ||
| 122 | + filename = f"model-{input_len_in_seconds}-seconds.onnx" | ||
| 123 | + torch.onnx.export( | ||
| 124 | + model, | ||
| 125 | + (x, prompt), | ||
| 126 | + filename, | ||
| 127 | + opset_version=opset_version, | ||
| 128 | + input_names=["x", "prompt"], | ||
| 129 | + output_names=["logits"], | ||
| 130 | + dynamic_axes={}, | ||
| 131 | + ) | ||
| 132 | + | ||
| 133 | + model_author = os.environ.get("model_author", "iic") | ||
| 134 | + comment = os.environ.get("comment", "iic/SenseVoiceSmall") | ||
| 135 | + url = os.environ.get("url", "https://huggingface.co/FunAudioLLM/SenseVoiceSmall") | ||
| 136 | + | ||
| 137 | + meta_data = { | ||
| 138 | + "lfr_window_size": lfr_window_size, | ||
| 139 | + "lfr_window_shift": lfr_window_shift, | ||
| 140 | + "num_input_frames": num_input_frames, | ||
| 141 | + "normalize_samples": 0, # input should be in the range [-32768, 32767] | ||
| 142 | + "model_type": "sense_voice_ctc", | ||
| 143 | + "version": "1", | ||
| 144 | + "model_author": model_author, | ||
| 145 | + "maintainer": "k2-fsa", | ||
| 146 | + "vocab_size": vocab_size, | ||
| 147 | + "comment": comment, | ||
| 148 | + "lang_auto": model.lid_dict["auto"], | ||
| 149 | + "lang_zh": model.lid_dict["zh"], | ||
| 150 | + "lang_en": model.lid_dict["en"], | ||
| 151 | + "lang_yue": model.lid_dict["yue"], # cantonese | ||
| 152 | + "lang_ja": model.lid_dict["ja"], | ||
| 153 | + "lang_ko": model.lid_dict["ko"], | ||
| 154 | + "lang_nospeech": model.lid_dict["nospeech"], | ||
| 155 | + "with_itn": model.textnorm_dict["withitn"], | ||
| 156 | + "without_itn": model.textnorm_dict["woitn"], | ||
| 157 | + "url": url, | ||
| 158 | + } | ||
| 159 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 160 | + | ||
| 161 | + | ||
| 162 | +if __name__ == "__main__": | ||
| 163 | + torch.manual_seed(20250717) | ||
| 164 | + main() |
scripts/sense-voice/rknn/export-rknn.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +import logging | ||
| 6 | +from pathlib import Path | ||
| 7 | + | ||
| 8 | +from rknn.api import RKNN | ||
| 9 | + | ||
| 10 | +logging.basicConfig(level=logging.WARNING) | ||
| 11 | + | ||
| 12 | +g_platforms = [ | ||
| 13 | + # "rv1103", | ||
| 14 | + # "rv1103b", | ||
| 15 | + # "rv1106", | ||
| 16 | + # "rk2118", | ||
| 17 | + "rk3562", | ||
| 18 | + "rk3566", | ||
| 19 | + "rk3568", | ||
| 20 | + "rk3576", | ||
| 21 | + "rk3588", | ||
| 22 | +] | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def get_parser(): | ||
| 26 | + parser = argparse.ArgumentParser( | ||
| 27 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 28 | + ) | ||
| 29 | + | ||
| 30 | + parser.add_argument( | ||
| 31 | + "--target-platform", | ||
| 32 | + type=str, | ||
| 33 | + required=True, | ||
| 34 | + help=f"Supported values are: {','.join(g_platforms)}", | ||
| 35 | + ) | ||
| 36 | + | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--in-model", | ||
| 39 | + type=str, | ||
| 40 | + required=True, | ||
| 41 | + help="Path to the input onnx model", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument( | ||
| 45 | + "--out-model", | ||
| 46 | + type=str, | ||
| 47 | + required=True, | ||
| 48 | + help="Path to the output rknn model", | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + return parser | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +def get_meta_data(model: str): | ||
| 55 | + import onnxruntime | ||
| 56 | + | ||
| 57 | + session_opts = onnxruntime.SessionOptions() | ||
| 58 | + session_opts.inter_op_num_threads = 1 | ||
| 59 | + session_opts.intra_op_num_threads = 1 | ||
| 60 | + | ||
| 61 | + m = onnxruntime.InferenceSession( | ||
| 62 | + model, | ||
| 63 | + sess_options=session_opts, | ||
| 64 | + providers=["CPUExecutionProvider"], | ||
| 65 | + ) | ||
| 66 | + | ||
| 67 | + for i in m.get_inputs(): | ||
| 68 | + print(i) | ||
| 69 | + | ||
| 70 | + print("-----") | ||
| 71 | + | ||
| 72 | + for i in m.get_outputs(): | ||
| 73 | + print(i) | ||
| 74 | + print() | ||
| 75 | + | ||
| 76 | + meta = m.get_modelmeta().custom_metadata_map | ||
| 77 | + s = "" | ||
| 78 | + sep = "" | ||
| 79 | + for key, value in meta.items(): | ||
| 80 | + if key in ("neg_mean", "inv_stddev"): | ||
| 81 | + continue | ||
| 82 | + s = s + sep + f"{key}={value}" | ||
| 83 | + sep = ";" | ||
| 84 | + assert len(s) < 1024, len(s) | ||
| 85 | + | ||
| 86 | + print("len(s)", len(s), s) | ||
| 87 | + | ||
| 88 | + return s | ||
| 89 | + | ||
| 90 | + | ||
| 91 | +def export_rknn(rknn, filename): | ||
| 92 | + ret = rknn.export_rknn(filename) | ||
| 93 | + if ret != 0: | ||
| 94 | + exit(f"Export rknn model to {filename} failed!") | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +def init_model(filename: str, target_platform: str, custom_string=None): | ||
| 98 | + rknn = RKNN(verbose=False) | ||
| 99 | + | ||
| 100 | + rknn.config( | ||
| 101 | + optimization_level=0, | ||
| 102 | + target_platform=target_platform, | ||
| 103 | + custom_string=custom_string, | ||
| 104 | + ) | ||
| 105 | + if not Path(filename).is_file(): | ||
| 106 | + exit(f"{filename} does not exist") | ||
| 107 | + | ||
| 108 | + ret = rknn.load_onnx(model=filename) | ||
| 109 | + if ret != 0: | ||
| 110 | + exit(f"Load model {filename} failed!") | ||
| 111 | + | ||
| 112 | + ret = rknn.build(do_quantization=False) | ||
| 113 | + if ret != 0: | ||
| 114 | + exit(f"Build model {filename} failed!") | ||
| 115 | + | ||
| 116 | + return rknn | ||
| 117 | + | ||
| 118 | + | ||
| 119 | +class RKNNModel: | ||
| 120 | + def __init__( | ||
| 121 | + self, | ||
| 122 | + model: str, | ||
| 123 | + target_platform: str, | ||
| 124 | + ): | ||
| 125 | + meta = get_meta_data(model) | ||
| 126 | + print(meta) | ||
| 127 | + | ||
| 128 | + self.model = init_model( | ||
| 129 | + model, | ||
| 130 | + target_platform=target_platform, | ||
| 131 | + custom_string=meta, | ||
| 132 | + ) | ||
| 133 | + | ||
| 134 | + def export_rknn(self, model): | ||
| 135 | + export_rknn(self.model, model) | ||
| 136 | + | ||
| 137 | + def release(self): | ||
| 138 | + self.model.release() | ||
| 139 | + | ||
| 140 | + | ||
| 141 | +def main(): | ||
| 142 | + args = get_parser().parse_args() | ||
| 143 | + print(vars(args)) | ||
| 144 | + | ||
| 145 | + model = RKNNModel( | ||
| 146 | + model=args.in_model, | ||
| 147 | + target_platform=args.target_platform, | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + model.export_rknn( | ||
| 151 | + model=args.out_model, | ||
| 152 | + ) | ||
| 153 | + | ||
| 154 | + model.release() | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +if __name__ == "__main__": | ||
| 158 | + main() |
scripts/sense-voice/rknn/test_onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from typing import Tuple | ||
| 6 | + | ||
| 7 | +import kaldi_native_fbank as knf | ||
| 8 | +import numpy as np | ||
| 9 | +import onnxruntime | ||
| 10 | +import onnxruntime as ort | ||
| 11 | +import soundfile as sf | ||
| 12 | +import torch | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def get_args(): | ||
| 16 | + parser = argparse.ArgumentParser( | ||
| 17 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 18 | + ) | ||
| 19 | + | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--model", | ||
| 22 | + type=str, | ||
| 23 | + required=True, | ||
| 24 | + help="Path to model.onnx", | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--tokens", | ||
| 29 | + type=str, | ||
| 30 | + required=True, | ||
| 31 | + help="Path to tokens.txt", | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + parser.add_argument( | ||
| 35 | + "--wave", | ||
| 36 | + type=str, | ||
| 37 | + required=True, | ||
| 38 | + help="The input wave to be recognized", | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + parser.add_argument( | ||
| 42 | + "--language", | ||
| 43 | + type=str, | ||
| 44 | + default="auto", | ||
| 45 | + help="the language of the input wav file. Supported values: zh, en, ja, ko, yue, auto", | ||
| 46 | + ) | ||
| 47 | + | ||
| 48 | + parser.add_argument( | ||
| 49 | + "--use-itn", | ||
| 50 | + type=int, | ||
| 51 | + default=0, | ||
| 52 | + help="1 to use inverse text normalization. 0 to not use inverse text normalization", | ||
| 53 | + ) | ||
| 54 | + | ||
| 55 | + return parser.parse_args() | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +class OnnxModel: | ||
| 59 | + def __init__(self, filename): | ||
| 60 | + session_opts = ort.SessionOptions() | ||
| 61 | + session_opts.inter_op_num_threads = 1 | ||
| 62 | + session_opts.intra_op_num_threads = 1 | ||
| 63 | + | ||
| 64 | + self.session_opts = session_opts | ||
| 65 | + | ||
| 66 | + self.model = ort.InferenceSession( | ||
| 67 | + filename, | ||
| 68 | + sess_options=self.session_opts, | ||
| 69 | + providers=["CPUExecutionProvider"], | ||
| 70 | + ) | ||
| 71 | + | ||
| 72 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 73 | + | ||
| 74 | + self.window_size = int(meta["lfr_window_size"]) # lfr_m | ||
| 75 | + self.window_shift = int(meta["lfr_window_shift"]) # lfr_n | ||
| 76 | + | ||
| 77 | + lang_zh = int(meta["lang_zh"]) | ||
| 78 | + lang_en = int(meta["lang_en"]) | ||
| 79 | + lang_ja = int(meta["lang_ja"]) | ||
| 80 | + lang_ko = int(meta["lang_ko"]) | ||
| 81 | + lang_yue = int(meta["lang_yue"]) | ||
| 82 | + lang_auto = int(meta["lang_auto"]) | ||
| 83 | + | ||
| 84 | + self.lang_id = { | ||
| 85 | + "zh": lang_zh, | ||
| 86 | + "en": lang_en, | ||
| 87 | + "ja": lang_ja, | ||
| 88 | + "ko": lang_ko, | ||
| 89 | + "yue": lang_yue, | ||
| 90 | + "auto": lang_auto, | ||
| 91 | + } | ||
| 92 | + self.with_itn = int(meta["with_itn"]) | ||
| 93 | + self.without_itn = int(meta["without_itn"]) | ||
| 94 | + | ||
| 95 | + self.max_len = self.model.get_inputs()[0].shape[1] | ||
| 96 | + | ||
| 97 | + def __call__(self, x, prompt): | ||
| 98 | + logits = self.model.run( | ||
| 99 | + [ | ||
| 100 | + self.model.get_outputs()[0].name, | ||
| 101 | + ], | ||
| 102 | + { | ||
| 103 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 104 | + self.model.get_inputs()[1].name: prompt.numpy(), | ||
| 105 | + }, | ||
| 106 | + )[0] | ||
| 107 | + | ||
| 108 | + return torch.from_numpy(logits) | ||
| 109 | + | ||
| 110 | + | ||
| 111 | +def load_audio(filename: str) -> Tuple[np.ndarray, int]: | ||
| 112 | + data, sample_rate = sf.read( | ||
| 113 | + filename, | ||
| 114 | + always_2d=True, | ||
| 115 | + dtype="float32", | ||
| 116 | + ) | ||
| 117 | + data = data[:, 0] # use only the first channel | ||
| 118 | + samples = np.ascontiguousarray(data) | ||
| 119 | + return samples, sample_rate | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +def load_tokens(filename): | ||
| 123 | + ans = dict() | ||
| 124 | + i = 0 | ||
| 125 | + with open(filename, encoding="utf-8") as f: | ||
| 126 | + for line in f: | ||
| 127 | + ans[i] = line.strip().split()[0] | ||
| 128 | + i += 1 | ||
| 129 | + return ans | ||
| 130 | + | ||
| 131 | + | ||
| 132 | +def compute_feat( | ||
| 133 | + samples, | ||
| 134 | + sample_rate, | ||
| 135 | + max_len: int, | ||
| 136 | + window_size: int = 7, # lfr_m | ||
| 137 | + window_shift: int = 6, # lfr_n | ||
| 138 | +): | ||
| 139 | + opts = knf.FbankOptions() | ||
| 140 | + opts.frame_opts.dither = 0 | ||
| 141 | + opts.frame_opts.snip_edges = False | ||
| 142 | + opts.frame_opts.window_type = "hamming" | ||
| 143 | + opts.frame_opts.samp_freq = sample_rate | ||
| 144 | + opts.mel_opts.num_bins = 80 | ||
| 145 | + | ||
| 146 | + online_fbank = knf.OnlineFbank(opts) | ||
| 147 | + online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist()) | ||
| 148 | + online_fbank.input_finished() | ||
| 149 | + | ||
| 150 | + features = np.stack( | ||
| 151 | + [online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)] | ||
| 152 | + ) | ||
| 153 | + assert features.data.contiguous is True | ||
| 154 | + assert features.dtype == np.float32, features.dtype | ||
| 155 | + | ||
| 156 | + T = (features.shape[0] - window_size) // window_shift + 1 | ||
| 157 | + features = np.lib.stride_tricks.as_strided( | ||
| 158 | + features, | ||
| 159 | + shape=(T, features.shape[1] * window_size), | ||
| 160 | + strides=((window_shift * features.shape[1]) * 4, 4), | ||
| 161 | + ) | ||
| 162 | + | ||
| 163 | + print("features.shape", features.shape) | ||
| 164 | + | ||
| 165 | + if features.shape[0] > max_len: | ||
| 166 | + features = features[:max_len] | ||
| 167 | + elif features.shape[0] < max_len: | ||
| 168 | + features = np.pad( | ||
| 169 | + features, | ||
| 170 | + ((0, max_len - features.shape[0]), (0, 0)), | ||
| 171 | + mode="constant", | ||
| 172 | + constant_values=0, | ||
| 173 | + ) | ||
| 174 | + | ||
| 175 | + print("features.shape", features.shape) | ||
| 176 | + | ||
| 177 | + return features | ||
| 178 | + | ||
| 179 | + | ||
| 180 | +def main(): | ||
| 181 | + args = get_args() | ||
| 182 | + print(vars(args)) | ||
| 183 | + samples, sample_rate = load_audio(args.wave) | ||
| 184 | + if sample_rate != 16000: | ||
| 185 | + import librosa | ||
| 186 | + | ||
| 187 | + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) | ||
| 188 | + sample_rate = 16000 | ||
| 189 | + | ||
| 190 | + model = OnnxModel(filename=args.model) | ||
| 191 | + | ||
| 192 | + features = compute_feat( | ||
| 193 | + samples=samples, | ||
| 194 | + sample_rate=sample_rate, | ||
| 195 | + max_len=model.max_len, | ||
| 196 | + window_size=model.window_size, | ||
| 197 | + window_shift=model.window_shift, | ||
| 198 | + ) | ||
| 199 | + | ||
| 200 | + features = torch.from_numpy(features).unsqueeze(0) | ||
| 201 | + | ||
| 202 | + language = model.lang_id["auto"] | ||
| 203 | + if args.language in model.lang_id: | ||
| 204 | + language = model.lang_id[args.language] | ||
| 205 | + else: | ||
| 206 | + print(f"Invalid language: '{args.language}'") | ||
| 207 | + print("Use auto") | ||
| 208 | + | ||
| 209 | + if args.use_itn: | ||
| 210 | + text_norm = model.with_itn | ||
| 211 | + else: | ||
| 212 | + text_norm = model.without_itn | ||
| 213 | + | ||
| 214 | + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32) | ||
| 215 | + | ||
| 216 | + logits = model( | ||
| 217 | + x=features, | ||
| 218 | + prompt=prompt, | ||
| 219 | + ) | ||
| 220 | + | ||
| 221 | + idx = logits.squeeze(0).argmax(dim=-1) | ||
| 222 | + # idx is of shape (T,) | ||
| 223 | + idx = torch.unique_consecutive(idx) | ||
| 224 | + | ||
| 225 | + blank_id = 0 | ||
| 226 | + idx = idx[idx != blank_id].tolist() | ||
| 227 | + | ||
| 228 | + tokens = load_tokens(args.tokens) | ||
| 229 | + text = "".join([tokens[i] for i in idx]) | ||
| 230 | + | ||
| 231 | + text = text.replace("▁", " ") | ||
| 232 | + print(text) | ||
| 233 | + | ||
| 234 | + | ||
| 235 | +if __name__ == "__main__": | ||
| 236 | + main() |
scripts/sense-voice/rknn/torch_model.py
0 → 100644
| 1 | +# This file is modified from | ||
| 2 | +# https://github.com/modelscope/FunASR/blob/main/funasr/models/sense_voice/model.py | ||
| 3 | + | ||
| 4 | +import torch | ||
| 5 | +import torch.nn | ||
| 6 | +import torch.nn as nn | ||
| 7 | +import torch.nn.functional as F | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +class SinusoidalPositionEncoder(nn.Module): | ||
| 11 | + """ """ | ||
| 12 | + | ||
| 13 | + def __init__(self, d_model=80, dropout_rate=0.1): | ||
| 14 | + pass | ||
| 15 | + | ||
| 16 | + def encode( | ||
| 17 | + self, | ||
| 18 | + positions: torch.Tensor = None, | ||
| 19 | + depth: int = None, | ||
| 20 | + dtype: torch.dtype = torch.float32, | ||
| 21 | + ): | ||
| 22 | + """ | ||
| 23 | + Args: | ||
| 24 | + positions: (batch_size, ) | ||
| 25 | + """ | ||
| 26 | + batch_size = positions.size(0) | ||
| 27 | + positions = positions.type(dtype) | ||
| 28 | + device = positions.device | ||
| 29 | + log_timescale_increment = torch.log( | ||
| 30 | + torch.tensor([10000], dtype=dtype, device=device) | ||
| 31 | + ) / (depth / 2 - 1) | ||
| 32 | + inv_timescales = torch.exp( | ||
| 33 | + torch.arange(depth / 2, device=device).type(dtype) | ||
| 34 | + * (-log_timescale_increment) | ||
| 35 | + ) | ||
| 36 | + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) | ||
| 37 | + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( | ||
| 38 | + inv_timescales, [1, 1, -1] | ||
| 39 | + ) | ||
| 40 | + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) | ||
| 41 | + return encoding.type(dtype) | ||
| 42 | + | ||
| 43 | + def forward(self, x): | ||
| 44 | + batch_size, timesteps, input_dim = x.size() | ||
| 45 | + positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] | ||
| 46 | + position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) | ||
| 47 | + | ||
| 48 | + return x + position_encoding | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +class PositionwiseFeedForward(nn.Module): | ||
| 52 | + """Positionwise feed forward layer. | ||
| 53 | + | ||
| 54 | + Args: | ||
| 55 | + idim (int): Input dimenstion. | ||
| 56 | + hidden_units (int): The number of hidden units. | ||
| 57 | + dropout_rate (float): Dropout rate. | ||
| 58 | + | ||
| 59 | + """ | ||
| 60 | + | ||
| 61 | + def __init__(self, idim, hidden_units, dropout_rate, activation=None): | ||
| 62 | + super().__init__() | ||
| 63 | + self.w_1 = torch.nn.Linear(idim, hidden_units) | ||
| 64 | + self.w_2 = torch.nn.Linear(hidden_units, idim) | ||
| 65 | + self.dropout = torch.nn.Dropout(dropout_rate) | ||
| 66 | + if activation is None: | ||
| 67 | + activation = torch.nn.ReLU() | ||
| 68 | + self.activation = activation | ||
| 69 | + | ||
| 70 | + def forward(self, x): | ||
| 71 | + """Forward function.""" | ||
| 72 | + return self.w_2(self.dropout(self.activation(self.w_1(x)))) | ||
| 73 | + | ||
| 74 | + | ||
| 75 | +class MultiHeadedAttentionSANM(nn.Module): | ||
| 76 | + """Multi-Head Attention layer. | ||
| 77 | + | ||
| 78 | + Args: | ||
| 79 | + n_head (int): The number of heads. | ||
| 80 | + n_feat (int): The number of features. | ||
| 81 | + dropout_rate (float): Dropout rate. | ||
| 82 | + | ||
| 83 | + """ | ||
| 84 | + | ||
| 85 | + def __init__( | ||
| 86 | + self, | ||
| 87 | + n_head, | ||
| 88 | + in_feat, | ||
| 89 | + n_feat, | ||
| 90 | + dropout_rate, | ||
| 91 | + kernel_size, | ||
| 92 | + sanm_shfit=0, | ||
| 93 | + lora_list=None, | ||
| 94 | + lora_rank=8, | ||
| 95 | + lora_alpha=16, | ||
| 96 | + lora_dropout=0.1, | ||
| 97 | + ): | ||
| 98 | + super().__init__() | ||
| 99 | + assert n_feat % n_head == 0 | ||
| 100 | + # We assume d_v always equals d_k | ||
| 101 | + self.d_k = n_feat // n_head | ||
| 102 | + self.h = n_head | ||
| 103 | + self.linear_out = nn.Linear(n_feat, n_feat) | ||
| 104 | + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) | ||
| 105 | + self.attn = None | ||
| 106 | + self.dropout = nn.Dropout(p=dropout_rate) | ||
| 107 | + | ||
| 108 | + self.fsmn_block = nn.Conv1d( | ||
| 109 | + n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False | ||
| 110 | + ) | ||
| 111 | + # padding | ||
| 112 | + left_padding = (kernel_size - 1) // 2 | ||
| 113 | + if sanm_shfit > 0: | ||
| 114 | + left_padding = left_padding + sanm_shfit | ||
| 115 | + right_padding = kernel_size - 1 - left_padding | ||
| 116 | + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | ||
| 117 | + | ||
| 118 | + def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | ||
| 119 | + b, t, d = inputs.size() | ||
| 120 | + if mask is not None: | ||
| 121 | + mask = torch.reshape(mask, (b, -1, 1)) | ||
| 122 | + if mask_shfit_chunk is not None: | ||
| 123 | + mask = mask * mask_shfit_chunk | ||
| 124 | + inputs = inputs * mask | ||
| 125 | + | ||
| 126 | + x = inputs.transpose(1, 2) | ||
| 127 | + x = self.pad_fn(x) | ||
| 128 | + x = self.fsmn_block(x) | ||
| 129 | + x = x.transpose(1, 2) | ||
| 130 | + x += inputs | ||
| 131 | + x = self.dropout(x) | ||
| 132 | + if mask is not None: | ||
| 133 | + x = x * mask | ||
| 134 | + return x | ||
| 135 | + | ||
| 136 | + def forward_qkv(self, x): | ||
| 137 | + """Transform query, key and value. | ||
| 138 | + | ||
| 139 | + Args: | ||
| 140 | + query (torch.Tensor): Query tensor (#batch, time1, size). | ||
| 141 | + key (torch.Tensor): Key tensor (#batch, time2, size). | ||
| 142 | + value (torch.Tensor): Value tensor (#batch, time2, size). | ||
| 143 | + | ||
| 144 | + Returns: | ||
| 145 | + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | ||
| 146 | + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | ||
| 147 | + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | ||
| 148 | + | ||
| 149 | + """ | ||
| 150 | + b, t, d = x.size() | ||
| 151 | + q_k_v = self.linear_q_k_v(x) | ||
| 152 | + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) | ||
| 153 | + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( | ||
| 154 | + 1, 2 | ||
| 155 | + ) # (batch, head, time1, d_k) | ||
| 156 | + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( | ||
| 157 | + 1, 2 | ||
| 158 | + ) # (batch, head, time2, d_k) | ||
| 159 | + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( | ||
| 160 | + 1, 2 | ||
| 161 | + ) # (batch, head, time2, d_k) | ||
| 162 | + | ||
| 163 | + return q_h, k_h, v_h, v | ||
| 164 | + | ||
| 165 | + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): | ||
| 166 | + """Compute attention context vector. | ||
| 167 | + | ||
| 168 | + Args: | ||
| 169 | + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | ||
| 170 | + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | ||
| 171 | + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | ||
| 172 | + | ||
| 173 | + Returns: | ||
| 174 | + torch.Tensor: Transformed value (#batch, time1, d_model) | ||
| 175 | + weighted by the attention score (#batch, time1, time2). | ||
| 176 | + | ||
| 177 | + """ | ||
| 178 | + n_batch = value.size(0) | ||
| 179 | + if mask is not None: | ||
| 180 | + if mask_att_chunk_encoder is not None: | ||
| 181 | + mask = mask * mask_att_chunk_encoder | ||
| 182 | + | ||
| 183 | + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | ||
| 184 | + | ||
| 185 | + min_value = -float( | ||
| 186 | + "inf" | ||
| 187 | + ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) | ||
| 188 | + scores = scores.masked_fill(mask, min_value) | ||
| 189 | + attn = torch.softmax(scores, dim=-1).masked_fill( | ||
| 190 | + mask, 0.0 | ||
| 191 | + ) # (batch, head, time1, time2) | ||
| 192 | + else: | ||
| 193 | + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) | ||
| 194 | + | ||
| 195 | + p_attn = self.dropout(attn) | ||
| 196 | + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | ||
| 197 | + x = ( | ||
| 198 | + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) | ||
| 199 | + ) # (batch, time1, d_model) | ||
| 200 | + | ||
| 201 | + return self.linear_out(x) # (batch, time1, d_model) | ||
| 202 | + | ||
| 203 | + def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): | ||
| 204 | + """Compute scaled dot product attention. | ||
| 205 | + | ||
| 206 | + Args: | ||
| 207 | + query (torch.Tensor): Query tensor (#batch, time1, size). | ||
| 208 | + key (torch.Tensor): Key tensor (#batch, time2, size). | ||
| 209 | + value (torch.Tensor): Value tensor (#batch, time2, size). | ||
| 210 | + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | ||
| 211 | + (#batch, time1, time2). | ||
| 212 | + | ||
| 213 | + Returns: | ||
| 214 | + torch.Tensor: Output tensor (#batch, time1, d_model). | ||
| 215 | + | ||
| 216 | + """ | ||
| 217 | + q_h, k_h, v_h, v = self.forward_qkv(x) | ||
| 218 | + fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) | ||
| 219 | + q_h = q_h * self.d_k ** (-0.5) | ||
| 220 | + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) | ||
| 221 | + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) | ||
| 222 | + return att_outs + fsmn_memory | ||
| 223 | + | ||
| 224 | + | ||
| 225 | +class EncoderLayerSANM(nn.Module): | ||
| 226 | + def __init__( | ||
| 227 | + self, | ||
| 228 | + in_size, | ||
| 229 | + size, | ||
| 230 | + self_attn, | ||
| 231 | + feed_forward, | ||
| 232 | + dropout_rate, | ||
| 233 | + normalize_before=True, | ||
| 234 | + concat_after=False, | ||
| 235 | + stochastic_depth_rate=0.0, | ||
| 236 | + ): | ||
| 237 | + super().__init__() | ||
| 238 | + self.self_attn = self_attn | ||
| 239 | + self.feed_forward = feed_forward | ||
| 240 | + self.norm1 = LayerNorm(in_size) | ||
| 241 | + self.norm2 = LayerNorm(size) | ||
| 242 | + self.dropout = nn.Dropout(dropout_rate) | ||
| 243 | + self.in_size = in_size | ||
| 244 | + self.size = size | ||
| 245 | + self.normalize_before = normalize_before | ||
| 246 | + self.concat_after = concat_after | ||
| 247 | + if self.concat_after: | ||
| 248 | + self.concat_linear = nn.Linear(size + size, size) | ||
| 249 | + self.stochastic_depth_rate = stochastic_depth_rate | ||
| 250 | + self.dropout_rate = dropout_rate | ||
| 251 | + | ||
| 252 | + def forward( | ||
| 253 | + self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None | ||
| 254 | + ): | ||
| 255 | + """Compute encoded features. | ||
| 256 | + | ||
| 257 | + Args: | ||
| 258 | + x_input (torch.Tensor): Input tensor (#batch, time, size). | ||
| 259 | + mask (torch.Tensor): Mask tensor for the input (#batch, time). | ||
| 260 | + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | ||
| 261 | + | ||
| 262 | + Returns: | ||
| 263 | + torch.Tensor: Output tensor (#batch, time, size). | ||
| 264 | + torch.Tensor: Mask tensor (#batch, time). | ||
| 265 | + | ||
| 266 | + """ | ||
| 267 | + skip_layer = False | ||
| 268 | + # with stochastic depth, residual connection `x + f(x)` becomes | ||
| 269 | + # `x <- x + 1 / (1 - p) * f(x)` at training time. | ||
| 270 | + stoch_layer_coeff = 1.0 | ||
| 271 | + if self.training and self.stochastic_depth_rate > 0: | ||
| 272 | + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | ||
| 273 | + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | ||
| 274 | + | ||
| 275 | + if skip_layer: | ||
| 276 | + if cache is not None: | ||
| 277 | + x = torch.cat([cache, x], dim=1) | ||
| 278 | + return x, mask | ||
| 279 | + | ||
| 280 | + residual = x | ||
| 281 | + if self.normalize_before: | ||
| 282 | + x = self.norm1(x) | ||
| 283 | + | ||
| 284 | + if self.concat_after: | ||
| 285 | + x_concat = torch.cat( | ||
| 286 | + ( | ||
| 287 | + x, | ||
| 288 | + self.self_attn( | ||
| 289 | + x, | ||
| 290 | + mask, | ||
| 291 | + mask_shfit_chunk=mask_shfit_chunk, | ||
| 292 | + mask_att_chunk_encoder=mask_att_chunk_encoder, | ||
| 293 | + ), | ||
| 294 | + ), | ||
| 295 | + dim=-1, | ||
| 296 | + ) | ||
| 297 | + if self.in_size == self.size: | ||
| 298 | + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | ||
| 299 | + else: | ||
| 300 | + x = stoch_layer_coeff * self.concat_linear(x_concat) | ||
| 301 | + else: | ||
| 302 | + if self.in_size == self.size: | ||
| 303 | + x = residual + stoch_layer_coeff * self.dropout( | ||
| 304 | + self.self_attn( | ||
| 305 | + x, | ||
| 306 | + mask, | ||
| 307 | + mask_shfit_chunk=mask_shfit_chunk, | ||
| 308 | + mask_att_chunk_encoder=mask_att_chunk_encoder, | ||
| 309 | + ) | ||
| 310 | + ) | ||
| 311 | + else: | ||
| 312 | + x = stoch_layer_coeff * self.dropout( | ||
| 313 | + self.self_attn( | ||
| 314 | + x, | ||
| 315 | + mask, | ||
| 316 | + mask_shfit_chunk=mask_shfit_chunk, | ||
| 317 | + mask_att_chunk_encoder=mask_att_chunk_encoder, | ||
| 318 | + ) | ||
| 319 | + ) | ||
| 320 | + return x, mask | ||
| 321 | + if not self.normalize_before: | ||
| 322 | + x = self.norm1(x) | ||
| 323 | + | ||
| 324 | + residual = x | ||
| 325 | + if self.normalize_before: | ||
| 326 | + x = self.norm2(x) | ||
| 327 | + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | ||
| 328 | + if not self.normalize_before: | ||
| 329 | + x = self.norm2(x) | ||
| 330 | + | ||
| 331 | + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder | ||
| 332 | + | ||
| 333 | + | ||
| 334 | +class LayerNorm(nn.LayerNorm): | ||
| 335 | + def __init__(self, *args, **kwargs): | ||
| 336 | + super().__init__(*args, **kwargs) | ||
| 337 | + | ||
| 338 | + def forward(self, input): | ||
| 339 | + output = F.layer_norm( | ||
| 340 | + input.float(), | ||
| 341 | + self.normalized_shape, | ||
| 342 | + self.weight.float() if self.weight is not None else None, | ||
| 343 | + self.bias.float() if self.bias is not None else None, | ||
| 344 | + self.eps, | ||
| 345 | + ) | ||
| 346 | + return output.type_as(input) | ||
| 347 | + | ||
| 348 | + | ||
| 349 | +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): | ||
| 350 | + if maxlen is None: | ||
| 351 | + maxlen = lengths.max() | ||
| 352 | + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) | ||
| 353 | + matrix = torch.unsqueeze(lengths, dim=-1) | ||
| 354 | + mask = row_vector < matrix | ||
| 355 | + mask = mask.detach() | ||
| 356 | + | ||
| 357 | + return mask.type(dtype).to(device) if device is not None else mask.type(dtype) | ||
| 358 | + | ||
| 359 | + | ||
| 360 | +class SenseVoiceEncoderSmall(nn.Module): | ||
| 361 | + def __init__(self): | ||
| 362 | + super().__init__() | ||
| 363 | + self.input_size = 80 * 7 | ||
| 364 | + self.output_size = 512 | ||
| 365 | + self.attention_heads = 4 | ||
| 366 | + self.linear_units = 2048 | ||
| 367 | + self.num_blocks = 50 | ||
| 368 | + self.tp_blocks = 20 | ||
| 369 | + self.input_layer = "pe" | ||
| 370 | + self.pos_enc_class = "SinusoidalPositionEncoder" | ||
| 371 | + self.normalize_before = True | ||
| 372 | + self.kernel_size = 11 | ||
| 373 | + self.sanm_shfit = 0 | ||
| 374 | + self.concat_after = False | ||
| 375 | + self.positionwise_layer_type = "linear" | ||
| 376 | + self.positionwise_conv_kernel_size = 1 | ||
| 377 | + self.padding_idx = -1 | ||
| 378 | + self.selfattention_layer_type = "sanm" | ||
| 379 | + self.dropout_rate = 0.1 | ||
| 380 | + self.attention_dropout_rate = 0.1 | ||
| 381 | + | ||
| 382 | + self._output_size = self.output_size | ||
| 383 | + | ||
| 384 | + self.embed = SinusoidalPositionEncoder() | ||
| 385 | + | ||
| 386 | + positionwise_layer = PositionwiseFeedForward | ||
| 387 | + positionwise_layer_args = ( | ||
| 388 | + self.output_size, | ||
| 389 | + self.linear_units, | ||
| 390 | + self.dropout_rate, | ||
| 391 | + ) | ||
| 392 | + | ||
| 393 | + encoder_selfattn_layer = MultiHeadedAttentionSANM | ||
| 394 | + encoder_selfattn_layer_args0 = ( | ||
| 395 | + self.attention_heads, | ||
| 396 | + self.input_size, | ||
| 397 | + self.output_size, | ||
| 398 | + self.attention_dropout_rate, | ||
| 399 | + self.kernel_size, | ||
| 400 | + self.sanm_shfit, | ||
| 401 | + ) | ||
| 402 | + encoder_selfattn_layer_args = ( | ||
| 403 | + self.attention_heads, | ||
| 404 | + self.output_size, | ||
| 405 | + self.output_size, | ||
| 406 | + self.attention_dropout_rate, | ||
| 407 | + self.kernel_size, | ||
| 408 | + self.sanm_shfit, | ||
| 409 | + ) | ||
| 410 | + | ||
| 411 | + self.encoders0 = nn.ModuleList( | ||
| 412 | + [ | ||
| 413 | + EncoderLayerSANM( | ||
| 414 | + self.input_size, | ||
| 415 | + self.output_size, | ||
| 416 | + encoder_selfattn_layer(*encoder_selfattn_layer_args0), | ||
| 417 | + positionwise_layer(*positionwise_layer_args), | ||
| 418 | + self.dropout_rate, | ||
| 419 | + ) | ||
| 420 | + for i in range(1) | ||
| 421 | + ] | ||
| 422 | + ) | ||
| 423 | + | ||
| 424 | + self.encoders = nn.ModuleList( | ||
| 425 | + [ | ||
| 426 | + EncoderLayerSANM( | ||
| 427 | + self.output_size, | ||
| 428 | + self.output_size, | ||
| 429 | + encoder_selfattn_layer(*encoder_selfattn_layer_args), | ||
| 430 | + positionwise_layer(*positionwise_layer_args), | ||
| 431 | + self.dropout_rate, | ||
| 432 | + ) | ||
| 433 | + for i in range(self.num_blocks - 1) | ||
| 434 | + ] | ||
| 435 | + ) | ||
| 436 | + | ||
| 437 | + self.tp_encoders = nn.ModuleList( | ||
| 438 | + [ | ||
| 439 | + EncoderLayerSANM( | ||
| 440 | + self.output_size, | ||
| 441 | + self.output_size, | ||
| 442 | + encoder_selfattn_layer(*encoder_selfattn_layer_args), | ||
| 443 | + positionwise_layer(*positionwise_layer_args), | ||
| 444 | + self.dropout_rate, | ||
| 445 | + ) | ||
| 446 | + for i in range(self.tp_blocks) | ||
| 447 | + ] | ||
| 448 | + ) | ||
| 449 | + | ||
| 450 | + self.after_norm = LayerNorm(self.output_size) | ||
| 451 | + | ||
| 452 | + self.tp_norm = LayerNorm(self.output_size) | ||
| 453 | + | ||
| 454 | + def forward( | ||
| 455 | + self, | ||
| 456 | + xs_pad: torch.Tensor, | ||
| 457 | + ): | ||
| 458 | + masks = None | ||
| 459 | + | ||
| 460 | + xs_pad *= self.output_size**0.5 | ||
| 461 | + | ||
| 462 | + xs_pad = self.embed(xs_pad) | ||
| 463 | + | ||
| 464 | + # forward encoder1 | ||
| 465 | + for layer_idx, encoder_layer in enumerate(self.encoders0): | ||
| 466 | + encoder_outs = encoder_layer(xs_pad, masks) | ||
| 467 | + xs_pad, masks = encoder_outs[0], encoder_outs[1] | ||
| 468 | + | ||
| 469 | + for layer_idx, encoder_layer in enumerate(self.encoders): | ||
| 470 | + encoder_outs = encoder_layer(xs_pad, masks) | ||
| 471 | + xs_pad, masks = encoder_outs[0], encoder_outs[1] | ||
| 472 | + | ||
| 473 | + xs_pad = self.after_norm(xs_pad) | ||
| 474 | + | ||
| 475 | + for layer_idx, encoder_layer in enumerate(self.tp_encoders): | ||
| 476 | + encoder_outs = encoder_layer(xs_pad, masks) | ||
| 477 | + xs_pad, masks = encoder_outs[0], encoder_outs[1] | ||
| 478 | + | ||
| 479 | + xs_pad = self.tp_norm(xs_pad) | ||
| 480 | + return xs_pad | ||
| 481 | + | ||
| 482 | + | ||
| 483 | +class CTC(nn.Module): | ||
| 484 | + def __init__( | ||
| 485 | + self, | ||
| 486 | + odim: int, | ||
| 487 | + encoder_output_size: int, | ||
| 488 | + dropout_rate: float = 0.0, | ||
| 489 | + ctc_type: str = "builtin", | ||
| 490 | + reduce: bool = True, | ||
| 491 | + ignore_nan_grad: bool = True, | ||
| 492 | + extra_linear: bool = True, | ||
| 493 | + ): | ||
| 494 | + super().__init__() | ||
| 495 | + eprojs = encoder_output_size | ||
| 496 | + self.dropout_rate = dropout_rate | ||
| 497 | + | ||
| 498 | + if extra_linear: | ||
| 499 | + self.ctc_lo = torch.nn.Linear(eprojs, odim) | ||
| 500 | + else: | ||
| 501 | + self.ctc_lo = None | ||
| 502 | + | ||
| 503 | + def softmax(self, hs_pad): | ||
| 504 | + """softmax of frame activations | ||
| 505 | + | ||
| 506 | + Args: | ||
| 507 | + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | ||
| 508 | + Returns: | ||
| 509 | + torch.Tensor: softmax applied 3d tensor (B, Tmax, odim) | ||
| 510 | + """ | ||
| 511 | + if self.ctc_lo is not None: | ||
| 512 | + return F.softmax(self.ctc_lo(hs_pad), dim=2) | ||
| 513 | + else: | ||
| 514 | + return F.softmax(hs_pad, dim=2) | ||
| 515 | + | ||
| 516 | + def log_softmax(self, hs_pad): | ||
| 517 | + """log_softmax of frame activations | ||
| 518 | + | ||
| 519 | + Args: | ||
| 520 | + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | ||
| 521 | + Returns: | ||
| 522 | + torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) | ||
| 523 | + """ | ||
| 524 | + if self.ctc_lo is not None: | ||
| 525 | + return F.log_softmax(self.ctc_lo(hs_pad), dim=2) | ||
| 526 | + else: | ||
| 527 | + return F.log_softmax(hs_pad, dim=2) | ||
| 528 | + | ||
| 529 | + def argmax(self, hs_pad): | ||
| 530 | + """argmax of frame activations | ||
| 531 | + | ||
| 532 | + Args: | ||
| 533 | + torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | ||
| 534 | + Returns: | ||
| 535 | + torch.Tensor: argmax applied 2d tensor (B, Tmax) | ||
| 536 | + """ | ||
| 537 | + if self.ctc_lo is not None: | ||
| 538 | + return torch.argmax(self.ctc_lo(hs_pad), dim=2) | ||
| 539 | + else: | ||
| 540 | + return torch.argmax(hs_pad, dim=2) | ||
| 541 | + | ||
| 542 | + | ||
| 543 | +class SenseVoiceSmall(nn.Module): | ||
| 544 | + def __init__(self, neg_mean: torch.Tensor, inv_stddev: torch.Tensor): | ||
| 545 | + super().__init__() | ||
| 546 | + self.sos = 1 | ||
| 547 | + self.eos = 2 | ||
| 548 | + self.length_normalized_loss = True | ||
| 549 | + self.ignore_id = -1 | ||
| 550 | + self.blank_id = 0 | ||
| 551 | + self.input_size = 80 * 7 | ||
| 552 | + self.vocab_size = 25055 | ||
| 553 | + | ||
| 554 | + self.neg_mean = neg_mean.unsqueeze(0).unsqueeze(0) | ||
| 555 | + self.inv_stddev = inv_stddev.unsqueeze(0).unsqueeze(0) | ||
| 556 | + | ||
| 557 | + self.lid_dict = { | ||
| 558 | + "auto": 0, | ||
| 559 | + "zh": 3, | ||
| 560 | + "en": 4, | ||
| 561 | + "yue": 7, | ||
| 562 | + "ja": 11, | ||
| 563 | + "ko": 12, | ||
| 564 | + "nospeech": 13, | ||
| 565 | + } | ||
| 566 | + self.lid_int_dict = { | ||
| 567 | + 24884: 3, | ||
| 568 | + 24885: 4, | ||
| 569 | + 24888: 7, | ||
| 570 | + 24892: 11, | ||
| 571 | + 24896: 12, | ||
| 572 | + 24992: 13, | ||
| 573 | + } | ||
| 574 | + self.textnorm_dict = {"withitn": 14, "woitn": 15} | ||
| 575 | + self.textnorm_int_dict = {25016: 14, 25017: 15} | ||
| 576 | + | ||
| 577 | + self.emo_dict = { | ||
| 578 | + "unk": 25009, | ||
| 579 | + "happy": 25001, | ||
| 580 | + "sad": 25002, | ||
| 581 | + "angry": 25003, | ||
| 582 | + "neutral": 25004, | ||
| 583 | + } | ||
| 584 | + | ||
| 585 | + self.encoder = SenseVoiceEncoderSmall() | ||
| 586 | + self.ctc = CTC( | ||
| 587 | + odim=self.vocab_size, | ||
| 588 | + encoder_output_size=self.encoder.output_size, | ||
| 589 | + ) | ||
| 590 | + self.embed = torch.nn.Embedding( | ||
| 591 | + 7 + len(self.lid_dict) + len(self.textnorm_dict), self.input_size | ||
| 592 | + ) | ||
| 593 | + | ||
| 594 | + def forward(self, x, prompt): | ||
| 595 | + input_query = self.embed(prompt).unsqueeze(0) | ||
| 596 | + | ||
| 597 | + # for export, we always assume x and self.neg_mean are on CPU | ||
| 598 | + x = (x + self.neg_mean) * self.inv_stddev | ||
| 599 | + x = torch.cat((input_query, x), dim=1) | ||
| 600 | + | ||
| 601 | + encoder_out = self.encoder(x) | ||
| 602 | + logits = self.ctc.ctc_lo(encoder_out) | ||
| 603 | + | ||
| 604 | + return logits |
| @@ -173,6 +173,8 @@ list(APPEND sources | @@ -173,6 +173,8 @@ list(APPEND sources | ||
| 173 | ) | 173 | ) |
| 174 | if(SHERPA_ONNX_ENABLE_RKNN) | 174 | if(SHERPA_ONNX_ENABLE_RKNN) |
| 175 | list(APPEND sources | 175 | list(APPEND sources |
| 176 | + ./rknn/offline-ctc-greedy-search-decoder-rknn.cc | ||
| 177 | + ./rknn/offline-sense-voice-model-rknn.cc | ||
| 176 | ./rknn/online-stream-rknn.cc | 178 | ./rknn/online-stream-rknn.cc |
| 177 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc | 179 | ./rknn/online-transducer-greedy-search-decoder-rknn.cc |
| 178 | ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc | 180 | ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | 7 | ||
| 8 | #include "sherpa-onnx/csrc/file-utils.h" | 8 | #include "sherpa-onnx/csrc/file-utils.h" |
| 9 | #include "sherpa-onnx/csrc/macros.h" | 9 | #include "sherpa-onnx/csrc/macros.h" |
| 10 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 10 | 11 | ||
| 11 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 12 | 13 | ||
| @@ -57,10 +58,38 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -57,10 +58,38 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 57 | } | 58 | } |
| 58 | 59 | ||
| 59 | bool OfflineModelConfig::Validate() const { | 60 | bool OfflineModelConfig::Validate() const { |
| 61 | + // For RK NPU, we reinterpret num_threads: | ||
| 62 | + // | ||
| 63 | + // For RK3588 only | ||
| 64 | + // num_threads == 1 -> Select a core randomly | ||
| 65 | + // num_threads == 0 -> Use NPU core 0 | ||
| 66 | + // num_threads == -1 -> Use NPU core 1 | ||
| 67 | + // num_threads == -2 -> Use NPU core 2 | ||
| 68 | + // num_threads == -3 -> Use NPU core 0 and core 1 | ||
| 69 | + // num_threads == -4 -> Use NPU core 0, core 1, and core 2 | ||
| 70 | + if (provider != "rknn") { | ||
| 60 | if (num_threads < 1) { | 71 | if (num_threads < 1) { |
| 61 | SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | 72 | SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); |
| 62 | return false; | 73 | return false; |
| 63 | } | 74 | } |
| 75 | + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".rknn"))) { | ||
| 76 | + SHERPA_ONNX_LOGE( | ||
| 77 | + "--provider is %s, which is not rknn, but you pass a rknn model " | ||
| 78 | + "filename. model: '%s'", | ||
| 79 | + provider.c_str(), sense_voice.model.c_str()); | ||
| 80 | + return false; | ||
| 81 | + } | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + if (provider == "rknn") { | ||
| 85 | + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".onnx"))) { | ||
| 86 | + SHERPA_ONNX_LOGE( | ||
| 87 | + "--provider is rknn, but you pass an onnx model " | ||
| 88 | + "filename. model: '%s'", | ||
| 89 | + sense_voice.model.c_str()); | ||
| 90 | + return false; | ||
| 91 | + } | ||
| 92 | + } | ||
| 64 | 93 | ||
| 65 | if (!FileExists(tokens)) { | 94 | if (!FileExists(tokens)) { |
| 66 | SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); | 95 | SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); |
| @@ -35,10 +35,32 @@ | @@ -35,10 +35,32 @@ | ||
| 35 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" | 35 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" |
| 36 | #include "sherpa-onnx/csrc/text-utils.h" | 36 | #include "sherpa-onnx/csrc/text-utils.h" |
| 37 | 37 | ||
| 38 | +#if SHERPA_ONNX_ENABLE_RKNN | ||
| 39 | +#include "sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h" | ||
| 40 | +#endif | ||
| 41 | + | ||
| 38 | namespace sherpa_onnx { | 42 | namespace sherpa_onnx { |
| 39 | 43 | ||
| 40 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 44 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 41 | const OfflineRecognizerConfig &config) { | 45 | const OfflineRecognizerConfig &config) { |
| 46 | + if (config.model_config.provider == "rknn") { | ||
| 47 | +#if SHERPA_ONNX_ENABLE_RKNN | ||
| 48 | + if (config.model_config.sense_voice.model.empty()) { | ||
| 49 | + SHERPA_ONNX_LOGE( | ||
| 50 | + "Only SenseVoice models are currently supported " | ||
| 51 | + "by rknn for non-streaming ASR. Fallback to CPU"); | ||
| 52 | + } else if (!config.model_config.sense_voice.model.empty()) { | ||
| 53 | + return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(config); | ||
| 54 | + } | ||
| 55 | +#else | ||
| 56 | + SHERPA_ONNX_LOGE( | ||
| 57 | + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " | ||
| 58 | + "want to use rknn."); | ||
| 59 | + SHERPA_ONNX_EXIT(-1); | ||
| 60 | + return nullptr; | ||
| 61 | +#endif | ||
| 62 | + } | ||
| 63 | + | ||
| 42 | if (!config.model_config.sense_voice.model.empty()) { | 64 | if (!config.model_config.sense_voice.model.empty()) { |
| 43 | return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config); | 65 | return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config); |
| 44 | } | 66 | } |
| @@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 229 | template <typename Manager> | 251 | template <typename Manager> |
| 230 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 252 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 231 | Manager *mgr, const OfflineRecognizerConfig &config) { | 253 | Manager *mgr, const OfflineRecognizerConfig &config) { |
| 254 | + if (config.model_config.provider == "rknn") { | ||
| 255 | +#if SHERPA_ONNX_ENABLE_RKNN | ||
| 256 | + if (config.model_config.sense_voice.model.empty()) { | ||
| 257 | + SHERPA_ONNX_LOGE( | ||
| 258 | + "Only SenseVoice models are currently supported " | ||
| 259 | + "by rknn for non-streaming ASR. Fallback to CPU"); | ||
| 260 | + } else if (!config.model_config.sense_voice.model.empty()) { | ||
| 261 | + return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(mgr, config); | ||
| 262 | + } | ||
| 263 | +#else | ||
| 264 | + SHERPA_ONNX_LOGE( | ||
| 265 | + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " | ||
| 266 | + "want to use rknn."); | ||
| 267 | + SHERPA_ONNX_EXIT(-1); | ||
| 268 | + return nullptr; | ||
| 269 | +#endif | ||
| 270 | + } | ||
| 271 | + | ||
| 232 | if (!config.model_config.sense_voice.model.empty()) { | 272 | if (!config.model_config.sense_voice.model.empty()) { |
| 233 | return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config); | 273 | return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config); |
| 234 | } | 274 | } |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include <utility> | 11 | #include <utility> |
| 12 | #include <vector> | 12 | #include <vector> |
| 13 | 13 | ||
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" | 15 | #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" |
| 15 | #include "sherpa-onnx/csrc/offline-model-config.h" | 16 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 16 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 17 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| @@ -21,7 +22,7 @@ | @@ -21,7 +22,7 @@ | ||
| 21 | 22 | ||
| 22 | namespace sherpa_onnx { | 23 | namespace sherpa_onnx { |
| 23 | 24 | ||
| 24 | -static OfflineRecognitionResult ConvertSenseVoiceResult( | 25 | +OfflineRecognitionResult ConvertSenseVoiceResult( |
| 25 | const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, | 26 | const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, |
| 26 | int32_t frame_shift_ms, int32_t subsampling_factor) { | 27 | int32_t frame_shift_ms, int32_t subsampling_factor) { |
| 27 | OfflineRecognitionResult r; | 28 | OfflineRecognitionResult r; |
| @@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | @@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | ||
| 72 | } else { | 73 | } else { |
| 73 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | 74 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", |
| 74 | config.decoding_method.c_str()); | 75 | config.decoding_method.c_str()); |
| 75 | - exit(-1); | 76 | + SHERPA_ONNX_EXIT(-1); |
| 76 | } | 77 | } |
| 77 | 78 | ||
| 78 | InitFeatConfig(); | 79 | InitFeatConfig(); |
| @@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | @@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | ||
| 93 | } else { | 94 | } else { |
| 94 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | 95 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", |
| 95 | config.decoding_method.c_str()); | 96 | config.decoding_method.c_str()); |
| 96 | - exit(-1); | 97 | + SHERPA_ONNX_EXIT(-1); |
| 97 | } | 98 | } |
| 98 | 99 | ||
| 99 | InitFeatConfig(); | 100 | InitFeatConfig(); |
| @@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 37 | const OnlineRecognizerConfig &config) { | 37 | const OnlineRecognizerConfig &config) { |
| 38 | if (config.model_config.provider_config.provider == "rknn") { | 38 | if (config.model_config.provider_config.provider == "rknn") { |
| 39 | #if SHERPA_ONNX_ENABLE_RKNN | 39 | #if SHERPA_ONNX_ENABLE_RKNN |
| 40 | - // Currently, only zipformer v1 is suported for rknn | ||
| 41 | if (config.model_config.transducer.encoder.empty() && | 40 | if (config.model_config.transducer.encoder.empty() && |
| 42 | config.model_config.zipformer2_ctc.model.empty()) { | 41 | config.model_config.zipformer2_ctc.model.empty()) { |
| 43 | SHERPA_ONNX_LOGE( | 42 | SHERPA_ONNX_LOGE( |
| 1 | +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +OfflineCtcDecoderResult OfflineCtcGreedySearchDecoderRknn::Decode( | ||
| 16 | + const float *logits, int32_t num_frames, int32_t vocab_size) { | ||
| 17 | + OfflineCtcDecoderResult ans; | ||
| 18 | + | ||
| 19 | + int64_t prev_id = -1; | ||
| 20 | + | ||
| 21 | + for (int32_t t = 0; t != num_frames; ++t) { | ||
| 22 | + auto y = static_cast<int64_t>(std::distance( | ||
| 23 | + static_cast<const float *>(logits), | ||
| 24 | + std::max_element(static_cast<const float *>(logits), | ||
| 25 | + static_cast<const float *>(logits) + vocab_size))); | ||
| 26 | + logits += vocab_size; | ||
| 27 | + | ||
| 28 | + if (y != blank_id_ && y != prev_id) { | ||
| 29 | + ans.tokens.push_back(y); | ||
| 30 | + ans.timestamps.push_back(t); | ||
| 31 | + } | ||
| 32 | + prev_id = y; | ||
| 33 | + } // for (int32_t t = 0; ...) | ||
| 34 | + | ||
| 35 | + return ans; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineCtcGreedySearchDecoderRknn { | ||
| 15 | + public: | ||
| 16 | + explicit OfflineCtcGreedySearchDecoderRknn(int32_t blank_id) | ||
| 17 | + : blank_id_(blank_id) {} | ||
| 18 | + | ||
| 19 | + OfflineCtcDecoderResult Decode(const float *logits, int32_t num_frames, | ||
| 20 | + int32_t vocab_size); | ||
| 21 | + | ||
| 22 | + private: | ||
| 23 | + int32_t blank_id_; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-sense-voice-rknn-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 16 | +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h" | ||
| 17 | +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h" | ||
| 18 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +// defined in ../online-recognizer-sense-voice-impl.h | ||
| 23 | +OfflineRecognitionResult ConvertSenseVoiceResult( | ||
| 24 | + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, | ||
| 25 | + int32_t frame_shift_ms, int32_t subsampling_factor); | ||
| 26 | + | ||
| 27 | +class OfflineRecognizerSenseVoiceRknnImpl : public OfflineRecognizerImpl { | ||
| 28 | + public: | ||
| 29 | + explicit OfflineRecognizerSenseVoiceRknnImpl( | ||
| 30 | + const OfflineRecognizerConfig &config) | ||
| 31 | + : OfflineRecognizerImpl(config), | ||
| 32 | + config_(config), | ||
| 33 | + symbol_table_(config_.model_config.tokens), | ||
| 34 | + model_( | ||
| 35 | + std::make_unique<OfflineSenseVoiceModelRknn>(config.model_config)) { | ||
| 36 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 37 | + if (config.decoding_method == "greedy_search") { | ||
| 38 | + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>( | ||
| 39 | + meta_data.blank_id); | ||
| 40 | + } else { | ||
| 41 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 42 | + config.decoding_method.c_str()); | ||
| 43 | + SHERPA_ONNX_EXIT(-1); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + InitFeatConfig(); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + template <typename Manager> | ||
| 50 | + OfflineRecognizerSenseVoiceRknnImpl(Manager *mgr, | ||
| 51 | + const OfflineRecognizerConfig &config) | ||
| 52 | + : OfflineRecognizerImpl(mgr, config), | ||
| 53 | + config_(config), | ||
| 54 | + symbol_table_(mgr, config_.model_config.tokens), | ||
| 55 | + model_(std::make_unique<OfflineSenseVoiceModelRknn>( | ||
| 56 | + mgr, config.model_config)) { | ||
| 57 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 58 | + if (config.decoding_method == "greedy_search") { | ||
| 59 | + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>( | ||
| 60 | + meta_data.blank_id); | ||
| 61 | + } else { | ||
| 62 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 63 | + config.decoding_method.c_str()); | ||
| 64 | + SHERPA_ONNX_EXIT(-1); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + InitFeatConfig(); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 71 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 75 | + for (int32_t i = 0; i < n; ++i) { | ||
| 76 | + DecodeOneStream(ss[i]); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + OfflineRecognizerConfig GetConfig() const override { return config_; } | ||
| 81 | + | ||
| 82 | + private: | ||
| 83 | + void InitFeatConfig() { | ||
| 84 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 85 | + | ||
| 86 | + config_.feat_config.normalize_samples = meta_data.normalize_samples; | ||
| 87 | + config_.feat_config.window_type = "hamming"; | ||
| 88 | + config_.feat_config.high_freq = 0; | ||
| 89 | + config_.feat_config.snip_edges = true; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + void DecodeOneStream(OfflineStream *s) const { | ||
| 93 | + const auto &meta_data = model_->GetModelMetadata(); | ||
| 94 | + | ||
| 95 | + std::vector<float> f = s->GetFrames(); | ||
| 96 | + | ||
| 97 | + int32_t language = 0; | ||
| 98 | + if (config_.model_config.sense_voice.language.empty()) { | ||
| 99 | + language = 0; | ||
| 100 | + } else if (meta_data.lang2id.count( | ||
| 101 | + config_.model_config.sense_voice.language)) { | ||
| 102 | + language = | ||
| 103 | + meta_data.lang2id.at(config_.model_config.sense_voice.language); | ||
| 104 | + } else { | ||
| 105 | + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", | ||
| 106 | + config_.model_config.sense_voice.language.c_str()); | ||
| 107 | + } | ||
| 108 | + | ||
| 109 | + int32_t text_norm = config_.model_config.sense_voice.use_itn | ||
| 110 | + ? meta_data.with_itn_id | ||
| 111 | + : meta_data.without_itn_id; | ||
| 112 | + | ||
| 113 | + std::vector<float> logits = model_->Run(std::move(f), language, text_norm); | ||
| 114 | + int32_t num_out_frames = logits.size() / meta_data.vocab_size; | ||
| 115 | + | ||
| 116 | + auto result = | ||
| 117 | + decoder_->Decode(logits.data(), num_out_frames, meta_data.vocab_size); | ||
| 118 | + | ||
| 119 | + int32_t frame_shift_ms = 10; | ||
| 120 | + int32_t subsampling_factor = meta_data.window_shift; | ||
| 121 | + auto r = ConvertSenseVoiceResult(result, symbol_table_, frame_shift_ms, | ||
| 122 | + subsampling_factor); | ||
| 123 | + | ||
| 124 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 125 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 126 | + s->SetResult(r); | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + private: | ||
| 130 | + OfflineRecognizerConfig config_; | ||
| 131 | + SymbolTable symbol_table_; | ||
| 132 | + std::unique_ptr<OfflineSenseVoiceModelRknn> model_; | ||
| 133 | + std::unique_ptr<OfflineCtcGreedySearchDecoderRknn> decoder_; | ||
| 134 | +}; | ||
| 135 | + | ||
| 136 | +} // namespace sherpa_onnx | ||
| 137 | + | ||
| 138 | +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <array> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#if __OHOS__ | ||
| 18 | +#include "rawfile/raw_file_manager.h" | ||
| 19 | +#endif | ||
| 20 | + | ||
| 21 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 22 | +#include "sherpa-onnx/csrc/rknn/macros.h" | ||
| 23 | +#include "sherpa-onnx/csrc/rknn/utils.h" | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +class OfflineSenseVoiceModelRknn::Impl { | ||
| 28 | + public: | ||
| 29 | + ~Impl() { | ||
| 30 | + auto ret = rknn_destroy(ctx_); | ||
| 31 | + if (ret != RKNN_SUCC) { | ||
| 32 | + SHERPA_ONNX_LOGE("Failed to destroy the context"); | ||
| 33 | + } | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + explicit Impl(const OfflineModelConfig &config) : config_(config) { | ||
| 37 | + { | ||
| 38 | + auto buf = ReadFile(config_.sense_voice.model); | ||
| 39 | + Init(buf.data(), buf.size()); | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + SetCoreMask(ctx_, config_.num_threads); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + template <typename Manager> | ||
| 46 | + Impl(Manager *mgr, const OfflineModelConfig &config) : config_(config) { | ||
| 47 | + { | ||
| 48 | + auto buf = ReadFile(mgr, config_.sense_voice.model); | ||
| 49 | + Init(buf.data(), buf.size()); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + SetCoreMask(ctx_, config_.num_threads); | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const { | ||
| 56 | + return meta_data_; | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + std::vector<float> Run(std::vector<float> features, int32_t language, | ||
| 60 | + int32_t text_norm) { | ||
| 61 | + features = ApplyLFR(std::move(features)); | ||
| 62 | + | ||
| 63 | + std::vector<rknn_input> inputs(input_attrs_.size()); | ||
| 64 | + | ||
| 65 | + std::array<int32_t, 4> prompt{language, 1, 2, text_norm}; | ||
| 66 | + | ||
| 67 | + inputs[0].index = input_attrs_[0].index; | ||
| 68 | + inputs[0].type = RKNN_TENSOR_FLOAT32; | ||
| 69 | + inputs[0].fmt = input_attrs_[0].fmt; | ||
| 70 | + inputs[0].buf = reinterpret_cast<void *>(features.data()); | ||
| 71 | + inputs[0].size = features.size() * sizeof(float); | ||
| 72 | + | ||
| 73 | + inputs[1].index = input_attrs_[1].index; | ||
| 74 | + inputs[1].type = RKNN_TENSOR_INT32; | ||
| 75 | + inputs[1].fmt = input_attrs_[1].fmt; | ||
| 76 | + inputs[1].buf = reinterpret_cast<void *>(prompt.data()); | ||
| 77 | + inputs[1].size = prompt.size() * sizeof(int32_t); | ||
| 78 | + | ||
| 79 | + std::vector<float> out(output_attrs_[0].n_elems); | ||
| 80 | + | ||
| 81 | + std::vector<rknn_output> outputs(output_attrs_.size()); | ||
| 82 | + outputs[0].index = output_attrs_[0].index; | ||
| 83 | + outputs[0].is_prealloc = 1; | ||
| 84 | + outputs[0].want_float = 1; | ||
| 85 | + outputs[0].size = out.size() * sizeof(float); | ||
| 86 | + outputs[0].buf = reinterpret_cast<void *>(out.data()); | ||
| 87 | + | ||
| 88 | + rknn_context ctx = 0; | ||
| 89 | + auto ret = rknn_dup_context(&ctx_, &ctx); | ||
| 90 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx"); | ||
| 91 | + | ||
| 92 | + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data()); | ||
| 93 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); | ||
| 94 | + | ||
| 95 | + ret = rknn_run(ctx, nullptr); | ||
| 96 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); | ||
| 97 | + | ||
| 98 | + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr); | ||
| 99 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); | ||
| 100 | + | ||
| 101 | + rknn_destroy(ctx); | ||
| 102 | + | ||
| 103 | + return out; | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + private: | ||
| 107 | + void Init(void *model_data, size_t model_data_length) { | ||
| 108 | + InitContext(model_data, model_data_length, config_.debug, &ctx_); | ||
| 109 | + | ||
| 110 | + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); | ||
| 111 | + | ||
| 112 | + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); | ||
| 113 | + | ||
| 114 | + auto meta = Parse(custom_string, config_.debug); | ||
| 115 | + | ||
| 116 | +#define SHERPA_ONNX_RKNN_READ_META_DATA_INT(dst, src_key) \ | ||
| 117 | + do { \ | ||
| 118 | + if (!meta.count(#src_key)) { \ | ||
| 119 | + SHERPA_ONNX_LOGE("'%s' does not exist in the custom_string", #src_key); \ | ||
| 120 | + SHERPA_ONNX_EXIT(-1); \ | ||
| 121 | + } \ | ||
| 122 | + \ | ||
| 123 | + dst = atoi(meta.at(#src_key).c_str()); \ | ||
| 124 | + } while (0) | ||
| 125 | + | ||
| 126 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.with_itn_id, with_itn); | ||
| 127 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.without_itn_id, without_itn); | ||
| 128 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_size, | ||
| 129 | + lfr_window_size); | ||
| 130 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_shift, | ||
| 131 | + lfr_window_shift); | ||
| 132 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.vocab_size, vocab_size); | ||
| 133 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.normalize_samples, | ||
| 134 | + normalize_samples); | ||
| 135 | + | ||
| 136 | + int32_t lang_auto = 0; | ||
| 137 | + int32_t lang_zh = 0; | ||
| 138 | + int32_t lang_en = 0; | ||
| 139 | + int32_t lang_ja = 0; | ||
| 140 | + int32_t lang_ko = 0; | ||
| 141 | + int32_t lang_yue = 0; | ||
| 142 | + | ||
| 143 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_auto, lang_auto); | ||
| 144 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_zh, lang_zh); | ||
| 145 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_en, lang_en); | ||
| 146 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ja, lang_ja); | ||
| 147 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ko, lang_ko); | ||
| 148 | + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_yue, lang_yue); | ||
| 149 | + | ||
| 150 | + meta_data_.lang2id = { | ||
| 151 | + {"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en}, | ||
| 152 | + {"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue}, | ||
| 153 | + }; | ||
| 154 | + | ||
| 155 | + // for rknn models, neg_mean and inv_stddev are stored inside the model | ||
| 156 | + | ||
| 157 | +#undef SHERPA_ONNX_RKNN_READ_META_DATA_INT | ||
| 158 | + | ||
| 159 | + num_input_frames_ = input_attrs_[0].dims[1]; | ||
| 160 | + } | ||
| 161 | + | ||
| 162 | + std::vector<float> ApplyLFR(std::vector<float> in) const { | ||
| 163 | + int32_t lfr_window_size = meta_data_.window_size; | ||
| 164 | + int32_t lfr_window_shift = meta_data_.window_shift; | ||
| 165 | + int32_t in_feat_dim = 80; | ||
| 166 | + | ||
| 167 | + int32_t in_num_frames = in.size() / in_feat_dim; | ||
| 168 | + int32_t out_num_frames = | ||
| 169 | + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; | ||
| 170 | + | ||
| 171 | + if (out_num_frames > num_input_frames_) { | ||
| 172 | + SHERPA_ONNX_LOGE( | ||
| 173 | + "Number of input frames %d is too large. Truncate it to %d frames.", | ||
| 174 | + out_num_frames, num_input_frames_); | ||
| 175 | + | ||
| 176 | + SHERPA_ONNX_LOGE( | ||
| 177 | + "Recognition result may be truncated/incomplete. Please select a " | ||
| 178 | + "model accepting longer audios."); | ||
| 179 | + | ||
| 180 | + out_num_frames = num_input_frames_; | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + int32_t out_feat_dim = in_feat_dim * lfr_window_size; | ||
| 184 | + | ||
| 185 | + std::vector<float> out(num_input_frames_ * out_feat_dim); | ||
| 186 | + | ||
| 187 | + const float *p_in = in.data(); | ||
| 188 | + float *p_out = out.data(); | ||
| 189 | + | ||
| 190 | + for (int32_t i = 0; i != out_num_frames; ++i) { | ||
| 191 | + std::copy(p_in, p_in + out_feat_dim, p_out); | ||
| 192 | + | ||
| 193 | + p_out += out_feat_dim; | ||
| 194 | + p_in += lfr_window_shift * in_feat_dim; | ||
| 195 | + } | ||
| 196 | + | ||
| 197 | + return out; | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + private: | ||
| 201 | + OfflineModelConfig config_; | ||
| 202 | + | ||
| 203 | + rknn_context ctx_ = 0; | ||
| 204 | + | ||
| 205 | + std::vector<rknn_tensor_attr> input_attrs_; | ||
| 206 | + std::vector<rknn_tensor_attr> output_attrs_; | ||
| 207 | + | ||
| 208 | + OfflineSenseVoiceModelMetaData meta_data_; | ||
| 209 | + int32_t num_input_frames_ = -1; | ||
| 210 | +}; | ||
| 211 | + | ||
| 212 | +OfflineSenseVoiceModelRknn::~OfflineSenseVoiceModelRknn() = default; | ||
| 213 | + | ||
| 214 | +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( | ||
| 215 | + const OfflineModelConfig &config) | ||
| 216 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 217 | + | ||
| 218 | +template <typename Manager> | ||
| 219 | +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( | ||
| 220 | + Manager *mgr, const OfflineModelConfig &config) | ||
| 221 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 222 | + | ||
| 223 | +std::vector<float> OfflineSenseVoiceModelRknn::Run(std::vector<float> features, | ||
| 224 | + int32_t language, | ||
| 225 | + int32_t text_norm) const { | ||
| 226 | + return impl_->Run(std::move(features), language, text_norm); | ||
| 227 | +} | ||
| 228 | + | ||
| 229 | +const OfflineSenseVoiceModelMetaData & | ||
| 230 | +OfflineSenseVoiceModelRknn::GetModelMetadata() const { | ||
| 231 | + return impl_->GetModelMetadata(); | ||
| 232 | +} | ||
| 233 | + | ||
| 234 | +#if __ANDROID_API__ >= 9 | ||
| 235 | +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( | ||
| 236 | + AAssetManager *mgr, const OfflineModelConfig &config); | ||
| 237 | +#endif | ||
| 238 | + | ||
| 239 | +#if __OHOS__ | ||
| 240 | +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( | ||
| 241 | + NativeResourceManager *mgr, const OfflineModelConfig &config); | ||
| 242 | +#endif | ||
| 243 | + | ||
| 244 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "rknn_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineSenseVoiceModelRknn { | ||
| 18 | + public: | ||
| 19 | + ~OfflineSenseVoiceModelRknn(); | ||
| 20 | + | ||
| 21 | + explicit OfflineSenseVoiceModelRknn(const OfflineModelConfig &config); | ||
| 22 | + | ||
| 23 | + template <typename Manager> | ||
| 24 | + OfflineSenseVoiceModelRknn(Manager *mgr, const OfflineModelConfig &config); | ||
| 25 | + | ||
| 26 | + /** | ||
| 27 | + * @param features A tensor of shape (num_frames, feature_dim) | ||
| 28 | + * before applying LFR. | ||
| 29 | + * @param language | ||
| 30 | + * @param text_norm | ||
| 31 | + * @returns Return a tensor of shape (num_output_frames, vocab_size) | ||
| 32 | + */ | ||
| 33 | + std::vector<float> Run(std::vector<float> features, int32_t language, | ||
| 34 | + int32_t text_norm) const; | ||
| 35 | + | ||
| 36 | + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const; | ||
| 37 | + | ||
| 38 | + private: | ||
| 39 | + class Impl; | ||
| 40 | + std::unique_ptr<Impl> impl_; | ||
| 41 | +}; | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx | ||
| 44 | + | ||
| 45 | +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ |
| @@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 55 | SetCoreMask(ctx_, config_.num_threads); | 55 | SetCoreMask(ctx_, config_.num_threads); |
| 56 | } | 56 | } |
| 57 | 57 | ||
| 58 | - // TODO(fangjun): Support Android | ||
| 59 | - | ||
| 60 | std::vector<std::vector<uint8_t>> GetInitStates() const { | 58 | std::vector<std::vector<uint8_t>> GetInitStates() const { |
| 61 | // input_attrs_[0] is for the feature | 59 | // input_attrs_[0] is for the feature |
| 62 | // input_attrs_[1:] is for states | 60 | // input_attrs_[1:] is for states |
| @@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 89 | SetCoreMask(joiner_ctx_, config_.num_threads); | 89 | SetCoreMask(joiner_ctx_, config_.num_threads); |
| 90 | } | 90 | } |
| 91 | 91 | ||
| 92 | - // TODO(fangjun): Support Android | ||
| 93 | - | ||
| 94 | std::vector<std::vector<uint8_t>> GetEncoderInitStates() const { | 92 | std::vector<std::vector<uint8_t>> GetEncoderInitStates() const { |
| 95 | // encoder_input_attrs_[0] is for the feature | 93 | // encoder_input_attrs_[0] is for the feature |
| 96 | // encoder_input_attrs_[1:] is for states | 94 | // encoder_input_attrs_[1:] is for states |
-
请 注册 或 登录 后发表评论