ae8746784999c48b97644ca76898f5f30538768e
[idea/community.git] / update-server-mock / src / main / java / org / jetbrains / updater / mock / Server.kt
1 /*
2  * Copyright (c) 2016 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package org.jetbrains.updater.mock
17
18 import com.sun.net.httpserver.Filter
19 import com.sun.net.httpserver.HttpExchange
20 import com.sun.net.httpserver.HttpHandler
21 import com.sun.net.httpserver.HttpServer
22 import java.io.OutputStream
23 import java.net.HttpURLConnection.HTTP_BAD_REQUEST
24 import java.net.HttpURLConnection.HTTP_OK
25 import java.net.InetSocketAddress
26 import java.time.ZonedDateTime
27 import java.time.format.DateTimeFormatter
28
29 class Server(private val port: Int, private val generator: Generator) {
30   private val server = HttpServer.create()
31
32   fun start() {
33     server.bind(InetSocketAddress("localhost", port), 0)
34     server.handle("/updates/updates.xml", HttpHandler { sendUpdatesXml(it) })
35     server.handle("/patches/", HttpHandler { sendPatch(it) })
36     server.handle("/", HttpHandler { sendText(it, "Mock Update Server") })
37     server.start()
38   }
39
40   private fun sendText(ex: HttpExchange, data: String, type: String = "text/plain", code: Int = HTTP_OK) {
41     val bytes = data.toByteArray()
42     ex.responseHeaders.add("Content-Type", "$type; charset=utf-8")
43     ex.sendResponseHeaders(code, bytes.size.toLong())
44     ex.responseBody.write(bytes)
45     ex.close()
46   }
47
48   private fun sendUpdatesXml(ex: HttpExchange) {
49     var build: String? = null
50     var eap = false
51     ex.requestURI.query?.splitToSequence('&')?.forEach {
52       val p = it.split('=', limit = 2)
53       when (p[0]) {
54         "build" -> build = if (p.size > 1) p[1] else null
55         "eap" -> eap = true
56       }
57     }
58     if (build == null) {
59       sendText(ex, "Parameter missing", code = HTTP_BAD_REQUEST)
60       return
61     }
62
63     val result = "([A-Z]+)-([0-9.]+)".toRegex().find(build!!)
64     val productCode = result?.groups?.get(1)?.value
65     val buildId = result?.groups?.get(2)?.value
66     if (productCode == null || buildId == null) {
67       sendText(ex, "Parameter malformed", code = HTTP_BAD_REQUEST)
68       return
69     }
70
71     val xml = generator.generateXml(productCode, buildId, eap)
72     sendText(ex, xml, "text/xml")
73   }
74
75   private fun sendPatch(ex: HttpExchange) {
76     if (!ex.requestURI.path.endsWith(".jar")) {
77       sendText(ex, "Request malformed", code = HTTP_BAD_REQUEST)
78       return
79     }
80
81     val patch = generator.generatePatch()
82     ex.responseHeaders.add("Content-Type", "binary/octet-stream")
83     ex.sendResponseHeaders(HTTP_OK, patch.size.toLong())
84     ex.responseBody.write(patch)
85     ex.close()
86   }
87 }
88
89 private fun HttpServer.handle(path: String, handler: HttpHandler) {
90   val ctx = createContext(path, handler)
91   ctx.filters += AccessLogFilter()
92 }
93
94 private class AccessLogFilter : Filter() {
95   companion object {
96     private val DTF = DateTimeFormatter.ofPattern("dd/MMM/yyyy:kk:mm:ss ZZ")
97   }
98
99   override fun description() = "Access Log Filter"
100
101   override fun doFilter(ex: HttpExchange, chain: Chain) {
102     val out = CountingOutputStream(ex.responseBody)
103     ex.setStreams(ex.requestBody, out)
104
105     try {
106       chain.doFilter(ex)
107       println("${ex.remoteAddress.address.hostAddress} - - [${DTF.format(ZonedDateTime.now())}]" +
108         " \"${ex.requestMethod} ${ex.requestURI}\" ${ex.responseCode} ${out.count}")
109     }
110     catch(e: Exception) {
111       e.printStackTrace()
112       ex.close()
113     }
114   }
115 }
116
117 private class CountingOutputStream(private val stream: OutputStream) : OutputStream() {
118   var count: Int = 0
119     private set(v) { field = v }
120
121   override fun write(b: Int) {
122     stream.write(b)
123     count += 1
124   }
125
126   override fun write(b: ByteArray, off: Int, len: Int) {
127     stream.write(b, off, len)
128     count += len
129   }
130 }