Committed by
GitHub
Fix c api (#2545)
This PR fixes the C API by adding proper support for durations in offline recognition results. The issue addresses problems introduced in a previous PR where the durations field was added to the C API struct but not properly handled across all language bindings. Key changes: - Adds durations field handling across multiple language bindings (Swift, Kotlin, Java, C#) - Fixes field ordering in C API struct to ensure ABI compatibility - Updates JNI implementation to properly extract and pass durations data
正在显示
14 个修改的文件
包含
172 行增加
和
27 行删除
| @@ -40,7 +40,15 @@ function(download_cppinyin) | @@ -40,7 +40,15 @@ function(download_cppinyin) | ||
| 40 | if(NOT cppinyin_POPULATED) | 40 | if(NOT cppinyin_POPULATED) |
| 41 | message(STATUS "Downloading cppinyin ${cppinyin_URL}") | 41 | message(STATUS "Downloading cppinyin ${cppinyin_URL}") |
| 42 | FetchContent_Populate(cppinyin) | 42 | FetchContent_Populate(cppinyin) |
| 43 | + | ||
| 44 | + file(REMOVE ${cppinyin_SOURCE_DIR}/CMakeLists.txt) | ||
| 45 | + configure_file( | ||
| 46 | + ${CMAKE_SOURCE_DIR}/cmake/cppinyin.patch | ||
| 47 | + ${cppinyin_SOURCE_DIR}/CMakeLists.txt | ||
| 48 | + COPYONLY | ||
| 49 | + ) | ||
| 43 | endif() | 50 | endif() |
| 51 | + | ||
| 44 | message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}") | 52 | message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}") |
| 45 | 53 | ||
| 46 | if(BUILD_SHARED_LIBS) | 54 | if(BUILD_SHARED_LIBS) |
cmake/cppinyin.patch
0 → 100644
| 1 | +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) | ||
| 2 | + | ||
| 3 | +project(cppinyin) | ||
| 4 | + | ||
| 5 | +set(CPPINYIN_VERSION "0.10") | ||
| 6 | + | ||
| 7 | +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | ||
| 8 | +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | ||
| 9 | +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") | ||
| 10 | + | ||
| 11 | +set(CMAKE_SKIP_BUILD_RPATH FALSE) | ||
| 12 | +set(BUILD_RPATH_USE_ORIGIN TRUE) | ||
| 13 | +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) | ||
| 14 | + | ||
| 15 | +if(NOT APPLE) | ||
| 16 | + set(CPPINYIN_RPATH_ORIGIN "$ORIGIN") | ||
| 17 | +else() | ||
| 18 | + set(CPPINYIN_RPATH_ORIGIN "@loader_path") | ||
| 19 | +endif() | ||
| 20 | + | ||
| 21 | +set(CMAKE_INSTALL_RPATH ${CPPINYIN_RPATH_ORIGIN}) | ||
| 22 | +set(CMAKE_BUILD_RPATH ${CPPINYIN_RPATH_ORIGIN}) | ||
| 23 | + | ||
| 24 | +option(CPPINYIN_ENABLE_TESTS "Whether to build tests" OFF) | ||
| 25 | +option(CPPINYIN_BUILD_PYTHON "Whether to build Python" OFF) | ||
| 26 | +option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON) | ||
| 27 | + | ||
| 28 | +if(NOT CMAKE_BUILD_TYPE) | ||
| 29 | + message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") | ||
| 30 | + set(CMAKE_BUILD_TYPE Release) | ||
| 31 | +endif() | ||
| 32 | + | ||
| 33 | +list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | ||
| 34 | +list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) | ||
| 35 | + | ||
| 36 | +include(CheckCXXCompilerFlag) | ||
| 37 | +if(NOT WIN32) | ||
| 38 | + check_cxx_compiler_flag("-std=c++14" CPPINYIN_COMPILER_SUPPORTS_CXX14) | ||
| 39 | +else() | ||
| 40 | + # windows x86 or x86_64 | ||
| 41 | + check_cxx_compiler_flag("/std:c++14" CPPINYIN_COMPILER_SUPPORTS_CXX14) | ||
| 42 | +endif() | ||
| 43 | +if(NOT CPPINYIN_COMPILER_SUPPORTS_CXX14) | ||
| 44 | + message(FATAL_ERROR " | ||
| 45 | + cppinyin requires a compiler supporting at least C++14. | ||
| 46 | + If you are using GCC, please upgrade it to at least version 7.0. | ||
| 47 | + If you are using Clang, please upgrade it to at least version 3.4.") | ||
| 48 | +endif() | ||
| 49 | + | ||
| 50 | +if(NOT CMAKE_CXX_STANDARD) | ||
| 51 | + set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") | ||
| 52 | +endif() | ||
| 53 | +set(CMAKE_CXX_EXTENSIONS OFF) | ||
| 54 | +message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}") | ||
| 55 | + | ||
| 56 | +if(CPPINYIN_BUILD_PYTHON) | ||
| 57 | + include(pybind11) | ||
| 58 | +endif() | ||
| 59 | + | ||
| 60 | +include_directories(${CMAKE_SOURCE_DIR}) | ||
| 61 | + | ||
| 62 | +if(WIN32) | ||
| 63 | + # disable various warnings for MSVC | ||
| 64 | + # 4244: 'initializing': conversion from 'float' to 'int32_t', | ||
| 65 | + # 4267: 'argument': conversion from 'size_t' to 'uint32_t', possible loss of data | ||
| 66 | + set(disabled_warnings | ||
| 67 | + /wd4244 | ||
| 68 | + /wd4267 | ||
| 69 | + ) | ||
| 70 | + message(STATUS "Disabled warnings: ${disabled_warnings}") | ||
| 71 | + foreach(w IN LISTS disabled_warnings) | ||
| 72 | + string(APPEND CMAKE_CXX_FLAGS " ${w} ") | ||
| 73 | + endforeach() | ||
| 74 | +endif() | ||
| 75 | + | ||
| 76 | +if(CPPINYIN_ENABLE_TESTS) | ||
| 77 | + include(googletest) | ||
| 78 | + enable_testing() | ||
| 79 | +endif() | ||
| 80 | + | ||
| 81 | +add_subdirectory(cppinyin) |
| @@ -52,7 +52,7 @@ function process_windows_x64() { | @@ -52,7 +52,7 @@ function process_windows_x64() { | ||
| 52 | cd t | 52 | cd t |
| 53 | curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename | 53 | curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename |
| 54 | unzip $windows_x64_wheel_filename | 54 | unzip $windows_x64_wheel_filename |
| 55 | - cp -v sherpa_onnx-${SHERPA_ONNX_VERSION}.data/data/bin/*.dll ../windows | 55 | + cp -v sherpa_onnx/lib/*.dll ../windows |
| 56 | cd .. | 56 | cd .. |
| 57 | rm -rf t | 57 | rm -rf t |
| 58 | } | 58 | } |
| @@ -77,6 +77,21 @@ namespace SherpaOnnx | @@ -77,6 +77,21 @@ namespace SherpaOnnx | ||
| 77 | } | 77 | } |
| 78 | } | 78 | } |
| 79 | 79 | ||
| 80 | + unsafe | ||
| 81 | + { | ||
| 82 | + if (impl.Durations != IntPtr.Zero) | ||
| 83 | + { | ||
| 84 | + float *d = (float*)impl.Durations; | ||
| 85 | + _durations = new float[impl.Count]; | ||
| 86 | + fixed (float* f = _durations) | ||
| 87 | + { | ||
| 88 | + for (int k = 0; k < impl.Count; k++) | ||
| 89 | + { | ||
| 90 | + f[k] = d[k]; | ||
| 91 | + } | ||
| 92 | + } | ||
| 93 | + } | ||
| 94 | + } | ||
| 80 | } | 95 | } |
| 81 | 96 | ||
| 82 | [StructLayout(LayoutKind.Sequential)] | 97 | [StructLayout(LayoutKind.Sequential)] |
| @@ -86,6 +101,7 @@ namespace SherpaOnnx | @@ -86,6 +101,7 @@ namespace SherpaOnnx | ||
| 86 | public IntPtr Timestamps; | 101 | public IntPtr Timestamps; |
| 87 | public int Count; | 102 | public int Count; |
| 88 | public IntPtr Tokens; | 103 | public IntPtr Tokens; |
| 104 | + public IntPtr Durations; | ||
| 89 | } | 105 | } |
| 90 | 106 | ||
| 91 | private String _text; | 107 | private String _text; |
| @@ -96,5 +112,8 @@ namespace SherpaOnnx | @@ -96,5 +112,8 @@ namespace SherpaOnnx | ||
| 96 | 112 | ||
| 97 | private float[] _timestamps; | 113 | private float[] _timestamps; |
| 98 | public float[] Timestamps => _timestamps; | 114 | public float[] Timestamps => _timestamps; |
| 115 | + | ||
| 116 | + private float[] _durations; | ||
| 117 | + public float[] Durations => _durations; | ||
| 99 | } | 118 | } |
| 100 | } | 119 | } |
| @@ -134,7 +134,7 @@ if [ ! -f $src_dir/windows-x64/sherpa-onnx-c-api.dll ]; then | @@ -134,7 +134,7 @@ if [ ! -f $src_dir/windows-x64/sherpa-onnx-c-api.dll ]; then | ||
| 134 | curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename | 134 | curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename |
| 135 | fi | 135 | fi |
| 136 | unzip $windows_x64_wheel_filename | 136 | unzip $windows_x64_wheel_filename |
| 137 | - cp -v sherpa_onnx-${SHERPA_ONNX_VERSION}.data/data/bin/*.dll ../ | 137 | + cp -v sherpa_onnx/lib/*.dll ../ |
| 138 | cd .. | 138 | cd .. |
| 139 | 139 | ||
| 140 | rm -rf wheel | 140 | rm -rf wheel |
| @@ -141,7 +141,7 @@ function windows() { | @@ -141,7 +141,7 @@ function windows() { | ||
| 141 | wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl | 141 | wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl |
| 142 | unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl | 142 | unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl |
| 143 | 143 | ||
| 144 | - cp -v sherpa_onnx_core-${SHERPA_ONNX_VERSION}.data/data/Scripts/*.dll $dst | 144 | + cp -v sherpa_onnx/lib/*.dll $dst |
| 145 | 145 | ||
| 146 | cd .. | 146 | cd .. |
| 147 | rm -rf t | 147 | rm -rf t |
| @@ -153,7 +153,7 @@ function windows() { | @@ -153,7 +153,7 @@ function windows() { | ||
| 153 | wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl | 153 | wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl |
| 154 | unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl | 154 | unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl |
| 155 | 155 | ||
| 156 | - cp -v sherpa_onnx_core-${SHERPA_ONNX_VERSION}.data/data/Scripts/*.dll $dst | 156 | + cp -v sherpa_onnx/lib/*.dll $dst |
| 157 | 157 | ||
| 158 | cd .. | 158 | cd .. |
| 159 | rm -rf t | 159 | rm -rf t |
| @@ -614,10 +614,6 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | @@ -614,10 +614,6 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | ||
| 614 | // It is NULL if the model does not support timestamps | 614 | // It is NULL if the model does not support timestamps |
| 615 | float *timestamps; | 615 | float *timestamps; |
| 616 | 616 | ||
| 617 | - // Pointer to continuous memory which holds durations (in seconds) for each token | ||
| 618 | - // It is NULL if the model does not support durations | ||
| 619 | - float *durations; | ||
| 620 | - | ||
| 621 | // number of entries in timestamps | 617 | // number of entries in timestamps |
| 622 | int32_t count; | 618 | int32_t count; |
| 623 | 619 | ||
| @@ -651,6 +647,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | @@ -651,6 +647,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | ||
| 651 | 647 | ||
| 652 | // return event. | 648 | // return event. |
| 653 | const char *event; | 649 | const char *event; |
| 650 | + | ||
| 651 | + // Pointer to continuous memory which holds durations (in seconds) for each | ||
| 652 | + // token It is NULL if the model does not support durations | ||
| 653 | + float *durations; | ||
| 654 | } SherpaOnnxOfflineRecognizerResult; | 654 | } SherpaOnnxOfflineRecognizerResult; |
| 655 | 655 | ||
| 656 | /// Get the result of the offline stream. | 656 | /// Get the result of the offline stream. |
| @@ -356,22 +356,28 @@ OfflineRecognizerResult OfflineRecognizer::GetResult( | @@ -356,22 +356,28 @@ OfflineRecognizerResult OfflineRecognizer::GetResult( | ||
| 356 | ans.event = r->event ? r->event : ""; | 356 | ans.event = r->event ? r->event : ""; |
| 357 | } | 357 | } |
| 358 | 358 | ||
| 359 | + if (r->durations) { | ||
| 360 | + ans.durations.resize(r->count); | ||
| 361 | + std::copy(r->durations, r->durations + r->count, ans.durations.data()); | ||
| 362 | + } | ||
| 363 | + | ||
| 359 | SherpaOnnxDestroyOfflineRecognizerResult(r); | 364 | SherpaOnnxDestroyOfflineRecognizerResult(r); |
| 360 | 365 | ||
| 361 | return ans; | 366 | return ans; |
| 362 | } | 367 | } |
| 363 | 368 | ||
| 364 | -std::shared_ptr<OfflineRecognizerResult> OfflineRecognizer::GetResultPtr(const OfflineStream *s) const | ||
| 365 | -{ | 369 | +std::shared_ptr<OfflineRecognizerResult> OfflineRecognizer::GetResultPtr( |
| 370 | + const OfflineStream *s) const { | ||
| 366 | auto r = SherpaOnnxGetOfflineStreamResult(s->Get()); | 371 | auto r = SherpaOnnxGetOfflineStreamResult(s->Get()); |
| 367 | 372 | ||
| 368 | - OfflineRecognizerResult* ans = new OfflineRecognizerResult; | 373 | + OfflineRecognizerResult *ans = new OfflineRecognizerResult; |
| 369 | if (r) { | 374 | if (r) { |
| 370 | ans->text = r->text; | 375 | ans->text = r->text; |
| 371 | 376 | ||
| 372 | if (r->timestamps) { | 377 | if (r->timestamps) { |
| 373 | ans->timestamps.resize(r->count); | 378 | ans->timestamps.resize(r->count); |
| 374 | - std::copy(r->timestamps, r->timestamps + r->count, ans->timestamps.data()); | 379 | + std::copy(r->timestamps, r->timestamps + r->count, |
| 380 | + ans->timestamps.data()); | ||
| 375 | } | 381 | } |
| 376 | 382 | ||
| 377 | ans->tokens.resize(r->count); | 383 | ans->tokens.resize(r->count); |
| @@ -484,8 +490,7 @@ std::shared_ptr<GeneratedAudio> OfflineTts::Generate2( | @@ -484,8 +490,7 @@ std::shared_ptr<GeneratedAudio> OfflineTts::Generate2( | ||
| 484 | ans->samples = std::move(audio.samples); | 490 | ans->samples = std::move(audio.samples); |
| 485 | ans->sample_rate = audio.sample_rate; | 491 | ans->sample_rate = audio.sample_rate; |
| 486 | 492 | ||
| 487 | - return std::shared_ptr<GeneratedAudio>(ans, | ||
| 488 | - [](GeneratedAudio *p) { delete p; }); | 493 | + return std::shared_ptr<GeneratedAudio>(ans); |
| 489 | } | 494 | } |
| 490 | 495 | ||
| 491 | KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) { | 496 | KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) { |
| @@ -594,7 +599,6 @@ KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const { | @@ -594,7 +599,6 @@ KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const { | ||
| 594 | return ans; | 599 | return ans; |
| 595 | } | 600 | } |
| 596 | 601 | ||
| 597 | - | ||
| 598 | void KeywordSpotter::Reset(const OnlineStream *s) const { | 602 | void KeywordSpotter::Reset(const OnlineStream *s) const { |
| 599 | SherpaOnnxResetKeywordStream(p_, s->Get()); | 603 | SherpaOnnxResetKeywordStream(p_, s->Get()); |
| 600 | } | 604 | } |
| @@ -753,11 +757,10 @@ SpeechSegment VoiceActivityDetector::Front() const { | @@ -753,11 +757,10 @@ SpeechSegment VoiceActivityDetector::Front() const { | ||
| 753 | return segment; | 757 | return segment; |
| 754 | } | 758 | } |
| 755 | 759 | ||
| 756 | -std::shared_ptr<SpeechSegment> VoiceActivityDetector::FrontPtr() const | ||
| 757 | -{ | 760 | +std::shared_ptr<SpeechSegment> VoiceActivityDetector::FrontPtr() const { |
| 758 | auto f = SherpaOnnxVoiceActivityDetectorFront(p_); | 761 | auto f = SherpaOnnxVoiceActivityDetectorFront(p_); |
| 759 | 762 | ||
| 760 | - SpeechSegment* segment = new SpeechSegment; | 763 | + SpeechSegment *segment = new SpeechSegment; |
| 761 | segment->start = f->start; | 764 | segment->start = f->start; |
| 762 | segment->samples = std::vector<float>{f->samples, f->samples + f->n}; | 765 | segment->samples = std::vector<float>{f->samples, f->samples + f->n}; |
| 763 | 766 | ||
| @@ -824,7 +827,8 @@ bool FileExists(const std::string &filename) { | @@ -824,7 +827,8 @@ bool FileExists(const std::string &filename) { | ||
| 824 | // ============================================================ | 827 | // ============================================================ |
| 825 | // For Offline Punctuation | 828 | // For Offline Punctuation |
| 826 | // ============================================================ | 829 | // ============================================================ |
| 827 | -OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &config) { | 830 | +OfflinePunctuation OfflinePunctuation::Create( |
| 831 | + const OfflinePunctuationConfig &config) { | ||
| 828 | struct SherpaOnnxOfflinePunctuationConfig c; | 832 | struct SherpaOnnxOfflinePunctuationConfig c; |
| 829 | memset(&c, 0, sizeof(c)); | 833 | memset(&c, 0, sizeof(c)); |
| 830 | c.model.ct_transformer = config.model.ct_transformer.c_str(); | 834 | c.model.ct_transformer = config.model.ct_transformer.c_str(); |
| @@ -832,12 +836,13 @@ OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &co | @@ -832,12 +836,13 @@ OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &co | ||
| 832 | c.model.debug = config.model.debug; | 836 | c.model.debug = config.model.debug; |
| 833 | c.model.provider = config.model.provider.c_str(); | 837 | c.model.provider = config.model.provider.c_str(); |
| 834 | 838 | ||
| 835 | - const SherpaOnnxOfflinePunctuation *punct = SherpaOnnxCreateOfflinePunctuation(&c); | 839 | + const SherpaOnnxOfflinePunctuation *punct = |
| 840 | + SherpaOnnxCreateOfflinePunctuation(&c); | ||
| 836 | return OfflinePunctuation(punct); | 841 | return OfflinePunctuation(punct); |
| 837 | } | 842 | } |
| 838 | 843 | ||
| 839 | OfflinePunctuation::OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p) | 844 | OfflinePunctuation::OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p) |
| 840 | - : MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation>(p) {} | 845 | + : MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation>(p) {} |
| 841 | 846 | ||
| 842 | void OfflinePunctuation::Destroy(const SherpaOnnxOfflinePunctuation *p) const { | 847 | void OfflinePunctuation::Destroy(const SherpaOnnxOfflinePunctuation *p) const { |
| 843 | SherpaOnnxDestroyOfflinePunctuation(p); | 848 | SherpaOnnxDestroyOfflinePunctuation(p); |
| @@ -319,6 +319,9 @@ struct SHERPA_ONNX_API OfflineRecognizerResult { | @@ -319,6 +319,9 @@ struct SHERPA_ONNX_API OfflineRecognizerResult { | ||
| 319 | std::string lang; | 319 | std::string lang; |
| 320 | std::string emotion; | 320 | std::string emotion; |
| 321 | std::string event; | 321 | std::string event; |
| 322 | + | ||
| 323 | + // non-empty only for TDT models | ||
| 324 | + std::vector<float> durations; | ||
| 322 | }; | 325 | }; |
| 323 | 326 | ||
| 324 | class SHERPA_ONNX_API OfflineStream | 327 | class SHERPA_ONNX_API OfflineStream |
| @@ -349,8 +352,9 @@ class SHERPA_ONNX_API OfflineRecognizer | @@ -349,8 +352,9 @@ class SHERPA_ONNX_API OfflineRecognizer | ||
| 349 | 352 | ||
| 350 | OfflineRecognizerResult GetResult(const OfflineStream *s) const; | 353 | OfflineRecognizerResult GetResult(const OfflineStream *s) const; |
| 351 | 354 | ||
| 352 | - std::shared_ptr<OfflineRecognizerResult> GetResultPtr(const OfflineStream *s) const; | ||
| 353 | - | 355 | + std::shared_ptr<OfflineRecognizerResult> GetResultPtr( |
| 356 | + const OfflineStream *s) const; | ||
| 357 | + | ||
| 354 | void SetConfig(const OfflineRecognizerConfig &config) const; | 358 | void SetConfig(const OfflineRecognizerConfig &config) const; |
| 355 | 359 | ||
| 356 | private: | 360 | private: |
| @@ -61,7 +61,8 @@ public class OfflineRecognizer { | @@ -61,7 +61,8 @@ public class OfflineRecognizer { | ||
| 61 | String lang = (String) arr[3]; | 61 | String lang = (String) arr[3]; |
| 62 | String emotion = (String) arr[4]; | 62 | String emotion = (String) arr[4]; |
| 63 | String event = (String) arr[5]; | 63 | String event = (String) arr[5]; |
| 64 | - return new OfflineRecognizerResult(text, tokens, timestamps, lang, emotion, event); | 64 | + float[] durations = (float[]) arr[6]; |
| 65 | + return new OfflineRecognizerResult(text, tokens, timestamps, lang, emotion, event, durations); | ||
| 65 | } | 66 | } |
| 66 | 67 | ||
| 67 | private native void delete(long ptr); | 68 | private native void delete(long ptr); |
| @@ -9,14 +9,16 @@ public class OfflineRecognizerResult { | @@ -9,14 +9,16 @@ public class OfflineRecognizerResult { | ||
| 9 | private final String lang; | 9 | private final String lang; |
| 10 | private final String emotion; | 10 | private final String emotion; |
| 11 | private final String event; | 11 | private final String event; |
| 12 | + private final float[] durations; | ||
| 12 | 13 | ||
| 13 | - public OfflineRecognizerResult(String text, String[] tokens, float[] timestamps, String lang, String emotion, String event) { | 14 | + public OfflineRecognizerResult(String text, String[] tokens, float[] timestamps, String lang, String emotion, String event, float[] durations) { |
| 14 | this.text = text; | 15 | this.text = text; |
| 15 | this.tokens = tokens; | 16 | this.tokens = tokens; |
| 16 | this.timestamps = timestamps; | 17 | this.timestamps = timestamps; |
| 17 | this.lang = lang; | 18 | this.lang = lang; |
| 18 | this.emotion = emotion; | 19 | this.emotion = emotion; |
| 19 | this.event = event; | 20 | this.event = event; |
| 21 | + this.durations = durations; | ||
| 20 | } | 22 | } |
| 21 | 23 | ||
| 22 | public String getText() { | 24 | public String getText() { |
| @@ -42,4 +44,8 @@ public class OfflineRecognizerResult { | @@ -42,4 +44,8 @@ public class OfflineRecognizerResult { | ||
| 42 | public String getEvent() { | 44 | public String getEvent() { |
| 43 | return event; | 45 | return event; |
| 44 | } | 46 | } |
| 47 | + | ||
| 48 | + public float[] getDurations() { | ||
| 49 | + return durations; | ||
| 50 | + } | ||
| 45 | } | 51 | } |
| @@ -509,8 +509,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | @@ -509,8 +509,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | ||
| 509 | // [3]: lang, jstring | 509 | // [3]: lang, jstring |
| 510 | // [4]: emotion, jstring | 510 | // [4]: emotion, jstring |
| 511 | // [5]: event, jstring | 511 | // [5]: event, jstring |
| 512 | + // [6]: durations, array of float | ||
| 512 | jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( | 513 | jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( |
| 513 | - 6, env->FindClass("java/lang/Object"), nullptr); | 514 | + 7, env->FindClass("java/lang/Object"), nullptr); |
| 514 | 515 | ||
| 515 | jstring text = env->NewStringUTF(result.text.c_str()); | 516 | jstring text = env->NewStringUTF(result.text.c_str()); |
| 516 | env->SetObjectArrayElement(obj_arr, 0, text); | 517 | env->SetObjectArrayElement(obj_arr, 0, text); |
| @@ -543,5 +544,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | @@ -543,5 +544,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | ||
| 543 | env->SetObjectArrayElement(obj_arr, 5, | 544 | env->SetObjectArrayElement(obj_arr, 5, |
| 544 | env->NewStringUTF(result.event.c_str())); | 545 | env->NewStringUTF(result.event.c_str())); |
| 545 | 546 | ||
| 547 | + // [6]: durations, array of float | ||
| 548 | + jfloatArray durations_arr = env->NewFloatArray(result.durations.size()); | ||
| 549 | + env->SetFloatArrayRegion(durations_arr, 0, result.durations.size(), | ||
| 550 | + result.durations.data()); | ||
| 551 | + | ||
| 552 | + env->SetObjectArrayElement(obj_arr, 6, durations_arr); | ||
| 553 | + | ||
| 546 | return obj_arr; | 554 | return obj_arr; |
| 547 | } | 555 | } |
| @@ -9,6 +9,9 @@ data class OfflineRecognizerResult( | @@ -9,6 +9,9 @@ data class OfflineRecognizerResult( | ||
| 9 | val lang: String, | 9 | val lang: String, |
| 10 | val emotion: String, | 10 | val emotion: String, |
| 11 | val event: String, | 11 | val event: String, |
| 12 | + | ||
| 13 | + // valid only for TDT models | ||
| 14 | + val durations: FloatArray, | ||
| 12 | ) | 15 | ) |
| 13 | 16 | ||
| 14 | data class OfflineTransducerModelConfig( | 17 | data class OfflineTransducerModelConfig( |
| @@ -139,13 +142,15 @@ class OfflineRecognizer( | @@ -139,13 +142,15 @@ class OfflineRecognizer( | ||
| 139 | val lang = objArray[3] as String | 142 | val lang = objArray[3] as String |
| 140 | val emotion = objArray[4] as String | 143 | val emotion = objArray[4] as String |
| 141 | val event = objArray[5] as String | 144 | val event = objArray[5] as String |
| 145 | + val durations = objArray[6] as FloatArray | ||
| 142 | return OfflineRecognizerResult( | 146 | return OfflineRecognizerResult( |
| 143 | text = text, | 147 | text = text, |
| 144 | tokens = tokens, | 148 | tokens = tokens, |
| 145 | timestamps = timestamps, | 149 | timestamps = timestamps, |
| 146 | lang = lang, | 150 | lang = lang, |
| 147 | emotion = emotion, | 151 | emotion = emotion, |
| 148 | - event = event | 152 | + event = event, |
| 153 | + durations = durations, | ||
| 149 | ) | 154 | ) |
| 150 | } | 155 | } |
| 151 | 156 |
| @@ -539,6 +539,11 @@ class SherpaOnnxOfflineRecongitionResult { | @@ -539,6 +539,11 @@ class SherpaOnnxOfflineRecongitionResult { | ||
| 539 | return (0..<result.pointee.count).map { p[Int($0)] } | 539 | return (0..<result.pointee.count).map { p[Int($0)] } |
| 540 | }() | 540 | }() |
| 541 | 541 | ||
| 542 | + private lazy var _durations: [Float] = { | ||
| 543 | + guard let p = result.pointee.durations else { return [] } | ||
| 544 | + return (0..<result.pointee.count).map { p[Int($0)] } | ||
| 545 | + }() | ||
| 546 | + | ||
| 542 | private lazy var _lang: String = { | 547 | private lazy var _lang: String = { |
| 543 | guard let cstr = result.pointee.lang else { return "" } | 548 | guard let cstr = result.pointee.lang else { return "" } |
| 544 | return String(cString: cstr) | 549 | return String(cString: cstr) |
| @@ -561,6 +566,9 @@ class SherpaOnnxOfflineRecongitionResult { | @@ -561,6 +566,9 @@ class SherpaOnnxOfflineRecongitionResult { | ||
| 561 | var count: Int { Int(result.pointee.count) } | 566 | var count: Int { Int(result.pointee.count) } |
| 562 | var timestamps: [Float] { _timestamps } | 567 | var timestamps: [Float] { _timestamps } |
| 563 | 568 | ||
| 569 | + // Non-empty for TDT models. Empty for all other non-TDT models | ||
| 570 | + var durations: [Float] { _durations } | ||
| 571 | + | ||
| 564 | // For SenseVoice models, it can be zh, en, ja, yue, ko | 572 | // For SenseVoice models, it can be zh, en, ja, yue, ko |
| 565 | // where zh is for Chinese | 573 | // where zh is for Chinese |
| 566 | // en is for English | 574 | // en is for English |
-
请 注册 或 登录 后发表评论