1 // Copyright 2000-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.jetbrains.python.refactoring;
4 import com.google.common.collect.Collections2;
5 import com.intellij.codeInsight.PsiEquivalenceUtil;
6 import com.intellij.ide.fileTemplates.FileTemplate;
7 import com.intellij.ide.fileTemplates.FileTemplateManager;
8 import com.intellij.openapi.project.Project;
9 import com.intellij.openapi.util.Comparing;
10 import com.intellij.openapi.util.Pair;
11 import com.intellij.openapi.util.TextRange;
12 import com.intellij.openapi.util.io.FileUtilRt;
13 import com.intellij.openapi.util.text.StringUtil;
14 import com.intellij.openapi.vfs.LocalFileSystem;
15 import com.intellij.openapi.vfs.VirtualFile;
16 import com.intellij.psi.*;
17 import com.intellij.psi.util.PsiTreeUtil;
18 import com.intellij.psi.util.PsiUtilCore;
19 import com.intellij.util.IncorrectOperationException;
20 import com.intellij.util.containers.ContainerUtil;
21 import com.jetbrains.NotNullPredicate;
22 import com.jetbrains.python.PyBundle;
23 import com.jetbrains.python.PyNames;
24 import com.jetbrains.python.psi.*;
25 import com.jetbrains.python.refactoring.classes.extractSuperclass.PyExtractSuperclassHelper;
26 import com.jetbrains.python.refactoring.classes.membersManager.PyMemberInfo;
27 import com.jetbrains.python.refactoring.introduce.IntroduceValidator;
28 import org.jetbrains.annotations.NotNull;
29 import org.jetbrains.annotations.Nullable;
32 import java.io.IOException;
34 import java.util.function.BiPredicate;
36 public final class PyRefactoringUtil {
37 private PyRefactoringUtil() {
41 public static List<PsiElement> getOccurrences(@NotNull final PsiElement pattern, @Nullable final PsiElement context) {
42 if (context == null) {
43 return Collections.emptyList();
45 final List<PsiElement> occurrences = new ArrayList<>();
46 final PyElementVisitor visitor = new PyElementVisitor() {
48 public void visitElement(@NotNull final PsiElement element) {
49 if (element instanceof PyParameter) {
52 if (PsiEquivalenceUtil.areElementsEquivalent(element, pattern)) {
53 occurrences.add(element);
56 if (element instanceof PyStringLiteralExpression) {
57 final Pair<PsiElement, TextRange> selection = pattern.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE);
58 if (selection != null) {
59 final String substring = selection.getSecond().substring(pattern.getText());
60 final PyStringLiteralExpression expr = (PyStringLiteralExpression)element;
61 final String text = element.getText();
62 if (text != null && expr.getStringNodes().size() == 1) {
63 final int start = text.indexOf(substring);
65 element.putUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE, Pair.create(element, TextRange.from(start, substring.length())));
66 occurrences.add(element);
72 element.acceptChildren(this);
75 context.acceptChildren(visitor);
80 public static PyExpression getSelectedExpression(@NotNull final Project project,
81 @NotNull PsiFile file,
82 @NotNull final PsiElement element1,
83 @NotNull final PsiElement element2) {
84 PsiElement parent = PsiTreeUtil.findCommonParent(element1, element2);
85 if (parent != null && !(parent instanceof PyElement)) {
86 parent = PsiTreeUtil.getParentOfType(parent, PyElement.class);
91 // If it is PyIfPart for example, parent if statement, we should deny
92 if (!(parent instanceof PyExpression)){
95 // We cannot extract anything within import statements
96 if (PsiTreeUtil.getParentOfType(parent, PyImportStatement.class, PyFromImportStatement.class) != null){
99 if ((element1 == PsiTreeUtil.getDeepestFirst(parent)) && (element2 == PsiTreeUtil.getDeepestLast(parent))) {
100 return (PyExpression) parent;
103 // Check if selection breaks AST node in binary expression
104 if (parent instanceof PyBinaryExpression) {
105 final String selection = file.getText().substring(element1.getTextOffset(), element2.getTextOffset() + element2.getTextLength());
106 final PyElementGenerator generator = PyElementGenerator.getInstance(project);
107 final LanguageLevel langLevel = LanguageLevel.forElement(element1);
108 final PyExpression expression = generator.createFromText(langLevel, PyAssignmentStatement.class, "z=" + selection).getAssignedValue();
109 if (!(expression instanceof PyBinaryExpression) || PsiUtilCore.hasErrorElementChild(expression)) {
112 final String parentText = parent.getText();
113 final int startOffset = element1.getTextOffset() - parent.getTextOffset() - 1;
114 if (startOffset < 0) {
117 final int endOffset = element2.getTextOffset() + element2.getTextLength() - parent.getTextOffset();
119 final String prefix = parentText.substring(0, startOffset);
120 final String suffix = parentText.substring(endOffset);
121 final TextRange textRange = TextRange.from(startOffset, endOffset - startOffset);
122 final PsiElement fakeExpression = generator.createExpressionFromText(langLevel, prefix + "python" + suffix);
123 if (PsiUtilCore.hasErrorElementChild(fakeExpression)) {
127 expression.putUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE, Pair.create(parent, textRange));
134 public static PsiElement findExpressionInRange(@NotNull final PsiFile file, int startOffset, int endOffset) {
135 PsiElement element1 = file.findElementAt(startOffset);
136 PsiElement element2 = file.findElementAt(endOffset - 1);
137 if (element1 instanceof PsiWhiteSpace) {
138 startOffset = element1.getTextRange().getEndOffset();
139 element1 = file.findElementAt(startOffset);
141 if (element2 instanceof PsiWhiteSpace) {
142 endOffset = element2.getTextRange().getStartOffset();
143 element2 = file.findElementAt(endOffset - 1);
145 if (element1 == null || element2 == null) {
148 return getSelectedExpression(file.getProject(), file, element1, element2);
151 public static PsiElement @NotNull [] findStatementsInRange(@NotNull final PsiFile file, int startOffset, int endOffset) {
152 ArrayList<PsiElement> array = new ArrayList<>();
154 PsiElement element1 = file.findElementAt(startOffset);
155 PsiElement element2 = file.findElementAt(endOffset - 1);
156 PsiElement endComment = null;
158 boolean startsWithWhitespace = false;
159 boolean endsWithWhitespace = false;
160 if (element1 instanceof PsiWhiteSpace) {
161 startOffset = element1.getTextRange().getEndOffset();
162 element1 = file.findElementAt(startOffset);
163 startsWithWhitespace = true;
165 if (element2 instanceof PsiWhiteSpace) {
166 element2 = PsiTreeUtil.skipWhitespacesBackward(element2);
167 endsWithWhitespace = true;
169 while (element2 instanceof PsiComment) {
170 endComment = element2;
171 element2 = PsiTreeUtil.skipWhitespacesAndCommentsBackward(element2);
172 endsWithWhitespace = true;
175 while (element1 instanceof PsiComment) {
177 element1 = PsiTreeUtil.skipWhitespacesForward(element1);
178 startsWithWhitespace = true;
181 if (element1 == null || element2 == null) {
182 return PsiElement.EMPTY_ARRAY;
185 PsiElement parent = PsiTreeUtil.findCommonParent(element1, element2);
186 if (parent == null) {
187 return PsiElement.EMPTY_ARRAY;
191 if (parent instanceof PyStatement) {
192 parent = parent.getParent();
195 if (parent instanceof PyStatementList) {
198 if (parent == null || parent instanceof PsiFile) {
199 return PsiElement.EMPTY_ARRAY;
201 parent = parent.getParent();
204 if (!parent.equals(element1)) {
205 while (!parent.equals(element1.getParent())) {
206 element1 = element1.getParent();
209 if (startOffset != element1.getTextRange().getStartOffset() && !startsWithWhitespace) {
210 return PsiElement.EMPTY_ARRAY;
213 if (!parent.equals(element2)) {
214 while (!parent.equals(element2.getParent())) {
215 element2 = element2.getParent();
218 if (endOffset != element2.getTextRange().getEndOffset() && !endsWithWhitespace) {
219 return PsiElement.EMPTY_ARRAY;
222 if (element1 instanceof PyFunction || element1 instanceof PyClass) {
223 return PsiElement.EMPTY_ARRAY;
225 if (element2 instanceof PyFunction || element2 instanceof PyClass) {
226 return PsiElement.EMPTY_ARRAY;
229 PsiElement[] children = parent.getChildren();
231 boolean flag = false;
232 for (PsiElement child : children) {
233 if (child.equals(element1)) {
236 if (flag && !(child instanceof PsiWhiteSpace)) {
239 if (child.equals(element2)) {
244 while (endComment instanceof PsiComment) {
245 array.add(endComment);
246 endComment = PsiTreeUtil.skipWhitespacesForward(endComment);
249 for (PsiElement element : array) {
250 if (!(element instanceof PyStatement || element instanceof PsiWhiteSpace || element instanceof PsiComment)) {
251 return PsiElement.EMPTY_ARRAY;
254 return PsiUtilCore.toPsiElementArray(array);
257 public static boolean areConflictingMethods(PyFunction pyFunction, PyFunction pyFunction1) {
258 final PyParameter[] firstParams = pyFunction.getParameterList().getParameters();
259 final PyParameter[] secondParams = pyFunction1.getParameterList().getParameters();
260 final String firstName = pyFunction.getName();
261 final String secondName = pyFunction1.getName();
263 return Comparing.strEqual(firstName, secondName) && firstParams.length == secondParams.length;
267 * Selects the shortest unique name inside the scope of scopeAnchor generated using {@link NameSuggesterUtil#generateNamesByType(String)}.
268 * If none of those names is suitable, unique names is made by appending number suffix.
270 * @param typeName initial type name for generator
271 * @param scopeAnchor PSI element used to determine correct scope
272 * @return unique name in the scope of scopeAnchor
275 public static String selectUniqueNameFromType(@NotNull String typeName, @NotNull PsiElement scopeAnchor) {
276 return selectUniqueName(typeName, true, scopeAnchor, PyRefactoringUtil::isValidNewName);
280 * Selects the shortest unique name inside the scope of scopeAnchor generated using {@link NameSuggesterUtil#generateNames(String)}.
281 * If none of those names is suitable, unique names is made by appending number suffix.
283 * @param templateName initial template name for generator
284 * @param scopeAnchor PSI element used to determine correct scope
285 * @return unique name in the scope of scopeAnchor
288 public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor) {
289 return selectUniqueName(templateName, false, scopeAnchor, PyRefactoringUtil::isValidNewName);
293 public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
294 return selectUniqueName(templateName, false, scopeAnchor, isValid);
298 private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
299 final Collection<String> suggestions;
300 if (templateIsType) {
301 suggestions = NameSuggesterUtil.generateNamesByType(templateName);
304 suggestions = NameSuggesterUtil.generateNames(templateName);
306 for (String name : suggestions) {
307 if (isValid.test(name, scopeAnchor)) {
312 final String shortestName = ContainerUtil.getFirstItem(suggestions);
313 //noinspection ConstantConditions
314 return appendNumberUntilValid(shortestName, scopeAnchor, isValid);
318 * Appends increasing numbers starting from 1 to the name until it becomes unique within the scope of the scopeAnchor.
320 * @param name initial name
321 * @param scopeAnchor PSI element used to determine correct scope
322 * @param predicate used to test if suggested name is valid
323 * @return unique name in the scope probably with number suffix appended
326 public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> predicate) {
328 String candidate = name;
329 while (!predicate.test(candidate, scopeAnchor)) {
330 candidate = name + counter;
336 public static boolean isValidNewName(@NotNull String name, @NotNull PsiElement scopeAnchor) {
337 return !(IntroduceValidator.isDefinedInScope(name, scopeAnchor) || PyNames.isReserved(name));
341 public static PyFile getOrCreateFile(String path, Project project) {
342 final VirtualFile vfile = LocalFileSystem.getInstance().findFileByIoFile(new File(path));
345 final File file = new File(path);
347 final VirtualFile baseDir = project.getBaseDir();
348 final FileTemplateManager fileTemplateManager = FileTemplateManager.getInstance(project);
349 final FileTemplate template = fileTemplateManager.getInternalTemplate("Python Script");
350 final Properties properties = fileTemplateManager.getDefaultProperties();
351 properties.setProperty("NAME", FileUtilRt.getNameWithoutExtension(file.getName()));
352 final String content = (template != null) ? template.getText(properties) : null;
353 psi = PyExtractSuperclassHelper.placeFile(project,
354 StringUtil.notNullize(
356 baseDir != null ? baseDir
363 catch (IOException e) {
364 throw new IncorrectOperationException(String.format("Cannot create file '%s'", path), (Throwable)e);
368 psi = PsiManager.getInstance(project).findFile(vfile);
370 if (!(psi instanceof PyFile)) {
371 throw new IncorrectOperationException(PyBundle.message(
372 "refactoring.move.module.members.error.cannot.place.elements.into.nonpython.file"));
378 * Filters out {@link PyMemberInfo}
379 * that should not be displayed in this refactoring (like object)
381 * @param pyMemberInfos collection to sort
382 * @return sorted collection
385 public static Collection<PyMemberInfo<PyElement>> filterOutObject(@NotNull final Collection<PyMemberInfo<PyElement>> pyMemberInfos) {
386 return Collections2.filter(pyMemberInfos, new ObjectPredicate(false));
390 * Filters only PyClass object (new class)
392 public static class ObjectPredicate extends NotNullPredicate<PyMemberInfo<PyElement>> {
393 private final boolean myAllowObjects;
396 * @param allowObjects allows only objects if true. Allows all but objects otherwise.
398 public ObjectPredicate(final boolean allowObjects) {
399 myAllowObjects = allowObjects;
403 public boolean applyNotNull(@NotNull final PyMemberInfo<PyElement> input) {
404 return myAllowObjects == isObject(input);
407 private static boolean isObject(@NotNull final PyMemberInfo<PyElement> classMemberInfo) {
408 final PyElement element = classMemberInfo.getMember();
409 return (element instanceof PyClass) && PyNames.OBJECT.equals(element.getName());
415 * @deprecated TODO: remove in 2021.1
419 public static PsiElement addElementToStatementList(@NotNull PsiElement element,
420 @NotNull PyStatementList statementList,
421 boolean toTheBeginning) {
422 return PyPsiRefactoringUtil.addElementToStatementList(element, statementList, toTheBeginning);