David Liu

Fix up media track equals detector

... ... @@ -4,6 +4,8 @@
<bytecodeTargetLevel target="1.8">
<module name="livekit-android.livekit-android-sdk" target="11" />
<module name="livekit-android.sample-app" target="11" />
<module name="livekit-android.sample-app-common" target="11" />
<module name="livekit-android.sample-app-compose" target="11" />
</bytecodeTargetLevel>
</component>
</project>
\ No newline at end of file
... ...
... ... @@ -53,6 +53,7 @@ ext {
versions = [
androidx_core : "1.6.0",
androidx_lifecycle: "2.4.0",
autoService : '1.0.1',
dagger : "2.27",
groupie : "2.9.0",
junit : "4.13.2",
... ... @@ -64,9 +65,12 @@ ext {
protoSrc: "$projectDir/protocol",
]
deps = [
auto : [
'service' : "com.google.auto.service:auto-service:${versions.autoService}",
'serviceAnnotations': "com.google.auto.service:auto-service-annotations:${versions.autoService}",
],
kotlinx_coroutines: "org.jetbrains.kotlinx:kotlinx-coroutines-android:1.5.2",
timber : "com.github.ajalt:timberkt:1.5.1",
// lint
lint : "com.android.tools.lint:lint:${versions.lint}",
lintApi : "com.android.tools.lint:lint-api:${versions.lint}",
... ...
... ... @@ -118,6 +118,9 @@ dependencies {
implementation deps.timber
implementation 'com.vdurmont:semver4j:3.1.0'
lintChecks project(':livekit-lint')
lintPublish project(':livekit-lint')
testImplementation 'junit:junit:4.13.2'
testImplementation 'org.robolectric:robolectric:4.6'
testImplementation 'org.mockito:mockito-core:4.0.0'
... ...
plugins {
id 'java-library'
id 'kotlin'
id 'kotlin-kapt'
}
java {
... ... @@ -9,15 +10,22 @@ java {
}
dependencies {
compileOnly "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
// used for lint rules
compileOnly deps.lintApi
compileOnly deps.lintChecks
compileOnly deps.lintTests
// Handle creating manifests for lint checker
compileOnly deps.auto.serviceAnnotations
kapt deps.auto.service
// test lint
testImplementation deps.lint
testImplementation deps.lintTests
compileOnly "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
// test runners
testImplementation deps.junit
testImplementation deps.junitJupiterApi
... ...
... ... @@ -3,7 +3,10 @@ package io.livekit.lint
import com.android.tools.lint.client.api.IssueRegistry
import com.android.tools.lint.detector.api.CURRENT_API
import com.android.tools.lint.detector.api.Issue
import com.google.auto.service.AutoService
@Suppress("UnstableApiUsage", "unused")
@AutoService(value = [IssueRegistry::class])
class IssueRegistry : IssueRegistry() {
override val api: Int = CURRENT_API
... ...
... ... @@ -20,14 +20,14 @@ package io.livekit.lint
import com.android.tools.lint.client.api.UElementHandler
import com.android.tools.lint.detector.api.*
import com.intellij.psi.CommonClassNames.JAVA_LANG_OBJECT
import com.intellij.psi.PsiClassType
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMethod
import org.jetbrains.uast.*
import org.jetbrains.uast.util.isMethodCall
import org.jetbrains.uast.UBinaryExpression
import org.jetbrains.uast.UCallExpression
import org.jetbrains.uast.UastBinaryOperator
/** Checks related to DiffUtil computation. */
/**
* Detects MediaStreamTrack.equals() usage. This is generally a mistake and should not be used.
*/
class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
override fun getApplicableUastTypes() =
... ... @@ -46,138 +46,21 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
}
}
private fun defaultEquals(context: JavaContext, node: UElement): Boolean {
val resolved: PsiMethod?
when (node) {
is UBinaryExpression -> {
resolved = node.resolveOperator()
if (resolved == null) {
val left = node.leftOperand.getExpressionType() as? PsiClassType
return defaultEquals(context, left)
}
}
is UCallExpression -> {
resolved = node.takeIf { it.isMethodCall() }
?.resolve()
}
is UParenthesizedExpression -> {
return defaultEquals(context, node.expression)
}
else -> {
// We don't know any better
return false
}
}
return resolved?.containingClass?.qualifiedName == MEDIA_STREAM_TRACK
}
private fun defaultEquals(
context: JavaContext,
type: PsiClassType?
): Boolean {
val cls = type?.resolve() ?: return false
if (isKotlin(cls) && (context.evaluator.isSealed(cls) || context.evaluator.isData(cls))) {
// Sealed class doesn't guarantee that it defines equals/hashCode
// but it's likely (we'd need to go look at each inner class)
return false
}
for (m in cls.findMethodsByName("equals", true)) {
if (m is PsiMethod) {
val parameters = m.parameterList.parameters
if (parameters.size == 1 &&
parameters[0].type.canonicalText == MEDIA_STREAM_TRACK
) {
return m.containingClass?.qualifiedName == MEDIA_STREAM_TRACK
}
}
}
return false
}
private fun checkCall(context: JavaContext, node: UCallExpression) {
if (defaultEquals(context, node)) {
// Within cast or instanceof check which implies a more specific type
// which provides an equals implementation?
if (withinCastWithEquals(context, node)) {
return
}
val message = DEFAULT_MSG
val location = context.getCallLocation(
node,
includeReceiver = false,
includeArguments = true
)
context.report(ISSUE, node, location, message)
}
}
/**
* Is this .equals() call within another if check which checks
* instanceof on a more specific type than we're calling equals on?
* If so, does that more specific type define its own equals?
*
* Also handle an implicit check via short circuit evaluation; e.g.
* something like "return a is A && b is B && a.equals(b)".
*/
private fun withinCastWithEquals(context: JavaContext, node: UExpression): Boolean {
var parent = skipParenthesizedExprUp(node.uastParent)
if (parent is UQualifiedReferenceExpression) {
parent = skipParenthesizedExprUp(parent.uastParent)
}
val target: PsiElement? = when (node) {
is UCallExpression -> node.receiver?.tryResolve()
is UBinaryExpression -> node.leftOperand.tryResolve()
else -> null
}
if (parent is UPolyadicExpression && parent.operator == UastBinaryOperator.LOGICAL_AND) {
val operands = parent.operands
for (operand in operands) {
if (operand === node) {
break
}
if (isCastWithEquals(context, operand, target)) {
return true
}
}
}
val ifStatement = node.getParentOfType<UElement>(UIfExpression::class.java, false, UMethod::class.java)
as? UIfExpression ?: return false
val condition = ifStatement.condition
return isCastWithEquals(context, condition, target)
}
private fun isCastWithEquals(context: JavaContext, node: UExpression, target: PsiElement?): Boolean {
when {
node is UBinaryExpressionWithType -> {
if (target != null) {
val resolved = node.operand.tryResolve()
// Unfortunately in some scenarios isEquivalentTo returns false for equal instances
//noinspection LintImplPsiEquals
if (resolved != null && !(target == resolved || target.isEquivalentTo(resolved))) {
return false
}
}
return !defaultEquals(context, node.type as? PsiClassType)
}
node is UPolyadicExpression && node.operator == UastBinaryOperator.LOGICAL_AND -> {
for (operand in node.operands) {
if (isCastWithEquals(context, operand, target)) {
return true
}
}
}
node is UParenthesizedExpression -> {
return isCastWithEquals(context, node.expression, target)
if (node.methodName == "equals") {
val left = node.receiverType ?: return
val right = node.valueArguments.takeIf { it.isNotEmpty() }
?.get(0)
?.getExpressionType()
?: return
if (left is PsiClassType && right is PsiClassType
&& (left.canonicalText == MEDIA_STREAM_TRACK || right.canonicalText == MEDIA_STREAM_TRACK)
) {
val message = DEFAULT_MSG
val location = context.getLocation(node)
context.report(ISSUE, node, location, message)
}
}
return false
}
private fun checkExpression(context: JavaContext, node: UBinaryExpression) {
... ... @@ -187,27 +70,13 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
val left = node.leftOperand.getExpressionType() ?: return
val right = node.rightOperand.getExpressionType() ?: return
if (left is PsiClassType && right is PsiClassType
&& (left.className == "MediaStreamTrack" || right.className == "MediaStreamTrack")
&& (left.canonicalText == MEDIA_STREAM_TRACK || right.canonicalText == MEDIA_STREAM_TRACK)
) {
if (node.operator == UastBinaryOperator.EQUALS) {
if (defaultEquals(context, node)) {
if (withinCastWithEquals(context, node)) {
return
}
val message = DEFAULT_MSG
val location = node.operatorIdentifier?.let {
context.getLocation(it)
} ?: context.getLocation(node)
context.report(ISSUE, node, location, message)
}
} else {
val message = DEFAULT_MSG
val location = node.operatorIdentifier?.let {
context.getLocation(it)
} ?: context.getLocation(node)
context.report(ISSUE, node, location, message)
}
val message = DEFAULT_MSG
val location = node.operatorIdentifier?.let {
context.getLocation(it)
} ?: context.getLocation(node)
context.report(ISSUE, node, location, message)
}
}
}
... ... @@ -224,17 +93,16 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
@JvmField
val ISSUE = Issue.create(
id = "MediaTrackEqualsDetector",
briefDescription = "Suspicious DiffUtil Equality",
briefDescription = "Suspicious MediaStreamTrack Equality",
explanation = """
`areContentsTheSame` is used by `DiffUtil` to produce diffs. If the \
method is implemented incorrectly, such as using identity equals \
instead of equals, or calling equals on a class that has not implemented \
it, weird visual artifacts can occur.
MediaStreamTrack does not implement `equals`, and therefore cannot be relied upon.
Additionally, many MediaStreamTrack objects may exist for the same underlying stream,
and therefore the identity operator `===` is unreliable.
""",
category = Category.CORRECTNESS,
priority = 4,
androidSpecific = true,
moreInfo = "https://issuetracker.google.com/116789824",
moreInfo = "https://github.com/livekit/client-sdk-android/commit/01152f2ac01dae59759383d587cdc21035718b8e",
severity = Severity.ERROR,
implementation = IMPLEMENTATION
)
... ...
... ... @@ -4,7 +4,7 @@ package io.livekit.lint
import com.android.tools.lint.checks.infrastructure.TestFile
import com.android.tools.lint.checks.infrastructure.TestFiles.java
import com.android.tools.lint.checks.infrastructure.TestLintTask
import com.android.tools.lint.checks.infrastructure.TestFiles.kotlin
import com.android.tools.lint.checks.infrastructure.TestLintTask.lint
import org.junit.Test
... ... @@ -131,25 +131,85 @@ class MediaTrackEqualsDetectorTest {
.run()
.expectClean()
}
}
fun TestLintTask.mediaStreamTrack(): TestLintTask {
return this.files(
java(
"""
package org.webrtc;
class MediaStreamTrack {
int getId(){
return 0;
}
}
"""
).indented()
)
@Test
fun kotlinMediaTrackEqualityOperator() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
kotlin(
"""
package foo
import org.webrtc.MediaStreamTrack
class Example {
fun foo() : Boolean {
val a = MediaStreamTrack()
val b = MediaStreamTrack()
return a === b;
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectErrorCount(1)
}
@Test
fun kotlinMediaTrackIdentityEqualityOperator() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
kotlin(
"""
package foo
import org.webrtc.MediaStreamTrack
class Example {
fun foo() : Boolean {
val a = MediaStreamTrack()
val b = MediaStreamTrack()
return a == b
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectErrorCount(1)
}
@Test
fun kotlinProperMediaTrackEquality() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
kotlin(
"""
package foo
import org.webrtc.MediaStreamTrack
class Example {
fun foo() : Boolean {
val a = MediaStreamTrack()
val b = MediaStreamTrack()
return a.id() == b.id()
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectClean()
}
}
fun Any.mediaStreamTrack(): TestFile {
fun mediaStreamTrack(): TestFile {
return java(
"""
package org.webrtc;
... ...
package org.webrtc;
class MediaStreamTrack {
int getId() {
return 0;
}
}
\ No newline at end of file