package com.undefinedlabs.scope.rules.sql.provider.internal.mysql;

import com.undefinedlabs.scope.jdk.reflection.ReflectionContext;
import com.undefinedlabs.scope.rules.sql.model.PreparedStatementQuery;
import com.undefinedlabs.scope.rules.sql.model.PreparedStatementQueryParameter;
import com.undefinedlabs.scope.rules.sql.provider.PreparedStatementQueryProvider;
import com.undefinedlabs.scope.rules.sql.provider.internal.PreparedStatementQueryUtils;
import org.apache.commons.lang3.reflect.MethodUtils;

import java.sql.PreparedStatement;
import java.util.LinkedHashMap;
import java.util.Map;

public enum MySQLClientPreparedStatementQueryProvider implements PreparedStatementQueryProvider {

    INSTANCE;

    public static final String CLIENT_PREPARED_STATEMENT_CLASS_NAME = "com.mysql.cj.jdbc.ClientPreparedStatement";
    public static final String CALLABLE_PREPARED_STATEMENT_CLASS_NAME = "com.mysql.cj.jdbc.CallableStatement";
    public static final String SERVER_PREPARED_STATEMENT_CLASS_NAME = "com.mysql.cj.jdbc.ServerPreparedStatement";

    @Override
    public PreparedStatementQuery create(PreparedStatement preparedStatement) {
        if(preparedStatement == null || !(ReflectionContext.INSTANCE.getScopeClass(CLIENT_PREPARED_STATEMENT_CLASS_NAME).isAssignableFrom(preparedStatement.getClass()))) {
            return PreparedStatementQuery.EMPTY;
        }

        try {
            final String sqlPreparedStatement = (String) MethodUtils.invokeMethod(preparedStatement, "getPreparedSql");
            final String sql = (String) MethodUtils.invokeMethod(preparedStatement, "asSql");
            final String sqlMethod = PreparedStatementQueryUtils.INSTANCE.extractSqlMethod(sqlPreparedStatement);
            final Object queryBindings = MethodUtils.invokeMethod(preparedStatement, "getQueryBindings");
            final Object[] bindValues = (Object[]) MethodUtils.invokeMethod(queryBindings, "getBindValues");

            final Map<String, PreparedStatementQueryParameter> parametersMap = new LinkedHashMap<>();
            for(int i = 0; i < bindValues.length; i++) {
                final Object bindValue = bindValues[i];
                final Object type = MethodUtils.invokeMethod(bindValue, "getMysqlType");
                final String typeName = (String) MethodUtils.invokeMethod(type, "getName");
                final String typeClassName = (String) MethodUtils.invokeMethod(type, "getClassName");
                final byte[] value = (byte[]) MethodUtils.invokeMethod(bindValue, "getByteValue");

                final int paramIndex = i + 1; //SQL params start by 1.
                final String paramKey = PreparedStatementQueryUtils.INSTANCE.generateParamKey(paramIndex);
                parametersMap.put(paramKey, new PreparedStatementQueryParameter(typeName, typeClassName, new String(value)));
            }

            final PreparedStatementQuery.Builder builder = PreparedStatementQuery.newBuilder();
            builder.withSqlStatement(sql);
            builder.withSqlMethod(sqlMethod);
            builder.withSqlPreparedStatement(parametersMap.isEmpty() ? null : sqlPreparedStatement);
            builder.withSqlParameterMap(parametersMap.isEmpty() ? null : parametersMap);
            return builder.build();
        } catch(Exception e){
            throw new RuntimeException(e);
        }
    }
}
