Cleanup: NotNull/Nullable
[idea/community.git] / java / java-analysis-impl / src / com / intellij / codeInspection / dataFlow / inference / ContractInferenceInterpreter.java
1 // Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.intellij.codeInspection.dataFlow.inference;
3
4 import com.intellij.codeInspection.dataFlow.ContractReturnValue;
5 import com.intellij.codeInspection.dataFlow.StandardMethodContract;
6 import com.intellij.codeInspection.dataFlow.StandardMethodContract.ValueConstraint;
7 import com.intellij.lang.LighterAST;
8 import com.intellij.lang.LighterASTNode;
9 import com.intellij.psi.JavaTokenType;
10 import com.intellij.psi.impl.source.tree.ElementType;
11 import com.intellij.psi.impl.source.tree.RecursiveLighterASTNodeWalkingVisitor;
12 import com.intellij.psi.tree.IElementType;
13 import com.intellij.psi.tree.TokenSet;
14 import com.intellij.util.containers.ContainerUtil;
15 import org.jetbrains.annotations.NotNull;
16 import org.jetbrains.annotations.Nullable;
17
18 import java.util.ArrayList;
19 import java.util.Arrays;
20 import java.util.BitSet;
21 import java.util.List;
22
23 import static com.intellij.codeInspection.dataFlow.ContractReturnValue.*;
24 import static com.intellij.codeInspection.dataFlow.StandardMethodContract.ValueConstraint.*;
25 import static com.intellij.psi.impl.source.JavaLightTreeUtil.*;
26 import static com.intellij.psi.impl.source.tree.JavaElementType.*;
27 import static com.intellij.psi.impl.source.tree.LightTreeUtil.firstChildOfType;
28 import static com.intellij.psi.impl.source.tree.LightTreeUtil.getChildrenOfType;
29 import static java.util.Collections.emptyList;
30 import static java.util.Collections.singletonList;
31
32 class ContractInferenceInterpreter {
33   private static final TokenSet UNARY_INCREMENT_DECREMENT = TokenSet.create(JavaTokenType.PLUSPLUS, JavaTokenType.MINUSMINUS);
34   private final LighterAST myTree;
35   private final LighterASTNode myMethod;
36   private final LighterASTNode myBody;
37
38   ContractInferenceInterpreter(LighterAST tree, LighterASTNode method, LighterASTNode body) {
39     myTree = tree;
40     myMethod = method;
41     myBody = body;
42   }
43
44   @NotNull
45   List<LighterASTNode> getParameters() {
46     LighterASTNode paramList = firstChildOfType(myTree, myMethod, PARAMETER_LIST);
47     return paramList != null ? getChildrenOfType(myTree, paramList, PARAMETER) : emptyList();
48   }
49
50   @NotNull
51   List<PreContract> inferContracts(List<LighterASTNode> statements) {
52     if (statements.isEmpty()) return emptyList();
53
54     if (statements.size() == 1) {
55       List<PreContract> result = handleSingleStatement(statements.get(0));
56       if (result != null) return result;
57     }
58
59     List<PreContract> contracts =
60       visitStatements(singletonList(StandardMethodContract.createConstraintArray(getParameters().size())), statements);
61     if (contracts.isEmpty()) {
62       ContractReturnValue value = getDefaultReturnValue(statements);
63       if (!value.isFail() && !value.equals(returnAny())) {
64         contracts = singletonList(
65           new KnownContract(StandardMethodContract.trivialContract(getParameters().size(), value)));
66       }
67     }
68     return contracts;
69   }
70
71   @Nullable
72   private List<PreContract> handleSingleStatement(LighterASTNode statement) {
73     if (statement.getTokenType() == RETURN_STATEMENT) {
74       LighterASTNode returned = findExpressionChild(myTree, statement);
75       return getLiteralConstraint(returned) != null ? emptyList() : handleDelegation(returned, false);
76     }
77     if (statement.getTokenType() == EXPRESSION_STATEMENT) {
78       LighterASTNode expr = findExpressionChild(myTree, statement);
79       return expr != null && expr.getTokenType() == METHOD_CALL_EXPRESSION ? handleDelegation(expr, false) : null;
80     }
81     return null;
82   }
83
84   @Nullable
85   private LighterASTNode getCodeBlock(@Nullable LighterASTNode parent) {
86     return firstChildOfType(myTree, parent, CODE_BLOCK);
87   }
88
89   @NotNull
90   static List<LighterASTNode> getStatements(@Nullable LighterASTNode codeBlock, LighterAST tree) {
91     return codeBlock == null ? emptyList() : getChildrenOfType(tree, codeBlock, ElementType.JAVA_STATEMENT_BIT_SET);
92   }
93
94   @Nullable
95   private List<PreContract> handleDelegation(@Nullable LighterASTNode expression, boolean negated) {
96     if (expression == null) return null;
97     if (expression.getTokenType() == PARENTH_EXPRESSION) {
98       return handleDelegation(findExpressionChild(myTree, expression), negated);
99     }
100
101     if (isNegationExpression(expression)) {
102       return handleDelegation(findExpressionChild(myTree, expression), !negated);
103     }
104
105     if (expression.getTokenType() == METHOD_CALL_EXPRESSION) {
106       return singletonList(new DelegationContract(ExpressionRange.create(expression, myBody.getStartOffset()), negated));
107     }
108
109     return null;
110   }
111
112   private boolean isNegationExpression(@Nullable LighterASTNode expression) {
113     return expression != null && expression.getTokenType() == PREFIX_EXPRESSION && firstChildOfType(myTree, expression, JavaTokenType.EXCL) != null;
114   }
115
116   private ContractReturnValue getDefaultReturnValue(List<LighterASTNode> statements) {
117     class ReturnValueVisitor extends RecursiveLighterASTNodeWalkingVisitor {
118       public ContractReturnValue returnValue = fail();
119
120       private BitSet assignedParameters;
121
122       ReturnValueVisitor() {
123         super(myTree);
124       }
125
126       @Override
127       public void visitNode(@NotNull LighterASTNode element) {
128         IElementType type = element.getTokenType();
129         if (type == CLASS || type == LAMBDA_EXPRESSION) return;
130         if (returnValue.equals(returnAny())) {
131           return;
132         }
133         if (type == ASSIGNMENT_EXPRESSION || type == POSTFIX_EXPRESSION ||
134             (type == PREFIX_EXPRESSION && firstChildOfType(myTree, element, UNARY_INCREMENT_DECREMENT) != null)) {
135           LighterASTNode expression = skipParenthesesCastsDown(myTree, findExpressionChild(myTree, element));
136           int paramIndex = resolveParameter(expression);
137           if (paramIndex >= 0) {
138             if (assignedParameters == null) {
139               assignedParameters = new BitSet();
140             }
141             assignedParameters.set(paramIndex);
142             if (returnValue.equals(returnParameter(paramIndex))) {
143               returnValue = returnAny();
144             }
145           }
146         }
147         if (type == RETURN_STATEMENT) {
148           LighterASTNode expression = findExpressionChild(myTree, element);
149           ContractReturnValue newReturnValue = expressionToReturnValue(expression);
150           if (returnValue.isFail()) {
151             returnValue = newReturnValue;
152           }
153           else if (!returnValue.equals(newReturnValue)) {
154             returnValue = returnAny();
155           }
156         }
157         super.visitNode(element);
158       }
159
160       @NotNull
161       private ContractReturnValue expressionToReturnValue(LighterASTNode expression) {
162         expression = skipParenthesesDown(myTree, expression);
163         if (expression == null) return returnAny();
164         IElementType type = expression.getTokenType();
165         if (type == NEW_EXPRESSION) {
166           return returnNew();
167         }
168         if (type == THIS_EXPRESSION) {
169           return returnThis();
170         }
171         if (type == REFERENCE_EXPRESSION) {
172           int paramIndex = resolveParameter(expression);
173           if (paramIndex >= 0 && (assignedParameters == null || !assignedParameters.get(paramIndex))) {
174             return returnParameter(paramIndex);
175           }
176         }
177         return returnAny();
178       }
179     }
180     ReturnValueVisitor visitor = new ReturnValueVisitor();
181     for (LighterASTNode statement : statements) {
182       visitor.visitNode(statement);
183     }
184     return visitor.returnValue;
185   }
186
187   @NotNull
188   private List<PreContract> visitExpression(final List<ValueConstraint[]> states, @Nullable LighterASTNode expr) {
189     if (expr == null) return emptyList();
190     if (states.isEmpty()) return emptyList();
191     if (states.size() > 300) return emptyList(); // too complex
192
193     IElementType type = expr.getTokenType();
194     if (type == POLYADIC_EXPRESSION || type == BINARY_EXPRESSION) {
195       return visitPolyadic(states, expr);
196     }
197
198     if (type == CONDITIONAL_EXPRESSION) {
199       List<LighterASTNode> children = getExpressionChildren(myTree, expr);
200       if (children.size() != 3) return emptyList();
201
202       List<PreContract> conditionResults = visitExpression(states, children.get(0));
203       return ContainerUtil.concat(
204         visitExpression(antecedentsReturning(conditionResults, returnTrue()), children.get(1)),
205         visitExpression(antecedentsReturning(conditionResults, returnFalse()), children.get(2)));
206     }
207
208
209     if (type == PARENTH_EXPRESSION) {
210       return visitExpression(states, findExpressionChild(myTree, expr));
211     }
212     if (type == TYPE_CAST_EXPRESSION) {
213       return visitExpression(states, findExpressionChild(myTree, expr));
214     }
215
216     if (isNegationExpression(expr)) {
217       return ContainerUtil.mapNotNull(visitExpression(states, findExpressionChild(myTree, expr)), PreContract::negate);
218     }
219
220     if (type == INSTANCE_OF_EXPRESSION) {
221       final int parameter = resolveParameter(findExpressionChild(myTree, expr));
222       if (parameter >= 0) {
223         return asPreContracts(ContainerUtil.mapNotNull(states, state -> contractWithConstraint(state, parameter, NULL_VALUE, returnFalse())));
224       }
225     }
226
227     if (type == NEW_EXPRESSION) {
228       return asPreContracts(toContracts(states, returnNew()));
229     }
230     if (type == THIS_EXPRESSION) {
231       return asPreContracts(toContracts(states, returnThis()));
232     }
233     if (type == METHOD_CALL_EXPRESSION) {
234       return singletonList(new MethodCallContract(ExpressionRange.create(expr, myBody.getStartOffset()),
235                                                               ContainerUtil.map(states, Arrays::asList)));
236     }
237
238     final ValueConstraint constraint = getLiteralConstraint(expr);
239     if (constraint != null) {
240       return asPreContracts(toContracts(states, constraint.asReturnValue()));
241     }
242
243     int paramIndex = resolveParameter(expr);
244     if (paramIndex >= 0) {
245       List<StandardMethodContract> result = ContainerUtil.newArrayList();
246       for (ValueConstraint[] state : states) {
247         if (state[paramIndex] == TRUE_VALUE || state[paramIndex] == FALSE_VALUE || state[paramIndex] == NULL_VALUE) {
248           // like "if(x == null) return x": no need to refer to parameter
249           result.add(new StandardMethodContract(state, state[paramIndex].asReturnValue()));
250         } else if (JavaTokenType.BOOLEAN_KEYWORD == getPrimitiveParameterType(paramIndex)) {
251           // if (boolValue) ...
252           ContainerUtil.addIfNotNull(result, contractWithConstraint(state, paramIndex, TRUE_VALUE, returnTrue()));
253           ContainerUtil.addIfNotNull(result, contractWithConstraint(state, paramIndex, FALSE_VALUE, returnFalse()));
254         } else {
255           result.add(new StandardMethodContract(state, returnParameter(paramIndex)));
256         }
257       }
258       return asPreContracts(result);
259     }
260
261     return emptyList();
262   }
263
264   @NotNull
265   private List<PreContract> visitPolyadic(List<ValueConstraint[]> states, @NotNull LighterASTNode expr) {
266     if (firstChildOfType(myTree, expr, JavaTokenType.PLUS) != null) {
267       return asPreContracts(ContainerUtil.map(states, s -> new StandardMethodContract(s, returnNotNull())));
268     }
269
270     List<LighterASTNode> operands = getExpressionChildren(myTree, expr);
271     if (operands.size() == 2) {
272       boolean equality = firstChildOfType(myTree, expr, JavaTokenType.EQEQ) != null;
273       if (equality || firstChildOfType(myTree, expr, JavaTokenType.NE) != null) {
274         return asPreContracts(visitEqualityComparison(states, operands.get(0), operands.get(1), equality));
275       }
276     }
277     boolean logicalAnd = firstChildOfType(myTree, expr, JavaTokenType.ANDAND) != null;
278     if (logicalAnd || firstChildOfType(myTree, expr, JavaTokenType.OROR) != null) {
279       return asPreContracts(visitLogicalOperation(operands, logicalAnd, states));
280     }
281     return emptyList();
282   }
283
284   @NotNull
285   private static List<PreContract> asPreContracts(List<StandardMethodContract> contracts) {
286     return ContainerUtil.map(contracts, KnownContract::new);
287   }
288
289   @Nullable
290   private static StandardMethodContract contractWithConstraint(ValueConstraint[] state,
291                                                                int parameter, ValueConstraint paramConstraint,
292                                                                ContractReturnValue returnValue) {
293     ValueConstraint[] newState = withConstraint(state, parameter, paramConstraint);
294     return newState == null ? null : new StandardMethodContract(newState, returnValue);
295   }
296
297   private List<StandardMethodContract> visitEqualityComparison(List<ValueConstraint[]> states,
298                                                                LighterASTNode op1,
299                                                                LighterASTNode op2,
300                                                                boolean equality) {
301     int parameter = resolveParameter(op1);
302     ValueConstraint constraint = getLiteralConstraint(op2);
303     if (parameter < 0 || constraint == null) {
304       parameter = resolveParameter(op2);
305       constraint = getLiteralConstraint(op1);
306     }
307     if (parameter >= 0 && constraint != null) {
308       List<StandardMethodContract> result = ContainerUtil.newArrayList();
309       for (ValueConstraint[] state : states) {
310         if (constraint == NOT_NULL_VALUE) {
311           if (getPrimitiveParameterType(parameter) == null) {
312             ContainerUtil.addIfNotNull(result, contractWithConstraint(state, parameter, NULL_VALUE, returnBoolean(!equality)));
313           }
314         } else {
315           ContainerUtil.addIfNotNull(result, contractWithConstraint(state, parameter, constraint, returnBoolean(equality)));
316           ContainerUtil.addIfNotNull(result, contractWithConstraint(state, parameter, constraint.negate(), returnBoolean(!equality)));
317         }
318       }
319       return result;
320     }
321     return emptyList();
322   }
323
324   @Nullable
325   private IElementType getPrimitiveParameterType(int paramIndex) {
326     LighterASTNode typeElement = firstChildOfType(myTree, getParameters().get(paramIndex), TYPE);
327     LighterASTNode primitive = firstChildOfType(myTree, typeElement, ElementType.PRIMITIVE_TYPE_BIT_SET);
328     return primitive == null ? null : primitive.getTokenType();
329   }
330
331   static List<StandardMethodContract> toContracts(List<ValueConstraint[]> states, ContractReturnValue constraint) {
332     return ContainerUtil.map(states, state -> new StandardMethodContract(state, constraint));
333   }
334
335   private List<StandardMethodContract> visitLogicalOperation(List<LighterASTNode> operands, boolean conjunction, List<ValueConstraint[]> states) {
336     BooleanReturnValue breakValue = returnBoolean(!conjunction);
337     List<StandardMethodContract> finalStates = ContainerUtil.newArrayList();
338     for (LighterASTNode operand : operands) {
339       List<PreContract> opResults = visitExpression(states, operand);
340       finalStates.addAll(ContainerUtil.filter(knownContracts(opResults), contract -> contract.getReturnValue() == breakValue));
341       states = antecedentsReturning(opResults, breakValue.negate());
342     }
343     finalStates.addAll(toContracts(states, breakValue.negate()));
344     return finalStates;
345   }
346
347   private static List<StandardMethodContract> knownContracts(List<PreContract> values) {
348     return ContainerUtil.mapNotNull(values, pc -> pc instanceof KnownContract ? ((KnownContract)pc).getContract() : null);
349   }
350
351   private static List<ValueConstraint[]> antecedentsReturning(List<PreContract> values, ContractReturnValue result) {
352     return ContainerUtil.mapNotNull(knownContracts(values),
353                                     contract -> contract.getReturnValue().equals(result) ?
354                                                 contract.getConstraints().toArray(new ValueConstraint[0]) : null);
355   }
356
357   private static class CodeBlockContracts {
358     List<PreContract> accumulated = new ArrayList<>();
359     List<ExpressionRange> varInitializers = new ArrayList<>();
360
361     void addAll(List<PreContract> contracts) {
362       if (contracts.isEmpty()) return;
363
364       if (varInitializers.isEmpty()) {
365         accumulated.addAll(contracts);
366       } else {
367         accumulated.add(new SideEffectFilter(varInitializers, contracts));
368       }
369     }
370
371     void registerDeclaration(@NotNull LighterASTNode declStatement, @NotNull LighterAST tree, int scopeStart) {
372       for (LighterASTNode var : getChildrenOfType(tree, declStatement, LOCAL_VARIABLE)) {
373         LighterASTNode initializer = findExpressionChild(tree, var);
374         if (initializer != null) {
375           varInitializers.add(ExpressionRange.create(initializer, scopeStart));
376         }
377       }
378     }
379   }
380
381   @NotNull
382   private List<PreContract> visitStatements(List<ValueConstraint[]> states, List<LighterASTNode> statements) {
383     CodeBlockContracts result = new CodeBlockContracts();
384     for (LighterASTNode statement : statements) {
385       IElementType type = statement.getTokenType();
386       if (type == BLOCK_STATEMENT) {
387         result.addAll(visitStatements(states, getStatements(getCodeBlock(statement), myTree)));
388       }
389       else if (type == IF_STATEMENT) {
390         List<PreContract> conditionResults = visitExpression(states, findExpressionChild(myTree, statement));
391
392         List<LighterASTNode> thenElse = getStatements(statement, myTree);
393         if (thenElse.size() > 0) {
394           result.addAll(visitStatements(antecedentsReturning(conditionResults, returnTrue()), singletonList(thenElse.get(0))));
395         }
396
397         List<ValueConstraint[]> falseStates = antecedentsReturning(conditionResults, returnFalse());
398         if (thenElse.size() > 1) {
399           result.addAll(visitStatements(falseStates, singletonList(thenElse.get(1))));
400         } else {
401           states = falseStates;
402           continue;
403         }
404       }
405       else if (type == WHILE_STATEMENT) {
406         states = antecedentsReturning(visitExpression(states, findExpressionChild(myTree, statement)), returnFalse());
407         continue;
408       }
409       else if (type == THROW_STATEMENT) {
410         result.addAll(asPreContracts(toContracts(states, fail())));
411       }
412       else if (type == RETURN_STATEMENT) {
413         result.addAll(visitExpression(states, findExpressionChild(myTree, statement)));
414       }
415       else if (type == ASSERT_STATEMENT) {
416         List<PreContract> conditionResults = visitExpression(states, findExpressionChild(myTree, statement));
417         result.addAll(asPreContracts(toContracts(antecedentsReturning(conditionResults, returnFalse()), fail())));
418       }
419       else if (type == DECLARATION_STATEMENT) {
420         result.registerDeclaration(statement, myTree, myBody.getStartOffset());
421         continue;
422       }
423       else if (type == DO_WHILE_STATEMENT) {
424         result.addAll(visitStatements(states, getStatements(statement, myTree)));
425       }
426
427       break; // visit only the first statement unless it's 'if' whose 'then' always returns and the next statement is effectively 'else'
428     }
429     return result.accumulated;
430   }
431
432   @Nullable
433   private ValueConstraint getLiteralConstraint(@Nullable LighterASTNode expr) {
434     if (expr != null && expr.getTokenType() == LITERAL_EXPRESSION) {
435       return getLiteralConstraint(myTree.getChildren(expr).get(0).getTokenType());
436     }
437     return null;
438   }
439
440   @NotNull
441   static ValueConstraint getLiteralConstraint(@NotNull IElementType literalTokenType) {
442     if (literalTokenType.equals(JavaTokenType.TRUE_KEYWORD)) return TRUE_VALUE;
443     if (literalTokenType.equals(JavaTokenType.FALSE_KEYWORD)) return FALSE_VALUE;
444     if (literalTokenType.equals(JavaTokenType.NULL_KEYWORD)) return NULL_VALUE;
445     return NOT_NULL_VALUE;
446   }
447
448   private int resolveParameter(@Nullable LighterASTNode expr) {
449     if (expr != null && expr.getTokenType() == REFERENCE_EXPRESSION && findExpressionChild(myTree, expr) == null) {
450       String name = getNameIdentifierText(myTree, expr);
451       if (name == null) return -1;
452
453       List<LighterASTNode> parameters = getParameters();
454       for (int i = 0; i < parameters.size(); i++) {
455         if (name.equals(getNameIdentifierText(myTree, parameters.get(i)))) {
456           return i;
457         }
458       }
459     }
460     return -1;
461   }
462
463   @Nullable
464   static ValueConstraint[] withConstraint(ValueConstraint[] constraints, int index, ValueConstraint constraint) {
465     if (constraints[index] == constraint) return constraints;
466
467     ValueConstraint negated = constraint.negate();
468     if (negated != constraint && constraints[index] == negated) {
469       return null;
470     }
471
472     ValueConstraint[] copy = constraints.clone();
473     copy[index] = constraint;
474     return copy;
475   }
476
477 }