Fangjun Kuang
Committed by GitHub

Fix displaying streaming speech recognition results for Python. (#2196)

@@ -14,7 +14,7 @@ project(sherpa-onnx) @@ -14,7 +14,7 @@ project(sherpa-onnx)
14 # Remember to update 14 # Remember to update
15 # ./CHANGELOG.md 15 # ./CHANGELOG.md
16 # ./new-release.sh 16 # ./new-release.sh
17 -set(SHERPA_ONNX_VERSION "1.11.5") 17 +set(SHERPA_ONNX_VERSION "1.11.6")
18 18
19 # Disable warning about 19 # Disable warning about
20 # 20 #
@@ -11,8 +11,8 @@ @@ -11,8 +11,8 @@
11 # to download pre-trained models 11 # to download pre-trained models
12 12
13 import argparse 13 import argparse
14 -import sys  
15 from pathlib import Path 14 from pathlib import Path
  15 +
16 import sherpa_onnx 16 import sherpa_onnx
17 17
18 18
@@ -202,8 +202,8 @@ def main(): @@ -202,8 +202,8 @@ def main():
202 202
203 stream = recognizer.create_stream() 203 stream = recognizer.create_stream()
204 204
205 - last_result = ""  
206 - segment_id = 0 205 + display = sherpa_onnx.Display()
  206 +
207 while True: 207 while True:
208 samples = alsa.read(samples_per_read) # a blocking read 208 samples = alsa.read(samples_per_read) # a blocking read
209 stream.accept_waveform(sample_rate, samples) 209 stream.accept_waveform(sample_rate, samples)
@@ -214,13 +214,14 @@ def main(): @@ -214,13 +214,14 @@ def main():
214 214
215 result = recognizer.get_result(stream) 215 result = recognizer.get_result(stream)
216 216
217 - if result and (last_result != result):  
218 - last_result = result  
219 - print("\r{}:{}".format(segment_id, result), end="", flush=True) 217 + display.update_text(result)
  218 + display.display()
  219 +
220 if is_endpoint: 220 if is_endpoint:
221 if result: 221 if result:
222 - print("\r{}:{}".format(segment_id, result), flush=True)  
223 - segment_id += 1 222 + display.finalize_current_sentence()
  223 + display.display()
  224 +
224 recognizer.reset(stream) 225 recognizer.reset(stream)
225 226
226 227
@@ -192,8 +192,8 @@ def main(): @@ -192,8 +192,8 @@ def main():
192 192
193 stream = recognizer.create_stream() 193 stream = recognizer.create_stream()
194 194
195 - last_result = ""  
196 - segment_id = 0 195 + display = sherpa_onnx.Display()
  196 +
197 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 197 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
198 while True: 198 while True:
199 samples, _ = s.read(samples_per_read) # a blocking read 199 samples, _ = s.read(samples_per_read) # a blocking read
@@ -206,13 +206,14 @@ def main(): @@ -206,13 +206,14 @@ def main():
206 206
207 result = recognizer.get_result(stream) 207 result = recognizer.get_result(stream)
208 208
209 - if result and (last_result != result):  
210 - last_result = result  
211 - print("\r{}:{}".format(segment_id, result), end="", flush=True) 209 + display.update_text(result)
  210 + display.display()
  211 +
212 if is_endpoint: 212 if is_endpoint:
213 if result: 213 if result:
214 - print("\r{}:{}".format(segment_id, result), flush=True)  
215 - segment_id += 1 214 + display.finalize_current_sentence()
  215 + display.display()
  216 +
216 recognizer.reset(stream) 217 recognizer.reset(stream)
217 218
218 219
@@ -192,8 +192,7 @@ def main(): @@ -192,8 +192,7 @@ def main():
192 192
193 stream = recognizer.create_stream() 193 stream = recognizer.create_stream()
194 194
195 - last_result = ""  
196 - segment_id = 0 195 + display = sherpa_onnx.Display()
197 196
198 print("Started!") 197 print("Started!")
199 while True: 198 while True:
@@ -213,13 +212,14 @@ def main(): @@ -213,13 +212,14 @@ def main():
213 212
214 result = recognizer.get_result(stream) 213 result = recognizer.get_result(stream)
215 214
216 - if result and (last_result != result):  
217 - last_result = result  
218 - print("\r{}:{}".format(segment_id, result), end="", flush=True) 215 + display.update_text(result)
  216 + display.display()
  217 +
219 if is_endpoint: 218 if is_endpoint:
220 if result: 219 if result:
221 - print("\r{}:{}".format(segment_id, result), flush=True)  
222 - segment_id += 1 220 + display.finalize_current_sentence()
  221 + display.display()
  222 +
223 recognizer.reset(stream) 223 recognizer.reset(stream)
224 224
225 225
@@ -74,8 +74,8 @@ def main(): @@ -74,8 +74,8 @@ def main():
74 74
75 stream = recognizer.create_stream() 75 stream = recognizer.create_stream()
76 76
77 - last_result = ""  
78 - segment_id = 0 77 + display = sherpa_onnx.Display()
  78 +
79 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 79 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
80 while True: 80 while True:
81 samples, _ = s.read(samples_per_read) # a blocking read 81 samples, _ = s.read(samples_per_read) # a blocking read
@@ -88,13 +88,14 @@ def main(): @@ -88,13 +88,14 @@ def main():
88 88
89 result = recognizer.get_result(stream) 89 result = recognizer.get_result(stream)
90 90
91 - if result and (last_result != result):  
92 - last_result = result  
93 - print("\r{}:{}".format(segment_id, result), end="", flush=True) 91 + display.update_text(result)
  92 + display.display()
  93 +
94 if is_endpoint: 94 if is_endpoint:
95 if result: 95 if result:
96 - print("\r{}:{}".format(segment_id, result), flush=True)  
97 - segment_id += 1 96 + display.finalize_current_sentence()
  97 + display.display()
  98 +
98 recognizer.reset(stream) 99 recognizer.reset(stream)
99 100
100 101
@@ -46,7 +46,6 @@ python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \ @@ -46,7 +46,6 @@ python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \
46 import argparse 46 import argparse
47 import sys 47 import sys
48 from pathlib import Path 48 from pathlib import Path
49 -from typing import List  
50 49
51 import numpy as np 50 import numpy as np
52 51
@@ -375,8 +374,7 @@ def main(): @@ -375,8 +374,7 @@ def main():
375 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 374 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
376 stream = first_recognizer.create_stream() 375 stream = first_recognizer.create_stream()
377 376
378 - last_result = ""  
379 - segment_id = 0 377 + display = sherpa_onnx.Display()
380 378
381 sample_buffers = [] 379 sample_buffers = []
382 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 380 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
@@ -395,14 +393,8 @@ def main(): @@ -395,14 +393,8 @@ def main():
395 result = first_recognizer.get_result(stream) 393 result = first_recognizer.get_result(stream)
396 result = result.lower().strip() 394 result = result.lower().strip()
397 395
398 - if last_result != result:  
399 - print(  
400 - "\r{}:{}".format(segment_id, " " * len(last_result)),  
401 - end="",  
402 - flush=True,  
403 - )  
404 - last_result = result  
405 - print("\r{}:{}".format(segment_id, result), end="", flush=True) 396 + display.update_text(result)
  397 + display.display()
406 398
407 if is_endpoint: 399 if is_endpoint:
408 if result: 400 if result:
@@ -419,14 +411,9 @@ def main(): @@ -419,14 +411,9 @@ def main():
419 sample_rate=sample_rate, 411 sample_rate=sample_rate,
420 ) 412 )
421 result = result.lower().strip() 413 result = result.lower().strip()
422 -  
423 - print(  
424 - "\r{}:{}".format(segment_id, " " * len(last_result)),  
425 - end="",  
426 - flush=True,  
427 - )  
428 - print("\r{}:{}".format(segment_id, result), flush=True)  
429 - segment_id += 1 414 + display.update_text(result)
  415 + display.finalize_current_sentence()
  416 + display.display()
430 else: 417 else:
431 sample_buffers = [] 418 sample_buffers = []
432 419
@@ -6,7 +6,6 @@ from _sherpa_onnx import ( @@ -6,7 +6,6 @@ from _sherpa_onnx import (
6 AudioTaggingModelConfig, 6 AudioTaggingModelConfig,
7 CircularBuffer, 7 CircularBuffer,
8 DenoisedAudio, 8 DenoisedAudio,
9 - Display,  
10 FastClustering, 9 FastClustering,
11 FastClusteringConfig, 10 FastClusteringConfig,
12 OfflinePunctuation, 11 OfflinePunctuation,
@@ -48,6 +47,7 @@ from _sherpa_onnx import ( @@ -48,6 +47,7 @@ from _sherpa_onnx import (
48 write_wave, 47 write_wave,
49 ) 48 )
50 49
  50 +from .display import Display
51 from .keyword_spotter import KeywordSpotter 51 from .keyword_spotter import KeywordSpotter
52 from .offline_recognizer import OfflineRecognizer 52 from .offline_recognizer import OfflineRecognizer
53 from .online_recognizer import OnlineRecognizer 53 from .online_recognizer import OnlineRecognizer
  1 +# Copyright (c) 2025 Xiaomi Corporation
  2 +import os
  3 +from time import gmtime, strftime
  4 +
  5 +
  6 +def get_current_time():
  7 + return strftime("%Y-%m-%d %H:%M:%S", gmtime())
  8 +
  9 +
  10 +def clear_console():
  11 + os.system("cls" if os.name == "nt" else "clear")
  12 +
  13 +
  14 +class Display:
  15 + def __init__(self):
  16 + self.sentences = []
  17 + self.currentText = ""
  18 +
  19 + def update_text(self, text):
  20 + self.currentText = text
  21 +
  22 + def finalize_current_sentence(self):
  23 + if self.currentText.strip():
  24 + self.sentences.append((get_current_time(), self.currentText))
  25 +
  26 + self.currentText = ""
  27 +
  28 + def display(self):
  29 + clear_console()
  30 + print("=== Speech Recognition with Next-gen Kaldi ===")
  31 + print("Time:", get_current_time())
  32 + print("-" * 30)
  33 +
  34 + # display history sentences
  35 + if self.sentences:
  36 + for i, (when, text) in enumerate(self.sentences):
  37 + print(f"[{when}] {i + 1}. {text}")
  38 + print("-" * 30)
  39 +
  40 + if self.currentText.strip():
  41 + print("Recognizing:", self.currentText)