/*
 * Decompiled with CFR 0.152.
 */
package org.biojava.nbio.structure.align.multiple;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.vecmath.Matrix4d;
import org.biojava.nbio.structure.Atom;
import org.biojava.nbio.structure.Calc;
import org.biojava.nbio.structure.SVDSuperimposer;
import org.biojava.nbio.structure.StructureException;
import org.biojava.nbio.structure.align.multiple.Block;
import org.biojava.nbio.structure.align.multiple.BlockSet;
import org.biojava.nbio.structure.align.multiple.MultipleAlignment;
import org.biojava.nbio.structure.jama.Matrix;

public class MultipleAlignmentScorer {
    public static final String PROBABILITY = "Probability";
    public static final String CE_SCORE = "CEscore";
    public static final String RMSD = "RMSD";
    public static final String AVG_TMSCORE = "Avg-TMscore";
    public static final String CEMC_SCORE = "CEMCscore";
    public static final String REF_RMSD = "Ref-RMSD";
    public static final String REF_TMSCORE = "Ref-TMscore";

    public static void calculateScores(MultipleAlignment alignment) throws StructureException {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        alignment.putScore(RMSD, MultipleAlignmentScorer.getRMSD(transformed));
        ArrayList<Integer> lengths = new ArrayList<Integer>(alignment.size());
        for (Atom[] atoms : alignment.getEnsemble().getAtomArrays()) {
            lengths.add(atoms.length);
        }
        alignment.putScore(AVG_TMSCORE, MultipleAlignmentScorer.getAvgTMScore(transformed, lengths));
    }

    public static List<Atom[]> transformAtoms(MultipleAlignment alignment) {
        if (alignment.getEnsemble() == null) {
            throw new NullPointerException("No ensemble set for this alignment");
        }
        List<Atom[]> atomArrays = alignment.getEnsemble().getAtomArrays();
        ArrayList<Atom[]> transformed = new ArrayList<Atom[]>(atomArrays.size());
        for (int i = 0; i < atomArrays.size(); ++i) {
            Matrix4d transform = null;
            if (alignment.getTransformations() != null) {
                transform = alignment.getTransformations().get(i);
            }
            Atom[] curr = atomArrays.get(i);
            Atom[] transformedAtoms = new Atom[alignment.length()];
            int transformedAtomsLength = 0;
            for (BlockSet bs : alignment.getBlockSets()) {
                Atom[] blocksetAtoms = new Atom[bs.length()];
                for (Block blk : bs.getBlocks()) {
                    if (blk.size() != atomArrays.size()) {
                        throw new IllegalStateException(String.format("Mismatched block length. Expected %d structures, found %d.", atomArrays.size(), blk.size()));
                    }
                    for (int j = 0; j < blk.length(); ++j) {
                        Integer alignedPos = blk.getAlignRes().get(i).get(j);
                        if (alignedPos == null) continue;
                        blocksetAtoms[j] = (Atom)curr[alignedPos].clone();
                    }
                }
                Matrix4d blockTrans = null;
                if (bs.getTransformations() != null) {
                    blockTrans = bs.getTransformations().get(i);
                }
                if (blockTrans == null) {
                    blockTrans = transform;
                }
                for (Atom a : blocksetAtoms) {
                    if (a != null) {
                        Calc.transform(a, blockTrans);
                    }
                    transformedAtoms[transformedAtomsLength] = a;
                    ++transformedAtomsLength;
                }
            }
            assert (transformedAtomsLength == alignment.length());
            transformed.add(transformedAtoms);
        }
        return transformed;
    }

    public static double getRMSD(MultipleAlignment alignment) {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        return MultipleAlignmentScorer.getRMSD(transformed);
    }

    private static double getRMSD(List<Atom[]> transformed) {
        double sumSqDist = 0.0;
        int comparisons = 0;
        for (int r1 = 0; r1 < transformed.size(); ++r1) {
            for (int c = 0; c < transformed.get(r1).length; ++c) {
                Atom refAtom = transformed.get(r1)[c];
                if (refAtom == null) continue;
                double nonNullSqDist = 0.0;
                int nonNullLength = 0;
                for (int r2 = r1 + 1; r2 < transformed.size(); ++r2) {
                    Atom atom = transformed.get(r2)[c];
                    if (atom == null) continue;
                    nonNullSqDist += Calc.getDistanceFast(refAtom, atom);
                    ++nonNullLength;
                }
                if (nonNullLength <= 0) continue;
                ++comparisons;
                sumSqDist += nonNullSqDist / (double)nonNullLength;
            }
        }
        return Math.sqrt(sumSqDist / (double)comparisons);
    }

    public static double getRefRMSD(MultipleAlignment alignment, int reference) {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        return MultipleAlignmentScorer.getRefRMSD(transformed, reference);
    }

    private static double getRefRMSD(List<Atom[]> transformed, int reference) {
        double sumSqDist = 0.0;
        int totalLength = 0;
        for (int c = 0; c < transformed.get(reference).length; ++c) {
            Atom refAtom = transformed.get(reference)[c];
            if (refAtom == null) continue;
            double nonNullSqDist = 0.0;
            int nonNullLength = 0;
            for (int r = 0; r < transformed.size(); ++r) {
                Atom atom;
                if (r == reference || (atom = transformed.get(r)[c]) == null) continue;
                nonNullSqDist += Calc.getDistanceFast(refAtom, atom);
                ++nonNullLength;
            }
            if (nonNullLength <= 0) continue;
            ++totalLength;
            sumSqDist += nonNullSqDist / (double)nonNullLength;
        }
        return Math.sqrt(sumSqDist / (double)totalLength);
    }

    public static double getAvgTMScore(MultipleAlignment alignment) throws StructureException {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        ArrayList<Integer> lengths = new ArrayList<Integer>(alignment.size());
        for (Atom[] atoms : alignment.getEnsemble().getAtomArrays()) {
            lengths.add(atoms.length);
        }
        return MultipleAlignmentScorer.getAvgTMScore(transformed, lengths);
    }

    private static double getAvgTMScore(List<Atom[]> transformed, List<Integer> lengths) throws StructureException {
        if (transformed.size() != lengths.size()) {
            throw new IllegalArgumentException("Input sizes differ.");
        }
        double sumTM = 0.0;
        int comparisons = 0;
        for (int r1 = 0; r1 < transformed.size(); ++r1) {
            for (int r2 = r1 + 1; r2 < transformed.size(); ++r2) {
                int len = transformed.get(r1).length;
                Atom[] ref = new Atom[len];
                Atom[] aln = new Atom[len];
                int nonNullLen = 0;
                for (int c = 0; c < len; ++c) {
                    if (transformed.get(r1)[c] == null || transformed.get(r2)[c] == null) continue;
                    ref[nonNullLen] = transformed.get(r1)[c];
                    aln[nonNullLen] = transformed.get(r2)[c];
                    ++nonNullLen;
                }
                if (nonNullLen < len) {
                    ref = Arrays.copyOf(ref, nonNullLen);
                    aln = Arrays.copyOf(aln, nonNullLen);
                }
                sumTM += SVDSuperimposer.getTMScore(ref, aln, lengths.get(r1), lengths.get(r2));
                ++comparisons;
            }
        }
        return sumTM / (double)comparisons;
    }

    public static double getRefTMScore(MultipleAlignment alignment, int reference) throws StructureException {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        ArrayList<Integer> lengths = new ArrayList<Integer>(alignment.size());
        for (Atom[] atoms : alignment.getEnsemble().getAtomArrays()) {
            lengths.add(atoms.length);
        }
        return MultipleAlignmentScorer.getRefTMScore(transformed, lengths, reference);
    }

    private static double getRefTMScore(List<Atom[]> transformed, List<Integer> lengths, int reference) throws StructureException {
        if (transformed.size() != lengths.size()) {
            throw new IllegalArgumentException("Input sizes differ");
        }
        double sumTM = 0.0;
        int comparisons = 0;
        int len = transformed.get(reference).length;
        for (int r = 0; r < transformed.size(); ++r) {
            if (r == reference) continue;
            Atom[] ref = new Atom[len];
            Atom[] aln = new Atom[len];
            int nonNullLen = 0;
            for (int c = 0; c < len; ++c) {
                if (transformed.get(reference)[c] == null || transformed.get(r)[c] == null) continue;
                ref[nonNullLen] = transformed.get(reference)[c];
                aln[nonNullLen] = transformed.get(r)[c];
                ++nonNullLen;
            }
            if (nonNullLen < len) {
                ref = Arrays.copyOf(ref, nonNullLen);
                aln = Arrays.copyOf(aln, nonNullLen);
            }
            sumTM += SVDSuperimposer.getTMScore(ref, aln, lengths.get(reference), lengths.get(r));
            ++comparisons;
        }
        return sumTM / (double)comparisons;
    }

    public static double getCEMCScore(MultipleAlignment alignment) throws StructureException {
        List<Atom[]> transformed = MultipleAlignmentScorer.transformAtoms(alignment);
        int minLen = Integer.MAX_VALUE;
        for (Atom[] atoms : alignment.getEnsemble().getAtomArrays()) {
            if (atoms.length >= minLen) continue;
            minLen = atoms.length;
        }
        double d0 = 1.24 * Math.cbrt((double)minLen - 15.0) - 1.8;
        return MultipleAlignmentScorer.getCEMCScore(transformed, d0);
    }

    private static double getCEMCScore(List<Atom[]> transformed, double d0) throws StructureException {
        int size = transformed.size();
        int length = transformed.get(0).length;
        Matrix residueDistances = new Matrix(size, length, -1.0);
        double scoreMC = 0.0;
        int gapOpen = 0;
        int gapExtend = 0;
        for (int r1 = 0; r1 < size; ++r1) {
            boolean gapped = false;
            for (int c = 0; c < transformed.get(r1).length; ++c) {
                Atom refAtom = transformed.get(r1)[c];
                if (refAtom == null) {
                    if (gapped) {
                        ++gapExtend;
                        continue;
                    }
                    gapped = true;
                    ++gapOpen;
                    continue;
                }
                gapped = false;
                for (int r2 = r1 + 1; r2 < size; ++r2) {
                    Atom atom = transformed.get(r2)[c];
                    if (atom == null) continue;
                    double distance = Calc.getDistance(refAtom, atom);
                    if (residueDistances.get(r1, c) == -1.0) {
                        residueDistances.set(r1, c, 1.0 + distance);
                    } else {
                        residueDistances.set(r1, c, residueDistances.get(r1, c) + distance);
                    }
                    if (residueDistances.get(r2, c) == -1.0) {
                        residueDistances.set(r2, c, 1.0 + distance);
                        continue;
                    }
                    residueDistances.set(r2, c, residueDistances.get(r2, c) + distance);
                }
            }
        }
        for (int c = 0; c < length; ++c) {
            int r;
            int nonNullRes = 0;
            for (r = 0; r < size; ++r) {
                if (residueDistances.get(r, c) == -1.0) continue;
                ++nonNullRes;
            }
            for (r = 0; r < size; ++r) {
                if (residueDistances.get(r, c) == -1.0) continue;
                residueDistances.set(r, c, residueDistances.get(r, c) / (double)nonNullRes);
            }
        }
        for (int row = 0; row < size; ++row) {
            for (int col = 0; col < length; ++col) {
                if (residueDistances.get(row, col) == -1.0) continue;
                double d1 = residueDistances.get(row, col);
                double resScore = 20.0 / (1.0 + d1 * d1 / (d0 * d0));
                scoreMC += resScore;
            }
        }
        return scoreMC - ((double)gapOpen * 10.0 + (double)gapExtend * 5.0);
    }
}

