/*
 * Decompiled with CFR 0.152.
 */
package org.biojava.nbio.structure.align.quaternary;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.vecmath.Matrix4d;
import org.biojava.nbio.structure.Atom;
import org.biojava.nbio.structure.Calc;
import org.biojava.nbio.structure.Structure;
import org.biojava.nbio.structure.StructureException;
import org.biojava.nbio.structure.StructureTools;
import org.biojava.nbio.structure.align.multiple.BlockImpl;
import org.biojava.nbio.structure.align.multiple.BlockSetImpl;
import org.biojava.nbio.structure.align.multiple.MultipleAlignmentEnsembleImpl;
import org.biojava.nbio.structure.align.multiple.MultipleAlignmentImpl;
import org.biojava.nbio.structure.align.multiple.util.MultipleAlignmentScorer;
import org.biojava.nbio.structure.align.multiple.util.ReferenceSuperimposer;
import org.biojava.nbio.structure.align.quaternary.QsAlignParameters;
import org.biojava.nbio.structure.align.quaternary.QsAlignResult;
import org.biojava.nbio.structure.cluster.Subunit;
import org.biojava.nbio.structure.cluster.SubunitCluster;
import org.biojava.nbio.structure.cluster.SubunitClusterer;
import org.biojava.nbio.structure.cluster.SubunitClustererParameters;
import org.biojava.nbio.structure.cluster.SubunitExtractor;
import org.biojava.nbio.structure.contact.Pair;
import org.biojava.nbio.structure.geometry.SuperPositions;
import org.biojava.nbio.structure.geometry.UnitQuaternions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class QsAlign {
    private static final Logger logger = LoggerFactory.getLogger(QsAlign.class);

    public static QsAlignResult align(Structure s1, Structure s2, SubunitClustererParameters cParams, QsAlignParameters aParams) throws StructureException {
        return QsAlign.align(SubunitExtractor.extractSubunits(s1, cParams.getAbsoluteMinimumSequenceLength(), cParams.getMinimumSequenceLengthFraction(), cParams.getMinimumSequenceLength()), SubunitExtractor.extractSubunits(s2, cParams.getAbsoluteMinimumSequenceLength(), cParams.getMinimumSequenceLengthFraction(), cParams.getMinimumSequenceLength()), cParams, aParams);
    }

    public static QsAlignResult align(List<Subunit> s1, List<Subunit> s2, SubunitClustererParameters cParams, QsAlignParameters aParams) throws StructureException {
        QsAlignResult result = new QsAlignResult(s1, s2);
        List<SubunitCluster> c1 = SubunitClusterer.cluster(s1, cParams);
        List<SubunitCluster> c2 = SubunitClusterer.cluster(s2, cParams);
        HashMap<Integer, Integer> clusterMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < c1.size(); ++i) {
            for (int j = 0; j < c2.size() && !clusterMap.keySet().contains(i); ++j) {
                if (clusterMap.values().contains(j) || !c1.get(i).mergeStructure(c2.get(j), cParams.getRmsdThreshold(), cParams.getCoverageThreshold(), aParams.getAligner())) continue;
                clusterMap.put(i, j);
            }
        }
        logger.info("Cluster Map: " + ((Object)clusterMap).toString());
        result.setClusters(c1);
        Iterator iterator = clusterMap.keySet().iterator();
        while (iterator.hasNext()) {
            Matrix4d transform;
            int globalKey = (Integer)iterator.next();
            SubunitCluster clust1 = c1.get(globalKey);
            SubunitCluster clust2 = c2.get((Integer)clusterMap.get(globalKey));
            int index1 = 0;
            int index2 = clust1.size() - clust2.size();
            Map<Integer, Integer> subunitMap = new HashMap<Integer, Integer>();
            subunitMap.put(index1, index2);
            HashMap<Integer, Map<Integer, Integer>> clustSubunitMap = new HashMap<Integer, Map<Integer, Integer>>();
            clustSubunitMap.put(globalKey, subunitMap);
            ArrayList keySet = new ArrayList(clusterMap.keySet());
            keySet.remove((Object)globalKey);
            keySet.add(0, globalKey);
            Iterator iterator2 = clusterMap.keySet().iterator();
            while (iterator2.hasNext()) {
                int key = (Integer)iterator2.next();
                subunitMap = key == globalKey ? (Map)clustSubunitMap.get(key) : new HashMap();
                clust1 = c1.get(key);
                clust2 = c2.get((Integer)clusterMap.get(key));
                index1 = 0;
                index2 = clust1.size() - clust2.size();
                for (int i = 0; i < index2; ++i) {
                    for (int j = index2; j < clust1.size() && !subunitMap.keySet().contains(i); ++j) {
                        if (subunitMap.values().contains(j)) continue;
                        transform = QsAlign.getTransformForClusterSubunitMap(c1, clustSubunitMap);
                        Atom[] atoms1 = clust1.getAlignedAtomsSubunit(i);
                        Atom[] atoms2 = clust1.getAlignedAtomsSubunit(j);
                        Atom centr1 = Calc.getCentroid(atoms1);
                        Atom centr2 = Calc.getCentroid(atoms2);
                        Calc.transform(centr2, transform);
                        double dCentroid = Calc.getDistance(centr1, centr2);
                        if (dCentroid > aParams.getdCutoff()) {
                            logger.debug(String.format("Subunit matching %d vs %d of cluster %d could not be matched, because centroid distance is %.2f", index1, index2, key, dCentroid));
                            continue;
                        }
                        Atom[] atoms2c = StructureTools.cloneAtomArray(atoms2);
                        Calc.transform(atoms2c, transform);
                        double qOrient = UnitQuaternions.orientationAngle(Calc.atomsToPoints(atoms1), Calc.atomsToPoints(atoms2c), false);
                        qOrient = Math.min(Math.abs(Math.PI * 2 - qOrient), qOrient);
                        if (qOrient > aParams.getMaxOrientationAngle()) {
                            logger.debug(String.format("Subunit matching %d vs %d of cluster %d could not be matched, because orientation metric is %.2f", i, j, key, qOrient));
                            continue;
                        }
                        double rmsd = Calc.rmsd(atoms1, atoms2c);
                        if (rmsd > aParams.getMaxRmsd()) {
                            logger.debug(String.format("Subunit matching %d vs %d of cluster %d could not be matched, because RMSD is %.2f", i, j, key, rmsd));
                            continue;
                        }
                        logger.info(String.format("Subunit matching %d vs %d of cluster %d with centroid distance %.2f, orientation metric %.2f and RMSD %.2f", i, j, key, dCentroid, qOrient, rmsd));
                        subunitMap.put(i, j);
                    }
                }
                clustSubunitMap.put(key, subunitMap);
            }
            logger.info("Cluster Subunit Map: " + ((Object)clustSubunitMap).toString());
            subunitMap = new HashMap();
            ArrayList<Integer> alignRes1 = new ArrayList<Integer>();
            ArrayList<Integer> alignRes2 = new ArrayList<Integer>();
            ArrayList<Atom> atomArray1 = new ArrayList<Atom>();
            ArrayList<Atom> atomArray2 = new ArrayList<Atom>();
            transform = clustSubunitMap.keySet().iterator();
            while (transform.hasNext()) {
                int key = (Integer)transform.next();
                SubunitCluster cluster = c1.get(key);
                List<List<Integer>> clusterEqrs = cluster.getMultipleAlignment().getBlock(0).getAlignRes();
                for (Map.Entry pair : ((Map)clustSubunitMap.get(key)).entrySet()) {
                    int i = (Integer)pair.getKey();
                    int j = (Integer)pair.getValue();
                    int orig1 = s1.indexOf(cluster.getSubunits().get(i));
                    int orig2 = s2.indexOf(cluster.getSubunits().get(j));
                    for (Integer eqr : clusterEqrs.get(i)) {
                        alignRes1.add(eqr + atomArray1.size());
                    }
                    for (Integer eqr : clusterEqrs.get(j)) {
                        alignRes2.add(eqr + atomArray2.size());
                    }
                    atomArray1.addAll(Arrays.asList(s1.get(orig1).getRepresentativeAtoms()));
                    atomArray2.addAll(Arrays.asList(s2.get(orig2).getRepresentativeAtoms()));
                    subunitMap.put(orig1, orig2);
                }
            }
            MultipleAlignmentImpl msa = new MultipleAlignmentImpl();
            msa.setEnsemble(new MultipleAlignmentEnsembleImpl());
            msa.getEnsemble().setAtomArrays(Arrays.asList(atomArray1.toArray(new Atom[atomArray1.size()]), atomArray2.toArray(new Atom[atomArray2.size()])));
            BlockSetImpl bs = new BlockSetImpl(msa);
            BlockImpl b = new BlockImpl(bs);
            ArrayList<List<Integer>> alignRes = new ArrayList<List<Integer>>(2);
            alignRes.add(alignRes1);
            alignRes.add(alignRes2);
            b.setAlignRes(alignRes);
            new ReferenceSuperimposer().superimpose(msa);
            MultipleAlignmentScorer.calculateScores(msa);
            if (subunitMap.size() > result.getSubunitMap().size()) {
                result.setSubunitMap(subunitMap);
                result.setAlignment(msa);
                logger.info("Better result found: " + result.toString());
                continue;
            }
            if (subunitMap.size() != result.getSubunitMap().size()) continue;
            if (result.getAlignment() == null) {
                result.setSubunitMap(subunitMap);
                result.setAlignment(msa);
                continue;
            }
            if (!(msa.getScore("RMSD") < result.getRmsd())) continue;
            result.setSubunitMap(subunitMap);
            result.setAlignment(msa);
            logger.info("Better result found: " + result.toString());
        }
        return result;
    }

    private static Pair<Atom[]> getAlignedAtomsForClusterSubunitMap(List<SubunitCluster> clusters, Map<Integer, Map<Integer, Integer>> clusterSubunitMap) {
        ArrayList<Atom> atomArray1 = new ArrayList<Atom>();
        ArrayList<Atom> atomArray2 = new ArrayList<Atom>();
        for (int key : clusterSubunitMap.keySet()) {
            SubunitCluster cluster = clusters.get(key);
            for (Map.Entry<Integer, Integer> pair : clusterSubunitMap.get(key).entrySet()) {
                int i = pair.getKey();
                int j = pair.getValue();
                atomArray1.addAll(Arrays.asList(cluster.getAlignedAtomsSubunit(i)));
                atomArray2.addAll(Arrays.asList(cluster.getAlignedAtomsSubunit(j)));
            }
        }
        return new Pair<Atom[]>(atomArray1.toArray(new Atom[atomArray1.size()]), atomArray2.toArray(new Atom[atomArray2.size()]));
    }

    private static Matrix4d getTransformForClusterSubunitMap(List<SubunitCluster> clusters, Map<Integer, Map<Integer, Integer>> clusterSubunitMap) throws StructureException {
        Pair<Atom[]> pair = QsAlign.getAlignedAtomsForClusterSubunitMap(clusters, clusterSubunitMap);
        return SuperPositions.superpose(Calc.atomsToPoints(pair.getFirst()), Calc.atomsToPoints(pair.getSecond()));
    }
}

