davidliu
Committed by GitHub

Add pre-connect audio for use with agents (#666)

* Update protocol submodule to v1.38.0

* Basic Preconnect audio buffer implementation

* Add Participant.State and related events

* prerecording full implementation

* Fix outgoing byte datastreams incorrectly padding data

* Add pre-connect audio for use with agents
正在显示 22 个修改的文件 包含 806 行增加16 行删除
---
"client-sdk-android": minor
---
Add pre-connect audio for use with agents
See Room.withPreconnectAudio for details.
... ...
---
"client-sdk-android": patch
---
Fix outgoing datastreams incorrectly padding data
... ...
---
"client-sdk-android": minor
---
Add Participant.State and related events
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -26,6 +26,7 @@ import io.livekit.android.audio.AudioProcessorOptions
import io.livekit.android.audio.AudioSwitchHandler
import io.livekit.android.audio.NoAudioHandler
import io.livekit.android.room.Room
import io.livekit.android.room.track.LocalAudioTrack
import livekit.org.webrtc.EglBase
import livekit.org.webrtc.PeerConnectionFactory
import livekit.org.webrtc.VideoDecoderFactory
... ... @@ -33,6 +34,7 @@ import livekit.org.webrtc.VideoEncoderFactory
import livekit.org.webrtc.audio.AudioDeviceModule
import livekit.org.webrtc.audio.JavaAudioDeviceModule
import okhttp3.OkHttpClient
/**
* Overrides to replace LiveKit internally used components with custom implementations.
*/
... ... @@ -110,6 +112,11 @@ class AudioOptions(
* Called after default setup to allow for customizations on the [JavaAudioDeviceModule].
*
* Not used if [audioDeviceModule] is provided.
*
* Note: We require setting the [JavaAudioDeviceModule.Builder.setSamplesReadyCallback] to provide
* support for [LocalAudioTrack.addSink]. If you wish to grab the audio samples
* from the local microphone track, use [LocalAudioTrack.addSink] instead of setting your own
* callback.
*/
val javaAudioDeviceModuleCustomizer: ((builder: JavaAudioDeviceModule.Builder) -> Unit)? = null,
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.audio
import android.os.SystemClock
import io.livekit.android.audio.PreconnectAudioBuffer.Companion.DEFAULT_TOPIC
import io.livekit.android.audio.PreconnectAudioBuffer.Companion.TIMEOUT
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.collect
import io.livekit.android.room.ConnectionState
import io.livekit.android.room.Room
import io.livekit.android.room.datastream.StreamBytesOptions
import io.livekit.android.room.participant.Participant
import io.livekit.android.util.LKLog
import io.livekit.android.util.flow
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch
import livekit.org.webrtc.AudioTrackSink
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import kotlin.math.min
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
internal class PreconnectAudioBuffer
internal constructor(timeout: Duration) : AudioTrackSink {
companion object {
const val DEFAULT_TOPIC = "lk.agent.pre-connect-audio-buffer"
val TIMEOUT = 10.seconds
}
private val outputStreamLock = Any()
private val outputStream by lazy {
ByteArrayOutputStream()
}
private lateinit var collectedBytes: ByteArray
private val tempArray = ByteArray(1024)
private var initialTime = -1L
private var bitsPerSample = 16
private var sampleRate = 48000 // default sampleRate from JavaAudioDeviceModule
private var numberOfChannels = 1 // default channels from JavaAudioDeviceModule
private var isRecording = true
private val timeoutMs = timeout.inWholeMilliseconds
fun startRecording() {
isRecording = true
}
fun stopRecording() {
synchronized(outputStreamLock) {
if (isRecording) {
collectedBytes = outputStream.toByteArray()
isRecording = false
}
}
}
fun clear() {
stopRecording()
collectedBytes = ByteArray(0)
}
override fun onData(
audioData: ByteBuffer,
bitsPerSample: Int,
sampleRate: Int,
numberOfChannels: Int,
numberOfFrames: Int,
absoluteCaptureTimestampMs: Long,
) {
if (!isRecording) {
return
}
if (initialTime == -1L) {
initialTime = SystemClock.elapsedRealtime()
}
this.bitsPerSample = bitsPerSample
this.sampleRate = sampleRate
this.numberOfChannels = numberOfChannels
val currentTime = SystemClock.elapsedRealtime()
// Limit reached, don't buffer any more.
if (currentTime - initialTime > timeoutMs) {
return
}
audioData.rewind()
synchronized(outputStreamLock) {
if (audioData.hasArray()) {
outputStream.write(audioData.array())
} else {
while (audioData.hasRemaining()) {
val readBytes = min(tempArray.size, audioData.remaining())
audioData.get(tempArray, 0, readBytes)
outputStream.write(tempArray)
}
}
}
}
suspend fun sendAudioData(room: Room, trackSid: String?, agentIdentities: List<Participant.Identity>, topic: String = DEFAULT_TOPIC) {
if (agentIdentities.isEmpty()) {
return
}
val audioData = outputStream.toByteArray()
if (audioData.size <= 1024) {
LKLog.i { "Audio data size too small, nothing to send." }
return
}
val sender = room.localParticipant.streamBytes(
StreamBytesOptions(
topic = topic,
attributes = mapOf(
"sampleRate" to "${this.sampleRate}",
"channels" to "${this.numberOfChannels}",
"trackId" to (trackSid ?: ""),
),
destinationIdentities = agentIdentities,
totalSize = audioData.size.toLong(),
name = "preconnect-audio-buffer",
),
)
try {
sender.write(audioData)
sender.close()
} catch (e: Exception) {
sender.close(e.localizedMessage)
}
val samples = audioData.size / (numberOfChannels * bitsPerSample / 8)
val duration = samples.toFloat() / sampleRate
LKLog.i { "Sent ${duration}s (${audioData.size / 1024}KB) of audio data to ${agentIdentities.size} agent(s) (${agentIdentities.joinToString(",")})" }
}
}
/**
* Starts a pre-connect audio recording that will be sent to
* any agents that connect within the [timeout]. This speeds up
* preceived connection times, as the user can start speaking
* prior to actual connection with the agent.
*
* This will automatically be cleaned up when the room disconnects or the operation fails.
*
* Example:
* ```
* try {
* room.withPreconnectAudio {
* // Audio is being captured automatically
* // Perform any other (async) setup here
* val (url, token) = tokenService.fetchConnectionDetails()
* room.connect(
* url = url,
* token = token,
* )
* room.localParticipant.setMicrophoneEnabled(true)
* }
* } catch (e: Throwable) {
* Log.e(TAG, "Error!")
* }
* ```
* @param timeout the timeout for the remote participant to subscribe to the audio track.
* The room connection needs to be established and the remote participant needs to subscribe to the audio track
* before the timeout is reached. Otherwise, the audio stream will be flushed without sending.
* @param topic the topic to send the preconnect audio buffer to. By default this is configured for
* use with LiveKit Agents.
* @param onError The error handler to call when an error occurs while sending the audio buffer.
* @param operation The connection lambda to call with the pre-connect audio.
*
*/
suspend fun <T> Room.withPreconnectAudio(
timeout: Duration = TIMEOUT,
topic: String = DEFAULT_TOPIC,
onError: ((e: Exception) -> Unit)? = null,
operation: suspend () -> T,
) = coroutineScope {
isPrerecording = true
val audioTrack = localParticipant.getOrCreateDefaultAudioTrack()
val preconnectAudioBuffer = PreconnectAudioBuffer(timeout)
LKLog.v { "Starting preconnect audio buffer" }
preconnectAudioBuffer.startRecording()
audioTrack.addSink(preconnectAudioBuffer)
audioTrack.prewarm()
fun stopRecording() {
if (!isPrerecording) {
return
}
LKLog.v { "Stopping preconnect audio buffer" }
audioTrack.removeSink(preconnectAudioBuffer)
preconnectAudioBuffer.stopRecording()
isPrerecording = false
}
// Clear the preconnect audio buffer after the timeout to free memory.
launch {
delay(TIMEOUT)
preconnectAudioBuffer.clear()
}
val sentIdentities = mutableSetOf<Participant.Identity>()
launch {
suspend fun handleSendIfNeeded(participant: Participant) {
coroutineScope inner@{
engine::connectionState.flow
.takeWhile { it != ConnectionState.CONNECTED }
.collect()
val kind = participant.kind
val state = participant.state
val identity = participant.identity
if (sentIdentities.contains(identity) || kind != Participant.Kind.AGENT || state != Participant.State.ACTIVE || identity == null) {
return@inner
}
stopRecording()
launch {
try {
preconnectAudioBuffer.sendAudioData(
room = this@withPreconnectAudio,
trackSid = audioTrack.sid,
agentIdentities = listOf(identity),
topic = topic,
)
sentIdentities.add(identity)
} catch (e: Exception) {
LKLog.w(e) { "Error occurred while sending the audio preconnect data." }
onError?.invoke(e)
}
}
}
}
events.collect { event ->
when (event) {
is RoomEvent.LocalTrackSubscribed -> {
LKLog.i { "Local audio track has been subscribed to, stopping preconnect audio recording." }
stopRecording()
}
is RoomEvent.ParticipantConnected -> {
// agents may connect with ACTIVE state and not trigger a participant state changed.
handleSendIfNeeded(event.participant)
}
is RoomEvent.ParticipantStateChanged -> {
handleSendIfNeeded(event.participant)
}
is RoomEvent.Disconnected -> {
cancel()
}
else -> {
// Intentionally blank.
}
}
}
}
val retValue: T
try {
retValue = operation.invoke()
} catch (e: Exception) {
cancel()
throw e
}
return@coroutineScope retValue
}
... ...
... ... @@ -197,4 +197,13 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() {
*/
val publication: TrackPublication?,
) : ParticipantEvent(participant)
/**
* A participant's state has changed.
*/
class StateChanged(
override val participant: Participant,
val newState: Participant.State,
val oldState: Participant.State,
) : ParticipantEvent(participant)
}
... ...
... ... @@ -271,6 +271,16 @@ sealed class RoomEvent(val room: Room) : Event() {
*/
val publication: TrackPublication?,
) : RoomEvent(room)
/**
* The state for a participant has changed.
*/
class ParticipantStateChanged(
room: Room,
val participant: Participant,
val newState: Participant.State,
val oldState: Participant.State,
) : RoomEvent(room)
}
enum class DisconnectReason {
... ... @@ -288,6 +298,7 @@ enum class DisconnectReason {
USER_UNAVAILABLE,
USER_REJECTED,
SIP_TRUNK_FAILURE,
CONNECTION_TIMEOUT,
}
/**
... ... @@ -308,6 +319,7 @@ fun LivekitModels.DisconnectReason?.convert(): DisconnectReason {
LivekitModels.DisconnectReason.USER_UNAVAILABLE -> DisconnectReason.USER_UNAVAILABLE
LivekitModels.DisconnectReason.USER_REJECTED -> DisconnectReason.USER_REJECTED
LivekitModels.DisconnectReason.SIP_TRUNK_FAILURE -> DisconnectReason.SIP_TRUNK_FAILURE
LivekitModels.DisconnectReason.CONNECTION_TIMEOUT -> DisconnectReason.CONNECTION_TIMEOUT
LivekitModels.DisconnectReason.UNKNOWN_REASON,
LivekitModels.DisconnectReason.UNRECOGNIZED,
null,
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -24,6 +24,9 @@ import io.livekit.android.room.track.ScreenSharePresets
import javax.inject.Inject
import javax.inject.Singleton
/**
* @suppress
*/
@Singleton
class DefaultsManager
@Inject
... ... @@ -34,4 +37,6 @@ constructor() {
var videoTrackPublishDefaults: VideoTrackPublishDefaults = VideoTrackPublishDefaults()
var screenShareTrackCaptureDefaults: LocalVideoTrackOptions = LocalVideoTrackOptions(isScreencast = true, captureParams = ScreenSharePresets.ORIGINAL.capture)
var screenShareTrackPublishDefaults: VideoTrackPublishDefaults = VideoTrackPublishDefaults(videoEncoding = ScreenSharePresets.ORIGINAL.encoding)
var isPrerecording: Boolean = false
}
... ...
... ... @@ -110,7 +110,7 @@ internal constructor(
* Reflects the combined connection state of SignalClient and primary PeerConnection.
*/
@FlowObservable
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
@get:FlowObservable
var connectionState: ConnectionState by flowDelegate(ConnectionState.DISCONNECTED) { newVal, oldVal ->
if (newVal == oldVal) {
return@flowDelegate
... ... @@ -302,7 +302,8 @@ internal constructor(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
reliableDataChannelManager = DataChannelManager(dataChannel, DataChannelObserver(dataChannel))
dataChannel.registerObserver(reliableDataChannelManager)
}
}
... ... @@ -315,7 +316,8 @@ internal constructor(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
lossyDataChannelManager = DataChannelManager(dataChannel, DataChannelObserver(dataChannel))
dataChannel.registerObserver(lossyDataChannelManager)
}
}
}
... ... @@ -1095,6 +1097,7 @@ internal constructor(
-> {
listener?.onRpcPacketReceived(dp)
}
LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
null,
-> {
... ...
... ... @@ -76,7 +76,7 @@ class Room
@AssistedInject
constructor(
@Assisted private val context: Context,
private val engine: RTCEngine,
internal val engine: RTCEngine,
private val eglBase: EglBase,
localParticipantFactory: LocalParticipant.Factory,
private val defaultsManager: DefaultsManager,
... ... @@ -320,6 +320,8 @@ constructor(
private var transcriptionReceivedTimes = mutableMapOf<String, Long>()
internal var isPrerecording by defaultsManager::isPrerecording
private fun getCurrentRoomOptions(): RoomOptions =
RoomOptions(
adaptiveStream = adaptiveStream,
... ... @@ -678,6 +680,15 @@ constructor(
)
}
is ParticipantEvent.StateChanged -> {
RoomEvent.ParticipantStateChanged(
this@Room,
it.participant,
it.newState,
it.oldState,
)
}
else -> {
// do nothing
}
... ... @@ -808,6 +819,17 @@ constructor(
),
)
is ParticipantEvent.StateChanged -> {
eventBus.postEvent(
RoomEvent.ParticipantStateChanged(
room = this@Room,
participant = it.participant,
newState = it.newState,
oldState = it.oldState,
),
)
}
else -> {
// do nothing
}
... ...
... ... @@ -780,6 +780,10 @@ constructor(
-> {
LKLog.v { "empty messageCase!" }
}
LivekitRtc.SignalResponse.MessageCase.ROOM_MOVED -> {
// TODO
}
}
}
... ...
... ... @@ -45,4 +45,7 @@ interface StreamDestination<T> {
suspend fun close(reason: String?)
}
internal typealias DataChunker<T> = (data: T, chunkSize: Int) -> List<ByteArray>
/**
* @suppress
*/
typealias DataChunker<T> = (data: T, chunkSize: Int) -> List<ByteArray>
... ...
... ... @@ -24,6 +24,7 @@ import okio.Source
import okio.source
import java.io.InputStream
import java.util.Arrays
import kotlin.math.min
class ByteStreamSender(
val info: ByteStreamInfo,
... ... @@ -36,7 +37,16 @@ class ByteStreamSender(
private val byteDataChunker: DataChunker<ByteArray> = { data: ByteArray, chunkSize: Int ->
(data.indices step chunkSize)
.map { index -> Arrays.copyOfRange(data, index, index + chunkSize) }
.map { index ->
Arrays.copyOfRange(
/* original = */
data,
/* from = */
index,
/* to = */
min(index + chunkSize, data.size),
)
}
}
/**
... ...
... ... @@ -67,6 +67,7 @@ import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import livekit.LivekitModels
import livekit.LivekitModels.AudioTrackFeature
import livekit.LivekitModels.Codec
import livekit.LivekitModels.DataPacket
import livekit.LivekitModels.TrackInfo
... ... @@ -434,7 +435,7 @@ internal constructor(
options: AudioTrackPublishOptions = AudioTrackPublishOptions(
null,
audioTrackPublishDefaults,
),
).copy(preconnect = defaultsManager.isPrerecording),
publishListener: PublishListener? = null,
): Boolean {
val encodings = listOf(
... ... @@ -450,6 +451,7 @@ internal constructor(
requestConfig = {
disableDtx = !options.dtx
disableRed = !options.red
addAllAudioFeatures(options.getFeaturesList())
source = options.source?.toProto() ?: LivekitModels.TrackSource.MICROPHONE
},
encodings = encodings,
... ... @@ -459,7 +461,7 @@ internal constructor(
if (publication != null) {
val job = scope.launch {
track::features.flow.collect {
engine.updateLocalAudioTrack(publication.sid, it)
engine.updateLocalAudioTrack(publication.sid, it + options.getFeaturesList())
}
}
jobs[publication] = job
... ... @@ -1763,6 +1765,7 @@ data class AudioTrackPublishOptions(
override val red: Boolean = true,
override val source: Track.Source? = null,
override val stream: String? = null,
val preconnect: Boolean = false,
) : BaseAudioTrackPublishOptions(), TrackPublishOptions {
constructor(
name: String? = null,
... ... @@ -1777,6 +1780,17 @@ data class AudioTrackPublishOptions(
source = source,
stream = stream,
)
internal fun getFeaturesList(): Set<AudioTrackFeature> {
val features = mutableSetOf<AudioTrackFeature>()
if (!dtx) {
features.add(AudioTrackFeature.TF_NO_DTX)
}
if (preconnect) {
features.add(AudioTrackFeature.TF_PRECONNECT_BUFFER)
}
return features
}
}
data class ParticipantTrackPermission(
... ...
... ... @@ -93,6 +93,27 @@ open class Participant(
@VisibleForTesting set
/**
* The participant state.
*
* Changes can be observed by using [io.livekit.android.util.flow]
*/
@FlowObservable
@get:FlowObservable
var state: State by flowDelegate(State.UNKNOWN) { newState, oldState ->
if (newState != oldState) {
eventBus.postEvent(
ParticipantEvent.StateChanged(
participant = this,
newState = newState,
oldState = oldState,
),
scope,
)
}
}
@VisibleForTesting set
/**
* Changes can be observed by using [io.livekit.android.util.flow]
*/
@FlowObservable
... ... @@ -377,6 +398,7 @@ open class Participant(
permissions = ParticipantPermission.fromProto(info.permission)
}
attributes = info.attributesMap
state = State.fromProto(info.state)
}
override fun equals(other: Any?): Boolean {
... ... @@ -474,6 +496,34 @@ open class Participant(
}
}
}
enum class State {
// websocket' connected, but not offered yet
JOINING,
// server received client offer
JOINED,
// ICE connectivity established
ACTIVE,
// WS disconnected
DISCONNECTED,
UNKNOWN;
companion object {
fun fromProto(proto: LivekitModels.ParticipantInfo.State): State {
return when (proto) {
LivekitModels.ParticipantInfo.State.JOINING -> JOINING
LivekitModels.ParticipantInfo.State.JOINED -> JOINED
LivekitModels.ParticipantInfo.State.ACTIVE -> ACTIVE
LivekitModels.ParticipantInfo.State.DISCONNECTED -> DISCONNECTED
LivekitModels.ParticipantInfo.State.UNRECOGNIZED -> UNKNOWN
}
}
}
}
}
/**
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.test.mock.room.datastream.outgoing
import io.livekit.android.room.datastream.outgoing.DataChunker
import io.livekit.android.room.datastream.outgoing.StreamDestination
class MockStreamDestination<T>(val chunkSize: Int) : StreamDestination<T> {
override var isOpen: Boolean = true
val writtenChunks = mutableListOf<ByteArray>()
override suspend fun write(data: T, chunker: DataChunker<T>) {
val chunks = chunker.invoke(data, chunkSize)
for (chunk in chunks) {
writtenChunks.add(chunk)
}
}
override suspend fun close(reason: String?) {
isOpen = false
}
}
... ...
... ... @@ -14,10 +14,10 @@
* limitations under the License.
*/
package io.livekit.android.room
package io.livekit.android.room.datastream
import com.google.protobuf.ByteString
import io.livekit.android.room.datastream.StreamException
import io.livekit.android.room.RTCEngine
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClass
import io.livekit.android.test.mock.MockDataChannel
... ... @@ -38,7 +38,7 @@ import org.junit.Test
import java.nio.ByteBuffer
@OptIn(ExperimentalCoroutinesApi::class)
class RoomDataStreamMockE2ETest : MockE2ETest() {
class RoomIncomingDataStreamMockE2ETest : MockE2ETest() {
@Test
fun dataStream() = runTest {
connect()
... ...
/*
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room.datastream
import com.google.protobuf.ByteString
import io.livekit.android.room.RTCEngine
import io.livekit.android.room.participant.Participant
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test
@OptIn(ExperimentalCoroutinesApi::class)
class RoomOutgoingDataStreamMockE2ETest : MockE2ETest() {
private lateinit var pubDataChannel: MockDataChannel
override suspend fun connect(joinResponse: LivekitRtc.SignalResponse) {
super.connect(joinResponse)
val pubPeerConnection = component.rtcEngine().getPublisherPeerConnection() as MockPeerConnection
pubDataChannel = pubPeerConnection.dataChannels[RTCEngine.RELIABLE_DATA_CHANNEL_LABEL] as MockDataChannel
}
@Test
fun dataStream() = runTest {
connect()
// Remote participant to send data to
wsFactory.listener.onMessage(
wsFactory.ws,
TestData.PARTICIPANT_JOIN.toOkioByteString(),
)
val bytesToStream = ByteArray(100)
for (i in bytesToStream.indices) {
bytesToStream[i] = i.toByte()
}
val job = launch {
val sender = room.localParticipant.streamBytes(
StreamBytesOptions(
topic = "topic",
attributes = mapOf("hello" to "world"),
streamId = "stream_id",
destinationIdentities = listOf(Participant.Identity(TestData.REMOTE_PARTICIPANT.identity)),
name = "stream_name",
totalSize = bytesToStream.size.toLong(),
),
)
sender.write(bytesToStream)
sender.close()
}
job.join()
val buffers = pubDataChannel.sentBuffers
println(buffers)
assertEquals(3, buffers.size)
val headerPacket = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
assertTrue(headerPacket.hasStreamHeader())
with(headerPacket.streamHeader) {
assertTrue(hasByteHeader())
}
val payloadPacket = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
assertTrue(payloadPacket.hasStreamChunk())
with(payloadPacket.streamChunk) {
assertEquals(100, content.size())
}
val trailerPacket = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[2].data))
assertTrue(trailerPacket.hasStreamTrailer())
with(trailerPacket.streamTrailer) {
assertTrue(reason.isNullOrEmpty())
}
}
}
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room.datastream.outgoing
import io.livekit.android.room.datastream.ByteStreamInfo
import io.livekit.android.test.BaseTest
import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination
import kotlinx.coroutines.launch
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Test
import kotlin.math.roundToInt
class ByteStreamSenderTest : BaseTest() {
companion object {
val CHUNK_SIZE = 2048
}
@Test
fun sendsSmallBytes() = runTest {
val destination = MockStreamDestination<ByteArray>(CHUNK_SIZE)
val sender = ByteStreamSender(
info = createInfo(),
destination = destination,
)
val job = launch {
sender.write(ByteArray(100))
sender.close()
}
job.join()
assertFalse(destination.isOpen)
assertEquals(1, destination.writtenChunks.size)
assertEquals(100, destination.writtenChunks[0].size)
}
@Test
fun sendsLargeBytes() = runTest {
val destination = MockStreamDestination<ByteArray>(CHUNK_SIZE)
val sender = ByteStreamSender(
info = createInfo(),
destination = destination,
)
val bytes = ByteArray((CHUNK_SIZE * 1.5).roundToInt())
val job = launch {
sender.write(bytes)
sender.close()
}
job.join()
assertFalse(destination.isOpen)
assertEquals(2, destination.writtenChunks.size)
assertEquals(CHUNK_SIZE, destination.writtenChunks[0].size)
assertEquals(bytes.size - CHUNK_SIZE, destination.writtenChunks[1].size)
}
fun createInfo(): ByteStreamInfo = ByteStreamInfo(id = "stream_id", topic = "topic", timestampMs = 0, totalSize = null, attributes = mapOf(), mimeType = "", name = null)
}
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room.datastream.outgoing
import io.livekit.android.room.datastream.TextStreamInfo
import io.livekit.android.test.BaseTest
import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination
import kotlinx.coroutines.launch
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNotEquals
import org.junit.Test
class TextStreamSenderTest : BaseTest() {
companion object {
val CHUNK_SIZE = 20
}
@Test
fun sendsSingle() = runTest {
val destination = MockStreamDestination<String>(CHUNK_SIZE)
val sender = TextStreamSender(
info = createInfo(),
destination = destination,
)
val text = "abcdefghi"
val job = launch {
sender.write(text)
sender.close()
}
job.join()
assertFalse(destination.isOpen)
assertEquals(1, destination.writtenChunks.size)
assertEquals(text, destination.writtenChunks[0].decodeToString())
}
@Test
fun sendsChunks() = runTest {
val destination = MockStreamDestination<String>(CHUNK_SIZE)
val sender = TextStreamSender(
info = createInfo(),
destination = destination,
)
val text = with(StringBuilder()) {
for (i in 1..CHUNK_SIZE) {
append("abcdefghi")
}
toString()
}
val job = launch {
sender.write(text)
sender.close()
}
job.join()
assertFalse(destination.isOpen)
assertNotEquals(1, destination.writtenChunks.size)
val writtenString = with(StringBuilder()) {
for (chunk in destination.writtenChunks) {
append(chunk.decodeToString())
}
toString()
}
assertEquals(text, writtenString)
}
fun createInfo(): TextStreamInfo = TextStreamInfo(
id = "stream_id",
topic = "topic",
timestampMs = 0,
totalSize = null,
attributes = mapOf(),
operationType = TextStreamInfo.OperationType.CREATE,
version = 0,
replyToStreamId = null,
attachedStreamIds = listOf(),
generated = false,
)
}
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -58,7 +58,7 @@ class ParticipantTest {
assertEquals(INFO.name, participant.name)
assertEquals(Participant.Kind.fromProto(INFO.kind), participant.kind)
assertEquals(INFO.attributesMap, participant.attributes)
assertEquals(Participant.State.fromProto(INFO.state), participant.state)
assertEquals(INFO, participant.participantInfo)
}
... ... @@ -151,6 +151,23 @@ class ParticipantTest {
}
@Test
fun setStateChangedEvent() = runTest {
val eventCollector = EventCollector(participant.events, coroutineRule.scope)
participant.state = Participant.State.JOINED
val events = eventCollector.stopCollecting()
assertEquals(1, events.size)
assertEquals(true, events[0] is ParticipantEvent.StateChanged)
val event = events[0] as ParticipantEvent.StateChanged
assertEquals(participant, event.participant)
assertEquals(Participant.State.JOINED, event.newState)
assertEquals(Participant.State.UNKNOWN, event.oldState)
}
@Test
fun addTrackPublication() = runTest {
val audioPublication = TrackPublication(TRACK_INFO, null, participant)
participant.addTrackPublication(audioPublication)
... ... @@ -197,6 +214,7 @@ class ParticipantTest {
.setName("name")
.setKind(LivekitModels.ParticipantInfo.Kind.STANDARD)
.putAttributes("attribute", "value")
.setState(LivekitModels.ParticipantInfo.State.JOINED)
.build()
val TRACK_INFO = LivekitModels.TrackInfo.newBuilder()
... ...
Subproject commit 02ee5e6947593443d0dfc90cae0b27ce03b6c1fe
Subproject commit 499c17c48063582ac2af0a021827fab18356cc29
... ...