/*
 * Copyright (c) 2017, Per Rovegård <per@rovegard.se>
 *
 * Distributed under the MIT License (license terms are at http://per.mit-license.org, or in the LICENSE file at
 * https://github.com/provegard/tinyws/blob/master/LICENSE).
 */

package com.programmaticallyspeaking.tinyws;

import javax.net.ssl.*;
import java.io.*;
import java.net.*;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.*;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
 * A WebSocket server. Usage:
 *
 * <ol>
 * <li>Create an instance of this class.
 * <li>Add one or more handler factories using the {@link Server#addHandlerFactory(String, Supplier)} method.
 * <li>Start the server using {@link Server#start()}.
 * <li>Connect clients...
 * <li>Stop using {@link Server#stop()}.
 * </ol>
 *
 * The server implementation passes all tests of <a href="https://github.com/crossbario/autobahn-testsuite">
 * Autobahn|Testsuite</a> (version 0.10.9) except 12.* and 13.* (compression using the permessage-deflate extension).
 */
public class Server {
    public static final String ServerName = "TinyWS Server";
    public static final String ServerVersion = "@VERSION@";

    private static final String HANDSHAKE_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    private static final int SupportedVersion = 13;

    private final Executor mainExecutor;
    private final Options options;
    private final Logger logger;

    private ServerSocket serverSocket;
    private Map<String, Supplier<WebSocketHandler>> handlerFactories = new HashMap<>();

    private FallbackHandler fallbackHandler = new DefaultFallbackHandler();

    /**
     * Constructs a new server instance but doesn't start listening for client connections.
     *
     * @param mainExecutor the {@link Executor} instance that will be used to run the main listener task as well as
     *                     tasks for handling connected clients. Please note that each task will use excessive blocking
     *                     I/O, so use an appropriate executor.
     * @param options server options
     */
    public Server(Executor mainExecutor, Options options) {
        this.mainExecutor = mainExecutor;
        this.options = options;
        this.logger = new Logger() {
            public void log(LogLevel level, String message, Throwable error) {
                if (isEnabledAt(level)) {
                    try {
                        options.logger.log(level, message, error);
                    } catch (Exception ignore) {
                        // ignore logging errors
                    }
                }
            }

            @Override
            public boolean isEnabledAt(LogLevel level) {
                return options.logger != null && options.logger.isEnabledAt(level);
            }
        };
    }

    private void lazyLog(LogLevel level, Supplier<String> msgFun) {
        if (logger.isEnabledAt(level)) logger.log(level, msgFun.get(), null);
    }

    /**
     * Adds a factory for creating handlerFactories for a specific endpoint. Handler factories must be added before the server
     * is started. The endpoint must match a requested resource exactly to be used. The root handler factory must thus
     * be registered for the "/" endpoint.
     *
     * @param endpoint non-{@code null}, non-empty endpoint
     * @param handlerFactory a handler factory
     * @exception IllegalStateException if the server has been started
     */
    public void addHandlerFactory(String endpoint, Supplier<WebSocketHandler> handlerFactory) {
        if (endpoint == null || "".equals(endpoint)) throw new IllegalArgumentException("Endpoint must be non-empty.");
        if (serverSocket != null) throw new IllegalStateException("Please add handler factories before starting the server.");
        handlerFactories.put(endpoint, handlerFactory);
    }

    /**
     * Sets the fallback handler to use for endpoints with no registered handler factory. Fallback handler must be set
     * before the server is started.
     *
     * @param handler the fallback handler; {@code null} restores the default handler (which responds 404)
     */
    public void setFallbackHandler(FallbackHandler handler) {
        if (serverSocket != null) throw new IllegalStateException("Please add fallback handler before starting the server.");
        if (handler == null) handler = new DefaultFallbackHandler();
        fallbackHandler = handler;
    }

    /**
     * Starts listening for client connections, using the port specified in the options passed to the constructor. If
     * a backlog was not specified in the options, the Java-default backlog (50 for Java 8) is used.
     *
     * The server socket is created on the current thread, in the interest of fail-fast. The main executor is then
     * used to start the listening task.
     *
     * @exception IOException if creating the server socket fails
     * @exception GeneralSecurityException if an SSL related error occurs
     */
    public void start() throws IOException, GeneralSecurityException {
        serverSocket = createServerSocket();
        mainExecutor.execute(this::acceptInLoop);
    }

    private ServerSocket createServerSocket() throws IOException, GeneralSecurityException {
        // Using backlog 0 will force ServerSocket to use the default (50).
        int backlog = options.backlog != null ? options.backlog : 0;

        if (options.shouldUseSSL()) {
            SSLServerSocketFactory sslServerSocketFactory = options.sslContext.getServerSocketFactory();
            return sslServerSocketFactory.createServerSocket(options.port, backlog, options.address);
        }

        return new ServerSocket(options.port, backlog, options.address);
    }

    /**
     * Stops listening for client connection. Does nothing if the server has already been stopped (or hasn't been started
     * to begin with). Any exception from closing the server socket is suppressed but will be logged at WARN level to
     * any supplied logger.
     */
    public void stop() {
        if (serverSocket == null) return;
        try {
            serverSocket.close();
        } catch (IOException e) {
            logger.log(LogLevel.WARN, "Failed to close server socket.", e);
        }
        serverSocket = null;
    }

    private void acceptInLoop() {
        try {
            lazyLog(LogLevel.INFO, () -> "Receiving WebSocket clients at " + serverSocket.getLocalSocketAddress());

            while (true) {
                Socket clientSocket = serverSocket.accept();
                
                // We need this on Linux. Without it the close frame sent just before closing the
                // socket won't be seen by the WebSocket client.
                clientSocket.setTcpNoDelay(true);

                mainExecutor.execute(new ClientHandler(clientSocket));
            }
        } catch (SocketException e) {
            logger.log(LogLevel.DEBUG, "Server socket was closed, probably because the server was stopped.", e);
        } catch (Exception ex) {
            logger.log(LogLevel.ERROR, "Error accepting a client socket.", ex);
        }
    }

    private class ClientHandler implements Runnable {

        private final Socket clientSocket;
        private final OutputStream out;
        private final InputStream in;
        private final PayloadCoder payloadCoder;
        private final FrameWriter frameWriter;
        private WebSocketHandler handler;
        private volatile boolean isClosed; // potentially set from handler thread

        ClientHandler(Socket clientSocket) throws IOException {
            this.clientSocket = clientSocket;
            out = clientSocket.getOutputStream();
            in = clientSocket.getInputStream();

            payloadCoder = new PayloadCoder();
            frameWriter = new FrameWriter(out, payloadCoder, options.maxFrameSize);
        }

        private void invokeHandler(Consumer<WebSocketHandler> fun) {
            if (handler == null) return;
            try {
                fun.accept(handler);
            } catch (Exception ex) {
                logger.log(LogLevel.ERROR, "Handler invocation error.", ex);
            }
        }

        @Override
        public void run() {
            try {
                communicate();
            } catch (WebSocketClosure ex) {
                lazyLog(LogLevel.DEBUG, () -> String.format("Closing with code %d (%s)%s", ex.code, ex.reason,
                        ex.debugDetails != null ? (" because: " + ex.debugDetails) : ""));
                doIgnoringExceptions(() -> frameWriter.writeClose(ex.code, ex.reason));
                // If the connection was closed by the client, we expect onClosedByClient to have been invoked and
                // we must *not* invoke onClosedByServer since that would be a lie...
                if (!ex.closedByClient)
                    invokeHandler(h -> h.onClosedByServer(ex.code, ex.reason));
            } catch (MethodNotAllowedException ex) {
                lazyLog(LogLevel.WARN, () -> String.format("WebSocket client from %s used a non-allowed method: %s",
                            clientSocket.getRemoteSocketAddress(), ex.method));
                sendMethodNotAllowedResponse();
            } catch (IllegalArgumentException ex) {
                lazyLog(LogLevel.WARN, () -> String.format("WebSocket client from %s sent a malformed request: %s",
                        clientSocket.getRemoteSocketAddress(), ex.getMessage()));
                sendBadRequestResponse();
            } catch (FileNotFoundException ex) {
                lazyLog(LogLevel.WARN, () -> String.format("WebSocket client from %s requested an unknown endpoint.",
                        clientSocket.getRemoteSocketAddress()));
                sendNotFoundResponse();
            } catch (SocketException ex) {
                if (!isClosed) {
                    logger.log(LogLevel.ERROR, "Client socket error.", ex);
                    invokeHandler(h -> h.onFailure(ex));
                }
            } catch (Exception ex) {
                logger.log(LogLevel.ERROR, "Client communication error.", ex);
                invokeHandler(h -> h.onFailure(ex));
            }
            abort();
        }

        private void abort() {
            if (isClosed) return;
            doIgnoringExceptions(clientSocket::close);
            isClosed = true;
        }

        private void communicate() throws IOException, NoSuchAlgorithmException {
            maybeLogSSLDetails();

            Headers headers = Headers.read(in, isSSL());
            String endpoint = headers.endpoint;

            Supplier<WebSocketHandler> handlerFactory = handlerFactories.get(endpoint);
            if (handlerFactory == null || (handler = handlerFactory.get()) == null) {
                fallbackHandler.handle(createConnection(headers));
                return;
            }

            if (!"GET".equals(headers.method)) throw new MethodNotAllowedException(headers.method);
            if (!headers.isProperUpgrade()) throw new IllegalArgumentException("Handshake has malformed upgrade.");
            if (headers.version() != SupportedVersion) throw new IllegalArgumentException("Bad version, must be: " + SupportedVersion);

            lazyLog(LogLevel.INFO, () -> String.format("New WebSocket client from %s at endpoint '%s'.",
                        clientSocket.getRemoteSocketAddress(), endpoint));

            invokeHandler(h -> h.onOpened(new WebSocketClientImpl(frameWriter, this::abort, headers)));

            String key = headers.key();
            if (key == null) throw new IllegalArgumentException("Missing Sec-WebSocket-Key in handshake.");

            String responseKey = createResponseKey(key);

            lazyLog(LogLevel.TRACE, () -> String.format("Opening handshake key is '%s', sending response key '%s'.", key, responseKey));

            sendHandshakeResponse(responseKey);

            List<Frame> frameBatch = new ArrayList<>();
            while (true) {
                frameBatch.add(Frame.read(in));
                handleBatch(frameBatch);
            }
        }

        private Connection createConnection(Headers headers) {
            return new Connection() {
                @Override public String method() { return headers.method; }
                @Override public URI uri() { return headers.uri; }
                @Override public InputStream inputStream() { return in; }
                @Override public OutputStream outputStream() { return out; }
                @Override public Iterable<String> headerNames() { return headers.headers.keySet(); }
                @Override public Optional<String> header(String name) {
                    if (name == null) throw new IllegalArgumentException("Header name must be non-null");
                    String value = headers.headers.get(name);
                    return Optional.ofNullable(value);
                }
                @Override
                public void sendResponse(int statusCode, String reason, Map<String, String> headers) {
                    if (reason == null) reason = "";
                    if (headers == null) headers = Collections.emptyMap();
                    ClientHandler.this.sendResponse(statusCode, reason, headers);
                }
            };
        }

        private boolean isSSL() {
            return clientSocket instanceof SSLSocket;
        }

        private void maybeLogSSLDetails() {
            if (isSSL()) {
                SSLSocket sslSocket = (SSLSocket) clientSocket;
                SSLSession sslSession = sslSocket.getSession();
                lazyLog(LogLevel.DEBUG, () -> String.format("SSL session uses protocol %s and cipher suite %s.",
                        sslSession.getProtocol(), sslSession.getCipherSuite()));
            }
        }

        private void handleBatch(List<Frame> frameBatch) throws IOException {
            Frame firstFrame = frameBatch.get(0);

            if (firstFrame.opCode == 0) throw WebSocketClosure.protocolError("Continuation frame with nothing to continue.");

            Frame lastOne = frameBatch.get(frameBatch.size() - 1);
            lazyLog(LogLevel.TRACE, lastOne::toString);
            if (!lastOne.isFin) return;
            if (firstFrame != lastOne) {
                if (lastOne.isControl()) {
                    // Interleaved control frame
                    frameBatch.remove(frameBatch.size() - 1);
                    handleResultFrame(lastOne);
                    return;
                } else if (lastOne.opCode > 0) {
                    throw WebSocketClosure.protocolError("Continuation frame must have opcode 0.");
                }
            }

            Frame result = frameBatch.size() > 1 ? Frame.merge(frameBatch) : lastOne;

            frameBatch.clear();

            handleResultFrame(result);
        }

        private void handleResultFrame(Frame result) throws IOException {
            switch (result.opCode) {
                case 1:
                    CharSequence data = payloadCoder.decode(result.payloadData);
                    invokeHandler(h -> h.onTextMessage(data));
                    break;
                case 2:
                    invokeHandler(h -> h.onBinaryData(result.payloadData));
                    break;
                case 8:
                    CloseData cd = result.toCloseData(payloadCoder);

                    if (cd.hasInvalidCode()) throw WebSocketClosure.protocolError("Invalid close frame code: " + cd.code);

                    // 1000 is normal close
                    int i = cd.code != null ? cd.code : 1000;

                    invokeHandler(h -> h.onClosedByClient(i, cd.reason));

                    throw WebSocketClosure.fromClient(i);
                case 9:
                    // Ping, send pong!
                    logger.log(LogLevel.TRACE, "Got ping frame, sending pong.", null);
                    frameWriter.writePong(result.payloadData);
                    break;
                case 10:
                    // Pong is ignored
                    logger.log(LogLevel.TRACE, "Ignoring unsolicited pong frame.", null);
                    break;
                default:
                    throw WebSocketClosure.protocolError("Invalid opcode: " + result.opCode);
            }
        }

        private void outputLine(PrintWriter writer, String data) {
            writer.print(data);
            writer.print("\r\n");
        }

        private void sendHandshakeResponse(String responseKey) {
            Map<String, String> headers = new HashMap<String, String>() {{
                put("Upgrade", "websocket");
                put("Connection", "upgrade");
                put("Sec-WebSocket-Accept", responseKey);
            }};
            sendEmptyResponseBeforeClose(101, "Switching Protocols", headers);
        }

        private void sendBadRequestResponse() {
            // Advertise supported version regardless of what was bad. A bit lazy, but simple.
            Map<String, String> headers = new HashMap<String, String>() {{
                put("Sec-WebSocket-Version", Integer.toString(SupportedVersion));
            }};
            sendEmptyResponseBeforeClose(400, "Bad Request", headers);
        }
        private void sendMethodNotAllowedResponse() {
            Map<String, String> headers = new HashMap<String, String>() {{
                put("Allow", "GET");
            }};
            sendEmptyResponseBeforeClose(405, "Method Not Allowed", headers);
        }
        private void sendNotFoundResponse() {
            sendEmptyResponseBeforeClose(404, "Not Found", Collections.emptyMap());
        }

        private void sendEmptyResponseBeforeClose(int statusCode, String reason, Map<String, String> headers) {
            Map<String, String> allHeaders = new HashMap<>(headers);
            // Headers added when we don't do a connection upgrade to WebSocket!
            if (statusCode >= 200) {
                // https://tools.ietf.org/html/rfc7230#section-6.1
                allHeaders.put("Connection", "close");

                // https://tools.ietf.org/html/rfc7230#section-3.3.2
                allHeaders.put("Content-Length", "0");
            }

            sendResponse(statusCode, reason, allHeaders);
        }
        private void sendResponse(int statusCode, String reason, Map<String, String> headers) {
            PrintWriter writer = new PrintWriter(out, false);
            outputLine(writer, "HTTP/1.1 " + statusCode + " " + reason);
            for (Map.Entry<String, String> entry : headers.entrySet()) {
                outputLine(writer, entry.getKey() + ": " + entry.getValue());
            }

            // https://tools.ietf.org/html/rfc7231#section-7.4.2
            outputLine(writer, String.format("Server: %s %s", ServerName, ServerVersion));

            outputLine(writer, "");
            writer.flush();
            // Note: Do NOT close the writer, as the stream must remain open
        }
    }

    static class Frame {

        final int opCode;
        final byte[] payloadData;
        final boolean isFin;

        public String toString() {
            return String.format("Frame[opcode=%d, control=%b, payload length=%d, fragmented=%b]",
                    opCode, isControl(), payloadData.length, !isFin);
        }

        boolean isControl() {
            return (opCode & 8) == 8;
        }

        private Frame(int opCode, byte[] payloadData, boolean isFin) {
            this.opCode = opCode;
            this.payloadData = payloadData;
            this.isFin = isFin;
        }

        private static int toUnsigned(byte b) {
            int result = b;
            if (result < 0) result += 256;
            return result;
        }

        private static byte[] readBytes(InputStream in, int len, byte[] target) throws IOException {
            assert target == null || target.length >= len : "readBytes target is too small";
            byte[] buf = target != null ? target : new byte[len];
            int totalRead = 0;
            int offs = 0;
            while (totalRead < len) {
                int readLen = in.read(buf, offs, len - offs);
                if (readLen < 0) break;
                totalRead += readLen;
                offs += readLen;
            }
            if (totalRead != len) throw new IOException("Expected to read " + len + " bytes but read " + totalRead);
            return buf;
        }

        private static long toLong(byte[] data, int offset, int len) {
            long result = 0;
            for (int i = offset, j = offset + len; i < j; i++) {
                result = (result << 8) + toUnsigned(data[i]);
            }
            return result;
        }

        static Frame read(InputStream in) throws IOException {
            // We will read at most 8 bytes at any time, except the payload data.
            byte[] buf = new byte[8];

            // Read first 2 bytes
            readBytes(in, 2, buf);
            byte firstByte = buf[0];
            byte secondByte = buf[1];
            boolean isFin = (firstByte & 128) == 128;
            boolean hasZeroReserved = (firstByte & 112) == 0;
            if (!hasZeroReserved) throw WebSocketClosure.protocolError("Non-zero reserved bits in 1st byte: " + (firstByte & 112));
            int opCode = (firstByte & 15);
            boolean isControlFrame = (opCode & 8) == 8;
            boolean isMasked = (secondByte & 128) == 128;
            int len = (secondByte & 127);
            if (isControlFrame) {
                if (len > 125) throw WebSocketClosure.protocolError("Control frame length exceeding 125 bytes.");
                if (!isFin) throw WebSocketClosure.protocolError("Fragmented control frame.");
            }
            if (len == 126) {
                // 2 bytes of extended len
                long tmp = toLong(readBytes(in, 2, buf), 0, 2);
                len = (int) tmp;
            } else if (len == 127) {
                // 8 bytes of extended len
                long tmp = toLong(readBytes(in, 8, buf), 0, 8);
                if (tmp > Integer.MAX_VALUE) throw WebSocketClosure.protocolError("Frame length greater than 0x7fffffff not supported.");
                len = (int) tmp;
            }
            byte[] maskingKey = isMasked ? readBytes(in, 4, buf) : null;
            byte[] payloadData = unmaskIfNeededInPlace(readBytes(in, len, null), maskingKey);
            return new Frame(opCode, payloadData, isFin);
        }

        CloseData toCloseData(PayloadCoder payloadCoder) throws WebSocketClosure {
            if (opCode != 8) throw new IllegalStateException("Not a close frame: " + opCode);
            if (payloadData.length == 0) return new CloseData(null, null);
            if (payloadData.length == 1) throw WebSocketClosure.protocolError("Invalid close frame payload length (1).");
            int code = (int) toLong(payloadData, 0, 2);
            CharSequence reason = payloadData.length > 2 ? payloadCoder.decode(payloadData, 2, payloadData.length - 2) : null;
            return new CloseData(code, reason != null ? reason.toString() : null);
        }

        static Frame merge(List<Frame> frameBatch) {                // Combine payloads!
            int totalLength = frameBatch.stream().mapToInt(f -> f.payloadData.length).sum();
            byte[] allTheData = new byte[totalLength];
            int offs = 0;
            for (Frame frame : frameBatch) {
                System.arraycopy(frame.payloadData, 0, allTheData, offs, frame.payloadData.length);
                offs += frame.payloadData.length;
            }
            return new Frame(frameBatch.get(0).opCode, allTheData, true);
        }
    }

    private static class CloseData {
        private final Integer code;
        private final String reason;

        CloseData(Integer code, String reason) {
            this.code = code;
            this.reason = reason;
        }

        boolean hasInvalidCode() {
            if (code == null) return false; // no code isn't invalid
            if (code < 1000 || code >= 5000) return true;
            if (code >= 3000) return false; // 3000-3999 and 4000-4999 are valid
            return code == 1004 || code == 1005 || code == 1006 || code > 1011;
        }
    }

    static class Headers {
        private final Map<String, String> headers;
        final String endpoint;
        final String method;
        final String query;
        final String fragment;
        final URI uri;

        private Headers(Map<String, String> headers, URI uri, String method) {
            this.headers = headers;
            this.method = method;
            this.uri = uri;
            this.endpoint = uri.getPath();
            this.query = uri.getQuery();
            this.fragment = uri.getFragment();
        }

        boolean isProperUpgrade() {
            return "websocket".equalsIgnoreCase(headers.get("Upgrade")) && "Upgrade".equalsIgnoreCase(headers.get("Connection"));
        }
        int version() {
            String versionStr = headers.get("Sec-WebSocket-Version");
            try {
                return Integer.parseInt(versionStr);
            } catch (Exception ignore) {
                return 0;
            }
        }

        String key() { return headers.get("Sec-WebSocket-Key"); }
        String userAgent() { return headers.get("User-Agent"); }
        String host() { return headers.get("Host"); }

        static Headers read(InputStream in, boolean isSSL) throws IOException {
            BufferedReader reader = new BufferedReader(new InputStreamReader(in));
            String inputLine, method = "";
            String path = null;
            URI endpoint = null;
            Map<String, String> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
            boolean isFirstLine = true;
            while (!"".equals((inputLine = reader.readLine())) && inputLine != null) {
                if (isFirstLine) {
                    String[] parts = inputLine.split(" ", 3);
                    if (parts.length != 3) throw new IllegalArgumentException("Malformed 1st header line: " + inputLine);
                    method = parts[0];
                    if (!"HTTP/1.1".equals(parts[2])) throw new IllegalArgumentException("Only HTTP/1.1 is supported");
                    path = parts[1];
                    isFirstLine = false;
                }

                String[] keyValue = inputLine.split(":", 2);
                if (keyValue.length != 2) continue;

                headers.put(keyValue[0], keyValue[1].trim());
            }
            if (isFirstLine) throw new IllegalArgumentException("No HTTP headers received");

            // Figure out the entire endpoint
            try {
                String scheme = isSSL ? "https" : "http";
                String host = headers.get("host");
                if (host == null) host = "server";
                endpoint = new URI(scheme + "://" + host + path);
            } catch (URISyntaxException e) {
                throw new IllegalArgumentException("Invalid endpoint: " + path);
            }

            // Note: Do NOT close the reader, because the stream must remain open!
            return new Headers(headers, endpoint, method);
        }
    }

    static class PayloadCoder {
        private final Charset charset = StandardCharsets.UTF_8;
        private final CharsetDecoder decoder = charset.newDecoder();
        private final CharsetEncoder encoder = charset.newEncoder();

        CharSequence decode(byte[] bytes) throws WebSocketClosure {
            return decode(bytes, 0, bytes.length);
        }

        /**
         * Decodes the given byte data as UTF-8 and returns the result as a string.
         *
         * @param bytes the byte array
         * @param offset offset into the array where to start decoding
         * @param len length of data to decode
         * @return the decoded string
         * @throws WebSocketClosure (1007) thrown if the data are not valid UTF-8
         */
        CharSequence decode(byte[] bytes, int offset, int len) throws WebSocketClosure {
            decoder.reset();
            try {
                return decoder.decode(ByteBuffer.wrap(bytes, offset, len));
            } catch (Exception ex) {
                throw WebSocketClosure.invalidFramePayloadData();
            }
        }

        ByteBuffer encode(CharSequence s) throws CharacterCodingException {
            encoder.reset();
            ByteBuffer buf = encoder.encode(CharBuffer.wrap(s));
            assert buf.hasArray() : "Expected ByteBuffer to have an array";
            return buf;
        }
    }

    static class WebSocketClosure extends IOException {
        final int code;
        final String reason;
        final String debugDetails;
        final boolean closedByClient;

        private WebSocketClosure(int code, String reason, String debugDetails, boolean closedByClient) {
            this.code = code;
            this.reason = reason;
            this.debugDetails = debugDetails;
            this.closedByClient = closedByClient;
        }

        static WebSocketClosure fromClient(int code) {
            return new WebSocketClosure(code, "", "Closed by client", true);
        }
        static WebSocketClosure protocolError(String debugDetails) {
            return new WebSocketClosure(1002, "Protocol error", debugDetails, false);
        }
        static WebSocketClosure invalidFramePayloadData() {
            return new WebSocketClosure(1007, "Invalid frame payload data", null, false);
        }

        @Override
        public synchronized Throwable fillInStackTrace() {
            // Stack trace is not relevant
            return this;
        }
    }

    static class FrameWriter {
        private final OutputStream out;
        private final PayloadCoder payloadCoder;
        private final int maxFrameSize;

        // Reusable array for writing length bytes
        private final byte[] lengthBytes = new byte[8];

        FrameWriter(OutputStream out, PayloadCoder payloadCoder, int maxFrameSize) {
            this.out = out;
            this.payloadCoder = payloadCoder;
            this.maxFrameSize = maxFrameSize;
        }

        void writeClose(int code, String reason) throws IOException {
            ByteBuffer buf = payloadCoder.encode(reason);
            int bufLen = buf.limit();
            byte[] combined = new byte[2 + bufLen];
            numberToBytes(code, 2, combined);
            buf.get(combined, 2, bufLen);
            writeFrame(8, combined);
        }

        void writeText(CharSequence text) throws IOException {
            ByteBuffer buf = payloadCoder.encode(text);
            writePossiblyFragmentedFrames(1, buf);
        }

        void writeBinary(byte[] data) throws IOException {
            writePossiblyFragmentedFrames(2, ByteBuffer.wrap(data));
        }

        void writePing(byte[] data) throws IOException {
            writeFrame(9, data);
        }

        void writePong(byte[] data) throws IOException {
            writeFrame(10, data);
        }

        private void writePossiblyFragmentedFrames(int opCode, ByteBuffer buf) throws IOException {
            // https://tools.ietf.org/html/rfc6455#section-5.6 implies that a single frame may contain an UTF-8
            // sequence that by itself is invalid, as long as the entire message text is valid UTF-8.
            int bufLen = buf.limit();
            byte[] data = buf.array();
            if (maxFrameSize == 0 || bufLen <= maxFrameSize) {
                writeFrame(opCode, data, bufLen, 0, bufLen);
            } else {
                int offset = 0;
                while (offset < bufLen) {
                    int len = Math.min(bufLen - offset, maxFrameSize);
                    writeFrame(opCode, data, bufLen, offset, len);
                    offset += len;
                }
            }
        }

        private void writeFrame(int opCode, byte[] data) throws IOException {
            int dataLen = data != null ? data.length : 0;
            writeFrame(opCode, data, dataLen, 0, dataLen);
        }

        /**
         * Writes a frame to the output stream. Since FrameWriter is handed out to potentially different threads,
         * this method is synchronized.
         *
         * @param opCode the opcode of the frame
         * @param data array that contains frame data
         * @param totalLen total data length (differs from {@code len} when data are split across multiple frames
         * @param offset offset in the {@code data} array where the frame data starts
         * @param len length of frame data
         * @throws IOException thrown if writing to the socket fails
         */
        synchronized private void writeFrame(int opCode, byte[] data, int totalLen, int offset, int len) throws IOException {
            boolean isFirstFrame = offset == 0;
            boolean isFinalFrame = offset + len == totalLen;

            int firstByte = isFirstFrame ? opCode : 0;
            if (isFinalFrame) firstByte |= 128; // FIN
            int secondByte;
            int extraLengthBytes = 0;
            if (len < 126) {
                secondByte = len;
            } else if (len < 65536) {
                secondByte = 126;
                extraLengthBytes = 2;
            } else {
                secondByte = 127;
                extraLengthBytes = 8;
            }
            out.write(firstByte);
            out.write(secondByte);
            if (extraLengthBytes > 0) {
                out.write(numberToBytes(len, extraLengthBytes, lengthBytes), 0, extraLengthBytes);
            }
            if (data != null) out.write(data, offset, len);
            out.flush();
        }
    }

    private static class WebSocketClientImpl implements WebSocketClient {

        private final FrameWriter writer;
        private final Runnable closeCallback;
        private final Headers headers;

        WebSocketClientImpl(FrameWriter writer, Runnable closeCallback, Headers headers) {
            this.writer = writer;
            this.closeCallback = closeCallback;
            this.headers = headers;
        }

        public void ping() throws IOException {
            writer.writePing(null);
        }

        public void close() {
            doIgnoringExceptions(() -> {
                writer.writeClose(1001, "Going Away");
                closeCallback.run();
            });
        }

        public void sendTextMessage(CharSequence text) throws IOException {
            if (text == null) throw new IllegalArgumentException("Cannot send null text");
            writer.writeText(text);
        }

        public void sendBinaryData(byte[] data) throws IOException {
            if (data == null) throw new IllegalArgumentException("Cannot send null data");
            writer.writeBinary(data);
        }

        public String userAgent() { return headers.userAgent(); }
        public String host() { return headers.host(); }
        public String query() { return headers.query; }
        public String fragment() { return headers.fragment; }
    }

    static byte[] numberToBytes(int number, int len, byte[] target) {
        assert target == null || target.length >= len : "numberToBytes target is too small";
        byte[] array = target != null ? target : new byte[len];
        // Start from the end (network byte order), assume array is filled with zeros.
        for (int i = len - 1; i >= 0; i--) {
            array[i] = (byte) (number & 0xff);
            number = number >> 8;
        }
        return array;
    }

    static String createResponseKey(String key) throws NoSuchAlgorithmException {
        MessageDigest sha1 = MessageDigest.getInstance("SHA-1");
        byte[] rawBytes = (key + HANDSHAKE_GUID).getBytes();
        byte[] result = sha1.digest(rawBytes);
        return Base64.getEncoder().encodeToString(result);
    }

    private static void doIgnoringExceptions(RunnableThatThrows runnable) {
        try {
            runnable.run();
        } catch (Exception ex) {
            // ignore
        }
    }
    private interface RunnableThatThrows {
        void run() throws Exception;
    }

    static byte[] unmaskIfNeededInPlace(byte[] bytes, byte[] maskingKey) {
        if (maskingKey != null) {
            // Performance note: This code is up to 4 times faster than using only the last loop by itself.
            // Using an IntBuffer is not faster.
            byte m0 = maskingKey[0], m1 = maskingKey[1], m2 = maskingKey[2], m3 = maskingKey[3];
            int roundedLen = 4 * (bytes.length / 4);
            int i = 0;
            for (; i < roundedLen; i += 4) {
                bytes[i] = (byte) (bytes[i] ^ m0);
                bytes[i+1] = (byte) (bytes[i+1] ^ m1);
                bytes[i+2] = (byte) (bytes[i+2] ^ m2);
                bytes[i+3] = (byte) (bytes[i+3] ^ m3);
            }
            for (; i < bytes.length; i++) {
                int j = i % 4;
                bytes[i] = (byte) (bytes[i] ^ maskingKey[j]);
            }
        }
        return bytes;
    }

    static class MethodNotAllowedException extends IllegalArgumentException {
        final String method;
        public MethodNotAllowedException(String method) {
            this.method = method;
        }
    }

    /**
     * Server options, configured using a fluent interface. Start with {@code Options.withPort(int)} since port is
     * required.
     */
    public static class Options {
        Integer backlog;
        int port;
        Logger logger;
        InetAddress address;
        int maxFrameSize;
        SSLContext sslContext;

        private boolean shouldUseSSL() { return sslContext != null; }

        private Options(int port) {
            this.port = port;
        }

        /**
         * Creates new options with the given port.
         *
         * @param port the port to use when listening for WebSocket clients
         * @return this options instance
         */
        public static Options withPort(int port) {
            return new Options(port);
        }

        /**
         * Specifies the backlog size, i.e. the size of the client connection queue. If the queue is full, a client
         * connection is rejected.
         *
         * @param backlog the backlog size, which must be greater than 0
         * @return this options instance
         */
        public Options andBacklog(int backlog) {
            if (backlog <= 0) throw new IllegalArgumentException("Backlog must be > 0");
            this.backlog = backlog;
            return this;
        }

        /**
         * Specifies a logger.
         *
         * @param logger the logger instance
         * @return this options instance
         */
        public Options andLogger(Logger logger) {
            this.logger = logger;
            return this;
        }

        /**
         * Specifies the address to use when creating the server socket.
         *
         * @param address the address to bind to
         * @return this options instance
         */
        public Options andAddress(InetAddress address) {
            this.address = address;
            return this;
        }

        /**
         * Specifies the maximum frame size. The maximum frame size must be at least 126, as it doesn't make much
         * sense to create frame fragments smaller than that.
         *
         * @param size maximum frame size
         * @return this options instance
         */
        public Options andMaxFrameSize(int size) {
            if (size <= 125) throw new IllegalArgumentException("Max frame size must be at least 126.");
            this.maxFrameSize = size;
            return this;
        }

        /**
         * Configures the server for SSL.
         *
         * @param sslContext the SSL context that creates an SSL socket
         * @return this options instance
         */
        public Options andSSL(SSLContext sslContext) {
            if (sslContext == null) throw new IllegalArgumentException("SSL context cannot be null.");
            this.sslContext = sslContext;
            return this;
        }
    }

    /**
     * Log level for logging.
     */
    public enum LogLevel {
        TRACE(0),
        DEBUG(10),
        INFO(20),
        WARN(50),
        ERROR(100);

        /**
         * Numeric level corresponding to this log level. ERROR has the highest level and TRACE has the lowest.
         */
        public final int level;

        LogLevel(int level) {
            this.level = level;
        }
    }

    /**
     * A simple interface for logging.
     */
    public interface Logger {
        /**
         * Logs a message at a certain level. Note that this method is not called if {@link #isEnabledAt(LogLevel)}
         * returns {@code false} for the given level.
         *
         * @param level the log level
         * @param message the message to log
         * @param error an optional error
         */
        void log(LogLevel level, String message, Throwable error);

        /**
         * Determines if logging is enabled for the given level.
         *
         * @param level a log level
         * @return {@code true} if logging is enabled at the level, {@code false} otherwise
         */
        boolean isEnabledAt(LogLevel level);
    }

    /**
     * Represents a WebSocket client and exposes methods that makes it possible to interact with the client, as well
     * as methods for getting information about the client and how it requested the handled resource.
     *
     * Methods on this interface can be invoked from any thread.
     */
    public interface WebSocketClient {
        /**
         * Sends a ping to the client. This can be used to send keep-alive messages to the client.
         *
         * @throws IOException on I/O failure while sending the ping
         */
        void ping() throws IOException;

        /**
         * Performs a clean close of the connection to the client.
         */
        void close();

        /**
         * Sends a text message to the client.
         *
         * @param text the text to send
         * @throws IOException on I/O failure while sending
         */
        void sendTextMessage(CharSequence text) throws IOException;

        /**
         * Sends binary data to the client.
         *
         * @param data the data to send
         * @throws IOException on I/O failure while sending
         */
        void sendBinaryData(byte[] data) throws IOException;

        /**
         * Returns the value of the User-Agent header passed by the client when requesting a Websocket connection. If no
         * User-Agent header was present, returns {@code null}.
         *
         * @return the Host header value, or {@code null}
         */
        String userAgent();

        /**
         * Returns the value of the Host header passed by the client when requesting a Websocket connection. If no
         * Host header was present, returns {@code null}.
         *
         * @return the Host header value, or {@code null}
         */
        String host();

        /**
         * Returns the query (part after '?', not including the fragment) used by the client when requesting a WebSocket
         * connection. If no query was specified, returns {@code null}. If an empty query was specified, returns the
         * empty string. The '?' character is never included.
         *
         * @return a query string, or {@code null}
         * @see <a href="https://tools.ietf.org/html/rfc3986#section-3">Uniform Resource Identifier (URI): Generic Syntax</a>
         */
        String query();

        /**
         * Returns the fragment (part after '#') used by the client when requesting a WebSocket connection. If no
         * fragment was specified, returns {@code null}. If an empty fragment was specified, returns the empty string.
         * The '#' character is never included.
         *
         * @return a fragment string, or {@code null}
         * @see <a href="https://tools.ietf.org/html/rfc3986#section-3">Uniform Resource Identifier (URI): Generic Syntax</a>
         */
        String fragment();
    }

    /**
     * A connection to a non-WebSocket endpoint.
     */
    public interface Connection {
        /**
         * The method (e.g. GET or POST) that the client used to request the endpoint.
         */
        String method();

        /**
         * The URI of the endpoint, including any Host header sent by the client.
         */
        URI uri();

        /**
         * The input stream through which a fallback handler can read data sent by the client.
         */
        InputStream inputStream();

        /**
         * The output stream through which a fallback handler can send data to the client.
         */
        OutputStream outputStream();

        /**
         * All header names sent by the client.
         */
        Iterable<String> headerNames();

        /**
         * The value (if present) of a header sent by the client.
         *
         * @param name case-insensitive, non-{@code null} header name
         * @return an {@code Optional} with the header value if the specified header was sent by the client,
         * otherwise an empty {@code Optional}
         */
        Optional<String> header(String name);

        /**
         * Sends an HTTP response to the client. The server will add a 'Server' header, but otherwise the fallback
         * handler needs to set all headers correctly (e.g. Content-Length).
         *
         * @param statusCode the status code
         * @param reason the status message/reason (e.g. "OK" in "200 OK"); {@code null} is treated as the empty string
         * @param headers headers to send to the client, can be {@code null} to send no headers (except Server)
         */
        void sendResponse(int statusCode, String reason, Map<String, String> headers);
    }

    /**
     * A fallback handler deals with a connection to an unknown endpoint. This makes it possible to let non-WebSocket
     * endpoints be regular HTTP endpoints, for example.
     */
    public interface FallbackHandler {
        /**
         * Handles the connection to an endpoint for which a WebSocket handler hasn't been registered.
         *
         * @param connection an object that provides access to connection details
         * @throws IOException if an I/O error occurs
         */
        void handle(Connection connection) throws IOException;
    }

    static class DefaultFallbackHandler implements FallbackHandler {
        @Override
        public void handle(Connection connection) throws IOException {
            throw new FileNotFoundException("Unknown endpoint: " + connection.uri().getPath());
        }
    }

    /**
     * A handler for a WebSocket client connection. A new handler instance will be created for each connected client.
     * Handlers are invoked on the handler executor passed to the {@code Server} constructor.
     */
    public interface WebSocketHandler {
        /**
         * Invoked right after construction.
         *
         * @param client instance for interacting with the client
         */
        void onOpened(WebSocketClient client);

        /**
         * Invoked when the client closes the connection in an orderly fashion.
         *
         * @param code the close code the client used, if any, otherwise 1000 (Normal Closure).
         * @param reason the close reason the client used, if any. May be {@code null}.
         */
        void onClosedByClient(int code, String reason);

        /**
         * Invoked when the server closes the connection in an orderly fashion, but must likely because of an error.
         *
         * @param code the close code sent to the client
         * @param reason the close reason sent to the client
         */
        void onClosedByServer(int code, String reason);

        /**
         * Invoked when the connection is closed abruptly because of an error.
         *
         * @param t the error that occurred
         */
        void onFailure(Throwable t);

        /**
         * Invoked when the client sends a text message.
         *
         * @param text the message
         */
        void onTextMessage(CharSequence text);

        /**
         * Invoked when the client sends binary data.
         *
         * @param data the data
         */
        void onBinaryData(byte[] data);
    }
}
