import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.jetbrains.python.PyBundle;
+import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
final PyFunction function = PsiTreeUtil.getParentOfType(defaultValue, PyFunction.class);
assert param != null;
final String defName = param.getName();
- if (function != null) {
- final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
- final PyStatementList list = function.getStatementList();
- final PyParameterList paramList = function.getParameterList();
-
- final StringBuilder functionText = new StringBuilder("def " + function.getName() + "(");
- final int size = paramList.getParameters().length;
- for (int i = 0; i < size; i++) {
- final PyParameter p = paramList.getParameters()[i];
- if (p == param) {
- functionText.append(defName).append("=None");
- }
- else {
- functionText.append(p.getText());
- }
- if (i != size-1) {
- functionText.append(", ");
- }
- }
+ if (function != null && defName != null) {
+ final PyElementGenerator generator = PyElementGenerator.getInstance(project);
+ final LanguageLevel languageLevel = LanguageLevel.forElement(function);
- functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
- final PyStatement[] statements = list.getStatements();
- final PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
- final PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
- functionText.toString());
- if (firstStatement == null) {
- function.replace(newFunction);
- }
- else {
- final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
- final PyStringLiteralExpression docString = function.getDocStringExpression();
- if (docString != null) {
- list.addAfter(ifStatement, firstStatement);
- }
- else {
- list.addBefore(ifStatement, firstStatement);
- }
- paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
- PyFunction.class, functionText.toString()).getParameterList());
- }
+ final PyNamedParameter newParam = generator.createParameter(defName, PyNames.NONE, null, languageLevel);
+ param.replace(newParam);
+
+ final String conditionalText = "if not " + defName + ":\n\t" + defName + " = " + defaultValue.getText() + "\n";
+ final PyIfStatement conditionalAssignment = generator.createFromText(languageLevel, PyIfStatement.class, conditionalText);
+ PyUtil.addElementToStatementList(conditionalAssignment, function.getStatementList(), true);
}
}
}
public static PsiElement addElementToStatementList(@NotNull PsiElement element,
@NotNull PyStatementList statementList,
boolean toTheBeginning) {
- final boolean statementListWasEmpty = statementList.getStatements().length == 0;
+ final PsiElement prevElem = PyPsiUtils.getPrevNonWhitespaceSibling(statementList);
+ // If statement list is on the same line as previous element (supposedly colon), move its only statement on the next line
+ if (prevElem != null && onSameLine(statementList, prevElem)) {
+ final PsiDocumentManager manager = PsiDocumentManager.getInstance(statementList.getProject());
+ final Document document = manager.getDocument(statementList.getContainingFile());
+ if (document != null) {
+ final PsiElement container = statementList.getParent();
+ manager.doPostponedOperationsAndUnblockDocument(document);
+ final String indentation = "\n" + PyIndentUtil.getElementIndent(statementList);
+ // If statement list was empty initially, we need to add some anchor statement ("pass"), so that preceding new line was not
+ // parsed as following entire StatementListContainer (e.g. function). It's going to be replaced anyway.
+ final String text = statementList.getStatements().length == 0 ? indentation + PyNames.PASS : indentation;
+ document.insertString(statementList.getTextRange().getStartOffset(), text);
+ manager.commitDocument(document);
+ statementList = ((PyStatementListContainer)container).getStatementList();
+ }
+ }
final PsiElement firstChild = statementList.getFirstChild();
if (firstChild == statementList.getLastChild() && firstChild instanceof PyPassStatement) {
element = firstChild.replace(element);
element = statementList.add(element);
}
}
- if (statementListWasEmpty) {
- final PsiDocumentManager documentManager = PsiDocumentManager.getInstance(statementList.getProject());
- final Document document = documentManager.getDocument(statementList.getContainingFile());
- if (document != null) {
- documentManager.doPostponedOperationsAndUnblockDocument(document);
- document.insertString(statementList.getTextOffset(), "\n" + PyIndentUtil.getElementIndent(statementList));
- documentManager.commitDocument(document);
- }
- }
return element;
}