/*
 * Decompiled with CFR 0.152.
 */
package g3501_3600.s3559_number_of_ways_to_assign_edge_weights_ii;

import java.util.ArrayList;
import java.util.List;

public class Solution {
    private static final int MOD = 1000000007;
    private List<List<Integer>> adj;
    private int[] level;
    private int[][] jumps;

    private void mark(int node, int par) {
        for (int neigh : this.adj.get(node)) {
            if (neigh == par) continue;
            this.level[neigh] = this.level[node] + 1;
            this.jumps[neigh][0] = node;
            this.mark(neigh, node);
        }
    }

    public int lift(int u, int diff) {
        while (diff > 0) {
            int rightmost = diff ^ diff & diff - 1;
            int jump = (int)(Math.log(rightmost) / Math.log(2.0));
            u = this.jumps[u][jump];
            diff -= rightmost;
        }
        return u;
    }

    private int findLca(int u, int v) {
        if (this.level[u] > this.level[v]) {
            int temp = u;
            u = v;
            v = temp;
        }
        if (u == (v = this.lift(v, this.level[v] - this.level[u]))) {
            return u;
        }
        for (int i = this.jumps[0].length - 1; i >= 0; --i) {
            if (this.jumps[u][i] == this.jumps[v][i]) continue;
            u = this.jumps[u][i];
            v = this.jumps[v][i];
        }
        return this.jumps[u][0];
    }

    private int findDist(int a, int b) {
        return this.level[a] + this.level[b] - 2 * this.level[this.findLca(a, b)];
    }

    public int[] assignEdgeWeights(int[][] edges, int[][] queries) {
        int i;
        int n = edges.length + 1;
        this.adj = new ArrayList<List<Integer>>();
        this.level = new int[n];
        for (int i2 = 0; i2 < n; ++i2) {
            this.adj.add(new ArrayList());
        }
        for (int[] i3 : edges) {
            this.adj.get(i3[0] - 1).add(i3[1] - 1);
            this.adj.get(i3[1] - 1).add(i3[0] - 1);
        }
        int m = (int)Math.ceil(Math.log((double)n - 1.0) / Math.log(2.0)) + 1;
        this.jumps = new int[n][m];
        this.mark(0, -1);
        for (int j = 1; j < m; ++j) {
            for (i = 0; i < n; ++i) {
                int p = this.jumps[i][j - 1];
                this.jumps[i][j] = this.jumps[p][j - 1];
            }
        }
        int[] pow = new int[n + 1];
        pow[0] = 1;
        for (i = 1; i <= n; ++i) {
            pow[i] = pow[i - 1] * 2 % 1000000007;
        }
        int q = queries.length;
        int[] ans = new int[q];
        for (int i4 = 0; i4 < q; ++i4) {
            int d = this.findDist(queries[i4][0] - 1, queries[i4][1] - 1);
            ans[i4] = d > 0 ? pow[d - 1] : 0;
        }
        return ans;
    }
}

