davidliu
Committed by GitHub

Add fps to subscribe setting (#207)

@@ -374,6 +374,7 @@ constructor( @@ -374,6 +374,7 @@ constructor(
374 disabled: Boolean, 374 disabled: Boolean,
375 videoDimensions: Track.Dimensions?, 375 videoDimensions: Track.Dimensions?,
376 videoQuality: LivekitModels.VideoQuality?, 376 videoQuality: LivekitModels.VideoQuality?,
  377 + fps: Int?,
377 ) { 378 ) {
378 val trackSettings = LivekitRtc.UpdateTrackSettings.newBuilder() 379 val trackSettings = LivekitRtc.UpdateTrackSettings.newBuilder()
379 .addTrackSids(sid) 380 .addTrackSids(sid)
@@ -388,6 +389,10 @@ constructor( @@ -388,6 +389,10 @@ constructor(
388 // default to HIGH 389 // default to HIGH
389 quality = LivekitModels.VideoQuality.HIGH 390 quality = LivekitModels.VideoQuality.HIGH
390 } 391 }
  392 +
  393 + if(fps != null){
  394 + setFps(fps)
  395 + }
391 } 396 }
392 397
393 val request = LivekitRtc.SignalRequest.newBuilder() 398 val request = LivekitRtc.SignalRequest.newBuilder()
@@ -53,6 +53,7 @@ class RemoteTrackPublication( @@ -53,6 +53,7 @@ class RemoteTrackPublication(
53 private var disabled: Boolean = false 53 private var disabled: Boolean = false
54 private var videoQuality: LivekitModels.VideoQuality? = LivekitModels.VideoQuality.HIGH 54 private var videoQuality: LivekitModels.VideoQuality? = LivekitModels.VideoQuality.HIGH
55 private var videoDimensions: Track.Dimensions? = null 55 private var videoDimensions: Track.Dimensions? = null
  56 + private var fps: Int? = null
56 57
57 var subscriptionAllowed: Boolean = true 58 var subscriptionAllowed: Boolean = true
58 internal set 59 internal set
@@ -128,10 +129,12 @@ class RemoteTrackPublication( @@ -128,10 +129,12 @@ class RemoteTrackPublication(
128 } 129 }
129 130
130 /** 131 /**
131 - * for tracks that support simulcasting, directly adjust subscribed quality 132 + * For tracks that support simulcasting, directly adjust subscribed quality
132 * 133 *
133 - * this indicates the highest quality the client can accept. if network bandwidth does not  
134 - * allow, server will automatically reduce quality to optimize for uninterrupted video 134 + * This indicates the highest quality the client can accept. If network bandwidth does not
  135 + * allow, server will automatically reduce quality to optimize for uninterrupted video.
  136 + *
  137 + * Will override previous calls to [setVideoDimensions].
135 */ 138 */
136 fun setVideoQuality(quality: LivekitModels.VideoQuality) { 139 fun setVideoQuality(quality: LivekitModels.VideoQuality) {
137 if (isAutoManaged 140 if (isAutoManaged
@@ -148,6 +151,8 @@ class RemoteTrackPublication( @@ -148,6 +151,8 @@ class RemoteTrackPublication(
148 151
149 /** 152 /**
150 * Update the dimensions that the server will use for determining the video quality to send down. 153 * Update the dimensions that the server will use for determining the video quality to send down.
  154 + *
  155 + * Will override previous calls to [setVideoQuality].
151 */ 156 */
152 fun setVideoDimensions(dimensions: Track.Dimensions) { 157 fun setVideoDimensions(dimensions: Track.Dimensions) {
153 if (isAutoManaged 158 if (isAutoManaged
@@ -163,6 +168,22 @@ class RemoteTrackPublication( @@ -163,6 +168,22 @@ class RemoteTrackPublication(
163 sendUpdateTrackSettings.invoke() 168 sendUpdateTrackSettings.invoke()
164 } 169 }
165 170
  171 + /**
  172 + * Update the fps that the server will use for determining the video quality to send down.
  173 + */
  174 + fun setVideoFps(fps: Int?) {
  175 + if (isAutoManaged
  176 + || !subscribed
  177 + || this.fps == fps
  178 + || track !is VideoTrack
  179 + ) {
  180 + return
  181 + }
  182 +
  183 + this.fps = fps
  184 + sendUpdateTrackSettings.invoke()
  185 + }
  186 +
166 private fun handleVisibilityChanged(isVisible: Boolean) { 187 private fun handleVisibilityChanged(isVisible: Boolean) {
167 disabled = !isVisible 188 disabled = !isVisible
168 sendUpdateTrackSettings.invoke() 189 sendUpdateTrackSettings.invoke()
@@ -194,7 +215,8 @@ class RemoteTrackPublication( @@ -194,7 +215,8 @@ class RemoteTrackPublication(
194 sid, 215 sid,
195 disabled, 216 disabled,
196 videoDimensions, 217 videoDimensions,
197 - videoQuality 218 + videoQuality,
  219 + fps
198 ) 220 )
199 } 221 }
200 222
@@ -5,7 +5,7 @@ import org.webrtc.MediaStream @@ -5,7 +5,7 @@ import org.webrtc.MediaStream
5 import org.webrtc.VideoTrack 5 import org.webrtc.VideoTrack
6 6
7 fun createMediaStreamId(participantSid: String, trackSid: String) = 7 fun createMediaStreamId(participantSid: String, trackSid: String) =
8 - "${TestData.REMOTE_PARTICIPANT.sid}|${TestData.REMOTE_AUDIO_TRACK.sid}" 8 + "${participantSid}|${trackSid}"
9 9
10 class MockMediaStream(private val id: String = "id") : MediaStream(1L) { 10 class MockMediaStream(private val id: String = "id") : MediaStream(1L) {
11 11
@@ -5,11 +5,14 @@ import org.webrtc.VideoTrack @@ -5,11 +5,14 @@ import org.webrtc.VideoTrack
5 5
6 class MockVideoStreamTrack( 6 class MockVideoStreamTrack(
7 val id: String = "id", 7 val id: String = "id",
8 - val kind: String = AUDIO_TRACK_KIND, 8 + val kind: String = VIDEO_TRACK_KIND,
9 var enabled: Boolean = true, 9 var enabled: Boolean = true,
10 var state: State = State.LIVE, 10 var state: State = State.LIVE,
11 ) : VideoTrack(1L) { 11 ) : VideoTrack(1L) {
12 val sinks = mutableSetOf<VideoSink>() 12 val sinks = mutableSetOf<VideoSink>()
  13 +
  14 + private var shouldReceive = true
  15 +
13 override fun id(): String = id 16 override fun id(): String = id
14 17
15 override fun kind(): String = kind 18 override fun kind(): String = kind
@@ -21,6 +24,12 @@ class MockVideoStreamTrack( @@ -21,6 +24,12 @@ class MockVideoStreamTrack(
21 return true 24 return true
22 } 25 }
23 26
  27 + override fun shouldReceive() = shouldReceive
  28 +
  29 + override fun setShouldReceive(shouldReceive: Boolean) {
  30 + this.shouldReceive = shouldReceive
  31 + }
  32 +
24 override fun state(): State { 33 override fun state(): State {
25 return state 34 return state
26 } 35 }
@@ -47,4 +47,8 @@ class MockWebSocket( @@ -47,4 +47,8 @@ class MockWebSocket(
47 mutableSentRequests.add(bytes) 47 mutableSentRequests.add(bytes)
48 return !isClosed 48 return !isClosed
49 } 49 }
  50 +
  51 + fun clearRequests() {
  52 + mutableSentRequests.clear()
  53 + }
50 } 54 }
@@ -16,6 +16,12 @@ object TestData { @@ -16,6 +16,12 @@ object TestData {
16 build() 16 build()
17 } 17 }
18 18
  19 + val REMOTE_VIDEO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
  20 + sid = "remote_video_track_sid"
  21 + type = LivekitModels.TrackType.VIDEO
  22 + build()
  23 + }
  24 +
19 val LOCAL_PARTICIPANT = with(LivekitModels.ParticipantInfo.newBuilder()) { 25 val LOCAL_PARTICIPANT = with(LivekitModels.ParticipantInfo.newBuilder()) {
20 sid = "local_participant_sid" 26 sid = "local_participant_sid"
21 identity = "local_participant_identity" 27 identity = "local_participant_identity"
@@ -44,6 +50,7 @@ object TestData { @@ -44,6 +50,7 @@ object TestData {
44 build() 50 build()
45 } 51 }
46 addTracks(REMOTE_AUDIO_TRACK) 52 addTracks(REMOTE_AUDIO_TRACK)
  53 + addTracks(REMOTE_VIDEO_TRACK)
47 build() 54 build()
48 } 55 }
49 56
@@ -17,7 +17,6 @@ import io.livekit.android.room.track.Track @@ -17,7 +17,6 @@ import io.livekit.android.room.track.Track
17 import io.livekit.android.util.flow 17 import io.livekit.android.util.flow
18 import io.livekit.android.util.toOkioByteString 18 import io.livekit.android.util.toOkioByteString
19 import junit.framework.Assert.assertEquals 19 import junit.framework.Assert.assertEquals
20 -import junit.framework.Assert.assertNull  
21 import kotlinx.coroutines.ExperimentalCoroutinesApi 20 import kotlinx.coroutines.ExperimentalCoroutinesApi
22 import kotlinx.coroutines.launch 21 import kotlinx.coroutines.launch
23 import org.junit.Assert 22 import org.junit.Assert
@@ -137,9 +136,14 @@ class RoomMockE2ETest : MockE2ETest() { @@ -137,9 +136,14 @@ class RoomMockE2ETest : MockE2ETest() {
137 simulateMessageFromServer(SignalClientTest.PARTICIPANT_JOIN) 136 simulateMessageFromServer(SignalClientTest.PARTICIPANT_JOIN)
138 val events = eventCollector.stopCollecting() 137 val events = eventCollector.stopCollecting()
139 138
140 - Assert.assertEquals(2, events.size)  
141 - Assert.assertEquals(true, events[0] is RoomEvent.ParticipantConnected)  
142 - Assert.assertEquals(true, events[1] is RoomEvent.TrackPublished) 139 + assertIsClassList(
  140 + listOf(
  141 + RoomEvent.ParticipantConnected::class.java,
  142 + RoomEvent.TrackPublished::class.java,
  143 + RoomEvent.TrackPublished::class.java,
  144 + ),
  145 + events
  146 + )
143 } 147 }
144 148
145 @Test 149 @Test
@@ -157,9 +161,14 @@ class RoomMockE2ETest : MockE2ETest() { @@ -157,9 +161,14 @@ class RoomMockE2ETest : MockE2ETest() {
157 ) 161 )
158 val events = eventCollector.stopCollecting() 162 val events = eventCollector.stopCollecting()
159 163
160 - Assert.assertEquals(2, events.size)  
161 - Assert.assertEquals(true, events[0] is RoomEvent.TrackUnpublished)  
162 - Assert.assertEquals(true, events[1] is RoomEvent.ParticipantDisconnected) 164 + assertIsClassList(
  165 + listOf(
  166 + RoomEvent.TrackUnpublished::class.java,
  167 + RoomEvent.TrackUnpublished::class.java,
  168 + RoomEvent.ParticipantDisconnected::class.java,
  169 + ),
  170 + events
  171 + )
163 } 172 }
164 173
165 @Test 174 @Test
@@ -16,7 +16,6 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -16,7 +16,6 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
16 import kotlinx.coroutines.flow.MutableSharedFlow 16 import kotlinx.coroutines.flow.MutableSharedFlow
17 import kotlinx.coroutines.flow.SharedFlow 17 import kotlinx.coroutines.flow.SharedFlow
18 import kotlinx.coroutines.test.runTest 18 import kotlinx.coroutines.test.runTest
19 -import org.junit.Assert  
20 import org.junit.Assert.* 19 import org.junit.Assert.*
21 import org.junit.Before 20 import org.junit.Before
22 import org.junit.Rule 21 import org.junit.Rule
@@ -211,11 +210,16 @@ class RoomTest { @@ -211,11 +210,16 @@ class RoomTest {
211 room.onEngineDisconnected(DisconnectReason.CLIENT_INITIATED) 210 room.onEngineDisconnected(DisconnectReason.CLIENT_INITIATED)
212 val events = eventCollector.stopCollecting() 211 val events = eventCollector.stopCollecting()
213 212
214 - assertEquals(4, events.size)  
215 - assertEquals(true, events[0] is RoomEvent.TrackUnsubscribed)  
216 - assertEquals(true, events[1] is RoomEvent.TrackUnpublished)  
217 - assertEquals(true, events[2] is RoomEvent.ParticipantDisconnected)  
218 - assertEquals(true, events[3] is RoomEvent.Disconnected)  
219 - Assert.assertTrue(room.remoteParticipants.isEmpty()) 213 + assertIsClassList(
  214 + listOf(
  215 + RoomEvent.TrackUnsubscribed::class.java,
  216 + RoomEvent.TrackUnpublished::class.java,
  217 + RoomEvent.TrackUnpublished::class.java,
  218 + RoomEvent.ParticipantDisconnected::class.java,
  219 + RoomEvent.Disconnected::class.java
  220 + ),
  221 + events
  222 + )
  223 + assertTrue(room.remoteParticipants.isEmpty())
220 } 224 }
221 } 225 }
  1 +package io.livekit.android.room.track
  2 +
  3 +import io.livekit.android.MockE2ETest
  4 +import io.livekit.android.mock.*
  5 +import io.livekit.android.room.SignalClientTest
  6 +import io.livekit.android.util.toOkioByteString
  7 +import io.livekit.android.util.toPBByteString
  8 +import kotlinx.coroutines.ExperimentalCoroutinesApi
  9 +import kotlinx.coroutines.test.advanceUntilIdle
  10 +import livekit.LivekitModels
  11 +import livekit.LivekitModels.VideoQuality
  12 +import livekit.LivekitRtc
  13 +import org.junit.Assert.assertEquals
  14 +import org.junit.Assert.assertTrue
  15 +import org.junit.Test
  16 +import org.junit.runner.RunWith
  17 +import org.robolectric.RobolectricTestRunner
  18 +
  19 +
  20 +@ExperimentalCoroutinesApi
  21 +@RunWith(RobolectricTestRunner::class)
  22 +class RemoteTrackPublicationTest : MockE2ETest() {
  23 +
  24 + @Test
  25 + fun trackSetting() = runTest {
  26 + room.adaptiveStream = false
  27 +
  28 + connect()
  29 +
  30 + wsFactory.listener.onMessage(
  31 + wsFactory.ws,
  32 + SignalClientTest.PARTICIPANT_JOIN.toOkioByteString()
  33 + )
  34 +
  35 + room.onAddTrack(
  36 + MockVideoStreamTrack(),
  37 + arrayOf(
  38 + MockMediaStream(
  39 + id = createMediaStreamId(
  40 + TestData.REMOTE_PARTICIPANT.sid,
  41 + TestData.REMOTE_VIDEO_TRACK.sid
  42 + )
  43 + )
  44 + )
  45 + )
  46 +
  47 + advanceUntilIdle()
  48 + wsFactory.ws.clearRequests()
  49 +
  50 + val remoteVideoPub = room.remoteParticipants.values.first()
  51 + .videoTracks.first()
  52 + .first as RemoteTrackPublication
  53 +
  54 + remoteVideoPub.setVideoQuality(VideoQuality.LOW)
  55 + remoteVideoPub.setVideoFps(100)
  56 +
  57 + advanceUntilIdle()
  58 +
  59 + val lastRequest = LivekitRtc.SignalRequest.newBuilder()
  60 + .mergeFrom(wsFactory.ws.sentRequests.last().toPBByteString())
  61 + .build()
  62 +
  63 + assertTrue(lastRequest.hasTrackSetting())
  64 + assertEquals(100, lastRequest.trackSetting.fps)
  65 + assertEquals(VideoQuality.LOW, lastRequest.trackSetting.quality)
  66 + }
  67 +}