/*
 * Copyright (c) 2020 pCloud AG
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.pcloud.networking.client;

import static com.pcloud.utils.IOUtils.closeQuietly;

import com.pcloud.networking.protocol.BytesReader;
import com.pcloud.networking.protocol.BytesWriter;
import com.pcloud.networking.protocol.DataSource;
import com.pcloud.networking.protocol.ForwardingProtocolRequestWriter;
import com.pcloud.networking.protocol.ForwardingProtocolResponseReader;
import com.pcloud.networking.protocol.ProtocolRequestWriter;
import com.pcloud.networking.protocol.ProtocolResponseReader;
import com.pcloud.networking.protocol.TypeToken;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

import okio.BufferedSink;

class RealApiChannel implements ApiChannel {

    private final ConnectionProvider connectionProvider;
    private Connection connection;
    private final CountingProtocolRequestWriter writer;
    private final CountingProtocolResponseReader reader;
    private final Endpoint endpoint;
    private final AtomicBoolean closed = new AtomicBoolean(false);

    RealApiChannel(ConnectionProvider connectionProvider, Endpoint endpoint) throws IOException {
        this.connectionProvider = connectionProvider;
        this.connection = connectionProvider.obtainConnection(endpoint);
        this.endpoint = connection.endpoint();
        this.writer = new CountingProtocolRequestWriter(new BytesWriter(connection.sink()), this);
        this.reader = new CountingProtocolResponseReader(new BytesReader(connection.source()), this);
    }

    @Override
    public Endpoint endpoint() {
        return endpoint;
    }

    @Override
    public ProtocolResponseReader reader() {
        return reader;
    }

    @Override
    public ProtocolRequestWriter writer() {
        return writer;
    }

    @Override
    public boolean isOpen() {
        return !closed.get();
    }

    @Override
    public boolean isIdle() {
        // Store the request/response completion counters
        // and in addition to being equal, check again the stored values
        // against fresh values to be able to determine if the
        // counters have changed while doing the checks here and avoid races.
        long sentRequests = writer.getCompletedRequests();
        long completedResponses = reader.getCompletedResponses();
        return sentRequests == completedResponses &&
                writer.getCompletedRequests() == sentRequests &&
                reader.getStartedResponses() == completedResponses;
    }

    @Override
    public void close(boolean attemptConnectionReuse) {
        if (closed.compareAndSet(false, true)) {
            if (attemptConnectionReuse && isIdle()) {
                connectionProvider.recycleConnection(connection);
            } else {
                closeQuietly(connection);
            }
            connection = null;
        }
    }

    private void checkNotClosed() throws ClosedChannelException {
        if (closed.get()) {
            throw new ClosedChannelException();
        }
    }

    private static class CountingProtocolRequestWriter extends ForwardingProtocolRequestWriter {

        private final RealApiChannel apiChannel;
        private final AtomicLong startedRequests = new AtomicLong(0L);
        private final AtomicLong completedRequests = new AtomicLong(0L);

        CountingProtocolRequestWriter(ProtocolRequestWriter delegate, RealApiChannel apiChannel) {
            super(delegate);
            this.apiChannel = apiChannel;
        }

        @Override
        public ProtocolRequestWriter beginRequest() throws IOException {
            apiChannel.checkNotClosed();
            startRequest();
            // Check again to avoid any races if another thread
            // is running RealApiChannel.close() at this moment.
            apiChannel.checkNotClosed();
            super.beginRequest();
            return this;
        }

        @Override
        public ProtocolRequestWriter writeData(DataSource source) throws IOException {
            apiChannel.checkNotClosed();
            super.writeData(source);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeMethodName(String name) throws IOException {
            apiChannel.checkNotClosed();
            super.writeMethodName(name);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeName(String name) throws IOException {
            apiChannel.checkNotClosed();
            super.writeName(name);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(Object value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(String value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(double value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(float value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(long value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public ProtocolRequestWriter writeValue(boolean value) throws IOException {
            apiChannel.checkNotClosed();
            super.writeValue(value);
            return this;
        }

        @Override
        public void close() {
            apiChannel.close();
        }

        @Override
        public ProtocolRequestWriter endRequest() throws IOException {
            apiChannel.checkNotClosed();
            // Mark request as complete before completing it to avoid
            // race conditions between closing the channel and completing the actual write
            // to the underlying sink. By marking in advance the case in question materializes,
            // the channel's connection will not be reused without being sure it's clean and
            // there aren't any unfinished request/response pairs.
            completeRequest();
            // Check again for closure to avoid touching the underlying writer if already closed.
            apiChannel.checkNotClosed();
            super.endRequest();
            return this;
        }

        public long getStartedRequests() {
            return startedRequests.get();
        }

        public long getCompletedRequests() {
            return completedRequests.get();
        }

        private void startRequest() {
            startedRequests.incrementAndGet();
        }

        private void completeRequest() {
            completedRequests.incrementAndGet();
        }
    }

    private static class CountingProtocolResponseReader extends ForwardingProtocolResponseReader {

        private final RealApiChannel apiChannel;
        private final AtomicLong startedResponses;
        private final AtomicLong completedResponses;

        CountingProtocolResponseReader(ProtocolResponseReader delegate, RealApiChannel apiChannel) {
            this(delegate, apiChannel, 0L, 0L);
        }

        private CountingProtocolResponseReader(
                ProtocolResponseReader delegate,
                RealApiChannel apiChannel,
                long startedResponses,
                long completedResponses) {
            super(delegate);
            this.apiChannel = apiChannel;
            this.startedResponses = new AtomicLong(startedResponses);
            this.completedResponses = new AtomicLong(completedResponses);
        }

        @Override
        public long beginResponse() throws IOException {
            apiChannel.checkNotClosed();
            startResponse();
            // Check again to avoid any races if another thread  is running `RealApiChannel.close()`
            // at this moment. If the channel is closed by the time we get here, checking again
            // ensure that the underlying ProtocolResponseReader will not be touched.
            apiChannel.checkNotClosed();
            return super.beginResponse();
        }

        @Override
        public boolean endResponse() throws IOException {
            apiChannel.checkNotClosed();
            boolean hasData = super.endResponse();
            if (!hasData) {
                completeResponse();
            }
            return hasData;
        }

        @Override
        public void readData(BufferedSink sink) throws IOException {
            apiChannel.checkNotClosed();
            super.readData(sink);
            completeResponse();
        }

        @Override
        public void readData(OutputStream outputStream) throws IOException {
            apiChannel.checkNotClosed();
            super.readData(outputStream);
            completeResponse();
        }

        @Override
        public TypeToken peek() throws IOException {
            apiChannel.checkNotClosed();
            return super.peek();
        }

        @Override
        public void beginObject() throws IOException {
            apiChannel.checkNotClosed();
            super.beginObject();
        }

        @Override
        public void beginArray() throws IOException {
            apiChannel.checkNotClosed();
            super.beginArray();
        }

        @Override
        public void endArray() throws IOException {
            apiChannel.checkNotClosed();
            super.endArray();
        }

        @Override
        public void endObject() throws IOException {
            apiChannel.checkNotClosed();
            super.endObject();
        }

        @Override
        public boolean readBoolean() throws IOException {
            apiChannel.checkNotClosed();
            return super.readBoolean();
        }

        @Override
        public String readString() throws IOException {
            apiChannel.checkNotClosed();
            return super.readString();
        }

        @Override
        public long readNumber() throws IOException {
            apiChannel.checkNotClosed();
            return super.readNumber();
        }

        @Override
        public void close() {
            // Let the ApiChannel decide how to release held resources.
            apiChannel.close();
        }

        @Override
        public boolean hasNext() throws IOException {
            apiChannel.checkNotClosed();
            return super.hasNext();
        }

        @Override
        public void skipValue() throws IOException {
            apiChannel.checkNotClosed();
            super.skipValue();
        }

        @Override
        public ProtocolResponseReader newPeekingReader() {
            return new CountingProtocolResponseReader(
                    super.newPeekingReader(),
                    apiChannel,
                    getStartedResponses(),
                    getCompletedResponses());
        }

        public long getStartedResponses() {
            return startedResponses.get();
        }

        public long getCompletedResponses() {
            return completedResponses.get();
        }

        private void startResponse() {
            startedResponses.incrementAndGet();
        }

        private void completeResponse() {
            completedResponses.incrementAndGet();
        }
    }
}
