/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.ocr.model.table;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import cn.smartjavaai.common.entity.DetectionRectangle;
import cn.smartjavaai.common.entity.R;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.common.utils.OpenCVUtils;
import cn.smartjavaai.ocr.config.OcrRecOptions;
import cn.smartjavaai.ocr.entity.OcrBox;
import cn.smartjavaai.ocr.entity.OcrInfo;
import cn.smartjavaai.ocr.entity.OcrItem;
import cn.smartjavaai.ocr.entity.TableStructureResult;
import cn.smartjavaai.ocr.exception.OcrException;
import cn.smartjavaai.ocr.model.common.detect.OcrCommonDetModel;
import cn.smartjavaai.ocr.model.common.direction.OcrDirectionModel;
import cn.smartjavaai.ocr.model.common.recognize.OcrCommonRecModel;
import cn.smartjavaai.ocr.model.table.TableStructureModel;
import cn.smartjavaai.ocr.utils.ConvertHtml2Excel;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import javax.imageio.ImageIO;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;
import org.opencv.core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TableRecognizer {
    private static final Logger log = LoggerFactory.getLogger(TableRecognizer.class);
    private OcrCommonDetModel textDetector;
    private TableStructureModel tableStructureModel;
    private OcrCommonRecModel textRecModel;
    private OcrDirectionModel directionModel;

    private TableRecognizer(Builder builder) {
        this.tableStructureModel = builder.tableStructureModel;
        this.textRecModel = builder.textRecModel;
        this.directionModel = builder.directionModel;
        this.textDetector = builder.textDetector;
        this.textRecModel.setTextDetModel(this.textDetector);
        this.textRecModel.setDirectionModel(this.directionModel);
    }

    public static Builder builder() {
        return new Builder();
    }

    public TableRecognizer withTextRecModel(OcrCommonRecModel textRecModel) {
        this.textRecModel = textRecModel;
        return this;
    }

    public TableRecognizer withStructureModel(TableStructureModel tableStructureModel) {
        this.tableStructureModel = tableStructureModel;
        return this;
    }

    public R<TableStructureResult> recognize(Image image) {
        R<TableStructureResult> result = this.tableStructureModel.detect(image);
        if (!result.isSuccess()) {
            return R.fail((Integer)result.getCode(), (String)result.getMessage());
        }
        boolean enableDirectionCorrect = this.directionModel != null;
        OcrRecOptions options = new OcrRecOptions(enableDirectionCorrect, false);
        OcrInfo ocrInfo = this.textRecModel.recognize(image, options);
        List<String> tableContentList = this.buildTable((TableStructureResult)result.getData(), ocrInfo);
        String html = this.convertHtml(((TableStructureResult)result.getData()).getTableTagList(), tableContentList);
        ((TableStructureResult)result.getData()).setHtml(html);
        return result;
    }

    public R<TableStructureResult> recognize(BufferedImage image) {
        if (!ImageUtils.isImageValid((BufferedImage)image)) {
            return R.fail((R.Status)R.Status.INVALID_IMAGE);
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat((BufferedImage)image));
            R<TableStructureResult> r = this.recognize(img);
            return r;
        }
        catch (Exception e) {
            throw new OcrException(e);
        }
        finally {
            if (Objects.nonNull(img)) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    public R<TableStructureResult> recognize(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            return R.fail((R.Status)R.Status.FILE_NOT_FOUND);
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            R<TableStructureResult> r = this.recognize(img);
            return r;
        }
        catch (IOException e) {
            throw new OcrException("\u65e0\u6548\u7684\u56fe\u7247", e);
        }
        finally {
            if (Objects.nonNull(img)) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    public R<TableStructureResult> recognize(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            return R.fail((R.Status)R.Status.INVALID_IMAGE);
        }
        try {
            BufferedImage image = ImageIO.read(new ByteArrayInputStream(imageData));
            return this.recognize(image);
        }
        catch (IOException e) {
            throw new OcrException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    public void drawTable(TableStructureResult tableStructureResult, BufferedImage image, String savePath) {
        if (Objects.isNull(tableStructureResult) || CollectionUtils.isEmpty(tableStructureResult.getTableTagList())) {
            throw new OcrException("\u8868\u683c\u7ed3\u6784\u4e3a\u7a7a");
        }
        for (int i = 0; i < tableStructureResult.getOcrItemList().size(); ++i) {
            OcrItem item = tableStructureResult.getOcrItemList().get(i);
            DetectionRectangle detectionRectangle = item.getOcrBox().toDetectionRectangle();
            ImageUtils.drawImageRectWithText((BufferedImage)image, (DetectionRectangle)detectionRectangle, (String)(i + ""), (Color)Color.RED);
        }
        ImageUtils.saveImage((BufferedImage)image, (String)savePath);
    }

    public static String removeStyleBlock(String html) {
        String lowerHtml = html.toLowerCase();
        int styleStart = lowerHtml.indexOf("<style");
        if (styleStart == -1) {
            return html;
        }
        int styleEnd = lowerHtml.indexOf("</style>", styleStart);
        if (styleEnd == -1) {
            return html;
        }
        return html.substring(0, styleStart) + html.substring(styleEnd += "</style>".length());
    }

    public void exportExcel(String html, String savePath) {
        try {
            String content = TableRecognizer.removeStyleBlock(html);
            content = content.replace("<html><body>", "");
            content = content.replace("</body></html>", "");
            HSSFWorkbook workbook = ConvertHtml2Excel.table2Excel(content);
            workbook.write(new File(savePath));
        }
        catch (Exception e) {
            throw new OcrException("\u5bfc\u51faexcel\u5931\u8d25\uff0c\u8bf7\u68c0\u67e5\u8868\u7ed3\u6784\u662f\u5426\u8bc6\u522b\u6b63\u786e");
        }
    }

    public List<String> buildTable(TableStructureResult tableStructureResult, OcrInfo ocrInfo) {
        ConcurrentHashMap matched = new ConcurrentHashMap();
        List<OcrItem> ocrItems = ocrInfo.getOcrItemList();
        for (int i = 0; i < ocrItems.size(); ++i) {
            OcrBox ocrBox = ocrItems.get(i).getOcrBox();
            int[] box_1 = new int[]{(int)ocrBox.getTopLeft().getX(), (int)ocrBox.getTopLeft().getY(), (int)ocrBox.getBottomRight().getX(), (int)ocrBox.getBottomRight().getY()};
            ArrayList<Pair<Float, Float>> distances = new ArrayList<Pair<Float, Float>>();
            for (OcrItem cell : tableStructureResult.getOcrItemList()) {
                OcrBox cellBox = cell.getOcrBox();
                int[] box_2 = new int[]{(int)cellBox.getTopLeft().getX(), (int)cellBox.getTopLeft().getY(), (int)cellBox.getBottomRight().getX(), (int)cellBox.getBottomRight().getY()};
                float distance = this.distance(box_1, box_2);
                float iou = 1.0f - this.computeIou(box_1, box_2);
                distances.add((Pair<Float, Float>)Pair.of((Object)Float.valueOf(distance), (Object)Float.valueOf(iou)));
            }
            Pair<Float, Float> nearest = this.sorted(distances);
            int id = 0;
            for (int idx = 0; idx < distances.size(); ++idx) {
                Pair current = (Pair)distances.get(idx);
                if (((Float)current.getLeft()).floatValue() != ((Float)nearest.getLeft()).floatValue() || ((Float)current.getRight()).floatValue() != ((Float)nearest.getRight()).floatValue()) continue;
                id = idx;
                break;
            }
            if (!matched.containsKey(id)) {
                ArrayList<Integer> textIds = new ArrayList<Integer>();
                textIds.add(i);
                matched.put(id, textIds);
                continue;
            }
            ((List)matched.get(id)).add(i);
        }
        ArrayList<String> cell_contents = new ArrayList<String>();
        ArrayList<Double> probs = new ArrayList<Double>();
        for (int i = 0; i < tableStructureResult.getOcrItemList().size(); ++i) {
            List textIds = (List)matched.get(i);
            ArrayList<String> contents = new ArrayList<String>();
            String content = "";
            if (textIds != null) {
                for (Integer id : textIds) {
                    contents.add(ocrItems.get(id).getText());
                }
                content = StringUtils.join(contents, (String)" ");
            }
            cell_contents.add(content);
            probs.add(-1.0);
        }
        return cell_contents;
    }

    private int distance(int[] box_1, int[] box_2) {
        int x1 = box_1[0];
        int y1 = box_1[1];
        int x2 = box_1[2];
        int y2 = box_1[3];
        int x3 = box_2[0];
        int y3 = box_2[1];
        int x4 = box_2[2];
        int y4 = box_2[3];
        int dis = Math.abs(x3 - x1) + Math.abs(y3 - y1) + Math.abs(x4 - x2) + Math.abs(y4 - y2);
        int dis_2 = Math.abs(x3 - x1) + Math.abs(y3 - y1);
        int dis_3 = Math.abs(x4 - x2) + Math.abs(y4 - y2);
        return dis + Math.min(dis_2, dis_3);
    }

    private float computeIou(int[] rec1, int[] rec2) {
        int S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]);
        int S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]);
        int sum_area = S_rec1 + S_rec2;
        int left_line = Math.max(rec1[1], rec2[1]);
        int right_line = Math.min(rec1[3], rec2[3]);
        int top_line = Math.max(rec1[0], rec2[0]);
        int bottom_line = Math.min(rec1[2], rec2[2]);
        if (left_line >= right_line || top_line >= bottom_line) {
            return 0.0f;
        }
        float intersect = (right_line - left_line) * (bottom_line - top_line);
        return intersect / ((float)sum_area - intersect) * 1.0f;
    }

    private Pair<Float, Float> sorted(List<Pair<Float, Float>> distances) {
        Comparator<Pair<Float, Float>> comparator = new Comparator<Pair<Float, Float>>(){

            @Override
            public int compare(Pair<Float, Float> a1, Pair<Float, Float> a2) {
                if (((Float)a1.getRight()).floatValue() > ((Float)a2.getRight()).floatValue()) {
                    return 1;
                }
                if (((Float)a1.getRight()).floatValue() == ((Float)a2.getRight()).floatValue()) {
                    if (((Float)a1.getLeft()).floatValue() > ((Float)a2.getLeft()).floatValue()) {
                        return 1;
                    }
                    return -1;
                }
                return -1;
            }
        };
        ArrayList newDistances = new ArrayList();
        CollectionUtils.addAll(newDistances, (Object[])new Object[distances.size()]);
        Collections.copy(newDistances, distances);
        Collections.sort(newDistances, comparator);
        return (Pair)newDistances.get(0);
    }

    public String convertHtml(List<String> pred_structures, List<String> cell_contents) {
        StringBuffer html = new StringBuffer();
        html.append("<style>\n");
        html.append("table { border-collapse: collapse; }\n");
        html.append("td, th, table { border: 1px solid black; padding: 5px; }\n");
        html.append("</style>\n");
        int td_index = 0;
        for (String tag : pred_structures) {
            if (tag.contains("<td></td>")) {
                String content = cell_contents.get(td_index);
                html.append("<td>");
                html.append(content);
                html.append("</td>");
                ++td_index;
                continue;
            }
            html.append(tag);
        }
        return html.toString();
    }

    public static class Builder {
        private TableStructureModel tableStructureModel;
        private OcrCommonRecModel textRecModel;
        private OcrDirectionModel directionModel;
        private OcrCommonDetModel textDetector;

        public Builder withStructureModel(TableStructureModel model) {
            this.tableStructureModel = model;
            return this;
        }

        public Builder withTextRecModel(OcrCommonRecModel model) {
            this.textRecModel = model;
            return this;
        }

        public Builder withDirectionModel(OcrDirectionModel model) {
            this.directionModel = model;
            return this;
        }

        public Builder withTextDetModel(OcrCommonDetModel model) {
            this.textDetector = model;
            return this;
        }

        public TableRecognizer build() {
            if (this.tableStructureModel == null) {
                throw new IllegalStateException("tableStructureModel \u672a\u8bbe\u7f6e");
            }
            if (this.textDetector == null) {
                throw new IllegalStateException("textDetector \u672a\u8bbe\u7f6e");
            }
            if (this.textRecModel == null) {
                throw new IllegalStateException("textRecModel \u672a\u8bbe\u7f6e");
            }
            return new TableRecognizer(this);
        }
    }
}

