davidliu
Committed by GitHub

Handle reconnect response (#202)

* Update protos

* Handle reconnect response
@@ -11,7 +11,7 @@ buildscript { @@ -11,7 +11,7 @@ buildscript {
11 classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.7.10" 11 classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.7.10"
12 classpath "org.jetbrains.kotlin:kotlin-serialization:$kotlin_version" 12 classpath "org.jetbrains.kotlin:kotlin-serialization:$kotlin_version"
13 classpath "org.jetbrains.dokka:dokka-gradle-plugin:$dokka_version" 13 classpath "org.jetbrains.dokka:dokka-gradle-plugin:$dokka_version"
14 - classpath 'com.google.protobuf:protobuf-gradle-plugin:0.8.18' 14 + classpath 'com.google.protobuf:protobuf-gradle-plugin:0.8.19'
15 classpath "io.codearte.gradle.nexus:gradle-nexus-staging-plugin:0.30.0" 15 classpath "io.codearte.gradle.nexus:gradle-nexus-staging-plugin:0.30.0"
16 // NOTE: Do not place your application dependencies here; they belong 16 // NOTE: Do not place your application dependencies here; they belong
17 // in the individual module build.gradle files 17 // in the individual module build.gradle files
@@ -30,7 +30,15 @@ android { @@ -30,7 +30,15 @@ android {
30 } 30 }
31 31
32 sourceSets { 32 sourceSets {
33 - main.java.srcDirs += "${protobuf.generatedFilesBaseDir}/main/javalite" 33 + main {
  34 + proto {
  35 + srcDir generated.protoSrc
  36 + exclude '*/*.proto' // only use top-level protos.
  37 + }
  38 + java {
  39 + srcDir "${protobuf.generatedFilesBaseDir}/main/javalite"
  40 + }
  41 + }
34 } 42 }
35 43
36 testOptions { 44 testOptions {
@@ -112,7 +120,6 @@ dokkaHtml { @@ -112,7 +120,6 @@ dokkaHtml {
112 } 120 }
113 121
114 dependencies { 122 dependencies {
115 - protobuf files(generated.protoSrc)  
116 implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" 123 implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
117 implementation deps.coroutines.lib 124 implementation deps.coroutines.lib
118 implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.1.0' 125 implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.1.0'
@@ -17,9 +17,9 @@ data class ConnectOptions( @@ -17,9 +17,9 @@ data class ConnectOptions(
17 /** 17 /**
18 * A user-provided RTCConfiguration to override options. 18 * A user-provided RTCConfiguration to override options.
19 * 19 *
20 - * Note: LiveKit requires [PeerConnection.SdpSemantics.UNIFIED_PLAN]  
21 - * and a mutable list should be provided for iceServers constructor.  
22 - * */ 20 + * Note: LiveKit requires [PeerConnection.SdpSemantics.UNIFIED_PLAN] and
  21 + * [PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY].
  22 + */
23 val rtcConfig: PeerConnection.RTCConfiguration? = null, 23 val rtcConfig: PeerConnection.RTCConfiguration? = null,
24 /** 24 /**
25 * capture and publish audio track on connect, defaults to false 25 * capture and publish audio track on connect, defaults to false
@@ -15,6 +15,7 @@ import kotlinx.coroutines.runBlocking @@ -15,6 +15,7 @@ import kotlinx.coroutines.runBlocking
15 import kotlinx.coroutines.sync.Mutex 15 import kotlinx.coroutines.sync.Mutex
16 import kotlinx.coroutines.sync.withLock 16 import kotlinx.coroutines.sync.withLock
17 import org.webrtc.* 17 import org.webrtc.*
  18 +import org.webrtc.PeerConnection.RTCConfiguration
18 import javax.inject.Named 19 import javax.inject.Named
19 20
20 /** 21 /**
@@ -137,6 +138,10 @@ constructor( @@ -137,6 +138,10 @@ constructor(
137 peerConnection.close() 138 peerConnection.close()
138 } 139 }
139 140
  141 + fun updateRTCConfig(config: RTCConfiguration) {
  142 + peerConnection.setConfiguration(config)
  143 + }
  144 +
140 @AssistedFactory 145 @AssistedFactory
141 interface Factory { 146 interface Factory {
142 fun create( 147 fun create(
@@ -15,6 +15,7 @@ import io.livekit.android.room.util.setLocalDescription @@ -15,6 +15,7 @@ import io.livekit.android.room.util.setLocalDescription
15 import io.livekit.android.util.CloseableCoroutineScope 15 import io.livekit.android.util.CloseableCoroutineScope
16 import io.livekit.android.util.Either 16 import io.livekit.android.util.Either
17 import io.livekit.android.util.LKLog 17 import io.livekit.android.util.LKLog
  18 +import io.livekit.android.webrtc.copy
18 import io.livekit.android.webrtc.isConnected 19 import io.livekit.android.webrtc.isConnected
19 import io.livekit.android.webrtc.isDisconnected 20 import io.livekit.android.webrtc.isDisconnected
20 import io.livekit.android.webrtc.toProtoSessionDescription 21 import io.livekit.android.webrtc.toProtoSessionDescription
@@ -22,7 +23,10 @@ import kotlinx.coroutines.* @@ -22,7 +23,10 @@ import kotlinx.coroutines.*
22 import kotlinx.coroutines.sync.Mutex 23 import kotlinx.coroutines.sync.Mutex
23 import livekit.LivekitModels 24 import livekit.LivekitModels
24 import livekit.LivekitRtc 25 import livekit.LivekitRtc
  26 +import livekit.LivekitRtc.JoinResponse
  27 +import livekit.LivekitRtc.ReconnectResponse
25 import org.webrtc.* 28 import org.webrtc.*
  29 +import org.webrtc.PeerConnection.RTCConfiguration
26 import java.net.ConnectException 30 import java.net.ConnectException
27 import java.nio.ByteBuffer 31 import java.nio.ByteBuffer
28 import javax.inject.Inject 32 import javax.inject.Inject
@@ -127,7 +131,7 @@ internal constructor( @@ -127,7 +131,7 @@ internal constructor(
127 token: String, 131 token: String,
128 options: ConnectOptions, 132 options: ConnectOptions,
129 roomOptions: RoomOptions 133 roomOptions: RoomOptions
130 - ): LivekitRtc.JoinResponse { 134 + ): JoinResponse {
131 coroutineScope.close() 135 coroutineScope.close()
132 coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher) 136 coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
133 sessionUrl = url 137 sessionUrl = url
@@ -142,7 +146,7 @@ internal constructor( @@ -142,7 +146,7 @@ internal constructor(
142 token: String, 146 token: String,
143 options: ConnectOptions, 147 options: ConnectOptions,
144 roomOptions: RoomOptions 148 roomOptions: RoomOptions
145 - ): LivekitRtc.JoinResponse { 149 + ): JoinResponse {
146 val joinResponse = client.join(url, token, options, roomOptions) 150 val joinResponse = client.join(url, token, options, roomOptions)
147 listener?.onJoinResponse(joinResponse) 151 listener?.onJoinResponse(joinResponse)
148 isClosed = false 152 isClosed = false
@@ -160,7 +164,7 @@ internal constructor( @@ -160,7 +164,7 @@ internal constructor(
160 return joinResponse 164 return joinResponse
161 } 165 }
162 166
163 - private fun configure(joinResponse: LivekitRtc.JoinResponse, connectOptions: ConnectOptions?) { 167 + private fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
164 if (_publisher != null && _subscriber != null) { 168 if (_publisher != null && _subscriber != null) {
165 // already configured 169 // already configured
166 return 170 return
@@ -172,60 +176,8 @@ internal constructor( @@ -172,60 +176,8 @@ internal constructor(
172 null 176 null
173 } 177 }
174 178
175 - // update ICE servers before creating PeerConnection  
176 - val iceServers = run {  
177 - val servers = mutableListOf<PeerConnection.IceServer>()  
178 - for (serverInfo in joinResponse.iceServersList) {  
179 - val username = serverInfo.username ?: ""  
180 - val credential = serverInfo.credential ?: ""  
181 - servers.add(  
182 - PeerConnection.IceServer  
183 - .builder(serverInfo.urlsList)  
184 - .setUsername(username)  
185 - .setPassword(credential)  
186 - .createIceServer()  
187 - )  
188 - }  
189 -  
190 - if (servers.isEmpty()) {  
191 - servers.addAll(SignalClient.DEFAULT_ICE_SERVERS)  
192 - }  
193 - servers  
194 - }  
195 -  
196 // Setup peer connections 179 // Setup peer connections
197 - val rtcConfig = connectOptions?.rtcConfig?.apply {  
198 - val mergedServers = this.iceServers  
199 - if (connectOptions.iceServers != null) {  
200 - connectOptions.iceServers.forEach { server ->  
201 - if (!mergedServers.contains(server)) {  
202 - mergedServers.add(server)  
203 - }  
204 - }  
205 - }  
206 -  
207 - // Only use server-provided servers if user doesn't provide any.  
208 - if (mergedServers.isEmpty()) {  
209 - iceServers.forEach { server ->  
210 - if (!mergedServers.contains(server)) {  
211 - mergedServers.add(server)  
212 - }  
213 - }  
214 - }  
215 - }  
216 - ?: PeerConnection.RTCConfiguration(iceServers).apply {  
217 - sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN  
218 - continualGatheringPolicy =  
219 - PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY  
220 - }  
221 -  
222 - if (joinResponse.hasClientConfiguration()) {  
223 - val clientConfig = joinResponse.clientConfiguration  
224 -  
225 - if (clientConfig.forceRelay == LivekitModels.ClientConfigSetting.ENABLED) {  
226 - rtcConfig.iceTransportsType = PeerConnection.IceTransportsType.RELAY  
227 - }  
228 - } 180 + val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
229 181
230 _publisher?.close() 182 _publisher?.close()
231 _publisher = pctFactory.create( 183 _publisher = pctFactory.create(
@@ -398,12 +350,13 @@ internal constructor( @@ -398,12 +350,13 @@ internal constructor(
398 ReconnectType.FORCE_FULL_RECONNECT -> true 350 ReconnectType.FORCE_FULL_RECONNECT -> true
399 } 351 }
400 352
  353 + val connectOptions = connectOptions ?: ConnectOptions()
401 if (isFullReconnect) { 354 if (isFullReconnect) {
402 LKLog.v { "Attempting full reconnect." } 355 LKLog.v { "Attempting full reconnect." }
403 try { 356 try {
404 closeResources("Full Reconnecting") 357 closeResources("Full Reconnecting")
405 listener?.onFullReconnecting() 358 listener?.onFullReconnecting()
406 - joinImpl(url, token, connectOptions ?: ConnectOptions(), lastRoomOptions ?: RoomOptions()) 359 + joinImpl(url, token, connectOptions, lastRoomOptions ?: RoomOptions())
407 } catch (e: Exception) { 360 } catch (e: Exception) {
408 LKLog.w(e) { "Error during reconnection." } 361 LKLog.w(e) { "Error during reconnection." }
409 // reconnect failed, retry. 362 // reconnect failed, retry.
@@ -413,8 +366,13 @@ internal constructor( @@ -413,8 +366,13 @@ internal constructor(
413 LKLog.v { "Attempting soft reconnect." } 366 LKLog.v { "Attempting soft reconnect." }
414 subscriber.prepareForIceRestart() 367 subscriber.prepareForIceRestart()
415 try { 368 try {
416 - client.reconnect(url, token, participantSid)  
417 - // no join response for regular reconnects 369 + val response = client.reconnect(url, token, participantSid)
  370 + if (response is Either.Left) {
  371 + val reconnectResponse = response.value
  372 + val rtcConfig = makeRTCConfig(Either.Right(reconnectResponse), connectOptions)
  373 + _subscriber?.updateRTCConfig(rtcConfig)
  374 + _publisher?.updateRTCConfig(rtcConfig)
  375 + }
418 client.onReadyForResponses() 376 client.onReadyForResponses()
419 } catch (e: Exception) { 377 } catch (e: Exception) {
420 LKLog.w(e) { "Error during reconnection." } 378 LKLog.w(e) { "Error during reconnection." }
@@ -573,13 +531,87 @@ internal constructor( @@ -573,13 +531,87 @@ internal constructor(
573 } 531 }
574 } 532 }
575 533
  534 + private fun makeRTCConfig(
  535 + serverResponse: Either<JoinResponse, ReconnectResponse>,
  536 + connectOptions: ConnectOptions
  537 + ): RTCConfiguration {
  538 +
  539 + // Convert protobuf ice servers
  540 + val serverIceServers = run {
  541 + val servers = mutableListOf<PeerConnection.IceServer>()
  542 + val responseServers = when (serverResponse) {
  543 + is Either.Left -> serverResponse.value.iceServersList
  544 + is Either.Right -> serverResponse.value.iceServersList
  545 + }
  546 + for (serverInfo in responseServers) {
  547 + servers.add(serverInfo.toWebrtc())
  548 + }
  549 +
  550 + if (servers.isEmpty()) {
  551 + servers.addAll(SignalClient.DEFAULT_ICE_SERVERS)
  552 + }
  553 + servers
  554 + }
  555 +
  556 + val rtcConfig = connectOptions.rtcConfig?.copy()?.apply {
  557 + val mergedServers = iceServers.toMutableList()
  558 + if (connectOptions.iceServers != null) {
  559 + connectOptions.iceServers.forEach { server ->
  560 + if (!mergedServers.contains(server)) {
  561 + mergedServers.add(server)
  562 + }
  563 + }
  564 + }
  565 +
  566 + // Only use server-provided servers if user doesn't provide any.
  567 + if (mergedServers.isEmpty()) {
  568 + iceServers.forEach { server ->
  569 + if (!mergedServers.contains(server)) {
  570 + mergedServers.add(server)
  571 + }
  572 + }
  573 + }
  574 +
  575 + iceServers = mergedServers
  576 + }
  577 + ?: RTCConfiguration(serverIceServers).apply {
  578 + sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN
  579 + continualGatheringPolicy =
  580 + PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY
  581 + }
  582 +
  583 + val clientConfig = when (serverResponse) {
  584 + is Either.Left -> {
  585 + if (serverResponse.value.hasClientConfiguration()) {
  586 + serverResponse.value.clientConfiguration
  587 + } else {
  588 + null
  589 + }
  590 + }
  591 + is Either.Right -> {
  592 + if (serverResponse.value.hasClientConfiguration()) {
  593 + serverResponse.value.clientConfiguration
  594 + } else {
  595 + null
  596 + }
  597 + }
  598 + }
  599 + if (clientConfig != null) {
  600 + if (clientConfig.forceRelay == LivekitModels.ClientConfigSetting.ENABLED) {
  601 + rtcConfig.iceTransportsType = PeerConnection.IceTransportsType.RELAY
  602 + }
  603 + }
  604 +
  605 + return rtcConfig
  606 + }
  607 +
576 internal interface Listener { 608 internal interface Listener {
577 fun onEngineConnected() 609 fun onEngineConnected()
578 fun onEngineReconnected() 610 fun onEngineReconnected()
579 fun onEngineReconnecting() 611 fun onEngineReconnecting()
580 fun onEngineDisconnected(reason: DisconnectReason) 612 fun onEngineDisconnected(reason: DisconnectReason)
581 fun onFailToConnect(error: Throwable) 613 fun onFailToConnect(error: Throwable)
582 - fun onJoinResponse(response: LivekitRtc.JoinResponse) 614 + fun onJoinResponse(response: JoinResponse)
583 fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>) 615 fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>)
584 fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>) 616 fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>)
585 fun onActiveSpeakersUpdate(speakers: List<LivekitModels.SpeakerInfo>) 617 fun onActiveSpeakersUpdate(speakers: List<LivekitModels.SpeakerInfo>)
@@ -837,4 +869,11 @@ enum class ReconnectType { @@ -837,4 +869,11 @@ enum class ReconnectType {
837 DEFAULT, 869 DEFAULT,
838 FORCE_SOFT_RECONNECT, 870 FORCE_SOFT_RECONNECT,
839 FORCE_FULL_RECONNECT; 871 FORCE_FULL_RECONNECT;
840 -}  
  872 +}
  873 +
  874 +internal fun LivekitRtc.ICEServer.toWebrtc() = PeerConnection.IceServer.builder(urlsList)
  875 + .setUsername(username ?: "")
  876 + .setPassword(credential ?: "")
  877 + .setTlsAlpnProtocols(emptyList())
  878 + .setTlsEllipticCurves(emptyList())
  879 + .createIceServer()
@@ -19,6 +19,8 @@ import kotlinx.serialization.encodeToString @@ -19,6 +19,8 @@ import kotlinx.serialization.encodeToString
19 import kotlinx.serialization.json.Json 19 import kotlinx.serialization.json.Json
20 import livekit.LivekitModels 20 import livekit.LivekitModels
21 import livekit.LivekitRtc 21 import livekit.LivekitRtc
  22 +import livekit.LivekitRtc.JoinResponse
  23 +import livekit.LivekitRtc.ReconnectResponse
22 import okhttp3.* 24 import okhttp3.*
23 import okio.ByteString 25 import okio.ByteString
24 import okio.ByteString.Companion.toByteString 26 import okio.ByteString.Companion.toByteString
@@ -56,7 +58,12 @@ constructor( @@ -56,7 +58,12 @@ constructor(
56 private var lastOptions: ConnectOptions? = null 58 private var lastOptions: ConnectOptions? = null
57 private var lastRoomOptions: RoomOptions? = null 59 private var lastRoomOptions: RoomOptions? = null
58 60
59 - private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null 61 + // join will always return a JoinResponse.
  62 + // reconnect will return a ReconnectResponse or a Unit if a different response was received.
  63 + private var joinContinuation: CancellableContinuation<
  64 + Either<
  65 + JoinResponse,
  66 + Either<ReconnectResponse, Unit>>>? = null
60 private lateinit var coroutineScope: CloseableCoroutineScope 67 private lateinit var coroutineScope: CloseableCoroutineScope
61 68
62 private val requestFlowJobLock = Object() 69 private val requestFlowJobLock = Object()
@@ -80,7 +87,7 @@ constructor( @@ -80,7 +87,7 @@ constructor(
80 token: String, 87 token: String,
81 options: ConnectOptions = ConnectOptions(), 88 options: ConnectOptions = ConnectOptions(),
82 roomOptions: RoomOptions = RoomOptions(), 89 roomOptions: RoomOptions = RoomOptions(),
83 - ): LivekitRtc.JoinResponse { 90 + ): JoinResponse {
84 val joinResponse = connect(url, token, options, roomOptions) 91 val joinResponse = connect(url, token, options, roomOptions)
85 return (joinResponse as Either.Left).value 92 return (joinResponse as Either.Left).value
86 } 93 }
@@ -88,8 +95,8 @@ constructor( @@ -88,8 +95,8 @@ constructor(
88 /** 95 /**
89 * @throws Exception if fails to connect. 96 * @throws Exception if fails to connect.
90 */ 97 */
91 - suspend fun reconnect(url: String, token: String, participantSid: String?) {  
92 - connect( 98 + suspend fun reconnect(url: String, token: String, participantSid: String?): Either<ReconnectResponse, Unit> {
  99 + val reconnectResponse = connect(
93 url, 100 url,
94 token, 101 token,
95 (lastOptions ?: ConnectOptions()).copy() 102 (lastOptions ?: ConnectOptions()).copy()
@@ -99,14 +106,15 @@ constructor( @@ -99,14 +106,15 @@ constructor(
99 }, 106 },
100 lastRoomOptions ?: RoomOptions() 107 lastRoomOptions ?: RoomOptions()
101 ) 108 )
  109 + return (reconnectResponse as Either.Right).value
102 } 110 }
103 111
104 - suspend fun connect( 112 + private suspend fun connect(
105 url: String, 113 url: String,
106 token: String, 114 token: String,
107 options: ConnectOptions, 115 options: ConnectOptions,
108 roomOptions: RoomOptions 116 roomOptions: RoomOptions
109 - ): Either<LivekitRtc.JoinResponse, Unit> { 117 + ): Either<JoinResponse, Either<ReconnectResponse, Unit>> {
110 // Clean up any pre-existing connection. 118 // Clean up any pre-existing connection.
111 close(reason = "Starting new connection") 119 close(reason = "Starting new connection")
112 120
@@ -210,18 +218,6 @@ constructor( @@ -210,18 +218,6 @@ constructor(
210 } 218 }
211 219
212 //--------------------------------- WebSocket Listener --------------------------------------// 220 //--------------------------------- WebSocket Listener --------------------------------------//
213 - override fun onOpen(webSocket: WebSocket, response: Response) {  
214 - if (isReconnecting) {  
215 - // no need to wait for join response on reconnection.  
216 - isReconnecting = false  
217 - isConnected = true  
218 - joinContinuation?.resumeWith(Result.success(Either.Right(Unit)))  
219 - // Restart ping job with old settings.  
220 - startPingJob()  
221 - }  
222 - response.body?.close()  
223 - }  
224 -  
225 override fun onMessage(webSocket: WebSocket, text: String) { 221 override fun onMessage(webSocket: WebSocket, text: String) {
226 LKLog.w { "received JSON message, unsupported in this version." } 222 LKLog.w { "received JSON message, unsupported in this version." }
227 } 223 }
@@ -484,7 +480,9 @@ constructor( @@ -484,7 +480,9 @@ constructor(
484 LKLog.v { "response: $response" } 480 LKLog.v { "response: $response" }
485 481
486 if (!isConnected) { 482 if (!isConnected) {
487 - // Only handle joins if not connected. 483 + var shouldProcessMessage = false
  484 +
  485 + // Only handle certain messages if not connected.
488 if (response.hasJoin()) { 486 if (response.hasJoin()) {
489 isConnected = true 487 isConnected = true
490 startRequestQueue() 488 startRequestQueue()
@@ -500,10 +498,28 @@ constructor( @@ -500,10 +498,28 @@ constructor(
500 } else if (response.hasLeave()) { 498 } else if (response.hasLeave()) {
501 // Some reconnects may immediately send leave back without a join response first. 499 // Some reconnects may immediately send leave back without a join response first.
502 handleSignalResponseImpl(response) 500 handleSignalResponseImpl(response)
  501 + } else if (isReconnecting) {
  502 + // When reconnecting, any message received means signal reconnected.
  503 + // Newer servers will send a reconnect response first
  504 + isReconnecting = false
  505 + isConnected = true
  506 +
  507 + // Restart ping job with old settings.
  508 + startPingJob()
  509 +
  510 + if (response.hasReconnect()) {
  511 + joinContinuation?.resumeWith(Result.success(Either.Right(Either.Left(response.reconnect))))
  512 + } else {
  513 + joinContinuation?.resumeWith(Result.success(Either.Right(Either.Right(Unit))))
  514 + // Non-reconnect response, handle normally
  515 + shouldProcessMessage = true
  516 + }
503 } else { 517 } else {
504 LKLog.e { "Received response while not connected. $response" } 518 LKLog.e { "Received response while not connected. $response" }
505 } 519 }
506 - return 520 + if (!shouldProcessMessage) {
  521 + return
  522 + }
507 } 523 }
508 responseFlow.tryEmit(response) 524 responseFlow.tryEmit(response)
509 } 525 }
@@ -715,6 +731,7 @@ constructor( @@ -715,6 +731,7 @@ constructor(
715 } 731 }
716 } 732 }
717 733
  734 +@Suppress("EnumEntryName", "unused")
718 enum class ProtocolVersion(val value: Int) { 735 enum class ProtocolVersion(val value: Int) {
719 v1(1), 736 v1(1),
720 v2(2), 737 v2(2),
1 package io.livekit.android.webrtc 1 package io.livekit.android.webrtc
2 2
3 import org.webrtc.PeerConnection 3 import org.webrtc.PeerConnection
  4 +import org.webrtc.PeerConnection.RTCConfiguration
4 5
5 /** 6 /**
6 * Completed state is a valid state for a connected connection, so this should be used 7 * Completed state is a valid state for a connected connection, so this should be used
@@ -25,3 +26,55 @@ internal fun PeerConnection.PeerConnectionState.isDisconnected(): Boolean { @@ -25,3 +26,55 @@ internal fun PeerConnection.PeerConnectionState.isDisconnected(): Boolean {
25 else -> false 26 else -> false
26 } 27 }
27 } 28 }
  29 +
  30 +internal fun RTCConfiguration.copy(): RTCConfiguration {
  31 + val newConfig = RTCConfiguration(emptyList())
  32 + newConfig.copyFrom(this)
  33 + return newConfig
  34 +}
  35 +
  36 +internal fun RTCConfiguration.copyFrom(config: RTCConfiguration) {
  37 + iceTransportsType = config.iceTransportsType
  38 + iceServers = config.iceServers
  39 + bundlePolicy = config.bundlePolicy
  40 + certificate = config.certificate
  41 + rtcpMuxPolicy = config.rtcpMuxPolicy
  42 + tcpCandidatePolicy = config.tcpCandidatePolicy
  43 + candidateNetworkPolicy = config.candidateNetworkPolicy
  44 + audioJitterBufferMaxPackets = config.audioJitterBufferMaxPackets
  45 + audioJitterBufferFastAccelerate = config.audioJitterBufferFastAccelerate
  46 + iceConnectionReceivingTimeout = config.iceConnectionReceivingTimeout
  47 + iceBackupCandidatePairPingInterval = config.iceBackupCandidatePairPingInterval
  48 + keyType = config.keyType
  49 + continualGatheringPolicy = config.continualGatheringPolicy
  50 + iceCandidatePoolSize = config.iceCandidatePoolSize
  51 +
  52 + pruneTurnPorts = config.pruneTurnPorts
  53 + turnPortPrunePolicy = config.turnPortPrunePolicy
  54 + presumeWritableWhenFullyRelayed = config.presumeWritableWhenFullyRelayed
  55 + surfaceIceCandidatesOnIceTransportTypeChanged = config.surfaceIceCandidatesOnIceTransportTypeChanged
  56 + iceCheckIntervalStrongConnectivityMs = config.iceCheckIntervalStrongConnectivityMs
  57 + iceCheckIntervalWeakConnectivityMs = config.iceCheckIntervalWeakConnectivityMs
  58 + iceCheckMinInterval = config.iceCheckMinInterval
  59 + iceUnwritableTimeMs = config.iceUnwritableTimeMs
  60 + iceUnwritableMinChecks = config.iceUnwritableMinChecks
  61 + stunCandidateKeepaliveIntervalMs = config.stunCandidateKeepaliveIntervalMs
  62 + stableWritableConnectionPingIntervalMs = config.stableWritableConnectionPingIntervalMs
  63 + disableIPv6OnWifi = config.disableIPv6OnWifi
  64 + maxIPv6Networks = config.maxIPv6Networks
  65 + disableIpv6 = config.disableIpv6
  66 + enableDscp = config.enableDscp
  67 + enableCpuOveruseDetection = config.enableCpuOveruseDetection
  68 + suspendBelowMinBitrate = config.suspendBelowMinBitrate
  69 + screencastMinBitrate = config.screencastMinBitrate
  70 + combinedAudioVideoBwe = config.combinedAudioVideoBwe
  71 + networkPreference = config.networkPreference
  72 + sdpSemantics = config.sdpSemantics
  73 + turnCustomizer = config.turnCustomizer
  74 + activeResetSrtpParams = config.activeResetSrtpParams
  75 + allowCodecSwitching = config.allowCodecSwitching
  76 + cryptoOptions = config.cryptoOptions
  77 + turnLoggingId = config.turnLoggingId
  78 + enableImplicitRollback = config.enableImplicitRollback
  79 + offerExtmapAllowMixed = config.offerExtmapAllowMixed
  80 +}
@@ -7,7 +7,7 @@ private class MockNativePeerConnectionFactory : NativePeerConnectionFactory { @@ -7,7 +7,7 @@ private class MockNativePeerConnectionFactory : NativePeerConnectionFactory {
7 } 7 }
8 8
9 class MockPeerConnection( 9 class MockPeerConnection(
10 - val rtcConfig: RTCConfiguration, 10 + var rtcConfig: RTCConfiguration,
11 val observer: Observer? 11 val observer: Observer?
12 ) : PeerConnection(MockNativePeerConnectionFactory()) { 12 ) : PeerConnection(MockNativePeerConnectionFactory()) {
13 13
@@ -51,7 +51,8 @@ class MockPeerConnection( @@ -51,7 +51,8 @@ class MockPeerConnection(
51 override fun setAudioRecording(recording: Boolean) { 51 override fun setAudioRecording(recording: Boolean) {
52 } 52 }
53 53
54 - override fun setConfiguration(config: RTCConfiguration?): Boolean { 54 + override fun setConfiguration(config: RTCConfiguration): Boolean {
  55 + this.rtcConfig = config
55 return true 56 return true
56 } 57 }
57 58
@@ -28,6 +28,16 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -28,6 +28,16 @@ class RTCEngineMockE2ETest : MockE2ETest() {
28 } 28 }
29 29
30 @Test 30 @Test
  31 + fun iceServersSetOnJoin() = runTest {
  32 + connect()
  33 + val sentIceServers = SignalClientTest.JOIN.join.iceServersList
  34 + .map { it.toWebrtc() }
  35 + val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
  36 +
  37 + assertEquals(sentIceServers, subPeerConnection.rtcConfig.iceServers)
  38 + }
  39 +
  40 + @Test
31 fun iceSubscriberConnect() = runTest { 41 fun iceSubscriberConnect() = runTest {
32 connect() 42 connect()
33 assertEquals( 43 assertEquals(
@@ -111,8 +121,8 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -111,8 +121,8 @@ class RTCEngineMockE2ETest : MockE2ETest() {
111 build() 121 build()
112 }) 122 })
113 123
114 - val pubPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection  
115 - assertEquals(PeerConnection.IceTransportsType.RELAY, pubPeerConnection.rtcConfig.iceTransportsType) 124 + val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
  125 + assertEquals(PeerConnection.IceTransportsType.RELAY, subPeerConnection.rtcConfig.iceTransportsType)
116 } 126 }
117 127
118 fun participantIdOnReconnect() = runTest { 128 fun participantIdOnReconnect() = runTest {
1 package io.livekit.android.room 1 package io.livekit.android.room
2 2
3 import io.livekit.android.MockE2ETest 3 import io.livekit.android.MockE2ETest
4 -import io.livekit.android.assert.assertIsClassList  
5 -import io.livekit.android.events.EventCollector  
6 -import io.livekit.android.events.FlowCollector  
7 -import io.livekit.android.events.RoomEvent  
8 import io.livekit.android.mock.MockAudioStreamTrack 4 import io.livekit.android.mock.MockAudioStreamTrack
  5 +import io.livekit.android.mock.MockPeerConnection
9 import io.livekit.android.room.track.LocalAudioTrack 6 import io.livekit.android.room.track.LocalAudioTrack
10 -import io.livekit.android.util.flow  
11 import io.livekit.android.util.toPBByteString 7 import io.livekit.android.util.toPBByteString
12 -import junit.framework.Assert.assertEquals  
13 import kotlinx.coroutines.ExperimentalCoroutinesApi 8 import kotlinx.coroutines.ExperimentalCoroutinesApi
14 import kotlinx.coroutines.launch 9 import kotlinx.coroutines.launch
15 import livekit.LivekitRtc 10 import livekit.LivekitRtc
16 import org.junit.Assert 11 import org.junit.Assert
  12 +import org.junit.Assert.assertEquals
17 import org.junit.Test 13 import org.junit.Test
18 import org.junit.runner.RunWith 14 import org.junit.runner.RunWith
19 import org.robolectric.RobolectricTestRunner 15 import org.robolectric.RobolectricTestRunner
  16 +import org.webrtc.PeerConnection
20 17
21 /** 18 /**
22 * For tests that only target one reconnection type. 19 * For tests that only target one reconnection type.
@@ -37,6 +34,8 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -37,6 +34,8 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
37 34
38 if (softReconnectParam == 0) { 35 if (softReconnectParam == 0) {
39 simulateMessageFromServer(SignalClientTest.JOIN) 36 simulateMessageFromServer(SignalClientTest.JOIN)
  37 + } else {
  38 + simulateMessageFromServer(SignalClientTest.RECONNECT)
40 } 39 }
41 } 40 }
42 } 41 }
@@ -66,6 +65,27 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -66,6 +65,27 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
66 } 65 }
67 66
68 @Test 67 @Test
  68 + fun softReconnectConfiguration() = runTest {
  69 +
  70 + room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
  71 + connect()
  72 + prepareForReconnect()
  73 + disconnectPeerConnection()
  74 + // Wait so that the reconnect job properly starts first.
  75 + testScheduler.advanceTimeBy(1000)
  76 + connectPeerConnection()
  77 +
  78 + val rtcEngine = component.rtcEngine()
  79 + val rtcConfig = (rtcEngine.subscriber.peerConnection as MockPeerConnection).rtcConfig
  80 + assertEquals(PeerConnection.IceTransportsType.RELAY, rtcConfig.iceTransportsType)
  81 +
  82 + val sentIceServers = SignalClientTest.RECONNECT.reconnect.iceServersList
  83 + .map { server -> server.toWebrtc() }
  84 + assertEquals(sentIceServers, rtcConfig.iceServers)
  85 +
  86 + }
  87 +
  88 + @Test
69 fun fullReconnectRepublishesTracks() = runTest { 89 fun fullReconnectRepublishesTracks() = runTest {
70 room.setReconnectionType(ReconnectType.FORCE_FULL_RECONNECT) 90 room.setReconnectionType(ReconnectType.FORCE_FULL_RECONNECT)
71 connect() 91 connect()
@@ -39,6 +39,8 @@ class RoomReconnectionTypesMockE2ETest( @@ -39,6 +39,8 @@ class RoomReconnectionTypesMockE2ETest(
39 39
40 if (softReconnectParam == 0) { 40 if (softReconnectParam == 0) {
41 simulateMessageFromServer(SignalClientTest.JOIN) 41 simulateMessageFromServer(SignalClientTest.JOIN)
  42 + } else {
  43 + simulateMessageFromServer(SignalClientTest.RECONNECT)
42 } 44 }
43 } 45 }
44 } 46 }
@@ -11,7 +11,9 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -11,7 +11,9 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
11 import kotlinx.coroutines.async 11 import kotlinx.coroutines.async
12 import kotlinx.serialization.json.Json 12 import kotlinx.serialization.json.Json
13 import livekit.LivekitModels 13 import livekit.LivekitModels
  14 +import livekit.LivekitModels.ClientConfiguration
14 import livekit.LivekitRtc 15 import livekit.LivekitRtc
  16 +import livekit.LivekitRtc.ICEServer
15 import okhttp3.* 17 import okhttp3.*
16 import org.junit.Assert.* 18 import org.junit.Assert.*
17 import org.junit.Before 19 import org.junit.Before
@@ -86,6 +88,7 @@ class SignalClientTest : BaseTest() { @@ -86,6 +88,7 @@ class SignalClientTest : BaseTest() {
86 } 88 }
87 89
88 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request)) 90 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  91 + client.onMessage(wsFactory.ws, RECONNECT.toOkioByteString())
89 92
90 job.await() 93 job.await()
91 assertEquals(true, client.isConnected) 94 assertEquals(true, client.isConnected)
@@ -199,6 +202,7 @@ class SignalClientTest : BaseTest() { @@ -199,6 +202,7 @@ class SignalClientTest : BaseTest() {
199 202
200 val job = async { client.reconnect(EXAMPLE_URL, "", "participant_sid") } 203 val job = async { client.reconnect(EXAMPLE_URL, "", "participant_sid") }
201 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request)) 204 client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  205 + client.onMessage(wsFactory.ws, RECONNECT.toOkioByteString())
202 job.await() 206 job.await()
203 207
204 val ws = wsFactory.ws 208 val ws = wsFactory.ws
@@ -291,12 +295,35 @@ class SignalClientTest : BaseTest() { @@ -291,12 +295,35 @@ class SignalClientTest : BaseTest() {
291 } 295 }
292 participant = TestData.LOCAL_PARTICIPANT 296 participant = TestData.LOCAL_PARTICIPANT
293 subscriberPrimary = true 297 subscriberPrimary = true
  298 + addIceServers(with(ICEServer.newBuilder()) {
  299 + addUrls("stun:stun.join.com:19302")
  300 + username = "username"
  301 + credential = "credential"
  302 + build()
  303 + })
294 serverVersion = "0.15.2" 304 serverVersion = "0.15.2"
295 build() 305 build()
296 } 306 }
297 build() 307 build()
298 } 308 }
299 309
  310 + val RECONNECT = with(LivekitRtc.SignalResponse.newBuilder()) {
  311 + reconnect = with(LivekitRtc.ReconnectResponse.newBuilder()) {
  312 + addIceServers(with(ICEServer.newBuilder()) {
  313 + addUrls("stun:stun.reconnect.com:19302")
  314 + username = "username"
  315 + credential = "credential"
  316 + build()
  317 + })
  318 + clientConfiguration = with(ClientConfiguration.newBuilder()) {
  319 + forceRelay = LivekitModels.ClientConfigSetting.ENABLED
  320 + build()
  321 + }
  322 + build()
  323 + }
  324 + build()
  325 + }
  326 +
300 val OFFER = with(LivekitRtc.SignalResponse.newBuilder()) { 327 val OFFER = with(LivekitRtc.SignalResponse.newBuilder()) {
301 offer = with(LivekitRtc.SessionDescription.newBuilder()) { 328 offer = with(LivekitRtc.SessionDescription.newBuilder()) {
302 sdp = "remote_offer" 329 sdp = "remote_offer"
  1 +package io.livekit.android.webrtc
  2 +
  3 +import io.livekit.android.BaseTest
  4 +import org.junit.Assert.assertEquals
  5 +import org.junit.Assert.assertTrue
  6 +import org.junit.Test
  7 +import org.mockito.Mockito
  8 +import org.webrtc.PeerConnection.RTCConfiguration
  9 +
  10 +class RTCConfigurationTest : BaseTest() {
  11 +
  12 + @Test
  13 + fun copyTest() {
  14 + val originalConfig = RTCConfiguration(mutableListOf())
  15 + fillWithMockData(originalConfig)
  16 + val newConfig = originalConfig.copy()
  17 + newConfig::class.java
  18 + .declaredFields
  19 + .forEach { field ->
  20 + assertEquals("Failed on ${field.name}", field.get(originalConfig), field.get(newConfig))
  21 + }
  22 + }
  23 +
  24 + // Test to make sure the copy test is actually checking properly
  25 + @Test
  26 + fun copyFailureCheckTest() {
  27 + val originalConfig = RTCConfiguration(mutableListOf())
  28 + fillWithMockData(originalConfig)
  29 + val newConfig = originalConfig.copy()
  30 +
  31 + newConfig.activeResetSrtpParams = false
  32 +
  33 + var caughtError = false
  34 + try {
  35 + newConfig::class.java
  36 + .declaredFields
  37 + .forEach { field ->
  38 + assertEquals("Failed on ${field.name}", field.get(originalConfig), field.get(newConfig))
  39 + }
  40 + } catch (e: java.lang.AssertionError) {
  41 + // Error expected
  42 + caughtError = true
  43 + }
  44 +
  45 + assertTrue(caughtError)
  46 + }
  47 +
  48 + private fun fillWithMockData(config: RTCConfiguration) {
  49 + config::class.java
  50 + .declaredFields
  51 + .forEach { field ->
  52 + // Ignore iceServers.
  53 + if (field.name == "iceServers") {
  54 + return@forEach
  55 + }
  56 +
  57 + val value = field.get(config)
  58 + val newValue = if (value == null) {
  59 + when (field.type) {
  60 + Byte::class.javaObjectType -> 1.toByte()
  61 + Short::class.javaObjectType -> 1.toShort()
  62 + Int::class.javaObjectType -> 1
  63 + Long::class.javaObjectType -> 1.toLong()
  64 + Float::class.javaObjectType -> 1.toFloat()
  65 + Double::class.javaObjectType -> 1.toDouble()
  66 + Boolean::class.javaObjectType -> true
  67 + Char::class.javaObjectType -> 1.toChar()
  68 + String::class.javaObjectType -> "mock string"
  69 + else -> Mockito.mock(field.type)
  70 + }
  71 + } else {
  72 + when (value::class.javaObjectType) {
  73 + Byte::class.javaObjectType -> ((value as Byte) + 1).toByte()
  74 + Short::class.javaObjectType -> ((value as Short) + 1).toShort()
  75 + Int::class.javaObjectType -> (value as Int) + 1
  76 + Long::class.javaObjectType -> (value as Long) + 1
  77 + Float::class.javaObjectType -> (value as Float) + 1
  78 + Double::class.javaObjectType -> (value as Double) + 1
  79 + Boolean::class.javaObjectType -> !(value as Boolean)
  80 + Char::class.javaObjectType -> (value as Char) + 1
  81 + String::class.javaObjectType -> "mock string"
  82 + else -> Mockito.mock(field.type)
  83 + }
  84 + }
  85 + field.set(config, newValue)
  86 + }
  87 + }
  88 +}
1 -Subproject commit ec5189c9fd40d61cbc0b0ca43da6c71e7284ebcb 1 +Subproject commit a1819deeabe143b1af1bf375d84e52b152f784ac