PY-17392 Update specific PSI elements, instead of entire parameter and statement...
authorMikhail Golubev <mikhail.golubev@jetbrains.com>
Fri, 30 Oct 2015 12:39:30 +0000 (15:39 +0300)
committerMikhail Golubev <mikhail.golubev@jetbrains.com>
Fri, 30 Oct 2015 12:59:44 +0000 (15:59 +0300)
Helper method PyUtil#addElementToStatementList moves old statement list
on the next line (below the header of the parent statement), if
necessary, *before* the new statement is inserted. If it's done
afterwards and new statement is multiline, it might not be properly
indented.

python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java
python/src/com/jetbrains/python/psi/PyUtil.java
python/testData/inspections/DefaultArgumentCommentsInsideParameters.py [new file with mode: 0644]
python/testData/inspections/DefaultArgumentCommentsInsideParameters_after.py [new file with mode: 0644]
python/testData/inspections/DefaultArgument_after.py
python/testSrc/com/jetbrains/python/PyQuickFixTest.java

index 4832014829651ddf63703a811871613d75f95263..4a0d61d4e8b41e75014919af90dd167052de0a03 100644 (file)
@@ -21,6 +21,7 @@ import com.intellij.openapi.project.Project;
 import com.intellij.psi.PsiElement;
 import com.intellij.psi.util.PsiTreeUtil;
 import com.jetbrains.python.PyBundle;
+import com.jetbrains.python.PyNames;
 import com.jetbrains.python.psi.*;
 import org.jetbrains.annotations.NotNull;
 
@@ -56,46 +57,16 @@ public class PyDefaultArgumentQuickFix implements LocalQuickFix {
     final PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
     assert param != null;
     final String defName = param.getName();
-    if (function != null) {
-      final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
-      final PyStatementList list = function.getStatementList();
-      final PyParameterList paramList = function.getParameterList();
-
-      final StringBuilder functionText = new StringBuilder("def " + function.getName() + "(");
-      final int size = paramList.getParameters().length;
-      for (int i = 0; i < size; i++) {
-        final PyParameter p = paramList.getParameters()[i];
-        if (p == param) {
-          functionText.append(defName).append("=None");
-        }
-        else {
-          functionText.append(p.getText());
-        }
-        if (i != size-1) {
-          functionText.append(", ");
-        }
-      }
+    if (function != null && defName != null) {
+      final PyElementGenerator generator = PyElementGenerator.getInstance(project);
+      final LanguageLevel languageLevel = LanguageLevel.forElement(function);
       
-      functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
-      final PyStatement[] statements = list.getStatements();
-      final PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
-      final PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
-                                                                     functionText.toString());
-      if (firstStatement == null) {
-        function.replace(newFunction);
-      }
-      else {
-        final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
-        final PyStringLiteralExpression docString = function.getDocStringExpression();
-        if (docString != null) {
-          list.addAfter(ifStatement, firstStatement);
-        }
-        else {
-          list.addBefore(ifStatement, firstStatement);
-        }
-        paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
-                                                          PyFunction.class, functionText.toString()).getParameterList());
-      }
+      final PyNamedParameter newParam = generator.createParameter(defName, PyNames.NONE, null, languageLevel);
+      param.replace(newParam);
+
+      final String conditionalText = "if not " + defName + ":\n\t" + defName + " = " + defaultValue.getText() + "\n";
+      final PyIfStatement conditionalAssignment = generator.createFromText(languageLevel, PyIfStatement.class, conditionalText);
+      PyUtil.addElementToStatementList(conditionalAssignment, function.getStatementList(), true);
     }
   }
 }
index 883e63fac11b550e85911e3e7f7666e2de9e356a..fc3b8ff5e686c995a6b029ed3799bcd29f36568d 100644 (file)
@@ -1469,7 +1469,23 @@ public class PyUtil {
   public static PsiElement addElementToStatementList(@NotNull PsiElement element,
                                                      @NotNull PyStatementList statementList,
                                                      boolean toTheBeginning) {
-    final boolean statementListWasEmpty = statementList.getStatements().length == 0;
+    final PsiElement prevElem = PyPsiUtils.getPrevNonWhitespaceSibling(statementList);
+    // If statement list is on the same line as previous element (supposedly colon), move its only statement on the next line
+    if (prevElem != null && onSameLine(statementList, prevElem)) {
+      final PsiDocumentManager manager = PsiDocumentManager.getInstance(statementList.getProject());
+      final Document document = manager.getDocument(statementList.getContainingFile());
+      if (document != null) {
+        final PsiElement container = statementList.getParent();
+        manager.doPostponedOperationsAndUnblockDocument(document);
+        final String indentation = "\n" + PyIndentUtil.getElementIndent(statementList);
+        // If statement list was empty initially, we need to add some anchor statement ("pass"), so that preceding new line was not
+        // parsed as following entire StatementListContainer (e.g. function). It's going to be replaced anyway.
+        final String text = statementList.getStatements().length == 0 ? indentation + PyNames.PASS : indentation;
+        document.insertString(statementList.getTextRange().getStartOffset(), text);
+        manager.commitDocument(document);
+        statementList = ((PyStatementListContainer)container).getStatementList();
+      }
+    }
     final PsiElement firstChild = statementList.getFirstChild();
     if (firstChild == statementList.getLastChild() && firstChild instanceof PyPassStatement) {
       element = firstChild.replace(element);
@@ -1508,15 +1524,6 @@ public class PyUtil {
         element = statementList.add(element);
       }
     }
-    if (statementListWasEmpty) {
-      final PsiDocumentManager documentManager = PsiDocumentManager.getInstance(statementList.getProject());
-      final Document document = documentManager.getDocument(statementList.getContainingFile());
-      if (document != null) {
-        documentManager.doPostponedOperationsAndUnblockDocument(document);
-        document.insertString(statementList.getTextOffset(), "\n" + PyIndentUtil.getElementIndent(statementList));
-        documentManager.commitDocument(document);
-      }
-    }
     return element;
   }
 
diff --git a/python/testData/inspections/DefaultArgumentCommentsInsideParameters.py b/python/testData/inspections/DefaultArgumentCommentsInsideParameters.py
new file mode 100644 (file)
index 0000000..e57dcf7
--- /dev/null
@@ -0,0 +1,4 @@
+def func(x,  # comment
+         mutable=<warning descr="Default argument value is mutable">[<caret>]</warning>):
+    """Docstring."""
+    print(mutable)
\ No newline at end of file
diff --git a/python/testData/inspections/DefaultArgumentCommentsInsideParameters_after.py b/python/testData/inspections/DefaultArgumentCommentsInsideParameters_after.py
new file mode 100644 (file)
index 0000000..22bae04
--- /dev/null
@@ -0,0 +1,6 @@
+def func(x,  # comment
+         mutable=None):
+    """Docstring."""
+    if not mutable:
+        mutable = []
+    print(mutable)
\ No newline at end of file
index 8007f9411df1ed374da4e5d6ce420dcf6f56956b..2de1bb88ef9250100b3d0ed43c27f91701ca0e38 100644 (file)
@@ -1,4 +1,3 @@
 def foo(args=None):
     if not args:
-        args = []
-    pass
\ No newline at end of file
+        args = []
\ No newline at end of file
index c76a713877fbbf71fbf8d037782898da1db0cadc..e18ce4c133462239ff6b2e37911354edea4eb0f7 100644 (file)
@@ -342,6 +342,11 @@ public class PyQuickFixTest extends PyTestCase {
     doInspectionTest(PyDefaultArgumentInspection.class, PyBundle.message("QFIX.default.argument"), true, true);
   }
 
+  // PY-17392
+  public void testDefaultArgumentCommentsInsideParameters() {
+    doInspectionTest(PyDefaultArgumentInspection.class, PyBundle.message("QFIX.default.argument"), true, true);
+  }
+
   // PY-3125
   public void testArgumentEqualDefault() {
     doInspectionTest(PyArgumentEqualDefaultInspection.class, PyBundle.message("QFIX.remove.argument.equal.default"), true, true);