继续操作前请注册或者登录。
davidliu
Committed by GitHub

Fix memory leak caused by disconnecting before connect finished (#386)

* State locking for Room and RTC engine around critical spots

* Cancel connect job if invoking coroutine is cancelled

* cleanup

* Clean up test logs

* revert stress test changes to sample apps
@@ -144,7 +144,7 @@ constructor( @@ -144,7 +144,7 @@ constructor(
144 restartingIce = true 144 restartingIce = true
145 } 145 }
146 146
147 - if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) { 147 + if (peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
148 // we're waiting for the peer to accept our offer, so we'll just wait 148 // we're waiting for the peer to accept our offer, so we'll just wait
149 // the only exception to this is when ICE restart is needed 149 // the only exception to this is when ICE restart is needed
150 val curSd = peerConnection.remoteDescription 150 val curSd = peerConnection.remoteDescription
@@ -313,7 +313,7 @@ constructor( @@ -313,7 +313,7 @@ constructor(
313 } 313 }
314 314
315 @OptIn(ExperimentalContracts::class) 315 @OptIn(ExperimentalContracts::class)
316 - private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend () -> T): T? { 316 + private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend CoroutineScope.() -> T): T? {
317 contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) } 317 contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
318 if (isClosed()) { 318 if (isClosed()) {
319 return null 319 return null
@@ -35,13 +35,17 @@ import io.livekit.android.util.FlowObservable @@ -35,13 +35,17 @@ import io.livekit.android.util.FlowObservable
35 import io.livekit.android.util.LKLog 35 import io.livekit.android.util.LKLog
36 import io.livekit.android.util.flowDelegate 36 import io.livekit.android.util.flowDelegate
37 import io.livekit.android.util.nullSafe 37 import io.livekit.android.util.nullSafe
  38 +import io.livekit.android.util.withCheckLock
38 import io.livekit.android.webrtc.RTCStatsGetter 39 import io.livekit.android.webrtc.RTCStatsGetter
39 import io.livekit.android.webrtc.copy 40 import io.livekit.android.webrtc.copy
40 import io.livekit.android.webrtc.isConnected 41 import io.livekit.android.webrtc.isConnected
41 import io.livekit.android.webrtc.isDisconnected 42 import io.livekit.android.webrtc.isDisconnected
42 import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread 43 import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
  44 +import io.livekit.android.webrtc.peerconnection.launchBlockingOnRTCThread
43 import io.livekit.android.webrtc.toProtoSessionDescription 45 import io.livekit.android.webrtc.toProtoSessionDescription
44 import kotlinx.coroutines.* 46 import kotlinx.coroutines.*
  47 +import kotlinx.coroutines.sync.Mutex
  48 +import kotlinx.coroutines.sync.withLock
45 import livekit.LivekitModels 49 import livekit.LivekitModels
46 import livekit.LivekitRtc 50 import livekit.LivekitRtc
47 import livekit.LivekitRtc.JoinResponse 51 import livekit.LivekitRtc.JoinResponse
@@ -134,6 +138,12 @@ internal constructor( @@ -134,6 +138,12 @@ internal constructor(
134 138
135 private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher) 139 private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
136 140
  141 + /**
  142 + * Note: If this lock is ever used in conjunction with the RTC thread,
  143 + * this must be grabbed on the RTC thread to prevent deadlocks.
  144 + */
  145 + private var configurationLock = Mutex()
  146 +
137 init { 147 init {
138 client.listener = this 148 client.listener = this
139 } 149 }
@@ -158,8 +168,10 @@ internal constructor( @@ -158,8 +168,10 @@ internal constructor(
158 token: String, 168 token: String,
159 options: ConnectOptions, 169 options: ConnectOptions,
160 roomOptions: RoomOptions, 170 roomOptions: RoomOptions,
161 - ): JoinResponse { 171 + ): JoinResponse = coroutineScope {
162 val joinResponse = client.join(url, token, options, roomOptions) 172 val joinResponse = client.join(url, token, options, roomOptions)
  173 + ensureActive()
  174 +
163 listener?.onJoinResponse(joinResponse) 175 listener?.onJoinResponse(joinResponse)
164 isClosed = false 176 isClosed = false
165 listener?.onSignalConnected(false) 177 listener?.onSignalConnected(false)
@@ -169,19 +181,25 @@ internal constructor( @@ -169,19 +181,25 @@ internal constructor(
169 configure(joinResponse, options) 181 configure(joinResponse, options)
170 182
171 // create offer 183 // create offer
172 - if (!this.isSubscriberPrimary) { 184 + if (!isSubscriberPrimary) {
173 negotiatePublisher() 185 negotiatePublisher()
174 } 186 }
175 client.onReadyForResponses() 187 client.onReadyForResponses()
176 - return joinResponse 188 +
  189 + return@coroutineScope joinResponse
177 } 190 }
178 191
179 private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) { 192 private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
  193 + launchBlockingOnRTCThread {
  194 + configurationLock.withCheckLock(
  195 + {
  196 + ensureActive()
180 if (publisher != null && subscriber != null) { 197 if (publisher != null && subscriber != null) {
181 // already configured 198 // already configured
182 - return 199 + return@launchBlockingOnRTCThread
183 } 200 }
184 - 201 + },
  202 + ) {
185 participantSid = if (joinResponse.hasParticipant()) { 203 participantSid = if (joinResponse.hasParticipant()) {
186 joinResponse.participant.sid 204 joinResponse.participant.sid
187 } else { 205 } else {
@@ -235,6 +253,7 @@ internal constructor( @@ -235,6 +253,7 @@ internal constructor(
235 publisherObserver.connectionChangeListener = connectionStateListener 253 publisherObserver.connectionChangeListener = connectionStateListener
236 } 254 }
237 255
  256 + ensureActive()
238 // data channels 257 // data channels
239 val reliableInit = DataChannel.Init() 258 val reliableInit = DataChannel.Init()
240 reliableInit.ordered = true 259 reliableInit.ordered = true
@@ -247,6 +266,7 @@ internal constructor( @@ -247,6 +266,7 @@ internal constructor(
247 } 266 }
248 } 267 }
249 268
  269 + ensureActive()
250 val lossyInit = DataChannel.Init() 270 val lossyInit = DataChannel.Init()
251 lossyInit.ordered = true 271 lossyInit.ordered = true
252 lossyInit.maxRetransmits = 0 272 lossyInit.maxRetransmits = 0
@@ -259,6 +279,8 @@ internal constructor( @@ -259,6 +279,8 @@ internal constructor(
259 } 279 }
260 } 280 }
261 } 281 }
  282 + }
  283 + }
262 284
263 /** 285 /**
264 * @param builder an optional builder to include other parameters related to the track 286 * @param builder an optional builder to include other parameters related to the track
@@ -327,6 +349,8 @@ internal constructor( @@ -327,6 +349,8 @@ internal constructor(
327 349
328 private fun closeResources(reason: String) { 350 private fun closeResources(reason: String) {
329 executeBlockingOnRTCThread { 351 executeBlockingOnRTCThread {
  352 + runBlocking {
  353 + configurationLock.withLock {
330 publisherObserver.connectionChangeListener = null 354 publisherObserver.connectionChangeListener = null
331 subscriberObserver.connectionChangeListener = null 355 subscriberObserver.connectionChangeListener = null
332 publisher?.closeBlocking() 356 publisher?.closeBlocking()
@@ -339,6 +363,7 @@ internal constructor( @@ -339,6 +363,7 @@ internal constructor(
339 this?.close() 363 this?.close()
340 this?.dispose() 364 this?.dispose()
341 } 365 }
  366 +
342 reliableDataChannel?.completeDispose() 367 reliableDataChannel?.completeDispose()
343 reliableDataChannel = null 368 reliableDataChannel = null
344 reliableDataChannelSub?.completeDispose() 369 reliableDataChannelSub?.completeDispose()
@@ -349,6 +374,8 @@ internal constructor( @@ -349,6 +374,8 @@ internal constructor(
349 lossyDataChannelSub = null 374 lossyDataChannelSub = null
350 isSubscriberPrimary = false 375 isSubscriberPrimary = false
351 } 376 }
  377 + }
  378 + }
352 client.close(reason = reason) 379 client.close(reason = reason)
353 } 380 }
354 381
@@ -49,6 +49,8 @@ import io.livekit.android.webrtc.getFilteredStats @@ -49,6 +49,8 @@ import io.livekit.android.webrtc.getFilteredStats
49 import kotlinx.coroutines.* 49 import kotlinx.coroutines.*
50 import kotlinx.coroutines.flow.filterNotNull 50 import kotlinx.coroutines.flow.filterNotNull
51 import kotlinx.coroutines.flow.first 51 import kotlinx.coroutines.flow.first
  52 +import kotlinx.coroutines.sync.Mutex
  53 +import kotlinx.coroutines.sync.withLock
52 import kotlinx.serialization.Serializable 54 import kotlinx.serialization.Serializable
53 import livekit.LivekitModels 55 import livekit.LivekitModels
54 import livekit.LivekitRtc 56 import livekit.LivekitRtc
@@ -243,6 +245,8 @@ constructor( @@ -243,6 +245,8 @@ constructor(
243 private var hasLostConnectivity: Boolean = false 245 private var hasLostConnectivity: Boolean = false
244 private var connectOptions: ConnectOptions = ConnectOptions() 246 private var connectOptions: ConnectOptions = ConnectOptions()
245 247
  248 + private var stateLock = Mutex()
  249 +
246 private fun getCurrentRoomOptions(): RoomOptions = 250 private fun getCurrentRoomOptions(): RoomOptions =
247 RoomOptions( 251 RoomOptions(
248 adaptiveStream = adaptiveStream, 252 adaptiveStream = adaptiveStream,
@@ -260,15 +264,32 @@ constructor( @@ -260,15 +264,32 @@ constructor(
260 * @param url 264 * @param url
261 * @param token 265 * @param token
262 * @param options 266 * @param options
  267 + *
  268 + * @throws IllegalStateException when connect is attempted while the room is not disconnected.
  269 + * @throws Exception when connection fails
263 */ 270 */
264 @Throws(Exception::class) 271 @Throws(Exception::class)
265 - suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) {  
266 - if (this::coroutineScope.isInitialized) { 272 + suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) = coroutineScope {
  273 + if (state != State.DISCONNECTED) {
  274 + throw IllegalStateException("Room.connect attempted while room is not disconnected!")
  275 + }
  276 + val roomOptions: RoomOptions
  277 + stateLock.withLock {
  278 + if (state != State.DISCONNECTED) {
  279 + throw IllegalStateException("Room.connect attempted while room is not disconnected!")
  280 + }
  281 + if (::coroutineScope.isInitialized) {
  282 + val job = coroutineScope.coroutineContext.job
267 coroutineScope.cancel() 283 coroutineScope.cancel()
  284 + job.join()
268 } 285 }
  286 +
  287 + state = State.CONNECTING
  288 + connectOptions = options
  289 +
269 coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob()) 290 coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
270 291
271 - val roomOptions = getCurrentRoomOptions() 292 + roomOptions = getCurrentRoomOptions()
272 293
273 // Setup local participant. 294 // Setup local participant.
274 localParticipant.reinitialize() 295 localParticipant.reinitialize()
@@ -319,9 +340,6 @@ constructor( @@ -319,9 +340,6 @@ constructor(
319 } 340 }
320 } 341 }
321 342
322 - state = State.CONNECTING  
323 - connectOptions = options  
324 -  
325 if (roomOptions.e2eeOptions != null) { 343 if (roomOptions.e2eeOptions != null) {
326 e2eeManager = e2EEManagerFactory.create(roomOptions.e2eeOptions.keyProvider).apply { 344 e2eeManager = e2EEManagerFactory.create(roomOptions.e2eeOptions.keyProvider).apply {
327 setup(this@Room) { event -> 345 setup(this@Room) { event ->
@@ -331,7 +349,14 @@ constructor( @@ -331,7 +349,14 @@ constructor(
331 } 349 }
332 } 350 }
333 } 351 }
  352 + }
334 353
  354 + // Use an empty coroutineExceptionHandler since we want to
  355 + // rethrow all throwables from the connect job.
  356 + val emptyCoroutineExceptionHandler = CoroutineExceptionHandler { _, _ -> }
  357 + val connectJob = coroutineScope.launch(
  358 + ioDispatcher + emptyCoroutineExceptionHandler,
  359 + ) {
335 engine.join(url, token, options, roomOptions) 360 engine.join(url, token, options, roomOptions)
336 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager 361 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
337 val networkRequest = NetworkRequest.Builder() 362 val networkRequest = NetworkRequest.Builder()
@@ -339,16 +364,35 @@ constructor( @@ -339,16 +364,35 @@ constructor(
339 .build() 364 .build()
340 cm.registerNetworkCallback(networkRequest, networkCallback) 365 cm.registerNetworkCallback(networkRequest, networkCallback)
341 366
  367 + ensureActive()
342 if (options.audio) { 368 if (options.audio) {
343 val audioTrack = localParticipant.createAudioTrack() 369 val audioTrack = localParticipant.createAudioTrack()
344 localParticipant.publishAudioTrack(audioTrack) 370 localParticipant.publishAudioTrack(audioTrack)
345 } 371 }
  372 + ensureActive()
346 if (options.video) { 373 if (options.video) {
347 val videoTrack = localParticipant.createVideoTrack() 374 val videoTrack = localParticipant.createVideoTrack()
348 localParticipant.publishVideoTrack(videoTrack) 375 localParticipant.publishVideoTrack(videoTrack)
349 } 376 }
350 } 377 }
351 378
  379 + val outerHandler = coroutineContext.job.invokeOnCompletion { cause ->
  380 + // Cancel connect job if invoking coroutine is cancelled.
  381 + if (cause is CancellationException) {
  382 + connectJob.cancel(cause)
  383 + }
  384 + }
  385 +
  386 + var error: Throwable? = null
  387 + connectJob.invokeOnCompletion { cause ->
  388 + outerHandler.dispose()
  389 + error = cause
  390 + }
  391 + connectJob.join()
  392 +
  393 + error?.let { throw it }
  394 + }
  395 +
352 /** 396 /**
353 * Disconnect from the room. 397 * Disconnect from the room.
354 */ 398 */
@@ -592,28 +636,15 @@ constructor( @@ -592,28 +636,15 @@ constructor(
592 engine.reconnect() 636 engine.reconnect()
593 } 637 }
594 638
595 - /**  
596 - * Removes all participants and tracks from the room.  
597 - */  
598 - private fun cleanupRoom() {  
599 - e2eeManager?.cleanUp()  
600 - e2eeManager = null  
601 - localParticipant.cleanup()  
602 - remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.  
603 - .forEach { sid -> handleParticipantDisconnect(sid) }  
604 -  
605 - sid = null  
606 - metadata = null  
607 - name = null  
608 - isRecording = false  
609 - sidToIdentity.clear()  
610 - }  
611 -  
612 private fun handleDisconnect(reason: DisconnectReason) { 639 private fun handleDisconnect(reason: DisconnectReason) {
613 if (state == State.DISCONNECTED) { 640 if (state == State.DISCONNECTED) {
614 return 641 return
615 } 642 }
616 - 643 + runBlocking {
  644 + stateLock.withLock {
  645 + if (state == State.DISCONNECTED) {
  646 + return@runBlocking
  647 + }
617 try { 648 try {
618 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager 649 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
619 cm.unregisterNetworkCallback(networkCallback) 650 cm.unregisterNetworkCallback(networkCallback)
@@ -628,11 +659,28 @@ constructor( @@ -628,11 +659,28 @@ constructor(
628 localParticipant.dispose() 659 localParticipant.dispose()
629 660
630 // Ensure all observers see the disconnected before closing scope. 661 // Ensure all observers see the disconnected before closing scope.
631 - runBlocking {  
632 eventBus.postEvent(RoomEvent.Disconnected(this@Room, null, reason), coroutineScope).join() 662 eventBus.postEvent(RoomEvent.Disconnected(this@Room, null, reason), coroutineScope).join()
633 - }  
634 coroutineScope.cancel() 663 coroutineScope.cancel()
635 } 664 }
  665 + }
  666 + }
  667 +
  668 + /**
  669 + * Removes all participants and tracks from the room.
  670 + */
  671 + private fun cleanupRoom() {
  672 + e2eeManager?.cleanUp()
  673 + e2eeManager = null
  674 + localParticipant.cleanup()
  675 + remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.
  676 + .forEach { sid -> handleParticipantDisconnect(sid) }
  677 +
  678 + sid = null
  679 + metadata = null
  680 + name = null
  681 + isRecording = false
  682 + sidToIdentity.clear()
  683 + }
636 684
637 private fun sendSyncState() { 685 private fun sendSyncState() {
638 // Whether we're sending subscribed tracks or tracks to unsubscribe. 686 // Whether we're sending subscribed tracks or tracks to unsubscribe.
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.util
  18 +
  19 +import kotlinx.coroutines.sync.Mutex
  20 +import kotlinx.coroutines.sync.withLock
  21 +
  22 +/**
  23 + * Applies a double-checked lock before running [action].
  24 + */
  25 +suspend inline fun <T> Mutex.withCheckLock(check: () -> Unit, action: () -> T): T {
  26 + check()
  27 + return withLock {
  28 + check()
  29 + action()
  30 + }
  31 +}
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package io.livekit.android.webrtc.peerconnection @@ -18,6 +18,7 @@ package io.livekit.android.webrtc.peerconnection
18 18
19 import androidx.annotation.VisibleForTesting 19 import androidx.annotation.VisibleForTesting
20 import kotlinx.coroutines.CoroutineDispatcher 20 import kotlinx.coroutines.CoroutineDispatcher
  21 +import kotlinx.coroutines.CoroutineScope
21 import kotlinx.coroutines.asCoroutineDispatcher 22 import kotlinx.coroutines.asCoroutineDispatcher
22 import kotlinx.coroutines.async 23 import kotlinx.coroutines.async
23 import kotlinx.coroutines.coroutineScope 24 import kotlinx.coroutines.coroutineScope
@@ -41,7 +42,7 @@ private val threadFactory = object : ThreadFactory { @@ -41,7 +42,7 @@ private val threadFactory = object : ThreadFactory {
41 } 42 }
42 } 43 }
43 44
44 -// var only for testing purposes, do not alter! 45 +// var only for testing purposes, do not alter in production!
45 private var executor = Executors.newSingleThreadExecutor(threadFactory) 46 private var executor = Executors.newSingleThreadExecutor(threadFactory)
46 private var rtcDispatcher: CoroutineDispatcher = executor.asCoroutineDispatcher() 47 private var rtcDispatcher: CoroutineDispatcher = executor.asCoroutineDispatcher()
47 48
@@ -82,12 +83,12 @@ fun <T> executeBlockingOnRTCThread(action: () -> T): T { @@ -82,12 +83,12 @@ fun <T> executeBlockingOnRTCThread(action: () -> T): T {
82 * is generally not thread safe, so all actions relating to 83 * is generally not thread safe, so all actions relating to
83 * peer connection objects should go through the RTC thread. 84 * peer connection objects should go through the RTC thread.
84 */ 85 */
85 -suspend fun <T> launchBlockingOnRTCThread(action: suspend () -> T): T = coroutineScope { 86 +suspend fun <T> launchBlockingOnRTCThread(action: suspend CoroutineScope.() -> T): T = coroutineScope {
86 return@coroutineScope if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) { 87 return@coroutineScope if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
87 - action() 88 + this.action()
88 } else { 89 } else {
89 async(rtcDispatcher) { 90 async(rtcDispatcher) {
90 - action() 91 + this.action()
91 }.await() 92 }.await()
92 } 93 }
93 } 94 }