client should release buffer as soon as possible
[idea/community.git] / platform / built-in-server / testSrc / BinaryRequestHandlerTest.kt
1 package org.jetbrains.ide
2
3 import io.netty.channel.ChannelInitializer
4 import io.netty.channel.Channel
5 import org.jetbrains.io.Decoder
6 import io.netty.channel.ChannelHandlerContext
7 import io.netty.buffer.ByteBuf
8 import com.intellij.util.Consumer
9 import java.util.UUID
10 import io.netty.channel.ChannelHandler
11 import io.netty.util.CharsetUtil
12 import org.jetbrains.io.ChannelExceptionHandler
13 import org.jetbrains.io.NettyUtil
14 import com.intellij.util.net.NetUtils
15 import io.netty.buffer.Unpooled
16 import junit.framework.TestCase
17 import org.junit.rules.RuleChain
18 import org.junit.Rule
19 import org.junit.Test
20 import org.jetbrains.io.MessageDecoder
21 import org.jetbrains.concurrency.AsyncPromise
22 import org.jetbrains.concurrency.Promise
23 import com.intellij.util.concurrency.Semaphore
24
25 // we don't handle String in efficient way - because we want to test readContent/readChars also
26 public class BinaryRequestHandlerTest {
27   private val fixtureManager = FixtureRule()
28
29   private val _chain = RuleChain
30       .outerRule(fixtureManager)
31
32   Rule
33   public fun getChain(): RuleChain = _chain
34
35   Test
36   public fun test() {
37     val text = "Hello!"
38     val result = AsyncPromise<String>()
39
40     val bootstrap = NettyUtil.oioClientBootstrap().handler(object : ChannelInitializer<Channel>() {
41       override fun initChannel(channel: Channel) {
42         channel.pipeline().addLast(object : Decoder() {
43           override fun messageReceived(context: ChannelHandlerContext, input: ByteBuf) {
44             val requiredLength = 4 + text.length()
45             val response = readContent(input, context, requiredLength) {(buffer, context, isCumulateBuffer) -> buffer.toString(buffer.readerIndex(), requiredLength, CharsetUtil.UTF_8) }
46             if (response != null) {
47               result.setResult(response)
48             }
49           }
50         }, ChannelExceptionHandler.getInstance())
51       }
52     })
53
54     val port = BuiltInServerManager.getInstance().waitForStart().getPort()
55     val channel = bootstrap.connect(NetUtils.getLoopbackAddress(), port).syncUninterruptibly().channel()
56     val buffer = channel.alloc().buffer()
57     buffer.writeByte('C'.toInt())
58     buffer.writeByte('H'.toInt())
59     buffer.writeLong(MyBinaryRequestHandler.ID.getMostSignificantBits())
60     buffer.writeLong(MyBinaryRequestHandler.ID.getLeastSignificantBits())
61
62     val message = Unpooled.copiedBuffer(text, CharsetUtil.UTF_8)
63     buffer.writeShort(message.readableBytes())
64     channel.write(buffer)
65     channel.writeAndFlush(message).syncUninterruptibly()
66
67     try {
68       result.rejected(object : Consumer<Throwable> {
69         override fun consume(error: Throwable) {
70           TestCase.fail(error.getMessage())
71         }
72       })
73
74       if (result.getState() == Promise.State.PENDING) {
75         val semaphore = Semaphore()
76         semaphore.down()
77         result.processed { semaphore.up() }
78         if (!semaphore.waitForUnsafe(5000)) {
79           TestCase.fail("Time limit exceeded")
80           return
81         }
82       }
83
84       TestCase.assertEquals("got-" + text, result.get())
85     }
86     finally {
87       channel.close()
88     }
89   }
90
91   class MyBinaryRequestHandler : BinaryRequestHandler() {
92     class object {
93       val ID = UUID.fromString("E5068DD6-1DB7-437C-A3FC-3CA53B6E1AC9")
94     }
95
96     override fun getId(): UUID {
97       return ID
98     }
99
100     override fun getInboundHandler(context: ChannelHandlerContext): ChannelHandler {
101       return MyDecoder()
102     }
103
104     private class MyDecoder : MessageDecoder() {
105       private var state = State.HEADER
106
107       private enum class State {
108         HEADER
109         CONTENT
110       }
111
112       override fun messageReceived(context: ChannelHandlerContext, input: ByteBuf) {
113         while (true) {
114           when (state) {
115             State.HEADER -> {
116               val buffer = getBufferIfSufficient(input, 2, context)
117               if (buffer == null) {
118                 return
119               }
120
121               contentLength = buffer.readUnsignedShort()
122               state = State.CONTENT
123             }
124
125             State.CONTENT -> {
126               val messageText = readChars(input)
127               if (messageText == null) {
128                 return
129               }
130
131               state = State.HEADER
132               context.writeAndFlush(Unpooled.copiedBuffer("got-" + messageText, CharsetUtil.UTF_8))
133             }
134           }
135         }
136       }
137     }
138   }
139 }