davidliu

room reusable after disconnect

... ... @@ -34,7 +34,8 @@ class RTCEngine
internal constructor(
val client: SignalClient,
private val pctFactory: PeerConnectionTransport.Factory,
@Named(InjectionNames.DISPATCHER_IO) ioDispatcher: CoroutineDispatcher,
@Named(InjectionNames.DISPATCHER_IO)
private val ioDispatcher: CoroutineDispatcher,
) : SignalClient.Listener, DataChannel.Observer {
internal var listener: Listener? = null
... ... @@ -77,8 +78,20 @@ internal constructor(
private val publisherObserver = PublisherTransportObserver(this, client)
private val subscriberObserver = SubscriberTransportObserver(this, client)
internal lateinit var publisher: PeerConnectionTransport
internal lateinit var subscriber: PeerConnectionTransport
private var _publisher: PeerConnectionTransport? = null
internal val publisher: PeerConnectionTransport
get() {
return _publisher
?: throw UninitializedPropertyAccessException("publisher has not been initialized yet.")
}
private var _subscriber: PeerConnectionTransport? = null
internal val subscriber: PeerConnectionTransport
get() {
return _subscriber
?: throw UninitializedPropertyAccessException("subscriber has not been initialized yet.")
}
private var reliableDataChannel: DataChannel? = null
private var reliableDataChannelSub: DataChannel? = null
private var lossyDataChannel: DataChannel? = null
... ... @@ -89,13 +102,15 @@ internal constructor(
private var hasPublished = false
private val coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
init {
client.listener = this
}
suspend fun join(url: String, token: String, options: ConnectOptions): LivekitRtc.JoinResponse {
coroutineScope.close()
coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
sessionUrl = url
sessionToken = token
val joinResponse = client.join(url, token, options)
... ... @@ -104,9 +119,8 @@ internal constructor(
isSubscriberPrimary = joinResponse.subscriberPrimary
if (!this::publisher.isInitialized) {
configure(joinResponse, options)
}
configure(joinResponse, options)
// create offer
if (!this.isSubscriberPrimary) {
negotiate()
... ... @@ -116,7 +130,7 @@ internal constructor(
}
private fun configure(joinResponse: LivekitRtc.JoinResponse, connectOptions: ConnectOptions?) {
if (this::publisher.isInitialized || this::subscriber.isInitialized) {
if (_publisher != null && _subscriber != null) {
// already configured
return
}
... ... @@ -160,13 +174,14 @@ internal constructor(
enableDtlsSrtp = true
}
publisher = pctFactory.create(
_publisher?.close()
_publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber = pctFactory.create(
_subscriber?.close()
_subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
... ... @@ -248,10 +263,15 @@ internal constructor(
}
fun close() {
if (isClosed) {
return
}
isClosed = true
coroutineScope.close()
publisher.close()
subscriber.close()
_publisher?.close()
_publisher = null
_subscriber?.close()
_subscriber = null
client.close()
}
... ... @@ -318,8 +338,8 @@ internal constructor(
}
listener?.onEngineDisconnected("failed reconnecting.")
close()
listener?.onEngineDisconnected("failed reconnecting.")
}
reconnectingJob = job
... ... @@ -361,8 +381,8 @@ internal constructor(
return
}
if (!this::publisher.isInitialized) {
throw RoomException.ConnectException("Publisher is not connected!")
if (_publisher == null) {
throw RoomException.ConnectException("Publisher isn't setup yet! Is room not connected?!")
}
if (!publisher.peerConnection.isConnected() &&
... ... @@ -572,7 +592,7 @@ internal constructor(
// Signal error
override fun onError(error: Throwable) {
if (isClosed) {
if (connectionState == ConnectionState.CONNECTING) {
listener?.onFailToConnect(error)
}
}
... ...
... ... @@ -142,7 +142,7 @@ constructor(
}
coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
state = State.CONNECTING
this.connectOptions = connectOptions
connectOptions = options
val response = engine.join(url, token, options)
LKLog.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
... ...
... ... @@ -9,6 +9,9 @@ import org.mockito.junit.MockitoJUnit
@ExperimentalCoroutinesApi
abstract class BaseTest {
// Uncomment to enable logging in tests.
//@get:Rule
//var loggingRule = LoggingRule()
@get:Rule
var mockitoRule = MockitoJUnit.rule()
... ...
... ... @@ -2,10 +2,12 @@ package io.livekit.android
import android.content.Context
import androidx.test.core.app.ApplicationProvider
import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.mock.MockWebSocketFactory
import io.livekit.android.mock.dagger.DaggerTestLiveKitComponent
import io.livekit.android.mock.dagger.TestCoroutinesModule
import io.livekit.android.mock.dagger.TestLiveKitComponent
import io.livekit.android.room.PeerConnectionTransport
import io.livekit.android.room.Room
import io.livekit.android.room.SignalClientTest
import io.livekit.android.util.toOkioByteString
... ... @@ -15,14 +17,16 @@ import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import org.junit.Before
import org.webrtc.PeerConnection
@ExperimentalCoroutinesApi
abstract class MockE2ETest : BaseTest() {
internal lateinit var component: TestLiveKitComponent
lateinit var context: Context
lateinit var room: Room
lateinit var wsFactory: MockWebSocketFactory
internal lateinit var context: Context
internal lateinit var room: Room
internal lateinit var wsFactory: MockWebSocketFactory
internal lateinit var subscriber: PeerConnectionTransport
@Before
fun setup() {
... ... @@ -37,6 +41,11 @@ abstract class MockE2ETest : BaseTest() {
}
suspend fun connect() {
connectSignal()
connectPeerConnection()
}
suspend fun connectSignal() {
val job = coroutineRule.scope.launch {
room.connect(
url = SignalClientTest.EXAMPLE_URL,
... ... @@ -49,6 +58,13 @@ abstract class MockE2ETest : BaseTest() {
job.join()
}
suspend fun connectPeerConnection() {
subscriber = component.rtcEngine().subscriber
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.OFFER.toOkioByteString())
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
}
fun createOpenResponse(request: Request): Response {
return Response.Builder()
.request(request)
... ...
... ... @@ -6,7 +6,6 @@ import io.livekit.android.mock.MockWebSocket
import io.livekit.android.util.LoggingRule
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Before
... ... @@ -14,18 +13,12 @@ import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.webrtc.PeerConnection
import org.webrtc.SessionDescription
@ExperimentalCoroutinesApi
@RunWith(RobolectricTestRunner::class)
class RTCEngineMockE2ETest : MockE2ETest() {
@get:Rule
var loggingRule = LoggingRule()
lateinit var rtcEngine: RTCEngine
@Before
... ... @@ -36,11 +29,10 @@ class RTCEngineMockE2ETest : MockE2ETest() {
@Test
fun iceSubscriberConnect() = runTest {
connect()
val remoteOffer = SessionDescription(SessionDescription.Type.OFFER, "remote_offer")
rtcEngine.onOffer(remoteOffer)
Assert.assertEquals(remoteOffer, rtcEngine.subscriber.peerConnection.remoteDescription)
Assert.assertEquals(
SignalClientTest.OFFER.offer.sdp,
rtcEngine.subscriber.peerConnection.remoteDescription.description
)
val ws = wsFactory.ws as MockWebSocket
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
... ... @@ -52,9 +44,6 @@ class RTCEngineMockE2ETest : MockE2ETest() {
Assert.assertTrue(sentRequest.hasAnswer())
Assert.assertEquals(localAnswer.description, sentRequest.answer.sdp)
Assert.assertEquals(localAnswer.type.canonicalForm(), sentRequest.answer.type)
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
Assert.assertEquals(ConnectionState.CONNECTED, rtcEngine.connectionState)
}
... ...
... ... @@ -3,6 +3,7 @@ package io.livekit.android.room
import android.net.Network
import io.livekit.android.MockE2ETest
import io.livekit.android.events.EventCollector
import io.livekit.android.events.FlowCollector
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.MockAudioStreamTrack
import io.livekit.android.mock.MockMediaStream
... ... @@ -10,6 +11,8 @@ import io.livekit.android.mock.TestData
import io.livekit.android.mock.createMediaStreamId
import io.livekit.android.room.participant.ConnectionQuality
import io.livekit.android.room.track.Track
import io.livekit.android.util.delegate
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
... ... @@ -25,7 +28,13 @@ class RoomMockE2ETest : MockE2ETest() {
@Test
fun connectTest() = runTest {
val collector = FlowCollector(room::state.flow, coroutineRule.scope)
connect()
val events = collector.stopCollecting()
Assert.assertEquals(3, events.size)
Assert.assertEquals(Room.State.DISCONNECTED, events[0])
Assert.assertEquals(Room.State.CONNECTING, events[1])
Assert.assertEquals(Room.State.CONNECTED, events[2])
}
@Test
... ... @@ -247,4 +256,11 @@ class RoomMockE2ETest : MockE2ETest() {
Assert.assertEquals(true, events[0] is RoomEvent.Disconnected)
}
@Test
fun reconnectAfterDisconnect() = runTest {
connect()
room.disconnect()
connect()
Assert.assertEquals(room.state, Room.State.CONNECTED)
}
}
\ No newline at end of file
... ...
... ... @@ -174,7 +174,7 @@ class SignalClientTest : BaseTest() {
val OFFER = with(LivekitRtc.SignalResponse.newBuilder()) {
offer = with(offerBuilder) {
sdp = ""
sdp = "remote_offer"
type = "offer"
build()
}
... ...
... ... @@ -7,9 +7,12 @@ import org.junit.runner.Description
import org.junit.runners.model.Statement
import timber.log.Timber
/**
* Add this rule to a test class to turn on logs.
*/
class LoggingRule : TestRule {
val logTree = object : Timber.Tree() {
val logTree = object : Timber.DebugTree() {
override fun log(priority: Int, tag: String?, message: String, t: Throwable?) {
val priorityChar = when (priority) {
Log.VERBOSE -> "v"
... ... @@ -32,8 +35,8 @@ class LoggingRule : TestRule {
override fun apply(base: Statement, description: Description?) = object : Statement() {
override fun evaluate() {
val oldLoggingLevel = LiveKit.loggingLevel
LiveKit.loggingLevel = LoggingLevel.VERBOSE
Timber.plant(logTree)
LiveKit.loggingLevel = LoggingLevel.VERBOSE
base.evaluate()
Timber.uproot(logTree)
LiveKit.loggingLevel = oldLoggingLevel
... ...
... ... @@ -248,13 +248,24 @@ class CallViewModel(
room.value?.localParticipant?.setTrackSubscriptionPermissions(mutablePermissionAllowed.value)
}
fun simulateMigration(){
fun simulateMigration() {
room.value?.sendSimulateScenario(
LivekitRtc.SimulateScenario.newBuilder()
.setMigration(true)
.build()
)
}
fun reconnect() {
room.value?.disconnect()
viewModelScope.launch {
room.value?.connect(
url,
token
)
}
}
}
private fun <T> LiveData<T>.hide(): LiveData<T> = this
... ...
... ... @@ -111,6 +111,7 @@ class CallActivity : AppCompatActivity() {
onExitClick = { finish() },
onSendMessage = { viewModel.sendData(it) },
onSimulateMigration = { viewModel.simulateMigration() },
fullReconnect = { viewModel.reconnect() },
)
}
}
... ... @@ -159,6 +160,7 @@ class CallActivity : AppCompatActivity() {
onSnackbarDismiss: () -> Unit = {},
onSendMessage: (String) -> Unit = {},
onSimulateMigration: () -> Unit = {},
fullReconnect: () -> Unit = {},
) {
AppTheme(darkTheme = true) {
ConstraintLayout(
... ... @@ -410,7 +412,8 @@ class CallActivity : AppCompatActivity() {
if (showDebugDialog) {
DebugMenuDialog(
onDismissRequest = { showDebugDialog = false },
simulateMigration = { onSimulateMigration() }
simulateMigration = { onSimulateMigration() },
fullReconnect = { fullReconnect() },
)
}
}
... ...
... ... @@ -17,7 +17,8 @@ import androidx.compose.ui.window.Dialog
@Composable
fun DebugMenuDialog(
onDismissRequest: () -> Unit = {},
simulateMigration: () -> Unit = {}
simulateMigration: () -> Unit = {},
fullReconnect: () -> Unit = {},
) {
Dialog(onDismissRequest = onDismissRequest) {
Column(
... ... @@ -36,6 +37,11 @@ fun DebugMenuDialog(
}) {
Text("Simulate Migration")
}
Button(onClick = {
fullReconnect()
}) {
Text("Reconnect to room")
}
}
}
}
\ No newline at end of file
... ...