davidliu
Committed by GitHub

Refactor sendData to return Result instead of throwing exceptions (#703)

* Switch away from throwing exceptions internally where not needed to avoid crashes.

* changeset
正在显示 18 个修改的文件 包含 324 行增加135 行删除
---
"client-sdk-android": minor
---
Refactor some internal data message sending methods to use Result instead of throwing Exceptions to
fix crashes.
... ...
... ... @@ -148,7 +148,10 @@ internal constructor(timeout: Duration) : AudioTrackSink {
)
try {
sender.write(audioData)
val result = sender.write(audioData)
if (result.isFailure) {
result.exceptionOrNull()?.let { throw it }
}
sender.close()
} catch (e: Exception) {
sender.close(e.localizedMessage)
... ...
... ... @@ -17,6 +17,7 @@
package io.livekit.android.room
import android.os.SystemClock
import androidx.annotation.CheckResult
import androidx.annotation.VisibleForTesting
import com.google.protobuf.ByteString
import com.vdurmont.semver4j.Semver
... ... @@ -82,7 +83,6 @@ import livekit.org.webrtc.RtpSender
import livekit.org.webrtc.RtpTransceiver
import livekit.org.webrtc.RtpTransceiver.RtpTransceiverInit
import livekit.org.webrtc.SessionDescription
import java.net.ConnectException
import java.nio.ByteBuffer
import javax.inject.Inject
import javax.inject.Named
... ... @@ -445,7 +445,7 @@ internal constructor(
* reconnect Signal and PeerConnections
*/
@Synchronized
@VisibleForTesting
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
fun reconnect() {
if (reconnectingJob?.isActive == true) {
LKLog.d { "Reconnection is already in progress" }
... ... @@ -625,22 +625,32 @@ internal constructor(
}
}
internal suspend fun sendData(dataPacket: LivekitModels.DataPacket) {
ensurePublisherConnected(dataPacket.kind)
@CheckResult
internal suspend fun sendData(dataPacket: LivekitModels.DataPacket): Result<Unit> {
try {
ensurePublisherConnected(dataPacket.kind)
val buf = DataChannel.Buffer(
ByteBuffer.wrap(dataPacket.toByteArray()),
true,
)
val buf = DataChannel.Buffer(
ByteBuffer.wrap(dataPacket.toByteArray()),
true,
)
val channel = dataChannelForKind(dataPacket.kind)
?: throw TrackException.PublishException("channel not established for ${dataPacket.kind.name}")
val channel = dataChannelForKind(dataPacket.kind)
?: throw RoomException.ConnectException("channel not established for ${dataPacket.kind.name}")
channel.send(buf)
channel.send(buf)
} catch (e: Exception) {
return Result.failure(e)
}
return Result.success(Unit)
}
internal suspend fun waitForBufferStatusLow(kind: LivekitModels.DataPacket.Kind) {
ensurePublisherConnected(kind)
try {
ensurePublisherConnected(kind)
} catch (e: Exception) {
return
}
val manager = when (kind) {
LivekitModels.DataPacket.Kind.RELIABLE -> reliableDataChannelManager
LivekitModels.DataPacket.Kind.LOSSY -> lossyDataChannelManager
... ... @@ -650,11 +660,12 @@ internal constructor(
}
if (manager == null) {
throw IllegalStateException("Not connected!")
return
}
manager.waitForBufferedAmountLow(DATA_CHANNEL_LOW_THRESHOLD.toLong())
}
@Throws(exceptionClasses = [RoomException.ConnectException::class])
private suspend fun ensurePublisherConnected(kind: LivekitModels.DataPacket.Kind) {
if (!isSubscriberPrimary) {
return
... ... @@ -671,7 +682,7 @@ internal constructor(
this.negotiatePublisher()
}
val targetChannel = dataChannelForKind(kind) ?: throw IllegalArgumentException("Unknown data packet kind!")
val targetChannel = dataChannelForKind(kind) ?: throw RoomException.ConnectException("Publisher isn't setup yet! Is room not connected?!")
if (targetChannel.state() == DataChannel.State.OPEN) {
return
}
... ... @@ -685,14 +696,14 @@ internal constructor(
delay(50)
}
throw ConnectException("could not establish publisher connection")
throw RoomException.ConnectException("could not establish publisher connection")
}
private fun dataChannelForKind(kind: LivekitModels.DataPacket.Kind) =
when (kind) {
LivekitModels.DataPacket.Kind.RELIABLE -> reliableDataChannel
LivekitModels.DataPacket.Kind.LOSSY -> lossyDataChannel
else -> null
LivekitModels.DataPacket.Kind.UNRECOGNIZED -> throw IllegalArgumentException("Unknown data packet kind!")
}
private fun getPublisherOfferConstraints(): MediaConstraints {
... ... @@ -1137,6 +1148,7 @@ internal constructor(
val dataChannelInfos = LivekitModels.DataPacket.Kind.values()
.toList()
.filterNot { it == LivekitModels.DataPacket.Kind.UNRECOGNIZED }
.mapNotNull { kind -> dataChannelForKind(kind) }
.map { dataChannel ->
LivekitRtc.DataChannelInfo.newBuilder()
... ...
... ... @@ -17,13 +17,48 @@
package io.livekit.android.room.datastream
sealed class StreamException(message: String? = null) : Exception(message) {
/**
* Unable to open a stream with the same ID as an existing open stream.
*/
class AlreadyOpenedException : StreamException()
class AbnormalEndException(message: String?) : StreamException(message)
/**
* Stream closed abnormally by remote participant.
*/
class AbnormalEndException(message: String? = null) : StreamException(message)
/**
* Incoming chunk data could not be decoded.
*/
class DecodeFailedException : StreamException()
/**
* Length exceeded total length specified in stream header.
*/
class LengthExceededException : StreamException()
/**
* Length is less than total length specified in stream header.
*/
class IncompleteException : StreamException()
class TerminatedException : StreamException()
/**
* Stream terminated before completion.
*/
class TerminatedException(message: String? = null) : StreamException(message)
/**
* Cannot perform operations on an unknown stream.
*/
class UnknownStreamException : StreamException()
/**
* Given destination URL is not a directory.
*/
class NotDirectoryException : StreamException()
/**
* Unable to read information about the file to send.
*/
class FileInfoUnavailableException : StreamException()
}
... ...
... ... @@ -16,21 +16,31 @@
package io.livekit.android.room.datastream.outgoing
import androidx.annotation.CheckResult
import io.livekit.android.room.datastream.StreamException
abstract class BaseStreamSender<T>(
internal val destination: StreamDestination<T>,
) {
suspend fun write(data: T) {
val isOpen: Boolean
get() = destination.isOpen
/**
* Write to the stream.
*/
@CheckResult
suspend fun write(data: T): Result<Unit> {
if (!destination.isOpen) {
throw StreamException.TerminatedException()
return Result.failure(StreamException.TerminatedException())
}
writeImpl(data)
return writeImpl(data)
}
internal abstract suspend fun writeImpl(data: T)
@CheckResult
internal abstract suspend fun writeImpl(data: T): Result<Unit>
suspend fun close(reason: String? = null) {
destination.close(reason)
}
... ... @@ -41,7 +51,9 @@ abstract class BaseStreamSender<T>(
*/
interface StreamDestination<T> {
val isOpen: Boolean
suspend fun write(data: T, chunker: DataChunker<T>)
@CheckResult
suspend fun write(data: T, chunker: DataChunker<T>): Result<Unit>
suspend fun close(reason: String?)
}
... ...
... ... @@ -30,8 +30,9 @@ class ByteStreamSender(
val info: ByteStreamInfo,
destination: StreamDestination<ByteArray>,
) : BaseStreamSender<ByteArray>(destination = destination) {
override suspend fun writeImpl(data: ByteArray) {
destination.write(data, byteDataChunker)
override suspend fun writeImpl(data: ByteArray): Result<Unit> {
return destination.write(data, byteDataChunker)
}
}
... ... @@ -50,9 +51,7 @@ private val byteDataChunker: DataChunker<ByteArray> = { data: ByteArray, chunkSi
}
/**
* Reads the file and writes it to the data stream.
*
* @throws
* Reads the file from [filePath] and writes it to the data stream.
*/
suspend fun ByteStreamSender.writeFile(filePath: String) {
write(FileSystem.SYSTEM.source(filePath.toPath()))
... ... @@ -68,14 +67,23 @@ suspend fun ByteStreamSender.write(input: InputStream) {
/**
* Reads the source and sends it to the data stream.
*/
suspend fun ByteStreamSender.write(source: Source) {
suspend fun ByteStreamSender.write(source: Source): Result<Unit> {
val buffer = Buffer()
while (true) {
val readLen = source.read(buffer, 4096)
if (readLen == -1L) {
break
}
try {
val readLen = source.read(buffer, 4096)
if (readLen == -1L) {
break
}
write(buffer.readByteArray())
val result = write(buffer.readByteArray())
if (result.isFailure) {
return result
}
} catch (e: Exception) {
return Result.failure(e)
}
}
return Result.success(Unit)
}
... ...
... ... @@ -16,6 +16,7 @@
package io.livekit.android.room.datastream.outgoing
import androidx.annotation.CheckResult
import com.google.protobuf.ByteString
import io.livekit.android.room.RTCEngine
import io.livekit.android.room.datastream.ByteStreamInfo
... ... @@ -36,11 +37,15 @@ import javax.inject.Inject
interface OutgoingDataStreamManager {
/**
* Start sending a stream of text. Call [TextStreamSender.close] when finished sending.
*
* @throws StreamException if the stream failed to open.
*/
suspend fun streamText(options: StreamTextOptions = StreamTextOptions()): TextStreamSender
/**
* Start sending a stream of bytes. Call [ByteStreamSender.close] when finished sending.
*
* @throws StreamException if the stream failed to open.
*/
suspend fun streamBytes(options: StreamBytesOptions): ByteStreamSender
}
... ... @@ -63,10 +68,11 @@ constructor(
private val openStreams = Collections.synchronizedMap(mutableMapOf<String, Descriptor>())
@CheckResult
private suspend fun openStream(
info: StreamInfo,
destinationIdentities: List<Participant.Identity> = emptyList(),
) {
): Result<Unit> {
if (openStreams.containsKey(info.id)) {
throw StreamException.AlreadyOpenedException()
}
... ... @@ -112,15 +118,20 @@ constructor(
build()
}
engine.sendData(headerPacket)
val result = engine.sendData(headerPacket)
if (result.isFailure) {
return result
}
val descriptor = Descriptor(info, destinationIdentityStrings)
openStreams[info.id] = descriptor
LKLog.d { "Opened send stream ${info.id}" }
return Result.success(Unit)
}
private suspend fun sendChunk(streamId: String, dataChunk: ByteArray) {
@CheckResult
private suspend fun sendChunk(streamId: String, dataChunk: ByteArray): Result<Unit> {
val descriptor = openStreams[streamId] ?: throw StreamException.UnknownStreamException()
val nextChunkIndex = descriptor.nextChunkIndex.getAndIncrement()
... ... @@ -137,7 +148,7 @@ constructor(
}
engine.waitForBufferStatusLow(DataPacket.Kind.RELIABLE)
engine.sendData(chunkPacket)
return engine.sendData(chunkPacket)
}
private suspend fun closeStream(streamId: String, reason: String? = null) {
... ... @@ -157,7 +168,12 @@ constructor(
}
engine.waitForBufferStatusLow(DataPacket.Kind.RELIABLE)
engine.sendData(trailerPacket)
val result = engine.sendData(trailerPacket)
if (result.isFailure) {
// Log close failure only for now.
LKLog.w(result.exceptionOrNull()) { "Error when closing stream!" }
}
openStreams.remove(streamId)
LKLog.d { "Closed send stream $streamId" }
... ... @@ -178,7 +194,11 @@ constructor(
)
val streamId = options.streamId
openStream(streamInfo, options.destinationIdentities)
val result = openStream(streamInfo, options.destinationIdentities)
if (result.isFailure) {
throw result.exceptionOrNull() ?: StreamException.TerminatedException("Unknown failure when opening the stream!")
}
val destination = ManagerStreamDestination<String>(streamId)
return TextStreamSender(
... ... @@ -199,8 +219,11 @@ constructor(
)
val streamId = options.streamId
openStream(streamInfo, options.destinationIdentities)
val result = openStream(streamInfo, options.destinationIdentities)
if (result.isFailure) {
throw result.exceptionOrNull() ?: StreamException.TerminatedException("Unknown failure when opening the stream!")
}
val destination = ManagerStreamDestination<ByteArray>(streamId)
return ByteStreamSender(
streamInfo,
... ... @@ -212,12 +235,19 @@ constructor(
override val isOpen: Boolean
get() = openStreams.contains(streamId)
override suspend fun write(data: T, chunker: DataChunker<T>) {
override suspend fun write(data: T, chunker: DataChunker<T>): Result<Unit> {
if (!isOpen) {
return Result.failure(StreamException.TerminatedException("Stream is closed!"))
}
val chunks = chunker.invoke(data, RTCEngine.MAX_DATA_PACKET_SIZE)
for (chunk in chunks) {
sendChunk(streamId, chunk)
val result = sendChunk(streamId, chunk)
if (result.isFailure) {
return result
}
}
return Result.success(Unit)
}
override suspend fun close(reason: String?) {
... ...
... ... @@ -23,8 +23,8 @@ class TextStreamSender(
val info: TextStreamInfo,
destination: StreamDestination<String>,
) : BaseStreamSender<String>(destination) {
override suspend fun writeImpl(data: String) {
destination.write(data, stringChunker)
override suspend fun writeImpl(data: String): Result<Unit> {
return destination.write(data, stringChunker)
}
}
... ...
/*
* Copyright 2024 LiveKit, Inc.
* Copyright 2024-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -68,7 +68,10 @@ private suspend fun collectPublisherMetrics(room: Room, rtcEngine: RTCEngine) {
}
try {
rtcEngine.sendData(dataPacket)
val result = rtcEngine.sendData(dataPacket)
result.exceptionOrNull()?.let {
throw it
}
} catch (e: Exception) {
LKLog.i(e) { "Error sending metrics: " }
}
... ... @@ -98,7 +101,10 @@ private suspend fun collectSubscriberMetrics(room: Room, rtcEngine: RTCEngine) {
}
try {
rtcEngine.sendData(dataPacket)
val result = rtcEngine.sendData(dataPacket)
result.exceptionOrNull()?.let {
throw it
}
} catch (e: Exception) {
LKLog.i(e) { "Error sending metrics: " }
}
... ...
... ... @@ -19,6 +19,7 @@ package io.livekit.android.room.participant
import android.Manifest
import android.content.Context
import android.content.Intent
import androidx.annotation.CheckResult
import androidx.annotation.VisibleForTesting
import com.google.protobuf.ByteString
import com.vdurmont.semver4j.Semver
... ... @@ -62,7 +63,6 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
... ... @@ -315,7 +315,8 @@ internal constructor(
*
* For screenshare audio, a [ScreenAudioCapturer] can be used.
*
* @param mediaProjectionPermissionResultData The resultData returned from launching
* @param screenCaptureParams When enabling the screenshare, this must be provided with
* [ScreenCaptureParams.mediaProjectionPermissionResultData] containing resultData returned from launching
* [MediaProjectionManager.createScreenCaptureIntent()](https://developer.android.com/reference/android/media/projection/MediaProjectionManager#createScreenCaptureIntent()).
* @throws IllegalArgumentException if attempting to enable screenshare without [mediaProjectionPermissionResultData]
* @see Room.screenShareTrackCaptureDefaults
... ... @@ -904,14 +905,17 @@ internal constructor(
* @param reliability for delivery guarantee, use RELIABLE. for fastest delivery without guarantee, use LOSSY
* @param topic the topic under which the message was published
* @param identities list of participant identities to deliver the payload, null to deliver to everyone
*
* @return A [Result] that succeeds if the publish succeeded, or a failure containing the exception.
*/
@Suppress("unused")
@CheckResult
suspend fun publishData(
data: ByteArray,
reliability: DataPublishReliability = DataPublishReliability.RELIABLE,
topic: String? = null,
identities: List<Identity>? = null,
) {
): Result<Unit> {
if (data.size > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
... ... @@ -935,7 +939,7 @@ internal constructor(
.setKind(kind)
.build()
engine.sendData(dataPacket)
return engine.sendData(dataPacket)
}
/**
... ... @@ -946,13 +950,16 @@ internal constructor(
*
* @param code an integer representing the DTMF signal code
* @param digit the string representing the DTMF digit (e.g., "1", "#", "*")
*
* @return A [Result] that succeeds if the publish succeeded, or a failure containing the exception.
*/
@Suppress("unused")
@CheckResult
suspend fun publishDtmf(
code: Int,
digit: String,
) {
): Result<Unit> {
val sipDTMF = LivekitModels.SipDTMF.newBuilder().setCode(code)
.setDigit(digit)
.build()
... ... @@ -962,7 +969,7 @@ internal constructor(
.setKind(LivekitModels.DataPacket.Kind.RELIABLE)
.build()
engine.sendData(dataPacket)
return engine.sendData(dataPacket)
}
/**
... ... @@ -998,7 +1005,6 @@ internal constructor(
* @see RpcInvocationData
* @see performRpc
*/
@Suppress("RedundantSuspendModifier")
override suspend fun registerRpcMethod(
method: String,
handler: RpcHandler,
... ... @@ -1088,7 +1094,7 @@ internal constructor(
val requestId = UUID.randomUUID().toString()
publishRpcRequest(
val result = publishRpcRequest(
destinationIdentity = destinationIdentity,
requestId = requestId,
method = method,
... ... @@ -1096,6 +1102,12 @@ internal constructor(
responseTimeout = responseTimeout - maxRoundTripLatency,
)
if (result.isFailure) {
val exception = result.exceptionOrNull() as? RpcError
?: RpcError.BuiltinRpcError.SEND_FAILED.create(data = "Error while sending rpc request.", cause = result.exceptionOrNull())
throw exception
}
val responsePayload = suspendCancellableCoroutine { continuation ->
var ackTimeoutJob: Job? = null
var responseTimeoutJob: Job? = null
... ... @@ -1149,13 +1161,25 @@ internal constructor(
return@coroutineScope responsePayload
}
@CheckResult
private suspend fun rpcSendData(dataPacket: DataPacket): Result<Unit> {
val result = engine.sendData(dataPacket)
return if (result.isFailure) {
Result.failure(RpcError.BuiltinRpcError.SEND_FAILED.create(cause = result.exceptionOrNull()))
} else {
result
}
}
@CheckResult
private suspend fun publishRpcRequest(
destinationIdentity: Identity,
requestId: String,
method: String,
payload: String,
responseTimeout: Duration = 10.seconds,
) {
): Result<Unit> {
if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
... ... @@ -1174,15 +1198,16 @@ internal constructor(
build()
}
engine.sendData(dataPacket)
return rpcSendData(dataPacket)
}
@CheckResult
private suspend fun publishRpcResponse(
destinationIdentity: Identity,
requestId: String,
payload: String?,
error: RpcError?,
) {
): Result<Unit> {
if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
... ... @@ -1202,13 +1227,14 @@ internal constructor(
build()
}
engine.sendData(dataPacket)
return rpcSendData(dataPacket)
}
@CheckResult
private suspend fun publishRpcAck(
destinationIdentity: Identity,
requestId: String,
) {
): Result<Unit> {
val dataPacket = with(DataPacket.newBuilder()) {
addDestinationIdentities(destinationIdentity.value)
kind = DataPacket.Kind.RELIABLE
... ... @@ -1219,7 +1245,7 @@ internal constructor(
build()
}
engine.sendData(dataPacket)
return rpcSendData(dataPacket)
}
private fun handleIncomingRpcAck(requestId: String) {
... ... @@ -1252,7 +1278,12 @@ internal constructor(
responseTimeout: Duration,
version: Int,
) {
publishRpcAck(callerIdentity, requestId)
publishRpcAck(callerIdentity, requestId).also { result ->
if (result.isFailure) {
LKLog.w(result.exceptionOrNull()) { "Error sending ack for request $requestId." }
return
}
}
if (version != RpcManager.RPC_VERSION) {
publishRpcResponse(
... ... @@ -1260,7 +1291,12 @@ internal constructor(
requestId = requestId,
payload = null,
error = RpcError.BuiltinRpcError.UNSUPPORTED_VERSION.create(),
)
).also { result ->
if (result.isFailure) {
LKLog.w(result.exceptionOrNull()) { "Error sending error response for request $requestId." }
}
}
return
}
... ... @@ -1272,7 +1308,12 @@ internal constructor(
requestId = requestId,
payload = null,
error = RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(),
)
).also { result ->
if (result.isFailure) {
LKLog.w(result.exceptionOrNull()) { "Error sending error response for request $requestId." }
}
}
return
}
... ... @@ -1309,7 +1350,11 @@ internal constructor(
requestId = requestId,
payload = responsePayload,
error = responseError,
)
).also { result ->
if (result.isFailure) {
LKLog.w(result.exceptionOrNull()) { "Error sending error response for request $requestId." }
}
}
}
internal fun handleParticipantDisconnect(identity: Identity) {
... ...
... ... @@ -44,6 +44,12 @@ data class RpcError(
* An optional data payload. Must be smaller than 15KB in size, or else will be truncated.
*/
val data: String = "",
/**
* The local cause of the error, if any. This will not be passed over the wire to the remote.
*/
override val cause: Throwable? = null,
) : Exception(message) {
enum class BuiltinRpcError(val code: Int, val message: String) {
... ... @@ -61,8 +67,8 @@ data class RpcError(
UNSUPPORTED_VERSION(1404, "Unsupported RPC version"),
;
fun create(data: String = ""): RpcError {
return RpcError(code, message, data)
fun create(data: String = "", cause: Throwable? = null): RpcError {
return RpcError(code, message, data, cause)
}
}
... ...
... ... @@ -74,6 +74,7 @@ class DataChannelManager(
return
}
disposed = true
bufferedAmount = 0
}
executeOnRTCThread {
dataChannel.unregisterObserver()
... ...
... ... @@ -27,7 +27,6 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
}
override fun unregisterObserver() {
observer = null
}
override fun label(): String? {
... ...
... ... @@ -16,6 +16,7 @@
package io.livekit.android.test.mock.room.datastream.outgoing
import io.livekit.android.room.datastream.StreamException
import io.livekit.android.room.datastream.outgoing.DataChunker
import io.livekit.android.room.datastream.outgoing.StreamDestination
... ... @@ -23,12 +24,18 @@ class MockStreamDestination<T>(val chunkSize: Int) : StreamDestination<T> {
override var isOpen: Boolean = true
val writtenChunks = mutableListOf<ByteArray>()
override suspend fun write(data: T, chunker: DataChunker<T>) {
override suspend fun write(data: T, chunker: DataChunker<T>): Result<Unit> {
if (!isOpen) {
return Result.failure(StreamException.TerminatedException())
}
val chunks = chunker.invoke(data, chunkSize)
for (chunk in chunks) {
writtenChunks.add(chunk)
}
return Result.success(Unit)
}
override suspend fun close(reason: String?) {
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -35,8 +35,6 @@ import io.livekit.android.test.mock.createMediaStreamId
import io.livekit.android.test.mock.room.track.createMockLocalAudioTrack
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertNotNull
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
... ... @@ -44,6 +42,8 @@ import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotNull
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito
... ...
... ... @@ -25,10 +25,10 @@ import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
... ... @@ -58,22 +58,19 @@ class RoomOutgoingDataStreamMockE2ETest : MockE2ETest() {
for (i in bytesToStream.indices) {
bytesToStream[i] = i.toByte()
}
val job = launch {
val sender = room.localParticipant.streamBytes(
StreamBytesOptions(
topic = "topic",
attributes = mapOf("hello" to "world"),
streamId = "stream_id",
destinationIdentities = listOf(Participant.Identity(TestData.REMOTE_PARTICIPANT.identity)),
name = "stream_name",
totalSize = bytesToStream.size.toLong(),
),
)
sender.write(bytesToStream)
sender.close()
}
job.join()
val sender = room.localParticipant.streamBytes(
StreamBytesOptions(
topic = "topic",
attributes = mapOf("hello" to "world"),
streamId = "stream_id",
destinationIdentities = listOf(Participant.Identity(TestData.REMOTE_PARTICIPANT.identity)),
name = "stream_name",
totalSize = bytesToStream.size.toLong(),
),
)
assertTrue(sender.write(bytesToStream).isSuccess)
sender.close()
assertFalse(sender.isOpen)
val buffers = pubDataChannel.sentBuffers
... ... @@ -111,25 +108,23 @@ class RoomOutgoingDataStreamMockE2ETest : MockE2ETest() {
)
val text = "test_text"
val job = launch {
val sender = room.localParticipant.streamText(
StreamTextOptions(
topic = "topic",
attributes = mapOf("hello" to "world"),
streamId = "stream_id",
destinationIdentities = listOf(Participant.Identity(TestData.REMOTE_PARTICIPANT.identity)),
operationType = TextStreamInfo.OperationType.CREATE,
version = 0,
attachedStreamIds = emptyList(),
replyToStreamId = null,
totalSize = 3,
),
)
sender.write(text)
sender.close()
}
val sender = room.localParticipant.streamText(
StreamTextOptions(
topic = "topic",
attributes = mapOf("hello" to "world"),
streamId = "stream_id",
destinationIdentities = listOf(Participant.Identity(TestData.REMOTE_PARTICIPANT.identity)),
operationType = TextStreamInfo.OperationType.CREATE,
version = 0,
attachedStreamIds = emptyList(),
replyToStreamId = null,
totalSize = 3,
),
)
assertTrue(sender.write(text).isSuccess)
sender.close()
job.join()
assertFalse(sender.isOpen)
val buffers = pubDataChannel.sentBuffers
... ...
... ... @@ -19,12 +19,14 @@ package io.livekit.android.room.datastream.outgoing
import io.livekit.android.room.datastream.ByteStreamInfo
import io.livekit.android.test.BaseTest
import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination
import kotlinx.coroutines.launch
import kotlinx.coroutines.ExperimentalCoroutinesApi
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import kotlin.math.roundToInt
@OptIn(ExperimentalCoroutinesApi::class)
class ByteStreamSenderTest : BaseTest() {
companion object {
... ... @@ -39,12 +41,9 @@ class ByteStreamSenderTest : BaseTest() {
destination = destination,
)
val job = launch {
sender.write(ByteArray(100))
sender.close()
}
job.join()
val result = sender.write(ByteArray(100))
assertTrue(result.isSuccess)
sender.close()
assertFalse(destination.isOpen)
assertEquals(1, destination.writtenChunks.size)
... ... @@ -61,12 +60,9 @@ class ByteStreamSenderTest : BaseTest() {
val bytes = ByteArray((CHUNK_SIZE * 1.5).roundToInt())
val job = launch {
sender.write(bytes)
sender.close()
}
job.join()
val result = sender.write(bytes)
assertTrue(result.isSuccess)
sender.close()
assertFalse(destination.isOpen)
assertEquals(2, destination.writtenChunks.size)
... ... @@ -74,5 +70,18 @@ class ByteStreamSenderTest : BaseTest() {
assertEquals(bytes.size - CHUNK_SIZE, destination.writtenChunks[1].size)
}
@Test
fun writeFailsAfterClose() = runTest {
val destination = MockStreamDestination<ByteArray>(CHUNK_SIZE)
val sender = ByteStreamSender(
info = createInfo(),
destination = destination,
)
assertTrue(sender.write(ByteArray(100)).isSuccess)
sender.close()
assertTrue(sender.write(ByteArray(100)).isFailure)
}
fun createInfo(): ByteStreamInfo = ByteStreamInfo(id = "stream_id", topic = "topic", timestampMs = 0, totalSize = null, attributes = mapOf(), mimeType = "", name = null)
}
... ...
... ... @@ -19,12 +19,14 @@ package io.livekit.android.room.datastream.outgoing
import io.livekit.android.room.datastream.TextStreamInfo
import io.livekit.android.test.BaseTest
import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination
import kotlinx.coroutines.launch
import kotlinx.coroutines.ExperimentalCoroutinesApi
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNotEquals
import org.junit.Assert.assertTrue
import org.junit.Test
@OptIn(ExperimentalCoroutinesApi::class)
class TextStreamSenderTest : BaseTest() {
companion object {
... ... @@ -40,12 +42,9 @@ class TextStreamSenderTest : BaseTest() {
)
val text = "abcdefghi"
val job = launch {
sender.write(text)
sender.close()
}
job.join()
val result = sender.write(text)
assertTrue(result.isSuccess)
sender.close()
assertFalse(destination.isOpen)
assertEquals(1, destination.writtenChunks.size)
... ... @@ -67,12 +66,9 @@ class TextStreamSenderTest : BaseTest() {
toString()
}
val job = launch {
sender.write(text)
sender.close()
}
job.join()
val result = sender.write(text)
assertTrue(result.isSuccess)
sender.close()
assertFalse(destination.isOpen)
assertNotEquals(1, destination.writtenChunks.size)
... ... @@ -87,6 +83,25 @@ class TextStreamSenderTest : BaseTest() {
assertEquals(text, writtenString)
}
@Test
fun writeFailsAfterClose() = runTest {
val destination = MockStreamDestination<String>(CHUNK_SIZE)
val sender = TextStreamSender(
info = createInfo(),
destination = destination,
)
val text = "abcdefghi"
assertTrue(sender.write(text).isSuccess)
sender.close()
assertTrue(sender.write(text).isFailure)
assertFalse(destination.isOpen)
assertEquals(1, destination.writtenChunks.size)
assertEquals(text, destination.writtenChunks[0].decodeToString())
}
fun createInfo(): TextStreamInfo = TextStreamInfo(
id = "stream_id",
topic = "topic",
... ...