manyeyes
Committed by GitHub

adding a python api for offline decode (#110)

  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 by manyeyes
  4 +
  5 +"""
  6 +This file demonstrates how to use sherpa-onnx Python API to transcribe
  7 +file(s) with a non-streaming model.
  8 +
  9 +paraformer Usage:
  10 + ./python-api-examples/offline-decode-files.py \
  11 + --tokens=/path/to/tokens.txt \
  12 + --paraformer=/path/to/paraformer.onnx \
  13 + --num-threads=2 \
  14 + --decoding-method=greedy_search \
  15 + --debug=false \
  16 + --sample-rate=16000 \
  17 + --feature-dim=80 \
  18 + /path/to/0.wav \
  19 + /path/to/1.wav
  20 +
  21 +transducer Usage:
  22 + ./python-api-examples/offline-decode-files.py \
  23 + --tokens=/path/to/tokens.txt \
  24 + --encoder=/path/to/encoder.onnx \
  25 + --decoder=/path/to/decoder.onnx \
  26 + --joiner=/path/to/joiner.onnx \
  27 + --num-threads=2 \
  28 + --decoding-method=greedy_search \
  29 + --debug=false \
  30 + --sample-rate=16000 \
  31 + --feature-dim=80 \
  32 + /path/to/0.wav \
  33 + /path/to/1.wav
  34 +
  35 +Please refer to
  36 +https://k2-fsa.github.io/sherpa/onnx/index.html
  37 +to install sherpa-onnx and to download the pre-trained models
  38 +used in this file.
  39 +"""
  40 +import argparse
  41 +import time
  42 +import wave
  43 +from pathlib import Path
  44 +from typing import Tuple
  45 +
  46 +import numpy as np
  47 +import sherpa_onnx
  48 +
  49 +def get_args():
  50 + parser = argparse.ArgumentParser(
  51 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  52 + )
  53 +
  54 + parser.add_argument(
  55 + "--tokens",
  56 + type=str,
  57 + help="Path to tokens.txt",
  58 + )
  59 +
  60 + parser.add_argument(
  61 + "--encoder",
  62 + default="",
  63 + type=str,
  64 + help="Path to the encoder model",
  65 + )
  66 +
  67 + parser.add_argument(
  68 + "--decoder",
  69 + default="",
  70 + type=str,
  71 + help="Path to the decoder model",
  72 + )
  73 +
  74 + parser.add_argument(
  75 + "--joiner",
  76 + default="",
  77 + type=str,
  78 + help="Path to the joiner model",
  79 + )
  80 +
  81 + parser.add_argument(
  82 + "--paraformer",
  83 + default="",
  84 + type=str,
  85 + help="Path to the paraformer model",
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--num-threads",
  90 + type=int,
  91 + default=1,
  92 + help="Number of threads for neural network computation",
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--decoding-method",
  97 + type=str,
  98 + default="greedy_search",
  99 + help="Valid values are greedy_search and modified_beam_search",
  100 + )
  101 + parser.add_argument(
  102 + "--debug",
  103 + type=bool,
  104 + default=False,
  105 + help="True to show debug messages",
  106 + )
  107 +
  108 + parser.add_argument(
  109 + "--sample-rate",
  110 + type=int,
  111 + default=16000,
  112 + help="Sample rate of the feature extractor. Must match the one expected by the model. Note: The input sound files can have a different sample rate from this argument.",
  113 + )
  114 +
  115 + parser.add_argument(
  116 + "--feature-dim",
  117 + type=int,
  118 + default=80,
  119 + help="Feature dimension. Must match the one expected by the model",
  120 + )
  121 +
  122 + parser.add_argument(
  123 + "sound_files",
  124 + type=str,
  125 + nargs="+",
  126 + help="The input sound file(s) to decode. Each file must be of WAVE"
  127 + "format with a single channel, and each sample has 16-bit, "
  128 + "i.e., int16_t. "
  129 + "The sample rate of the file can be arbitrary and does not need to "
  130 + "be 16 kHz",
  131 + )
  132 +
  133 + return parser.parse_args()
  134 +
  135 +
  136 +def assert_file_exists(filename: str):
  137 + assert Path(filename).is_file(), (
  138 + f"{filename} does not exist!\n"
  139 + "Please refer to "
  140 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  141 + )
  142 +
  143 +
  144 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  145 + """
  146 + Args:
  147 + wave_filename:
  148 + Path to a wave file. It should be single channel and each sample should
  149 + be 16-bit. Its sample rate does not need to be 16kHz.
  150 + Returns:
  151 + Return a tuple containing:
  152 + - A 1-D array of dtype np.float32 containing the samples, which are
  153 + normalized to the range [-1, 1].
  154 + - sample rate of the wave file
  155 + """
  156 +
  157 + with wave.open(wave_filename) as f:
  158 + assert f.getnchannels() == 1, f.getnchannels()
  159 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  160 + num_samples = f.getnframes()
  161 + samples = f.readframes(num_samples)
  162 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  163 + samples_float32 = samples_int16.astype(np.float32)
  164 +
  165 + samples_float32 = samples_float32 / 32768
  166 + return samples_float32, f.getframerate()
  167 +
  168 +def main():
  169 + args = get_args()
  170 + assert_file_exists(args.tokens)
  171 + assert args.num_threads > 0, args.num_threads
  172 + if len(args.encoder) > 0:
  173 + assert_file_exists(args.encoder)
  174 + assert_file_exists(args.decoder)
  175 + assert_file_exists(args.joiner)
  176 + assert len(args.paraformer) == 0, args.paraformer
  177 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  178 + encoder=args.encoder,
  179 + decoder=args.decoder,
  180 + joiner=args.joiner,
  181 + tokens=args.tokens,
  182 + num_threads=args.num_threads,
  183 + sample_rate=args.sample_rate,
  184 + feature_dim=args.feature_dim,
  185 + decoding_method=args.decoding_method,
  186 + debug=args.debug
  187 + )
  188 + else:
  189 + assert_file_exists(args.paraformer)
  190 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  191 + paraformer=args.paraformer,
  192 + tokens=args.tokens,
  193 + num_threads=args.num_threads,
  194 + sample_rate=args.sample_rate,
  195 + feature_dim=args.feature_dim,
  196 + decoding_method=args.decoding_method,
  197 + debug=args.debug
  198 + )
  199 +
  200 +
  201 + print("Started!")
  202 + start_time = time.time()
  203 +
  204 + streams = []
  205 + total_duration = 0
  206 + for wave_filename in args.sound_files:
  207 + assert_file_exists(wave_filename)
  208 + samples, sample_rate = read_wave(wave_filename)
  209 + duration = len(samples) / sample_rate
  210 + total_duration += duration
  211 +
  212 + s = recognizer.create_stream()
  213 + s.accept_waveform(sample_rate, samples)
  214 +
  215 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  216 + s.accept_waveform(sample_rate, tail_paddings)
  217 +
  218 + streams.append(s)
  219 +
  220 +
  221 + recognizer.decode_streams(streams)
  222 + results = [s.result.text for s in streams]
  223 + end_time = time.time()
  224 + print("Done!")
  225 +
  226 + for wave_filename, result in zip(args.sound_files, results):
  227 + print(f"{wave_filename}\n{result}")
  228 + print("-" * 10)
  229 +
  230 + elapsed_seconds = end_time - start_time
  231 + rtf = elapsed_seconds / duration
  232 + print(f"num_threads: {args.num_threads}")
  233 + print(f"decoding_method: {args.decoding_method}")
  234 + print(f"Wave duration: {duration:.3f} s")
  235 + print(f"Elapsed time: {elapsed_seconds:.3f} s")
  236 + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
  237 +
  238 +
  239 +if __name__ == "__main__":
  240 + main()
@@ -16,20 +16,7 @@ @@ -16,20 +16,7 @@
16 16
17 namespace sherpa_onnx { 17 namespace sherpa_onnx {
18 18
19 -struct OfflineRecognitionResult {  
20 - // Recognition results.  
21 - // For English, it consists of space separated words.  
22 - // For Chinese, it consists of Chinese words without spaces.  
23 - std::string text;  
24 -  
25 - // Decoded results at the token level.  
26 - // For instance, for BPE-based models it consists of a list of BPE tokens.  
27 - std::vector<std::string> tokens;  
28 -  
29 - /// timestamps.size() == tokens.size()  
30 - /// timestamps[i] records the time in seconds when tokens[i] is decoded.  
31 - std::vector<float> timestamps;  
32 -}; 19 +struct OfflineRecognitionResult;
33 20
34 struct OfflineRecognizerConfig { 21 struct OfflineRecognizerConfig {
35 OfflineFeatureExtractorConfig feat_config; 22 OfflineFeatureExtractorConfig feat_config;
@@ -13,7 +13,21 @@ @@ -13,7 +13,21 @@
13 #include "sherpa-onnx/csrc/parse-options.h" 13 #include "sherpa-onnx/csrc/parse-options.h"
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 -struct OfflineRecognitionResult; 16 +
  17 +struct OfflineRecognitionResult {
  18 + // Recognition results.
  19 + // For English, it consists of space separated words.
  20 + // For Chinese, it consists of Chinese words without spaces.
  21 + std::string text;
  22 +
  23 + // Decoded results at the token level.
  24 + // For instance, for BPE-based models it consists of a list of BPE tokens.
  25 + std::vector<std::string> tokens;
  26 +
  27 + /// timestamps.size() == tokens.size()
  28 + /// timestamps[i] records the time in seconds when tokens[i] is decoded.
  29 + std::vector<float> timestamps;
  30 +};
17 31
18 struct OfflineFeatureExtractorConfig { 32 struct OfflineFeatureExtractorConfig {
19 // Sampling rate used by the feature extractor. If it is different from 33 // Sampling rate used by the feature extractor. If it is different from
@@ -4,6 +4,11 @@ pybind11_add_module(_sherpa_onnx @@ -4,6 +4,11 @@ pybind11_add_module(_sherpa_onnx
4 display.cc 4 display.cc
5 endpoint.cc 5 endpoint.cc
6 features.cc 6 features.cc
  7 + offline-model-config.cc
  8 + offline-paraformer-model-config.cc
  9 + offline-recognizer.cc
  10 + offline-stream.cc
  11 + offline-transducer-model-config.cc
7 online-recognizer.cc 12 online-recognizer.cc
8 online-stream.cc 13 online-stream.cc
9 online-transducer-model-config.cc 14 online-transducer-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  12 +
  13 +#include "sherpa-onnx/csrc/offline-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +void PybindOfflineModelConfig(py::module *m) {
  18 + PybindOfflineTransducerModelConfig(m);
  19 + PybindOfflineParaformerModelConfig(m);
  20 +
  21 + using PyClass = OfflineModelConfig;
  22 + py::class_<PyClass>(*m, "OfflineModelConfig")
  23 + .def(py::init<OfflineTransducerModelConfig &,
  24 + OfflineParaformerModelConfig &,
  25 + const std::string &, int32_t, bool>(),
  26 + py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"),
  27 + py::arg("num_threads"), py::arg("debug") = false)
  28 + .def_readwrite("transducer", &PyClass::transducer)
  29 + .def_readwrite("paraformer", &PyClass::paraformer)
  30 + .def_readwrite("tokens", &PyClass::tokens)
  31 + .def_readwrite("num_threads", &PyClass::num_threads)
  32 + .def_readwrite("debug", &PyClass::debug)
  33 + .def("__str__", &PyClass::ToString);
  34 +}
  35 +
  36 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-model-config.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-paraformer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  6 +
  7 +
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void PybindOfflineParaformerModelConfig(py::module *m) {
  16 + using PyClass = OfflineParaformerModelConfig;
  17 + py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
  18 + .def(py::init<const std::string &>(),
  19 + py::arg("model"))
  20 + .def_readwrite("model", &PyClass::model)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-paraformer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineParaformerModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-recognizer.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-recognizer.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +
  15 +
  16 +static void PybindOfflineRecognizerConfig(py::module *m) {
  17 + using PyClass = OfflineRecognizerConfig;
  18 + py::class_<PyClass>(*m, "OfflineRecognizerConfig")
  19 + .def(py::init<const OfflineFeatureExtractorConfig &,
  20 + const OfflineModelConfig &, const std::string &>(),
  21 + py::arg("feat_config"), py::arg("model_config"),
  22 + py::arg("decoding_method"))
  23 + .def_readwrite("feat_config", &PyClass::feat_config)
  24 + .def_readwrite("model_config", &PyClass::model_config)
  25 + .def_readwrite("decoding_method", &PyClass::decoding_method)
  26 + .def("__str__", &PyClass::ToString);
  27 +}
  28 +
  29 +void PybindOfflineRecognizer(py::module *m) {
  30 + PybindOfflineRecognizerConfig(m);
  31 +
  32 + using PyClass = OfflineRecognizer;
  33 + py::class_<PyClass>(*m, "OfflineRecognizer")
  34 + .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
  35 + .def("create_stream", &PyClass::CreateStream)
  36 + .def("decode_stream", &PyClass::DecodeStream)
  37 + .def("decode_streams",
  38 + [](PyClass &self, std::vector<OfflineStream *> ss) {
  39 + self.DecodeStreams(ss.data(), ss.size());
  40 + });
  41 +}
  42 +
  43 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-recognizer.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineRecognizer(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_
  1 +// sherpa-onnx/python/csrc/offline-stream.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-stream.h"
  6 +
  7 +#include "sherpa-onnx/csrc/offline-stream.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +constexpr const char *kAcceptWaveformUsage = R"(
  12 +Process audio samples.
  13 +
  14 +Args:
  15 + sample_rate:
  16 + Sample rate of the input samples. If it is different from the one
  17 + expected by the model, we will do resampling inside.
  18 + waveform:
  19 + A 1-D float32 tensor containing audio samples. It must be normalized
  20 + to the range [-1, 1].
  21 +)";
  22 +
  23 +static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
  24 + using PyClass = OfflineRecognitionResult;
  25 + py::class_<PyClass>(*m, "OfflineRecognitionResult")
  26 + .def_property_readonly("text",
  27 + [](const PyClass &self) { return self.text; })
  28 + .def_property_readonly("tokens",
  29 + [](const PyClass &self) { return self.tokens; })
  30 + .def_property_readonly(
  31 + "timestamps", [](const PyClass &self) { return self.timestamps; });
  32 +}
  33 +
  34 +
  35 +static void PybindOfflineFeatureExtractorConfig(py::module *m) {
  36 + using PyClass = OfflineFeatureExtractorConfig;
  37 + py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
  38 + .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
  39 + py::arg("feature_dim") = 80)
  40 + .def_readwrite("sampling_rate", &PyClass::sampling_rate)
  41 + .def_readwrite("feature_dim", &PyClass::feature_dim)
  42 + .def("__str__", &PyClass::ToString);
  43 +}
  44 +
  45 +
  46 +void PybindOfflineStream(py::module *m) {
  47 + PybindOfflineFeatureExtractorConfig(m);
  48 + PybindOfflineRecognitionResult(m);
  49 +
  50 + using PyClass = OfflineStream;
  51 + py::class_<PyClass>(*m, "OfflineStream")
  52 + .def(
  53 + "accept_waveform",
  54 + [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
  55 + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
  56 + },
  57 + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
  58 + .def_property_readonly("result", &PyClass::GetResult);
  59 +}
  60 +
  61 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-stream.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineStream(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_
  1 +// sherpa-onnx/python/csrc/offline-transducer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
  6 +
  7 +
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void PybindOfflineTransducerModelConfig(py::module *m) {
  16 + using PyClass = OfflineTransducerModelConfig;
  17 + py::class_<PyClass>(*m, "OfflineTransducerModelConfig")
  18 + .def(py::init<const std::string &, const std::string &,
  19 + const std::string &>(),
  20 + py::arg("encoder_filename"), py::arg("decoder_filename"),
  21 + py::arg("joiner_filename"))
  22 + .def_readwrite("encoder_filename", &PyClass::encoder_filename)
  23 + .def_readwrite("decoder_filename", &PyClass::decoder_filename)
  24 + .def_readwrite("joiner_filename", &PyClass::joiner_filename)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-transducer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTransducerModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
@@ -11,10 +11,17 @@ @@ -11,10 +11,17 @@
11 #include "sherpa-onnx/python/csrc/online-stream.h" 11 #include "sherpa-onnx/python/csrc/online-stream.h"
12 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 12 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
13 13
  14 +#include "sherpa-onnx/python/csrc/offline-model-config.h"
  15 +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  16 +#include "sherpa-onnx/python/csrc/offline-recognizer.h"
  17 +#include "sherpa-onnx/python/csrc/offline-stream.h"
  18 +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
  19 +
14 namespace sherpa_onnx { 20 namespace sherpa_onnx {
15 21
16 PYBIND11_MODULE(_sherpa_onnx, m) { 22 PYBIND11_MODULE(_sherpa_onnx, m) {
17 m.doc() = "pybind11 binding of sherpa-onnx"; 23 m.doc() = "pybind11 binding of sherpa-onnx";
  24 +
18 PybindFeatures(&m); 25 PybindFeatures(&m);
19 PybindOnlineTransducerModelConfig(&m); 26 PybindOnlineTransducerModelConfig(&m);
20 PybindOnlineStream(&m); 27 PybindOnlineStream(&m);
@@ -22,6 +29,10 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -22,6 +29,10 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
22 PybindOnlineRecognizer(&m); 29 PybindOnlineRecognizer(&m);
23 30
24 PybindDisplay(&m); 31 PybindDisplay(&m);
  32 +
  33 + PybindOfflineStream(&m);
  34 + PybindOfflineModelConfig(&m);
  35 + PybindOfflineRecognizer(&m);
25 } 36 }
26 37
27 } // namespace sherpa_onnx 38 } // namespace sherpa_onnx
1 from _sherpa_onnx import Display 1 from _sherpa_onnx import Display
2 2
3 from .online_recognizer import OnlineRecognizer 3 from .online_recognizer import OnlineRecognizer
  4 +from .offline_recognizer import OfflineRecognizer
  1 +# Copyright (c) 2023 by manyeyes
  2 +from pathlib import Path
  3 +from typing import List
  4 +
  5 +from _sherpa_onnx import (
  6 + OfflineFeatureExtractorConfig,
  7 + OfflineRecognizer as _Recognizer,
  8 + OfflineRecognizerConfig,
  9 + OfflineStream,
  10 + OfflineModelConfig,
  11 + OfflineTransducerModelConfig,
  12 + OfflineParaformerModelConfig,
  13 +)
  14 +
  15 +
  16 +def _assert_file_exists(f: str):
  17 + assert Path(f).is_file(), f"{f} does not exist"
  18 +
  19 +
  20 +class OfflineRecognizer(object):
  21 + """A class for offline speech recognition."""
  22 +
  23 + @classmethod
  24 + def from_transducer(
  25 + cls,
  26 + encoder: str,
  27 + decoder: str,
  28 + joiner: str,
  29 + tokens: str,
  30 + num_threads: int,
  31 + sample_rate: int = 16000,
  32 + feature_dim: int = 80,
  33 + decoding_method: str = "greedy_search",
  34 + debug: bool = False,
  35 + ):
  36 + """
  37 + Please refer to
  38 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  39 + to download pre-trained models for different languages, e.g., Chinese,
  40 + English, etc.
  41 +
  42 + Args:
  43 + tokens:
  44 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  45 + columns::
  46 +
  47 + symbol integer_id
  48 +
  49 + encoder:
  50 + Path to ``encoder.onnx``.
  51 + decoder:
  52 + Path to ``decoder.onnx``.
  53 + joiner:
  54 + Path to ``joiner.onnx``.
  55 + num_threads:
  56 + Number of threads for neural network computation.
  57 + sample_rate:
  58 + Sample rate of the training data used to train the model.
  59 + feature_dim:
  60 + Dimension of the feature used to train the model.
  61 + decoding_method:
  62 + Valid values are greedy_search, modified_beam_search.
  63 + debug:
  64 + True to show debug messages.
  65 + """
  66 + self = cls.__new__(cls)
  67 + model_config = OfflineModelConfig(
  68 + transducer=OfflineTransducerModelConfig(
  69 + encoder_filename=encoder,
  70 + decoder_filename=decoder,
  71 + joiner_filename=joiner
  72 + ),
  73 + paraformer=OfflineParaformerModelConfig(
  74 + model=""
  75 + ),
  76 + tokens=tokens,
  77 + num_threads=num_threads,
  78 + debug=debug
  79 + )
  80 +
  81 + feat_config = OfflineFeatureExtractorConfig(
  82 + sampling_rate=sample_rate,
  83 + feature_dim=feature_dim,
  84 + )
  85 +
  86 + recognizer_config = OfflineRecognizerConfig(
  87 + feat_config=feat_config,
  88 + model_config=model_config,
  89 + decoding_method=decoding_method,
  90 + )
  91 + self.recognizer = _Recognizer(recognizer_config)
  92 + return self
  93 +
  94 + @classmethod
  95 + def from_paraformer(
  96 + cls,
  97 + paraformer: str,
  98 + tokens: str,
  99 + num_threads: int,
  100 + sample_rate: int = 16000,
  101 + feature_dim: int = 80,
  102 + decoding_method: str = "greedy_search",
  103 + debug: bool = False,
  104 + ):
  105 + """
  106 + Please refer to
  107 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  108 + to download pre-trained models for different languages, e.g., Chinese,
  109 + English, etc.
  110 +
  111 + Args:
  112 + tokens:
  113 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  114 + columns::
  115 +
  116 + symbol integer_id
  117 +
  118 + paraformer:
  119 + Path to ``paraformer.onnx``.
  120 + num_threads:
  121 + Number of threads for neural network computation.
  122 + sample_rate:
  123 + Sample rate of the training data used to train the model.
  124 + feature_dim:
  125 + Dimension of the feature used to train the model.
  126 + decoding_method:
  127 + Valid values are greedy_search, modified_beam_search.
  128 + debug:
  129 + True to show debug messages.
  130 + """
  131 + self = cls.__new__(cls)
  132 + model_config = OfflineModelConfig(
  133 + transducer=OfflineTransducerModelConfig(
  134 + encoder_filename="",
  135 + decoder_filename="",
  136 + joiner_filename=""
  137 + ),
  138 + paraformer=OfflineParaformerModelConfig(
  139 + model=paraformer
  140 + ),
  141 + tokens=tokens,
  142 + num_threads=num_threads,
  143 + debug=debug
  144 + )
  145 +
  146 + feat_config = OfflineFeatureExtractorConfig(
  147 + sampling_rate=sample_rate,
  148 + feature_dim=feature_dim,
  149 + )
  150 +
  151 + recognizer_config = OfflineRecognizerConfig(
  152 + feat_config=feat_config,
  153 + model_config=model_config,
  154 + decoding_method=decoding_method,
  155 + )
  156 + self.recognizer = _Recognizer(recognizer_config)
  157 + return self
  158 +
  159 + def create_stream(self):
  160 + return self.recognizer.create_stream()
  161 +
  162 + def decode_stream(self, s: OfflineStream):
  163 + self.recognizer.decode_stream(s)
  164 +
  165 + def decode_streams(self, ss: List[OfflineStream]):
  166 + self.recognizer.decode_streams(ss)
  167 +