Fangjun Kuang
Committed by GitHub

Add C API for punctuation (#768)

... ... @@ -11,9 +11,22 @@ log() {
echo "SLID_EXE is $SLID_EXE"
echo "SID_EXE is $SID_EXE"
echo "AT_EXE is $AT_EXE"
echo "PUNCT_EXE is $PUNCT_EXE"
echo "PATH: $PATH"
log "------------------------------------------------------------"
log "Test adding punctuations "
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
ls -lh
tar xf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
ls -lh sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
$PUNCT_EXE
rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
log "------------------------------------------------------------"
log "Test audio tagging "
log "------------------------------------------------------------"
... ...
... ... @@ -126,7 +126,7 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*
path: install/*
- name: Test offline punctuation
shell: bash
... ... @@ -143,6 +143,7 @@ jobs:
export SLID_EXE=spoken-language-identification-c-api
export SID_EXE=speaker-identification-c-api
export AT_EXE=audio-tagging-c-api
export PUNCT_EXE=add-punctuation-c-api
.github/scripts/test-c-api.sh
... ...
... ... @@ -122,6 +122,7 @@ jobs:
export SLID_EXE=spoken-language-identification-c-api
export SID_EXE=speaker-identification-c-api
export AT_EXE=audio-tagging-c-api
export PUNCT_EXE=add-punctuation-c-api
.github/scripts/test-c-api.sh
... ...
... ... @@ -89,6 +89,7 @@ jobs:
export SLID_EXE=spoken-language-identification-c-api.exe
export SID_EXE=speaker-identification-c-api.exe
export AT_EXE=audio-tagging-c-api.exe
export PUNCT_EXE=add-punctuation-c-api.exe
.github/scripts/test-c-api.sh
... ...
... ... @@ -89,6 +89,7 @@ jobs:
export SLID_EXE=spoken-language-identification-c-api.exe
export SID_EXE=speaker-identification-c-api.exe
export AT_EXE=audio-tagging-c-api.exe
export PUNCT_EXE=add-punctuation-c-api.exe
.github/scripts/test-c-api.sh
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.19")
set(SHERPA_ONNX_VERSION "1.9.21")
# Disable warning about
#
... ...
... ... @@ -21,6 +21,9 @@ target_link_libraries(streaming-hlg-decode-file-c-api sherpa-onnx-c-api)
add_executable(audio-tagging-c-api audio-tagging-c-api.c)
target_link_libraries(audio-tagging-c-api sherpa-onnx-c-api)
add_executable(add-punctuation-c-api add-punctuation-c-api.c)
target_link_libraries(add-punctuation-c-api sherpa-onnx-c-api)
if(SHERPA_ONNX_HAS_ALSA)
add_subdirectory(./asr-microphone-example)
elseif((UNIX AND NOT APPLE) OR LINUX)
... ...
// c-api-examples/add-punctuation-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation
// We assume you have pre-downloaded the model files for testing
// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
//
// An example is given below:
//
// clang-format off
//
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
// tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
// rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
//
// clang-format on
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "sherpa-onnx/c-api/c-api.h"
int32_t main() {
SherpaOnnxOfflinePunctuationConfig config;
memset(&config, 0, sizeof(config));
// clang-format off
config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx";
// clang-format on
config.model.num_threads = 1;
config.model.debug = 1;
config.model.provider = "cpu";
const SherpaOnnxOfflinePunctuation *punct =
SherpaOnnxCreateOfflinePunctuation(&config);
if (!punct) {
fprintf(stderr,
"Failed to create OfflinePunctuation. Please check your config");
return -1;
}
const char *texts[] = {
"这是一个测试你好吗How are you我很好thank you are you ok谢谢你",
"我们都是木头人不会说话不会动",
"The African blogosphere is rapidly expanding bringing more voices "
"online in the form of commentaries opinions analyses rants and poetry",
};
int32_t n = sizeof(texts) / sizeof(const char *);
fprintf(stderr, "n: %d\n", n);
fprintf(stderr, "--------------------\n");
for (int32_t i = 0; i != n; ++i) {
const char *text_with_punct =
SherpaOfflinePunctuationAddPunct(punct, texts[i]);
fprintf(stderr, "Input text: %s\n", texts[i]);
fprintf(stderr, "Output text: %s\n", text_with_punct);
SherpaOfflinePunctuationFreeText(text_with_punct);
fprintf(stderr, "--------------------\n");
}
SherpaOnnxDestroyOfflinePunctuation(punct);
return 0;
};
... ...
#!/usr/bin/env python3
# Real-time speech recognition from a microphone with sherpa-onnx Python API
# with endpoint detection.
# This script uses a streaming paraformer
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#
# to download pre-trained models
import sys
from pathlib import Path
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html to download it"
)
def create_recognizer():
encoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"
decoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"
tokens = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"
assert_file_exists(encoder)
assert_file_exists(decoder)
assert_file_exists(tokens)
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=tokens,
encoder=encoder,
decoder=decoder,
num_threads=1,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
)
return recognizer
def main():
devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
recognizer = create_recognizer()
print("Started! Please speak")
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
# sherpa-onnx will do resampling inside.
sample_rate = 48000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
stream = recognizer.create_stream()
last_result = ""
segment_id = 0
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
is_endpoint = recognizer.is_endpoint(stream)
result = recognizer.get_result(stream)
if result and (last_result != result):
last_result = result
print("\r{}:{}".format(segment_id, result), end="", flush=True)
if is_endpoint:
if result:
print("\r{}:{}".format(segment_id, result), flush=True)
segment_id += 1
recognizer.reset(stream)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
... ...
... ... @@ -15,6 +15,7 @@
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
... ... @@ -1299,3 +1300,48 @@ void SherpaOnnxAudioTaggingFreeResults(
delete[] events;
}
struct SherpaOnnxOfflinePunctuation {
std::unique_ptr<sherpa_onnx::OfflinePunctuation> impl;
};
const SherpaOnnxOfflinePunctuation *SherpaOnnxCreateOfflinePunctuation(
const SherpaOnnxOfflinePunctuationConfig *config) {
sherpa_onnx::OfflinePunctuationConfig c;
c.model.ct_transformer = SHERPA_ONNX_OR(config->model.ct_transformer, "");
c.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
c.model.debug = config->model.debug;
c.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
if (c.model.debug) {
SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str());
}
if (!c.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}
SherpaOnnxOfflinePunctuation *punct = new SherpaOnnxOfflinePunctuation;
punct->impl = std::make_unique<sherpa_onnx::OfflinePunctuation>(c);
return punct;
}
void SherpaOnnxDestroyOfflinePunctuation(
const SherpaOnnxOfflinePunctuation *punct) {
delete punct;
}
const char *SherpaOfflinePunctuationAddPunct(
const SherpaOnnxOfflinePunctuation *punct, const char *text) {
std::string text_with_punct = punct->impl->AddPunctuation(text);
char *ans = new char[text_with_punct.size() + 1];
std::copy(text_with_punct.begin(), text_with_punct.end(), ans);
ans[text_with_punct.size()] = 0;
return ans;
}
void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }
... ...
... ... @@ -1149,6 +1149,41 @@ SherpaOnnxAudioTaggingCompute(const SherpaOnnxAudioTagging *tagger,
SHERPA_ONNX_API void SherpaOnnxAudioTaggingFreeResults(
const SherpaOnnxAudioEvent *const *p);
// ============================================================
// For punctuation
// ============================================================
SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuationModelConfig {
const char *ct_transformer;
int32_t num_threads;
int32_t debug; // true to print debug information of the model
const char *provider;
} SherpaOnnxOfflinePunctuationModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuationConfig {
SherpaOnnxOfflinePunctuationModelConfig model;
} SherpaOnnxOfflinePunctuationConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuation
SherpaOnnxOfflinePunctuation;
// The user has to invoke SherpaOnnxDestroyOfflinePunctuation()
// to free the returned pointer to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOfflinePunctuation *
SherpaOnnxCreateOfflinePunctuation(
const SherpaOnnxOfflinePunctuationConfig *config);
SHERPA_ONNX_API void SherpaOnnxDestroyOfflinePunctuation(
const SherpaOnnxOfflinePunctuation *punct);
// Add punctuations to the input text.
// The user has to invoke SherpaOfflinePunctuationFreeText()
// to free the returned pointer to avoid memory leak
SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct(
const SherpaOnnxOfflinePunctuation *punct, const char *text);
SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text);
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
... ...
... ... @@ -134,25 +134,40 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
}
} // for (int32_t i = 0; i != num_segments; ++i)
std::string ans;
if (punctuations.empty()) {
return text + meta_data.id2punct[meta_data.dot_id];
}
std::vector<std::string> words_punct;
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
if (i > tokens.size()) {
if (i >= tokens.size()) {
break;
}
const std::string &w = tokens[i];
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
ans.push_back(' ');
std::string &w = tokens[i];
if (i > 0 && !(words_punct.back()[0] & 0x80) && !(w[0] & 0x80)) {
words_punct.push_back(" ");
}
ans.append(w);
words_punct.push_back(std::move(w));
if (punctuations[i] != meta_data.underline_id) {
ans.append(meta_data.id2punct[punctuations[i]]);
words_punct.push_back(meta_data.id2punct[punctuations[i]]);
}
}
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
ans.push_back(meta_data.dot_id);
if (words_punct.back() == meta_data.id2punct[meta_data.comma_id] ||
words_punct.back() == meta_data.id2punct[meta_data.pause_id]) {
words_punct.back() = meta_data.id2punct[meta_data.dot_id];
}
if (words_punct.back() != meta_data.id2punct[meta_data.dot_id] &&
words_punct.back() != meta_data.id2punct[meta_data.quest_id]) {
words_punct.push_back(meta_data.id2punct[meta_data.dot_id]);
}
std::string ans;
for (const auto &w : words_punct) {
ans.append(w);
}
return ans;
}
... ...
... ... @@ -4,6 +4,8 @@
#include "sherpa-onnx/python/csrc/offline-punctuation.h"
#include <string>
#include "sherpa-onnx/csrc/offline-punctuation.h"
namespace sherpa_onnx {
... ...