/*
 * This file is part of MinecraftAuth - https://github.com/RaphiMC/MinecraftAuth
 * Copyright (C) 2022-2024 RK_01/RaphiMC and contributors
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 3 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package net.raphimc.minecraftauth.step.msa;

import com.google.gson.JsonObject;
import net.lenni0451.commons.httpclient.HttpClient;
import net.lenni0451.commons.httpclient.content.impl.URLEncodedFormContent;
import net.lenni0451.commons.httpclient.requests.impl.PostRequest;
import net.raphimc.minecraftauth.MinecraftAuth;
import net.raphimc.minecraftauth.responsehandler.MsaResponseHandler;
import net.raphimc.minecraftauth.step.AbstractStep;
import net.raphimc.minecraftauth.util.JsonUtil;
import java.time.Instant;
import java.time.ZoneId;
import java.util.HashMap;
import java.util.Map;

public class StepMsaToken extends AbstractStep<MsaCodeStep.MsaCode, StepMsaToken.MsaToken> {
    public StepMsaToken(final AbstractStep<?, MsaCodeStep.MsaCode> prevStep) {
        super("msaToken", prevStep);
    }

    @Override
    public MsaToken applyStep(final HttpClient httpClient, final MsaCodeStep.MsaCode msaCode) throws Exception {
        return this.apply(httpClient, msaCode.getCode(), msaCode.getApplicationDetails().getRedirectUri() != null ? "authorization_code" : "refresh_token", msaCode);
    }

    @Override
    public MsaToken refresh(final HttpClient httpClient, final MsaToken msaToken) throws Exception {
        if (!msaToken.isExpired()) {
            return msaToken;
        } else if (msaToken.getRefreshToken() != null) {
            return this.apply(httpClient, msaToken.getRefreshToken(), "refresh_token", msaToken.getMsaCode());
        } else {
            return super.refresh(httpClient, msaToken);
        }
    }

    @Override
    public MsaToken fromJson(final JsonObject json) {
        final MsaCodeStep.MsaCode msaCode = this.prevStep != null ? this.prevStep.fromJson(json.getAsJsonObject(this.prevStep.name)) : null;
        return new MsaToken(json.get("expireTimeMs").getAsLong(), json.get("accessToken").getAsString(), JsonUtil.getStringOr(json, "refreshToken", null), JsonUtil.getStringOr(json, "idToken", null), msaCode);
    }

    @Override
    public JsonObject toJson(final MsaToken msaToken) {
        final JsonObject json = new JsonObject();
        json.addProperty("expireTimeMs", msaToken.expireTimeMs);
        json.addProperty("accessToken", msaToken.accessToken);
        json.addProperty("refreshToken", msaToken.refreshToken);
        json.addProperty("idToken", msaToken.idToken);
        if (this.prevStep != null) json.add(this.prevStep.name, this.prevStep.toJson(msaToken.msaCode));
        return json;
    }

    private MsaToken apply(final HttpClient httpClient, final String code, final String type, final MsaCodeStep.MsaCode msaCode) throws Exception {
        final MsaCodeStep.ApplicationDetails applicationDetails = msaCode.getApplicationDetails();
        MinecraftAuth.LOGGER.info("Getting MSA Token...");
        final Map<String, String> postData = new HashMap<>();
        postData.put("client_id", applicationDetails.getClientId());
        postData.put("scope", applicationDetails.getScope());
        postData.put("grant_type", type);
        if (type.equals("refresh_token")) {
            postData.put("refresh_token", code);
        } else {
            postData.put("code", code);
            postData.put("redirect_uri", applicationDetails.getRedirectUri());
        }
        if (applicationDetails.getClientSecret() != null) {
            postData.put("client_secret", applicationDetails.getClientSecret());
        }
        final PostRequest postRequest = new PostRequest(applicationDetails.getOAuthEnvironment().getTokenUrl());
        postRequest.setContent(new URLEncodedFormContent(postData));
        final JsonObject obj = httpClient.execute(postRequest, new MsaResponseHandler());
        final MsaToken msaToken = new MsaToken(System.currentTimeMillis() + obj.get("expires_in").getAsLong() * 1000, obj.get("access_token").getAsString(), JsonUtil.getStringOr(obj, "refresh_token", null), JsonUtil.getStringOr(obj, "id_token", null), msaCode);
        MinecraftAuth.LOGGER.info("Got MSA Token, expires: " + Instant.ofEpochMilli(msaToken.getExpireTimeMs()).atZone(ZoneId.systemDefault()));
        return msaToken;
    }


    public static final class MsaToken extends AbstractStep.StepResult<MsaCodeStep.MsaCode> {
        private final long expireTimeMs;
        private final String accessToken;
        private final String refreshToken;
        private final String idToken;
        private final MsaCodeStep.MsaCode msaCode;

        @Override
        protected MsaCodeStep.MsaCode prevResult() {
            return this.msaCode;
        }

        @Override
        public boolean isExpired() {
            return this.expireTimeMs <= System.currentTimeMillis();
        }

        public MsaToken(final long expireTimeMs, final String accessToken, final String refreshToken, final String idToken, final MsaCodeStep.MsaCode msaCode) {
            this.expireTimeMs = expireTimeMs;
            this.accessToken = accessToken;
            this.refreshToken = refreshToken;
            this.idToken = idToken;
            this.msaCode = msaCode;
        }

        public long getExpireTimeMs() {
            return this.expireTimeMs;
        }

        public String getAccessToken() {
            return this.accessToken;
        }

        public String getRefreshToken() {
            return this.refreshToken;
        }

        public String getIdToken() {
            return this.idToken;
        }

        public MsaCodeStep.MsaCode getMsaCode() {
            return this.msaCode;
        }

        @Override
        public String toString() {
            return "StepMsaToken.MsaToken(expireTimeMs=" + this.getExpireTimeMs() + ", accessToken=" + this.getAccessToken() + ", refreshToken=" + this.getRefreshToken() + ", idToken=" + this.getIdToken() + ", msaCode=" + this.getMsaCode() + ")";
        }

        @Override
        public boolean equals(final Object o) {
            if (o == this) return true;
            if (!(o instanceof StepMsaToken.MsaToken)) return false;
            final StepMsaToken.MsaToken other = (StepMsaToken.MsaToken) o;
            if (!other.canEqual((Object) this)) return false;
            if (this.getExpireTimeMs() != other.getExpireTimeMs()) return false;
            final Object this$accessToken = this.getAccessToken();
            final Object other$accessToken = other.getAccessToken();
            if (this$accessToken == null ? other$accessToken != null : !this$accessToken.equals(other$accessToken)) return false;
            final Object this$refreshToken = this.getRefreshToken();
            final Object other$refreshToken = other.getRefreshToken();
            if (this$refreshToken == null ? other$refreshToken != null : !this$refreshToken.equals(other$refreshToken)) return false;
            final Object this$idToken = this.getIdToken();
            final Object other$idToken = other.getIdToken();
            if (this$idToken == null ? other$idToken != null : !this$idToken.equals(other$idToken)) return false;
            final Object this$msaCode = this.getMsaCode();
            final Object other$msaCode = other.getMsaCode();
            if (this$msaCode == null ? other$msaCode != null : !this$msaCode.equals(other$msaCode)) return false;
            return true;
        }

        protected boolean canEqual(final Object other) {
            return other instanceof StepMsaToken.MsaToken;
        }

        @Override
        public int hashCode() {
            final int PRIME = 59;
            int result = 1;
            final long $expireTimeMs = this.getExpireTimeMs();
            result = result * PRIME + (int) ($expireTimeMs >>> 32 ^ $expireTimeMs);
            final Object $accessToken = this.getAccessToken();
            result = result * PRIME + ($accessToken == null ? 43 : $accessToken.hashCode());
            final Object $refreshToken = this.getRefreshToken();
            result = result * PRIME + ($refreshToken == null ? 43 : $refreshToken.hashCode());
            final Object $idToken = this.getIdToken();
            result = result * PRIME + ($idToken == null ? 43 : $idToken.hashCode());
            final Object $msaCode = this.getMsaCode();
            result = result * PRIME + ($msaCode == null ? 43 : $msaCode.hashCode());
            return result;
        }
    }
}
