davidliu
Committed by GitHub

Add TranscriptionReceived event to track publication (#449)

* Add TranscriptionReceived event to track publication

* spotless
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.events
  18 +
  19 +import io.livekit.android.room.track.TrackPublication
  20 +import io.livekit.android.room.types.TranscriptionSegment
  21 +
  22 +sealed class TrackPublicationEvent(val publication: TrackPublication) : Event() {
  23 + class TranscriptionReceived(
  24 + /**
  25 + * The applicable track publication these transcriptions apply to.
  26 + */
  27 + publication: TrackPublication,
  28 + /**
  29 + * The transcription segments.
  30 + */
  31 + val transcriptions: List<TranscriptionSegment>,
  32 + ) : TrackPublicationEvent(publication)
  33 +}
@@ -1022,7 +1022,7 @@ constructor( @@ -1022,7 +1022,7 @@ constructor(
1022 ) 1022 )
1023 eventBus.tryPostEvent(event) 1023 eventBus.tryPostEvent(event)
1024 participant?.onTranscriptionReceived(event) 1024 participant?.onTranscriptionReceived(event)
1025 - // TODO: Emit for publication 1025 + publication?.onTranscriptionReceived(event)
1026 } 1026 }
1027 1027
1028 /** 1028 /**
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -16,6 +16,9 @@ @@ -16,6 +16,9 @@
16 16
17 package io.livekit.android.room.track 17 package io.livekit.android.room.track
18 18
  19 +import io.livekit.android.events.BroadcastEventBus
  20 +import io.livekit.android.events.RoomEvent
  21 +import io.livekit.android.events.TrackPublicationEvent
19 import io.livekit.android.room.participant.Participant 22 import io.livekit.android.room.participant.Participant
20 import io.livekit.android.util.FlowObservable 23 import io.livekit.android.util.FlowObservable
21 import io.livekit.android.util.flowDelegate 24 import io.livekit.android.util.flowDelegate
@@ -44,6 +47,9 @@ open class TrackPublication( @@ -44,6 +47,9 @@ open class TrackPublication(
44 return trackInfo?.encryption ?: LivekitModels.Encryption.Type.NONE 47 return trackInfo?.encryption ?: LivekitModels.Encryption.Type.NONE
45 } 48 }
46 49
  50 + protected val eventBus = BroadcastEventBus<TrackPublicationEvent>()
  51 + val events = eventBus.readOnly()
  52 +
47 @FlowObservable 53 @FlowObservable
48 @get:FlowObservable 54 @get:FlowObservable
49 open var muted: Boolean by flowDelegate(false) 55 open var muted: Boolean by flowDelegate(false)
@@ -87,4 +93,16 @@ open class TrackPublication( @@ -87,4 +93,16 @@ open class TrackPublication(
87 93
88 trackInfo = info 94 trackInfo = info
89 } 95 }
  96 +
  97 + internal fun onTranscriptionReceived(transcription: RoomEvent.TranscriptionReceived) {
  98 + if (transcription.publication != this) {
  99 + return
  100 + }
  101 + eventBus.tryPostEvent(
  102 + TrackPublicationEvent.TranscriptionReceived(
  103 + publication = this,
  104 + transcriptions = transcription.transcriptionSegments,
  105 + ),
  106 + )
  107 + }
90 } 108 }
@@ -316,6 +316,7 @@ object TestData { @@ -316,6 +316,7 @@ object TestData {
316 startTime = 1 316 startTime = 1
317 endTime = 10 317 endTime = 10
318 final = true 318 final = true
  319 + trackId = LOCAL_AUDIO_TRACK.sid
319 build() 320 build()
320 }, 321 },
321 ) 322 )
@@ -16,10 +16,16 @@ @@ -16,10 +16,16 @@
16 16
17 package io.livekit.android.room 17 package io.livekit.android.room
18 18
  19 +import io.livekit.android.events.ParticipantEvent
19 import io.livekit.android.events.RoomEvent 20 import io.livekit.android.events.RoomEvent
  21 +import io.livekit.android.events.TrackPublicationEvent
  22 +import io.livekit.android.room.participant.AudioTrackPublishOptions
  23 +import io.livekit.android.room.track.LocalAudioTrack
  24 +import io.livekit.android.room.track.Track
20 import io.livekit.android.test.MockE2ETest 25 import io.livekit.android.test.MockE2ETest
21 import io.livekit.android.test.assert.assertIsClass 26 import io.livekit.android.test.assert.assertIsClass
22 import io.livekit.android.test.events.EventCollector 27 import io.livekit.android.test.events.EventCollector
  28 +import io.livekit.android.test.mock.MockAudioStreamTrack
23 import io.livekit.android.test.mock.MockDataChannel 29 import io.livekit.android.test.mock.MockDataChannel
24 import io.livekit.android.test.mock.MockPeerConnection 30 import io.livekit.android.test.mock.MockPeerConnection
25 import io.livekit.android.test.mock.TestData 31 import io.livekit.android.test.mock.TestData
@@ -33,26 +39,57 @@ class RoomTranscriptionMockE2ETest : MockE2ETest() { @@ -33,26 +39,57 @@ class RoomTranscriptionMockE2ETest : MockE2ETest() {
33 @Test 39 @Test
34 fun transcriptionReceived() = runTest { 40 fun transcriptionReceived() = runTest {
35 connect() 41 connect()
  42 + room.localParticipant.publishAudioTrack(
  43 + LocalAudioTrack(
  44 + "",
  45 + MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
  46 + ),
  47 + options = AudioTrackPublishOptions(
  48 + source = Track.Source.MICROPHONE,
  49 + ),
  50 + )
36 val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection 51 val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
37 val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL) 52 val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
38 subPeerConnection.observer?.onDataChannel(subDataChannel) 53 subPeerConnection.observer?.onDataChannel(subDataChannel)
39 54
40 - val collector = EventCollector(room.events, coroutineRule.scope) 55 + val roomCollector = EventCollector(room.events, coroutineRule.scope)
  56 + val participantCollector = EventCollector(room.localParticipant.events, coroutineRule.scope)
  57 + val publicationCollector = EventCollector(room.localParticipant.getTrackPublication(Track.Source.MICROPHONE)!!.events, coroutineRule.scope)
  58 +
41 val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer() 59 val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()
42 60
43 subDataChannel.observer?.onMessage(dataBuffer) 61 subDataChannel.observer?.onMessage(dataBuffer)
44 - val events = collector.stopCollecting()  
45 62
46 - assertEquals(1, events.size)  
47 - assertIsClass(RoomEvent.TranscriptionReceived::class.java, events[0]) 63 + val roomEvents = roomCollector.stopCollecting()
  64 + val participantEvents = participantCollector.stopCollecting()
  65 + val publicationEvents = publicationCollector.stopCollecting()
  66 +
  67 + // Verify room events
  68 + run {
  69 + assertEquals(1, roomEvents.size)
  70 + assertIsClass(RoomEvent.TranscriptionReceived::class.java, roomEvents[0])
48 71
49 - val event = events.first() as RoomEvent.TranscriptionReceived 72 + val event = roomEvents.first() as RoomEvent.TranscriptionReceived
50 assertEquals(room, event.room) 73 assertEquals(room, event.room)
51 assertEquals(room.localParticipant, event.participant) 74 assertEquals(room.localParticipant, event.participant)
  75 + assertEquals(room.localParticipant.getTrackPublication(Track.Source.MICROPHONE)!!, event.publication)
52 76
53 val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0) 77 val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
54 val receivedSegment = event.transcriptionSegments.first() 78 val receivedSegment = event.transcriptionSegments.first()
55 assertEquals(expectedSegment.id, receivedSegment.id) 79 assertEquals(expectedSegment.id, receivedSegment.id)
56 assertEquals(expectedSegment.text, receivedSegment.text) 80 assertEquals(expectedSegment.text, receivedSegment.text)
57 } 81 }
  82 +
  83 + // Verify participant events
  84 + run {
  85 + assertEquals(1, participantEvents.size)
  86 + assertIsClass(ParticipantEvent.TranscriptionReceived::class.java, participantEvents[0])
  87 + }
  88 +
  89 + // Verify publication events
  90 + run {
  91 + assertEquals(1, publicationEvents.size)
  92 + assertIsClass(TrackPublicationEvent.TranscriptionReceived::class.java, publicationEvents[0])
  93 + }
  94 + }
58 } 95 }