44804e0ddee0f42bbea562d8ac7c00becc4709e4
[idea/community.git] / python / src / com / jetbrains / python / refactoring / PyRefactoringUtil.java
1 // Copyright 2000-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.jetbrains.python.refactoring;
3
4 import com.google.common.collect.Collections2;
5 import com.intellij.codeInsight.PsiEquivalenceUtil;
6 import com.intellij.ide.fileTemplates.FileTemplate;
7 import com.intellij.ide.fileTemplates.FileTemplateManager;
8 import com.intellij.openapi.project.Project;
9 import com.intellij.openapi.util.Comparing;
10 import com.intellij.openapi.util.Pair;
11 import com.intellij.openapi.util.TextRange;
12 import com.intellij.openapi.util.io.FileUtilRt;
13 import com.intellij.openapi.util.text.StringUtil;
14 import com.intellij.openapi.vfs.LocalFileSystem;
15 import com.intellij.openapi.vfs.VirtualFile;
16 import com.intellij.psi.*;
17 import com.intellij.psi.util.PsiTreeUtil;
18 import com.intellij.psi.util.PsiUtilCore;
19 import com.intellij.util.IncorrectOperationException;
20 import com.intellij.util.containers.ContainerUtil;
21 import com.jetbrains.NotNullPredicate;
22 import com.jetbrains.python.PyBundle;
23 import com.jetbrains.python.PyNames;
24 import com.jetbrains.python.psi.*;
25 import com.jetbrains.python.refactoring.classes.extractSuperclass.PyExtractSuperclassHelper;
26 import com.jetbrains.python.refactoring.classes.membersManager.PyMemberInfo;
27 import com.jetbrains.python.refactoring.introduce.IntroduceValidator;
28 import org.jetbrains.annotations.NotNull;
29 import org.jetbrains.annotations.Nullable;
30
31 import java.io.File;
32 import java.io.IOException;
33 import java.util.*;
34 import java.util.function.BiPredicate;
35
36 public final class PyRefactoringUtil {
37   private PyRefactoringUtil() {
38   }
39
40   @NotNull
41   public static List<PsiElement> getOccurrences(@NotNull final PsiElement pattern, @Nullable final PsiElement context) {
42     if (context == null) {
43       return Collections.emptyList();
44     }
45     final List<PsiElement> occurrences = new ArrayList<>();
46     final PyElementVisitor visitor = new PyElementVisitor() {
47       @Override
48       public void visitElement(@NotNull final PsiElement element) {
49         if (element instanceof PyParameter) {
50           return;
51         }
52         if (PsiEquivalenceUtil.areElementsEquivalent(element, pattern)) {
53           occurrences.add(element);
54           return;
55         }
56         if (element instanceof PyStringLiteralExpression) {
57           final Pair<PsiElement, TextRange> selection = pattern.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE);
58           if (selection != null) {
59             final String substring = selection.getSecond().substring(pattern.getText());
60             final PyStringLiteralExpression expr = (PyStringLiteralExpression)element;
61             final String text = element.getText();
62             if (text != null && expr.getStringNodes().size() == 1) {
63               final int start = text.indexOf(substring);
64               if (start >= 0) {
65                 element.putUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE, Pair.create(element, TextRange.from(start, substring.length())));
66                 occurrences.add(element);
67                 return;
68               }
69             }
70           }
71         }
72         element.acceptChildren(this);
73       }
74     };
75     context.acceptChildren(visitor);
76     return occurrences;
77   }
78
79   @Nullable
80   public static PyExpression getSelectedExpression(@NotNull final Project project,
81                                                    @NotNull PsiFile file,
82                                                    @NotNull final PsiElement element1,
83                                                    @NotNull final PsiElement element2) {
84     PsiElement parent = PsiTreeUtil.findCommonParent(element1, element2);
85     if (parent != null && !(parent instanceof PyElement)) {
86       parent = PsiTreeUtil.getParentOfType(parent, PyElement.class);
87     }
88     if (parent == null) {
89       return null;
90     }
91     // If it is PyIfPart for example, parent if statement, we should deny
92     if (!(parent instanceof PyExpression)){
93       return null;
94     }
95     // We cannot extract anything within import statements
96     if (PsiTreeUtil.getParentOfType(parent, PyImportStatement.class, PyFromImportStatement.class) != null){
97       return null;
98     }
99     if ((element1 == PsiTreeUtil.getDeepestFirst(parent)) && (element2 == PsiTreeUtil.getDeepestLast(parent))) {
100       return (PyExpression) parent;
101     }
102
103     // Check if selection breaks AST node in binary expression
104     if (parent instanceof PyBinaryExpression) {
105       final String selection = file.getText().substring(element1.getTextOffset(), element2.getTextOffset() + element2.getTextLength());
106       final PyElementGenerator generator = PyElementGenerator.getInstance(project);
107       final LanguageLevel langLevel = LanguageLevel.forElement(element1);
108       final PyExpression expression = generator.createFromText(langLevel, PyAssignmentStatement.class, "z=" + selection).getAssignedValue();
109       if (!(expression instanceof PyBinaryExpression) || PsiUtilCore.hasErrorElementChild(expression)) {
110         return null;
111       }
112       final String parentText = parent.getText();
113       final int startOffset = element1.getTextOffset() - parent.getTextOffset() - 1;
114       if (startOffset < 0) {
115         return null;
116       }
117       final int endOffset = element2.getTextOffset() + element2.getTextLength() - parent.getTextOffset();
118
119       final String prefix = parentText.substring(0, startOffset);
120       final String suffix = parentText.substring(endOffset);
121       final TextRange textRange = TextRange.from(startOffset, endOffset - startOffset);
122       final PsiElement fakeExpression = generator.createExpressionFromText(langLevel, prefix + "python" + suffix);
123       if (PsiUtilCore.hasErrorElementChild(fakeExpression)) {
124         return null;
125       }
126
127       expression.putUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE, Pair.create(parent, textRange));
128       return expression;
129     }
130     return null;
131   }
132
133   @Nullable
134   public static PsiElement findExpressionInRange(@NotNull final PsiFile file, int startOffset, int endOffset) {
135     PsiElement element1 = file.findElementAt(startOffset);
136     PsiElement element2 = file.findElementAt(endOffset - 1);
137     if (element1 instanceof PsiWhiteSpace) {
138       startOffset = element1.getTextRange().getEndOffset();
139       element1 = file.findElementAt(startOffset);
140     }
141     if (element2 instanceof PsiWhiteSpace) {
142       endOffset = element2.getTextRange().getStartOffset();
143       element2 = file.findElementAt(endOffset - 1);
144     }
145     if (element1 == null || element2 == null) {
146       return null;
147     }
148     return getSelectedExpression(file.getProject(), file, element1, element2);
149   }
150
151   public static PsiElement @NotNull [] findStatementsInRange(@NotNull final PsiFile file, int startOffset, int endOffset) {
152     ArrayList<PsiElement> array = new ArrayList<>();
153
154     PsiElement element1 = file.findElementAt(startOffset);
155     PsiElement element2 = file.findElementAt(endOffset - 1);
156     PsiElement endComment = null;
157
158     boolean startsWithWhitespace = false;
159     boolean endsWithWhitespace = false;
160     if (element1 instanceof PsiWhiteSpace) {
161       startOffset = element1.getTextRange().getEndOffset();
162       element1 = file.findElementAt(startOffset);
163       startsWithWhitespace = true;
164     }
165     if (element2 instanceof PsiWhiteSpace) {
166       element2 = PsiTreeUtil.skipWhitespacesBackward(element2);
167       endsWithWhitespace = true;
168     }
169     while (element2 instanceof PsiComment) {
170       endComment = element2;
171       element2 = PsiTreeUtil.skipWhitespacesAndCommentsBackward(element2);
172       endsWithWhitespace = true;
173     }
174
175     while (element1 instanceof PsiComment) {
176       array.add(element1);
177       element1 = PsiTreeUtil.skipWhitespacesForward(element1);
178       startsWithWhitespace = true;
179     }
180
181     if (element1 == null || element2 == null) {
182       return PsiElement.EMPTY_ARRAY;
183     }
184
185     PsiElement parent = PsiTreeUtil.findCommonParent(element1, element2);
186     if (parent == null) {
187       return PsiElement.EMPTY_ARRAY;
188     }
189
190     while (true) {
191       if (parent instanceof PyStatement) {
192         parent = parent.getParent();
193         break;
194       }
195       if (parent instanceof PyStatementList) {
196         break;
197       }
198       if (parent == null || parent instanceof PsiFile) {
199         return PsiElement.EMPTY_ARRAY;
200       }
201       parent = parent.getParent();
202     }
203
204     if (!parent.equals(element1)) {
205       while (!parent.equals(element1.getParent())) {
206         element1 = element1.getParent();
207       }
208     }
209     if (startOffset != element1.getTextRange().getStartOffset() && !startsWithWhitespace) {
210       return PsiElement.EMPTY_ARRAY;
211     }
212
213     if (!parent.equals(element2)) {
214       while (!parent.equals(element2.getParent())) {
215         element2 = element2.getParent();
216       }
217     }
218     if (endOffset != element2.getTextRange().getEndOffset() && !endsWithWhitespace) {
219       return PsiElement.EMPTY_ARRAY;
220     }
221
222     if (element1 instanceof PyFunction || element1 instanceof PyClass) {
223       return PsiElement.EMPTY_ARRAY;
224     }
225     if (element2 instanceof PyFunction || element2 instanceof PyClass) {
226       return PsiElement.EMPTY_ARRAY;
227     }
228
229     PsiElement[] children = parent.getChildren();
230
231     boolean flag = false;
232     for (PsiElement child : children) {
233       if (child.equals(element1)) {
234         flag = true;
235       }
236       if (flag && !(child instanceof PsiWhiteSpace)) {
237         array.add(child);
238       }
239       if (child.equals(element2)) {
240         break;
241       }
242     }
243
244     while (endComment instanceof PsiComment) {
245       array.add(endComment);
246       endComment = PsiTreeUtil.skipWhitespacesForward(endComment);
247     }
248
249     for (PsiElement element : array) {
250       if (!(element instanceof PyStatement || element instanceof PsiWhiteSpace || element instanceof PsiComment)) {
251         return PsiElement.EMPTY_ARRAY;
252       }
253     }
254     return PsiUtilCore.toPsiElementArray(array);
255   }
256
257   public static boolean areConflictingMethods(PyFunction pyFunction, PyFunction pyFunction1) {
258     final PyParameter[] firstParams = pyFunction.getParameterList().getParameters();
259     final PyParameter[] secondParams = pyFunction1.getParameterList().getParameters();
260     final String firstName = pyFunction.getName();
261     final String secondName = pyFunction1.getName();
262
263     return Comparing.strEqual(firstName, secondName) && firstParams.length == secondParams.length;
264   }
265
266   /**
267    * Selects the shortest unique name inside the scope of scopeAnchor generated using {@link NameSuggesterUtil#generateNamesByType(String)}.
268    * If none of those names is suitable, unique names is made by appending number suffix.
269    *
270    * @param typeName    initial type name for generator
271    * @param scopeAnchor PSI element used to determine correct scope
272    * @return unique name in the scope of scopeAnchor
273    */
274   @NotNull
275   public static String selectUniqueNameFromType(@NotNull String typeName, @NotNull PsiElement scopeAnchor) {
276     return selectUniqueName(typeName, true, scopeAnchor, PyRefactoringUtil::isValidNewName);
277   }
278
279   /**
280    * Selects the shortest unique name inside the scope of scopeAnchor generated using {@link NameSuggesterUtil#generateNames(String)}.
281    * If none of those names is suitable, unique names is made by appending number suffix.
282    *
283    * @param templateName initial template name for generator
284    * @param scopeAnchor  PSI element used to determine correct scope
285    * @return unique name in the scope of scopeAnchor
286    */
287   @NotNull
288   public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor) {
289     return selectUniqueName(templateName, false, scopeAnchor, PyRefactoringUtil::isValidNewName);
290   }
291
292   @NotNull
293   public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
294     return selectUniqueName(templateName, false, scopeAnchor, isValid);
295   }
296
297   @NotNull
298   private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
299     final Collection<String> suggestions;
300     if (templateIsType) {
301       suggestions = NameSuggesterUtil.generateNamesByType(templateName);
302     }
303     else {
304       suggestions = NameSuggesterUtil.generateNames(templateName);
305     }
306     for (String name : suggestions) {
307       if (isValid.test(name, scopeAnchor)) {
308         return name;
309       }
310     }
311
312     final String shortestName = ContainerUtil.getFirstItem(suggestions);
313     //noinspection ConstantConditions
314     return appendNumberUntilValid(shortestName, scopeAnchor, isValid);
315   }
316
317   /**
318    * Appends increasing numbers starting from 1 to the name until it becomes unique within the scope of the scopeAnchor.
319    *
320    * @param name        initial name
321    * @param scopeAnchor PSI element used to determine correct scope
322    * @param predicate used to test if suggested name is valid
323    * @return unique name in the scope probably with number suffix appended
324    */
325   @NotNull
326   public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> predicate) {
327     int counter = 1;
328     String candidate = name;
329     while (!predicate.test(candidate, scopeAnchor)) {
330       candidate = name + counter;
331       counter++;
332     }
333     return candidate;
334   }
335
336   public static boolean isValidNewName(@NotNull String name, @NotNull PsiElement scopeAnchor) {
337     return !(IntroduceValidator.isDefinedInScope(name, scopeAnchor) || PyNames.isReserved(name));
338   }
339
340   @NotNull
341   public static PyFile getOrCreateFile(String path, Project project) {
342     final VirtualFile vfile = LocalFileSystem.getInstance().findFileByIoFile(new File(path));
343     final PsiFile psi;
344     if (vfile == null) {
345       final File file = new File(path);
346       try {
347         final VirtualFile baseDir = project.getBaseDir();
348         final FileTemplateManager fileTemplateManager = FileTemplateManager.getInstance(project);
349         final FileTemplate template = fileTemplateManager.getInternalTemplate("Python Script");
350         final Properties properties = fileTemplateManager.getDefaultProperties();
351         properties.setProperty("NAME", FileUtilRt.getNameWithoutExtension(file.getName()));
352         final String content = (template != null) ? template.getText(properties) : null;
353         psi = PyExtractSuperclassHelper.placeFile(project,
354                                                   StringUtil.notNullize(
355                                                     file.getParent(),
356                                                     baseDir != null ? baseDir
357                                                       .getPath() : "."
358                                                   ),
359                                                   file.getName(),
360                                                   content
361         );
362       }
363       catch (IOException e) {
364         throw new IncorrectOperationException(String.format("Cannot create file '%s'", path), (Throwable)e);
365       }
366     }
367     else {
368       psi = PsiManager.getInstance(project).findFile(vfile);
369     }
370     if (!(psi instanceof PyFile)) {
371       throw new IncorrectOperationException(PyBundle.message(
372         "refactoring.move.module.members.error.cannot.place.elements.into.nonpython.file"));
373     }
374     return (PyFile)psi;
375   }
376
377   /**
378    * Filters out {@link PyMemberInfo}
379    * that should not be displayed in this refactoring (like object)
380    *
381    * @param pyMemberInfos collection to sort
382    * @return sorted collection
383    */
384   @NotNull
385   public static Collection<PyMemberInfo<PyElement>> filterOutObject(@NotNull final Collection<PyMemberInfo<PyElement>> pyMemberInfos) {
386     return Collections2.filter(pyMemberInfos, new ObjectPredicate(false));
387   }
388
389   /**
390    * Filters only PyClass object (new class)
391    */
392   public static class ObjectPredicate extends NotNullPredicate<PyMemberInfo<PyElement>> {
393     private final boolean myAllowObjects;
394
395     /**
396      * @param allowObjects allows only objects if true. Allows all but objects otherwise.
397      */
398     public ObjectPredicate(final boolean allowObjects) {
399       myAllowObjects = allowObjects;
400     }
401
402     @Override
403     public boolean applyNotNull(@NotNull final PyMemberInfo<PyElement> input) {
404       return myAllowObjects == isObject(input);
405     }
406
407     private static boolean isObject(@NotNull final PyMemberInfo<PyElement> classMemberInfo) {
408       final PyElement element = classMemberInfo.getMember();
409       return (element instanceof PyClass) && PyNames.OBJECT.equals(element.getName());
410     }
411   }
412
413
414   /**
415    * @deprecated TODO: remove in 2021.1
416    */
417   @NotNull
418   @Deprecated
419   public static PsiElement addElementToStatementList(@NotNull PsiElement element,
420                                                      @NotNull PyStatementList statementList,
421                                                      boolean toTheBeginning) {
422     return PyPsiRefactoringUtil.addElementToStatementList(element, statementList, toTheBeginning);
423   }
424 }