3df1e2826d686e410aac85cfb6fc7f077689670b
[idea/community.git] / python / src / com / jetbrains / python / inspections / quickfix / PyMakeFunctionFromMethodQuickFix.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.inspections.quickfix;
17
18 import com.intellij.codeInspection.LocalInspectionToolSession;
19 import com.intellij.codeInspection.LocalQuickFix;
20 import com.intellij.codeInspection.ProblemDescriptor;
21 import com.intellij.openapi.project.Project;
22 import com.intellij.psi.PsiElement;
23 import com.intellij.psi.PsiFile;
24 import com.intellij.psi.PsiNamedElement;
25 import com.intellij.psi.PsiReference;
26 import com.intellij.psi.util.PsiTreeUtil;
27 import com.intellij.usageView.UsageInfo;
28 import com.jetbrains.python.PyBundle;
29 import com.jetbrains.python.PyNames;
30 import com.jetbrains.python.codeInsight.imports.AddImportHelper;
31 import com.jetbrains.python.inspections.unresolvedReference.PyUnresolvedReferencesInspection;
32 import com.jetbrains.python.psi.*;
33 import com.jetbrains.python.refactoring.PyRefactoringUtil;
34 import org.jetbrains.annotations.NonNls;
35 import org.jetbrains.annotations.NotNull;
36
37 import java.util.Collections;
38 import java.util.List;
39
40 /**
41  * User: ktisha
42  */
43 public class PyMakeFunctionFromMethodQuickFix implements LocalQuickFix {
44   public PyMakeFunctionFromMethodQuickFix() {
45   }
46
47   @NotNull
48   public String getName() {
49     return PyBundle.message("QFIX.NAME.make.function");
50   }
51
52   @NonNls
53   @NotNull
54   public String getFamilyName() {
55     return getName();
56   }
57
58   public void applyFix(@NotNull final Project project, @NotNull final ProblemDescriptor descriptor) {
59     final PsiElement element = descriptor.getPsiElement();
60     final PyFunction problemFunction = PsiTreeUtil.getParentOfType(element, PyFunction.class);
61     if (problemFunction == null) return;
62     final PyClass containingClass = problemFunction.getContainingClass();
63     if (containingClass == null) return;
64
65     final List<UsageInfo> usages = PyRefactoringUtil.findUsages(problemFunction, false);
66     final PyParameter[] parameters = problemFunction.getParameterList().getParameters();
67     if (parameters.length > 0) {
68       parameters[0].delete();
69     }
70
71     PsiElement copy = problemFunction.copy();
72     problemFunction.delete();
73     final PsiElement parent = containingClass.getParent();
74     PyClass aClass = PsiTreeUtil.getTopmostParentOfType(containingClass, PyClass.class);
75     if (aClass == null)
76       aClass = containingClass;
77     copy = parent.addBefore(copy, aClass);
78
79     for (UsageInfo usage : usages) {
80       final PsiElement usageElement = usage.getElement();
81       if (usageElement instanceof PyReferenceExpression) {
82         final PsiFile usageFile = usageElement.getContainingFile();
83         updateUsage(copy, (PyReferenceExpression)usageElement, usageFile, !usageFile.equals(parent));
84       }
85     }
86   }
87
88   private static void updateUsage(@NotNull final PsiElement finalElement, @NotNull final PyReferenceExpression element,
89                                   @NotNull final PsiFile usageFile, boolean addImport) {
90     final PyExpression qualifier = element.getQualifier();
91     if (qualifier == null) return;
92     if (qualifier.getText().equals(PyNames.CANONICAL_SELF)) PyUtil.removeQualifier(element);
93     if (qualifier instanceof PyCallExpression) {              // remove qualifier A().m()
94       if (addImport)
95         AddImportHelper.addImport((PsiNamedElement)finalElement, usageFile, element);
96
97       PyUtil.removeQualifier(element);
98       removeFormerImport(usageFile, addImport);
99     }
100     else {
101       final PsiReference reference = qualifier.getReference();
102       if (reference == null) return;
103
104       final PsiElement resolved = reference.resolve();
105       if (resolved instanceof PyTargetExpression) {  // qualifier came from assignment  a = A(); a.m()
106         updateAssignment(element, resolved);
107       }
108       else if (resolved instanceof PyClass) {     //call with first instance argument A.m(A())
109         final PsiElement dot = qualifier.getNextSibling();
110         if (dot != null) dot.delete();
111         qualifier.delete();
112         updateArgumentList(element);
113       }
114     }
115   }
116
117   private static void removeFormerImport(@NotNull final PsiFile usageFile, boolean addImport) {
118     if (usageFile instanceof PyFile && addImport) {
119       final LocalInspectionToolSession session = new LocalInspectionToolSession(usageFile, 0, usageFile.getTextLength());
120       final PyUnresolvedReferencesInspection.Visitor visitor = new PyUnresolvedReferencesInspection.Visitor(null,
121                                                                                                             session,
122                                                                                                             Collections.<String>emptyList());
123       usageFile.accept(new PyRecursiveElementVisitor() {
124         @Override
125         public void visitPyElement(PyElement node) {
126           super.visitPyElement(node);
127           node.accept(visitor);
128         }
129       });
130
131       visitor.optimizeImports();
132     }
133   }
134
135   private static void updateAssignment(PyReferenceExpression element, @NotNull final PsiElement resolved) {
136     final PsiElement parent = resolved.getParent();
137     if (parent instanceof PyAssignmentStatement) {
138       final PyExpression value = ((PyAssignmentStatement)parent).getAssignedValue();
139       if (value instanceof PyCallExpression) {
140         final PyExpression callee = ((PyCallExpression)value).getCallee();
141         if (callee instanceof PyReferenceExpression) {
142           final PyExpression calleeQualifier = ((PyReferenceExpression)callee).getQualifier();
143           if (calleeQualifier != null) {
144             value.replace(calleeQualifier);
145           }
146           else {
147             PyUtil.removeQualifier(element);
148           }
149         }
150       }
151     }
152   }
153
154   private static void updateArgumentList(@NotNull final PyReferenceExpression element) {
155     final PyCallExpression callExpression = PsiTreeUtil.getParentOfType(element, PyCallExpression.class);
156     if (callExpression == null) return;
157     final PyArgumentList argumentList = callExpression.getArgumentList();
158     if (argumentList == null) return;
159     final PyExpression[] arguments = argumentList.getArguments();
160     if (arguments.length > 0) {
161       arguments[0].delete();
162     }
163   }
164 }