Fangjun Kuang
Committed by GitHub

Support RK NPU for SenseVoice non-streaming ASR models (#2589)

This PR adds RK NPU support for SenseVoice non-streaming ASR models by implementing a new RKNN backend with greedy CTC decoding.

- Adds offline RKNN implementation for SenseVoice models including model loading, feature processing, and CTC decoding
- Introduces export tools to convert SenseVoice models from PyTorch to ONNX and then to RKNN format
- Implements provider-aware validation to prevent mismatched model and provider usage
@@ -152,3 +152,7 @@ vocab.json @@ -152,3 +152,7 @@ vocab.json
152 *.so 152 *.so
153 sherpa-onnx-streaming-t-one-russian-2025-09-08 153 sherpa-onnx-streaming-t-one-russian-2025-09-08
154 sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10 154 sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10
  155 +am.mvn
  156 +*bpe.model
  157 +config.yaml
  158 +configuration.json
@@ -118,8 +118,11 @@ def display_params(params): @@ -118,8 +118,11 @@ def display_params(params):
118 os.system(f"cat {params['config']}") 118 os.system(f"cat {params['config']}")
119 119
120 120
  121 +@torch.no_grad()
121 def main(): 122 def main():
122 model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu") 123 model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu")
  124 + model.eval()
  125 +
123 display_params(params) 126 display_params(params)
124 127
125 generate_tokens(params) 128 generate_tokens(params)
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +import os
  6 +from typing import Any, Dict, List, Tuple
  7 +
  8 +import onnx
  9 +import sentencepiece as spm
  10 +import torch
  11 +
  12 +from torch_model import SenseVoiceSmall
  13 +
  14 +
  15 +def get_args():
  16 + parser = argparse.ArgumentParser(
  17 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  18 + )
  19 +
  20 + parser.add_argument(
  21 + "--input-len-in-seconds",
  22 + type=int,
  23 + required=True,
  24 + help="""RKNN does not support dynamic shape, so we need to hard-code
  25 + how long the model can process.
  26 + """,
  27 + )
  28 + return parser.parse_args()
  29 +
  30 +
  31 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  32 + """Add meta data to an ONNX model. It is changed in-place.
  33 +
  34 + Args:
  35 + filename:
  36 + Filename of the ONNX model to be changed.
  37 + meta_data:
  38 + Key-value pairs.
  39 + """
  40 + model = onnx.load(filename)
  41 + while len(model.metadata_props):
  42 + model.metadata_props.pop()
  43 +
  44 + for key, value in meta_data.items():
  45 + meta = model.metadata_props.add()
  46 + meta.key = key
  47 + meta.value = str(value)
  48 +
  49 + onnx.save(model, filename)
  50 +
  51 +
  52 +def load_cmvn(filename) -> Tuple[List[float], List[float]]:
  53 + neg_mean = None
  54 + inv_stddev = None
  55 +
  56 + with open(filename) as f:
  57 + for line in f:
  58 + if not line.startswith("<LearnRateCoef>"):
  59 + continue
  60 + t = line.split()[3:-1]
  61 +
  62 + if neg_mean is None:
  63 + neg_mean = list(map(lambda x: float(x), t))
  64 + else:
  65 + inv_stddev = list(map(lambda x: float(x), t))
  66 +
  67 + return neg_mean, inv_stddev
  68 +
  69 +
  70 +def generate_tokens(sp):
  71 + with open("tokens.txt", "w", encoding="utf-8") as f:
  72 + for i in range(sp.vocab_size()):
  73 + f.write(f"{sp.id_to_piece(i)} {i}\n")
  74 + print("saved to tokens.txt")
  75 +
  76 +
  77 +@torch.no_grad()
  78 +def main():
  79 + args = get_args()
  80 + print(vars(args))
  81 +
  82 + sp = spm.SentencePieceProcessor()
  83 + sp.load("./chn_jpn_yue_eng_ko_spectok.bpe.model")
  84 + vocab_size = sp.vocab_size()
  85 + generate_tokens(sp)
  86 +
  87 + print("loading model")
  88 +
  89 + state_dict = torch.load("./model.pt")
  90 + if "state_dict" in state_dict:
  91 + state_dict = state_dict["state_dict"]
  92 +
  93 + neg_mean, inv_stddev = load_cmvn("./am.mvn")
  94 +
  95 + neg_mean = torch.tensor(neg_mean, dtype=torch.float32)
  96 + inv_stddev = torch.tensor(inv_stddev, dtype=torch.float32)
  97 +
  98 + model = SenseVoiceSmall(neg_mean=neg_mean, inv_stddev=inv_stddev)
  99 + model.load_state_dict(state_dict)
  100 + model.eval()
  101 + del state_dict
  102 +
  103 + lfr_window_size = 7
  104 + lfr_window_shift = 6
  105 +
  106 + # frame shift is 10ms, 1 second has about 100 feature frames
  107 + input_len_in_seconds = int(args.input_len_in_seconds)
  108 + num_frames = input_len_in_seconds * 100
  109 + print("num_frames", num_frames)
  110 +
  111 + # num_input_frames is an approximate number
  112 + num_input_frames = int(num_frames / lfr_window_shift + 0.5)
  113 + print("num_input_frames", num_input_frames)
  114 +
  115 + x = torch.randn(1, num_input_frames, 560, dtype=torch.float32)
  116 +
  117 + language = 3
  118 + text_norm = 15
  119 + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
  120 +
  121 + opset_version = 13
  122 + filename = f"model-{input_len_in_seconds}-seconds.onnx"
  123 + torch.onnx.export(
  124 + model,
  125 + (x, prompt),
  126 + filename,
  127 + opset_version=opset_version,
  128 + input_names=["x", "prompt"],
  129 + output_names=["logits"],
  130 + dynamic_axes={},
  131 + )
  132 +
  133 + model_author = os.environ.get("model_author", "iic")
  134 + comment = os.environ.get("comment", "iic/SenseVoiceSmall")
  135 + url = os.environ.get("url", "https://huggingface.co/FunAudioLLM/SenseVoiceSmall")
  136 +
  137 + meta_data = {
  138 + "lfr_window_size": lfr_window_size,
  139 + "lfr_window_shift": lfr_window_shift,
  140 + "num_input_frames": num_input_frames,
  141 + "normalize_samples": 0, # input should be in the range [-32768, 32767]
  142 + "model_type": "sense_voice_ctc",
  143 + "version": "1",
  144 + "model_author": model_author,
  145 + "maintainer": "k2-fsa",
  146 + "vocab_size": vocab_size,
  147 + "comment": comment,
  148 + "lang_auto": model.lid_dict["auto"],
  149 + "lang_zh": model.lid_dict["zh"],
  150 + "lang_en": model.lid_dict["en"],
  151 + "lang_yue": model.lid_dict["yue"], # cantonese
  152 + "lang_ja": model.lid_dict["ja"],
  153 + "lang_ko": model.lid_dict["ko"],
  154 + "lang_nospeech": model.lid_dict["nospeech"],
  155 + "with_itn": model.textnorm_dict["withitn"],
  156 + "without_itn": model.textnorm_dict["woitn"],
  157 + "url": url,
  158 + }
  159 + add_meta_data(filename=filename, meta_data=meta_data)
  160 +
  161 +
  162 +if __name__ == "__main__":
  163 + torch.manual_seed(20250717)
  164 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +import logging
  6 +from pathlib import Path
  7 +
  8 +from rknn.api import RKNN
  9 +
  10 +logging.basicConfig(level=logging.WARNING)
  11 +
  12 +g_platforms = [
  13 + # "rv1103",
  14 + # "rv1103b",
  15 + # "rv1106",
  16 + # "rk2118",
  17 + "rk3562",
  18 + "rk3566",
  19 + "rk3568",
  20 + "rk3576",
  21 + "rk3588",
  22 +]
  23 +
  24 +
  25 +def get_parser():
  26 + parser = argparse.ArgumentParser(
  27 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  28 + )
  29 +
  30 + parser.add_argument(
  31 + "--target-platform",
  32 + type=str,
  33 + required=True,
  34 + help=f"Supported values are: {','.join(g_platforms)}",
  35 + )
  36 +
  37 + parser.add_argument(
  38 + "--in-model",
  39 + type=str,
  40 + required=True,
  41 + help="Path to the input onnx model",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--out-model",
  46 + type=str,
  47 + required=True,
  48 + help="Path to the output rknn model",
  49 + )
  50 +
  51 + return parser
  52 +
  53 +
  54 +def get_meta_data(model: str):
  55 + import onnxruntime
  56 +
  57 + session_opts = onnxruntime.SessionOptions()
  58 + session_opts.inter_op_num_threads = 1
  59 + session_opts.intra_op_num_threads = 1
  60 +
  61 + m = onnxruntime.InferenceSession(
  62 + model,
  63 + sess_options=session_opts,
  64 + providers=["CPUExecutionProvider"],
  65 + )
  66 +
  67 + for i in m.get_inputs():
  68 + print(i)
  69 +
  70 + print("-----")
  71 +
  72 + for i in m.get_outputs():
  73 + print(i)
  74 + print()
  75 +
  76 + meta = m.get_modelmeta().custom_metadata_map
  77 + s = ""
  78 + sep = ""
  79 + for key, value in meta.items():
  80 + if key in ("neg_mean", "inv_stddev"):
  81 + continue
  82 + s = s + sep + f"{key}={value}"
  83 + sep = ";"
  84 + assert len(s) < 1024, len(s)
  85 +
  86 + print("len(s)", len(s), s)
  87 +
  88 + return s
  89 +
  90 +
  91 +def export_rknn(rknn, filename):
  92 + ret = rknn.export_rknn(filename)
  93 + if ret != 0:
  94 + exit(f"Export rknn model to {filename} failed!")
  95 +
  96 +
  97 +def init_model(filename: str, target_platform: str, custom_string=None):
  98 + rknn = RKNN(verbose=False)
  99 +
  100 + rknn.config(
  101 + optimization_level=0,
  102 + target_platform=target_platform,
  103 + custom_string=custom_string,
  104 + )
  105 + if not Path(filename).is_file():
  106 + exit(f"{filename} does not exist")
  107 +
  108 + ret = rknn.load_onnx(model=filename)
  109 + if ret != 0:
  110 + exit(f"Load model {filename} failed!")
  111 +
  112 + ret = rknn.build(do_quantization=False)
  113 + if ret != 0:
  114 + exit(f"Build model {filename} failed!")
  115 +
  116 + return rknn
  117 +
  118 +
  119 +class RKNNModel:
  120 + def __init__(
  121 + self,
  122 + model: str,
  123 + target_platform: str,
  124 + ):
  125 + meta = get_meta_data(model)
  126 + print(meta)
  127 +
  128 + self.model = init_model(
  129 + model,
  130 + target_platform=target_platform,
  131 + custom_string=meta,
  132 + )
  133 +
  134 + def export_rknn(self, model):
  135 + export_rknn(self.model, model)
  136 +
  137 + def release(self):
  138 + self.model.release()
  139 +
  140 +
  141 +def main():
  142 + args = get_parser().parse_args()
  143 + print(vars(args))
  144 +
  145 + model = RKNNModel(
  146 + model=args.in_model,
  147 + target_platform=args.target_platform,
  148 + )
  149 +
  150 + model.export_rknn(
  151 + model=args.out_model,
  152 + )
  153 +
  154 + model.release()
  155 +
  156 +
  157 +if __name__ == "__main__":
  158 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +from typing import Tuple
  6 +
  7 +import kaldi_native_fbank as knf
  8 +import numpy as np
  9 +import onnxruntime
  10 +import onnxruntime as ort
  11 +import soundfile as sf
  12 +import torch
  13 +
  14 +
  15 +def get_args():
  16 + parser = argparse.ArgumentParser(
  17 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  18 + )
  19 +
  20 + parser.add_argument(
  21 + "--model",
  22 + type=str,
  23 + required=True,
  24 + help="Path to model.onnx",
  25 + )
  26 +
  27 + parser.add_argument(
  28 + "--tokens",
  29 + type=str,
  30 + required=True,
  31 + help="Path to tokens.txt",
  32 + )
  33 +
  34 + parser.add_argument(
  35 + "--wave",
  36 + type=str,
  37 + required=True,
  38 + help="The input wave to be recognized",
  39 + )
  40 +
  41 + parser.add_argument(
  42 + "--language",
  43 + type=str,
  44 + default="auto",
  45 + help="the language of the input wav file. Supported values: zh, en, ja, ko, yue, auto",
  46 + )
  47 +
  48 + parser.add_argument(
  49 + "--use-itn",
  50 + type=int,
  51 + default=0,
  52 + help="1 to use inverse text normalization. 0 to not use inverse text normalization",
  53 + )
  54 +
  55 + return parser.parse_args()
  56 +
  57 +
  58 +class OnnxModel:
  59 + def __init__(self, filename):
  60 + session_opts = ort.SessionOptions()
  61 + session_opts.inter_op_num_threads = 1
  62 + session_opts.intra_op_num_threads = 1
  63 +
  64 + self.session_opts = session_opts
  65 +
  66 + self.model = ort.InferenceSession(
  67 + filename,
  68 + sess_options=self.session_opts,
  69 + providers=["CPUExecutionProvider"],
  70 + )
  71 +
  72 + meta = self.model.get_modelmeta().custom_metadata_map
  73 +
  74 + self.window_size = int(meta["lfr_window_size"]) # lfr_m
  75 + self.window_shift = int(meta["lfr_window_shift"]) # lfr_n
  76 +
  77 + lang_zh = int(meta["lang_zh"])
  78 + lang_en = int(meta["lang_en"])
  79 + lang_ja = int(meta["lang_ja"])
  80 + lang_ko = int(meta["lang_ko"])
  81 + lang_yue = int(meta["lang_yue"])
  82 + lang_auto = int(meta["lang_auto"])
  83 +
  84 + self.lang_id = {
  85 + "zh": lang_zh,
  86 + "en": lang_en,
  87 + "ja": lang_ja,
  88 + "ko": lang_ko,
  89 + "yue": lang_yue,
  90 + "auto": lang_auto,
  91 + }
  92 + self.with_itn = int(meta["with_itn"])
  93 + self.without_itn = int(meta["without_itn"])
  94 +
  95 + self.max_len = self.model.get_inputs()[0].shape[1]
  96 +
  97 + def __call__(self, x, prompt):
  98 + logits = self.model.run(
  99 + [
  100 + self.model.get_outputs()[0].name,
  101 + ],
  102 + {
  103 + self.model.get_inputs()[0].name: x.numpy(),
  104 + self.model.get_inputs()[1].name: prompt.numpy(),
  105 + },
  106 + )[0]
  107 +
  108 + return torch.from_numpy(logits)
  109 +
  110 +
  111 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  112 + data, sample_rate = sf.read(
  113 + filename,
  114 + always_2d=True,
  115 + dtype="float32",
  116 + )
  117 + data = data[:, 0] # use only the first channel
  118 + samples = np.ascontiguousarray(data)
  119 + return samples, sample_rate
  120 +
  121 +
  122 +def load_tokens(filename):
  123 + ans = dict()
  124 + i = 0
  125 + with open(filename, encoding="utf-8") as f:
  126 + for line in f:
  127 + ans[i] = line.strip().split()[0]
  128 + i += 1
  129 + return ans
  130 +
  131 +
  132 +def compute_feat(
  133 + samples,
  134 + sample_rate,
  135 + max_len: int,
  136 + window_size: int = 7, # lfr_m
  137 + window_shift: int = 6, # lfr_n
  138 +):
  139 + opts = knf.FbankOptions()
  140 + opts.frame_opts.dither = 0
  141 + opts.frame_opts.snip_edges = False
  142 + opts.frame_opts.window_type = "hamming"
  143 + opts.frame_opts.samp_freq = sample_rate
  144 + opts.mel_opts.num_bins = 80
  145 +
  146 + online_fbank = knf.OnlineFbank(opts)
  147 + online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist())
  148 + online_fbank.input_finished()
  149 +
  150 + features = np.stack(
  151 + [online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)]
  152 + )
  153 + assert features.data.contiguous is True
  154 + assert features.dtype == np.float32, features.dtype
  155 +
  156 + T = (features.shape[0] - window_size) // window_shift + 1
  157 + features = np.lib.stride_tricks.as_strided(
  158 + features,
  159 + shape=(T, features.shape[1] * window_size),
  160 + strides=((window_shift * features.shape[1]) * 4, 4),
  161 + )
  162 +
  163 + print("features.shape", features.shape)
  164 +
  165 + if features.shape[0] > max_len:
  166 + features = features[:max_len]
  167 + elif features.shape[0] < max_len:
  168 + features = np.pad(
  169 + features,
  170 + ((0, max_len - features.shape[0]), (0, 0)),
  171 + mode="constant",
  172 + constant_values=0,
  173 + )
  174 +
  175 + print("features.shape", features.shape)
  176 +
  177 + return features
  178 +
  179 +
  180 +def main():
  181 + args = get_args()
  182 + print(vars(args))
  183 + samples, sample_rate = load_audio(args.wave)
  184 + if sample_rate != 16000:
  185 + import librosa
  186 +
  187 + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
  188 + sample_rate = 16000
  189 +
  190 + model = OnnxModel(filename=args.model)
  191 +
  192 + features = compute_feat(
  193 + samples=samples,
  194 + sample_rate=sample_rate,
  195 + max_len=model.max_len,
  196 + window_size=model.window_size,
  197 + window_shift=model.window_shift,
  198 + )
  199 +
  200 + features = torch.from_numpy(features).unsqueeze(0)
  201 +
  202 + language = model.lang_id["auto"]
  203 + if args.language in model.lang_id:
  204 + language = model.lang_id[args.language]
  205 + else:
  206 + print(f"Invalid language: '{args.language}'")
  207 + print("Use auto")
  208 +
  209 + if args.use_itn:
  210 + text_norm = model.with_itn
  211 + else:
  212 + text_norm = model.without_itn
  213 +
  214 + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
  215 +
  216 + logits = model(
  217 + x=features,
  218 + prompt=prompt,
  219 + )
  220 +
  221 + idx = logits.squeeze(0).argmax(dim=-1)
  222 + # idx is of shape (T,)
  223 + idx = torch.unique_consecutive(idx)
  224 +
  225 + blank_id = 0
  226 + idx = idx[idx != blank_id].tolist()
  227 +
  228 + tokens = load_tokens(args.tokens)
  229 + text = "".join([tokens[i] for i in idx])
  230 +
  231 + text = text.replace("▁", " ")
  232 + print(text)
  233 +
  234 +
  235 +if __name__ == "__main__":
  236 + main()
  1 +# This file is modified from
  2 +# https://github.com/modelscope/FunASR/blob/main/funasr/models/sense_voice/model.py
  3 +
  4 +import torch
  5 +import torch.nn
  6 +import torch.nn as nn
  7 +import torch.nn.functional as F
  8 +
  9 +
  10 +class SinusoidalPositionEncoder(nn.Module):
  11 + """ """
  12 +
  13 + def __init__(self, d_model=80, dropout_rate=0.1):
  14 + pass
  15 +
  16 + def encode(
  17 + self,
  18 + positions: torch.Tensor = None,
  19 + depth: int = None,
  20 + dtype: torch.dtype = torch.float32,
  21 + ):
  22 + """
  23 + Args:
  24 + positions: (batch_size, )
  25 + """
  26 + batch_size = positions.size(0)
  27 + positions = positions.type(dtype)
  28 + device = positions.device
  29 + log_timescale_increment = torch.log(
  30 + torch.tensor([10000], dtype=dtype, device=device)
  31 + ) / (depth / 2 - 1)
  32 + inv_timescales = torch.exp(
  33 + torch.arange(depth / 2, device=device).type(dtype)
  34 + * (-log_timescale_increment)
  35 + )
  36 + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
  37 + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
  38 + inv_timescales, [1, 1, -1]
  39 + )
  40 + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  41 + return encoding.type(dtype)
  42 +
  43 + def forward(self, x):
  44 + batch_size, timesteps, input_dim = x.size()
  45 + positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
  46 + position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
  47 +
  48 + return x + position_encoding
  49 +
  50 +
  51 +class PositionwiseFeedForward(nn.Module):
  52 + """Positionwise feed forward layer.
  53 +
  54 + Args:
  55 + idim (int): Input dimenstion.
  56 + hidden_units (int): The number of hidden units.
  57 + dropout_rate (float): Dropout rate.
  58 +
  59 + """
  60 +
  61 + def __init__(self, idim, hidden_units, dropout_rate, activation=None):
  62 + super().__init__()
  63 + self.w_1 = torch.nn.Linear(idim, hidden_units)
  64 + self.w_2 = torch.nn.Linear(hidden_units, idim)
  65 + self.dropout = torch.nn.Dropout(dropout_rate)
  66 + if activation is None:
  67 + activation = torch.nn.ReLU()
  68 + self.activation = activation
  69 +
  70 + def forward(self, x):
  71 + """Forward function."""
  72 + return self.w_2(self.dropout(self.activation(self.w_1(x))))
  73 +
  74 +
  75 +class MultiHeadedAttentionSANM(nn.Module):
  76 + """Multi-Head Attention layer.
  77 +
  78 + Args:
  79 + n_head (int): The number of heads.
  80 + n_feat (int): The number of features.
  81 + dropout_rate (float): Dropout rate.
  82 +
  83 + """
  84 +
  85 + def __init__(
  86 + self,
  87 + n_head,
  88 + in_feat,
  89 + n_feat,
  90 + dropout_rate,
  91 + kernel_size,
  92 + sanm_shfit=0,
  93 + lora_list=None,
  94 + lora_rank=8,
  95 + lora_alpha=16,
  96 + lora_dropout=0.1,
  97 + ):
  98 + super().__init__()
  99 + assert n_feat % n_head == 0
  100 + # We assume d_v always equals d_k
  101 + self.d_k = n_feat // n_head
  102 + self.h = n_head
  103 + self.linear_out = nn.Linear(n_feat, n_feat)
  104 + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  105 + self.attn = None
  106 + self.dropout = nn.Dropout(p=dropout_rate)
  107 +
  108 + self.fsmn_block = nn.Conv1d(
  109 + n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
  110 + )
  111 + # padding
  112 + left_padding = (kernel_size - 1) // 2
  113 + if sanm_shfit > 0:
  114 + left_padding = left_padding + sanm_shfit
  115 + right_padding = kernel_size - 1 - left_padding
  116 + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  117 +
  118 + def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  119 + b, t, d = inputs.size()
  120 + if mask is not None:
  121 + mask = torch.reshape(mask, (b, -1, 1))
  122 + if mask_shfit_chunk is not None:
  123 + mask = mask * mask_shfit_chunk
  124 + inputs = inputs * mask
  125 +
  126 + x = inputs.transpose(1, 2)
  127 + x = self.pad_fn(x)
  128 + x = self.fsmn_block(x)
  129 + x = x.transpose(1, 2)
  130 + x += inputs
  131 + x = self.dropout(x)
  132 + if mask is not None:
  133 + x = x * mask
  134 + return x
  135 +
  136 + def forward_qkv(self, x):
  137 + """Transform query, key and value.
  138 +
  139 + Args:
  140 + query (torch.Tensor): Query tensor (#batch, time1, size).
  141 + key (torch.Tensor): Key tensor (#batch, time2, size).
  142 + value (torch.Tensor): Value tensor (#batch, time2, size).
  143 +
  144 + Returns:
  145 + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  146 + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  147 + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  148 +
  149 + """
  150 + b, t, d = x.size()
  151 + q_k_v = self.linear_q_k_v(x)
  152 + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  153 + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
  154 + 1, 2
  155 + ) # (batch, head, time1, d_k)
  156 + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
  157 + 1, 2
  158 + ) # (batch, head, time2, d_k)
  159 + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
  160 + 1, 2
  161 + ) # (batch, head, time2, d_k)
  162 +
  163 + return q_h, k_h, v_h, v
  164 +
  165 + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  166 + """Compute attention context vector.
  167 +
  168 + Args:
  169 + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  170 + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  171 + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  172 +
  173 + Returns:
  174 + torch.Tensor: Transformed value (#batch, time1, d_model)
  175 + weighted by the attention score (#batch, time1, time2).
  176 +
  177 + """
  178 + n_batch = value.size(0)
  179 + if mask is not None:
  180 + if mask_att_chunk_encoder is not None:
  181 + mask = mask * mask_att_chunk_encoder
  182 +
  183 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  184 +
  185 + min_value = -float(
  186 + "inf"
  187 + ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
  188 + scores = scores.masked_fill(mask, min_value)
  189 + attn = torch.softmax(scores, dim=-1).masked_fill(
  190 + mask, 0.0
  191 + ) # (batch, head, time1, time2)
  192 + else:
  193 + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  194 +
  195 + p_attn = self.dropout(attn)
  196 + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  197 + x = (
  198 + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  199 + ) # (batch, time1, d_model)
  200 +
  201 + return self.linear_out(x) # (batch, time1, d_model)
  202 +
  203 + def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  204 + """Compute scaled dot product attention.
  205 +
  206 + Args:
  207 + query (torch.Tensor): Query tensor (#batch, time1, size).
  208 + key (torch.Tensor): Key tensor (#batch, time2, size).
  209 + value (torch.Tensor): Value tensor (#batch, time2, size).
  210 + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  211 + (#batch, time1, time2).
  212 +
  213 + Returns:
  214 + torch.Tensor: Output tensor (#batch, time1, d_model).
  215 +
  216 + """
  217 + q_h, k_h, v_h, v = self.forward_qkv(x)
  218 + fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  219 + q_h = q_h * self.d_k ** (-0.5)
  220 + scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  221 + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  222 + return att_outs + fsmn_memory
  223 +
  224 +
  225 +class EncoderLayerSANM(nn.Module):
  226 + def __init__(
  227 + self,
  228 + in_size,
  229 + size,
  230 + self_attn,
  231 + feed_forward,
  232 + dropout_rate,
  233 + normalize_before=True,
  234 + concat_after=False,
  235 + stochastic_depth_rate=0.0,
  236 + ):
  237 + super().__init__()
  238 + self.self_attn = self_attn
  239 + self.feed_forward = feed_forward
  240 + self.norm1 = LayerNorm(in_size)
  241 + self.norm2 = LayerNorm(size)
  242 + self.dropout = nn.Dropout(dropout_rate)
  243 + self.in_size = in_size
  244 + self.size = size
  245 + self.normalize_before = normalize_before
  246 + self.concat_after = concat_after
  247 + if self.concat_after:
  248 + self.concat_linear = nn.Linear(size + size, size)
  249 + self.stochastic_depth_rate = stochastic_depth_rate
  250 + self.dropout_rate = dropout_rate
  251 +
  252 + def forward(
  253 + self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None
  254 + ):
  255 + """Compute encoded features.
  256 +
  257 + Args:
  258 + x_input (torch.Tensor): Input tensor (#batch, time, size).
  259 + mask (torch.Tensor): Mask tensor for the input (#batch, time).
  260 + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  261 +
  262 + Returns:
  263 + torch.Tensor: Output tensor (#batch, time, size).
  264 + torch.Tensor: Mask tensor (#batch, time).
  265 +
  266 + """
  267 + skip_layer = False
  268 + # with stochastic depth, residual connection `x + f(x)` becomes
  269 + # `x <- x + 1 / (1 - p) * f(x)` at training time.
  270 + stoch_layer_coeff = 1.0
  271 + if self.training and self.stochastic_depth_rate > 0:
  272 + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  273 + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  274 +
  275 + if skip_layer:
  276 + if cache is not None:
  277 + x = torch.cat([cache, x], dim=1)
  278 + return x, mask
  279 +
  280 + residual = x
  281 + if self.normalize_before:
  282 + x = self.norm1(x)
  283 +
  284 + if self.concat_after:
  285 + x_concat = torch.cat(
  286 + (
  287 + x,
  288 + self.self_attn(
  289 + x,
  290 + mask,
  291 + mask_shfit_chunk=mask_shfit_chunk,
  292 + mask_att_chunk_encoder=mask_att_chunk_encoder,
  293 + ),
  294 + ),
  295 + dim=-1,
  296 + )
  297 + if self.in_size == self.size:
  298 + x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  299 + else:
  300 + x = stoch_layer_coeff * self.concat_linear(x_concat)
  301 + else:
  302 + if self.in_size == self.size:
  303 + x = residual + stoch_layer_coeff * self.dropout(
  304 + self.self_attn(
  305 + x,
  306 + mask,
  307 + mask_shfit_chunk=mask_shfit_chunk,
  308 + mask_att_chunk_encoder=mask_att_chunk_encoder,
  309 + )
  310 + )
  311 + else:
  312 + x = stoch_layer_coeff * self.dropout(
  313 + self.self_attn(
  314 + x,
  315 + mask,
  316 + mask_shfit_chunk=mask_shfit_chunk,
  317 + mask_att_chunk_encoder=mask_att_chunk_encoder,
  318 + )
  319 + )
  320 + return x, mask
  321 + if not self.normalize_before:
  322 + x = self.norm1(x)
  323 +
  324 + residual = x
  325 + if self.normalize_before:
  326 + x = self.norm2(x)
  327 + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  328 + if not self.normalize_before:
  329 + x = self.norm2(x)
  330 +
  331 + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  332 +
  333 +
  334 +class LayerNorm(nn.LayerNorm):
  335 + def __init__(self, *args, **kwargs):
  336 + super().__init__(*args, **kwargs)
  337 +
  338 + def forward(self, input):
  339 + output = F.layer_norm(
  340 + input.float(),
  341 + self.normalized_shape,
  342 + self.weight.float() if self.weight is not None else None,
  343 + self.bias.float() if self.bias is not None else None,
  344 + self.eps,
  345 + )
  346 + return output.type_as(input)
  347 +
  348 +
  349 +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
  350 + if maxlen is None:
  351 + maxlen = lengths.max()
  352 + row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
  353 + matrix = torch.unsqueeze(lengths, dim=-1)
  354 + mask = row_vector < matrix
  355 + mask = mask.detach()
  356 +
  357 + return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
  358 +
  359 +
  360 +class SenseVoiceEncoderSmall(nn.Module):
  361 + def __init__(self):
  362 + super().__init__()
  363 + self.input_size = 80 * 7
  364 + self.output_size = 512
  365 + self.attention_heads = 4
  366 + self.linear_units = 2048
  367 + self.num_blocks = 50
  368 + self.tp_blocks = 20
  369 + self.input_layer = "pe"
  370 + self.pos_enc_class = "SinusoidalPositionEncoder"
  371 + self.normalize_before = True
  372 + self.kernel_size = 11
  373 + self.sanm_shfit = 0
  374 + self.concat_after = False
  375 + self.positionwise_layer_type = "linear"
  376 + self.positionwise_conv_kernel_size = 1
  377 + self.padding_idx = -1
  378 + self.selfattention_layer_type = "sanm"
  379 + self.dropout_rate = 0.1
  380 + self.attention_dropout_rate = 0.1
  381 +
  382 + self._output_size = self.output_size
  383 +
  384 + self.embed = SinusoidalPositionEncoder()
  385 +
  386 + positionwise_layer = PositionwiseFeedForward
  387 + positionwise_layer_args = (
  388 + self.output_size,
  389 + self.linear_units,
  390 + self.dropout_rate,
  391 + )
  392 +
  393 + encoder_selfattn_layer = MultiHeadedAttentionSANM
  394 + encoder_selfattn_layer_args0 = (
  395 + self.attention_heads,
  396 + self.input_size,
  397 + self.output_size,
  398 + self.attention_dropout_rate,
  399 + self.kernel_size,
  400 + self.sanm_shfit,
  401 + )
  402 + encoder_selfattn_layer_args = (
  403 + self.attention_heads,
  404 + self.output_size,
  405 + self.output_size,
  406 + self.attention_dropout_rate,
  407 + self.kernel_size,
  408 + self.sanm_shfit,
  409 + )
  410 +
  411 + self.encoders0 = nn.ModuleList(
  412 + [
  413 + EncoderLayerSANM(
  414 + self.input_size,
  415 + self.output_size,
  416 + encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  417 + positionwise_layer(*positionwise_layer_args),
  418 + self.dropout_rate,
  419 + )
  420 + for i in range(1)
  421 + ]
  422 + )
  423 +
  424 + self.encoders = nn.ModuleList(
  425 + [
  426 + EncoderLayerSANM(
  427 + self.output_size,
  428 + self.output_size,
  429 + encoder_selfattn_layer(*encoder_selfattn_layer_args),
  430 + positionwise_layer(*positionwise_layer_args),
  431 + self.dropout_rate,
  432 + )
  433 + for i in range(self.num_blocks - 1)
  434 + ]
  435 + )
  436 +
  437 + self.tp_encoders = nn.ModuleList(
  438 + [
  439 + EncoderLayerSANM(
  440 + self.output_size,
  441 + self.output_size,
  442 + encoder_selfattn_layer(*encoder_selfattn_layer_args),
  443 + positionwise_layer(*positionwise_layer_args),
  444 + self.dropout_rate,
  445 + )
  446 + for i in range(self.tp_blocks)
  447 + ]
  448 + )
  449 +
  450 + self.after_norm = LayerNorm(self.output_size)
  451 +
  452 + self.tp_norm = LayerNorm(self.output_size)
  453 +
  454 + def forward(
  455 + self,
  456 + xs_pad: torch.Tensor,
  457 + ):
  458 + masks = None
  459 +
  460 + xs_pad *= self.output_size**0.5
  461 +
  462 + xs_pad = self.embed(xs_pad)
  463 +
  464 + # forward encoder1
  465 + for layer_idx, encoder_layer in enumerate(self.encoders0):
  466 + encoder_outs = encoder_layer(xs_pad, masks)
  467 + xs_pad, masks = encoder_outs[0], encoder_outs[1]
  468 +
  469 + for layer_idx, encoder_layer in enumerate(self.encoders):
  470 + encoder_outs = encoder_layer(xs_pad, masks)
  471 + xs_pad, masks = encoder_outs[0], encoder_outs[1]
  472 +
  473 + xs_pad = self.after_norm(xs_pad)
  474 +
  475 + for layer_idx, encoder_layer in enumerate(self.tp_encoders):
  476 + encoder_outs = encoder_layer(xs_pad, masks)
  477 + xs_pad, masks = encoder_outs[0], encoder_outs[1]
  478 +
  479 + xs_pad = self.tp_norm(xs_pad)
  480 + return xs_pad
  481 +
  482 +
  483 +class CTC(nn.Module):
  484 + def __init__(
  485 + self,
  486 + odim: int,
  487 + encoder_output_size: int,
  488 + dropout_rate: float = 0.0,
  489 + ctc_type: str = "builtin",
  490 + reduce: bool = True,
  491 + ignore_nan_grad: bool = True,
  492 + extra_linear: bool = True,
  493 + ):
  494 + super().__init__()
  495 + eprojs = encoder_output_size
  496 + self.dropout_rate = dropout_rate
  497 +
  498 + if extra_linear:
  499 + self.ctc_lo = torch.nn.Linear(eprojs, odim)
  500 + else:
  501 + self.ctc_lo = None
  502 +
  503 + def softmax(self, hs_pad):
  504 + """softmax of frame activations
  505 +
  506 + Args:
  507 + Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
  508 + Returns:
  509 + torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
  510 + """
  511 + if self.ctc_lo is not None:
  512 + return F.softmax(self.ctc_lo(hs_pad), dim=2)
  513 + else:
  514 + return F.softmax(hs_pad, dim=2)
  515 +
  516 + def log_softmax(self, hs_pad):
  517 + """log_softmax of frame activations
  518 +
  519 + Args:
  520 + Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
  521 + Returns:
  522 + torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
  523 + """
  524 + if self.ctc_lo is not None:
  525 + return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
  526 + else:
  527 + return F.log_softmax(hs_pad, dim=2)
  528 +
  529 + def argmax(self, hs_pad):
  530 + """argmax of frame activations
  531 +
  532 + Args:
  533 + torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
  534 + Returns:
  535 + torch.Tensor: argmax applied 2d tensor (B, Tmax)
  536 + """
  537 + if self.ctc_lo is not None:
  538 + return torch.argmax(self.ctc_lo(hs_pad), dim=2)
  539 + else:
  540 + return torch.argmax(hs_pad, dim=2)
  541 +
  542 +
  543 +class SenseVoiceSmall(nn.Module):
  544 + def __init__(self, neg_mean: torch.Tensor, inv_stddev: torch.Tensor):
  545 + super().__init__()
  546 + self.sos = 1
  547 + self.eos = 2
  548 + self.length_normalized_loss = True
  549 + self.ignore_id = -1
  550 + self.blank_id = 0
  551 + self.input_size = 80 * 7
  552 + self.vocab_size = 25055
  553 +
  554 + self.neg_mean = neg_mean.unsqueeze(0).unsqueeze(0)
  555 + self.inv_stddev = inv_stddev.unsqueeze(0).unsqueeze(0)
  556 +
  557 + self.lid_dict = {
  558 + "auto": 0,
  559 + "zh": 3,
  560 + "en": 4,
  561 + "yue": 7,
  562 + "ja": 11,
  563 + "ko": 12,
  564 + "nospeech": 13,
  565 + }
  566 + self.lid_int_dict = {
  567 + 24884: 3,
  568 + 24885: 4,
  569 + 24888: 7,
  570 + 24892: 11,
  571 + 24896: 12,
  572 + 24992: 13,
  573 + }
  574 + self.textnorm_dict = {"withitn": 14, "woitn": 15}
  575 + self.textnorm_int_dict = {25016: 14, 25017: 15}
  576 +
  577 + self.emo_dict = {
  578 + "unk": 25009,
  579 + "happy": 25001,
  580 + "sad": 25002,
  581 + "angry": 25003,
  582 + "neutral": 25004,
  583 + }
  584 +
  585 + self.encoder = SenseVoiceEncoderSmall()
  586 + self.ctc = CTC(
  587 + odim=self.vocab_size,
  588 + encoder_output_size=self.encoder.output_size,
  589 + )
  590 + self.embed = torch.nn.Embedding(
  591 + 7 + len(self.lid_dict) + len(self.textnorm_dict), self.input_size
  592 + )
  593 +
  594 + def forward(self, x, prompt):
  595 + input_query = self.embed(prompt).unsqueeze(0)
  596 +
  597 + # for export, we always assume x and self.neg_mean are on CPU
  598 + x = (x + self.neg_mean) * self.inv_stddev
  599 + x = torch.cat((input_query, x), dim=1)
  600 +
  601 + encoder_out = self.encoder(x)
  602 + logits = self.ctc.ctc_lo(encoder_out)
  603 +
  604 + return logits
@@ -173,6 +173,8 @@ list(APPEND sources @@ -173,6 +173,8 @@ list(APPEND sources
173 ) 173 )
174 if(SHERPA_ONNX_ENABLE_RKNN) 174 if(SHERPA_ONNX_ENABLE_RKNN)
175 list(APPEND sources 175 list(APPEND sources
  176 + ./rknn/offline-ctc-greedy-search-decoder-rknn.cc
  177 + ./rknn/offline-sense-voice-model-rknn.cc
176 ./rknn/online-stream-rknn.cc 178 ./rknn/online-stream-rknn.cc
177 ./rknn/online-transducer-greedy-search-decoder-rknn.cc 179 ./rknn/online-transducer-greedy-search-decoder-rknn.cc
178 ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc 180 ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 7
8 #include "sherpa-onnx/csrc/file-utils.h" 8 #include "sherpa-onnx/csrc/file-utils.h"
9 #include "sherpa-onnx/csrc/macros.h" 9 #include "sherpa-onnx/csrc/macros.h"
  10 +#include "sherpa-onnx/csrc/text-utils.h"
10 11
11 namespace sherpa_onnx { 12 namespace sherpa_onnx {
12 13
@@ -57,9 +58,37 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -57,9 +58,37 @@ void OfflineModelConfig::Register(ParseOptions *po) {
57 } 58 }
58 59
59 bool OfflineModelConfig::Validate() const { 60 bool OfflineModelConfig::Validate() const {
60 - if (num_threads < 1) {  
61 - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);  
62 - return false; 61 + // For RK NPU, we reinterpret num_threads:
  62 + //
  63 + // For RK3588 only
  64 + // num_threads == 1 -> Select a core randomly
  65 + // num_threads == 0 -> Use NPU core 0
  66 + // num_threads == -1 -> Use NPU core 1
  67 + // num_threads == -2 -> Use NPU core 2
  68 + // num_threads == -3 -> Use NPU core 0 and core 1
  69 + // num_threads == -4 -> Use NPU core 0, core 1, and core 2
  70 + if (provider != "rknn") {
  71 + if (num_threads < 1) {
  72 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  73 + return false;
  74 + }
  75 + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".rknn"))) {
  76 + SHERPA_ONNX_LOGE(
  77 + "--provider is %s, which is not rknn, but you pass a rknn model "
  78 + "filename. model: '%s'",
  79 + provider.c_str(), sense_voice.model.c_str());
  80 + return false;
  81 + }
  82 + }
  83 +
  84 + if (provider == "rknn") {
  85 + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".onnx"))) {
  86 + SHERPA_ONNX_LOGE(
  87 + "--provider is rknn, but you pass an onnx model "
  88 + "filename. model: '%s'",
  89 + sense_voice.model.c_str());
  90 + return false;
  91 + }
63 } 92 }
64 93
65 if (!FileExists(tokens)) { 94 if (!FileExists(tokens)) {
@@ -35,10 +35,32 @@ @@ -35,10 +35,32 @@
35 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" 35 #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
36 #include "sherpa-onnx/csrc/text-utils.h" 36 #include "sherpa-onnx/csrc/text-utils.h"
37 37
  38 +#if SHERPA_ONNX_ENABLE_RKNN
  39 +#include "sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h"
  40 +#endif
  41 +
38 namespace sherpa_onnx { 42 namespace sherpa_onnx {
39 43
40 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 44 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
41 const OfflineRecognizerConfig &config) { 45 const OfflineRecognizerConfig &config) {
  46 + if (config.model_config.provider == "rknn") {
  47 +#if SHERPA_ONNX_ENABLE_RKNN
  48 + if (config.model_config.sense_voice.model.empty()) {
  49 + SHERPA_ONNX_LOGE(
  50 + "Only SenseVoice models are currently supported "
  51 + "by rknn for non-streaming ASR. Fallback to CPU");
  52 + } else if (!config.model_config.sense_voice.model.empty()) {
  53 + return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(config);
  54 + }
  55 +#else
  56 + SHERPA_ONNX_LOGE(
  57 + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
  58 + "want to use rknn.");
  59 + SHERPA_ONNX_EXIT(-1);
  60 + return nullptr;
  61 +#endif
  62 + }
  63 +
42 if (!config.model_config.sense_voice.model.empty()) { 64 if (!config.model_config.sense_voice.model.empty()) {
43 return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config); 65 return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
44 } 66 }
@@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
229 template <typename Manager> 251 template <typename Manager>
230 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 252 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
231 Manager *mgr, const OfflineRecognizerConfig &config) { 253 Manager *mgr, const OfflineRecognizerConfig &config) {
  254 + if (config.model_config.provider == "rknn") {
  255 +#if SHERPA_ONNX_ENABLE_RKNN
  256 + if (config.model_config.sense_voice.model.empty()) {
  257 + SHERPA_ONNX_LOGE(
  258 + "Only SenseVoice models are currently supported "
  259 + "by rknn for non-streaming ASR. Fallback to CPU");
  260 + } else if (!config.model_config.sense_voice.model.empty()) {
  261 + return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(mgr, config);
  262 + }
  263 +#else
  264 + SHERPA_ONNX_LOGE(
  265 + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
  266 + "want to use rknn.");
  267 + SHERPA_ONNX_EXIT(-1);
  268 + return nullptr;
  269 +#endif
  270 + }
  271 +
232 if (!config.model_config.sense_voice.model.empty()) { 272 if (!config.model_config.sense_voice.model.empty()) {
233 return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config); 273 return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
234 } 274 }
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include <utility> 11 #include <utility>
12 #include <vector> 12 #include <vector>
13 13
  14 +#include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" 15 #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
15 #include "sherpa-onnx/csrc/offline-model-config.h" 16 #include "sherpa-onnx/csrc/offline-model-config.h"
16 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 17 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
@@ -21,7 +22,7 @@ @@ -21,7 +22,7 @@
21 22
22 namespace sherpa_onnx { 23 namespace sherpa_onnx {
23 24
24 -static OfflineRecognitionResult ConvertSenseVoiceResult( 25 +OfflineRecognitionResult ConvertSenseVoiceResult(
25 const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, 26 const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
26 int32_t frame_shift_ms, int32_t subsampling_factor) { 27 int32_t frame_shift_ms, int32_t subsampling_factor) {
27 OfflineRecognitionResult r; 28 OfflineRecognitionResult r;
@@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { @@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
72 } else { 73 } else {
73 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", 74 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
74 config.decoding_method.c_str()); 75 config.decoding_method.c_str());
75 - exit(-1); 76 + SHERPA_ONNX_EXIT(-1);
76 } 77 }
77 78
78 InitFeatConfig(); 79 InitFeatConfig();
@@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { @@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
93 } else { 94 } else {
94 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", 95 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
95 config.decoding_method.c_str()); 96 config.decoding_method.c_str());
96 - exit(-1); 97 + SHERPA_ONNX_EXIT(-1);
97 } 98 }
98 99
99 InitFeatConfig(); 100 InitFeatConfig();
@@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
37 const OnlineRecognizerConfig &config) { 37 const OnlineRecognizerConfig &config) {
38 if (config.model_config.provider_config.provider == "rknn") { 38 if (config.model_config.provider_config.provider == "rknn") {
39 #if SHERPA_ONNX_ENABLE_RKNN 39 #if SHERPA_ONNX_ENABLE_RKNN
40 - // Currently, only zipformer v1 is suported for rknn  
41 if (config.model_config.transducer.encoder.empty() && 40 if (config.model_config.transducer.encoder.empty() &&
42 config.model_config.zipformer2_ctc.model.empty()) { 41 config.model_config.zipformer2_ctc.model.empty()) {
43 SHERPA_ONNX_LOGE( 42 SHERPA_ONNX_LOGE(
  1 +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +OfflineCtcDecoderResult OfflineCtcGreedySearchDecoderRknn::Decode(
  16 + const float *logits, int32_t num_frames, int32_t vocab_size) {
  17 + OfflineCtcDecoderResult ans;
  18 +
  19 + int64_t prev_id = -1;
  20 +
  21 + for (int32_t t = 0; t != num_frames; ++t) {
  22 + auto y = static_cast<int64_t>(std::distance(
  23 + static_cast<const float *>(logits),
  24 + std::max_element(static_cast<const float *>(logits),
  25 + static_cast<const float *>(logits) + vocab_size)));
  26 + logits += vocab_size;
  27 +
  28 + if (y != blank_id_ && y != prev_id) {
  29 + ans.tokens.push_back(y);
  30 + ans.timestamps.push_back(t);
  31 + }
  32 + prev_id = y;
  33 + } // for (int32_t t = 0; ...)
  34 +
  35 + return ans;
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineCtcGreedySearchDecoderRknn {
  15 + public:
  16 + explicit OfflineCtcGreedySearchDecoderRknn(int32_t blank_id)
  17 + : blank_id_(blank_id) {}
  18 +
  19 + OfflineCtcDecoderResult Decode(const float *logits, int32_t num_frames,
  20 + int32_t vocab_size);
  21 +
  22 + private:
  23 + int32_t blank_id_;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
  1 +// sherpa-onnx/csrc/offline-recognizer-sense-voice-rknn-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/offline-model-config.h"
  14 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  15 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  16 +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
  17 +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
  18 +#include "sherpa-onnx/csrc/symbol-table.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +// defined in ../online-recognizer-sense-voice-impl.h
  23 +OfflineRecognitionResult ConvertSenseVoiceResult(
  24 + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
  25 + int32_t frame_shift_ms, int32_t subsampling_factor);
  26 +
  27 +class OfflineRecognizerSenseVoiceRknnImpl : public OfflineRecognizerImpl {
  28 + public:
  29 + explicit OfflineRecognizerSenseVoiceRknnImpl(
  30 + const OfflineRecognizerConfig &config)
  31 + : OfflineRecognizerImpl(config),
  32 + config_(config),
  33 + symbol_table_(config_.model_config.tokens),
  34 + model_(
  35 + std::make_unique<OfflineSenseVoiceModelRknn>(config.model_config)) {
  36 + const auto &meta_data = model_->GetModelMetadata();
  37 + if (config.decoding_method == "greedy_search") {
  38 + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
  39 + meta_data.blank_id);
  40 + } else {
  41 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  42 + config.decoding_method.c_str());
  43 + SHERPA_ONNX_EXIT(-1);
  44 + }
  45 +
  46 + InitFeatConfig();
  47 + }
  48 +
  49 + template <typename Manager>
  50 + OfflineRecognizerSenseVoiceRknnImpl(Manager *mgr,
  51 + const OfflineRecognizerConfig &config)
  52 + : OfflineRecognizerImpl(mgr, config),
  53 + config_(config),
  54 + symbol_table_(mgr, config_.model_config.tokens),
  55 + model_(std::make_unique<OfflineSenseVoiceModelRknn>(
  56 + mgr, config.model_config)) {
  57 + const auto &meta_data = model_->GetModelMetadata();
  58 + if (config.decoding_method == "greedy_search") {
  59 + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
  60 + meta_data.blank_id);
  61 + } else {
  62 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  63 + config.decoding_method.c_str());
  64 + SHERPA_ONNX_EXIT(-1);
  65 + }
  66 +
  67 + InitFeatConfig();
  68 + }
  69 +
  70 + std::unique_ptr<OfflineStream> CreateStream() const override {
  71 + return std::make_unique<OfflineStream>(config_.feat_config);
  72 + }
  73 +
  74 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  75 + for (int32_t i = 0; i < n; ++i) {
  76 + DecodeOneStream(ss[i]);
  77 + }
  78 + }
  79 +
  80 + OfflineRecognizerConfig GetConfig() const override { return config_; }
  81 +
  82 + private:
  83 + void InitFeatConfig() {
  84 + const auto &meta_data = model_->GetModelMetadata();
  85 +
  86 + config_.feat_config.normalize_samples = meta_data.normalize_samples;
  87 + config_.feat_config.window_type = "hamming";
  88 + config_.feat_config.high_freq = 0;
  89 + config_.feat_config.snip_edges = true;
  90 + }
  91 +
  92 + void DecodeOneStream(OfflineStream *s) const {
  93 + const auto &meta_data = model_->GetModelMetadata();
  94 +
  95 + std::vector<float> f = s->GetFrames();
  96 +
  97 + int32_t language = 0;
  98 + if (config_.model_config.sense_voice.language.empty()) {
  99 + language = 0;
  100 + } else if (meta_data.lang2id.count(
  101 + config_.model_config.sense_voice.language)) {
  102 + language =
  103 + meta_data.lang2id.at(config_.model_config.sense_voice.language);
  104 + } else {
  105 + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
  106 + config_.model_config.sense_voice.language.c_str());
  107 + }
  108 +
  109 + int32_t text_norm = config_.model_config.sense_voice.use_itn
  110 + ? meta_data.with_itn_id
  111 + : meta_data.without_itn_id;
  112 +
  113 + std::vector<float> logits = model_->Run(std::move(f), language, text_norm);
  114 + int32_t num_out_frames = logits.size() / meta_data.vocab_size;
  115 +
  116 + auto result =
  117 + decoder_->Decode(logits.data(), num_out_frames, meta_data.vocab_size);
  118 +
  119 + int32_t frame_shift_ms = 10;
  120 + int32_t subsampling_factor = meta_data.window_shift;
  121 + auto r = ConvertSenseVoiceResult(result, symbol_table_, frame_shift_ms,
  122 + subsampling_factor);
  123 +
  124 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  125 + r.text = ApplyHomophoneReplacer(std::move(r.text));
  126 + s->SetResult(r);
  127 + }
  128 +
  129 + private:
  130 + OfflineRecognizerConfig config_;
  131 + SymbolTable symbol_table_;
  132 + std::unique_ptr<OfflineSenseVoiceModelRknn> model_;
  133 + std::unique_ptr<OfflineCtcGreedySearchDecoderRknn> decoder_;
  134 +};
  135 +
  136 +} // namespace sherpa_onnx
  137 +
  138 +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
  1 +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
  6 +
  7 +#include <algorithm>
  8 +#include <array>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#if __OHOS__
  18 +#include "rawfile/raw_file_manager.h"
  19 +#endif
  20 +
  21 +#include "sherpa-onnx/csrc/file-utils.h"
  22 +#include "sherpa-onnx/csrc/rknn/macros.h"
  23 +#include "sherpa-onnx/csrc/rknn/utils.h"
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +class OfflineSenseVoiceModelRknn::Impl {
  28 + public:
  29 + ~Impl() {
  30 + auto ret = rknn_destroy(ctx_);
  31 + if (ret != RKNN_SUCC) {
  32 + SHERPA_ONNX_LOGE("Failed to destroy the context");
  33 + }
  34 + }
  35 +
  36 + explicit Impl(const OfflineModelConfig &config) : config_(config) {
  37 + {
  38 + auto buf = ReadFile(config_.sense_voice.model);
  39 + Init(buf.data(), buf.size());
  40 + }
  41 +
  42 + SetCoreMask(ctx_, config_.num_threads);
  43 + }
  44 +
  45 + template <typename Manager>
  46 + Impl(Manager *mgr, const OfflineModelConfig &config) : config_(config) {
  47 + {
  48 + auto buf = ReadFile(mgr, config_.sense_voice.model);
  49 + Init(buf.data(), buf.size());
  50 + }
  51 +
  52 + SetCoreMask(ctx_, config_.num_threads);
  53 + }
  54 +
  55 + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
  56 + return meta_data_;
  57 + }
  58 +
  59 + std::vector<float> Run(std::vector<float> features, int32_t language,
  60 + int32_t text_norm) {
  61 + features = ApplyLFR(std::move(features));
  62 +
  63 + std::vector<rknn_input> inputs(input_attrs_.size());
  64 +
  65 + std::array<int32_t, 4> prompt{language, 1, 2, text_norm};
  66 +
  67 + inputs[0].index = input_attrs_[0].index;
  68 + inputs[0].type = RKNN_TENSOR_FLOAT32;
  69 + inputs[0].fmt = input_attrs_[0].fmt;
  70 + inputs[0].buf = reinterpret_cast<void *>(features.data());
  71 + inputs[0].size = features.size() * sizeof(float);
  72 +
  73 + inputs[1].index = input_attrs_[1].index;
  74 + inputs[1].type = RKNN_TENSOR_INT32;
  75 + inputs[1].fmt = input_attrs_[1].fmt;
  76 + inputs[1].buf = reinterpret_cast<void *>(prompt.data());
  77 + inputs[1].size = prompt.size() * sizeof(int32_t);
  78 +
  79 + std::vector<float> out(output_attrs_[0].n_elems);
  80 +
  81 + std::vector<rknn_output> outputs(output_attrs_.size());
  82 + outputs[0].index = output_attrs_[0].index;
  83 + outputs[0].is_prealloc = 1;
  84 + outputs[0].want_float = 1;
  85 + outputs[0].size = out.size() * sizeof(float);
  86 + outputs[0].buf = reinterpret_cast<void *>(out.data());
  87 +
  88 + rknn_context ctx = 0;
  89 + auto ret = rknn_dup_context(&ctx_, &ctx);
  90 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
  91 +
  92 + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
  93 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
  94 +
  95 + ret = rknn_run(ctx, nullptr);
  96 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
  97 +
  98 + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
  99 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
  100 +
  101 + rknn_destroy(ctx);
  102 +
  103 + return out;
  104 + }
  105 +
  106 + private:
  107 + void Init(void *model_data, size_t model_data_length) {
  108 + InitContext(model_data, model_data_length, config_.debug, &ctx_);
  109 +
  110 + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
  111 +
  112 + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
  113 +
  114 + auto meta = Parse(custom_string, config_.debug);
  115 +
  116 +#define SHERPA_ONNX_RKNN_READ_META_DATA_INT(dst, src_key) \
  117 + do { \
  118 + if (!meta.count(#src_key)) { \
  119 + SHERPA_ONNX_LOGE("'%s' does not exist in the custom_string", #src_key); \
  120 + SHERPA_ONNX_EXIT(-1); \
  121 + } \
  122 + \
  123 + dst = atoi(meta.at(#src_key).c_str()); \
  124 + } while (0)
  125 +
  126 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.with_itn_id, with_itn);
  127 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.without_itn_id, without_itn);
  128 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_size,
  129 + lfr_window_size);
  130 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_shift,
  131 + lfr_window_shift);
  132 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.vocab_size, vocab_size);
  133 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.normalize_samples,
  134 + normalize_samples);
  135 +
  136 + int32_t lang_auto = 0;
  137 + int32_t lang_zh = 0;
  138 + int32_t lang_en = 0;
  139 + int32_t lang_ja = 0;
  140 + int32_t lang_ko = 0;
  141 + int32_t lang_yue = 0;
  142 +
  143 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_auto, lang_auto);
  144 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_zh, lang_zh);
  145 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_en, lang_en);
  146 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ja, lang_ja);
  147 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ko, lang_ko);
  148 + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_yue, lang_yue);
  149 +
  150 + meta_data_.lang2id = {
  151 + {"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en},
  152 + {"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue},
  153 + };
  154 +
  155 + // for rknn models, neg_mean and inv_stddev are stored inside the model
  156 +
  157 +#undef SHERPA_ONNX_RKNN_READ_META_DATA_INT
  158 +
  159 + num_input_frames_ = input_attrs_[0].dims[1];
  160 + }
  161 +
  162 + std::vector<float> ApplyLFR(std::vector<float> in) const {
  163 + int32_t lfr_window_size = meta_data_.window_size;
  164 + int32_t lfr_window_shift = meta_data_.window_shift;
  165 + int32_t in_feat_dim = 80;
  166 +
  167 + int32_t in_num_frames = in.size() / in_feat_dim;
  168 + int32_t out_num_frames =
  169 + (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
  170 +
  171 + if (out_num_frames > num_input_frames_) {
  172 + SHERPA_ONNX_LOGE(
  173 + "Number of input frames %d is too large. Truncate it to %d frames.",
  174 + out_num_frames, num_input_frames_);
  175 +
  176 + SHERPA_ONNX_LOGE(
  177 + "Recognition result may be truncated/incomplete. Please select a "
  178 + "model accepting longer audios.");
  179 +
  180 + out_num_frames = num_input_frames_;
  181 + }
  182 +
  183 + int32_t out_feat_dim = in_feat_dim * lfr_window_size;
  184 +
  185 + std::vector<float> out(num_input_frames_ * out_feat_dim);
  186 +
  187 + const float *p_in = in.data();
  188 + float *p_out = out.data();
  189 +
  190 + for (int32_t i = 0; i != out_num_frames; ++i) {
  191 + std::copy(p_in, p_in + out_feat_dim, p_out);
  192 +
  193 + p_out += out_feat_dim;
  194 + p_in += lfr_window_shift * in_feat_dim;
  195 + }
  196 +
  197 + return out;
  198 + }
  199 +
  200 + private:
  201 + OfflineModelConfig config_;
  202 +
  203 + rknn_context ctx_ = 0;
  204 +
  205 + std::vector<rknn_tensor_attr> input_attrs_;
  206 + std::vector<rknn_tensor_attr> output_attrs_;
  207 +
  208 + OfflineSenseVoiceModelMetaData meta_data_;
  209 + int32_t num_input_frames_ = -1;
  210 +};
  211 +
  212 +OfflineSenseVoiceModelRknn::~OfflineSenseVoiceModelRknn() = default;
  213 +
  214 +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
  215 + const OfflineModelConfig &config)
  216 + : impl_(std::make_unique<Impl>(config)) {}
  217 +
  218 +template <typename Manager>
  219 +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
  220 + Manager *mgr, const OfflineModelConfig &config)
  221 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  222 +
  223 +std::vector<float> OfflineSenseVoiceModelRknn::Run(std::vector<float> features,
  224 + int32_t language,
  225 + int32_t text_norm) const {
  226 + return impl_->Run(std::move(features), language, text_norm);
  227 +}
  228 +
  229 +const OfflineSenseVoiceModelMetaData &
  230 +OfflineSenseVoiceModelRknn::GetModelMetadata() const {
  231 + return impl_->GetModelMetadata();
  232 +}
  233 +
  234 +#if __ANDROID_API__ >= 9
  235 +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
  236 + AAssetManager *mgr, const OfflineModelConfig &config);
  237 +#endif
  238 +
  239 +#if __OHOS__
  240 +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
  241 + NativeResourceManager *mgr, const OfflineModelConfig &config);
  242 +#endif
  243 +
  244 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
  5 +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "rknn_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-model-config.h"
  13 +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineSenseVoiceModelRknn {
  18 + public:
  19 + ~OfflineSenseVoiceModelRknn();
  20 +
  21 + explicit OfflineSenseVoiceModelRknn(const OfflineModelConfig &config);
  22 +
  23 + template <typename Manager>
  24 + OfflineSenseVoiceModelRknn(Manager *mgr, const OfflineModelConfig &config);
  25 +
  26 + /**
  27 + * @param features A tensor of shape (num_frames, feature_dim)
  28 + * before applying LFR.
  29 + * @param language
  30 + * @param text_norm
  31 + * @returns Return a tensor of shape (num_output_frames, vocab_size)
  32 + */
  33 + std::vector<float> Run(std::vector<float> features, int32_t language,
  34 + int32_t text_norm) const;
  35 +
  36 + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
  37 +
  38 + private:
  39 + class Impl;
  40 + std::unique_ptr<Impl> impl_;
  41 +};
  42 +
  43 +} // namespace sherpa_onnx
  44 +
  45 +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
@@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl {
55 SetCoreMask(ctx_, config_.num_threads); 55 SetCoreMask(ctx_, config_.num_threads);
56 } 56 }
57 57
58 - // TODO(fangjun): Support Android  
59 -  
60 std::vector<std::vector<uint8_t>> GetInitStates() const { 58 std::vector<std::vector<uint8_t>> GetInitStates() const {
61 // input_attrs_[0] is for the feature 59 // input_attrs_[0] is for the feature
62 // input_attrs_[1:] is for states 60 // input_attrs_[1:] is for states
@@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl {
89 SetCoreMask(joiner_ctx_, config_.num_threads); 89 SetCoreMask(joiner_ctx_, config_.num_threads);
90 } 90 }
91 91
92 - // TODO(fangjun): Support Android  
93 -  
94 std::vector<std::vector<uint8_t>> GetEncoderInitStates() const { 92 std::vector<std::vector<uint8_t>> GetEncoderInitStates() const {
95 // encoder_input_attrs_[0] is for the feature 93 // encoder_input_attrs_[0] is for the feature
96 // encoder_input_attrs_[1:] is for states 94 // encoder_input_attrs_[1:] is for states