2017년 10월 22일 일요일

arrayfire batch multiplication implementation

github repo

arrayfire batch matmul

arrayfire batch matmul

matrix batch multiplication for arrayfire using cublas.

support data type, structure

The function names are af::matmul3CNN, af::matmul3CTN, af::matmul3CNT.

The '3' means "third dimension".

'N' means "normal", 'T' means "hermitian transformed".

You can read cublas gemm for detail.

The matrix batch multiplication performs below.

A(input matrix): (n x m x l) B(input matrix): (m x k x l) C(output matrix): (n x k x l) <= matmul3CXX(A x B)

And af::matmul3CXX needs 3 matrix arguments: A, B, C.

These all of 3 matrics must be prepared(allocated).

| |type |dimension | |--|----------------|-----------| |A |af::array(c32, c64) |(n x m x l)| |B |af::array(c32, c64) |(m x k x l)| |C |af::array(c32, c64) |(n x k x l)|

precaution

It doesn't support real number matrix transpose.

example

int main(void)
{
    af::setBackend(AF_BACKEND_CUDA);

    cublasHandle_t handle;
    cublasCreate(&handle);

    cudaError_t cudaStat;

    cudaStat = cudaSetDevice(0);
    if (cudaStat != cudaSuccess) {
        fprintf(stderr, "cudaSetDevice failed!  Do you have a CUDA-capable GPU installed?");
        return 0;
    }

    float Adata[4] = { 1, 2, 3, 4 };
    float Bdata[4] = { 5, 6, 7, 8 };

    auto A = af::array(2, 2, Adata, afHost).as(c32);
    auto B = af::array(2, 2, Bdata, afHost).as(c32);

    af_print(A);
    af_print(B);
    af_print(af::matmul(A, B));

    /*
    expect
        (23.0000,0.0000)          (31.0000,0.0000)
        (34.0000,0.0000)          (46.0000,0.0000)
    */

    // arrayfire matmul

    auto batchA = af::tile(A, 1, 1, 2);
    auto batchB = af::tile(B, 1, 1, 2);
    auto batchC = af::constant(0.0f, 2, 2, 2, c32);

    af_print(af::matmul3CNN(handle, batchA, batchB, batchC)); // it returns batchC.

    /*
    expect
         (23.0000,0.0000)          (31.0000,0.0000)
         (34.0000,0.0000)          (46.0000,0.0000)
         (23.0000,0.0000)          (31.0000,0.0000)
         (34.0000,0.0000)          (46.0000,0.0000)

    */
    return 0;
}