davidliu
Committed by GitHub

Websocket Ping/Pong (#144)

* Update protocol submodule

* fix compile

* Remove obsolete when safe operator

* ping pong implementation

* Add in backup websocket ping/pong

* pingpong test

* lower okhttp ping interval to 20 seconds for now
@@ -9,6 +9,7 @@ import io.livekit.android.stats.AndroidNetworkInfo @@ -9,6 +9,7 @@ import io.livekit.android.stats.AndroidNetworkInfo
9 import io.livekit.android.stats.NetworkInfo 9 import io.livekit.android.stats.NetworkInfo
10 import okhttp3.OkHttpClient 10 import okhttp3.OkHttpClient
11 import okhttp3.WebSocket 11 import okhttp3.WebSocket
  12 +import java.util.concurrent.TimeUnit
12 import javax.inject.Named 13 import javax.inject.Named
13 import javax.inject.Singleton 14 import javax.inject.Singleton
14 15
@@ -21,6 +22,10 @@ object WebModule { @@ -21,6 +22,10 @@ object WebModule {
21 @Nullable 22 @Nullable
22 okHttpClientOverride: OkHttpClient? 23 okHttpClientOverride: OkHttpClient?
23 ): OkHttpClient { 24 ): OkHttpClient {
  25 + OkHttpClient.Builder()
  26 + .pingInterval(20, TimeUnit.SECONDS)
  27 + .build()
  28 +
24 return okHttpClientOverride ?: OkHttpClient() 29 return okHttpClientOverride ?: OkHttpClient()
25 } 30 }
26 31
@@ -11,7 +11,6 @@ import io.livekit.android.stats.getClientInfo @@ -11,7 +11,6 @@ import io.livekit.android.stats.getClientInfo
11 import io.livekit.android.util.CloseableCoroutineScope 11 import io.livekit.android.util.CloseableCoroutineScope
12 import io.livekit.android.util.Either 12 import io.livekit.android.util.Either
13 import io.livekit.android.util.LKLog 13 import io.livekit.android.util.LKLog
14 -import io.livekit.android.util.safe  
15 import io.livekit.android.webrtc.toProtoSessionDescription 14 import io.livekit.android.webrtc.toProtoSessionDescription
16 import kotlinx.coroutines.* 15 import kotlinx.coroutines.*
17 import kotlinx.coroutines.flow.MutableSharedFlow 16 import kotlinx.coroutines.flow.MutableSharedFlow
@@ -26,6 +25,7 @@ import okio.ByteString.Companion.toByteString @@ -26,6 +25,7 @@ import okio.ByteString.Companion.toByteString
26 import org.webrtc.IceCandidate 25 import org.webrtc.IceCandidate
27 import org.webrtc.PeerConnection 26 import org.webrtc.PeerConnection
28 import org.webrtc.SessionDescription 27 import org.webrtc.SessionDescription
  28 +import java.util.*
29 import javax.inject.Inject 29 import javax.inject.Inject
30 import javax.inject.Named 30 import javax.inject.Named
31 import javax.inject.Singleton 31 import javax.inject.Singleton
@@ -65,6 +65,11 @@ constructor( @@ -65,6 +65,11 @@ constructor(
65 65
66 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE) 66 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
67 67
  68 + private var pingJob: Job? = null
  69 + private var pongJob: Job? = null
  70 + private var pingTimeoutDurationMillis: Long = 0
  71 + private var pingIntervalDurationMillis: Long = 0
  72 +
68 var connectionState: ConnectionState = ConnectionState.DISCONNECTED 73 var connectionState: ConnectionState = ConnectionState.DISCONNECTED
69 74
70 /** 75 /**
@@ -205,6 +210,8 @@ constructor( @@ -205,6 +210,8 @@ constructor(
205 isReconnecting = false 210 isReconnecting = false
206 isConnected = true 211 isConnected = true
207 joinContinuation?.resumeWith(Result.success(Either.Right(Unit))) 212 joinContinuation?.resumeWith(Result.success(Either.Right(Unit)))
  213 + // Restart ping job with old settings.
  214 + startPingJob()
208 } 215 }
209 response.body?.close() 216 response.body?.close()
210 } 217 }
@@ -259,19 +266,23 @@ constructor( @@ -259,19 +266,23 @@ constructor(
259 } 266 }
260 267
261 val wasConnected = isConnected 268 val wasConnected = isConnected
262 - isConnected = false  
263 269
264 if (wasConnected) { 270 if (wasConnected) {
265 handleWebSocketClose( 271 handleWebSocketClose(
266 reason = reason ?: response?.toString() ?: t.localizedMessage ?: "websocket failure", 272 reason = reason ?: response?.toString() ?: t.localizedMessage ?: "websocket failure",
267 - code = response?.code ?: 500 273 + code = response?.code ?: CLOSE_REASON_WEBSOCKET_FAILURE
268 ) 274 )
269 } 275 }
270 } 276 }
271 277
272 private fun handleWebSocketClose(reason: String, code: Int) { 278 private fun handleWebSocketClose(reason: String, code: Int) {
273 LKLog.v { "websocket closed" } 279 LKLog.v { "websocket closed" }
  280 + isConnected = false
274 listener?.onClose(reason, code) 281 listener?.onClose(reason, code)
  282 + requestFlow.resetReplayCache()
  283 + responseFlow.resetReplayCache()
  284 + pingJob?.cancel()
  285 + pongJob?.cancel()
275 } 286 }
276 287
277 //------------------------------- End WebSocket Listener ------------------------------------// 288 //------------------------------- End WebSocket Listener ------------------------------------//
@@ -471,6 +482,9 @@ constructor( @@ -471,6 +482,9 @@ constructor(
471 if (response.hasJoin()) { 482 if (response.hasJoin()) {
472 isConnected = true 483 isConnected = true
473 startRequestQueue() 484 startRequestQueue()
  485 + pingTimeoutDurationMillis = response.join.pingTimeout.toLong() * 1000
  486 + pingIntervalDurationMillis = response.join.pingInterval.toLong() * 1000
  487 + startPingJob()
474 try { 488 try {
475 serverVersion = Semver(response.join.serverVersion) 489 serverVersion = Semver(response.join.serverVersion)
476 } catch (t: Throwable) { 490 } catch (t: Throwable) {
@@ -537,7 +551,7 @@ constructor( @@ -537,7 +551,7 @@ constructor(
537 } 551 }
538 LivekitRtc.SignalResponse.MessageCase.SUBSCRIBED_QUALITY_UPDATE -> { 552 LivekitRtc.SignalResponse.MessageCase.SUBSCRIBED_QUALITY_UPDATE -> {
539 val versionToIgnoreUpTo = Semver("0.15.1") 553 val versionToIgnoreUpTo = Semver("0.15.1")
540 - if (serverVersion?.compareTo(versionToIgnoreUpTo) ?: 1 <= 0) { 554 + if ((serverVersion?.compareTo(versionToIgnoreUpTo) ?: 1) <= 0) {
541 return 555 return
542 } 556 }
543 listener?.onSubscribedQualityUpdate(response.subscribedQualityUpdate) 557 listener?.onSubscribedQualityUpdate(response.subscribedQualityUpdate)
@@ -551,11 +565,48 @@ constructor( @@ -551,11 +565,48 @@ constructor(
551 LivekitRtc.SignalResponse.MessageCase.TRACK_UNPUBLISHED -> { 565 LivekitRtc.SignalResponse.MessageCase.TRACK_UNPUBLISHED -> {
552 listener?.onLocalTrackUnpublished(response.trackUnpublished) 566 listener?.onLocalTrackUnpublished(response.trackUnpublished)
553 } 567 }
  568 + LivekitRtc.SignalResponse.MessageCase.PONG -> {
  569 + resetPingTimeout()
  570 + }
554 LivekitRtc.SignalResponse.MessageCase.MESSAGE_NOT_SET, 571 LivekitRtc.SignalResponse.MessageCase.MESSAGE_NOT_SET,
555 null -> { 572 null -> {
556 LKLog.v { "empty messageCase!" } 573 LKLog.v { "empty messageCase!" }
557 } 574 }
558 - }.safe() 575 + }
  576 + }
  577 +
  578 + private fun startPingJob() {
  579 + if (pingJob == null && pingIntervalDurationMillis != 0L) {
  580 + pingJob = coroutineScope.launch {
  581 + while (true) {
  582 + delay(pingIntervalDurationMillis)
  583 +
  584 + val pingTimestamp = Date().time
  585 + val pingRequest = LivekitRtc.SignalRequest.newBuilder()
  586 + .setPing(pingTimestamp)
  587 + .build()
  588 + LKLog.v { "Sending ping: $pingTimestamp" }
  589 + sendRequest(pingRequest)
  590 + startPingTimeout(pingTimestamp)
  591 + }
  592 + }
  593 + }
  594 + }
  595 +
  596 + private fun startPingTimeout(timestamp: Long) {
  597 + if (pongJob != null) {
  598 + return
  599 + }
  600 + pongJob = coroutineScope.launch {
  601 + delay(pingTimeoutDurationMillis)
  602 + LKLog.d { "Ping timeout reached for ping sent at $timestamp." }
  603 + currentWs?.close(CLOSE_REASON_PING_TIMEOUT, "Ping timeout")
  604 + }
  605 + }
  606 +
  607 + private fun resetPingTimeout() {
  608 + pongJob?.cancel()
  609 + pongJob = null
559 } 610 }
560 611
561 /** 612 /**
@@ -563,14 +614,19 @@ constructor( @@ -563,14 +614,19 @@ constructor(
563 * 614 *
564 * Can be reused afterwards. 615 * Can be reused afterwards.
565 */ 616 */
566 - fun close(code: Int = 1000, reason: String = "Normal Closure") { 617 + fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure") {
567 LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" } 618 LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" }
568 isConnected = false 619 isConnected = false
569 isReconnecting = false 620 isReconnecting = false
570 - requestFlowJob = null  
571 if (::coroutineScope.isInitialized) { 621 if (::coroutineScope.isInitialized) {
572 coroutineScope.close() 622 coroutineScope.close()
573 } 623 }
  624 + requestFlowJob?.cancel()
  625 + requestFlowJob = null
  626 + pingJob?.cancel()
  627 + pingJob = null
  628 + pongJob?.cancel()
  629 + pongJob = null
574 currentWs?.close(code, reason) 630 currentWs?.close(code, reason)
575 currentWs = null 631 currentWs = null
576 joinContinuation?.cancel() 632 joinContinuation?.cancel()
@@ -640,6 +696,9 @@ constructor( @@ -640,6 +696,9 @@ constructor(
640 // iceServer("stun:stun3.l.google.com:19302"), 696 // iceServer("stun:stun3.l.google.com:19302"),
641 // iceServer("stun:stun4.l.google.com:19302"), 697 // iceServer("stun:stun4.l.google.com:19302"),
642 ) 698 )
  699 + const val CLOSE_REASON_NORMAL_CLOSURE = 1000
  700 + const val CLOSE_REASON_PING_TIMEOUT = 3000
  701 + const val CLOSE_REASON_WEBSOCKET_FAILURE = 3500
643 } 702 }
644 } 703 }
645 704
1 -@file:Suppress("unused")  
2 -  
3 -package io.livekit.android.util  
4 -  
5 -/**  
6 - * Forces a when expression to be exhaustive only.  
7 - */  
8 -internal fun Unit.safe() {}  
9 -  
10 -/**  
11 - * Forces a when expression to be exhaustive only.  
12 - */  
13 -internal fun Nothing?.safe() {}  
14 -  
15 -/**  
16 - * Forces a when expression to be exhaustive only.  
17 - */  
18 -internal fun Any?.safe() {}  
@@ -13,7 +13,7 @@ import kotlinx.serialization.json.Json @@ -13,7 +13,7 @@ import kotlinx.serialization.json.Json
13 import livekit.LivekitModels 13 import livekit.LivekitModels
14 import livekit.LivekitRtc 14 import livekit.LivekitRtc
15 import okhttp3.* 15 import okhttp3.*
16 -import org.junit.Assert 16 +import org.junit.Assert.*
17 import org.junit.Before 17 import org.junit.Before
18 import org.junit.Test 18 import org.junit.Test
19 import org.mockito.Mock 19 import org.mockito.Mock
@@ -68,7 +68,6 @@ class SignalClientTest : BaseTest() { @@ -68,7 +68,6 @@ class SignalClientTest : BaseTest() {
68 68
69 @Test 69 @Test
70 fun joinAndResponse() = runTest { 70 fun joinAndResponse() = runTest {
71 - println("dispatcher = ${this.coroutineContext}")  
72 val job = async { 71 val job = async {
73 client.join(EXAMPLE_URL, "") 72 client.join(EXAMPLE_URL, "")
74 } 73 }
@@ -76,8 +75,8 @@ class SignalClientTest : BaseTest() { @@ -76,8 +75,8 @@ class SignalClientTest : BaseTest() {
76 connectWebsocketAndJoin() 75 connectWebsocketAndJoin()
77 76
78 val response = job.await() 77 val response = job.await()
79 - Assert.assertEquals(true, client.isConnected)  
80 - Assert.assertEquals(response, JOIN.join) 78 + assertEquals(true, client.isConnected)
  79 + assertEquals(response, JOIN.join)
81 } 80 }
82 81
83 @Test 82 @Test
@@ -89,7 +88,7 @@ class SignalClientTest : BaseTest() { @@ -89,7 +88,7 @@ class SignalClientTest : BaseTest() {
89 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request)) 88 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
90 89
91 job.await() 90 job.await()
92 - Assert.assertEquals(true, client.isConnected) 91 + assertEquals(true, client.isConnected)
93 } 92 }
94 93
95 @Test 94 @Test
@@ -106,7 +105,7 @@ class SignalClientTest : BaseTest() { @@ -106,7 +105,7 @@ class SignalClientTest : BaseTest() {
106 client.onFailure(wsFactory.ws, Exception(), null) 105 client.onFailure(wsFactory.ws, Exception(), null)
107 job.await() 106 job.await()
108 107
109 - Assert.assertTrue(failed) 108 + assertTrue(failed)
110 } 109 }
111 110
112 @Test 111 @Test
@@ -165,12 +164,12 @@ class SignalClientTest : BaseTest() { @@ -165,12 +164,12 @@ class SignalClientTest : BaseTest() {
165 164
166 val ws = wsFactory.ws 165 val ws = wsFactory.ws
167 166
168 - Assert.assertEquals(1, ws.sentRequests.size) 167 + assertEquals(1, ws.sentRequests.size)
169 val sentRequest = LivekitRtc.SignalRequest.newBuilder() 168 val sentRequest = LivekitRtc.SignalRequest.newBuilder()
170 .mergeFrom(ws.sentRequests[0].toPBByteString()) 169 .mergeFrom(ws.sentRequests[0].toPBByteString())
171 .build() 170 .build()
172 171
173 - Assert.assertTrue(sentRequest.hasMute()) 172 + assertTrue(sentRequest.hasMute())
174 } 173 }
175 174
176 @Test 175 @Test
@@ -184,12 +183,12 @@ class SignalClientTest : BaseTest() { @@ -184,12 +183,12 @@ class SignalClientTest : BaseTest() {
184 job.await() 183 job.await()
185 184
186 val ws = wsFactory.ws 185 val ws = wsFactory.ws
187 - Assert.assertEquals(3, ws.sentRequests.size) 186 + assertEquals(3, ws.sentRequests.size)
188 val sentRequest = LivekitRtc.SignalRequest.newBuilder() 187 val sentRequest = LivekitRtc.SignalRequest.newBuilder()
189 .mergeFrom(ws.sentRequests[0].toPBByteString()) 188 .mergeFrom(ws.sentRequests[0].toPBByteString())
190 .build() 189 .build()
191 190
192 - Assert.assertTrue(sentRequest.hasMute()) 191 + assertTrue(sentRequest.hasMute())
193 } 192 }
194 193
195 @Test 194 @Test
@@ -205,16 +204,78 @@ class SignalClientTest : BaseTest() { @@ -205,16 +204,78 @@ class SignalClientTest : BaseTest() {
205 val ws = wsFactory.ws 204 val ws = wsFactory.ws
206 205
207 // Wait until peer connection is connected to send requests. 206 // Wait until peer connection is connected to send requests.
208 - Assert.assertEquals(0, ws.sentRequests.size) 207 + assertEquals(0, ws.sentRequests.size)
209 208
210 client.onPCConnected() 209 client.onPCConnected()
211 210
212 - Assert.assertEquals(3, ws.sentRequests.size) 211 + assertEquals(3, ws.sentRequests.size)
213 val sentRequest = LivekitRtc.SignalRequest.newBuilder() 212 val sentRequest = LivekitRtc.SignalRequest.newBuilder()
214 .mergeFrom(ws.sentRequests[0].toPBByteString()) 213 .mergeFrom(ws.sentRequests[0].toPBByteString())
215 .build() 214 .build()
216 215
217 - Assert.assertTrue(sentRequest.hasMute()) 216 + assertTrue(sentRequest.hasMute())
  217 + }
  218 +
  219 + @Test
  220 + fun pingTest() = runTest {
  221 +
  222 + val joinResponseWithPing = with(JOIN.toBuilder()) {
  223 + join = with(join.toBuilder()) {
  224 + pingInterval = 10
  225 + pingTimeout = 20
  226 + build()
  227 + }
  228 + build()
  229 + }
  230 +
  231 + val job = async {
  232 + client.join(EXAMPLE_URL, "")
  233 + }
  234 + client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  235 + client.onMessage(wsFactory.ws, joinResponseWithPing.toOkioByteString())
  236 + job.await()
  237 + val originalWs = wsFactory.ws
  238 + assertFalse(originalWs.isClosed)
  239 +
  240 + testScheduler.advanceTimeBy(15 * 1000)
  241 + assertTrue(originalWs.sentRequests.any { requestString ->
  242 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  243 + .mergeFrom(requestString.toPBByteString())
  244 + .build()
  245 +
  246 + return@any sentRequest.hasPing()
  247 + })
  248 +
  249 + client.onMessage(wsFactory.ws, PONG.toOkioByteString())
  250 +
  251 + testScheduler.advanceTimeBy(10 * 1000)
  252 + assertFalse(originalWs.isClosed)
  253 + }
  254 +
  255 + @Test
  256 + fun pingTimeoutTest() = runTest {
  257 +
  258 + val joinResponseWithPing = with(JOIN.toBuilder()) {
  259 + join = with(join.toBuilder()) {
  260 + pingInterval = 10
  261 + pingTimeout = 20
  262 + build()
  263 + }
  264 + build()
  265 + }
  266 +
  267 + val job = async {
  268 + client.join(EXAMPLE_URL, "")
  269 + }
  270 + client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  271 + client.onMessage(wsFactory.ws, joinResponseWithPing.toOkioByteString())
  272 + job.await()
  273 + val originalWs = wsFactory.ws
  274 + assertFalse(originalWs.isClosed)
  275 +
  276 + testScheduler.advanceUntilIdle()
  277 +
  278 + assertTrue(originalWs.isClosed)
218 } 279 }
219 280
220 // mock data 281 // mock data
@@ -373,6 +434,12 @@ class SignalClientTest : BaseTest() { @@ -373,6 +434,12 @@ class SignalClientTest : BaseTest() {
373 refreshToken = "refresh_token" 434 refreshToken = "refresh_token"
374 build() 435 build()
375 } 436 }
  437 +
  438 + val PONG = with(LivekitRtc.SignalResponse.newBuilder()) {
  439 + pong = 1L
  440 + build()
  441 + }
  442 +
376 val LEAVE = with(LivekitRtc.SignalResponse.newBuilder()) { 443 val LEAVE = with(LivekitRtc.SignalResponse.newBuilder()) {
377 leave = with(LivekitRtc.LeaveRequest.newBuilder()) { 444 leave = with(LivekitRtc.LeaveRequest.newBuilder()) {
378 build() 445 build()
1 -Subproject commit 51a8116f88b2c88ee1492c6a4d512b4611400918 1 +Subproject commit 6ec04e9ca47ebad2f3426be543fb6cbeef58c2b5