Committed by
GitHub
Add non-streaming websocket server for python (#259)
正在显示
24 个修改的文件
包含
1247 行增加
和
92 行删除
| @@ -23,12 +23,12 @@ permissions: | @@ -23,12 +23,12 @@ permissions: | ||
| 23 | jobs: | 23 | jobs: |
| 24 | test_pip_install: | 24 | test_pip_install: |
| 25 | runs-on: ${{ matrix.os }} | 25 | runs-on: ${{ matrix.os }} |
| 26 | - name: Test pip install on ${{ matrix.os }} | 26 | + name: ${{ matrix.os }} ${{ matrix.python-version }} |
| 27 | strategy: | 27 | strategy: |
| 28 | fail-fast: false | 28 | fail-fast: false |
| 29 | matrix: | 29 | matrix: |
| 30 | os: [ubuntu-latest, windows-latest, macos-latest] | 30 | os: [ubuntu-latest, windows-latest, macos-latest] |
| 31 | - python-version: ["3.8", "3.9", "3.10"] | 31 | + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] |
| 32 | 32 | ||
| 33 | steps: | 33 | steps: |
| 34 | - uses: actions/checkout@v2 | 34 | - uses: actions/checkout@v2 |
| @@ -50,3 +50,15 @@ jobs: | @@ -50,3 +50,15 @@ jobs: | ||
| 50 | run: | | 50 | run: | |
| 51 | python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" | 51 | python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" |
| 52 | python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" | 52 | python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" |
| 53 | + | ||
| 54 | + sherpa-onnx --help | ||
| 55 | + sherpa-onnx-offline --help | ||
| 56 | + | ||
| 57 | + sherpa-onnx-microphone --help | ||
| 58 | + sherpa-onnx-microphone-offline --help | ||
| 59 | + | ||
| 60 | + sherpa-onnx-offline-websocket-server --help | ||
| 61 | + sherpa-onnx-offline-websocket-client --help | ||
| 62 | + | ||
| 63 | + sherpa-onnx-online-websocket-server --help | ||
| 64 | + sherpa-onnx-online-websocket-client --help |
| 1 | +name: Python offline websocket server | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + pull_request: | ||
| 8 | + branches: | ||
| 9 | + - master | ||
| 10 | + | ||
| 11 | +concurrency: | ||
| 12 | + group: python-offline-websocket-server-${{ github.ref }} | ||
| 13 | + cancel-in-progress: true | ||
| 14 | + | ||
| 15 | +permissions: | ||
| 16 | + contents: read | ||
| 17 | + | ||
| 18 | +jobs: | ||
| 19 | + python_offline_websocket_server: | ||
| 20 | + runs-on: ${{ matrix.os }} | ||
| 21 | + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }} | ||
| 22 | + strategy: | ||
| 23 | + fail-fast: false | ||
| 24 | + matrix: | ||
| 25 | + os: [ubuntu-latest, windows-latest, macos-latest] | ||
| 26 | + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | ||
| 27 | + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"] | ||
| 28 | + | ||
| 29 | + steps: | ||
| 30 | + - uses: actions/checkout@v2 | ||
| 31 | + with: | ||
| 32 | + fetch-depth: 0 | ||
| 33 | + | ||
| 34 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 35 | + uses: actions/setup-python@v2 | ||
| 36 | + with: | ||
| 37 | + python-version: ${{ matrix.python-version }} | ||
| 38 | + | ||
| 39 | + - name: Install Python dependencies | ||
| 40 | + shell: bash | ||
| 41 | + run: | | ||
| 42 | + python3 -m pip install --upgrade pip numpy | ||
| 43 | + | ||
| 44 | + - name: Install sherpa-onnx | ||
| 45 | + shell: bash | ||
| 46 | + run: | | ||
| 47 | + python3 -m pip install --no-deps --verbose . | ||
| 48 | + python3 -m pip install websockets | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + - name: Start server for transducer models | ||
| 52 | + if: matrix.model_type == 'transducer' | ||
| 53 | + shell: bash | ||
| 54 | + run: | | ||
| 55 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26 | ||
| 56 | + cd sherpa-onnx-zipformer-en-2023-06-26 | ||
| 57 | + git lfs pull --include "*.onnx" | ||
| 58 | + cd .. | ||
| 59 | + | ||
| 60 | + python3 ./python-api-examples/non_streaming_server.py \ | ||
| 61 | + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ | ||
| 62 | + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ | ||
| 63 | + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ | ||
| 64 | + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt & | ||
| 65 | + | ||
| 66 | + echo "sleep 10 seconds to wait the server start" | ||
| 67 | + sleep 10 | ||
| 68 | + | ||
| 69 | + - name: Start client for transducer models | ||
| 70 | + if: matrix.model_type == 'transducer' | ||
| 71 | + shell: bash | ||
| 72 | + run: | | ||
| 73 | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | ||
| 74 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \ | ||
| 75 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \ | ||
| 76 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav | ||
| 77 | + | ||
| 78 | + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | ||
| 79 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \ | ||
| 80 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \ | ||
| 81 | + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav | ||
| 82 | + | ||
| 83 | + - name: Start server for paraformer models | ||
| 84 | + if: matrix.model_type == 'paraformer' | ||
| 85 | + shell: bash | ||
| 86 | + run: | | ||
| 87 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 88 | + cd sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 89 | + git lfs pull --include "*.onnx" | ||
| 90 | + cd .. | ||
| 91 | + | ||
| 92 | + python3 ./python-api-examples/non_streaming_server.py \ | ||
| 93 | + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ | ||
| 94 | + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt & | ||
| 95 | + | ||
| 96 | + echo "sleep 10 seconds to wait the server start" | ||
| 97 | + sleep 10 | ||
| 98 | + | ||
| 99 | + - name: Start client for paraformer models | ||
| 100 | + if: matrix.model_type == 'paraformer' | ||
| 101 | + shell: bash | ||
| 102 | + run: | | ||
| 103 | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | ||
| 104 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ | ||
| 105 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ | ||
| 106 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ | ||
| 107 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav | ||
| 108 | + | ||
| 109 | + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | ||
| 110 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ | ||
| 111 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ | ||
| 112 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ | ||
| 113 | + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav | ||
| 114 | + | ||
| 115 | + - name: Start server for nemo_ctc models | ||
| 116 | + if: matrix.model_type == 'nemo_ctc' | ||
| 117 | + shell: bash | ||
| 118 | + run: | | ||
| 119 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium | ||
| 120 | + cd sherpa-onnx-nemo-ctc-en-conformer-medium | ||
| 121 | + git lfs pull --include "*.onnx" | ||
| 122 | + cd .. | ||
| 123 | + | ||
| 124 | + python3 ./python-api-examples/non_streaming_server.py \ | ||
| 125 | + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ | ||
| 126 | + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt & | ||
| 127 | + | ||
| 128 | + echo "sleep 10 seconds to wait the server start" | ||
| 129 | + sleep 10 | ||
| 130 | + | ||
| 131 | + - name: Start client for nemo_ctc models | ||
| 132 | + if: matrix.model_type == 'nemo_ctc' | ||
| 133 | + shell: bash | ||
| 134 | + run: | | ||
| 135 | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | ||
| 136 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ | ||
| 137 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ | ||
| 138 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav | ||
| 139 | + | ||
| 140 | + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | ||
| 141 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ | ||
| 142 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ | ||
| 143 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav | ||
| 144 | + | ||
| 145 | + - name: Start server for whisper models | ||
| 146 | + if: matrix.model_type == 'whisper' | ||
| 147 | + shell: bash | ||
| 148 | + run: | | ||
| 149 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en | ||
| 150 | + cd sherpa-onnx-whisper-tiny.en | ||
| 151 | + git lfs pull --include "*.onnx" | ||
| 152 | + cd .. | ||
| 153 | + | ||
| 154 | + python3 ./python-api-examples/non_streaming_server.py \ | ||
| 155 | + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ | ||
| 156 | + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ | ||
| 157 | + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt & | ||
| 158 | + | ||
| 159 | + echo "sleep 10 seconds to wait the server start" | ||
| 160 | + sleep 10 | ||
| 161 | + | ||
| 162 | + - name: Start client for whisper models | ||
| 163 | + if: matrix.model_type == 'whisper' | ||
| 164 | + shell: bash | ||
| 165 | + run: | | ||
| 166 | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | ||
| 167 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ | ||
| 168 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ | ||
| 169 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav | ||
| 170 | + | ||
| 171 | + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | ||
| 172 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ | ||
| 173 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ | ||
| 174 | + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav |
| 1 | +name: Python online websocket server | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + pull_request: | ||
| 8 | + branches: | ||
| 9 | + - master | ||
| 10 | + | ||
| 11 | +concurrency: | ||
| 12 | + group: python-online-websocket-server-${{ github.ref }} | ||
| 13 | + cancel-in-progress: true | ||
| 14 | + | ||
| 15 | +permissions: | ||
| 16 | + contents: read | ||
| 17 | + | ||
| 18 | +jobs: | ||
| 19 | + python_online_websocket_server: | ||
| 20 | + runs-on: ${{ matrix.os }} | ||
| 21 | + name: ${{ matrix.os }} ${{ matrix.python-version }} | ||
| 22 | + strategy: | ||
| 23 | + fail-fast: false | ||
| 24 | + matrix: | ||
| 25 | + os: [ubuntu-latest, windows-latest, macos-latest] | ||
| 26 | + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | ||
| 27 | + model_type: ["transducer"] | ||
| 28 | + | ||
| 29 | + steps: | ||
| 30 | + - uses: actions/checkout@v2 | ||
| 31 | + with: | ||
| 32 | + fetch-depth: 0 | ||
| 33 | + | ||
| 34 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 35 | + uses: actions/setup-python@v2 | ||
| 36 | + with: | ||
| 37 | + python-version: ${{ matrix.python-version }} | ||
| 38 | + | ||
| 39 | + - name: Install Python dependencies | ||
| 40 | + shell: bash | ||
| 41 | + run: | | ||
| 42 | + python3 -m pip install --upgrade pip numpy | ||
| 43 | + | ||
| 44 | + - name: Install sherpa-onnx | ||
| 45 | + shell: bash | ||
| 46 | + run: | | ||
| 47 | + python3 -m pip install --no-deps --verbose . | ||
| 48 | + python3 -m pip install websockets | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + - name: Start server for transducer models | ||
| 52 | + if: matrix.model_type == 'transducer' | ||
| 53 | + shell: bash | ||
| 54 | + run: | | ||
| 55 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 | ||
| 56 | + cd sherpa-onnx-streaming-zipformer-en-2023-06-26 | ||
| 57 | + git lfs pull --include "*.onnx" | ||
| 58 | + cd .. | ||
| 59 | + | ||
| 60 | + python3 ./python-api-examples/streaming_server.py \ | ||
| 61 | + --encoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \ | ||
| 62 | + --decoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \ | ||
| 63 | + --joiner ./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \ | ||
| 64 | + --tokens ./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt & | ||
| 65 | + echo "sleep 10 seconds to wait the server start" | ||
| 66 | + sleep 10 | ||
| 67 | + | ||
| 68 | + - name: Start client for transducer models | ||
| 69 | + if: matrix.model_type == 'transducer' | ||
| 70 | + shell: bash | ||
| 71 | + run: | | ||
| 72 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 73 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav |
c-api-examples/README.md
0 → 100644
dotnet-examples/README.md
0 → 100644
go-api-examples/README.md
0 → 100644
python-api-examples/non_streaming_server.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2022-2023 Xiaomi Corp. | ||
| 3 | +""" | ||
| 4 | +A server for non-streaming speech recognition. Non-streaming means you send all | ||
| 5 | +the content of the audio at once for recognition. | ||
| 6 | + | ||
| 7 | +It supports multiple clients sending at the same time. | ||
| 8 | + | ||
| 9 | +Usage: | ||
| 10 | + ./non_streaming_server.py --help | ||
| 11 | + | ||
| 12 | +Please refer to | ||
| 13 | + | ||
| 14 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html | ||
| 15 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html | ||
| 16 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html | ||
| 17 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html | ||
| 18 | + | ||
| 19 | +for pre-trained models to download. | ||
| 20 | + | ||
| 21 | +Usage examples: | ||
| 22 | + | ||
| 23 | +(1) Use a non-streaming transducer model | ||
| 24 | + | ||
| 25 | +cd /path/to/sherpa-onnx | ||
| 26 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26 | ||
| 27 | +cd sherpa-onnx-zipformer-en-2023-06-26 | ||
| 28 | +git lfs pull --include "*.onnx" | ||
| 29 | +cd .. | ||
| 30 | + | ||
| 31 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 32 | + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ | ||
| 33 | + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ | ||
| 34 | + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ | ||
| 35 | + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt | ||
| 36 | + | ||
| 37 | +(2) Use a non-streaming paraformer | ||
| 38 | + | ||
| 39 | +cd /path/to/sherpa-onnx | ||
| 40 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 41 | +cd sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 42 | +git lfs pull --include "*.onnx" | ||
| 43 | +cd .. | ||
| 44 | + | ||
| 45 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 46 | + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ | ||
| 47 | + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt | ||
| 48 | + | ||
| 49 | +(3) Use a non-streaming CTC model from NeMo | ||
| 50 | + | ||
| 51 | +cd /path/to/sherpa-onnx | ||
| 52 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium | ||
| 53 | +cd sherpa-onnx-nemo-ctc-en-conformer-medium | ||
| 54 | +git lfs pull --include "*.onnx" | ||
| 55 | +cd .. | ||
| 56 | + | ||
| 57 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 58 | + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ | ||
| 59 | + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt | ||
| 60 | + | ||
| 61 | +(4) Use a Whisper model | ||
| 62 | + | ||
| 63 | +cd /path/to/sherpa-onnx | ||
| 64 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en | ||
| 65 | +cd sherpa-onnx-whisper-tiny.en | ||
| 66 | +git lfs pull --include "*.onnx" | ||
| 67 | +cd .. | ||
| 68 | + | ||
| 69 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 70 | + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ | ||
| 71 | + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ | ||
| 72 | + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt | ||
| 73 | + | ||
| 74 | +---- | ||
| 75 | + | ||
| 76 | +To use a certificate so that you can use https, please use | ||
| 77 | + | ||
| 78 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 79 | + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ | ||
| 80 | + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ | ||
| 81 | + --certificate=/path/to/your/cert.pem | ||
| 82 | + | ||
| 83 | +If you don't have a certificate, please run: | ||
| 84 | + | ||
| 85 | + cd ./python-api-examples/web | ||
| 86 | + ./generate-certificate.py | ||
| 87 | + | ||
| 88 | +It will generate 3 files, one of which is the required `cert.pem`. | ||
| 89 | +""" # noqa | ||
| 90 | + | ||
| 91 | +import argparse | ||
| 92 | +import asyncio | ||
| 93 | +import http | ||
| 94 | +import logging | ||
| 95 | +import socket | ||
| 96 | +import ssl | ||
| 97 | +import sys | ||
| 98 | +import warnings | ||
| 99 | +from concurrent.futures import ThreadPoolExecutor | ||
| 100 | +from datetime import datetime | ||
| 101 | +from pathlib import Path | ||
| 102 | +from typing import Optional, Tuple | ||
| 103 | + | ||
| 104 | +import numpy as np | ||
| 105 | +import sherpa_onnx | ||
| 106 | + | ||
| 107 | +import websockets | ||
| 108 | + | ||
| 109 | +from http_server import HttpServer | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +def setup_logger( | ||
| 113 | + log_filename: str, | ||
| 114 | + log_level: str = "info", | ||
| 115 | + use_console: bool = True, | ||
| 116 | +) -> None: | ||
| 117 | + """Setup log level. | ||
| 118 | + | ||
| 119 | + Args: | ||
| 120 | + log_filename: | ||
| 121 | + The filename to save the log. | ||
| 122 | + log_level: | ||
| 123 | + The log level to use, e.g., "debug", "info", "warning", "error", | ||
| 124 | + "critical" | ||
| 125 | + use_console: | ||
| 126 | + True to also print logs to console. | ||
| 127 | + """ | ||
| 128 | + now = datetime.now() | ||
| 129 | + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") | ||
| 130 | + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
| 131 | + log_filename = f"{log_filename}-{date_time}.txt" | ||
| 132 | + | ||
| 133 | + Path(log_filename).parent.mkdir(parents=True, exist_ok=True) | ||
| 134 | + | ||
| 135 | + level = logging.ERROR | ||
| 136 | + if log_level == "debug": | ||
| 137 | + level = logging.DEBUG | ||
| 138 | + elif log_level == "info": | ||
| 139 | + level = logging.INFO | ||
| 140 | + elif log_level == "warning": | ||
| 141 | + level = logging.WARNING | ||
| 142 | + elif log_level == "critical": | ||
| 143 | + level = logging.CRITICAL | ||
| 144 | + | ||
| 145 | + logging.basicConfig( | ||
| 146 | + filename=log_filename, | ||
| 147 | + format=formatter, | ||
| 148 | + level=level, | ||
| 149 | + filemode="w", | ||
| 150 | + ) | ||
| 151 | + if use_console: | ||
| 152 | + console = logging.StreamHandler() | ||
| 153 | + console.setLevel(level) | ||
| 154 | + console.setFormatter(logging.Formatter(formatter)) | ||
| 155 | + logging.getLogger("").addHandler(console) | ||
| 156 | + | ||
| 157 | + | ||
| 158 | +def add_transducer_model_args(parser: argparse.ArgumentParser): | ||
| 159 | + parser.add_argument( | ||
| 160 | + "--encoder", | ||
| 161 | + default="", | ||
| 162 | + type=str, | ||
| 163 | + help="Path to the transducer encoder model", | ||
| 164 | + ) | ||
| 165 | + | ||
| 166 | + parser.add_argument( | ||
| 167 | + "--decoder", | ||
| 168 | + default="", | ||
| 169 | + type=str, | ||
| 170 | + help="Path to the transducer decoder model", | ||
| 171 | + ) | ||
| 172 | + | ||
| 173 | + parser.add_argument( | ||
| 174 | + "--joiner", | ||
| 175 | + default="", | ||
| 176 | + type=str, | ||
| 177 | + help="Path to the transducer joiner model", | ||
| 178 | + ) | ||
| 179 | + | ||
| 180 | + | ||
| 181 | +def add_paraformer_model_args(parser: argparse.ArgumentParser): | ||
| 182 | + parser.add_argument( | ||
| 183 | + "--paraformer", | ||
| 184 | + default="", | ||
| 185 | + type=str, | ||
| 186 | + help="Path to the model.onnx from Paraformer", | ||
| 187 | + ) | ||
| 188 | + | ||
| 189 | + | ||
| 190 | +def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): | ||
| 191 | + parser.add_argument( | ||
| 192 | + "--nemo-ctc", | ||
| 193 | + default="", | ||
| 194 | + type=str, | ||
| 195 | + help="Path to the model.onnx from NeMo CTC", | ||
| 196 | + ) | ||
| 197 | + | ||
| 198 | + | ||
| 199 | +def add_whisper_model_args(parser: argparse.ArgumentParser): | ||
| 200 | + parser.add_argument( | ||
| 201 | + "--whisper-encoder", | ||
| 202 | + default="", | ||
| 203 | + type=str, | ||
| 204 | + help="Path to whisper encoder model", | ||
| 205 | + ) | ||
| 206 | + | ||
| 207 | + parser.add_argument( | ||
| 208 | + "--whisper-decoder", | ||
| 209 | + default="", | ||
| 210 | + type=str, | ||
| 211 | + help="Path to whisper decoder model", | ||
| 212 | + ) | ||
| 213 | + | ||
| 214 | + | ||
| 215 | +def add_model_args(parser: argparse.ArgumentParser): | ||
| 216 | + add_transducer_model_args(parser) | ||
| 217 | + add_paraformer_model_args(parser) | ||
| 218 | + add_nemo_ctc_model_args(parser) | ||
| 219 | + add_whisper_model_args(parser) | ||
| 220 | + | ||
| 221 | + parser.add_argument( | ||
| 222 | + "--tokens", | ||
| 223 | + type=str, | ||
| 224 | + help="Path to tokens.txt", | ||
| 225 | + ) | ||
| 226 | + | ||
| 227 | + parser.add_argument( | ||
| 228 | + "--num-threads", | ||
| 229 | + type=int, | ||
| 230 | + default=2, | ||
| 231 | + help="Number of threads to run the neural network model", | ||
| 232 | + ) | ||
| 233 | + | ||
| 234 | + parser.add_argument( | ||
| 235 | + "--provider", | ||
| 236 | + type=str, | ||
| 237 | + default="cpu", | ||
| 238 | + help="Valid values: cpu, cuda, coreml", | ||
| 239 | + ) | ||
| 240 | + | ||
| 241 | + | ||
| 242 | +def add_feature_config_args(parser: argparse.ArgumentParser): | ||
| 243 | + parser.add_argument( | ||
| 244 | + "--sample-rate", | ||
| 245 | + type=int, | ||
| 246 | + default=16000, | ||
| 247 | + help="Sample rate of the data used to train the model. ", | ||
| 248 | + ) | ||
| 249 | + | ||
| 250 | + parser.add_argument( | ||
| 251 | + "--feat-dim", | ||
| 252 | + type=int, | ||
| 253 | + default=80, | ||
| 254 | + help="Feature dimension of the model", | ||
| 255 | + ) | ||
| 256 | + | ||
| 257 | + | ||
| 258 | +def add_decoding_args(parser: argparse.ArgumentParser): | ||
| 259 | + parser.add_argument( | ||
| 260 | + "--decoding-method", | ||
| 261 | + type=str, | ||
| 262 | + default="greedy_search", | ||
| 263 | + help="""Decoding method to use. Current supported methods are: | ||
| 264 | + - greedy_search | ||
| 265 | + - modified_beam_search (for transducer models only) | ||
| 266 | + """, | ||
| 267 | + ) | ||
| 268 | + | ||
| 269 | + add_modified_beam_search_args(parser) | ||
| 270 | + | ||
| 271 | + | ||
| 272 | +def add_modified_beam_search_args(parser: argparse.ArgumentParser): | ||
| 273 | + parser.add_argument( | ||
| 274 | + "--max-active-paths", | ||
| 275 | + type=int, | ||
| 276 | + default=4, | ||
| 277 | + help="""Used only when --decoding-method is modified_beam_search. | ||
| 278 | + It specifies number of active paths to keep during decoding. | ||
| 279 | + """, | ||
| 280 | + ) | ||
| 281 | + | ||
| 282 | + | ||
| 283 | +def check_args(args): | ||
| 284 | + if not Path(args.tokens).is_file(): | ||
| 285 | + raise ValueError(f"{args.tokens} does not exist") | ||
| 286 | + | ||
| 287 | + if args.decoding_method not in ( | ||
| 288 | + "greedy_search", | ||
| 289 | + "modified_beam_search", | ||
| 290 | + ): | ||
| 291 | + raise ValueError(f"Unsupported decoding method {args.decoding_method}") | ||
| 292 | + | ||
| 293 | + if args.decoding_method == "modified_beam_search": | ||
| 294 | + assert args.num_active_paths > 0, args.num_active_paths | ||
| 295 | + assert Path(args.encoder).is_file(), args.encoder | ||
| 296 | + assert Path(args.decoder).is_file(), args.decoder | ||
| 297 | + assert Path(args.joiner).is_file(), args.joiner | ||
| 298 | + | ||
| 299 | + | ||
| 300 | +def get_args(): | ||
| 301 | + parser = argparse.ArgumentParser( | ||
| 302 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 303 | + ) | ||
| 304 | + | ||
| 305 | + add_model_args(parser) | ||
| 306 | + add_feature_config_args(parser) | ||
| 307 | + add_decoding_args(parser) | ||
| 308 | + | ||
| 309 | + parser.add_argument( | ||
| 310 | + "--port", | ||
| 311 | + type=int, | ||
| 312 | + default=6006, | ||
| 313 | + help="The server will listen on this port", | ||
| 314 | + ) | ||
| 315 | + | ||
| 316 | + parser.add_argument( | ||
| 317 | + "--max-batch-size", | ||
| 318 | + type=int, | ||
| 319 | + default=25, | ||
| 320 | + help="""Max batch size for computation. Note if there are not enough | ||
| 321 | + requests in the queue, it will wait for max_wait_ms time. After that, | ||
| 322 | + even if there are not enough requests, it still sends the | ||
| 323 | + available requests in the queue for computation. | ||
| 324 | + """, | ||
| 325 | + ) | ||
| 326 | + | ||
| 327 | + parser.add_argument( | ||
| 328 | + "--max-wait-ms", | ||
| 329 | + type=float, | ||
| 330 | + default=5, | ||
| 331 | + help="""Max time in millisecond to wait to build batches for inference. | ||
| 332 | + If there are not enough requests in the feature queue to build a batch | ||
| 333 | + of max_batch_size, it waits up to this time before fetching available | ||
| 334 | + requests for computation. | ||
| 335 | + """, | ||
| 336 | + ) | ||
| 337 | + | ||
| 338 | + parser.add_argument( | ||
| 339 | + "--nn-pool-size", | ||
| 340 | + type=int, | ||
| 341 | + default=1, | ||
| 342 | + help="Number of threads for NN computation and decoding.", | ||
| 343 | + ) | ||
| 344 | + | ||
| 345 | + parser.add_argument( | ||
| 346 | + "--max-message-size", | ||
| 347 | + type=int, | ||
| 348 | + default=(1 << 20), | ||
| 349 | + help="""Max message size in bytes. | ||
| 350 | + The max size per message cannot exceed this limit. | ||
| 351 | + """, | ||
| 352 | + ) | ||
| 353 | + | ||
| 354 | + parser.add_argument( | ||
| 355 | + "--max-queue-size", | ||
| 356 | + type=int, | ||
| 357 | + default=32, | ||
| 358 | + help="Max number of messages in the queue for each connection.", | ||
| 359 | + ) | ||
| 360 | + | ||
| 361 | + parser.add_argument( | ||
| 362 | + "--max-active-connections", | ||
| 363 | + type=int, | ||
| 364 | + default=500, | ||
| 365 | + help="""Maximum number of active connections. The server will refuse | ||
| 366 | + to accept new connections once the current number of active connections | ||
| 367 | + equals to this limit. | ||
| 368 | + """, | ||
| 369 | + ) | ||
| 370 | + | ||
| 371 | + parser.add_argument( | ||
| 372 | + "--certificate", | ||
| 373 | + type=str, | ||
| 374 | + help="""Path to the X.509 certificate. You need it only if you want to | ||
| 375 | + use a secure websocket connection, i.e., use wss:// instead of ws://. | ||
| 376 | + You can use ./web/generate-certificate.py | ||
| 377 | + to generate the certificate `cert.pem`. | ||
| 378 | + Note ./web/generate-certificate.py will generate three files but you | ||
| 379 | + only need to pass the generated cert.pem to this option. | ||
| 380 | + """, | ||
| 381 | + ) | ||
| 382 | + | ||
| 383 | + parser.add_argument( | ||
| 384 | + "--doc-root", | ||
| 385 | + type=str, | ||
| 386 | + default="./python-api-examples/web", | ||
| 387 | + help="Path to the web root", | ||
| 388 | + ) | ||
| 389 | + | ||
| 390 | + return parser.parse_args() | ||
| 391 | + | ||
| 392 | + | ||
| 393 | +class NonStreamingServer: | ||
| 394 | + def __init__( | ||
| 395 | + self, | ||
| 396 | + recognizer: sherpa_onnx.OfflineRecognizer, | ||
| 397 | + max_batch_size: int, | ||
| 398 | + max_wait_ms: float, | ||
| 399 | + nn_pool_size: int, | ||
| 400 | + max_message_size: int, | ||
| 401 | + max_queue_size: int, | ||
| 402 | + max_active_connections: int, | ||
| 403 | + doc_root: str, | ||
| 404 | + certificate: Optional[str] = None, | ||
| 405 | + ): | ||
| 406 | + """ | ||
| 407 | + Args: | ||
| 408 | + recognizer: | ||
| 409 | + An instance of the sherpa_onnx.OfflineRecognizer. | ||
| 410 | + max_batch_size: | ||
| 411 | + Max batch size for inference. | ||
| 412 | + max_wait_ms: | ||
| 413 | + Max wait time in milliseconds in order to build a batch of | ||
| 414 | + `max_batch_size`. | ||
| 415 | + nn_pool_size: | ||
| 416 | + Number of threads for the thread pool that is used for NN | ||
| 417 | + computation and decoding. | ||
| 418 | + max_message_size: | ||
| 419 | + Max size in bytes per message. | ||
| 420 | + max_queue_size: | ||
| 421 | + Max number of messages in the queue for each connection. | ||
| 422 | + max_active_connections: | ||
| 423 | + Max number of active connections. Once number of active client | ||
| 424 | + equals to this limit, the server refuses to accept new connections. | ||
| 425 | + doc_root: | ||
| 426 | + Path to the directory where files like index.html for the HTTP | ||
| 427 | + server locate. | ||
| 428 | + certificate: | ||
| 429 | + Optional. If not None, it will use secure websocket. | ||
| 430 | + You can use ./web/generate-certificate.py to generate | ||
| 431 | + it (the default generated filename is `cert.pem`). | ||
| 432 | + """ | ||
| 433 | + self.recognizer = recognizer | ||
| 434 | + | ||
| 435 | + self.certificate = certificate | ||
| 436 | + self.http_server = HttpServer(doc_root) | ||
| 437 | + | ||
| 438 | + self.nn_pool = ThreadPoolExecutor( | ||
| 439 | + max_workers=nn_pool_size, | ||
| 440 | + thread_name_prefix="nn", | ||
| 441 | + ) | ||
| 442 | + | ||
| 443 | + self.stream_queue = asyncio.Queue() | ||
| 444 | + | ||
| 445 | + self.max_wait_ms = max_wait_ms | ||
| 446 | + self.max_batch_size = max_batch_size | ||
| 447 | + self.max_message_size = max_message_size | ||
| 448 | + self.max_queue_size = max_queue_size | ||
| 449 | + self.max_active_connections = max_active_connections | ||
| 450 | + | ||
| 451 | + self.current_active_connections = 0 | ||
| 452 | + self.sample_rate = int(recognizer.config.feat_config.sampling_rate) | ||
| 453 | + | ||
| 454 | + async def process_request( | ||
| 455 | + self, | ||
| 456 | + path: str, | ||
| 457 | + request_headers: websockets.Headers, | ||
| 458 | + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: | ||
| 459 | + if "sec-websocket-key" not in request_headers: | ||
| 460 | + # This is a normal HTTP request | ||
| 461 | + if path == "/": | ||
| 462 | + path = "/index.html" | ||
| 463 | + if path[-1] == "?": | ||
| 464 | + path = path[:-1] | ||
| 465 | + | ||
| 466 | + if path == "/streaming_record.html": | ||
| 467 | + response = r""" | ||
| 468 | +<!doctype html><html><head> | ||
| 469 | +<title>Speech recognition with next-gen Kaldi</title><body> | ||
| 470 | +<h2>Only | ||
| 471 | +<a href="/upload.html">/upload.html</a> | ||
| 472 | +and | ||
| 473 | +<a href="/offline_record.html">/offline_record.html</a> | ||
| 474 | +is available for the non-streaming server.<h2> | ||
| 475 | +<br/> | ||
| 476 | +<br/> | ||
| 477 | +Go back to <a href="/upload.html">/upload.html</a> | ||
| 478 | +or <a href="/offline_record.html">/offline_record.html</a> | ||
| 479 | +</body></head></html> | ||
| 480 | +""" | ||
| 481 | + found = True | ||
| 482 | + mime_type = "text/html" | ||
| 483 | + else: | ||
| 484 | + found, response, mime_type = self.http_server.process_request(path) | ||
| 485 | + if isinstance(response, str): | ||
| 486 | + response = response.encode("utf-8") | ||
| 487 | + | ||
| 488 | + if not found: | ||
| 489 | + status = http.HTTPStatus.NOT_FOUND | ||
| 490 | + else: | ||
| 491 | + status = http.HTTPStatus.OK | ||
| 492 | + header = {"Content-Type": mime_type} | ||
| 493 | + return status, header, response | ||
| 494 | + | ||
| 495 | + if self.current_active_connections < self.max_active_connections: | ||
| 496 | + self.current_active_connections += 1 | ||
| 497 | + return None | ||
| 498 | + | ||
| 499 | + # Refuse new connections | ||
| 500 | + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503 | ||
| 501 | + header = {"Hint": "The server is overloaded. Please retry later."} | ||
| 502 | + response = b"The server is busy. Please retry later." | ||
| 503 | + | ||
| 504 | + return status, header, response | ||
| 505 | + | ||
| 506 | + async def run(self, port: int): | ||
| 507 | + logging.info("started") | ||
| 508 | + | ||
| 509 | + task = asyncio.create_task(self.stream_consumer_task()) | ||
| 510 | + | ||
| 511 | + if self.certificate: | ||
| 512 | + logging.info(f"Using certificate: {self.certificate}") | ||
| 513 | + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | ||
| 514 | + ssl_context.load_cert_chain(self.certificate) | ||
| 515 | + else: | ||
| 516 | + ssl_context = None | ||
| 517 | + logging.info("No certificate provided") | ||
| 518 | + | ||
| 519 | + async with websockets.serve( | ||
| 520 | + self.handle_connection, | ||
| 521 | + host="", | ||
| 522 | + port=port, | ||
| 523 | + max_size=self.max_message_size, | ||
| 524 | + max_queue=self.max_queue_size, | ||
| 525 | + process_request=self.process_request, | ||
| 526 | + ssl=ssl_context, | ||
| 527 | + ): | ||
| 528 | + ip_list = ["localhost"] | ||
| 529 | + if ssl_context: | ||
| 530 | + ip_list += ["0.0.0.0", "127.0.0.1"] | ||
| 531 | + ip_list.append(socket.gethostbyname(socket.gethostname())) | ||
| 532 | + | ||
| 533 | + proto = "http://" if ssl_context is None else "https://" | ||
| 534 | + s = "Please visit one of the following addresses:\n\n" | ||
| 535 | + for p in ip_list: | ||
| 536 | + s += " " + proto + p + f":{port}" "\n" | ||
| 537 | + logging.info(s) | ||
| 538 | + | ||
| 539 | + await asyncio.Future() # run forever | ||
| 540 | + | ||
| 541 | + await task # not reachable | ||
| 542 | + | ||
| 543 | + async def recv_audio_samples( | ||
| 544 | + self, | ||
| 545 | + socket: websockets.WebSocketServerProtocol, | ||
| 546 | + ) -> Tuple[Optional[np.ndarray], Optional[float]]: | ||
| 547 | + """Receive a tensor from the client. | ||
| 548 | + | ||
| 549 | + The message from the client is a **bytes** buffer. | ||
| 550 | + | ||
| 551 | + The first message can be either "Done" meaning the client won't send | ||
| 552 | + anything in the future or it can be a buffer containing 8 bytes. | ||
| 553 | + The first 4 bytes in little endian specifies the sample | ||
| 554 | + rate of the audio samples; the second 4 bytes in little endian specifies | ||
| 555 | + the number of bytes in the audio file, which will be sent by the client | ||
| 556 | + in the subsequent messages. | ||
| 557 | + Since there is a limit in the message size posed by the websocket | ||
| 558 | + protocol, the client may send the audio file in multiple messages if the | ||
| 559 | + audio file is very large. | ||
| 560 | + | ||
| 561 | + The second and remaining messages contain audio samples. | ||
| 562 | + | ||
| 563 | + Please refer to ./offline-websocket-client-decode-files-paralell.py | ||
| 564 | + and ./offline-websocket-client-decode-files-sequential.py | ||
| 565 | + for how the client sends the message. | ||
| 566 | + | ||
| 567 | + Args: | ||
| 568 | + socket: | ||
| 569 | + The socket for communicating with the client. | ||
| 570 | + Returns: | ||
| 571 | + Return a containing: | ||
| 572 | + - 1-D np.float32 array containing the audio samples | ||
| 573 | + - sample rate of the audio samples | ||
| 574 | + or return (None, None) indicating the end of utterance. | ||
| 575 | + """ | ||
| 576 | + header = await socket.recv() | ||
| 577 | + if header == "Done": | ||
| 578 | + return None, None | ||
| 579 | + | ||
| 580 | + assert len(header) >= 8, ( | ||
| 581 | + "The first message should contain at least 8 bytes." | ||
| 582 | + + f"Given {len(header)}" | ||
| 583 | + ) | ||
| 584 | + | ||
| 585 | + sample_rate = int.from_bytes(header[:4], "little", signed=True) | ||
| 586 | + expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True) | ||
| 587 | + | ||
| 588 | + received = [] | ||
| 589 | + num_received_bytes = 0 | ||
| 590 | + if len(header) > 8: | ||
| 591 | + received.append(header[8:]) | ||
| 592 | + num_received_bytes += len(header) - 8 | ||
| 593 | + | ||
| 594 | + if num_received_bytes < expected_num_bytes: | ||
| 595 | + async for message in socket: | ||
| 596 | + received.append(message) | ||
| 597 | + num_received_bytes += len(message) | ||
| 598 | + if num_received_bytes >= expected_num_bytes: | ||
| 599 | + break | ||
| 600 | + | ||
| 601 | + assert num_received_bytes == expected_num_bytes, ( | ||
| 602 | + num_received_bytes, | ||
| 603 | + expected_num_bytes, | ||
| 604 | + ) | ||
| 605 | + | ||
| 606 | + samples = b"".join(received) | ||
| 607 | + array = np.frombuffer(samples, dtype=np.float32) | ||
| 608 | + return array, sample_rate | ||
| 609 | + | ||
| 610 | + async def stream_consumer_task(self): | ||
| 611 | + """This function extracts streams from the queue, batches them up, sends | ||
| 612 | + them to the RNN-T model for computation and decoding. | ||
| 613 | + """ | ||
| 614 | + while True: | ||
| 615 | + if self.stream_queue.empty(): | ||
| 616 | + await asyncio.sleep(self.max_wait_ms / 1000) | ||
| 617 | + continue | ||
| 618 | + | ||
| 619 | + batch = [] | ||
| 620 | + try: | ||
| 621 | + while len(batch) < self.max_batch_size: | ||
| 622 | + item = self.stream_queue.get_nowait() | ||
| 623 | + | ||
| 624 | + batch.append(item) | ||
| 625 | + except asyncio.QueueEmpty: | ||
| 626 | + pass | ||
| 627 | + stream_list = [b[0] for b in batch] | ||
| 628 | + future_list = [b[1] for b in batch] | ||
| 629 | + | ||
| 630 | + loop = asyncio.get_running_loop() | ||
| 631 | + await loop.run_in_executor( | ||
| 632 | + self.nn_pool, | ||
| 633 | + self.recognizer.decode_streams, | ||
| 634 | + stream_list, | ||
| 635 | + ) | ||
| 636 | + | ||
| 637 | + for f in future_list: | ||
| 638 | + self.stream_queue.task_done() | ||
| 639 | + f.set_result(None) | ||
| 640 | + | ||
| 641 | + async def compute_and_decode( | ||
| 642 | + self, | ||
| 643 | + stream: sherpa_onnx.OfflineStream, | ||
| 644 | + ) -> None: | ||
| 645 | + """Put the stream into the queue and wait it to be processed by the | ||
| 646 | + consumer task. | ||
| 647 | + | ||
| 648 | + Args: | ||
| 649 | + stream: | ||
| 650 | + The stream to be processed. Note: It is changed in-place. | ||
| 651 | + """ | ||
| 652 | + loop = asyncio.get_running_loop() | ||
| 653 | + future = loop.create_future() | ||
| 654 | + await self.stream_queue.put((stream, future)) | ||
| 655 | + await future | ||
| 656 | + | ||
| 657 | + async def handle_connection( | ||
| 658 | + self, | ||
| 659 | + socket: websockets.WebSocketServerProtocol, | ||
| 660 | + ): | ||
| 661 | + """Receive audio samples from the client, process it, and sends | ||
| 662 | + deocoding result back to the client. | ||
| 663 | + | ||
| 664 | + Args: | ||
| 665 | + socket: | ||
| 666 | + The socket for communicating with the client. | ||
| 667 | + """ | ||
| 668 | + try: | ||
| 669 | + await self.handle_connection_impl(socket) | ||
| 670 | + except websockets.exceptions.ConnectionClosedError: | ||
| 671 | + logging.info(f"{socket.remote_address} disconnected") | ||
| 672 | + finally: | ||
| 673 | + # Decrement so that it can accept new connections | ||
| 674 | + self.current_active_connections -= 1 | ||
| 675 | + | ||
| 676 | + logging.info( | ||
| 677 | + f"Disconnected: {socket.remote_address}. " | ||
| 678 | + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa | ||
| 679 | + ) | ||
| 680 | + | ||
| 681 | + async def handle_connection_impl( | ||
| 682 | + self, | ||
| 683 | + socket: websockets.WebSocketServerProtocol, | ||
| 684 | + ): | ||
| 685 | + """Receive audio samples from the client, process it, and send | ||
| 686 | + decoding results back to the client. | ||
| 687 | + | ||
| 688 | + Args: | ||
| 689 | + socket: | ||
| 690 | + The socket for communicating with the client. | ||
| 691 | + """ | ||
| 692 | + logging.info( | ||
| 693 | + f"Connected: {socket.remote_address}. " | ||
| 694 | + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa | ||
| 695 | + ) | ||
| 696 | + | ||
| 697 | + while True: | ||
| 698 | + stream = self.recognizer.create_stream() | ||
| 699 | + samples, sample_rate = await self.recv_audio_samples(socket) | ||
| 700 | + if samples is None: | ||
| 701 | + break | ||
| 702 | + # stream.accept_samples() runs in the main thread | ||
| 703 | + | ||
| 704 | + stream.accept_waveform(sample_rate, samples) | ||
| 705 | + | ||
| 706 | + await self.compute_and_decode(stream) | ||
| 707 | + result = stream.result.text | ||
| 708 | + logging.info(f"result: {result}") | ||
| 709 | + | ||
| 710 | + if result: | ||
| 711 | + await socket.send(result) | ||
| 712 | + else: | ||
| 713 | + # If result is an empty string, send something to the client. | ||
| 714 | + # Otherwise, socket.send() is a no-op and the client will | ||
| 715 | + # wait for a reply indefinitely. | ||
| 716 | + await socket.send("<EMPTY>") | ||
| 717 | + | ||
| 718 | + | ||
| 719 | +def assert_file_exists(filename: str): | ||
| 720 | + assert Path(filename).is_file(), ( | ||
| 721 | + f"{filename} does not exist!\n" | ||
| 722 | + "Please refer to " | ||
| 723 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 724 | + ) | ||
| 725 | + | ||
| 726 | + | ||
| 727 | +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 728 | + if args.encoder: | ||
| 729 | + assert len(args.paraformer) == 0, args.paraformer | ||
| 730 | + assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 731 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 732 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 733 | + | ||
| 734 | + assert_file_exists(args.encoder) | ||
| 735 | + assert_file_exists(args.decoder) | ||
| 736 | + assert_file_exists(args.joiner) | ||
| 737 | + | ||
| 738 | + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | ||
| 739 | + encoder=args.encoder, | ||
| 740 | + decoder=args.decoder, | ||
| 741 | + joiner=args.joiner, | ||
| 742 | + tokens=args.tokens, | ||
| 743 | + num_threads=args.num_threads, | ||
| 744 | + sample_rate=args.sample_rate, | ||
| 745 | + feature_dim=args.feat_dim, | ||
| 746 | + decoding_method=args.decoding_method, | ||
| 747 | + max_active_paths=args.max_active_paths, | ||
| 748 | + ) | ||
| 749 | + elif args.paraformer: | ||
| 750 | + assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 751 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 752 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 753 | + | ||
| 754 | + assert_file_exists(args.paraformer) | ||
| 755 | + | ||
| 756 | + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | ||
| 757 | + paraformer=args.paraformer, | ||
| 758 | + tokens=args.tokens, | ||
| 759 | + num_threads=args.num_threads, | ||
| 760 | + sample_rate=args.sample_rate, | ||
| 761 | + feature_dim=args.feat_dim, | ||
| 762 | + decoding_method=args.decoding_method, | ||
| 763 | + ) | ||
| 764 | + elif args.nemo_ctc: | ||
| 765 | + assert len(args.whisper_encoder) == 0, args.whisper_encoder | ||
| 766 | + assert len(args.whisper_decoder) == 0, args.whisper_decoder | ||
| 767 | + | ||
| 768 | + assert_file_exists(args.nemo_ctc) | ||
| 769 | + | ||
| 770 | + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | ||
| 771 | + model=args.nemo_ctc, | ||
| 772 | + tokens=args.tokens, | ||
| 773 | + num_threads=args.num_threads, | ||
| 774 | + sample_rate=args.sample_rate, | ||
| 775 | + feature_dim=args.feat_dim, | ||
| 776 | + decoding_method=args.decoding_method, | ||
| 777 | + ) | ||
| 778 | + elif args.whisper_encoder: | ||
| 779 | + assert_file_exists(args.whisper_encoder) | ||
| 780 | + assert_file_exists(args.whisper_decoder) | ||
| 781 | + | ||
| 782 | + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | ||
| 783 | + encoder=args.whisper_encoder, | ||
| 784 | + decoder=args.whisper_decoder, | ||
| 785 | + tokens=args.tokens, | ||
| 786 | + num_threads=args.num_threads, | ||
| 787 | + decoding_method=args.decoding_method, | ||
| 788 | + ) | ||
| 789 | + else: | ||
| 790 | + raise ValueError("Please specify at least one model") | ||
| 791 | + | ||
| 792 | + return recognizer | ||
| 793 | + | ||
| 794 | + | ||
| 795 | +def main(): | ||
| 796 | + args = get_args() | ||
| 797 | + logging.info(vars(args)) | ||
| 798 | + check_args(args) | ||
| 799 | + | ||
| 800 | + recognizer = create_recognizer(args) | ||
| 801 | + | ||
| 802 | + port = args.port | ||
| 803 | + max_wait_ms = args.max_wait_ms | ||
| 804 | + max_batch_size = args.max_batch_size | ||
| 805 | + nn_pool_size = args.nn_pool_size | ||
| 806 | + max_message_size = args.max_message_size | ||
| 807 | + max_queue_size = args.max_queue_size | ||
| 808 | + max_active_connections = args.max_active_connections | ||
| 809 | + certificate = args.certificate | ||
| 810 | + doc_root = args.doc_root | ||
| 811 | + | ||
| 812 | + if certificate and not Path(certificate).is_file(): | ||
| 813 | + raise ValueError(f"{certificate} does not exist") | ||
| 814 | + | ||
| 815 | + if not Path(doc_root).is_dir(): | ||
| 816 | + raise ValueError(f"Directory {doc_root} does not exist") | ||
| 817 | + | ||
| 818 | + non_streaming_server = NonStreamingServer( | ||
| 819 | + recognizer=recognizer, | ||
| 820 | + max_wait_ms=max_wait_ms, | ||
| 821 | + max_batch_size=max_batch_size, | ||
| 822 | + nn_pool_size=nn_pool_size, | ||
| 823 | + max_message_size=max_message_size, | ||
| 824 | + max_queue_size=max_queue_size, | ||
| 825 | + max_active_connections=max_active_connections, | ||
| 826 | + certificate=certificate, | ||
| 827 | + doc_root=doc_root, | ||
| 828 | + ) | ||
| 829 | + asyncio.run(non_streaming_server.run(port)) | ||
| 830 | + | ||
| 831 | + | ||
| 832 | +if __name__ == "__main__": | ||
| 833 | + log_filename = "log/log-non-streaming-server" | ||
| 834 | + setup_logger(log_filename) | ||
| 835 | + main() |
| @@ -119,7 +119,13 @@ async def run( | @@ -119,7 +119,13 @@ async def run( | ||
| 119 | buf += (samples.size * 4).to_bytes(4, byteorder="little") | 119 | buf += (samples.size * 4).to_bytes(4, byteorder="little") |
| 120 | buf += samples.tobytes() | 120 | buf += samples.tobytes() |
| 121 | 121 | ||
| 122 | - await websocket.send(buf) | 122 | + payload_len = 10240 |
| 123 | + while len(buf) > payload_len: | ||
| 124 | + await websocket.send(buf[:payload_len]) | ||
| 125 | + buf = buf[payload_len:] | ||
| 126 | + | ||
| 127 | + if buf: | ||
| 128 | + await websocket.send(buf) | ||
| 123 | 129 | ||
| 124 | decoding_results = await websocket.recv() | 130 | decoding_results = await websocket.recv() |
| 125 | logging.info(f"{wave_filename}\n{decoding_results}") | 131 | logging.info(f"{wave_filename}\n{decoding_results}") |
| @@ -116,11 +116,18 @@ async def run( | @@ -116,11 +116,18 @@ async def run( | ||
| 116 | assert isinstance(sample_rate, int) | 116 | assert isinstance(sample_rate, int) |
| 117 | assert samples.dtype == np.float32, samples.dtype | 117 | assert samples.dtype == np.float32, samples.dtype |
| 118 | assert samples.ndim == 1, samples.dim | 118 | assert samples.ndim == 1, samples.dim |
| 119 | + | ||
| 119 | buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes | 120 | buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes |
| 120 | buf += (samples.size * 4).to_bytes(4, byteorder="little") | 121 | buf += (samples.size * 4).to_bytes(4, byteorder="little") |
| 121 | buf += samples.tobytes() | 122 | buf += samples.tobytes() |
| 122 | 123 | ||
| 123 | - await websocket.send(buf) | 124 | + payload_len = 10240 |
| 125 | + while len(buf) > payload_len: | ||
| 126 | + await websocket.send(buf[:payload_len]) | ||
| 127 | + buf = buf[payload_len:] | ||
| 128 | + | ||
| 129 | + if buf: | ||
| 130 | + await websocket.send(buf) | ||
| 124 | 131 | ||
| 125 | decoding_results = await websocket.recv() | 132 | decoding_results = await websocket.recv() |
| 126 | print(decoding_results) | 133 | print(decoding_results) |
| @@ -15,10 +15,9 @@ Usage: | @@ -15,10 +15,9 @@ Usage: | ||
| 15 | 15 | ||
| 16 | (Note: You have to first start the server before starting the client) | 16 | (Note: You have to first start the server before starting the client) |
| 17 | 17 | ||
| 18 | -You can find the server at | 18 | +You can find the c++ server at |
| 19 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc | 19 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc |
| 20 | - | ||
| 21 | -Note: The server is implemented in C++. | 20 | +or use the python server ./python-api-examples/streaming_server.py |
| 22 | 21 | ||
| 23 | There is also a C++ version of the client. Please see | 22 | There is also a C++ version of the client. Please see |
| 24 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc | 23 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc |
| @@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): | @@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): | ||
| 115 | last_message = message | 114 | last_message = message |
| 116 | logging.info(message) | 115 | logging.info(message) |
| 117 | else: | 116 | else: |
| 118 | - return last_message | 117 | + break |
| 118 | + return last_message | ||
| 119 | 119 | ||
| 120 | 120 | ||
| 121 | async def run( | 121 | async def run( |
| @@ -142,6 +142,7 @@ async def run( | @@ -142,6 +142,7 @@ async def run( | ||
| 142 | 142 | ||
| 143 | await websocket.send(d) | 143 | await websocket.send(d) |
| 144 | 144 | ||
| 145 | + # Simulate streaming. You can remove the sleep if you want | ||
| 145 | await asyncio.sleep(seconds_per_message) # in seconds | 146 | await asyncio.sleep(seconds_per_message) # in seconds |
| 146 | 147 | ||
| 147 | start += samples_per_message | 148 | start += samples_per_message |
| @@ -12,10 +12,9 @@ Usage: | @@ -12,10 +12,9 @@ Usage: | ||
| 12 | 12 | ||
| 13 | (Note: You have to first start the server before starting the client) | 13 | (Note: You have to first start the server before starting the client) |
| 14 | 14 | ||
| 15 | -You can find the server at | 15 | +You can find the C++ server at |
| 16 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc | 16 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc |
| 17 | - | ||
| 18 | -Note: The server is implemented in C++. | 17 | +or use the python server ./python-api-examples/streaming_server.py |
| 19 | 18 | ||
| 20 | There is also a C++ version of the client. Please see | 19 | There is also a C++ version of the client. Please see |
| 21 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc | 20 | https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc |
| @@ -13,11 +13,37 @@ Usage: | @@ -13,11 +13,37 @@ Usage: | ||
| 13 | 13 | ||
| 14 | Example: | 14 | Example: |
| 15 | 15 | ||
| 16 | +(1) Without a certificate | ||
| 17 | + | ||
| 16 | python3 ./python-api-examples/streaming_server.py \ | 18 | python3 ./python-api-examples/streaming_server.py \ |
| 17 | --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ | 19 | --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ |
| 18 | --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | 20 | --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ |
| 19 | --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | 21 | --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ |
| 20 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt | 22 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt |
| 23 | + | ||
| 24 | +(2) With a certificate | ||
| 25 | + | ||
| 26 | +(a) Generate a certificate first: | ||
| 27 | + | ||
| 28 | + cd python-api-examples/web | ||
| 29 | + ./generate-certificate.py | ||
| 30 | + cd ../.. | ||
| 31 | + | ||
| 32 | +(b) Start the server | ||
| 33 | + | ||
| 34 | +python3 ./python-api-examples/streaming_server.py \ | ||
| 35 | + --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ | ||
| 36 | + --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | ||
| 37 | + --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | ||
| 38 | + --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ | ||
| 39 | + --certificate ./python-api-examples/web/cert.pem | ||
| 40 | + | ||
| 41 | +Please refer to | ||
| 42 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html | ||
| 43 | +to download pre-trained models. | ||
| 44 | + | ||
| 45 | +The model in the above help messages is from | ||
| 46 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english | ||
| 21 | """ | 47 | """ |
| 22 | 48 | ||
| 23 | import argparse | 49 | import argparse |
| @@ -35,6 +61,7 @@ from typing import List, Optional, Tuple | @@ -35,6 +61,7 @@ from typing import List, Optional, Tuple | ||
| 35 | import numpy as np | 61 | import numpy as np |
| 36 | import sherpa_onnx | 62 | import sherpa_onnx |
| 37 | import websockets | 63 | import websockets |
| 64 | + | ||
| 38 | from http_server import HttpServer | 65 | from http_server import HttpServer |
| 39 | 66 | ||
| 40 | 67 | ||
| @@ -269,8 +296,8 @@ def get_args(): | @@ -269,8 +296,8 @@ def get_args(): | ||
| 269 | parser.add_argument( | 296 | parser.add_argument( |
| 270 | "--num-threads", | 297 | "--num-threads", |
| 271 | type=int, | 298 | type=int, |
| 272 | - default=1, | ||
| 273 | - help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.", | 299 | + default=2, |
| 300 | + help="Number of threads to run the neural network model", | ||
| 274 | ) | 301 | ) |
| 275 | 302 | ||
| 276 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -278,8 +305,10 @@ def get_args(): | @@ -278,8 +305,10 @@ def get_args(): | ||
| 278 | type=str, | 305 | type=str, |
| 279 | help="""Path to the X.509 certificate. You need it only if you want to | 306 | help="""Path to the X.509 certificate. You need it only if you want to |
| 280 | use a secure websocket connection, i.e., use wss:// instead of ws://. | 307 | use a secure websocket connection, i.e., use wss:// instead of ws://. |
| 281 | - You can use sherpa/bin/web/generate-certificate.py | 308 | + You can use ./web/generate-certificate.py |
| 282 | to generate the certificate `cert.pem`. | 309 | to generate the certificate `cert.pem`. |
| 310 | + Note ./web/generate-certificate.py will generate three files but you | ||
| 311 | + only need to pass the generated cert.pem to this option. | ||
| 283 | """, | 312 | """, |
| 284 | ) | 313 | ) |
| 285 | 314 | ||
| @@ -287,7 +316,7 @@ def get_args(): | @@ -287,7 +316,7 @@ def get_args(): | ||
| 287 | "--doc-root", | 316 | "--doc-root", |
| 288 | type=str, | 317 | type=str, |
| 289 | default="./python-api-examples/web", | 318 | default="./python-api-examples/web", |
| 290 | - help="""Path to the web root""", | 319 | + help="Path to the web root", |
| 291 | ) | 320 | ) |
| 292 | 321 | ||
| 293 | return parser.parse_args() | 322 | return parser.parse_args() |
| @@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | @@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 299 | encoder=args.encoder_model, | 328 | encoder=args.encoder_model, |
| 300 | decoder=args.decoder_model, | 329 | decoder=args.decoder_model, |
| 301 | joiner=args.joiner_model, | 330 | joiner=args.joiner_model, |
| 302 | - num_threads=1, | ||
| 303 | - sample_rate=16000, | ||
| 304 | - feature_dim=80, | 331 | + num_threads=args.num_threads, |
| 332 | + sample_rate=args.sample_rate, | ||
| 333 | + feature_dim=args.feat_dim, | ||
| 305 | decoding_method=args.decoding_method, | 334 | decoding_method=args.decoding_method, |
| 306 | max_active_paths=args.num_active_paths, | 335 | max_active_paths=args.num_active_paths, |
| 307 | enable_endpoint_detection=args.use_endpoint != 0, | 336 | enable_endpoint_detection=args.use_endpoint != 0, |
| @@ -359,7 +388,7 @@ class StreamingServer(object): | @@ -359,7 +388,7 @@ class StreamingServer(object): | ||
| 359 | server locate. | 388 | server locate. |
| 360 | certificate: | 389 | certificate: |
| 361 | Optional. If not None, it will use secure websocket. | 390 | Optional. If not None, it will use secure websocket. |
| 362 | - You can use ./sherpa/bin/web/generate-certificate.py to generate | 391 | + You can use ./web/generate-certificate.py to generate |
| 363 | it (the default generated filename is `cert.pem`). | 392 | it (the default generated filename is `cert.pem`). |
| 364 | """ | 393 | """ |
| 365 | self.recognizer = recognizer | 394 | self.recognizer = recognizer |
| @@ -373,6 +402,7 @@ class StreamingServer(object): | @@ -373,6 +402,7 @@ class StreamingServer(object): | ||
| 373 | ) | 402 | ) |
| 374 | 403 | ||
| 375 | self.stream_queue = asyncio.Queue() | 404 | self.stream_queue = asyncio.Queue() |
| 405 | + | ||
| 376 | self.max_wait_ms = max_wait_ms | 406 | self.max_wait_ms = max_wait_ms |
| 377 | self.max_batch_size = max_batch_size | 407 | self.max_batch_size = max_batch_size |
| 378 | self.max_message_size = max_message_size | 408 | self.max_message_size = max_message_size |
| @@ -382,11 +412,10 @@ class StreamingServer(object): | @@ -382,11 +412,10 @@ class StreamingServer(object): | ||
| 382 | self.current_active_connections = 0 | 412 | self.current_active_connections = 0 |
| 383 | 413 | ||
| 384 | self.sample_rate = int(recognizer.config.feat_config.sampling_rate) | 414 | self.sample_rate = int(recognizer.config.feat_config.sampling_rate) |
| 385 | - self.decoding_method = recognizer.config.decoding_method | ||
| 386 | 415 | ||
| 387 | async def stream_consumer_task(self): | 416 | async def stream_consumer_task(self): |
| 388 | """This function extracts streams from the queue, batches them up, sends | 417 | """This function extracts streams from the queue, batches them up, sends |
| 389 | - them to the RNN-T model for computation and decoding. | 418 | + them to the neural network model for computation and decoding. |
| 390 | """ | 419 | """ |
| 391 | while True: | 420 | while True: |
| 392 | if self.stream_queue.empty(): | 421 | if self.stream_queue.empty(): |
| @@ -442,7 +471,22 @@ class StreamingServer(object): | @@ -442,7 +471,22 @@ class StreamingServer(object): | ||
| 442 | # This is a normal HTTP request | 471 | # This is a normal HTTP request |
| 443 | if path == "/": | 472 | if path == "/": |
| 444 | path = "/index.html" | 473 | path = "/index.html" |
| 445 | - found, response, mime_type = self.http_server.process_request(path) | 474 | + |
| 475 | + if path in ("/upload.html", "/offline_record.html"): | ||
| 476 | + response = r""" | ||
| 477 | +<!doctype html><html><head> | ||
| 478 | +<title>Speech recognition with next-gen Kaldi</title><body> | ||
| 479 | +<h2>Only /streaming_record.html is available for the streaming server.<h2> | ||
| 480 | +<br/> | ||
| 481 | +<br/> | ||
| 482 | +Go back to <a href="/streaming_record.html">/streaming_record.html</a> | ||
| 483 | +</body></head></html> | ||
| 484 | +""" | ||
| 485 | + found = True | ||
| 486 | + mime_type = "text/html" | ||
| 487 | + else: | ||
| 488 | + found, response, mime_type = self.http_server.process_request(path) | ||
| 489 | + | ||
| 446 | if isinstance(response, str): | 490 | if isinstance(response, str): |
| 447 | response = response.encode("utf-8") | 491 | response = response.encode("utf-8") |
| 448 | 492 | ||
| @@ -484,12 +528,21 @@ class StreamingServer(object): | @@ -484,12 +528,21 @@ class StreamingServer(object): | ||
| 484 | process_request=self.process_request, | 528 | process_request=self.process_request, |
| 485 | ssl=ssl_context, | 529 | ssl=ssl_context, |
| 486 | ): | 530 | ): |
| 487 | - ip_list = ["0.0.0.0", "localhost", "127.0.0.1"] | ||
| 488 | - ip_list.append(socket.gethostbyname(socket.gethostname())) | 531 | + ip_list = ["localhost"] |
| 532 | + if ssl_context: | ||
| 533 | + ip_list += ["0.0.0.0", "127.0.0.1"] | ||
| 534 | + ip_list.append(socket.gethostbyname(socket.gethostname())) | ||
| 489 | proto = "http://" if ssl_context is None else "https://" | 535 | proto = "http://" if ssl_context is None else "https://" |
| 490 | s = "Please visit one of the following addresses:\n\n" | 536 | s = "Please visit one of the following addresses:\n\n" |
| 491 | for p in ip_list: | 537 | for p in ip_list: |
| 492 | s += " " + proto + p + f":{port}" "\n" | 538 | s += " " + proto + p + f":{port}" "\n" |
| 539 | + | ||
| 540 | + if not ssl_context: | ||
| 541 | + s += "\nSince you are not providing a certificate, you cannot " | ||
| 542 | + s += "use your microphone from within the browser using " | ||
| 543 | + s += "public IP addresses. Only localhost can be used." | ||
| 544 | + s += "You also cannot use 0.0.0.0 or 127.0.0.1" | ||
| 545 | + | ||
| 493 | logging.info(s) | 546 | logging.info(s) |
| 494 | 547 | ||
| 495 | await asyncio.Future() # run forever | 548 | await asyncio.Future() # run forever |
| @@ -525,7 +578,7 @@ class StreamingServer(object): | @@ -525,7 +578,7 @@ class StreamingServer(object): | ||
| 525 | socket: websockets.WebSocketServerProtocol, | 578 | socket: websockets.WebSocketServerProtocol, |
| 526 | ): | 579 | ): |
| 527 | """Receive audio samples from the client, process it, and send | 580 | """Receive audio samples from the client, process it, and send |
| 528 | - deocoding result back to the client. | 581 | + decoding result back to the client. |
| 529 | 582 | ||
| 530 | Args: | 583 | Args: |
| 531 | socket: | 584 | socket: |
| @@ -560,8 +613,6 @@ class StreamingServer(object): | @@ -560,8 +613,6 @@ class StreamingServer(object): | ||
| 560 | self.recognizer.reset(stream) | 613 | self.recognizer.reset(stream) |
| 561 | segment += 1 | 614 | segment += 1 |
| 562 | 615 | ||
| 563 | - print(message) | ||
| 564 | - | ||
| 565 | await socket.send(json.dumps(message)) | 616 | await socket.send(json.dumps(message)) |
| 566 | 617 | ||
| 567 | tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) | 618 | tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) |
| @@ -583,7 +634,7 @@ class StreamingServer(object): | @@ -583,7 +634,7 @@ class StreamingServer(object): | ||
| 583 | self, | 634 | self, |
| 584 | socket: websockets.WebSocketServerProtocol, | 635 | socket: websockets.WebSocketServerProtocol, |
| 585 | ) -> Optional[np.ndarray]: | 636 | ) -> Optional[np.ndarray]: |
| 586 | - """Receives a tensor from the client. | 637 | + """Receive a tensor from the client. |
| 587 | 638 | ||
| 588 | Each message contains either a bytes buffer containing audio samples | 639 | Each message contains either a bytes buffer containing audio samples |
| 589 | in 16 kHz or contains "Done" meaning the end of utterance. | 640 | in 16 kHz or contains "Done" meaning the end of utterance. |
| @@ -660,6 +711,6 @@ def main(): | @@ -660,6 +711,6 @@ def main(): | ||
| 660 | 711 | ||
| 661 | 712 | ||
| 662 | if __name__ == "__main__": | 713 | if __name__ == "__main__": |
| 663 | - log_filename = "log/log-streaming-zipformer" | 714 | + log_filename = "log/log-streaming-server" |
| 664 | setup_logger(log_filename) | 715 | setup_logger(log_filename) |
| 665 | main() | 716 | main() |
python-api-examples/web/README.md
已删除
100644 → 0
| 1 | -# How to use | ||
| 2 | - | ||
| 3 | -```bash | ||
| 4 | -git clone https://github.com/k2-fsa/sherpa | ||
| 5 | - | ||
| 6 | -cd sherpa/sherpa/bin/web | ||
| 7 | -python3 -m http.server 6009 | ||
| 8 | -``` | ||
| 9 | -and then go to <http://localhost:6009> | ||
| 10 | - | ||
| 11 | -You will see a page like the following screenshot: | ||
| 12 | - | ||
| 13 | - | ||
| 14 | - | ||
| 15 | -If your server is listening at the port *6006* with address **localhost**, | ||
| 16 | -then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it. | ||
| 17 | - | ||
| 18 | -## File descriptions | ||
| 19 | - | ||
| 20 | -### ./css/bootstrap.min.css | ||
| 21 | - | ||
| 22 | -It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css | ||
| 23 | - | ||
| 24 | -### ./js/jquery-3.6.0.min.js | ||
| 25 | - | ||
| 26 | -It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js | ||
| 27 | - | ||
| 28 | -### ./js/popper.min.js | ||
| 29 | - | ||
| 30 | -It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js | ||
| 31 | - | ||
| 32 | -### ./js/bootstrap.min.js | ||
| 33 | - | ||
| 34 | -It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js |
| @@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser: | @@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser: | ||
| 35 | 35 | ||
| 36 | 36 | ||
| 37 | def cert_gen( | 37 | def cert_gen( |
| 38 | - emailAddress="https://github.com/k2-fsa/k2", | ||
| 39 | - commonName="sherpa", | 38 | + emailAddress="https://github.com/k2-fsa/sherpa-onnx", |
| 39 | + commonName="sherpa-onnx", | ||
| 40 | countryName="CN", | 40 | countryName="CN", |
| 41 | localityName="k2-fsa", | 41 | localityName="k2-fsa", |
| 42 | stateOrProvinceName="k2-fsa", | 42 | stateOrProvinceName="k2-fsa", |
| @@ -70,17 +70,13 @@ def cert_gen( | @@ -70,17 +70,13 @@ def cert_gen( | ||
| 70 | cert.set_pubkey(k) | 70 | cert.set_pubkey(k) |
| 71 | cert.sign(k, "sha512") | 71 | cert.sign(k, "sha512") |
| 72 | with open(CERT_FILE, "wt") as f: | 72 | with open(CERT_FILE, "wt") as f: |
| 73 | - f.write( | ||
| 74 | - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8") | ||
| 75 | - ) | 73 | + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) |
| 76 | with open(KEY_FILE, "wt") as f: | 74 | with open(KEY_FILE, "wt") as f: |
| 77 | f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) | 75 | f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) |
| 78 | 76 | ||
| 79 | with open(ALL_IN_ONE_FILE, "wt") as f: | 77 | with open(ALL_IN_ONE_FILE, "wt") as f: |
| 80 | f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) | 78 | f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) |
| 81 | - f.write( | ||
| 82 | - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8") | ||
| 83 | - ) | 79 | + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) |
| 84 | print(f"Generated {CERT_FILE}") | 80 | print(f"Generated {CERT_FILE}") |
| 85 | print(f"Generated {KEY_FILE}") | 81 | print(f"Generated {KEY_FILE}") |
| 86 | print(f"Generated {ALL_IN_ONE_FILE}") | 82 | print(f"Generated {ALL_IN_ONE_FILE}") |
| @@ -53,7 +53,7 @@ | @@ -53,7 +53,7 @@ | ||
| 53 | </ul> | 53 | </ul> |
| 54 | 54 | ||
| 55 | Code is available at | 55 | Code is available at |
| 56 | - <a href="https://github.com/k2-fsa/sherpa"> https://github.com/k2-fsa/sherpa</a> | 56 | + <a href="https://github.com/k2-fsa/sherpa-onnx"> https://github.com/k2-fsa/sherpa-onnx</a> |
| 57 | 57 | ||
| 58 | <!-- Optional JavaScript --> | 58 | <!-- Optional JavaScript --> |
| 59 | <!-- jQuery first, then Popper.js, then Bootstrap JS --> | 59 | <!-- jQuery first, then Popper.js, then Bootstrap JS --> |
| @@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips'); | @@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips'); | ||
| 60 | const canvas = document.getElementById('canvas'); | 60 | const canvas = document.getElementById('canvas'); |
| 61 | const mainSection = document.querySelector('.container'); | 61 | const mainSection = document.querySelector('.container'); |
| 62 | 62 | ||
| 63 | +recordBtn.disabled = true; | ||
| 63 | stopBtn.disabled = true; | 64 | stopBtn.disabled = true; |
| 64 | 65 | ||
| 65 | window.onload = (event) => { | 66 | window.onload = (event) => { |
| @@ -95,9 +96,10 @@ clearBtn.onclick = function() { | @@ -95,9 +96,10 @@ clearBtn.onclick = function() { | ||
| 95 | }; | 96 | }; |
| 96 | 97 | ||
| 97 | function send_header(n) { | 98 | function send_header(n) { |
| 98 | - const header = new ArrayBuffer(4); | ||
| 99 | - new DataView(header).setInt32(0, n, true /* littleEndian */); | ||
| 100 | - socket.send(new Int32Array(header, 0, 1)); | 99 | + const header = new ArrayBuffer(8); |
| 100 | + new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */); | ||
| 101 | + new DataView(header).setInt32(4, n, true /* littleEndian */); | ||
| 102 | + socket.send(new Int32Array(header, 0, 2)); | ||
| 101 | } | 103 | } |
| 102 | 104 | ||
| 103 | // copied/modified from https://mdn.github.io/web-dictaphone/ | 105 | // copied/modified from https://mdn.github.io/web-dictaphone/ |
| @@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas'); | @@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas'); | ||
| 88 | const mainSection = document.querySelector('.container'); | 88 | const mainSection = document.querySelector('.container'); |
| 89 | 89 | ||
| 90 | stopBtn.disabled = true; | 90 | stopBtn.disabled = true; |
| 91 | +recordBtn.disabled = true; | ||
| 91 | 92 | ||
| 92 | let audioCtx; | 93 | let audioCtx; |
| 93 | const canvasCtx = canvas.getContext('2d'); | 94 | const canvasCtx = canvas.getContext('2d'); |
| @@ -74,9 +74,11 @@ connectBtn.onclick = function() { | @@ -74,9 +74,11 @@ connectBtn.onclick = function() { | ||
| 74 | }; | 74 | }; |
| 75 | 75 | ||
| 76 | function send_header(n) { | 76 | function send_header(n) { |
| 77 | - const header = new ArrayBuffer(4); | ||
| 78 | - new DataView(header).setInt32(0, n, true /* littleEndian */); | ||
| 79 | - socket.send(new Int32Array(header, 0, 1)); | 77 | + const header = new ArrayBuffer(8); |
| 78 | + // assume the uploaded wave is 16000 Hz | ||
| 79 | + new DataView(header).setInt32(0, 16000, true /* littleEndian */); | ||
| 80 | + new DataView(header).setInt32(4, n, true /* littleEndian */); | ||
| 81 | + socket.send(new Int32Array(header, 0, 2)); | ||
| 80 | } | 82 | } |
| 81 | 83 | ||
| 82 | function onFileChange() { | 84 | function onFileChange() { |
| @@ -33,9 +33,9 @@ | @@ -33,9 +33,9 @@ | ||
| 33 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> | 33 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> |
| 34 | </div> | 34 | </div> |
| 35 | <span class="input-group-text" id="ws-protocol">ws://</span> | 35 | <span class="input-group-text" id="ws-protocol">ws://</span> |
| 36 | - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> | 36 | + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP"> |
| 37 | <span class="input-group-text">:</span> | 37 | <span class="input-group-text">:</span> |
| 38 | - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> | 38 | + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port"> |
| 39 | </div> | 39 | </div> |
| 40 | 40 | ||
| 41 | <div class="row"> | 41 | <div class="row"> |
| @@ -33,9 +33,9 @@ | @@ -33,9 +33,9 @@ | ||
| 33 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> | 33 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> |
| 34 | </div> | 34 | </div> |
| 35 | <span class="input-group-text" id="ws-protocol">ws://</span> | 35 | <span class="input-group-text" id="ws-protocol">ws://</span> |
| 36 | - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> | 36 | + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP"> |
| 37 | <span class="input-group-text">:</span> | 37 | <span class="input-group-text">:</span> |
| 38 | - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> | 38 | + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port"> |
| 39 | </div> | 39 | </div> |
| 40 | 40 | ||
| 41 | <div class="row"> | 41 | <div class="row"> |
| @@ -32,9 +32,9 @@ | @@ -32,9 +32,9 @@ | ||
| 32 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> | 32 | <button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button> |
| 33 | </div> | 33 | </div> |
| 34 | <span class="input-group-text" id="ws-protocol">ws://</span> | 34 | <span class="input-group-text" id="ws-protocol">ws://</span> |
| 35 | - <input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP"> | 35 | + <input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP"> |
| 36 | <span class="input-group-text">:</span> | 36 | <span class="input-group-text">:</span> |
| 37 | - <input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port"> | 37 | + <input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port"> |
| 38 | </div> | 38 | </div> |
| 39 | 39 | ||
| 40 | <form> | 40 | <form> |
| 1 | from typing import Dict, List, Optional | 1 | from typing import Dict, List, Optional |
| 2 | 2 | ||
| 3 | -from _sherpa_onnx import Display | 3 | +from _sherpa_onnx import Display, OfflineStream, OnlineStream |
| 4 | 4 | ||
| 5 | -from .online_recognizer import OnlineRecognizer | ||
| 6 | -from .online_recognizer import OnlineStream | ||
| 7 | from .offline_recognizer import OfflineRecognizer | 5 | from .offline_recognizer import OfflineRecognizer |
| 8 | - | 6 | +from .online_recognizer import OnlineRecognizer |
| 9 | from .utils import encode_contexts | 7 | from .utils import encode_contexts |
| 10 | - | ||
| 11 | - | ||
| 12 | - |
| @@ -41,6 +41,7 @@ class OfflineRecognizer(object): | @@ -41,6 +41,7 @@ class OfflineRecognizer(object): | ||
| 41 | sample_rate: int = 16000, | 41 | sample_rate: int = 16000, |
| 42 | feature_dim: int = 80, | 42 | feature_dim: int = 80, |
| 43 | decoding_method: str = "greedy_search", | 43 | decoding_method: str = "greedy_search", |
| 44 | + max_active_paths: int = 4, | ||
| 44 | context_score: float = 1.5, | 45 | context_score: float = 1.5, |
| 45 | debug: bool = False, | 46 | debug: bool = False, |
| 46 | provider: str = "cpu", | 47 | provider: str = "cpu", |
| @@ -72,6 +73,9 @@ class OfflineRecognizer(object): | @@ -72,6 +73,9 @@ class OfflineRecognizer(object): | ||
| 72 | Dimension of the feature used to train the model. | 73 | Dimension of the feature used to train the model. |
| 73 | decoding_method: | 74 | decoding_method: |
| 74 | Valid values: greedy_search, modified_beam_search. | 75 | Valid values: greedy_search, modified_beam_search. |
| 76 | + max_active_paths: | ||
| 77 | + Maximum number of active paths to keep. Used only when | ||
| 78 | + decoding_method is modified_beam_search. | ||
| 75 | debug: | 79 | debug: |
| 76 | True to show debug messages. | 80 | True to show debug messages. |
| 77 | provider: | 81 | provider: |
| @@ -103,6 +107,7 @@ class OfflineRecognizer(object): | @@ -103,6 +107,7 @@ class OfflineRecognizer(object): | ||
| 103 | context_score=context_score, | 107 | context_score=context_score, |
| 104 | ) | 108 | ) |
| 105 | self.recognizer = _Recognizer(recognizer_config) | 109 | self.recognizer = _Recognizer(recognizer_config) |
| 110 | + self.config = recognizer_config | ||
| 106 | return self | 111 | return self |
| 107 | 112 | ||
| 108 | @classmethod | 113 | @classmethod |
| @@ -166,6 +171,7 @@ class OfflineRecognizer(object): | @@ -166,6 +171,7 @@ class OfflineRecognizer(object): | ||
| 166 | decoding_method=decoding_method, | 171 | decoding_method=decoding_method, |
| 167 | ) | 172 | ) |
| 168 | self.recognizer = _Recognizer(recognizer_config) | 173 | self.recognizer = _Recognizer(recognizer_config) |
| 174 | + self.config = recognizer_config | ||
| 169 | return self | 175 | return self |
| 170 | 176 | ||
| 171 | @classmethod | 177 | @classmethod |
| @@ -229,6 +235,7 @@ class OfflineRecognizer(object): | @@ -229,6 +235,7 @@ class OfflineRecognizer(object): | ||
| 229 | decoding_method=decoding_method, | 235 | decoding_method=decoding_method, |
| 230 | ) | 236 | ) |
| 231 | self.recognizer = _Recognizer(recognizer_config) | 237 | self.recognizer = _Recognizer(recognizer_config) |
| 238 | + self.config = recognizer_config | ||
| 232 | return self | 239 | return self |
| 233 | 240 | ||
| 234 | @classmethod | 241 | @classmethod |
| @@ -291,6 +298,7 @@ class OfflineRecognizer(object): | @@ -291,6 +298,7 @@ class OfflineRecognizer(object): | ||
| 291 | decoding_method=decoding_method, | 298 | decoding_method=decoding_method, |
| 292 | ) | 299 | ) |
| 293 | self.recognizer = _Recognizer(recognizer_config) | 300 | self.recognizer = _Recognizer(recognizer_config) |
| 301 | + self.config = recognizer_config | ||
| 294 | return self | 302 | return self |
| 295 | 303 | ||
| 296 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | 304 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
-
请 注册 或 登录 后发表评论