Fangjun Kuang
Committed by GitHub

Add Python API (#31)

正在显示 51 个修改的文件 包含 967 行增加57 行删除
  1 +[flake8]
  2 +show-source=true
  3 +statistics=true
  4 +max-line-length = 80
  5 +
  6 +exclude =
  7 + .git,
  8 + ./cmake,
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +
  12 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17
  13 +
  14 +log "Start testing ${repo_url}"
  15 +repo=$(basename $repo_url)
  16 +log "Download pretrained model and test-data from $repo_url"
  17 +
  18 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  19 +pushd $repo
  20 +git lfs pull --include "*.onnx"
  21 +popd
  22 +
  23 +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
  24 +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")
  25 +
  26 +echo "sherpa_onnx version: $sherpa_onnx_version"
  27 +
  28 +pwd
  29 +ls -lh
  30 +
  31 +ls -lh $repo
  32 +
  33 +python3 python-api-examples/decode-file.py
@@ -47,7 +47,7 @@ jobs: @@ -47,7 +47,7 @@ jobs:
47 cd build 47 cd build
48 cmake -D CMAKE_BUILD_TYPE=Release .. 48 cmake -D CMAKE_BUILD_TYPE=Release ..
49 49
50 - - name: Build sherpa for macos 50 + - name: Build sherpa-onnx for macos
51 shell: bash 51 shell: bash
52 run: | 52 run: |
53 cd build 53 cd build
  1 +name: run-python-test
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/run-python-test.yaml'
  9 + - '.github/scripts/test-python.sh'
  10 + - 'CMakeLists.txt'
  11 + - 'cmake/**'
  12 + - 'sherpa-onnx/csrc/*'
  13 + pull_request:
  14 + branches:
  15 + - master
  16 + paths:
  17 + - '.github/workflows/run-python-test.yaml'
  18 + - '.github/scripts/test-python.sh'
  19 + - 'CMakeLists.txt'
  20 + - 'cmake/**'
  21 + - 'sherpa-onnx/csrc/*'
  22 +
  23 +concurrency:
  24 + group: run-python-test-${{ github.ref }}
  25 + cancel-in-progress: true
  26 +
  27 +permissions:
  28 + contents: read
  29 +
  30 +jobs:
  31 + run-python-test:
  32 + runs-on: ${{ matrix.os }}
  33 + strategy:
  34 + fail-fast: false
  35 + matrix:
  36 + os: [ubuntu-latest, macos-latest, windows-latest]
  37 + python-version: ["3.7", "3.8", "3.9", "3.10"]
  38 +
  39 + steps:
  40 + - uses: actions/checkout@v2
  41 + with:
  42 + fetch-depth: 0
  43 +
  44 + - name: Setup Python
  45 + uses: actions/setup-python@v2
  46 + with:
  47 + python-version: ${{ matrix.python-version }}
  48 +
  49 + - name: Install Python dependencies
  50 + shell: bash
  51 + run: |
  52 + python3 -m pip install --upgrade pip numpy
  53 +
  54 + - name: Install sherpa-onnx
  55 + shell: bash
  56 + run: |
  57 + python3 setup.py install
  58 +
  59 + - name: Test sherpa-onnx
  60 + shell: bash
  61 + run: |
  62 + .github/scripts/test-python.sh
@@ -5,3 +5,6 @@ onnxruntime-* @@ -5,3 +5,6 @@ onnxruntime-*
5 icefall-* 5 icefall-*
6 run.sh 6 run.sh
7 sherpa-onnx-* 7 sherpa-onnx-*
  8 +__pycache__
  9 +dist/
  10 +sherpa_onnx.egg-info/
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.0") 4 +set(SHERPA_ONNX_VERSION "1.1")
  5 +
  6 +# Disable warning about
  7 +#
  8 +# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is
  9 +# not set.
  10 +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
  11 + cmake_policy(SET CMP0135 NEW)
  12 +endif()
  13 +
  14 +option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
  15 +option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
  16 +option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
5 17
6 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 18 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
7 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 19 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -20,13 +32,20 @@ endif() @@ -20,13 +32,20 @@ endif()
20 set(CMAKE_INSTALL_RPATH ${SHERPA_ONNX_RPATH_ORIGIN}) 32 set(CMAKE_INSTALL_RPATH ${SHERPA_ONNX_RPATH_ORIGIN})
21 set(CMAKE_BUILD_RPATH ${SHERPA_ONNX_RPATH_ORIGIN}) 33 set(CMAKE_BUILD_RPATH ${SHERPA_ONNX_RPATH_ORIGIN})
22 34
23 -set(BUILD_SHARED_LIBS ON) 35 +if(WIN32 AND BUILD_SHARED_LIBS)
  36 + message(STATUS "Set BUILD_SHARED_LIBS to OFF for windows")
  37 + set(BUILD_SHARED_LIBS OFF)
  38 +endif()
24 39
25 if(NOT CMAKE_BUILD_TYPE) 40 if(NOT CMAKE_BUILD_TYPE)
26 message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") 41 message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
27 set(CMAKE_BUILD_TYPE Release) 42 set(CMAKE_BUILD_TYPE Release)
28 endif() 43 endif()
  44 +
29 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") 45 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
  46 +message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
  47 +message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
  48 +message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
30 49
31 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 50 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
32 set(CMAKE_CXX_EXTENSIONS OFF) 51 set(CMAKE_CXX_EXTENSIONS OFF)
@@ -37,4 +56,12 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) @@ -37,4 +56,12 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
37 include(kaldi-native-fbank) 56 include(kaldi-native-fbank)
38 include(onnxruntime) 57 include(onnxruntime)
39 58
  59 +if(SHERPA_ONNX_ENABLE_PYTHON)
  60 + include(pybind11)
  61 +endif()
  62 +
  63 +if(SHERPA_ONNX_ENABLE_TESTS)
  64 + enable_testing()
  65 +endif()
  66 +
40 add_subdirectory(sherpa-onnx) 67 add_subdirectory(sherpa-onnx)
  1 +# cmake/cmake_extension.py
  2 +# Copyright (c) 2023 Xiaomi Corporation
  3 +#
  4 +# flake8: noqa
  5 +
  6 +import os
  7 +import platform
  8 +import sys
  9 +from pathlib import Path
  10 +
  11 +import setuptools
  12 +from setuptools.command.build_ext import build_ext
  13 +
  14 +
  15 +def is_for_pypi():
  16 + ans = os.environ.get("SHERPA_ONNX_IS_FOR_PYPI", None)
  17 + return ans is not None
  18 +
  19 +
  20 +def is_macos():
  21 + return platform.system() == "Darwin"
  22 +
  23 +
  24 +def is_windows():
  25 + return platform.system() == "Windows"
  26 +
  27 +
  28 +try:
  29 + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
  30 +
  31 + class bdist_wheel(_bdist_wheel):
  32 + def finalize_options(self):
  33 + _bdist_wheel.finalize_options(self)
  34 + # In this case, the generated wheel has a name in the form
  35 + # sherpa-xxx-pyxx-none-any.whl
  36 + if is_for_pypi() and not is_macos():
  37 + self.root_is_pure = True
  38 + else:
  39 + # The generated wheel has a name ending with
  40 + # -linux_x86_64.whl
  41 + self.root_is_pure = False
  42 +
  43 +except ImportError:
  44 + bdist_wheel = None
  45 +
  46 +
  47 +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
  48 + kwargs["language"] = "c++"
  49 + sources = []
  50 + return setuptools.Extension(name, sources, *args, **kwargs)
  51 +
  52 +
  53 +class BuildExtension(build_ext):
  54 + def build_extension(self, ext: setuptools.extension.Extension):
  55 + # build/temp.linux-x86_64-3.8
  56 + os.makedirs(self.build_temp, exist_ok=True)
  57 +
  58 + # build/lib.linux-x86_64-3.8
  59 + os.makedirs(self.build_lib, exist_ok=True)
  60 +
  61 + install_dir = Path(self.build_lib).resolve() / "sherpa_onnx"
  62 +
  63 + sherpa_onnx_dir = Path(__file__).parent.parent.resolve()
  64 +
  65 + cmake_args = os.environ.get("SHERPA_ONNX_CMAKE_ARGS", "")
  66 + make_args = os.environ.get("SHERPA_ONNX_MAKE_ARGS", "")
  67 + system_make_args = os.environ.get("MAKEFLAGS", "")
  68 +
  69 + if cmake_args == "":
  70 + cmake_args = "-DCMAKE_BUILD_TYPE=Release"
  71 +
  72 + extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} "
  73 + if not is_windows():
  74 + extra_cmake_args += " -DBUILD_SHARED_LIBS=ON "
  75 + else:
  76 + extra_cmake_args += " -DBUILD_SHARED_LIBS=OFF "
  77 + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_PYTHON=ON "
  78 +
  79 + if "PYTHON_EXECUTABLE" not in cmake_args:
  80 + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
  81 + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
  82 +
  83 + cmake_args += extra_cmake_args
  84 +
  85 + if is_windows():
  86 + build_cmd = f"""
  87 + cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}
  88 + cmake --build {self.build_temp} --target install --config Release -- -m
  89 + """
  90 + print(f"build command is:\n{build_cmd}")
  91 + ret = os.system(
  92 + f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}"
  93 + )
  94 + if ret != 0:
  95 + raise Exception("Failed to configure sherpa")
  96 +
  97 + ret = os.system(
  98 + f"cmake --build {self.build_temp} --target install --config Release -- -m" # noqa
  99 + )
  100 + if ret != 0:
  101 + raise Exception("Failed to build and install sherpa")
  102 + else:
  103 + if make_args == "" and system_make_args == "":
  104 + print("for fast compilation, run:")
  105 + print('export SHERPA_ONNX_MAKE_ARGS="-j"; python setup.py install')
  106 + print('Setting make_args to "-j4"')
  107 + make_args = "-j4"
  108 +
  109 + build_cmd = f"""
  110 + cd {self.build_temp}
  111 +
  112 + cmake {cmake_args} {sherpa_onnx_dir}
  113 +
  114 + make {make_args} install/strip
  115 + """
  116 + print(f"build command is:\n{build_cmd}")
  117 +
  118 + ret = os.system(build_cmd)
  119 + if ret != 0:
  120 + raise Exception(
  121 + "\nBuild sherpa-onnx failed. Please check the error message.\n"
  122 + "You can ask for help by creating an issue on GitHub.\n"
  123 + "\nClick:\n\thttps://github.com/k2-fsa/sherpa-onnx/issues/new\n" # noqa
  124 + )
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz")  
5 - set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz")
  5 + set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57")
6 6
7 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 7 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
8 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -11,11 +11,11 @@ function(download_kaldi_native_fbank) @@ -11,11 +11,11 @@ function(download_kaldi_native_fbank)
11 # If you don't have access to the Internet, 11 # If you don't have access to the Internet,
12 # please pre-download kaldi-native-fbank 12 # please pre-download kaldi-native-fbank
13 set(possible_file_locations 13 set(possible_file_locations
14 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz  
15 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz  
16 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz  
17 - /tmp/kaldi-native-fbank-1.12.tar.gz  
18 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz 14 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz
  15 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz
  16 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz
  17 + /tmp/kaldi-native-fbank-1.13.tar.gz
  18 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz
19 ) 19 )
20 20
21 foreach(f IN LISTS possible_file_locations) 21 foreach(f IN LISTS possible_file_locations)
@@ -44,6 +44,7 @@ function(download_kaldi_native_fbank) @@ -44,6 +44,7 @@ function(download_kaldi_native_fbank)
44 INTERFACE 44 INTERFACE
45 ${kaldi_native_fbank_SOURCE_DIR}/ 45 ${kaldi_native_fbank_SOURCE_DIR}/
46 ) 46 )
  47 + install(TARGETS kaldi-native-fbank-core DESTINATION lib)
47 endfunction() 48 endfunction()
48 49
49 download_kaldi_native_fbank() 50 download_kaldi_native_fbank()
@@ -85,6 +85,7 @@ function(download_onnxruntime) @@ -85,6 +85,7 @@ function(download_onnxruntime)
85 message(STATUS "location_onnxruntime: ${location_onnxruntime}") 85 message(STATUS "location_onnxruntime: ${location_onnxruntime}")
86 86
87 add_library(onnxruntime SHARED IMPORTED) 87 add_library(onnxruntime SHARED IMPORTED)
  88 +
88 set_target_properties(onnxruntime PROPERTIES 89 set_target_properties(onnxruntime PROPERTIES
89 IMPORTED_LOCATION ${location_onnxruntime} 90 IMPORTED_LOCATION ${location_onnxruntime}
90 INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" 91 INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
@@ -100,6 +101,18 @@ function(download_onnxruntime) @@ -100,6 +101,18 @@ function(download_onnxruntime)
100 ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} 101 ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
101 ) 102 )
102 endif() 103 endif()
  104 +
  105 +
  106 + if(UNIX AND NOT APPLE)
  107 + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*")
  108 + elseif(APPLE)
  109 + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*dylib")
  110 + elseif(WIN32)
  111 + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/*.dll")
  112 + endif()
  113 +
  114 + message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
  115 + install(FILES ${onnxruntime_lib_files} DESTINATION lib)
103 endfunction() 116 endfunction()
104 117
105 download_onnxruntime() 118 download_onnxruntime()
  1 +function(download_pybind11)
  2 + include(FetchContent)
  3 +
  4 + set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.2.tar.gz")
  5 + set(pybind11_HASH "SHA256=93bd1e625e43e03028a3ea7389bba5d3f9f2596abc074b068e70f4ef9b1314ae")
  6 +
  7 + # If you don't have access to the Internet,
  8 + # please pre-download pybind11
  9 + set(possible_file_locations
  10 + $ENV{HOME}/Downloads/pybind11-2.10.2.tar.gz
  11 + ${PROJECT_SOURCE_DIR}/pybind11-2.10.2.tar.gz
  12 + ${PROJECT_BINARY_DIR}/pybind11-2.10.2.tar.gz
  13 + /tmp/pybind11-2.10.2.tar.gz
  14 + /star-fj/fangjun/download/github/pybind11-2.10.2.tar.gz
  15 + )
  16 +
  17 + foreach(f IN LISTS possible_file_locations)
  18 + if(EXISTS ${f})
  19 + set(pybind11_URL "file://${f}")
  20 + break()
  21 + endif()
  22 + endforeach()
  23 +
  24 + FetchContent_Declare(pybind11
  25 + URL ${pybind11_URL}
  26 + URL_HASH ${pybind11_HASH}
  27 + )
  28 +
  29 + FetchContent_GetProperties(pybind11)
  30 + if(NOT pybind11_POPULATED)
  31 + message(STATUS "Downloading pybind11 from ${pybind11_URL}")
  32 + FetchContent_Populate(pybind11)
  33 + endif()
  34 + message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")
  35 + add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL)
  36 +endfunction()
  37 +
  38 +download_pybind11()
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file demonstrates how to use sherpa-onnx Python API to recognize
  5 +a single file.
  6 +
  7 +Please refer to
  8 +https://k2-fsa.github.io/sherpa/onnx/index.html
  9 +to install sherpa-onnx and to download the pre-trained models
  10 +used in this file.
  11 +"""
  12 +import wave
  13 +import time
  14 +
  15 +import numpy as np
  16 +import sherpa_onnx
  17 +
  18 +
  19 +def main():
  20 + sample_rate = 16000
  21 + num_threads = 4
  22 + recognizer = sherpa_onnx.OnlineRecognizer(
  23 + tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt",
  24 + encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx",
  25 + decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx",
  26 + joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx",
  27 + num_threads=num_threads,
  28 + sample_rate=sample_rate,
  29 + feature_dim=80,
  30 + )
  31 + filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
  32 + with wave.open(filename) as f:
  33 + assert f.getframerate() == sample_rate, f.getframerate()
  34 + assert f.getnchannels() == 1, f.getnchannels()
  35 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  36 + num_samples = f.getnframes()
  37 + samples = f.readframes(num_samples)
  38 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  39 + samples_float32 = samples_int16.astype(np.float32)
  40 +
  41 + samples_float32 = samples_float32 / 32768
  42 +
  43 + duration = len(samples_float32) / sample_rate
  44 +
  45 + start_time = time.time()
  46 + print("Started!")
  47 +
  48 + stream = recognizer.create_stream()
  49 +
  50 + stream.accept_waveform(sample_rate, samples_float32)
  51 +
  52 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  53 + stream.accept_waveform(sample_rate, tail_paddings)
  54 +
  55 + stream.input_finished()
  56 +
  57 + while recognizer.is_ready(stream):
  58 + recognizer.decode_stream(stream)
  59 +
  60 + print(recognizer.get_result(stream))
  61 +
  62 + print("Done!")
  63 + end_time = time.time()
  64 + elapsed_seconds = end_time - start_time
  65 + rtf = elapsed_seconds / duration
  66 + print(f"num_threads: {num_threads}")
  67 + print(f"Wave duration: {duration:.3f} s")
  68 + print(f"Elapsed time: {elapsed_seconds:.3f} s")
  69 + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
  70 +
  71 +
  72 +if __name__ == "__main__":
  73 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +import os
  4 +import re
  5 +import sys
  6 +from pathlib import Path
  7 +
  8 +import setuptools
  9 +
  10 +from cmake.cmake_extension import (
  11 + BuildExtension,
  12 + bdist_wheel,
  13 + cmake_extension,
  14 + is_windows,
  15 +)
  16 +
  17 +
  18 +def read_long_description():
  19 + with open("README.md", encoding="utf8") as f:
  20 + readme = f.read()
  21 + return readme
  22 +
  23 +
  24 +def get_package_version():
  25 + with open("CMakeLists.txt") as f:
  26 + content = f.read()
  27 +
  28 + match = re.search(r"set\(SHERPA_ONNX_VERSION (.*)\)", content)
  29 + latest_version = match.group(1).strip('"')
  30 + return latest_version
  31 +
  32 +
  33 +package_name = "sherpa-onnx"
  34 +
  35 +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
  36 + f.write(f"__version__ = '{get_package_version()}'\n")
  37 +
  38 +install_requires = [
  39 + "numpy",
  40 +]
  41 +
  42 +setuptools.setup(
  43 + name=package_name,
  44 + python_requires=">=3.6",
  45 + install_requires=install_requires,
  46 + version=get_package_version(),
  47 + author="The sherpa-onnx development team",
  48 + author_email="dpovey@gmail.com",
  49 + package_dir={
  50 + "sherpa_onnx": "sherpa-onnx/python/sherpa_onnx",
  51 + },
  52 + packages=["sherpa_onnx"],
  53 + url="https://github.com/k2-fsa/sherpa-onnx",
  54 + long_description=read_long_description(),
  55 + long_description_content_type="text/markdown",
  56 + ext_modules=[cmake_extension("_sherpa_onnx")],
  57 + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel},
  58 + zip_safe=False,
  59 + classifiers=[
  60 + "Programming Language :: C++",
  61 + "Programming Language :: Python",
  62 + "Topic :: Scientific/Engineering :: Artificial Intelligence",
  63 + ],
  64 + license="Apache licensed, as found in the LICENSE file",
  65 +)
  66 +
  67 +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "r") as f:
  68 + lines = f.readlines()
  69 +
  70 +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "w") as f:
  71 + for line in lines:
  72 + if "__version__" in line:
  73 + # skip __version__ = "x.x.x"
  74 + continue
  75 + f.write(line)
1 add_subdirectory(csrc) 1 add_subdirectory(csrc)
  2 +if(SHERPA_ONNX_ENABLE_PYTHON)
  3 + add_subdirectory(python)
  4 +endif()
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 -add_executable(sherpa-onnx 3 +add_library(sherpa-onnx-core
4 features.cc 4 features.cc
5 online-lstm-transducer-model.cc 5 online-lstm-transducer-model.cc
6 online-recognizer.cc 6 online-recognizer.cc
@@ -9,15 +9,21 @@ add_executable(sherpa-onnx @@ -9,15 +9,21 @@ add_executable(sherpa-onnx
9 online-transducer-model-config.cc 9 online-transducer-model-config.cc
10 online-transducer-model.cc 10 online-transducer-model.cc
11 onnx-utils.cc 11 onnx-utils.cc
12 - sherpa-onnx.cc  
13 symbol-table.cc 12 symbol-table.cc
14 wave-reader.cc 13 wave-reader.cc
15 ) 14 )
16 15
17 -target_link_libraries(sherpa-onnx 16 +target_link_libraries(sherpa-onnx-core
18 onnxruntime 17 onnxruntime
19 kaldi-native-fbank-core 18 kaldi-native-fbank-core
20 ) 19 )
21 20
22 -add_executable(sherpa-onnx-show-info show-onnx-info.cc)  
23 -target_link_libraries(sherpa-onnx-show-info onnxruntime) 21 +add_executable(sherpa-onnx sherpa-onnx.cc)
  22 +
  23 +target_link_libraries(sherpa-onnx sherpa-onnx-core)
  24 +if(NOT WIN32)
  25 + target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
  26 +endif()
  27 +
  28 +install(TARGETS sherpa-onnx-core DESTINATION lib)
  29 +install(TARGETS sherpa-onnx DESTINATION bin)
1 -// sherpa/csrc/features.cc 1 +// sherpa-onnx/csrc/features.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/features.h 1 +// sherpa-onnx/csrc/features.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/online-lstm-transducer-model.cc 1 +// sherpa-onnx/csrc/online-lstm-transducer-model.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
@@ -232,7 +232,7 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { @@ -232,7 +232,7 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
232 232
233 std::pair<Ort::Value, std::vector<Ort::Value>> 233 std::pair<Ort::Value, std::vector<Ort::Value>>
234 OnlineLstmTransducerModel::RunEncoder(Ort::Value features, 234 OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
235 - std::vector<Ort::Value> &states) { 235 + std::vector<Ort::Value> states) {
236 auto memory_info = 236 auto memory_info =
237 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 237 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
238 238
1 -// sherpa/csrc/online-lstm-transducer-model.h 1 +// sherpa-onnx/csrc/online-lstm-transducer-model.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ 4 #ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
@@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
28 std::vector<Ort::Value> GetEncoderInitStates() override; 28 std::vector<Ort::Value> GetEncoderInitStates() override;
29 29
30 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 30 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
31 - Ort::Value features, std::vector<Ort::Value> &states) override; 31 + Ort::Value features, std::vector<Ort::Value> states) override;
32 32
33 Ort::Value BuildDecoderInput( 33 Ort::Value BuildDecoderInput(
34 const std::vector<OnlineTransducerDecoderResult> &results) override; 34 const std::vector<OnlineTransducerDecoderResult> &results) override;
@@ -98,7 +98,7 @@ class OnlineRecognizer::Impl { @@ -98,7 +98,7 @@ class OnlineRecognizer::Impl {
98 98
99 auto states = model_->StackStates(states_vec); 99 auto states = model_->StackStates(states_vec);
100 100
101 - auto pair = model_->RunEncoder(std::move(x), states); 101 + auto pair = model_->RunEncoder(std::move(x), std::move(states));
102 102
103 decoder_->Decode(std::move(pair.first), &results); 103 decoder_->Decode(std::move(pair.first), &results);
104 104
@@ -23,6 +23,13 @@ struct OnlineRecognizerConfig { @@ -23,6 +23,13 @@ struct OnlineRecognizerConfig {
23 OnlineTransducerModelConfig model_config; 23 OnlineTransducerModelConfig model_config;
24 std::string tokens; 24 std::string tokens;
25 25
  26 + OnlineRecognizerConfig() = default;
  27 +
  28 + OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
  29 + const OnlineTransducerModelConfig &model_config,
  30 + const std::string &tokens)
  31 + : feat_config(feat_config), model_config(model_config), tokens(tokens) {}
  32 +
26 std::string ToString() const; 33 std::string ToString() const;
27 }; 34 };
28 35
1 -// sherpa/csrc/online-transducer-decoder.h 1 +// sherpa-onnx/csrc/online-transducer-decoder.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/online-transducer-greedy-search-decoder.cc 1 +// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/online-transducer-greedy-search-decoder.h 1 +// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/online-transducer-model-config.cc 1 +// sherpa-onnx/csrc/online-transducer-model-config.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 4 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
1 -// sherpa/csrc/online-transducer-model-config.h 1 +// sherpa-onnx/csrc/online-transducer-model-config.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ 4 #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
@@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig { @@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig {
15 int32_t num_threads; 15 int32_t num_threads;
16 bool debug = false; 16 bool debug = false;
17 17
  18 + OnlineTransducerModelConfig() = default;
  19 + OnlineTransducerModelConfig(const std::string &encoder_filename,
  20 + const std::string &decoder_filename,
  21 + const std::string &joiner_filename,
  22 + int32_t num_threads, bool debug)
  23 + : encoder_filename(encoder_filename),
  24 + decoder_filename(decoder_filename),
  25 + joiner_filename(joiner_filename),
  26 + num_threads(num_threads),
  27 + debug(debug) {}
  28 +
18 std::string ToString() const; 29 std::string ToString() const;
19 }; 30 };
20 31
1 -// sherpa/csrc/online-transducer-model.cc 1 +// sherpa-onnx/csrc/online-transducer-model.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-transducer-model.h" 4 #include "sherpa-onnx/csrc/online-transducer-model.h"
1 -// sherpa/csrc/online-transducer-model.h 1 +// sherpa-onnx/csrc/online-transducer-model.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ 4 #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
@@ -59,7 +59,7 @@ class OnlineTransducerModel { @@ -59,7 +59,7 @@ class OnlineTransducerModel {
59 */ 59 */
60 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 60 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
61 Ort::Value features, 61 Ort::Value features,
62 - std::vector<Ort::Value> &states) = 0; // NOLINT 62 + std::vector<Ort::Value> states) = 0; // NOLINT
63 63
64 virtual Ort::Value BuildDecoderInput( 64 virtual Ort::Value BuildDecoderInput(
65 const std::vector<OnlineTransducerDecoderResult> &results) = 0; 65 const std::vector<OnlineTransducerDecoderResult> &results) = 0;
1 -// sherpa/csrc/onnx-utils.cc 1 +// sherpa-onnx/csrc/onnx-utils.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/onnx-utils.h" 4 #include "sherpa-onnx/csrc/onnx-utils.h"
1 -// sherpa/csrc/onnx-utils.h 1 +// sherpa-onnx/csrc/onnx-utils.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 4 #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
1 -// sherpa-onnx/csrc/show-onnx-info.cc  
2 -//  
3 -// Copyright (c) 2022-2023 Xiaomi Corporation  
4 -  
5 -#include <iostream>  
6 -#include <sstream>  
7 -  
8 -#include "onnxruntime_cxx_api.h" // NOLINT  
9 -  
10 -int main() {  
11 - std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";  
12 - std::vector<std::string> providers = Ort::GetAvailableProviders();  
13 - std::ostringstream os;  
14 - os << "Available providers: ";  
15 - std::string sep = "";  
16 - for (const auto &p : providers) {  
17 - os << sep << p;  
18 - sep = ", ";  
19 - }  
20 - std::cout << os.str() << "\n";  
21 - return 0;  
22 -}  
1 -// sherpa-onnx/csrc/symbol-table.cc 1 +// sherpa-onnx/csrc/symbol-table.h
2 // 2 //
3 // Copyright (c) 2022-2023 Xiaomi Corporation 3 // Copyright (c) 2022-2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/wave-reader.cc 1 +// sherpa-onnx/csrc/wave-reader.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
1 -// sherpa/csrc/wave-reader.h 1 +// sherpa-onnx/csrc/wave-reader.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
  1 +add_subdirectory(csrc)
  2 +
  3 +if(SHERPA_ONNX_ENABLE_TESTS)
  4 + add_subdirectory(tests)
  5 +endif()
  1 +include_directories(${CMAKE_SOURCE_DIR})
  2 +
  3 +pybind11_add_module(_sherpa_onnx
  4 + features.cc
  5 + online-transducer-model-config.cc
  6 + sherpa-onnx.cc
  7 + online-stream.cc
  8 + online-recognizer.cc
  9 +)
  10 +
  11 +if(APPLE)
  12 + execute_process(
  13 + COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())"
  14 + OUTPUT_STRIP_TRAILING_WHITESPACE
  15 + OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR
  16 + )
  17 + message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}")
  18 + target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}")
  19 +endif()
  20 +
  21 +if(NOT WIN32)
  22 + target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/sherpa_onnx/lib")
  23 +endif()
  24 +
  25 +target_link_libraries(_sherpa_onnx PRIVATE sherpa-onnx-core)
  26 +
  27 +install(TARGETS _sherpa_onnx
  28 + DESTINATION ../
  29 +)
  1 +// sherpa-onnx/python/csrc/features.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/features.h"
  6 +
  7 +#include "sherpa-onnx/csrc/features.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static void PybindFeatureExtractorConfig(py::module *m) {
  12 + using PyClass = FeatureExtractorConfig;
  13 + py::class_<PyClass>(*m, "FeatureExtractorConfig")
  14 + .def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
  15 + py::arg("feature_dim") = 80)
  16 + .def_readwrite("sampling_rate", &PyClass::sampling_rate)
  17 + .def_readwrite("feature_dim", &PyClass::feature_dim)
  18 + .def("__str__", &PyClass::ToString);
  19 +}
  20 +
  21 +void PybindFeatures(py::module *m) { PybindFeatureExtractorConfig(m); }
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/features.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindFeatures(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
  1 +// sherpa-onnx/python/csrc/online-recongizer.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-recognizer.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-recognizer.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +static void PybindOnlineRecognizerResult(py::module *m) {
  15 + using PyClass = OnlineRecognizerResult;
  16 + py::class_<PyClass>(*m, "OnlineRecognizerResult")
  17 + .def_property_readonly("text", [](PyClass &self) { return self.text; });
  18 +}
  19 +
  20 +static void PybindOnlineRecognizerConfig(py::module *m) {
  21 + using PyClass = OnlineRecognizerConfig;
  22 + py::class_<PyClass>(*m, "OnlineRecognizerConfig")
  23 + .def(py::init<const FeatureExtractorConfig &,
  24 + const OnlineTransducerModelConfig &, const std::string &>(),
  25 + py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
  26 + .def_readwrite("feat_config", &PyClass::feat_config)
  27 + .def_readwrite("model_config", &PyClass::model_config)
  28 + .def_readwrite("tokens", &PyClass::tokens)
  29 + .def("__str__", &PyClass::ToString);
  30 +}
  31 +
  32 +void PybindOnlineRecognizer(py::module *m) {
  33 + PybindOnlineRecognizerResult(m);
  34 + PybindOnlineRecognizerConfig(m);
  35 +
  36 + using PyClass = OnlineRecognizer;
  37 + py::class_<PyClass>(*m, "OnlineRecognizer")
  38 + .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
  39 + .def("create_stream", &PyClass::CreateStream)
  40 + .def("is_ready", &PyClass::IsReady)
  41 + .def("decode_stream", &PyClass::DecodeStream)
  42 + .def("decode_streams",
  43 + [](PyClass &self, std::vector<OnlineStream *> ss) {
  44 + self.DecodeStreams(ss.data(), ss.size());
  45 + })
  46 + .def("get_result", &PyClass::GetResult);
  47 +}
  48 +
  49 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-recongizer.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineRecognizer(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
  1 +// sherpa-onnx/python/csrc/online-stream.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-stream.h"
  6 +
  7 +#include "sherpa-onnx/csrc/online-stream.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +void PybindOnlineStream(py::module *m) {
  12 + using PyClass = OnlineStream;
  13 + py::class_<PyClass>(*m, "OnlineStream")
  14 + .def("accept_waveform",
  15 + [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
  16 + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
  17 + })
  18 + .def("input_finished", &PyClass::InputFinished);
  19 +}
  20 +
  21 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-stream.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineStream(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
  1 +// sherpa-onnx/python/csrc/online-transducer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOnlineTransducerModelConfig(py::module *m) {
  14 + using PyClass = OnlineTransducerModelConfig;
  15 + py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
  16 + .def(py::init<const std::string &, const std::string &,
  17 + const std::string &, int32_t, bool>(),
  18 + py::arg("encoder_filename"), py::arg("decoder_filename"),
  19 + py::arg("joiner_filename"), py::arg("num_threads"),
  20 + py::arg("debug") = false)
  21 + .def_readwrite("encoder_filename", &PyClass::encoder_filename)
  22 + .def_readwrite("decoder_filename", &PyClass::decoder_filename)
  23 + .def_readwrite("joiner_filename", &PyClass::joiner_filename)
  24 + .def_readwrite("num_threads", &PyClass::num_threads)
  25 + .def_readwrite("debug", &PyClass::debug)
  26 + .def("__str__", &PyClass::ToString);
  27 +}
  28 +
  29 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-transducer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineTransducerModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/sherpa-onnx.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  6 +
  7 +#include "sherpa-onnx/python/csrc/features.h"
  8 +#include "sherpa-onnx/python/csrc/online-recognizer.h"
  9 +#include "sherpa-onnx/python/csrc/online-stream.h"
  10 +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +PYBIND11_MODULE(_sherpa_onnx, m) {
  15 + m.doc() = "pybind11 binding of sherpa-onnx";
  16 + PybindFeatures(&m);
  17 + PybindOnlineTransducerModelConfig(&m);
  18 + PybindOnlineStream(&m);
  19 + PybindOnlineRecognizer(&m);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/sherpa-onnx.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
  7 +
  8 +#include "pybind11/numpy.h"
  9 +#include "pybind11/pybind11.h"
  10 +#include "pybind11/stl.h"
  11 +
  12 +namespace py = pybind11;
  13 +
  14 +#endif // SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
  1 +from _sherpa_onnx import (
  2 + FeatureExtractorConfig,
  3 + OnlineRecognizerConfig,
  4 + OnlineStream,
  5 + OnlineTransducerModelConfig,
  6 +)
  7 +
  8 +from .online_recognizer import OnlineRecognizer
  1 +from pathlib import Path
  2 +from typing import List
  3 +
  4 +from _sherpa_onnx import (
  5 + OnlineStream,
  6 + OnlineTransducerModelConfig,
  7 + FeatureExtractorConfig,
  8 + OnlineRecognizerConfig,
  9 +)
  10 +from _sherpa_onnx import OnlineRecognizer as _Recognizer
  11 +
  12 +
  13 +def _assert_file_exists(f: str):
  14 + assert Path(f).is_file(), f"{f} does not exist"
  15 +
  16 +
  17 +class OnlineRecognizer(object):
  18 + """A class for streaming speech recognition."""
  19 +
  20 + def __init__(
  21 + self,
  22 + tokens: str,
  23 + encoder: str,
  24 + decoder: str,
  25 + joiner: str,
  26 + num_threads: int = 4,
  27 + sample_rate: float = 16000,
  28 + feature_dim: int = 80,
  29 + ):
  30 + """
  31 + Please refer to
  32 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  33 + to download pre-trained models for different languages, e.g., Chinese,
  34 + English, etc.
  35 +
  36 + Args:
  37 + tokens:
  38 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  39 + columns::
  40 +
  41 + symbol integer_id
  42 +
  43 + encoder:
  44 + Path to ``encoder.onnx``.
  45 + decoder:
  46 + Path to ``decoder.onnx``.
  47 + joiner:
  48 + Path to ``joiner.onnx``.
  49 + num_threads:
  50 + Number of threads for neural network computation.
  51 + sample_rate:
  52 + Sample rate of the training data used to train the model.
  53 + feature_dim:
  54 + Dimension of the feature used to train the model.
  55 + """
  56 + _assert_file_exists(tokens)
  57 + _assert_file_exists(encoder)
  58 + _assert_file_exists(decoder)
  59 + _assert_file_exists(joiner)
  60 +
  61 + assert num_threads > 0, num_threads
  62 +
  63 + model_config = OnlineTransducerModelConfig(
  64 + encoder_filename=encoder,
  65 + decoder_filename=decoder,
  66 + joiner_filename=joiner,
  67 + num_threads=num_threads,
  68 + )
  69 +
  70 + feat_config = FeatureExtractorConfig(
  71 + sampling_rate=sample_rate,
  72 + feature_dim=feature_dim,
  73 + )
  74 +
  75 + recognizer_config = OnlineRecognizerConfig(
  76 + feat_config=feat_config,
  77 + model_config=model_config,
  78 + tokens=tokens,
  79 + )
  80 +
  81 + self.recognizer = _Recognizer(recognizer_config)
  82 +
  83 + def create_stream(self):
  84 + return self.recognizer.create_stream()
  85 +
  86 + def decode_stream(self, s: OnlineStream):
  87 + self.recognizer.decode_stream(s)
  88 +
  89 + def decode_streams(self, ss: List[OnlineStream]):
  90 + self.recognizer.decode_streams(ss)
  91 +
  92 + def is_ready(self, s: OnlineStream) -> bool:
  93 + return self.recognizer.is_ready(s)
  94 +
  95 + def get_result(self, s: OnlineStream) -> str:
  96 + return self.recognizer.get_result(s).text
  1 +function(sherpa_onnx_add_py_test source)
  2 + get_filename_component(name ${source} NAME_WE)
  3 + set(name "${name}_py")
  4 +
  5 + add_test(NAME ${name}
  6 + COMMAND
  7 + "${PYTHON_EXECUTABLE}"
  8 + "${CMAKE_CURRENT_SOURCE_DIR}/${source}"
  9 + )
  10 +
  11 + get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
  12 +
  13 + set_property(TEST ${name}
  14 + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
  15 + )
  16 +endfunction()
  17 +
  18 +# please sort the files in alphabetic order
  19 +set(py_test_files
  20 + test_feature_extractor_config.py
  21 + test_online_transducer_model_config.py
  22 +)
  23 +
  24 +foreach(source IN LISTS py_test_files)
  25 + sherpa_onnx_add_py_test(${source})
  26 +endforeach()
  27 +
  1 +# sherpa-onnx/python/tests/test_feature_extractor_config.py
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_feature_extractor_config_py
  8 +
  9 +import unittest
  10 +
  11 +import sherpa_onnx
  12 +
  13 +
  14 +class TestFeatureExtractorConfig(unittest.TestCase):
  15 + def test_default_constructor(self):
  16 + config = sherpa_onnx.FeatureExtractorConfig()
  17 + assert config.sampling_rate == 16000, config.sampling_rate
  18 + assert config.feature_dim == 80, config.feature_dim
  19 + print(config)
  20 +
  21 + def test_constructor(self):
  22 + config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
  23 + assert config.sampling_rate == 8000, config.sampling_rate
  24 + assert config.feature_dim == 40, config.feature_dim
  25 + print(config)
  26 +
  27 +
  28 +if __name__ == "__main__":
  29 + unittest.main()
  1 +# sherpa-onnx/python/tests/test_online_transducer_model_config.py
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_online_transducer_model_config_py
  8 +
  9 +import unittest
  10 +
  11 +import sherpa_onnx
  12 +
  13 +
  14 +class TestOnlineTransducerModelConfig(unittest.TestCase):
  15 + def test_constructor(self):
  16 + config = sherpa_onnx.OnlineTransducerModelConfig(
  17 + encoder_filename="encoder.onnx",
  18 + decoder_filename="decoder.onnx",
  19 + joiner_filename="joiner.onnx",
  20 + num_threads=8,
  21 + debug=True,
  22 + )
  23 + assert config.encoder_filename == "encoder.onnx", config.encoder_filename
  24 + assert config.decoder_filename == "decoder.onnx", config.decoder_filename
  25 + assert config.joiner_filename == "joiner.onnx", config.joiner_filename
  26 + assert config.num_threads == 8, config.num_threads
  27 + assert config.debug is True, config.debug
  28 + print(config)
  29 +
  30 +
  31 +if __name__ == "__main__":
  32 + unittest.main()