davidliu
Committed by GitHub

Add TranscriptionReceived event to track publication (#449)

* Add TranscriptionReceived event to track publication

* spotless
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.events
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.types.TranscriptionSegment
sealed class TrackPublicationEvent(val publication: TrackPublication) : Event() {
class TranscriptionReceived(
/**
* The applicable track publication these transcriptions apply to.
*/
publication: TrackPublication,
/**
* The transcription segments.
*/
val transcriptions: List<TranscriptionSegment>,
) : TrackPublicationEvent(publication)
}
... ...
... ... @@ -1022,7 +1022,7 @@ constructor(
)
eventBus.tryPostEvent(event)
participant?.onTranscriptionReceived(event)
// TODO: Emit for publication
publication?.onTranscriptionReceived(event)
}
/**
... ...
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -16,6 +16,9 @@
package io.livekit.android.room.track
import io.livekit.android.events.BroadcastEventBus
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.TrackPublicationEvent
import io.livekit.android.room.participant.Participant
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.flowDelegate
... ... @@ -44,6 +47,9 @@ open class TrackPublication(
return trackInfo?.encryption ?: LivekitModels.Encryption.Type.NONE
}
protected val eventBus = BroadcastEventBus<TrackPublicationEvent>()
val events = eventBus.readOnly()
@FlowObservable
@get:FlowObservable
open var muted: Boolean by flowDelegate(false)
... ... @@ -87,4 +93,16 @@ open class TrackPublication(
trackInfo = info
}
internal fun onTranscriptionReceived(transcription: RoomEvent.TranscriptionReceived) {
if (transcription.publication != this) {
return
}
eventBus.tryPostEvent(
TrackPublicationEvent.TranscriptionReceived(
publication = this,
transcriptions = transcription.transcriptionSegments,
),
)
}
}
... ...
... ... @@ -316,6 +316,7 @@ object TestData {
startTime = 1
endTime = 10
final = true
trackId = LOCAL_AUDIO_TRACK.sid
build()
},
)
... ...
... ... @@ -16,10 +16,16 @@
package io.livekit.android.room
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.TrackPublicationEvent
import io.livekit.android.room.participant.AudioTrackPublishOptions
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.Track
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClass
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
... ... @@ -33,26 +39,57 @@ class RoomTranscriptionMockE2ETest : MockE2ETest() {
@Test
fun transcriptionReceived() = runTest {
connect()
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
),
options = AudioTrackPublishOptions(
source = Track.Source.MICROPHONE,
),
)
val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
subPeerConnection.observer?.onDataChannel(subDataChannel)
val collector = EventCollector(room.events, coroutineRule.scope)
val roomCollector = EventCollector(room.events, coroutineRule.scope)
val participantCollector = EventCollector(room.localParticipant.events, coroutineRule.scope)
val publicationCollector = EventCollector(room.localParticipant.getTrackPublication(Track.Source.MICROPHONE)!!.events, coroutineRule.scope)
val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()
subDataChannel.observer?.onMessage(dataBuffer)
val events = collector.stopCollecting()
assertEquals(1, events.size)
assertIsClass(RoomEvent.TranscriptionReceived::class.java, events[0])
val roomEvents = roomCollector.stopCollecting()
val participantEvents = participantCollector.stopCollecting()
val publicationEvents = publicationCollector.stopCollecting()
// Verify room events
run {
assertEquals(1, roomEvents.size)
assertIsClass(RoomEvent.TranscriptionReceived::class.java, roomEvents[0])
val event = roomEvents.first() as RoomEvent.TranscriptionReceived
assertEquals(room, event.room)
assertEquals(room.localParticipant, event.participant)
assertEquals(room.localParticipant.getTrackPublication(Track.Source.MICROPHONE)!!, event.publication)
val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
val receivedSegment = event.transcriptionSegments.first()
assertEquals(expectedSegment.id, receivedSegment.id)
assertEquals(expectedSegment.text, receivedSegment.text)
}
val event = events.first() as RoomEvent.TranscriptionReceived
assertEquals(room, event.room)
assertEquals(room.localParticipant, event.participant)
// Verify participant events
run {
assertEquals(1, participantEvents.size)
assertIsClass(ParticipantEvent.TranscriptionReceived::class.java, participantEvents[0])
}
val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
val receivedSegment = event.transcriptionSegments.first()
assertEquals(expectedSegment.id, receivedSegment.id)
assertEquals(expectedSegment.text, receivedSegment.text)
// Verify publication events
run {
assertEquals(1, publicationEvents.size)
assertIsClass(TrackPublicationEvent.TranscriptionReceived::class.java, publicationEvents[0])
}
}
}
... ...