IDEA-CR-11871 partially fix possible race condition
authorVladimir Krivosheev <vladimir.krivosheev@jetbrains.com>
Wed, 27 Jul 2016 15:02:49 +0000 (17:02 +0200)
committerVladimir Krivosheev <vladimir.krivosheev@jetbrains.com>
Wed, 27 Jul 2016 15:02:49 +0000 (17:02 +0200)
platform/platform-api/src/org/jetbrains/concurrency/AsyncPromise.kt

index 91f13db02c968c55bf5242e58b5b3ca5e34e85b4..03bb4a69e895ca521bf817e498f1a30f052c2dd2 100644 (file)
@@ -20,6 +20,7 @@ import com.intellij.openapi.util.Getter
 import com.intellij.util.Consumer
 import com.intellij.util.Function
 import java.util.*
+import java.util.concurrent.atomic.AtomicReference
 
 private val LOG = Logger.getInstance(AsyncPromise::class.java)
 
@@ -30,30 +31,30 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
   @Volatile private var done: Consumer<in T>? = null
   @Volatile private var rejected: Consumer<in Throwable>? = null
 
-  @Volatile private var state: Promise.State = Promise.State.PENDING
+  private val state = AtomicReference(Promise.State.PENDING)
 
   // result object or error message
   @Volatile private var result: Any? = null
 
-  override fun getState() = state
+  override fun getState() = state.get()!!
 
   override fun done(done: Consumer<in T>): Promise<T> {
     if (isObsolete(done)) {
       return this
     }
 
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
+        this.done = setHandler(this.done, done, State.FULFILLED)
       }
       Promise.State.FULFILLED -> {
         @Suppress("UNCHECKED_CAST")
         done.consume(result as T?)
-        return this
       }
-      Promise.State.REJECTED -> return this
+      Promise.State.REJECTED -> {
+      }
     }
 
-    this.done = setHandler(this.done, done)
     return this
   }
 
@@ -62,26 +63,26 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
       return this
     }
 
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
+        this.rejected = setHandler(this.rejected, rejected, State.REJECTED)
+      }
+      Promise.State.FULFILLED -> {
       }
-      Promise.State.FULFILLED -> return this
       Promise.State.REJECTED -> {
         rejected.consume(result as Throwable?)
-        return this
       }
     }
 
-    this.rejected = setHandler(this.rejected, rejected)
     return this
   }
 
   @Suppress("UNCHECKED_CAST")
-  override fun get() = if (state == Promise.State.FULFILLED) result as T? else null
+  override fun get() = if (state.get() == Promise.State.FULFILLED) result as T? else null
 
   override fun <SUB_RESULT> then(fulfilled: Function<in T, out SUB_RESULT>): Promise<SUB_RESULT> {
     @Suppress("UNCHECKED_CAST")
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
       }
       Promise.State.FULFILLED -> return DonePromise<SUB_RESULT>(fulfilled.`fun`(result as T?))
@@ -105,26 +106,23 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
   override fun notify(child: AsyncPromise<in T>) {
     LOG.assertTrue(child !== this)
 
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
+        addHandlers(Consumer({ child.catchError { child.setResult(it) } }), Consumer({ child.setError(it) }))
       }
       Promise.State.FULFILLED -> {
         @Suppress("UNCHECKED_CAST")
         child.setResult(result as T)
-        return
       }
       Promise.State.REJECTED -> {
         child.setError((result as Throwable?)!!)
-        return
       }
     }
-
-    addHandlers(Consumer({ child.catchError { child.setResult(it) } }), Consumer({ child.setError(it) }))
   }
 
   override fun <SUB_RESULT> thenAsync(fulfilled: Function<in T, Promise<SUB_RESULT>>): Promise<SUB_RESULT> {
     @Suppress("UNCHECKED_CAST")
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
       }
       Promise.State.FULFILLED -> return fulfilled.`fun`(result as T?)
@@ -144,36 +142,32 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
   }
 
   override fun processed(fulfilled: AsyncPromise<in T>): Promise<T> {
-    when (state) {
+    when (state.get()!!) {
       Promise.State.PENDING -> {
+        addHandlers(Consumer({ result -> fulfilled.catchError { fulfilled.setResult(result) } }), Consumer({ fulfilled.setError(it) }))
       }
       Promise.State.FULFILLED -> {
         @Suppress("UNCHECKED_CAST")
         fulfilled.setResult(result as T)
-        return this
       }
       Promise.State.REJECTED -> {
         fulfilled.setError((result as Throwable?)!!)
-        return this
       }
     }
-
-    addHandlers(Consumer({ result -> fulfilled.catchError { fulfilled.setResult(result) } }), Consumer({ fulfilled.setError(it) }))
     return this
   }
 
   private fun addHandlers(done: Consumer<T>, rejected: Consumer<Throwable>) {
-    this.done = setHandler(this.done, done)
-    this.rejected = setHandler(this.rejected, rejected)
+    this.done = setHandler(this.done, done, State.FULFILLED)
+    this.rejected = setHandler(this.rejected, rejected, State.REJECTED)
   }
 
   fun setResult(result: T?) {
-    if (state != Promise.State.PENDING) {
+    if (!state.compareAndSet(Promise.State.PENDING, Promise.State.FULFILLED)) {
       return
     }
 
     this.result = result
-    state = Promise.State.FULFILLED
 
     val done = this.done
     clearHandlers()
@@ -182,21 +176,18 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
     }
   }
 
-  fun setError(error: String): Boolean {
-    return setError(Promise.createError(error))
-  }
+  fun setError(error: String) = setError(Promise.createError(error))
 
   fun cancel() {
     setError(OBSOLETE_ERROR)
   }
 
   open fun setError(error: Throwable): Boolean {
-    if (state != Promise.State.PENDING) {
+    if (!state.compareAndSet(Promise.State.PENDING, Promise.State.REJECTED)) {
       return false
     }
 
     result = error
-    state = Promise.State.REJECTED
 
     val rejected = this.rejected
     clearHandlers()
@@ -216,13 +207,38 @@ open class AsyncPromise<T> : Promise<T>(), Getter<T> {
 
   override fun processed(processed: Consumer<in T>): Promise<T> {
     done(processed)
-    rejected({ error -> processed.consume(null) })
+    rejected { processed.consume(null) }
     return this
   }
+
+  private fun <T> setHandler(oldConsumer: Consumer<in T>?, newConsumer: Consumer<in T>, targetState: State): Consumer<in T>? = when (oldConsumer) {
+    null -> newConsumer
+    is CompoundConsumer<*> -> {
+      @Suppress("UNCHECKED_CAST")
+      val compoundConsumer = oldConsumer as CompoundConsumer<T>
+      synchronized(compoundConsumer) {
+        compoundConsumer.consumers.let {
+          if (it == null) {
+            // clearHandlers was called - just execute newConsumer
+            if (state.get() == targetState) {
+              @Suppress("UNCHECKED_CAST")
+              newConsumer.consume(result as T?)
+            }
+            return null
+          }
+          else {
+            it.add(newConsumer)
+            return compoundConsumer
+          }
+        }
+      }
+    }
+    else -> CompoundConsumer(oldConsumer, newConsumer)
+  }
 }
 
 private class CompoundConsumer<T>(c1: Consumer<in T>, c2: Consumer<in T>) : Consumer<T> {
-  private var consumers: MutableList<Consumer<in T>>? = ArrayList()
+  var consumers: MutableList<Consumer<in T>>? = ArrayList()
 
   init {
     synchronized(this) {
@@ -247,23 +263,16 @@ private class CompoundConsumer<T>(c1: Consumer<in T>, c2: Consumer<in T>) : Cons
 
   fun add(consumer: Consumer<in T>) {
     synchronized(this) {
-      if (consumers != null) {
-        consumers!!.add(consumer)
+      consumers.let {
+        if (it == null) {
+          // it means that clearHandlers was called
+        }
+        consumers?.add(consumer)
       }
     }
   }
 }
 
-private fun <T> setHandler(oldConsumer: Consumer<in T>?, newConsumer: Consumer<in T>) = when (oldConsumer) {
-  null -> newConsumer
-  is CompoundConsumer<*> -> {
-    @Suppress("UNCHECKED_CAST")
-    (oldConsumer as CompoundConsumer<T>).add(newConsumer)
-    oldConsumer
-  }
-  else -> CompoundConsumer(oldConsumer, newConsumer)
-}
-
 internal fun isObsolete(consumer: Consumer<*>?) = consumer is Obsolescent && consumer.isObsolete
 
 inline fun <T> AsyncPromise<*>.catchError(runnable: () -> T): T? {