/*
 * Decompiled with CFR 0.152.
 */
package org.biojava.nbio.structure.asa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.vecmath.Point3d;
import javax.vecmath.Tuple3d;
import javax.vecmath.Vector3d;
import org.biojava.nbio.structure.AminoAcid;
import org.biojava.nbio.structure.Atom;
import org.biojava.nbio.structure.Calc;
import org.biojava.nbio.structure.Element;
import org.biojava.nbio.structure.Group;
import org.biojava.nbio.structure.GroupType;
import org.biojava.nbio.structure.NucleotideImpl;
import org.biojava.nbio.structure.ResidueNumber;
import org.biojava.nbio.structure.Structure;
import org.biojava.nbio.structure.StructureTools;
import org.biojava.nbio.structure.asa.GroupAsa;
import org.biojava.nbio.structure.contact.Contact;
import org.biojava.nbio.structure.contact.Grid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsaCalculator {
    private static final Logger logger = LoggerFactory.getLogger(AsaCalculator.class);
    public static final int DEFAULT_N_SPHERE_POINTS = 1000;
    public static final double DEFAULT_PROBE_SIZE = 1.4;
    public static final int DEFAULT_NTHREADS = 1;
    private static final boolean DEFAULT_USE_SPATIAL_HASHING = true;
    public static final double TRIGONAL_CARBON_VDW = 1.76;
    public static final double TETRAHEDRAL_CARBON_VDW = 1.87;
    public static final double TRIGONAL_NITROGEN_VDW = 1.65;
    public static final double TETRAHEDRAL_NITROGEN_VDW = 1.5;
    public static final double SULFUR_VDW = 1.85;
    public static final double OXIGEN_VDW = 1.4;
    public static final double NUC_CARBON_VDW = 1.8;
    public static final double NUC_NITROGEN_VDW = 1.6;
    public static final double PHOSPHOROUS_VDW = 1.9;
    private final Point3d[] atomCoords;
    private final Atom[] atoms;
    private final double[] radii;
    private final double probe;
    private final int nThreads;
    private Vector3d[] spherePoints;
    private double cons;
    private IndexAndDistance[][] neighborIndices;
    private boolean useSpatialHashingForNeighbors;

    public AsaCalculator(Structure structure, double probe, int nSpherePoints, int nThreads, boolean hetAtoms) {
        this.atoms = StructureTools.getAllNonHAtomArray(structure, hetAtoms);
        this.atomCoords = Calc.atomsToPoints(this.atoms);
        this.probe = probe;
        this.nThreads = nThreads;
        this.useSpatialHashingForNeighbors = true;
        this.radii = new double[this.atomCoords.length];
        for (int i = 0; i < this.atomCoords.length; ++i) {
            this.radii[i] = AsaCalculator.getRadius(this.atoms[i]);
        }
        this.initSpherePoints(nSpherePoints);
    }

    public AsaCalculator(Atom[] atoms, double probe, int nSpherePoints, int nThreads) {
        this.atoms = atoms;
        this.atomCoords = Calc.atomsToPoints(atoms);
        this.probe = probe;
        this.nThreads = nThreads;
        this.useSpatialHashingForNeighbors = true;
        for (Atom atom : atoms) {
            if (atom.getElement() != Element.H) continue;
            throw new IllegalArgumentException("Can't calculate ASA for an array that contains Hydrogen atoms ");
        }
        this.radii = new double[atoms.length];
        for (int i = 0; i < atoms.length; ++i) {
            this.radii[i] = AsaCalculator.getRadius(atoms[i]);
        }
        this.initSpherePoints(nSpherePoints);
    }

    public AsaCalculator(Point3d[] atomCoords, double probe, int nSpherePoints, int nThreads, double radius) {
        this.atoms = null;
        this.atomCoords = atomCoords;
        this.probe = probe;
        this.nThreads = nThreads;
        this.useSpatialHashingForNeighbors = true;
        this.radii = new double[atomCoords.length];
        for (int i = 0; i < atomCoords.length; ++i) {
            this.radii[i] = radius;
        }
        this.initSpherePoints(nSpherePoints);
    }

    private void initSpherePoints(int nSpherePoints) {
        logger.debug("Will use {} sphere points", (Object)nSpherePoints);
        this.spherePoints = this.generateSpherePoints(nSpherePoints);
        this.cons = Math.PI * 4 / (double)nSpherePoints;
    }

    public GroupAsa[] getGroupAsas() {
        TreeMap<ResidueNumber, GroupAsa> asas = new TreeMap<ResidueNumber, GroupAsa>();
        double[] asasPerAtom = this.calculateAsas();
        for (int i = 0; i < this.atomCoords.length; ++i) {
            GroupAsa groupAsa;
            Group g = this.atoms[i].getGroup();
            if (!asas.containsKey(g.getResidueNumber())) {
                groupAsa = new GroupAsa(g);
                groupAsa.addAtomAsaU(asasPerAtom[i]);
                asas.put(g.getResidueNumber(), groupAsa);
                continue;
            }
            groupAsa = (GroupAsa)asas.get(g.getResidueNumber());
            groupAsa.addAtomAsaU(asasPerAtom[i]);
        }
        return asas.values().toArray(new GroupAsa[0]);
    }

    public double[] calculateAsas() {
        double[] asas = new double[this.atomCoords.length];
        long start = System.currentTimeMillis();
        if (this.useSpatialHashingForNeighbors) {
            logger.debug("Will use spatial hashing to find neighbors");
            this.neighborIndices = this.findNeighborIndicesSpatialHashing();
        } else {
            logger.debug("Will not use spatial hashing to find neighbors");
            this.neighborIndices = this.findNeighborIndices();
        }
        long end = System.currentTimeMillis();
        logger.debug("Took {} s to find neighbors", (Object)((double)(end - start) / 1000.0));
        start = System.currentTimeMillis();
        if (this.nThreads <= 1) {
            logger.debug("Will use 1 thread for ASA calculation");
            for (int i = 0; i < this.atomCoords.length; ++i) {
                asas[i] = this.calcSingleAsa(i);
            }
        } else {
            logger.debug("Will use {} threads for ASA calculation", (Object)this.nThreads);
            ExecutorService threadPool = Executors.newFixedThreadPool(this.nThreads);
            for (int i = 0; i < this.atomCoords.length; ++i) {
                threadPool.submit(new AsaCalcWorker(i, asas));
            }
            threadPool.shutdown();
            while (!threadPool.isTerminated()) {
            }
        }
        end = System.currentTimeMillis();
        logger.debug("Took {} s to calculate all {} atoms ASAs (excluding neighbors calculation)", (Object)((double)(end - start) / 1000.0), (Object)this.atomCoords.length);
        return asas;
    }

    void setUseSpatialHashingForNeighbors(boolean useSpatialHashingForNeighbors) {
        this.useSpatialHashingForNeighbors = useSpatialHashingForNeighbors;
    }

    private Vector3d[] generateSpherePoints(int nSpherePoints) {
        Vector3d[] points = new Vector3d[nSpherePoints];
        double inc = Math.PI * (3.0 - Math.sqrt(5.0));
        double offset = 2.0 / (double)nSpherePoints;
        for (int k = 0; k < nSpherePoints; ++k) {
            double y = (double)k * offset - 1.0 + offset / 2.0;
            double r = Math.sqrt(1.0 - y * y);
            double phi = (double)k * inc;
            points[k] = new Vector3d(Math.cos(phi) * r, y, Math.sin(phi) * r);
        }
        return points;
    }

    IndexAndDistance[][] findNeighborIndices() {
        int initialCapacity = 60;
        IndexAndDistance[][] nbsIndices = new IndexAndDistance[this.atomCoords.length][];
        for (int k = 0; k < this.atomCoords.length; ++k) {
            double radius = this.radii[k] + this.probe + this.probe;
            ArrayList<IndexAndDistance> thisNbIndices = new ArrayList<IndexAndDistance>(initialCapacity);
            for (int i = 0; i < this.atomCoords.length; ++i) {
                double dist;
                if (i == k || !((dist = this.atomCoords[i].distance(this.atomCoords[k])) < radius + this.radii[i])) continue;
                thisNbIndices.add(new IndexAndDistance(i, dist));
            }
            IndexAndDistance[] indicesArray = thisNbIndices.toArray(new IndexAndDistance[0]);
            nbsIndices[k] = indicesArray;
        }
        return nbsIndices;
    }

    /*
     * WARNING - void declaration
     */
    IndexAndDistance[][] findNeighborIndicesSpatialHashing() {
        void var5_8;
        int initialCapacity = 60;
        List<Contact> contactList = this.calcContacts();
        HashMap indices = new HashMap(this.atomCoords.length);
        for (Contact contact : contactList) {
            List<IndexAndDistance> jIndices;
            List<IndexAndDistance> iIndices;
            int n = contact.getI();
            int j = contact.getJ();
            if (!indices.containsKey(n)) {
                iIndices = new ArrayList(initialCapacity);
                indices.put(n, iIndices);
            } else {
                iIndices = (List)indices.get(n);
            }
            if (!indices.containsKey(j)) {
                jIndices = new ArrayList(initialCapacity);
                indices.put(j, jIndices);
            } else {
                jIndices = (List)indices.get(j);
            }
            double radius = this.radii[n] + this.probe + this.probe;
            double dist = contact.getDistance();
            if (!(dist < radius + this.radii[j])) continue;
            iIndices.add(new IndexAndDistance(j, dist));
            jIndices.add(new IndexAndDistance(n, dist));
        }
        IndexAndDistance[][] nbsIndices = new IndexAndDistance[this.atomCoords.length][];
        for (Map.Entry entry : indices.entrySet()) {
            List list = (List)entry.getValue();
            IndexAndDistance[] indexAndDistances = list.toArray(new IndexAndDistance[0]);
            nbsIndices[((Integer)entry.getKey()).intValue()] = indexAndDistances;
        }
        boolean bl = false;
        while (var5_8 < nbsIndices.length) {
            if (nbsIndices[var5_8] == null) {
                nbsIndices[var5_8] = new IndexAndDistance[0];
            }
            ++var5_8;
        }
        return nbsIndices;
    }

    Point3d[] getAtomCoords() {
        return this.atomCoords;
    }

    private List<Contact> calcContacts() {
        if (this.atomCoords.length == 0) {
            return new ArrayList<Contact>();
        }
        double maxRadius = 0.0;
        OptionalDouble optionalDouble = Arrays.stream(this.radii).max();
        if (optionalDouble.isPresent()) {
            maxRadius = optionalDouble.getAsDouble();
        }
        double cutoff = maxRadius + maxRadius + this.probe + this.probe;
        logger.debug("Max radius is {}, cutoff is {}", (Object)maxRadius, (Object)cutoff);
        Grid grid = new Grid(cutoff);
        grid.addCoords(this.atomCoords);
        return grid.getIndicesContacts();
    }

    private double calcSingleAsa(int i) {
        Point3d atom_i = this.atomCoords[i];
        int n_neighbor = this.neighborIndices[i].length;
        IndexAndDistance[] neighbor_indices = this.neighborIndices[i];
        Arrays.sort(neighbor_indices, Comparator.comparingDouble(o -> o.dist));
        double radius_i = this.probe + this.radii[i];
        int n_accessible_point = 0;
        int[] numDistsCalced = null;
        if (logger.isDebugEnabled()) {
            numDistsCalced = new int[n_neighbor];
        }
        double[] sqRadii = new double[n_neighbor];
        Vector3d[] aj_minus_ais = new Vector3d[n_neighbor];
        for (int nbArrayInd = 0; nbArrayInd < n_neighbor; ++nbArrayInd) {
            int j = neighbor_indices[nbArrayInd].index;
            double dist = neighbor_indices[nbArrayInd].dist;
            double radius_j = this.radii[j] + this.probe;
            sqRadii[nbArrayInd] = (dist * dist + radius_i * radius_i - radius_j * radius_j) / (2.0 * radius_i);
            Vector3d aj_minus_ai = new Vector3d((Tuple3d)this.atomCoords[j]);
            aj_minus_ai.sub((Tuple3d)atom_i);
            aj_minus_ais[nbArrayInd] = aj_minus_ai;
        }
        for (Vector3d point : this.spherePoints) {
            boolean is_accessible = true;
            for (int nbArrayInd = 0; nbArrayInd < n_neighbor; ++nbArrayInd) {
                double dotProd = aj_minus_ais[nbArrayInd].dot(point);
                if (numDistsCalced != null) {
                    int n = nbArrayInd;
                    numDistsCalced[n] = numDistsCalced[n] + 1;
                }
                if (!(dotProd > sqRadii[nbArrayInd])) continue;
                is_accessible = false;
                break;
            }
            if (!is_accessible) continue;
            ++n_accessible_point;
        }
        if (numDistsCalced != null) {
            int sum = 0;
            for (int numDistCalcedForJ : numDistsCalced) {
                sum += numDistCalcedForJ;
            }
            logger.debug("Number of sample points distances calculated for neighbors of i={} : average {}, all {}", new Object[]{i, (double)sum / (double)n_neighbor, numDistsCalced});
        }
        return this.cons * (double)n_accessible_point * radius_i * radius_i;
    }

    private static double getRadiusForAmino(AminoAcid amino, Atom atom) {
        if (atom.getElement().equals((Object)Element.H)) {
            return Element.H.getVDWRadius();
        }
        if (atom.getElement().equals((Object)Element.D)) {
            return Element.D.getVDWRadius();
        }
        String atomCode = atom.getName();
        char aa = amino.getAminoType().charValue();
        if (atom.getElement() == Element.O) {
            return 1.4;
        }
        if (atom.getElement() == Element.S) {
            return 1.85;
        }
        if (atom.getElement() == Element.N) {
            if ("NZ".equals(atomCode)) {
                return 1.5;
            }
            return 1.65;
        }
        if (atom.getElement() == Element.C) {
            if ("C".equals(atomCode) || "CE1".equals(atomCode) || "CE2".equals(atomCode) || "CE3".equals(atomCode) || "CH2".equals(atomCode) || "CZ".equals(atomCode) || "CZ2".equals(atomCode) || "CZ3".equals(atomCode)) {
                return 1.76;
            }
            if ("CA".equals(atomCode) || "CB".equals(atomCode) || "CE".equals(atomCode) || "CG1".equals(atomCode) || "CG2".equals(atomCode)) {
                return 1.87;
            }
            switch (aa) {
                case 'D': 
                case 'F': 
                case 'H': 
                case 'N': 
                case 'W': 
                case 'Y': {
                    return 1.76;
                }
                case 'I': 
                case 'K': 
                case 'L': 
                case 'M': 
                case 'P': 
                case 'R': {
                    return 1.87;
                }
                case 'E': 
                case 'Q': {
                    if ("CD".equals(atomCode)) {
                        return 1.76;
                    }
                    if (!"CG".equals(atomCode)) break;
                    return 1.87;
                }
            }
            logger.info("Unexpected carbon atom {} for aminoacid {}, assigning its standard vdw radius", (Object)atomCode, (Object)Character.valueOf(aa));
            return Element.C.getVDWRadius();
        }
        logger.debug("Unexpected atom {} for aminoacid {} ({}), assigning its standard vdw radius", new Object[]{atomCode, Character.valueOf(aa), amino.getPDBName()});
        return atom.getElement().getVDWRadius();
    }

    private static double getRadiusForNucl(NucleotideImpl nuc, Atom atom) {
        if (atom.getElement().equals((Object)Element.H)) {
            return Element.H.getVDWRadius();
        }
        if (atom.getElement().equals((Object)Element.D)) {
            return Element.D.getVDWRadius();
        }
        if (atom.getElement() == Element.C) {
            return 1.8;
        }
        if (atom.getElement() == Element.N) {
            return 1.6;
        }
        if (atom.getElement() == Element.P) {
            return 1.9;
        }
        if (atom.getElement() == Element.O) {
            return 1.4;
        }
        logger.info("Unexpected atom " + atom.getName() + " for nucleotide " + nuc.getPDBName() + ", assigning its standard vdw radius");
        return atom.getElement().getVDWRadius();
    }

    public static double getRadius(Atom atom) {
        if (atom.getElement() == null) {
            logger.warn("Unrecognised atom " + atom.getName() + " with serial " + atom.getPDBserial() + ", assigning the default vdw radius (Nitrogen vdw radius).");
            return Element.N.getVDWRadius();
        }
        Group res = atom.getGroup();
        if (res == null) {
            logger.warn("Unknown parent residue for atom " + atom.getName() + " with serial " + atom.getPDBserial() + ", assigning its default vdw radius");
            return atom.getElement().getVDWRadius();
        }
        GroupType type = res.getType();
        if (type == GroupType.AMINOACID) {
            return AsaCalculator.getRadiusForAmino((AminoAcid)res, atom);
        }
        if (type == GroupType.NUCLEOTIDE) {
            return AsaCalculator.getRadiusForNucl((NucleotideImpl)res, atom);
        }
        return atom.getElement().getVDWRadius();
    }

    static class IndexAndDistance {
        final int index;
        final double dist;

        IndexAndDistance(int index, double dist) {
            this.index = index;
            this.dist = dist;
        }
    }

    private class AsaCalcWorker
    implements Runnable {
        private final int i;
        private final double[] asas;

        private AsaCalcWorker(int i, double[] asas) {
            this.i = i;
            this.asas = asas;
        }

        @Override
        public void run() {
            this.asas[this.i] = AsaCalculator.this.calcSingleAsa(this.i);
        }
    }
}

