/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.ocr.model.plate.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Point;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import cn.smartjavaai.common.utils.LetterBoxUtils;
import cn.smartjavaai.common.utils.NMSUtils;
import java.util.ArrayList;
import java.util.Map;

public class Yolov7PlateDetectTranslator
implements Translator<Image, DetectedObjects> {
    private int inputSize = 640;
    private float minConfThreshold = 0.3f;
    private float iouThreshold = 0.5f;
    private float confThreshold = 0.0f;
    private int imageWidth;
    private int imageHeight;
    private int topK;
    private LetterBoxUtils.ResizeResult letterBoxResult;

    public Yolov7PlateDetectTranslator(Map<String, ?> arguments) {
        this.confThreshold = arguments.containsKey("confThreshold") ? (float)Integer.parseInt(arguments.get("confThreshold").toString()) : 0.3f;
        this.iouThreshold = arguments.containsKey("iouThreshold") ? (float)Integer.parseInt(arguments.get("iouThreshold").toString()) : 0.5f;
        this.topK = arguments.containsKey("topk") ? Integer.parseInt(arguments.get("topk").toString()) : 100;
    }

    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        this.imageWidth = (int)array.getShape().get(1);
        this.imageHeight = (int)array.getShape().get(0);
        this.letterBoxResult = LetterBoxUtils.letterbox((NDManager)manager, (NDArray)array, (int)this.inputSize, (int)this.inputSize, (float)114.0f, (LetterBoxUtils.PaddingPosition)LetterBoxUtils.PaddingPosition.CENTER);
        array = this.letterBoxResult.image;
        array = array.toType(DataType.FLOAT32, false).div((Number)Float.valueOf(255.0f));
        array = array.transpose(new int[]{2, 0, 1});
        return new NDList(new NDArray[]{array.expandDims(0)});
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDManager manager = ctx.getNDManager();
        int num_cls = 2;
        NDArray dets = list.singletonOrThrow();
        NDArray dets0 = dets.get(new long[]{0L});
        NDArray conf = dets0.get(":, 4", new Object[0]);
        NDArray mask = conf.gt((Number)Float.valueOf(this.minConfThreshold));
        NDArray detsFiltered = dets0.get(mask);
        NDArray clsLogits = detsFiltered.get(":, 5:7", new Object[0]);
        NDArray confFiltered = detsFiltered.get(":, 4", new Object[0]).reshape(new long[]{-1L, 1L});
        clsLogits = clsLogits.mul(confFiltered);
        NDArray jointScore = clsLogits.max(new int[]{1});
        NDArray jointMask = jointScore.gt((Number)Float.valueOf(this.confThreshold));
        detsFiltered = detsFiltered.get(jointMask);
        clsLogits = clsLogits.get(jointMask);
        NDArray xywh = detsFiltered.get(":, 0:4", new Object[0]);
        NDArray halfWH = xywh.get(":, 2:4", new Object[0]).div((Number)2);
        NDArray xy1 = xywh.get(":, 0:2", new Object[0]).sub(halfWH);
        NDArray xy2 = xywh.get(":, 0:2", new Object[0]).add(halfWH);
        NDArray boxes = NDArrays.concat((NDList)new NDList(new NDArray[]{xy1, xy2}), (int)1);
        NDArray scores = clsLogits.max(new int[]{1}, true);
        NDArray indices = clsLogits.argMax(1).reshape(new long[]{-1L, 1L}).toType(DataType.FLOAT32, false);
        NDArray keyPoints = NDArrays.concat((NDList)new NDList(new NDArray[]{detsFiltered.get(":, 7:8", new Object[0]), detsFiltered.get(":, 8:9", new Object[0]), detsFiltered.get(":, 10:11", new Object[0]), detsFiltered.get(":, 11:12", new Object[0]), detsFiltered.get(":, 13:14", new Object[0]), detsFiltered.get(":, 14:15", new Object[0]), detsFiltered.get(":, 16:17", new Object[0]), detsFiltered.get(":, 17:18", new Object[0])}), (int)1);
        NDArray output = NDArrays.concat((NDList)new NDList(new NDArray[]{boxes, scores, keyPoints, indices}), (int)1);
        int[] keepIndices = NMSUtils.nms((NDArray)boxes, (NDArray)scores.squeeze(), (float)this.iouThreshold);
        NDArray kept = output.get(manager.create(keepIndices));
        if (keepIndices.length > this.topK) {
            int[] topkIndices = new int[this.topK];
            System.arraycopy(keepIndices, 0, topkIndices, 0, this.topK);
            keepIndices = topkIndices;
        }
        NDArray restored = LetterBoxUtils.restoreBox((NDArray)kept, (float)this.letterBoxResult.r, (float)this.letterBoxResult.left, (float)this.letterBoxResult.top, (int)5, (int)8);
        ArrayList<String> classNames = new ArrayList<String>();
        ArrayList<Double> probabilities = new ArrayList<Double>();
        ArrayList<Landmark> boundingBoxes = new ArrayList<Landmark>();
        float[] flatData = restored.toFloatArray();
        long[] shape = restored.getShape().getShape();
        int rows = (int)shape[0];
        int cols = (int)shape[1];
        float[][] data = new float[rows][cols];
        for (int i = 0; i < rows; ++i) {
            System.arraycopy(flatData, i * cols, data[i], 0, cols);
        }
        for (float[] row : data) {
            float x1 = row[0];
            float y1 = row[1];
            float x2 = row[2];
            float y2 = row[3];
            float score = row[4];
            int classIndex = (int)row[13];
            double prob = score;
            String className = classIndex == 0 ? "single" : "double";
            double rectX = x1 / (float)this.imageWidth;
            double rectY = y1 / (float)this.imageHeight;
            double rectW = (x2 - x1) / (float)this.imageWidth;
            double rectH = (y2 - y1) / (float)this.imageHeight;
            ArrayList<Point> pointsSrc = new ArrayList<Point>();
            pointsSrc.add(new Point((double)row[5], (double)row[6]));
            pointsSrc.add(new Point((double)row[7], (double)row[8]));
            pointsSrc.add(new Point((double)row[9], (double)row[10]));
            pointsSrc.add(new Point((double)row[11], (double)row[12]));
            Landmark box = new Landmark(rectX, rectY, rectW, rectH, pointsSrc);
            classNames.add(className);
            probabilities.add(prob);
            boundingBoxes.add(box);
        }
        DetectedObjects detectedObjects = new DetectedObjects(classNames, probabilities, boundingBoxes);
        return detectedObjects;
    }

    public Batchifier getBatchifier() {
        return null;
    }
}

