/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Optional;
import javax.annotation.Nullable;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonCheck;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
import org.sonar.plugins.python.api.quickfix.PythonTextEdit;
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.tree.ArgList;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.ClassDef;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.FunctionDef;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.python.checks.utils.CheckUtils;
import org.sonar.python.quickfix.TextEditUtils;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S6978")
public class TorchModuleShouldCallInitCheck
extends PythonSubscriptionCheck {
    private static final String TORCH_NN_MODULE = "torch.nn.modules.module.Module";
    private static final String MESSAGE = "Add a call to super().__init__().";
    private static final String SECONDARY_MESSAGE = "Inheritance happens here";
    public static final String QUICK_FIX_MESSAGE = "insert call to super constructor";

    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, ctx -> {
            FunctionDef funcDef = (FunctionDef)ctx.syntaxNode();
            ClassDef classDef = CheckUtils.getParentClassDef((Tree)funcDef);
            if (TorchModuleShouldCallInitCheck.isInheritingFromTorchModule(classDef) && TorchModuleShouldCallInitCheck.isConstructor(funcDef) && TorchModuleShouldCallInitCheck.isMissingSuperCall(funcDef)) {
                PythonCheck.PreciseIssue issue = ctx.addIssue((Tree)funcDef.name(), MESSAGE);
                issue.secondary((Tree)classDef.name(), SECONDARY_MESSAGE);
                TorchModuleShouldCallInitCheck.createQuickFix(funcDef).ifPresent(arg_0 -> ((PythonCheck.PreciseIssue)issue).addQuickFix(arg_0));
            }
        });
    }

    private static boolean isConstructor(FunctionDef funcDef) {
        FunctionSymbol symbol = TreeUtils.getFunctionSymbolFromDef((FunctionDef)funcDef);
        return symbol != null && "__init__".equals(symbol.name()) && funcDef.isMethodDefinition();
    }

    private static boolean isInheritingFromTorchModule(@Nullable ClassDef classDef) {
        if (classDef == null) {
            return false;
        }
        ArgList args = classDef.args();
        return args != null && args.arguments().stream().flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class)).map(arg -> TorchModuleShouldCallInitCheck.getQualifiedName(arg.expression())).anyMatch(expr -> expr.filter(TORCH_NN_MODULE::equals).isPresent());
    }

    private static Optional<String> getQualifiedName(Expression node) {
        return TreeUtils.getSymbolFromTree((Tree)node).flatMap(symbol -> Optional.ofNullable(symbol.fullyQualifiedName()));
    }

    private static boolean isMissingSuperCall(FunctionDef funcDef) {
        ClassDef parentClassDef = CheckUtils.getParentClassDef((Tree)funcDef);
        return parentClassDef != null && !TreeUtils.hasDescendant((Tree)parentClassDef, t -> t.is(new Tree.Kind[]{Tree.Kind.CALL_EXPR}) && TorchModuleShouldCallInitCheck.isSuperConstructorCall((CallExpression)t));
    }

    private static boolean isSuperConstructorCall(CallExpression callExpr) {
        QualifiedExpression qualifiedCallee;
        Expression expression = callExpr.callee();
        return expression instanceof QualifiedExpression && TorchModuleShouldCallInitCheck.isSuperCall((qualifiedCallee = (QualifiedExpression)expression).qualifier()) && "__init__".equals(qualifiedCallee.name().name());
    }

    private static boolean isSuperCall(Expression qualifier) {
        if (qualifier instanceof CallExpression) {
            CallExpression callExpression = (CallExpression)qualifier;
            Symbol superSymbol = callExpression.calleeSymbol();
            return superSymbol != null && "super".equals(superSymbol.name());
        }
        return false;
    }

    private static Optional<PythonQuickFix> createQuickFix(FunctionDef functionDef) {
        if (functionDef.colon().line() == functionDef.body().firstToken().line()) {
            return Optional.empty();
        }
        PythonTextEdit pythonTextEdit = TextEditUtils.insertLineAfter((Tree)functionDef.colon(), (Tree)functionDef.body(), (String)"super().__init__()");
        return Optional.of(PythonQuickFix.newQuickFix((String)QUICK_FIX_MESSAGE, (PythonTextEdit[])new PythonTextEdit[]{pythonTextEdit}));
    }
}

