David Liu

FlowDelegateUsageDetector

... ... @@ -20,194 +20,332 @@ package io.livekit.lint
import com.android.tools.lint.client.api.UElementHandler
import com.android.tools.lint.detector.api.*
import com.intellij.psi.CommonClassNames.JAVA_LANG_OBJECT
import com.intellij.psi.PsiClassType
import com.intellij.psi.JavaElementVisitor
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiJavaCodeReferenceElement
import com.intellij.psi.PsiMethod
import org.jetbrains.uast.*
import org.jetbrains.uast.util.isMethodCall
import org.jetbrains.uast.kotlin.KotlinUQualifiedReferenceExpression
/** Checks related to DiffUtil computation. */
class FlowDelegateUsageDetector : Detector(), SourceCodeScanner {
override fun visitReference(
context: JavaContext,
visitor: JavaElementVisitor?,
reference: PsiJavaCodeReferenceElement,
referenced: PsiElement
) {
super.visitReference(context, visitor, reference, referenced)
}
override fun visitReference(context: JavaContext, reference: UReferenceExpression, referenced: PsiElement) {
// Check if we're actually trying to access the flow delegate
val referencedMethod = referenced as? PsiMethod ?: return
if (referenced.name != "getFlow" || referencedMethod.containingClass?.qualifiedName != "io.livekit.android.util.FlowObservableKt") {
return
}
// This should get the property we're trying to receive the flow from.
val receiver = ((reference.uastParent as? KotlinUQualifiedReferenceExpression)
?.receiver as? UCallableReferenceExpression)
?: return
// This should get the original class associated with the property.
val className = receiver.qualifierType?.canonicalText
val psiClass = if (className != null) context.evaluator.findClass(className) else null
val psiField = psiClass?.findFieldByName("${receiver.callableName}\$delegate", true)
val isAnnotated = psiField?.hasAnnotation("io.livekit.android.util.FlowObservable") ?: false
if (!isAnnotated) {
val message = DEFAULT_MSG
val location = context.getLocation(reference)
context.report(ISSUE, reference, location, message)
}
}
override fun getApplicableReferenceNames(): List<String>? =
listOf("flow")
override fun visitMethodCall(context: JavaContext, node: UCallExpression, method: PsiMethod) {
super.visitMethodCall(context, node, method)
}
override fun getApplicableUastTypes() =
listOf(UBinaryExpression::class.java, UCallExpression::class.java)
listOf(
UBinaryExpression::class.java,
UCallExpression::class.java,
UAnnotation::class.java,
UArrayAccessExpression::class.java,
UBinaryExpressionWithType::class.java,
UBlockExpression::class.java,
UBreakExpression::class.java,
UCallableReferenceExpression::class.java,
UCatchClause::class.java,
UClass::class.java,
UClassLiteralExpression::class.java,
UContinueExpression::class.java,
UDeclaration::class.java,
UDoWhileExpression::class.java,
UElement::class.java,
UEnumConstant::class.java,
UExpression::class.java,
UExpressionList::class.java,
UField::class.java,
UFile::class.java,
UForEachExpression::class.java,
UForExpression::class.java,
UIfExpression::class.java,
UImportStatement::class.java,
UClassInitializer::class.java,
ULabeledExpression::class.java,
ULambdaExpression::class.java,
ULiteralExpression::class.java,
ULocalVariable::class.java,
UMethod::class.java,
UObjectLiteralExpression::class.java,
UParameter::class.java,
UParenthesizedExpression::class.java,
UPolyadicExpression::class.java,
UPostfixExpression::class.java,
UPrefixExpression::class.java,
UQualifiedReferenceExpression::class.java,
UReturnExpression::class.java,
USimpleNameReferenceExpression::class.java,
USuperExpression::class.java,
USwitchClauseExpression::class.java,
USwitchExpression::class.java,
UThisExpression::class.java,
UThrowExpression::class.java,
UTryExpression::class.java,
UTypeReferenceExpression::class.java,
UUnaryExpression::class.java,
UVariable::class.java,
UWhileExpression::class.java,
UYieldExpression::class.java,
)
override fun createUastHandler(context: JavaContext): UElementHandler? {
return object : UElementHandler() {
val context = context
override fun visitBinaryExpression(node: UBinaryExpression) {
checkExpression(context, node)
println(0)
}
override fun visitCallExpression(node: UCallExpression) {
checkCall(context, node)
node.classReference
node.methodName
node.methodIdentifier
node.receiverType
node.receiver
node.kind
node.valueArguments
node.valueArgumentCount
node.typeArguments
node.typeArgumentCount
node.returnType
println(1)
}
}
}
private fun defaultEquals(context: JavaContext, node: UElement): Boolean {
val resolved: PsiMethod?
override fun visitAnnotation(node: UAnnotation) {
println(0)
}
when (node) {
is UBinaryExpression -> {
resolved = node.resolveOperator()
if (resolved == null) {
val left = node.leftOperand.getExpressionType() as? PsiClassType
return defaultEquals(context, left)
}
override fun visitArrayAccessExpression(node: UArrayAccessExpression) {
println(0)
}
is UCallExpression -> {
resolved = node.takeIf { it.isMethodCall() }
?.resolve()
override fun visitBinaryExpressionWithType(node: UBinaryExpressionWithType) {
println(0)
}
is UParenthesizedExpression -> {
return defaultEquals(context, node.expression)
override fun visitBlockExpression(node: UBlockExpression) {
println(0)
}
else -> {
// We don't know any better
return false
override fun visitBreakExpression(node: UBreakExpression) {
println(0)
}
}
return resolved?.containingClass?.qualifiedName == MEDIA_STREAM_TRACK
}
override fun visitCallableReferenceExpression(node: UCallableReferenceExpression) {
println(0)
}
private fun defaultEquals(
context: JavaContext,
type: PsiClassType?
): Boolean {
val cls = type?.resolve() ?: return false
if (isKotlin(cls) && (context.evaluator.isSealed(cls) || context.evaluator.isData(cls))) {
// Sealed class doesn't guarantee that it defines equals/hashCode
// but it's likely (we'd need to go look at each inner class)
return false
}
override fun visitCatchClause(node: UCatchClause) {
println(0)
}
for (m in cls.findMethodsByName("equals", true)) {
if (m is PsiMethod) {
val parameters = m.parameterList.parameters
if (parameters.size == 1 &&
parameters[0].type.canonicalText == MEDIA_STREAM_TRACK
) {
return m.containingClass?.qualifiedName == MEDIA_STREAM_TRACK
}
override fun visitClass(node: UClass) {
println(0)
}
}
return false
}
override fun visitClassLiteralExpression(node: UClassLiteralExpression) {
println(0)
}
private fun checkCall(context: JavaContext, node: UCallExpression) {
if (defaultEquals(context, node)) {
// Within cast or instanceof check which implies a more specific type
// which provides an equals implementation?
if (withinCastWithEquals(context, node)) {
return
override fun visitContinueExpression(node: UContinueExpression) {
println(0)
}
val message = DEFAULT_MSG
val location = context.getCallLocation(
node,
includeReceiver = false,
includeArguments = true
)
context.report(ISSUE, node, location, message)
}
}
override fun visitDeclaration(node: UDeclaration) {
println(0)
}
/**
* Is this .equals() call within another if check which checks
* instanceof on a more specific type than we're calling equals on?
* If so, does that more specific type define its own equals?
*
* Also handle an implicit check via short circuit evaluation; e.g.
* something like "return a is A && b is B && a.equals(b)".
*/
private fun withinCastWithEquals(context: JavaContext, node: UExpression): Boolean {
var parent = skipParenthesizedExprUp(node.uastParent)
if (parent is UQualifiedReferenceExpression) {
parent = skipParenthesizedExprUp(parent.uastParent)
}
val target: PsiElement? = when (node) {
is UCallExpression -> node.receiver?.tryResolve()
is UBinaryExpression -> node.leftOperand.tryResolve()
else -> null
}
override fun visitDeclarationsExpression(node: UDeclarationsExpression) {
println(0)
}
if (parent is UPolyadicExpression && parent.operator == UastBinaryOperator.LOGICAL_AND) {
val operands = parent.operands
for (operand in operands) {
if (operand === node) {
break
}
if (isCastWithEquals(context, operand, target)) {
return true
}
override fun visitDoWhileExpression(node: UDoWhileExpression) {
println(0)
}
}
val ifStatement = node.getParentOfType<UElement>(UIfExpression::class.java, false, UMethod::class.java)
as? UIfExpression ?: return false
val condition = ifStatement.condition
return isCastWithEquals(context, condition, target)
}
private fun isCastWithEquals(context: JavaContext, node: UExpression, target: PsiElement?): Boolean {
when {
node is UBinaryExpressionWithType -> {
if (target != null) {
val resolved = node.operand.tryResolve()
// Unfortunately in some scenarios isEquivalentTo returns false for equal instances
//noinspection LintImplPsiEquals
if (resolved != null && !(target == resolved || target.isEquivalentTo(resolved))) {
return false
}
}
return !defaultEquals(context, node.type as? PsiClassType)
}
node is UPolyadicExpression && node.operator == UastBinaryOperator.LOGICAL_AND -> {
for (operand in node.operands) {
if (isCastWithEquals(context, operand, target)) {
return true
}
}
}
node is UParenthesizedExpression -> {
return isCastWithEquals(context, node.expression, target)
override fun visitElement(node: UElement) {
println(0)
}
override fun visitEnumConstant(node: UEnumConstant) {
println(0)
}
override fun visitExpression(node: UExpression) {
println(0)
}
override fun visitExpressionList(node: UExpressionList) {
println(0)
}
override fun visitField(node: UField) {
println(0)
}
override fun visitFile(node: UFile) {
println(0)
}
override fun visitForEachExpression(node: UForEachExpression) {
println(0)
}
override fun visitForExpression(node: UForExpression) {
println(0)
}
override fun visitIfExpression(node: UIfExpression) {
println(0)
}
override fun visitImportStatement(node: UImportStatement) {
println(0)
}
override fun visitInitializer(node: UClassInitializer) {
println(0)
}
override fun visitLabeledExpression(node: ULabeledExpression) {
println(0)
}
override fun visitLambdaExpression(node: ULambdaExpression) {
println(0)
}
override fun visitLiteralExpression(node: ULiteralExpression) {
println(0)
}
override fun visitLocalVariable(node: ULocalVariable) {
println(0)
}
override fun visitMethod(node: UMethod) {
println(0)
}
override fun visitObjectLiteralExpression(node: UObjectLiteralExpression) {
println(0)
}
override fun visitParameter(node: UParameter) {
println(0)
}
override fun visitParenthesizedExpression(node: UParenthesizedExpression) {
println(0)
}
override fun visitPolyadicExpression(node: UPolyadicExpression) {
println(0)
}
override fun visitPostfixExpression(node: UPostfixExpression) {
println(0)
}
override fun visitPrefixExpression(node: UPrefixExpression) {
println(0)
}
override fun visitQualifiedReferenceExpression(node: UQualifiedReferenceExpression) {
println(0)
}
override fun visitReturnExpression(node: UReturnExpression) {
println(0)
}
override fun visitSimpleNameReferenceExpression(node: USimpleNameReferenceExpression) {
println(0)
//(((node as KotlinUSimpleReferenceExpression).uastParent as KotlinUQualifiedReferenceExpression).receiver as KotlinUCallableReferenceExpression).qualifierType?.canonicalText
}
override fun visitSuperExpression(node: USuperExpression) {
println(0)
}
override fun visitSwitchClauseExpression(node: USwitchClauseExpression) {
println(0)
}
override fun visitSwitchExpression(node: USwitchExpression) {
println(0)
}
override fun visitThisExpression(node: UThisExpression) {
println(0)
}
override fun visitThrowExpression(node: UThrowExpression) {
println(0)
}
override fun visitTryExpression(node: UTryExpression) {
println(0)
}
override fun visitTypeReferenceExpression(node: UTypeReferenceExpression) {
println(0)
}
override fun visitUnaryExpression(node: UUnaryExpression) {
println(0)
}
override fun visitVariable(node: UVariable) {
println(0)
}
override fun visitWhileExpression(node: UWhileExpression) {
println(0)
}
}
return false
}
private fun checkExpression(context: JavaContext, node: UBinaryExpression) {
if (node.operator == UastBinaryOperator.IDENTITY_EQUALS ||
node.operator == UastBinaryOperator.EQUALS
) {
val left = node.leftOperand.getExpressionType() ?: return
val right = node.rightOperand.getExpressionType() ?: return
if (left is PsiClassType && right is PsiClassType
&& (left.className == "MediaStreamTrack" || right.className == "MediaStreamTrack")
) {
if (node.operator == UastBinaryOperator.EQUALS) {
if (defaultEquals(context, node)) {
if (withinCastWithEquals(context, node)) {
return
}
val message = DEFAULT_MSG
val location = node.operatorIdentifier?.let {
context.getLocation(it)
} ?: context.getLocation(node)
context.report(ISSUE, node, location, message)
}
} else {
val message = DEFAULT_MSG
val location = node.operatorIdentifier?.let {
context.getLocation(it)
} ?: context.getLocation(node)
context.report(ISSUE, node, location, message)
}
override fun visitYieldExpression(node: UYieldExpression) {
println(0)
}
}
}
... ... @@ -216,20 +354,18 @@ class FlowDelegateUsageDetector : Detector(), SourceCodeScanner {
private const val MEDIA_STREAM_TRACK = "org.webrtc.MediaStreamTrack"
private const val DEFAULT_MSG =
"Suspicious equality check: MediaStreamTracks should not be checked for equality. Check id() instead."
"Incorrect flow property usage: Only properties marked with the @FlowObservable annotation can be observed using `io.livekit.android.util.flow`."
private val IMPLEMENTATION =
Implementation(FlowDelegateUsageDetector::class.java, Scope.JAVA_FILE_SCOPE)
@JvmField
val ISSUE = Issue.create(
id = "MediaTrackEqualsDetector",
briefDescription = "Suspicious DiffUtil Equality",
id = "FlowDelegateUsageDetector",
briefDescription = "flow on a non-@FlowObservable property",
explanation = """
`areContentsTheSame` is used by `DiffUtil` to produce diffs. If the \
method is implemented incorrectly, such as using identity equals \
instead of equals, or calling equals on a class that has not implemented \
it, weird visual artifacts can occur.
Only properties marked with the @FlowObservable annotation can be observed using
`io.livekit.android.util.flow`.
""",
category = Category.CORRECTNESS,
priority = 4,
... ...
... ... @@ -2,135 +2,114 @@
package io.livekit.lint
import com.android.tools.lint.checks.infrastructure.LintDetectorTest.bytes
import com.android.tools.lint.checks.infrastructure.TestFile
import com.android.tools.lint.checks.infrastructure.TestFiles.java
import com.android.tools.lint.checks.infrastructure.TestLintTask
import com.android.tools.lint.checks.infrastructure.TestFiles.kotlin
import com.android.tools.lint.checks.infrastructure.TestLintTask.lint
import org.junit.Test
class FlowDelegateUsageDetectorTest {
@Test
fun objectEquals() {
fun normalFlowAccess() {
lint()
.allowMissingSdk()
.files(
java(
flowAccess(),
kotlin(
"""
package foo;
class Example {
public boolean foo() {
Object a = new Object();
Object b = new Object();
return a.equals(b);
}
}"""
package foo
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.flow
import io.livekit.android.util.flowDelegate
class Example {
@field:FlowObservable
val value: Int by flowDelegate(0)
fun foo() {
::value.flow
return
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.issues(FlowDelegateUsageDetector.ISSUE)
.run()
.expectClean()
}
@Test
fun objectEqualityOperator() {
fun nonAnnotatedFlowAccess() {
lint()
.allowMissingSdk()
.files(
java(
flowAccess(),
kotlin(
"""
package foo;
class Example {
public boolean foo() {
Object a = new Object();
Object b = new Object();
return a == b;
}
}"""
package foo
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.flow
import io.livekit.android.util.flowDelegate
class Example {
val value: Int by flowDelegate(0)
fun foo() {
this::value.flow
return
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectClean()
}
@Test
fun mediaTrackEquals() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
java(
"""
package foo;
import org.webrtc.MediaStreamTrack;
class Example {
public boolean foo() {
MediaStreamTrack a = new MediaStreamTrack();
MediaStreamTrack b = new MediaStreamTrack();
return a.equals(b);
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.issues(FlowDelegateUsageDetector.ISSUE)
.run()
.expectErrorCount(1)
}
}
@Test
fun mediaTrackEqualityOperator() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
java(
"""
package foo;
import org.webrtc.MediaStreamTrack;
class Example {
public boolean foo() {
ABC a = new ABC();
MediaStreamTrack b = new MediaStreamTrack();
a.equals(b);
return a == b;
fun flowAccess(): TestFile {
return kotlin(
"""
package io.livekit.android.util
import kotlin.reflect.KProperty
import kotlin.reflect.KProperty0
internal val <T> KProperty0<T>.delegate: Any?
get() { getDelegate() }
@Suppress("UNCHECKED_CAST")
val <T> KProperty0<T>.flow: StateFlow<T>
get() = delegate as StateFlow<T>
@Target(AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.SOURCE)
@MustBeDocumented
annotation class FlowObservable
class MutableStateFlowDelegate<T>
internal constructor(
private val flow: MutableStateFlow<T>,
private val onSetValue: ((newValue: T, oldValue: T) -> Unit)? = null
) : MutableStateFlow<T> by flow {
operator fun getValue(thisRef: Any?, property: KProperty<*>): T {
return flow.value
}
public boolean equals(Object o){
return false;
operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) {
val oldValue = flow.value
flow.value = value
onSetValue?.invoke(value, oldValue)
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectErrorCount(1)
}
@Test
fun properMediaTrackEquality() {
lint()
.allowMissingSdk()
.files(
mediaStreamTrack(),
java(
"""
package foo;
}
public fun <T> flowDelegate(
initialValue: T,
onSetValue: ((newValue: T, oldValue: T) -> Unit)? = null
): MutableStateFlowDelegate<T> {
return MutableStateFlowDelegate(MutableStateFlow(initialValue), onSetValue)
}
class Example {
public boolean foo() {
MediaStreamTrack a = new MediaStreamTrack();
MediaStreamTrack b = new MediaStreamTrack();
return a.getId() == b.getId();
}
}"""
).indented()
)
.issues(MediaTrackEqualsDetector.ISSUE)
.run()
.expectClean()
}
interface StateFlow<out T> {
val value: T
}
class MutableStateFlow<T>(override var value: T) : StateFlow<T>
"""
).indented()
}
\ No newline at end of file
... ...