0659a9c0e9a32fad0e10db4d4f47a9bd9681807d
[idea/community.git] / platform / built-in-server / testSrc / BinaryRequestHandlerTest.kt
1 package org.jetbrains.ide
2
3 import io.netty.channel.ChannelInitializer
4 import io.netty.channel.Channel
5 import org.jetbrains.io.Decoder
6 import io.netty.channel.ChannelHandlerContext
7 import io.netty.buffer.ByteBuf
8 import com.intellij.util.Consumer
9 import java.util.UUID
10 import io.netty.channel.ChannelHandler
11 import com.intellij.openapi.util.AsyncResult
12 import io.netty.util.CharsetUtil
13 import org.jetbrains.io.ChannelExceptionHandler
14 import org.jetbrains.io.NettyUtil
15 import com.intellij.util.net.NetUtils
16 import io.netty.buffer.Unpooled
17 import junit.framework.TestCase
18 import org.junit.rules.RuleChain
19 import org.junit.Rule
20 import org.junit.Test
21 import org.jetbrains.io.MessageDecoder
22
23 public class BinaryRequestHandlerTest {
24   private val fixtureManager = FixtureRule()
25
26   private val _chain = RuleChain
27       .outerRule(fixtureManager)
28
29   Rule
30   public fun getChain(): RuleChain = _chain
31
32   Test
33   public fun test() {
34     val text = "Hello!"
35     val result = AsyncResult<String>()
36
37     val bootstrap = NettyUtil.oioClientBootstrap().handler(object : ChannelInitializer<Channel>() {
38       override fun initChannel(channel: Channel) {
39         channel.pipeline().addLast(object : Decoder() {
40           override fun messageReceived(context: ChannelHandlerContext, input: ByteBuf) {
41             val requiredLength = 4 + text.length()
42             val response = readContent(input, context, requiredLength) {(buffer, context, isCumulateBuffer) -> buffer.toString(buffer.readerIndex(), requiredLength, CharsetUtil.UTF_8) }
43             if (response != null) {
44               result.setDone(response)
45             }
46           }
47         }, ChannelExceptionHandler.getInstance())
48       }
49     })
50
51     val port = BuiltInServerManager.getInstance().waitForStart().getPort()
52     val channel = bootstrap.connect(NetUtils.getLoopbackAddress(), port).syncUninterruptibly().channel()
53     val buffer = channel.alloc().buffer()
54     buffer.writeByte('C'.toInt())
55     buffer.writeByte('H'.toInt())
56     buffer.writeLong(MyBinaryRequestHandler.ID.getMostSignificantBits())
57     buffer.writeLong(MyBinaryRequestHandler.ID.getLeastSignificantBits())
58
59     val message = Unpooled.copiedBuffer(text, CharsetUtil.UTF_8)
60     buffer.writeShort(message.readableBytes())
61
62     channel.write(buffer)
63     channel.writeAndFlush(message).syncUninterruptibly()
64
65     try {
66       result.doWhenRejected(object : Consumer<String> {
67         override fun consume(error: String) {
68           TestCase.fail(error)
69         }
70       })
71
72       TestCase.assertEquals("got-" + text, result.getResultSync(5000))
73     }
74     finally {
75       channel.close()
76     }
77   }
78
79   class MyBinaryRequestHandler : BinaryRequestHandler() {
80     class object {
81       val ID = UUID.fromString("E5068DD6-1DB7-437C-A3FC-3CA53B6E1AC9")
82     }
83
84     override fun getId(): UUID {
85       return ID
86     }
87
88     override fun getInboundHandler(context: ChannelHandlerContext): ChannelHandler {
89       return MyDecoder()
90     }
91
92     private class MyDecoder : MessageDecoder() {
93       private var state = State.HEADER
94
95       private enum class State {
96         HEADER
97         CONTENT
98       }
99
100       override fun messageReceived(context: ChannelHandlerContext, input: ByteBuf) {
101         while (true) {
102           when (state) {
103             State.HEADER -> {
104               val buffer = getBufferIfSufficient(input, 2, context)
105               if (buffer == null) {
106                 return
107               }
108
109               contentLength = buffer.readUnsignedShort()
110               state = State.CONTENT
111             }
112
113             State.CONTENT -> {
114               val messageText = readChars(input)
115               if (messageText == null) {
116                 return
117               }
118
119               state = State.HEADER
120               context.writeAndFlush(Unpooled.copiedBuffer("got-" + messageText, CharsetUtil.UTF_8))
121             }
122           }
123         }
124       }
125     }
126   }
127 }