Fangjun Kuang
Committed by GitHub

Add an example for computing RTF about streaming ASR. (#1501)

@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) 3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) 4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
5 5
  6 +add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
  7 +target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
  8 +
6 add_executable(whisper-cxx-api ./whisper-cxx-api.cc) 9 add_executable(whisper-cxx-api ./whisper-cxx-api.cc)
7 target_link_libraries(whisper-cxx-api sherpa-onnx-cxx-api) 10 target_link_libraries(whisper-cxx-api sherpa-onnx-cxx-api)
8 11
  1 +// cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc
  2 +// Copyright (c) 2024 Xiaomi Corporation
  3 +
  4 +//
  5 +// This file demonstrates how to use streaming Zipformer
  6 +// with sherpa-onnx's C++ API.
  7 +//
  8 +// clang-format off
  9 +//
  10 +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
  11 +// tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
  12 +// rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
  13 +//
  14 +// clang-format on
  15 +
  16 +#include <chrono> // NOLINT
  17 +#include <iostream>
  18 +#include <string>
  19 +
  20 +#include "sherpa-onnx/c-api/cxx-api.h"
  21 +
  22 +int32_t main(int argc, char *argv[]) {
  23 + int32_t num_runs = 1;
  24 + if (argc == 2) {
  25 + num_runs = atoi(argv[1]);
  26 + if (num_runs < 0) {
  27 + num_runs = 1;
  28 + }
  29 + }
  30 +
  31 + using namespace sherpa_onnx::cxx; // NOLINT
  32 + OnlineRecognizerConfig config;
  33 +
  34 + // please see
  35 + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
  36 + config.model_config.transducer.encoder =
  37 + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/"
  38 + "encoder-epoch-99-avg-1.int8.onnx";
  39 +
  40 + // Note: We recommend not using int8.onnx for the decoder.
  41 + config.model_config.transducer.decoder =
  42 + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/"
  43 + "decoder-epoch-99-avg-1.onnx";
  44 +
  45 + config.model_config.transducer.joiner =
  46 + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/"
  47 + "joiner-epoch-99-avg-1.int8.onnx";
  48 +
  49 + config.model_config.tokens =
  50 + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt";
  51 +
  52 + config.model_config.num_threads = 1;
  53 +
  54 + std::cout << "Loading model\n";
  55 + OnlineRecognizer recongizer = OnlineRecognizer::Create(config);
  56 + if (!recongizer.Get()) {
  57 + std::cerr << "Please check your config\n";
  58 + return -1;
  59 + }
  60 + std::cout << "Loading model done\n";
  61 +
  62 + std::string wave_filename =
  63 + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/"
  64 + "0.wav";
  65 + Wave wave = ReadWave(wave_filename);
  66 + if (wave.samples.empty()) {
  67 + std::cerr << "Failed to read: '" << wave_filename << "'\n";
  68 + return -1;
  69 + }
  70 +
  71 + std::cout << "Start recognition\n";
  72 + float total_elapsed_seconds = 0;
  73 + OnlineRecognizerResult result;
  74 + for (int32_t i = 0; i < num_runs; ++i) {
  75 + const auto begin = std::chrono::steady_clock::now();
  76 +
  77 + OnlineStream stream = recongizer.CreateStream();
  78 + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
  79 + wave.samples.size());
  80 + stream.InputFinished();
  81 +
  82 + while (recongizer.IsReady(&stream)) {
  83 + recongizer.Decode(&stream);
  84 + }
  85 +
  86 + result = recongizer.GetResult(&stream);
  87 +
  88 + auto end = std::chrono::steady_clock::now();
  89 + float elapsed_seconds =
  90 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  91 + .count() /
  92 + 1000.;
  93 + printf("Run %d/%d, elapsed seconds: %.3f\n", i, num_runs, elapsed_seconds);
  94 + total_elapsed_seconds += elapsed_seconds;
  95 + }
  96 + float average_elapsed_secodns = total_elapsed_seconds / num_runs;
  97 + float duration = wave.samples.size() / static_cast<float>(wave.sample_rate);
  98 + float rtf = total_elapsed_seconds / num_runs / duration;
  99 +
  100 + std::cout << "text: " << result.text << "\n";
  101 + printf("Number of threads: %d\n", config.model_config.num_threads);
  102 + printf("Duration: %.3fs\n", duration);
  103 + printf("Total Elapsed seconds: %.3fs\n", total_elapsed_seconds);
  104 + printf("Num runs: %d\n", num_runs);
  105 + printf("Elapsed seconds per run: %.3f/%d=%.3f\n", total_elapsed_seconds,
  106 + num_runs, average_elapsed_secodns);
  107 + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n",
  108 + average_elapsed_secodns, duration, rtf);
  109 +
  110 + return 0;
  111 +}