PY-17392 Cleanup in PyDefaultArgumentQuickFix
[idea/community.git] / python / src / com / jetbrains / python / inspections / quickfix / PyDefaultArgumentQuickFix.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.inspections.quickfix;
17
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.jetbrains.python.PyBundle;
24 import com.jetbrains.python.psi.*;
25 import org.jetbrains.annotations.NotNull;
26
27 /**
28  * User: catherine
29  *
30  * QuickFix to replace mutable default argument. For instance,
31  * def foo(args=[]):
32      pass
33  * replace with:
34  * def foo(args=None):
35      if not args: args = []
36      pass
37  */
38 public class PyDefaultArgumentQuickFix implements LocalQuickFix {
39
40   @Override
41   @NotNull
42   public String getName() {
43     return PyBundle.message("QFIX.default.argument");
44   }
45
46   @Override
47   @NotNull
48   public String getFamilyName() {
49     return getName();
50   }
51
52   @Override
53   public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
54     final PsiElement defaultValue = descriptor.getPsiElement();
55     final PyNamedParameter param = PsiTreeUtil.getParentOfType(defaultValue, PyNamedParameter.class);
56     final PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
57     assert param != null;
58     final String defName = param.getName();
59     if (function != null) {
60       final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
61       final PyStatementList list = function.getStatementList();
62       final PyParameterList paramList = function.getParameterList();
63
64       final StringBuilder functionText = new StringBuilder("def " + function.getName() + "(");
65       final int size = paramList.getParameters().length;
66       for (int i = 0; i < size; i++) {
67         final PyParameter p = paramList.getParameters()[i];
68         if (p == param) {
69           functionText.append(defName).append("=None");
70         }
71         else {
72           functionText.append(p.getText());
73         }
74         if (i != size-1) {
75           functionText.append(", ");
76         }
77       }
78       
79       functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
80       final PyStatement[] statements = list.getStatements();
81       final PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
82       final PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
83                                                                      functionText.toString());
84       if (firstStatement == null) {
85         function.replace(newFunction);
86       }
87       else {
88         final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
89         final PyStringLiteralExpression docString = function.getDocStringExpression();
90         if (docString != null) {
91           list.addAfter(ifStatement, firstStatement);
92         }
93         else {
94           list.addBefore(ifStatement, firstStatement);
95         }
96         paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
97                                                           PyFunction.class, functionText.toString()).getParameterList());
98       }
99     }
100   }
101 }