David Liu

some coroutines stuff

  1 +package io.livekit.android.dagger
  2 +
  3 +import dagger.Module
  4 +import dagger.Provides
  5 +import kotlinx.coroutines.Dispatchers
  6 +import javax.inject.Named
  7 +
  8 +@Module
  9 +class CoroutinesModule {
  10 + companion object {
  11 +
  12 +
  13 +
  14 + @Provides
  15 + @Named(InjectionNames.DISPATCHER_DEFAULT)
  16 + fun defaultDispatcher() = Dispatchers.Default
  17 +
  18 + @Provides
  19 + @Named(InjectionNames.DISPATCHER_IO)
  20 + fun ioDispatcher() = Dispatchers.IO
  21 +
  22 + @Provides
  23 + @Named(InjectionNames.DISPATCHER_MAIN)
  24 + fun mainDispatcher() = Dispatchers.Main
  25 +
  26 + @Provides
  27 + @Named(InjectionNames.DISPATCHER_UNCONFINED)
  28 + fun unconfinedDispatcher() = Dispatchers.Unconfined
  29 +
  30 + }
  31 +}
  1 +package io.livekit.android.dagger
  2 +
  3 +class InjectionNames {
  4 + companion object {
  5 +
  6 + const val DISPATCHER_DEFAULT = "dispatcher_default"
  7 + const val DISPATCHER_IO = "dispatcher_io";
  8 + const val DISPATCHER_MAIN = "dispatcher_main"
  9 + const val DISPATCHER_UNCONFINED = "dispatcher_unconfined"
  10 + }
  11 +}
@@ -2,6 +2,7 @@ package io.livekit.android.room @@ -2,6 +2,7 @@ package io.livekit.android.room
2 2
3 import com.github.ajalt.timberkt.Timber 3 import com.github.ajalt.timberkt.Timber
4 import com.google.protobuf.util.JsonFormat 4 import com.google.protobuf.util.JsonFormat
  5 +import io.livekit.android.room.track.Track
5 import kotlinx.serialization.decodeFromString 6 import kotlinx.serialization.decodeFromString
6 import kotlinx.serialization.encodeToString 7 import kotlinx.serialization.encodeToString
7 import kotlinx.serialization.json.Json 8 import kotlinx.serialization.json.Json
@@ -144,9 +145,9 @@ constructor( @@ -144,9 +145,9 @@ constructor(
144 sendRequest(request) 145 sendRequest(request)
145 } 146 }
146 147
147 - fun sendMuteTrack(trackSid: String, muted: Boolean) { 148 + fun sendMuteTrack(trackSid: Track.Sid, muted: Boolean) {
148 val muteRequest = Rtc.MuteTrackRequest.newBuilder() 149 val muteRequest = Rtc.MuteTrackRequest.newBuilder()
149 - .setSid(trackSid) 150 + .setSid(trackSid.sid)
150 .setMuted(muted) 151 .setMuted(muted)
151 .build() 152 .build()
152 153
@@ -157,9 +158,9 @@ constructor( @@ -157,9 +158,9 @@ constructor(
157 sendRequest(request) 158 sendRequest(request)
158 } 159 }
159 160
160 - fun sendAddTrack(cid: String, name: String, type: Model.TrackType) { 161 + fun sendAddTrack(cid: Track.Cid, name: String, type: Model.TrackType) {
161 val addTrackRequest = Rtc.AddTrackRequest.newBuilder() 162 val addTrackRequest = Rtc.AddTrackRequest.newBuilder()
162 - .setCid(cid) 163 + .setCid(cid.cid)
163 .setName(name) 164 .setName(name)
164 .setType(type) 165 .setType(type)
165 .build() 166 .build()
@@ -183,7 +184,6 @@ constructor( @@ -183,7 +184,6 @@ constructor(
183 Timber.d { "error sending request: $request" } 184 Timber.d { "error sending request: $request" }
184 throw IllegalStateException() 185 throw IllegalStateException()
185 } 186 }
186 -  
187 } 187 }
188 188
189 fun handleSignalResponse(response: Rtc.SignalResponse) { 189 fun handleSignalResponse(response: Rtc.SignalResponse) {
@@ -224,15 +224,18 @@ constructor( @@ -224,15 +224,18 @@ constructor(
224 } 224 }
225 } 225 }
226 226
  227 + fun close() {
  228 + TODO("Not yet implemented")
  229 + }
  230 +
227 interface Listener { 231 interface Listener {
228 fun onJoin(info: Rtc.JoinResponse) 232 fun onJoin(info: Rtc.JoinResponse)
229 -  
230 fun onAnswer(sessionDescription: SessionDescription) 233 fun onAnswer(sessionDescription: SessionDescription)
231 fun onOffer(sessionDescription: SessionDescription) 234 fun onOffer(sessionDescription: SessionDescription)
232 -  
233 fun onTrickle(candidate: IceCandidate, target: Rtc.SignalTarget) 235 fun onTrickle(candidate: IceCandidate, target: Rtc.SignalTarget)
234 - fun onLocalTrackPublished(trackPublished: Rtc.TrackPublishedResponse) 236 + fun onLocalTrackPublished(response: Rtc.TrackPublishedResponse)
235 fun onParticipantUpdate(updates: List<Model.ParticipantInfo>) 237 fun onParticipantUpdate(updates: List<Model.ParticipantInfo>)
  238 + fun onActiveSpeakersChanged(speakers: List<Rtc.SpeakerInfo>)
236 fun onClose(reason: String, code: Int) 239 fun onClose(reason: String, code: Int)
237 fun onError(error: Error) 240 fun onError(error: Error)
238 } 241 }
1 package io.livekit.android.room 1 package io.livekit.android.room
2 2
3 import android.content.Context 3 import android.content.Context
  4 +import com.github.ajalt.timberkt.Timber
  5 +import io.livekit.android.dagger.InjectionNames
4 import io.livekit.android.room.track.Track 6 import io.livekit.android.room.track.Track
  7 +import io.livekit.android.room.track.TrackException
  8 +import io.livekit.android.room.util.CoroutineSdpObserver
  9 +import io.livekit.android.util.CloseableCoroutineScope
  10 +import io.livekit.android.util.Either
  11 +import kotlinx.coroutines.CoroutineDispatcher
  12 +import kotlinx.coroutines.SupervisorJob
  13 +import kotlinx.coroutines.launch
5 import livekit.Model 14 import livekit.Model
6 import livekit.Rtc 15 import livekit.Rtc
7 import org.webrtc.* 16 import org.webrtc.*
8 import javax.inject.Inject 17 import javax.inject.Inject
  18 +import javax.inject.Named
9 import kotlin.coroutines.Continuation 19 import kotlin.coroutines.Continuation
  20 +import kotlin.coroutines.resume
  21 +import kotlin.coroutines.suspendCoroutine
10 22
11 23
12 class RTCEngine 24 class RTCEngine
@@ -15,6 +27,7 @@ constructor( @@ -15,6 +27,7 @@ constructor(
15 private val appContext: Context, 27 private val appContext: Context,
16 val client: RTCClient, 28 val client: RTCClient,
17 pctFactory: PeerConnectionTransport.Factory, 29 pctFactory: PeerConnectionTransport.Factory,
  30 + @Named(InjectionNames.DISPATCHER_IO) ioDispatcher: CoroutineDispatcher,
18 ) : RTCClient.Listener { 31 ) : RTCClient.Listener {
19 32
20 var listener: Listener? = null 33 var listener: Listener? = null
@@ -39,6 +52,7 @@ constructor( @@ -39,6 +52,7 @@ constructor(
39 52
40 private var privateDataChannel: DataChannel 53 private var privateDataChannel: DataChannel
41 54
  55 + private val coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
42 init { 56 init {
43 val rtcConfig = PeerConnection.RTCConfiguration(RTCClient.DEFAULT_ICE_SERVERS).apply { 57 val rtcConfig = PeerConnection.RTCConfiguration(RTCClient.DEFAULT_ICE_SERVERS).apply {
44 sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN 58 sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN
@@ -55,12 +69,66 @@ constructor( @@ -55,12 +69,66 @@ constructor(
55 ) 69 )
56 } 70 }
57 71
58 - suspend fun join(url: String, token: String, isSecure: Boolean) { 72 + fun join(url: String, token: String, isSecure: Boolean) {
59 client.join(url, token, isSecure) 73 client.join(url, token, isSecure)
60 } 74 }
61 75
  76 + suspend fun addTrack(cid: Track.Cid, name: String, kind: Model.TrackType): Model.TrackInfo {
  77 + if (pendingTrackResolvers[cid] != null) {
  78 + throw TrackException.DuplicateTrackException("Track with same ID $cid has already been published!")
  79 + }
  80 +
  81 + return suspendCoroutine { cont ->
  82 + pendingTrackResolvers[cid] = cont
  83 + client.sendAddTrack(cid, name, kind)
  84 + }
  85 + }
  86 +
  87 + fun updateMuteStatus(sid: Track.Sid, muted: Boolean) {
  88 + client.sendMuteTrack(sid, muted)
  89 + }
  90 +
  91 + fun close() {
  92 + publisher.close()
  93 + subscriber.close()
  94 + client.close()
  95 + }
  96 +
62 fun negotiate() { 97 fun negotiate() {
63 - TODO("Not yet implemented") 98 + coroutineScope.launch {
  99 + val offerObserver = CoroutineSdpObserver()
  100 + publisher.peerConnection.createOffer(offerObserver, OFFER_CONSTRAINTS)
  101 + val offerOutcome = offerObserver.awaitCreate()
  102 + val sdpOffer = when (offerOutcome) {
  103 + is Either.Left -> offerOutcome.value
  104 + is Either.Right -> {
  105 + Timber.d { "error creating offer: ${offerOutcome.value}" }
  106 + return@launch
  107 + }
  108 + }
  109 +
  110 + if (sdpOffer == null) {
  111 + Timber.d { "sdp is missing during negotiation?" }
  112 + return@launch
  113 + }
  114 +
  115 + val setObserver = CoroutineSdpObserver()
  116 + publisher.peerConnection.setLocalDescription(setObserver, sdpOffer)
  117 + val setOutcome = setObserver.awaitSet()
  118 + when (setOutcome) {
  119 + is Either.Left -> client.sendOffer(sdpOffer)
  120 + is Either.Right -> Timber.d { "error setting local description: ${setOutcome.value}" }
  121 + }
  122 + }
  123 + }
  124 +
  125 + private fun onRTCConnected() {
  126 + Timber.v { "RTC Connected" }
  127 + rtcConnected = true
  128 + pendingCandidates.forEach { candidate ->
  129 + client.sendCandidate(candidate, Rtc.SignalTarget.PUBLISHER)
  130 + }
  131 + pendingCandidates.clear()
64 } 132 }
65 133
66 interface Listener { 134 interface Listener {
@@ -69,17 +137,58 @@ constructor( @@ -69,17 +137,58 @@ constructor(
69 fun onPublishLocalTrack(cid: String, track: Model.TrackInfo) 137 fun onPublishLocalTrack(cid: String, track: Model.TrackInfo)
70 fun onAddDataChannel(channel: DataChannel) 138 fun onAddDataChannel(channel: DataChannel)
71 fun onUpdateParticipants(updates: Array<out Model.ParticipantInfo>) 139 fun onUpdateParticipants(updates: Array<out Model.ParticipantInfo>)
72 - fun onUpdateSpeakers(speakers: Array<out Rtc.SpeakerInfo>) 140 + fun onUpdateSpeakers(speakers: List<Rtc.SpeakerInfo>)
73 fun onDisconnect(reason: String) 141 fun onDisconnect(reason: String)
74 fun onFailToConnect(error: Error) 142 fun onFailToConnect(error: Error)
75 } 143 }
76 144
77 companion object { 145 companion object {
78 private const val PRIVATE_DATA_CHANNEL_LABEL = "_private" 146 private const val PRIVATE_DATA_CHANNEL_LABEL = "_private"
  147 +
  148 + private val OFFER_CONSTRAINTS = MediaConstraints().apply {
  149 + with(mandatory) {
  150 + add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false"))
  151 + add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "false"))
  152 + }
  153 + }
  154 +
  155 + private val MEDIA_CONSTRAINTS = MediaConstraints()
  156 +
  157 + private val CONN_CONSTRAINTS = MediaConstraints().apply {
  158 + with(optional) {
  159 + add(MediaConstraints.KeyValuePair("DtlsSrtpKeyAgreement", "true"))
  160 + }
  161 + }
79 } 162 }
80 163
81 override fun onJoin(info: Rtc.JoinResponse) { 164 override fun onJoin(info: Rtc.JoinResponse) {
82 - TODO("Not yet implemented") 165 + joinResponse = info
  166 +
  167 + coroutineScope.launch {
  168 + val offerObserver = CoroutineSdpObserver()
  169 + publisher.peerConnection.createOffer(offerObserver, OFFER_CONSTRAINTS)
  170 + val offerOutcome = offerObserver.awaitCreate()
  171 + val sdpOffer = when (offerOutcome) {
  172 + is Either.Left -> offerOutcome.value
  173 + is Either.Right -> {
  174 + Timber.d { "error creating offer: ${offerOutcome.value}" }
  175 + return@launch
  176 + }
  177 + }
  178 +
  179 + if (sdpOffer == null) {
  180 + Timber.d { "sdp is missing during negotiation?" }
  181 + return@launch
  182 + }
  183 +
  184 + val setObserver = CoroutineSdpObserver()
  185 + publisher.peerConnection.setLocalDescription(setObserver, sdpOffer)
  186 + val setOutcome = setObserver.awaitSet()
  187 + when (setOutcome) {
  188 + is Either.Left -> client.sendOffer(sdpOffer)
  189 + is Either.Right -> Timber.d { "error setting local description: ${setOutcome.value}" }
  190 + }
  191 + }
83 } 192 }
84 193
85 override fun onAnswer(sessionDescription: SessionDescription) { 194 override fun onAnswer(sessionDescription: SessionDescription) {
@@ -94,14 +203,36 @@ constructor( @@ -94,14 +203,36 @@ constructor(
94 TODO("Not yet implemented") 203 TODO("Not yet implemented")
95 } 204 }
96 205
97 - override fun onLocalTrackPublished(trackPublished: Rtc.TrackPublishedResponse) {  
98 - TODO("Not yet implemented") 206 + override fun onLocalTrackPublished(response: Rtc.TrackPublishedResponse) {
  207 + val cid = response.cid ?: run {
  208 + Timber.e { "local track published with null cid?" }
  209 + return
  210 + }
  211 +
  212 + val track = response.track
  213 + if (track == null) {
  214 + Timber.d { "local track published with null track info?" }
  215 + }
  216 +
  217 + Timber.v { "local track published $cid" }
  218 + val cont = pendingTrackResolvers.remove(cid)
  219 + if (cont == null) {
  220 + Timber.d { "missing track resolver for: $cid" }
  221 + return
  222 + }
  223 + cont.resume(response.track)
  224 + listener?.onPublishLocalTrack(cid, track)
  225 +
99 } 226 }
100 227
101 override fun onParticipantUpdate(updates: List<Model.ParticipantInfo>) { 228 override fun onParticipantUpdate(updates: List<Model.ParticipantInfo>) {
102 TODO("Not yet implemented") 229 TODO("Not yet implemented")
103 } 230 }
104 231
  232 + override fun onActiveSpeakersChanged(speakers: List<Rtc.SpeakerInfo>) {
  233 + listener?.onUpdateSpeakers(speakers)
  234 + }
  235 +
105 override fun onClose(reason: String, code: Int) { 236 override fun onClose(reason: String, code: Int) {
106 TODO("Not yet implemented") 237 TODO("Not yet implemented")
107 } 238 }
@@ -44,16 +44,20 @@ class Track(name: String, state: State) { @@ -44,16 +44,20 @@ class Track(name: String, state: State) {
44 } 44 }
45 } 45 }
46 46
47 -sealed class TrackException(message: String?, cause: Throwable?) : Exception(message, cause) {  
48 - class InvalidTrackTypeException(message: String?, cause: Throwable?) : 47 +sealed class TrackException(message: String? = null, cause: Throwable? = null) :
  48 + Exception(message, cause) {
  49 + class InvalidTrackTypeException(message: String? = null, cause: Throwable? = null) :
49 TrackException(message, cause) 50 TrackException(message, cause)
50 51
51 - class DuplicateTrackException(message: String?, cause: Throwable?) : 52 + class DuplicateTrackException(message: String? = null, cause: Throwable? = null) :
52 TrackException(message, cause) 53 TrackException(message, cause)
53 54
54 - class InvalidTrackStateException(message: String?, cause: Throwable?) : 55 + class InvalidTrackStateException(message: String? = null, cause: Throwable? = null) :
55 TrackException(message, cause) 56 TrackException(message, cause)
56 57
57 - class MediaException(message: String?, cause: Throwable?) : TrackException(message, cause)  
58 - class PublishException(message: String?, cause: Throwable?) : TrackException(message, cause) 58 + class MediaException(message: String? = null, cause: Throwable? = null) :
  59 + TrackException(message, cause)
  60 +
  61 + class PublishException(message: String? = null, cause: Throwable? = null) :
  62 + TrackException(message, cause)
59 } 63 }
  1 +package io.livekit.android.room.util
  2 +
  3 +import io.livekit.android.util.Either
  4 +import org.webrtc.SdpObserver
  5 +import org.webrtc.SessionDescription
  6 +import kotlin.coroutines.Continuation
  7 +import kotlin.coroutines.resume
  8 +import kotlin.coroutines.suspendCoroutine
  9 +
  10 +class CoroutineSdpObserver : SdpObserver {
  11 + private var createOutcome: Either<SessionDescription?, String?>? = null
  12 + set(value) {
  13 + field = value
  14 + if (value != null) {
  15 + val conts = pendingCreate.toList()
  16 + pendingCreate.clear()
  17 + conts.forEach {
  18 + it.resume(value)
  19 + }
  20 + }
  21 + }
  22 + private var pendingCreate = mutableListOf<Continuation<Either<SessionDescription?, String?>>>()
  23 +
  24 + private var setOutcome: Either<Unit, String?>? = null
  25 + set(value) {
  26 + field = value
  27 + if (value != null) {
  28 + val conts = pendingSets.toList()
  29 + pendingSets.clear()
  30 + conts.forEach {
  31 + it.resume(value)
  32 + }
  33 + }
  34 + }
  35 + private var pendingSets = mutableListOf<Continuation<Either<Unit, String?>>>()
  36 +
  37 + override fun onCreateSuccess(sdp: SessionDescription?) {
  38 + createOutcome = Either.Left(sdp)
  39 + }
  40 +
  41 + override fun onSetSuccess() {
  42 + setOutcome = Either.Left(Unit)
  43 + }
  44 +
  45 + override fun onCreateFailure(message: String?) {
  46 + createOutcome = Either.Right(message)
  47 + }
  48 +
  49 + override fun onSetFailure(message: String?) {
  50 + setOutcome = Either.Right(message)
  51 + }
  52 +
  53 + suspend fun awaitCreate() = suspendCoroutine<Either<SessionDescription?, String?>> { cont ->
  54 + val curOutcome = createOutcome
  55 + if (curOutcome != null) {
  56 + cont.resume(curOutcome)
  57 + } else {
  58 + pendingCreate.add(cont)
  59 + }
  60 + }
  61 +
  62 + suspend fun awaitSet() = suspendCoroutine<Either<Unit, String?>> { cont ->
  63 + val curOutcome = setOutcome
  64 + if (curOutcome != null) {
  65 + cont.resume(curOutcome)
  66 + } else {
  67 + pendingSets.add(cont)
  68 + }
  69 + }
  70 +}
  1 +package io.livekit.android.util
  2 +
  3 +import kotlinx.coroutines.CoroutineScope
  4 +import kotlinx.coroutines.cancel
  5 +import java.io.Closeable
  6 +import kotlin.coroutines.CoroutineContext
  7 +
  8 +internal class CloseableCoroutineScope(context: CoroutineContext) : Closeable, CoroutineScope {
  9 + override val coroutineContext: CoroutineContext = context
  10 +
  11 + override fun close() {
  12 + coroutineContext.cancel()
  13 + }
  14 +}
  1 +package io.livekit.android.util
  2 +
  3 +sealed class Either<out A, out B> {
  4 + class Left<A>(val value: A) : Either<A, Nothing>()
  5 + class Right<B>(val value: B) : Either<Nothing, B>()
  6 +}