/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.ClassSymbol;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.symbols.Usage;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.HasSymbol;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.python.checks.utils.Expressions;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S6973")
public class SklearnEstimatorHyperparametersCheck
extends PythonSubscriptionCheck {
    private static final String MESSAGE = "Specify important hyperparameters when instantiating a Scikit-learn estimator.";
    private static final String LEARNING_RATE = "learning_rate";
    private static final String N_NEIGHBORS = "n_neighbors";
    private static final String KERNEL = "kernel";
    private static final String GAMMA = "gamma";
    private static final String C = "C";
    private static final Map<String, List<Param>> ESTIMATORS_AND_PARAMETERS_TO_CHECK = Map.ofEntries(Map.entry("sklearn.ensemble._weight_boosting.AdaBoostClassifier", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._weight_boosting.AdaBoostRegressor", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._gb.GradientBoostingClassifier", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._gb.GradientBoostingRegressor", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingClassifier", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingRegressor", List.of(new Param("learning_rate"))), Map.entry("sklearn.ensemble._forest.RandomForestClassifier", List.of(new Param("min_samples_leaf"), new Param("max_features"))), Map.entry("sklearn.ensemble._forest.RandomForestRegressor", List.of(new Param("min_samples_leaf"), new Param("max_features"))), Map.entry("sklearn.linear_model._coordinate_descent.ElasticNet", List.of(new Param("alpha", 0), new Param("l1_ratio"))), Map.entry("sklearn.neighbors._unsupervised.NearestNeighbors", List.of(new Param("n_neighbors", 0))), Map.entry("sklearn.neighbors._classification.KNeighborsClassifier", List.of(new Param("n_neighbors", 0))), Map.entry("sklearn.neighbors._regression.KNeighborsRegressor", List.of(new Param("n_neighbors", 0))), Map.entry("sklearn.svm._classes.NuSVC", List.of(new Param("nu"), new Param("kernel"), new Param("gamma"))), Map.entry("sklearn.svm._classes.NuSVR", List.of(new Param("C"), new Param("kernel"), new Param("gamma"))), Map.entry("sklearn.svm._classes.SVC", List.of(new Param("C"), new Param("kernel"), new Param("gamma"))), Map.entry("sklearn.svm._classes.SVR", List.of(new Param("C"), new Param("kernel"), new Param("gamma"))), Map.entry("sklearn.tree._classes.DecisionTreeClassifier", List.of(new Param("ccp_alpha"))), Map.entry("sklearn.tree._classes.DecisionTreeRegressor", List.of(new Param("ccp_alpha"))), Map.entry("sklearn.neural_network._multilayer_perceptron.MLPClassifier", List.of(new Param("hidden_layer_sizes", 0))), Map.entry("sklearn.neural_network._multilayer_perceptron.MLPRegressor", List.of(new Param("hidden_layer_sizes", 0))), Map.entry("sklearn.preprocessing._polynomial.PolynomialFeatures", List.of(new Param("degree", 0), new Param("interaction_only"))));
    private static final Set<String> SEARCH_CV_FQNS = Set.of("sklearn.model_selection._search.GridSearchCV", "sklearn.model_selection._search.RandomizedSearchCV", "sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV", "sklearn.model_selection._search_successive_halving.HalvingGridSearchCV");
    private static final Set<String> PIPELINE_FQNS = Set.of("sklearn.pipeline.make_pipeline", "sklearn.pipeline.Pipeline");

    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, SklearnEstimatorHyperparametersCheck::checkEstimator);
    }

    private static void checkEstimator(SubscriptionContext ctx) {
        CallExpression callExpression = (CallExpression)ctx.syntaxNode();
        Symbol calleeSymbol = callExpression.calleeSymbol();
        Optional.ofNullable(calleeSymbol).filter(callee -> callee.is(new Symbol.Kind[]{Symbol.Kind.CLASS})).map(ClassSymbol.class::cast).map(Symbol::fullyQualifiedName).map(ESTIMATORS_AND_PARAMETERS_TO_CHECK::get).filter(parameters -> !SklearnEstimatorHyperparametersCheck.isDirectlyUsedInSearchCV(callExpression)).filter(parameters -> !SklearnEstimatorHyperparametersCheck.isSetParamsCalled(callExpression)).filter(parameters -> !SklearnEstimatorHyperparametersCheck.isPartOfPipelineAndSearchCV(callExpression)).filter(parameters -> SklearnEstimatorHyperparametersCheck.isMissingAHyperparameter(callExpression, parameters)).ifPresent(parameters -> ctx.addIssue((Tree)callExpression, MESSAGE));
    }

    private static boolean isMissingAHyperparameter(CallExpression callExpression, List<Param> parametersToCheck) {
        return parametersToCheck.stream().map(param -> param.position().map(position -> TreeUtils.nthArgumentOrKeyword((int)position, (String)param.name, (List)callExpression.arguments())).orElse(TreeUtils.argumentByKeyword((String)param.name, (List)callExpression.arguments()))).anyMatch(Objects::isNull);
    }

    private static boolean isDirectlyUsedInSearchCV(CallExpression callExpression) {
        return Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)callExpression, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.REGULAR_ARGUMENT})).flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class)).map(SklearnEstimatorHyperparametersCheck::isArgumentPartOfSearchCV).orElse(false);
    }

    private static boolean isPartOfPipelineAndSearchCV(CallExpression callExpression) {
        return Expressions.getAssignedName((Expression)callExpression).map(SklearnEstimatorHyperparametersCheck::isEstimatorUsedInSearchCV).or(() -> SklearnEstimatorHyperparametersCheck.getPipelineAssignement(callExpression).map(SklearnEstimatorHyperparametersCheck::isEstimatorUsedInSearchCV)).orElse(false);
    }

    private static Optional<Name> getPipelineAssignement(CallExpression callExpression) {
        return Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)callExpression, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.CALL_EXPR})).flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class)).filter(callExp -> Optional.ofNullable(callExp.calleeSymbol()).map(Symbol::fullyQualifiedName).map(PIPELINE_FQNS::contains).orElse(false)).flatMap(Expressions::getAssignedName);
    }

    private static boolean isEstimatorUsedInSearchCV(Name estimator) {
        return Optional.ofNullable(estimator.symbol()).map(Symbol::usages).map(usages -> usages.stream().map(Usage::tree).map(Tree::parent).filter(parent -> parent.is(new Tree.Kind[]{Tree.Kind.REGULAR_ARGUMENT})).map(RegularArgument.class::cast).anyMatch(SklearnEstimatorHyperparametersCheck::isArgumentPartOfSearchCV)).orElse(false);
    }

    private static boolean isArgumentPartOfSearchCV(RegularArgument arg) {
        return Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)arg, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.CALL_EXPR})).flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class)).map(CallExpression::calleeSymbol).map(Symbol::fullyQualifiedName).map(SEARCH_CV_FQNS::contains).orElse(false);
    }

    private static boolean isSetParamsCalled(CallExpression callExpression) {
        return Expressions.getAssignedName((Expression)callExpression).map(HasSymbol::symbol).map(Symbol::usages).map(SklearnEstimatorHyperparametersCheck::isUsedWithSetParams).orElse(false);
    }

    private static boolean isUsedWithSetParams(List<Usage> usages) {
        return usages.stream().map(Usage::tree).map(Tree::parent).filter(parent -> parent.is(new Tree.Kind[]{Tree.Kind.QUALIFIED_EXPR})).map(TreeUtils.toInstanceOfMapper(QualifiedExpression.class)).filter(Objects::nonNull).map(qExp -> qExp.name().name()).anyMatch("set_params"::equals);
    }

    private record Param(String name, Optional<Integer> position) {
        public Param(String name) {
            this(name, Optional.empty());
        }

        public Param(String name, int position) {
            this(name, Optional.of(position));
        }
    }
}

