package com.stringee.network.tcpclient;

import android.os.Build;
import android.util.Base64;
import android.util.Log;

import com.stringee.common.Utils;
import com.stringee.network.tcpclient.packet.Packet;
import com.stringee.network.tcpclient.ssl.SSLManager;

import org.json.JSONObject;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.security.cert.X509Certificate;

/**
 * @author Alex
 */

public class TcpClient extends Thread {

    private InetSocketAddress serverAddress;
    private Socket socket;
    private InputStream inputStream;
    private OutputStream outputStream;
    private int connectTimeout = 15000;// 15s
    private HandlerBase handler;
    private static final int READ_BUFFER_SIZE = 0x500000;// 1MB
    private final ByteBuffer readBuffer = ByteBuffer.allocateDirect(READ_BUFFER_SIZE);
    private long lastPacketReceivedTime = System.currentTimeMillis();

    //SSL
    private static SSLSocketFactory socketFactory = null;
    private boolean useSsl;
    private String sslAcceptedHostname;
    private static HostnameVerifier hostnameVerifier;
    //static
    private static final boolean sslDevelopmentMode = true;

    private boolean connectCompleted = false;
    private final boolean trustAll;
    private final List<String> publicKeys;

    private static final String TAG = "Stringee";

    public void sslInit(List<StringeeCertificate> certificates, boolean trustAll) {
        if (!sslDevelopmentMode) {
            socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
        } else {
            try {
                SSLManager sslManager = new SSLManager(certificates, trustAll);
                socketFactory = sslManager.getSocketFactory();
            } catch (NoSuchAlgorithmException | KeyManagementException ex) {
                Utils.reportException(TcpClient.class, ex);
            }
        }

        hostnameVerifier = HttpsURLConnection.getDefaultHostnameVerifier();
    }

    /**
     * TCP Client
     *
     * @param handler             handler
     * @param hostname            Server hostname
     * @param port                Server port
     * @param useSsl              true = use TLS/SSL
     * @param sslAcceptedHostname accept hostname
     */
    public TcpClient(HandlerBase handler, String hostname, int port, boolean useSsl, String sslAcceptedHostname, List<StringeeCertificate> certificates, boolean trustAll, List<String> publicKeys) {
        this(handler, new InetSocketAddress(hostname, port), useSsl, sslAcceptedHostname, certificates, trustAll, publicKeys);
    }

    /**
     * @param handler             handler
     * @param address             Server address
     * @param useSsl              true = use TLS/SSL
     * @param sslAcceptedHostname accept hostname
     */
    public TcpClient(HandlerBase handler, InetSocketAddress address, boolean useSsl, String sslAcceptedHostname, List<StringeeCertificate> certificates, boolean trustAll, List<String> publicKeys) {
        super("io-thread");
        this.serverAddress = address;
        this.handler = handler;
        this.useSsl = useSsl;
        this.sslAcceptedHostname = sslAcceptedHostname;
        this.trustAll = trustAll;
        this.publicKeys = publicKeys;

        if (useSsl) {
            if (socketFactory == null) {
                sslInit(certificates, trustAll);
            }

            try {
                socket = socketFactory.createSocket();
            } catch (IOException ex) {
                Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            socket = new Socket();
        }
    }

    /**
     * Connect to server
     *
     * @param timeout timeout
     * @return true: connected; false: NOT connected
     */
    public boolean connect(int timeout) {
        connectTimeout = timeout;

        try {
            Log.d(TAG, "+++++++Connecting SDK Server...");
            socket.connect(serverAddress, timeout);

            if (useSsl) {
                SSLSocket sslSocket = (SSLSocket) socket;
                sslSocket.addHandshakeCompletedListener(hce -> Log.d(TAG, "++++++++++++++++handshakeCompleted"));
                sslSocket.startHandshake();

                //check hostname
                Log.d(TAG, "PeerPrincipal" + sslSocket.getSession().getPeerPrincipal());
                boolean hostnameOk = false;
                if (trustAll || !Utils.isEmpty(publicKeys)) {
                    if (trustAll) {
                        hostnameOk = true;
                    } else {
                        try {
                            for (X509Certificate certificate : sslSocket.getSession().getPeerCertificateChain()) {
                                if (!Utils.isEmpty(publicKeys)) {
                                    for (int i = 0; i < publicKeys.size(); i++) {
                                        String publicKey = publicKeys.get(i);
                                        publicKey = publicKey.trim().replaceAll("\n", "");
                                        if (Base64.encodeToString(certificate.getPublicKey().getEncoded(), Base64.DEFAULT).trim().replaceAll("\n", "").equals(publicKey)) {
                                            hostnameOk = true;
                                            break;
                                        }
                                    }
                                }
                                if (hostnameOk) {
                                    break;
                                }
                            }
                        } catch (SSLPeerUnverifiedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                } else {
                    hostnameOk = hostnameVerifier.verify(sslAcceptedHostname, sslSocket.getSession());
                }
                if (!hostnameOk) {
                    //close
                    if (socket.isConnected() && !socket.isClosed()) {
                        try {
                            socket.close();
                        } catch (IOException ex1) {
                            Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex1);
                        }
                    }

                    //throw
                    throw new IOException("SSL/TLS - Hostname not valid: " + sslAcceptedHostname + "; PeerPrincipal: " + sslSocket.getSession().getPeerPrincipal());
                }
            }

            inputStream = socket.getInputStream();
            outputStream = socket.getOutputStream();

            connectCompleted = true;

            return true;
        } catch (IOException ex) {
            // Connection refused | Connection timed out | UnknownHostException | Handshake
            // (DNS) | loi get input/output stream
            Log.d(TAG, "+++++++Error 1: " + serverAddress.getHostName() + " :" + serverAddress.getPort());
            if (ex instanceof SSLHandshakeException) {
                if (!Utils.isEmpty(ex.getMessage())) {
                    if (ex.getMessage().startsWith("Unacceptable certificate")) {
                        handler.onUnacceptedCertificate();
                    }
                }
            }
            // Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IllegalArgumentException ex) {
            Log.d(TAG, "+++++++Error 2: " + ex.getMessage());
            // Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
        }

        return false;
    }

    @Override
    public void run() {
        boolean successfully = connect(connectTimeout);
        if (successfully && socket.isConnected()) {
            handler.onConnected(this);

            //read data from socket
            while (socket.isConnected()) {
                try {
                    byte[] buffer = new byte[4096];

                    int read = inputStream.read(buffer);
                    if (read > 0) {
                        readBuffer.put(buffer, 0, read);

                        readBuffer.flip();
                        handler.onRead(this, readBuffer);
                        readBuffer.compact();
                    } else if (read < 0) {
                        // end of the stream has been reached
                        // System.out.println("--------end of the stream has been reached (server
                        // close socket)-----");
                        break;
                    }
                } catch (IOException ex) {
                    Utils.reportException(TcpClient.class, ex);
                    break;
                }
            }
        }

        disconnect();
        handler.onDisconnected(this);
    }

    public boolean isConnected() {
        return socket != null && connectCompleted && socket.isConnected() && !socket.isClosed();
    }

    public void disconnect() {
        if (socket.isConnected() && !socket.isClosed()) {
            try {
                socket.close();
            } catch (IOException ex1) {
                Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex1);
            }
        }
        if (inputStream != null) {
            try {
                inputStream.close();
            } catch (IOException ex) {
                Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        if (outputStream != null) {
            try {
                outputStream.close();
            } catch (IOException ex) {
                Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    public int getConnectTimeout() {
        return connectTimeout;
    }

    public void setConnectTimeout(int connectTimeout) {
        this.connectTimeout = connectTimeout;
    }

    public boolean send(byte[] bytes, int offset, int length) {
        if (outputStream == null) {
            return false;
        }

        if (socket.isConnected()) {
            try {
                outputStream.write(bytes, offset, length);
                return true;
            } catch (IOException ex) {
                Logger.getLogger(TcpClient.class.getName()).log(Level.SEVERE, null, ex);
            }
        }

        return false;
    }

    public boolean send(Packet packet) {
        if (packet.getJsonData() != null && packet.getLength() == 0) {
            JSONObject json = packet.getJsonData();
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
                packet.setData(json.toString().getBytes(StandardCharsets.UTF_8));
            } else {
                try {
                    packet.setData(json.toString().getBytes("UTF-8"));
                } catch (UnsupportedEncodingException e) {
                    return false;
                }
            }
        }

        ByteBuffer buff = ByteBuffer.allocate(packet.getLength() + Packet.HEADER_LENGTH);
        // HEADER_LENGTH
        buff.put(Packet.MAGIC);
        buff.putInt(packet.getLength());
        buff.putShort(packet.getService());
        buff.put(packet.getData());
        buff.flip();

        boolean sent = send(buff.array(), 0, buff.remaining());
        Log.d(TAG, "Send: " + packet);
        if (sent) {
            handler.onWriteCompleted(packet, this);
        }
        return sent;
    }

    public Socket getSocket() {
        return socket;
    }

    public HandlerBase getHandler() {
        return handler;
    }

    public void setHandler(HandlerBase handler) {
        this.handler = handler;
    }

    public long getLastPacketReceivedTime() {
        return lastPacketReceivedTime;
    }

    public void setLastPacketReceivedTime(long lastPacketReceivedTime) {
        this.lastPacketReceivedTime = lastPacketReceivedTime;
    }

    public String getLocalIp() {
        return socket.getLocalAddress().getHostAddress();
    }

    public boolean isConnectCompleted() {
        return connectCompleted;
    }

    public InetSocketAddress getServerAddress() {
        return serverAddress;
    }

    public void setServerAddress(InetSocketAddress serverAddress) {
        this.serverAddress = serverAddress;
    }

    public boolean isUseSsl() {
        return useSsl;
    }

    public void setUseSsl(boolean useSsl) {
        this.useSsl = useSsl;
    }

    public String getSslAcceptedHostname() {
        return sslAcceptedHostname;
    }

    public void setSslAcceptedHostname(String sslAcceptedHostname) {
        this.sslAcceptedHostname = sslAcceptedHostname;
    }
}
