/*
 * Decompiled with CFR 0.152.
 */
package edu.emory.mathcs.restoretools.iterative.preconditioner;

import cern.colt.function.tint.IntComparator;
import cern.colt.matrix.AbstractMatrix3D;
import cern.colt.matrix.tfcomplex.FComplexMatrix3D;
import cern.colt.matrix.tfcomplex.impl.DenseFComplexMatrix3D;
import cern.colt.matrix.tfloat.FloatMatrix1D;
import cern.colt.matrix.tfloat.FloatMatrix3D;
import cern.colt.matrix.tfloat.algo.FloatSorting;
import cern.colt.matrix.tfloat.impl.DenseFloatMatrix1D;
import cern.colt.matrix.tfloat.impl.DenseFloatMatrix3D;
import cern.jet.math.tfcomplex.FComplexFunctions;
import cern.jet.math.tfloat.FloatFunctions;
import edu.emory.mathcs.restoretools.iterative.FloatCommon3D;
import edu.emory.mathcs.restoretools.iterative.IterativeEnums;
import edu.emory.mathcs.restoretools.iterative.preconditioner.FloatPreconditioner3D;
import edu.emory.mathcs.restoretools.iterative.psf.FloatPSFMatrix3D;
import edu.emory.mathcs.utils.ConcurrencyUtils;
import ij.IJ;
import java.util.concurrent.Future;

public class FFTFloatPreconditioner3D
implements FloatPreconditioner3D {
    private AbstractMatrix3D matdata;
    private float tol;
    private IterativeEnums.BoundaryType boundary;
    private int[] imSize;
    private int[] psfSize;
    private int[] padSize;

    public FFTFloatPreconditioner3D(FloatPSFMatrix3D PSFMatrix, FloatMatrix3D B, float tol) {
        this.tol = tol;
        this.boundary = PSFMatrix.getBoundary();
        this.imSize = new int[3];
        this.imSize[0] = B.slices();
        this.imSize[1] = B.rows();
        this.imSize[2] = B.columns();
        if (PSFMatrix.getType() == IterativeEnums.PSFType.INVARIANT) {
            this.psfSize = PSFMatrix.getInvPsfSize();
            this.padSize = PSFMatrix.getInvPadSize();
        } else {
            this.psfSize = PSFMatrix.getPSF().getSize();
            int[] minimal = new int[]{this.psfSize[0] + this.imSize[0], this.psfSize[1] + this.imSize[1], this.psfSize[2] + this.imSize[2]};
            switch (PSFMatrix.getResizing()) {
                case AUTO: {
                    int[] nextPowTwo = new int[]{!ConcurrencyUtils.isPowerOf2((int)minimal[0]) ? ConcurrencyUtils.nextPow2((int)minimal[0]) : minimal[0], !ConcurrencyUtils.isPowerOf2((int)minimal[1]) ? ConcurrencyUtils.nextPow2((int)minimal[1]) : minimal[1], !ConcurrencyUtils.isPowerOf2((int)minimal[2]) ? ConcurrencyUtils.nextPow2((int)minimal[2]) : minimal[2]};
                    if ((double)nextPowTwo[0] >= 1.5 * (double)minimal[0] || (double)nextPowTwo[1] >= 1.5 * (double)minimal[1] || (double)nextPowTwo[2] >= 1.5 * (double)minimal[2]) {
                        this.psfSize[0] = minimal[0];
                        this.psfSize[1] = minimal[1];
                        this.psfSize[2] = minimal[2];
                        break;
                    }
                    this.psfSize[0] = nextPowTwo[0];
                    this.psfSize[1] = nextPowTwo[1];
                    this.psfSize[2] = nextPowTwo[2];
                    break;
                }
                case MINIMAL: {
                    this.psfSize[0] = minimal[0];
                    this.psfSize[1] = minimal[1];
                    this.psfSize[2] = minimal[2];
                    break;
                }
                case NEXT_POWER_OF_TWO: {
                    this.psfSize[0] = minimal[0];
                    this.psfSize[1] = minimal[1];
                    this.psfSize[2] = minimal[2];
                    if (!ConcurrencyUtils.isPowerOf2((int)this.psfSize[0])) {
                        this.psfSize[0] = ConcurrencyUtils.nextPow2((int)this.psfSize[0]);
                    }
                    if (!ConcurrencyUtils.isPowerOf2((int)this.psfSize[1])) {
                        this.psfSize[1] = ConcurrencyUtils.nextPow2((int)this.psfSize[1]);
                    }
                    if (ConcurrencyUtils.isPowerOf2((int)this.psfSize[2])) break;
                    this.psfSize[2] = ConcurrencyUtils.nextPow2((int)this.psfSize[2]);
                }
            }
            this.padSize = new int[3];
            if (this.imSize[0] < this.psfSize[0]) {
                this.padSize[0] = (this.psfSize[0] - this.imSize[0] + 1) / 2;
            }
            if (this.imSize[1] < this.psfSize[1]) {
                this.padSize[1] = (this.psfSize[1] - this.imSize[1] + 1) / 2;
            }
            if (this.imSize[2] < this.psfSize[2]) {
                this.padSize[2] = (this.psfSize[2] - this.imSize[2] + 1) / 2;
            }
        }
        this.constructMatrix(PSFMatrix.getPSF().getImage(), B, PSFMatrix.getPSF().getCenter());
    }

    public float getTolerance() {
        return this.tol;
    }

    public FloatMatrix1D solve(FloatMatrix1D b, boolean transpose) {
        DenseFloatMatrix3D B = null;
        B = b.isView() ? new DenseFloatMatrix3D(this.imSize[0], this.imSize[1], this.imSize[2], (float[])b.copy().elements(), 0, 0, 0, this.imSize[1] * this.imSize[2], this.imSize[2], 1, false) : new DenseFloatMatrix3D(this.imSize[0], this.imSize[1], this.imSize[2], (float[])b.elements(), 0, 0, 0, this.imSize[1] * this.imSize[2], this.imSize[2], 1, false);
        B = this.solve((AbstractMatrix3D)B, transpose);
        return new DenseFloatMatrix1D((int)B.size(), (float[])B.elements(), 0, 1, false);
    }

    public FloatMatrix3D solve(AbstractMatrix3D B, boolean transpose) {
        switch (this.boundary) {
            case ZERO: {
                B = FloatCommon3D.padZero((FloatMatrix3D)B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
                break;
            }
            case PERIODIC: {
                B = FloatCommon3D.padPeriodic((FloatMatrix3D)B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
                break;
            }
            case REFLEXIVE: {
                B = FloatCommon3D.padReflexive((FloatMatrix3D)B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
            }
        }
        B = ((DenseFloatMatrix3D)B).getFft3();
        if (transpose) {
            ((FComplexMatrix3D)B).assign((FComplexMatrix3D)this.matdata, FComplexFunctions.multConjSecond);
        } else {
            ((FComplexMatrix3D)B).assign((FComplexMatrix3D)this.matdata, FComplexFunctions.mult);
        }
        ((DenseFComplexMatrix3D)B).ifft3(true);
        return ((FComplexMatrix3D)B).viewPart(this.padSize[0], this.padSize[1], this.padSize[2], this.imSize[0], this.imSize[1], this.imSize[2]).getRealPart();
    }

    private void constructMatrix(FloatMatrix3D[][][] PSFs, FloatMatrix3D B, int[][][][] center) {
        this.matdata = PSFs[0][0][0].like();
        int[] center1 = center[0][0][0];
        int slices = PSFs.length;
        int rows = PSFs[0].length;
        int columns = PSFs[0][0].length;
        int size = slices * rows * columns;
        for (int s = 0; s < slices; ++s) {
            for (int r = 0; r < rows; ++r) {
                for (int c = 0; c < columns; ++c) {
                    ((FloatMatrix3D)this.matdata).assign(PSFs[s][r][c], FloatFunctions.plus);
                }
            }
        }
        if (size != 1) {
            ((FloatMatrix3D)this.matdata).assign(FloatFunctions.div((float)size));
        }
        switch (this.boundary) {
            case ZERO: {
                B = FloatCommon3D.padZero(B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
                break;
            }
            case PERIODIC: {
                B = FloatCommon3D.padPeriodic(B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
                break;
            }
            case REFLEXIVE: {
                B = FloatCommon3D.padReflexive(B, this.psfSize[0], this.psfSize[1], this.psfSize[2]);
            }
        }
        this.precMatrixOnePsf(center1, B);
    }

    private void precMatrixOnePsf(int[] center, FloatMatrix3D Bpad) {
        int[] padSize = new int[]{Bpad.slices() - this.matdata.slices(), Bpad.rows() - this.matdata.rows(), Bpad.columns() - this.matdata.columns()};
        if (padSize[0] > 0 || padSize[1] > 0 || padSize[2] > 0) {
            this.matdata = FloatCommon3D.padZero((FloatMatrix3D)this.matdata, padSize, IterativeEnums.PaddingType.POST);
        }
        this.matdata = FloatCommon3D.circShift((FloatMatrix3D)this.matdata, center);
        this.matdata = ((DenseFloatMatrix3D)this.matdata).getFft3();
        FComplexMatrix3D E = ((FComplexMatrix3D)this.matdata).copy();
        E.assign(FComplexFunctions.abs);
        E = E.getRealPart();
        float[] maxAndLoc = ((FloatMatrix3D)E).getMaxLocation();
        final float maxE = maxAndLoc[0];
        if (this.tol == -1.0f) {
            IJ.showStatus((String)"Computing tolerance for preconditioner...");
            float[] minAndLoc = ((FloatMatrix3D)E).getMinLocation();
            float minE = minAndLoc[0];
            this.tol = maxE / minE < 100.0f ? 0.0f : this.defaultTol2((FloatMatrix3D)E, Bpad);
            IJ.showStatus((String)"Computing tolerance for preconditioner...done.");
        }
        final float[] one = new float[]{1.0f, 0.0f};
        if ((double)maxE != 1.0) {
            ((FComplexMatrix3D)this.matdata).assign(FComplexFunctions.div((float[])new float[]{maxE, 0.0f}));
        }
        int slices = E.slices();
        final int rows = E.rows();
        final int cols = E.columns();
        final float[] elementsE = (float[])((FloatMatrix3D)E).elements();
        final int zeroE = (int)((FloatMatrix3D)E).index(0, 0, 0);
        final int sliceStrideE = ((FloatMatrix3D)E).sliceStride();
        final int rowStrideE = ((FloatMatrix3D)E).rowStride();
        final int columnStrideE = ((FloatMatrix3D)E).columnStride();
        final float[] elementsM = (float[])((FComplexMatrix3D)this.matdata).elements();
        final int zeroM = (int)((FComplexMatrix3D)this.matdata).index(0, 0, 0);
        final int sliceStrideM = ((FComplexMatrix3D)this.matdata).sliceStride();
        final int rowStrideM = ((FComplexMatrix3D)this.matdata).rowStride();
        final int columnStrideM = ((FComplexMatrix3D)this.matdata).columnStride();
        int np = ConcurrencyUtils.getNumberOfThreads();
        if (np > 1 && slices * rows * cols >= ConcurrencyUtils.getThreadsBeginN_3D()) {
            Future[] futures = new Future[np];
            int k = slices / np;
            for (int j = 0; j < np; ++j) {
                final int startslice = j * k;
                final int stopslice = j == np - 1 ? slices : startslice + k;
                futures[j] = ConcurrencyUtils.submit((Runnable)new Runnable(){

                    public void run() {
                        float[] elem = new float[2];
                        if ((double)maxE != 1.0) {
                            for (int s = startslice; s < stopslice; ++s) {
                                for (int r = 0; r < rows; ++r) {
                                    for (int c = 0; c < cols; ++c) {
                                        int idxE = zeroE + s * sliceStrideE + r * rowStrideE + c * columnStrideE;
                                        int idxM = zeroM + s * sliceStrideM + r * rowStrideM + c * columnStrideM;
                                        elem[0] = elementsM[idxM];
                                        elem[1] = elementsM[idxM + 1];
                                        if (elementsE[idxE] >= FFTFloatPreconditioner3D.this.tol) {
                                            if ((double)elem[1] != 0.0) {
                                                float scalar;
                                                if (Math.abs(elem[0]) >= Math.abs(elem[1])) {
                                                    elem[0] = scalar = (float)(1.0 / (double)(elem[0] + elem[1] * (elem[1] / elem[0])));
                                                    elem[1] = scalar * (-elem[1] / elem[0]);
                                                } else {
                                                    scalar = (float)(1.0 / (double)(elem[0] * (elem[0] / elem[1]) + elem[1]));
                                                    elem[0] = scalar * (elem[0] / elem[1]);
                                                    elem[1] = -scalar;
                                                }
                                            } else {
                                                elem[0] = 1.0f / elem[0];
                                                elem[1] = 0.0f;
                                            }
                                            elem[0] = elem[0] * maxE;
                                            elem[1] = elem[1] * maxE;
                                            elementsM[idxM] = elem[0];
                                            elementsM[idxM + 1] = elem[1];
                                            continue;
                                        }
                                        elementsM[idxM] = one[0];
                                        elementsM[idxM + 1] = one[1];
                                    }
                                }
                            }
                        } else {
                            for (int s = startslice; s < stopslice; ++s) {
                                for (int r = 0; r < rows; ++r) {
                                    for (int c = 0; c < cols; ++c) {
                                        int idxE = zeroE + s * sliceStrideE + r * rowStrideE + c * columnStrideE;
                                        int idxM = zeroM + s * sliceStrideM + r * rowStrideM + c * columnStrideM;
                                        elem[0] = elementsM[idxM];
                                        elem[1] = elementsM[idxM + 1];
                                        if (elementsE[idxE] >= FFTFloatPreconditioner3D.this.tol) {
                                            if ((double)elem[1] != 0.0) {
                                                float scalar;
                                                if (Math.abs(elem[0]) >= Math.abs(elem[1])) {
                                                    elem[0] = scalar = (float)(1.0 / (double)(elem[0] + elem[1] * (elem[1] / elem[0])));
                                                    elem[1] = scalar * (-elem[1] / elem[0]);
                                                } else {
                                                    scalar = (float)(1.0 / (double)(elem[0] * (elem[0] / elem[1]) + elem[1]));
                                                    elem[0] = scalar * (elem[0] / elem[1]);
                                                    elem[1] = -scalar;
                                                }
                                            } else {
                                                elem[0] = 1.0f / elem[0];
                                                elem[1] = 0.0f;
                                            }
                                            elementsM[idxM] = elem[0];
                                            elementsM[idxM + 1] = elem[1];
                                            continue;
                                        }
                                        elementsM[idxM] = one[0];
                                        elementsM[idxM + 1] = one[1];
                                    }
                                }
                            }
                        }
                    }
                });
            }
            ConcurrencyUtils.waitForCompletion((Future[])futures);
        } else {
            float[] elem = new float[2];
            if ((double)maxE != 1.0) {
                for (int s = 0; s < slices; ++s) {
                    for (int r = 0; r < rows; ++r) {
                        for (int c = 0; c < cols; ++c) {
                            int idxE = zeroE + s * sliceStrideE + r * rowStrideE + c * columnStrideE;
                            int idxM = zeroM + s * sliceStrideM + r * rowStrideM + c * columnStrideM;
                            elem[0] = elementsM[idxM];
                            elem[1] = elementsM[idxM + 1];
                            if (elementsE[idxE] >= this.tol) {
                                if ((double)elem[1] != 0.0) {
                                    float scalar;
                                    if (Math.abs(elem[0]) >= Math.abs(elem[1])) {
                                        elem[0] = scalar = (float)(1.0 / (double)(elem[0] + elem[1] * (elem[1] / elem[0])));
                                        elem[1] = scalar * (-elem[1] / elem[0]);
                                    } else {
                                        scalar = (float)(1.0 / (double)(elem[0] * (elem[0] / elem[1]) + elem[1]));
                                        elem[0] = scalar * (elem[0] / elem[1]);
                                        elem[1] = -scalar;
                                    }
                                } else {
                                    elem[0] = 1.0f / elem[0];
                                    elem[1] = 0.0f;
                                }
                                elem[0] = elem[0] * maxE;
                                elem[1] = elem[1] * maxE;
                                elementsM[idxM] = elem[0];
                                elementsM[idxM + 1] = elem[1];
                                continue;
                            }
                            elementsM[idxM] = one[0];
                            elementsM[idxM + 1] = one[1];
                        }
                    }
                }
            } else {
                for (int s = 0; s < slices; ++s) {
                    for (int r = 0; r < rows; ++r) {
                        for (int c = 0; c < cols; ++c) {
                            int idxE = zeroE + s * sliceStrideE + r * rowStrideE + c * columnStrideE;
                            int idxM = zeroM + s * sliceStrideM + r * rowStrideM + c * columnStrideM;
                            elem[0] = elementsM[idxM];
                            elem[1] = elementsM[idxM + 1];
                            if (elementsE[idxE] >= this.tol) {
                                if ((double)elem[1] != 0.0) {
                                    float scalar;
                                    if (Math.abs(elem[0]) >= Math.abs(elem[1])) {
                                        elem[0] = scalar = (float)(1.0 / (double)(elem[0] + elem[1] * (elem[1] / elem[0])));
                                        elem[1] = scalar * (-elem[1] / elem[0]);
                                    } else {
                                        scalar = (float)(1.0 / (double)(elem[0] * (elem[0] / elem[1]) + elem[1]));
                                        elem[0] = scalar * (elem[0] / elem[1]);
                                        elem[1] = -scalar;
                                    }
                                } else {
                                    elem[0] = 1.0f / elem[0];
                                    elem[1] = 0.0f;
                                }
                                elementsM[idxM] = elem[0];
                                elementsM[idxM + 1] = elem[1];
                                continue;
                            }
                            elementsM[idxM] = one[0];
                            elementsM[idxM + 1] = one[1];
                        }
                    }
                }
            }
        }
    }

    private float defaultTol2(FloatMatrix3D E, FloatMatrix3D B) {
        int k;
        DenseFloatMatrix1D s = new DenseFloatMatrix1D((int)E.size());
        System.arraycopy((float[])E.elements(), 0, (float[])s.elements(), 0, (int)s.size());
        final float[] evalues = (float[])s.elements();
        IntComparator compDec = new IntComparator(){

            public int compare(int a, int b) {
                if (evalues[a] != evalues[a] || evalues[b] != evalues[b]) {
                    return FFTFloatPreconditioner3D.this.compareNaN(evalues[a], evalues[b]);
                }
                return evalues[a] < evalues[b] ? 1 : (evalues[a] == evalues[b] ? 0 : -1);
            }
        };
        int[] indices = FloatSorting.quickSort.sortIndex((FloatMatrix1D)s, compDec);
        s = s.viewSelection(indices);
        DenseFComplexMatrix3D Bhat = ((DenseFloatMatrix3D)B).getFft3();
        ((FComplexMatrix3D)Bhat).assign(FComplexFunctions.abs);
        Bhat = ((FComplexMatrix3D)Bhat).getRealPart();
        DenseFloatMatrix1D bhat = new DenseFloatMatrix1D((int)Bhat.size(), (float[])((FloatMatrix3D)Bhat).elements(), 0, 1, false);
        bhat = bhat.viewSelection(indices);
        bhat.assign(FloatFunctions.div((float)((float)Math.sqrt(B.size()))));
        int n = (int)s.size();
        float[] rho = new float[n - 1];
        rho[n - 2] = bhat.getQuick(n - 1) * bhat.getQuick(n - 1);
        DenseFloatMatrix1D G = new DenseFloatMatrix1D(n - 1);
        float[] elemsG = (float[])G.elements();
        elemsG[n - 2] = rho[n - 2];
        for (k = n - 2; k > 0; --k) {
            float bhatel = bhat.getQuick(k);
            rho[k - 1] = rho[k] + bhatel * bhatel;
            float temp1 = n - k;
            temp1 *= temp1;
            elemsG[k - 1] = rho[k - 1] / temp1;
        }
        for (k = 0; k < n - 3; ++k) {
            if (s.getQuick(k) != s.getQuick(k + 1)) continue;
            elemsG[k] = Float.POSITIVE_INFINITY;
        }
        return s.getQuick((int)G.getMinLocation()[1]);
    }

    private final int compareNaN(float a, float b) {
        if (a != a) {
            if (b != b) {
                return 0;
            }
            return 1;
        }
        return -1;
    }
}

