davidliu
Committed by GitHub

add some tests (#537)

* add some tests

* spotless
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.audio
  18 +
  19 +import org.junit.Assert.assertEquals
  20 +import org.junit.Assert.assertTrue
  21 +import org.junit.Test
  22 +import java.nio.ByteBuffer
  23 +
  24 +class AudioBufferCallbackDispatcherTest {
  25 +
  26 + @Test
  27 + fun callsThrough() {
  28 + val dispatcher = AudioBufferCallbackDispatcher()
  29 + val audioBuffer = ByteBuffer.allocateDirect(0)
  30 + var called = false
  31 + val callback = object : AudioBufferCallback {
  32 + override fun onBuffer(buffer: ByteBuffer, audioFormat: Int, channelCount: Int, sampleRate: Int, bytesRead: Int, captureTimeNs: Long): Long {
  33 + assertEquals(audioBuffer, buffer)
  34 + called = true
  35 + return captureTimeNs
  36 + }
  37 + }
  38 + dispatcher.bufferCallback = callback
  39 + dispatcher.onBuffer(
  40 + buffer = audioBuffer,
  41 + audioFormat = 0,
  42 + channelCount = 1,
  43 + sampleRate = 48000,
  44 + bytesRead = 0,
  45 + captureTimeNs = 0L,
  46 + )
  47 +
  48 + assertTrue(called)
  49 + }
  50 +}
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.audio
  18 +
  19 +import android.media.AudioFormat
  20 +import org.junit.Assert.assertEquals
  21 +import org.junit.Test
  22 +import java.nio.ByteBuffer
  23 +import java.nio.ByteOrder
  24 +
  25 +class MixerAudioBufferCallbackTest {
  26 +
  27 + @Test
  28 + fun mixesByte() {
  29 + val mixer = IncrementMixer()
  30 + val buffer = ByteBuffer.allocateDirect(1).order(ByteOrder.nativeOrder())
  31 + buffer.put(0, 0.toByte())
  32 + mixer.onBuffer(
  33 + buffer = buffer,
  34 + audioFormat = AudioFormat.ENCODING_PCM_8BIT,
  35 + channelCount = 1,
  36 + sampleRate = 1,
  37 + bytesRead = 1,
  38 + captureTimeNs = 0,
  39 + )
  40 +
  41 + assertEquals((0 + INCREMENT).toByte(), buffer.get(0))
  42 + }
  43 +
  44 + @Test
  45 + fun mixesShort() {
  46 + val mixer = IncrementMixer()
  47 + val buffer = ByteBuffer.allocateDirect(2).order(ByteOrder.nativeOrder())
  48 + val shortBuffer = buffer.asShortBuffer()
  49 + shortBuffer.put(0, 0.toShort())
  50 + mixer.onBuffer(
  51 + buffer = buffer,
  52 + audioFormat = AudioFormat.ENCODING_PCM_16BIT,
  53 + channelCount = 1,
  54 + sampleRate = 1,
  55 + bytesRead = 2,
  56 + captureTimeNs = 0,
  57 + )
  58 +
  59 + assertEquals((0 + INCREMENT).toShort(), shortBuffer.get(0))
  60 + }
  61 +
  62 + @Test
  63 + fun mixesFloat() {
  64 + val mixer = IncrementMixer()
  65 + val buffer = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder())
  66 + val floatBuffer = buffer.asFloatBuffer()
  67 + floatBuffer.put(0, 0.toFloat())
  68 + mixer.onBuffer(
  69 + buffer = buffer,
  70 + audioFormat = AudioFormat.ENCODING_PCM_FLOAT,
  71 + channelCount = 1,
  72 + sampleRate = 1,
  73 + bytesRead = 1,
  74 + captureTimeNs = 0,
  75 + )
  76 +
  77 + assertEquals((0 + INCREMENT).toFloat(), floatBuffer.get(0))
  78 + }
  79 +
  80 + companion object {
  81 + const val INCREMENT = 1
  82 + }
  83 +
  84 + class IncrementMixer : MixerAudioBufferCallback() {
  85 + override fun onBufferRequest(originalBuffer: ByteBuffer, audioFormat: Int, channelCount: Int, sampleRate: Int, bytesRead: Int, captureTimeNs: Long): BufferResponse? {
  86 + val byteBuffer = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder())
  87 +
  88 + when (audioFormat) {
  89 + AudioFormat.ENCODING_PCM_8BIT -> {
  90 + byteBuffer.put(0, INCREMENT.toByte())
  91 + }
  92 +
  93 + AudioFormat.ENCODING_PCM_16BIT -> {
  94 + byteBuffer.asShortBuffer().put(0, INCREMENT.toShort())
  95 + }
  96 +
  97 + AudioFormat.ENCODING_PCM_FLOAT -> {
  98 + byteBuffer.asFloatBuffer().put(0, INCREMENT.toFloat())
  99 + }
  100 +
  101 + AudioFormat.ENCODING_INVALID -> throw IllegalArgumentException("Bad audio format $audioFormat")
  102 + }
  103 +
  104 + return BufferResponse(byteBuffer)
  105 + }
  106 + }
  107 +}
  108 +
  109 +private fun getBytesPerSample(audioFormat: Int): Int {
  110 + return when (audioFormat) {
  111 + AudioFormat.ENCODING_PCM_8BIT -> 1
  112 + AudioFormat.ENCODING_PCM_16BIT, AudioFormat.ENCODING_IEC61937, AudioFormat.ENCODING_DEFAULT -> 2
  113 + AudioFormat.ENCODING_PCM_FLOAT -> 4
  114 + AudioFormat.ENCODING_INVALID -> throw IllegalArgumentException("Bad audio format $audioFormat")
  115 + else -> throw IllegalArgumentException("Bad audio format $audioFormat")
  116 + }
  117 +}