davidliu
Committed by GitHub

Race condition fixes (#315)

* Add more mutexes surrounding peerconnection usage

* proper locking for response flow job

* Fix race condition in CoroutineSdpObserver

* Lock down peer connection usage

* spotless apply

* Exclude ReentrantMutex from spotless

* Fix tests

* spotless apply
@@ -53,14 +53,17 @@ subprojects { @@ -53,14 +53,17 @@ subprojects {
53 // through git history (see "license" section below) 53 // through git history (see "license" section below)
54 licenseHeaderFile rootProject.file("LicenseHeaderFile.txt") 54 licenseHeaderFile rootProject.file("LicenseHeaderFile.txt")
55 removeUnusedImports() 55 removeUnusedImports()
  56 + toggleOffOn()
56 } 57 }
57 kotlin { 58 kotlin {
58 target("src/*/java/**/*.kt") 59 target("src/*/java/**/*.kt")
  60 + targetExclude("src/*/java/**/ReentrantMutex.kt") // Different license
59 ktlint("0.50.0") 61 ktlint("0.50.0")
60 .setEditorConfigPath("$rootDir/.editorconfig") 62 .setEditorConfigPath("$rootDir/.editorconfig")
61 licenseHeaderFile(rootProject.file("LicenseHeaderFile.txt")) 63 licenseHeaderFile(rootProject.file("LicenseHeaderFile.txt"))
62 .named('license') 64 .named('license')
63 endWithNewline() 65 endWithNewline()
  66 + toggleOffOn()
64 } 67 }
65 } 68 }
66 } 69 }
  1 +/*
  2 +This is free and unencumbered software released into the public domain.
  3 +
  4 +Anyone is free to copy, modify, publish, use, compile, sell, or
  5 +distribute this software, either in source code form or as a compiled
  6 +binary, for any purpose, commercial or non-commercial, and by any
  7 +means.
  8 +
  9 +In jurisdictions that recognize copyright laws, the author or authors
  10 +of this software dedicate any and all copyright interest in the
  11 +software to the public domain. We make this dedication for the benefit
  12 +of the public at large and to the detriment of our heirs and
  13 +successors. We intend this dedication to be an overt act of
  14 +relinquishment in perpetuity of all present and future rights to this
  15 +software under copyright law.
  16 +
  17 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18 +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19 +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
  20 +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
  21 +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
  22 +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
  23 +OTHER DEALINGS IN THE SOFTWARE.
  24 +
  25 +For more information, please refer to <https://unlicense.org>
  26 +
  27 +Original is at https://gist.github.com/elizarov/9a48b9709ffd508909d34fab6786acfe
  28 +*/
  29 +
  30 +package io.livekit.android.coroutines
  31 +
  32 +import kotlinx.coroutines.sync.Mutex
  33 +import kotlinx.coroutines.sync.withLock
  34 +import kotlinx.coroutines.withContext
  35 +import kotlin.coroutines.CoroutineContext
  36 +import kotlin.coroutines.coroutineContext
  37 +
  38 +internal suspend fun <T> Mutex.withReentrantLock(block: suspend () -> T): T {
  39 + val key = ReentrantMutexContextKey(this)
  40 + // call block directly when this mutex is already locked in the context
  41 + if (coroutineContext[key] != null) return block()
  42 + // otherwise add it to the context and lock the mutex
  43 + return withContext(ReentrantMutexContextElement(key)) {
  44 + withLock { block() }
  45 + }
  46 +}
  47 +
  48 +internal class ReentrantMutexContextElement(
  49 + override val key: ReentrantMutexContextKey,
  50 +) : CoroutineContext.Element
  51 +
  52 +internal data class ReentrantMutexContextKey(
  53 + val mutex: Mutex,
  54 +) : CoroutineContext.Key<ReentrantMutexContextElement>
@@ -21,6 +21,7 @@ import android.javax.sdp.SdpFactory @@ -21,6 +21,7 @@ import android.javax.sdp.SdpFactory
21 import dagger.assisted.Assisted 21 import dagger.assisted.Assisted
22 import dagger.assisted.AssistedFactory 22 import dagger.assisted.AssistedFactory
23 import dagger.assisted.AssistedInject 23 import dagger.assisted.AssistedInject
  24 +import io.livekit.android.coroutines.withReentrantLock
24 import io.livekit.android.dagger.InjectionNames 25 import io.livekit.android.dagger.InjectionNames
25 import io.livekit.android.room.util.* 26 import io.livekit.android.room.util.*
26 import io.livekit.android.util.Either 27 import io.livekit.android.util.Either
@@ -32,15 +33,20 @@ import io.livekit.android.webrtc.getExts @@ -32,15 +33,20 @@ import io.livekit.android.webrtc.getExts
32 import io.livekit.android.webrtc.getFmtps 33 import io.livekit.android.webrtc.getFmtps
33 import io.livekit.android.webrtc.getMsid 34 import io.livekit.android.webrtc.getMsid
34 import io.livekit.android.webrtc.getRtps 35 import io.livekit.android.webrtc.getRtps
  36 +import io.livekit.android.webrtc.isConnected
35 import kotlinx.coroutines.CoroutineDispatcher 37 import kotlinx.coroutines.CoroutineDispatcher
36 import kotlinx.coroutines.CoroutineScope 38 import kotlinx.coroutines.CoroutineScope
37 import kotlinx.coroutines.SupervisorJob 39 import kotlinx.coroutines.SupervisorJob
38 import kotlinx.coroutines.runBlocking 40 import kotlinx.coroutines.runBlocking
39 import kotlinx.coroutines.sync.Mutex 41 import kotlinx.coroutines.sync.Mutex
40 -import kotlinx.coroutines.sync.withLock  
41 import org.webrtc.* 42 import org.webrtc.*
42 import org.webrtc.PeerConnection.RTCConfiguration 43 import org.webrtc.PeerConnection.RTCConfiguration
  44 +import org.webrtc.PeerConnection.SignalingState
  45 +import java.util.concurrent.atomic.AtomicBoolean
43 import javax.inject.Named 46 import javax.inject.Named
  47 +import kotlin.contracts.ExperimentalContracts
  48 +import kotlin.contracts.InvocationKind
  49 +import kotlin.contracts.contract
44 import kotlin.math.roundToLong 50 import kotlin.math.roundToLong
45 51
46 /** 52 /**
@@ -58,7 +64,7 @@ constructor( @@ -58,7 +64,7 @@ constructor(
58 private val sdpFactory: SdpFactory, 64 private val sdpFactory: SdpFactory,
59 ) { 65 ) {
60 private val coroutineScope = CoroutineScope(ioDispatcher + SupervisorJob()) 66 private val coroutineScope = CoroutineScope(ioDispatcher + SupervisorJob())
61 - internal val peerConnection: PeerConnection = connectionFactory.createPeerConnection( 67 + private val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
62 config, 68 config,
63 pcObserver, 69 pcObserver,
64 ) ?: throw IllegalStateException("peer connection creation failed?") 70 ) ?: throw IllegalStateException("peer connection creation failed?")
@@ -70,6 +76,7 @@ constructor( @@ -70,6 +76,7 @@ constructor(
70 private val mutex = Mutex() 76 private val mutex = Mutex()
71 77
72 private var trackBitrates = mutableMapOf<TrackBitrateInfoKey, TrackBitrateInfo>() 78 private var trackBitrates = mutableMapOf<TrackBitrateInfoKey, TrackBitrateInfo>()
  79 + private var isClosed = AtomicBoolean(false)
73 80
74 interface Listener { 81 interface Listener {
75 fun onOffer(sd: SessionDescription) 82 fun onOffer(sd: SessionDescription)
@@ -77,7 +84,7 @@ constructor( @@ -77,7 +84,7 @@ constructor(
77 84
78 fun addIceCandidate(candidate: IceCandidate) { 85 fun addIceCandidate(candidate: IceCandidate) {
79 runBlocking { 86 runBlocking {
80 - mutex.withLock { 87 + withNotClosedLock {
81 if (peerConnection.remoteDescription != null && !restartingIce) { 88 if (peerConnection.remoteDescription != null && !restartingIce) {
82 peerConnection.addIceCandidate(candidate) 89 peerConnection.addIceCandidate(candidate)
83 } else { 90 } else {
@@ -87,17 +94,24 @@ constructor( @@ -87,17 +94,24 @@ constructor(
87 } 94 }
88 } 95 }
89 96
  97 + suspend fun <T> withPeerConnection(action: suspend PeerConnection.() -> T): T? {
  98 + return withNotClosedLock {
  99 + action(peerConnection)
  100 + }
  101 + }
  102 +
90 suspend fun setRemoteDescription(sd: SessionDescription): Either<Unit, String?> { 103 suspend fun setRemoteDescription(sd: SessionDescription): Either<Unit, String?> {
91 - val result = peerConnection.setRemoteDescription(sd)  
92 - if (result is Either.Left) {  
93 - mutex.withLock { 104 + val result = withNotClosedLock {
  105 + val result = peerConnection.setRemoteDescription(sd)
  106 + if (result is Either.Left) {
94 pendingCandidates.forEach { pending -> 107 pendingCandidates.forEach { pending ->
95 peerConnection.addIceCandidate(pending) 108 peerConnection.addIceCandidate(pending)
96 } 109 }
97 pendingCandidates.clear() 110 pendingCandidates.clear()
98 restartingIce = false 111 restartingIce = false
99 } 112 }
100 - } 113 + result
  114 + } ?: Either.Right("PCT is closed.")
101 115
102 if (this.renegotiate) { 116 if (this.renegotiate) {
103 this.renegotiate = false 117 this.renegotiate = false
@@ -115,60 +129,65 @@ constructor( @@ -115,60 +129,65 @@ constructor(
115 } 129 }
116 } 130 }
117 131
118 - suspend fun createAndSendOffer(constraints: MediaConstraints = MediaConstraints()) { 132 + private suspend fun createAndSendOffer(constraints: MediaConstraints = MediaConstraints()) {
119 if (listener == null) { 133 if (listener == null) {
120 return 134 return
121 } 135 }
122 136
123 - val iceRestart =  
124 - constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE  
125 - if (iceRestart) {  
126 - LKLog.d { "restarting ice" }  
127 - restartingIce = true  
128 - } 137 + var finalSdp: SessionDescription? = null
129 138
130 - if (this.peerConnection.signalingState() == PeerConnection.SignalingState.HAVE_LOCAL_OFFER) {  
131 - // we're waiting for the peer to accept our offer, so we'll just wait  
132 - // the only exception to this is when ICE restart is needed  
133 - val curSd = peerConnection.remoteDescription  
134 - if (iceRestart && curSd != null) {  
135 - // TODO: handle when ICE restart is needed but we don't have a remote description  
136 - // the best thing to do is to recreate the peerconnection  
137 - peerConnection.setRemoteDescription(curSd)  
138 - } else {  
139 - renegotiate = true  
140 - return 139 + // TODO: This is a potentially long lock hold. May need to break up.
  140 + withNotClosedLock {
  141 + val iceRestart =
  142 + constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE
  143 + if (iceRestart) {
  144 + LKLog.d { "restarting ice" }
  145 + restartingIce = true
141 } 146 }
142 - }  
143 147
144 - // actually negotiate  
145 - LKLog.d { "starting to negotiate" }  
146 - val sdpOffer = when (val outcome = peerConnection.createOffer(constraints)) {  
147 - is Either.Left -> outcome.value  
148 - is Either.Right -> {  
149 - LKLog.d { "error creating offer: ${outcome.value}" }  
150 - return 148 + if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
  149 + // we're waiting for the peer to accept our offer, so we'll just wait
  150 + // the only exception to this is when ICE restart is needed
  151 + val curSd = peerConnection.remoteDescription
  152 + if (iceRestart && curSd != null) {
  153 + // TODO: handle when ICE restart is needed but we don't have a remote description
  154 + // the best thing to do is to recreate the peerconnection
  155 + peerConnection.setRemoteDescription(curSd)
  156 + } else {
  157 + renegotiate = true
  158 + return@withNotClosedLock
  159 + }
151 } 160 }
152 - }  
153 161
154 - // munge sdp  
155 - val sdpDescription = sdpFactory.createSessionDescription(sdpOffer.description)  
156 -  
157 - val mediaDescs = sdpDescription.getMediaDescriptions(true)  
158 - for (mediaDesc in mediaDescs) {  
159 - if (mediaDesc !is MediaDescription) {  
160 - continue 162 + // actually negotiate
  163 + val sdpOffer = when (val outcome = peerConnection.createOffer(constraints)) {
  164 + is Either.Left -> outcome.value
  165 + is Either.Right -> {
  166 + LKLog.d { "error creating offer: ${outcome.value}" }
  167 + return@withNotClosedLock
  168 + }
161 } 169 }
162 - if (mediaDesc.media.mediaType == "audio") {  
163 - // TODO  
164 - } else if (mediaDesc.media.mediaType == "video") {  
165 - ensureVideoDDExtensionForSVC(mediaDesc)  
166 - ensureCodecBitrates(mediaDesc, trackBitrates = trackBitrates) 170 +
  171 + // munge sdp
  172 + val sdpDescription = sdpFactory.createSessionDescription(sdpOffer.description)
  173 +
  174 + val mediaDescs = sdpDescription.getMediaDescriptions(true)
  175 + for (mediaDesc in mediaDescs) {
  176 + if (mediaDesc !is MediaDescription) {
  177 + continue
  178 + }
  179 + if (mediaDesc.media.mediaType == "audio") {
  180 + // TODO
  181 + } else if (mediaDesc.media.mediaType == "video") {
  182 + ensureVideoDDExtensionForSVC(mediaDesc)
  183 + ensureCodecBitrates(mediaDesc, trackBitrates = trackBitrates)
  184 + }
167 } 185 }
  186 + finalSdp = setMungedSdp(sdpOffer, sdpDescription.toString())
  187 + }
  188 + if (finalSdp != null) {
  189 + listener.onOffer(finalSdp!!)
168 } 190 }
169 -  
170 - val finalSdp = setMungedSdp(sdpOffer, sdpDescription.toString())  
171 - listener.onOffer(finalSdp)  
172 } 191 }
173 192
174 private suspend fun setMungedSdp(sdp: SessionDescription, mungedDescription: String, remote: Boolean = false): SessionDescription { 193 private suspend fun setMungedSdp(sdp: SessionDescription, mungedDescription: String, remote: Boolean = false): SessionDescription {
@@ -233,12 +252,27 @@ constructor( @@ -233,12 +252,27 @@ constructor(
233 restartingIce = true 252 restartingIce = true
234 } 253 }
235 254
236 - fun close() {  
237 - peerConnection.dispose() 255 + fun isClosed() = isClosed.get()
  256 +
  257 + fun closeBlocking() {
  258 + runBlocking {
  259 + close()
  260 + }
  261 + }
  262 +
  263 + suspend fun close() {
  264 + withNotClosedLock {
  265 + isClosed.set(true)
  266 + peerConnection.dispose()
  267 + }
238 } 268 }
239 269
240 fun updateRTCConfig(config: RTCConfiguration) { 270 fun updateRTCConfig(config: RTCConfiguration) {
241 - peerConnection.setConfiguration(config) 271 + runBlocking {
  272 + withNotClosedLock {
  273 + peerConnection.setConfiguration(config)
  274 + }
  275 + }
242 } 276 }
243 277
244 fun registerTrackBitrateInfo(cid: String, trackBitrateInfo: TrackBitrateInfo) { 278 fun registerTrackBitrateInfo(cid: String, trackBitrateInfo: TrackBitrateInfo) {
@@ -249,6 +283,44 @@ constructor( @@ -249,6 +283,44 @@ constructor(
249 trackBitrates[TrackBitrateInfoKey.Transceiver(transceiver)] = trackBitrateInfo 283 trackBitrates[TrackBitrateInfoKey.Transceiver(transceiver)] = trackBitrateInfo
250 } 284 }
251 285
  286 + suspend fun isConnected(): Boolean {
  287 + return withNotClosedLock {
  288 + peerConnection.isConnected()
  289 + } ?: false
  290 + }
  291 +
  292 + suspend fun iceConnectionState(): PeerConnection.IceConnectionState {
  293 + return withNotClosedLock {
  294 + peerConnection.iceConnectionState()
  295 + } ?: PeerConnection.IceConnectionState.CLOSED
  296 + }
  297 +
  298 + suspend fun connectionState(): PeerConnection.PeerConnectionState {
  299 + return withNotClosedLock {
  300 + peerConnection.connectionState()
  301 + } ?: PeerConnection.PeerConnectionState.CLOSED
  302 + }
  303 +
  304 + suspend fun signalingState(): SignalingState {
  305 + return withNotClosedLock {
  306 + peerConnection.signalingState()
  307 + } ?: SignalingState.CLOSED
  308 + }
  309 +
  310 + @OptIn(ExperimentalContracts::class)
  311 + private suspend inline fun <T> withNotClosedLock(crossinline action: suspend () -> T): T? {
  312 + contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
  313 + if (isClosed()) {
  314 + return null
  315 + }
  316 + return mutex.withReentrantLock {
  317 + if (isClosed()) {
  318 + return@withReentrantLock null
  319 + }
  320 + return@withReentrantLock action()
  321 + }
  322 + }
  323 +
252 @AssistedFactory 324 @AssistedFactory
253 interface Factory { 325 interface Factory {
254 fun create( 326 fun create(
@@ -296,7 +368,7 @@ internal fun ensureVideoDDExtensionForSVC(mediaDesc: MediaDescription) { @@ -296,7 +368,7 @@ internal fun ensureVideoDDExtensionForSVC(mediaDesc: MediaDescription) {
296 } 368 }
297 } 369 }
298 370
299 -/* The svc codec (av1/vp9) would use a very low bitrate at the begining and 371 +/* The svc codec (av1/vp9) would use a very low bitrate at the beginning and
300 increase slowly by the bandwidth estimator until it reach the target bitrate. The 372 increase slowly by the bandwidth estimator until it reach the target bitrate. The
301 process commonly cost more than 10 seconds cause subscriber will get blur video at 373 process commonly cost more than 10 seconds cause subscriber will get blur video at
302 the first few seconds. So we use a 70% of target bitrate here as the start bitrate to 374 the first few seconds. So we use a 70% of target bitrate here as the start bitrate to
@@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
17 package io.livekit.android.room 17 package io.livekit.android.room
18 18
19 import android.os.SystemClock 19 import android.os.SystemClock
  20 +import androidx.annotation.VisibleForTesting
20 import com.google.protobuf.ByteString 21 import com.google.protobuf.ByteString
21 import io.livekit.android.ConnectOptions 22 import io.livekit.android.ConnectOptions
22 import io.livekit.android.RoomOptions 23 import io.livekit.android.RoomOptions
@@ -31,6 +32,7 @@ import io.livekit.android.room.util.setLocalDescription @@ -31,6 +32,7 @@ import io.livekit.android.room.util.setLocalDescription
31 import io.livekit.android.util.CloseableCoroutineScope 32 import io.livekit.android.util.CloseableCoroutineScope
32 import io.livekit.android.util.Either 33 import io.livekit.android.util.Either
33 import io.livekit.android.util.LKLog 34 import io.livekit.android.util.LKLog
  35 +import io.livekit.android.webrtc.RTCStatsGetter
34 import io.livekit.android.webrtc.copy 36 import io.livekit.android.webrtc.copy
35 import io.livekit.android.webrtc.isConnected 37 import io.livekit.android.webrtc.isConnected
36 import io.livekit.android.webrtc.isDisconnected 38 import io.livekit.android.webrtc.isDisconnected
@@ -114,18 +116,8 @@ internal constructor( @@ -114,18 +116,8 @@ internal constructor(
114 private val publisherObserver = PublisherTransportObserver(this, client) 116 private val publisherObserver = PublisherTransportObserver(this, client)
115 private val subscriberObserver = SubscriberTransportObserver(this, client) 117 private val subscriberObserver = SubscriberTransportObserver(this, client)
116 118
117 - private var _publisher: PeerConnectionTransport? = null  
118 - internal val publisher: PeerConnectionTransport  
119 - get() {  
120 - return _publisher  
121 - ?: throw UninitializedPropertyAccessException("publisher has not been initialized yet.")  
122 - }  
123 - private var _subscriber: PeerConnectionTransport? = null  
124 - internal val subscriber: PeerConnectionTransport  
125 - get() {  
126 - return _subscriber  
127 - ?: throw UninitializedPropertyAccessException("subscriber has not been initialized yet.")  
128 - } 119 + private var publisher: PeerConnectionTransport? = null
  120 + private var subscriber: PeerConnectionTransport? = null
129 121
130 private var reliableDataChannel: DataChannel? = null 122 private var reliableDataChannel: DataChannel? = null
131 private var reliableDataChannelSub: DataChannel? = null 123 private var reliableDataChannelSub: DataChannel? = null
@@ -181,8 +173,8 @@ internal constructor( @@ -181,8 +173,8 @@ internal constructor(
181 return joinResponse 173 return joinResponse
182 } 174 }
183 175
184 - private fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {  
185 - if (_publisher != null && _subscriber != null) { 176 + private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
  177 + if (publisher != null && subscriber != null) {
186 // already configured 178 // already configured
187 return 179 return
188 } 180 }
@@ -196,14 +188,14 @@ internal constructor( @@ -196,14 +188,14 @@ internal constructor(
196 // Setup peer connections 188 // Setup peer connections
197 val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions) 189 val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
198 190
199 - _publisher?.close()  
200 - _publisher = pctFactory.create( 191 + publisher?.close()
  192 + publisher = pctFactory.create(
201 rtcConfig, 193 rtcConfig,
202 publisherObserver, 194 publisherObserver,
203 publisherObserver, 195 publisherObserver,
204 ) 196 )
205 - _subscriber?.close()  
206 - _subscriber = pctFactory.create( 197 + subscriber?.close()
  198 + subscriber = pctFactory.create(
207 rtcConfig, 199 rtcConfig,
208 subscriberObserver, 200 subscriberObserver,
209 null, 201 null,
@@ -243,19 +235,22 @@ internal constructor( @@ -243,19 +235,22 @@ internal constructor(
243 // data channels 235 // data channels
244 val reliableInit = DataChannel.Init() 236 val reliableInit = DataChannel.Init()
245 reliableInit.ordered = true 237 reliableInit.ordered = true
246 - reliableDataChannel = publisher.peerConnection.createDataChannel(  
247 - RELIABLE_DATA_CHANNEL_LABEL,  
248 - reliableInit,  
249 - )  
250 - reliableDataChannel!!.registerObserver(this) 238 + reliableDataChannel = publisher?.withPeerConnection {
  239 + createDataChannel(
  240 + RELIABLE_DATA_CHANNEL_LABEL,
  241 + reliableInit,
  242 + ).apply { registerObserver(this@RTCEngine) }
  243 + }
  244 +
251 val lossyInit = DataChannel.Init() 245 val lossyInit = DataChannel.Init()
252 lossyInit.ordered = true 246 lossyInit.ordered = true
253 lossyInit.maxRetransmits = 0 247 lossyInit.maxRetransmits = 0
254 - lossyDataChannel = publisher.peerConnection.createDataChannel(  
255 - LOSSY_DATA_CHANNEL_LABEL,  
256 - lossyInit,  
257 - )  
258 - lossyDataChannel!!.registerObserver(this) 248 + lossyDataChannel = publisher?.withPeerConnection {
  249 + createDataChannel(
  250 + LOSSY_DATA_CHANNEL_LABEL,
  251 + lossyInit,
  252 + ).apply { registerObserver(this@RTCEngine) }
  253 + }
259 } 254 }
260 255
261 /** 256 /**
@@ -277,11 +272,13 @@ internal constructor( @@ -277,11 +272,13 @@ internal constructor(
277 } 272 }
278 } 273 }
279 274
280 - internal fun createSenderTransceiver( 275 + internal suspend fun createSenderTransceiver(
281 rtcTrack: MediaStreamTrack, 276 rtcTrack: MediaStreamTrack,
282 transInit: RtpTransceiverInit, 277 transInit: RtpTransceiverInit,
283 ): RtpTransceiver? { 278 ): RtpTransceiver? {
284 - return publisher.peerConnection.addTransceiver(rtcTrack, transInit) 279 + return publisher?.withPeerConnection {
  280 + addTransceiver(rtcTrack, transInit)
  281 + }
285 } 282 }
286 283
287 fun updateSubscriptionPermissions( 284 fun updateSubscriptionPermissions(
@@ -301,15 +298,15 @@ internal constructor( @@ -301,15 +298,15 @@ internal constructor(
301 } 298 }
302 LKLog.v { "Close - $reason" } 299 LKLog.v { "Close - $reason" }
303 isClosed = true 300 isClosed = true
  301 + reconnectingJob?.cancel()
  302 + reconnectingJob = null
  303 + coroutineScope.close()
304 hasPublished = false 304 hasPublished = false
305 sessionUrl = null 305 sessionUrl = null
306 sessionToken = null 306 sessionToken = null
307 connectOptions = null 307 connectOptions = null
308 lastRoomOptions = null 308 lastRoomOptions = null
309 participantSid = null 309 participantSid = null
310 - reconnectingJob?.cancel()  
311 - reconnectingJob = null  
312 - coroutineScope.close()  
313 closeResources(reason) 310 closeResources(reason)
314 connectionState = ConnectionState.DISCONNECTED 311 connectionState = ConnectionState.DISCONNECTED
315 } 312 }
@@ -317,10 +314,10 @@ internal constructor( @@ -317,10 +314,10 @@ internal constructor(
317 private fun closeResources(reason: String) { 314 private fun closeResources(reason: String) {
318 publisherObserver.connectionChangeListener = null 315 publisherObserver.connectionChangeListener = null
319 subscriberObserver.connectionChangeListener = null 316 subscriberObserver.connectionChangeListener = null
320 - _publisher?.close()  
321 - _publisher = null  
322 - _subscriber?.close()  
323 - _subscriber = null 317 + publisher?.closeBlocking()
  318 + publisher = null
  319 + subscriber?.closeBlocking()
  320 + subscriber = null
324 321
325 fun DataChannel?.completeDispose() { 322 fun DataChannel?.completeDispose() {
326 this?.unregisterObserver() 323 this?.unregisterObserver()
@@ -366,6 +363,15 @@ internal constructor( @@ -366,6 +363,15 @@ internal constructor(
366 363
367 val reconnectStartTime = SystemClock.elapsedRealtime() 364 val reconnectStartTime = SystemClock.elapsedRealtime()
368 for (retries in 0 until MAX_RECONNECT_RETRIES) { 365 for (retries in 0 until MAX_RECONNECT_RETRIES) {
  366 + if (retries != 0) {
  367 + yield()
  368 + }
  369 +
  370 + if (isClosed) {
  371 + LKLog.v { "RTCEngine closed, aborting reconnection" }
  372 + break
  373 + }
  374 +
369 var startDelay = 100 + retries.toLong() * retries * 500 375 var startDelay = 100 + retries.toLong() * retries * 500
370 if (startDelay > 5000) { 376 if (startDelay > 5000) {
371 startDelay = 5000 377 startDelay = 5000
@@ -395,14 +401,14 @@ internal constructor( @@ -395,14 +401,14 @@ internal constructor(
395 } 401 }
396 } else { 402 } else {
397 LKLog.v { "Attempting soft reconnect." } 403 LKLog.v { "Attempting soft reconnect." }
398 - subscriber.prepareForIceRestart() 404 + subscriber?.prepareForIceRestart()
399 try { 405 try {
400 val response = client.reconnect(url, token, participantSid) 406 val response = client.reconnect(url, token, participantSid)
401 if (response is Either.Left) { 407 if (response is Either.Left) {
402 val reconnectResponse = response.value 408 val reconnectResponse = response.value
403 val rtcConfig = makeRTCConfig(Either.Right(reconnectResponse), connectOptions) 409 val rtcConfig = makeRTCConfig(Either.Right(reconnectResponse), connectOptions)
404 - _subscriber?.updateRTCConfig(rtcConfig)  
405 - _publisher?.updateRTCConfig(rtcConfig) 410 + subscriber?.updateRTCConfig(rtcConfig)
  411 + publisher?.updateRTCConfig(rtcConfig)
406 } 412 }
407 client.onReadyForResponses() 413 client.onReadyForResponses()
408 } catch (e: Exception) { 414 } catch (e: Exception) {
@@ -420,11 +426,17 @@ internal constructor( @@ -420,11 +426,17 @@ internal constructor(
420 negotiatePublisher() 426 negotiatePublisher()
421 } 427 }
422 } 428 }
  429 +
  430 + if (isClosed) {
  431 + LKLog.v { "RTCEngine closed, aborting reconnection" }
  432 + break
  433 + }
  434 +
423 // wait until ICE connected 435 // wait until ICE connected
424 val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS 436 val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
425 if (hasPublished) { 437 if (hasPublished) {
426 while (SystemClock.elapsedRealtime() < endTime) { 438 while (SystemClock.elapsedRealtime() < endTime) {
427 - if (publisher.peerConnection.connectionState().isConnected()) { 439 + if (publisher?.isConnected() == true) {
428 LKLog.v { "publisher reconnected to ICE" } 440 LKLog.v { "publisher reconnected to ICE" }
429 break 441 break
430 } 442 }
@@ -432,8 +444,13 @@ internal constructor( @@ -432,8 +444,13 @@ internal constructor(
432 } 444 }
433 } 445 }
434 446
  447 + if (isClosed) {
  448 + LKLog.v { "RTCEngine closed, aborting reconnection" }
  449 + break
  450 + }
  451 +
435 while (SystemClock.elapsedRealtime() < endTime) { 452 while (SystemClock.elapsedRealtime() < endTime) {
436 - if (subscriber.peerConnection.connectionState().isConnected()) { 453 + if (subscriber?.isConnected() == true) {
437 LKLog.v { "reconnected to ICE" } 454 LKLog.v { "reconnected to ICE" }
438 connectionState = ConnectionState.CONNECTED 455 connectionState = ConnectionState.CONNECTED
439 break 456 break
@@ -441,8 +458,12 @@ internal constructor( @@ -441,8 +458,12 @@ internal constructor(
441 delay(100) 458 delay(100)
442 } 459 }
443 460
  461 + if (isClosed) {
  462 + LKLog.v { "RTCEngine closed, aborting reconnection" }
  463 + break
  464 + }
444 if (connectionState == ConnectionState.CONNECTED && 465 if (connectionState == ConnectionState.CONNECTED &&
445 - (!hasPublished || publisher.peerConnection.connectionState().isConnected()) 466 + (!hasPublished || publisher?.isConnected() == true)
446 ) { 467 ) {
447 client.onPCConnected() 468 client.onPCConnected()
448 listener?.onPostReconnect(isFullReconnect) 469 listener?.onPostReconnect(isFullReconnect)
@@ -475,7 +496,7 @@ internal constructor( @@ -475,7 +496,7 @@ internal constructor(
475 hasPublished = true 496 hasPublished = true
476 497
477 coroutineScope.launch { 498 coroutineScope.launch {
478 - publisher.negotiate(getPublisherOfferConstraints()) 499 + publisher?.negotiate?.invoke(getPublisherOfferConstraints())
479 } 500 }
480 } 501 }
481 502
@@ -498,12 +519,12 @@ internal constructor( @@ -498,12 +519,12 @@ internal constructor(
498 return 519 return
499 } 520 }
500 521
501 - if (_publisher == null) { 522 + if (publisher == null) {
502 throw RoomException.ConnectException("Publisher isn't setup yet! Is room not connected?!") 523 throw RoomException.ConnectException("Publisher isn't setup yet! Is room not connected?!")
503 } 524 }
504 525
505 - if (!publisher.peerConnection.isConnected() &&  
506 - publisher.peerConnection.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING 526 + if (publisher?.isConnected() != true &&
  527 + publisher?.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING
507 ) { 528 ) {
508 // start negotiation 529 // start negotiation
509 this.negotiatePublisher() 530 this.negotiatePublisher()
@@ -517,7 +538,7 @@ internal constructor( @@ -517,7 +538,7 @@ internal constructor(
517 // wait until publisher ICE connected 538 // wait until publisher ICE connected
518 val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS 539 val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
519 while (SystemClock.elapsedRealtime() < endTime) { 540 while (SystemClock.elapsedRealtime() < endTime) {
520 - if (this.publisher.peerConnection.isConnected() && targetChannel.state() == DataChannel.State.OPEN) { 541 + if (publisher?.isConnected() == true && targetChannel.state() == DataChannel.State.OPEN) {
521 return 542 return
522 } 543 }
523 delay(50) 544 delay(50)
@@ -676,10 +697,11 @@ internal constructor( @@ -676,10 +697,11 @@ internal constructor(
676 // ---------------------------------- SignalClient.Listener --------------------------------------// 697 // ---------------------------------- SignalClient.Listener --------------------------------------//
677 698
678 override fun onAnswer(sessionDescription: SessionDescription) { 699 override fun onAnswer(sessionDescription: SessionDescription) {
679 - LKLog.v { "received server answer: ${sessionDescription.type}, ${publisher.peerConnection.signalingState()}" } 700 + val signalingState = runBlocking { publisher?.signalingState() }
  701 + LKLog.v { "received server answer: ${sessionDescription.type}, $signalingState" }
680 coroutineScope.launch { 702 coroutineScope.launch {
681 LKLog.i { sessionDescription.toString() } 703 LKLog.i { sessionDescription.toString() }
682 - when (val outcome = publisher.setRemoteDescription(sessionDescription)) { 704 + when (val outcome = publisher?.setRemoteDescription(sessionDescription)) {
683 is Either.Left -> { 705 is Either.Left -> {
684 // do nothing. 706 // do nothing.
685 } 707 }
@@ -687,49 +709,71 @@ internal constructor( @@ -687,49 +709,71 @@ internal constructor(
687 is Either.Right -> { 709 is Either.Right -> {
688 LKLog.e { "error setting remote description for answer: ${outcome.value} " } 710 LKLog.e { "error setting remote description for answer: ${outcome.value} " }
689 } 711 }
  712 +
  713 + else -> {
  714 + LKLog.w { "publisher is null, can't set remote description." }
  715 + }
690 } 716 }
691 } 717 }
692 } 718 }
693 719
694 override fun onOffer(sessionDescription: SessionDescription) { 720 override fun onOffer(sessionDescription: SessionDescription) {
695 - LKLog.v { "received server offer: ${sessionDescription.type}, ${subscriber.peerConnection.signalingState()}" } 721 + val signalingState = runBlocking { publisher?.signalingState() }
  722 + LKLog.v { "received server offer: ${sessionDescription.type}, $signalingState" }
696 coroutineScope.launch { 723 coroutineScope.launch {
697 - run<Unit> {  
698 - when (  
699 - val outcome =  
700 - subscriber.setRemoteDescription(sessionDescription)  
701 - ) {  
702 - is Either.Right -> {  
703 - LKLog.e { "error setting remote description for answer: ${outcome.value} " }  
704 - return@launch 724 + // TODO: This is a potentially very long lock hold. May need to break up.
  725 + val answer = subscriber?.withPeerConnection {
  726 + run {
  727 + when (
  728 + val outcome =
  729 + subscriber?.setRemoteDescription(sessionDescription)
  730 + ) {
  731 + is Either.Right -> {
  732 + LKLog.e { "error setting remote description for answer: ${outcome.value} " }
  733 + return@withPeerConnection null
  734 + }
  735 +
  736 + else -> {}
705 } 737 }
  738 + }
706 739
707 - else -> {} 740 + if (isClosed) {
  741 + return@withPeerConnection null
708 } 742 }
709 - }  
710 743
711 - val answer = run {  
712 - when (val outcome = subscriber.peerConnection.createAnswer(MediaConstraints())) {  
713 - is Either.Left -> outcome.value  
714 - is Either.Right -> {  
715 - LKLog.e { "error creating answer: ${outcome.value}" }  
716 - return@launch 744 + val answer = run {
  745 + when (val outcome = createAnswer(MediaConstraints())) {
  746 + is Either.Left -> outcome.value
  747 + is Either.Right -> {
  748 + LKLog.e { "error creating answer: ${outcome.value}" }
  749 + return@withPeerConnection null
  750 + }
717 } 751 }
718 } 752 }
719 - }  
720 753
721 - run<Unit> {  
722 - when (val outcome = subscriber.peerConnection.setLocalDescription(answer)) {  
723 - is Either.Right -> {  
724 - LKLog.e { "error setting local description for answer: ${outcome.value}" }  
725 - return@launch 754 + if (isClosed) {
  755 + return@withPeerConnection null
  756 + }
  757 +
  758 + run<Unit> {
  759 + when (val outcome = setLocalDescription(answer)) {
  760 + is Either.Right -> {
  761 + LKLog.e { "error setting local description for answer: ${outcome.value}" }
  762 + return@withPeerConnection null
  763 + }
  764 +
  765 + else -> {}
726 } 766 }
  767 + }
727 768
728 - else -> {} 769 + if (isClosed) {
  770 + return@withPeerConnection null
729 } 771 }
  772 + return@withPeerConnection answer
  773 + }
  774 + answer?.let {
  775 + client.sendAnswer(it)
730 } 776 }
731 -  
732 - client.sendAnswer(answer)  
733 } 777 }
734 } 778 }
735 779
@@ -737,14 +781,15 @@ internal constructor( @@ -737,14 +781,15 @@ internal constructor(
737 LKLog.v { "received ice candidate from peer: $candidate, $target" } 781 LKLog.v { "received ice candidate from peer: $candidate, $target" }
738 when (target) { 782 when (target) {
739 LivekitRtc.SignalTarget.PUBLISHER -> { 783 LivekitRtc.SignalTarget.PUBLISHER -> {
740 - if (_publisher != null) {  
741 - publisher.addIceCandidate(candidate)  
742 - } else {  
743 - LKLog.w { "received candidate for publisher when we don't have one. ignoring." }  
744 - } 784 + publisher?.addIceCandidate(candidate)
  785 + ?: LKLog.w { "received candidate for publisher when we don't have one. ignoring." }
  786 + }
  787 +
  788 + LivekitRtc.SignalTarget.SUBSCRIBER -> {
  789 + subscriber?.addIceCandidate(candidate)
  790 + ?: LKLog.w { "received candidate for subscriber when we don't have one. ignoring." }
745 } 791 }
746 792
747 - LivekitRtc.SignalTarget.SUBSCRIBER -> subscriber.addIceCandidate(candidate)  
748 else -> LKLog.i { "unknown ice candidate target?" } 793 else -> LKLog.i { "unknown ice candidate target?" }
749 } 794 }
750 } 795 }
@@ -866,7 +911,9 @@ internal constructor( @@ -866,7 +911,9 @@ internal constructor(
866 subscription: LivekitRtc.UpdateSubscription, 911 subscription: LivekitRtc.UpdateSubscription,
867 publishedTracks: List<LivekitRtc.TrackPublishedResponse>, 912 publishedTracks: List<LivekitRtc.TrackPublishedResponse>,
868 ) { 913 ) {
869 - val answer = subscriber.peerConnection.localDescription?.toProtoSessionDescription() 914 + val answer = runBlocking {
  915 + subscriber?.withPeerConnection { localDescription?.toProtoSessionDescription() }
  916 + }
870 917
871 val dataChannelInfos = LivekitModels.DataPacket.Kind.values() 918 val dataChannelInfos = LivekitModels.DataPacket.Kind.values()
872 .toList() 919 .toList()
@@ -892,12 +939,66 @@ internal constructor( @@ -892,12 +939,66 @@ internal constructor(
892 } 939 }
893 940
894 fun getPublisherRTCStats(callback: RTCStatsCollectorCallback) { 941 fun getPublisherRTCStats(callback: RTCStatsCollectorCallback) {
895 - _publisher?.peerConnection?.getStats(callback) ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap())) 942 + runBlocking {
  943 + publisher?.withPeerConnection { getStats(callback) }
  944 + ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
  945 + }
896 } 946 }
897 947
898 fun getSubscriberRTCStats(callback: RTCStatsCollectorCallback) { 948 fun getSubscriberRTCStats(callback: RTCStatsCollectorCallback) {
899 - _subscriber?.peerConnection?.getStats(callback) ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap())) 949 + runBlocking {
  950 + subscriber?.withPeerConnection { getStats(callback) }
  951 + ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
  952 + }
900 } 953 }
  954 +
  955 + fun createStatsGetter(sender: RtpSender): RTCStatsGetter {
  956 + val p = publisher
  957 + return { statsCallback: RTCStatsCollectorCallback ->
  958 + runBlocking {
  959 + p?.withPeerConnection {
  960 + getStats(sender, statsCallback)
  961 + } ?: statsCallback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
  962 + }
  963 + }
  964 + }
  965 +
  966 + fun createStatsGetter(receiver: RtpReceiver): RTCStatsGetter {
  967 + val p = subscriber
  968 + return { statsCallback: RTCStatsCollectorCallback ->
  969 + runBlocking {
  970 + p?.withPeerConnection {
  971 + getStats(receiver, statsCallback)
  972 + } ?: statsCallback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
  973 + }
  974 + }
  975 + }
  976 +
  977 + internal fun registerTrackBitrateInfo(cid: String, trackBitrateInfo: TrackBitrateInfo) {
  978 + publisher?.registerTrackBitrateInfo(cid, trackBitrateInfo)
  979 + }
  980 +
  981 + internal fun removeTrack(rtcTrack: MediaStreamTrack) {
  982 + runBlocking {
  983 + publisher?.withPeerConnection {
  984 + val senders = this.senders
  985 + for (sender in senders) {
  986 + val t = sender.track() ?: continue
  987 + if (t.id() == rtcTrack.id()) {
  988 + this@withPeerConnection.removeTrack(sender)
  989 + }
  990 + }
  991 + }
  992 + }
  993 + }
  994 +
  995 + @VisibleForTesting
  996 + internal suspend fun getPublisherPeerConnection() =
  997 + publisher?.withPeerConnection { this }!!
  998 +
  999 + @VisibleForTesting
  1000 + internal suspend fun getSubscriberPeerConnection() =
  1001 + subscriber?.withPeerConnection { this }!!
901 } 1002 }
902 1003
903 /** 1004 /**
@@ -42,7 +42,6 @@ import io.livekit.android.util.FlowObservable @@ -42,7 +42,6 @@ import io.livekit.android.util.FlowObservable
42 import io.livekit.android.util.LKLog 42 import io.livekit.android.util.LKLog
43 import io.livekit.android.util.flowDelegate 43 import io.livekit.android.util.flowDelegate
44 import io.livekit.android.util.invoke 44 import io.livekit.android.util.invoke
45 -import io.livekit.android.webrtc.createStatsGetter  
46 import io.livekit.android.webrtc.getFilteredStats 45 import io.livekit.android.webrtc.getFilteredStats
47 import kotlinx.coroutines.* 46 import kotlinx.coroutines.*
48 import livekit.LivekitModels 47 import livekit.LivekitModels
@@ -708,7 +707,7 @@ constructor( @@ -708,7 +707,7 @@ constructor(
708 trackSid = track.id() 707 trackSid = track.id()
709 } 708 }
710 val participant = getOrCreateRemoteParticipant(participantSid) 709 val participant = getOrCreateRemoteParticipant(participantSid)
711 - val statsGetter = createStatsGetter(engine.subscriber.peerConnection, receiver) 710 + val statsGetter = engine.createStatsGetter(receiver)
712 participant.addSubscribedMediaTrack( 711 participant.addSubscribedMediaTrack(
713 track, 712 track,
714 trackSid!!, 713 trackSid!!,
@@ -88,6 +88,8 @@ constructor( @@ -88,6 +88,8 @@ constructor(
88 private var requestFlowJob: Job? = null 88 private var requestFlowJob: Job? = null
89 private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE) 89 private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
90 90
  91 + private val responseFlowJobLock = Object()
  92 + private var responseFlowJob: Job? = null
91 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE) 93 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
92 94
93 private var pingJob: Job? = null 95 private var pingJob: Job? = null
@@ -202,10 +204,17 @@ constructor( @@ -202,10 +204,17 @@ constructor(
202 * Should be called after resolving the join message. 204 * Should be called after resolving the join message.
203 */ 205 */
204 fun onReadyForResponses() { 206 fun onReadyForResponses() {
205 - coroutineScope.launch {  
206 - responseFlow.collect {  
207 - responseFlow.resetReplayCache()  
208 - handleSignalResponseImpl(it) 207 + if (responseFlowJob != null) {
  208 + return
  209 + }
  210 + synchronized(responseFlowJobLock) {
  211 + if (responseFlowJob == null) {
  212 + responseFlowJob = coroutineScope.launch {
  213 + responseFlow.collect {
  214 + responseFlow.resetReplayCache()
  215 + handleSignalResponseImpl(it)
  216 + }
  217 + }
209 } 218 }
210 } 219 }
211 } 220 }
@@ -378,7 +387,7 @@ constructor( @@ -378,7 +387,7 @@ constructor(
378 type: LivekitModels.TrackType, 387 type: LivekitModels.TrackType,
379 builder: LivekitRtc.AddTrackRequest.Builder = LivekitRtc.AddTrackRequest.newBuilder(), 388 builder: LivekitRtc.AddTrackRequest.Builder = LivekitRtc.AddTrackRequest.newBuilder(),
380 ) { 389 ) {
381 - var encryptionType = lastRoomOptions?.e2eeOptions?.encryptionType ?: LivekitModels.Encryption.Type.NONE 390 + val encryptionType = lastRoomOptions?.e2eeOptions?.encryptionType ?: LivekitModels.Encryption.Type.NONE
382 val addTrackRequest = builder 391 val addTrackRequest = builder
383 .setCid(cid) 392 .setCid(cid)
384 .setName(name) 393 .setName(name)
@@ -731,6 +740,8 @@ constructor( @@ -731,6 +740,8 @@ constructor(
731 } 740 }
732 requestFlowJob?.cancel() 741 requestFlowJob?.cancel()
733 requestFlowJob = null 742 requestFlowJob = null
  743 + responseFlowJob?.cancel()
  744 + responseFlowJob = null
734 pingJob?.cancel() 745 pingJob?.cancel()
735 pingJob = null 746 pingJob = null
736 pongJob?.cancel() 747 pongJob?.cancel()
@@ -34,7 +34,6 @@ import io.livekit.android.room.isSVCCodec @@ -34,7 +34,6 @@ import io.livekit.android.room.isSVCCodec
34 import io.livekit.android.room.track.* 34 import io.livekit.android.room.track.*
35 import io.livekit.android.room.util.EncodingUtils 35 import io.livekit.android.room.util.EncodingUtils
36 import io.livekit.android.util.LKLog 36 import io.livekit.android.util.LKLog
37 -import io.livekit.android.webrtc.createStatsGetter  
38 import io.livekit.android.webrtc.sortVideoCodecPreferences 37 import io.livekit.android.webrtc.sortVideoCodecPreferences
39 import kotlinx.coroutines.CoroutineDispatcher 38 import kotlinx.coroutines.CoroutineDispatcher
40 import kotlinx.coroutines.launch 39 import kotlinx.coroutines.launch
@@ -393,12 +392,12 @@ internal constructor( @@ -393,12 +392,12 @@ internal constructor(
393 return false 392 return false
394 } 393 }
395 394
396 - track.statsGetter = createStatsGetter(engine.publisher.peerConnection, transceiver.sender) 395 + track.statsGetter = engine.createStatsGetter(transceiver.sender)
397 396
398 // Handle trackBitrates 397 // Handle trackBitrates
399 if (encodings.isNotEmpty()) { 398 if (encodings.isNotEmpty()) {
400 if (options is VideoTrackPublishOptions && isSVCCodec(options.videoCodec) && encodings.firstOrNull()?.maxBitrateBps != null) { 399 if (options is VideoTrackPublishOptions && isSVCCodec(options.videoCodec) && encodings.firstOrNull()?.maxBitrateBps != null) {
401 - engine.publisher.registerTrackBitrateInfo( 400 + engine.registerTrackBitrateInfo(
402 cid = cid, 401 cid = cid,
403 TrackBitrateInfo( 402 TrackBitrateInfo(
404 codec = options.videoCodec, 403 codec = options.videoCodec,
@@ -556,13 +555,7 @@ internal constructor( @@ -556,13 +555,7 @@ internal constructor(
556 tracks = tracks.toMutableMap().apply { remove(sid) } 555 tracks = tracks.toMutableMap().apply { remove(sid) }
557 556
558 if (engine.connectionState == ConnectionState.CONNECTED) { 557 if (engine.connectionState == ConnectionState.CONNECTED) {
559 - val senders = engine.publisher.peerConnection.senders  
560 - for (sender in senders) {  
561 - val t = sender.track() ?: continue  
562 - if (t.id() == track.rtcTrack.id()) {  
563 - engine.publisher.peerConnection.removeTrack(sender)  
564 - }  
565 - } 558 + engine.removeTrack(track.rtcTrack)
566 } 559 }
567 if (stopOnUnpublish) { 560 if (stopOnUnpublish) {
568 track.stop() 561 track.stop()
@@ -17,6 +17,10 @@ @@ -17,6 +17,10 @@
17 package io.livekit.android.room.util 17 package io.livekit.android.room.util
18 18
19 import io.livekit.android.util.Either 19 import io.livekit.android.util.Either
  20 +import kotlinx.coroutines.runBlocking
  21 +import kotlinx.coroutines.suspendCancellableCoroutine
  22 +import kotlinx.coroutines.sync.Mutex
  23 +import kotlinx.coroutines.sync.withLock
20 import org.webrtc.MediaConstraints 24 import org.webrtc.MediaConstraints
21 import org.webrtc.PeerConnection 25 import org.webrtc.PeerConnection
22 import org.webrtc.SdpObserver 26 import org.webrtc.SdpObserver
@@ -26,26 +30,47 @@ import kotlin.coroutines.resume @@ -26,26 +30,47 @@ import kotlin.coroutines.resume
26 import kotlin.coroutines.suspendCoroutine 30 import kotlin.coroutines.suspendCoroutine
27 31
28 open class CoroutineSdpObserver : SdpObserver { 32 open class CoroutineSdpObserver : SdpObserver {
  33 +
  34 + private val stateLock = Mutex()
29 private var createOutcome: Either<SessionDescription, String?>? = null 35 private var createOutcome: Either<SessionDescription, String?>? = null
30 set(value) { 36 set(value) {
31 - field = value 37 + val conts = runBlocking {
  38 + stateLock.withLock {
  39 + field = value
  40 + if (value != null) {
  41 + val conts = pendingCreate.toList()
  42 + pendingCreate.clear()
  43 + conts
  44 + } else {
  45 + null
  46 + }
  47 + }
  48 + }
32 if (value != null) { 49 if (value != null) {
33 - val conts = pendingCreate.toList()  
34 - pendingCreate.clear()  
35 - conts.forEach { 50 + conts?.forEach {
36 it.resume(value) 51 it.resume(value)
37 } 52 }
38 } 53 }
39 } 54 }
  55 +
40 private var pendingCreate = mutableListOf<Continuation<Either<SessionDescription, String?>>>() 56 private var pendingCreate = mutableListOf<Continuation<Either<SessionDescription, String?>>>()
41 57
42 private var setOutcome: Either<Unit, String?>? = null 58 private var setOutcome: Either<Unit, String?>? = null
43 set(value) { 59 set(value) {
44 - field = value 60 + val conts = runBlocking {
  61 + stateLock.withLock {
  62 + field = value
  63 + if (value != null) {
  64 + val conts = pendingSets.toList()
  65 + pendingSets.clear()
  66 + conts
  67 + } else {
  68 + null
  69 + }
  70 + }
  71 + }
45 if (value != null) { 72 if (value != null) {
46 - val conts = pendingSets.toList()  
47 - pendingSets.clear()  
48 - conts.forEach { 73 + conts?.forEach {
49 it.resume(value) 74 it.resume(value)
50 } 75 }
51 } 76 }
@@ -72,21 +97,41 @@ open class CoroutineSdpObserver : SdpObserver { @@ -72,21 +97,41 @@ open class CoroutineSdpObserver : SdpObserver {
72 setOutcome = Either.Right(message) 97 setOutcome = Either.Right(message)
73 } 98 }
74 99
75 - suspend fun awaitCreate() = suspendCoroutine { cont ->  
76 - val curOutcome = createOutcome  
77 - if (curOutcome != null) {  
78 - cont.resume(curOutcome) 100 + suspend fun awaitCreate() = suspendCancellableCoroutine { cont ->
  101 + val unlockedOutcome = createOutcome
  102 + if (unlockedOutcome != null) {
  103 + cont.resume(unlockedOutcome)
79 } else { 104 } else {
80 - pendingCreate.add(cont) 105 + runBlocking {
  106 + stateLock.lock()
  107 + val lockedOutcome = createOutcome
  108 + if (lockedOutcome != null) {
  109 + stateLock.unlock()
  110 + cont.resume(lockedOutcome)
  111 + } else {
  112 + pendingCreate.add(cont)
  113 + stateLock.unlock()
  114 + }
  115 + }
81 } 116 }
82 } 117 }
83 118
84 suspend fun awaitSet() = suspendCoroutine { cont -> 119 suspend fun awaitSet() = suspendCoroutine { cont ->
85 - val curOutcome = setOutcome  
86 - if (curOutcome != null) {  
87 - cont.resume(curOutcome) 120 + val unlockedOutcome = setOutcome
  121 + if (unlockedOutcome != null) {
  122 + cont.resume(unlockedOutcome)
88 } else { 123 } else {
89 - pendingSets.add(cont) 124 + runBlocking {
  125 + stateLock.lock()
  126 + val lockedOutcome = setOutcome
  127 + if (lockedOutcome != null) {
  128 + stateLock.unlock()
  129 + cont.resume(lockedOutcome)
  130 + } else {
  131 + pendingSets.add(cont)
  132 + stateLock.unlock()
  133 + }
  134 + }
90 } 135 }
91 } 136 }
92 } 137 }
@@ -19,12 +19,9 @@ package io.livekit.android.webrtc @@ -19,12 +19,9 @@ package io.livekit.android.webrtc
19 import io.livekit.android.util.LKLog 19 import io.livekit.android.util.LKLog
20 import kotlinx.coroutines.suspendCancellableCoroutine 20 import kotlinx.coroutines.suspendCancellableCoroutine
21 import org.webrtc.MediaStreamTrack 21 import org.webrtc.MediaStreamTrack
22 -import org.webrtc.PeerConnection  
23 import org.webrtc.RTCStats 22 import org.webrtc.RTCStats
24 import org.webrtc.RTCStatsCollectorCallback 23 import org.webrtc.RTCStatsCollectorCallback
25 import org.webrtc.RTCStatsReport 24 import org.webrtc.RTCStatsReport
26 -import org.webrtc.RtpReceiver  
27 -import org.webrtc.RtpSender  
28 import kotlin.coroutines.resume 25 import kotlin.coroutines.resume
29 26
30 /** 27 /**
@@ -174,13 +171,3 @@ suspend fun RTCStatsGetter.getStats(): RTCStatsReport = suspendCancellableCorout @@ -174,13 +171,3 @@ suspend fun RTCStatsGetter.getStats(): RTCStatsReport = suspendCancellableCorout
174 } 171 }
175 this.invoke(listener) 172 this.invoke(listener)
176 } 173 }
177 -  
178 -fun createStatsGetter(peerConnection: PeerConnection, sender: RtpSender): RTCStatsGetter =  
179 - { statsCallback: RTCStatsCollectorCallback ->  
180 - peerConnection.getStats(sender, statsCallback)  
181 - }  
182 -  
183 -fun createStatsGetter(peerConnection: PeerConnection, receiver: RtpReceiver): RTCStatsGetter =  
184 - { statsCallback: RTCStatsCollectorCallback ->  
185 - peerConnection.getStats(receiver, statsCallback)  
186 - }  
@@ -24,7 +24,6 @@ import io.livekit.android.mock.MockWebSocketFactory @@ -24,7 +24,6 @@ import io.livekit.android.mock.MockWebSocketFactory
24 import io.livekit.android.mock.dagger.DaggerTestLiveKitComponent 24 import io.livekit.android.mock.dagger.DaggerTestLiveKitComponent
25 import io.livekit.android.mock.dagger.TestCoroutinesModule 25 import io.livekit.android.mock.dagger.TestCoroutinesModule
26 import io.livekit.android.mock.dagger.TestLiveKitComponent 26 import io.livekit.android.mock.dagger.TestLiveKitComponent
27 -import io.livekit.android.room.PeerConnectionTransport  
28 import io.livekit.android.room.Room 27 import io.livekit.android.room.Room
29 import io.livekit.android.room.SignalClientTest 28 import io.livekit.android.room.SignalClientTest
30 import io.livekit.android.util.toOkioByteString 29 import io.livekit.android.util.toOkioByteString
@@ -45,7 +44,6 @@ abstract class MockE2ETest : BaseTest() { @@ -45,7 +44,6 @@ abstract class MockE2ETest : BaseTest() {
45 internal lateinit var context: Context 44 internal lateinit var context: Context
46 internal lateinit var room: Room 45 internal lateinit var room: Room
47 internal lateinit var wsFactory: MockWebSocketFactory 46 internal lateinit var wsFactory: MockWebSocketFactory
48 - internal lateinit var subscriber: PeerConnectionTransport  
49 47
50 @Before 48 @Before
51 fun mocksSetup() { 49 fun mocksSetup() {
@@ -77,16 +75,26 @@ abstract class MockE2ETest : BaseTest() { @@ -77,16 +75,26 @@ abstract class MockE2ETest : BaseTest() {
77 job.join() 75 job.join()
78 } 76 }
79 77
80 - fun connectPeerConnection() {  
81 - subscriber = component.rtcEngine().subscriber 78 + suspend fun getSubscriberPeerConnection() =
  79 + component
  80 + .rtcEngine()
  81 + .getSubscriberPeerConnection() as MockPeerConnection
  82 +
  83 + suspend fun getPublisherPeerConnection() =
  84 + component
  85 + .rtcEngine()
  86 + .getPublisherPeerConnection() as MockPeerConnection
  87 +
  88 + suspend fun connectPeerConnection() {
82 simulateMessageFromServer(SignalClientTest.OFFER) 89 simulateMessageFromServer(SignalClientTest.OFFER)
83 - val subPeerConnection = subscriber.peerConnection as MockPeerConnection 90 + val subPeerConnection = getSubscriberPeerConnection()
84 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED) 91 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
85 } 92 }
86 93
87 - fun disconnectPeerConnection() {  
88 - subscriber = component.rtcEngine().subscriber  
89 - val subPeerConnection = subscriber.peerConnection as MockPeerConnection 94 + suspend fun disconnectPeerConnection() {
  95 + val subPeerConnection = component
  96 + .rtcEngine()
  97 + .getSubscriberPeerConnection() as MockPeerConnection
90 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 98 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
91 } 99 }
92 100
@@ -16,10 +16,14 @@ @@ -16,10 +16,14 @@
16 16
17 package io.livekit.android.mock 17 package io.livekit.android.mock
18 18
  19 +import com.google.protobuf.MessageLite
19 import io.livekit.android.util.toOkioByteString 20 import io.livekit.android.util.toOkioByteString
20 import io.livekit.android.util.toPBByteString 21 import io.livekit.android.util.toPBByteString
21 import livekit.LivekitModels 22 import livekit.LivekitModels
22 -import livekit.LivekitRtc 23 +import livekit.LivekitRtc.LeaveRequest
  24 +import livekit.LivekitRtc.SignalRequest
  25 +import livekit.LivekitRtc.SignalResponse
  26 +import livekit.LivekitRtc.TrackPublishedResponse
23 import okhttp3.Request 27 import okhttp3.Request
24 import okhttp3.WebSocket 28 import okhttp3.WebSocket
25 import okhttp3.WebSocketListener 29 import okhttp3.WebSocketListener
@@ -42,34 +46,80 @@ class MockWebSocketFactory : WebSocket.Factory { @@ -42,34 +46,80 @@ class MockWebSocketFactory : WebSocket.Factory {
42 lateinit var listener: WebSocketListener 46 lateinit var listener: WebSocketListener
43 override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket { 47 override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
44 this.ws = MockWebSocket(request, listener) { byteString -> 48 this.ws = MockWebSocket(request, listener) { byteString ->
45 - val signalRequest = LivekitRtc.SignalRequest.parseFrom(byteString.toPBByteString())  
46 - if (signalRequest.hasAddTrack()) { 49 + val signalRequest = SignalRequest.parseFrom(byteString.toPBByteString())
  50 + handleSignalRequest(signalRequest)
  51 + }
  52 + this.listener = listener
  53 + this.request = request
  54 +
  55 + onOpen?.invoke(this)
  56 + return ws
  57 + }
  58 +
  59 + private val signalRequestHandlers = mutableListOf<SignalRequestHandler>(
  60 + { signalRequest -> defaultHandleSignalRequest(signalRequest) },
  61 + )
  62 +
  63 + fun registerSignalRequestHandler(handler: SignalRequestHandler) {
  64 + signalRequestHandlers.add(0, handler)
  65 + }
  66 +
  67 + private fun handleSignalRequest(signalRequest: SignalRequest) {
  68 + for (handler in signalRequestHandlers) {
  69 + if (handler.invoke(signalRequest)) {
  70 + break
  71 + }
  72 + }
  73 + }
  74 +
  75 + private fun defaultHandleSignalRequest(signalRequest: SignalRequest): Boolean {
  76 + when (signalRequest.messageCase) {
  77 + SignalRequest.MessageCase.ADD_TRACK -> {
47 val addTrack = signalRequest.addTrack 78 val addTrack = signalRequest.addTrack
48 - val trackPublished = with(LivekitRtc.SignalResponse.newBuilder()) {  
49 - trackPublished = with(LivekitRtc.TrackPublishedResponse.newBuilder()) { 79 + val trackPublished = with(SignalResponse.newBuilder()) {
  80 + trackPublished = with(TrackPublishedResponse.newBuilder()) {
50 cid = addTrack.cid 81 cid = addTrack.cid
51 - if (addTrack.type == LivekitModels.TrackType.AUDIO) {  
52 - track = TestData.LOCAL_AUDIO_TRACK 82 + track = if (addTrack.type == LivekitModels.TrackType.AUDIO) {
  83 + TestData.LOCAL_AUDIO_TRACK
53 } else { 84 } else {
54 - track = TestData.LOCAL_VIDEO_TRACK 85 + TestData.LOCAL_VIDEO_TRACK
55 } 86 }
56 build() 87 build()
57 } 88 }
58 build() 89 build()
59 } 90 }
60 - this.listener.onMessage(this.ws, trackPublished.toOkioByteString()) 91 + receiveMessage(trackPublished)
  92 + return true
61 } 93 }
62 - }  
63 - this.listener = listener  
64 - this.request = request  
65 94
66 - onOpen?.invoke(this)  
67 - return ws 95 + SignalRequest.MessageCase.LEAVE -> {
  96 + val leaveResponse = with(SignalResponse.newBuilder()) {
  97 + leave = with(LeaveRequest.newBuilder()) {
  98 + canReconnect = false
  99 + reason = LivekitModels.DisconnectReason.CLIENT_INITIATED
  100 + build()
  101 + }
  102 + build()
  103 + }
  104 + receiveMessage(leaveResponse)
  105 + return true
  106 + }
  107 +
  108 + else -> {
  109 + return false
  110 + }
  111 + }
68 } 112 }
69 113
70 var onOpen: ((MockWebSocketFactory) -> Unit)? = null 114 var onOpen: ((MockWebSocketFactory) -> Unit)? = null
71 115
  116 + fun receiveMessage(message: MessageLite) {
  117 + receiveMessage(message.toOkioByteString())
  118 + }
  119 +
72 fun receiveMessage(byteString: ByteString) { 120 fun receiveMessage(byteString: ByteString) {
73 listener.onMessage(ws, byteString) 121 listener.onMessage(ws, byteString)
74 } 122 }
75 } 123 }
  124 +
  125 +typealias SignalRequestHandler = (SignalRequest) -> Boolean
@@ -17,7 +17,6 @@ @@ -17,7 +17,6 @@
17 package io.livekit.android.room 17 package io.livekit.android.room
18 18
19 import io.livekit.android.MockE2ETest 19 import io.livekit.android.MockE2ETest
20 -import io.livekit.android.mock.MockPeerConnection  
21 import io.livekit.android.util.toOkioByteString 20 import io.livekit.android.util.toOkioByteString
22 import io.livekit.android.util.toPBByteString 21 import io.livekit.android.util.toPBByteString
23 import kotlinx.coroutines.ExperimentalCoroutinesApi 22 import kotlinx.coroutines.ExperimentalCoroutinesApi
@@ -47,7 +46,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -47,7 +46,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
47 connect() 46 connect()
48 val sentIceServers = SignalClientTest.JOIN.join.iceServersList 47 val sentIceServers = SignalClientTest.JOIN.join.iceServersList
49 .map { it.toWebrtc() } 48 .map { it.toWebrtc() }
50 - val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection 49 + val subPeerConnection = getSubscriberPeerConnection()
51 50
52 assertEquals(sentIceServers, subPeerConnection.rtcConfig.iceServers) 51 assertEquals(sentIceServers, subPeerConnection.rtcConfig.iceServers)
53 } 52 }
@@ -57,7 +56,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -57,7 +56,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
57 connect() 56 connect()
58 assertEquals( 57 assertEquals(
59 SignalClientTest.OFFER.offer.sdp, 58 SignalClientTest.OFFER.offer.sdp,
60 - rtcEngine.subscriber.peerConnection.remoteDescription.description, 59 + getSubscriberPeerConnection().remoteDescription?.description,
61 ) 60 )
62 61
63 val ws = wsFactory.ws 62 val ws = wsFactory.ws
@@ -65,7 +64,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -65,7 +64,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
65 .mergeFrom(ws.sentRequests[0].toPBByteString()) 64 .mergeFrom(ws.sentRequests[0].toPBByteString())
66 .build() 65 .build()
67 66
68 - val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection 67 + val subPeerConnection = getSubscriberPeerConnection()
69 val localAnswer = subPeerConnection.localDescription ?: throw IllegalStateException("no answer was created.") 68 val localAnswer = subPeerConnection.localDescription ?: throw IllegalStateException("no answer was created.")
70 Assert.assertTrue(sentRequest.hasAnswer()) 69 Assert.assertTrue(sentRequest.hasAnswer())
71 assertEquals(localAnswer.description, sentRequest.answer.sdp) 70 assertEquals(localAnswer.description, sentRequest.answer.sdp)
@@ -88,7 +87,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -88,7 +87,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
88 connect() 87 connect()
89 val oldWs = wsFactory.ws 88 val oldWs = wsFactory.ws
90 89
91 - val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection 90 + val subPeerConnection = getSubscriberPeerConnection()
92 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 91 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
93 testScheduler.advanceTimeBy(1000) 92 testScheduler.advanceTimeBy(1000)
94 93
@@ -101,7 +100,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -101,7 +100,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
101 connect() 100 connect()
102 val oldWs = wsFactory.ws 101 val oldWs = wsFactory.ws
103 102
104 - val pubPeerConnection = rtcEngine.publisher.peerConnection as MockPeerConnection 103 + val pubPeerConnection = getPublisherPeerConnection()
105 pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 104 pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
106 testScheduler.advanceTimeBy(1000) 105 testScheduler.advanceTimeBy(1000)
107 106
@@ -138,7 +137,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -138,7 +137,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
138 }, 137 },
139 ) 138 )
140 139
141 - val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection 140 + val subPeerConnection = getSubscriberPeerConnection()
142 assertEquals(PeerConnection.IceTransportsType.RELAY, subPeerConnection.rtcConfig.iceTransportsType) 141 assertEquals(PeerConnection.IceTransportsType.RELAY, subPeerConnection.rtcConfig.iceTransportsType)
143 } 142 }
144 143
@@ -34,8 +34,12 @@ import io.livekit.android.room.track.Track @@ -34,8 +34,12 @@ import io.livekit.android.room.track.Track
34 import io.livekit.android.util.flow 34 import io.livekit.android.util.flow
35 import io.livekit.android.util.toOkioByteString 35 import io.livekit.android.util.toOkioByteString
36 import junit.framework.Assert.assertEquals 36 import junit.framework.Assert.assertEquals
  37 +import kotlinx.coroutines.CoroutineScope
  38 +import kotlinx.coroutines.Dispatchers
37 import kotlinx.coroutines.ExperimentalCoroutinesApi 39 import kotlinx.coroutines.ExperimentalCoroutinesApi
  40 +import kotlinx.coroutines.SupervisorJob
38 import kotlinx.coroutines.launch 41 import kotlinx.coroutines.launch
  42 +import livekit.LivekitRtc
39 import org.junit.Assert 43 import org.junit.Assert
40 import org.junit.Test 44 import org.junit.Test
41 import org.junit.runner.RunWith 45 import org.junit.runner.RunWith
@@ -336,6 +340,52 @@ class RoomMockE2ETest : MockE2ETest() { @@ -336,6 +340,52 @@ class RoomMockE2ETest : MockE2ETest() {
336 ) 340 )
337 341
338 val eventCollector = EventCollector(room.events, coroutineRule.scope) 342 val eventCollector = EventCollector(room.events, coroutineRule.scope)
  343 +
  344 + wsFactory.listener.onMessage(
  345 + wsFactory.ws,
  346 + SignalClientTest.LEAVE.toOkioByteString(),
  347 + )
  348 + room.disconnect()
  349 + val events = eventCollector.stopCollecting()
  350 +
  351 + assertEquals(2, events.size)
  352 + assertEquals(true, events[0] is RoomEvent.TrackUnpublished)
  353 + assertEquals(true, events[1] is RoomEvent.Disconnected)
  354 + }
  355 +
  356 + /**
  357 + *
  358 + */
  359 + @Test
  360 + fun disconnectWithTracks() = runTest {
  361 + connect()
  362 +
  363 + val differentThread = CoroutineScope(Dispatchers.IO + SupervisorJob())
  364 + wsFactory.registerSignalRequestHandler {
  365 + if (it.hasLeave()) {
  366 + differentThread.launch {
  367 + val leaveResponse = with(LivekitRtc.SignalResponse.newBuilder()) {
  368 + leave = with(LivekitRtc.LeaveRequest.newBuilder()) {
  369 + canReconnect = false
  370 + reason = livekit.LivekitModels.DisconnectReason.CLIENT_INITIATED
  371 + build()
  372 + }
  373 + build()
  374 + }
  375 + wsFactory.receiveMessage(leaveResponse)
  376 + }
  377 + return@registerSignalRequestHandler true
  378 + }
  379 + return@registerSignalRequestHandler false
  380 + }
  381 + room.localParticipant.publishAudioTrack(
  382 + LocalAudioTrack(
  383 + "",
  384 + MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
  385 + ),
  386 + )
  387 +
  388 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
339 room.disconnect() 389 room.disconnect()
340 val events = eventCollector.stopCollecting() 390 val events = eventCollector.stopCollecting()
341 391
@@ -18,7 +18,6 @@ package io.livekit.android.room @@ -18,7 +18,6 @@ package io.livekit.android.room
18 18
19 import io.livekit.android.MockE2ETest 19 import io.livekit.android.MockE2ETest
20 import io.livekit.android.mock.MockAudioStreamTrack 20 import io.livekit.android.mock.MockAudioStreamTrack
21 -import io.livekit.android.mock.MockPeerConnection  
22 import io.livekit.android.room.track.LocalAudioTrack 21 import io.livekit.android.room.track.LocalAudioTrack
23 import io.livekit.android.util.toPBByteString 22 import io.livekit.android.util.toPBByteString
24 import kotlinx.coroutines.ExperimentalCoroutinesApi 23 import kotlinx.coroutines.ExperimentalCoroutinesApi
@@ -89,8 +88,7 @@ class RoomReconnectionMockE2ETest : MockE2ETest() { @@ -89,8 +88,7 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
89 testScheduler.advanceTimeBy(1000) 88 testScheduler.advanceTimeBy(1000)
90 connectPeerConnection() 89 connectPeerConnection()
91 90
92 - val rtcEngine = component.rtcEngine()  
93 - val rtcConfig = (rtcEngine.subscriber.peerConnection as MockPeerConnection).rtcConfig 91 + val rtcConfig = getSubscriberPeerConnection().rtcConfig
94 assertEquals(PeerConnection.IceTransportsType.RELAY, rtcConfig.iceTransportsType) 92 assertEquals(PeerConnection.IceTransportsType.RELAY, rtcConfig.iceTransportsType)
95 93
96 val sentIceServers = SignalClientTest.RECONNECT.reconnect.iceServersList 94 val sentIceServers = SignalClientTest.RECONNECT.reconnect.iceServersList
@@ -23,7 +23,6 @@ import io.livekit.android.events.ParticipantEvent @@ -23,7 +23,6 @@ import io.livekit.android.events.ParticipantEvent
23 import io.livekit.android.events.RoomEvent 23 import io.livekit.android.events.RoomEvent
24 import io.livekit.android.mock.MockAudioStreamTrack 24 import io.livekit.android.mock.MockAudioStreamTrack
25 import io.livekit.android.mock.MockEglBase 25 import io.livekit.android.mock.MockEglBase
26 -import io.livekit.android.mock.MockPeerConnection  
27 import io.livekit.android.mock.MockVideoCapturer 26 import io.livekit.android.mock.MockVideoCapturer
28 import io.livekit.android.mock.MockVideoStreamTrack 27 import io.livekit.android.mock.MockVideoStreamTrack
29 import io.livekit.android.room.DefaultsManager 28 import io.livekit.android.room.DefaultsManager
@@ -176,7 +175,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() { @@ -176,7 +175,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
176 175
177 room.localParticipant.publishVideoTrack(track = createLocalTrack()) 176 room.localParticipant.publishVideoTrack(track = createLocalTrack())
178 177
179 - val peerConnection = component.rtcEngine().publisher.peerConnection 178 + val peerConnection = getPublisherPeerConnection()
180 val transceiver = peerConnection.transceivers.first() 179 val transceiver = peerConnection.transceivers.first()
181 180
182 Mockito.verify(transceiver).setCodecPreferences( 181 Mockito.verify(transceiver).setCodecPreferences(
@@ -195,7 +194,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() { @@ -195,7 +194,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
195 194
196 room.localParticipant.publishVideoTrack(track = createLocalTrack()) 195 room.localParticipant.publishVideoTrack(track = createLocalTrack())
197 196
198 - val peerConnection = component.rtcEngine().publisher.peerConnection 197 + val peerConnection = getPublisherPeerConnection()
199 val transceiver = peerConnection.transceivers.first() 198 val transceiver = peerConnection.transceivers.first()
200 199
201 Mockito.verify(transceiver).setCodecPreferences( 200 Mockito.verify(transceiver).setCodecPreferences(
@@ -236,7 +235,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() { @@ -236,7 +235,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
236 val vp8Codec = addTrackRequest.simulcastCodecsList[1] 235 val vp8Codec = addTrackRequest.simulcastCodecsList[1]
237 assertEquals("vp8", vp8Codec.codec) 236 assertEquals("vp8", vp8Codec.codec)
238 237
239 - val publisherConn = component.rtcEngine().publisher.peerConnection as MockPeerConnection 238 + val publisherConn = getPublisherPeerConnection()
240 239
241 assertEquals(1, publisherConn.transceivers.size) 240 assertEquals(1, publisherConn.transceivers.size)
242 Mockito.verify(publisherConn.transceivers.first()).setCodecPreferences( 241 Mockito.verify(publisherConn.transceivers.first()).setCodecPreferences(