davidliu
Committed by GitHub

Permissions API (#37)

* permissions API

* naming

* don't add track if disallowed

* hold on to track for now

* room event for subscription permission update

* update test

* language cleanup

* Move track permission handling to RemoteParticipant to allow emission of ParticipantEvent

* Subscription status enum

* revert preventing subscription to disallowed track

* keep subscribe boolean and add extra subscriptionStatus

* fix build
package io.livekit.android.events
import io.livekit.android.room.Room
import io.livekit.android.room.participant.LocalParticipant
import io.livekit.android.room.participant.Participant
import io.livekit.android.room.participant.RemoteParticipant
... ... @@ -107,4 +108,14 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() {
val trackPublication: TrackPublication,
val streamState: Track.StreamState
) : ParticipantEvent(participant)
/**
* A remote track's subscription permissions have changed.
*/
class TrackSubscriptionPermissionChanged(
override val participant: RemoteParticipant,
val trackPublication: RemoteTrackPublication,
val subscriptionAllowed: Boolean
) : ParticipantEvent(participant)
}
\ No newline at end of file
... ...
... ... @@ -6,6 +6,7 @@ import io.livekit.android.room.participant.LocalParticipant
import io.livekit.android.room.participant.Participant
import io.livekit.android.room.participant.RemoteParticipant
import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackPublication
... ... @@ -131,6 +132,16 @@ sealed class RoomEvent(val room: Room) : Event() {
) : RoomEvent(room)
/**
* A remote track's subscription permissions have changed.
*/
class TrackSubscriptionPermissionChanged(
room: Room,
val participant: RemoteParticipant,
val trackPublication: RemoteTrackPublication,
val subscriptionAllowed: Boolean
) : RoomEvent(room)
/**
* Received data published by another participant
*/
class DataReceived(room: Room, val data: ByteArray, val participant: RemoteParticipant) : RoomEvent(room)
... ...
... ... @@ -3,6 +3,7 @@ package io.livekit.android.room
import android.os.SystemClock
import io.livekit.android.ConnectOptions
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.ParticipantTrackPermission
import io.livekit.android.room.track.TrackException
import io.livekit.android.room.util.*
import io.livekit.android.util.CloseableCoroutineScope
... ... @@ -229,6 +230,13 @@ internal constructor(
}
}
fun updateSubscriptionPermissions(
allParticipants: Boolean,
participantTrackPermissions: List<ParticipantTrackPermission>
) {
client.sendUpdateSubscriptionPermissions(allParticipants, participantTrackPermissions)
}
fun updateMuteStatus(sid: String, muted: Boolean) {
client.sendMuteTrack(sid, muted)
}
... ... @@ -386,6 +394,7 @@ internal constructor(
fun onUserPacket(packet: LivekitModels.UserPacket, kind: LivekitModels.DataPacket.Kind)
fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>)
fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate)
fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate)
}
companion object {
... ... @@ -531,6 +540,10 @@ internal constructor(
listener?.onSubscribedQualityUpdate(subscribedQualityUpdate)
}
override fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate) {
listener?.onSubscriptionPermissionUpdate(subscriptionPermissionUpdate)
}
//--------------------------------- DataChannel.Observer ------------------------------------//
override fun onBufferedAmountChange(previousAmount: Long) {
... ...
... ... @@ -226,6 +226,14 @@ constructor(
it.streamState
)
)
is ParticipantEvent.TrackSubscriptionPermissionChanged -> eventBus.postEvent(
RoomEvent.TrackSubscriptionPermissionChanged(
this@Room,
it.participant,
it.trackPublication,
it.subscriptionAllowed
)
)
}
}
}
... ... @@ -483,6 +491,11 @@ constructor(
localParticipant.handleSubscribedQualityUpdate(subscribedQualityUpdate)
}
override fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate) {
val participant = getParticipant(subscriptionPermissionUpdate.participantSid) as? RemoteParticipant ?: return
participant.onSubscriptionPermissionUpdate(subscriptionPermissionUpdate)
}
/**
* @suppress
*/
... ...
... ... @@ -5,6 +5,7 @@ import com.vdurmont.semver4j.Semver
import io.livekit.android.ConnectOptions
import io.livekit.android.Version
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.ParticipantTrackPermission
import io.livekit.android.room.track.Track
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
... ... @@ -329,6 +330,21 @@ constructor(
sendRequest(request)
}
fun sendUpdateSubscriptionPermissions(
allParticipants: Boolean,
participantTrackPermissions: List<ParticipantTrackPermission>
) {
val update = LivekitRtc.UpdateSubscriptionPermissions.newBuilder()
.setAllParticipants(allParticipants)
.addAllTrackPermissions(participantTrackPermissions.map { it.toProto() })
val request = LivekitRtc.SignalRequest.newBuilder()
.setSubscriptionPermissions(update)
.build()
sendRequest(request)
}
fun sendLeave() {
val request = LivekitRtc.SignalRequest.newBuilder()
.setLeave(LivekitRtc.LeaveRequest.newBuilder().build())
... ... @@ -433,7 +449,7 @@ constructor(
listener?.onSubscribedQualityUpdate(response.subscribedQualityUpdate)
}
LivekitRtc.SignalResponse.MessageCase.SUBSCRIPTION_PERMISSION_UPDATE -> {
// TODO
listener?.onSubscriptionPermissionUpdate(response.subscriptionPermissionUpdate)
}
LivekitRtc.SignalResponse.MessageCase.MESSAGE_NOT_SET,
null -> {
... ... @@ -463,6 +479,7 @@ constructor(
fun onError(error: Throwable)
fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>)
fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate)
fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate)
}
companion object {
... ...
... ... @@ -419,6 +419,30 @@ internal constructor(
}
}
/**
* Control who can subscribe to LocalParticipant's published tracks.
*
* By default, all participants can subscribe. This allows fine-grained control over
* who is able to subscribe at a participant and track level.
*
* Note: if access is given at a track-level (i.e. both [allParticipantsAllowed] and
* [ParticipantTrackPermission.allTracksAllowed] are false), any newer published tracks
* will not grant permissions to any participants and will require a subsequent
* permissions update to allow subscription.
*
* @param allParticipantsAllowed Allows all participants to subscribe all tracks.
* Takes precedence over [participantTrackPermissions] if set to true.
* By default this is set to true.
* @param participantTrackPermissions Full list of individual permissions per
* participant/track. Any omitted participants will not receive any permissions.
*/
fun setTrackSubscriptionPermissions(
allParticipantsAllowed: Boolean,
participantTrackPermissions: List<ParticipantTrackPermission> = emptyList()
) {
engine.updateSubscriptionPermissions(allParticipantsAllowed, participantTrackPermissions)
}
fun unpublishTrack(track: Track) {
val publication = localTrackPublications.firstOrNull { it.track == track }
if (publication === null) {
... ... @@ -616,4 +640,29 @@ data class AudioTrackPublishOptions(
base.audioBitrate,
base.dtx
)
}
data class ParticipantTrackPermission(
/**
* The participant id this permission applies to.
*/
val participantSid: String,
/**
* If set to true, the target participant can subscribe to all tracks from the local participant.
*
* Takes precedence over [allowedTrackSids].
*/
val allTracksAllowed: Boolean,
/**
* The list of track ids that the target participant can subscribe to.
*/
val allowedTrackSids: List<String> = emptyList()
) {
fun toProto(): LivekitRtc.TrackPermission {
return LivekitRtc.TrackPermission.newBuilder()
.setParticipantSid(participantSid)
.setAllTracks(allTracksAllowed)
.addAllTrackSids(allowedTrackSids)
.build()
}
}
\ No newline at end of file
... ...
... ... @@ -10,6 +10,7 @@ import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.AudioTrack
import org.webrtc.MediaStreamTrack
import org.webrtc.VideoTrack
... ... @@ -19,8 +20,8 @@ class RemoteParticipant(
identity: String? = null,
val signalClient: SignalClient,
private val ioDispatcher: CoroutineDispatcher,
defaultdispatcher: CoroutineDispatcher,
) : Participant(sid, identity, defaultdispatcher) {
defaultDispatcher: CoroutineDispatcher,
) : Participant(sid, identity, defaultDispatcher) {
/**
* @suppress
*/
... ... @@ -28,18 +29,18 @@ class RemoteParticipant(
info: LivekitModels.ParticipantInfo,
signalClient: SignalClient,
ioDispatcher: CoroutineDispatcher,
defaultdispatcher: CoroutineDispatcher,
defaultDispatcher: CoroutineDispatcher,
) : this(
info.sid,
info.identity,
signalClient,
ioDispatcher,
defaultdispatcher
defaultDispatcher
) {
updateFromInfo(info)
}
private val coroutineScope = CloseableCoroutineScope(SupervisorJob())
private val coroutineScope = CloseableCoroutineScope(defaultDispatcher + SupervisorJob())
fun getTrackPublication(sid: String): RemoteTrackPublication? = tracks[sid] as? RemoteTrackPublication
... ... @@ -98,17 +99,8 @@ class RemoteParticipant(
triesLeft: Int = 20
) {
val publication = getTrackPublication(sid)
val track: Track = when (val kind = mediaTrack.kind()) {
KIND_AUDIO -> AudioTrack(rtcTrack = mediaTrack as AudioTrack, name = "")
KIND_VIDEO -> RemoteVideoTrack(
rtcTrack = mediaTrack as VideoTrack,
name = "",
autoManageVideo = autoManageVideo,
dispatcher = ioDispatcher
)
else -> throw TrackException.InvalidTrackTypeException("invalid track type: $kind")
}
// We may receive subscribed tracks before publications come in. Retry until then.
if (publication == null) {
if (triesLeft == 0) {
val message = "Could not find published track with sid: $sid"
... ... @@ -127,6 +119,17 @@ class RemoteParticipant(
return
}
val track: Track = when (val kind = mediaTrack.kind()) {
KIND_AUDIO -> AudioTrack(rtcTrack = mediaTrack as AudioTrack, name = "")
KIND_VIDEO -> RemoteVideoTrack(
rtcTrack = mediaTrack as VideoTrack,
name = "",
autoManageVideo = autoManageVideo,
dispatcher = ioDispatcher
)
else -> throw TrackException.InvalidTrackTypeException("invalid track type: $kind")
}
publication.track = track
track.name = publication.name
track.sid = publication.sid
... ... @@ -158,6 +161,19 @@ class RemoteParticipant(
}
}
internal fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate) {
val pub = tracks[subscriptionPermissionUpdate.trackSid] as? RemoteTrackPublication ?: return
if (pub.subscriptionAllowed != subscriptionPermissionUpdate.allowed) {
pub.subscriptionAllowed = subscriptionPermissionUpdate.allowed
eventBus.postEvent(
ParticipantEvent.TrackSubscriptionPermissionChanged(this, pub, pub.subscriptionAllowed),
coroutineScope
)
}
}
// Internal methods just for posting events.
internal fun onDataReceived(data: ByteArray) {
listener?.onDataReceived(data, this)
... ...
... ... @@ -42,20 +42,6 @@ class RemoteTrackPublication(
}
}
private fun handleVisibilityChanged(trackEvent: TrackEvent.VisibilityChanged) {
disabled = !trackEvent.isVisible
sendUpdateTrackSettings.invoke()
}
private fun handleVideoDimensionsChanged(trackEvent: TrackEvent.VideoDimensionsChanged) {
videoDimensions = trackEvent.newDimensions
sendUpdateTrackSettings.invoke()
}
private fun handleStreamStateChanged(trackEvent: TrackEvent.StreamStateChanged) {
participant.get()?.onTrackStreamStateChanged(trackEvent)
}
private var trackJob: Job? = null
private var unsubscribed: Boolean = false
... ... @@ -63,16 +49,36 @@ class RemoteTrackPublication(
private var videoQuality: LivekitModels.VideoQuality? = LivekitModels.VideoQuality.HIGH
private var videoDimensions: Track.Dimensions? = null
var subscriptionAllowed: Boolean = true
internal set
val isAutoManaged: Boolean
get() = (track as? RemoteVideoTrack)?.autoManageVideo ?: false
/**
* Returns true if track is subscribed, and ready for playback
*
* @see [subscriptionStatus]
*/
override val subscribed: Boolean
get() {
if (unsubscribed) {
if (unsubscribed || !subscriptionAllowed) {
return false
}
return super.subscribed
}
val subscriptionStatus: SubscriptionStatus
get() {
return if (!unsubscribed || track == null) {
SubscriptionStatus.UNSUBSCRIBED
} else if (!subscriptionAllowed) {
SubscriptionStatus.SUBSCRIBED_AND_NOT_ALLOWED
} else {
SubscriptionStatus.SUBSCRIBED
}
}
override var muted: Boolean = false
set(v) {
if (field == v) {
... ... @@ -88,12 +94,11 @@ class RemoteTrackPublication(
}
/**
* subscribe or unsubscribe from this track
* Subscribe or unsubscribe from this track
*/
fun setSubscribed(subscribed: Boolean) {
unsubscribed = !subscribed
val participant = this.participant.get() as? RemoteParticipant ?: return
participant.signalClient.sendUpdateSubscription(sid, !unsubscribed)
}
... ... @@ -147,6 +152,20 @@ class RemoteTrackPublication(
sendUpdateTrackSettings.invoke()
}
private fun handleVisibilityChanged(trackEvent: TrackEvent.VisibilityChanged) {
disabled = !trackEvent.isVisible
sendUpdateTrackSettings.invoke()
}
private fun handleVideoDimensionsChanged(trackEvent: TrackEvent.VideoDimensionsChanged) {
videoDimensions = trackEvent.newDimensions
sendUpdateTrackSettings.invoke()
}
private fun handleStreamStateChanged(trackEvent: TrackEvent.StreamStateChanged) {
participant.get()?.onTrackStreamStateChanged(trackEvent)
}
// Debounce just in case multiple settings get changed at once.
private val sendUpdateTrackSettings = debounce<Unit, Unit>(100L, CoroutineScope(ioDispatcher)) {
sendUpdateTrackSettingsImpl()
... ... @@ -162,4 +181,21 @@ class RemoteTrackPublication(
videoQuality
)
}
enum class SubscriptionStatus {
/**
* Has a valid track, receiving data.
*/
SUBSCRIBED,
/**
* Has a track, but no data will be received due to permissions.
*/
SUBSCRIBED_AND_NOT_ALLOWED,
/**
* Not subscribed.
*/
UNSUBSCRIBED
}
}
\ No newline at end of file
... ...
... ... @@ -4,6 +4,9 @@ import org.webrtc.AudioTrack
import org.webrtc.MediaStream
import org.webrtc.VideoTrack
fun createMediaStreamId(participantSid: String, trackSid: String) =
"${TestData.REMOTE_PARTICIPANT.sid}|${TestData.REMOTE_AUDIO_TRACK.sid}"
class MockMediaStream(private val id: String = "id") : MediaStream(1L) {
override fun addTrack(track: AudioTrack): Boolean {
... ...
... ... @@ -9,10 +9,10 @@ import kotlinx.coroutines.test.TestCoroutineDispatcher
import javax.inject.Named
@Module
object TestCoroutinesModule {
class TestCoroutinesModule(
@OptIn(ExperimentalCoroutinesApi::class)
val coroutineDispatcher: CoroutineDispatcher = TestCoroutineDispatcher()
) {
@Provides
@Named(InjectionNames.DISPATCHER_DEFAULT)
... ...
... ... @@ -23,6 +23,6 @@ interface TestLiveKitComponent : LiveKitComponent {
@Component.Factory
interface Factory {
fun create(@BindsInstance appContext: Context): TestLiveKitComponent
fun create(@BindsInstance appContext: Context, coroutinesModule: TestCoroutinesModule = TestCoroutinesModule()): TestLiveKitComponent
}
}
\ No newline at end of file
... ...
... ... @@ -4,16 +4,17 @@ import android.content.Context
import androidx.test.core.app.ApplicationProvider
import io.livekit.android.coroutines.TestCoroutineRule
import io.livekit.android.events.EventCollector
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.*
import io.livekit.android.mock.dagger.DaggerTestLiveKitComponent
import io.livekit.android.mock.dagger.TestCoroutinesModule
import io.livekit.android.room.participant.ConnectionQuality
import io.livekit.android.room.track.Track
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runBlockingTest
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Before
import org.junit.Rule
... ... @@ -41,7 +42,7 @@ class RoomMockE2ETest {
context = ApplicationProvider.getApplicationContext()
val component = DaggerTestLiveKitComponent
.factory()
.create(context)
.create(context, TestCoroutinesModule(coroutineRule.dispatcher))
room = component.roomFactory()
.create(context)
... ... @@ -55,9 +56,10 @@ class RoomMockE2ETest {
token = "",
)
}
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.JOIN.toOkioByteString())
// PeerTransport negotiation is on a debounce delay.
coroutineRule.dispatcher.advanceTimeBy(1000L)
runBlockingTest {
job.join()
}
... ... @@ -69,7 +71,7 @@ class RoomMockE2ETest {
}
@Test
fun connectFailureProperlyContinues(){
fun connectFailureProperlyContinues() {
var didThrowException = false
val job = coroutineRule.scope.launch {
... ... @@ -91,6 +93,7 @@ class RoomMockE2ETest {
Assert.assertTrue(didThrowException)
}
@Test
fun roomUpdateTest() {
connect()
... ... @@ -203,7 +206,14 @@ class RoomMockE2ETest {
// add track.
room.onAddTrack(
MockAudioStreamTrack(),
arrayOf(MockMediaStream(id = "${TestData.REMOTE_PARTICIPANT.sid}|${TestData.REMOTE_AUDIO_TRACK.sid}"))
arrayOf(
MockMediaStream(
id = createMediaStreamId(
TestData.REMOTE_PARTICIPANT.sid,
TestData.REMOTE_AUDIO_TRACK.sid
)
)
)
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.listener.onMessage(
... ... @@ -220,6 +230,41 @@ class RoomMockE2ETest {
}
@Test
fun trackSubscriptionPermissionChanged() {
connect()
wsFactory.listener.onMessage(
wsFactory.ws,
SignalClientTest.PARTICIPANT_JOIN.toOkioByteString()
)
room.onAddTrack(
MockAudioStreamTrack(),
arrayOf(
MockMediaStream(
id = createMediaStreamId(
TestData.REMOTE_PARTICIPANT.sid,
TestData.REMOTE_AUDIO_TRACK.sid
)
)
)
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.listener.onMessage(
wsFactory.ws,
SignalClientTest.SUBSCRIPTION_PERMISSION_UPDATE.toOkioByteString()
)
val events = eventCollector.stopCollecting()
Assert.assertEquals(1, events.size)
Assert.assertEquals(true, events[0] is RoomEvent.TrackSubscriptionPermissionChanged)
val event = events[0] as RoomEvent.TrackSubscriptionPermissionChanged
Assert.assertEquals(TestData.REMOTE_PARTICIPANT.sid, event.participant.sid)
Assert.assertEquals(TestData.REMOTE_AUDIO_TRACK.sid, event.trackPublication.sid)
Assert.assertEquals(false, event.subscriptionAllowed)
}
@Test
fun leave() {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
... ...
... ... @@ -241,6 +241,16 @@ class SignalClientTest {
}
build()
}
val SUBSCRIPTION_PERMISSION_UPDATE = with(LivekitRtc.SignalResponse.newBuilder()) {
subscriptionPermissionUpdate = with(LivekitRtc.SubscriptionPermissionUpdate.newBuilder()) {
participantSid = TestData.REMOTE_PARTICIPANT.sid
trackSid = TestData.REMOTE_AUDIO_TRACK.sid
allowed = false
build()
}
build()
}
val LEAVE = with(LivekitRtc.SignalResponse.newBuilder()) {
leave = with(leaveBuilder) {
build()
... ...
... ... @@ -24,7 +24,7 @@ class RemoteParticipantTest {
"sid",
signalClient = signalClient,
ioDispatcher = coroutineRule.dispatcher,
defaultdispatcher = coroutineRule.dispatcher,
defaultDispatcher = coroutineRule.dispatcher,
)
}
... ... @@ -38,7 +38,7 @@ class RemoteParticipantTest {
info,
signalClient,
ioDispatcher = coroutineRule.dispatcher,
defaultdispatcher = coroutineRule.dispatcher,
defaultDispatcher = coroutineRule.dispatcher,
)
assertEquals(1, participant.tracks.values.size)
... ...