正在显示
11 个修改的文件
包含
170 行增加
和
19 行删除
| @@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | @@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | ||
| 126 | include(asio) | 126 | include(asio) |
| 127 | endif() | 127 | endif() |
| 128 | 128 | ||
| 129 | +include(json) | ||
| 130 | + | ||
| 129 | add_subdirectory(sherpa-onnx) | 131 | add_subdirectory(sherpa-onnx) |
| 130 | 132 | ||
| 131 | if(SHERPA_ONNX_ENABLE_C_API) | 133 | if(SHERPA_ONNX_ENABLE_C_API) |
cmake/json.cmake
0 → 100644
| 1 | +function(download_json) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz") | ||
| 5 | + set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz") | ||
| 6 | + set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download json | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/json-3.11.2.tar.gz | ||
| 12 | + ${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz | ||
| 13 | + ${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz | ||
| 14 | + /tmp/json-3.11.2.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/json-3.11.2.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(json_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${json_URL}" json_URL) | ||
| 22 | + set(json_URL2) | ||
| 23 | + break() | ||
| 24 | + endif() | ||
| 25 | + endforeach() | ||
| 26 | + | ||
| 27 | + FetchContent_Declare(json | ||
| 28 | + URL | ||
| 29 | + ${json_URL} | ||
| 30 | + ${json_URL2} | ||
| 31 | + URL_HASH ${json_HASH} | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + FetchContent_GetProperties(json) | ||
| 35 | + if(NOT json_POPULATED) | ||
| 36 | + message(STATUS "Downloading json from ${json_URL}") | ||
| 37 | + FetchContent_Populate(json) | ||
| 38 | + endif() | ||
| 39 | + message(STATUS "json is downloaded to ${json_SOURCE_DIR}") | ||
| 40 | + include_directories(${json_SOURCE_DIR}/include) | ||
| 41 | + # Use #include "nlohmann/json.hpp" | ||
| 42 | +endfunction() | ||
| 43 | + | ||
| 44 | +download_json() |
| @@ -8,11 +8,13 @@ | @@ -8,11 +8,13 @@ | ||
| 8 | #include <assert.h> | 8 | #include <assert.h> |
| 9 | 9 | ||
| 10 | #include <algorithm> | 10 | #include <algorithm> |
| 11 | +#include <iomanip> | ||
| 11 | #include <memory> | 12 | #include <memory> |
| 12 | #include <sstream> | 13 | #include <sstream> |
| 13 | #include <utility> | 14 | #include <utility> |
| 14 | #include <vector> | 15 | #include <vector> |
| 15 | 16 | ||
| 17 | +#include "nlohmann/json.hpp" | ||
| 16 | #include "sherpa-onnx/csrc/file-utils.h" | 18 | #include "sherpa-onnx/csrc/file-utils.h" |
| 17 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 18 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 20 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| @@ -22,16 +24,56 @@ | @@ -22,16 +24,56 @@ | ||
| 22 | 24 | ||
| 23 | namespace sherpa_onnx { | 25 | namespace sherpa_onnx { |
| 24 | 26 | ||
| 27 | +std::string OnlineRecognizerResult::AsJsonString() const { | ||
| 28 | + using json = nlohmann::json; | ||
| 29 | + json j; | ||
| 30 | + j["text"] = text; | ||
| 31 | + j["tokens"] = tokens; | ||
| 32 | + j["start_time"] = start_time; | ||
| 33 | +#if 1 | ||
| 34 | + // This branch chooses number of decimal points to keep in | ||
| 35 | + // the return json string | ||
| 36 | + std::ostringstream os; | ||
| 37 | + os << "["; | ||
| 38 | + std::string sep = ""; | ||
| 39 | + for (auto t : timestamps) { | ||
| 40 | + os << sep << std::fixed << std::setprecision(2) << t; | ||
| 41 | + sep = ", "; | ||
| 42 | + } | ||
| 43 | + os << "]"; | ||
| 44 | + j["timestamps"] = os.str(); | ||
| 45 | +#else | ||
| 46 | + j["timestamps"] = timestamps; | ||
| 47 | +#endif | ||
| 48 | + | ||
| 49 | + j["segment"] = segment; | ||
| 50 | + j["is_final"] = is_final; | ||
| 51 | + | ||
| 52 | + return j.dump(); | ||
| 53 | +} | ||
| 54 | + | ||
| 25 | static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | 55 | static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, |
| 26 | - const SymbolTable &sym_table) { | ||
| 27 | - std::string text; | ||
| 28 | - for (auto t : src.tokens) { | ||
| 29 | - text += sym_table[t]; | 56 | + const SymbolTable &sym_table, |
| 57 | + int32_t frame_shift_ms, | ||
| 58 | + int32_t subsampling_factor) { | ||
| 59 | + OnlineRecognizerResult r; | ||
| 60 | + r.tokens.reserve(src.tokens.size()); | ||
| 61 | + r.timestamps.reserve(src.tokens.size()); | ||
| 62 | + | ||
| 63 | + for (auto i : src.tokens) { | ||
| 64 | + auto sym = sym_table[i]; | ||
| 65 | + | ||
| 66 | + r.text.append(sym); | ||
| 67 | + r.tokens.push_back(std::move(sym)); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 71 | + for (auto t : src.timestamps) { | ||
| 72 | + float time = frame_shift_s * t; | ||
| 73 | + r.timestamps.push_back(time); | ||
| 30 | } | 74 | } |
| 31 | 75 | ||
| 32 | - OnlineRecognizerResult ans; | ||
| 33 | - ans.text = std::move(text); | ||
| 34 | - return ans; | 76 | + return r; |
| 35 | } | 77 | } |
| 36 | 78 | ||
| 37 | void OnlineRecognizerConfig::Register(ParseOptions *po) { | 79 | void OnlineRecognizerConfig::Register(ParseOptions *po) { |
| @@ -169,7 +211,10 @@ class OnlineRecognizer::Impl { | @@ -169,7 +211,10 @@ class OnlineRecognizer::Impl { | ||
| 169 | OnlineTransducerDecoderResult decoder_result = s->GetResult(); | 211 | OnlineTransducerDecoderResult decoder_result = s->GetResult(); |
| 170 | decoder_->StripLeadingBlanks(&decoder_result); | 212 | decoder_->StripLeadingBlanks(&decoder_result); |
| 171 | 213 | ||
| 172 | - return Convert(decoder_result, sym_); | 214 | + // TODO(fangjun): Remember to change these constants if needed |
| 215 | + int32_t frame_shift_ms = 10; | ||
| 216 | + int32_t subsampling_factor = 4; | ||
| 217 | + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); | ||
| 173 | } | 218 | } |
| 174 | 219 | ||
| 175 | bool IsEndpoint(OnlineStream *s) const { | 220 | bool IsEndpoint(OnlineStream *s) const { |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <string> | 9 | #include <string> |
| 10 | +#include <vector> | ||
| 10 | 11 | ||
| 11 | #if __ANDROID_API__ >= 9 | 12 | #if __ANDROID_API__ >= 9 |
| 12 | #include "android/asset_manager.h" | 13 | #include "android/asset_manager.h" |
| @@ -22,10 +23,45 @@ | @@ -22,10 +23,45 @@ | ||
| 22 | namespace sherpa_onnx { | 23 | namespace sherpa_onnx { |
| 23 | 24 | ||
| 24 | struct OnlineRecognizerResult { | 25 | struct OnlineRecognizerResult { |
| 26 | + /// Recognition results. | ||
| 27 | + /// For English, it consists of space separated words. | ||
| 28 | + /// For Chinese, it consists of Chinese words without spaces. | ||
| 29 | + /// Example 1: "hello world" | ||
| 30 | + /// Example 2: "你好世界" | ||
| 25 | std::string text; | 31 | std::string text; |
| 26 | 32 | ||
| 27 | - // TODO(fangjun): Add a method to return a json string | ||
| 28 | - std::string ToString() const { return text; } | 33 | + /// Decoded results at the token level. |
| 34 | + /// For instance, for BPE-based models it consists of a list of BPE tokens. | ||
| 35 | + std::vector<std::string> tokens; | ||
| 36 | + | ||
| 37 | + /// timestamps.size() == tokens.size() | ||
| 38 | + /// timestamps[i] records the time in seconds when tokens[i] is decoded. | ||
| 39 | + std::vector<float> timestamps; | ||
| 40 | + | ||
| 41 | + /// ID of this segment | ||
| 42 | + /// When an endpoint is detected, it is incremented | ||
| 43 | + int32_t segment = 0; | ||
| 44 | + | ||
| 45 | + /// Starting frame of this segment. | ||
| 46 | + /// When an endpoint is detected, it will change | ||
| 47 | + float start_time = 0; | ||
| 48 | + | ||
| 49 | + /// True if this is the last segment. | ||
| 50 | + bool is_final = false; | ||
| 51 | + | ||
| 52 | + /** Return a json string. | ||
| 53 | + * | ||
| 54 | + * The returned string contains: | ||
| 55 | + * { | ||
| 56 | + * "text": "The recognition result", | ||
| 57 | + * "tokens": [x, x, x], | ||
| 58 | + * "timestamps": [x, x, x], | ||
| 59 | + * "segment": x, | ||
| 60 | + * "start_time": x, | ||
| 61 | + * "is_final": true|false | ||
| 62 | + * } | ||
| 63 | + */ | ||
| 64 | + std::string AsJsonString() const; | ||
| 29 | }; | 65 | }; |
| 30 | 66 | ||
| 31 | struct OnlineRecognizerConfig { | 67 | struct OnlineRecognizerConfig { |
| @@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | @@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 34 | 34 | ||
| 35 | hyps = other.hyps; | 35 | hyps = other.hyps; |
| 36 | 36 | ||
| 37 | + frame_offset = other.frame_offset; | ||
| 38 | + timestamps = other.timestamps; | ||
| 39 | + | ||
| 37 | return *this; | 40 | return *this; |
| 38 | } | 41 | } |
| 39 | 42 | ||
| @@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | @@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 54 | decoder_out = std::move(other.decoder_out); | 57 | decoder_out = std::move(other.decoder_out); |
| 55 | hyps = std::move(other.hyps); | 58 | hyps = std::move(other.hyps); |
| 56 | 59 | ||
| 60 | + frame_offset = other.frame_offset; | ||
| 61 | + timestamps = std::move(other.timestamps); | ||
| 62 | + | ||
| 57 | return *this; | 63 | return *this; |
| 58 | } | 64 | } |
| 59 | 65 |
| @@ -13,12 +13,18 @@ | @@ -13,12 +13,18 @@ | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | struct OnlineTransducerDecoderResult { | 15 | struct OnlineTransducerDecoderResult { |
| 16 | + /// Number of frames after subsampling we have decoded so far | ||
| 17 | + int32_t frame_offset = 0; | ||
| 18 | + | ||
| 16 | /// The decoded token IDs so far | 19 | /// The decoded token IDs so far |
| 17 | std::vector<int64_t> tokens; | 20 | std::vector<int64_t> tokens; |
| 18 | 21 | ||
| 19 | /// number of trailing blank frames decoded so far | 22 | /// number of trailing blank frames decoded so far |
| 20 | int32_t num_trailing_blanks = 0; | 23 | int32_t num_trailing_blanks = 0; |
| 21 | 24 | ||
| 25 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 26 | + std::vector<int32_t> timestamps; | ||
| 27 | + | ||
| 22 | // Cache decoder_out for endpointing | 28 | // Cache decoder_out for endpointing |
| 23 | Ort::Value decoder_out; | 29 | Ort::Value decoder_out; |
| 24 | 30 |
| @@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 102 | 102 | ||
| 103 | bool emitted = false; | 103 | bool emitted = false; |
| 104 | for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { | 104 | for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { |
| 105 | + auto &r = (*result)[i]; | ||
| 105 | auto y = static_cast<int32_t>(std::distance( | 106 | auto y = static_cast<int32_t>(std::distance( |
| 106 | static_cast<const float *>(p_logit), | 107 | static_cast<const float *>(p_logit), |
| 107 | std::max_element(static_cast<const float *>(p_logit), | 108 | std::max_element(static_cast<const float *>(p_logit), |
| 108 | static_cast<const float *>(p_logit) + vocab_size))); | 109 | static_cast<const float *>(p_logit) + vocab_size))); |
| 109 | if (y != 0) { | 110 | if (y != 0) { |
| 110 | emitted = true; | 111 | emitted = true; |
| 111 | - (*result)[i].tokens.push_back(y); | ||
| 112 | - (*result)[i].num_trailing_blanks = 0; | 112 | + r.tokens.push_back(y); |
| 113 | + r.timestamps.push_back(t + r.frame_offset); | ||
| 114 | + r.num_trailing_blanks = 0; | ||
| 113 | } else { | 115 | } else { |
| 114 | - ++(*result)[i].num_trailing_blanks; | 116 | + ++r.num_trailing_blanks; |
| 115 | } | 117 | } |
| 116 | } | 118 | } |
| 117 | if (emitted) { | 119 | if (emitted) { |
| @@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 121 | } | 123 | } |
| 122 | 124 | ||
| 123 | UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); | 125 | UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); |
| 126 | + | ||
| 127 | + // Update frame_offset | ||
| 128 | + for (auto &r : *result) { | ||
| 129 | + r.frame_offset += num_frames; | ||
| 130 | + } | ||
| 124 | } | 131 | } |
| 125 | 132 | ||
| 126 | } // namespace sherpa_onnx | 133 | } // namespace sherpa_onnx |
| @@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | @@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | ||
| 87 | 87 | ||
| 88 | std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); | 88 | std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); |
| 89 | r->tokens = std::move(tokens); | 89 | r->tokens = std::move(tokens); |
| 90 | + r->timestamps = std::move(hyp.timestamps); | ||
| 90 | r->num_trailing_blanks = hyp.num_trailing_blanks; | 91 | r->num_trailing_blanks = hyp.num_trailing_blanks; |
| 91 | } | 92 | } |
| 92 | 93 | ||
| @@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 148 | float *p_logit = logit.GetTensorMutableData<float>(); | 149 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 149 | 150 | ||
| 150 | for (int32_t b = 0; b < batch_size; ++b) { | 151 | for (int32_t b = 0; b < batch_size; ++b) { |
| 152 | + int32_t frame_offset = (*result)[b].frame_offset; | ||
| 151 | int32_t start = hyps_num_split[b]; | 153 | int32_t start = hyps_num_split[b]; |
| 152 | int32_t end = hyps_num_split[b + 1]; | 154 | int32_t end = hyps_num_split[b + 1]; |
| 153 | LogSoftmax(p_logit, vocab_size, (end - start)); | 155 | LogSoftmax(p_logit, vocab_size, (end - start)); |
| @@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 162 | Hypothesis new_hyp = prev[hyp_index]; | 164 | Hypothesis new_hyp = prev[hyp_index]; |
| 163 | if (new_token != 0) { | 165 | if (new_token != 0) { |
| 164 | new_hyp.ys.push_back(new_token); | 166 | new_hyp.ys.push_back(new_token); |
| 167 | + new_hyp.timestamps.push_back(t + frame_offset); | ||
| 165 | new_hyp.num_trailing_blanks = 0; | 168 | new_hyp.num_trailing_blanks = 0; |
| 166 | } else { | 169 | } else { |
| 167 | ++new_hyp.num_trailing_blanks; | 170 | ++new_hyp.num_trailing_blanks; |
| @@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 177 | for (int32_t b = 0; b != batch_size; ++b) { | 180 | for (int32_t b = 0; b != batch_size; ++b) { |
| 178 | auto &hyps = cur[b]; | 181 | auto &hyps = cur[b]; |
| 179 | auto best_hyp = hyps.GetMostProbable(true); | 182 | auto best_hyp = hyps.GetMostProbable(true); |
| 183 | + auto &r = (*result)[b]; | ||
| 180 | 184 | ||
| 181 | - (*result)[b].hyps = std::move(hyps); | ||
| 182 | - (*result)[b].tokens = std::move(best_hyp.ys); | ||
| 183 | - (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; | 185 | + r.hyps = std::move(hyps); |
| 186 | + r.tokens = std::move(best_hyp.ys); | ||
| 187 | + r.num_trailing_blanks = best_hyp.num_trailing_blanks; | ||
| 188 | + r.frame_offset += num_frames; | ||
| 184 | } | 189 | } |
| 185 | } | 190 | } |
| 186 | 191 |
| @@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() { | @@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() { | ||
| 196 | auto result = recognizer_->GetResult(c->s.get()); | 196 | auto result = recognizer_->GetResult(c->s.get()); |
| 197 | 197 | ||
| 198 | asio::post(server_->GetConnectionContext(), | 198 | asio::post(server_->GetConnectionContext(), |
| 199 | - [this, hdl = c->hdl, str = result.ToString()]() { | 199 | + [this, hdl = c->hdl, str = result.AsJsonString()]() { |
| 200 | server_->Send(hdl, str); | 200 | server_->Send(hdl, str); |
| 201 | }); | 201 | }); |
| 202 | active_.erase(c->hdl); | 202 | active_.erase(c->hdl); |
| @@ -102,7 +102,7 @@ for a list of pre-trained models to download. | @@ -102,7 +102,7 @@ for a list of pre-trained models to download. | ||
| 102 | recognizer.DecodeStream(s.get()); | 102 | recognizer.DecodeStream(s.get()); |
| 103 | } | 103 | } |
| 104 | 104 | ||
| 105 | - std::string text = recognizer.GetResult(s.get()).text; | 105 | + std::string text = recognizer.GetResult(s.get()).AsJsonString(); |
| 106 | 106 | ||
| 107 | fprintf(stderr, "Done!\n"); | 107 | fprintf(stderr, "Done!\n"); |
| 108 | 108 |
| @@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( | @@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( | ||
| 434 | sherpa_onnx::OnlineStream *s = | 434 | sherpa_onnx::OnlineStream *s = |
| 435 | reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr); | 435 | reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr); |
| 436 | sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); | 436 | sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); |
| 437 | - return env->NewStringUTF(result.ToString().c_str()); | 437 | + return env->NewStringUTF(result.text.c_str()); |
| 438 | } | 438 | } |
| 439 | 439 | ||
| 440 | SHERPA_ONNX_EXTERN_C | 440 | SHERPA_ONNX_EXTERN_C |
-
请 注册 或 登录 后发表评论