*/
package com.intellij.updater.mock
-import com.sun.net.httpserver.Filter
-import com.sun.net.httpserver.HttpExchange
-import com.sun.net.httpserver.HttpHandler
import com.sun.net.httpserver.HttpServer
-import java.io.OutputStream
-import java.net.HttpURLConnection.HTTP_BAD_REQUEST
-import java.net.HttpURLConnection.HTTP_OK
+import java.net.HttpURLConnection.*
import java.net.InetSocketAddress
+import java.net.URI
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
class Server(private val port: Int, private val generator: Generator) {
private val server = HttpServer.create()
+ private val buildFormat = "([A-Z]+)-([0-9.]+)".toRegex()
+ private val tsFormat = DateTimeFormatter.ofPattern("dd/MMM/yyyy:kk:mm:ss ZZ")
fun start() {
server.bind(InetSocketAddress("localhost", port), 0)
- server.handle("/updates/updates.xml", HttpHandler { sendUpdatesXml(it) })
- server.handle("/patches/", HttpHandler { sendPatch(it) })
- server.handle("/", HttpHandler { sendText(it, "Mock Update Server") })
- server.start()
- }
-
- private fun sendText(ex: HttpExchange, data: String, type: String = "text/plain", code: Int = HTTP_OK) {
- val bytes = data.toByteArray()
- ex.responseHeaders.add("Content-Type", "$type; charset=utf-8")
- ex.sendResponseHeaders(code, bytes.size.toLong())
- ex.responseBody.write(bytes)
- ex.close()
- }
- private fun sendUpdatesXml(ex: HttpExchange) {
- var build: String? = null
- var eap = false
- ex.requestURI.query?.splitToSequence('&')?.forEach {
- val p = it.split('=', limit = 2)
- when (p[0]) {
- "build" -> build = if (p.size > 1) p[1] else null
- "eap" -> eap = true
+ server.createContext("/") { ex ->
+ val response = try {
+ process(ex.requestMethod, ex.requestURI)
+ }
+ catch(e: Exception) {
+ e.printStackTrace()
+ Response(HTTP_INTERNAL_ERROR, "Internal error")
}
- }
- if (build == null) {
- sendText(ex, "Parameter missing", code = HTTP_BAD_REQUEST)
- return
- }
-
- val result = "([A-Z]+)-([0-9.]+)".toRegex().find(build!!)
- val productCode = result?.groups?.get(1)?.value
- val buildId = result?.groups?.get(2)?.value
- if (productCode == null || buildId == null) {
- sendText(ex, "Parameter malformed", code = HTTP_BAD_REQUEST)
- return
- }
- val xml = generator.generateXml(productCode, buildId, eap)
- sendText(ex, xml, "text/xml")
- }
+ val contentType = if (response.type.startsWith("text/")) response.type + "; charset=utf-8" else response.type
+ ex.responseHeaders.add("Content-Type", contentType)
+ ex.sendResponseHeaders(response.code, response.bytes.size.toLong())
+ ex.responseBody.write(response.bytes)
+ ex.close()
- private fun sendPatch(ex: HttpExchange) {
- if (!ex.requestURI.path.endsWith(".jar")) {
- sendText(ex, "Request malformed", code = HTTP_BAD_REQUEST)
- return
+ println("${ex.remoteAddress.address.hostAddress} - - [${tsFormat.format(ZonedDateTime.now())}]" +
+ " \"${ex.requestMethod} ${ex.requestURI}\" ${ex.responseCode} ${response.bytes.size}")
}
- val patch = generator.generatePatch()
- ex.responseHeaders.add("Content-Type", "binary/octet-stream")
- ex.sendResponseHeaders(HTTP_OK, patch.size.toLong())
- ex.responseBody.write(patch)
- ex.close()
- }
-}
-
-private fun HttpServer.handle(path: String, handler: HttpHandler) {
- val ctx = createContext(path, handler)
- ctx.filters += AccessLogFilter()
-}
-
-private class AccessLogFilter : Filter() {
- companion object {
- private val DTF = DateTimeFormatter.ofPattern("dd/MMM/yyyy:kk:mm:ss ZZ")
+ server.start()
}
- override fun description() = "Access Log Filter"
-
- override fun doFilter(ex: HttpExchange, chain: Chain) {
- val out = CountingOutputStream(ex.responseBody)
- ex.setStreams(ex.requestBody, out)
-
- try {
- chain.doFilter(ex)
- println("${ex.remoteAddress.address.hostAddress} - - [${DTF.format(ZonedDateTime.now())}]" +
- " \"${ex.requestMethod} ${ex.requestURI}\" ${ex.responseCode} ${out.count}")
- }
- catch(e: Exception) {
- e.printStackTrace()
- ex.close()
+ private fun process(method: String, uri: URI): Response {
+ val path = uri.path
+ return when {
+ method != "GET" -> Response(HTTP_BAD_REQUEST, "Didn't get")
+ path == "/" -> Response(HTTP_OK, "Mock Update Server")
+ path == "/updates/updates.xml" -> xml(uri.query ?: "")
+ path.startsWith("/patches/") -> patch(path)
+ else -> Response(HTTP_NOT_FOUND, "Miss")
}
}
-}
-private class CountingOutputStream(private val stream: OutputStream) : OutputStream() {
- var count: Int = 0
- private set(v) { field = v }
+ private fun xml(query: String): Response {
+ val parameters = query.splitToSequence('&')
+ .filter { it.startsWith("build") || it.startsWith("eap") }
+ .map { it.split('=', limit = 2) }
+ .map { it[0] to if (it.size > 1) it[1] else "" }
+ .toMap()
+
+ val build = parameters["build"]
+ if (build != null) {
+ val match = buildFormat.find(build)
+ val productCode = match?.groups?.get(1)?.value
+ val buildId = match?.groups?.get(2)?.value
+ if (productCode != null && buildId != null) {
+ val xml = generator.generateXml(productCode, buildId, "eap" in parameters)
+ return Response(HTTP_OK, "text/xml", xml.toByteArray())
+ }
+ }
- override fun write(b: Int) {
- stream.write(b)
- count += 1
+ return Response(HTTP_BAD_REQUEST, "Bad parameters")
}
- override fun write(b: ByteArray, off: Int, len: Int) {
- stream.write(b, off, len)
- count += len
+ private fun patch(path: String): Response = when {
+ path.endsWith(".jar") -> Response(HTTP_OK, "binary/octet-stream", generator.generatePatch())
+ else -> Response(HTTP_BAD_REQUEST, "Bad path")
}
- override fun flush() = stream.flush()
-
- override fun close() = stream.close()
+ private class Response(val code: Int, val type: String, val bytes: ByteArray) {
+ constructor(code: Int, text: String) : this(code, "text/plain", text.toByteArray())
+ }
}
\ No newline at end of file