/*
 * Decompiled with CFR 0.152.
 */
package g1601_1700.s1632_rank_transform_of_a_matrix;

import java.util.Arrays;

public class Solution {
    public int[][] matrixRankTransform(int[][] matrix) {
        int rowCount = matrix.length;
        int colCount = matrix[0].length;
        long[] nums = new long[rowCount * colCount];
        int numsIdx = 0;
        int[] rows = new int[rowCount];
        int[] cols = new int[colCount];
        for (int r = rowCount - 1; r >= 0; --r) {
            for (int c = colCount - 1; c >= 0; --c) {
                nums[numsIdx++] = (long)matrix[r][c] << 32 | (long)r << 16 | (long)c;
            }
        }
        Arrays.sort(nums);
        int nIdx = 0;
        while (nIdx < numsIdx) {
            int endIdx;
            long num = nums[nIdx] & 0xFFFFFFFF00000000L;
            for (endIdx = nIdx + 1; endIdx < numsIdx && (nums[endIdx] & 0xFFFFFFFF00000000L) == num; ++endIdx) {
            }
            this.doGroup(matrix, nums, nIdx, endIdx, rows, cols);
            nIdx = endIdx;
        }
        return matrix;
    }

    private void doGroup(int[][] matrix, long[] nums, int startIdx, int endIdx, int[] rows, int[] cols) {
        if (startIdx + 1 == endIdx) {
            int r = (int)nums[startIdx] >> 16 & 0xFFFF;
            int c = (int)nums[startIdx] & 0xFFFF;
            rows[r] = cols[c] = Math.max(rows[r], cols[c]) + 1;
            matrix[r][c] = cols[c];
        } else {
            int c;
            int r;
            int nIdx;
            int rowCount = matrix.length;
            int[] ufind = new int[rowCount + matrix[0].length];
            Arrays.fill(ufind, -1);
            for (nIdx = startIdx; nIdx < endIdx; ++nIdx) {
                int pc;
                r = (int)nums[nIdx] >> 16 & 0xFFFF;
                c = (int)nums[nIdx] & 0xFFFF;
                int pr = this.getIdx(ufind, r);
                if (pr == (pc = this.getIdx(ufind, rowCount + c))) continue;
                ufind[pr] = Math.min(Math.min(ufind[pr], ufind[pc]), -Math.max(rows[r], cols[c]) - 1);
                ufind[pc] = pr;
            }
            for (nIdx = startIdx; nIdx < endIdx; ++nIdx) {
                r = (int)nums[nIdx] >> 16 & 0xFFFF;
                c = (int)nums[nIdx] & 0xFFFF;
                rows[r] = cols[c] = -ufind[this.getIdx(ufind, r)];
                matrix[r][c] = cols[c];
            }
        }
    }

    private int getIdx(int[] ufind, int idx) {
        if (ufind[idx] < 0) {
            return idx;
        }
        ufind[idx] = this.getIdx(ufind, ufind[idx]);
        return ufind[idx];
    }
}

