Cleanup: NotNull/Nullable
[idea/community.git] / java / java-impl / src / com / intellij / codeInspection / streamMigration / FoldExpressionIntoStreamInspection.java
1 // Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.intellij.codeInspection.streamMigration;
3
4 import com.intellij.codeInspection.*;
5 import com.intellij.codeInspection.util.LambdaGenerationUtil;
6 import com.intellij.openapi.project.Project;
7 import com.intellij.psi.*;
8 import com.intellij.psi.codeStyle.JavaCodeStyleManager;
9 import com.intellij.psi.codeStyle.VariableKind;
10 import com.intellij.psi.tree.IElementType;
11 import com.intellij.psi.util.InheritanceUtil;
12 import com.intellij.psi.util.PsiLiteralUtil;
13 import com.intellij.psi.util.PsiTreeUtil;
14 import com.intellij.psi.util.PsiUtil;
15 import com.intellij.util.ArrayUtil;
16 import com.siyeh.ig.psiutils.*;
17 import one.util.streamex.IntStreamEx;
18 import one.util.streamex.StreamEx;
19 import org.jetbrains.annotations.Nls;
20 import org.jetbrains.annotations.NotNull;
21 import org.jetbrains.annotations.Nullable;
22
23 import java.util.ArrayList;
24 import java.util.Collections;
25 import java.util.List;
26 import java.util.Objects;
27
28 import static com.intellij.codeInsight.intention.impl.StreamRefactoringUtil.getMapOperationName;
29 import static com.intellij.util.ObjectUtils.tryCast;
30
31 public class FoldExpressionIntoStreamInspection extends AbstractBaseJavaLocalInspectionTool {
32   @NotNull
33   @Override
34   public PsiElementVisitor buildVisitor(@NotNull ProblemsHolder holder, boolean isOnTheFly) {
35     if (!PsiUtil.isLanguageLevel8OrHigher(holder.getFile())) {
36       return PsiElementVisitor.EMPTY_VISITOR;
37     }
38     return new JavaElementVisitor() {
39       @Override
40       public void visitPolyadicExpression(PsiPolyadicExpression expression) {
41         TerminalGenerator generator = getGenerator(expression);
42         if (generator == null) return;
43         List<PsiExpression> diff = extractDiff(generator, expression);
44         if (diff.isEmpty()) return;
45         if (!LambdaGenerationUtil.canBeUncheckedLambda(expression)) return;
46         boolean stringJoin = generator.isStringJoin(expression, diff);
47         String message = InspectionsBundle.message(stringJoin ?
48                                                    "inspection.fold.expression.into.string.display.name" :
49                                                    "inspection.fold.expression.into.stream.display.name");
50         holder.registerProblem(expression, message,
51                                new FoldExpressionIntoStreamFix(stringJoin));
52       }
53     };
54   }
55
56   private static List<PsiExpression> extractDiff(TerminalGenerator generator,
57                                                  PsiPolyadicExpression expression) {
58     EquivalenceChecker equivalence = EquivalenceChecker.getCanonicalPsiEquivalence();
59     PsiExpression[] operands = generator.getOperands(expression);
60     if (operands.length < 3) return Collections.emptyList();
61     List<PsiExpression> elements = new ArrayList<>();
62     for (int i = 1; i < operands.length; i++) {
63       if (!Objects.equals(operands[0].getType(), operands[i].getType())) return Collections.emptyList();
64       EquivalenceChecker.Match match = equivalence.expressionsMatch(operands[0], operands[i]);
65       PsiExpression left = null;
66       PsiExpression right = null;
67       if (match.isPartialMatch()) {
68         left = tryCast(match.getLeftDiff(), PsiExpression.class);
69         right = tryCast(match.getRightDiff(), PsiExpression.class);
70       }
71       else if (match.isExactMismatch() && generator.isDittoSupported()) {
72         left = operands[0];
73         right = operands[i];
74       }
75       if (left == null || right == null) return Collections.emptyList();
76       if (elements.isEmpty()) {
77         if (!StreamApiUtil.isSupportedStreamElement(left.getType()) || !ExpressionUtils.isSafelyRecomputableExpression(left)) {
78           return Collections.emptyList();
79         }
80         PsiBinaryExpression binOp = tryCast(PsiUtil.skipParenthesizedExprDown(operands[0]), PsiBinaryExpression.class);
81         if (binOp != null) {
82           if (ComparisonUtils.isComparison(binOp) &&
83               (left == binOp.getLOperand() && ExpressionUtils.isSafelyRecomputableExpression(binOp.getROperand())) ||
84               (left == binOp.getROperand() && ExpressionUtils.isSafelyRecomputableExpression(binOp.getLOperand()))) {
85             // Disable for simple comparison chains like "a == null && b == null && c == null":
86             // using Stream API here looks an overkill
87             return Collections.emptyList();
88           }
89         }
90         elements.add(left);
91       }
92       else if (elements.get(0) != left) {
93         return Collections.emptyList();
94       }
95       if (!Objects.equals(left.getType(), right.getType()) ||
96           !ExpressionUtils.isSafelyRecomputableExpression(right)) {
97         return Collections.emptyList();
98       }
99       elements.add(right);
100     }
101     return elements;
102   }
103
104   private interface TerminalGenerator {
105     default PsiExpression[] getOperands(PsiPolyadicExpression polyadicExpression) {
106       return polyadicExpression.getOperands();
107     }
108
109     default boolean isDittoSupported() {
110       return false;
111     }
112
113     @NotNull
114     String generateTerminal(PsiType elementType, String lambda, CommentTracker ct);
115
116     default boolean isStringJoin(PsiPolyadicExpression expression, List<? extends PsiExpression> diff) {
117       return false;
118     }
119   }
120
121   @Nullable
122   private static TerminalGenerator getGenerator(PsiPolyadicExpression polyadicExpression) {
123     IElementType tokenType = polyadicExpression.getOperationTokenType();
124     if (tokenType.equals(JavaTokenType.OROR)) {
125       return (elementType, lambda, ct) -> ".anyMatch(" + lambda + ")";
126     }
127     else if (tokenType.equals(JavaTokenType.ANDAND)) {
128       return (elementType, lambda, ct) -> ".allMatch(" + lambda + ")";
129     }
130     else if (tokenType.equals(JavaTokenType.PLUS)) {
131       PsiType type = polyadicExpression.getType();
132       if (type instanceof PsiPrimitiveType) {
133         if (!StreamApiUtil.isSupportedStreamElement(type)) return null;
134         return (elementType, lambda, ct) -> "." + getMapOperationName(elementType, type) + "(" + lambda + ").sum()";
135       }
136       if (!TypeUtils.isJavaLangString(type)) return null;
137       PsiExpression[] operands = polyadicExpression.getOperands();
138       String mapToString;
139       PsiType operandType = operands[0].getType();
140       if (!InheritanceUtil.isInheritor(operandType, "java.lang.CharSequence")) {
141         if (!StreamApiUtil.isSupportedStreamElement(operandType)) return null;
142         mapToString = "."+getMapOperationName(operandType, type)+"(String::valueOf)";
143       } else {
144         mapToString = "";
145       }
146       PsiExpression delimiter = null;
147       PsiExpression rest = null;
148       if (operands.length > 4 && ExpressionUtils.isSafelyRecomputableExpression(operands[1]) &&
149           IntStreamEx.range(1, operands.length, 2).elements(operands)
150                      .pairMap(EquivalenceChecker.getCanonicalPsiEquivalence()::expressionsAreEquivalent)
151                      .allMatch(Boolean.TRUE::equals)) {
152         delimiter = operands[1];
153         if (!InheritanceUtil.isInheritor(delimiter.getType(), "java.lang.CharSequence") &&
154             !(delimiter instanceof PsiLiteralExpression && PsiType.CHAR.equals(delimiter.getType()))) {
155           return null;
156         }
157         if (operands.length % 2 == 0) {
158           rest = ArrayUtil.getLastElement(operands);
159         }
160       }
161       return new JoiningTerminalGenerator(operandType, mapToString, delimiter, rest);
162     }
163     return null;
164   }
165
166   @NotNull
167   private static String mapToString(PsiType elementType, PsiType resultType, String lambda) {
168     return "." + getMapOperationName(elementType, resultType) + "(" + lambda + ")";
169   }
170
171   private static class FoldExpressionIntoStreamFix implements LocalQuickFix {
172     private final boolean myStringJoin;
173
174     private FoldExpressionIntoStreamFix(boolean stringJoin) {myStringJoin = stringJoin;}
175
176     @Nls
177     @NotNull
178     @Override
179     public String getFamilyName() {
180       return InspectionsBundle.message("inspection.fold.expression.fix.family.name");
181     }
182
183     @Nls(capitalization = Nls.Capitalization.Sentence)
184     @NotNull
185     @Override
186     public String getName() {
187       return InspectionsBundle.message(myStringJoin ?
188                                        "inspection.fold.expression.into.string.fix.name" :
189                                        "inspection.fold.expression.into.stream.fix.name");
190     }
191
192     @Override
193     public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
194       PsiPolyadicExpression expression = tryCast(descriptor.getStartElement(), PsiPolyadicExpression.class);
195       if (expression == null) return;
196       TerminalGenerator generator = getGenerator(expression);
197       if (generator == null) return;
198       List<PsiExpression> diffs = extractDiff(generator, expression);
199       if (diffs.isEmpty()) return;
200
201       PsiExpression[] operands = expression.getOperands();
202       PsiExpression firstExpression = diffs.get(0);
203       assert PsiTreeUtil.isAncestor(operands[0], firstExpression, false);
204       Object marker = new Object();
205       PsiTreeUtil.mark(firstExpression, marker);
206       CommentTracker ct = new CommentTracker();
207       PsiExpression operandCopy = (PsiExpression)ct.markUnchanged(operands[0]).copy();
208       PsiElement expressionCopy = PsiTreeUtil.releaseMark(operandCopy, marker);
209       if (expressionCopy == null) return;
210       PsiType elementType = firstExpression.getType();
211       String name = new VariableNameGenerator(expression, VariableKind.PARAMETER).byType(elementType).byName("v").generate(true);
212       PsiElementFactory factory = JavaPsiFacade.getElementFactory(project);
213       PsiExpression expressionCopyReplaced = (PsiExpression)expressionCopy.replace(factory.createExpressionFromText(name, expressionCopy));
214       if (operandCopy == expressionCopy) {
215         operandCopy = expressionCopyReplaced;
216       }
217       String operandCopyText = operandCopy.getText();
218       String lambda = operandCopyText.equals(name) ? null : name + "->" + operandCopyText;
219       String streamClass = StreamApiUtil.getStreamClassForType(elementType);
220       if (streamClass == null) return;
221       String source = streamClass + "." + (elementType instanceof PsiClassType ? "<" + elementType.getCanonicalText() + ">" : "")
222                       + "of" + StreamEx.of(diffs).map(ct::text).joining(",", "(", ")");
223       String fullStream = source + generator.generateTerminal(elementType, lambda, ct);
224       PsiElement result = ct.replaceAndRestoreComments(expression, fullStream);
225       cleanup(result);
226     }
227
228     private static void cleanup(PsiElement result) {
229       JavaCodeStyleManager codeStyleManager = JavaCodeStyleManager.getInstance(result.getProject());
230       result = SimplifyStreamApiCallChainsInspection.simplifyStreamExpressions(result, false);
231       LambdaCanBeMethodReferenceInspection.replaceAllLambdasWithMethodReferences(result);
232       result = codeStyleManager.shortenClassReferences(result);
233       RemoveRedundantTypeArgumentsUtil.removeRedundantTypeArguments(result);
234     }
235   }
236
237   private static class JoiningTerminalGenerator implements TerminalGenerator {
238     private final PsiType myOperandType;
239     private final String myMapToString;
240     private final PsiExpression myDelimiter;
241     private final PsiExpression myRest;
242
243     JoiningTerminalGenerator(PsiType operandType, String mapToString, PsiExpression delimiter, PsiExpression rest) {
244       myOperandType = operandType;
245       myMapToString = mapToString;
246       myDelimiter = delimiter;
247       myRest = rest;
248     }
249
250     @Override
251     public PsiExpression[] getOperands(PsiPolyadicExpression polyadicExpression) {
252       PsiExpression[] ops = polyadicExpression.getOperands();
253       return myDelimiter == null ? ops :
254              IntStreamEx.range(0, ops.length, 2).elements(ops).toArray(PsiExpression.EMPTY_ARRAY);
255     }
256
257     @Override
258     public boolean isDittoSupported() {
259       return myDelimiter != null;
260     }
261
262     @Override
263     public boolean isStringJoin(PsiPolyadicExpression expression, List<? extends PsiExpression> diff) {
264       if (!myMapToString.isEmpty()) return false;
265       PsiExpression[] operands = getOperands(expression);
266       return operands[0] == diff.get(0);
267     }
268
269     @NotNull
270     @Override
271     public String generateTerminal(PsiType elementType, String lambda, CommentTracker ct) {
272       String map = (lambda == null ? "" : mapToString(elementType, myOperandType, lambda)) + myMapToString;
273       return map +
274              ".collect(" + CommonClassNames.JAVA_UTIL_STREAM_COLLECTORS +
275              ".joining(" + getDelimiterText(ct) + "))" +
276              (myRest == null ? "" : "+" + ct.text(myRest));
277     }
278
279     @NotNull
280     private String getDelimiterText(CommentTracker ct) {
281       if (myDelimiter == null) {
282         return "";
283       }
284       String text = ct.text(myDelimiter);
285       if (text.startsWith("'")) {
286         return PsiLiteralUtil.stringForCharLiteral(text);
287       }
288       return text;
289     }
290   }
291 }