Fangjun Kuang
Committed by GitHub

Refactor python examples (#67)

@@ -9,7 +9,7 @@ log() { @@ -9,7 +9,7 @@ log() {
9 } 9 }
10 10
11 11
12 -repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 12 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
13 13
14 log "Start testing ${repo_url}" 14 log "Start testing ${repo_url}"
15 repo=$(basename $repo_url) 15 repo=$(basename $repo_url)
@@ -30,4 +30,9 @@ ls -lh @@ -30,4 +30,9 @@ ls -lh
30 30
31 ls -lh $repo 31 ls -lh $repo
32 32
33 -python3 python-api-examples/decode-file.py 33 +python3 ./python-api-examples/decode-file.py \
  34 + --tokens=$repo/tokens.txt \
  35 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  36 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  37 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  38 + --wave-filename=$repo/test_wavs/4.wav
@@ -33,3 +33,4 @@ decode-file @@ -33,3 +33,4 @@ decode-file
33 *.dylib 33 *.dylib
34 tokens.txt 34 tokens.txt
35 *.onnx 35 *.onnx
  36 +log.txt
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.1") 4 +set(SHERPA_ONNX_VERSION "1.2")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -9,27 +9,83 @@ https://k2-fsa.github.io/sherpa/onnx/index.html @@ -9,27 +9,83 @@ https://k2-fsa.github.io/sherpa/onnx/index.html
9 to install sherpa-onnx and to download the pre-trained models 9 to install sherpa-onnx and to download the pre-trained models
10 used in this file. 10 used in this file.
11 """ 11 """
12 -import wave 12 +import argparse
13 import time 13 import time
  14 +import wave
  15 +from pathlib import Path
14 16
15 import numpy as np 17 import numpy as np
16 import sherpa_onnx 18 import sherpa_onnx
17 19
18 20
  21 +def assert_file_exists(filename: str):
  22 + assert Path(
  23 + filename
  24 + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  25 +
  26 +
  27 +def get_args():
  28 + parser = argparse.ArgumentParser(
  29 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  30 + )
  31 +
  32 + parser.add_argument(
  33 + "--tokens",
  34 + type=str,
  35 + help="Path to tokens.txt",
  36 + )
  37 +
  38 + parser.add_argument(
  39 + "--encoder",
  40 + type=str,
  41 + help="Path to the encoder model",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--decoder",
  46 + type=str,
  47 + help="Path to the decoder model",
  48 + )
  49 +
  50 + parser.add_argument(
  51 + "--joiner",
  52 + type=str,
  53 + help="Path to the joiner model",
  54 + )
  55 +
  56 + parser.add_argument(
  57 + "--wave-filename",
  58 + type=str,
  59 + help="""Path to the wave filename. Must be 16 kHz,
  60 + mono with 16-bit samples""",
  61 + )
  62 +
  63 + return parser.parse_args()
  64 +
  65 +
19 def main(): 66 def main():
20 sample_rate = 16000 67 sample_rate = 16000
21 - num_threads = 4 68 + num_threads = 2
  69 +
  70 + args = get_args()
  71 + assert_file_exists(args.encoder)
  72 + assert_file_exists(args.decoder)
  73 + assert_file_exists(args.joiner)
  74 + assert_file_exists(args.tokens)
  75 + if not Path(args.wave_filename).is_file():
  76 + print(f"{args.wave_filename} does not exist!")
  77 + return
  78 +
22 recognizer = sherpa_onnx.OnlineRecognizer( 79 recognizer = sherpa_onnx.OnlineRecognizer(
23 - tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt",  
24 - encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx",  
25 - decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx",  
26 - joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", 80 + tokens=args.tokens,
  81 + encoder=args.encoder,
  82 + decoder=args.decoder,
  83 + joiner=args.joiner,
27 num_threads=num_threads, 84 num_threads=num_threads,
28 sample_rate=sample_rate, 85 sample_rate=sample_rate,
29 feature_dim=80, 86 feature_dim=80,
30 ) 87 )
31 - filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"  
32 - with wave.open(filename) as f: 88 + with wave.open(args.wave_filename) as f:
33 assert f.getframerate() == sample_rate, f.getframerate() 89 assert f.getframerate() == sample_rate, f.getframerate()
34 assert f.getnchannels() == 1, f.getnchannels() 90 assert f.getnchannels() == 1, f.getnchannels()
35 assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes 91 assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
@@ -7,7 +7,9 @@ @@ -7,7 +7,9 @@
7 # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 7 # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
8 # to download pre-trained models 8 # to download pre-trained models
9 9
  10 +import argparse
10 import sys 11 import sys
  12 +from pathlib import Path
11 13
12 try: 14 try:
13 import sounddevice as sd 15 import sounddevice as sd
@@ -22,18 +24,65 @@ except ImportError as e: @@ -22,18 +24,65 @@ except ImportError as e:
22 import sherpa_onnx 24 import sherpa_onnx
23 25
24 26
  27 +def assert_file_exists(filename: str):
  28 + assert Path(
  29 + filename
  30 + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  31 +
  32 +
  33 +def get_args():
  34 + parser = argparse.ArgumentParser(
  35 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  36 + )
  37 +
  38 + parser.add_argument(
  39 + "--tokens",
  40 + type=str,
  41 + help="Path to tokens.txt",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--encoder",
  46 + type=str,
  47 + help="Path to the encoder model",
  48 + )
  49 +
  50 + parser.add_argument(
  51 + "--decoder",
  52 + type=str,
  53 + help="Path to the decoder model",
  54 + )
  55 +
  56 + parser.add_argument(
  57 + "--joiner",
  58 + type=str,
  59 + help="Path to the joiner model",
  60 + )
  61 +
  62 + parser.add_argument(
  63 + "--wave-filename",
  64 + type=str,
  65 + help="""Path to the wave filename. Must be 16 kHz,
  66 + mono with 16-bit samples""",
  67 + )
  68 +
  69 + return parser.parse_args()
  70 +
  71 +
25 def create_recognizer(): 72 def create_recognizer():
  73 + args = get_args()
  74 + assert_file_exists(args.encoder)
  75 + assert_file_exists(args.decoder)
  76 + assert_file_exists(args.joiner)
  77 + assert_file_exists(args.tokens)
26 # Please replace the model files if needed. 78 # Please replace the model files if needed.
27 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 79 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
28 # for download links. 80 # for download links.
29 recognizer = sherpa_onnx.OnlineRecognizer( 81 recognizer = sherpa_onnx.OnlineRecognizer(
30 - tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt",  
31 - encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx",  
32 - decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx",  
33 - joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx",  
34 - num_threads=4,  
35 - sample_rate=16000,  
36 - feature_dim=80, 82 + tokens=args.tokens,
  83 + encoder=args.encoder,
  84 + decoder=args.decoder,
  85 + joiner=args.joiner,
37 enable_endpoint_detection=True, 86 enable_endpoint_detection=True,
38 rule1_min_trailing_silence=2.4, 87 rule1_min_trailing_silence=2.4,
39 rule2_min_trailing_silence=1.2, 88 rule2_min_trailing_silence=1.2,
@@ -6,7 +6,9 @@ @@ -6,7 +6,9 @@
6 # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 6 # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
7 # to download pre-trained models 7 # to download pre-trained models
8 8
  9 +import argparse
9 import sys 10 import sys
  11 +from pathlib import Path
10 12
11 try: 13 try:
12 import sounddevice as sd 14 import sounddevice as sd
@@ -21,15 +23,65 @@ except ImportError as e: @@ -21,15 +23,65 @@ except ImportError as e:
21 import sherpa_onnx 23 import sherpa_onnx
22 24
23 25
  26 +def assert_file_exists(filename: str):
  27 + assert Path(
  28 + filename
  29 + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  30 +
  31 +
  32 +def get_args():
  33 + parser = argparse.ArgumentParser(
  34 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  35 + )
  36 +
  37 + parser.add_argument(
  38 + "--tokens",
  39 + type=str,
  40 + help="Path to tokens.txt",
  41 + )
  42 +
  43 + parser.add_argument(
  44 + "--encoder",
  45 + type=str,
  46 + help="Path to the encoder model",
  47 + )
  48 +
  49 + parser.add_argument(
  50 + "--decoder",
  51 + type=str,
  52 + help="Path to the decoder model",
  53 + )
  54 +
  55 + parser.add_argument(
  56 + "--joiner",
  57 + type=str,
  58 + help="Path to the joiner model",
  59 + )
  60 +
  61 + parser.add_argument(
  62 + "--wave-filename",
  63 + type=str,
  64 + help="""Path to the wave filename. Must be 16 kHz,
  65 + mono with 16-bit samples""",
  66 + )
  67 +
  68 + return parser.parse_args()
  69 +
  70 +
24 def create_recognizer(): 71 def create_recognizer():
  72 + args = get_args()
  73 + assert_file_exists(args.encoder)
  74 + assert_file_exists(args.decoder)
  75 + assert_file_exists(args.joiner)
  76 + assert_file_exists(args.tokens)
25 # Please replace the model files if needed. 77 # Please replace the model files if needed.
26 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 78 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
27 # for download links. 79 # for download links.
28 recognizer = sherpa_onnx.OnlineRecognizer( 80 recognizer = sherpa_onnx.OnlineRecognizer(
29 - tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt",  
30 - encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx",  
31 - decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx",  
32 - joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", 81 + tokens=args.tokens,
  82 + encoder=args.encoder,
  83 + decoder=args.decoder,
  84 + joiner=args.joiner,
33 num_threads=4, 85 num_threads=4,
34 sample_rate=16000, 86 sample_rate=16000,
35 feature_dim=80, 87 feature_dim=80,
@@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/onnx-utils.h" 4 #include "sherpa-onnx/csrc/onnx-utils.h"
5 5
  6 +#include <algorithm>
6 #include <fstream> 7 #include <fstream>
7 #include <string> 8 #include <string>
8 #include <vector> 9 #include <vector>