/*
 * Decompiled with CFR 0.152.
 */
package org.javacs.action;

import com.sun.source.tree.ClassTree;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.LineMap;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.Tree;
import com.sun.source.util.JavacTask;
import com.sun.source.util.TreePath;
import com.sun.source.util.Trees;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import org.javacs.CompileTask;
import org.javacs.CompilerProvider;
import org.javacs.FindTypeDeclarationAt;
import org.javacs.action.FindMethodDeclarationAt;
import org.javacs.lsp.CodeAction;
import org.javacs.lsp.CodeActionParams;
import org.javacs.lsp.Diagnostic;
import org.javacs.lsp.Position;
import org.javacs.lsp.Range;
import org.javacs.lsp.TextEdit;
import org.javacs.lsp.WorkspaceEdit;
import org.javacs.rewrite.AddException;
import org.javacs.rewrite.AddImport;
import org.javacs.rewrite.AddSuppressWarningAnnotation;
import org.javacs.rewrite.ConvertFieldToBlock;
import org.javacs.rewrite.ConvertVariableToStatement;
import org.javacs.rewrite.CreateMissingMethod;
import org.javacs.rewrite.GenerateRecordConstructor;
import org.javacs.rewrite.ImplementAbstractMethods;
import org.javacs.rewrite.OverrideInheritedMethod;
import org.javacs.rewrite.RemoveClass;
import org.javacs.rewrite.RemoveException;
import org.javacs.rewrite.RemoveMethod;
import org.javacs.rewrite.Rewrite;

public class CodeActionProvider {
    private final CompilerProvider compiler;
    private static final Pattern NOT_THROWN_EXCEPTION = Pattern.compile("^'((\\w+\\.)*\\w+)' is not thrown");
    private static final Pattern UNREPORTED_EXCEPTION = Pattern.compile("unreported exception ((\\w+\\.)*\\w+)");
    private static final Logger LOG = Logger.getLogger("main");

    public CodeActionProvider(CompilerProvider compiler) {
        this.compiler = compiler;
    }

    public List<CodeAction> codeActionsForCursor(CodeActionParams params) {
        LOG.info(String.format("Find code actions at %s(%d)...", params.textDocument.uri.getPath(), params.range.start.line + 1));
        Instant started = Instant.now();
        Path file = Paths.get(params.textDocument.uri);
        TreeMap<String, Rewrite> rewrites = new TreeMap<String, Rewrite>();
        try (CompileTask task = this.compiler.compile(file);){
            long elapsed = Duration.between(started, Instant.now()).toMillis();
            LOG.info(String.format("...compiled in %d ms", elapsed));
            LineMap lines = task.root().getLineMap();
            long cursor = lines.getPosition(params.range.start.line + 1, params.range.start.character + 1);
            rewrites.putAll(this.overrideInheritedMethods(task, file, cursor));
        }
        ArrayList<CodeAction> actions = new ArrayList<CodeAction>();
        for (String title : rewrites.keySet()) {
            actions.addAll(this.createQuickFix(title, (Rewrite)rewrites.get(title)));
        }
        long elapsed = Duration.between(started, Instant.now()).toMillis();
        LOG.info(String.format("...created %d actions in %d ms", actions.size(), elapsed));
        return actions;
    }

    private Map<String, Rewrite> overrideInheritedMethods(CompileTask task, Path file, long cursor) {
        if (!this.isBlankLine(task.root(), cursor)) {
            return Map.of();
        }
        if (this.isInMethod(task, cursor)) {
            return Map.of();
        }
        MethodTree methodTree = (MethodTree)new FindMethodDeclarationAt(task.task).scan(task.root(), Long.valueOf(cursor));
        if (methodTree != null) {
            return Map.of();
        }
        TreeMap<String, Rewrite> actions = new TreeMap<String, Rewrite>();
        Trees trees = Trees.instance(task.task);
        ClassTree classTree = (ClassTree)new FindTypeDeclarationAt(task.task).scan(task.root(), Long.valueOf(cursor));
        if (classTree == null) {
            return Map.of();
        }
        TreePath classPath = trees.getPath(task.root(), classTree);
        Elements elements = task.task.getElements();
        TypeElement classElement = (TypeElement)trees.getElement(classPath);
        for (Element element : elements.getAllMembers(classElement)) {
            if (element.getModifiers().contains((Object)Modifier.FINAL) || element.getKind() != ElementKind.METHOD) continue;
            ExecutableElement method = (ExecutableElement)element;
            TypeElement methodSource = (TypeElement)element.getEnclosingElement();
            if (methodSource.getQualifiedName().contentEquals("java.lang.Object") || methodSource.equals(classElement)) continue;
            MethodPtr ptr = new MethodPtr(task.task, method);
            OverrideInheritedMethod rewrite = new OverrideInheritedMethod(ptr.className, ptr.methodName, ptr.erasedParameterTypes, file, (int)cursor);
            String title = "Override '" + String.valueOf(method.getSimpleName()) + "' from " + ptr.className;
            actions.put(title, rewrite);
        }
        return actions;
    }

    private boolean isInMethod(CompileTask task, long cursor) {
        MethodTree method = (MethodTree)new FindMethodDeclarationAt(task.task).scan(task.root(), Long.valueOf(cursor));
        return method != null;
    }

    private boolean isBlankLine(CompilationUnitTree root, long cursor) {
        CharSequence contents;
        LineMap lines = root.getLineMap();
        long line = lines.getLineNumber(cursor);
        long start = lines.getStartPosition(line);
        try {
            contents = root.getSourceFile().getCharContent(true);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        for (long i = start; i < cursor; ++i) {
            if (Character.isWhitespace(contents.charAt((int)i))) continue;
            return false;
        }
        return true;
    }

    public List<CodeAction> codeActionForDiagnostics(CodeActionParams params) {
        LOG.info(String.format("Check %d diagnostics for quick fixes...", params.context.diagnostics.size()));
        Instant started = Instant.now();
        Path file = Paths.get(params.textDocument.uri);
        try (CompileTask task = this.compiler.compile(file);){
            ArrayList<CodeAction> actions = new ArrayList<CodeAction>();
            for (Diagnostic d : params.context.diagnostics) {
                List<CodeAction> newActions = this.codeActionForDiagnostic(task, file, d);
                actions.addAll(newActions);
            }
            long elapsed = Duration.between(started, Instant.now()).toMillis();
            LOG.info(String.format("...created %d quick fixes in %d ms", actions.size(), elapsed));
            ArrayList<CodeAction> arrayList = actions;
            return arrayList;
        }
    }

    private List<CodeAction> codeActionForDiagnostic(CompileTask task, Path file, Diagnostic d) {
        switch (d.code) {
            case "unused_local": {
                ConvertVariableToStatement toStatement = new ConvertVariableToStatement(file, this.findPosition(task, d.range.start));
                return this.createQuickFix("Convert to statement", toStatement);
            }
            case "unused_field": {
                ConvertFieldToBlock toBlock = new ConvertFieldToBlock(file, this.findPosition(task, d.range.start));
                return this.createQuickFix("Convert to block", toBlock);
            }
            case "unused_class": {
                RemoveClass removeClass = new RemoveClass(file, this.findPosition(task, d.range.start));
                return this.createQuickFix("Remove class", removeClass);
            }
            case "unused_method": {
                MethodPtr unusedMethod = this.findMethod(task, d.range);
                RemoveMethod removeMethod = new RemoveMethod(unusedMethod.className, unusedMethod.methodName, unusedMethod.erasedParameterTypes);
                return this.createQuickFix("Remove method", removeMethod);
            }
            case "unused_throws": {
                CharSequence shortExceptionName = this.extractRange(task, d.range);
                String notThrown = this.extractNotThrownExceptionName(d.message);
                MethodPtr methodWithExtraThrow = this.findMethod(task, d.range);
                RemoveException removeThrow = new RemoveException(methodWithExtraThrow.className, methodWithExtraThrow.methodName, methodWithExtraThrow.erasedParameterTypes, notThrown);
                return this.createQuickFix("Remove '" + String.valueOf(shortExceptionName) + "'", removeThrow);
            }
            case "compiler.warn.unchecked.call.mbr.of.raw.type": {
                MethodPtr warnedMethod = this.findMethod(task, d.range);
                AddSuppressWarningAnnotation suppressWarning = new AddSuppressWarningAnnotation(warnedMethod.className, warnedMethod.methodName, warnedMethod.erasedParameterTypes);
                return this.createQuickFix("Suppress 'unchecked' warning", suppressWarning);
            }
            case "compiler.err.unreported.exception.need.to.catch.or.throw": {
                MethodPtr needsThrow = this.findMethod(task, d.range);
                String exceptionName = this.extractExceptionName(d.message);
                AddException addThrows = new AddException(needsThrow.className, needsThrow.methodName, needsThrow.erasedParameterTypes, exceptionName);
                return this.createQuickFix("Add 'throws'", addThrows);
            }
            case "compiler.err.cant.resolve": 
            case "compiler.err.cant.resolve.location": {
                CharSequence simpleName = this.extractRange(task, d.range);
                ArrayList<CodeAction> allImports = new ArrayList<CodeAction>();
                for (String qualifiedName : this.compiler.publicTopLevelTypes()) {
                    if (!qualifiedName.endsWith("." + String.valueOf(simpleName))) continue;
                    String title = "Import '" + qualifiedName + "'";
                    AddImport addImport = new AddImport(file, qualifiedName);
                    allImports.addAll(this.createQuickFix(title, addImport));
                }
                return allImports;
            }
            case "compiler.err.var.not.initialized.in.default.constructor": {
                String needsConstructor = this.findClassNeedingConstructor(task, d.range);
                if (needsConstructor == null) {
                    return List.of();
                }
                GenerateRecordConstructor generateConstructor = new GenerateRecordConstructor(needsConstructor);
                return this.createQuickFix("Generate constructor", generateConstructor);
            }
            case "compiler.err.does.not.override.abstract": {
                String missingAbstracts = this.findClass(task, d.range);
                ImplementAbstractMethods implementAbstracts = new ImplementAbstractMethods(missingAbstracts);
                return this.createQuickFix("Implement abstract methods", implementAbstracts);
            }
            case "compiler.err.cant.resolve.location.args": {
                CreateMissingMethod missingMethod = new CreateMissingMethod(file, this.findPosition(task, d.range.start));
                return this.createQuickFix("Create missing method", missingMethod);
            }
        }
        return List.of();
    }

    private int findPosition(CompileTask task, Position position) {
        LineMap lines = task.root().getLineMap();
        return (int)lines.getPosition(position.line + 1, position.character + 1);
    }

    private String findClassNeedingConstructor(CompileTask task, Range range) {
        ClassTree type = this.findClassTree(task, range);
        if (type == null || this.hasConstructor(task, type)) {
            return null;
        }
        return this.qualifiedName(task, type);
    }

    private String findClass(CompileTask task, Range range) {
        ClassTree type = this.findClassTree(task, range);
        if (type == null) {
            return null;
        }
        return this.qualifiedName(task, type);
    }

    private ClassTree findClassTree(CompileTask task, Range range) {
        long position = task.root().getLineMap().getPosition(range.start.line + 1, range.start.character + 1);
        return (ClassTree)new FindTypeDeclarationAt(task.task).scan(task.root(), Long.valueOf(position));
    }

    private String qualifiedName(CompileTask task, ClassTree tree) {
        Trees trees = Trees.instance(task.task);
        TreePath path = trees.getPath(task.root(), tree);
        TypeElement type = (TypeElement)trees.getElement(path);
        return type.getQualifiedName().toString();
    }

    private boolean hasConstructor(CompileTask task, ClassTree type) {
        for (Tree tree : type.getMembers()) {
            MethodTree method;
            if (!(tree instanceof MethodTree) || !this.isConstructor(task, method = (MethodTree)tree)) continue;
            return true;
        }
        return false;
    }

    private boolean isConstructor(CompileTask task, MethodTree method) {
        return method.getName().contentEquals("<init>") && !this.synthentic(task, method);
    }

    private boolean synthentic(CompileTask task, MethodTree method) {
        return Trees.instance(task.task).getSourcePositions().getStartPosition(task.root(), method) != -1L;
    }

    private MethodPtr findMethod(CompileTask task, Range range) {
        Trees trees = Trees.instance(task.task);
        long position = task.root().getLineMap().getPosition(range.start.line + 1, range.start.character + 1);
        MethodTree tree = (MethodTree)new FindMethodDeclarationAt(task.task).scan(task.root(), Long.valueOf(position));
        TreePath path = trees.getPath(task.root(), tree);
        ExecutableElement method = (ExecutableElement)trees.getElement(path);
        return new MethodPtr(task.task, method);
    }

    private String extractNotThrownExceptionName(String message) {
        Matcher matcher = NOT_THROWN_EXCEPTION.matcher(message);
        if (!matcher.find()) {
            LOG.warning(String.format("`%s` doesn't match `%s`", message, NOT_THROWN_EXCEPTION));
            return "";
        }
        return matcher.group(1);
    }

    private String extractExceptionName(String message) {
        Matcher matcher = UNREPORTED_EXCEPTION.matcher(message);
        if (!matcher.find()) {
            LOG.warning(String.format("`%s` doesn't match `%s`", message, UNREPORTED_EXCEPTION));
            return "";
        }
        return matcher.group(1);
    }

    private CharSequence extractRange(CompileTask task, Range range) {
        CharSequence contents;
        try {
            contents = task.root().getSourceFile().getCharContent(true);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        int start = (int)task.root().getLineMap().getPosition(range.start.line + 1, range.start.character + 1);
        int end = (int)task.root().getLineMap().getPosition(range.end.line + 1, range.end.character + 1);
        return contents.subSequence(start, end);
    }

    private List<CodeAction> createQuickFix(String title, Rewrite rewrite) {
        Map<Path, TextEdit[]> edits = rewrite.rewrite(this.compiler);
        if (edits == Rewrite.CANCELLED) {
            return List.of();
        }
        CodeAction a = new CodeAction();
        a.kind = "quickfix";
        a.title = title;
        a.edit = new WorkspaceEdit();
        for (Path file : edits.keySet()) {
            a.edit.changes.put(file.toUri(), List.of(edits.get(file)));
        }
        return List.of(a);
    }

    class MethodPtr {
        String className;
        String methodName;
        String[] erasedParameterTypes;

        MethodPtr(JavacTask task, ExecutableElement method) {
            Types types = task.getTypes();
            TypeElement parent = (TypeElement)method.getEnclosingElement();
            this.className = parent.getQualifiedName().toString();
            this.methodName = method.getSimpleName().toString();
            this.erasedParameterTypes = new String[method.getParameters().size()];
            for (int i = 0; i < this.erasedParameterTypes.length; ++i) {
                VariableElement param = method.getParameters().get(i);
                TypeMirror type = param.asType();
                TypeMirror erased = types.erasure(type);
                this.erasedParameterTypes[i] = erased.toString();
            }
        }
    }
}

