davidliu
Committed by GitHub

Fix deadlock caused by multiple concurrent setCameraEnabled calls (#472)

* Fix deadlock caused by multiple concurrent setCameraEnabled calls

* tests

* spotless
@@ -19,13 +19,13 @@ package io.livekit.android.memory @@ -19,13 +19,13 @@ package io.livekit.android.memory
19 import livekit.org.webrtc.SurfaceTextureHelper 19 import livekit.org.webrtc.SurfaceTextureHelper
20 import java.io.Closeable 20 import java.io.Closeable
21 21
22 -internal class SurfaceTextureHelperCloser(private val surfaceTextureHelper: SurfaceTextureHelper) : Closeable { 22 +internal class SurfaceTextureHelperCloser(private val surfaceTextureHelper: SurfaceTextureHelper?) : Closeable {
23 private var isClosed = false 23 private var isClosed = false
24 override fun close() { 24 override fun close() {
25 if (!isClosed) { 25 if (!isClosed) {
26 isClosed = true 26 isClosed = true
27 - surfaceTextureHelper.stopListening()  
28 - surfaceTextureHelper.dispose() 27 + surfaceTextureHelper?.stopListening()
  28 + surfaceTextureHelper?.dispose()
29 } 29 }
30 } 30 }
31 } 31 }
@@ -52,6 +52,8 @@ import io.livekit.android.webrtc.sortVideoCodecPreferences @@ -52,6 +52,8 @@ import io.livekit.android.webrtc.sortVideoCodecPreferences
52 import kotlinx.coroutines.CoroutineDispatcher 52 import kotlinx.coroutines.CoroutineDispatcher
53 import kotlinx.coroutines.Job 53 import kotlinx.coroutines.Job
54 import kotlinx.coroutines.launch 54 import kotlinx.coroutines.launch
  55 +import kotlinx.coroutines.sync.Mutex
  56 +import kotlinx.coroutines.sync.withLock
55 import livekit.LivekitModels 57 import livekit.LivekitModels
56 import livekit.LivekitRtc 58 import livekit.LivekitRtc
57 import livekit.LivekitRtc.AddTrackRequest 59 import livekit.LivekitRtc.AddTrackRequest
@@ -99,6 +101,12 @@ internal constructor( @@ -99,6 +101,12 @@ internal constructor(
99 101
100 private val jobs = mutableMapOf<Any, Job>() 102 private val jobs = mutableMapOf<Any, Job>()
101 103
  104 + // For ensuring that only one caller can execute setTrackEnabled at a time.
  105 + // Without it, there's a potential to create multiple of the same source,
  106 + // Camera has deadlock issues with multiple CameraCapturers trying to activate/stop.
  107 + private val sourcePubLocks = Track.Source.values()
  108 + .associate { source -> source to Mutex() }
  109 +
102 /** 110 /**
103 * Creates an audio track, recording audio through the microphone with the given [options]. 111 * Creates an audio track, recording audio through the microphone with the given [options].
104 * 112 *
@@ -228,57 +236,60 @@ internal constructor( @@ -228,57 +236,60 @@ internal constructor(
228 enabled: Boolean, 236 enabled: Boolean,
229 mediaProjectionPermissionResultData: Intent? = null, 237 mediaProjectionPermissionResultData: Intent? = null,
230 ) { 238 ) {
231 - val pub = getTrackPublication(source)  
232 - if (enabled) {  
233 - if (pub != null) {  
234 - pub.muted = false  
235 -  
236 - if (source == Track.Source.CAMERA && pub.track is LocalVideoTrack) {  
237 - (pub.track as? LocalVideoTrack)?.startCapture()  
238 - }  
239 - } else {  
240 - when (source) {  
241 - Track.Source.CAMERA -> {  
242 - val track = createVideoTrack()  
243 - track.startCapture()  
244 - publishVideoTrack(track) 239 + val pubLock = sourcePubLocks[source]!!
  240 + pubLock.withLock {
  241 + val pub = getTrackPublication(source)
  242 + if (enabled) {
  243 + if (pub != null) {
  244 + pub.muted = false
  245 + if (source == Track.Source.CAMERA && pub.track is LocalVideoTrack) {
  246 + (pub.track as? LocalVideoTrack)?.startCapture()
245 } 247 }
  248 + } else {
  249 + when (source) {
  250 + Track.Source.CAMERA -> {
  251 + val track = createVideoTrack()
  252 + track.startCapture()
  253 + publishVideoTrack(track)
  254 + }
246 255
247 - Track.Source.MICROPHONE -> {  
248 - val track = createAudioTrack()  
249 - publishAudioTrack(track)  
250 - } 256 + Track.Source.MICROPHONE -> {
  257 + val track = createAudioTrack()
  258 + publishAudioTrack(track)
  259 + }
251 260
252 - Track.Source.SCREEN_SHARE -> {  
253 - if (mediaProjectionPermissionResultData == null) {  
254 - throw IllegalArgumentException("Media Projection permission result data is required to create a screen share track.") 261 + Track.Source.SCREEN_SHARE -> {
  262 + if (mediaProjectionPermissionResultData == null) {
  263 + throw IllegalArgumentException("Media Projection permission result data is required to create a screen share track.")
  264 + }
  265 + val track =
  266 + createScreencastTrack(mediaProjectionPermissionResultData = mediaProjectionPermissionResultData)
  267 + track.startForegroundService(null, null)
  268 + track.startCapture()
  269 + publishVideoTrack(track)
255 } 270 }
256 - val track =  
257 - createScreencastTrack(mediaProjectionPermissionResultData = mediaProjectionPermissionResultData)  
258 - track.startForegroundService(null, null)  
259 - track.startCapture()  
260 - publishVideoTrack(track)  
261 - }  
262 271
263 - else -> {  
264 - LKLog.w { "Attempting to enable an unknown source, ignoring." } 272 + else -> {
  273 + LKLog.w { "Attempting to enable an unknown source, ignoring." }
  274 + }
265 } 275 }
266 } 276 }
267 - }  
268 - } else {  
269 - pub?.track?.let { track ->  
270 - // screenshare cannot be muted, unpublish instead  
271 - if (pub.source == Track.Source.SCREEN_SHARE) {  
272 - unpublishTrack(track)  
273 - } else {  
274 - pub.muted = true  
275 -  
276 - // Release camera session so other apps can use.  
277 - if (pub.source == Track.Source.CAMERA && track is LocalVideoTrack) {  
278 - track.stopCapture() 277 + } else {
  278 + pub?.track?.let { track ->
  279 + // screenshare cannot be muted, unpublish instead
  280 + if (pub.source == Track.Source.SCREEN_SHARE) {
  281 + unpublishTrack(track)
  282 + } else {
  283 + pub.muted = true
  284 +
  285 + // Release camera session so other apps can use.
  286 + if (pub.source == Track.Source.CAMERA && track is LocalVideoTrack) {
  287 + track.stopCapture()
  288 + }
279 } 289 }
280 } 290 }
281 } 291 }
  292 + return@withLock
282 } 293 }
283 } 294 }
284 295
@@ -484,6 +495,7 @@ internal constructor( @@ -484,6 +495,7 @@ internal constructor(
484 options = options, 495 options = options,
485 ) 496 )
486 addTrackPublication(publication) 497 addTrackPublication(publication)
  498 + LKLog.e { "add track publication $publication" }
487 499
488 publishListener?.onPublishSuccess(publication) 500 publishListener?.onPublishSuccess(publication)
489 internalListener?.onTrackPublished(publication, this) 501 internalListener?.onTrackPublished(publication, this)
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.test.mock
  18 +
  19 +import livekit.org.webrtc.AudioSource
  20 +
  21 +class MockAudioSource : AudioSource(100L)
@@ -17,7 +17,6 @@ @@ -17,7 +17,6 @@
17 package io.livekit.android.test.mock 17 package io.livekit.android.test.mock
18 18
19 import com.google.protobuf.MessageLite 19 import com.google.protobuf.MessageLite
20 -import io.livekit.android.test.util.toOkioByteString  
21 import io.livekit.android.test.util.toPBByteString 20 import io.livekit.android.test.util.toPBByteString
22 import io.livekit.android.util.toOkioByteString 21 import io.livekit.android.util.toOkioByteString
23 import livekit.LivekitModels 22 import livekit.LivekitModels
@@ -56,14 +55,21 @@ class MockWebSocketFactory : WebSocket.Factory { @@ -56,14 +55,21 @@ class MockWebSocketFactory : WebSocket.Factory {
56 return ws 55 return ws
57 } 56 }
58 57
59 - private val signalRequestHandlers = mutableListOf<SignalRequestHandler>(  
60 - { signalRequest -> defaultHandleSignalRequest(signalRequest) },  
61 - ) 58 + val defaultSignalRequestHandler: SignalRequestHandler = { signalRequest -> defaultHandleSignalRequest(signalRequest) }
62 59
  60 + private val signalRequestHandlers = mutableListOf(defaultSignalRequestHandler)
  61 +
  62 + /**
  63 + * Adds a handler to the front of the list.
  64 + */
63 fun registerSignalRequestHandler(handler: SignalRequestHandler) { 65 fun registerSignalRequestHandler(handler: SignalRequestHandler) {
64 signalRequestHandlers.add(0, handler) 66 signalRequestHandlers.add(0, handler)
65 } 67 }
66 68
  69 + fun unregisterSignalRequestHandler(handler: SignalRequestHandler) {
  70 + signalRequestHandlers.remove(handler)
  71 + }
  72 +
67 private fun handleSignalRequest(signalRequest: SignalRequest) { 73 private fun handleSignalRequest(signalRequest: SignalRequest) {
68 for (handler in signalRequestHandlers) { 74 for (handler in signalRequestHandlers) {
69 if (handler.invoke(signalRequest)) { 75 if (handler.invoke(signalRequest)) {
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.test.mock.camera
  18 +
  19 +import android.content.Context
  20 +import io.livekit.android.room.track.LocalVideoTrackOptions
  21 +import io.livekit.android.room.track.video.CameraCapturerUtils
  22 +import io.livekit.android.room.track.video.CameraEventsDispatchHandler
  23 +import livekit.org.webrtc.CameraEnumerationAndroid
  24 +import livekit.org.webrtc.CameraEnumerator
  25 +import livekit.org.webrtc.CameraVideoCapturer
  26 +import livekit.org.webrtc.CapturerObserver
  27 +import livekit.org.webrtc.SurfaceTextureHelper
  28 +import livekit.org.webrtc.VideoCapturer
  29 +
  30 +class MockCameraProvider : CameraCapturerUtils.CameraProvider {
  31 +
  32 + companion object {
  33 + fun register() {
  34 + CameraCapturerUtils.registerCameraProvider(MockCameraProvider())
  35 + }
  36 + }
  37 +
  38 + private val enumerator by lazy { MockCameraEnumerator() }
  39 +
  40 + override val cameraVersion: Int = 100
  41 +
  42 + override fun provideEnumerator(context: Context): CameraEnumerator {
  43 + return enumerator
  44 + }
  45 +
  46 + override fun provideCapturer(context: Context, options: LocalVideoTrackOptions, eventsHandler: CameraEventsDispatchHandler): VideoCapturer {
  47 + return enumerator.createCapturer(options.deviceId, eventsHandler)
  48 + }
  49 +
  50 + override fun isSupported(context: Context): Boolean {
  51 + return true
  52 + }
  53 +}
  54 +
  55 +class MockCameraEnumerator : CameraEnumerator {
  56 + override fun getDeviceNames(): Array<String> {
  57 + return arrayOf("camera")
  58 + }
  59 +
  60 + override fun isFrontFacing(deviceName: String): Boolean {
  61 + return true
  62 + }
  63 +
  64 + override fun isBackFacing(deviceName: String): Boolean {
  65 + return false
  66 + }
  67 +
  68 + override fun getSupportedFormats(p0: String): MutableList<CameraEnumerationAndroid.CaptureFormat> {
  69 + return mutableListOf(
  70 + CameraEnumerationAndroid.CaptureFormat(480, 640, 30, 30),
  71 + )
  72 + }
  73 +
  74 + override fun createCapturer(deviceName: String?, eventsHandler: CameraVideoCapturer.CameraEventsHandler): CameraVideoCapturer {
  75 + return MockCameraVideoCapturer()
  76 + }
  77 +}
  78 +
  79 +class MockCameraVideoCapturer : CameraVideoCapturer {
  80 + override fun initialize(p0: SurfaceTextureHelper?, p1: Context?, p2: CapturerObserver?) {
  81 + }
  82 +
  83 + override fun startCapture(p0: Int, p1: Int, p2: Int) {
  84 + }
  85 +
  86 + override fun stopCapture() {
  87 + }
  88 +
  89 + override fun changeCaptureFormat(p0: Int, p1: Int, p2: Int) {
  90 + }
  91 +
  92 + override fun dispose() {
  93 + }
  94 +
  95 + override fun isScreencast(): Boolean = false
  96 +
  97 + override fun switchCamera(p0: CameraVideoCapturer.CameraSwitchHandler?) {
  98 + }
  99 +
  100 + override fun switchCamera(p0: CameraVideoCapturer.CameraSwitchHandler?, p1: String?) {
  101 + }
  102 +}
@@ -17,8 +17,13 @@ @@ -17,8 +17,13 @@
17 package io.livekit.android.test.util 17 package io.livekit.android.test.util
18 18
19 import com.google.protobuf.ByteString 19 import com.google.protobuf.ByteString
  20 +import livekit.LivekitRtc
20 import okio.ByteString.Companion.toByteString 21 import okio.ByteString.Companion.toByteString
21 22
22 fun com.google.protobuf.ByteString.toOkioByteString() = toByteArray().toByteString() 23 fun com.google.protobuf.ByteString.toOkioByteString() = toByteArray().toByteString()
23 24
24 fun okio.ByteString.toPBByteString() = ByteString.copyFrom(toByteArray()) 25 fun okio.ByteString.toPBByteString() = ByteString.copyFrom(toByteArray())
  26 +
  27 +fun okio.ByteString.toSignalRequest() = LivekitRtc.SignalRequest.newBuilder()
  28 + .mergeFrom(toPBByteString())
  29 + .build()
@@ -16,6 +16,8 @@ @@ -16,6 +16,8 @@
16 16
17 package livekit.org.webrtc 17 package livekit.org.webrtc
18 18
  19 +import io.livekit.android.test.mock.MockAudioSource
  20 +import io.livekit.android.test.mock.MockAudioStreamTrack
19 import io.livekit.android.test.mock.MockPeerConnection 21 import io.livekit.android.test.mock.MockPeerConnection
20 import io.livekit.android.test.mock.MockVideoSource 22 import io.livekit.android.test.mock.MockVideoSource
21 import io.livekit.android.test.mock.MockVideoStreamTrack 23 import io.livekit.android.test.mock.MockVideoStreamTrack
@@ -30,6 +32,14 @@ class MockPeerConnectionFactory : PeerConnectionFactory(1L) { @@ -30,6 +32,14 @@ class MockPeerConnectionFactory : PeerConnectionFactory(1L) {
30 return MockPeerConnection(rtcConfig, observer) 32 return MockPeerConnection(rtcConfig, observer)
31 } 33 }
32 34
  35 + override fun createAudioSource(constraints: MediaConstraints?): AudioSource {
  36 + return MockAudioSource()
  37 + }
  38 +
  39 + override fun createAudioTrack(id: String, source: AudioSource?): AudioTrack {
  40 + return MockAudioStreamTrack(id = id)
  41 + }
  42 +
33 override fun createVideoSource(isScreencast: Boolean, alignTimestamps: Boolean): VideoSource { 43 override fun createVideoSource(isScreencast: Boolean, alignTimestamps: Boolean): VideoSource {
34 return MockVideoSource() 44 return MockVideoSource()
35 } 45 }
@@ -16,6 +16,10 @@ @@ -16,6 +16,10 @@
16 16
17 package io.livekit.android.room.participant 17 package io.livekit.android.room.participant
18 18
  19 +import android.Manifest
  20 +import android.app.Application
  21 +import android.content.Context
  22 +import androidx.test.core.app.ApplicationProvider
19 import io.livekit.android.audio.AudioProcessorInterface 23 import io.livekit.android.audio.AudioProcessorInterface
20 import io.livekit.android.events.ParticipantEvent 24 import io.livekit.android.events.ParticipantEvent
21 import io.livekit.android.events.RoomEvent 25 import io.livekit.android.events.RoomEvent
@@ -36,9 +40,14 @@ import io.livekit.android.test.mock.MockEglBase @@ -36,9 +40,14 @@ import io.livekit.android.test.mock.MockEglBase
36 import io.livekit.android.test.mock.MockVideoCapturer 40 import io.livekit.android.test.mock.MockVideoCapturer
37 import io.livekit.android.test.mock.MockVideoStreamTrack 41 import io.livekit.android.test.mock.MockVideoStreamTrack
38 import io.livekit.android.test.mock.TestData 42 import io.livekit.android.test.mock.TestData
  43 +import io.livekit.android.test.mock.camera.MockCameraProvider
39 import io.livekit.android.test.util.toPBByteString 44 import io.livekit.android.test.util.toPBByteString
40 import io.livekit.android.util.toOkioByteString 45 import io.livekit.android.util.toOkioByteString
  46 +import kotlinx.coroutines.CoroutineScope
41 import kotlinx.coroutines.ExperimentalCoroutinesApi 47 import kotlinx.coroutines.ExperimentalCoroutinesApi
  48 +import kotlinx.coroutines.Job
  49 +import kotlinx.coroutines.cancel
  50 +import kotlinx.coroutines.launch
42 import kotlinx.coroutines.test.advanceUntilIdle 51 import kotlinx.coroutines.test.advanceUntilIdle
43 import livekit.LivekitModels 52 import livekit.LivekitModels
44 import livekit.LivekitModels.AudioTrackFeature 53 import livekit.LivekitModels.AudioTrackFeature
@@ -57,6 +66,7 @@ import org.mockito.Mockito @@ -57,6 +66,7 @@ import org.mockito.Mockito
57 import org.mockito.Mockito.mock 66 import org.mockito.Mockito.mock
58 import org.mockito.kotlin.argThat 67 import org.mockito.kotlin.argThat
59 import org.robolectric.RobolectricTestRunner 68 import org.robolectric.RobolectricTestRunner
  69 +import org.robolectric.Shadows
60 import java.nio.ByteBuffer 70 import java.nio.ByteBuffer
61 71
62 @ExperimentalCoroutinesApi 72 @ExperimentalCoroutinesApi
@@ -111,6 +121,49 @@ class LocalParticipantMockE2ETest : MockE2ETest() { @@ -111,6 +121,49 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
111 } 121 }
112 122
113 @Test 123 @Test
  124 + fun setTrackEnabledIsSynchronizedSingleSource() = runTest {
  125 + connect()
  126 +
  127 + val context = ApplicationProvider.getApplicationContext<Context>()
  128 + val shadowApplication = Shadows.shadowOf(context as Application)
  129 + shadowApplication.grantPermissions(Manifest.permission.RECORD_AUDIO)
  130 + wsFactory.unregisterSignalRequestHandler(wsFactory.defaultSignalRequestHandler)
  131 + wsFactory.ws.clearRequests()
  132 +
  133 + val backgroundScope = CoroutineScope(coroutineContext + Job())
  134 + try {
  135 + backgroundScope.launch { room.localParticipant.setMicrophoneEnabled(true) }
  136 + backgroundScope.launch { room.localParticipant.setMicrophoneEnabled(true) }
  137 +
  138 + assertEquals(1, wsFactory.ws.sentRequests.size)
  139 + } finally {
  140 + backgroundScope.cancel()
  141 + }
  142 + }
  143 +
  144 + @Test
  145 + fun setTrackEnabledIsSynchronizedMultipleSource() = runTest {
  146 + connect()
  147 +
  148 + MockCameraProvider.register()
  149 + val context = ApplicationProvider.getApplicationContext<Context>()
  150 + val shadowApplication = Shadows.shadowOf(context as Application)
  151 + shadowApplication.grantPermissions(Manifest.permission.RECORD_AUDIO, Manifest.permission.CAMERA)
  152 + wsFactory.unregisterSignalRequestHandler(wsFactory.defaultSignalRequestHandler)
  153 + wsFactory.ws.clearRequests()
  154 +
  155 + val backgroundScope = CoroutineScope(coroutineContext + Job())
  156 + try {
  157 + backgroundScope.launch { room.localParticipant.setMicrophoneEnabled(true) }
  158 + backgroundScope.launch { room.localParticipant.setCameraEnabled(true) }
  159 +
  160 + assertEquals(2, wsFactory.ws.sentRequests.size)
  161 + } finally {
  162 + backgroundScope.cancel()
  163 + }
  164 + }
  165 +
  166 + @Test
114 fun publishVideoTrackRequest() = runTest { 167 fun publishVideoTrackRequest() = runTest {
115 connect() 168 connect()
116 wsFactory.ws.clearRequests() 169 wsFactory.ws.clearRequests()