davidliu
Committed by GitHub

Implement RPC (#578)

* Update protocol

* fix build errors

* Implement RPC

* tests

* spotless

* comment fixes
  1 +---
  2 +"client-sdk-android": minor
  3 +---
  4 +
  5 +Implement RPC
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -273,6 +273,9 @@ enum class DisconnectReason { @@ -273,6 +273,9 @@ enum class DisconnectReason {
273 MIGRATION, 273 MIGRATION,
274 SIGNAL_CLOSE, 274 SIGNAL_CLOSE,
275 ROOM_CLOSED, 275 ROOM_CLOSED,
  276 + USER_UNAVAILABLE,
  277 + USER_REJECTED,
  278 + SIP_TRUNK_FAILURE,
276 } 279 }
277 280
278 /** 281 /**
@@ -290,6 +293,9 @@ fun LivekitModels.DisconnectReason?.convert(): DisconnectReason { @@ -290,6 +293,9 @@ fun LivekitModels.DisconnectReason?.convert(): DisconnectReason {
290 LivekitModels.DisconnectReason.MIGRATION -> DisconnectReason.MIGRATION 293 LivekitModels.DisconnectReason.MIGRATION -> DisconnectReason.MIGRATION
291 LivekitModels.DisconnectReason.SIGNAL_CLOSE -> DisconnectReason.SIGNAL_CLOSE 294 LivekitModels.DisconnectReason.SIGNAL_CLOSE -> DisconnectReason.SIGNAL_CLOSE
292 LivekitModels.DisconnectReason.ROOM_CLOSED -> DisconnectReason.ROOM_CLOSED 295 LivekitModels.DisconnectReason.ROOM_CLOSED -> DisconnectReason.ROOM_CLOSED
  296 + LivekitModels.DisconnectReason.USER_UNAVAILABLE -> DisconnectReason.USER_UNAVAILABLE
  297 + LivekitModels.DisconnectReason.USER_REJECTED -> DisconnectReason.USER_REJECTED
  298 + LivekitModels.DisconnectReason.SIP_TRUNK_FAILURE -> DisconnectReason.SIP_TRUNK_FAILURE
293 LivekitModels.DisconnectReason.UNKNOWN_REASON, 299 LivekitModels.DisconnectReason.UNKNOWN_REASON,
294 LivekitModels.DisconnectReason.UNRECOGNIZED, 300 LivekitModels.DisconnectReason.UNRECOGNIZED,
295 null, 301 null,
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package io.livekit.android.room @@ -19,6 +19,7 @@ package io.livekit.android.room
19 import android.os.SystemClock 19 import android.os.SystemClock
20 import androidx.annotation.VisibleForTesting 20 import androidx.annotation.VisibleForTesting
21 import com.google.protobuf.ByteString 21 import com.google.protobuf.ByteString
  22 +import com.vdurmont.semver4j.Semver
22 import io.livekit.android.ConnectOptions 23 import io.livekit.android.ConnectOptions
23 import io.livekit.android.RoomOptions 24 import io.livekit.android.RoomOptions
24 import io.livekit.android.dagger.InjectionNames 25 import io.livekit.android.dagger.InjectionNames
@@ -148,6 +149,9 @@ internal constructor( @@ -148,6 +149,9 @@ internal constructor(
148 private var lastRoomOptions: RoomOptions? = null 149 private var lastRoomOptions: RoomOptions? = null
149 private var participantSid: String? = null 150 private var participantSid: String? = null
150 151
  152 + internal val serverVersion: Semver?
  153 + get() = client.serverVersion
  154 +
151 private val publisherObserver = PublisherTransportObserver(this, client) 155 private val publisherObserver = PublisherTransportObserver(this, client)
152 private val subscriberObserver = SubscriberTransportObserver(this, client) 156 private val subscriberObserver = SubscriberTransportObserver(this, client)
153 157
@@ -777,6 +781,7 @@ internal constructor( @@ -777,6 +781,7 @@ internal constructor(
777 fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse) 781 fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse)
778 fun onTranscriptionReceived(transcription: LivekitModels.Transcription) 782 fun onTranscriptionReceived(transcription: LivekitModels.Transcription)
779 fun onLocalTrackSubscribed(trackSubscribed: LivekitRtc.TrackSubscribed) 783 fun onLocalTrackSubscribed(trackSubscribed: LivekitRtc.TrackSubscribed)
  784 + fun onRpcPacketReceived(dp: LivekitModels.DataPacket)
780 } 785 }
781 786
782 companion object { 787 companion object {
@@ -792,7 +797,7 @@ internal constructor( @@ -792,7 +797,7 @@ internal constructor(
792 */ 797 */
793 @VisibleForTesting 798 @VisibleForTesting
794 const val LOSSY_DATA_CHANNEL_LABEL = "_lossy" 799 const val LOSSY_DATA_CHANNEL_LABEL = "_lossy"
795 - internal const val MAX_DATA_PACKET_SIZE = 15000 800 + internal const val MAX_DATA_PACKET_SIZE = 15360 // 15 KB
796 private const val MAX_RECONNECT_RETRIES = 10 801 private const val MAX_RECONNECT_RETRIES = 10
797 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000 802 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000
798 private const val MAX_ICE_CONNECT_TIMEOUT_MS = 20000 803 private const val MAX_ICE_CONNECT_TIMEOUT_MS = 20000
@@ -1040,13 +1045,21 @@ internal constructor( @@ -1040,13 +1045,21 @@ internal constructor(
1040 LivekitModels.DataPacket.ValueCase.RPC_ACK, 1045 LivekitModels.DataPacket.ValueCase.RPC_ACK,
1041 LivekitModels.DataPacket.ValueCase.RPC_RESPONSE, 1046 LivekitModels.DataPacket.ValueCase.RPC_RESPONSE,
1042 -> { 1047 -> {
1043 - // TODO 1048 + listener?.onRpcPacketReceived(dp)
1044 } 1049 }
1045 LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, 1050 LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
1046 null, 1051 null,
1047 -> { 1052 -> {
1048 LKLog.v { "invalid value for data packet" } 1053 LKLog.v { "invalid value for data packet" }
1049 } 1054 }
  1055 +
  1056 + LivekitModels.DataPacket.ValueCase.STREAM_HEADER -> {
  1057 + // TODO
  1058 + }
  1059 +
  1060 + LivekitModels.DataPacket.ValueCase.STREAM_CHUNK -> {
  1061 + // TODO
  1062 + }
1050 } 1063 }
1051 } 1064 }
1052 1065
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -648,6 +648,8 @@ constructor( @@ -648,6 +648,8 @@ constructor(
648 648
649 mutableRemoteParticipants = newParticipants 649 mutableRemoteParticipants = newParticipants
650 eventBus.postEvent(RoomEvent.ParticipantDisconnected(this, removedParticipant), coroutineScope) 650 eventBus.postEvent(RoomEvent.ParticipantDisconnected(this, removedParticipant), coroutineScope)
  651 +
  652 + localParticipant.handleParticipantDisconnect(identity)
651 } 653 }
652 654
653 fun getParticipantBySid(sid: String): Participant? { 655 fun getParticipantBySid(sid: String): Participant? {
@@ -1195,6 +1197,10 @@ constructor( @@ -1195,6 +1197,10 @@ constructor(
1195 publication?.onTranscriptionReceived(event) 1197 publication?.onTranscriptionReceived(event)
1196 } 1198 }
1197 1199
  1200 + override fun onRpcPacketReceived(dp: LivekitModels.DataPacket) {
  1201 + localParticipant.handleDataPacket(dp)
  1202 + }
  1203 +
1198 /** 1204 /**
1199 * @suppress 1205 * @suppress
1200 */ 1206 */
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -84,7 +84,7 @@ constructor( @@ -84,7 +84,7 @@ constructor(
84 private var currentWs: WebSocket? = null 84 private var currentWs: WebSocket? = null
85 private var isReconnecting: Boolean = false 85 private var isReconnecting: Boolean = false
86 var listener: Listener? = null 86 var listener: Listener? = null
87 - private var serverVersion: Semver? = null 87 + internal var serverVersion: Semver? = null
88 private var lastUrl: String? = null 88 private var lastUrl: String? = null
89 private var lastOptions: ConnectOptions? = null 89 private var lastOptions: ConnectOptions? = null
90 private var lastRoomOptions: RoomOptions? = null 90 private var lastRoomOptions: RoomOptions? = null
@@ -841,6 +841,7 @@ constructor( @@ -841,6 +841,7 @@ constructor(
841 lastUrl = null 841 lastUrl = null
842 lastOptions = null 842 lastOptions = null
843 lastRoomOptions = null 843 lastRoomOptions = null
  844 + serverVersion = null
844 } 845 }
845 846
846 interface Listener { 847 interface Listener {
@@ -21,6 +21,7 @@ import android.content.Context @@ -21,6 +21,7 @@ import android.content.Context
21 import android.content.Intent 21 import android.content.Intent
22 import androidx.annotation.VisibleForTesting 22 import androidx.annotation.VisibleForTesting
23 import com.google.protobuf.ByteString 23 import com.google.protobuf.ByteString
  24 +import com.vdurmont.semver4j.Semver
24 import dagger.assisted.Assisted 25 import dagger.assisted.Assisted
25 import dagger.assisted.AssistedFactory 26 import dagger.assisted.AssistedFactory
26 import dagger.assisted.AssistedInject 27 import dagger.assisted.AssistedInject
@@ -47,15 +48,21 @@ import io.livekit.android.room.track.VideoCaptureParameter @@ -47,15 +48,21 @@ import io.livekit.android.room.track.VideoCaptureParameter
47 import io.livekit.android.room.track.VideoCodec 48 import io.livekit.android.room.track.VideoCodec
48 import io.livekit.android.room.track.VideoEncoding 49 import io.livekit.android.room.track.VideoEncoding
49 import io.livekit.android.room.util.EncodingUtils 50 import io.livekit.android.room.util.EncodingUtils
  51 +import io.livekit.android.rpc.RpcError
50 import io.livekit.android.util.LKLog 52 import io.livekit.android.util.LKLog
  53 +import io.livekit.android.util.byteLength
51 import io.livekit.android.util.flow 54 import io.livekit.android.util.flow
52 import io.livekit.android.webrtc.sortVideoCodecPreferences 55 import io.livekit.android.webrtc.sortVideoCodecPreferences
53 import kotlinx.coroutines.CoroutineDispatcher 56 import kotlinx.coroutines.CoroutineDispatcher
54 import kotlinx.coroutines.Job 57 import kotlinx.coroutines.Job
  58 +import kotlinx.coroutines.coroutineScope
  59 +import kotlinx.coroutines.delay
55 import kotlinx.coroutines.launch 60 import kotlinx.coroutines.launch
  61 +import kotlinx.coroutines.suspendCancellableCoroutine
56 import kotlinx.coroutines.sync.Mutex 62 import kotlinx.coroutines.sync.Mutex
57 import kotlinx.coroutines.sync.withLock 63 import kotlinx.coroutines.sync.withLock
58 import livekit.LivekitModels 64 import livekit.LivekitModels
  65 +import livekit.LivekitModels.DataPacket
59 import livekit.LivekitRtc 66 import livekit.LivekitRtc
60 import livekit.LivekitRtc.AddTrackRequest 67 import livekit.LivekitRtc.AddTrackRequest
61 import livekit.LivekitRtc.SimulcastCodec 68 import livekit.LivekitRtc.SimulcastCodec
@@ -67,9 +74,15 @@ import livekit.org.webrtc.RtpTransceiver.RtpTransceiverInit @@ -67,9 +74,15 @@ import livekit.org.webrtc.RtpTransceiver.RtpTransceiverInit
67 import livekit.org.webrtc.SurfaceTextureHelper 74 import livekit.org.webrtc.SurfaceTextureHelper
68 import livekit.org.webrtc.VideoCapturer 75 import livekit.org.webrtc.VideoCapturer
69 import livekit.org.webrtc.VideoProcessor 76 import livekit.org.webrtc.VideoProcessor
  77 +import java.util.Collections
  78 +import java.util.UUID
70 import javax.inject.Named 79 import javax.inject.Named
  80 +import kotlin.coroutines.resume
71 import kotlin.math.max 81 import kotlin.math.max
72 import kotlin.math.min 82 import kotlin.math.min
  83 +import kotlin.time.Duration
  84 +import kotlin.time.Duration.Companion.milliseconds
  85 +import kotlin.time.Duration.Companion.seconds
73 86
74 class LocalParticipant 87 class LocalParticipant
75 @AssistedInject 88 @AssistedInject
@@ -105,6 +118,10 @@ internal constructor( @@ -105,6 +118,10 @@ internal constructor(
105 118
106 private val jobs = mutableMapOf<Any, Job>() 119 private val jobs = mutableMapOf<Any, Job>()
107 120
  121 + private val rpcHandlers = Collections.synchronizedMap(mutableMapOf<String, RpcHandler>()) // methodName to handler
  122 + private val pendingAcks = Collections.synchronizedMap(mutableMapOf<String, PendingRpcAck>()) // requestId to pending ack
  123 + private val pendingResponses = Collections.synchronizedMap(mutableMapOf<String, PendingRpcResponse>()) // requestId to pending response
  124 +
108 // For ensuring that only one caller can execute setTrackEnabled at a time. 125 // For ensuring that only one caller can execute setTrackEnabled at a time.
109 // Without it, there's a potential to create multiple of the same source, 126 // Without it, there's a potential to create multiple of the same source,
110 // Camera has deadlock issues with multiple CameraCapturers trying to activate/stop. 127 // Camera has deadlock issues with multiple CameraCapturers trying to activate/stop.
@@ -714,8 +731,8 @@ internal constructor( @@ -714,8 +731,8 @@ internal constructor(
714 } 731 }
715 732
716 val kind = when (reliability) { 733 val kind = when (reliability) {
717 - DataPublishReliability.RELIABLE -> LivekitModels.DataPacket.Kind.RELIABLE  
718 - DataPublishReliability.LOSSY -> LivekitModels.DataPacket.Kind.LOSSY 734 + DataPublishReliability.RELIABLE -> DataPacket.Kind.RELIABLE
  735 + DataPublishReliability.LOSSY -> DataPacket.Kind.LOSSY
719 } 736 }
720 val packetBuilder = LivekitModels.UserPacket.newBuilder().apply { 737 val packetBuilder = LivekitModels.UserPacket.newBuilder().apply {
721 payload = ByteString.copyFrom(data) 738 payload = ByteString.copyFrom(data)
@@ -727,7 +744,7 @@ internal constructor( @@ -727,7 +744,7 @@ internal constructor(
727 addAllDestinationIdentities(identities.map { it.value }) 744 addAllDestinationIdentities(identities.map { it.value })
728 } 745 }
729 } 746 }
730 - val dataPacket = LivekitModels.DataPacket.newBuilder() 747 + val dataPacket = DataPacket.newBuilder()
731 .setUser(packetBuilder) 748 .setUser(packetBuilder)
732 .setKind(kind) 749 .setKind(kind)
733 .build() 750 .build()
@@ -741,9 +758,8 @@ internal constructor( @@ -741,9 +758,8 @@ internal constructor(
741 * SipDTMF message using the provided code and digit, then encapsulates it 758 * SipDTMF message using the provided code and digit, then encapsulates it
742 * in a DataPacket before sending it via the engine. 759 * in a DataPacket before sending it via the engine.
743 * 760 *
744 - * Parameters:  
745 - * - code: an integer representing the DTMF signal code  
746 - * - digit: the string representing the DTMF digit (e.g., "1", "#", "*") 761 + * @param code an integer representing the DTMF signal code
  762 + * @param digit the string representing the DTMF digit (e.g., "1", "#", "*")
747 */ 763 */
748 764
749 @Suppress("unused") 765 @Suppress("unused")
@@ -764,6 +780,375 @@ internal constructor( @@ -764,6 +780,375 @@ internal constructor(
764 } 780 }
765 781
766 /** 782 /**
  783 + * Establishes the participant as a receiver for calls of the specified RPC method.
  784 + * Will overwrite any existing callback for the same method.
  785 + *
  786 + * Example:
  787 + * ```kt
  788 + * room.localParticipant.registerRpcMethod("greet") { (requestId, callerIdentity, payload, responseTimeout) ->
  789 + * Log.i("TAG", "Received greeting from ${callerIdentity}: ${payload}")
  790 + *
  791 + * // Return a string
  792 + * "Hello, ${callerIdentity}!"
  793 + * }
  794 + * ```
  795 + *
  796 + * The handler receives an [RpcInvocationData] with the following parameters:
  797 + * - `requestId`: A unique identifier for this RPC request
  798 + * - `callerIdentity`: The identity of the RemoteParticipant who initiated the RPC call
  799 + * - `payload`: The data sent by the caller (as a string)
  800 + * - `responseTimeout`: The maximum time available to return a response
  801 + *
  802 + * The handler should return a string.
  803 + * If unable to respond within [RpcInvocationData.responseTimeout], the request will result in an error on the caller's side.
  804 + *
  805 + * You may throw errors of type [RpcError] with a string `message` in the handler,
  806 + * and they will be received on the caller's side with the message intact.
  807 + * Other errors thrown in your handler will not be transmitted as-is, and will instead arrive to the caller as `1500` ("Application Error").
  808 + *
  809 + * @param method The name of the indicated RPC method
  810 + * @param handler Will be invoked when an RPC request for this method is received
  811 + * @see RpcHandler
  812 + * @see RpcInvocationData
  813 + * @see performRpc
  814 + */
  815 + @Suppress("RedundantSuspendModifier")
  816 + suspend fun registerRpcMethod(
  817 + method: String,
  818 + handler: RpcHandler,
  819 + ) {
  820 + this.rpcHandlers[method] = handler
  821 + }
  822 +
  823 + /**
  824 + * Unregisters a previously registered RPC method.
  825 + *
  826 + * @param method The name of the RPC method to unregister
  827 + */
  828 + fun unregisterRpcMethod(
  829 + method: String,
  830 + ) {
  831 + this.rpcHandlers.remove(method)
  832 + }
  833 +
  834 + internal fun handleDataPacket(packet: DataPacket) {
  835 + when {
  836 + packet.hasRpcRequest() -> {
  837 + val rpcRequest = packet.rpcRequest
  838 + scope.launch {
  839 + handleIncomingRpcRequest(
  840 + callerIdentity = Identity(packet.participantIdentity),
  841 + requestId = rpcRequest.id,
  842 + method = rpcRequest.method,
  843 + payload = rpcRequest.payload,
  844 + responseTimeout = rpcRequest.responseTimeoutMs.toUInt().toLong().milliseconds,
  845 + version = rpcRequest.version,
  846 + )
  847 + }
  848 + }
  849 +
  850 + packet.hasRpcResponse() -> {
  851 + val rpcResponse = packet.rpcResponse
  852 + var payload: String? = null
  853 + var error: RpcError? = null
  854 +
  855 + if (rpcResponse.hasPayload()) {
  856 + payload = rpcResponse.payload
  857 + } else if (rpcResponse.hasError()) {
  858 + error = RpcError.fromProto(rpcResponse.error)
  859 + }
  860 + handleIncomingRpcResponse(
  861 + requestId = rpcResponse.requestId,
  862 + payload = payload,
  863 + error = error,
  864 + )
  865 + }
  866 +
  867 + packet.hasRpcAck() -> {
  868 + val rpcAck = packet.rpcAck
  869 + handleIncomingRpcAck(rpcAck.requestId)
  870 + }
  871 + }
  872 + }
  873 +
  874 + /**
  875 + * Initiate an RPC call to a remote participant
  876 + * @param destinationIdentity The identity of the destination participant.
  877 + * @param method The method name to call.
  878 + * @param payload The payload to pass to the method.
  879 + * @param responseTimeout Timeout for receiving a response after initial connection.
  880 + * Defaults to 10000. Max value of UInt.MAX_VALUE milliseconds.
  881 + * @return The response payload.
  882 + * @throws RpcError on failure. Details in [RpcError.message].
  883 + */
  884 + suspend fun performRpc(
  885 + destinationIdentity: Identity,
  886 + method: String,
  887 + payload: String,
  888 + responseTimeout: Duration = 10.seconds,
  889 + ): String = coroutineScope {
  890 + val maxRoundTripLatency = 2.seconds
  891 +
  892 + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
  893 + throw RpcError.BuiltinRpcError.REQUEST_PAYLOAD_TOO_LARGE.create()
  894 + }
  895 +
  896 + val serverVersion = engine.serverVersion
  897 + ?: throw RpcError.BuiltinRpcError.SEND_FAILED.create(data = "Not connected.")
  898 +
  899 + if (serverVersion < Semver("1.8.0")) {
  900 + throw RpcError.BuiltinRpcError.UNSUPPORTED_SERVER.create()
  901 + }
  902 +
  903 + val requestId = UUID.randomUUID().toString()
  904 +
  905 + publishRpcRequest(
  906 + destinationIdentity = destinationIdentity,
  907 + requestId = requestId,
  908 + method = method,
  909 + payload = payload,
  910 + responseTimeout = responseTimeout - maxRoundTripLatency,
  911 + )
  912 +
  913 + val responsePayload = suspendCancellableCoroutine { continuation ->
  914 + var ackTimeoutJob: Job? = null
  915 + var responseTimeoutJob: Job? = null
  916 +
  917 + fun cleanup() {
  918 + ackTimeoutJob?.cancel()
  919 + responseTimeoutJob?.cancel()
  920 + pendingAcks.remove(requestId)
  921 + pendingResponses.remove(requestId)
  922 + }
  923 +
  924 + continuation.invokeOnCancellation { cleanup() }
  925 +
  926 + ackTimeoutJob = launch {
  927 + delay(maxRoundTripLatency)
  928 + val receivedAck = pendingAcks.remove(requestId) == null
  929 + if (!receivedAck) {
  930 + pendingResponses.remove(requestId)
  931 + continuation.cancel(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create())
  932 + }
  933 + }
  934 + pendingAcks[requestId] = PendingRpcAck(
  935 + participantIdentity = destinationIdentity,
  936 + onResolve = { ackTimeoutJob.cancel() },
  937 + )
  938 +
  939 + responseTimeoutJob = launch {
  940 + delay(responseTimeout)
  941 + val receivedResponse = pendingResponses.remove(requestId) == null
  942 + if (!receivedResponse) {
  943 + continuation.cancel(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create())
  944 + }
  945 + }
  946 +
  947 + pendingResponses[requestId] = PendingRpcResponse(
  948 + participantIdentity = destinationIdentity,
  949 + onResolve = { payload, error ->
  950 + if (pendingAcks.containsKey(requestId)) {
  951 + LKLog.i { "RPC response received before ack, id: $requestId" }
  952 + }
  953 + cleanup()
  954 +
  955 + if (error != null) {
  956 + continuation.cancel(error)
  957 + } else {
  958 + continuation.resume(payload ?: "")
  959 + }
  960 + },
  961 + )
  962 + }
  963 + return@coroutineScope responsePayload
  964 + }
  965 +
  966 + private suspend fun publishRpcRequest(
  967 + destinationIdentity: Identity,
  968 + requestId: String,
  969 + method: String,
  970 + payload: String,
  971 + responseTimeout: Duration = 10.seconds,
  972 + ) {
  973 + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
  974 + throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
  975 + }
  976 +
  977 + val dataPacket = with(DataPacket.newBuilder()) {
  978 + addDestinationIdentities(destinationIdentity.value)
  979 + kind = DataPacket.Kind.RELIABLE
  980 + rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) {
  981 + this.id = requestId
  982 + this.method = method
  983 + this.payload = payload
  984 + this.responseTimeoutMs = responseTimeout.inWholeMilliseconds.toUInt().toInt()
  985 + build()
  986 + }
  987 + build()
  988 + }
  989 +
  990 + engine.sendData(dataPacket)
  991 + }
  992 +
  993 + private suspend fun publishRpcResponse(
  994 + destinationIdentity: Identity,
  995 + requestId: String,
  996 + payload: String?,
  997 + error: RpcError?,
  998 + ) {
  999 + if (payload.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
  1000 + throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
  1001 + }
  1002 +
  1003 + val dataPacket = with(DataPacket.newBuilder()) {
  1004 + addDestinationIdentities(destinationIdentity.value)
  1005 + kind = DataPacket.Kind.RELIABLE
  1006 + rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) {
  1007 + this.requestId = requestId
  1008 + if (error != null) {
  1009 + this.error = error.toProto()
  1010 + } else {
  1011 + this.payload = payload ?: ""
  1012 + }
  1013 + build()
  1014 + }
  1015 + build()
  1016 + }
  1017 +
  1018 + engine.sendData(dataPacket)
  1019 + }
  1020 +
  1021 + private suspend fun publishRpcAck(
  1022 + destinationIdentity: Identity,
  1023 + requestId: String,
  1024 + ) {
  1025 + val dataPacket = with(DataPacket.newBuilder()) {
  1026 + addDestinationIdentities(destinationIdentity.value)
  1027 + kind = DataPacket.Kind.RELIABLE
  1028 + rpcAck = with(LivekitModels.RpcAck.newBuilder()) {
  1029 + this.requestId = requestId
  1030 + build()
  1031 + }
  1032 + build()
  1033 + }
  1034 +
  1035 + engine.sendData(dataPacket)
  1036 + }
  1037 +
  1038 + private fun handleIncomingRpcAck(requestId: String) {
  1039 + val handler = this.pendingAcks.remove(requestId)
  1040 + if (handler != null) {
  1041 + handler.onResolve()
  1042 + } else {
  1043 + LKLog.e { "Ack received for unexpected RPC request, id = $requestId" }
  1044 + }
  1045 + }
  1046 +
  1047 + private fun handleIncomingRpcResponse(
  1048 + requestId: String,
  1049 + payload: String?,
  1050 + error: RpcError?,
  1051 + ) {
  1052 + val handler = this.pendingResponses.remove(requestId)
  1053 + if (handler != null) {
  1054 + handler.onResolve(payload, error)
  1055 + } else {
  1056 + LKLog.e { "Response received for unexpected RPC request, id = $requestId" }
  1057 + }
  1058 + }
  1059 +
  1060 + private suspend fun handleIncomingRpcRequest(
  1061 + callerIdentity: Identity,
  1062 + requestId: String,
  1063 + method: String,
  1064 + payload: String,
  1065 + responseTimeout: Duration,
  1066 + version: Int,
  1067 + ) {
  1068 + publishRpcAck(callerIdentity, requestId)
  1069 +
  1070 + if (version != 1) {
  1071 + publishRpcResponse(
  1072 + destinationIdentity = callerIdentity,
  1073 + requestId = requestId,
  1074 + payload = null,
  1075 + error = RpcError.BuiltinRpcError.UNSUPPORTED_VERSION.create(),
  1076 + )
  1077 + return
  1078 + }
  1079 +
  1080 + val handler = this.rpcHandlers[method]
  1081 +
  1082 + if (handler == null) {
  1083 + publishRpcResponse(
  1084 + destinationIdentity = callerIdentity,
  1085 + requestId = requestId,
  1086 + payload = null,
  1087 + error = RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(),
  1088 + )
  1089 + return
  1090 + }
  1091 +
  1092 + var responseError: RpcError? = null
  1093 + var responsePayload: String? = null
  1094 +
  1095 + try {
  1096 + val response = handler.invoke(
  1097 + RpcInvocationData(
  1098 + requestId = requestId,
  1099 + callerIdentity = callerIdentity,
  1100 + payload = payload,
  1101 + responseTimeout = responseTimeout,
  1102 + ),
  1103 + )
  1104 +
  1105 + if (response.byteLength() > RTCEngine.MAX_DATA_PACKET_SIZE) {
  1106 + responseError = RpcError.BuiltinRpcError.RESPONSE_PAYLOAD_TOO_LARGE.create()
  1107 + LKLog.w { "RPC Response payload too large for $method" }
  1108 + } else {
  1109 + responsePayload = response
  1110 + }
  1111 + } catch (e: Exception) {
  1112 + if (e is RpcError) {
  1113 + responseError = e
  1114 + } else {
  1115 + LKLog.w(e) { "Uncaught error returned by RPC handler for $method. Returning APPLICATION_ERROR instead." }
  1116 + responseError = RpcError.BuiltinRpcError.APPLICATION_ERROR.create()
  1117 + }
  1118 + }
  1119 +
  1120 + publishRpcResponse(
  1121 + destinationIdentity = callerIdentity,
  1122 + requestId = requestId,
  1123 + payload = responsePayload,
  1124 + error = responseError,
  1125 + )
  1126 + }
  1127 +
  1128 + internal fun handleParticipantDisconnect(identity: Identity) {
  1129 + synchronized(pendingAcks) {
  1130 + val acksIterator = pendingAcks.iterator()
  1131 + while (acksIterator.hasNext()) {
  1132 + val (_, ack) = acksIterator.next()
  1133 + if (ack.participantIdentity == identity) {
  1134 + acksIterator.remove()
  1135 + }
  1136 + }
  1137 + }
  1138 +
  1139 + synchronized(pendingResponses) {
  1140 + val responsesIterator = pendingResponses.iterator()
  1141 + while (responsesIterator.hasNext()) {
  1142 + val (_, response) = responsesIterator.next()
  1143 + if (response.participantIdentity == identity) {
  1144 + responsesIterator.remove()
  1145 + response.onResolve(null, RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create())
  1146 + }
  1147 + }
  1148 + }
  1149 + }
  1150 +
  1151 + /**
767 * @suppress 1152 * @suppress
768 */ 1153 */
769 @VisibleForTesting 1154 @VisibleForTesting
@@ -1232,3 +1617,42 @@ internal fun VideoTrackPublishOptions.hasBackupCodec(): Boolean { @@ -1232,3 +1617,42 @@ internal fun VideoTrackPublishOptions.hasBackupCodec(): Boolean {
1232 1617
1233 private val backupCodecs = listOf(VideoCodec.VP8.codecName, VideoCodec.H264.codecName) 1618 private val backupCodecs = listOf(VideoCodec.VP8.codecName, VideoCodec.H264.codecName)
1234 private fun isBackupCodec(codecName: String) = backupCodecs.contains(codecName) 1619 private fun isBackupCodec(codecName: String) = backupCodecs.contains(codecName)
  1620 +
  1621 +/**
  1622 + * A handler that processes an RPC request and returns a string
  1623 + * that will be sent back to the requester.
  1624 + *
  1625 + * Throwing an [RpcError] will send the error back to the requester.
  1626 + *
  1627 + * @see [LocalParticipant.registerRpcMethod]
  1628 + */
  1629 +typealias RpcHandler = suspend (RpcInvocationData) -> String
  1630 +
  1631 +data class RpcInvocationData(
  1632 + /**
  1633 + * A unique identifier for this RPC request
  1634 + */
  1635 + val requestId: String,
  1636 + /**
  1637 + * The identity of the RemoteParticipant who initiated the RPC call
  1638 + */
  1639 + val callerIdentity: Participant.Identity,
  1640 + /**
  1641 + * The data sent by the caller (as a string)
  1642 + */
  1643 + val payload: String,
  1644 + /**
  1645 + * The maximum time available to return a response
  1646 + */
  1647 + val responseTimeout: Duration,
  1648 +)
  1649 +
  1650 +private data class PendingRpcAck(
  1651 + val onResolve: () -> Unit,
  1652 + val participantIdentity: Participant.Identity,
  1653 +)
  1654 +
  1655 +private data class PendingRpcResponse(
  1656 + val onResolve: (payload: String?, error: RpcError?) -> Unit,
  1657 + val participantIdentity: Participant.Identity,
  1658 +)
  1 +/*
  2 + * Copyright 2025 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.rpc
  18 +
  19 +import io.livekit.android.room.RTCEngine
  20 +import io.livekit.android.util.truncateBytes
  21 +import livekit.LivekitModels
  22 +
  23 +/**
  24 + * Specialized error handling for RPC methods.
  25 + *
  26 + * Instances of this type, when thrown in a RPC method handler, will have their [message]
  27 + * serialized and sent across the wire. The sender will receive an equivalent error on the other side.
  28 + *
  29 + * Built-in types are included but developers may use any message string, with a max length of 256 bytes.
  30 + */
  31 +data class RpcError(
  32 + /**
  33 + * The error code of the RPC call. Error codes 1001-1999 are reserved for built-in errors.
  34 + *
  35 + * See [RpcError.BuiltinRpcError] for built-in error information.
  36 + */
  37 + val code: Int,
  38 +
  39 + /**
  40 + * A message to include. Strings over 256 bytes will be truncated.
  41 + */
  42 + override val message: String,
  43 + /**
  44 + * An optional data payload. Must be smaller than 15KB in size, or else will be truncated.
  45 + */
  46 + val data: String = "",
  47 +) : Exception(message) {
  48 +
  49 + enum class BuiltinRpcError(val code: Int, val message: String) {
  50 + APPLICATION_ERROR(1500, "Application error in method handler"),
  51 + CONNECTION_TIMEOUT(1501, "Connection timeout"),
  52 + RESPONSE_TIMEOUT(1502, "Response timeout"),
  53 + RECIPIENT_DISCONNECTED(1503, "Recipient disconnected"),
  54 + RESPONSE_PAYLOAD_TOO_LARGE(1504, "Response payload too large"),
  55 + SEND_FAILED(1505, "Failed to send"),
  56 +
  57 + UNSUPPORTED_METHOD(1400, "Method not supported at destination"),
  58 + RECIPIENT_NOT_FOUND(1401, "Recipient not found"),
  59 + REQUEST_PAYLOAD_TOO_LARGE(1402, "Request payload too large"),
  60 + UNSUPPORTED_SERVER(1403, "RPC not supported by server"),
  61 + UNSUPPORTED_VERSION(1404, "Unsupported RPC version"),
  62 + ;
  63 +
  64 + fun create(data: String = ""): RpcError {
  65 + return RpcError(code, message, data)
  66 + }
  67 + }
  68 +
  69 + companion object {
  70 + const val MAX_MESSAGE_BYTES = 256
  71 +
  72 + fun fromProto(proto: LivekitModels.RpcError): RpcError {
  73 + return RpcError(
  74 + code = proto.code,
  75 + message = (proto.message ?: "").truncateBytes(MAX_MESSAGE_BYTES),
  76 + data = proto.data.truncateBytes(RTCEngine.MAX_DATA_PACKET_SIZE),
  77 + )
  78 + }
  79 + }
  80 +
  81 + fun toProto(): LivekitModels.RpcError {
  82 + return with(LivekitModels.RpcError.newBuilder()) {
  83 + this.code = this@RpcError.code
  84 + this.message = this@RpcError.message
  85 + this.data = this@RpcError.data
  86 + build()
  87 + }
  88 + }
  89 +}
  1 +/*
  2 + * Copyright 2025 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.util
  18 +
  19 +import okio.ByteString.Companion.encode
  20 +
  21 +internal fun String?.byteLength(): Int {
  22 + if (this == null) {
  23 + return 0
  24 + }
  25 + return this.encode(Charsets.UTF_8).size
  26 +}
  27 +
  28 +internal fun String.truncateBytes(maxBytes: Int): String {
  29 + if (this.byteLength() <= maxBytes) {
  30 + return this
  31 + }
  32 +
  33 + var low = 0
  34 + var high = length
  35 +
  36 + // Binary search for string that fits.
  37 + while (low < high) {
  38 + val mid = (low + high + 1) / 2
  39 + if (this.substring(0, mid).byteLength() <= maxBytes) {
  40 + low = mid
  41 + } else {
  42 + high = mid - 1
  43 + }
  44 + }
  45 +
  46 + return substring(0, low)
  47 +}
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -69,7 +69,7 @@ abstract class MockE2ETest : BaseTest() { @@ -69,7 +69,7 @@ abstract class MockE2ETest : BaseTest() {
69 room.release() 69 room.release()
70 } 70 }
71 71
72 - suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) { 72 + open suspend fun connect(joinResponse: LivekitRtc.SignalResponse = TestData.JOIN) {
73 connectSignal(joinResponse) 73 connectSignal(joinResponse)
74 connectPeerConnection() 74 connectPeerConnection()
75 } 75 }
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@ import livekit.org.webrtc.DataChannel @@ -21,7 +21,7 @@ import livekit.org.webrtc.DataChannel
21 class MockDataChannel(private val label: String?) : DataChannel(1L) { 21 class MockDataChannel(private val label: String?) : DataChannel(1L) {
22 22
23 var observer: Observer? = null 23 var observer: Observer? = null
24 - var sentBuffers = mutableListOf<Buffer?>() 24 + var sentBuffers = mutableListOf<Buffer>()
25 override fun registerObserver(observer: Observer?) { 25 override fun registerObserver(observer: Observer?) {
26 this.observer = observer 26 this.observer = observer
27 } 27 }
@@ -46,7 +46,7 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) { @@ -46,7 +46,7 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
46 return 0 46 return 0
47 } 47 }
48 48
49 - override fun send(buffer: Buffer?): Boolean { 49 + override fun send(buffer: Buffer): Boolean {
50 sentBuffers.add(buffer) 50 sentBuffers.add(buffer)
51 return true 51 return true
52 } 52 }
@@ -56,4 +56,8 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) { @@ -56,4 +56,8 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
56 56
57 override fun dispose() { 57 override fun dispose() {
58 } 58 }
  59 +
  60 + fun simulateBufferReceived(buffer: Buffer) {
  61 + observer?.onMessage(buffer)
  62 + }
59 } 63 }
1 /* 1 /*
2 - * Copyright 2023-2024 LiveKit, Inc. 2 + * Copyright 2023-2025 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package io.livekit.android.test.mock @@ -18,6 +18,7 @@ package io.livekit.android.test.mock
18 18
19 import livekit.LivekitModels 19 import livekit.LivekitModels
20 import livekit.LivekitRtc 20 import livekit.LivekitRtc
  21 +import java.util.UUID
21 22
22 object TestData { 23 object TestData {
23 24
@@ -110,7 +111,7 @@ object TestData { @@ -110,7 +111,7 @@ object TestData {
110 build() 111 build()
111 }, 112 },
112 ) 113 )
113 - serverVersion = "0.15.2" 114 + serverVersion = "1.8.0"
114 build() 115 build()
115 } 116 }
116 build() 117 build()
@@ -327,4 +328,16 @@ object TestData { @@ -327,4 +328,16 @@ object TestData {
327 } 328 }
328 build() 329 build()
329 } 330 }
  331 + val DATA_PACKET_RPC_REQUEST = with(LivekitModels.DataPacket.newBuilder()) {
  332 + participantIdentity = REMOTE_PARTICIPANT.identity
  333 + rpcRequest = with(LivekitModels.RpcRequest.newBuilder()) {
  334 + id = UUID.randomUUID().toString()
  335 + method = "hello"
  336 + payload = "hello world"
  337 + responseTimeoutMs = 10000
  338 + version = 1
  339 + build()
  340 + }
  341 + build()
  342 + }
330 } 343 }
  1 +/*
  2 + * Copyright 2023-2025 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.room.participant
  18 +
  19 +import com.google.protobuf.ByteString
  20 +import io.livekit.android.room.RTCEngine
  21 +import io.livekit.android.rpc.RpcError
  22 +import io.livekit.android.test.MockE2ETest
  23 +import io.livekit.android.test.mock.MockDataChannel
  24 +import io.livekit.android.test.mock.MockPeerConnection
  25 +import io.livekit.android.test.mock.TestData
  26 +import io.livekit.android.test.mock.TestData.REMOTE_PARTICIPANT
  27 +import io.livekit.android.test.util.toDataChannelBuffer
  28 +import kotlinx.coroutines.ExperimentalCoroutinesApi
  29 +import kotlinx.coroutines.async
  30 +import kotlinx.coroutines.launch
  31 +import livekit.LivekitModels
  32 +import livekit.LivekitRtc
  33 +import org.junit.Assert.assertEquals
  34 +import org.junit.Assert.assertTrue
  35 +import org.junit.Test
  36 +import org.junit.runner.RunWith
  37 +import org.robolectric.RobolectricTestRunner
  38 +import kotlin.time.Duration.Companion.milliseconds
  39 +
  40 +@ExperimentalCoroutinesApi
  41 +@RunWith(RobolectricTestRunner::class)
  42 +class RpcMockE2ETest : MockE2ETest() {
  43 +
  44 + lateinit var pubDataChannel: MockDataChannel
  45 + lateinit var subDataChannel: MockDataChannel
  46 +
  47 + companion object {
  48 + val ERROR = RpcError(
  49 + 1,
  50 + "This is an error message.",
  51 + "This is an error payload.",
  52 + )
  53 + }
  54 +
  55 + override suspend fun connect(joinResponse: LivekitRtc.SignalResponse) {
  56 + super.connect(joinResponse)
  57 +
  58 + val pubPeerConnection = component.rtcEngine().getPublisherPeerConnection() as MockPeerConnection
  59 + pubDataChannel = pubPeerConnection.dataChannels[RTCEngine.RELIABLE_DATA_CHANNEL_LABEL] as MockDataChannel
  60 +
  61 + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
  62 + subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
  63 + subPeerConnection.observer?.onDataChannel(subDataChannel)
  64 + }
  65 +
  66 + private fun createAck(requestId: String) =
  67 + with(LivekitModels.DataPacket.newBuilder()) {
  68 + participantIdentity = REMOTE_PARTICIPANT.identity
  69 + rpcAck = with(LivekitModels.RpcAck.newBuilder()) {
  70 + this.requestId = requestId
  71 + build()
  72 + }
  73 + build()
  74 + }.toDataChannelBuffer()
  75 +
  76 + private fun createResponse(requestId: String, payload: String? = null, error: RpcError? = null) = with(LivekitModels.DataPacket.newBuilder()) {
  77 + participantIdentity = REMOTE_PARTICIPANT.identity
  78 + rpcResponse = with(LivekitModels.RpcResponse.newBuilder()) {
  79 + this.requestId = requestId
  80 + if (error != null) {
  81 + this.error = error.toProto()
  82 + } else if (payload != null) {
  83 + this.payload = payload
  84 + }
  85 +
  86 + build()
  87 + }
  88 + build()
  89 + }.toDataChannelBuffer()
  90 +
  91 + @Test
  92 + fun handleRpcRequest() = runTest {
  93 + connect()
  94 +
  95 + var methodCalled = false
  96 + room.localParticipant.registerRpcMethod("hello") {
  97 + methodCalled = true
  98 + "bye"
  99 + }
  100 + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
  101 + assertTrue(methodCalled)
  102 +
  103 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  104 +
  105 + // Check that ack and response were sent
  106 + val buffers = pubDataChannel.sentBuffers
  107 + assertEquals(2, buffers.size)
  108 +
  109 + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  110 + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
  111 +
  112 + assertTrue(ackBuffer.hasRpcAck())
  113 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
  114 +
  115 + assertTrue(responseBuffer.hasRpcResponse())
  116 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
  117 + assertEquals("bye", responseBuffer.rpcResponse.payload)
  118 + }
  119 +
  120 + @Test
  121 + fun handleRpcRequestWithError() = runTest {
  122 + connect()
  123 +
  124 + var methodCalled = false
  125 + room.localParticipant.registerRpcMethod("hello") {
  126 + methodCalled = true
  127 + throw ERROR
  128 + }
  129 + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
  130 + assertTrue(methodCalled)
  131 +
  132 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  133 +
  134 + // Check that ack and response were sent
  135 + val buffers = pubDataChannel.sentBuffers
  136 + assertEquals(2, buffers.size)
  137 +
  138 + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  139 + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
  140 +
  141 + assertTrue(ackBuffer.hasRpcAck())
  142 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
  143 +
  144 + assertTrue(responseBuffer.hasRpcResponse())
  145 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
  146 + assertEquals(ERROR, RpcError.fromProto(responseBuffer.rpcResponse.error))
  147 + }
  148 +
  149 + @Test
  150 + fun handleRpcRequestWithNoHandler() = runTest {
  151 + connect()
  152 +
  153 + subDataChannel.simulateBufferReceived(TestData.DATA_PACKET_RPC_REQUEST.toDataChannelBuffer())
  154 +
  155 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  156 +
  157 + // Check that ack and response were sent
  158 + val buffers = pubDataChannel.sentBuffers
  159 + assertEquals(2, buffers.size)
  160 +
  161 + val ackBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  162 + val responseBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[1].data))
  163 +
  164 + assertTrue(ackBuffer.hasRpcAck())
  165 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, ackBuffer.rpcAck.requestId)
  166 +
  167 + assertTrue(responseBuffer.hasRpcResponse())
  168 + assertEquals(TestData.DATA_PACKET_RPC_REQUEST.rpcRequest.id, responseBuffer.rpcResponse.requestId)
  169 + assertEquals(RpcError.BuiltinRpcError.UNSUPPORTED_METHOD.create(), RpcError.fromProto(responseBuffer.rpcResponse.error))
  170 + }
  171 +
  172 + @Test
  173 + fun performRpc() = runTest {
  174 + connect()
  175 +
  176 + val rpcJob = async {
  177 + room.localParticipant.performRpc(
  178 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  179 + method = "hello",
  180 + payload = "hello world",
  181 + )
  182 + }
  183 +
  184 + // Check that request was sent
  185 + val buffers = pubDataChannel.sentBuffers
  186 + assertEquals(1, buffers.size)
  187 +
  188 + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  189 +
  190 + assertTrue(requestBuffer.hasRpcRequest())
  191 + assertEquals("hello", requestBuffer.rpcRequest.method)
  192 + assertEquals("hello world", requestBuffer.rpcRequest.payload)
  193 +
  194 + val requestId = requestBuffer.rpcRequest.id
  195 +
  196 + // receive ack and response
  197 + subDataChannel.simulateBufferReceived(createAck(requestId))
  198 + subDataChannel.simulateBufferReceived(createResponse(requestId, payload = "bye"))
  199 +
  200 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  201 + val response = rpcJob.await()
  202 +
  203 + assertEquals("bye", response)
  204 + }
  205 +
  206 + @Test
  207 + fun performRpcWithError() = runTest {
  208 + connect()
  209 +
  210 + val rpcJob = async {
  211 + var expectedError: Exception? = null
  212 + try {
  213 + room.localParticipant.performRpc(
  214 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  215 + method = "hello",
  216 + payload = "hello world",
  217 + )
  218 + } catch (e: Exception) {
  219 + expectedError = e
  220 + }
  221 + return@async expectedError
  222 + }
  223 +
  224 + val buffers = pubDataChannel.sentBuffers
  225 + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  226 + val requestId = requestBuffer.rpcRequest.id
  227 +
  228 + // receive ack and response
  229 + subDataChannel.simulateBufferReceived(createAck(requestId))
  230 + subDataChannel.simulateBufferReceived(createResponse(requestId, error = ERROR))
  231 +
  232 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  233 + val receivedError = rpcJob.await()
  234 +
  235 + assertEquals(ERROR, receivedError)
  236 + }
  237 +
  238 + @Test
  239 + fun performRpcWithParticipantDisconnected() = runTest {
  240 + connect()
  241 + simulateMessageFromServer(TestData.PARTICIPANT_JOIN)
  242 +
  243 + val rpcJob = async {
  244 + var expectedError: Exception? = null
  245 + try {
  246 + room.localParticipant.performRpc(
  247 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  248 + method = "hello",
  249 + payload = "hello world",
  250 + )
  251 + } catch (e: Exception) {
  252 + expectedError = e
  253 + }
  254 + return@async expectedError
  255 + }
  256 +
  257 + simulateMessageFromServer(TestData.PARTICIPANT_DISCONNECT)
  258 +
  259 + coroutineRule.dispatcher.scheduler.advanceUntilIdle()
  260 + val error = rpcJob.await()
  261 +
  262 + assertEquals(RpcError.BuiltinRpcError.RECIPIENT_DISCONNECTED.create(), error)
  263 + }
  264 +
  265 + @Test
  266 + fun performRpcWithConnectionTimeoutError() = runTest {
  267 + connect()
  268 +
  269 + val rpcJob = async {
  270 + var expectedError: Exception? = null
  271 + try {
  272 + room.localParticipant.performRpc(
  273 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  274 + method = "hello",
  275 + payload = "hello world",
  276 + )
  277 + } catch (e: Exception) {
  278 + expectedError = e
  279 + }
  280 + return@async expectedError
  281 + }
  282 +
  283 + coroutineRule.dispatcher.scheduler.advanceTimeBy(3000)
  284 +
  285 + val error = rpcJob.await()
  286 +
  287 + assertEquals(RpcError.BuiltinRpcError.CONNECTION_TIMEOUT.create(), error)
  288 + }
  289 +
  290 + @Test
  291 + fun performRpcWithResponseTimeoutError() = runTest {
  292 + connect()
  293 +
  294 + val rpcJob = async {
  295 + var expectedError: Exception? = null
  296 + try {
  297 + room.localParticipant.performRpc(
  298 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  299 + method = "hello",
  300 + payload = "hello world",
  301 + )
  302 + } catch (e: Exception) {
  303 + expectedError = e
  304 + }
  305 + return@async expectedError
  306 + }
  307 +
  308 + val buffers = pubDataChannel.sentBuffers
  309 + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  310 + val requestId = requestBuffer.rpcRequest.id
  311 +
  312 + // receive ack only
  313 + subDataChannel.simulateBufferReceived(createAck(requestId))
  314 +
  315 + coroutineRule.dispatcher.scheduler.advanceTimeBy(15000)
  316 +
  317 + val error = rpcJob.await()
  318 +
  319 + assertEquals(RpcError.BuiltinRpcError.RESPONSE_TIMEOUT.create(), error)
  320 + }
  321 +
  322 + @Test
  323 + fun uintMaxValueVerification() = runTest {
  324 + assertEquals(4_294_967_295L, UInt.MAX_VALUE.toLong())
  325 + }
  326 +
  327 + /**
  328 + * Protobuf handles UInt32 as Java signed integers.
  329 + * This test verifies whether our conversion is properly sent over the wire.
  330 + */
  331 + @Test
  332 + fun performRpcProtoUIntVerification() = runTest {
  333 + connect()
  334 + val rpcJob = launch {
  335 + room.localParticipant.performRpc(
  336 + destinationIdentity = Participant.Identity(REMOTE_PARTICIPANT.identity),
  337 + method = "hello",
  338 + payload = "hello world",
  339 + responseTimeout = UInt.MAX_VALUE.toLong().milliseconds,
  340 + )
  341 + }
  342 +
  343 + val buffers = pubDataChannel.sentBuffers
  344 + val requestBuffer = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffers[0].data))
  345 +
  346 + val expectedResponseTimeout = UInt.MAX_VALUE - 2000u // 2000 comes from maxRoundTripLatency
  347 + val responseTimeout = requestBuffer.rpcRequest.responseTimeoutMs.toUInt()
  348 + assertEquals(expectedResponseTimeout, responseTimeout)
  349 + rpcJob.cancel()
  350 + }
  351 +
  352 + /**
  353 + * Protobuf handles UInt32 as Java signed integers.
  354 + * This test verifies whether our conversion is properly sent over the wire.
  355 + */
  356 + @Test
  357 + fun handleRpcProtoUIntVerification() = runTest {
  358 + connect()
  359 +
  360 + var methodCalled = false
  361 + room.localParticipant.registerRpcMethod("hello") { invocationData ->
  362 + assertEquals(4_294_967_295L, invocationData.responseTimeout.inWholeMilliseconds)
  363 + methodCalled = true
  364 + "bye"
  365 + }
  366 + subDataChannel.simulateBufferReceived(
  367 + with(TestData.DATA_PACKET_RPC_REQUEST.toBuilder()) {
  368 + rpcRequest = with(rpcRequest.toBuilder()) {
  369 + responseTimeoutMs = UInt.MAX_VALUE.toInt()
  370 + build()
  371 + }
  372 + build()
  373 + }.toDataChannelBuffer(),
  374 + )
  375 + assertTrue(methodCalled)
  376 + }
  377 +}
1 -Subproject commit a601adc5e9027820857a6d445b32a868b19d4184 1 +Subproject commit 9e8d1e37c5eb4434424bc16c657c83e7dc63bc2a