2 * Copyright 2000-2014 JetBrains s.r.o.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 package com.jetbrains.python.inspections.quickfix;
18 import com.intellij.codeInspection.LocalQuickFix;
19 import com.intellij.codeInspection.ProblemDescriptor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.jetbrains.python.PyBundle;
24 import com.jetbrains.python.psi.*;
25 import org.jetbrains.annotations.NotNull;
30 * QuickFix to replace mutable default argument. For instance,
35 if not args: args = []
38 public class PyDefaultArgumentQuickFix implements LocalQuickFix {
42 public String getName() {
43 return PyBundle.message("QFIX.default.argument");
48 public String getFamilyName() {
53 public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
54 final PsiElement defaultValue = descriptor.getPsiElement();
55 final PyNamedParameter param = PsiTreeUtil.getParentOfType(defaultValue, PyNamedParameter.class);
56 final PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
58 final String defName = param.getName();
59 if (function != null) {
60 final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
61 final PyStatementList list = function.getStatementList();
62 final PyParameterList paramList = function.getParameterList();
64 final StringBuilder functionText = new StringBuilder("def " + function.getName() + "(");
65 final int size = paramList.getParameters().length;
66 for (int i = 0; i < size; i++) {
67 final PyParameter p = paramList.getParameters()[i];
69 functionText.append(defName).append("=None");
72 functionText.append(p.getText());
75 functionText.append(", ");
79 functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
80 final PyStatement[] statements = list.getStatements();
81 final PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
82 final PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
83 functionText.toString());
84 if (firstStatement == null) {
85 function.replace(newFunction);
88 final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
89 final PyStringLiteralExpression docString = function.getDocStringExpression();
90 if (docString != null) {
91 list.addAfter(ifStatement, firstStatement);
94 list.addBefore(ifStatement, firstStatement);
96 paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
97 PyFunction.class, functionText.toString()).getParameterList());