a32c4f8ed4b81794e8f3305fdb79c82902479601
[idea/community.git] / platform / projectModel-api / src / org / jetbrains / concurrency / AsyncPromise.kt
1 /*
2  * Copyright 2000-2016 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package org.jetbrains.concurrency
17
18 import com.intellij.openapi.diagnostic.Logger
19 import com.intellij.openapi.util.Getter
20 import com.intellij.util.Consumer
21 import com.intellij.util.Function
22 import org.jetbrains.concurrency.Promise.State
23 import java.util.*
24 import java.util.concurrent.CountDownLatch
25 import java.util.concurrent.TimeUnit
26 import java.util.concurrent.TimeoutException
27 import java.util.concurrent.atomic.AtomicReference
28
29 private val LOG = Logger.getInstance(AsyncPromise::class.java)
30
31 open class AsyncPromise<T> : Promise<T>, Getter<T> {
32   private val doneRef = AtomicReference<Consumer<in T>?>()
33   private val rejectedRef = AtomicReference<Consumer<in Throwable>?>()
34
35   private val stateRef = AtomicReference(State.PENDING)
36
37   // result object or error message
38   @Volatile private var result: Any? = null
39
40   override fun getState() = stateRef.get()!!
41
42   override fun done(done: Consumer<in T>): Promise<T> {
43     setHandler(doneRef, done, State.FULFILLED)
44     return this
45   }
46
47   override fun rejected(rejected: Consumer<Throwable>): Promise<T> {
48     setHandler(rejectedRef, rejected, State.REJECTED)
49     return this
50   }
51
52   @Suppress("UNCHECKED_CAST")
53   override fun get() = if (state == State.FULFILLED) result as T? else null
54
55   override fun <SUB_RESULT> then(handler: Function<in T, out SUB_RESULT>): Promise<SUB_RESULT> {
56     @Suppress("UNCHECKED_CAST")
57     when (state) {
58       State.PENDING -> {
59       }
60       State.FULFILLED -> return DonePromise<SUB_RESULT>(handler.`fun`(result as T?))
61       State.REJECTED -> return rejectedPromise(result as Throwable)
62     }
63
64     val promise = AsyncPromise<SUB_RESULT>()
65     addHandlers(Consumer({ result ->
66                            promise.catchError {
67                              if (handler is Obsolescent && handler.isObsolete) {
68                                promise.cancel()
69                              }
70                              else {
71                                promise.setResult(handler.`fun`(result))
72                              }
73                            }
74                          }), Consumer({ promise.setError(it) }))
75     return promise
76   }
77
78   override fun notify(child: AsyncPromise<in T>) {
79     LOG.assertTrue(child !== this)
80
81     when (state) {
82       State.PENDING -> {
83         addHandlers(Consumer({ child.catchError { child.setResult(it) } }), Consumer({ child.setError(it) }))
84       }
85       State.FULFILLED -> {
86         @Suppress("UNCHECKED_CAST")
87         child.setResult(result as T)
88       }
89       State.REJECTED -> {
90         child.setError((result as Throwable?)!!)
91       }
92     }
93   }
94
95   override fun <SUB_RESULT> thenAsync(handler: Function<in T, Promise<SUB_RESULT>>): Promise<SUB_RESULT> {
96     @Suppress("UNCHECKED_CAST")
97     when (state) {
98       State.PENDING -> {
99       }
100       State.FULFILLED -> return handler.`fun`(result as T?)
101       State.REJECTED -> return rejectedPromise(result as Throwable)
102     }
103
104     val promise = AsyncPromise<SUB_RESULT>()
105     val rejectedHandler = Consumer<Throwable>({ promise.setError(it) })
106     addHandlers(Consumer({
107                            promise.catchError {
108                              handler.`fun`(it)
109                                  .done { promise.catchError { promise.setResult(it) } }
110                                  .rejected(rejectedHandler)
111                            }
112                          }), rejectedHandler)
113     return promise
114   }
115
116   override fun processed(fulfilled: AsyncPromise<in T>): Promise<T> {
117     when (state) {
118       State.PENDING -> {
119         addHandlers(Consumer({ result -> fulfilled.catchError { fulfilled.setResult(result) } }), Consumer({ fulfilled.setError(it) }))
120       }
121       State.FULFILLED -> {
122         @Suppress("UNCHECKED_CAST")
123         fulfilled.setResult(result as T)
124       }
125       State.REJECTED -> {
126         fulfilled.setError((result as Throwable?)!!)
127       }
128     }
129     return this
130   }
131
132   private fun addHandlers(done: Consumer<T>, rejected: Consumer<Throwable>) {
133     setHandler(doneRef, done, State.FULFILLED)
134     setHandler(rejectedRef, rejected, State.REJECTED)
135   }
136
137   fun setResult(result: T?) {
138     if (!stateRef.compareAndSet(State.PENDING, State.FULFILLED)) {
139       return
140     }
141
142     this.result = result
143
144     val done = doneRef.getAndSet(null)
145     rejectedRef.set(null)
146
147     if (done != null && !isObsolete(done)) {
148       done.consume(result)
149     }
150   }
151
152   fun setError(error: String) = setError(createError(error))
153
154   fun cancel() {
155     setError(OBSOLETE_ERROR)
156   }
157
158   open fun setError(error: Throwable): Boolean {
159     if (!stateRef.compareAndSet(State.PENDING, State.REJECTED)) {
160       LOG.errorIfNotMessage(error)
161       return false
162     }
163
164     result = error
165
166     val rejected = rejectedRef.getAndSet(null)
167     doneRef.set(null)
168
169     if (rejected == null) {
170       LOG.errorIfNotMessage(error)
171     }
172     else if (!isObsolete(rejected)) {
173       rejected.consume(error)
174     }
175     return true
176   }
177
178   override fun processed(processed: Consumer<in T>): Promise<T> {
179     done(processed)
180     rejected { processed.consume(null) }
181     return this
182   }
183
184   override fun blockingGet(timeout: Int, timeUnit: TimeUnit): T? {
185     val latch = CountDownLatch(1)
186     processed { latch.countDown() }
187     if (!latch.await(timeout.toLong(), timeUnit)) {
188       throw TimeoutException()
189     }
190
191     @Suppress("UNCHECKED_CAST")
192     if (isRejected) {
193       throw (result as Throwable)
194     }
195     else {
196       return result as T?
197     }
198   }
199
200   private fun <T> setHandler(ref: AtomicReference<Consumer<in T>?>, newConsumer: Consumer<in T>, targetState: State) {
201     if (isObsolete(newConsumer)) {
202       return
203     }
204
205     if (state != State.PENDING) {
206       if (state == targetState) {
207         @Suppress("UNCHECKED_CAST")
208         newConsumer.consume(result as T?)
209       }
210       return
211     }
212
213     while (true) {
214       val oldConsumer = ref.get()
215       val newEffectiveConsumer = when (oldConsumer) {
216         null -> newConsumer
217         is CompoundConsumer<*> -> {
218           @Suppress("UNCHECKED_CAST")
219           val compoundConsumer = oldConsumer as CompoundConsumer<T>
220           var executed = true
221           synchronized(compoundConsumer) {
222             compoundConsumer.consumers?.let {
223               it.add(newConsumer)
224               executed = false
225             }
226           }
227
228           // clearHandlers was called - just execute newConsumer
229           if (executed) {
230             if (state == targetState) {
231               @Suppress("UNCHECKED_CAST")
232               newConsumer.consume(result as T?)
233             }
234             return
235           }
236
237           compoundConsumer
238         }
239         else -> CompoundConsumer(oldConsumer, newConsumer)
240       }
241
242       if (ref.compareAndSet(oldConsumer, newEffectiveConsumer)) {
243         break
244       }
245     }
246
247     if (state == targetState) {
248       ref.getAndSet(null)?.let {
249         @Suppress("UNCHECKED_CAST")
250         it.consume(result as T?)
251       }
252     }
253   }
254 }
255
256 private class CompoundConsumer<T>(c1: Consumer<in T>, c2: Consumer<in T>) : Consumer<T> {
257   var consumers: MutableList<Consumer<in T>>? = ArrayList()
258
259   init {
260     synchronized(this) {
261       consumers!!.add(c1)
262       consumers!!.add(c2)
263     }
264   }
265
266   override fun consume(t: T) {
267     val list = synchronized(this) {
268       val list = consumers
269       consumers = null
270       list
271     } ?: return
272
273     for (consumer in list) {
274       if (!isObsolete(consumer)) {
275         consumer.consume(t)
276       }
277     }
278   }
279 }
280
281 internal fun isObsolete(consumer: Consumer<*>?) = consumer is Obsolescent && consumer.isObsolete
282
283 inline fun <T> AsyncPromise<*>.catchError(runnable: () -> T): T? {
284   try {
285     return runnable()
286   }
287   catch (e: Throwable) {
288     setError(e)
289     return null
290   }
291 }