/*
 * Decompiled with CFR 0.152.
 */
package javaforce.cl;

import javaforce.JFLog;
import javaforce.cl.CL;

public class Compute {
    private CL cl;
    private long ctx;
    private long k_array_square;
    private long k_array_mult;
    private long k_matrix_mult;

    public boolean init(int type) {
        this.uninit();
        this.cl = CL.getInstance();
        if (this.cl == null) {
            return false;
        }
        this.ctx = this.cl.create("kernel void array_square(global float* input, global float* output)  {    int i = get_global_id(0);    output[i] = input[i] * input[i];  };\nkernel void array_mult(global float* input0, global float* input1, global float* output)  {    int i = get_global_id(0);    output[i] = input0[i] * input1[i];  };\nkernel void matrix_mult(const int M, const int N, const int K, global float* A, global float* B, global float* C)  {    const int globalRow = get_global_id(0);    const int globalCol = get_global_id(1);    float acc = 0.0f;    for (int k=0;k<K;k++) {      acc += A[k*M + globalRow] * B[globalCol*K + k];    }    C[globalCol*M + globalRow] = acc;  }\n", type);
        this.k_array_square = this.cl.kernel(this.ctx, "array_square");
        this.k_array_mult = this.cl.kernel(this.ctx, "array_mult");
        this.k_matrix_mult = this.cl.kernel(this.ctx, "matrix_mult");
        return true;
    }

    public void uninit() {
        if (this.cl != null) {
            this.cl.freeKernel(this.ctx, this.k_array_square);
            this.cl.freeKernel(this.ctx, this.k_array_mult);
            this.cl.freeKernel(this.ctx, this.k_matrix_mult);
            this.cl.close(this.ctx);
            this.cl = null;
        }
    }

    public boolean array_square(float[] a, float[] b) {
        if (a == null) {
            JFLog.log("Compute:array_square:a invalid");
            return false;
        }
        int size = a.length;
        if (b == null || b.length != size) {
            JFLog.log("Compute:array_square:b invalid");
            return false;
        }
        long input0 = this.cl.createWriteBuffer(this.ctx, 4 * size);
        long output = this.cl.createReadBuffer(this.ctx, 4 * size);
        this.cl.writeBuffer(this.ctx, input0, a);
        this.cl.setArg(this.ctx, this.k_array_square, 0, input0);
        this.cl.setArg(this.ctx, this.k_array_square, 1, output);
        boolean ret = this.cl.execute(this.ctx, this.k_array_square, size);
        this.cl.readBuffer(this.ctx, output, b);
        this.cl.freeBuffer(this.ctx, input0);
        this.cl.freeBuffer(this.ctx, output);
        return ret;
    }

    public boolean array_mult(float[] a, float[] b, float[] c) {
        if (a == null) {
            JFLog.log("Compute:array_mult:a invalid");
            return false;
        }
        int size = a.length;
        if (b == null || b.length != size) {
            JFLog.log("Compute:array_mult:b invalid");
            return false;
        }
        if (c == null || c.length != size) {
            JFLog.log("Compute:array_mult:c invalid");
            return false;
        }
        long input0 = this.cl.createWriteBuffer(this.ctx, 4 * size);
        long input1 = this.cl.createWriteBuffer(this.ctx, 4 * size);
        long output = this.cl.createReadBuffer(this.ctx, 4 * size);
        this.cl.writeBuffer(this.ctx, input0, a);
        this.cl.writeBuffer(this.ctx, input1, b);
        this.cl.setArg(this.ctx, this.k_array_mult, 0, input0);
        this.cl.setArg(this.ctx, this.k_array_mult, 1, input1);
        this.cl.setArg(this.ctx, this.k_array_mult, 2, output);
        boolean ret = this.cl.execute(this.ctx, this.k_array_mult, size);
        this.cl.readBuffer(this.ctx, output, c);
        this.cl.freeBuffer(this.ctx, input0);
        this.cl.freeBuffer(this.ctx, input1);
        this.cl.freeBuffer(this.ctx, output);
        return ret;
    }

    public boolean matrix_mult(int as, int bs, int ks, float[] a, float[] b, float[] c) {
        int a_size = as * ks;
        if (a == null || a.length != a_size) {
            JFLog.log("Compute:matrix_mult:a invalid");
            return false;
        }
        int b_size = bs * ks;
        if (b == null || b.length != b_size) {
            JFLog.log("Compute:matrix_mult:b invalid");
            return false;
        }
        int c_size = as * bs;
        if (c == null || c.length != c_size) {
            JFLog.log("Compute:matrix_mult:c invalid");
            return false;
        }
        long input0 = this.cl.createWriteBuffer(this.ctx, 4 * a_size);
        long input1 = this.cl.createWriteBuffer(this.ctx, 4 * b_size);
        long output = this.cl.createReadBuffer(this.ctx, 4 * c_size);
        this.cl.writeBuffer(this.ctx, input0, a);
        this.cl.writeBuffer(this.ctx, input1, b);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 0, as);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 1, bs);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 2, ks);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 3, input0);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 4, input1);
        this.cl.setArg(this.ctx, this.k_matrix_mult, 5, output);
        boolean ret = this.cl.execute2(this.ctx, this.k_matrix_mult, as, bs);
        this.cl.readBuffer(this.ctx, output, c);
        this.cl.freeBuffer(this.ctx, input0);
        this.cl.freeBuffer(this.ctx, input1);
        this.cl.freeBuffer(this.ctx, output);
        return ret;
    }
}

