/*
 * Decompiled with CFR 0.152.
 */
package graphql.schema.transform;

import graphql.PublicApi;
import graphql.schema.GraphQLEnumType;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInterfaceType;
import graphql.schema.GraphQLNamedSchemaElement;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.GraphQLUnionType;
import graphql.schema.SchemaTransformer;
import graphql.schema.SchemaTraverser;
import graphql.schema.transform.VisibleFieldPredicate;
import graphql.schema.transform.VisibleFieldPredicateEnvironment;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@PublicApi
public class FieldVisibilitySchemaTransformation {
    private final VisibleFieldPredicate visibleFieldPredicate;
    private final Runnable beforeTransformationHook;
    private final Runnable afterTransformationHook;

    public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate) {
        this(visibleFieldPredicate, () -> {}, () -> {});
    }

    public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate, Runnable beforeTransformationHook, Runnable afterTransformationHook) {
        this.visibleFieldPredicate = visibleFieldPredicate;
        this.beforeTransformationHook = beforeTransformationHook;
        this.afterTransformationHook = afterTransformationHook;
    }

    public final GraphQLSchema apply(GraphQLSchema schema) {
        HashSet observedBeforeTransform = new HashSet();
        HashSet observedAfterTransform = new HashSet();
        HashSet<GraphQLType> markedForRemovalTypes = new HashSet<GraphQLType>();
        Set protectedTypeNames = this.getRootTypes(schema).stream().map(GraphQLObjectType::getName).collect(Collectors.toSet());
        this.beforeTransformationHook.run();
        new SchemaTraverser().depthFirst((GraphQLTypeVisitor)new TypeObservingVisitor(observedBeforeTransform, schema), this.getRootTypes(schema));
        GraphQLSchema interimSchema = SchemaTransformer.transformSchema(schema, (GraphQLTypeVisitor)new FieldRemovalVisitor(this.visibleFieldPredicate, markedForRemovalTypes));
        new SchemaTraverser().depthFirst((GraphQLTypeVisitor)new TypeObservingVisitor(observedAfterTransform, interimSchema), this.getRootTypes(interimSchema));
        GraphQLSchema connectedSchema = SchemaTransformer.transformSchema(interimSchema, (GraphQLTypeVisitor)new TypeVisibilityVisitor(protectedTypeNames, observedBeforeTransform, observedAfterTransform));
        GraphQLSchema finalSchema = this.removeUnreferencedTypes(markedForRemovalTypes, connectedSchema);
        this.afterTransformationHook.run();
        return finalSchema;
    }

    private GraphQLSchema removeUnreferencedTypes(final Set<GraphQLType> markedForRemovalTypes, GraphQLSchema connectedSchema) {
        GraphQLSchema withoutAdditionalTypes = connectedSchema.transform(builder -> {
            HashSet<GraphQLType> additionalTypes = new HashSet<GraphQLType>(connectedSchema.getAdditionalTypes());
            additionalTypes.removeAll(markedForRemovalTypes);
            builder.clearAdditionalTypes();
            builder.additionalTypes(additionalTypes);
        });
        SchemaTransformer.transformSchema(withoutAdditionalTypes, (GraphQLTypeVisitor)new AdditionalTypeVisibilityVisitor(markedForRemovalTypes));
        return SchemaTransformer.transformSchema(connectedSchema, (GraphQLTypeVisitor)new GraphQLTypeVisitorStub(){

            @Override
            protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext<GraphQLSchemaElement> context) {
                if (node instanceof GraphQLType && markedForRemovalTypes.contains(node)) {
                    return this.deleteNode(context);
                }
                return super.visitGraphQLType(node, context);
            }
        });
    }

    private List<GraphQLObjectType> getRootTypes(GraphQLSchema schema) {
        return Stream.of(schema.getQueryType(), schema.getSubscriptionType(), schema.getMutationType()).filter(Objects::nonNull).collect(Collectors.toList());
    }

    private static class TypeObservingVisitor
    extends GraphQLTypeVisitorStub {
        private final Set<GraphQLType> observedTypes;
        private GraphQLSchema graphQLSchema;

        private TypeObservingVisitor(Set<GraphQLType> observedTypes, GraphQLSchema graphQLSchema) {
            this.observedTypes = observedTypes;
            this.graphQLSchema = graphQLSchema;
        }

        @Override
        protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext<GraphQLSchemaElement> context) {
            if (node instanceof GraphQLType) {
                this.observedTypes.add((GraphQLType)node);
            }
            if (node instanceof GraphQLInterfaceType) {
                this.observedTypes.addAll(this.graphQLSchema.getImplementations((GraphQLInterfaceType)node));
            }
            return TraversalControl.CONTINUE;
        }
    }

    private static class FieldRemovalVisitor
    extends GraphQLTypeVisitorStub {
        private final VisibleFieldPredicate visibilityPredicate;
        private final Set<GraphQLType> removedTypes;

        private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate, Set<GraphQLType> removedTypes) {
            this.visibilityPredicate = visibilityPredicate;
            this.removedTypes = removedTypes;
        }

        @Override
        public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition definition, TraverserContext<GraphQLSchemaElement> context) {
            return this.visitField(definition, context);
        }

        @Override
        public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField definition, TraverserContext<GraphQLSchemaElement> context) {
            return this.visitField(definition, context);
        }

        private TraversalControl visitField(GraphQLNamedSchemaElement element, TraverserContext<GraphQLSchemaElement> context) {
            VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl environment = new VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl(element, context.getParentNode());
            if (!this.visibilityPredicate.isVisible(environment)) {
                this.deleteNode(context);
                if (element instanceof GraphQLFieldDefinition) {
                    this.removedTypes.add(((GraphQLFieldDefinition)element).getType());
                } else if (element instanceof GraphQLInputObjectField) {
                    this.removedTypes.add(((GraphQLInputObjectField)element).getType());
                }
            }
            return TraversalControl.CONTINUE;
        }
    }

    private static class TypeVisibilityVisitor
    extends GraphQLTypeVisitorStub {
        private final Set<String> protectedTypeNames;
        private final Set<GraphQLType> observedBeforeTransform;
        private final Set<GraphQLType> observedAfterTransform;

        private TypeVisibilityVisitor(Set<String> protectedTypeNames, Set<GraphQLType> observedTypes, Set<GraphQLType> observedAfterTransform) {
            this.protectedTypeNames = protectedTypeNames;
            this.observedBeforeTransform = observedTypes;
            this.observedAfterTransform = observedAfterTransform;
        }

        @Override
        public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node, TraverserContext<GraphQLSchemaElement> context) {
            return super.visitGraphQLInterfaceType(node, context);
        }

        @Override
        public TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext<GraphQLSchemaElement> context) {
            if (this.observedBeforeTransform.contains(node) && !this.observedAfterTransform.contains(node) && (node instanceof GraphQLObjectType || node instanceof GraphQLEnumType || node instanceof GraphQLInterfaceType || node instanceof GraphQLUnionType)) {
                return this.deleteNode(context);
            }
            return TraversalControl.CONTINUE;
        }
    }

    private static class AdditionalTypeVisibilityVisitor
    extends GraphQLTypeVisitorStub {
        private final Set<GraphQLType> markedForRemovalTypes;

        private AdditionalTypeVisibilityVisitor(Set<GraphQLType> markedForRemovalTypes) {
            this.markedForRemovalTypes = markedForRemovalTypes;
        }

        @Override
        public TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext<GraphQLSchemaElement> context) {
            if (node instanceof GraphQLNamedType) {
                GraphQLNamedType namedType = (GraphQLNamedType)node;
                if (this.markedForRemovalTypes.contains(node)) {
                    this.markedForRemovalTypes.remove(namedType);
                }
            }
            return TraversalControl.CONTINUE;
        }
    }
}

