Add PyElementGenerator#createParameterList and PyElementGenerator#createParameterList
[idea/community.git] / python / src / com / jetbrains / python / inspections / quickfix / AddCallSuperQuickFix.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.inspections.quickfix;
17
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.openapi.util.Couple;
22 import com.intellij.openapi.util.text.StringUtil;
23 import com.intellij.psi.PsiElement;
24 import com.intellij.psi.util.PsiTreeUtil;
25 import com.intellij.util.containers.ContainerUtil;
26 import com.jetbrains.python.PyBundle;
27 import com.jetbrains.python.PyNames;
28 import com.jetbrains.python.PyTokenTypes;
29 import com.jetbrains.python.psi.*;
30 import com.jetbrains.python.psi.impl.PyPsiUtils;
31 import org.jetbrains.annotations.NonNls;
32 import org.jetbrains.annotations.NotNull;
33 import org.jetbrains.annotations.Nullable;
34
35 import java.util.*;
36
37 /**
38  * For:
39  * class B(A):
40  * def __init__(self):
41  * A.__init__(self)           #  inserted
42  * print "Constructor B was called"
43  * <p/>
44  * User: catherine
45  */
46 public class AddCallSuperQuickFix implements LocalQuickFix {
47
48   @NotNull
49   public String getName() {
50     return PyBundle.message("QFIX.add.super");
51   }
52
53   @NonNls
54   @NotNull
55   public String getFamilyName() {
56     return getName();
57   }
58
59   public void applyFix(@NotNull final Project project, @NotNull final ProblemDescriptor descriptor) {
60     final PyFunction problemFunction = PsiTreeUtil.getParentOfType(descriptor.getPsiElement(), PyFunction.class);
61     if (problemFunction == null) return;
62     final StringBuilder superCall = new StringBuilder();
63     final PyClass klass = problemFunction.getContainingClass();
64     if (klass == null) return;
65     final PyClass[] superClasses = klass.getSuperClasses();
66     if (superClasses.length == 0) return;
67
68     final PyClass superClass = superClasses[0];
69     final PyFunction superInit = superClass.findMethodByName(PyNames.INIT, true);
70     if (superInit == null) return;
71
72     final ParametersInfo origInfo = new ParametersInfo(problemFunction.getParameterList());
73     final ParametersInfo superInfo = new ParametersInfo(superInit.getParameterList());
74     final boolean addSelfToCall;
75
76     if (klass.isNewStyleClass(null)) {
77       addSelfToCall = false;
78       if (LanguageLevel.forElement(klass).isPy3K()) {
79         superCall.append("super().__init__(");
80       }
81       else {
82         superCall.append("super(").append(klass.getName()).append(", ").append(getSelfParameterName(origInfo)).append(").__init__(");
83       }
84     }
85     else {
86       addSelfToCall = true;
87       superCall.append(superClass.getName()).append(".__init__(");
88     }
89
90     final Couple<List<String>> couple = buildNewFunctionParamsAndSuperInitCallArgs(origInfo, superInfo, addSelfToCall);
91     final StringBuilder newParameters = new StringBuilder("(");
92     StringUtil.join(couple.getFirst(), ", ", newParameters);
93     newParameters.append(")");
94
95     StringUtil.join(couple.getSecond(), ", ", superCall);
96     superCall.append(")");
97
98     final PyElementGenerator generator = PyElementGenerator.getInstance(project);
99     final LanguageLevel languageLevel = LanguageLevel.forElement(problemFunction);
100     final PyStatement callSuperStatement = generator.createFromText(languageLevel, PyStatement.class, superCall.toString());
101     final PyParameterList newParameterList = generator.createParameterList(languageLevel, newParameters.toString());
102     problemFunction.getParameterList().replace(newParameterList);
103     final PyStatementList statementList = problemFunction.getStatementList();
104     PyUtil.addElementToStatementList(callSuperStatement, statementList, true);
105     PyPsiUtils.removeRedundantPass(statementList);
106   }
107
108   @NotNull
109   private static String getSelfParameterName(@NotNull ParametersInfo info) {
110     final PyParameter selfParameter = info.getSelfParameter();
111     if (selfParameter == null) {
112       return PyNames.CANONICAL_SELF;
113     }
114     return StringUtil.defaultIfEmpty(selfParameter.getName(), PyNames.CANONICAL_SELF);
115   }
116
117   @NotNull
118   private static Couple<List<String>> buildNewFunctionParamsAndSuperInitCallArgs(@NotNull ParametersInfo origInfo,
119                                                                                  @NotNull ParametersInfo superInfo,
120                                                                                  boolean addSelfToCall) {
121     final List<String> newFunctionParams = new ArrayList<String>();
122     final List<String> superCallArgs = new ArrayList<String>();
123
124     final PyParameter selfParameter = origInfo.getSelfParameter();
125     if (selfParameter != null && StringUtil.isNotEmpty(selfParameter.getName())) {
126       newFunctionParams.add(selfParameter.getText());
127     }
128     else {
129       newFunctionParams.add(PyNames.CANONICAL_SELF);
130     }
131
132     if (addSelfToCall) {
133       superCallArgs.add(getSelfParameterName(origInfo));
134     }
135
136     // Required parameters (not-keyword)
137     for (PyParameter param : origInfo.getRequiredParameters()) {
138       newFunctionParams.add(param.getText());
139     }
140     for (PyParameter param : superInfo.getRequiredParameters()) {
141       // Special case as if base class has constructor __init__((a, b), c) and
142       // subclass has constructor __init__(a, (b, c))
143       final PyTupleParameter tupleParam = param.getAsTuple();
144       if (tupleParam != null) {
145         final List<String> uniqueNames = collectParameterNames(tupleParam);
146         final boolean hasDuplicates = uniqueNames.removeAll(origInfo.getAllParameterNames());
147         if (hasDuplicates) {
148           newFunctionParams.addAll(uniqueNames);
149         }
150         else {
151           newFunctionParams.add(param.getText());
152         }
153         // Retain original structure of tuple parameter.
154         // Note that tuple parameters cannot have annotations or nested default values, so it's syntactically safe
155         superCallArgs.add(param.getText());
156       }
157       else {
158         if (!origInfo.getAllParameterNames().contains(param.getName())) {
159           newFunctionParams.add(param.getText());
160         }
161         superCallArgs.add(param.getName());
162       }
163     }
164
165     // Optional parameters (not-keyword)
166     for (PyParameter param : origInfo.getOptionalParameters()) {
167       newFunctionParams.add(param.getText());
168     }
169
170     // Pass parameters with default values to super class constructor, only if both functions contain them  
171     for (PyParameter param : superInfo.getOptionalParameters()) {
172       final PyTupleParameter tupleParam = param.getAsTuple();
173       if (tupleParam != null) {
174         if (origInfo.getAllParameterNames().containsAll(collectParameterNames(tupleParam))) {
175           final String paramText = tupleParam.getText();
176           final PsiElement equalSign = PyPsiUtils.getFirstChildOfType(param, PyTokenTypes.EQ);
177           if (equalSign != null) {
178             superCallArgs.add(paramText.substring(0, equalSign.getStartOffsetInParent()).trim());
179           }
180         }
181       }
182       else {
183         if (origInfo.getAllParameterNames().contains(param.getName())) {
184           superCallArgs.add(param.getName());
185         }
186       }
187     }
188
189     // Positional vararg
190     PyParameter starredParam = null;
191     if (origInfo.getPositionalContainerParameter() != null) {
192       starredParam = origInfo.getPositionalContainerParameter();
193     }
194     else if (superInfo.getPositionalContainerParameter() != null) {
195       starredParam = superInfo.getPositionalContainerParameter();
196     }
197     else if (origInfo.getSingleStarParameter() != null) {
198       starredParam = origInfo.getSingleStarParameter();
199     }
200     else if (superInfo.getSingleStarParameter() != null) {
201       starredParam = superInfo.getSingleStarParameter();
202     }
203     if (starredParam != null) {
204       newFunctionParams.add(starredParam.getText());
205       if (superInfo.getPositionalContainerParameter() != null) {
206         superCallArgs.add("*" + starredParam.getName());
207       }
208     }
209
210     // Required keyword-only parameters
211     boolean newSignatureContainsKeywordParams = false;
212     for (PyParameter param : origInfo.getRequiredKeywordOnlyParameters()) {
213       newFunctionParams.add(param.getText());
214       newSignatureContainsKeywordParams = true;
215     }
216     for (PyParameter param : superInfo.getRequiredKeywordOnlyParameters()) {
217       if (!origInfo.getAllParameterNames().contains(param.getName())) {
218         newFunctionParams.add(param.getText());
219         newSignatureContainsKeywordParams = true;
220       }
221       superCallArgs.add(param.getName() + "=" + param.getName());
222     }
223
224     // Optional keyword-only parameters
225     for (PyParameter param : origInfo.getOptionalKeywordOnlyParameters()) {
226       newFunctionParams.add(param.getText());
227       newSignatureContainsKeywordParams = true;
228     }
229     
230     // If '*' param is followed by nothing in result signature, remove it altogether 
231     if (starredParam instanceof PySingleStarParameter && !newSignatureContainsKeywordParams) {
232       newFunctionParams.remove(newFunctionParams.size() - 1);
233     }
234
235     for (PyParameter param : superInfo.getOptionalKeywordOnlyParameters()) {
236       if (origInfo.getAllParameterNames().contains(param.getName())) {
237         superCallArgs.add(param.getName() + "=" + param.getName());
238       }
239     }
240
241     // Keyword vararg
242     PyParameter doubleStarredParam = null;
243     if (origInfo.getKeywordContainerParameter() != null) {
244       doubleStarredParam = origInfo.getKeywordContainerParameter();
245     }
246     else if (superInfo.getKeywordContainerParameter() != null) {
247       doubleStarredParam = superInfo.getKeywordContainerParameter();
248     }
249     if (doubleStarredParam != null) {
250       newFunctionParams.add(doubleStarredParam.getText());
251       if (superInfo.getKeywordContainerParameter() != null) {
252         superCallArgs.add("**" + doubleStarredParam.getName());
253       }
254     }
255     return Couple.of(newFunctionParams, superCallArgs);
256   }
257
258   private static class ParametersInfo {
259
260     private final PyParameter mySelfParam;
261     /**
262      * Parameters without default value that come before first "*..." parameter.
263      */
264     private final List<PyParameter> myRequiredParams = new ArrayList<PyParameter>();
265     /**
266      * Parameters with default value that come before first "*..." parameter.
267      */
268     private final List<PyParameter> myOptionalParams = new ArrayList<PyParameter>();
269     /**
270      * Parameter of form "*args" (positional vararg), not the same as single "*".
271      */
272     private final PyParameter myPositionalContainerParam;
273     /**
274      * Parameter "*", that is used to delimit normal and keyword-only parameters.
275      */
276     private final PyParameter mySingleStarParam;
277     /**
278      * Parameters without default value that come after first "*..." parameter.
279      */
280     private final List<PyParameter> myRequiredKwOnlyParams = new ArrayList<PyParameter>();
281     /**
282      * Parameters with default value that come after first "*..." parameter.
283      */
284     private final List<PyParameter> myOptionalKwOnlyParams = new ArrayList<PyParameter>();
285     /**
286      * Parameter of form "**kwargs" (keyword vararg).
287      */
288     private final PyParameter myKeywordContainerParam;
289
290     private final Set<String> myAllParameterNames = new LinkedHashSet<String>();
291
292     public ParametersInfo(@NotNull PyParameterList parameterList) {
293       PyParameter positionalContainer = null;
294       PyParameter singleStarParam = null;
295       PyParameter keywordContainer = null;
296       PyParameter selfParam = null;
297
298       for (PyParameter param : parameterList.getParameters()) {
299         myAllParameterNames.addAll(collectParameterNames(param));
300
301         if (param.isSelf()) {
302           selfParam = param;
303         }
304         else if (param instanceof PySingleStarParameter) {
305           singleStarParam = param;
306         }
307         else if (param.getAsNamed() != null && param.getAsNamed().isKeywordContainer()) {
308           keywordContainer = param;
309         }
310         else if (param.getAsNamed() != null && param.getAsNamed().isPositionalContainer()) {
311           positionalContainer = param;
312         }
313         else if (param.getAsNamed() == null || !param.getAsNamed().isKeywordOnly()) {
314           if (param.hasDefaultValue()) {
315             myOptionalParams.add(param);
316           }
317           else {
318             myRequiredParams.add(param);
319           }
320         }
321         else {
322           if (param.hasDefaultValue()) {
323             myOptionalKwOnlyParams.add(param);
324           }
325           else {
326             myRequiredKwOnlyParams.add(param);
327           }
328         }
329       }
330
331       mySelfParam = selfParam;
332       myPositionalContainerParam = positionalContainer;
333       mySingleStarParam = singleStarParam;
334       myKeywordContainerParam = keywordContainer;
335     }
336
337     @Nullable
338     public PyParameter getSelfParameter() {
339       return mySelfParam;
340     }
341
342     @NotNull
343     public List<PyParameter> getRequiredParameters() {
344       return Collections.unmodifiableList(myRequiredParams);
345     }
346
347     @NotNull
348     public List<PyParameter> getOptionalParameters() {
349       return Collections.unmodifiableList(myOptionalParams);
350     }
351
352     @Nullable
353     public PyParameter getPositionalContainerParameter() {
354       return myPositionalContainerParam;
355     }
356
357     @Nullable
358     public PyParameter getSingleStarParameter() {
359       return mySingleStarParam;
360     }
361
362     @NotNull
363     public List<PyParameter> getRequiredKeywordOnlyParameters() {
364       return Collections.unmodifiableList(myRequiredKwOnlyParams);
365     }
366
367     @NotNull
368     public List<PyParameter> getOptionalKeywordOnlyParameters() {
369       return Collections.unmodifiableList(myOptionalKwOnlyParams);
370     }
371
372     @Nullable
373     public PyParameter getKeywordContainerParameter() {
374       return myKeywordContainerParam;
375     }
376
377     @NotNull
378     public Set<String> getAllParameterNames() {
379       return Collections.unmodifiableSet(myAllParameterNames);
380     }
381   }
382
383   @NotNull
384   private static List<String> collectParameterNames(@NotNull PyParameter param) {
385     final List<String> result = new ArrayList<String>();
386     collectParameterNames(param, result);
387     return result;
388   }
389
390
391   private static void collectParameterNames(@NotNull PyParameter param, @NotNull Collection<String> acc) {
392     final PyTupleParameter tupleParam = param.getAsTuple();
393     if (tupleParam != null) {
394       for (PyParameter subParam : tupleParam.getContents()) {
395         collectParameterNames(subParam, acc);
396       }
397     }
398     else {
399       ContainerUtil.addIfNotNull(acc, param.getName());
400     }
401   }
402 }