Merge remote-tracking branch 'origin/master'
[idea/community.git] / python / src / com / jetbrains / python / inspections / quickfix / PyDefaultArgumentQuickFix.java
1 /*
2  * Copyright 2000-2013 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.inspections.quickfix;
17
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.jetbrains.python.PyBundle;
24 import com.jetbrains.python.psi.*;
25 import org.jetbrains.annotations.NotNull;
26
27 /**
28  * User: catherine
29  *
30  * QuickFix to replace mutable default argument. For instance,
31  * def foo(args=[]):
32      pass
33  * replace with:
34  * def foo(args=None):
35      if not args: args = []
36      pass
37  */
38 public class PyDefaultArgumentQuickFix implements LocalQuickFix {
39
40   @Override
41   @NotNull
42   public String getName() {
43     return PyBundle.message("QFIX.default.argument");
44   }
45
46   @Override
47   @NotNull
48   public String getFamilyName() {
49     return getName();
50   }
51
52   @Override
53   public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
54     PsiElement defaultValue = descriptor.getPsiElement();
55     PyNamedParameter param = PsiTreeUtil.getParentOfType(defaultValue, PyNamedParameter.class);
56     PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
57     assert param != null;
58     String defName = param.getName();
59     if (function != null) {
60       PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
61       PyStatementList list = function.getStatementList();
62       PyParameterList paramList = function.getParameterList();
63
64       final StringBuilder functionText = new StringBuilder("def " + function.getName() + "(");
65       int size = paramList.getParameters().length;
66       for (int i = 0; i != size; ++i) {
67         PyParameter p = paramList.getParameters()[i];
68         if (p == param)
69           functionText.append(defName).append("=None");
70         else
71           functionText.append(p.getText());
72         if (i != size-1)
73           functionText.append(", ");
74       }
75       functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
76       final PyStatement[] statements = list.getStatements();
77       PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
78       PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
79                                                                functionText.toString());
80       if (firstStatement == null) {
81         function.replace(newFunction);
82       }
83       else {
84         final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
85         PyStringLiteralExpression docString = function.getDocStringExpression();
86         if (docString != null)
87           list.addAfter(ifStatement, firstStatement);
88         else {
89           list.addBefore(ifStatement, firstStatement);
90         }
91         paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
92                                                           PyFunction.class, functionText.toString()).getParameterList());
93       }
94     }
95   }
96 }