Cleanup: NotNull/Nullable
[idea/community.git] / java / java-impl / src / com / intellij / codeInspection / SimplifyOptionalCallChainsInspection.java
1 // Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.intellij.codeInspection;
3
4 import com.intellij.codeInsight.Nullability;
5 import com.intellij.codeInspection.dataFlow.*;
6 import com.intellij.codeInspection.dataFlow.value.DfaFactMapValue;
7 import com.intellij.codeInspection.dataFlow.value.DfaValue;
8 import com.intellij.codeInspection.util.LambdaGenerationUtil;
9 import com.intellij.codeInspection.util.OptionalRefactoringUtil;
10 import com.intellij.codeInspection.util.OptionalUtil;
11 import com.intellij.openapi.project.Project;
12 import com.intellij.pom.java.LanguageLevel;
13 import com.intellij.psi.*;
14 import com.intellij.psi.search.searches.ReferencesSearch;
15 import com.intellij.psi.util.PsiExpressionTrimRenderer;
16 import com.intellij.psi.util.PsiTreeUtil;
17 import com.intellij.psi.util.PsiUtil;
18 import com.intellij.refactoring.util.LambdaRefactoringUtil;
19 import com.siyeh.ig.callMatcher.CallHandler;
20 import com.siyeh.ig.callMatcher.CallMapper;
21 import com.siyeh.ig.callMatcher.CallMatcher;
22 import com.siyeh.ig.psiutils.*;
23 import org.jetbrains.annotations.Contract;
24 import org.jetbrains.annotations.Nls;
25 import org.jetbrains.annotations.NotNull;
26 import org.jetbrains.annotations.Nullable;
27
28 import java.util.Arrays;
29 import java.util.List;
30 import java.util.Objects;
31 import java.util.Optional;
32 import java.util.function.Function;
33 import java.util.regex.Pattern;
34
35 import static com.intellij.codeInspection.util.OptionalUtil.*;
36 import static com.intellij.psi.CommonClassNames.JAVA_UTIL_OPTIONAL;
37 import static com.intellij.util.ObjectUtils.tryCast;
38
39
40 public class SimplifyOptionalCallChainsInspection extends AbstractBaseJavaLocalInspectionTool {
41   private static final CallMatcher OPTIONAL_OR_ELSE =
42     CallMatcher.instanceCall(JAVA_UTIL_OPTIONAL, "orElse").parameterCount(1);
43   private static final CallMatcher OPTIONAL_GET =
44     CallMatcher.instanceCall(JAVA_UTIL_OPTIONAL, "get").parameterCount(0);
45   private static final CallMatcher OPTIONAL_OR_ELSE_GET =
46     CallMatcher.instanceCall(JAVA_UTIL_OPTIONAL, "orElseGet").parameterCount(1);
47   private static final CallMatcher OPTIONAL_OR_ELSE_OR_ELSE_GET = CallMatcher.anyOf(
48     OPTIONAL_OR_ELSE,
49     OPTIONAL_OR_ELSE_GET
50   );
51   private static final CallMatcher OPTIONAL_MAP =
52     CallMatcher.instanceCall(JAVA_UTIL_OPTIONAL, "map").parameterCount(1);
53   private static final CallMatcher OPTIONAL_OF_NULLABLE =
54     CallMatcher.staticCall(JAVA_UTIL_OPTIONAL, "ofNullable").parameterCount(1);
55   private static final CallMatcher OPTIONAL_OF_OF_NULLABLE =
56     CallMatcher.staticCall(JAVA_UTIL_OPTIONAL, "ofNullable", "of").parameterCount(1);
57   private static final CallMatcher OPTIONAL_IS_PRESENT =
58     CallMatcher.anyOf(
59       CallMatcher.exactInstanceCall(JAVA_UTIL_OPTIONAL, "isPresent").parameterCount(0),
60       CallMatcher.exactInstanceCall(OPTIONAL_INT, "isPresent").parameterCount(0),
61       CallMatcher.exactInstanceCall(OPTIONAL_LONG, "isPresent").parameterCount(0),
62       CallMatcher.exactInstanceCall(OPTIONAL_DOUBLE, "isPresent").parameterCount(0)
63     );
64   private static final CallMatcher OPTIONAL_IF_PRESENT =
65     CallMatcher.anyOf(
66       CallMatcher.exactInstanceCall(JAVA_UTIL_OPTIONAL, "ifPresent").parameterCount(1),
67       CallMatcher.exactInstanceCall(OPTIONAL_INT, "ifPresent").parameterCount(1),
68       CallMatcher.exactInstanceCall(OPTIONAL_LONG, "ifPresent").parameterCount(1),
69       CallMatcher.exactInstanceCall(OPTIONAL_DOUBLE, "ifPresent").parameterCount(1)
70     );
71   private static final CallMatcher OPTIONAL_IS_EMPTY =
72     CallMatcher.anyOf(
73       CallMatcher.exactInstanceCall(JAVA_UTIL_OPTIONAL, "isEmpty").parameterCount(0),
74       CallMatcher.exactInstanceCall(OPTIONAL_INT, "isEmpty").parameterCount(0),
75       CallMatcher.exactInstanceCall(OPTIONAL_LONG, "isEmpty").parameterCount(0),
76       CallMatcher.exactInstanceCall(OPTIONAL_DOUBLE, "isEmpty").parameterCount(0)
77     );
78
79
80   private static final CallMapper<OptionalSimplificationFix> ourMapper;
81
82   static {
83     List<ChainSimplificationCase<?>> cases = Arrays.asList(
84       new IfPresentFoldedCase(),
85       new MapUnwrappingCase(),
86       new OrElseNonNullCase(OrElseType.OrElse),
87       new OrElseNonNullCase(OrElseType.OrElseGet),
88       new FlipPresentOrEmptyCase(true),
89       new FlipPresentOrEmptyCase(false),
90       new OrElseReturnCase(OrElseType.OrElse),
91       new OrElseReturnCase(OrElseType.OrElseGet),
92       new RewrappingCase(RewrappingCase.Type.OptionalGet),
93       new RewrappingCase(RewrappingCase.Type.OrElseNull),
94       new MapOrElseCase(OrElseType.OrElseGet),
95       new MapOrElseCase(OrElseType.OrElse)
96     );
97     ourMapper = new CallMapper<>();
98     for (ChainSimplificationCase<?> theCase : cases) {
99       CallHandler<OptionalSimplificationFix> handler = CallHandler.of(theCase.getMatcher(), theCase);
100       ourMapper.register(handler);
101     }
102   }
103
104   @NotNull
105   @Override
106   public PsiElementVisitor buildVisitor(@NotNull ProblemsHolder holder, boolean isOnTheFly) {
107     LanguageLevel level = PsiUtil.getLanguageLevel(holder.getFile());
108     if (level.isLessThan(LanguageLevel.JDK_1_8)) {
109       return PsiElementVisitor.EMPTY_VISITOR;
110     }
111     return new OptionalChainVisitor(level) {
112       @Override
113       protected void handleSimplification(@NotNull PsiMethodCallExpression call, @NotNull OptionalSimplificationFix fix) {
114         PsiElement element = call.getMethodExpression().getReferenceNameElement();
115         holder.registerProblem(element != null ? element : call, fix.getDescription(), fix);
116       }
117     };
118   }
119
120   @Nullable
121   private static <T> OptionalSimplificationFix getFix(PsiMethodCallExpression call, ChainSimplificationCase<T> inspection) {
122     T context = inspection.extractContext(call.getProject(), call);
123     if (context == null) return null;
124     String name = inspection.getName(context);
125     String description = inspection.getDescription(context);
126     return new OptionalSimplificationFix(inspection, name, description);
127   }
128
129   private static <T> void handleSimplification(ChainSimplificationCase<T> inspection, Project project, PsiMethodCallExpression call) {
130     if (!inspection.getMatcher().matches(call)) return;
131     T context = inspection.extractContext(project, call);
132     if (context != null) {
133       inspection.apply(project, call, context);
134     }
135   }
136
137
138   @Nullable
139   private static PsiLambdaExpression getLambda(PsiExpression initializer) {
140     PsiExpression expression = PsiUtil.skipParenthesizedExprDown(initializer);
141     if (expression instanceof PsiLambdaExpression) {
142       return (PsiLambdaExpression)expression;
143     }
144     if (expression instanceof PsiMethodReferenceExpression) {
145       PsiMethodReferenceExpression methodRef = (PsiMethodReferenceExpression)expression;
146       PsiLambdaExpression lambda = LambdaRefactoringUtil.createLambda(methodRef, true);
147       if (lambda != null) {
148         LambdaRefactoringUtil.specifyLambdaParameterTypes(methodRef.getFunctionalInterfaceType(), lambda);
149         return lambda;
150       }
151     }
152     return null;
153   }
154
155   /**
156    * @return argument expression in case of absence of optional value
157    */
158   private static PsiExpression getOrElseArgument(PsiMethodCallExpression call, OrElseType type) {
159     if (type == OrElseType.OrElse) {
160       return call.getArgumentList().getExpressions()[0];
161     }
162     if (type == OrElseType.OrElseGet) {
163       PsiLambdaExpression lambda = getLambda(call.getArgumentList().getExpressions()[0]);
164       if (lambda == null || !lambda.getParameterList().isEmpty()) return null;
165       return LambdaUtil.extractSingleExpressionFromBody(lambda.getBody());
166     }
167     return null;
168   }
169
170   private static CallMatcher getMatcherByType(OrElseType type) {
171     if (type == OrElseType.OrElse) {
172       return OPTIONAL_OR_ELSE;
173     }
174     if (type == OrElseType.OrElseGet) {
175       return OPTIONAL_OR_ELSE_GET;
176     }
177     throw new IllegalStateException();
178   }
179
180   /*
181     if(optValue != null) {return optValue;} else {return "default";}
182     or
183     return optValue == null? "default" : optValue;
184      */
185   @Nullable
186   private static PsiExpression extractConditionalDefaultValue(@NotNull PsiStatement statement, @NotNull PsiVariable optValue) {
187     if (statement instanceof PsiIfStatement) {
188       PsiIfStatement ifStatement = (PsiIfStatement)statement;
189       PsiExpression condition = ifStatement.getCondition();
190       if (condition == null) return null;
191       PsiExpression thenExpr = getReturnExpression(ifStatement.getThenBranch());
192       PsiExpression elseExpr = getReturnExpression(ifStatement.getElseBranch());
193       if (thenExpr == null || elseExpr == null) return null;
194       return extractConditionalDefaultValue(thenExpr, elseExpr, condition, optValue);
195     }
196     else if (statement instanceof PsiReturnStatement) {
197       PsiExpression returnValue = ((PsiReturnStatement)statement).getReturnValue();
198       PsiConditionalExpression ternary = tryCast(PsiUtil.skipParenthesizedExprDown(returnValue), PsiConditionalExpression.class);
199       if (ternary == null) return null;
200       PsiExpression thenExpression = ternary.getThenExpression();
201       PsiExpression elseExpression = ternary.getElseExpression();
202       if (thenExpression == null || elseExpression == null) return null;
203       return extractConditionalDefaultValue(thenExpression, elseExpression, ternary.getCondition(), optValue);
204     }
205     return null;
206   }
207
208   @Contract("null -> null")
209   @Nullable
210   private static PsiExpression getReturnExpression(@Nullable PsiStatement block) {
211     if (block == null) return null;
212     PsiStatement statement = ControlFlowUtils.stripBraces(block);
213     PsiReturnStatement returnStatement = tryCast(statement, PsiReturnStatement.class);
214     if (returnStatement == null) return null;
215     return returnStatement.getReturnValue();
216   }
217
218   @Nullable
219   private static PsiExpression extractConditionalDefaultValue(@NotNull PsiExpression thenExpr,
220                                                               @NotNull PsiExpression elseExpr,
221                                                               @NotNull PsiExpression condition,
222                                                               @NotNull PsiVariable optValue) {
223     PsiVariable nullChecked = ExpressionUtils.getVariableFromNullComparison(condition, true);
224     boolean inverted = false;
225     if (nullChecked == null) {
226       nullChecked = ExpressionUtils.getVariableFromNullComparison(condition, false);
227       if (nullChecked == null) return null;
228       inverted = true;
229     }
230     if (!nullChecked.equals(optValue) || !ExpressionUtils.isReferenceTo(inverted ? thenExpr : elseExpr, optValue)) return null;
231     PsiExpression defaultExpression = inverted ? elseExpr : thenExpr;
232     if (VariableAccessUtils.variableIsUsed(optValue, defaultExpression)) return null;
233     return defaultExpression;
234   }
235
236   /**
237    * Optional.orElse and Optional.orElseGet have similar semantics and can be handled together.
238    * This enum represents what kind of method we are handling now.
239    */
240   private enum OrElseType {
241     OrElse,
242     OrElseGet
243   }
244
245   /**
246    * Stateless component, that can suggest simplification for call chain
247    *
248    * @param <C> context of the simplification
249    */
250   private interface ChainSimplificationCase<C> extends Function<PsiMethodCallExpression, OptionalSimplificationFix> {
251     @Override
252     default OptionalSimplificationFix apply(PsiMethodCallExpression expression) {
253       return getFix(expression, this);
254     }
255
256     @NotNull
257     String getName(@NotNull C context);
258
259     @NotNull
260     String getDescription(@NotNull C context);
261
262     /**
263      * Gathers context for handling simplification
264      * Called only if call matches to matcher returned by getMatcher call.
265      */
266     @Nullable
267     C extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call);
268
269     void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull C context);
270
271     default boolean isAvailable(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
272       return extractContext(project, call) != null;
273     }
274
275     @NotNull
276     CallMatcher getMatcher();
277
278     default boolean isAppropriateLanguageLevel(@NotNull LanguageLevel level) {
279       return true;
280     }
281   }
282
283   private static abstract class OptionalChainVisitor extends JavaElementVisitor {
284     private final LanguageLevel myLevel;
285
286     private OptionalChainVisitor(LanguageLevel level) {
287       myLevel = level;
288     }
289
290     @Override
291     public void visitMethodCallExpression(PsiMethodCallExpression expression) {
292       super.visitMethodCallExpression(expression);
293       Optional<OptionalSimplificationFix> fix = ourMapper
294         .mapAll(expression)
295         .filter(f -> f.myInspection.isAppropriateLanguageLevel(myLevel))
296         .findAny();
297       if (!fix.isPresent()) return;
298       handleSimplification(expression, fix.get());
299     }
300
301     protected abstract void handleSimplification(@NotNull PsiMethodCallExpression call, @NotNull OptionalSimplificationFix fix);
302   }
303
304   private static class MapOrElseCase extends BasicSimplificationInspection {
305     private final OrElseType myType;
306
307     private MapOrElseCase(OrElseType type) {myType = type;}
308
309     @Nullable
310     @Override
311     public StringReplacement extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
312       PsiExpression falseArg = getOrElseArgument(call, myType);
313       if (falseArg == null) return null;
314       PsiMethodCallExpression qualifierCall = MethodCallUtils.getQualifierMethodCall(call);
315       if (!OPTIONAL_MAP.test(qualifierCall)) return null;
316       PsiLambdaExpression lambda = getLambda(qualifierCall.getArgumentList().getExpressions()[0]);
317       if (lambda == null) return null;
318       PsiExpression trueArg = LambdaUtil.extractSingleExpressionFromBody(lambda.getBody());
319       if (trueArg == null) return null;
320       PsiParameter[] parameters = lambda.getParameterList().getParameters();
321       if (parameters.length != 1) return null;
322       PsiExpression qualifier = qualifierCall.getMethodExpression().getQualifierExpression();
323       if (qualifier == null) return null;
324       String opt = qualifier.getText();
325       PsiParameter parameter = parameters[0];
326       boolean useOrElseGet = myType == OrElseType.OrElseGet;
327       String proposed = OptionalRefactoringUtil.generateOptionalUnwrap(opt, parameter, trueArg, falseArg, call.getType(), useOrElseGet);
328       String canonicalOrElse;
329       if (useOrElseGet && !ExpressionUtils.isSafelyRecomputableExpression(falseArg)) {
330         canonicalOrElse = ".orElseGet(() -> " + falseArg.getText() + ")";
331       }
332       else {
333         canonicalOrElse = ".orElse(" + falseArg.getText() + ")";
334       }
335       String canonical = opt + ".map(" + LambdaUtil.createLambda(parameter, trueArg) + ")" + canonicalOrElse;
336       if (proposed.length() < canonical.length()) {
337         String displayCode;
338         if (proposed.equals(opt)) {
339           displayCode = "";
340         }
341         else if (opt.length() > 10) {
342           // should be a parseable expression
343           opt = "(($))";
344           String template =
345             OptionalRefactoringUtil.generateOptionalUnwrap(opt, parameter, trueArg, falseArg, call.getType(), useOrElseGet);
346           displayCode =
347             PsiExpressionTrimRenderer
348               .render(JavaPsiFacade.getElementFactory(parameter.getProject()).createExpressionFromText(template, call));
349           displayCode = displayCode.replaceFirst(Pattern.quote(opt), "..");
350         }
351         else {
352           displayCode =
353             PsiExpressionTrimRenderer
354               .render(JavaPsiFacade.getElementFactory(parameter.getProject()).createExpressionFromText(proposed, call));
355         }
356         String message = displayCode.isEmpty() ? "Remove redundant steps from optional chain" :
357                          "Simplify optional chain to '" + displayCode + "'";
358         String description = "Optional chain can be simplified";
359         return new StringReplacement(proposed, message, description);
360       }
361       return null;
362     }
363
364     @NotNull
365     @Override
366     public CallMatcher getMatcher () {
367       return getMatcherByType(myType);
368     }
369   }
370
371   static class OptionalSimplificationFix implements LocalQuickFix {
372     private final ChainSimplificationCase<?> myInspection;
373     private final String myName;
374     private final String myDescription;
375
376     OptionalSimplificationFix(ChainSimplificationCase<?> inspection, String name, String description) {
377       myInspection = inspection;
378       this.myName = name;
379       myDescription = description;
380     }
381
382     @Nls(capitalization = Nls.Capitalization.Sentence)
383     @NotNull
384     @Override
385     public String getFamilyName() {
386       return myName;
387     }
388
389     @Override
390     public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
391       PsiMethodCallExpression call = PsiTreeUtil.getParentOfType(descriptor.getStartElement(), PsiMethodCallExpression.class, false);
392       //PsiMethodCallExpression call = tryCast(descriptor.getStartElement(), PsiMethodCallExpression.class);
393       handleSimplification(myInspection, project, call);
394     }
395
396     String getDescription() {
397       return myDescription;
398     }
399   }
400
401   private abstract static class BasicSimplificationInspection
402     implements ChainSimplificationCase<BasicSimplificationInspection.StringReplacement> {
403     @NotNull
404     @Override
405     public String getName(@NotNull StringReplacement context) {
406       return context.myMessage;
407     }
408
409     @NotNull
410     @Override
411     public String getDescription(@NotNull StringReplacement context) {
412       return context.myDescription;
413     }
414
415
416     @Override
417     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull StringReplacement context) {
418       PsiExpression replacementExpression = JavaPsiFacade.getElementFactory(project).createExpressionFromText(context.myReplacement, call);
419       PsiElement result = call.replace(replacementExpression);
420       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
421       RemoveRedundantTypeArgumentsUtil.removeRedundantTypeArguments(result);
422     }
423
424     protected static class StringReplacement {
425       private final String myReplacement;
426       private final String myMessage;
427       private final String myDescription;
428
429       StringReplacement(String replacement, String message, String description) {
430         myReplacement = replacement;
431         myMessage = message;
432         myDescription = description;
433       }
434     }
435   }
436
437   private static class RewrappingCase implements ChainSimplificationCase<RewrappingCase.Context> {
438     private final CallMatcher myWrapper;
439     private final Type myType;
440
441     private RewrappingCase(Type type) {
442       myType = type;
443       if (myType == Type.OrElseNull) {
444         myWrapper = OPTIONAL_OF_NULLABLE;
445       } else {
446         myWrapper = OPTIONAL_OF_OF_NULLABLE;
447       }
448     }
449
450     @NotNull
451     @Override
452     public String getName(@NotNull Context context) {
453       return "Unwrap";
454     }
455
456     @NotNull
457     @Override
458     public String getDescription(@NotNull Context context) {
459       return "Unnecessary Optional rewrapping";
460     }
461
462     @Nullable
463     @Override
464     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
465       PsiElement parent = PsiUtil.skipParenthesizedExprUp(call.getParent());
466       if (!(parent instanceof PsiExpressionList)) return null;
467       PsiMethodCallExpression parentCall = tryCast(parent.getParent(), PsiMethodCallExpression.class);
468       if (!myWrapper.test(parentCall)) return null;
469       PsiExpression qualifier = call.getMethodExpression().getQualifierExpression();
470       if (qualifier == null ||
471           !EquivalenceChecker.getCanonicalPsiEquivalence().typesAreEquivalent(qualifier.getType(), parentCall.getType())) {
472         return null;
473       }
474       if ("get".equals(call.getMethodExpression().getReferenceName())) {
475         SpecialFieldValue fact = CommonDataflow.getExpressionFact(qualifier, DfaFactType.SPECIAL_FIELD_VALUE);
476         if (DfaFactType.NULLABILITY.fromDfaValue(SpecialField.OPTIONAL_VALUE.extract(fact)) != DfaNullability.NOT_NULL) return null;
477       }
478       return new Context(qualifier, parentCall);
479     }
480
481     @Override
482     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
483       PsiElement result = context.myCallToReplace.replace(context.myQualifier);
484       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
485       RemoveRedundantTypeArgumentsUtil.removeRedundantTypeArguments(result);
486     }
487
488     @NotNull
489     @Override
490     public CallMatcher getMatcher() {
491       if (myType == Type.OptionalGet) {
492         return OPTIONAL_GET;
493       }
494       return OPTIONAL_OR_ELSE_OR_ELSE_GET;
495     }
496
497     private static class Context {
498       private final PsiExpression myQualifier;
499       private final PsiExpression myCallToReplace;
500
501       private Context(PsiExpression qualifier, PsiExpression callToReplace) {
502         myQualifier = qualifier;
503         myCallToReplace = callToReplace;
504       }
505     }
506
507     enum Type {
508       OrElseNull,
509       OptionalGet
510     }
511   }
512
513   private static class OrElseReturnCase implements ChainSimplificationCase<OrElseReturnCase.Context> {
514     private final OrElseType myType;
515
516     private OrElseReturnCase(OrElseType type) {myType = type;}
517
518     @NotNull
519     @Override
520     public String getName(@NotNull Context context) {
521       String method = context.myIsSimple ? "orElse" : "orElseGet";
522       return "Replace null check with " + method + "(" + PsiExpressionTrimRenderer.render(context.myDefaultExpression) + ")";
523     }
524
525     @NotNull
526     @Override
527     public String getDescription(@NotNull Context context) {
528       return "Null check can be eliminated";
529     }
530
531     @Nullable
532     @Override
533     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
534       PsiExpression falseArg = getOrElseArgument(call, myType);
535       if (!ExpressionUtils.isNullLiteral(falseArg)) return null;
536       PsiLocalVariable returnVar = PsiTreeUtil.getParentOfType(call, PsiLocalVariable.class, true);
537       if (returnVar == null) return null;
538       PsiStatement nextStatement =
539         tryCast(PsiTreeUtil.skipWhitespacesForward(returnVar.getParent()), PsiStatement.class);
540       if (nextStatement == null) return null;
541       PsiExpression defaultValue = extractConditionalDefaultValue(nextStatement, returnVar);
542       boolean isSimple = ExpressionUtils.isSafelyRecomputableExpression(defaultValue);
543       if (defaultValue == null || (!isSimple && !LambdaGenerationUtil.canBeUncheckedLambda(defaultValue))) return null;
544       PsiType type = defaultValue.getType();
545       PsiType methodCallReturnValue = call.getMethodExpression().getType();
546       if (type == null || methodCallReturnValue == null || !methodCallReturnValue.isAssignableFrom(type)) return null;
547       return new Context(call, defaultValue, nextStatement, isSimple);
548     }
549
550     @Override
551     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
552       PsiExpression receiver = context.myOrElseCall.getMethodExpression().getQualifierExpression();
553       if (receiver == null) return;
554       String methodWithArg = context.myIsSimple
555                              ? ".orElse(" + context.myDefaultExpression.getText() + ")"
556                              : ".orElseGet(()->" + context.myDefaultExpression.getText() + ")";
557       String expressionText;
558       expressionText = receiver.getText() + methodWithArg;
559       PsiStatement finalStatement =
560         JavaPsiFacade.getElementFactory(project).createStatementFromText("return " + expressionText + ";", receiver);
561       PsiStatement current = PsiTreeUtil.getParentOfType(context.myOrElseCall, PsiStatement.class, false);
562       if (current == null) return;
563       PsiElement result = new CommentTracker().replaceAndRestoreComments(current, finalStatement);
564       new CommentTracker().deleteAndRestoreComments(context.myNextStatement);
565       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
566     }
567
568     @NotNull
569     @Override
570     public CallMatcher getMatcher() {
571       if (myType == OrElseType.OrElse) {
572         return OPTIONAL_OR_ELSE;
573       }
574       return OPTIONAL_OR_ELSE_GET;
575     }
576
577     private static class Context {
578       @NotNull private final PsiMethodCallExpression myOrElseCall;
579       @NotNull private final PsiExpression myDefaultExpression;
580       @NotNull private final PsiStatement myNextStatement;
581       private final boolean myIsSimple;
582
583       private Context(@NotNull PsiMethodCallExpression call,
584                       @NotNull PsiExpression defaultExpression,
585                       @NotNull PsiStatement nextStatement, boolean simple) {
586         myOrElseCall = call;
587         myDefaultExpression = defaultExpression;
588         myNextStatement = nextStatement;
589         myIsSimple = simple;
590       }
591     }
592   }
593
594   private static class FlipPresentOrEmptyCase implements ChainSimplificationCase<FlipPresentOrEmptyCase.Context> {
595     // Type of the inspection (may be either present or empty)
596     private final boolean myIsPresent;
597
598     private FlipPresentOrEmptyCase(boolean present) {myIsPresent = present;}
599
600     @NotNull
601     @Override
602     public String getName(@NotNull Context context) {
603       return CommonQuickFixBundle.message("fix.replace.with.x", context.myReplacement + "()");
604     }
605
606     @NotNull
607     @Override
608     public String getDescription(@NotNull Context context) {
609       return "'" + context.myReplacement + "()' can be used instead";
610     }
611
612     @Nullable
613     @Override
614     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
615       if (!BoolUtils.isNegated(call)) return null;
616       PsiElement nameElement = call.getMethodExpression().getReferenceNameElement();
617       if (nameElement == null) return null;
618       if (myIsPresent) {
619         return new Context("isEmpty");
620       }
621       return new Context("isPresent");
622     }
623
624     @Override
625     public boolean isAppropriateLanguageLevel(@NotNull LanguageLevel level) {
626       return level.isAtLeast(LanguageLevel.JDK_11);
627     }
628
629     @Override
630     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
631       PsiPrefixExpression negation = tryCast(PsiUtil.skipParenthesizedExprUp(call.getParent()), PsiPrefixExpression.class);
632       if (negation == null || BoolUtils.getNegated(negation) != call) return;
633       ExpressionUtils.bindCallTo(call, context.myReplacement);
634       new CommentTracker().replaceAndRestoreComments(negation, call);
635     }
636
637     @NotNull
638     @Override
639     public CallMatcher getMatcher() {
640       if (myIsPresent) {
641         return OPTIONAL_IS_PRESENT;
642       }
643       return OPTIONAL_IS_EMPTY;
644     }
645
646     private static class Context {
647       private final String myReplacement;
648
649       private Context(String replacement) {myReplacement = replacement;}
650     }
651   }
652
653   private static class OrElseNonNullCase implements ChainSimplificationCase<OrElseNonNullCase.Context> {
654     private final OrElseType myType;
655
656     private OrElseNonNullCase(OrElseType type) {myType = type;}
657
658     @NotNull
659     @Override
660     public String getName(@NotNull Context context) {
661       return "Replace null check with ifPresent()";
662     }
663
664     @NotNull
665     @Override
666     public String getDescription(@NotNull Context context) {
667       return "Null check can be eliminated with 'ifPresent'";
668     }
669
670     @Nullable
671     @Override
672     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
673       PsiExpression orElseArgument = getOrElseArgument(call, myType);
674       if (!ExpressionUtils.isNullLiteral(orElseArgument)) return null;
675       PsiLocalVariable returnVar = tryCast(call.getParent(), PsiLocalVariable.class);
676       if (returnVar == null) return null;
677       PsiStatement statement = PsiTreeUtil.getParentOfType(returnVar, PsiStatement.class, true);
678       if (statement == null) return null;
679       PsiStatement nextStatement =
680         tryCast(PsiTreeUtil.skipWhitespacesForward(returnVar.getParent()), PsiStatement.class);
681       if (nextStatement == null) return null;
682       PsiExpression lambdaExpr = extractMappingExpression(nextStatement, returnVar);
683       if (!LambdaGenerationUtil.canBeUncheckedLambda(lambdaExpr)) return null;
684       if (!ReferencesSearch.search(returnVar).allMatch(reference ->
685                                                          PsiTreeUtil.isAncestor(statement, reference.getElement(), false) ||
686                                                          PsiTreeUtil.isAncestor(nextStatement, reference.getElement(), false))) {
687         return null;
688       }
689       return new Context(lambdaExpr, nextStatement, statement, returnVar, call);
690     }
691
692     @Override
693     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
694       PsiExpression receiver = context.myOrElseCall.getMethodExpression().getQualifierExpression();
695       if(receiver == null) return;
696       String statementText = receiver.getText() + ".ifPresent(" + LambdaUtil.createLambda(context.myVariable, context.myAction) + ");";
697       PsiStatement finalStatement = JavaPsiFacade.getElementFactory(project).createStatementFromText(statementText, context.myStatement);
698       PsiElement result = context.myStatement.replace(finalStatement);
699       context.myConditionStatement.delete();
700       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
701     }
702
703     @NotNull
704     @Override
705     public CallMatcher getMatcher() {
706       return getMatcherByType(myType);
707     }
708
709
710
711     /**
712      * if(optValue != null) {
713      * System.out.println(optValue);
714      * }
715      **/
716     @Nullable
717     private static PsiExpression extractMappingExpression(@NotNull PsiStatement statement, @NotNull PsiVariable optValue) {
718       PsiIfStatement ifStatement = tryCast(statement, PsiIfStatement.class);
719       if (ifStatement == null) return null;
720       if (ifStatement.getElseBranch() != null) return null;
721       PsiExpression condition = ifStatement.getCondition();
722       if (condition == null) return null;
723       if (ExpressionUtils.getVariableFromNullComparison(condition, false) != optValue) return null;
724
725       PsiStatement thenStatement = ControlFlowUtils.stripBraces(ifStatement.getThenBranch());
726       PsiExpressionStatement expressionStatement = tryCast(thenStatement, PsiExpressionStatement.class);
727       if (expressionStatement == null) return null;
728       return expressionStatement.getExpression();
729     }
730
731     private static class Context {
732       private final @NotNull PsiExpression myAction;
733       private final @NotNull PsiStatement myConditionStatement;
734       private final @NotNull PsiStatement myStatement;
735       private final @NotNull PsiVariable myVariable;
736       private final @NotNull PsiMethodCallExpression myOrElseCall;
737
738       private Context(@NotNull PsiExpression action,
739                       @NotNull PsiStatement conditionStatement,
740                       @NotNull PsiStatement statement,
741                       @NotNull PsiVariable variable,
742                       @NotNull PsiMethodCallExpression call) {
743         myAction = action;
744         myConditionStatement = conditionStatement;
745         myStatement = statement;
746         myVariable = variable;
747         myOrElseCall = call;
748       }
749     }
750   }
751
752
753   /**
754    * Converts
755    * <pre>opt.map(a -> a.getOptional().orElse(null)).ifPresent(System.out::println);</pre>
756    * into
757    * <pre>opt.flatMap(Fra::getOptional).ifPresent(System.out::println);</pre>
758    */
759   private static class MapUnwrappingCase implements ChainSimplificationCase<MapUnwrappingCase.Context> {
760     @NotNull
761     @Override
762     public String getName(@NotNull Context context) {
763       return CommonQuickFixBundle.message("fix.replace.map.with.flat.map.name");
764     }
765
766     @NotNull
767     @Override
768     public String getDescription(@NotNull Context context) {
769       return CommonQuickFixBundle.message("fix.replace.map.with.flat.map.description");
770     }
771
772     @Nullable
773     @Override
774     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
775       PsiLambdaExpression lambda = getLambda(call.getArgumentList().getExpressions()[0]);
776       if (lambda == null) return null;
777       PsiParameter[] parameters = lambda.getParameterList().getParameters();
778       if (parameters.length != 1) return null;
779       PsiParameter mapLambdaParameter = parameters[0];
780       PsiExpression argument = LambdaUtil.extractSingleExpressionFromBody(lambda.getBody());
781       PsiMethodCallExpression insideLambdaCall = tryCast(argument, PsiMethodCallExpression.class);
782       if (insideLambdaCall == null) return null;
783       PsiExpression optionalQualifier = insideLambdaCall.getMethodExpression().getQualifierExpression();
784       if (optionalQualifier == null) return null;
785       if (!OPTIONAL_OR_ELSE.test(insideLambdaCall)) {
786         if (!OPTIONAL_GET.test(insideLambdaCall)) return null;
787         PsiExpression qualifier = insideLambdaCall.getMethodExpression().getQualifierExpression();
788         if (!isPresentOptional(qualifier)) return null;
789         return new Context(optionalQualifier, call, mapLambdaParameter);
790       }
791       if (!ExpressionUtils.isNullLiteral(insideLambdaCall.getArgumentList().getExpressions()[0])) return null;
792       return new Context(optionalQualifier, call, mapLambdaParameter);
793     }
794
795     @Override
796     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
797       CommentTracker ct = new CommentTracker();
798       String text = ct.text(context.myMapLambdaParameter) + " ->" + ct.text(context.myOptionalExpression);
799       PsiExpression qualifier = context.myMapCall.getMethodExpression().getQualifierExpression();
800       String callReplacement = Objects.requireNonNull(qualifier).getText() + ".flatMap(" + text + ")";
801       PsiElement result = ct.replaceAndRestoreComments(context.myMapCall, callReplacement);
802       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
803     }
804
805     @NotNull
806     @Override
807     public CallMatcher getMatcher() {
808       return OPTIONAL_MAP;
809     }
810
811     private static boolean isPresentOptional(PsiExpression optionalExpression) {
812       SpecialFieldValue fact = CommonDataflow.getExpressionFact(optionalExpression, DfaFactType.SPECIAL_FIELD_VALUE);
813       DfaValue value = SpecialField.OPTIONAL_VALUE.extract(fact);
814       if (!(value instanceof DfaFactMapValue)) return false;
815       return DfaNullability.toNullability(DfaFactType.NULLABILITY.fromDfaValue(value)) == Nullability.NOT_NULL;
816     }
817
818     private static class Context {
819       private final PsiExpression myOptionalExpression;
820       private final PsiMethodCallExpression myMapCall;
821       private final PsiParameter myMapLambdaParameter;
822
823       private Context(PsiExpression expression, PsiMethodCallExpression call, PsiParameter parameter) {
824         myOptionalExpression = expression;
825         myMapCall = call;
826         myMapLambdaParameter = parameter;
827       }
828     }
829   }
830
831   private static class IfPresentFoldedCase implements ChainSimplificationCase<IfPresentFoldedCase.Context> {
832     @NotNull
833     @Override
834     public String getName(@NotNull Context context) {
835       return CommonQuickFixBundle.message("fix.eliminate.folded.if.present.name");
836     }
837
838     @NotNull
839     @Override
840     public String getDescription(@NotNull Context context) {
841       return CommonQuickFixBundle.message("fix.eliminate.folded.if.present.description");
842     }
843
844     @Nullable
845     @Override
846     public Context extractContext(@NotNull Project project, @NotNull PsiMethodCallExpression call) {
847       PsiExpression outerIfPresentQualifier = call.getMethodExpression().getQualifierExpression();
848       PsiMethodCallExpression qualifierCall = tryCast(outerIfPresentQualifier, PsiMethodCallExpression.class);
849
850       PsiLambdaExpression outerIfPresentArgument = tryCast(call.getArgumentList().getExpressions()[0], PsiLambdaExpression.class);
851       if (outerIfPresentArgument == null) return null;
852       if (outerIfPresentArgument.getParameterList().getParametersCount() != 1) return null;
853       PsiParameter parameter = outerIfPresentArgument.getParameterList().getParameters()[0];
854       if (parameter == null) return null;
855       String outerIfPresentParameterName = parameter.getName();
856       if (outerIfPresentParameterName == null) return null;
857       PsiExpression outerIfPresentBodyExpr = LambdaUtil.extractSingleExpressionFromBody(outerIfPresentArgument.getBody());
858       PsiMethodCallExpression outerIfPresentBody = tryCast(outerIfPresentBodyExpr, PsiMethodCallExpression.class);
859       if (!OPTIONAL_IF_PRESENT.test(outerIfPresentBody)) return null;
860       PsiExpression innerIfPresentQualifier = outerIfPresentBody.getMethodExpression().getQualifierExpression();
861       PsiExpression nonTrivialQualifier = ExpressionUtils.isReferenceTo(innerIfPresentQualifier, parameter) ? null : innerIfPresentQualifier;
862       PsiExpression innerIfPresentArgument = outerIfPresentBody.getArgumentList().getExpressions()[0];
863
864       PsiMethodCallExpression mapBefore = null;
865       if (OPTIONAL_MAP.test(qualifierCall)) {
866         // case when map(Value::getOptional).ifPresent(p -> p.ifPresent(...))
867         if (isOptionalTypeParameter(qualifierCall.getType())) {
868           mapBefore = qualifierCall;
869         }
870       }
871       return new Context(mapBefore, nonTrivialQualifier, outerIfPresentParameterName, innerIfPresentArgument);
872     }
873
874     private static boolean isOptionalTypeParameter(@Nullable PsiType type) {
875       PsiClassType classType = tryCast(type, PsiClassType.class);
876       if (classType == null) return false;
877       if (classType.getParameterCount() != 1) return false;
878       PsiType typeParameter = classType.getParameters()[0];
879       PsiClass parameterClass = PsiUtil.resolveClassInClassTypeOnly(typeParameter);
880       if (parameterClass == null) return false;
881       return JAVA_UTIL_OPTIONAL.equals(parameterClass.getQualifiedName());
882     }
883
884     @Override
885     public void apply(@NotNull Project project, @NotNull PsiMethodCallExpression call, @NotNull Context context) {
886       PsiMethodCallExpression mapBefore = context.myMapBefore;
887       CommentTracker ct = new CommentTracker();
888       StringBuilder sb = new StringBuilder();
889       PsiExpression qualifer = call.getMethodExpression().getQualifierExpression();
890       assert qualifer != null;
891       sb.append(ct.text(qualifer)).append(".");
892       if (mapBefore != null) {
893         PsiExpression mapArgument = mapBefore.getArgumentList().getExpressions()[0];
894         sb.append("flatMap(").append(ct.text(mapArgument)).append(").");
895       }
896       PsiExpression lambdaBodyAfter = context.myMapLambdaBodyAfter;
897       if (lambdaBodyAfter != null) {
898         sb.append("flatMap(").append(context.myOuterIfPresentVarName).append("->").append(ct.text(lambdaBodyAfter)).append(").");
899       }
900       sb.append("ifPresent(").append(ct.text(context.myInnerIfPresentArgument)).append(")");
901       PsiElement result = ct.replaceAndRestoreComments(call, sb.toString());
902       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
903     }
904
905     @NotNull
906     @Override
907     public CallMatcher getMatcher() {
908       return OPTIONAL_IF_PRESENT;
909     }
910
911     class Context {
912       @Nullable PsiMethodCallExpression myMapBefore;
913       @Nullable PsiExpression myMapLambdaBodyAfter;
914       @NotNull String myOuterIfPresentVarName;
915       @NotNull PsiExpression myInnerIfPresentArgument;
916
917       Context(@Nullable PsiMethodCallExpression mapBefore,
918               @Nullable PsiExpression mapLambdaBodyAfter,
919               @NotNull String outerIfPresentVarName,
920               @NotNull PsiExpression innerIfPresentArgument) {
921         myMapBefore = mapBefore;
922         myMapLambdaBodyAfter = mapLambdaBodyAfter;
923         myOuterIfPresentVarName = outerIfPresentVarName;
924         myInnerIfPresentArgument = innerIfPresentArgument;
925       }
926     }
927   }
928 }