/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.symbols.Usage;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.types.InferredType;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S6982")
public class TorchModuleModeShouldBeSetAfterLoadingCheck
extends PythonSubscriptionCheck {
    private static final Set<String> STATE_SETTING_FUNCTION_FQNS = Set.of("eval", "train");
    private static final String LOAD_STATE_DICT_NAME = "load_state_dict";
    private static final String MESSAGE = "Set the module in training or evaluation mode.";

    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
            CallExpression callExpr = (CallExpression)ctx.syntaxNode();
            List<Usage> receiverUsages = TorchModuleModeShouldBeSetAfterLoadingCheck.getForwardUsagesOfReceiver(callExpr);
            if (TorchModuleModeShouldBeSetAfterLoadingCheck.isLoadStateDictCall(callExpr) && !TorchModuleModeShouldBeSetAfterLoadingCheck.hasEvalOrTrainUsage(receiverUsages) && !TorchModuleModeShouldBeSetAfterLoadingCheck.isModelPassedOn(receiverUsages)) {
                ctx.addIssue((Tree)callExpr.callee(), MESSAGE);
            }
        });
    }

    private static boolean isLoadStateDictCall(CallExpression callExpr) {
        Expression expression = callExpr.callee();
        if (expression instanceof QualifiedExpression) {
            QualifiedExpression qualifiedExpr = (QualifiedExpression)expression;
            InferredType qualifierType = qualifiedExpr.qualifier().type();
            boolean isModule = qualifierType.mustBeOrExtend("torch.nn.modules.module.Module") || qualifierType.mustBeOrExtend("torch.nn.Module");
            return isModule && LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name());
        }
        return false;
    }

    private static List<Usage> getForwardUsagesOfReceiver(CallExpression callExpr) {
        List usages = TorchModuleModeShouldBeSetAfterLoadingCheck.getFunctionCallReceiverName(callExpr).flatMap(name -> Optional.ofNullable(name.symbol())).map(Symbol::usages).orElse(Collections.emptyList());
        return usages.stream().filter(usage -> usage.tree().firstToken().line() > callExpr.firstToken().line()).toList();
    }

    private static Optional<Name> getFunctionCallReceiverName(CallExpression callExpr) {
        return Optional.ofNullable(callExpr.callee()).flatMap(TreeUtils.toOptionalInstanceOfMapper(QualifiedExpression.class)).flatMap(qualifiedExpr -> Optional.ofNullable(qualifiedExpr.qualifier())).flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class));
    }

    private static boolean hasEvalOrTrainUsage(List<Usage> usages) {
        return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isEvalOrTrain);
    }

    private static boolean isEvalOrTrain(Usage usage) {
        Tree callTree = TreeUtils.firstAncestorOfKind((Tree)usage.tree(), (Tree.Kind[])new Tree.Kind[]{Tree.Kind.CALL_EXPR});
        if (callTree != null) {
            CallExpression usageCall = (CallExpression)callTree;
            Symbol usageCallSymbol = usageCall.calleeSymbol();
            return usageCallSymbol != null && STATE_SETTING_FUNCTION_FQNS.contains(usageCallSymbol.name());
        }
        return false;
    }

    private static boolean isModelPassedOn(List<Usage> usages) {
        return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isPassingModel);
    }

    private static boolean isPassingModel(Usage usage) {
        return TreeUtils.firstAncestorOfKind((Tree)usage.tree(), (Tree.Kind[])new Tree.Kind[]{Tree.Kind.CALL_EXPR}) != null;
    }
}

