/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonCheck;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.PythonVersionUtils;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.TriBool;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.tree.AssignmentStatement;
import org.sonar.plugins.python.api.tree.BinaryExpression;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.CompoundAssignmentStatement;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.ExpressionList;
import org.sonar.plugins.python.api.tree.HasSymbol;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.types.v2.PythonType;
import org.sonar.python.checks.cdk.WeakSSLProtocolCheckPart;
import org.sonar.python.checks.utils.Expressions;
import org.sonar.python.semantic.v2.SymbolV2;
import org.sonar.python.semantic.v2.UsageV2;
import org.sonar.python.tree.TreeUtils;
import org.sonar.python.types.v2.TypeCheckBuilder;

@Rule(key="S4423")
public class WeakSSLProtocolCheck
extends PythonSubscriptionCheck {
    private static final List<String> WEAK_PROTOCOL_CONSTANTS = Arrays.asList("ssl.PROTOCOL_SSLv2", "ssl.PROTOCOL_SSLv3", "ssl.PROTOCOL_SSLv23", "ssl.PROTOCOL_TLSv1", "ssl.PROTOCOL_TLSv1_1", "OpenSSL.SSL.SSLv2_METHOD", "OpenSSL.SSL.SSLv3_METHOD", "OpenSSL.SSL.SSLv23_METHOD", "OpenSSL.SSL.TLSv1_METHOD", "OpenSSL.SSL.TLSv1_1_METHOD");
    private static final Set<String> SSL_CONTEXT_DEPENDENT_PROTOCOLS = Set.of("ssl.PROTOCOL_TLS_CLIENT", "ssl.PROTOCOL_TLS_SERVER", "ssl.PROTOCOL_TLS");
    private static final Set<String> OPENSSL_DEFAULT_TLS_METHODS = Set.of("OpenSSL.SSL.TLS_METHOD", "OpenSSL.SSL.TLS_SERVER_METHOD", "OpenSSL.SSL.TLS_CLIENT_METHOD");
    private static final Set<String> SAFE_VERSION_NAMES = Set.of("TLSv1_2", "TLSv1_3", "MAXIMUM_SUPPORTED");
    private static final Set<String> REQUIRED_SECURITY_FLAGS = Set.of("OP_NO_SSLv2", "OP_NO_SSLv3", "OP_NO_TLSv1", "OP_NO_TLSv1_1");
    private static final Set<String> DEFAULT_PURPOSES = Set.of("ssl.Purpose.CLIENT_AUTH", "ssl.Purpose.SERVER_AUTH");
    private static final String WEAK_PROTOCOL_MESSAGE = "Change this code to use a stronger protocol.";
    private static final String WEAK_PROTOCOL_MESSAGE_PYTHON_310 = "Use a stronger protocol, or upgrade to Python 3.10+ which uses secure defaults.";
    private TypeCheckBuilder createDefaultContextTypeCheckBuilder;
    private TypeCheckBuilder sslSSLContextTypeCheckBuilder;
    private TypeCheckBuilder openSSLContextTypeCheckBuilder;

    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.FILE_INPUT, ctx -> {
            this.createDefaultContextTypeCheckBuilder = ctx.typeChecker().typeCheckBuilder().isTypeWithName("ssl.create_default_context");
            this.sslSSLContextTypeCheckBuilder = ctx.typeChecker().typeCheckBuilder().isTypeWithName("ssl.SSLContext");
            this.openSSLContextTypeCheckBuilder = ctx.typeChecker().typeCheckBuilder().isTypeWithName("OpenSSL.SSL.Context");
        });
        context.registerSyntaxNodeConsumer(Tree.Kind.NAME, WeakSSLProtocolCheck::checkName);
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, this::checkCallExpression);
        new WeakSSLProtocolCheckPart().initialize(context);
    }

    private static void checkName(SubscriptionContext ctx) {
        Name name = (Name)ctx.syntaxNode();
        Optional.of(name).map(HasSymbol::symbol).map(Symbol::fullyQualifiedName).filter(WEAK_PROTOCOL_CONSTANTS::contains).ifPresent(fqn -> ctx.addIssue((Tree)name, WEAK_PROTOCOL_MESSAGE));
    }

    private void checkCallExpression(SubscriptionContext ctx) {
        CallExpression callExpression = (CallExpression)ctx.syntaxNode();
        Expression callee = callExpression.callee();
        PythonType pythonType = callee.typeV2();
        if (this.isSslContextWithDefaultProtocols(pythonType, callExpression)) {
            WeakSSLProtocolCheck.checkSSLContext(ctx, (Tree)callExpression);
        } else if (this.isOpenSSLContextWithDefaultTLSMethods(pythonType, callExpression)) {
            WeakSSLProtocolCheck.checkOpenSSLContext(ctx, (Tree)callExpression);
        }
    }

    private boolean isSslContextWithDefaultProtocols(PythonType pythonType, CallExpression callExpression) {
        return this.createDefaultContextTypeCheckBuilder.check(pythonType) == TriBool.TRUE && WeakSSLProtocolCheck.hasDefaultFirstArgument(callExpression, "purpose", DEFAULT_PURPOSES) || this.sslSSLContextTypeCheckBuilder.check(pythonType) == TriBool.TRUE && WeakSSLProtocolCheck.hasDefaultFirstArgument(callExpression, "protocol", SSL_CONTEXT_DEPENDENT_PROTOCOLS);
    }

    private boolean isOpenSSLContextWithDefaultTLSMethods(PythonType pythonType, CallExpression callExpression) {
        return this.openSSLContextTypeCheckBuilder.check(pythonType) == TriBool.TRUE && WeakSSLProtocolCheck.hasDefaultFirstArgument(callExpression, "method", OPENSSL_DEFAULT_TLS_METHODS);
    }

    private static boolean hasDefaultFirstArgument(CallExpression callExpr, String keyword, Set<String> allowedValues) {
        RegularArgument arg = TreeUtils.nthArgumentOrKeyword((int)0, (String)keyword, (List)callExpr.arguments());
        if (arg == null) {
            return true;
        }
        return Optional.of(arg.expression()).filter(HasSymbol.class::isInstance).map(HasSymbol.class::cast).map(HasSymbol::symbol).map(Symbol::fullyQualifiedName).filter(allowedValues::contains).isPresent();
    }

    private static void checkSSLContext(SubscriptionContext ctx, Tree tree) {
        WeakSSLProtocolCheck.getContextSymbol(tree).ifPresentOrElse(contextSymbol -> WeakSSLProtocolCheck.checkSSLContextSymbol(ctx, contextSymbol, tree), () -> {
            if (!WeakSSLProtocolCheck.isSafePythonVersion(ctx)) {
                ctx.addIssue(tree, WEAK_PROTOCOL_MESSAGE_PYTHON_310);
            }
        });
    }

    private static void checkOpenSSLContext(SubscriptionContext ctx, Tree tree) {
        WeakSSLProtocolCheck.getContextSymbol(tree).ifPresentOrElse(contextSymbol -> WeakSSLProtocolCheck.checkOpenSSLContextSymbol(ctx, contextSymbol, tree), () -> ctx.addIssue(tree, WEAK_PROTOCOL_MESSAGE));
    }

    private static void checkSSLContextSymbol(SubscriptionContext ctx, SymbolV2 contextSymbol, Tree locationForIssue) {
        boolean isUnsafeContext = WeakSSLProtocolCheck.isUnsafeDefaultContext(ctx, contextSymbol);
        Optional<AssignmentStatement> unsafeMaximumVersionStatement = WeakSSLProtocolCheck.findUnsafeMaximumVersionStatement(contextSymbol);
        if (isUnsafeContext && unsafeMaximumVersionStatement.isPresent()) {
            PythonCheck.PreciseIssue issue = ctx.addIssue(locationForIssue, WEAK_PROTOCOL_MESSAGE);
            issue.secondary((Tree)unsafeMaximumVersionStatement.get(), "Unsafe maximum version specified here");
        } else if (unsafeMaximumVersionStatement.isPresent()) {
            ctx.addIssue((Tree)unsafeMaximumVersionStatement.get(), WEAK_PROTOCOL_MESSAGE);
        } else if (isUnsafeContext) {
            ctx.addIssue(locationForIssue, WEAK_PROTOCOL_MESSAGE_PYTHON_310);
        }
    }

    private static void checkOpenSSLContextSymbol(SubscriptionContext ctx, SymbolV2 contextSymbol, Tree locationForIssue) {
        if (!WeakSSLProtocolCheck.isSecurelyConfiguredOpenSSLContext(contextSymbol)) {
            ctx.addIssue(locationForIssue, WEAK_PROTOCOL_MESSAGE);
        }
    }

    private static boolean isUnsafeDefaultContext(SubscriptionContext ctx, SymbolV2 contextSymbol) {
        return !WeakSSLProtocolCheck.isSafePythonVersion(ctx) && !WeakSSLProtocolCheck.isSecurelyConfigured(contextSymbol);
    }

    private static boolean isSafePythonVersion(SubscriptionContext ctx) {
        return PythonVersionUtils.areSourcePythonVersionsGreaterOrEqualThan((Set)ctx.sourcePythonVersions(), (PythonVersionUtils.Version)PythonVersionUtils.Version.V_310);
    }

    private static Optional<SymbolV2> getContextSymbol(Tree tree) {
        return Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)tree, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.ASSIGNMENT_STMT})).map(AssignmentStatement.class::cast).map(as -> (Expression)((ExpressionList)as.lhsExpressions().get(0)).expressions().get(0)).filter(Name.class::isInstance).map(Name.class::cast).map(Name::symbolV2);
    }

    private static boolean isSecurelyConfigured(SymbolV2 symbolV2) {
        Set<String> securityFlags = WeakSSLProtocolCheck.collectSecurityFlags(symbolV2, "options");
        return symbolV2.usages().stream().anyMatch(u -> WeakSSLProtocolCheck.isSettingSafeMinimumVersion(u.tree())) || securityFlags.containsAll(REQUIRED_SECURITY_FLAGS);
    }

    private static boolean isSecurelyConfiguredOpenSSLContext(SymbolV2 symbolV2) {
        return WeakSSLProtocolCheck.isSecureThroughMinProtoVersion(symbolV2) || WeakSSLProtocolCheck.isSecureThroughSetOptions(symbolV2);
    }

    private static boolean isSecureThroughMinProtoVersion(SymbolV2 symbolV2) {
        return symbolV2.usages().stream().anyMatch(u -> {
            CallExpression callExpression = (CallExpression)TreeUtils.firstAncestorOfKind((Tree)u.tree(), (Tree.Kind[])new Tree.Kind[]{Tree.Kind.CALL_EXPR});
            if (callExpression == null) {
                return false;
            }
            Symbol symbol = callExpression.calleeSymbol();
            if (symbol == null || !"set_min_proto_version".equals(symbol.name())) {
                return false;
            }
            return callExpression.arguments().stream().filter(RegularArgument.class::isInstance).map(RegularArgument.class::cast).map(RegularArgument::expression).filter(HasSymbol.class::isInstance).map(HasSymbol.class::cast).map(HasSymbol::symbol).filter(Objects::nonNull).map(Symbol::fullyQualifiedName).filter(Objects::nonNull).anyMatch(fqn -> fqn.contains("TLS1_2_VERSION") || fqn.contains("TLS1_3_VERSION"));
        });
    }

    private static boolean isSecureThroughSetOptions(SymbolV2 symbolV2) {
        Set<String> securityFlags = WeakSSLProtocolCheck.collectOpenSSLSecurityFlags(symbolV2);
        return securityFlags.containsAll(REQUIRED_SECURITY_FLAGS);
    }

    private static Set<String> collectOpenSSLSecurityFlags(SymbolV2 symbolV2) {
        HashSet<String> securityFlags = new HashSet<String>();
        symbolV2.usages().stream().map(UsageV2::tree).map(Tree::parent).filter(QualifiedExpression.class::isInstance).map(QualifiedExpression.class::cast).filter(qe -> "set_options".equals(qe.name().name())).map(Tree::parent).filter(CallExpression.class::isInstance).map(CallExpression.class::cast).forEach(call -> {
            if (!call.arguments().isEmpty()) {
                TreeUtils.nthArgumentOrKeywordOptional((int)0, (String)"options", (List)call.arguments()).map(RegularArgument::expression).ifPresent(expression -> WeakSSLProtocolCheck.collectSecurityFlagsFromExpression(expression, securityFlags));
            }
        });
        return securityFlags;
    }

    private static Set<String> collectSecurityFlags(SymbolV2 symbolV2, String propertyName) {
        HashSet<String> securityFlags = new HashSet<String>();
        symbolV2.usages().stream().map(UsageV2::tree).map(Tree::parent).filter(QualifiedExpression.class::isInstance).map(QualifiedExpression.class::cast).filter(qe -> propertyName.equals(qe.name().name())).map(qe -> TreeUtils.firstAncestorOfKind((Tree)qe, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.COMPOUND_ASSIGNMENT})).filter(CompoundAssignmentStatement.class::isInstance).map(CompoundAssignmentStatement.class::cast).map(CompoundAssignmentStatement::rhsExpression).forEach(rhs -> WeakSSLProtocolCheck.collectSecurityFlagsFromExpression(rhs, securityFlags));
        return securityFlags;
    }

    private static void collectSecurityFlagsFromExpression(Expression expression, Set<String> securityFlags) {
        if ((expression = Expressions.removeParentheses(expression)) instanceof HasSymbol) {
            HasSymbol hasSymbol = (HasSymbol)expression;
            Optional.ofNullable(hasSymbol.symbol()).map(Symbol::fullyQualifiedName).ifPresent(fqn -> REQUIRED_SECURITY_FLAGS.stream().filter(fqn::contains).forEach(securityFlags::add));
        } else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression)expression;
            WeakSSLProtocolCheck.collectSecurityFlagsFromExpression(binaryExpression.leftOperand(), securityFlags);
            WeakSSLProtocolCheck.collectSecurityFlagsFromExpression(binaryExpression.rightOperand(), securityFlags);
        }
    }

    private static boolean isSettingSafeMinimumVersion(Tree tree) {
        return WeakSSLProtocolCheck.findVersionStatement(tree, "minimum_version", WeakSSLProtocolCheck::containsSafeVersion).isPresent();
    }

    private static Optional<AssignmentStatement> findUnsafeMaximumVersionStatement(SymbolV2 symbolV2) {
        return symbolV2.usages().stream().map(u -> WeakSSLProtocolCheck.findVersionStatement(u.tree(), "maximum_version", fqn -> !WeakSSLProtocolCheck.containsSafeVersion(fqn))).filter(Optional::isPresent).map(Optional::get).findFirst();
    }

    private static boolean containsSafeVersion(String fullyQualifiedName) {
        return SAFE_VERSION_NAMES.stream().anyMatch(fullyQualifiedName::contains);
    }

    private static Optional<AssignmentStatement> findVersionStatement(Tree tree, String versionProperty, Predicate<String> versionPredicate) {
        return Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)tree, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.ASSIGNMENT_STMT})).map(AssignmentStatement.class::cast).filter(a -> WeakSSLProtocolCheck.isSettingVersionProperty(a, versionProperty)).filter(a -> Optional.of(a.assignedValue()).filter(HasSymbol.class::isInstance).map(HasSymbol.class::cast).map(HasSymbol::symbol).map(Symbol::fullyQualifiedName).filter(versionPredicate).isPresent());
    }

    private static boolean isSettingVersionProperty(AssignmentStatement assignment, String versionProperty) {
        return assignment.lhsExpressions().stream().flatMap(lhsExpr -> lhsExpr.expressions().stream()).filter(expr -> expr.is(new Tree.Kind[]{Tree.Kind.QUALIFIED_EXPR})).map(QualifiedExpression.class::cast).anyMatch(qexpr -> versionProperty.equals(qexpr.name().name()));
    }
}

