Fangjun Kuang
Committed by GitHub

Add non-streaming websocket server for python (#259)

@@ -23,12 +23,12 @@ permissions: @@ -23,12 +23,12 @@ permissions:
23 jobs: 23 jobs:
24 test_pip_install: 24 test_pip_install:
25 runs-on: ${{ matrix.os }} 25 runs-on: ${{ matrix.os }}
26 - name: Test pip install on ${{ matrix.os }} 26 + name: ${{ matrix.os }} ${{ matrix.python-version }}
27 strategy: 27 strategy:
28 fail-fast: false 28 fail-fast: false
29 matrix: 29 matrix:
30 os: [ubuntu-latest, windows-latest, macos-latest] 30 os: [ubuntu-latest, windows-latest, macos-latest]
31 - python-version: ["3.8", "3.9", "3.10"] 31 + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
32 32
33 steps: 33 steps:
34 - uses: actions/checkout@v2 34 - uses: actions/checkout@v2
@@ -50,3 +50,15 @@ jobs: @@ -50,3 +50,15 @@ jobs:
50 run: | 50 run: |
51 python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" 51 python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
52 python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" 52 python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)"
  53 +
  54 + sherpa-onnx --help
  55 + sherpa-onnx-offline --help
  56 +
  57 + sherpa-onnx-microphone --help
  58 + sherpa-onnx-microphone-offline --help
  59 +
  60 + sherpa-onnx-offline-websocket-server --help
  61 + sherpa-onnx-offline-websocket-client --help
  62 +
  63 + sherpa-onnx-online-websocket-server --help
  64 + sherpa-onnx-online-websocket-client --help
  1 +name: Python offline websocket server
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + pull_request:
  8 + branches:
  9 + - master
  10 +
  11 +concurrency:
  12 + group: python-offline-websocket-server-${{ github.ref }}
  13 + cancel-in-progress: true
  14 +
  15 +permissions:
  16 + contents: read
  17 +
  18 +jobs:
  19 + python_offline_websocket_server:
  20 + runs-on: ${{ matrix.os }}
  21 + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
  22 + strategy:
  23 + fail-fast: false
  24 + matrix:
  25 + os: [ubuntu-latest, windows-latest, macos-latest]
  26 + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
  27 + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"]
  28 +
  29 + steps:
  30 + - uses: actions/checkout@v2
  31 + with:
  32 + fetch-depth: 0
  33 +
  34 + - name: Setup Python ${{ matrix.python-version }}
  35 + uses: actions/setup-python@v2
  36 + with:
  37 + python-version: ${{ matrix.python-version }}
  38 +
  39 + - name: Install Python dependencies
  40 + shell: bash
  41 + run: |
  42 + python3 -m pip install --upgrade pip numpy
  43 +
  44 + - name: Install sherpa-onnx
  45 + shell: bash
  46 + run: |
  47 + python3 -m pip install --no-deps --verbose .
  48 + python3 -m pip install websockets
  49 +
  50 +
  51 + - name: Start server for transducer models
  52 + if: matrix.model_type == 'transducer'
  53 + shell: bash
  54 + run: |
  55 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
  56 + cd sherpa-onnx-zipformer-en-2023-06-26
  57 + git lfs pull --include "*.onnx"
  58 + cd ..
  59 +
  60 + python3 ./python-api-examples/non_streaming_server.py \
  61 + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
  62 + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
  63 + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
  64 + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt &
  65 +
  66 + echo "sleep 10 seconds to wait the server start"
  67 + sleep 10
  68 +
  69 + - name: Start client for transducer models
  70 + if: matrix.model_type == 'transducer'
  71 + shell: bash
  72 + run: |
  73 + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
  74 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
  75 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
  76 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
  77 +
  78 + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
  79 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
  80 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
  81 + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
  82 +
  83 + - name: Start server for paraformer models
  84 + if: matrix.model_type == 'paraformer'
  85 + shell: bash
  86 + run: |
  87 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
  88 + cd sherpa-onnx-paraformer-zh-2023-03-28
  89 + git lfs pull --include "*.onnx"
  90 + cd ..
  91 +
  92 + python3 ./python-api-examples/non_streaming_server.py \
  93 + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
  94 + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
  95 +
  96 + echo "sleep 10 seconds to wait the server start"
  97 + sleep 10
  98 +
  99 + - name: Start client for paraformer models
  100 + if: matrix.model_type == 'paraformer'
  101 + shell: bash
  102 + run: |
  103 + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
  104 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
  105 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
  106 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
  107 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
  108 +
  109 + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
  110 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
  111 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
  112 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
  113 + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
  114 +
  115 + - name: Start server for nemo_ctc models
  116 + if: matrix.model_type == 'nemo_ctc'
  117 + shell: bash
  118 + run: |
  119 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
  120 + cd sherpa-onnx-nemo-ctc-en-conformer-medium
  121 + git lfs pull --include "*.onnx"
  122 + cd ..
  123 +
  124 + python3 ./python-api-examples/non_streaming_server.py \
  125 + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
  126 + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt &
  127 +
  128 + echo "sleep 10 seconds to wait the server start"
  129 + sleep 10
  130 +
  131 + - name: Start client for nemo_ctc models
  132 + if: matrix.model_type == 'nemo_ctc'
  133 + shell: bash
  134 + run: |
  135 + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
  136 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
  137 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
  138 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
  139 +
  140 + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
  141 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
  142 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
  143 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
  144 +
  145 + - name: Start server for whisper models
  146 + if: matrix.model_type == 'whisper'
  147 + shell: bash
  148 + run: |
  149 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
  150 + cd sherpa-onnx-whisper-tiny.en
  151 + git lfs pull --include "*.onnx"
  152 + cd ..
  153 +
  154 + python3 ./python-api-examples/non_streaming_server.py \
  155 + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
  156 + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
  157 + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt &
  158 +
  159 + echo "sleep 10 seconds to wait the server start"
  160 + sleep 10
  161 +
  162 + - name: Start client for whisper models
  163 + if: matrix.model_type == 'whisper'
  164 + shell: bash
  165 + run: |
  166 + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
  167 + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
  168 + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
  169 + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
  170 +
  171 + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
  172 + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
  173 + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
  174 + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
  1 +name: Python online websocket server
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + pull_request:
  8 + branches:
  9 + - master
  10 +
  11 +concurrency:
  12 + group: python-online-websocket-server-${{ github.ref }}
  13 + cancel-in-progress: true
  14 +
  15 +permissions:
  16 + contents: read
  17 +
  18 +jobs:
  19 + python_online_websocket_server:
  20 + runs-on: ${{ matrix.os }}
  21 + name: ${{ matrix.os }} ${{ matrix.python-version }}
  22 + strategy:
  23 + fail-fast: false
  24 + matrix:
  25 + os: [ubuntu-latest, windows-latest, macos-latest]
  26 + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
  27 + model_type: ["transducer"]
  28 +
  29 + steps:
  30 + - uses: actions/checkout@v2
  31 + with:
  32 + fetch-depth: 0
  33 +
  34 + - name: Setup Python ${{ matrix.python-version }}
  35 + uses: actions/setup-python@v2
  36 + with:
  37 + python-version: ${{ matrix.python-version }}
  38 +
  39 + - name: Install Python dependencies
  40 + shell: bash
  41 + run: |
  42 + python3 -m pip install --upgrade pip numpy
  43 +
  44 + - name: Install sherpa-onnx
  45 + shell: bash
  46 + run: |
  47 + python3 -m pip install --no-deps --verbose .
  48 + python3 -m pip install websockets
  49 +
  50 +
  51 + - name: Start server for transducer models
  52 + if: matrix.model_type == 'transducer'
  53 + shell: bash
  54 + run: |
  55 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
  56 + cd sherpa-onnx-streaming-zipformer-en-2023-06-26
  57 + git lfs pull --include "*.onnx"
  58 + cd ..
  59 +
  60 + python3 ./python-api-examples/streaming_server.py \
  61 + --encoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \
  62 + --decoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \
  63 + --joiner ./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \
  64 + --tokens ./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt &
  65 + echo "sleep 10 seconds to wait the server start"
  66 + sleep 10
  67 +
  68 + - name: Start client for transducer models
  69 + if: matrix.model_type == 'transducer'
  70 + shell: bash
  71 + run: |
  72 + python3 ./python-api-examples/online-websocket-client-decode-file.py \
  73 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.1") 4 +set(SHERPA_ONNX_VERSION "1.7.2")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
  1 +# Introduction
  2 +
  3 +This folder contains C API examples for [sherpa-onnx][sherpa-onnx].
  4 +
  5 +Please refer to the documentation
  6 +https://k2-fsa.github.io/sherpa/onnx/c-api/index.html
  7 +for details.
  8 +
  9 +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
  1 +# Introduction
  2 +
  3 +This folder contains C# API examples for [sherpa-onnx][sherpa-onnx].
  4 +
  5 +Please refer to the documentation
  6 +https://k2-fsa.github.io/sherpa/onnx/csharp-api/index.html
  7 +for details.
  8 +
  9 +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
  1 +# Introduction
  2 +
  3 +This folder contains Go API examples for [sherpa-onnx][sherpa-onnx].
  4 +
  5 +Please refer to the documentation
  6 +https://k2-fsa.github.io/sherpa/onnx/go-api/index.html
  7 +for details.
  8 +
  9 +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
  1 +#!/usr/bin/env python3
  2 +# Copyright 2022-2023 Xiaomi Corp.
  3 +"""
  4 +A server for non-streaming speech recognition. Non-streaming means you send all
  5 +the content of the audio at once for recognition.
  6 +
  7 +It supports multiple clients sending at the same time.
  8 +
  9 +Usage:
  10 + ./non_streaming_server.py --help
  11 +
  12 +Please refer to
  13 +
  14 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
  15 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
  16 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
  17 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
  18 +
  19 +for pre-trained models to download.
  20 +
  21 +Usage examples:
  22 +
  23 +(1) Use a non-streaming transducer model
  24 +
  25 +cd /path/to/sherpa-onnx
  26 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
  27 +cd sherpa-onnx-zipformer-en-2023-06-26
  28 +git lfs pull --include "*.onnx"
  29 +cd ..
  30 +
  31 +python3 ./python-api-examples/non_streaming_server.py \
  32 + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
  33 + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
  34 + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
  35 + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt
  36 +
  37 +(2) Use a non-streaming paraformer
  38 +
  39 +cd /path/to/sherpa-onnx
  40 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
  41 +cd sherpa-onnx-paraformer-zh-2023-03-28
  42 +git lfs pull --include "*.onnx"
  43 +cd ..
  44 +
  45 +python3 ./python-api-examples/non_streaming_server.py \
  46 + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
  47 + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
  48 +
  49 +(3) Use a non-streaming CTC model from NeMo
  50 +
  51 +cd /path/to/sherpa-onnx
  52 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
  53 +cd sherpa-onnx-nemo-ctc-en-conformer-medium
  54 +git lfs pull --include "*.onnx"
  55 +cd ..
  56 +
  57 +python3 ./python-api-examples/non_streaming_server.py \
  58 + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
  59 + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
  60 +
  61 +(4) Use a Whisper model
  62 +
  63 +cd /path/to/sherpa-onnx
  64 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
  65 +cd sherpa-onnx-whisper-tiny.en
  66 +git lfs pull --include "*.onnx"
  67 +cd ..
  68 +
  69 +python3 ./python-api-examples/non_streaming_server.py \
  70 + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
  71 + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
  72 + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
  73 +
  74 +----
  75 +
  76 +To use a certificate so that you can use https, please use
  77 +
  78 +python3 ./python-api-examples/non_streaming_server.py \
  79 + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
  80 + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
  81 + --certificate=/path/to/your/cert.pem
  82 +
  83 +If you don't have a certificate, please run:
  84 +
  85 + cd ./python-api-examples/web
  86 + ./generate-certificate.py
  87 +
  88 +It will generate 3 files, one of which is the required `cert.pem`.
  89 +""" # noqa
  90 +
  91 +import argparse
  92 +import asyncio
  93 +import http
  94 +import logging
  95 +import socket
  96 +import ssl
  97 +import sys
  98 +import warnings
  99 +from concurrent.futures import ThreadPoolExecutor
  100 +from datetime import datetime
  101 +from pathlib import Path
  102 +from typing import Optional, Tuple
  103 +
  104 +import numpy as np
  105 +import sherpa_onnx
  106 +
  107 +import websockets
  108 +
  109 +from http_server import HttpServer
  110 +
  111 +
  112 +def setup_logger(
  113 + log_filename: str,
  114 + log_level: str = "info",
  115 + use_console: bool = True,
  116 +) -> None:
  117 + """Setup log level.
  118 +
  119 + Args:
  120 + log_filename:
  121 + The filename to save the log.
  122 + log_level:
  123 + The log level to use, e.g., "debug", "info", "warning", "error",
  124 + "critical"
  125 + use_console:
  126 + True to also print logs to console.
  127 + """
  128 + now = datetime.now()
  129 + date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
  130 + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
  131 + log_filename = f"{log_filename}-{date_time}.txt"
  132 +
  133 + Path(log_filename).parent.mkdir(parents=True, exist_ok=True)
  134 +
  135 + level = logging.ERROR
  136 + if log_level == "debug":
  137 + level = logging.DEBUG
  138 + elif log_level == "info":
  139 + level = logging.INFO
  140 + elif log_level == "warning":
  141 + level = logging.WARNING
  142 + elif log_level == "critical":
  143 + level = logging.CRITICAL
  144 +
  145 + logging.basicConfig(
  146 + filename=log_filename,
  147 + format=formatter,
  148 + level=level,
  149 + filemode="w",
  150 + )
  151 + if use_console:
  152 + console = logging.StreamHandler()
  153 + console.setLevel(level)
  154 + console.setFormatter(logging.Formatter(formatter))
  155 + logging.getLogger("").addHandler(console)
  156 +
  157 +
  158 +def add_transducer_model_args(parser: argparse.ArgumentParser):
  159 + parser.add_argument(
  160 + "--encoder",
  161 + default="",
  162 + type=str,
  163 + help="Path to the transducer encoder model",
  164 + )
  165 +
  166 + parser.add_argument(
  167 + "--decoder",
  168 + default="",
  169 + type=str,
  170 + help="Path to the transducer decoder model",
  171 + )
  172 +
  173 + parser.add_argument(
  174 + "--joiner",
  175 + default="",
  176 + type=str,
  177 + help="Path to the transducer joiner model",
  178 + )
  179 +
  180 +
  181 +def add_paraformer_model_args(parser: argparse.ArgumentParser):
  182 + parser.add_argument(
  183 + "--paraformer",
  184 + default="",
  185 + type=str,
  186 + help="Path to the model.onnx from Paraformer",
  187 + )
  188 +
  189 +
  190 +def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
  191 + parser.add_argument(
  192 + "--nemo-ctc",
  193 + default="",
  194 + type=str,
  195 + help="Path to the model.onnx from NeMo CTC",
  196 + )
  197 +
  198 +
  199 +def add_whisper_model_args(parser: argparse.ArgumentParser):
  200 + parser.add_argument(
  201 + "--whisper-encoder",
  202 + default="",
  203 + type=str,
  204 + help="Path to whisper encoder model",
  205 + )
  206 +
  207 + parser.add_argument(
  208 + "--whisper-decoder",
  209 + default="",
  210 + type=str,
  211 + help="Path to whisper decoder model",
  212 + )
  213 +
  214 +
  215 +def add_model_args(parser: argparse.ArgumentParser):
  216 + add_transducer_model_args(parser)
  217 + add_paraformer_model_args(parser)
  218 + add_nemo_ctc_model_args(parser)
  219 + add_whisper_model_args(parser)
  220 +
  221 + parser.add_argument(
  222 + "--tokens",
  223 + type=str,
  224 + help="Path to tokens.txt",
  225 + )
  226 +
  227 + parser.add_argument(
  228 + "--num-threads",
  229 + type=int,
  230 + default=2,
  231 + help="Number of threads to run the neural network model",
  232 + )
  233 +
  234 + parser.add_argument(
  235 + "--provider",
  236 + type=str,
  237 + default="cpu",
  238 + help="Valid values: cpu, cuda, coreml",
  239 + )
  240 +
  241 +
  242 +def add_feature_config_args(parser: argparse.ArgumentParser):
  243 + parser.add_argument(
  244 + "--sample-rate",
  245 + type=int,
  246 + default=16000,
  247 + help="Sample rate of the data used to train the model. ",
  248 + )
  249 +
  250 + parser.add_argument(
  251 + "--feat-dim",
  252 + type=int,
  253 + default=80,
  254 + help="Feature dimension of the model",
  255 + )
  256 +
  257 +
  258 +def add_decoding_args(parser: argparse.ArgumentParser):
  259 + parser.add_argument(
  260 + "--decoding-method",
  261 + type=str,
  262 + default="greedy_search",
  263 + help="""Decoding method to use. Current supported methods are:
  264 + - greedy_search
  265 + - modified_beam_search (for transducer models only)
  266 + """,
  267 + )
  268 +
  269 + add_modified_beam_search_args(parser)
  270 +
  271 +
  272 +def add_modified_beam_search_args(parser: argparse.ArgumentParser):
  273 + parser.add_argument(
  274 + "--max-active-paths",
  275 + type=int,
  276 + default=4,
  277 + help="""Used only when --decoding-method is modified_beam_search.
  278 + It specifies number of active paths to keep during decoding.
  279 + """,
  280 + )
  281 +
  282 +
  283 +def check_args(args):
  284 + if not Path(args.tokens).is_file():
  285 + raise ValueError(f"{args.tokens} does not exist")
  286 +
  287 + if args.decoding_method not in (
  288 + "greedy_search",
  289 + "modified_beam_search",
  290 + ):
  291 + raise ValueError(f"Unsupported decoding method {args.decoding_method}")
  292 +
  293 + if args.decoding_method == "modified_beam_search":
  294 + assert args.num_active_paths > 0, args.num_active_paths
  295 + assert Path(args.encoder).is_file(), args.encoder
  296 + assert Path(args.decoder).is_file(), args.decoder
  297 + assert Path(args.joiner).is_file(), args.joiner
  298 +
  299 +
  300 +def get_args():
  301 + parser = argparse.ArgumentParser(
  302 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  303 + )
  304 +
  305 + add_model_args(parser)
  306 + add_feature_config_args(parser)
  307 + add_decoding_args(parser)
  308 +
  309 + parser.add_argument(
  310 + "--port",
  311 + type=int,
  312 + default=6006,
  313 + help="The server will listen on this port",
  314 + )
  315 +
  316 + parser.add_argument(
  317 + "--max-batch-size",
  318 + type=int,
  319 + default=25,
  320 + help="""Max batch size for computation. Note if there are not enough
  321 + requests in the queue, it will wait for max_wait_ms time. After that,
  322 + even if there are not enough requests, it still sends the
  323 + available requests in the queue for computation.
  324 + """,
  325 + )
  326 +
  327 + parser.add_argument(
  328 + "--max-wait-ms",
  329 + type=float,
  330 + default=5,
  331 + help="""Max time in millisecond to wait to build batches for inference.
  332 + If there are not enough requests in the feature queue to build a batch
  333 + of max_batch_size, it waits up to this time before fetching available
  334 + requests for computation.
  335 + """,
  336 + )
  337 +
  338 + parser.add_argument(
  339 + "--nn-pool-size",
  340 + type=int,
  341 + default=1,
  342 + help="Number of threads for NN computation and decoding.",
  343 + )
  344 +
  345 + parser.add_argument(
  346 + "--max-message-size",
  347 + type=int,
  348 + default=(1 << 20),
  349 + help="""Max message size in bytes.
  350 + The max size per message cannot exceed this limit.
  351 + """,
  352 + )
  353 +
  354 + parser.add_argument(
  355 + "--max-queue-size",
  356 + type=int,
  357 + default=32,
  358 + help="Max number of messages in the queue for each connection.",
  359 + )
  360 +
  361 + parser.add_argument(
  362 + "--max-active-connections",
  363 + type=int,
  364 + default=500,
  365 + help="""Maximum number of active connections. The server will refuse
  366 + to accept new connections once the current number of active connections
  367 + equals to this limit.
  368 + """,
  369 + )
  370 +
  371 + parser.add_argument(
  372 + "--certificate",
  373 + type=str,
  374 + help="""Path to the X.509 certificate. You need it only if you want to
  375 + use a secure websocket connection, i.e., use wss:// instead of ws://.
  376 + You can use ./web/generate-certificate.py
  377 + to generate the certificate `cert.pem`.
  378 + Note ./web/generate-certificate.py will generate three files but you
  379 + only need to pass the generated cert.pem to this option.
  380 + """,
  381 + )
  382 +
  383 + parser.add_argument(
  384 + "--doc-root",
  385 + type=str,
  386 + default="./python-api-examples/web",
  387 + help="Path to the web root",
  388 + )
  389 +
  390 + return parser.parse_args()
  391 +
  392 +
  393 +class NonStreamingServer:
  394 + def __init__(
  395 + self,
  396 + recognizer: sherpa_onnx.OfflineRecognizer,
  397 + max_batch_size: int,
  398 + max_wait_ms: float,
  399 + nn_pool_size: int,
  400 + max_message_size: int,
  401 + max_queue_size: int,
  402 + max_active_connections: int,
  403 + doc_root: str,
  404 + certificate: Optional[str] = None,
  405 + ):
  406 + """
  407 + Args:
  408 + recognizer:
  409 + An instance of the sherpa_onnx.OfflineRecognizer.
  410 + max_batch_size:
  411 + Max batch size for inference.
  412 + max_wait_ms:
  413 + Max wait time in milliseconds in order to build a batch of
  414 + `max_batch_size`.
  415 + nn_pool_size:
  416 + Number of threads for the thread pool that is used for NN
  417 + computation and decoding.
  418 + max_message_size:
  419 + Max size in bytes per message.
  420 + max_queue_size:
  421 + Max number of messages in the queue for each connection.
  422 + max_active_connections:
  423 + Max number of active connections. Once number of active client
  424 + equals to this limit, the server refuses to accept new connections.
  425 + doc_root:
  426 + Path to the directory where files like index.html for the HTTP
  427 + server locate.
  428 + certificate:
  429 + Optional. If not None, it will use secure websocket.
  430 + You can use ./web/generate-certificate.py to generate
  431 + it (the default generated filename is `cert.pem`).
  432 + """
  433 + self.recognizer = recognizer
  434 +
  435 + self.certificate = certificate
  436 + self.http_server = HttpServer(doc_root)
  437 +
  438 + self.nn_pool = ThreadPoolExecutor(
  439 + max_workers=nn_pool_size,
  440 + thread_name_prefix="nn",
  441 + )
  442 +
  443 + self.stream_queue = asyncio.Queue()
  444 +
  445 + self.max_wait_ms = max_wait_ms
  446 + self.max_batch_size = max_batch_size
  447 + self.max_message_size = max_message_size
  448 + self.max_queue_size = max_queue_size
  449 + self.max_active_connections = max_active_connections
  450 +
  451 + self.current_active_connections = 0
  452 + self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
  453 +
  454 + async def process_request(
  455 + self,
  456 + path: str,
  457 + request_headers: websockets.Headers,
  458 + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]:
  459 + if "sec-websocket-key" not in request_headers:
  460 + # This is a normal HTTP request
  461 + if path == "/":
  462 + path = "/index.html"
  463 + if path[-1] == "?":
  464 + path = path[:-1]
  465 +
  466 + if path == "/streaming_record.html":
  467 + response = r"""
  468 +<!doctype html><html><head>
  469 +<title>Speech recognition with next-gen Kaldi</title><body>
  470 +<h2>Only
  471 +<a href="/upload.html">/upload.html</a>
  472 +and
  473 +<a href="/offline_record.html">/offline_record.html</a>
  474 +is available for the non-streaming server.<h2>
  475 +<br/>
  476 +<br/>
  477 +Go back to <a href="/upload.html">/upload.html</a>
  478 +or <a href="/offline_record.html">/offline_record.html</a>
  479 +</body></head></html>
  480 +"""
  481 + found = True
  482 + mime_type = "text/html"
  483 + else:
  484 + found, response, mime_type = self.http_server.process_request(path)
  485 + if isinstance(response, str):
  486 + response = response.encode("utf-8")
  487 +
  488 + if not found:
  489 + status = http.HTTPStatus.NOT_FOUND
  490 + else:
  491 + status = http.HTTPStatus.OK
  492 + header = {"Content-Type": mime_type}
  493 + return status, header, response
  494 +
  495 + if self.current_active_connections < self.max_active_connections:
  496 + self.current_active_connections += 1
  497 + return None
  498 +
  499 + # Refuse new connections
  500 + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
  501 + header = {"Hint": "The server is overloaded. Please retry later."}
  502 + response = b"The server is busy. Please retry later."
  503 +
  504 + return status, header, response
  505 +
  506 + async def run(self, port: int):
  507 + logging.info("started")
  508 +
  509 + task = asyncio.create_task(self.stream_consumer_task())
  510 +
  511 + if self.certificate:
  512 + logging.info(f"Using certificate: {self.certificate}")
  513 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  514 + ssl_context.load_cert_chain(self.certificate)
  515 + else:
  516 + ssl_context = None
  517 + logging.info("No certificate provided")
  518 +
  519 + async with websockets.serve(
  520 + self.handle_connection,
  521 + host="",
  522 + port=port,
  523 + max_size=self.max_message_size,
  524 + max_queue=self.max_queue_size,
  525 + process_request=self.process_request,
  526 + ssl=ssl_context,
  527 + ):
  528 + ip_list = ["localhost"]
  529 + if ssl_context:
  530 + ip_list += ["0.0.0.0", "127.0.0.1"]
  531 + ip_list.append(socket.gethostbyname(socket.gethostname()))
  532 +
  533 + proto = "http://" if ssl_context is None else "https://"
  534 + s = "Please visit one of the following addresses:\n\n"
  535 + for p in ip_list:
  536 + s += " " + proto + p + f":{port}" "\n"
  537 + logging.info(s)
  538 +
  539 + await asyncio.Future() # run forever
  540 +
  541 + await task # not reachable
  542 +
  543 + async def recv_audio_samples(
  544 + self,
  545 + socket: websockets.WebSocketServerProtocol,
  546 + ) -> Tuple[Optional[np.ndarray], Optional[float]]:
  547 + """Receive a tensor from the client.
  548 +
  549 + The message from the client is a **bytes** buffer.
  550 +
  551 + The first message can be either "Done" meaning the client won't send
  552 + anything in the future or it can be a buffer containing 8 bytes.
  553 + The first 4 bytes in little endian specifies the sample
  554 + rate of the audio samples; the second 4 bytes in little endian specifies
  555 + the number of bytes in the audio file, which will be sent by the client
  556 + in the subsequent messages.
  557 + Since there is a limit in the message size posed by the websocket
  558 + protocol, the client may send the audio file in multiple messages if the
  559 + audio file is very large.
  560 +
  561 + The second and remaining messages contain audio samples.
  562 +
  563 + Please refer to ./offline-websocket-client-decode-files-paralell.py
  564 + and ./offline-websocket-client-decode-files-sequential.py
  565 + for how the client sends the message.
  566 +
  567 + Args:
  568 + socket:
  569 + The socket for communicating with the client.
  570 + Returns:
  571 + Return a containing:
  572 + - 1-D np.float32 array containing the audio samples
  573 + - sample rate of the audio samples
  574 + or return (None, None) indicating the end of utterance.
  575 + """
  576 + header = await socket.recv()
  577 + if header == "Done":
  578 + return None, None
  579 +
  580 + assert len(header) >= 8, (
  581 + "The first message should contain at least 8 bytes."
  582 + + f"Given {len(header)}"
  583 + )
  584 +
  585 + sample_rate = int.from_bytes(header[:4], "little", signed=True)
  586 + expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True)
  587 +
  588 + received = []
  589 + num_received_bytes = 0
  590 + if len(header) > 8:
  591 + received.append(header[8:])
  592 + num_received_bytes += len(header) - 8
  593 +
  594 + if num_received_bytes < expected_num_bytes:
  595 + async for message in socket:
  596 + received.append(message)
  597 + num_received_bytes += len(message)
  598 + if num_received_bytes >= expected_num_bytes:
  599 + break
  600 +
  601 + assert num_received_bytes == expected_num_bytes, (
  602 + num_received_bytes,
  603 + expected_num_bytes,
  604 + )
  605 +
  606 + samples = b"".join(received)
  607 + array = np.frombuffer(samples, dtype=np.float32)
  608 + return array, sample_rate
  609 +
  610 + async def stream_consumer_task(self):
  611 + """This function extracts streams from the queue, batches them up, sends
  612 + them to the RNN-T model for computation and decoding.
  613 + """
  614 + while True:
  615 + if self.stream_queue.empty():
  616 + await asyncio.sleep(self.max_wait_ms / 1000)
  617 + continue
  618 +
  619 + batch = []
  620 + try:
  621 + while len(batch) < self.max_batch_size:
  622 + item = self.stream_queue.get_nowait()
  623 +
  624 + batch.append(item)
  625 + except asyncio.QueueEmpty:
  626 + pass
  627 + stream_list = [b[0] for b in batch]
  628 + future_list = [b[1] for b in batch]
  629 +
  630 + loop = asyncio.get_running_loop()
  631 + await loop.run_in_executor(
  632 + self.nn_pool,
  633 + self.recognizer.decode_streams,
  634 + stream_list,
  635 + )
  636 +
  637 + for f in future_list:
  638 + self.stream_queue.task_done()
  639 + f.set_result(None)
  640 +
  641 + async def compute_and_decode(
  642 + self,
  643 + stream: sherpa_onnx.OfflineStream,
  644 + ) -> None:
  645 + """Put the stream into the queue and wait it to be processed by the
  646 + consumer task.
  647 +
  648 + Args:
  649 + stream:
  650 + The stream to be processed. Note: It is changed in-place.
  651 + """
  652 + loop = asyncio.get_running_loop()
  653 + future = loop.create_future()
  654 + await self.stream_queue.put((stream, future))
  655 + await future
  656 +
  657 + async def handle_connection(
  658 + self,
  659 + socket: websockets.WebSocketServerProtocol,
  660 + ):
  661 + """Receive audio samples from the client, process it, and sends
  662 + deocoding result back to the client.
  663 +
  664 + Args:
  665 + socket:
  666 + The socket for communicating with the client.
  667 + """
  668 + try:
  669 + await self.handle_connection_impl(socket)
  670 + except websockets.exceptions.ConnectionClosedError:
  671 + logging.info(f"{socket.remote_address} disconnected")
  672 + finally:
  673 + # Decrement so that it can accept new connections
  674 + self.current_active_connections -= 1
  675 +
  676 + logging.info(
  677 + f"Disconnected: {socket.remote_address}. "
  678 + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
  679 + )
  680 +
  681 + async def handle_connection_impl(
  682 + self,
  683 + socket: websockets.WebSocketServerProtocol,
  684 + ):
  685 + """Receive audio samples from the client, process it, and send
  686 + decoding results back to the client.
  687 +
  688 + Args:
  689 + socket:
  690 + The socket for communicating with the client.
  691 + """
  692 + logging.info(
  693 + f"Connected: {socket.remote_address}. "
  694 + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
  695 + )
  696 +
  697 + while True:
  698 + stream = self.recognizer.create_stream()
  699 + samples, sample_rate = await self.recv_audio_samples(socket)
  700 + if samples is None:
  701 + break
  702 + # stream.accept_samples() runs in the main thread
  703 +
  704 + stream.accept_waveform(sample_rate, samples)
  705 +
  706 + await self.compute_and_decode(stream)
  707 + result = stream.result.text
  708 + logging.info(f"result: {result}")
  709 +
  710 + if result:
  711 + await socket.send(result)
  712 + else:
  713 + # If result is an empty string, send something to the client.
  714 + # Otherwise, socket.send() is a no-op and the client will
  715 + # wait for a reply indefinitely.
  716 + await socket.send("<EMPTY>")
  717 +
  718 +
  719 +def assert_file_exists(filename: str):
  720 + assert Path(filename).is_file(), (
  721 + f"{filename} does not exist!\n"
  722 + "Please refer to "
  723 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  724 + )
  725 +
  726 +
  727 +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  728 + if args.encoder:
  729 + assert len(args.paraformer) == 0, args.paraformer
  730 + assert len(args.nemo_ctc) == 0, args.nemo_ctc
  731 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  732 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  733 +
  734 + assert_file_exists(args.encoder)
  735 + assert_file_exists(args.decoder)
  736 + assert_file_exists(args.joiner)
  737 +
  738 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  739 + encoder=args.encoder,
  740 + decoder=args.decoder,
  741 + joiner=args.joiner,
  742 + tokens=args.tokens,
  743 + num_threads=args.num_threads,
  744 + sample_rate=args.sample_rate,
  745 + feature_dim=args.feat_dim,
  746 + decoding_method=args.decoding_method,
  747 + max_active_paths=args.max_active_paths,
  748 + )
  749 + elif args.paraformer:
  750 + assert len(args.nemo_ctc) == 0, args.nemo_ctc
  751 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  752 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  753 +
  754 + assert_file_exists(args.paraformer)
  755 +
  756 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  757 + paraformer=args.paraformer,
  758 + tokens=args.tokens,
  759 + num_threads=args.num_threads,
  760 + sample_rate=args.sample_rate,
  761 + feature_dim=args.feat_dim,
  762 + decoding_method=args.decoding_method,
  763 + )
  764 + elif args.nemo_ctc:
  765 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  766 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  767 +
  768 + assert_file_exists(args.nemo_ctc)
  769 +
  770 + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  771 + model=args.nemo_ctc,
  772 + tokens=args.tokens,
  773 + num_threads=args.num_threads,
  774 + sample_rate=args.sample_rate,
  775 + feature_dim=args.feat_dim,
  776 + decoding_method=args.decoding_method,
  777 + )
  778 + elif args.whisper_encoder:
  779 + assert_file_exists(args.whisper_encoder)
  780 + assert_file_exists(args.whisper_decoder)
  781 +
  782 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  783 + encoder=args.whisper_encoder,
  784 + decoder=args.whisper_decoder,
  785 + tokens=args.tokens,
  786 + num_threads=args.num_threads,
  787 + decoding_method=args.decoding_method,
  788 + )
  789 + else:
  790 + raise ValueError("Please specify at least one model")
  791 +
  792 + return recognizer
  793 +
  794 +
  795 +def main():
  796 + args = get_args()
  797 + logging.info(vars(args))
  798 + check_args(args)
  799 +
  800 + recognizer = create_recognizer(args)
  801 +
  802 + port = args.port
  803 + max_wait_ms = args.max_wait_ms
  804 + max_batch_size = args.max_batch_size
  805 + nn_pool_size = args.nn_pool_size
  806 + max_message_size = args.max_message_size
  807 + max_queue_size = args.max_queue_size
  808 + max_active_connections = args.max_active_connections
  809 + certificate = args.certificate
  810 + doc_root = args.doc_root
  811 +
  812 + if certificate and not Path(certificate).is_file():
  813 + raise ValueError(f"{certificate} does not exist")
  814 +
  815 + if not Path(doc_root).is_dir():
  816 + raise ValueError(f"Directory {doc_root} does not exist")
  817 +
  818 + non_streaming_server = NonStreamingServer(
  819 + recognizer=recognizer,
  820 + max_wait_ms=max_wait_ms,
  821 + max_batch_size=max_batch_size,
  822 + nn_pool_size=nn_pool_size,
  823 + max_message_size=max_message_size,
  824 + max_queue_size=max_queue_size,
  825 + max_active_connections=max_active_connections,
  826 + certificate=certificate,
  827 + doc_root=doc_root,
  828 + )
  829 + asyncio.run(non_streaming_server.run(port))
  830 +
  831 +
  832 +if __name__ == "__main__":
  833 + log_filename = "log/log-non-streaming-server"
  834 + setup_logger(log_filename)
  835 + main()
@@ -119,7 +119,13 @@ async def run( @@ -119,7 +119,13 @@ async def run(
119 buf += (samples.size * 4).to_bytes(4, byteorder="little") 119 buf += (samples.size * 4).to_bytes(4, byteorder="little")
120 buf += samples.tobytes() 120 buf += samples.tobytes()
121 121
122 - await websocket.send(buf) 122 + payload_len = 10240
  123 + while len(buf) > payload_len:
  124 + await websocket.send(buf[:payload_len])
  125 + buf = buf[payload_len:]
  126 +
  127 + if buf:
  128 + await websocket.send(buf)
123 129
124 decoding_results = await websocket.recv() 130 decoding_results = await websocket.recv()
125 logging.info(f"{wave_filename}\n{decoding_results}") 131 logging.info(f"{wave_filename}\n{decoding_results}")
@@ -116,11 +116,18 @@ async def run( @@ -116,11 +116,18 @@ async def run(
116 assert isinstance(sample_rate, int) 116 assert isinstance(sample_rate, int)
117 assert samples.dtype == np.float32, samples.dtype 117 assert samples.dtype == np.float32, samples.dtype
118 assert samples.ndim == 1, samples.dim 118 assert samples.ndim == 1, samples.dim
  119 +
119 buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes 120 buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
120 buf += (samples.size * 4).to_bytes(4, byteorder="little") 121 buf += (samples.size * 4).to_bytes(4, byteorder="little")
121 buf += samples.tobytes() 122 buf += samples.tobytes()
122 123
123 - await websocket.send(buf) 124 + payload_len = 10240
  125 + while len(buf) > payload_len:
  126 + await websocket.send(buf[:payload_len])
  127 + buf = buf[payload_len:]
  128 +
  129 + if buf:
  130 + await websocket.send(buf)
124 131
125 decoding_results = await websocket.recv() 132 decoding_results = await websocket.recv()
126 print(decoding_results) 133 print(decoding_results)
@@ -15,10 +15,9 @@ Usage: @@ -15,10 +15,9 @@ Usage:
15 15
16 (Note: You have to first start the server before starting the client) 16 (Note: You have to first start the server before starting the client)
17 17
18 -You can find the server at 18 +You can find the c++ server at
19 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc 19 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
20 -  
21 -Note: The server is implemented in C++. 20 +or use the python server ./python-api-examples/streaming_server.py
22 21
23 There is also a C++ version of the client. Please see 22 There is also a C++ version of the client. Please see
24 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc 23 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
@@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): @@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
115 last_message = message 114 last_message = message
116 logging.info(message) 115 logging.info(message)
117 else: 116 else:
118 - return last_message 117 + break
  118 + return last_message
119 119
120 120
121 async def run( 121 async def run(
@@ -142,6 +142,7 @@ async def run( @@ -142,6 +142,7 @@ async def run(
142 142
143 await websocket.send(d) 143 await websocket.send(d)
144 144
  145 + # Simulate streaming. You can remove the sleep if you want
145 await asyncio.sleep(seconds_per_message) # in seconds 146 await asyncio.sleep(seconds_per_message) # in seconds
146 147
147 start += samples_per_message 148 start += samples_per_message
@@ -12,10 +12,9 @@ Usage: @@ -12,10 +12,9 @@ Usage:
12 12
13 (Note: You have to first start the server before starting the client) 13 (Note: You have to first start the server before starting the client)
14 14
15 -You can find the server at 15 +You can find the C++ server at
16 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc 16 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
17 -  
18 -Note: The server is implemented in C++. 17 +or use the python server ./python-api-examples/streaming_server.py
19 18
20 There is also a C++ version of the client. Please see 19 There is also a C++ version of the client. Please see
21 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc 20 https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
@@ -13,11 +13,37 @@ Usage: @@ -13,11 +13,37 @@ Usage:
13 13
14 Example: 14 Example:
15 15
  16 +(1) Without a certificate
  17 +
16 python3 ./python-api-examples/streaming_server.py \ 18 python3 ./python-api-examples/streaming_server.py \
17 --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ 19 --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
18 --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ 20 --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
19 --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ 21 --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
20 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt 22 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
  23 +
  24 +(2) With a certificate
  25 +
  26 +(a) Generate a certificate first:
  27 +
  28 + cd python-api-examples/web
  29 + ./generate-certificate.py
  30 + cd ../..
  31 +
  32 +(b) Start the server
  33 +
  34 +python3 ./python-api-examples/streaming_server.py \
  35 + --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
  36 + --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
  37 + --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
  38 + --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
  39 + --certificate ./python-api-examples/web/cert.pem
  40 +
  41 +Please refer to
  42 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
  43 +to download pre-trained models.
  44 +
  45 +The model in the above help messages is from
  46 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
21 """ 47 """
22 48
23 import argparse 49 import argparse
@@ -35,6 +61,7 @@ from typing import List, Optional, Tuple @@ -35,6 +61,7 @@ from typing import List, Optional, Tuple
35 import numpy as np 61 import numpy as np
36 import sherpa_onnx 62 import sherpa_onnx
37 import websockets 63 import websockets
  64 +
38 from http_server import HttpServer 65 from http_server import HttpServer
39 66
40 67
@@ -269,8 +296,8 @@ def get_args(): @@ -269,8 +296,8 @@ def get_args():
269 parser.add_argument( 296 parser.add_argument(
270 "--num-threads", 297 "--num-threads",
271 type=int, 298 type=int,
272 - default=1,  
273 - help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.", 299 + default=2,
  300 + help="Number of threads to run the neural network model",
274 ) 301 )
275 302
276 parser.add_argument( 303 parser.add_argument(
@@ -278,8 +305,10 @@ def get_args(): @@ -278,8 +305,10 @@ def get_args():
278 type=str, 305 type=str,
279 help="""Path to the X.509 certificate. You need it only if you want to 306 help="""Path to the X.509 certificate. You need it only if you want to
280 use a secure websocket connection, i.e., use wss:// instead of ws://. 307 use a secure websocket connection, i.e., use wss:// instead of ws://.
281 - You can use sherpa/bin/web/generate-certificate.py 308 + You can use ./web/generate-certificate.py
282 to generate the certificate `cert.pem`. 309 to generate the certificate `cert.pem`.
  310 + Note ./web/generate-certificate.py will generate three files but you
  311 + only need to pass the generated cert.pem to this option.
283 """, 312 """,
284 ) 313 )
285 314
@@ -287,7 +316,7 @@ def get_args(): @@ -287,7 +316,7 @@ def get_args():
287 "--doc-root", 316 "--doc-root",
288 type=str, 317 type=str,
289 default="./python-api-examples/web", 318 default="./python-api-examples/web",
290 - help="""Path to the web root""", 319 + help="Path to the web root",
291 ) 320 )
292 321
293 return parser.parse_args() 322 return parser.parse_args()
@@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: @@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
299 encoder=args.encoder_model, 328 encoder=args.encoder_model,
300 decoder=args.decoder_model, 329 decoder=args.decoder_model,
301 joiner=args.joiner_model, 330 joiner=args.joiner_model,
302 - num_threads=1,  
303 - sample_rate=16000,  
304 - feature_dim=80, 331 + num_threads=args.num_threads,
  332 + sample_rate=args.sample_rate,
  333 + feature_dim=args.feat_dim,
305 decoding_method=args.decoding_method, 334 decoding_method=args.decoding_method,
306 max_active_paths=args.num_active_paths, 335 max_active_paths=args.num_active_paths,
307 enable_endpoint_detection=args.use_endpoint != 0, 336 enable_endpoint_detection=args.use_endpoint != 0,
@@ -359,7 +388,7 @@ class StreamingServer(object): @@ -359,7 +388,7 @@ class StreamingServer(object):
359 server locate. 388 server locate.
360 certificate: 389 certificate:
361 Optional. If not None, it will use secure websocket. 390 Optional. If not None, it will use secure websocket.
362 - You can use ./sherpa/bin/web/generate-certificate.py to generate 391 + You can use ./web/generate-certificate.py to generate
363 it (the default generated filename is `cert.pem`). 392 it (the default generated filename is `cert.pem`).
364 """ 393 """
365 self.recognizer = recognizer 394 self.recognizer = recognizer
@@ -373,6 +402,7 @@ class StreamingServer(object): @@ -373,6 +402,7 @@ class StreamingServer(object):
373 ) 402 )
374 403
375 self.stream_queue = asyncio.Queue() 404 self.stream_queue = asyncio.Queue()
  405 +
376 self.max_wait_ms = max_wait_ms 406 self.max_wait_ms = max_wait_ms
377 self.max_batch_size = max_batch_size 407 self.max_batch_size = max_batch_size
378 self.max_message_size = max_message_size 408 self.max_message_size = max_message_size
@@ -382,11 +412,10 @@ class StreamingServer(object): @@ -382,11 +412,10 @@ class StreamingServer(object):
382 self.current_active_connections = 0 412 self.current_active_connections = 0
383 413
384 self.sample_rate = int(recognizer.config.feat_config.sampling_rate) 414 self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
385 - self.decoding_method = recognizer.config.decoding_method  
386 415
387 async def stream_consumer_task(self): 416 async def stream_consumer_task(self):
388 """This function extracts streams from the queue, batches them up, sends 417 """This function extracts streams from the queue, batches them up, sends
389 - them to the RNN-T model for computation and decoding. 418 + them to the neural network model for computation and decoding.
390 """ 419 """
391 while True: 420 while True:
392 if self.stream_queue.empty(): 421 if self.stream_queue.empty():
@@ -442,7 +471,22 @@ class StreamingServer(object): @@ -442,7 +471,22 @@ class StreamingServer(object):
442 # This is a normal HTTP request 471 # This is a normal HTTP request
443 if path == "/": 472 if path == "/":
444 path = "/index.html" 473 path = "/index.html"
445 - found, response, mime_type = self.http_server.process_request(path) 474 +
  475 + if path in ("/upload.html", "/offline_record.html"):
  476 + response = r"""
  477 +<!doctype html><html><head>
  478 +<title>Speech recognition with next-gen Kaldi</title><body>
  479 +<h2>Only /streaming_record.html is available for the streaming server.<h2>
  480 +<br/>
  481 +<br/>
  482 +Go back to <a href="/streaming_record.html">/streaming_record.html</a>
  483 +</body></head></html>
  484 +"""
  485 + found = True
  486 + mime_type = "text/html"
  487 + else:
  488 + found, response, mime_type = self.http_server.process_request(path)
  489 +
446 if isinstance(response, str): 490 if isinstance(response, str):
447 response = response.encode("utf-8") 491 response = response.encode("utf-8")
448 492
@@ -484,12 +528,21 @@ class StreamingServer(object): @@ -484,12 +528,21 @@ class StreamingServer(object):
484 process_request=self.process_request, 528 process_request=self.process_request,
485 ssl=ssl_context, 529 ssl=ssl_context,
486 ): 530 ):
487 - ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]  
488 - ip_list.append(socket.gethostbyname(socket.gethostname())) 531 + ip_list = ["localhost"]
  532 + if ssl_context:
  533 + ip_list += ["0.0.0.0", "127.0.0.1"]
  534 + ip_list.append(socket.gethostbyname(socket.gethostname()))
489 proto = "http://" if ssl_context is None else "https://" 535 proto = "http://" if ssl_context is None else "https://"
490 s = "Please visit one of the following addresses:\n\n" 536 s = "Please visit one of the following addresses:\n\n"
491 for p in ip_list: 537 for p in ip_list:
492 s += " " + proto + p + f":{port}" "\n" 538 s += " " + proto + p + f":{port}" "\n"
  539 +
  540 + if not ssl_context:
  541 + s += "\nSince you are not providing a certificate, you cannot "
  542 + s += "use your microphone from within the browser using "
  543 + s += "public IP addresses. Only localhost can be used."
  544 + s += "You also cannot use 0.0.0.0 or 127.0.0.1"
  545 +
493 logging.info(s) 546 logging.info(s)
494 547
495 await asyncio.Future() # run forever 548 await asyncio.Future() # run forever
@@ -525,7 +578,7 @@ class StreamingServer(object): @@ -525,7 +578,7 @@ class StreamingServer(object):
525 socket: websockets.WebSocketServerProtocol, 578 socket: websockets.WebSocketServerProtocol,
526 ): 579 ):
527 """Receive audio samples from the client, process it, and send 580 """Receive audio samples from the client, process it, and send
528 - deocoding result back to the client. 581 + decoding result back to the client.
529 582
530 Args: 583 Args:
531 socket: 584 socket:
@@ -560,8 +613,6 @@ class StreamingServer(object): @@ -560,8 +613,6 @@ class StreamingServer(object):
560 self.recognizer.reset(stream) 613 self.recognizer.reset(stream)
561 segment += 1 614 segment += 1
562 615
563 - print(message)  
564 -  
565 await socket.send(json.dumps(message)) 616 await socket.send(json.dumps(message))
566 617
567 tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) 618 tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
@@ -583,7 +634,7 @@ class StreamingServer(object): @@ -583,7 +634,7 @@ class StreamingServer(object):
583 self, 634 self,
584 socket: websockets.WebSocketServerProtocol, 635 socket: websockets.WebSocketServerProtocol,
585 ) -> Optional[np.ndarray]: 636 ) -> Optional[np.ndarray]:
586 - """Receives a tensor from the client. 637 + """Receive a tensor from the client.
587 638
588 Each message contains either a bytes buffer containing audio samples 639 Each message contains either a bytes buffer containing audio samples
589 in 16 kHz or contains "Done" meaning the end of utterance. 640 in 16 kHz or contains "Done" meaning the end of utterance.
@@ -660,6 +711,6 @@ def main(): @@ -660,6 +711,6 @@ def main():
660 711
661 712
662 if __name__ == "__main__": 713 if __name__ == "__main__":
663 - log_filename = "log/log-streaming-zipformer" 714 + log_filename = "log/log-streaming-server"
664 setup_logger(log_filename) 715 setup_logger(log_filename)
665 main() 716 main()
1 -# How to use  
2 -  
3 -```bash  
4 -git clone https://github.com/k2-fsa/sherpa  
5 -  
6 -cd sherpa/sherpa/bin/web  
7 -python3 -m http.server 6009  
8 -```  
9 -and then go to <http://localhost:6009>  
10 -  
11 -You will see a page like the following screenshot:  
12 -  
13 -![Screenshot if you visit http://localhost:6009](./pic/web-ui.png)  
14 -  
15 -If your server is listening at the port *6006* with address **localhost**,  
16 -then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it.  
17 -  
18 -## File descriptions  
19 -  
20 -### ./css/bootstrap.min.css  
21 -  
22 -It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css  
23 -  
24 -### ./js/jquery-3.6.0.min.js  
25 -  
26 -It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js  
27 -  
28 -### ./js/popper.min.js  
29 -  
30 -It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js  
31 -  
32 -### ./js/bootstrap.min.js  
33 -  
34 -It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js  
@@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser: @@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser:
35 35
36 36
37 def cert_gen( 37 def cert_gen(
38 - emailAddress="https://github.com/k2-fsa/k2",  
39 - commonName="sherpa", 38 + emailAddress="https://github.com/k2-fsa/sherpa-onnx",
  39 + commonName="sherpa-onnx",
40 countryName="CN", 40 countryName="CN",
41 localityName="k2-fsa", 41 localityName="k2-fsa",
42 stateOrProvinceName="k2-fsa", 42 stateOrProvinceName="k2-fsa",
@@ -70,17 +70,13 @@ def cert_gen( @@ -70,17 +70,13 @@ def cert_gen(
70 cert.set_pubkey(k) 70 cert.set_pubkey(k)
71 cert.sign(k, "sha512") 71 cert.sign(k, "sha512")
72 with open(CERT_FILE, "wt") as f: 72 with open(CERT_FILE, "wt") as f:
73 - f.write(  
74 - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")  
75 - ) 73 + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
76 with open(KEY_FILE, "wt") as f: 74 with open(KEY_FILE, "wt") as f:
77 f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) 75 f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
78 76
79 with open(ALL_IN_ONE_FILE, "wt") as f: 77 with open(ALL_IN_ONE_FILE, "wt") as f:
80 f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) 78 f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
81 - f.write(  
82 - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")  
83 - ) 79 + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
84 print(f"Generated {CERT_FILE}") 80 print(f"Generated {CERT_FILE}")
85 print(f"Generated {KEY_FILE}") 81 print(f"Generated {KEY_FILE}")
86 print(f"Generated {ALL_IN_ONE_FILE}") 82 print(f"Generated {ALL_IN_ONE_FILE}")
@@ -53,7 +53,7 @@ @@ -53,7 +53,7 @@
53 </ul> 53 </ul>
54 54
55 Code is available at 55 Code is available at
56 - <a href="https://github.com/k2-fsa/sherpa"> https://github.com/k2-fsa/sherpa</a> 56 + <a href="https://github.com/k2-fsa/sherpa-onnx"> https://github.com/k2-fsa/sherpa-onnx</a>
57 57
58 <!-- Optional JavaScript --> 58 <!-- Optional JavaScript -->
59 <!-- jQuery first, then Popper.js, then Bootstrap JS --> 59 <!-- jQuery first, then Popper.js, then Bootstrap JS -->
@@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips'); @@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips');
60 const canvas = document.getElementById('canvas'); 60 const canvas = document.getElementById('canvas');
61 const mainSection = document.querySelector('.container'); 61 const mainSection = document.querySelector('.container');
62 62
  63 +recordBtn.disabled = true;
63 stopBtn.disabled = true; 64 stopBtn.disabled = true;
64 65
65 window.onload = (event) => { 66 window.onload = (event) => {
@@ -95,9 +96,10 @@ clearBtn.onclick = function() { @@ -95,9 +96,10 @@ clearBtn.onclick = function() {
95 }; 96 };
96 97
97 function send_header(n) { 98 function send_header(n) {
98 - const header = new ArrayBuffer(4);  
99 - new DataView(header).setInt32(0, n, true /* littleEndian */);  
100 - socket.send(new Int32Array(header, 0, 1)); 99 + const header = new ArrayBuffer(8);
  100 + new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */);
  101 + new DataView(header).setInt32(4, n, true /* littleEndian */);
  102 + socket.send(new Int32Array(header, 0, 2));
101 } 103 }
102 104
103 // copied/modified from https://mdn.github.io/web-dictaphone/ 105 // copied/modified from https://mdn.github.io/web-dictaphone/
@@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas'); @@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas');
88 const mainSection = document.querySelector('.container'); 88 const mainSection = document.querySelector('.container');
89 89
90 stopBtn.disabled = true; 90 stopBtn.disabled = true;
  91 +recordBtn.disabled = true;
91 92
92 let audioCtx; 93 let audioCtx;
93 const canvasCtx = canvas.getContext('2d'); 94 const canvasCtx = canvas.getContext('2d');
@@ -74,9 +74,11 @@ connectBtn.onclick = function() { @@ -74,9 +74,11 @@ connectBtn.onclick = function() {
74 }; 74 };
75 75
76 function send_header(n) { 76 function send_header(n) {
77 - const header = new ArrayBuffer(4);  
78 - new DataView(header).setInt32(0, n, true /* littleEndian */);  
79 - socket.send(new Int32Array(header, 0, 1)); 77 + const header = new ArrayBuffer(8);
  78 + // assume the uploaded wave is 16000 Hz
  79 + new DataView(header).setInt32(0, 16000, true /* littleEndian */);
  80 + new DataView(header).setInt32(4, n, true /* littleEndian */);
  81 + socket.send(new Int32Array(header, 0, 2));
80 } 82 }
81 83
82 function onFileChange() { 84 function onFileChange() {
@@ -33,9 +33,9 @@ @@ -33,9 +33,9 @@
33 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> 33 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
34 </div> 34 </div>
35 <span class="input-group-text" id="ws-protocol">ws://</span> 35 <span class="input-group-text" id="ws-protocol">ws://</span>
36 - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> 36 + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
37 <span class="input-group-text">:</span> 37 <span class="input-group-text">:</span>
38 - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> 38 + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
39 </div> 39 </div>
40 40
41 <div class="row"> 41 <div class="row">
@@ -33,9 +33,9 @@ @@ -33,9 +33,9 @@
33 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> 33 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
34 </div> 34 </div>
35 <span class="input-group-text" id="ws-protocol">ws://</span> 35 <span class="input-group-text" id="ws-protocol">ws://</span>
36 - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> 36 + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
37 <span class="input-group-text">:</span> 37 <span class="input-group-text">:</span>
38 - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> 38 + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
39 </div> 39 </div>
40 40
41 <div class="row"> 41 <div class="row">
@@ -32,9 +32,9 @@ @@ -32,9 +32,9 @@
32 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> 32 <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
33 </div> 33 </div>
34 <span class="input-group-text" id="ws-protocol">ws://</span> 34 <span class="input-group-text" id="ws-protocol">ws://</span>
35 - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> 35 + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
36 <span class="input-group-text">:</span> 36 <span class="input-group-text">:</span>
37 - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> 37 + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
38 </div> 38 </div>
39 39
40 <form> 40 <form>
1 from typing import Dict, List, Optional 1 from typing import Dict, List, Optional
2 2
3 -from _sherpa_onnx import Display 3 +from _sherpa_onnx import Display, OfflineStream, OnlineStream
4 4
5 -from .online_recognizer import OnlineRecognizer  
6 -from .online_recognizer import OnlineStream  
7 from .offline_recognizer import OfflineRecognizer 5 from .offline_recognizer import OfflineRecognizer
8 - 6 +from .online_recognizer import OnlineRecognizer
9 from .utils import encode_contexts 7 from .utils import encode_contexts
10 -  
11 -  
12 -  
@@ -41,6 +41,7 @@ class OfflineRecognizer(object): @@ -41,6 +41,7 @@ class OfflineRecognizer(object):
41 sample_rate: int = 16000, 41 sample_rate: int = 16000,
42 feature_dim: int = 80, 42 feature_dim: int = 80,
43 decoding_method: str = "greedy_search", 43 decoding_method: str = "greedy_search",
  44 + max_active_paths: int = 4,
44 context_score: float = 1.5, 45 context_score: float = 1.5,
45 debug: bool = False, 46 debug: bool = False,
46 provider: str = "cpu", 47 provider: str = "cpu",
@@ -72,6 +73,9 @@ class OfflineRecognizer(object): @@ -72,6 +73,9 @@ class OfflineRecognizer(object):
72 Dimension of the feature used to train the model. 73 Dimension of the feature used to train the model.
73 decoding_method: 74 decoding_method:
74 Valid values: greedy_search, modified_beam_search. 75 Valid values: greedy_search, modified_beam_search.
  76 + max_active_paths:
  77 + Maximum number of active paths to keep. Used only when
  78 + decoding_method is modified_beam_search.
75 debug: 79 debug:
76 True to show debug messages. 80 True to show debug messages.
77 provider: 81 provider:
@@ -103,6 +107,7 @@ class OfflineRecognizer(object): @@ -103,6 +107,7 @@ class OfflineRecognizer(object):
103 context_score=context_score, 107 context_score=context_score,
104 ) 108 )
105 self.recognizer = _Recognizer(recognizer_config) 109 self.recognizer = _Recognizer(recognizer_config)
  110 + self.config = recognizer_config
106 return self 111 return self
107 112
108 @classmethod 113 @classmethod
@@ -166,6 +171,7 @@ class OfflineRecognizer(object): @@ -166,6 +171,7 @@ class OfflineRecognizer(object):
166 decoding_method=decoding_method, 171 decoding_method=decoding_method,
167 ) 172 )
168 self.recognizer = _Recognizer(recognizer_config) 173 self.recognizer = _Recognizer(recognizer_config)
  174 + self.config = recognizer_config
169 return self 175 return self
170 176
171 @classmethod 177 @classmethod
@@ -229,6 +235,7 @@ class OfflineRecognizer(object): @@ -229,6 +235,7 @@ class OfflineRecognizer(object):
229 decoding_method=decoding_method, 235 decoding_method=decoding_method,
230 ) 236 )
231 self.recognizer = _Recognizer(recognizer_config) 237 self.recognizer = _Recognizer(recognizer_config)
  238 + self.config = recognizer_config
232 return self 239 return self
233 240
234 @classmethod 241 @classmethod
@@ -291,6 +298,7 @@ class OfflineRecognizer(object): @@ -291,6 +298,7 @@ class OfflineRecognizer(object):
291 decoding_method=decoding_method, 298 decoding_method=decoding_method,
292 ) 299 )
293 self.recognizer = _Recognizer(recognizer_config) 300 self.recognizer = _Recognizer(recognizer_config)
  301 + self.config = recognizer_config
294 return self 302 return self
295 303
296 def create_stream(self, contexts_list: Optional[List[List[int]]] = None): 304 def create_stream(self, contexts_list: Optional[List[List[int]]] = None):