davidliu
Committed by GitHub

Websocket Ping/Pong (#144)

* Update protocol submodule

* fix compile

* Remove obsolete when safe operator

* ping pong implementation

* Add in backup websocket ping/pong

* pingpong test

* lower okhttp ping interval to 20 seconds for now
... ... @@ -9,6 +9,7 @@ import io.livekit.android.stats.AndroidNetworkInfo
import io.livekit.android.stats.NetworkInfo
import okhttp3.OkHttpClient
import okhttp3.WebSocket
import java.util.concurrent.TimeUnit
import javax.inject.Named
import javax.inject.Singleton
... ... @@ -21,6 +22,10 @@ object WebModule {
@Nullable
okHttpClientOverride: OkHttpClient?
): OkHttpClient {
OkHttpClient.Builder()
.pingInterval(20, TimeUnit.SECONDS)
.build()
return okHttpClientOverride ?: OkHttpClient()
}
... ...
... ... @@ -11,7 +11,6 @@ import io.livekit.android.stats.getClientInfo
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import io.livekit.android.util.LKLog
import io.livekit.android.util.safe
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableSharedFlow
... ... @@ -26,6 +25,7 @@ import okio.ByteString.Companion.toByteString
import org.webrtc.IceCandidate
import org.webrtc.PeerConnection
import org.webrtc.SessionDescription
import java.util.*
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton
... ... @@ -65,6 +65,11 @@ constructor(
private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
private var pingJob: Job? = null
private var pongJob: Job? = null
private var pingTimeoutDurationMillis: Long = 0
private var pingIntervalDurationMillis: Long = 0
var connectionState: ConnectionState = ConnectionState.DISCONNECTED
/**
... ... @@ -205,6 +210,8 @@ constructor(
isReconnecting = false
isConnected = true
joinContinuation?.resumeWith(Result.success(Either.Right(Unit)))
// Restart ping job with old settings.
startPingJob()
}
response.body?.close()
}
... ... @@ -259,19 +266,23 @@ constructor(
}
val wasConnected = isConnected
isConnected = false
if (wasConnected) {
handleWebSocketClose(
reason = reason ?: response?.toString() ?: t.localizedMessage ?: "websocket failure",
code = response?.code ?: 500
code = response?.code ?: CLOSE_REASON_WEBSOCKET_FAILURE
)
}
}
private fun handleWebSocketClose(reason: String, code: Int) {
LKLog.v { "websocket closed" }
isConnected = false
listener?.onClose(reason, code)
requestFlow.resetReplayCache()
responseFlow.resetReplayCache()
pingJob?.cancel()
pongJob?.cancel()
}
//------------------------------- End WebSocket Listener ------------------------------------//
... ... @@ -471,6 +482,9 @@ constructor(
if (response.hasJoin()) {
isConnected = true
startRequestQueue()
pingTimeoutDurationMillis = response.join.pingTimeout.toLong() * 1000
pingIntervalDurationMillis = response.join.pingInterval.toLong() * 1000
startPingJob()
try {
serverVersion = Semver(response.join.serverVersion)
} catch (t: Throwable) {
... ... @@ -537,7 +551,7 @@ constructor(
}
LivekitRtc.SignalResponse.MessageCase.SUBSCRIBED_QUALITY_UPDATE -> {
val versionToIgnoreUpTo = Semver("0.15.1")
if (serverVersion?.compareTo(versionToIgnoreUpTo) ?: 1 <= 0) {
if ((serverVersion?.compareTo(versionToIgnoreUpTo) ?: 1) <= 0) {
return
}
listener?.onSubscribedQualityUpdate(response.subscribedQualityUpdate)
... ... @@ -551,11 +565,48 @@ constructor(
LivekitRtc.SignalResponse.MessageCase.TRACK_UNPUBLISHED -> {
listener?.onLocalTrackUnpublished(response.trackUnpublished)
}
LivekitRtc.SignalResponse.MessageCase.PONG -> {
resetPingTimeout()
}
LivekitRtc.SignalResponse.MessageCase.MESSAGE_NOT_SET,
null -> {
LKLog.v { "empty messageCase!" }
}
}.safe()
}
}
private fun startPingJob() {
if (pingJob == null && pingIntervalDurationMillis != 0L) {
pingJob = coroutineScope.launch {
while (true) {
delay(pingIntervalDurationMillis)
val pingTimestamp = Date().time
val pingRequest = LivekitRtc.SignalRequest.newBuilder()
.setPing(pingTimestamp)
.build()
LKLog.v { "Sending ping: $pingTimestamp" }
sendRequest(pingRequest)
startPingTimeout(pingTimestamp)
}
}
}
}
private fun startPingTimeout(timestamp: Long) {
if (pongJob != null) {
return
}
pongJob = coroutineScope.launch {
delay(pingTimeoutDurationMillis)
LKLog.d { "Ping timeout reached for ping sent at $timestamp." }
currentWs?.close(CLOSE_REASON_PING_TIMEOUT, "Ping timeout")
}
}
private fun resetPingTimeout() {
pongJob?.cancel()
pongJob = null
}
/**
... ... @@ -563,14 +614,19 @@ constructor(
*
* Can be reused afterwards.
*/
fun close(code: Int = 1000, reason: String = "Normal Closure") {
fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure") {
LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" }
isConnected = false
isReconnecting = false
requestFlowJob = null
if (::coroutineScope.isInitialized) {
coroutineScope.close()
}
requestFlowJob?.cancel()
requestFlowJob = null
pingJob?.cancel()
pingJob = null
pongJob?.cancel()
pongJob = null
currentWs?.close(code, reason)
currentWs = null
joinContinuation?.cancel()
... ... @@ -640,6 +696,9 @@ constructor(
// iceServer("stun:stun3.l.google.com:19302"),
// iceServer("stun:stun4.l.google.com:19302"),
)
const val CLOSE_REASON_NORMAL_CLOSURE = 1000
const val CLOSE_REASON_PING_TIMEOUT = 3000
const val CLOSE_REASON_WEBSOCKET_FAILURE = 3500
}
}
... ...
@file:Suppress("unused")
package io.livekit.android.util
/**
* Forces a when expression to be exhaustive only.
*/
internal fun Unit.safe() {}
/**
* Forces a when expression to be exhaustive only.
*/
internal fun Nothing?.safe() {}
/**
* Forces a when expression to be exhaustive only.
*/
internal fun Any?.safe() {}
\ No newline at end of file
... ... @@ -13,7 +13,7 @@ import kotlinx.serialization.json.Json
import livekit.LivekitModels
import livekit.LivekitRtc
import okhttp3.*
import org.junit.Assert
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
import org.mockito.Mock
... ... @@ -68,7 +68,6 @@ class SignalClientTest : BaseTest() {
@Test
fun joinAndResponse() = runTest {
println("dispatcher = ${this.coroutineContext}")
val job = async {
client.join(EXAMPLE_URL, "")
}
... ... @@ -76,8 +75,8 @@ class SignalClientTest : BaseTest() {
connectWebsocketAndJoin()
val response = job.await()
Assert.assertEquals(true, client.isConnected)
Assert.assertEquals(response, JOIN.join)
assertEquals(true, client.isConnected)
assertEquals(response, JOIN.join)
}
@Test
... ... @@ -89,7 +88,7 @@ class SignalClientTest : BaseTest() {
client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
job.await()
Assert.assertEquals(true, client.isConnected)
assertEquals(true, client.isConnected)
}
@Test
... ... @@ -106,7 +105,7 @@ class SignalClientTest : BaseTest() {
client.onFailure(wsFactory.ws, Exception(), null)
job.await()
Assert.assertTrue(failed)
assertTrue(failed)
}
@Test
... ... @@ -165,12 +164,12 @@ class SignalClientTest : BaseTest() {
val ws = wsFactory.ws
Assert.assertEquals(1, ws.sentRequests.size)
assertEquals(1, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
assertTrue(sentRequest.hasMute())
}
@Test
... ... @@ -184,12 +183,12 @@ class SignalClientTest : BaseTest() {
job.await()
val ws = wsFactory.ws
Assert.assertEquals(3, ws.sentRequests.size)
assertEquals(3, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
assertTrue(sentRequest.hasMute())
}
@Test
... ... @@ -205,16 +204,78 @@ class SignalClientTest : BaseTest() {
val ws = wsFactory.ws
// Wait until peer connection is connected to send requests.
Assert.assertEquals(0, ws.sentRequests.size)
assertEquals(0, ws.sentRequests.size)
client.onPCConnected()
Assert.assertEquals(3, ws.sentRequests.size)
assertEquals(3, ws.sentRequests.size)
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
Assert.assertTrue(sentRequest.hasMute())
assertTrue(sentRequest.hasMute())
}
@Test
fun pingTest() = runTest {
val joinResponseWithPing = with(JOIN.toBuilder()) {
join = with(join.toBuilder()) {
pingInterval = 10
pingTimeout = 20
build()
}
build()
}
val job = async {
client.join(EXAMPLE_URL, "")
}
client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
client.onMessage(wsFactory.ws, joinResponseWithPing.toOkioByteString())
job.await()
val originalWs = wsFactory.ws
assertFalse(originalWs.isClosed)
testScheduler.advanceTimeBy(15 * 1000)
assertTrue(originalWs.sentRequests.any { requestString ->
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(requestString.toPBByteString())
.build()
return@any sentRequest.hasPing()
})
client.onMessage(wsFactory.ws, PONG.toOkioByteString())
testScheduler.advanceTimeBy(10 * 1000)
assertFalse(originalWs.isClosed)
}
@Test
fun pingTimeoutTest() = runTest {
val joinResponseWithPing = with(JOIN.toBuilder()) {
join = with(join.toBuilder()) {
pingInterval = 10
pingTimeout = 20
build()
}
build()
}
val job = async {
client.join(EXAMPLE_URL, "")
}
client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
client.onMessage(wsFactory.ws, joinResponseWithPing.toOkioByteString())
job.await()
val originalWs = wsFactory.ws
assertFalse(originalWs.isClosed)
testScheduler.advanceUntilIdle()
assertTrue(originalWs.isClosed)
}
// mock data
... ... @@ -373,6 +434,12 @@ class SignalClientTest : BaseTest() {
refreshToken = "refresh_token"
build()
}
val PONG = with(LivekitRtc.SignalResponse.newBuilder()) {
pong = 1L
build()
}
val LEAVE = with(LivekitRtc.SignalResponse.newBuilder()) {
leave = with(LivekitRtc.LeaveRequest.newBuilder()) {
build()
... ...
Subproject commit 51a8116f88b2c88ee1492c6a4d512b4611400918
Subproject commit 6ec04e9ca47ebad2f3426be543fb6cbeef58c2b5
... ...