PY-20026 PyClassImpl#mroLinearize() properly handles duplicate base classes
authorMikhail Golubev <mikhail.golubev@jetbrains.com>
Wed, 13 Jul 2016 14:16:32 +0000 (17:16 +0300)
committerMikhail Golubev <mikhail.golubev@jetbrains.com>
Fri, 15 Jul 2016 15:39:11 +0000 (18:39 +0300)
The exception could happen in this case because we used the same cached
result of MRO linearization twice without defencive copying. Then later,
as a side effect of that, in mroMerge() we deleted one "head" from
several sequences simultaneously, hence the IndexOutOfBoundsException.

python/src/com/jetbrains/python/psi/impl/PyClassImpl.java
python/testData/codeInsight/classMRO/DuplicatedBaseClasses.py [new file with mode: 0644]
python/testSrc/com/jetbrains/python/codeInsight/PyClassMROTest.java

index e70052348e0fcda7856f23db8201a20c305f4d06..928d60736ab9b382e1e0e2844387651d8f605a2e 100644 (file)
@@ -437,16 +437,17 @@ public class PyClassImpl extends PyBaseElementImpl<PyClassStub> implements PyCla
       }
       return computed.get();
     }
-    cache.put(type, Ref.<List<PyClassLikeType>>create());
+    cache.put(type, Ref.create());
     List<PyClassLikeType> result = null;
     try {
-      final List<PyClassLikeType> bases = type.getSuperClassTypes(context);
-      final List<List<PyClassLikeType>> lines = new ArrayList<List<PyClassLikeType>>();
+      final List<PyClassLikeType> bases = removeNotNullDuplicates(type.getSuperClassTypes(context));
+      final List<List<PyClassLikeType>> lines = new ArrayList<>();
       for (PyClassLikeType base : bases) {
         if (base != null) {
           final List<PyClassLikeType> baseClassMRO = mroLinearize(base, true, context, cache);
           if (!baseClassMRO.isEmpty()) {
-            lines.add(baseClassMRO);
+            // mroMerge() updates passed MRO lists internally
+            lines.add(new LinkedList<>(baseClassMRO));
           }
         }
       }
@@ -457,6 +458,7 @@ public class PyClassImpl extends PyBaseElementImpl<PyClassStub> implements PyCla
       if (addThisType) {
         result.add(0, type);
       }
+      result = Collections.unmodifiableList(result);
     }
     finally {
       cache.put(type, Ref.create(result));
@@ -464,6 +466,22 @@ public class PyClassImpl extends PyBaseElementImpl<PyClassStub> implements PyCla
     return result;
   }
 
+  @NotNull
+  private static <T> List<T> removeNotNullDuplicates(@NotNull List<T> list) {
+    final Set<T> distinct = new HashSet<>();
+    final List<T> result = new ArrayList<>();
+    for (T elem : list) {
+      if (elem != null) {
+        final boolean isUnique = distinct.add(elem);
+        if (!isUnique) {
+          continue;
+        }
+      }
+      result.add(elem);
+    }
+    return result;
+  }
+
   @Override
   @NotNull
   public PyFunction[] getMethods() {
@@ -1366,11 +1384,11 @@ public class PyClassImpl extends PyBaseElementImpl<PyClassStub> implements PyCla
     PyPsiUtils.assertValid(this);
     final PyType thisType = context.getType(this);
     if (thisType instanceof PyClassLikeType) {
-      final PyClassLikeType thisClassLikeType = (PyClassLikeType)thisType;
-      final List<PyClassLikeType> ancestorTypes =
-        mroLinearize(thisClassLikeType, false, context, new HashMap<PyClassLikeType, Ref<List<PyClassLikeType>>>());
+      final List<PyClassLikeType> ancestorTypes = mroLinearize((PyClassLikeType)thisType, false, context, new HashMap<>());
       if (isOverriddenMRO(ancestorTypes, context)) {
-        ancestorTypes.add(null);
+        final ArrayList<PyClassLikeType> withNull = new ArrayList<>(ancestorTypes);
+        withNull.add(null);
+        return withNull;
       }
       return ancestorTypes;
     }
diff --git a/python/testData/codeInsight/classMRO/DuplicatedBaseClasses.py b/python/testData/codeInsight/classMRO/DuplicatedBaseClasses.py
new file mode 100644 (file)
index 0000000..dd481f9
--- /dev/null
@@ -0,0 +1,6 @@
+class Base(object):
+    pass
+
+
+class MyClass(Base, Base):
+    pass
index 5d053da791b116f569d74096f2773bb8cd65edfc..f82cf215e5fbe264891a69222e617b1c5612f6b8 100644 (file)
@@ -110,6 +110,11 @@ public class PyClassMROTest extends PyTestCase {
     assertOrderedEquals(classNames, Arrays.asList(mro));
   }
 
+  // PY-20026
+  public void testDuplicatedBaseClasses() {
+    assertMRO(getClass("MyClass"), "Base", "object");
+  }
+
   @NotNull
   public PyClass getClass(@NotNull String name) {
     myFixture.configureByFile(getPath(getTestName(false)));