Committed by
GitHub
Support reading multi-channel wave files with 8/16/32-bit encoded samples (#1258)
正在显示
5 个修改的文件
包含
148 行增加
和
42 行删除
| @@ -38,14 +38,28 @@ done | @@ -38,14 +38,28 @@ done | ||
| 38 | 38 | ||
| 39 | 39 | ||
| 40 | # test wav reader for non-standard wav files | 40 | # test wav reader for non-standard wav files |
| 41 | -curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/naudio.wav | ||
| 42 | -curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/junk-padding.wav | 41 | +waves=( |
| 42 | + naudio.wav | ||
| 43 | + junk-padding.wav | ||
| 44 | + int8-1-channel-zh.wav | ||
| 45 | + int8-2-channel-zh.wav | ||
| 46 | + int8-4-channel-zh.wav | ||
| 47 | + int16-1-channel-zh.wav | ||
| 48 | + int16-2-channel-zh.wav | ||
| 49 | + int32-1-channel-zh.wav | ||
| 50 | + int32-2-channel-zh.wav | ||
| 51 | + float32-1-channel-zh.wav | ||
| 52 | + float32-2-channel-zh.wav | ||
| 53 | +) | ||
| 54 | +for w in ${waves[@]}; do | ||
| 55 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/$w | ||
| 43 | 56 | ||
| 44 | -time $EXE \ | ||
| 45 | - --tokens=$repo/tokens.txt \ | ||
| 46 | - --sense-voice-model=$repo/model.int8.onnx \ | ||
| 47 | - ./naudio.wav \ | ||
| 48 | - ./junk-padding.wav | 57 | + time $EXE \ |
| 58 | + --tokens=$repo/tokens.txt \ | ||
| 59 | + --sense-voice-model=$repo/model.int8.onnx \ | ||
| 60 | + $w | ||
| 61 | + rm -v $w | ||
| 62 | +done | ||
| 49 | 63 | ||
| 50 | rm -rf $repo | 64 | rm -rf $repo |
| 51 | 65 |
| @@ -143,35 +143,34 @@ jobs: | @@ -143,35 +143,34 @@ jobs: | ||
| 143 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 143 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 144 | path: install/* | 144 | path: install/* |
| 145 | 145 | ||
| 146 | - - name: Test online punctuation | 146 | + - name: Test offline CTC |
| 147 | shell: bash | 147 | shell: bash |
| 148 | run: | | 148 | run: | |
| 149 | du -h -d1 . | 149 | du -h -d1 . |
| 150 | export PATH=$PWD/build/bin:$PATH | 150 | export PATH=$PWD/build/bin:$PATH |
| 151 | - export EXE=sherpa-onnx-online-punctuation | 151 | + export EXE=sherpa-onnx-offline |
| 152 | 152 | ||
| 153 | - .github/scripts/test-online-punctuation.sh | 153 | + .github/scripts/test-offline-ctc.sh |
| 154 | du -h -d1 . | 154 | du -h -d1 . |
| 155 | 155 | ||
| 156 | - - name: Test offline transducer | 156 | + - name: Test online punctuation |
| 157 | shell: bash | 157 | shell: bash |
| 158 | run: | | 158 | run: | |
| 159 | du -h -d1 . | 159 | du -h -d1 . |
| 160 | export PATH=$PWD/build/bin:$PATH | 160 | export PATH=$PWD/build/bin:$PATH |
| 161 | - export EXE=sherpa-onnx-offline | 161 | + export EXE=sherpa-onnx-online-punctuation |
| 162 | 162 | ||
| 163 | - .github/scripts/test-offline-transducer.sh | 163 | + .github/scripts/test-online-punctuation.sh |
| 164 | du -h -d1 . | 164 | du -h -d1 . |
| 165 | 165 | ||
| 166 | - | ||
| 167 | - - name: Test offline CTC | 166 | + - name: Test offline transducer |
| 168 | shell: bash | 167 | shell: bash |
| 169 | run: | | 168 | run: | |
| 170 | du -h -d1 . | 169 | du -h -d1 . |
| 171 | export PATH=$PWD/build/bin:$PATH | 170 | export PATH=$PWD/build/bin:$PATH |
| 172 | export EXE=sherpa-onnx-offline | 171 | export EXE=sherpa-onnx-offline |
| 173 | 172 | ||
| 174 | - .github/scripts/test-offline-ctc.sh | 173 | + .github/scripts/test-offline-transducer.sh |
| 175 | du -h -d1 . | 174 | du -h -d1 . |
| 176 | 175 | ||
| 177 | - name: Test online transducer | 176 | - name: Test online transducer |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ |
| 7 | #include <cstdint> | 7 | #include <cstdint> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | +#include <utility> | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 11 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -50,6 +50,16 @@ struct WaveHeader { | @@ -50,6 +50,16 @@ struct WaveHeader { | ||
| 50 | }; | 50 | }; |
| 51 | static_assert(sizeof(WaveHeader) == 44); | 51 | static_assert(sizeof(WaveHeader) == 44); |
| 52 | 52 | ||
| 53 | +/* | ||
| 54 | +sox int16-1-channel-zh.wav -b 8 int8-1-channel-zh.wav | ||
| 55 | + | ||
| 56 | +sox int16-1-channel-zh.wav -c 2 int16-2-channel-zh.wav | ||
| 57 | + | ||
| 58 | +we use audacity to generate int32-1-channel-zh.wav and float32-1-channel-zh.wav | ||
| 59 | +because sox uses WAVE_FORMAT_EXTENSIBLE, which is not easy to support | ||
| 60 | +in sherpa-onnx. | ||
| 61 | + */ | ||
| 62 | + | ||
| 53 | // Read a wave file of mono-channel. | 63 | // Read a wave file of mono-channel. |
| 54 | // Return its samples normalized to the range [-1, 1). | 64 | // Return its samples normalized to the range [-1, 1). |
| 55 | std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | 65 | std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, |
| @@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 114 | is.read(reinterpret_cast<char *>(&header.audio_format), | 124 | is.read(reinterpret_cast<char *>(&header.audio_format), |
| 115 | sizeof(header.audio_format)); | 125 | sizeof(header.audio_format)); |
| 116 | 126 | ||
| 117 | - if (header.audio_format != 1) { // 1 for PCM | 127 | + if (header.audio_format != 1 && header.audio_format != 3) { |
| 128 | + // 1 for integer PCM | ||
| 129 | + // 3 for floating point PCM | ||
| 130 | + // see https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html | ||
| 131 | + // and https://github.com/microsoft/DirectXTK/wiki/Wave-Formats | ||
| 118 | SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", | 132 | SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", |
| 119 | header.audio_format); | 133 | header.audio_format); |
| 134 | + | ||
| 135 | + if (header.audio_format == static_cast<int16_t>(0xfffe)) { | ||
| 136 | + SHERPA_ONNX_LOGE("We don't support WAVE_FORMAT_EXTENSIBLE files."); | ||
| 137 | + } | ||
| 138 | + | ||
| 120 | *is_ok = false; | 139 | *is_ok = false; |
| 121 | return {}; | 140 | return {}; |
| 122 | } | 141 | } |
| @@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 125 | sizeof(header.num_channels)); | 144 | sizeof(header.num_channels)); |
| 126 | 145 | ||
| 127 | if (header.num_channels != 1) { // we support only single channel for now | 146 | if (header.num_channels != 1) { // we support only single channel for now |
| 128 | - SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", | ||
| 129 | - header.num_channels); | ||
| 130 | - *is_ok = false; | ||
| 131 | - return {}; | 147 | + SHERPA_ONNX_LOGE( |
| 148 | + "Warning: %d channels are found. We only use the first channel.\n", | ||
| 149 | + header.num_channels); | ||
| 132 | } | 150 | } |
| 133 | 151 | ||
| 134 | is.read(reinterpret_cast<char *>(&header.sample_rate), | 152 | is.read(reinterpret_cast<char *>(&header.sample_rate), |
| @@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 161 | return {}; | 179 | return {}; |
| 162 | } | 180 | } |
| 163 | 181 | ||
| 164 | - if (header.bits_per_sample != 16) { // we support only 16 bits per sample | ||
| 165 | - SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n", | 182 | + if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && |
| 183 | + header.bits_per_sample != 32) { | ||
| 184 | + SHERPA_ONNX_LOGE("Expected bits_per_sample 8, 16 or 32. Given: %d\n", | ||
| 166 | header.bits_per_sample); | 185 | header.bits_per_sample); |
| 167 | *is_ok = false; | 186 | *is_ok = false; |
| 168 | return {}; | 187 | return {}; |
| @@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 199 | 218 | ||
| 200 | *sampling_rate = header.sample_rate; | 219 | *sampling_rate = header.sample_rate; |
| 201 | 220 | ||
| 202 | - // header.subchunk2_size contains the number of bytes in the data. | ||
| 203 | - // As we assume each sample contains two bytes, so it is divided by 2 here | ||
| 204 | - std::vector<int16_t> samples(header.subchunk2_size / 2); | 221 | + std::vector<float> ans; |
| 205 | 222 | ||
| 206 | - is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 207 | - if (!is) { | 223 | + if (header.bits_per_sample == 16 && header.audio_format == 1) { |
| 224 | + // header.subchunk2_size contains the number of bytes in the data. | ||
| 225 | + // As we assume each sample contains two bytes, so it is divided by 2 here | ||
| 226 | + std::vector<int16_t> samples(header.subchunk2_size / 2); | ||
| 227 | + SHERPA_ONNX_LOGE("%d samples, bytes: %d", (int)samples.size(), | ||
| 228 | + header.subchunk2_size); | ||
| 229 | + | ||
| 230 | + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 231 | + if (!is) { | ||
| 232 | + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); | ||
| 233 | + *is_ok = false; | ||
| 234 | + return {}; | ||
| 235 | + } | ||
| 236 | + | ||
| 237 | + ans.resize(samples.size() / header.num_channels); | ||
| 238 | + | ||
| 239 | + // samples are interleaved | ||
| 240 | + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 241 | + ans[i] = samples[i * header.num_channels] / 32768.; | ||
| 242 | + } | ||
| 243 | + } else if (header.bits_per_sample == 8 && header.audio_format == 1) { | ||
| 244 | + // number of samples == number of bytes for 8-bit encoded samples | ||
| 245 | + // | ||
| 246 | + // For 8-bit encoded samples, they are unsigned! | ||
| 247 | + std::vector<uint8_t> samples(header.subchunk2_size); | ||
| 248 | + | ||
| 249 | + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 250 | + if (!is) { | ||
| 251 | + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); | ||
| 252 | + *is_ok = false; | ||
| 253 | + return {}; | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | + ans.resize(samples.size() / header.num_channels); | ||
| 257 | + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 258 | + // Note(fangjun): We want to normalize each sample into the range [-1, 1] | ||
| 259 | + // Since each original sample is in the range [0, 256], dividing | ||
| 260 | + // them by 128 converts them to the range [0, 2]; | ||
| 261 | + // so after subtracting 1, we get the range [-1, 1] | ||
| 262 | + // | ||
| 263 | + ans[i] = samples[i * header.num_channels] / 128. - 1; | ||
| 264 | + } | ||
| 265 | + } else if (header.bits_per_sample == 32 && header.audio_format == 1) { | ||
| 266 | + // 32 here is for int32 | ||
| 267 | + // | ||
| 268 | + // header.subchunk2_size contains the number of bytes in the data. | ||
| 269 | + // As we assume each sample contains 4 bytes, so it is divided by 4 here | ||
| 270 | + std::vector<int32_t> samples(header.subchunk2_size / 4); | ||
| 271 | + | ||
| 272 | + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 273 | + if (!is) { | ||
| 274 | + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); | ||
| 275 | + *is_ok = false; | ||
| 276 | + return {}; | ||
| 277 | + } | ||
| 278 | + | ||
| 279 | + ans.resize(samples.size() / header.num_channels); | ||
| 280 | + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 281 | + ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31); | ||
| 282 | + } | ||
| 283 | + } else if (header.bits_per_sample == 32 && header.audio_format == 3) { | ||
| 284 | + // 32 here is for float32 | ||
| 285 | + // | ||
| 286 | + // header.subchunk2_size contains the number of bytes in the data. | ||
| 287 | + // As we assume each sample contains 4 bytes, so it is divided by 4 here | ||
| 288 | + std::vector<float> samples(header.subchunk2_size / 4); | ||
| 289 | + | ||
| 290 | + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 291 | + if (!is) { | ||
| 292 | + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); | ||
| 293 | + *is_ok = false; | ||
| 294 | + return {}; | ||
| 295 | + } | ||
| 296 | + | ||
| 297 | + ans.resize(samples.size() / header.num_channels); | ||
| 298 | + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 299 | + ans[i] = samples[i * header.num_channels]; | ||
| 300 | + } | ||
| 301 | + } else { | ||
| 302 | + SHERPA_ONNX_LOGE( | ||
| 303 | + "Unsupported %d bits per sample and audio format: %d. Supported values " | ||
| 304 | + "are: 8, 16, 32.", | ||
| 305 | + header.bits_per_sample, header.audio_format); | ||
| 208 | *is_ok = false; | 306 | *is_ok = false; |
| 209 | return {}; | 307 | return {}; |
| 210 | } | 308 | } |
| 211 | 309 | ||
| 212 | - std::vector<float> ans(samples.size()); | ||
| 213 | - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 214 | - ans[i] = samples[i] / 32768.; | ||
| 215 | - } | ||
| 216 | - | ||
| 217 | *is_ok = true; | 310 | *is_ok = true; |
| 218 | return ans; | 311 | return ans; |
| 219 | } | 312 | } |
| @@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env, | @@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env, | ||
| 264 | return (jlong)model; | 264 | return (jlong)model; |
| 265 | } | 265 | } |
| 266 | 266 | ||
| 267 | - | ||
| 268 | SHERPA_ONNX_EXTERN_C | 267 | SHERPA_ONNX_EXTERN_C |
| 269 | -JNIEXPORT void JNICALL | ||
| 270 | -Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(JNIEnv *env, | ||
| 271 | - jobject /*obj*/, | ||
| 272 | - jlong ptr, | ||
| 273 | - jobject _config) { | 268 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig( |
| 269 | + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { | ||
| 274 | auto config = sherpa_onnx::GetOfflineConfig(env, _config); | 270 | auto config = sherpa_onnx::GetOfflineConfig(env, _config); |
| 275 | SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | 271 | SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); |
| 276 | 272 | ||
| @@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | @@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, | ||
| 350 | // [3]: lang, jstring | 346 | // [3]: lang, jstring |
| 351 | // [4]: emotion, jstring | 347 | // [4]: emotion, jstring |
| 352 | // [5]: event, jstring | 348 | // [5]: event, jstring |
| 353 | - env->SetObjectArrayElement(obj_arr, 3, env->NewStringUTF(result.lang.c_str())); | ||
| 354 | - env->SetObjectArrayElement(obj_arr, 4, env->NewStringUTF(result.emotion.c_str())); | ||
| 355 | - env->SetObjectArrayElement(obj_arr, 5, env->NewStringUTF(result.event.c_str())); | 349 | + env->SetObjectArrayElement(obj_arr, 3, |
| 350 | + env->NewStringUTF(result.lang.c_str())); | ||
| 351 | + env->SetObjectArrayElement(obj_arr, 4, | ||
| 352 | + env->NewStringUTF(result.emotion.c_str())); | ||
| 353 | + env->SetObjectArrayElement(obj_arr, 5, | ||
| 354 | + env->NewStringUTF(result.event.c_str())); | ||
| 356 | 355 | ||
| 357 | return obj_arr; | 356 | return obj_arr; |
| 358 | } | 357 | } |
-
请 注册 或 登录 后发表评论