Fangjun Kuang
Committed by GitHub

Add non-streaming websocket server for python (#259)

... ... @@ -23,12 +23,12 @@ permissions:
jobs:
test_pip_install:
runs-on: ${{ matrix.os }}
name: Test pip install on ${{ matrix.os }}
name: ${{ matrix.os }} ${{ matrix.python-version }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
... ... @@ -50,3 +50,15 @@ jobs:
run: |
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)"
sherpa-onnx --help
sherpa-onnx-offline --help
sherpa-onnx-microphone --help
sherpa-onnx-microphone-offline --help
sherpa-onnx-offline-websocket-server --help
sherpa-onnx-offline-websocket-client --help
sherpa-onnx-online-websocket-server --help
sherpa-onnx-online-websocket-client --help
... ...
name: Python offline websocket server
on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency:
group: python-offline-websocket-server-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
python_offline_websocket_server:
runs-on: ${{ matrix.os }}
name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy
- name: Install sherpa-onnx
shell: bash
run: |
python3 -m pip install --no-deps --verbose .
python3 -m pip install websockets
- name: Start server for transducer models
if: matrix.model_type == 'transducer'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
cd sherpa-onnx-zipformer-en-2023-06-26
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
--decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
--joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for transducer models
if: matrix.model_type == 'transducer'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
- name: Start server for paraformer models
if: matrix.model_type == 'paraformer'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
cd sherpa-onnx-paraformer-zh-2023-03-28
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for paraformer models
if: matrix.model_type == 'paraformer'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
- name: Start server for nemo_ctc models
if: matrix.model_type == 'nemo_ctc'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
cd sherpa-onnx-nemo-ctc-en-conformer-medium
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for nemo_ctc models
if: matrix.model_type == 'nemo_ctc'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
- name: Start server for whisper models
if: matrix.model_type == 'whisper'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
cd sherpa-onnx-whisper-tiny.en
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for whisper models
if: matrix.model_type == 'whisper'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
... ...
name: Python online websocket server
on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency:
group: python-online-websocket-server-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
python_online_websocket_server:
runs-on: ${{ matrix.os }}
name: ${{ matrix.os }} ${{ matrix.python-version }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer"]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy
- name: Install sherpa-onnx
shell: bash
run: |
python3 -m pip install --no-deps --verbose .
python3 -m pip install websockets
- name: Start server for transducer models
if: matrix.model_type == 'transducer'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
cd sherpa-onnx-streaming-zipformer-en-2023-06-26
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/streaming_server.py \
--encoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \
--decoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \
--joiner ./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for transducer models
if: matrix.model_type == 'transducer'
shell: bash
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.7.1")
set(SHERPA_ONNX_VERSION "1.7.2")
# Disable warning about
#
... ...
# Introduction
This folder contains C API examples for [sherpa-onnx][sherpa-onnx].
Please refer to the documentation
https://k2-fsa.github.io/sherpa/onnx/c-api/index.html
for details.
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
... ...
# Introduction
This folder contains C# API examples for [sherpa-onnx][sherpa-onnx].
Please refer to the documentation
https://k2-fsa.github.io/sherpa/onnx/csharp-api/index.html
for details.
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
... ...
# Introduction
This folder contains Go API examples for [sherpa-onnx][sherpa-onnx].
Please refer to the documentation
https://k2-fsa.github.io/sherpa/onnx/go-api/index.html
for details.
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
... ...
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp.
"""
A server for non-streaming speech recognition. Non-streaming means you send all
the content of the audio at once for recognition.
It supports multiple clients sending at the same time.
Usage:
./non_streaming_server.py --help
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
for pre-trained models to download.
Usage examples:
(1) Use a non-streaming transducer model
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
cd sherpa-onnx-zipformer-en-2023-06-26
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
--decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
--joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt
(2) Use a non-streaming paraformer
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
cd sherpa-onnx-paraformer-zh-2023-03-28
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
(3) Use a non-streaming CTC model from NeMo
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
cd sherpa-onnx-nemo-ctc-en-conformer-medium
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
(4) Use a Whisper model
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
cd sherpa-onnx-whisper-tiny.en
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
----
To use a certificate so that you can use https, please use
python3 ./python-api-examples/non_streaming_server.py \
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--certificate=/path/to/your/cert.pem
If you don't have a certificate, please run:
cd ./python-api-examples/web
./generate-certificate.py
It will generate 3 files, one of which is the required `cert.pem`.
""" # noqa
import argparse
import asyncio
import http
import logging
import socket
import ssl
import sys
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import sherpa_onnx
import websockets
from http_server import HttpServer
def setup_logger(
log_filename: str,
log_level: str = "info",
use_console: bool = True,
) -> None:
"""Setup log level.
Args:
log_filename:
The filename to save the log.
log_level:
The log level to use, e.g., "debug", "info", "warning", "error",
"critical"
use_console:
True to also print logs to console.
"""
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
log_filename = f"{log_filename}-{date_time}.txt"
Path(log_filename).parent.mkdir(parents=True, exist_ok=True)
level = logging.ERROR
if log_level == "debug":
level = logging.DEBUG
elif log_level == "info":
level = logging.INFO
elif log_level == "warning":
level = logging.WARNING
elif log_level == "critical":
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename,
format=formatter,
level=level,
filemode="w",
)
if use_console:
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
def add_transducer_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder",
default="",
type=str,
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
default="",
type=str,
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
default="",
type=str,
help="Path to the transducer joiner model",
)
def add_paraformer_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--paraformer",
default="",
type=str,
help="Path to the model.onnx from Paraformer",
)
def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_whisper_model_args(parser)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--num-threads",
type=int,
default=2,
help="Number of threads to run the neural network model",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
def add_feature_config_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="Sample rate of the data used to train the model. ",
)
parser.add_argument(
"--feat-dim",
type=int,
default=80,
help="Feature dimension of the model",
)
def add_decoding_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Decoding method to use. Current supported methods are:
- greedy_search
- modified_beam_search (for transducer models only)
""",
)
add_modified_beam_search_args(parser)
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""Used only when --decoding-method is modified_beam_search.
It specifies number of active paths to keep during decoding.
""",
)
def check_args(args):
if not Path(args.tokens).is_file():
raise ValueError(f"{args.tokens} does not exist")
if args.decoding_method not in (
"greedy_search",
"modified_beam_search",
):
raise ValueError(f"Unsupported decoding method {args.decoding_method}")
if args.decoding_method == "modified_beam_search":
assert args.num_active_paths > 0, args.num_active_paths
assert Path(args.encoder).is_file(), args.encoder
assert Path(args.decoder).is_file(), args.decoder
assert Path(args.joiner).is_file(), args.joiner
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
add_model_args(parser)
add_feature_config_args(parser)
add_decoding_args(parser)
parser.add_argument(
"--port",
type=int,
default=6006,
help="The server will listen on this port",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=25,
help="""Max batch size for computation. Note if there are not enough
requests in the queue, it will wait for max_wait_ms time. After that,
even if there are not enough requests, it still sends the
available requests in the queue for computation.
""",
)
parser.add_argument(
"--max-wait-ms",
type=float,
default=5,
help="""Max time in millisecond to wait to build batches for inference.
If there are not enough requests in the feature queue to build a batch
of max_batch_size, it waits up to this time before fetching available
requests for computation.
""",
)
parser.add_argument(
"--nn-pool-size",
type=int,
default=1,
help="Number of threads for NN computation and decoding.",
)
parser.add_argument(
"--max-message-size",
type=int,
default=(1 << 20),
help="""Max message size in bytes.
The max size per message cannot exceed this limit.
""",
)
parser.add_argument(
"--max-queue-size",
type=int,
default=32,
help="Max number of messages in the queue for each connection.",
)
parser.add_argument(
"--max-active-connections",
type=int,
default=500,
help="""Maximum number of active connections. The server will refuse
to accept new connections once the current number of active connections
equals to this limit.
""",
)
parser.add_argument(
"--certificate",
type=str,
help="""Path to the X.509 certificate. You need it only if you want to
use a secure websocket connection, i.e., use wss:// instead of ws://.
You can use ./web/generate-certificate.py
to generate the certificate `cert.pem`.
Note ./web/generate-certificate.py will generate three files but you
only need to pass the generated cert.pem to this option.
""",
)
parser.add_argument(
"--doc-root",
type=str,
default="./python-api-examples/web",
help="Path to the web root",
)
return parser.parse_args()
class NonStreamingServer:
def __init__(
self,
recognizer: sherpa_onnx.OfflineRecognizer,
max_batch_size: int,
max_wait_ms: float,
nn_pool_size: int,
max_message_size: int,
max_queue_size: int,
max_active_connections: int,
doc_root: str,
certificate: Optional[str] = None,
):
"""
Args:
recognizer:
An instance of the sherpa_onnx.OfflineRecognizer.
max_batch_size:
Max batch size for inference.
max_wait_ms:
Max wait time in milliseconds in order to build a batch of
`max_batch_size`.
nn_pool_size:
Number of threads for the thread pool that is used for NN
computation and decoding.
max_message_size:
Max size in bytes per message.
max_queue_size:
Max number of messages in the queue for each connection.
max_active_connections:
Max number of active connections. Once number of active client
equals to this limit, the server refuses to accept new connections.
doc_root:
Path to the directory where files like index.html for the HTTP
server locate.
certificate:
Optional. If not None, it will use secure websocket.
You can use ./web/generate-certificate.py to generate
it (the default generated filename is `cert.pem`).
"""
self.recognizer = recognizer
self.certificate = certificate
self.http_server = HttpServer(doc_root)
self.nn_pool = ThreadPoolExecutor(
max_workers=nn_pool_size,
thread_name_prefix="nn",
)
self.stream_queue = asyncio.Queue()
self.max_wait_ms = max_wait_ms
self.max_batch_size = max_batch_size
self.max_message_size = max_message_size
self.max_queue_size = max_queue_size
self.max_active_connections = max_active_connections
self.current_active_connections = 0
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
async def process_request(
self,
path: str,
request_headers: websockets.Headers,
) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]:
if "sec-websocket-key" not in request_headers:
# This is a normal HTTP request
if path == "/":
path = "/index.html"
if path[-1] == "?":
path = path[:-1]
if path == "/streaming_record.html":
response = r"""
<!doctype html><html><head>
<title>Speech recognition with next-gen Kaldi</title><body>
<h2>Only
<a href="/upload.html">/upload.html</a>
and
<a href="/offline_record.html">/offline_record.html</a>
is available for the non-streaming server.<h2>
<br/>
<br/>
Go back to <a href="/upload.html">/upload.html</a>
or <a href="/offline_record.html">/offline_record.html</a>
</body></head></html>
"""
found = True
mime_type = "text/html"
else:
found, response, mime_type = self.http_server.process_request(path)
if isinstance(response, str):
response = response.encode("utf-8")
if not found:
status = http.HTTPStatus.NOT_FOUND
else:
status = http.HTTPStatus.OK
header = {"Content-Type": mime_type}
return status, header, response
if self.current_active_connections < self.max_active_connections:
self.current_active_connections += 1
return None
# Refuse new connections
status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
header = {"Hint": "The server is overloaded. Please retry later."}
response = b"The server is busy. Please retry later."
return status, header, response
async def run(self, port: int):
logging.info("started")
task = asyncio.create_task(self.stream_consumer_task())
if self.certificate:
logging.info(f"Using certificate: {self.certificate}")
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(self.certificate)
else:
ssl_context = None
logging.info("No certificate provided")
async with websockets.serve(
self.handle_connection,
host="",
port=port,
max_size=self.max_message_size,
max_queue=self.max_queue_size,
process_request=self.process_request,
ssl=ssl_context,
):
ip_list = ["localhost"]
if ssl_context:
ip_list += ["0.0.0.0", "127.0.0.1"]
ip_list.append(socket.gethostbyname(socket.gethostname()))
proto = "http://" if ssl_context is None else "https://"
s = "Please visit one of the following addresses:\n\n"
for p in ip_list:
s += " " + proto + p + f":{port}" "\n"
logging.info(s)
await asyncio.Future() # run forever
await task # not reachable
async def recv_audio_samples(
self,
socket: websockets.WebSocketServerProtocol,
) -> Tuple[Optional[np.ndarray], Optional[float]]:
"""Receive a tensor from the client.
The message from the client is a **bytes** buffer.
The first message can be either "Done" meaning the client won't send
anything in the future or it can be a buffer containing 8 bytes.
The first 4 bytes in little endian specifies the sample
rate of the audio samples; the second 4 bytes in little endian specifies
the number of bytes in the audio file, which will be sent by the client
in the subsequent messages.
Since there is a limit in the message size posed by the websocket
protocol, the client may send the audio file in multiple messages if the
audio file is very large.
The second and remaining messages contain audio samples.
Please refer to ./offline-websocket-client-decode-files-paralell.py
and ./offline-websocket-client-decode-files-sequential.py
for how the client sends the message.
Args:
socket:
The socket for communicating with the client.
Returns:
Return a containing:
- 1-D np.float32 array containing the audio samples
- sample rate of the audio samples
or return (None, None) indicating the end of utterance.
"""
header = await socket.recv()
if header == "Done":
return None, None
assert len(header) >= 8, (
"The first message should contain at least 8 bytes."
+ f"Given {len(header)}"
)
sample_rate = int.from_bytes(header[:4], "little", signed=True)
expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True)
received = []
num_received_bytes = 0
if len(header) > 8:
received.append(header[8:])
num_received_bytes += len(header) - 8
if num_received_bytes < expected_num_bytes:
async for message in socket:
received.append(message)
num_received_bytes += len(message)
if num_received_bytes >= expected_num_bytes:
break
assert num_received_bytes == expected_num_bytes, (
num_received_bytes,
expected_num_bytes,
)
samples = b"".join(received)
array = np.frombuffer(samples, dtype=np.float32)
return array, sample_rate
async def stream_consumer_task(self):
"""This function extracts streams from the queue, batches them up, sends
them to the RNN-T model for computation and decoding.
"""
while True:
if self.stream_queue.empty():
await asyncio.sleep(self.max_wait_ms / 1000)
continue
batch = []
try:
while len(batch) < self.max_batch_size:
item = self.stream_queue.get_nowait()
batch.append(item)
except asyncio.QueueEmpty:
pass
stream_list = [b[0] for b in batch]
future_list = [b[1] for b in batch]
loop = asyncio.get_running_loop()
await loop.run_in_executor(
self.nn_pool,
self.recognizer.decode_streams,
stream_list,
)
for f in future_list:
self.stream_queue.task_done()
f.set_result(None)
async def compute_and_decode(
self,
stream: sherpa_onnx.OfflineStream,
) -> None:
"""Put the stream into the queue and wait it to be processed by the
consumer task.
Args:
stream:
The stream to be processed. Note: It is changed in-place.
"""
loop = asyncio.get_running_loop()
future = loop.create_future()
await self.stream_queue.put((stream, future))
await future
async def handle_connection(
self,
socket: websockets.WebSocketServerProtocol,
):
"""Receive audio samples from the client, process it, and sends
deocoding result back to the client.
Args:
socket:
The socket for communicating with the client.
"""
try:
await self.handle_connection_impl(socket)
except websockets.exceptions.ConnectionClosedError:
logging.info(f"{socket.remote_address} disconnected")
finally:
# Decrement so that it can accept new connections
self.current_active_connections -= 1
logging.info(
f"Disconnected: {socket.remote_address}. "
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
)
async def handle_connection_impl(
self,
socket: websockets.WebSocketServerProtocol,
):
"""Receive audio samples from the client, process it, and send
decoding results back to the client.
Args:
socket:
The socket for communicating with the client.
"""
logging.info(
f"Connected: {socket.remote_address}. "
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
)
while True:
stream = self.recognizer.create_stream()
samples, sample_rate = await self.recv_audio_samples(socket)
if samples is None:
break
# stream.accept_samples() runs in the main thread
stream.accept_waveform(sample_rate, samples)
await self.compute_and_decode(stream)
result = stream.result.text
logging.info(f"result: {result}")
if result:
await socket.send(result)
else:
# If result is an empty string, send something to the client.
# Otherwise, socket.send() is a no-op and the client will
# wait for a reply indefinitely.
await socket.send("<EMPTY>")
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.nemo_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else:
raise ValueError("Please specify at least one model")
return recognizer
def main():
args = get_args()
logging.info(vars(args))
check_args(args)
recognizer = create_recognizer(args)
port = args.port
max_wait_ms = args.max_wait_ms
max_batch_size = args.max_batch_size
nn_pool_size = args.nn_pool_size
max_message_size = args.max_message_size
max_queue_size = args.max_queue_size
max_active_connections = args.max_active_connections
certificate = args.certificate
doc_root = args.doc_root
if certificate and not Path(certificate).is_file():
raise ValueError(f"{certificate} does not exist")
if not Path(doc_root).is_dir():
raise ValueError(f"Directory {doc_root} does not exist")
non_streaming_server = NonStreamingServer(
recognizer=recognizer,
max_wait_ms=max_wait_ms,
max_batch_size=max_batch_size,
nn_pool_size=nn_pool_size,
max_message_size=max_message_size,
max_queue_size=max_queue_size,
max_active_connections=max_active_connections,
certificate=certificate,
doc_root=doc_root,
)
asyncio.run(non_streaming_server.run(port))
if __name__ == "__main__":
log_filename = "log/log-non-streaming-server"
setup_logger(log_filename)
main()
... ...
... ... @@ -119,6 +119,12 @@ async def run(
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
payload_len = 10240
while len(buf) > payload_len:
await websocket.send(buf[:payload_len])
buf = buf[payload_len:]
if buf:
await websocket.send(buf)
decoding_results = await websocket.recv()
... ...
... ... @@ -116,10 +116,17 @@ async def run(
assert isinstance(sample_rate, int)
assert samples.dtype == np.float32, samples.dtype
assert samples.ndim == 1, samples.dim
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
payload_len = 10240
while len(buf) > payload_len:
await websocket.send(buf[:payload_len])
buf = buf[payload_len:]
if buf:
await websocket.send(buf)
decoding_results = await websocket.recv()
... ...
... ... @@ -15,10 +15,9 @@ Usage:
(Note: You have to first start the server before starting the client)
You can find the server at
You can find the c++ server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
Note: The server is implemented in C++.
or use the python server ./python-api-examples/streaming_server.py
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
... ... @@ -115,6 +114,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
last_message = message
logging.info(message)
else:
break
return last_message
... ... @@ -142,6 +142,7 @@ async def run(
await websocket.send(d)
# Simulate streaming. You can remove the sleep if you want
await asyncio.sleep(seconds_per_message) # in seconds
start += samples_per_message
... ...
... ... @@ -12,10 +12,9 @@ Usage:
(Note: You have to first start the server before starting the client)
You can find the server at
You can find the C++ server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
Note: The server is implemented in C++.
or use the python server ./python-api-examples/streaming_server.py
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
... ...
... ... @@ -13,11 +13,37 @@ Usage:
Example:
(1) Without a certificate
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
(2) With a certificate
(a) Generate a certificate first:
cd python-api-examples/web
./generate-certificate.py
cd ../..
(b) Start the server
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--certificate ./python-api-examples/web/cert.pem
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
"""
import argparse
... ... @@ -35,6 +61,7 @@ from typing import List, Optional, Tuple
import numpy as np
import sherpa_onnx
import websockets
from http_server import HttpServer
... ... @@ -269,8 +296,8 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.",
default=2,
help="Number of threads to run the neural network model",
)
parser.add_argument(
... ... @@ -278,8 +305,10 @@ def get_args():
type=str,
help="""Path to the X.509 certificate. You need it only if you want to
use a secure websocket connection, i.e., use wss:// instead of ws://.
You can use sherpa/bin/web/generate-certificate.py
You can use ./web/generate-certificate.py
to generate the certificate `cert.pem`.
Note ./web/generate-certificate.py will generate three files but you
only need to pass the generated cert.pem to this option.
""",
)
... ... @@ -287,7 +316,7 @@ def get_args():
"--doc-root",
type=str,
default="./python-api-examples/web",
help="""Path to the web root""",
help="Path to the web root",
)
return parser.parse_args()
... ... @@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
encoder=args.encoder_model,
decoder=args.decoder_model,
joiner=args.joiner_model,
num_threads=1,
sample_rate=16000,
feature_dim=80,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.num_active_paths,
enable_endpoint_detection=args.use_endpoint != 0,
... ... @@ -359,7 +388,7 @@ class StreamingServer(object):
server locate.
certificate:
Optional. If not None, it will use secure websocket.
You can use ./sherpa/bin/web/generate-certificate.py to generate
You can use ./web/generate-certificate.py to generate
it (the default generated filename is `cert.pem`).
"""
self.recognizer = recognizer
... ... @@ -373,6 +402,7 @@ class StreamingServer(object):
)
self.stream_queue = asyncio.Queue()
self.max_wait_ms = max_wait_ms
self.max_batch_size = max_batch_size
self.max_message_size = max_message_size
... ... @@ -382,11 +412,10 @@ class StreamingServer(object):
self.current_active_connections = 0
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
self.decoding_method = recognizer.config.decoding_method
async def stream_consumer_task(self):
"""This function extracts streams from the queue, batches them up, sends
them to the RNN-T model for computation and decoding.
them to the neural network model for computation and decoding.
"""
while True:
if self.stream_queue.empty():
... ... @@ -442,7 +471,22 @@ class StreamingServer(object):
# This is a normal HTTP request
if path == "/":
path = "/index.html"
if path in ("/upload.html", "/offline_record.html"):
response = r"""
<!doctype html><html><head>
<title>Speech recognition with next-gen Kaldi</title><body>
<h2>Only /streaming_record.html is available for the streaming server.<h2>
<br/>
<br/>
Go back to <a href="/streaming_record.html">/streaming_record.html</a>
</body></head></html>
"""
found = True
mime_type = "text/html"
else:
found, response, mime_type = self.http_server.process_request(path)
if isinstance(response, str):
response = response.encode("utf-8")
... ... @@ -484,12 +528,21 @@ class StreamingServer(object):
process_request=self.process_request,
ssl=ssl_context,
):
ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]
ip_list = ["localhost"]
if ssl_context:
ip_list += ["0.0.0.0", "127.0.0.1"]
ip_list.append(socket.gethostbyname(socket.gethostname()))
proto = "http://" if ssl_context is None else "https://"
s = "Please visit one of the following addresses:\n\n"
for p in ip_list:
s += " " + proto + p + f":{port}" "\n"
if not ssl_context:
s += "\nSince you are not providing a certificate, you cannot "
s += "use your microphone from within the browser using "
s += "public IP addresses. Only localhost can be used."
s += "You also cannot use 0.0.0.0 or 127.0.0.1"
logging.info(s)
await asyncio.Future() # run forever
... ... @@ -525,7 +578,7 @@ class StreamingServer(object):
socket: websockets.WebSocketServerProtocol,
):
"""Receive audio samples from the client, process it, and send
deocoding result back to the client.
decoding result back to the client.
Args:
socket:
... ... @@ -560,8 +613,6 @@ class StreamingServer(object):
self.recognizer.reset(stream)
segment += 1
print(message)
await socket.send(json.dumps(message))
tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
... ... @@ -583,7 +634,7 @@ class StreamingServer(object):
self,
socket: websockets.WebSocketServerProtocol,
) -> Optional[np.ndarray]:
"""Receives a tensor from the client.
"""Receive a tensor from the client.
Each message contains either a bytes buffer containing audio samples
in 16 kHz or contains "Done" meaning the end of utterance.
... ... @@ -660,6 +711,6 @@ def main():
if __name__ == "__main__":
log_filename = "log/log-streaming-zipformer"
log_filename = "log/log-streaming-server"
setup_logger(log_filename)
main()
... ...
# How to use
```bash
git clone https://github.com/k2-fsa/sherpa
cd sherpa/sherpa/bin/web
python3 -m http.server 6009
```
and then go to <http://localhost:6009>
You will see a page like the following screenshot:
![Screenshot if you visit http://localhost:6009](./pic/web-ui.png)
If your server is listening at the port *6006* with address **localhost**,
then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it.
## File descriptions
### ./css/bootstrap.min.css
It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css
### ./js/jquery-3.6.0.min.js
It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js
### ./js/popper.min.js
It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js
### ./js/bootstrap.min.js
It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js
... ... @@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser:
def cert_gen(
emailAddress="https://github.com/k2-fsa/k2",
commonName="sherpa",
emailAddress="https://github.com/k2-fsa/sherpa-onnx",
commonName="sherpa-onnx",
countryName="CN",
localityName="k2-fsa",
stateOrProvinceName="k2-fsa",
... ... @@ -70,17 +70,13 @@ def cert_gen(
cert.set_pubkey(k)
cert.sign(k, "sha512")
with open(CERT_FILE, "wt") as f:
f.write(
crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
)
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
with open(KEY_FILE, "wt") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
with open(ALL_IN_ONE_FILE, "wt") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
f.write(
crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
)
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
print(f"Generated {CERT_FILE}")
print(f"Generated {KEY_FILE}")
print(f"Generated {ALL_IN_ONE_FILE}")
... ...
... ... @@ -53,7 +53,7 @@
</ul>
Code is available at
<a href="https://github.com/k2-fsa/sherpa"> https://github.com/k2-fsa/sherpa</a>
<a href="https://github.com/k2-fsa/sherpa-onnx"> https://github.com/k2-fsa/sherpa-onnx</a>
<!-- Optional JavaScript -->
<!-- jQuery first, then Popper.js, then Bootstrap JS -->
... ...
... ... @@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips');
const canvas = document.getElementById('canvas');
const mainSection = document.querySelector('.container');
recordBtn.disabled = true;
stopBtn.disabled = true;
window.onload = (event) => {
... ... @@ -95,9 +96,10 @@ clearBtn.onclick = function() {
};
function send_header(n) {
const header = new ArrayBuffer(4);
new DataView(header).setInt32(0, n, true /* littleEndian */);
socket.send(new Int32Array(header, 0, 1));
const header = new ArrayBuffer(8);
new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */);
new DataView(header).setInt32(4, n, true /* littleEndian */);
socket.send(new Int32Array(header, 0, 2));
}
// copied/modified from https://mdn.github.io/web-dictaphone/
... ...
... ... @@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas');
const mainSection = document.querySelector('.container');
stopBtn.disabled = true;
recordBtn.disabled = true;
let audioCtx;
const canvasCtx = canvas.getContext('2d');
... ...
... ... @@ -74,9 +74,11 @@ connectBtn.onclick = function() {
};
function send_header(n) {
const header = new ArrayBuffer(4);
new DataView(header).setInt32(0, n, true /* littleEndian */);
socket.send(new Int32Array(header, 0, 1));
const header = new ArrayBuffer(8);
// assume the uploaded wave is 16000 Hz
new DataView(header).setInt32(0, 16000, true /* littleEndian */);
new DataView(header).setInt32(4, n, true /* littleEndian */);
socket.send(new Int32Array(header, 0, 2));
}
function onFileChange() {
... ...
... ... @@ -33,9 +33,9 @@
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
</div>
<span class="input-group-text" id="ws-protocol">ws://</span>
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
<span class="input-group-text">:</span>
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
</div>
<div class="row">
... ...
... ... @@ -33,9 +33,9 @@
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
</div>
<span class="input-group-text" id="ws-protocol">ws://</span>
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
<span class="input-group-text">:</span>
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
</div>
<div class="row">
... ...
... ... @@ -32,9 +32,9 @@
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
</div>
<span class="input-group-text" id="ws-protocol">ws://</span>
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
<span class="input-group-text">:</span>
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
</div>
<form>
... ...
from typing import Dict, List, Optional
from _sherpa_onnx import Display
from _sherpa_onnx import Display, OfflineStream, OnlineStream
from .online_recognizer import OnlineRecognizer
from .online_recognizer import OnlineStream
from .offline_recognizer import OfflineRecognizer
from .online_recognizer import OnlineRecognizer
from .utils import encode_contexts
... ...
... ... @@ -41,6 +41,7 @@ class OfflineRecognizer(object):
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
... ... @@ -72,6 +73,9 @@ class OfflineRecognizer(object):
Dimension of the feature used to train the model.
decoding_method:
Valid values: greedy_search, modified_beam_search.
max_active_paths:
Maximum number of active paths to keep. Used only when
decoding_method is modified_beam_search.
debug:
True to show debug messages.
provider:
... ... @@ -103,6 +107,7 @@ class OfflineRecognizer(object):
context_score=context_score,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
... ... @@ -166,6 +171,7 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
... ... @@ -229,6 +235,7 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
... ... @@ -291,6 +298,7 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
... ...