davidliu
Committed by GitHub

Sanitize response handling and clear out queued requests on disconnect (#387)

* Sanitize response handling and clear out queued requests upon disconnection

* cleanup test code

* spotless

* More sanitization

* Fix tests
@@ -84,13 +84,19 @@ constructor( @@ -84,13 +84,19 @@ constructor(
84 >? = null 84 >? = null
85 private lateinit var coroutineScope: CloseableCoroutineScope 85 private lateinit var coroutineScope: CloseableCoroutineScope
86 86
  87 + /**
  88 + * @see [startRequestQueue]
  89 + */
  90 + private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
87 private val requestFlowJobLock = Object() 91 private val requestFlowJobLock = Object()
88 private var requestFlowJob: Job? = null 92 private var requestFlowJob: Job? = null
89 - private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)  
90 93
  94 + /**
  95 + * @see [onReadyForResponses]
  96 + */
  97 + private val responseFlow = MutableSharedFlow<Pair<WebSocket, LivekitRtc.SignalResponse>>(Int.MAX_VALUE)
91 private val responseFlowJobLock = Object() 98 private val responseFlowJobLock = Object()
92 private var responseFlowJob: Job? = null 99 private var responseFlowJob: Job? = null
93 - private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)  
94 100
95 private var pingJob: Job? = null 101 private var pingJob: Job? = null
96 private var pongJob: Job? = null 102 private var pongJob: Job? = null
@@ -137,7 +143,7 @@ constructor( @@ -137,7 +143,7 @@ constructor(
137 roomOptions: RoomOptions, 143 roomOptions: RoomOptions,
138 ): Either<JoinResponse, Either<ReconnectResponse, Unit>> { 144 ): Either<JoinResponse, Either<ReconnectResponse, Unit>> {
139 // Clean up any pre-existing connection. 145 // Clean up any pre-existing connection.
140 - close(reason = "Starting new connection") 146 + close(reason = "Starting new connection", shouldClearQueuedRequests = false)
141 147
142 val wsUrlString = "$url/rtc" + createConnectionParams(token, getClientInfo(), options, roomOptions) 148 val wsUrlString = "$url/rtc" + createConnectionParams(token, getClientInfo(), options, roomOptions)
143 isReconnecting = options.reconnect 149 isReconnecting = options.reconnect
@@ -210,9 +216,9 @@ constructor( @@ -210,9 +216,9 @@ constructor(
210 synchronized(responseFlowJobLock) { 216 synchronized(responseFlowJobLock) {
211 if (responseFlowJob == null) { 217 if (responseFlowJob == null) {
212 responseFlowJob = coroutineScope.launch { 218 responseFlowJob = coroutineScope.launch {
213 - responseFlow.collect { 219 + responseFlow.collect { (ws, response) ->
214 responseFlow.resetReplayCache() 220 responseFlow.resetReplayCache()
215 - handleSignalResponseImpl(it) 221 + handleSignalResponseImpl(ws, response)
216 } 222 }
217 } 223 }
218 } 224 }
@@ -246,19 +252,31 @@ constructor( @@ -246,19 +252,31 @@ constructor(
246 252
247 // --------------------------------- WebSocket Listener --------------------------------------// 253 // --------------------------------- WebSocket Listener --------------------------------------//
248 override fun onMessage(webSocket: WebSocket, text: String) { 254 override fun onMessage(webSocket: WebSocket, text: String) {
  255 + if (webSocket != currentWs) {
  256 + // Possibly message from old websocket, discard.
  257 + return
  258 + }
  259 +
249 LKLog.w { "received JSON message, unsupported in this version." } 260 LKLog.w { "received JSON message, unsupported in this version." }
250 } 261 }
251 262
252 override fun onMessage(webSocket: WebSocket, bytes: ByteString) { 263 override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
  264 + if (webSocket != currentWs) {
  265 + // Possibly message from old websocket, discard.
  266 + return
  267 + }
253 val byteArray = bytes.toByteArray() 268 val byteArray = bytes.toByteArray()
254 val signalResponseBuilder = LivekitRtc.SignalResponse.newBuilder() 269 val signalResponseBuilder = LivekitRtc.SignalResponse.newBuilder()
255 .mergeFrom(byteArray) 270 .mergeFrom(byteArray)
256 val response = signalResponseBuilder.build() 271 val response = signalResponseBuilder.build()
257 272
258 - handleSignalResponse(response) 273 + handleSignalResponse(webSocket, response)
259 } 274 }
260 275
261 override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { 276 override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
  277 + if (webSocket != currentWs) {
  278 + return
  279 + }
262 handleWebSocketClose(reason, code) 280 handleWebSocketClose(reason, code)
263 } 281 }
264 282
@@ -267,6 +285,9 @@ constructor( @@ -267,6 +285,9 @@ constructor(
267 } 285 }
268 286
269 override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { 287 override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
  288 + if (webSocket != currentWs) {
  289 + return
  290 + }
270 var reason: String? = null 291 var reason: String? = null
271 try { 292 try {
272 lastUrl?.let { 293 lastUrl?.let {
@@ -553,7 +574,11 @@ constructor( @@ -553,7 +574,11 @@ constructor(
553 } 574 }
554 } 575 }
555 576
556 - private fun handleSignalResponse(response: LivekitRtc.SignalResponse) { 577 + private fun handleSignalResponse(ws: WebSocket, response: LivekitRtc.SignalResponse) {
  578 + if (ws != currentWs) {
  579 + return
  580 + }
  581 +
557 LKLog.v { "response: $response" } 582 LKLog.v { "response: $response" }
558 583
559 if (!isConnected) { 584 if (!isConnected) {
@@ -574,7 +599,7 @@ constructor( @@ -574,7 +599,7 @@ constructor(
574 joinContinuation?.resumeWith(Result.success(Either.Left(response.join))) 599 joinContinuation?.resumeWith(Result.success(Either.Left(response.join)))
575 } else if (response.hasLeave()) { 600 } else if (response.hasLeave()) {
576 // Some reconnects may immediately send leave back without a join response first. 601 // Some reconnects may immediately send leave back without a join response first.
577 - handleSignalResponseImpl(response) 602 + handleSignalResponseImpl(ws, response)
578 } else if (isReconnecting) { 603 } else if (isReconnecting) {
579 // When reconnecting, any message received means signal reconnected. 604 // When reconnecting, any message received means signal reconnected.
580 // Newer servers will send a reconnect response first 605 // Newer servers will send a reconnect response first
@@ -598,10 +623,15 @@ constructor( @@ -598,10 +623,15 @@ constructor(
598 return 623 return
599 } 624 }
600 } 625 }
601 - responseFlow.tryEmit(response) 626 + responseFlow.tryEmit(ws to response)
602 } 627 }
603 628
604 - private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) { 629 + private fun handleSignalResponseImpl(ws: WebSocket, response: LivekitRtc.SignalResponse) {
  630 + if (ws != currentWs) {
  631 + LKLog.v { "received message from old websocket, discarding." }
  632 + return
  633 + }
  634 +
605 when (response.messageCase) { 635 when (response.messageCase) {
606 LivekitRtc.SignalResponse.MessageCase.ANSWER -> { 636 LivekitRtc.SignalResponse.MessageCase.ANSWER -> {
607 val sd = fromProtoSessionDescription(response.answer) 637 val sd = fromProtoSessionDescription(response.answer)
@@ -738,7 +768,7 @@ constructor( @@ -738,7 +768,7 @@ constructor(
738 * 768 *
739 * Can be reused afterwards. 769 * Can be reused afterwards.
740 */ 770 */
741 - fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure") { 771 + fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure", shouldClearQueuedRequests: Boolean = true) {
742 LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" } 772 LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" }
743 isConnected = false 773 isConnected = false
744 isReconnecting = false 774 isReconnecting = false
@@ -757,8 +787,9 @@ constructor( @@ -757,8 +787,9 @@ constructor(
757 currentWs = null 787 currentWs = null
758 joinContinuation?.cancel() 788 joinContinuation?.cancel()
759 joinContinuation = null 789 joinContinuation = null
760 - // TODO: support calling this from connect without wiping any queued requests.  
761 - // requestFlow.resetReplayCache() 790 + if (shouldClearQueuedRequests) {
  791 + requestFlow.resetReplayCache()
  792 + }
762 responseFlow.resetReplayCache() 793 responseFlow.resetReplayCache()
763 lastUrl = null 794 lastUrl = null
764 lastOptions = null 795 lastOptions = null
@@ -35,6 +35,7 @@ import okhttp3.Protocol @@ -35,6 +35,7 @@ import okhttp3.Protocol
35 import okhttp3.Request 35 import okhttp3.Request
36 import okhttp3.Response 36 import okhttp3.Response
37 import okio.ByteString 37 import okio.ByteString
  38 +import org.junit.After
38 import org.junit.Before 39 import org.junit.Before
39 import org.junit.runner.RunWith 40 import org.junit.runner.RunWith
40 import org.robolectric.RobolectricTestRunner 41 import org.robolectric.RobolectricTestRunner
@@ -60,6 +61,11 @@ abstract class MockE2ETest : BaseTest() { @@ -60,6 +61,11 @@ abstract class MockE2ETest : BaseTest() {
60 wsFactory = component.websocketFactory() 61 wsFactory = component.websocketFactory()
61 } 62 }
62 63
  64 + @After
  65 + fun tearDown() {
  66 + room.release()
  67 + }
  68 +
63 suspend fun connect(joinResponse: LivekitRtc.SignalResponse = SignalClientTest.JOIN) { 69 suspend fun connect(joinResponse: LivekitRtc.SignalResponse = SignalClientTest.JOIN) {
64 connectSignal(joinResponse) 70 connectSignal(joinResponse)
65 connectPeerConnection() 71 connectPeerConnection()
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -52,7 +52,6 @@ class MockWebSocketFactory : WebSocket.Factory { @@ -52,7 +52,6 @@ class MockWebSocketFactory : WebSocket.Factory {
52 this.listener = listener 52 this.listener = listener
53 this.request = request 53 this.request = request
54 54
55 - onOpen?.invoke(this)  
56 return ws 55 return ws
57 } 56 }
58 57
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -20,14 +20,11 @@ import dagger.Module @@ -20,14 +20,11 @@ import dagger.Module
20 import dagger.Provides 20 import dagger.Provides
21 import io.livekit.android.dagger.InjectionNames 21 import io.livekit.android.dagger.InjectionNames
22 import kotlinx.coroutines.CoroutineDispatcher 22 import kotlinx.coroutines.CoroutineDispatcher
23 -import kotlinx.coroutines.ExperimentalCoroutinesApi  
24 -import kotlinx.coroutines.test.TestCoroutineDispatcher  
25 import javax.inject.Named 23 import javax.inject.Named
26 24
27 @Module 25 @Module
28 class TestCoroutinesModule( 26 class TestCoroutinesModule(
29 - @OptIn(ExperimentalCoroutinesApi::class)  
30 - val coroutineDispatcher: CoroutineDispatcher = TestCoroutineDispatcher(), 27 + private val coroutineDispatcher: CoroutineDispatcher,
31 ) { 28 ) {
32 29
33 @Provides 30 @Provides
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -47,7 +47,7 @@ internal interface TestLiveKitComponent : LiveKitComponent { @@ -47,7 +47,7 @@ internal interface TestLiveKitComponent : LiveKitComponent {
47 interface Factory { 47 interface Factory {
48 fun create( 48 fun create(
49 @BindsInstance appContext: Context, 49 @BindsInstance appContext: Context,
50 - coroutinesModule: TestCoroutinesModule = TestCoroutinesModule(), 50 + coroutinesModule: TestCoroutinesModule,
51 ): TestLiveKitComponent 51 ): TestLiveKitComponent
52 } 52 }
53 } 53 }
@@ -38,19 +38,17 @@ import org.robolectric.RobolectricTestRunner @@ -38,19 +38,17 @@ import org.robolectric.RobolectricTestRunner
38 @RunWith(RobolectricTestRunner::class) 38 @RunWith(RobolectricTestRunner::class)
39 class RoomReconnectionMockE2ETest : MockE2ETest() { 39 class RoomReconnectionMockE2ETest : MockE2ETest() {
40 40
41 - private fun prepareForReconnect() {  
42 - wsFactory.onOpen = {  
43 - wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))  
44 - val softReconnectParam = wsFactory.request.url  
45 - .queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)  
46 - ?.toIntOrNull()  
47 - ?: 0  
48 -  
49 - if (softReconnectParam == 0) {  
50 - simulateMessageFromServer(SignalClientTest.JOIN)  
51 - } else {  
52 - simulateMessageFromServer(SignalClientTest.RECONNECT)  
53 - } 41 + private fun reconnectWebsocket() {
  42 + wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  43 + val softReconnectParam = wsFactory.request.url
  44 + .queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
  45 + ?.toIntOrNull()
  46 + ?: 0
  47 +
  48 + if (softReconnectParam == 0) {
  49 + simulateMessageFromServer(SignalClientTest.JOIN)
  50 + } else {
  51 + simulateMessageFromServer(SignalClientTest.RECONNECT)
54 } 52 }
55 } 53 }
56 54
@@ -59,10 +57,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -59,10 +57,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
59 room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT) 57 room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
60 58
61 connect() 59 connect()
62 - prepareForReconnect()  
63 disconnectPeerConnection() 60 disconnectPeerConnection()
64 // Wait so that the reconnect job properly starts first. 61 // Wait so that the reconnect job properly starts first.
65 testScheduler.advanceTimeBy(1000) 62 testScheduler.advanceTimeBy(1000)
  63 + reconnectWebsocket()
66 connectPeerConnection() 64 connectPeerConnection()
67 65
68 testScheduler.advanceUntilIdle() 66 testScheduler.advanceUntilIdle()
@@ -82,10 +80,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -82,10 +80,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
82 fun softReconnectConfiguration() = runTest { 80 fun softReconnectConfiguration() = runTest {
83 room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT) 81 room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
84 connect() 82 connect()
85 - prepareForReconnect()  
86 disconnectPeerConnection() 83 disconnectPeerConnection()
87 // Wait so that the reconnect job properly starts first. 84 // Wait so that the reconnect job properly starts first.
88 testScheduler.advanceTimeBy(1000) 85 testScheduler.advanceTimeBy(1000)
  86 + reconnectWebsocket()
89 connectPeerConnection() 87 connectPeerConnection()
90 88
91 val rtcConfig = getSubscriberPeerConnection().rtcConfig 89 val rtcConfig = getSubscriberPeerConnection().rtcConfig
@@ -109,10 +107,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -109,10 +107,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
109 ), 107 ),
110 ) 108 )
111 109
112 - prepareForReconnect()  
113 disconnectPeerConnection() 110 disconnectPeerConnection()
114 // Wait so that the reconnect job properly starts first. 111 // Wait so that the reconnect job properly starts first.
115 testScheduler.advanceTimeBy(1000) 112 testScheduler.advanceTimeBy(1000)
  113 + reconnectWebsocket()
116 connectPeerConnection() 114 connectPeerConnection()
117 115
118 testScheduler.advanceUntilIdle() 116 testScheduler.advanceUntilIdle()
@@ -45,19 +45,17 @@ class RoomReconnectionTypesMockE2ETest( @@ -45,19 +45,17 @@ class RoomReconnectionTypesMockE2ETest(
45 ) 45 )
46 } 46 }
47 47
48 - private fun prepareForReconnect() {  
49 - wsFactory.onOpen = {  
50 - wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))  
51 - val softReconnectParam = wsFactory.request.url  
52 - .queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)  
53 - ?.toIntOrNull()  
54 - ?: 0  
55 -  
56 - if (softReconnectParam == 0) {  
57 - simulateMessageFromServer(SignalClientTest.JOIN)  
58 - } else {  
59 - simulateMessageFromServer(SignalClientTest.RECONNECT)  
60 - } 48 + private fun reconnectWebsocket() {
  49 + wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  50 + val softReconnectParam = wsFactory.request.url
  51 + .queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
  52 + ?.toIntOrNull()
  53 + ?: 0
  54 +
  55 + if (softReconnectParam == 0) {
  56 + simulateMessageFromServer(SignalClientTest.JOIN)
  57 + } else {
  58 + simulateMessageFromServer(SignalClientTest.RECONNECT)
61 } 59 }
62 } 60 }
63 61
@@ -111,10 +109,10 @@ class RoomReconnectionTypesMockE2ETest( @@ -111,10 +109,10 @@ class RoomReconnectionTypesMockE2ETest(
111 109
112 val eventCollector = EventCollector(room.events, coroutineRule.scope) 110 val eventCollector = EventCollector(room.events, coroutineRule.scope)
113 val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope) 111 val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
114 - prepareForReconnect()  
115 disconnectPeerConnection() 112 disconnectPeerConnection()
116 // Wait so that the reconnect job properly starts first. 113 // Wait so that the reconnect job properly starts first.
117 testScheduler.advanceTimeBy(1000) 114 testScheduler.advanceTimeBy(1000)
  115 + reconnectWebsocket()
118 connectPeerConnection() 116 connectPeerConnection()
119 117
120 testScheduler.advanceUntilIdle() 118 testScheduler.advanceUntilIdle()
@@ -138,10 +136,10 @@ class RoomReconnectionTypesMockE2ETest( @@ -138,10 +136,10 @@ class RoomReconnectionTypesMockE2ETest(
138 136
139 val eventCollector = EventCollector(room.events, coroutineRule.scope) 137 val eventCollector = EventCollector(room.events, coroutineRule.scope)
140 val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope) 138 val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
141 - prepareForReconnect()  
142 wsFactory.ws.cancel() 139 wsFactory.ws.cancel()
143 // Wait so that the reconnect job properly starts first. 140 // Wait so that the reconnect job properly starts first.
144 testScheduler.advanceTimeBy(1000) 141 testScheduler.advanceTimeBy(1000)
  142 + reconnectWebsocket()
145 connectPeerConnection() 143 connectPeerConnection()
146 144
147 testScheduler.advanceUntilIdle() 145 testScheduler.advanceUntilIdle()