davidliu
Committed by GitHub

Queue requests when reconnecting (#59)

* Queue requests when reconnecting

* more tests
... ... @@ -5,7 +5,9 @@ import io.livekit.android.ConnectOptions
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.ParticipantTrackPermission
import io.livekit.android.room.track.TrackException
import io.livekit.android.room.util.*
import io.livekit.android.room.util.MediaConstraintKeys
import io.livekit.android.room.util.createAnswer
import io.livekit.android.room.util.setLocalDescription
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import io.livekit.android.util.LKLog
... ... @@ -132,7 +134,7 @@ internal constructor(
if (!this.isSubscriberPrimary) {
negotiate()
}
client.onReady()
client.onReadyForResponses()
return joinResponse
}
... ... @@ -350,7 +352,7 @@ internal constructor(
try {
client.reconnect(url, token)
// no join response for regular reconnects
client.onReady()
client.onReadyForResponses()
} catch (e: Exception) {
LKLog.w(e) { "Error during reconnection." }
// ws reconnect failed, retry.
... ... @@ -378,6 +380,7 @@ internal constructor(
}
if (connectionState == ConnectionState.CONNECTED) {
client.onPCConnected()
listener?.onPostReconnect(isFullReconnect)
return@launch
}
... ...
package io.livekit.android.room
import android.net.Uri
import com.google.protobuf.util.JsonFormat
import com.vdurmont.semver4j.Semver
import io.livekit.android.ConnectOptions
import io.livekit.android.Version
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.ParticipantTrackPermission
import io.livekit.android.room.track.Track
... ... @@ -60,6 +58,10 @@ constructor(
private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null
private lateinit var coroutineScope: CloseableCoroutineScope
private val requestFlowJobLock = Object()
private var requestFlowJob: Job? = null
private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
/**
... ... @@ -150,7 +152,7 @@ constructor(
* Should be called after resolving the join message.
*/
@OptIn(ExperimentalCoroutinesApi::class)
fun onReady() {
fun onReadyForResponses() {
coroutineScope.launch {
responseFlow.collect {
responseFlow.resetReplayCache()
... ... @@ -159,6 +161,31 @@ constructor(
}
}
@OptIn(ExperimentalCoroutinesApi::class)
private fun startRequestQueue() {
if (requestFlowJob != null) {
return
}
synchronized(requestFlowJobLock) {
if (requestFlowJob == null) {
requestFlowJob = coroutineScope.launch {
requestFlow.collect {
requestFlow.resetReplayCache()
sendRequestImpl(it)
}
}
}
}
}
/**
* On reconnection, SignalClient waits until the peer connection is established to send messages.
* Call this method when it is connected.
*/
fun onPCConnected() {
startRequestQueue()
}
//--------------------------------- WebSocket Listener --------------------------------------//
override fun onOpen(webSocket: WebSocket, response: Response) {
if (isReconnecting) {
... ... @@ -402,6 +429,16 @@ constructor(
}
private fun sendRequest(request: LivekitRtc.SignalRequest) {
val skipQueue = skipQueueTypes.contains(request.messageCase)
if (skipQueue) {
sendRequestImpl(request)
} else {
requestFlow.tryEmit(request)
}
}
private fun sendRequestImpl(request: LivekitRtc.SignalRequest) {
LKLog.v { "sending request: $request" }
if (!isConnected || currentWs == null) {
LKLog.w { "not connected, could not send request $request" }
... ... @@ -428,6 +465,7 @@ constructor(
// Only handle joins if not connected.
if (response.hasJoin()) {
isConnected = true
startRequestQueue()
try {
serverVersion = Semver(response.join.serverVersion)
} catch (t: Throwable) {
... ... @@ -439,9 +477,7 @@ constructor(
}
return
}
coroutineScope.launch {
responseFlow.tryEmit(response)
}
responseFlow.tryEmit(response)
}
private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) {
... ... @@ -519,6 +555,7 @@ constructor(
fun close(code: Int = 1000, reason: String = "Normal Closure") {
isConnected = false
isReconnecting = false
requestFlowJob = null
if (::coroutineScope.isInitialized) {
coroutineScope.close()
}
... ... @@ -564,6 +601,14 @@ constructor(
const val PROTOCOL_VERSION = 6
const val SDK_TYPE = "android"
private val skipQueueTypes = listOf(
LivekitRtc.SignalRequest.MessageCase.SYNC_STATE,
LivekitRtc.SignalRequest.MessageCase.TRICKLE,
LivekitRtc.SignalRequest.MessageCase.OFFER,
LivekitRtc.SignalRequest.MessageCase.ANSWER,
LivekitRtc.SignalRequest.MessageCase.SIMULATE
)
private fun iceServer(url: String) =
PeerConnection.IceServer.builder(url).createIceServer()
... ...
... ... @@ -8,7 +8,7 @@ class MockWebSocketFactory : WebSocket.Factory {
/**
* The most recently created [WebSocket].
*/
lateinit var ws: WebSocket
lateinit var ws: MockWebSocket
/**
* The request used to create [ws]
... ...
... ... @@ -33,7 +33,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
rtcEngine.subscriber.peerConnection.remoteDescription.description
)
val ws = wsFactory.ws as MockWebSocket
val ws = wsFactory.ws
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
... ...
... ... @@ -5,6 +5,7 @@ import io.livekit.android.BaseTest
import io.livekit.android.mock.MockWebSocketFactory
import io.livekit.android.mock.TestData
import io.livekit.android.util.toOkioByteString
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.serialization.json.Json
... ... @@ -130,7 +131,7 @@ class SignalClientTest : BaseTest() {
client.onMessage(wsFactory.ws, OFFER.toOkioByteString())
job.await()
client.onReady()
client.onReadyForResponses()
Mockito.verify(listener)
.onOffer(argThat { type == SessionDescription.Type.OFFER && description == OFFER.offer.sdp })
}
... ... @@ -153,6 +154,68 @@ class SignalClientTest : BaseTest() {
.onClose(any(), any())
}
@Test
fun sendRequest() = runTest {
val job = async { client.join(EXAMPLE_URL, "") }
connectWebsocketAndJoin()
job.await()
client.sendMuteTrack("sid", true)
val ws = wsFactory.ws
Assert.assertEquals(1, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
}
@Test
fun queuedRequests() = runTest {
client.sendMuteTrack("sid", true)
client.sendMuteTrack("sid", true)
client.sendMuteTrack("sid", true)
val job = async { client.join(EXAMPLE_URL, "") }
connectWebsocketAndJoin()
job.await()
val ws = wsFactory.ws
Assert.assertEquals(3, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
}
@Test
fun queuedRequestsWhileReconnecting() = runTest {
client.sendMuteTrack("sid", true)
client.sendMuteTrack("sid", true)
client.sendMuteTrack("sid", true)
val job = async { client.reconnect(EXAMPLE_URL, "") }
client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
job.await()
val ws = wsFactory.ws
// Wait until peer connection is connected to send requests.
Assert.assertEquals(0, ws.sentRequests.size)
client.onPCConnected()
Assert.assertEquals(3, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
}
// mock data
companion object {
const val EXAMPLE_URL = "ws://www.example.com"
... ...