package com.github.azbh111.utils.java.lambda;


import com.github.azbh111.utils.java.exception.ExceptionUtils;

import java.lang.invoke.SerializedLambda;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 解析lambda的入参和出参
 *
 * @author: zyp
 * @date: 2020/12/17 11:46
 */
public class LambdaUtils {

    private static ConcurrentHashMap<String, SerializedLambda> cache = new ConcurrentHashMap<>();

    private static Pattern SFunctionPattern = Pattern.compile("\\(L([\\w\\d$_/]+);\\)L([\\w\\d$_/]+);");
    private static Pattern SBiFunctionPattern = Pattern.compile("\\(L([\\w\\d$_/]+);L([\\w\\d$_/]+);\\)L([\\w\\d$_/]+);");
    private static Pattern SConsumerPattern = Pattern.compile("\\(L([\\w\\d$_/]+);\\)V");
    private static Pattern SBiConsumerPattern = Pattern.compile("\\(L([\\w\\d$_/]+);L([\\w\\d$_/]+);\\)V");
    private static Pattern SSupplierrPattern = Pattern.compile("\\(\\)L([\\w\\d$_/]+);");

    public static <P, R> SFunctionInfo parse(SFunction<P, R> lambda) {
        if (lambda == null) {
            return null;
        }
        SerializedLambda sl = getSerializedLambda(lambda);
        Matcher m = SFunctionPattern.matcher(sl.getInstantiatedMethodType());
        if (!m.find()) {
            throw new Error("can not resolve instantiatedMethodType: " + sl.getInstantiatedMethodType());
        }
        try {
            return new SFunctionInfo(
                    Class.forName(m.group(1).replaceAll("/", ".")),
                    Class.forName(m.group(2).replaceAll("/", "."))
            );
        } catch (Throwable x) {
            ExceptionUtils.throwException(x);
            return null;
        }
    }

    public static <P1, P2, R> SBiFunctionInfo parse(SBiFunction<P1, P2, R> lambda) {
        if (lambda == null) {
            return null;
        }
        SerializedLambda sl = getSerializedLambda(lambda);
        Matcher m = SBiFunctionPattern.matcher(sl.getInstantiatedMethodType());
        if (!m.find()) {
            throw new Error("can not resolve instantiatedMethodType: " + sl.getInstantiatedMethodType());
        }
        try {
            return new SBiFunctionInfo(
                    Class.forName(m.group(1).replaceAll("/", ".")),
                    Class.forName(m.group(2).replaceAll("/", ".")),
                    Class.forName(m.group(3).replaceAll("/", "."))
            );
        } catch (Throwable x) {
            ExceptionUtils.throwException(x);
            return null;
        }
    }

    public static <P> SConsumerInfo parse(SConsumer<P> lambda) {
        if (lambda == null) {
            return null;
        }
        SerializedLambda sl = getSerializedLambda(lambda);
        Matcher m = SConsumerPattern.matcher(sl.getInstantiatedMethodType());
        if (!m.find()) {
            throw new Error("can not resolve instantiatedMethodType: " + sl.getInstantiatedMethodType());
        }
        try {
            return new SConsumerInfo(
                    Class.forName(m.group(1).replaceAll("/", "."))
            );
        } catch (Throwable x) {
            ExceptionUtils.throwException(x);
            return null;
        }
    }

    public static <P1, P2> SBiConsumerInfo parse(SBiConsumer<P1, P2> lambda) {
        if (lambda == null) {
            return null;
        }
        SerializedLambda sl = getSerializedLambda(lambda);
        Matcher m = SBiConsumerPattern.matcher(sl.getInstantiatedMethodType());
        if (!m.find()) {
            throw new Error("can not resolve instantiatedMethodType: " + sl.getInstantiatedMethodType());
        }
        try {
            return new SBiConsumerInfo(
                    Class.forName(m.group(1).replaceAll("/", ".")),
                    Class.forName(m.group(2).replaceAll("/", "."))
            );
        } catch (Throwable x) {
            ExceptionUtils.throwException(x);
            return null;
        }
    }

    public static <R> SSupplierInfo parse(SSupplier<R> lambda) {
        if (lambda == null) {
            return null;
        }
        SerializedLambda sl = getSerializedLambda(lambda);
        Matcher m = SSupplierrPattern.matcher(sl.getInstantiatedMethodType());
        if (!m.find()) {
            throw new Error("can not resolve instantiatedMethodType: " + sl.getInstantiatedMethodType());
        }
        try {
            return new SSupplierInfo(
                    Class.forName(m.group(1).replaceAll("/", "."))
            );
        } catch (Throwable x) {
            ExceptionUtils.throwException(x);
            return null;
        }
    }

    /**
     * implMethodSignature 可能会有基本数据类型, 不含泛型信息
     * instantiatedMethodType 没有基本数据类型, 不含泛型信息
     *
     * @param lambda
     * @return
     */
    public static SerializedLambda getSerializedLambda(SSerializable lambda) {
        SerializedLambda serializedLambda = cache.get(lambda.getClass().getName());
        if (serializedLambda != null) {
            return serializedLambda;
        }
        Method writeReplaceMethod;
        try {
            writeReplaceMethod = lambda.getClass().getDeclaredMethod("writeReplace");
        } catch (NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
        // 从序列化方法取出序列化的lambda信息
        boolean isAccessible = writeReplaceMethod.isAccessible();
        writeReplaceMethod.setAccessible(true);
        try {
            serializedLambda = (SerializedLambda) writeReplaceMethod.invoke(lambda);
            cache.put(lambda.getClass().getName(), serializedLambda);
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new Error(e);
        }
        writeReplaceMethod.setAccessible(isAccessible);
        return serializedLambda;
    }
}
