davidliu
Committed by GitHub

Implement RPC (#578)

* Update protocol

* fix build errors

* Implement RPC

* tests

* spotless

* comment fixes
---
"client-sdk-android": minor
---
Implement RPC
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -273,6 +273,9 @@ enum class DisconnectReason {
MIGRATION,
SIGNAL_CLOSE,
ROOM_CLOSED,
USER_UNAVAILABLE,
USER_REJECTED,
SIP_TRUNK_FAILURE,
}
/**
... ... @@ -290,6 +293,9 @@ fun LivekitModels.DisconnectReason?.convert(): DisconnectReason {
LivekitModels.DisconnectReason.MIGRATION -> DisconnectReason.MIGRATION
LivekitModels.DisconnectReason.SIGNAL_CLOSE -> DisconnectReason.SIGNAL_CLOSE
LivekitModels.DisconnectReason.ROOM_CLOSED -> DisconnectReason.ROOM_CLOSED
LivekitModels.DisconnectReason.USER_UNAVAILABLE -> DisconnectReason.USER_UNAVAILABLE
LivekitModels.DisconnectReason.USER_REJECTED -> DisconnectReason.USER_REJECTED
LivekitModels.DisconnectReason.SIP_TRUNK_FAILURE -> DisconnectReason.SIP_TRUNK_FAILURE
LivekitModels.DisconnectReason.UNKNOWN_REASON,
LivekitModels.DisconnectReason.UNRECOGNIZED,
null,
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -19,6 +19,7 @@ package io.livekit.android.room
import android.os.SystemClock
import androidx.annotation.VisibleForTesting
import com.google.protobuf.ByteString
import com.vdurmont.semver4j.Semver
import io.livekit.android.ConnectOptions
import io.livekit.android.RoomOptions
import io.livekit.android.dagger.InjectionNames
... ... @@ -148,6 +149,9 @@ internal constructor(
private var lastRoomOptions: RoomOptions? = null
private var participantSid: String? = null
internal val serverVersion: Semver?
get() = client.serverVersion
private val publisherObserver = PublisherTransportObserver(this, client)
private val subscriberObserver = SubscriberTransportObserver(this, client)
... ... @@ -777,6 +781,7 @@ internal constructor(
fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse)
fun onTranscriptionReceived(transcription: LivekitModels.Transcription)
fun onLocalTrackSubscribed(trackSubscribed: LivekitRtc.TrackSubscribed)
fun onRpcPacketReceived(dp: LivekitModels.DataPacket)
}
companion object {
... ... @@ -792,7 +797,7 @@ internal constructor(
*/
@VisibleForTesting
const val LOSSY_DATA_CHANNEL_LABEL = "_lossy"
internal const val MAX_DATA_PACKET_SIZE = 15000
internal const val MAX_DATA_PACKET_SIZE = 15360 // 15 KB
private const val MAX_RECONNECT_RETRIES = 10
private const val MAX_RECONNECT_TIMEOUT = 60 * 1000
private const val MAX_ICE_CONNECT_TIMEOUT_MS = 20000
... ... @@ -1040,13 +1045,21 @@ internal constructor(
LivekitModels.DataPacket.ValueCase.RPC_ACK,
LivekitModels.DataPacket.ValueCase.RPC_RESPONSE,
-> {
// TODO
listener?.onRpcPacketReceived(dp)
}
LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
null,
-> {
LKLog.v { "invalid value for data packet" }
}
LivekitModels.DataPacket.ValueCase.STREAM_HEADER -> {
// TODO
}
LivekitModels.DataPacket.ValueCase.STREAM_CHUNK -> {
// TODO
}
}
}
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -648,6 +648,8 @@ constructor(
mutableRemoteParticipants = newParticipants
eventBus.postEvent(RoomEvent.ParticipantDisconnected(this, removedParticipant), coroutineScope)
localParticipant.handleParticipantDisconnect(identity)
}
fun getParticipantBySid(sid: String): Participant? {
... ... @@ -1195,6 +1197,10 @@ constructor(
publication?.onTranscriptionReceived(event)
}
override fun onRpcPacketReceived(dp: LivekitModels.DataPacket) {
localParticipant.handleDataPacket(dp)
}
/**
* @suppress
*/
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -84,7 +84,7 @@ constructor(
private var currentWs: WebSocket? = null
private var isReconnecting: Boolean = false
var listener: Listener? = null
private var serverVersion: Semver? = null
internal var serverVersion: Semver? = null
private var lastUrl: String? = null
private var lastOptions: ConnectOptions? = null
private var lastRoomOptions: RoomOptions? = null
... ... @@ -841,6 +841,7 @@ constructor(
lastUrl = null
lastOptions = null
lastRoomOptions = null
serverVersion = null
}
interface Listener {
... ...
... ... @@ -21,6 +21,7 @@ import android.content.Context
import android.content.Intent
import androidx.annotation.VisibleForTesting
import com.google.protobuf.ByteString
import com.vdurmont.semver4j.Semver
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
... ... @@ -47,15 +48,21 @@ import io.livekit.android.room.track.VideoCaptureParameter
import io.livekit.android.room.track.VideoCodec
import io.livekit.android.room.track.VideoEncoding
import io.livekit.android.room.util.EncodingUtils
import io.livekit.android.rpc.RpcError
import io.livekit.android.util.LKLog
import io.livekit.android.util.byteLength
import io.livekit.android.util.flow
import io.livekit.android.webrtc.sortVideoCodecPreferences
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import livekit.LivekitModels
import livekit.LivekitModels.DataPacket
import livekit.LivekitRtc
import livekit.LivekitRtc.AddTrackRequest
import livekit.LivekitRtc.SimulcastCodec
... ... @@ -67,9 +74,15 @@ import livekit.org.webrtc.RtpTransceiver.RtpTransceiverInit
import livekit.org.webrtc.SurfaceTextureHelper
import livekit.org.webrtc.VideoCapturer
import livekit.org.webrtc.VideoProcessor
import java.util.Collections
import java.util.UUID
import javax.inject.Named
import kotlin.coroutines.resume
import kotlin.math.max
import kotlin.math.min
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
class LocalParticipant
@AssistedInject
... ... @@ -105,6 +118,10 @@ internal constructor(
private val jobs = mutableMapOf<Any, Job>()
private val rpcHandlers = Collections.synchronizedMap(mutableMapOf<String, RpcHandler>()) // methodName to handler
private val pendingAcks = Collections.synchronizedMap(mutableMapOf<String, PendingRpcAck>()) // requestId to pending ack
private val pendingResponses = Collections.synchronizedMap(mutableMapOf<String, PendingRpcResponse>()) // requestId to pending response
// For ensuring that only one caller can execute setTrackEnabled at a time.
// Without it, there's a potential to create multiple of the same source,
// Camera has deadlock issues with multiple CameraCapturers trying to activate/stop.
... ... @@ -714,8 +731,8 @@ internal constructor(
}
val kind = when (reliability) {
DataPublishReliability.RELIABLE -> LivekitModels.DataPacket.Kind.RELIABLE
DataPublishReliability.LOSSY -> LivekitModels.DataPacket.Kind.LOSSY
DataPublishReliability.RELIABLE -> DataPacket.Kind.RELIABLE
DataPublishReliability.LOSSY -> DataPacket.Kind.LOSSY
}
val packetBuilder = LivekitModels.UserPacket.newBuilder().apply {
payload = ByteString.copyFrom(data)
... ... @@ -727,7 +744,7 @@ internal constructor(
addAllDestinationIdentities(identities.map { it.value })
}
}
val dataPacket = LivekitModels.DataPacket.newBuilder()
val dataPacket = DataPacket.newBuilder()
.setUser(packetBuilder)
.setKind(kind)
.build()
... ... @@ -741,9 +758,8 @@ internal constructor(
* SipDTMF message using the provided code and digit, then encapsulates it
* in a DataPacket before sending it via the engine.
*
* Parameters:
* - code: an integer representing the DTMF signal code
* - digit: the string representing the DTMF digit (e.g., "1", "#", "*")
* @param code an integer representing the DTMF signal code
* @param digit the string representing the DTMF digit (e.g., "1", "#", "*")
*/
@Suppress("unused")
... ... @@ -764,6 +780,375 @@ internal constructor(
}
/**
* Establishes the participant as a receiver for calls of the specified RPC method.
* Will overwrite any existing callback for the same method.
*
* Example:
* ```kt
* room.localParticipant.registerRpcMethod("greet") { (requestId, callerIdentity, payload, responseTimeout) ->
* Log.i("TAG", "Received greeting from ${callerIdentity}: ${payload}")
*
* // Return a string
* "Hello, ${callerIdentity}!"
* }
* ```
*
* The handler receives an [RpcInvocationData] with the following parameters:
* - `requestId`: A unique identifier for this RPC request
* - `callerIdentity`: The identity of the RemoteParticipant who initiated the RPC call
* - `payload`: The data sent by the caller (as a string)
* - `responseTimeout`: The maximum time available to return a response
*
* The handler should return a string.
* If unable to respond within [RpcInvocationData.responseTimeout], the request will result in an error on the caller's side.
*
* You may throw errors of type [RpcError] with a string `message` in the handler,
* and they will be received on the caller's side with the message intact.
* Other errors thrown in your handler will not be transmitted as-is, and will instead arrive to the caller as `1500` ("Application Error").
*
* @param method The name of the indicated RPC method
* @param handler Will be invoked when an RPC request for this method is received
* @see RpcHandler
* @see RpcInvocationData
* @see performRpc
*/
@Suppress("RedundantSuspendModifier")
suspend fun registerRpcMethod(
method: String,
handler: RpcHandler,
) {
this.rpcHandlers[method] = handler
}
/**
* Unregisters a previously registered RPC method.
*
* @param method The name of the RPC method to unregister
*/
fun unregisterRpcMethod(
method: String,
) {
this.rpcHandlers.remove(method)
}
internal fun handleDataPacket(packet: DataPacket) {
when {
packet.hasRpcRequest() -> {
val rpcRequest = packet.rpcRequest
scope.launch {
handleIncomingRpcRequest(
callerIdentity = Identity(packet.participantIdentity),
requestId = rpcRequest.id,
method = rpcRequest.method,
payload = rpcRequest.payload,
responseTimeout = rpcRequest.responseTimeoutMs.toUInt().toLong().milliseconds,
version = rpcRequest.version,
)
}
}
packet.hasRpcResponse() -> {
val rpcResponse = packet.rpcResponse
var payload: String? = null
var error: RpcError? = null
if (rpcResponse.hasPayload()) {
payload = rpcResponse.payload
} else if (rpcResponse.hasError()) {
error = RpcError.fromProto(rpcResponse.error)
}
handleIncomingRpcResponse(
requestId = rpcResponse.requestId,
payload = payload,
error = error,
)
}
packet.hasRpcAck() -> {
val rpcAck = packet.rpcAck
handleIncomingRpcAck(rpcAck.requestId)
}
}
}
/**
* Initiate an RPC call to a remote participant
* @param destinationIdentity The identity of the destination participant.
* @param method The method name to call.
* @param payload The payload to pass to the method.
* @param responseTimeout Timeout for receiving a response after initial connection.
* Defaults to 10000. Max value of UInt.MAX_VALUE milliseconds.
* @return The response payload.
* @throws RpcError on failure. Details in [RpcError.message].
*/
suspend fun performRpc(
destinationIdentity: Identity,
method: String,
payload: String,
responseTimeout: Duration = 10.seconds,
): String = coroutineScope {
val maxRoundTripLatency = 2.seconds
if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw RpcError.BuiltinRpcError.REQUEST_PAYLOAD_TOO_LARGE.create()
}
val serverVersion = engine.serverVersion
?: throw RpcError.BuiltinRpcError.SEND_FAILED.create(data = "Not connected.")
if (serverVersion < Semver("1.8.0")) {
throw RpcError.BuiltinRpcError.UNSUPPORTED_SERVER.create()
}
val requestId = UUID.randomUUID().toString()
publishRpcRequest(
destinationIdentity = destinationIdentity,
requestId = requestId,
method = method,
payload = payload,
responseTimeout = responseTimeout - maxRoundTripLatency,
)
val responsePayload = suspendCancellableCoroutine { continuation ->
var ackTimeoutJob: Job? = null
var responseTimeoutJob: Job? = null
fun cleanup() {
ackTimeoutJob?.cancel()
responseTimeoutJob?.cancel()
pendingAcks.remove(requestId)
pendingResponses.remove(requestId)
}
continuation.invokeOnCancellation { cleanup() }
ackTimeoutJob = launch {
delay(maxRoundTripLatency)
val receivedAck = pendingAcks.remove(requestId) == null
if (!receivedAck) {
pendingResponses.remove(requestId)
continuation.cancel(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create())
}
}
pendingAcks[requestId] = PendingRpcAck(
participantIdentity = destinationIdentity,
onResolve = { ackTimeoutJob.cancel() },
)
responseTimeoutJob = launch {
delay(responseTimeout)
val receivedResponse = pendingResponses.remove(requestId) == null
if (!receivedResponse) {
continuation.cancel(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create())
}
}
pendingResponses[requestId] = PendingRpcResponse(
participantIdentity = destinationIdentity,
onResolve = { payload, error ->
if (pendingAcks.containsKey(requestId)) {
LKLog.i { "RPC response received before ack, id: $requestId" }
}
cleanup()
if (error != null) {
continuation.cancel(error)
} else {
continuation.resume(payload ?: "")
}
},
)
}
return@coroutineScope responsePayload
}
private suspend fun publishRpcRequest(
destinationIdentity: Identity,
requestId: String,
method: String,
payload: String,
responseTimeout: Duration = 10.seconds,
) {
if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
val dataPacket = with(DataPacket.newBuilder()) {
addDestinationIdentities(destinationIdentity.value)
kind = DataPacket.Kind.RELIABLE
rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) {
this.id = requestId
this.method = method
this.payload = payload
this.responseTimeoutMs = responseTimeout.inWholeMilliseconds.toUInt().toInt()
build()
}
build()
}
engine.sendData(dataPacket)
}
private suspend fun publishRpcResponse(
destinationIdentity: Identity,
requestId: String,
payload: String?,
error: RpcError?,
) {
if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
val dataPacket = with(DataPacket.newBuilder()) {
addDestinationIdentities(destinationIdentity.value)
kind = DataPacket.Kind.RELIABLE
rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) {
this.requestId = requestId
if (error != null) {
this.error = error.toProto()
} else {
this.payload = payload ?: ""
}
build()
}
build()
}
engine.sendData(dataPacket)
}
private suspend fun publishRpcAck(
destinationIdentity: Identity,
requestId: String,
) {
val dataPacket = with(DataPacket.newBuilder()) {
addDestinationIdentities(destinationIdentity.value)
kind = DataPacket.Kind.RELIABLE
rpcAck = with(LivekitModels.RpcAck.newBuilder()) {
this.requestId = requestId
build()
}
build()
}
engine.sendData(dataPacket)
}
private fun handleIncomingRpcAck(requestId: String) {
val handler = this.pendingAcks.remove(requestId)
if (handler != null) {
handler.onResolve()
} else {
LKLog.e { "Ack received for unexpected RPC request, id = $requestId" }
}
}
private fun handleIncomingRpcResponse(
requestId: String,
payload: String?,
error: RpcError?,
) {
val handler = this.pendingResponses.remove(requestId)
if (handler != null) {
handler.onResolve(payload, error)
} else {
LKLog.e { "Response received for unexpected RPC request, id = $requestId" }
}
}
private suspend fun handleIncomingRpcRequest(
callerIdentity: Identity,
requestId: String,
method: String,
payload: String,
responseTimeout: Duration,
version: Int,
) {
publishRpcAck(callerIdentity, requestId)
if (version != 1) {
publishRpcResponse(
destinationIdentity = callerIdentity,
requestId = requestId,
payload = null,
error = RpcError.BuiltinRpcError.UNSUPPORTED_VERSION.create(),
)
return
}
val handler = this.rpcHandlers[method]
if (handler == null) {
publishRpcResponse(
destinationIdentity = callerIdentity,
requestId = requestId,
payload = null,
error = RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(),
)
return
}
var responseError: RpcError? = null
var responsePayload: String? = null
try {
val response = handler.invoke(
RpcInvocationData(
requestId = requestId,
callerIdentity = callerIdentity,
payload = payload,
responseTimeout = responseTimeout,
),
)
if (response.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
responseError = RpcError.BuiltinRpcError.RESPONSE_PAYLOAD_TOO_LARGE.create()
LKLog.w { "RPC Response payload too large for $method" }
} else {
responsePayload = response
}
} catch (e: Exception) {
if (e is RpcError) {
responseError = e
} else {
LKLog.w(e) { "Uncaught error returned by RPC handler for $method. Returning APPLICATION_ERROR instead." }
responseError = RpcError.BuiltinRpcError.APPLICATION_ERROR.create()
}
}
publishRpcResponse(
destinationIdentity = callerIdentity,
requestId = requestId,
payload = responsePayload,
error = responseError,
)
}
internal fun handleParticipantDisconnect(identity: Identity) {
synchronized(pendingAcks) {
val acksIterator = pendingAcks.iterator()
while (acksIterator.hasNext()) {
val (_, ack) = acksIterator.next()
if (ack.participantIdentity == identity) {
acksIterator.remove()
}
}
}
synchronized(pendingResponses) {
val responsesIterator = pendingResponses.iterator()
while (responsesIterator.hasNext()) {
val (_, response) = responsesIterator.next()
if (response.participantIdentity == identity) {
responsesIterator.remove()
response.onResolve(null, RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create())
}
}
}
}
/**
* @suppress
*/
@VisibleForTesting
... ... @@ -1232,3 +1617,42 @@ internal fun VideoTrackPublishOptions.hasBackupCodec(): Boolean {
private val backupCodecs = listOf(VideoCodec.VP8.codecName, VideoCodec.H264.codecName)
private fun isBackupCodec(codecName: String) = backupCodecs.contains(codecName)
/**
* A handler that processes an RPC request and returns a string
* that will be sent back to the requester.
*
* Throwing an [RpcError] will send the error back to the requester.
*
* @see [LocalParticipant.registerRpcMethod]
*/
typealias RpcHandler = suspend (RpcInvocationData) -> String
data class RpcInvocationData(
/**
* A unique identifier for this RPC request
*/
val requestId: String,
/**
* The identity of the RemoteParticipant who initiated the RPC call
*/
val callerIdentity: Participant.Identity,
/**
* The data sent by the caller (as a string)
*/
val payload: String,
/**
* The maximum time available to return a response
*/
val responseTimeout: Duration,
)
private data class PendingRpcAck(
val onResolve: () -> Unit,
val participantIdentity: Participant.Identity,
)
private data class PendingRpcResponse(
val onResolve: (payload: String?, error: RpcError?) -> Unit,
val participantIdentity: Participant.Identity,
)
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.rpc
import io.livekit.android.room.RTCEngine
import io.livekit.android.util.truncateBytes
import livekit.LivekitModels
/**
* Specialized error handling for RPC methods.
*
* Instances of this type, when thrown in a RPC method handler, will have their [message]
* serialized and sent across the wire. The sender will receive an equivalent error on the other side.
*
* Built-in types are included but developers may use any message string, with a max length of 256 bytes.
*/
data class RpcError(
/**
* The error code of the RPC call. Error codes 1001-1999 are reserved for built-in errors.
*
* See [RpcError.BuiltinRpcError] for built-in error information.
*/
val code: Int,
/**
* A message to include. Strings over 256 bytes will be truncated.
*/
override val message: String,
/**
* An optional data payload. Must be smaller than 15KB in size, or else will be truncated.
*/
val data: String = "",
) : Exception(message) {
enum class BuiltinRpcError(val code: Int, val message: String) {
APPLICATION_ERROR(1500, "Application error in method handler"),
CONNECTION_TIMEOUT(1501, "Connection timeout"),
RESPONSE_TIMEOUT(1502, "Response timeout"),
RECIPIENT_DISCONNECTED(1503, "Recipient disconnected"),
RESPONSE_PAYLOAD_TOO_LARGE(1504, "Response payload too large"),
SEND_FAILED(1505, "Failed to send"),
UNSUPPORTED_METHOD(1400, "Method not supported at destination"),
RECIPIENT_NOT_FOUND(1401, "Recipient not found"),
REQUEST_PAYLOAD_TOO_LARGE(1402, "Request payload too large"),
UNSUPPORTED_SERVER(1403, "RPC not supported by server"),
UNSUPPORTED_VERSION(1404, "Unsupported RPC version"),
;
fun create(data: String = ""): RpcError {
return RpcError(code, message, data)
}
}
companion object {
const val MAX_MESSAGE_BYTES = 256
fun fromProto(proto: LivekitModels.RpcError): RpcError {
return RpcError(
code = proto.code,
message = (proto.message ?: "").truncateBytes(MAX_MESSAGE_BYTES),
data = proto.data.truncateBytes(RTCEngine.MAX_DATA_PACKET_SIZE),
)
}
}
fun toProto(): LivekitModels.RpcError {
return with(LivekitModels.RpcError.newBuilder()) {
this.code = this@RpcError.code
this.message = this@RpcError.message
this.data = this@RpcError.data
build()
}
}
}
... ...
/*
* Copyright 2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.util
import okio.ByteString.Companion.encode
internal fun String?.byteLength(): Int {
if (this == null) {
return 0
}
return this.encode(Charsets.UTF_8).size
}
internal fun String.truncateBytes(maxBytes: Int): String {
if (this.byteLength() <= maxBytes) {
return this
}
var low = 0
var high = length
// Binary search for string that fits.
while (low < high) {
val mid = (low + high + 1) / 2
if (this.substring(0, mid).byteLength() <= maxBytes) {
low = mid
} else {
high = mid - 1
}
}
return substring(0, low)
}
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -69,7 +69,7 @@ abstract class MockE2ETest : BaseTest() {
room.release()
}
suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) {
open suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) {
connectSignal(joinResponse)
connectPeerConnection()
}
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -21,7 +21,7 @@ import livekit.org.webrtc.DataChannel
class MockDataChannel(private val label: String?) : DataChannel(1L) {
var observer: Observer? = null
var sentBuffers = mutableListOf<Buffer?>()
var sentBuffers = mutableListOf<Buffer>()
override fun registerObserver(observer: Observer?) {
this.observer = observer
}
... ... @@ -46,7 +46,7 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
return 0
}
override fun send(buffer: Buffer?): Boolean {
override fun send(buffer: Buffer): Boolean {
sentBuffers.add(buffer)
return true
}
... ... @@ -56,4 +56,8 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
override fun dispose() {
}
fun simulateBufferReceived(buffer: Buffer) {
observer?.onMessage(buffer)
}
}
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -18,6 +18,7 @@ package io.livekit.android.test.mock
import livekit.LivekitModels
import livekit.LivekitRtc
import java.util.UUID
object TestData {
... ... @@ -110,7 +111,7 @@ object TestData {
build()
},
)
serverVersion = "0.15.2"
serverVersion = "1.8.0"
build()
}
build()
... ... @@ -327,4 +328,16 @@ object TestData {
}
build()
}
val DATA_PACKET_RPC_REQUEST = with(LivekitModels.DataPacket.newBuilder()) {
participantIdentity = REMOTE_PARTICIPANT.identity
rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) {
id = UUID.randomUUID().toString()
method = "hello"
payload = "hello world"
responseTimeoutMs = 10000
version = 1
build()
}
build()
}
}
... ...
/*
* Copyright 2023-2025 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room.participant
import com.google.protobuf.ByteString
import io.livekit.android.room.RTCEngine
import io.livekit.android.rpc.RpcError
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.mock.TestData.REMOTE_PARTICIPANT
import io.livekit.android.test.util.toDataChannelBuffer
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import kotlin.time.Duration.Companion.milliseconds
@ExperimentalCoroutinesApi
@RunWith(RobolectricTestRunner::class)
class RpcMockE2ETest : MockE2ETest() {
lateinit var pubDataChannel: MockDataChannel
lateinit var subDataChannel: MockDataChannel
companion object {
val ERROR = RpcError(
1,
"This is an error message.",
"This is an error payload.",
)
}
override suspend fun connect(joinResponse: LivekitRtc.SignalResponse) {
super.connect(joinResponse)
val pubPeerConnection = component.rtcEngine().getPublisherPeerConnection() as MockPeerConnection
pubDataChannel = pubPeerConnection.dataChannels[RTCEngine.RELIABLE_DATA_CHANNEL_LABEL] as MockDataChannel
val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
subPeerConnection.observer?.onDataChannel(subDataChannel)
}
private fun createAck(requestId: String) =
with(LivekitModels.DataPacket.newBuilder()) {
participantIdentity = REMOTE_PARTICIPANT.identity
rpcAck = with(LivekitModels.RpcAck.newBuilder()) {
this.requestId = requestId
build()
}
build()
}.toDataChannelBuffer()
private fun createResponse(requestId: String, payload: String? = null, error: RpcError? = null) = with(LivekitModels.DataPacket.newBuilder()) {
participantIdentity = REMOTE_PARTICIPANT.identity
rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) {
this.requestId = requestId
if (error != null) {
this.error = error.toProto()
} else if (payload != null) {
this.payload = payload
}
build()
}
build()
}.toDataChannelBuffer()
@Test
fun handleRpcRequest() = runTest {
connect()
var methodCalled = false
room.localParticipant.registerRpcMethod("hello") {
methodCalled = true
"bye"
}
subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
assertTrue(methodCalled)
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
// Check that ack and response were sent
val buffers = pubDataChannel.sentBuffers
assertEquals(2, buffers.size)
val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
assertTrue(ackBuffer.hasRpcAck())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
assertTrue(responseBuffer.hasRpcResponse())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
assertEquals("bye", responseBuffer.rpcResponse.payload)
}
@Test
fun handleRpcRequestWithError() = runTest {
connect()
var methodCalled = false
room.localParticipant.registerRpcMethod("hello") {
methodCalled = true
throw ERROR
}
subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
assertTrue(methodCalled)
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
// Check that ack and response were sent
val buffers = pubDataChannel.sentBuffers
assertEquals(2, buffers.size)
val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
assertTrue(ackBuffer.hasRpcAck())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
assertTrue(responseBuffer.hasRpcResponse())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
assertEquals(ERROR, RpcError.fromProto(responseBuffer.rpcResponse.error))
}
@Test
fun handleRpcRequestWithNoHandler() = runTest {
connect()
subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
// Check that ack and response were sent
val buffers = pubDataChannel.sentBuffers
assertEquals(2, buffers.size)
val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
assertTrue(ackBuffer.hasRpcAck())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
assertTrue(responseBuffer.hasRpcResponse())
assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
assertEquals(RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(), RpcError.fromProto(responseBuffer.rpcResponse.error))
}
@Test
fun performRpc() = runTest {
connect()
val rpcJob = async {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
)
}
// Check that request was sent
val buffers = pubDataChannel.sentBuffers
assertEquals(1, buffers.size)
val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
assertTrue(requestBuffer.hasRpcRequest())
assertEquals("hello", requestBuffer.rpcRequest.method)
assertEquals("hello world", requestBuffer.rpcRequest.payload)
val requestId = requestBuffer.rpcRequest.id
// receive ack and response
subDataChannel.simulateBufferReceived(createAck(requestId))
subDataChannel.simulateBufferReceived(createResponse(requestId, payload = "bye"))
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
val response = rpcJob.await()
assertEquals("bye", response)
}
@Test
fun performRpcWithError() = runTest {
connect()
val rpcJob = async {
var expectedError: Exception? = null
try {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
)
} catch (e: Exception) {
expectedError = e
}
return@async expectedError
}
val buffers = pubDataChannel.sentBuffers
val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val requestId = requestBuffer.rpcRequest.id
// receive ack and response
subDataChannel.simulateBufferReceived(createAck(requestId))
subDataChannel.simulateBufferReceived(createResponse(requestId, error = ERROR))
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
val receivedError = rpcJob.await()
assertEquals(ERROR, receivedError)
}
@Test
fun performRpcWithParticipantDisconnected() = runTest {
connect()
simulateMessageFromServer(TestData.PARTICIPANT_JOIN)
val rpcJob = async {
var expectedError: Exception? = null
try {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
)
} catch (e: Exception) {
expectedError = e
}
return@async expectedError
}
simulateMessageFromServer(TestData.PARTICIPANT_DISCONNECT)
coroutineRule.dispatcher.scheduler.advanceUntilIdle()
val error = rpcJob.await()
assertEquals(RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create(), error)
}
@Test
fun performRpcWithConnectionTimeoutError() = runTest {
connect()
val rpcJob = async {
var expectedError: Exception? = null
try {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
)
} catch (e: Exception) {
expectedError = e
}
return@async expectedError
}
coroutineRule.dispatcher.scheduler.advanceTimeBy(3000)
val error = rpcJob.await()
assertEquals(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create(), error)
}
@Test
fun performRpcWithResponseTimeoutError() = runTest {
connect()
val rpcJob = async {
var expectedError: Exception? = null
try {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
)
} catch (e: Exception) {
expectedError = e
}
return@async expectedError
}
val buffers = pubDataChannel.sentBuffers
val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val requestId = requestBuffer.rpcRequest.id
// receive ack only
subDataChannel.simulateBufferReceived(createAck(requestId))
coroutineRule.dispatcher.scheduler.advanceTimeBy(15000)
val error = rpcJob.await()
assertEquals(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create(), error)
}
@Test
fun uintMaxValueVerification() = runTest {
assertEquals(4_294_967_295L, UInt.MAX_VALUE.toLong())
}
/**
* Protobuf handles UInt32 as Java signed integers.
* This test verifies whether our conversion is properly sent over the wire.
*/
@Test
fun performRpcProtoUIntVerification() = runTest {
connect()
val rpcJob = launch {
room.localParticipant.performRpc(
destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
method = "hello",
payload = "hello world",
responseTimeout = UInt.MAX_VALUE.toLong().milliseconds,
)
}
val buffers = pubDataChannel.sentBuffers
val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
val expectedResponseTimeout = UInt.MAX_VALUE - 2000u // 2000 comes from maxRoundTripLatency
val responseTimeout = requestBuffer.rpcRequest.responseTimeoutMs.toUInt()
assertEquals(expectedResponseTimeout, responseTimeout)
rpcJob.cancel()
}
/**
* Protobuf handles UInt32 as Java signed integers.
* This test verifies whether our conversion is properly sent over the wire.
*/
@Test
fun handleRpcProtoUIntVerification() = runTest {
connect()
var methodCalled = false
room.localParticipant.registerRpcMethod("hello") { invocationData ->
assertEquals(4_294_967_295L, invocationData.responseTimeout.inWholeMilliseconds)
methodCalled = true
"bye"
}
subDataChannel.simulateBufferReceived(
with(TestData.DATA_PACKET_RPC_REQUEST.toBuilder()) {
rpcRequest = with(rpcRequest.toBuilder()) {
responseTimeoutMs = UInt.MAX_VALUE.toInt()
build()
}
build()
}.toDataChannelBuffer(),
)
assertTrue(methodCalled)
}
}
... ...
Subproject commit a601adc5e9027820857a6d445b32a868b19d4184
Subproject commit 9e8d1e37c5eb4434424bc16c657c83e7dc63bc2a
... ...