83617ed4479b57bc83fc9680cadbf6c3c9a37c55
[idea/community.git] / python / src / com / jetbrains / python / inspections / quickfix / AddCallSuperQuickFix.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.inspections.quickfix;
17
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.openapi.util.Couple;
22 import com.intellij.openapi.util.text.StringUtil;
23 import com.intellij.psi.PsiElement;
24 import com.intellij.psi.util.PsiTreeUtil;
25 import com.intellij.util.containers.ContainerUtil;
26 import com.jetbrains.python.PyBundle;
27 import com.jetbrains.python.PyNames;
28 import com.jetbrains.python.PyTokenTypes;
29 import com.jetbrains.python.psi.*;
30 import com.jetbrains.python.psi.impl.PyPsiUtils;
31 import org.jetbrains.annotations.NonNls;
32 import org.jetbrains.annotations.NotNull;
33 import org.jetbrains.annotations.Nullable;
34
35 import java.util.*;
36
37 /**
38  * For:
39  * class B(A):
40  * def __init__(self):
41  * A.__init__(self)           #  inserted
42  * print "Constructor B was called"
43  * <p/>
44  * User: catherine
45  */
46 public class AddCallSuperQuickFix implements LocalQuickFix {
47
48   @NotNull
49   public String getName() {
50     return PyBundle.message("QFIX.add.super");
51   }
52
53   @NonNls
54   @NotNull
55   public String getFamilyName() {
56     return getName();
57   }
58
59   public void applyFix(@NotNull final Project project, @NotNull final ProblemDescriptor descriptor) {
60     final PyFunction problemFunction = PsiTreeUtil.getParentOfType(descriptor.getPsiElement(), PyFunction.class);
61     if (problemFunction == null) return;
62     final StringBuilder superCall = new StringBuilder();
63     final PyClass klass = problemFunction.getContainingClass();
64     if (klass == null) return;
65     final PyClass[] superClasses = klass.getSuperClasses();
66     if (superClasses.length == 0) return;
67
68     final PyClass superClass = superClasses[0];
69     final PyFunction superInit = superClass.findMethodByName(PyNames.INIT, true);
70     if (superInit == null) return;
71
72     final ParametersInfo origInfo = new ParametersInfo(problemFunction.getParameterList());
73     final ParametersInfo superInfo = new ParametersInfo(superInit.getParameterList());
74     final boolean addSelfToCall;
75
76     if (klass.isNewStyleClass(null)) {
77       addSelfToCall = false;
78       if (LanguageLevel.forElement(klass).isPy3K()) {
79         superCall.append("super().__init__(");
80       }
81       else {
82         superCall.append("super(").append(klass.getName()).append(", ").append(getSelfParameterName(origInfo)).append(").__init__(");
83       }
84     }
85     else {
86       addSelfToCall = true;
87       superCall.append(superClass.getName()).append(".__init__(");
88     }
89     final StringBuilder newFunction = new StringBuilder("def __init__(");
90
91     final Couple<List<String>> couple = buildNewFunctionParamsAndSuperInitCallArgs(origInfo, superInfo, addSelfToCall);
92     StringUtil.join(couple.getFirst(), ", ", newFunction);
93     newFunction.append(")");
94
95     if (problemFunction.getAnnotation() != null) {
96       newFunction.append(problemFunction.getAnnotation().getText());
97     }
98     newFunction.append(": pass");
99
100     StringUtil.join(couple.getSecond(), ", ", superCall);
101     superCall.append(")");
102
103     final PyElementGenerator generator = PyElementGenerator.getInstance(project);
104     final LanguageLevel languageLevel = LanguageLevel.forElement(problemFunction);
105     final PyStatement callSuperStatement = generator.createFromText(languageLevel, PyStatement.class, superCall.toString());
106     final PyParameterList newParameterList = generator.createFromText(languageLevel,
107                                                                       PyParameterList.class,
108                                                                       newFunction.toString(),
109                                                                       new int[]{0, 3});
110     problemFunction.getParameterList().replace(newParameterList);
111     final PyStatementList statementList = problemFunction.getStatementList();
112     PyUtil.addElementToStatementList(callSuperStatement, statementList, true);
113     PyPsiUtils.removeRedundantPass(statementList);
114   }
115
116   @NotNull
117   private static String getSelfParameterName(@NotNull ParametersInfo info) {
118     final PyParameter selfParameter = info.getSelfParameter();
119     if (selfParameter == null) {
120       return PyNames.CANONICAL_SELF;
121     }
122     return StringUtil.defaultIfEmpty(selfParameter.getName(), PyNames.CANONICAL_SELF);
123   }
124
125   @NotNull
126   private static Couple<List<String>> buildNewFunctionParamsAndSuperInitCallArgs(@NotNull ParametersInfo origInfo,
127                                                                                  @NotNull ParametersInfo superInfo,
128                                                                                  boolean addSelfToCall) {
129     final List<String> newFunctionParams = new ArrayList<String>();
130     final List<String> superCallArgs = new ArrayList<String>();
131
132     final PyParameter selfParameter = origInfo.getSelfParameter();
133     if (selfParameter != null && StringUtil.isNotEmpty(selfParameter.getName())) {
134       newFunctionParams.add(selfParameter.getText());
135     }
136     else {
137       newFunctionParams.add(PyNames.CANONICAL_SELF);
138     }
139
140     if (addSelfToCall) {
141       superCallArgs.add(getSelfParameterName(origInfo));
142     }
143
144     // Required parameters (not-keyword)
145     for (PyParameter param : origInfo.getRequiredParameters()) {
146       newFunctionParams.add(param.getText());
147     }
148     for (PyParameter param : superInfo.getRequiredParameters()) {
149       // Special case as if base class has constructor __init__((a, b), c) and
150       // subclass has constructor __init__(a, (b, c))
151       final PyTupleParameter tupleParam = param.getAsTuple();
152       if (tupleParam != null) {
153         final List<String> uniqueNames = collectParameterNames(tupleParam);
154         final boolean hasDuplicates = uniqueNames.removeAll(origInfo.getAllParameterNames());
155         if (hasDuplicates) {
156           newFunctionParams.addAll(uniqueNames);
157         }
158         else {
159           newFunctionParams.add(param.getText());
160         }
161         // Retain original structure of tuple parameter.
162         // Note that tuple parameters cannot have annotations or nested default values, so it's syntactically safe
163         superCallArgs.add(param.getText());
164       }
165       else {
166         if (!origInfo.getAllParameterNames().contains(param.getName())) {
167           newFunctionParams.add(param.getText());
168         }
169         superCallArgs.add(param.getName());
170       }
171     }
172
173     // Optional parameters (not-keyword)
174     for (PyParameter param : origInfo.getOptionalParameters()) {
175       newFunctionParams.add(param.getText());
176     }
177
178     // Pass parameters with default values to super class constructor, only if both functions contain them  
179     for (PyParameter param : superInfo.getOptionalParameters()) {
180       final PyTupleParameter tupleParam = param.getAsTuple();
181       if (tupleParam != null) {
182         if (origInfo.getAllParameterNames().containsAll(collectParameterNames(tupleParam))) {
183           final String paramText = tupleParam.getText();
184           final PsiElement equalSign = PyPsiUtils.getFirstChildOfType(param, PyTokenTypes.EQ);
185           if (equalSign != null) {
186             superCallArgs.add(paramText.substring(0, equalSign.getStartOffsetInParent()).trim());
187           }
188         }
189       }
190       else {
191         if (origInfo.getAllParameterNames().contains(param.getName())) {
192           superCallArgs.add(param.getName());
193         }
194       }
195     }
196
197     // Positional vararg
198     PyParameter starredParam = null;
199     if (origInfo.getPositionalContainerParameter() != null) {
200       starredParam = origInfo.getPositionalContainerParameter();
201     }
202     else if (superInfo.getPositionalContainerParameter() != null) {
203       starredParam = superInfo.getPositionalContainerParameter();
204     }
205     else if (origInfo.getSingleStarParameter() != null) {
206       starredParam = origInfo.getSingleStarParameter();
207     }
208     else if (superInfo.getSingleStarParameter() != null) {
209       starredParam = superInfo.getSingleStarParameter();
210     }
211     if (starredParam != null) {
212       newFunctionParams.add(starredParam.getText());
213       if (superInfo.getPositionalContainerParameter() != null) {
214         superCallArgs.add("*" + starredParam.getName());
215       }
216     }
217
218     // Required keyword-only parameters
219     boolean newSignatureContainsKeywordParams = false;
220     for (PyParameter param : origInfo.getRequiredKeywordOnlyParameters()) {
221       newFunctionParams.add(param.getText());
222       newSignatureContainsKeywordParams = true;
223     }
224     for (PyParameter param : superInfo.getRequiredKeywordOnlyParameters()) {
225       if (!origInfo.getAllParameterNames().contains(param.getName())) {
226         newFunctionParams.add(param.getText());
227         newSignatureContainsKeywordParams = true;
228       }
229       superCallArgs.add(param.getName() + "=" + param.getName());
230     }
231
232     // Optional keyword-only parameters
233     for (PyParameter param : origInfo.getOptionalKeywordOnlyParameters()) {
234       newFunctionParams.add(param.getText());
235       newSignatureContainsKeywordParams = true;
236     }
237     
238     // If '*' param is followed by nothing in result signature, remove it altogether 
239     if (starredParam instanceof PySingleStarParameter && !newSignatureContainsKeywordParams) {
240       newFunctionParams.remove(newFunctionParams.size() - 1);
241     }
242
243     for (PyParameter param : superInfo.getOptionalKeywordOnlyParameters()) {
244       if (origInfo.getAllParameterNames().contains(param.getName())) {
245         superCallArgs.add(param.getName() + "=" + param.getName());
246       }
247     }
248
249     // Keyword vararg
250     PyParameter doubleStarredParam = null;
251     if (origInfo.getKeywordContainerParameter() != null) {
252       doubleStarredParam = origInfo.getKeywordContainerParameter();
253     }
254     else if (superInfo.getKeywordContainerParameter() != null) {
255       doubleStarredParam = superInfo.getKeywordContainerParameter();
256     }
257     if (doubleStarredParam != null) {
258       newFunctionParams.add(doubleStarredParam.getText());
259       if (superInfo.getKeywordContainerParameter() != null) {
260         superCallArgs.add("**" + doubleStarredParam.getName());
261       }
262     }
263     return Couple.of(newFunctionParams, superCallArgs);
264   }
265
266   private static class ParametersInfo {
267
268     private final PyParameter mySelfParam;
269     /**
270      * Parameters without default value that come before first "*..." parameter.
271      */
272     private final List<PyParameter> myRequiredParams = new ArrayList<PyParameter>();
273     /**
274      * Parameters with default value that come before first "*..." parameter.
275      */
276     private final List<PyParameter> myOptionalParams = new ArrayList<PyParameter>();
277     /**
278      * Parameter of form "*args" (positional vararg), not the same as single "*".
279      */
280     private final PyParameter myPositionalContainerParam;
281     /**
282      * Parameter "*", that is used to delimit normal and keyword-only parameters.
283      */
284     private final PyParameter mySingleStarParam;
285     /**
286      * Parameters without default value that come after first "*..." parameter.
287      */
288     private final List<PyParameter> myRequiredKwOnlyParams = new ArrayList<PyParameter>();
289     /**
290      * Parameters with default value that come after first "*..." parameter.
291      */
292     private final List<PyParameter> myOptionalKwOnlyParams = new ArrayList<PyParameter>();
293     /**
294      * Parameter of form "**kwargs" (keyword vararg).
295      */
296     private final PyParameter myKeywordContainerParam;
297
298     private final Set<String> myAllParameterNames = new LinkedHashSet<String>();
299
300     public ParametersInfo(@NotNull PyParameterList parameterList) {
301       PyParameter positionalContainer = null;
302       PyParameter singleStarParam = null;
303       PyParameter keywordContainer = null;
304       PyParameter selfParam = null;
305
306       for (PyParameter param : parameterList.getParameters()) {
307         myAllParameterNames.addAll(collectParameterNames(param));
308
309         if (param.isSelf()) {
310           selfParam = param;
311         }
312         else if (param instanceof PySingleStarParameter) {
313           singleStarParam = param;
314         }
315         else if (param.getAsNamed() != null && param.getAsNamed().isKeywordContainer()) {
316           keywordContainer = param;
317         }
318         else if (param.getAsNamed() != null && param.getAsNamed().isPositionalContainer()) {
319           positionalContainer = param;
320         }
321         else if (param.getAsNamed() == null || !param.getAsNamed().isKeywordOnly()) {
322           if (param.hasDefaultValue()) {
323             myOptionalParams.add(param);
324           }
325           else {
326             myRequiredParams.add(param);
327           }
328         }
329         else {
330           if (param.hasDefaultValue()) {
331             myOptionalKwOnlyParams.add(param);
332           }
333           else {
334             myRequiredKwOnlyParams.add(param);
335           }
336         }
337       }
338
339       mySelfParam = selfParam;
340       myPositionalContainerParam = positionalContainer;
341       mySingleStarParam = singleStarParam;
342       myKeywordContainerParam = keywordContainer;
343     }
344
345     @Nullable
346     public PyParameter getSelfParameter() {
347       return mySelfParam;
348     }
349
350     @NotNull
351     public List<PyParameter> getRequiredParameters() {
352       return Collections.unmodifiableList(myRequiredParams);
353     }
354
355     @NotNull
356     public List<PyParameter> getOptionalParameters() {
357       return Collections.unmodifiableList(myOptionalParams);
358     }
359
360     @Nullable
361     public PyParameter getPositionalContainerParameter() {
362       return myPositionalContainerParam;
363     }
364
365     @Nullable
366     public PyParameter getSingleStarParameter() {
367       return mySingleStarParam;
368     }
369
370     @NotNull
371     public List<PyParameter> getRequiredKeywordOnlyParameters() {
372       return Collections.unmodifiableList(myRequiredKwOnlyParams);
373     }
374
375     @NotNull
376     public List<PyParameter> getOptionalKeywordOnlyParameters() {
377       return Collections.unmodifiableList(myOptionalKwOnlyParams);
378     }
379
380     @Nullable
381     public PyParameter getKeywordContainerParameter() {
382       return myKeywordContainerParam;
383     }
384
385     @NotNull
386     public Set<String> getAllParameterNames() {
387       return Collections.unmodifiableSet(myAllParameterNames);
388     }
389   }
390
391   @NotNull
392   private static List<String> collectParameterNames(@NotNull PyParameter param) {
393     final List<String> result = new ArrayList<String>();
394     collectParameterNames(param, result);
395     return result;
396   }
397
398
399   private static void collectParameterNames(@NotNull PyParameter param, @NotNull Collection<String> acc) {
400     final PyTupleParameter tupleParam = param.getAsTuple();
401     if (tupleParam != null) {
402       for (PyParameter subParam : tupleParam.getContents()) {
403         collectParameterNames(subParam, acc);
404       }
405     }
406     else {
407       ContainerUtil.addIfNotNull(acc, param.getName());
408     }
409   }
410 }