/*
 * Copyright © 2019 Benny Bottema (benny@bennybottema.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.bbottema.javasocksproxyserver;

import org.bbottema.javasocksproxyserver.auth.Authenticator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ServerSocketFactory;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.ServerSocket;
import java.net.Socket;

public class SocksServer {

	private static final Logger LOGGER = LoggerFactory.getLogger(SocksServer.class);
	
	private volatile boolean stopped = false;
	private int listenPort;

	private ServerSocketFactory factory;
	private Authenticator authenticator = null;

	public SocksServer() {
		listenPort = 1080;
		factory = ServerSocketFactory.getDefault();
	}

	public SocksServer(int listenPort) {
		this.listenPort = listenPort;
		factory = ServerSocketFactory.getDefault();
	}

	public SocksServer(int listenPort, ServerSocketFactory factory) {
		this.listenPort = listenPort;
		this.factory = factory;
	}

	public synchronized SocksServer setAuthenticator(Authenticator authenticator) {
		this.authenticator = authenticator;
		return this;
	}

	@Deprecated
	public synchronized void start(int port) {
		start(port, ServerSocketFactory.getDefault());
  }

	@Deprecated
	public synchronized void start(int port, ServerSocketFactory factory) {
		listenPort = port;
		this.factory = factory;
		start();
	}

	public synchronized SocksServer start() {
		stopped = false;
		new Thread(new ServerProcess(listenPort, factory, authenticator)).start();
		return this;
	}

	public synchronized SocksServer stop() {
		stopped = true;
		return this;
	}
	
	private class ServerProcess implements Runnable {
		
		protected final int port;
		private final ServerSocketFactory serverSocketFactory;
		private final Authenticator authenticator;
		
		public ServerProcess(int port, ServerSocketFactory serverSocketFactory, Authenticator authenticator) {
			this.port = port;
			this.serverSocketFactory = serverSocketFactory;
			this.authenticator = authenticator;
		}
		
		@Override
		public void run() {
			LOGGER.debug("SOCKS server started...");
			try {
				handleClients(port);
				LOGGER.debug("SOCKS server stopped...");
			} catch (IOException e) {
				LOGGER.debug("SOCKS server crashed...");
				Thread.currentThread().interrupt();
			}
		}

		protected void handleClients(int port) throws IOException {
			final ServerSocket listenSocket = serverSocketFactory.createServerSocket(port);
			listenSocket.setSoTimeout(SocksConstants.LISTEN_TIMEOUT);

            LOGGER.debug("SOCKS server listening at port: {}", listenSocket.getLocalPort());

			while (true) {
				synchronized (SocksServer.this) {
					if (stopped) {
						break;
					}
				}
				handleNextClient(listenSocket);
			}

			try {
				listenSocket.close();
			} catch (IOException e) {
				// ignore
			}
		}

		private void handleNextClient(ServerSocket listenSocket) {
			try {
				final Socket clientSocket = listenSocket.accept();
				clientSocket.setSoTimeout(SocksConstants.DEFAULT_SERVER_TIMEOUT);
                LOGGER.debug("Connection from : {}", Utils.getSocketInfo(clientSocket));
				new Thread(new ProxyHandler(clientSocket, authenticator)).start();
			} catch (InterruptedIOException e) {
				//	This exception is thrown when accept timeout is expired
			} catch (Exception e) {
				LOGGER.error(e.getMessage(), e);
			}
		}
	}
}