/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.symbols.Usage;
import org.sonar.plugins.python.api.tree.Argument;
import org.sonar.plugins.python.api.tree.AssignmentStatement;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.DictionaryLiteral;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.ExpressionList;
import org.sonar.plugins.python.api.tree.KeyValuePair;
import org.sonar.plugins.python.api.tree.ListLiteral;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.StringLiteral;
import org.sonar.plugins.python.api.tree.SubscriptionExpression;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.tree.Tuple;
import org.sonar.python.checks.utils.Expressions;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S5659")
public class JwtVerificationCheck
extends PythonSubscriptionCheck {
    private static final String MESSAGE = "Don't use a JWT token without verifying its signature.";
    private static final Set<String> PROCESS_JWT_FQNS = Set.of("python_jwt.process_jwt", "jwt.process_jwt");
    private static final Set<String> VERIFY_JWT_FQNS = Set.of("python_jwt.verify_jwt", "jwt.verify_jwt");
    private static final Set<String> ALLOWED_KEYS_ACCESS = Set.of("jku", "jwk", "kid", "x5u", "x5c", "x5t", "xt#256");
    private static final Set<String> WHERE_VERIFY_KWARG_SHOULD_BE_TRUE_FQNS = Set.of("jwt.decode", "jose.jws.verify");
    private static final Set<String> UNVERIFIED_FQNS = Set.of("jwt.get_unverified_header", "jose.jwt.get_unverified_header", "jose.jwt.get_unverified_headers", "jose.jws.get_unverified_header", "jose.jws.get_unverified_headers", "jose.jwt.get_unverified_claims", "jose.jws.get_unverified_claims");
    private static final String VERIFY_SIGNATURE_KEYWORD = "verify_signature";
    public static final Set<String> VERIFY_SIGNATURE_OPTION_SUPPORTING_FUNCTION_FQNS = Set.of("jose.jwt.decode", "jwt.decode");

    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, JwtVerificationCheck::verifyCallExpression);
    }

    private static void verifyCallExpression(SubscriptionContext ctx) {
        CallExpression call = (CallExpression)ctx.syntaxNode();
        Symbol calleeSymbol = call.calleeSymbol();
        if (calleeSymbol == null || calleeSymbol.fullyQualifiedName() == null) {
            return;
        }
        String calleeFqn = calleeSymbol.fullyQualifiedName();
        if (WHERE_VERIFY_KWARG_SHOULD_BE_TRUE_FQNS.contains(calleeFqn)) {
            RegularArgument verifyArg = TreeUtils.argumentByKeyword((String)"verify", (List)call.arguments());
            if (verifyArg != null && Expressions.isFalsy(verifyArg.expression())) {
                ctx.addIssue((Tree)verifyArg, MESSAGE);
                return;
            }
        } else if (PROCESS_JWT_FQNS.contains(calleeFqn)) {
            Optional.ofNullable(TreeUtils.firstAncestorOfKind((Tree)call, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.FILE_INPUT, Tree.Kind.FUNCDEF})).filter(scriptOrFunction -> !TreeUtils.hasDescendant((Tree)scriptOrFunction, JwtVerificationCheck::isCallToVerifyJwt)).ifPresent(scriptOrFunction -> ctx.addIssue((Tree)call, MESSAGE));
        } else if (UNVERIFIED_FQNS.contains(calleeFqn) && !JwtVerificationCheck.accessOnlyAllowedHeaderKeys(call)) {
            Optional.ofNullable(TreeUtils.nthArgumentOrKeyword((int)0, (String)"", (List)call.arguments())).flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class)).map(RegularArgument::expression).ifPresent(argument -> ctx.addIssue((Tree)argument, MESSAGE));
        }
        if (VERIFY_SIGNATURE_OPTION_SUPPORTING_FUNCTION_FQNS.contains(calleeFqn)) {
            Optional.ofNullable(TreeUtils.argumentByKeyword((String)"options", (List)call.arguments())).map(RegularArgument::expression).filter(JwtVerificationCheck::isListOrDictWithSensitiveEntry).ifPresent(expression -> ctx.addIssue((Tree)expression, MESSAGE));
        }
    }

    private static boolean isListOrDictWithSensitiveEntry(@Nullable Expression expression) {
        if (expression == null) {
            return false;
        }
        if (expression.is(new Tree.Kind[]{Tree.Kind.NAME})) {
            return JwtVerificationCheck.isListOrDictWithSensitiveEntry(Expressions.singleAssignedNonNameValue((Name)expression).orElse(null));
        }
        if (expression.is(new Tree.Kind[]{Tree.Kind.DICTIONARY_LITERAL})) {
            return JwtVerificationCheck.hasTrueVerifySignatureEntry((DictionaryLiteral)expression);
        }
        if (expression.is(new Tree.Kind[]{Tree.Kind.LIST_LITERAL})) {
            return JwtVerificationCheck.hasTrueVerifySignatureEntry((ListLiteral)expression);
        }
        if (expression.is(new Tree.Kind[]{Tree.Kind.CALL_EXPR})) {
            return JwtVerificationCheck.isCallToDict((CallExpression)expression) && JwtVerificationCheck.hasIllegalDictKWArgument((CallExpression)expression);
        }
        return false;
    }

    private static boolean hasIllegalDictKWArgument(CallExpression expression) {
        return Optional.of(expression).map(CallExpression::arguments).map(arguments -> TreeUtils.argumentByKeyword((String)VERIFY_SIGNATURE_KEYWORD, (List)arguments)).map(RegularArgument::expression).filter(Expressions::isFalsy).isPresent();
    }

    private static boolean isCallToDict(CallExpression expression) {
        return Optional.of(expression).map(CallExpression::calleeSymbol).map(Symbol::fullyQualifiedName).filter("dict"::equals).isPresent();
    }

    private static boolean hasTrueVerifySignatureEntry(DictionaryLiteral dictionaryLiteral) {
        return dictionaryLiteral.elements().stream().filter(KeyValuePair.class::isInstance).map(KeyValuePair.class::cast).filter(keyValuePair -> JwtVerificationCheck.isSensitiveKey(keyValuePair.key())).map(KeyValuePair::value).anyMatch(Expressions::isFalsy);
    }

    private static boolean hasTrueVerifySignatureEntry(ListLiteral listLiteral) {
        return listLiteral.elements().expressions().stream().filter(Tuple.class::isInstance).map(Tuple.class::cast).map(Tuple::elements).filter(list -> list.size() == 2).filter(list -> JwtVerificationCheck.isSensitiveKey((Expression)list.get(0))).map(list -> (Expression)list.get(1)).anyMatch(Expressions::isFalsy);
    }

    private static boolean isSensitiveKey(Expression key) {
        return key.is(new Tree.Kind[]{Tree.Kind.STRING_LITERAL}) && VERIFY_SIGNATURE_KEYWORD.equals(((StringLiteral)key).trimmedQuotesValue());
    }

    private static boolean isCallToVerifyJwt(Tree t) {
        return TreeUtils.toOptionalInstanceOf(CallExpression.class, (Tree)t).map(CallExpression::calleeSymbol).map(Symbol::fullyQualifiedName).filter(VERIFY_JWT_FQNS::contains).isPresent();
    }

    private static boolean accessOnlyAllowedHeaderKeys(CallExpression call) {
        Name name;
        Symbol symbol;
        Tree assignment = TreeUtils.firstAncestorOfKind((Tree)call, (Tree.Kind[])new Tree.Kind[]{Tree.Kind.ASSIGNMENT_STMT});
        Stream<StringLiteral> headerKeysAccessedDirectly = JwtVerificationCheck.accessToHeaderKeyWithoutAssignment(call);
        if (assignment == null) {
            return JwtVerificationCheck.areStringLiteralsPartOfAllowedKeys(headerKeysAccessedDirectly);
        }
        List lhsExpressions = ((AssignmentStatement)assignment).lhsExpressions().stream().map(ExpressionList::expressions).flatMap(Collection::stream).toList();
        if (lhsExpressions.size() == 1 && ((Expression)lhsExpressions.get(0)).is(new Tree.Kind[]{Tree.Kind.NAME}) && (symbol = (name = (Name)lhsExpressions.get(0)).symbol()) != null) {
            Stream<StringLiteral> argumentsOfGet = JwtVerificationCheck.usagesAccessedByGet(symbol, call);
            Stream<StringLiteral> argumentsOfSubscription = JwtVerificationCheck.usagesAccessedBySubscription(symbol, call);
            Stream<StringLiteral> headerKeysAccessFromAssignedValues = Stream.concat(argumentsOfGet, argumentsOfSubscription);
            return JwtVerificationCheck.areStringLiteralsPartOfAllowedKeys(Stream.concat(headerKeysAccessFromAssignedValues, headerKeysAccessedDirectly));
        }
        return false;
    }

    private static boolean areStringLiteralsPartOfAllowedKeys(Stream<StringLiteral> literals) {
        List<StringLiteral> literalList = literals.toList();
        return !literalList.isEmpty() && literalList.stream().allMatch(str -> ALLOWED_KEYS_ACCESS.contains(str.trimmedQuotesValue()));
    }

    private static Stream<StringLiteral> accessToHeaderKeyWithoutAssignment(CallExpression call) {
        Stream<CallExpression> callExpressionFromGetUnverifiedHeaders = JwtVerificationCheck.getCallExprWhereDictIsAccessedWithGet(Stream.of(call.parent()));
        Stream<Argument> argumentsOfCallExpr = JwtVerificationCheck.getArgumentsFromCallExpr(callExpressionFromGetUnverifiedHeaders);
        Stream<StringLiteral> stringLiteralArgumentsFromGet = JwtVerificationCheck.getStringLiteralArguments(argumentsOfCallExpr);
        Stream<SubscriptionExpression> subscriptionFromGetUnverifiedHeaders = JwtVerificationCheck.getSubscriptions(Stream.of(call.parent()));
        Stream<StringLiteral> stringLiteralArgumentFromSubscription = JwtVerificationCheck.getSubscriptsStringLiteral(subscriptionFromGetUnverifiedHeaders);
        return Stream.concat(stringLiteralArgumentsFromGet, stringLiteralArgumentFromSubscription);
    }

    private static Stream<StringLiteral> usagesAccessedByGet(Symbol symbol, CallExpression call) {
        Stream<Usage> usages = JwtVerificationCheck.getForwardUsages(symbol, call);
        Stream<Tree> parentOfUsages = usages.map(Usage::tree).map(Tree::parent);
        Stream<CallExpression> callExpressionsFromUsages = JwtVerificationCheck.getCallExprWhereDictIsAccessedWithGet(parentOfUsages);
        return JwtVerificationCheck.getStringLiteralArguments(JwtVerificationCheck.getArgumentsFromCallExpr(callExpressionsFromUsages));
    }

    private static Stream<Argument> getArgumentsFromCallExpr(Stream<CallExpression> callExprs) {
        return callExprs.map(CallExpression::arguments).flatMap(Collection::stream);
    }

    private static Stream<Usage> getForwardUsages(Symbol symbol, CallExpression call) {
        return symbol.usages().stream().filter(usage -> usage.tree().firstToken().line() > call.callee().firstToken().line());
    }

    private static Stream<CallExpression> getCallExprWhereDictIsAccessedWithGet(Stream<Tree> parentQualifiedExpr) {
        return parentQualifiedExpr.filter(parent -> parent.is(new Tree.Kind[]{Tree.Kind.QUALIFIED_EXPR})).flatMap(TreeUtils.toStreamInstanceOfMapper(QualifiedExpression.class)).filter(expr -> "get".equals(expr.name().name())).filter(expr -> expr.parent().is(new Tree.Kind[]{Tree.Kind.CALL_EXPR})).map(Tree::parent).flatMap(TreeUtils.toStreamInstanceOfMapper(CallExpression.class));
    }

    private static Stream<StringLiteral> getStringLiteralArguments(Stream<Argument> arguments) {
        return arguments.filter(arg -> arg.is(new Tree.Kind[]{Tree.Kind.REGULAR_ARGUMENT})).flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class)).map(RegularArgument::expression).flatMap(TreeUtils.toStreamInstanceOfMapper(StringLiteral.class));
    }

    private static Stream<StringLiteral> usagesAccessedBySubscription(Symbol symbol, CallExpression call) {
        Stream<Usage> usages = JwtVerificationCheck.getForwardUsages(symbol, call);
        Stream<Tree> parentFromUsages = usages.map(Usage::tree).map(Tree::parent);
        Stream<SubscriptionExpression> subscriptionsFromUsages = JwtVerificationCheck.getSubscriptions(parentFromUsages);
        return JwtVerificationCheck.getSubscriptsStringLiteral(subscriptionsFromUsages);
    }

    private static Stream<SubscriptionExpression> getSubscriptions(Stream<Tree> subscriptions) {
        return subscriptions.filter(subscription -> subscription.is(new Tree.Kind[]{Tree.Kind.SUBSCRIPTION})).flatMap(TreeUtils.toStreamInstanceOfMapper(SubscriptionExpression.class));
    }

    private static Stream<StringLiteral> getSubscriptsStringLiteral(Stream<SubscriptionExpression> subscriptions) {
        return subscriptions.map(SubscriptionExpression::subscripts).map(ExpressionList::expressions).flatMap(Collection::stream).flatMap(TreeUtils.toStreamInstanceOfMapper(StringLiteral.class));
    }
}

