/*
 * Decompiled with CFR 0.152.
 */
package org.biojava.nbio.structure.align.multiple.mc;

import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import org.biojava.nbio.structure.Atom;
import org.biojava.nbio.structure.StructureException;
import org.biojava.nbio.structure.align.multiple.Block;
import org.biojava.nbio.structure.align.multiple.BlockSet;
import org.biojava.nbio.structure.align.multiple.MultipleAlignment;
import org.biojava.nbio.structure.align.multiple.MultipleAlignmentEnsemble;
import org.biojava.nbio.structure.align.multiple.mc.MultipleMcParameters;
import org.biojava.nbio.structure.align.multiple.util.CoreSuperimposer;
import org.biojava.nbio.structure.align.multiple.util.MultipleAlignmentScorer;
import org.biojava.nbio.structure.align.multiple.util.MultipleAlignmentTools;
import org.biojava.nbio.structure.align.multiple.util.MultipleSuperimposer;
import org.biojava.nbio.structure.jama.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultipleMcOptimizer
implements Callable<MultipleAlignment> {
    private static final Logger logger = LoggerFactory.getLogger(MultipleMcOptimizer.class);
    private Random rnd;
    private MultipleSuperimposer imposer;
    private int Rmin;
    private int Lmin;
    private int convergenceSteps;
    private double C;
    private double Gopen;
    private double Gextend;
    private double dCutoff;
    private MultipleAlignment msa;
    private List<SortedSet<Integer>> freePool;
    private List<Atom[]> atomArrays;
    private int size;
    private int blockNr;
    private double mcScore;
    private static final boolean history = false;
    private static final String pathToHistory = "McOptHistory.csv";
    private List<Integer> lengthHistory;
    private List<Double> rmsdHistory;
    private List<Double> scoreHistory;

    public MultipleMcOptimizer(MultipleAlignment seedAln, MultipleMcParameters params, int reference) {
        MultipleAlignmentEnsemble e = seedAln.getEnsemble().clone();
        this.msa = e.getMultipleAlignment(0);
        this.atomArrays = this.msa.getAtomArrays();
        this.size = seedAln.size();
        this.rnd = new Random(params.getRandomSeed());
        this.Gopen = params.getGapOpen();
        this.Gextend = params.getGapExtension();
        this.dCutoff = params.getDistanceCutoff();
        this.imposer = new CoreSuperimposer(reference);
        if (params.getConvergenceSteps() == 0) {
            ArrayList<Integer> lens = new ArrayList<Integer>();
            for (int i = 0; i < this.size; ++i) {
                lens.add(this.atomArrays.get(i).length);
            }
            this.convergenceSteps = (Integer)Collections.min(lens) * this.size;
        } else {
            this.convergenceSteps = params.getConvergenceSteps();
        }
        this.Rmin = params.getMinAlignedStructures() == 0 ? Math.max(this.size / 3, 2) : Math.min(Math.max(params.getMinAlignedStructures(), 2), this.size);
        this.C = 20 * this.size;
        this.Lmin = params.getMinBlockLen();
        ArrayList<Block> toDelete = new ArrayList<Block>();
        ArrayList<BlockSet> emptyBs = new ArrayList<BlockSet>();
        for (Block b : this.msa.getBlocks()) {
            if (b.getCoreLength() >= this.Lmin) continue;
            toDelete.add(b);
            logger.warn("Deleting a Block: coreLength < Lmin.");
        }
        for (Block b : toDelete) {
            for (BlockSet bs : this.msa.getBlockSets()) {
                bs.getBlocks().remove(b);
                if (bs.getBlocks().size() != 0) continue;
                emptyBs.add(bs);
            }
        }
        for (BlockSet bs : emptyBs) {
            this.msa.getBlockSets().remove(bs);
        }
        this.blockNr = this.msa.getBlocks().size();
        if (this.blockNr < 1) {
            throw new IllegalArgumentException("Optimization: empty seed alignment, no Blocks found.");
        }
    }

    @Override
    public MultipleAlignment call() throws Exception {
        return this.optimize();
    }

    private void initialize() throws StructureException {
        int i;
        this.freePool = new ArrayList<SortedSet<Integer>>();
        ArrayList aligned = new ArrayList();
        for (i = 0; i < this.size; ++i) {
            ArrayList<Integer> residues = new ArrayList<Integer>();
            for (BlockSet bs : this.msa.getBlockSets()) {
                for (Block b : bs.getBlocks()) {
                    for (int l = 0; l < b.length(); ++l) {
                        Integer residue = b.getAlignRes().get(i).get(l);
                        if (residue == null) continue;
                        residues.add(residue);
                    }
                }
            }
            aligned.add(residues);
            this.freePool.add(new TreeSet());
        }
        for (i = 0; i < this.size; ++i) {
            for (int k = 0; k < this.atomArrays.get(i).length; ++k) {
                if (((List)aligned.get(i)).contains(k)) continue;
                this.freePool.get(i).add(k);
            }
        }
        this.checkGaps();
        this.msa.clear();
        this.imposer.superimpose(this.msa);
        this.mcScore = MultipleAlignmentScorer.getMCScore(this.msa, this.Gopen, this.Gextend, this.dCutoff);
    }

    public MultipleAlignment optimize() throws StructureException {
        this.initialize();
        int conv = 0;
        int maxIter = this.convergenceSteps * 100;
        for (int i = 1; i < maxIter && conv < this.convergenceSteps; ++i) {
            MultipleAlignment lastMSA = this.msa.clone();
            ArrayList<SortedSet<Integer>> lastFreePool = new ArrayList<SortedSet<Integer>>();
            for (int k = 0; k < this.size; ++k) {
                TreeSet<Integer> p = new TreeSet<Integer>();
                for (Integer l : this.freePool.get(k)) {
                    p.add(l);
                }
                lastFreePool.add(p);
            }
            double lastScore = this.mcScore;
            boolean moved = false;
            while (!moved) {
                double move = this.rnd.nextDouble();
                if (move < 0.4) {
                    moved = this.shiftRow();
                    logger.debug("did shift");
                    continue;
                }
                if (move < 0.7) {
                    moved = this.expandBlock();
                    logger.debug("did expand");
                    continue;
                }
                if (move < 0.85) {
                    moved = this.shrinkBlock();
                    logger.debug("did shrink");
                    continue;
                }
                moved = this.insertGap();
                logger.debug("did insert gap");
            }
            this.msa.clear();
            this.imposer.superimpose(this.msa);
            this.mcScore = MultipleAlignmentScorer.getMCScore(this.msa, this.Gopen, this.Gextend, this.dCutoff);
            double AS = this.mcScore - lastScore;
            double prob = 1.0;
            if (AS < 0.0) {
                prob = this.probabilityFunction(AS, i, maxIter);
                double p = this.rnd.nextDouble();
                if (p > prob) {
                    this.msa = lastMSA;
                    this.freePool = lastFreePool;
                    this.mcScore = lastScore;
                    ++conv;
                } else {
                    conv = 0;
                }
            } else {
                conv = 0;
            }
            logger.debug("Step: " + i + ": --prob: " + prob + ", --score change: " + AS + ", --conv: " + conv);
        }
        this.imposer.superimpose(this.msa);
        MultipleAlignmentScorer.calculateScores(this.msa);
        this.msa.putScore("MC-score", this.mcScore);
        return this.msa;
    }

    private boolean checkGaps() {
        boolean shrinkedAny = false;
        ArrayList shrinkColumns = new ArrayList();
        for (Block b : this.msa.getBlocks()) {
            ArrayList<Integer> shrinkCol = new ArrayList<Integer>();
            for (int res = 0; res < b.length(); ++res) {
                int gapCount = 0;
                for (int su = 0; su < this.size; ++su) {
                    if (b.getAlignRes().get(su).get(res) != null) continue;
                    ++gapCount;
                }
                if (this.size - gapCount >= this.Rmin) continue;
                shrinkCol.add(res);
            }
            shrinkColumns.add(shrinkCol);
        }
        for (int b = 0; b < this.blockNr; ++b) {
            for (int col = ((List)shrinkColumns.get(b)).size() - 1; col >= 0; --col) {
                for (int str = 0; str < this.size; ++str) {
                    Block bk = this.msa.getBlock(b);
                    Integer residue = bk.getAlignRes().get(str).get((Integer)((List)shrinkColumns.get(b)).get(col));
                    bk.getAlignRes().get(str).remove((Integer)((List)shrinkColumns.get(b)).get(col));
                    if (residue == null) continue;
                    this.freePool.get(str).add(residue);
                }
                shrinkedAny = true;
            }
        }
        return shrinkedAny;
    }

    private boolean insertGap() {
        Matrix residueDistances = MultipleAlignmentTools.getAverageResidueDistances(this.msa);
        double maxDist = Double.MIN_VALUE;
        int structure = 0;
        int block = 0;
        int position = 0;
        int column = 0;
        for (int b = 0; b < this.blockNr; ++b) {
            for (int col = 0; col < this.msa.getBlock(b).length(); ++col) {
                for (int str = 0; str < this.size; ++str) {
                    if (residueDistances.get(str, column) == -1.0 || !(residueDistances.get(str, column) > maxDist) || !(this.rnd.nextDouble() > 0.5)) continue;
                    structure = str;
                    block = b;
                    position = col;
                    maxDist = residueDistances.get(str, column);
                }
                ++column;
            }
        }
        Block bk = this.msa.getBlock(block);
        if (bk.getCoreLength() <= this.Lmin) {
            return false;
        }
        Integer residueL = bk.getAlignRes().get(structure).get(position);
        if (residueL == null) {
            return false;
        }
        this.freePool.get(structure).add(residueL);
        bk.getAlignRes().get(structure).set(position, null);
        this.checkGaps();
        return true;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private boolean shiftRow() {
        int str = this.rnd.nextInt(this.size);
        int rl = this.rnd.nextInt(2);
        int bk = this.rnd.nextInt(this.blockNr);
        int res = this.rnd.nextInt(this.msa.getBlock(bk).length());
        Block block = this.msa.getBlock(bk);
        if (block.getCoreLength() <= this.Lmin) {
            return false;
        }
        if (block.getAlignRes().get(str).get(res) == null) {
            int rightRes;
            int leftRes = res;
            for (rightRes = res; block.getAlignRes().get(str).get(rightRes) == null && rightRes < block.length() - 1; ++rightRes) {
            }
            while (block.getAlignRes().get(str).get(leftRes) == null && leftRes > 0) {
                --leftRes;
            }
            if (block.getAlignRes().get(str).get(leftRes) == null && block.getAlignRes().get(str).get(rightRes) == null) {
                return false;
            }
            if (block.getAlignRes().get(str).get(leftRes) == null) {
                Integer residue = block.getAlignRes().get(str).get(rightRes) - 1;
                if (!this.freePool.get(str).contains(residue)) return false;
                block.getAlignRes().get(str).set(res, residue);
                this.freePool.get(str).remove(residue);
                return true;
            } else if (block.getAlignRes().get(str).get(rightRes) == null) {
                Integer residue = block.getAlignRes().get(str).get(leftRes) + 1;
                if (!this.freePool.contains(residue)) return false;
                block.getAlignRes().get(str).set(res, residue);
                this.freePool.get(str).remove(residue);
                return true;
            } else {
                if (block.getAlignRes().get(str).get(rightRes) == block.getAlignRes().get(str).get(leftRes) + 1) {
                    return false;
                }
                Integer residue = this.rnd.nextInt(block.getAlignRes().get(str).get(rightRes) - block.getAlignRes().get(str).get(leftRes) - 1) + block.getAlignRes().get(str).get(leftRes) + 1;
                if (!this.freePool.get(str).contains(residue)) return true;
                block.getAlignRes().get(str).set(res, residue);
                this.freePool.get(str).remove(residue);
            }
            return true;
        }
        switch (rl) {
            case 0: {
                int leftBoundary = res - 1;
                int leftPrevRes = res;
                while (leftBoundary >= 0 && block.getAlignRes().get(str).get(leftBoundary) != null && block.getAlignRes().get(str).get(leftPrevRes) <= block.getAlignRes().get(str).get(leftBoundary) + 1) {
                    leftPrevRes = leftBoundary--;
                }
                ++leftBoundary;
                int rightBoundary = res + 1;
                int rightPrevRes = res;
                while (rightBoundary != block.length() && block.getAlignRes().get(str).get(rightBoundary) != null && block.getAlignRes().get(str).get(rightPrevRes) + 1 >= block.getAlignRes().get(str).get(rightBoundary)) {
                    rightPrevRes = rightBoundary++;
                }
                Integer resR0 = block.getAlignRes().get(str).get(--rightBoundary);
                Integer resL0 = block.getAlignRes().get(str).get(leftBoundary);
                block.getAlignRes().get(str).remove(rightBoundary);
                if (resR0 != null) {
                    this.freePool.get(str).add(resR0);
                }
                if (resL0 != null) {
                    resL0 = resL0 - 1;
                }
                if (this.freePool.get(str).contains(resL0)) {
                    block.getAlignRes().get(str).add(leftBoundary, resL0);
                    this.freePool.get(str).remove(resL0);
                    break;
                }
                block.getAlignRes().get(str).add(leftBoundary, null);
                break;
            }
            case 1: {
                int leftBoundary1 = res - 1;
                int leftPrevRes1 = res;
                while (leftBoundary1 >= 0 && block.getAlignRes().get(str).get(leftBoundary1) != null && block.getAlignRes().get(str).get(leftPrevRes1) <= block.getAlignRes().get(str).get(leftBoundary1) + 1) {
                    leftPrevRes1 = leftBoundary1--;
                }
                ++leftBoundary1;
                int rightBoundary1 = res + 1;
                int rightPrevRes1 = res;
                while (rightBoundary1 != block.length() && block.getAlignRes().get(str).get(rightBoundary1) != null && block.getAlignRes().get(str).get(rightPrevRes1) + 1 >= block.getAlignRes().get(str).get(rightBoundary1)) {
                    rightPrevRes1 = rightBoundary1++;
                }
                Integer resR1 = block.getAlignRes().get(str).get(--rightBoundary1);
                Integer resL1 = block.getAlignRes().get(str).get(leftBoundary1);
                if (resR1 != null) {
                    resR1 = resR1 + 1;
                }
                if (this.freePool.contains(resR1)) {
                    if (rightBoundary1 == block.length() - 1) {
                        block.getAlignRes().get(str).add(resR1);
                    } else {
                        block.getAlignRes().get(str).add(rightBoundary1 + 1, resR1);
                    }
                    this.freePool.get(str).remove(resR1);
                } else {
                    block.getAlignRes().get(str).add(rightBoundary1 + 1, null);
                }
                block.getAlignRes().get(str).remove(leftBoundary1);
                if (resL1 == null) break;
                this.freePool.get(str).add(resL1);
            }
        }
        this.checkGaps();
        return true;
    }

    private boolean expandBlock() {
        int rl = this.rnd.nextInt(2);
        int bk = this.rnd.nextInt(this.blockNr);
        int res = this.rnd.nextInt(this.msa.getBlock(bk).length());
        Block block = this.msa.getBlock(bk);
        int gaps = 0;
        switch (rl) {
            case 0: {
                int str;
                int rightBound = res;
                int[] previousPos = new int[this.size];
                for (str = 0; str < this.size; ++str) {
                    previousPos[str] = -1;
                }
                while (block.length() - 1 > rightBound) {
                    int noncontinuous = 0;
                    for (int str2 = 0; str2 < this.size; ++str2) {
                        if (block.getAlignRes().get(str2).get(rightBound) == null) continue;
                        if (previousPos[str2] == -1) {
                            previousPos[str2] = block.getAlignRes().get(str2).get(rightBound);
                            continue;
                        }
                        if (block.getAlignRes().get(str2).get(rightBound) <= previousPos[str2] + 1) continue;
                        ++noncontinuous;
                    }
                    if (noncontinuous >= this.Rmin) break;
                    ++rightBound;
                }
                if (rightBound > 0) {
                    --rightBound;
                }
                for (str = 0; str < this.size; ++str) {
                    Integer residueR = block.getAlignRes().get(str).get(rightBound);
                    if (residueR == null) {
                        if (rightBound == block.length() - 1) {
                            block.getAlignRes().get(str).add(null);
                        } else {
                            block.getAlignRes().get(str).add(rightBound + 1, null);
                        }
                        ++gaps;
                        continue;
                    }
                    if (this.freePool.get(str).contains(residueR + 1)) {
                        Integer residueAdd = residueR + 1;
                        if (rightBound == block.length() - 1) {
                            block.getAlignRes().get(str).add(residueAdd);
                        } else {
                            block.getAlignRes().get(str).add(rightBound + 1, residueAdd);
                        }
                        this.freePool.get(str).remove(residueAdd);
                        continue;
                    }
                    if (rightBound == block.length() - 1) {
                        block.getAlignRes().get(str).add(null);
                    } else {
                        block.getAlignRes().get(str).add(rightBound + 1, null);
                    }
                    ++gaps;
                }
                break;
            }
            case 1: {
                int str;
                int leftBoundary = res;
                int[] nextPos = new int[this.size];
                for (str = 0; str < this.size; ++str) {
                    nextPos[str] = -1;
                }
                while (leftBoundary > 0) {
                    int noncontinuous = 0;
                    for (int str3 = 0; str3 < this.size; ++str3) {
                        if (block.getAlignRes().get(str3).get(leftBoundary) == null) continue;
                        if (nextPos[str3] == -1) {
                            nextPos[str3] = block.getAlignRes().get(str3).get(leftBoundary);
                            continue;
                        }
                        if (block.getAlignRes().get(str3).get(leftBoundary) >= nextPos[str3] - 1) continue;
                        ++noncontinuous;
                    }
                    if (noncontinuous >= this.Rmin) break;
                    --leftBoundary;
                }
                for (str = 0; str < this.size; ++str) {
                    Integer residueL = block.getAlignRes().get(str).get(leftBoundary);
                    if (residueL == null) {
                        block.getAlignRes().get(str).add(leftBoundary, null);
                        ++gaps;
                        continue;
                    }
                    if (this.freePool.get(str).contains(residueL - 1)) {
                        Integer residueAdd = residueL - 1;
                        block.getAlignRes().get(str).add(leftBoundary, residueAdd);
                        this.freePool.get(str).remove(residueAdd);
                        continue;
                    }
                    block.getAlignRes().get(str).add(leftBoundary, null);
                    ++gaps;
                }
                break;
            }
        }
        if (this.size - gaps >= this.Rmin) {
            return true;
        }
        this.checkGaps();
        return false;
    }

    private boolean shrinkBlock() {
        Matrix residueDistances = MultipleAlignmentTools.getAverageResidueDistances(this.msa);
        double[] colDistances = new double[this.msa.length()];
        double maxDist = Double.MIN_VALUE;
        int position = 0;
        int block = 0;
        int column = 0;
        for (int b = 0; b < this.msa.getBlocks().size(); ++b) {
            for (int col = 0; col < this.msa.getBlock(b).length(); ++col) {
                int normalize = 0;
                for (int s = 0; s < this.size; ++s) {
                    if (residueDistances.get(s, column) == -1.0) continue;
                    int n = column;
                    colDistances[n] = colDistances[n] + residueDistances.get(s, column);
                    ++normalize;
                }
                int n = column;
                colDistances[n] = colDistances[n] / (double)normalize;
                if (colDistances[column] > maxDist && this.rnd.nextDouble() > 0.5) {
                    maxDist = colDistances[column];
                    position = col;
                    block = b;
                }
                ++column;
            }
        }
        Block currentBlock = this.msa.getBlock(block);
        if (currentBlock.getCoreLength() <= this.Lmin) {
            return false;
        }
        for (int str = 0; str < this.size; ++str) {
            Integer residue = currentBlock.getAlignRes().get(str).get(position);
            currentBlock.getAlignRes().get(str).remove(position);
            if (residue == null) continue;
            this.freePool.get(str).add(residue);
        }
        return true;
    }

    private double probabilityFunction(double AS, int m, int maxIter) {
        double prob = (this.C + AS) / ((double)m * this.C);
        double norm = 1.0 - (double)m * 1.0 / (double)maxIter;
        return Math.min(Math.max(prob * norm, 0.0), 1.0);
    }

    private void saveHistory(String filePath) throws IOException {
        FileWriter writer = new FileWriter(filePath);
        writer.append("Step,Length,RMSD,Score\n");
        for (int i = 0; i < this.lengthHistory.size(); ++i) {
            writer.append(String.valueOf(i * 100));
            writer.append("," + this.lengthHistory.get(i));
            writer.append("," + this.rmsdHistory.get(i));
            writer.append("," + this.scoreHistory.get(i) + "\n");
        }
        writer.flush();
        writer.close();
    }
}

