/*
 * Decompiled with CFR 0.152.
 */
package org.javagrader;

import com.google.gson.Gson;
import com.google.gson.JsonIOException;
import java.io.PrintStream;
import java.lang.management.ManagementFactory;
import java.lang.reflect.Method;
import java.security.InvalidParameterException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;
import org.javagrader.Allow;
import org.javagrader.Allows;
import org.javagrader.Classpath;
import org.javagrader.CustomGradingResult;
import org.javagrader.Forbid;
import org.javagrader.Forbids;
import org.javagrader.Grade;
import org.javagrader.PrintConstants;
import org.javagrader.RestrictedClassLoader;
import org.javagrader.TestClassResult;
import org.javagrader.TestMethodResult;
import org.javagrader.TestResultStatus;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.DynamicTestInvocationContext;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
import org.junit.jupiter.api.extension.TestWatcher;
import org.junit.jupiter.api.function.Executable;
import org.junit.platform.commons.util.ReflectionUtils;
import org.opentest4j.TestAbortedException;

public class GraderExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback,
TestWatcher,
BeforeAllCallback,
ExtensionContext.Store.CloseableResource,
InvocationInterceptor {
    private static final Lock LOCK = new ReentrantLock();
    private static boolean started = false;
    private static final PrintStream originalStdOut = System.out;
    private PrintConstants.PrintMode printMode = PrintConstants.PrintMode.RST;
    private static Duration sumMaxCpuTimeout = Duration.ZERO;
    private static Duration sumMaxTimeout = Duration.ZERO;
    private static final String START_TIME = "start time";
    private static final String CPU_TIMEOUT = "cpu timeout";
    private static final String TIMEOUT_UNIT = "timeout units";
    private static Map<String, TestClassResult> testClassResult = new HashMap<String, TestClassResult>();

    public GraderExtension(GraderBuilder builder) {
        this.printMode = builder.printMode;
    }

    public GraderExtension() {
    }

    public static GraderBuilder builder() {
        return new GraderBuilder();
    }

    public GraderExtension setPrintMode(PrintConstants.PrintMode printMode) {
        this.printMode = printMode;
        return this;
    }

    private long getCpuTime() {
        return ManagementFactory.getThreadMXBean().getCurrentThreadCpuTime();
    }

    public void beforeAll(ExtensionContext context) {
        LOCK.lock();
        try {
            if (!started) {
                started = true;
                context.getRoot().getStore(ExtensionContext.Namespace.GLOBAL).put((Object)"any unique name", (Object)this);
            }
        }
        finally {
            LOCK.unlock();
        }
    }

    public void beforeTestExecution(ExtensionContext context) {
        Grade g = this.getBestAnnotationForTimeout(context);
        if (g != null) {
            this.getStore(context).put((Object)CPU_TIMEOUT, (Object)g.cpuTimeout());
            this.getStore(context).put((Object)TIMEOUT_UNIT, (Object)g.unit());
        } else {
            this.getStore(context).put((Object)CPU_TIMEOUT, (Object)0L);
            this.getStore(context).put((Object)TIMEOUT_UNIT, null);
        }
        this.getStore(context).put((Object)START_TIME, (Object)this.getCpuTime());
        this.updateSumTimeouts(context);
    }

    private Duration durationOf(long value, TimeUnit unit) {
        return Duration.ofNanos(unit.toNanos(value));
    }

    private void updateSumTimeouts(ExtensionContext context) {
        Timeout t;
        Grade g = this.getExistingGradeWithCpuTimeout(context);
        Duration wallClocktimeout = Duration.ZERO;
        if (g != null) {
            sumMaxCpuTimeout = sumMaxCpuTimeout.plus(this.durationOf(g.cpuTimeout(), g.unit()));
            wallClocktimeout = this.durationOf(g.cpuTimeout() * 3L, g.unit());
        }
        if ((t = this.getExistingGradeWithTimeout(context)) != null) {
            sumMaxTimeout = sumMaxTimeout.plus(this.durationOf(t.value(), t.unit()));
        } else if (!wallClocktimeout.isZero()) {
            sumMaxTimeout = sumMaxTimeout.plus(wallClocktimeout);
        }
    }

    private Grade getBestAnnotationForTimeout(ExtensionContext context) {
        Grade gm = context.getRequiredTestMethod().getAnnotation(Grade.class);
        Grade gc = context.getRequiredTestClass().getAnnotation(Grade.class);
        if (gm == null || gm.cpuTimeout() == Long.MAX_VALUE) {
            return gc;
        }
        return gm;
    }

    public void afterTestExecution(ExtensionContext context) throws Exception {
        TimeUnit unit;
        long excess;
        long startTime = (Long)this.getStore(context).remove((Object)START_TIME, Long.TYPE);
        long duration = this.getCpuTime() - startTime;
        long cpuTimeout = (Long)this.getStore(context).remove((Object)CPU_TIMEOUT, Long.TYPE);
        if (cpuTimeout > 0L && (excess = this.nanoSecondToTimeUnit(unit = (TimeUnit)((Object)this.getStore(context).remove((Object)TIMEOUT_UNIT, TimeUnit.class)), duration) - cpuTimeout) > 0L) {
            throw new TimeoutException(String.format("Execution exceeded CPU timeout of %s by %s", GraderExtension.getTimeoutMessage(unit, cpuTimeout), GraderExtension.getTimeoutMessage(unit, excess)));
        }
    }

    private long nanoSecondToTimeUnit(TimeUnit unit, long value) {
        return unit.convert(value, TimeUnit.NANOSECONDS);
    }

    private static String getTimeoutMessage(TimeUnit unit, long value) {
        String label = unit.name().toLowerCase();
        if (value == 1L && label.endsWith("s")) {
            label = label.substring(0, label.length() - 1);
        }
        return value + " " + label;
    }

    private ExtensionContext.Store getStore(ExtensionContext context) {
        return context.getStore(ExtensionContext.Namespace.create((Object[])new Object[]{this.getClass(), context.getRequiredTestMethod()}));
    }

    public void testSuccessful(ExtensionContext context) {
        this.addTestResult(context, TestResultStatus.SUCCESS);
    }

    public void testAborted(ExtensionContext context, Throwable cause) {
        this.addTestResult(context, TestResultStatus.ABORTED);
    }

    public void testDisabled(ExtensionContext context, Optional<String> reason) {
        this.addTestResult(context, TestResultStatus.DISABLED);
    }

    public void testFailed(ExtensionContext context, Throwable cause) {
        String b = cause.getClass().toString();
        if (b.equals("class org.javagrader.CustomGradingResult")) {
            Grade gradedMethod = context.getRequiredTestMethod().getAnnotation(Grade.class);
            Grade gradedClass = context.getRequiredTestClass().getAnnotation(Grade.class);
            if (gradedClass != null && gradedClass.custom() || gradedMethod != null && gradedMethod.custom()) {
                CustomGradingResult custom = CustomGradingResult.fromSerialization(cause);
                if (custom != null) {
                    this.addTestResult(context, custom.status, custom);
                } else {
                    System.out.println("Failed to interpret the custom grading result " + cause);
                    this.addTestResult(context, TestResultStatus.FAIL);
                }
            } else {
                System.out.println("WARNING: Received a CustomGradingResult exception while not expecting one.");
                System.out.println("If you are trying to solve this exercise: sadly, there is a protection against this ;-)");
                System.out.println("If you are the exercise creator, you probably forgot to put custom=true inside @Grade.");
                this.addTestResult(context, TestResultStatus.FAIL);
            }
        } else if (cause instanceof TimeoutException) {
            this.addTestResult(context, TestResultStatus.TIMEOUT);
        } else {
            this.addTestResult(context, TestResultStatus.FAIL);
        }
    }

    private void addTestResult(ExtensionContext context, TestResultStatus status, CustomGradingResult customGradingResult) {
        Method m = context.getRequiredTestMethod();
        Class c = context.getRequiredTestClass();
        boolean gradedClass = c.isAnnotationPresent(Grade.class);
        boolean gradedMethod = m.isAnnotationPresent(Grade.class);
        if (gradedClass || gradedMethod) {
            String methodName = this.getFactoryTestPrefix(context) + context.getDisplayName();
            String className = this.getClassDisplayName(context);
            TestMethodResult r = new TestMethodResult(methodName, m, status, customGradingResult);
            if (!testClassResult.containsKey(className)) {
                testClassResult.put(className, new TestClassResult(className, c));
            }
            testClassResult.get(className).addTestMethodResult(r);
        }
    }

    private void addTestResult(ExtensionContext context, TestResultStatus status) {
        this.addTestResult(context, status, null);
    }

    private String getClassDisplayName(ExtensionContext context) {
        Optional opt = ((ExtensionContext)context.getParent().get()).getParent();
        if (opt.isPresent()) {
            return this.getClassDisplayName((ExtensionContext)context.getParent().get());
        }
        return context.getDisplayName();
    }

    private String getFactoryTestPrefix(ExtensionContext context) {
        Optional opt = ((ExtensionContext)context.getParent().get()).getParent();
        if (opt.isPresent() && ((ExtensionContext)opt.get()).getParent().isPresent()) {
            return String.format("%s - ", ((ExtensionContext)context.getParent().get()).getDisplayName());
        }
        return "";
    }

    public void close() {
        if (this.printMode != PrintConstants.PrintMode.NONE) {
            System.setOut(originalStdOut);
            this.printSumTimeouts();
            this.printTable();
        }
    }

    private void printSumTimeouts() {
        System.out.println("Max timeout = " + this.formatDuration(sumMaxTimeout));
        System.out.println("Max cpu timeout = " + this.formatDuration(sumMaxCpuTimeout));
    }

    private String formatDuration(Duration d) {
        int partialMinutes = (int)(d.toMinutes() % 60L);
        int partialSeconds = (int)(d.getSeconds() % 60L);
        return String.format("%d:%02d:%02d", d.toHours(), partialMinutes, partialSeconds);
    }

    private void printTable() {
        System.out.println("--- GRADE ---");
        if (this.printMode == PrintConstants.PrintMode.RST) {
            System.out.println(".. csv-table::\n    :header: \"Test\", \"Status\", \"Grade\", \"Comment\"\n    :widths: auto\n    ");
        }
        double grade = 0.0;
        double maxGrade = 0.0;
        double gradeWithoutAborted = 0.0;
        double maxWithoutAborted = 0.0;
        for (TestClassResult test : testClassResult.values()) {
            System.out.println(test.format(this.printMode));
            grade += test.grade();
            maxGrade += test.maxGrade();
            gradeWithoutAborted += test.grade();
            maxWithoutAborted += test.maxGradeWithoutAborted();
        }
        String prefix = PrintConstants.globalPrefix(this.printMode);
        String sep = PrintConstants.separator(this.printMode, PrintConstants.SeparatorsType.CONTENT);
        switch (this.printMode) {
            case NONE: {
                break;
            }
            case NORMAL: {
                break;
            }
            case RST: {
                System.out.printf("%s%sTOTAL%s%s%s%s%s%s%n", prefix, PrintConstants.separator(this.printMode, PrintConstants.SeparatorsType.CLASS_PREFIX), PrintConstants.separator(this.printMode, PrintConstants.SeparatorsType.CLASS_SUFFIX), sep, sep, "**", PrintConstants.formatGrade(grade, maxGrade), "**");
                System.out.printf("%s%sTOTAL WITHOUT ABORTED%s%s%s%s%s%s%n%n", prefix, PrintConstants.separator(this.printMode, PrintConstants.SeparatorsType.CLASS_PREFIX), PrintConstants.separator(this.printMode, PrintConstants.SeparatorsType.CLASS_SUFFIX), sep, sep, "**", PrintConstants.formatGrade(gradeWithoutAborted, maxWithoutAborted), "**");
                break;
            }
            default: {
                throw new IllegalArgumentException("Unrecognized printing mode " + (Object)((Object)this.printMode));
            }
        }
        System.out.printf("TOTAL %s%n", PrintConstants.formatGrade(grade, maxGrade));
        System.out.printf("TOTAL WITHOUT IGNORED %s%n", PrintConstants.formatGrade(gradeWithoutAborted, maxWithoutAborted));
        System.out.println("--- END GRADE ---");
    }

    public void interceptTestMethod(InvocationInterceptor.Invocation<Void> invocation, ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext) throws Throwable {
        this.intercept(invocation, invocationContext, extensionContext);
    }

    public <T> T interceptTestFactoryMethod(InvocationInterceptor.Invocation<T> invocation, ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext) throws Throwable {
        return (T)super.interceptTestFactoryMethod(invocation, invocationContext, extensionContext);
    }

    public void interceptDynamicTest(InvocationInterceptor.Invocation<Void> invocation, DynamicTestInvocationContext invocationContext, ExtensionContext extensionContext) throws Throwable {
        super.interceptDynamicTest(invocation, invocationContext, extensionContext);
    }

    public void interceptTestTemplateMethod(InvocationInterceptor.Invocation<Void> invocation, ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext) throws Throwable {
        this.intercept(invocation, invocationContext, extensionContext);
    }

    private void intercept(InvocationInterceptor.Invocation<Void> invocation, ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext) throws Throwable {
        Timeout cTimeout = extensionContext.getRequiredTestClass().getAnnotation(Timeout.class);
        Timeout mTimeout = extensionContext.getRequiredTestMethod().getAnnotation(Timeout.class);
        Grade g = null;
        if (cTimeout == null && mTimeout == null) {
            g = this.getExistingGradeWithCpuTimeout(extensionContext);
        }
        if (!this.areImportRestricted(extensionContext)) {
            if (g != null) {
                long timeout = g.cpuTimeout() * 3L;
                Timeout.ThreadMode threadMode = g.threadMode();
                Duration duration = Duration.ofNanos(g.unit().toNanos(timeout));
                switch (threadMode) {
                    case INFERRED: {
                        Assertions.assertTimeout((Duration)duration, () -> invocation.proceed());
                        break;
                    }
                    case SAME_THREAD: {
                        Assertions.assertTimeout((Duration)duration, () -> invocation.proceed());
                        break;
                    }
                    case SEPARATE_THREAD: {
                        Assertions.assertTimeoutPreemptively((Duration)duration, () -> invocation.proceed());
                    }
                }
            } else {
                invocation.proceed();
            }
        } else {
            invocation.skip();
            Classpath classpath = Classpath.current();
            Set<String> forbids = this.getForbidden(extensionContext);
            Set<String> allows = this.getAllowed(extensionContext);
            RestrictedClassLoader modifiedClassLoader = classpath.newClassloader(forbids, allows);
            ClassLoader currentThreadPreviousClassLoader = this.replaceCurrentThreadClassLoader(modifiedClassLoader);
            String className = extensionContext.getRequiredTestClass().getName();
            try {
                Optional method;
                Class[] paramTypes;
                Class<?> testClass;
                try {
                    testClass = modifiedClassLoader.loadClass(className);
                }
                catch (ClassNotFoundException e) {
                    throw new IllegalStateException("Cannot load test class [" + className + "] from modified classloader, verify that you did not exclude a path containing the test", e);
                }
                String methodName = extensionContext.getRequiredTestMethod().getName();
                Method m = extensionContext.getRequiredTestMethod();
                Object testInstance = ReflectionUtils.newInstance(testClass, (Object[])new Object[0]);
                List l = invocationContext.getArguments();
                for (Class<?> clazz : paramTypes = m.getParameterTypes()) {
                    if (!modifiedClassLoader.isForbidden(clazz.getName())) continue;
                    throw new ClassNotFoundException("Failed to load the test instance as its contains a " + clazz.getName() + " parameter, which is forbidden");
                }
                try {
                    List ml = ReflectionUtils.findMethods(testInstance.getClass(), p -> {
                        if (!p.getName().equals(methodName)) {
                            return false;
                        }
                        Class<?>[] methodParams = p.getParameterTypes();
                        if (methodParams.length != paramTypes.length) {
                            return false;
                        }
                        for (int i = 0; i < paramTypes.length; ++i) {
                            if (this.isSameClass(methodParams[i], paramTypes[i])) continue;
                            return false;
                        }
                        return true;
                    });
                    method = ReflectionUtils.findMethod(testInstance.getClass(), (String)((Method)ml.get(0)).getName(), (Class[])((Method)ml.get(0)).getParameterTypes());
                }
                catch (NoClassDefFoundError e) {
                    throw new ClassNotFoundException("Failed to load the test instance. It may contain a method with a forbidden parameter type", e);
                }
                ArrayList convertedArgs = new ArrayList();
                for (Object e : l) {
                    String message;
                    if (e == null) {
                        convertedArgs.add(null);
                        continue;
                    }
                    String toLoad = e.getClass().getName();
                    try {
                        Class<?> caster = modifiedClassLoader.loadClass(toLoad);
                        Object o = this.castObj(e, caster);
                        convertedArgs.add(o);
                    }
                    catch (JsonIOException e2) {
                        message = String.format("Failed to provide argument of class %s (failed conversion to student's class loader using Gson. Consider using @Grade(noSecurity = true), use a simpler object type as input (perhaps you has cyclic references) or refer to Gson doc to make it compatible)", toLoad);
                        throw new InvalidParameterException(message);
                    }
                    catch (ClassNotFoundException e3) {
                        message = forbids.contains(toLoad) ? String.format("Failed to load class %s as it is forbidden", toLoad) : String.format("Failed to load class %s for unknown reason. Consider using @Grade(noSecurity = true)", toLoad);
                        throw new ClassNotFoundException(message, e3);
                    }
                }
                Executable e = () -> ReflectionUtils.invokeMethod((Method)((Method)method.orElseThrow(() -> new IllegalStateException("No test method named " + methodName + " for class " + testClass))), (Object)testInstance, (Object[])convertedArgs.toArray());
                if (g != null) {
                    long l2 = g.cpuTimeout() * 3L;
                    Timeout.ThreadMode threadMode = g.threadMode();
                    Duration duration = Duration.ofNanos(g.unit().toNanos(l2));
                    BiConsumer<Duration, Executable> timeoutAssert = null;
                    switch (threadMode) {
                        case INFERRED: {
                            timeoutAssert = Assertions::assertTimeout;
                            break;
                        }
                        case SAME_THREAD: {
                            timeoutAssert = Assertions::assertTimeout;
                            break;
                        }
                        case SEPARATE_THREAD: {
                            timeoutAssert = Assertions::assertTimeoutPreemptively;
                        }
                    }
                    timeoutAssert.accept(duration, e);
                } else {
                    e.execute();
                }
            }
            catch (Exception e1) {
                Class<?> castedException = currentThreadPreviousClassLoader.loadClass(e1.getClass().getName());
                if (castedException == TestAbortedException.class) {
                    throw new TestAbortedException(e1.getMessage());
                }
                throw e1;
            }
            finally {
                Thread.currentThread().setContextClassLoader(currentThreadPreviousClassLoader);
            }
        }
    }

    private Grade getExistingGradeWithCpuTimeout(ExtensionContext extensionContext) {
        Grade mGrade = extensionContext.getRequiredTestMethod().getAnnotation(Grade.class);
        if (mGrade != null && mGrade.cpuTimeout() != Long.MAX_VALUE) {
            return mGrade;
        }
        Grade cGrade = extensionContext.getRequiredTestClass().getAnnotation(Grade.class);
        if (cGrade != null && cGrade.cpuTimeout() != Long.MAX_VALUE) {
            return cGrade;
        }
        return null;
    }

    private boolean areImportRestricted(ExtensionContext extensionContext) {
        Grade g = this.getExistingGradeWithoutRestrictedImport(extensionContext);
        if (g != null) {
            return false;
        }
        return !this.getAllowed(extensionContext).contains("all");
    }

    private Grade getExistingGradeWithoutRestrictedImport(ExtensionContext extensionContext) {
        Grade mGrade = extensionContext.getRequiredTestMethod().getAnnotation(Grade.class);
        if (mGrade != null && mGrade.noRestrictedImport()) {
            return mGrade;
        }
        Grade cGrade = extensionContext.getRequiredTestClass().getAnnotation(Grade.class);
        if (cGrade != null && cGrade.noRestrictedImport()) {
            return cGrade;
        }
        return null;
    }

    private Timeout getExistingGradeWithTimeout(ExtensionContext extensionContext) {
        Timeout mTimeout = extensionContext.getRequiredTestMethod().getAnnotation(Timeout.class);
        if (mTimeout != null) {
            return mTimeout;
        }
        return extensionContext.getRequiredTestClass().getAnnotation(Timeout.class);
    }

    private ClassLoader replaceCurrentThreadClassLoader(ClassLoader modifiedClassLoader) {
        ClassLoader currentThreadPreviousClassLoader = Thread.currentThread().getContextClassLoader();
        Thread.currentThread().setContextClassLoader(modifiedClassLoader);
        return currentThreadPreviousClassLoader;
    }

    private Set<String> getForbidden(ExtensionContext extensionContext) {
        Forbid cForbid = extensionContext.getRequiredTestClass().getAnnotation(Forbid.class);
        Forbid mForbid = extensionContext.getRequiredTestMethod().getAnnotation(Forbid.class);
        Forbids cForbids = extensionContext.getRequiredTestClass().getAnnotation(Forbids.class);
        Forbids mForbids = extensionContext.getRequiredTestMethod().getAnnotation(Forbids.class);
        HashSet<String> forbidden = new HashSet<String>();
        if (cForbid != null) {
            forbidden.add(cForbid.value());
        }
        if (mForbid != null) {
            forbidden.add(mForbid.value());
        }
        if (cForbids != null) {
            for (Forbid value : cForbids.value()) {
                forbidden.add(value.value());
            }
        }
        if (mForbids != null) {
            for (Forbid value : mForbids.value()) {
                forbidden.add(value.value());
            }
        }
        return forbidden;
    }

    private Set<String> getAllowed(ExtensionContext extensionContext) {
        Allow cAllow = extensionContext.getRequiredTestClass().getAnnotation(Allow.class);
        Allow mAllow = extensionContext.getRequiredTestMethod().getAnnotation(Allow.class);
        Allows cAllows = extensionContext.getRequiredTestClass().getAnnotation(Allows.class);
        Allows mAllows = extensionContext.getRequiredTestMethod().getAnnotation(Allows.class);
        HashSet<String> allowed = new HashSet<String>();
        if (cAllow != null) {
            allowed.add(cAllow.value());
        }
        if (mAllow != null) {
            allowed.add(mAllow.value());
        }
        if (cAllows != null) {
            for (Allow value : cAllows.value()) {
                allowed.add(value.value());
            }
        }
        if (mAllows != null) {
            for (Allow value : mAllows.value()) {
                allowed.add(value.value());
            }
        }
        return allowed;
    }

    private <T> T castObj(Object o, Class<T> target) throws JsonIOException {
        Gson gson = new Gson();
        return (T)gson.fromJson(gson.toJson(o), target);
    }

    private boolean isInstance(Object o, Class<?> target) {
        if (target.isInstance(o)) {
            return true;
        }
        return o.getClass().toString().equals(target.toString());
    }

    private <T> boolean isSameClass(Class<?> origin, Class<?> target) {
        if (origin.equals(target)) {
            return true;
        }
        return origin.toString().equals(target.toString());
    }

    public static class GraderBuilder {
        private PrintConstants.PrintMode printMode = PrintConstants.PrintMode.RST;

        public GraderBuilder printMode(PrintConstants.PrintMode printMode) {
            this.printMode = printMode;
            return this;
        }

        public GraderExtension build() {
            return new GraderExtension(this);
        }
    }
}

