/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.ocr.model.table.translator;

import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import cn.smartjavaai.common.entity.Point;
import cn.smartjavaai.ocr.entity.OcrBox;
import cn.smartjavaai.ocr.entity.OcrItem;
import cn.smartjavaai.ocr.entity.TableStructureResult;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

public class TableStructTranslator
implements Translator<Image, TableStructureResult> {
    private final int maxLength = 488;
    private int height;
    private int width;
    private float scale = 1.0f;
    private float xScale;
    private float yScale;
    private List<String> dict;
    private String beg_str = "sos";
    private String end_str = "eos";
    private List<String> td_token = new ArrayList<String>();

    public void prepare(TranslatorContext ctx) throws IOException {
        Model model = ctx.getModel();
        try (InputStream is = model.getArtifact("table_structure_dict_ch.txt").openStream();){
            this.dict = Utils.readLines((InputStream)is, (boolean)false);
            this.dict.add(0, this.beg_str);
            if (this.dict.contains("<td>")) {
                this.dict.remove("<td>");
            }
            if (!this.dict.contains("<td></td>")) {
                this.dict.add("<td></td>");
            }
            this.dict.add(this.end_str);
        }
        this.td_token.add("<td>");
        this.td_token.add("<td");
        this.td_token.add("<td></td>");
    }

    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
        this.height = input.getHeight();
        this.width = input.getWidth();
        img = this.ResizeTableImage(img, this.height, this.width, 488);
        img = this.PaddingTableImage(ctx, img, 488);
        img = img.transpose(new int[]{2, 0, 1}).div((Number)255).flip(new int[]{0});
        img = NDImageUtils.normalize((NDArray)img, (float[])new float[]{0.485f, 0.456f, 0.406f}, (float[])new float[]{0.229f, 0.224f, 0.225f});
        img = img.expandDims(0);
        return new NDList(new NDArray[]{img});
    }

    public TableStructureResult processOutput(TranslatorContext ctx, NDList list) {
        NDArray bbox_preds = (NDArray)list.get(0);
        NDArray structure_probs = (NDArray)list.get(1);
        NDArray structure_idx = structure_probs.argMax(2);
        structure_probs = structure_probs.max(new int[]{2});
        ArrayList structure_batch_list = new ArrayList();
        ArrayList bbox_batch_list = new ArrayList();
        ArrayList result_score_list = new ArrayList();
        int beg_idx = this.dict.indexOf(this.beg_str);
        int end_idx = this.dict.indexOf(this.end_str);
        long batch_size = structure_idx.size(0);
        int batch_idx = 0;
        while ((long)batch_idx < batch_size) {
            ArrayList<String> structure_list = new ArrayList<String>();
            ArrayList<NDArray> bbox_list = new ArrayList<NDArray>();
            ArrayList<NDArray> score_list = new ArrayList<NDArray>();
            long len = structure_idx.get(new long[]{batch_idx}).size();
            int idx = 0;
            while ((long)idx < len) {
                int char_idx = (int)structure_idx.get(new long[]{batch_idx}).get(new long[]{idx}).toLongArray()[0];
                if (idx > 0 && char_idx == end_idx) break;
                String text = this.dict.get(char_idx);
                if (this.td_token.indexOf(text) > -1) {
                    NDArray bbox = bbox_preds.get(new long[]{batch_idx, idx});
                    bbox_list.add(bbox);
                }
                structure_list.add(text);
                score_list.add(structure_probs.get(new long[]{batch_idx, idx}));
                ++idx;
            }
            structure_batch_list.add(structure_list);
            bbox_batch_list.add(bbox_list);
            result_score_list.add(score_list);
            ++batch_idx;
        }
        List structure_str_list = (List)structure_batch_list.get(0);
        List bbox_list = (List)bbox_batch_list.get(0);
        List score_list = (List)result_score_list.get(0);
        structure_str_list.add(0, "<html>");
        structure_str_list.add(1, "<body>");
        structure_str_list.add(2, "<table>");
        structure_str_list.add("</table>");
        structure_str_list.add("</body>");
        structure_str_list.add("</html>");
        ArrayList<OcrItem> ocrItemList = new ArrayList<OcrItem>();
        for (int i = 0; i < bbox_list.size(); ++i) {
            NDArray box = (NDArray)bbox_list.get(i);
            float[] arr = new float[]{box.get(new NDIndex("0::2", new Object[0])).min().toFloatArray()[0], box.get(new NDIndex("1::2", new Object[0])).min().toFloatArray()[0], box.get(new NDIndex("0::2", new Object[0])).max().toFloatArray()[0], box.get(new NDIndex("1::2", new Object[0])).max().toFloatArray()[0]};
            Point topLeft = new Point((double)(arr[0] * this.xScale * (float)this.width), (double)(arr[1] * this.yScale * (float)this.height));
            Point topRight = new Point((double)(arr[2] * this.xScale * (float)this.width), (double)(arr[1] * this.yScale * (float)this.height));
            Point bottomRight = new Point((double)(arr[2] * this.xScale * (float)this.width), (double)(arr[3] * this.yScale * (float)this.height));
            Point bottomLeft = new Point((double)(arr[0] * this.xScale * (float)this.width), (double)(arr[3] * this.yScale * (float)this.height));
            OcrBox ocrBox = new OcrBox(topLeft, topRight, bottomRight, bottomLeft);
            float score = ((NDArray)score_list.get(i)).toFloatArray()[0];
            OcrItem item = new OcrItem();
            item.setOcrBox(ocrBox);
            item.setScore(score);
            ocrItemList.add(item);
        }
        return new TableStructureResult(ocrItemList, structure_str_list);
    }

    public Batchifier getBatchifier() {
        return null;
    }

    private NDArray ResizeTableImage(NDArray img, int height, int width, int maxLen) {
        int localMax = Math.max(height, width);
        float ratio = (float)maxLen * 1.0f / (float)localMax;
        int resize_h = (int)((float)height * ratio);
        int resize_w = (int)((float)width * ratio);
        this.scale = ratio;
        if (width > height) {
            this.xScale = 1.0f;
            this.yScale = (float)width / (float)height;
        } else {
            this.xScale = (float)height / (float)width;
            this.yScale = 1.0f;
        }
        img = NDImageUtils.resize((NDArray)img, (int)resize_w, (int)resize_h);
        return img;
    }

    private NDArray PaddingTableImage(TranslatorContext ctx, NDArray img, int maxLen) {
        NDArray paddingImg = ctx.getNDManager().zeros(new Shape(new long[]{maxLen, maxLen, 3L}), DataType.UINT8);
        paddingImg.set(new NDIndex("0:" + img.getShape().get(0) + ",0:" + img.getShape().get(1) + ",:", new Object[0]), img);
        return paddingImg;
    }
}

