package com.plugins.junk;

import com.android.build.gradle.BaseExtension;
import com.plugins.junk.field.FieldBean;
import com.plugins.junk.field.FieldCodeFactory;
import com.plugins.junk.method.create.MethodCodeFactory;
import com.plugins.junk.method.create.MethodBean;
import com.plugins.junk.method.insert.JunkCodeFactory;
import com.plugins.junk.utils.JavassistUtils;
import com.plugins.junk.utils.LogUtil;

import org.apache.commons.io.FileUtils;
import org.gradle.api.Project;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.Modifier;
import javassist.NotFoundException;
import javassist.bytecode.ClassFile;

public class CodeInjectUtil {
    private final ClassPool sClassPool = new ClassPool(true);
    private int mMethodSuccessCount = 0;
    private int mMethodCurIndex = 0;
    private int mCodeInjectPercentage;
    private float mCodeInjectMethodRatio;
    private float mCodeInjectFieldRatio;
    private String[] mCodeInjectWhiteList;

    public void injectCode(File baseClassPath, Project project, CodeExtension extension, Map<String, String> map) throws NotFoundException, CannotCompileException {
        try {
            //把类路径添加到classpool
            LogUtil.log("Class build path: " + baseClassPath.getPath());
            sClassPool.insertClassPath(baseClassPath.getPath());
        } catch (NotFoundException e) {
            e.printStackTrace();
        }
        //添加三方jar包
        addJarFile(map);
        //添加Android相关的类
        BaseExtension android = project.getExtensions().getByType(BaseExtension.class);
        sClassPool.insertClassPath(android.getBootClasspath().get(0).toString());
        LogUtil.log("Android libraries: " + android.getBootClasspath());
        mMethodSuccessCount = 0;
        mMethodCurIndex = 0;
        mCodeInjectPercentage = Math.round(extension.codeInjectPercentage / 10f);
        mCodeInjectMethodRatio = extension.codeInjectMethodRatio;
//        mCodeInjectFieldRatio = extension.codeInjectFieldRatio;
        mCodeInjectWhiteList = extension.codeInjectWhiteList;
        traverseFile(baseClassPath);
        LogUtil.log("Total method modified: " + mMethodSuccessCount);
    }

    /**
     * 添加三方jar包
     */
    private void addJarFile(Map<String, String> map) {
        LogUtil.log("library  size:" + map.size());
        map.forEach((s, s2) -> {
            try {
                LogUtil.log("library  name:%s , path: %s", s, s2);
                sClassPool.insertClassPath(s2);
            } catch (NotFoundException e) {
                e.printStackTrace();
            }
        });
    }

    private void traverseFile(File baseClassFile) {
        File[] files = baseClassFile.listFiles();
        for (File file : files) {
            if (file.isDirectory()) {    //若是目录，则递归
                if (file.getName().contains("META-INF")) {
                    LogUtil.log("文件夹 META-INF 跳过  :" + file.getName());
                    continue;
                }
                traverseFile(file);
            } else if (file.isFile()) {
                if (checkClassFile(file)) {
                    //代码插入
                    inject(file.getPath());
                }
            }
        }
    }

    /**
     * 这里真正实现对代码的注入
     */
    private void inject(String classFilePath) {
        try (FileInputStream is = new FileInputStream(classFilePath)) {
            ClassFile classFile = new ClassFile(new DataInputStream(new BufferedInputStream(is)));

            CtClass ctClass = sClassPool.get(classFile.getName());
            //解冻
            if (ctClass.isFrozen()) {
                ctClass.defrost();
            }
            if (!checkWhiteList(ctClass)) {
                //获取当前类的所有方法
                CtMethod[] ctMethods = ctClass.getDeclaredMethods();
                boolean isInterface = Modifier.isInterface(ctClass.getModifiers());
                //在类中创建变量
//                if (mCodeInjectFieldRatio > 0 && !isInterface) {
//                    createClassField(ctClass, ctMethods);
//                }
                //在类中创建方法
                if (mCodeInjectMethodRatio > 0) {
                    createClassMethod(ctClass, ctMethods, isInterface);
                }
                //在方法中插入代码
                if (mCodeInjectPercentage > 0 && !isInterface) {
                    injectClassMethod(ctClass, ctMethods);
                }
            }
            //保存class
            byte[] classBytes = ctClass.toBytecode();
            FileUtils.writeByteArrayToFile(new File(classFilePath), classBytes);
            //ctClass.writeFile(baseFilePath);//这个方法有bug,文件大于8K保存文件会损坏
            ctClass.detach();//释放

        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 在类中创建成员变量
     * 按照方法比例， 变量比例，自定义数量
     *
     * @param ctClass
     */
    private void createClassField(CtClass ctClass, CtMethod[] ctMethods) {
        if (ctMethods == null)
            return;
        int count = (int) (ctMethods.length * mCodeInjectFieldRatio);
        if (count > 0) {
            List<CallCode> fields = new LinkedList<>();
            List<CallCode> staticFields = new LinkedList<>();
            //获取静态方法比例
            int staticCount = (int) (getClassMethodStaticPercentage(ctMethods) * mCodeInjectFieldRatio);
            for (int i = 0; i < count; i++) {
                boolean isStatic = staticCount > i;
                FieldBean fieldBean = FieldCodeFactory.getFieldCode(sClassPool, ctClass, isStatic);
                if (fieldBean != null) {
                    try {
                        ctClass.addField(fieldBean.getCtField());
                        if (isStatic) {
                            staticFields.add(fieldBean);
                        } else {
                            fields.add(fieldBean);
                        }
                    } catch (CannotCompileException e) {
                        LogUtil.error("[add field failed]: %s - %s", fieldBean.getCtField().getName(), e.getMessage());
                    }
                }
            }
            //把创建的方法插入到原有的方法中调用
            insertCreateCode(ctMethods, fields, staticFields, false);
            insertCreateCode(ctMethods, fields, staticFields, true);
        }
    }

    /**
     * 按照比例在类中创建方法
     *
     * @param ctClass
     * @param ctMethods
     */
    private void createClassMethod(CtClass ctClass, CtMethod[] ctMethods, boolean isInterface) {
        if (ctMethods != null) {
            //计算创建方法的比例
            int count = (int) (ctMethods.length * mCodeInjectMethodRatio);
            if (count > 0) {
                List<CallCode> methods = new LinkedList<>();
                List<CallCode> staticMethods = new LinkedList<>();
                //获取静态方法比例
                int staticCount = (int) (getClassMethodStaticPercentage(ctMethods) * mCodeInjectMethodRatio);
                for (int i = 0; i < count; i++) {
                    boolean isStatic = staticCount > i;
                    //创建方法并添加
                    MethodBean methodBean = MethodCodeFactory.getMethodCode(sClassPool, ctClass, isStatic, isInterface);
                    if (methodBean != null) {
                        try {
                            ctClass.addMethod(methodBean.getMethod());
                            if (isStatic) {
                                staticMethods.add(methodBean);
                            } else {
                                methods.add(methodBean);
                            }
                        } catch (CannotCompileException e) {
                            LogUtil.error("[add method failed]: %s - %s", methodBean.getMethod().getLongName(), e.getMessage());
                        }
                    }
                }
                if (!isInterface) {
                    //把创建的方法插入到原有的方法中调用
                    insertCreateCode(ctMethods, methods, staticMethods, false);
                    insertCreateCode(ctMethods, methods, staticMethods, true);
                }
            }
        }
    }


    /**
     * 把创建的代码插入到原有的方法中调用
     *
     * @param ctMethods
     * @param methods
     */
    private void insertCreateCode(CtMethod[] ctMethods, List<CallCode> methods, List<CallCode> staticMethods, boolean isRandom) {
        int allSize = methods.size() + staticMethods.size();
        int index = 0;
        int staticIndex = 0;
        int size = allSize;
        if (isRandom) {//打乱数组顺序
            Collections.shuffle(Arrays.asList(ctMethods));
        }
        for (int i = 0; i < size; i++) {
            //防止角标越界
            CtMethod ctMethod = ctMethods[i % ctMethods.length];
            //跳过abstract和native方
            if (isSkipMethod(ctMethod)) {
                continue;
            }
            try {
                if (Modifier.isStatic(ctMethod.getModifiers())) {
                    JavassistUtils.insertAt(ctMethod, staticMethods.get(staticIndex % staticMethods.size()).call());
                    staticIndex++;
                } else {
                    JavassistUtils.insertAt(ctMethod, methods.get(index % methods.size()).call());
                    index++;
                }
            } catch (CannotCompileException e) {
                //如果循环次数 >methods.size() * 2 就不再循环类
                if (size < (allSize * 2)) {
                    //如果插入方法报错就把size+1，保证每个创建的方法都可以插入
                    size++;
                }
                LogUtil.error("[insert create code failed]: %s - %s", ctMethod.getLongName(), e.getMessage());
            }
        }
    }

    private void injectClassMethod(CtClass ctClass, CtMethod[] ctMethods) {
        if (ctMethods != null) {
            for (CtMethod ctMethod : ctMethods) {
                if (isSkipMethod(ctMethod)) {
                    continue;
                }
                boolean success = true;
                //按照百分比插入代码
                if (mCodeInjectPercentage > mMethodCurIndex) {
                    success = JunkCodeFactory.insert(sClassPool, ctClass, ctMethod);
                    if (success) {
                        mMethodSuccessCount++;
                    }
                }
                if (success) {
                    mMethodCurIndex++;
                    if (mMethodCurIndex == 10)
                        mMethodCurIndex = 0;
                }
            }
        }
    }

    /**
     * 跳过特殊的方法
     *
     * @param ctMethod
     * @return
     */
    private boolean isSkipMethod(CtMethod ctMethod) {
        //跳过abstract和native方法
        if (Modifier.isAbstract(ctMethod.getModifiers()) || Modifier.isNative(ctMethod.getModifiers())) {
            return true;
        }
        //kotlin协程生成的invokeSuspend方法需要跳过，否则会报错
        if ("invokeSuspend".equals(ctMethod.getName())) {
            return true;
        }
        //kotlin 挂起方法需要跳过，否则会报错（suspend标记的方法会添加Continuation形式参数）
        try {
            for (CtClass parameterType : ctMethod.getParameterTypes()) {
                if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) {
                    return true;
                }
            }
        } catch (NotFoundException e) {
        }
        return false;
    }

    /**
     * 检查CtClass是否在白名单中
     *
     * @param ctMethod
     * @return
     */
    private boolean checkWhiteList(CtClass ctMethod) {
        if (mCodeInjectWhiteList != null && mCodeInjectWhiteList.length > 0) {
            String name = ctMethod.getSimpleName();
            for (String s : mCodeInjectWhiteList) {
                //类及内部类
                if (name.equals(s) || name.startsWith(s + "$")) {
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * 过滤掉一些生成的类
     *
     * @param file
     * @return
     */
    private boolean checkClassFile(File file) {
        if (file.getName().contains("META-INF")) {
            LogUtil.log("META-INF跳过   :" + file.getName());
            return false;
        }
        if (file.isDirectory()) {
            LogUtil.log("文件夹跳过   :" + file.getName());
            return false;
        }
        if (file.getName().endsWith(".json")) {
            LogUtil.log("json文件跳过   :" + file.getName());
            return false;
        }
        String filePath = file.getPath();
        return !filePath.contains("R$") &&
                !filePath.contains("R.class") &&
                !filePath.contains("BuildConfig.class");
    }

    /**
     * 获取类中的静态方法比例
     *
     * @param ctMethods
     * @return 百分比 （10..15..20..50..100）
     */
    private int getClassMethodStaticPercentage(CtMethod[] ctMethods) {
        int staticCount = 0;
        for (CtMethod ctMethod : ctMethods) {
            if (Modifier.isStatic(ctMethod.getModifiers())) {
                staticCount++;
            }
        }
        return staticCount;
    }

}