Committed by
GitHub
Add an example for computing RTF about streaming ASR. (#1501)
正在显示
2 个修改的文件
包含
114 行增加
和
0 行删除
| @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) | 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) |
| 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) | 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) |
| 5 | 5 | ||
| 6 | +add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) | ||
| 7 | +target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api) | ||
| 8 | + | ||
| 6 | add_executable(whisper-cxx-api ./whisper-cxx-api.cc) | 9 | add_executable(whisper-cxx-api ./whisper-cxx-api.cc) |
| 7 | target_link_libraries(whisper-cxx-api sherpa-onnx-cxx-api) | 10 | target_link_libraries(whisper-cxx-api sherpa-onnx-cxx-api) |
| 8 | 11 |
| 1 | +// cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc | ||
| 2 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 3 | + | ||
| 4 | +// | ||
| 5 | +// This file demonstrates how to use streaming Zipformer | ||
| 6 | +// with sherpa-onnx's C++ API. | ||
| 7 | +// | ||
| 8 | +// clang-format off | ||
| 9 | +// | ||
| 10 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 11 | +// tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 12 | +// rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 13 | +// | ||
| 14 | +// clang-format on | ||
| 15 | + | ||
| 16 | +#include <chrono> // NOLINT | ||
| 17 | +#include <iostream> | ||
| 18 | +#include <string> | ||
| 19 | + | ||
| 20 | +#include "sherpa-onnx/c-api/cxx-api.h" | ||
| 21 | + | ||
| 22 | +int32_t main(int argc, char *argv[]) { | ||
| 23 | + int32_t num_runs = 1; | ||
| 24 | + if (argc == 2) { | ||
| 25 | + num_runs = atoi(argv[1]); | ||
| 26 | + if (num_runs < 0) { | ||
| 27 | + num_runs = 1; | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + using namespace sherpa_onnx::cxx; // NOLINT | ||
| 32 | + OnlineRecognizerConfig config; | ||
| 33 | + | ||
| 34 | + // please see | ||
| 35 | + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english | ||
| 36 | + config.model_config.transducer.encoder = | ||
| 37 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" | ||
| 38 | + "encoder-epoch-99-avg-1.int8.onnx"; | ||
| 39 | + | ||
| 40 | + // Note: We recommend not using int8.onnx for the decoder. | ||
| 41 | + config.model_config.transducer.decoder = | ||
| 42 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" | ||
| 43 | + "decoder-epoch-99-avg-1.onnx"; | ||
| 44 | + | ||
| 45 | + config.model_config.transducer.joiner = | ||
| 46 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" | ||
| 47 | + "joiner-epoch-99-avg-1.int8.onnx"; | ||
| 48 | + | ||
| 49 | + config.model_config.tokens = | ||
| 50 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"; | ||
| 51 | + | ||
| 52 | + config.model_config.num_threads = 1; | ||
| 53 | + | ||
| 54 | + std::cout << "Loading model\n"; | ||
| 55 | + OnlineRecognizer recongizer = OnlineRecognizer::Create(config); | ||
| 56 | + if (!recongizer.Get()) { | ||
| 57 | + std::cerr << "Please check your config\n"; | ||
| 58 | + return -1; | ||
| 59 | + } | ||
| 60 | + std::cout << "Loading model done\n"; | ||
| 61 | + | ||
| 62 | + std::string wave_filename = | ||
| 63 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/" | ||
| 64 | + "0.wav"; | ||
| 65 | + Wave wave = ReadWave(wave_filename); | ||
| 66 | + if (wave.samples.empty()) { | ||
| 67 | + std::cerr << "Failed to read: '" << wave_filename << "'\n"; | ||
| 68 | + return -1; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + std::cout << "Start recognition\n"; | ||
| 72 | + float total_elapsed_seconds = 0; | ||
| 73 | + OnlineRecognizerResult result; | ||
| 74 | + for (int32_t i = 0; i < num_runs; ++i) { | ||
| 75 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 76 | + | ||
| 77 | + OnlineStream stream = recongizer.CreateStream(); | ||
| 78 | + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), | ||
| 79 | + wave.samples.size()); | ||
| 80 | + stream.InputFinished(); | ||
| 81 | + | ||
| 82 | + while (recongizer.IsReady(&stream)) { | ||
| 83 | + recongizer.Decode(&stream); | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + result = recongizer.GetResult(&stream); | ||
| 87 | + | ||
| 88 | + auto end = std::chrono::steady_clock::now(); | ||
| 89 | + float elapsed_seconds = | ||
| 90 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 91 | + .count() / | ||
| 92 | + 1000.; | ||
| 93 | + printf("Run %d/%d, elapsed seconds: %.3f\n", i, num_runs, elapsed_seconds); | ||
| 94 | + total_elapsed_seconds += elapsed_seconds; | ||
| 95 | + } | ||
| 96 | + float average_elapsed_secodns = total_elapsed_seconds / num_runs; | ||
| 97 | + float duration = wave.samples.size() / static_cast<float>(wave.sample_rate); | ||
| 98 | + float rtf = total_elapsed_seconds / num_runs / duration; | ||
| 99 | + | ||
| 100 | + std::cout << "text: " << result.text << "\n"; | ||
| 101 | + printf("Number of threads: %d\n", config.model_config.num_threads); | ||
| 102 | + printf("Duration: %.3fs\n", duration); | ||
| 103 | + printf("Total Elapsed seconds: %.3fs\n", total_elapsed_seconds); | ||
| 104 | + printf("Num runs: %d\n", num_runs); | ||
| 105 | + printf("Elapsed seconds per run: %.3f/%d=%.3f\n", total_elapsed_seconds, | ||
| 106 | + num_runs, average_elapsed_secodns); | ||
| 107 | + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", | ||
| 108 | + average_elapsed_secodns, duration, rtf); | ||
| 109 | + | ||
| 110 | + return 0; | ||
| 111 | +} |
-
请 注册 或 登录 后发表评论