2 * Copyright 2000-2013 JetBrains s.r.o.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 package com.jetbrains.python.inspections.quickfix;
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.jetbrains.python.PyBundle;
24 import com.jetbrains.python.psi.*;
25 import org.jetbrains.annotations.NotNull;
30 * QuickFix to replace mutable default argument. For instance,
35 if not args: args = []
38 public class PyDefaultArgumentQuickFix implements LocalQuickFix {
42 public String getName() {
43 return PyBundle.message("QFIX.default.argument");
48 public String getFamilyName() {
53 public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
54 PsiElement defaultValue = descriptor.getPsiElement();
55 PyNamedParameter param = PsiTreeUtil.getParentOfType(defaultValue, PyNamedParameter.class);
56 PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
58 String defName = param.getName();
59 if (function != null) {
60 PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
61 PyStatementList list = function.getStatementList();
62 PyParameterList paramList = function.getParameterList();
64 final StringBuilder functionText = new StringBuilder("def foo(");
65 int size = paramList.getParameters().length;
66 for (int i = 0; i != size; ++i) {
67 PyParameter p = paramList.getParameters()[i];
69 functionText.append(defName).append("=None");
71 functionText.append(p.getText());
73 functionText.append(", ");
75 functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
76 final PyStatement[] statements = list.getStatements();
77 PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
78 PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
79 functionText.toString());
80 if (firstStatement == null) {
81 function.replace(newFunction);
84 final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
85 PyStringLiteralExpression docString = function.getDocStringExpression();
86 if (docString != null)
87 list.addAfter(ifStatement, firstStatement);
89 list.addBefore(ifStatement, firstStatement);
91 paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
92 PyFunction.class, functionText.toString()).getParameterList());