online_recognizer.py 4.2 KB
from pathlib import Path
from typing import List

from _sherpa_onnx import (
    EndpointConfig,
    FeatureExtractorConfig,
    OnlineRecognizer as _Recognizer,
    OnlineRecognizerConfig,
    OnlineStream,
    OnlineTransducerModelConfig,
)


def _assert_file_exists(f: str):
    assert Path(f).is_file(), f"{f} does not exist"


class OnlineRecognizer(object):
    """A class for streaming speech recognition."""

    def __init__(
        self,
        tokens: str,
        encoder: str,
        decoder: str,
        joiner: str,
        num_threads: int = 4,
        sample_rate: float = 16000,
        feature_dim: int = 80,
        enable_endpoint_detection: bool = False,
        rule1_min_trailing_silence: int = 2.4,
        rule2_min_trailing_silence: int = 1.2,
        rule3_min_utterance_length: int = 20,
    ):
        """
        Please refer to
        `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
        to download pre-trained models for different languages, e.g., Chinese,
        English, etc.

        Args:
          tokens:
            Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
            columns::

                symbol integer_id

          encoder:
            Path to ``encoder.onnx``.
          decoder:
            Path to ``decoder.onnx``.
          joiner:
            Path to ``joiner.onnx``.
          num_threads:
            Number of threads for neural network computation.
          sample_rate:
            Sample rate of the training data used to train the model.
          feature_dim:
            Dimension of the feature used to train the model.
          enable_endpoint_detection:
            True to enable endpoint detection. False to disable endpoint
            detection.
          rule1_min_trailing_silence:
            Used only when enable_endpoint_detection is True. If the duration
            of trailing silence in seconds is larger than this value, we assume
            an endpoint is detected.
          rule2_min_trailing_silence:
            Used only when enable_endpoint_detection is True. If we have decoded
            something that is nonsilence and if the duration of trailing silence
            in seconds is larger than this value, we assume an endpoint is
            detected.
          rule3_min_utterance_length:
            Used only when enable_endpoint_detection is True. If the utterance
            length in seconds is larger than this value, we assume an endpoint
            is detected.
        """
        _assert_file_exists(tokens)
        _assert_file_exists(encoder)
        _assert_file_exists(decoder)
        _assert_file_exists(joiner)

        assert num_threads > 0, num_threads

        model_config = OnlineTransducerModelConfig(
            encoder_filename=encoder,
            decoder_filename=decoder,
            joiner_filename=joiner,
            num_threads=num_threads,
        )

        feat_config = FeatureExtractorConfig(
            sampling_rate=sample_rate,
            feature_dim=feature_dim,
        )

        endpoint_config = EndpointConfig(
            rule1_min_trailing_silence=rule1_min_trailing_silence,
            rule2_min_trailing_silence=rule2_min_trailing_silence,
            rule3_min_utterance_length=rule3_min_utterance_length,
        )

        recognizer_config = OnlineRecognizerConfig(
            feat_config=feat_config,
            model_config=model_config,
            tokens=tokens,
            endpoint_config=endpoint_config,
            enable_endpoint=enable_endpoint_detection,
        )

        self.recognizer = _Recognizer(recognizer_config)

    def create_stream(self):
        return self.recognizer.create_stream()

    def decode_stream(self, s: OnlineStream):
        self.recognizer.decode_stream(s)

    def decode_streams(self, ss: List[OnlineStream]):
        self.recognizer.decode_streams(ss)

    def is_ready(self, s: OnlineStream) -> bool:
        return self.recognizer.is_ready(s)

    def get_result(self, s: OnlineStream) -> str:
        return self.recognizer.get_result(s).text.strip()

    def is_endpoint(self, s: OnlineStream) -> bool:
        return self.recognizer.is_endpoint(s)

    def reset(self, s: OnlineStream) -> bool:
        return self.recognizer.reset(s)