blockingGet — do not call await if not pending
[idea/community.git] / platform / projectModel-api / src / org / jetbrains / concurrency / AsyncPromise.kt
1 /*
2  * Copyright 2000-2016 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package org.jetbrains.concurrency
17
18 import com.intellij.openapi.diagnostic.Logger
19 import com.intellij.openapi.util.Getter
20 import com.intellij.util.Consumer
21 import com.intellij.util.Function
22 import org.jetbrains.concurrency.Promise.State
23 import java.util.*
24 import java.util.concurrent.CountDownLatch
25 import java.util.concurrent.TimeUnit
26 import java.util.concurrent.TimeoutException
27 import java.util.concurrent.atomic.AtomicReference
28
29 private val LOG = Logger.getInstance(AsyncPromise::class.java)
30
31 open class AsyncPromise<T> : Promise<T>, Getter<T> {
32   private val doneRef = AtomicReference<Consumer<in T>?>()
33   private val rejectedRef = AtomicReference<Consumer<in Throwable>?>()
34
35   private val stateRef = AtomicReference(State.PENDING)
36
37   // result object or error message
38   @Volatile private var result: Any? = null
39
40   override fun getState() = stateRef.get()!!
41
42   override fun done(done: Consumer<in T>): Promise<T> {
43     setHandler(doneRef, done, State.FULFILLED)
44     return this
45   }
46
47   override fun rejected(rejected: Consumer<Throwable>): Promise<T> {
48     setHandler(rejectedRef, rejected, State.REJECTED)
49     return this
50   }
51
52   @Suppress("UNCHECKED_CAST")
53   override fun get() = if (state == State.FULFILLED) result as T? else null
54
55   override fun <SUB_RESULT> then(handler: Function<in T, out SUB_RESULT>): Promise<SUB_RESULT> {
56     @Suppress("UNCHECKED_CAST")
57     when (state) {
58       State.PENDING -> {
59       }
60       State.FULFILLED -> return DonePromise<SUB_RESULT>(handler.`fun`(result as T?))
61       State.REJECTED -> return rejectedPromise(result as Throwable)
62     }
63
64     val promise = AsyncPromise<SUB_RESULT>()
65     addHandlers(Consumer({ result ->
66                            promise.catchError {
67                              if (handler is Obsolescent && handler.isObsolete) {
68                                promise.cancel()
69                              }
70                              else {
71                                promise.setResult(handler.`fun`(result))
72                              }
73                            }
74                          }), Consumer({ promise.setError(it) }))
75     return promise
76   }
77
78   override fun notify(child: AsyncPromise<in T>) {
79     LOG.assertTrue(child !== this)
80
81     when (state) {
82       State.PENDING -> {
83         addHandlers(Consumer({ child.catchError { child.setResult(it) } }), Consumer({ child.setError(it) }))
84       }
85       State.FULFILLED -> {
86         @Suppress("UNCHECKED_CAST")
87         child.setResult(result as T)
88       }
89       State.REJECTED -> {
90         child.setError((result as Throwable?)!!)
91       }
92     }
93   }
94
95   override fun <SUB_RESULT> thenAsync(handler: Function<in T, Promise<SUB_RESULT>>): Promise<SUB_RESULT> {
96     @Suppress("UNCHECKED_CAST")
97     when (state) {
98       State.PENDING -> {
99       }
100       State.FULFILLED -> return handler.`fun`(result as T?)
101       State.REJECTED -> return rejectedPromise(result as Throwable)
102     }
103
104     val promise = AsyncPromise<SUB_RESULT>()
105     val rejectedHandler = Consumer<Throwable>({ promise.setError(it) })
106     addHandlers(Consumer({
107                            promise.catchError {
108                              handler.`fun`(it)
109                                  .done { promise.catchError { promise.setResult(it) } }
110                                  .rejected(rejectedHandler)
111                            }
112                          }), rejectedHandler)
113     return promise
114   }
115
116   override fun processed(fulfilled: AsyncPromise<in T>): Promise<T> {
117     when (state) {
118       State.PENDING -> {
119         addHandlers(Consumer({ result -> fulfilled.catchError { fulfilled.setResult(result) } }), Consumer({ fulfilled.setError(it) }))
120       }
121       State.FULFILLED -> {
122         @Suppress("UNCHECKED_CAST")
123         fulfilled.setResult(result as T)
124       }
125       State.REJECTED -> {
126         fulfilled.setError((result as Throwable?)!!)
127       }
128     }
129     return this
130   }
131
132   private fun addHandlers(done: Consumer<T>, rejected: Consumer<Throwable>) {
133     setHandler(doneRef, done, State.FULFILLED)
134     setHandler(rejectedRef, rejected, State.REJECTED)
135   }
136
137   fun setResult(result: T?) {
138     if (!stateRef.compareAndSet(State.PENDING, State.FULFILLED)) {
139       return
140     }
141
142     this.result = result
143
144     val done = doneRef.getAndSet(null)
145     rejectedRef.set(null)
146
147     if (done != null && !isObsolete(done)) {
148       done.consume(result)
149     }
150   }
151
152   fun setError(error: String) = setError(createError(error))
153
154   fun cancel() {
155     setError(OBSOLETE_ERROR)
156   }
157
158   open fun setError(error: Throwable): Boolean {
159     if (!stateRef.compareAndSet(State.PENDING, State.REJECTED)) {
160       LOG.errorIfNotMessage(error)
161       return false
162     }
163
164     result = error
165
166     val rejected = rejectedRef.getAndSet(null)
167     doneRef.set(null)
168
169     if (rejected == null) {
170       LOG.errorIfNotMessage(error)
171     }
172     else if (!isObsolete(rejected)) {
173       rejected.consume(error)
174     }
175     return true
176   }
177
178   override fun processed(processed: Consumer<in T>): Promise<T> {
179     done(processed)
180     rejected { processed.consume(null) }
181     return this
182   }
183
184   override fun blockingGet(timeout: Int, timeUnit: TimeUnit): T? {
185     if (isPending) {
186       val latch = CountDownLatch(1)
187       processed { latch.countDown() }
188       if (!latch.await(timeout.toLong(), timeUnit)) {
189         throw TimeoutException()
190       }
191     }
192
193     @Suppress("UNCHECKED_CAST")
194     if (isRejected) {
195       throw (result as Throwable)
196     }
197     else {
198       return result as T?
199     }
200   }
201
202   private fun <T> setHandler(ref: AtomicReference<Consumer<in T>?>, newConsumer: Consumer<in T>, targetState: State) {
203     if (isObsolete(newConsumer)) {
204       return
205     }
206
207     if (state != State.PENDING) {
208       if (state == targetState) {
209         @Suppress("UNCHECKED_CAST")
210         newConsumer.consume(result as T?)
211       }
212       return
213     }
214
215     while (true) {
216       val oldConsumer = ref.get()
217       val newEffectiveConsumer = when (oldConsumer) {
218         null -> newConsumer
219         is CompoundConsumer<*> -> {
220           @Suppress("UNCHECKED_CAST")
221           val compoundConsumer = oldConsumer as CompoundConsumer<T>
222           var executed = true
223           synchronized(compoundConsumer) {
224             compoundConsumer.consumers?.let {
225               it.add(newConsumer)
226               executed = false
227             }
228           }
229
230           // clearHandlers was called - just execute newConsumer
231           if (executed) {
232             if (state == targetState) {
233               @Suppress("UNCHECKED_CAST")
234               newConsumer.consume(result as T?)
235             }
236             return
237           }
238
239           compoundConsumer
240         }
241         else -> CompoundConsumer(oldConsumer, newConsumer)
242       }
243
244       if (ref.compareAndSet(oldConsumer, newEffectiveConsumer)) {
245         break
246       }
247     }
248
249     if (state == targetState) {
250       ref.getAndSet(null)?.let {
251         @Suppress("UNCHECKED_CAST")
252         it.consume(result as T?)
253       }
254     }
255   }
256 }
257
258 private class CompoundConsumer<T>(c1: Consumer<in T>, c2: Consumer<in T>) : Consumer<T> {
259   var consumers: MutableList<Consumer<in T>>? = ArrayList()
260
261   init {
262     synchronized(this) {
263       consumers!!.add(c1)
264       consumers!!.add(c2)
265     }
266   }
267
268   override fun consume(t: T) {
269     val list = synchronized(this) {
270       val list = consumers
271       consumers = null
272       list
273     } ?: return
274
275     for (consumer in list) {
276       if (!isObsolete(consumer)) {
277         consumer.consume(t)
278       }
279     }
280   }
281 }
282
283 internal fun isObsolete(consumer: Consumer<*>?) = consumer is Obsolescent && consumer.isObsolete
284
285 inline fun <T> AsyncPromise<*>.catchError(runnable: () -> T): T? {
286   try {
287     return runnable()
288   }
289   catch (e: Throwable) {
290     setError(e)
291     return null
292   }
293 }