davidliu
Committed by GitHub

Separate DataChannel Observer into independent objects (#327)

* Separate DataChannel Observer into independent objects

unregisterObserver directly deletes the natively wrapped observer.
Previous implementation should be fine, but moving to independent for memory safety

* Spotless
@@ -65,7 +65,7 @@ internal constructor( @@ -65,7 +65,7 @@ internal constructor(
65 private val pctFactory: PeerConnectionTransport.Factory, 65 private val pctFactory: PeerConnectionTransport.Factory,
66 @Named(InjectionNames.DISPATCHER_IO) 66 @Named(InjectionNames.DISPATCHER_IO)
67 private val ioDispatcher: CoroutineDispatcher, 67 private val ioDispatcher: CoroutineDispatcher,
68 -) : SignalClient.Listener, DataChannel.Observer { 68 +) : SignalClient.Listener {
69 internal var listener: Listener? = null 69 internal var listener: Listener? = null
70 70
71 /** 71 /**
@@ -218,7 +218,7 @@ internal constructor( @@ -218,7 +218,7 @@ internal constructor(
218 LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel 218 LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
219 else -> return@onDataChannel 219 else -> return@onDataChannel
220 } 220 }
221 - dataChannel.registerObserver(this) 221 + dataChannel.registerObserver(DataChannelObserver(dataChannel))
222 } 222 }
223 223
224 subscriberObserver.connectionChangeListener = connectionStateListener 224 subscriberObserver.connectionChangeListener = connectionStateListener
@@ -239,7 +239,9 @@ internal constructor( @@ -239,7 +239,9 @@ internal constructor(
239 createDataChannel( 239 createDataChannel(
240 RELIABLE_DATA_CHANNEL_LABEL, 240 RELIABLE_DATA_CHANNEL_LABEL,
241 reliableInit, 241 reliableInit,
242 - ).apply { registerObserver(this@RTCEngine) } 242 + ).also { dataChannel ->
  243 + dataChannel.registerObserver(DataChannelObserver(dataChannel))
  244 + }
243 } 245 }
244 246
245 val lossyInit = DataChannel.Init() 247 val lossyInit = DataChannel.Init()
@@ -249,7 +251,9 @@ internal constructor( @@ -249,7 +251,9 @@ internal constructor(
249 createDataChannel( 251 createDataChannel(
250 LOSSY_DATA_CHANNEL_LABEL, 252 LOSSY_DATA_CHANNEL_LABEL,
251 lossyInit, 253 lossyInit,
252 - ).apply { registerObserver(this@RTCEngine) } 254 + ).also { dataChannel ->
  255 + dataChannel.registerObserver(DataChannelObserver(dataChannel))
  256 + }
253 } 257 }
254 } 258 }
255 259
@@ -684,8 +688,11 @@ internal constructor( @@ -684,8 +688,11 @@ internal constructor(
684 } 688 }
685 689
686 companion object { 690 companion object {
687 - private const val RELIABLE_DATA_CHANNEL_LABEL = "_reliable"  
688 - private const val LOSSY_DATA_CHANNEL_LABEL = "_lossy" 691 + @VisibleForTesting
  692 + internal const val RELIABLE_DATA_CHANNEL_LABEL = "_reliable"
  693 +
  694 + @VisibleForTesting
  695 + internal const val LOSSY_DATA_CHANNEL_LABEL = "_lossy"
689 internal const val MAX_DATA_PACKET_SIZE = 15000 696 internal const val MAX_DATA_PACKET_SIZE = 15000
690 private const val MAX_RECONNECT_RETRIES = 10 697 private const val MAX_RECONNECT_RETRIES = 10
691 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000 698 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000
@@ -883,13 +890,13 @@ internal constructor( @@ -883,13 +890,13 @@ internal constructor(
883 890
884 // --------------------------------- DataChannel.Observer ------------------------------------// 891 // --------------------------------- DataChannel.Observer ------------------------------------//
885 892
886 - override fun onBufferedAmountChange(previousAmount: Long) { 893 + fun onBufferedAmountChange(dataChannel: DataChannel, previousAmount: Long) {
887 } 894 }
888 895
889 - override fun onStateChange() { 896 + fun onStateChange(dataChannel: DataChannel) {
890 } 897 }
891 898
892 - override fun onMessage(buffer: DataChannel.Buffer?) { 899 + fun onMessage(dataChannel: DataChannel, buffer: DataChannel.Buffer?) {
893 if (buffer == null) { 900 if (buffer == null) {
894 return 901 return
895 } 902 }
@@ -911,6 +918,20 @@ internal constructor( @@ -911,6 +918,20 @@ internal constructor(
911 } 918 }
912 } 919 }
913 920
  921 + private inner class DataChannelObserver(val dataChannel: DataChannel) : DataChannel.Observer {
  922 + override fun onBufferedAmountChange(p0: Long) {
  923 + this@RTCEngine.onBufferedAmountChange(dataChannel, p0)
  924 + }
  925 +
  926 + override fun onStateChange() {
  927 + this@RTCEngine.onStateChange(dataChannel)
  928 + }
  929 +
  930 + override fun onMessage(p0: DataChannel.Buffer) {
  931 + this@RTCEngine.onMessage(dataChannel, p0)
  932 + }
  933 + }
  934 +
914 fun sendSyncState( 935 fun sendSyncState(
915 subscription: LivekitRtc.UpdateSubscription, 936 subscription: LivekitRtc.UpdateSubscription,
916 publishedTracks: List<LivekitRtc.TrackPublishedResponse>, 937 publishedTracks: List<LivekitRtc.TrackPublishedResponse>,
@@ -35,9 +35,12 @@ import okhttp3.Request @@ -35,9 +35,12 @@ import okhttp3.Request
35 import okhttp3.Response 35 import okhttp3.Response
36 import okio.ByteString 36 import okio.ByteString
37 import org.junit.Before 37 import org.junit.Before
  38 +import org.junit.runner.RunWith
  39 +import org.robolectric.RobolectricTestRunner
38 import org.webrtc.PeerConnection 40 import org.webrtc.PeerConnection
39 41
40 @ExperimentalCoroutinesApi 42 @ExperimentalCoroutinesApi
  43 +@RunWith(RobolectricTestRunner::class)
41 abstract class MockE2ETest : BaseTest() { 44 abstract class MockE2ETest : BaseTest() {
42 45
43 internal lateinit var component: TestLiveKitComponent 46 internal lateinit var component: TestLiveKitComponent
  1 +/*
  2 + * Copyright 2023 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.room
  18 +
  19 +import com.google.protobuf.ByteString
  20 +import io.livekit.android.MockE2ETest
  21 +import io.livekit.android.assert.assertIsClass
  22 +import io.livekit.android.events.EventCollector
  23 +import io.livekit.android.events.RoomEvent
  24 +import io.livekit.android.mock.MockDataChannel
  25 +import io.livekit.android.mock.MockPeerConnection
  26 +import kotlinx.coroutines.ExperimentalCoroutinesApi
  27 +import livekit.LivekitModels.DataPacket
  28 +import livekit.LivekitModels.UserPacket
  29 +import org.junit.Assert.assertEquals
  30 +import org.junit.Test
  31 +import org.webrtc.DataChannel
  32 +import java.nio.ByteBuffer
  33 +
  34 +@OptIn(ExperimentalCoroutinesApi::class)
  35 +class RoomDataMockE2ETest : MockE2ETest() {
  36 + @Test
  37 + fun dataReceivedEvent() = runTest {
  38 + connect()
  39 + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
  40 + val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
  41 + subPeerConnection.observer?.onDataChannel(subDataChannel)
  42 +
  43 + val collector = EventCollector(room.events, coroutineRule.scope)
  44 + val dataPacket = with(DataPacket.newBuilder()) {
  45 + user = with(UserPacket.newBuilder()) {
  46 + payload = ByteString.copyFrom("hello", Charsets.UTF_8)
  47 + build()
  48 + }
  49 + build()
  50 + }
  51 + val dataBuffer = DataChannel.Buffer(
  52 + ByteBuffer.wrap(dataPacket.toByteArray()),
  53 + true
  54 + )
  55 +
  56 + subDataChannel.observer?.onMessage(dataBuffer)
  57 + val events = collector.stopCollecting()
  58 +
  59 + assertEquals(1, events.size)
  60 + assertIsClass(RoomEvent.DataReceived::class.java, events[0])
  61 +
  62 + val event = events[0] as RoomEvent.DataReceived
  63 + assertEquals("hello", event.data.decodeToString())
  64 + }
  65 +}