diff --git a/Source/kronecker/GB_kron.c b/Source/kronecker/GB_kron.c index 5c462da512..818b08ff93 100644 --- a/Source/kronecker/GB_kron.c +++ b/Source/kronecker/GB_kron.c @@ -23,11 +23,71 @@ GB_Matrix_free (&T) ; \ } +#define GBI(Ai,p,avlen) ((Ai == NULL) ? ((p) % (avlen)) : Ai [p]) + +#define GBB(Ab,p) ((Ab == NULL) ? 1 : Ab [p]) + +#define GBP(Ap,k,avlen) ((Ap == NULL) ? ((k) * (avlen)) : Ap [k]) + +#define GBH(Ah,k) ((Ah == NULL) ? (k) : Ah [k]) + #include "kronecker/GB_kron.h" #include "mxm/GB_mxm.h" #include "transpose/GB_transpose.h" #include "mask/GB_accum_mask.h" +static bool GB_lookup_xoffset ( + GrB_Index* p, + GrB_Matrix A, + GrB_Index row, + GrB_Index col +) +{ + GrB_Index vector = A->is_csc ? col : row ; + GrB_Index coord = A->is_csc ? row : col ; + + if (A->p == NULL) + { + GrB_Index offset = vector * A->vlen + coord ; + if (A->b == NULL || ((int8_t*)A->b)[offset]) + { + *p = A->iso ? 0 : offset ; + return true ; + } + return false ; + } + + int64_t start, end ; + bool res ; + + if (A->h == NULL) + { + start = A->p_is_32 ? ((uint32_t*)A->p)[vector] : ((uint64_t*)A->p)[vector] ; + end = A->p_is_32 ? ((uint32_t*)A->p)[vector + 1] : ((uint64_t*)A->p)[vector + 1] ; + end-- ; + if (start > end) return false ; + res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ; + if (res) { *p = A->iso ? 0 : start ; } + return res ; + } + else + { + start = 0 ; end = A->plen - 1 ; + res = GB_binary_search(vector, A->h, A->j_is_32, &start, &end) ; + if (!res) return false ; + int64_t k = start ; + start = A->p_is_32 ? ((uint32_t*)A->p)[k] : ((uint64_t*)A->p)[k] ; + end = A->p_is_32 ? ((uint32_t*)A->p)[k+1] : ((uint64_t*)A->p)[k+1] ; + end-- ; + if (start > end) return false ; + res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ; + if (res) { *p = A->iso ? 0 : start ; } + return res ; + } +} + +#include "emult/GB_emult.h" + GrB_Info GB_kron // C = accum (C, kron(A,B)) ( GrB_Matrix C, // input/output matrix for results @@ -104,6 +164,314 @@ GrB_Info GB_kron // C = accum (C, kron(A,B)) // quick return if an empty mask is complemented GB_RETURN_IF_QUICK_MASK (C, C_replace, M, Mask_comp, Mask_struct) ; + // check if it's possible to apply mask immediately in kron + // TODO: make MT of same CSR/CSC format as C + + GrB_Matrix MT; + if (M != NULL && !Mask_comp) + { + // iterate over mask, count how many elements will be present in MT + // initialize MT->p + + GB_MATRIX_WAIT(M); + + size_t allocated = 0 ; + bool MT_hypersparse = (A->h != NULL) || (B->h != NULL); + int64_t centries ; + uint64_t nvecs ; + centries = 0 ; + nvecs = 0 ; + + uint32_t* MTp32 = NULL ; uint64_t* MTp64 = NULL ; + MTp32 = M->p_is_32 ? GB_calloc_memory (M->vdim + 1, sizeof(uint32_t), &allocated) : NULL ; + MTp64 = M->p_is_32 ? NULL : GB_calloc_memory (M->vdim + 1, sizeof(uint64_t), &allocated) ; + if (MTp32 == NULL && MTp64 == NULL) + { + OUT_OF_MEM_p: + GB_FREE_WORKSPACE ; + return GrB_OUT_OF_MEMORY ; + } + + GrB_Type MTtype = op->ztype ; + const size_t MTsize = MTtype->size ; + GB_void MTscalar [GB_VLA(MTsize)] ; + bool MTiso = GB_emult_iso (MTscalar, MTtype, A, B, op) ; + + GB_Mp_DECLARE(Mp, ) ; + GB_Mp_PTR(Mp, M) ; + + GB_Mh_DECLARE(Mh, ) ; + GB_Mh_PTR(Mh, M) ; + + GB_Mi_DECLARE(Mi, ) ; + GB_Mi_PTR(Mi, M) ; + + GB_cast_function cast_A = NULL ; + GB_cast_function cast_B = NULL ; + + cast_A = GB_cast_factory (op->xtype->code, A->type->code) ; + cast_B = GB_cast_factory (op->ytype->code, B->type->code) ; + + int64_t vlen = M->vlen ; + #pragma omp parallel + { + GrB_Index offset ; + + #pragma omp for reduction(+:nvecs) + for (GrB_Index k = 0 ; k < M->nvec ; k++) + { + GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ; + + int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ; + int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ; + bool nonempty = false ; + for (GrB_Index p = pA_start ; p < pA_end ; p++) + { + if (!GBB (M->b, p)) continue ; + + int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ; + GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ; + + // extract elements from A and B, increment MTp + + if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) + { + GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows); + GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols); + + GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows); + GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols); + + bool code = GB_lookup_xoffset(&offset, A, arow, acol) ; + if (!code) + { + continue; + } + + code = GB_lookup_xoffset(&offset, B, brow, bcol) ; + if (!code) + { + continue; + } + + if (M->p_is_32) + { + (MTp32[j])++ ; + } + else + { + (MTp64[j])++ ; + } + nonempty = true ; + } + } + if (nonempty) nvecs++ ; + } + } + + // GB_cumsum for MT->p + + double work = M->vdim ; + int nthreads_max = GB_Context_nthreads_max ( ) ; + double chunk = GB_Context_chunk ( ) ; + int cumsum_threads = GB_nthreads (work, chunk, nthreads_max) ; + M->p_is_32 ? GB_cumsum(MTp32, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) : + GB_cumsum(MTp64, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) ; + + centries = M->p_is_32 ? MTp32[M->vdim] : MTp64[M->vdim] ; + + uint32_t* MTi32 = NULL ; uint64_t* MTi64 = NULL; + MTi32 = M->i_is_32 ? GB_malloc_memory (centries, sizeof(uint32_t), &allocated) : NULL ; + MTi64 = M->i_is_32 ? NULL : GB_malloc_memory (centries, sizeof(uint64_t), &allocated) ; + + if (centries > 0 && MTi32 == NULL && MTi64 == NULL) + { + OUT_OF_MEM_i: + if (M->p_is_32) { GB_free_memory (&MTp32, (M->vdim + 1) * sizeof(uint32_t)) ; } + else { GB_free_memory (&MTp64, (M->vdim + 1) * sizeof(uint64_t)) ; } + goto OUT_OF_MEM_p ; + } + + void* MTx = NULL ; + if (!MTiso) + { + MTx = GB_malloc_memory (centries, op->ztype->size, &allocated) ; + } + else + { + MTx = GB_malloc_memory (1, op->ztype->size, &allocated) ; + if (MTx == NULL) goto OUT_OF_MEM_x ; + memcpy (MTx, MTscalar, MTsize) ; + } + + if (centries > 0 && MTx == NULL) + { + OUT_OF_MEM_x: + if (M->i_is_32) { GB_free_memory (&MTi32, centries * sizeof(uint32_t)) ; } + else { GB_free_memory (&MTi64, centries * sizeof (uint64_t)) ; } + goto OUT_OF_MEM_i ; + } + + #pragma omp parallel + { + GrB_Index offset ; + GB_void a_elem[op->xtype->size] ; + GB_void b_elem[op->ytype->size] ; + + #pragma omp for + for (GrB_Index k = 0 ; k < M->nvec ; k++) + { + GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ; + + int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ; + int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ; + GrB_Index pos = M->p_is_32 ? MTp32[j] : MTp64[j] ; + for (GrB_Index p = pA_start ; p < pA_end ; p++) + { + if (!GBB (M->b, p)) continue ; + + int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ; + GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ; + + // extract elements from A and B, + // initialize offset in MTi and MTx, + // get result of op, place it in MTx + + if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) + { + GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows); + GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols); + + GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows); + GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols); + + bool code = GB_lookup_xoffset (&offset, A, arow, acol) ; + if (!code) + { + continue; + } + if (!MTiso) + cast_A (a_elem, A->x + offset * A->type->size, A->type->size) ; + + code = GB_lookup_xoffset (&offset, B, brow, bcol) ; + if (!code) + { + continue; + } + if (!MTiso) + cast_B (b_elem, B->x + offset * B->type->size, B->type->size) ; + + if (!MTiso) + { + if (op->binop_function) + { + op->binop_function (MTx + op->ztype->size * pos, a_elem, b_elem) ; + } + else + { + GrB_Index ix, iy, jx, jy ; + ix = A_transpose ? acol : arow ; + iy = A_transpose ? arow : acol ; + jx = B_transpose ? bcol : brow ; + jy = B_transpose ? brow : bcol ; + op->idxbinop_function (MTx + op->ztype->size * pos, a_elem, ix, iy, + b_elem, jx, jy, op->theta) ; + } + } + + if (M->i_is_32) { MTi32[pos] = i ; } else { MTi64[pos] = i ; } + pos++ ; + } + } + } + } + + #undef GBI + #undef GBB + #undef GBP + #undef GBH + + // initialize other fields of MT properly + + MT = NULL ; + GrB_Info MTalloc = GB_new_bix (&MT, op->ztype, vlen, M->vdim, GB_ph_null, M->is_csc, + GxB_SPARSE, true, M->hyper_switch, M->vdim, centries, true, MTiso, + M->p_is_32, M->j_is_32, M->i_is_32) ; + if (MTalloc != GrB_SUCCESS) + { + if (MTiso) { GB_free_memory (&MTx, op->ztype->size) ; } + else { GB_free_memory (&MTx, centries * op->ztype->size) ; } + goto OUT_OF_MEM_x ; + } + + GB_MATRIX_WAIT(MT) ; + + GB_free_memory (&MT->i, MT->i_size) ; + GB_free_memory (&MT->x, MT->x_size) ; + + MT->p = M->p_is_32 ? (void*)MTp32 : (void*)MTp64 ; + MT->i = M->i_is_32 ? (void*)MTi32 : (void*)MTi64 ; + MT->x = MTx ; + + MT->p_size = (M->p_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * (M->vdim + 1) ; + MT->i_size = ((M->i_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * centries) ; + MT->x_size = MT->iso ? op->ztype->size : op->ztype->size * centries ; + MT->magic = GB_MAGIC ; + MT->nvals = centries ; + MT->nvec_nonempty = nvecs ; + + // transpose and convert to hyper if needed + + if (MT->is_csc != C->is_csc) + { + GrB_Info MTtranspose = GB_transpose_in_place (MT, true, Werk) ; + if (MTtranspose != GrB_SUCCESS) + { + GB_FREE_WORKSPACE ; + GB_Matrix_free (&MT) ; + return MTtranspose ; + } + } + + if (MT_hypersparse) + { + uint32_t* MTh32 = NULL ; uint64_t* MTh64 = NULL ; + if (MT->j_is_32) + { + MTh32 = GB_malloc_memory (MT->vdim, sizeof(uint32_t), &allocated) ; + } + else + { + MTh64 = GB_malloc_memory (MT->vdim, sizeof(uint64_t), &allocated) ; + } + + if (MTh32 == NULL && MTh64 == NULL) + { + GB_FREE_WORKSPACE ; + GB_Matrix_free (&MT) ; + return GrB_OUT_OF_MEMORY ; + } + + #pragma omp parallel for + for (GrB_Index i = 0; i < MT->vdim; i++) + { + if (MT->j_is_32) { MTh32[i] = i ; } else { MTh64[i] = i ; } + } + + MT->h = MTh32 ? (void*)MTh32 : (void*)MTh64 ; + + GrB_Info MThyperprune = GB_hyper_prune (MT, Werk) ; + if (MThyperprune != GrB_SUCCESS) + { + GB_FREE_WORKSPACE ; + GB_Matrix_free (&MT) ; + return MThyperprune ; + } + } + + return (GB_accum_mask (C, M, NULL, accum, &MT, C_replace, Mask_comp, Mask_struct, Werk)) ; + } + //-------------------------------------------------------------------------- // transpose A and B if requested //-------------------------------------------------------------------------- @@ -153,7 +521,7 @@ GrB_Info GB_kron // C = accum (C, kron(A,B)) GB_CLEAR_MATRIX_HEADER (T, &T_header) ; GB_OK (GB_kroner (T, T_is_csc, op, flipij, A_transpose ? AT : A, A_is_pattern, - B_transpose ? BT : B, B_is_pattern, Werk)) ; + B_transpose ? BT : B, B_is_pattern, M, Mask_comp, Mask_struct, Werk)) ; GB_FREE_WORKSPACE ; ASSERT_MATRIX_OK (T, "T = kron(A,B)", GB0) ; diff --git a/Source/kronecker/GB_kron.h b/Source/kronecker/GB_kron.h index a6c1a3f3ae..d1628b0899 100644 --- a/Source/kronecker/GB_kron.h +++ b/Source/kronecker/GB_kron.h @@ -37,6 +37,9 @@ GrB_Info GB_kroner // C = kron (A,B) bool A_is_pattern, // true if values of A are not used const GrB_Matrix B, // input matrix bool B_is_pattern, // true if values of B are not used + const GrB_Matrix Mask, + const bool Mask_comp, + const bool Mask_struct, GB_Werk Werk ) ; diff --git a/Source/kronecker/GB_kroner.c b/Source/kronecker/GB_kroner.c index afbb655ee4..4fb265b887 100644 --- a/Source/kronecker/GB_kroner.c +++ b/Source/kronecker/GB_kroner.c @@ -39,6 +39,9 @@ GrB_Info GB_kroner // C = kron (A,B) bool A_is_pattern, // true if values of A are not used const GrB_Matrix B_in, // input matrix bool B_is_pattern, // true if values of B are not used + const GrB_Matrix Mask, + const bool Mask_comp, + const bool Mask_struct, GB_Werk Werk ) { diff --git a/Test/test226.m b/Test/test226.m index 41f77fd3d3..4778425a8a 100644 --- a/Test/test226.m +++ b/Test/test226.m @@ -9,6 +9,8 @@ A.matrix = sprand (5, 10, 0.4) ; B.matrix = ones (3, 2) ; B.iso = true ; +M.matrix = sprandn (15, 20,0.2) ~= 0 ; +MT.matrix = sprandn (9, 4, 20,0.2) ~= 0 ; mult.opname = 'times' ; mult.optype = 'double' ; @@ -18,14 +20,26 @@ C2 = GB_spec_kron (Cin, [ ], [ ], mult, A, B, [ ]) ; GB_spec_compare (C1, C2) ; +C1 = GB_mex_kron (Cin, M, [ ], mult, A, B, [ ]) ; +C2 = GB_spec_kron (Cin, M, [ ], mult, A, B, [ ]) ; +GB_spec_compare (C1, C2) ; + C1 = GB_mex_kron (Cin, [ ], [ ], mult, B, A, [ ]) ; C2 = GB_spec_kron (Cin, [ ], [ ], mult, B, A, [ ]) ; GB_spec_compare (C1, C2) ; +C1 = GB_mex_kron (Cin, M, [ ], mult, B, A, [ ]) ; +C2 = GB_spec_kron (Cin, M, [ ], mult, B, A, [ ]) ; +GB_spec_compare (C1, C2) ; + Cin = sparse (9, 4) ; C1 = GB_mex_kron (Cin, [ ], [ ], mult, B, B, [ ]) ; C2 = GB_spec_kron (Cin, [ ], [ ], mult, B, B, [ ]) ; GB_spec_compare (C1, C2) ; +C1 = GB_mex_kron (Cin, MT, [ ], mult, B, B, [ ]) ; +C2 = GB_spec_kron (Cin, MT, [ ], mult, B, B, [ ]) ; +GB_spec_compare (C1, C2) ; + fprintf ('\ntest226: all tests passed\n') ; diff --git a/Test/test227.m b/Test/test227.m index 98e75da2e2..50dd0153d1 100644 --- a/Test/test227.m +++ b/Test/test227.m @@ -16,6 +16,16 @@ dnt = struct ( 'inp1', 'tran' ) ; dtt = struct ( 'inp0', 'tran', 'inp1', 'tran' ) ; +dnn.mask = 'default' ; +dtn.mask = 'default' ; +dnt.mask = 'default' ; +dtt.mask = 'default' ; + +dnn.outp = 'default' ; +dtn.outp = 'default' ; +dnt.outp = 'default' ; +dtt.outp = 'default' ; + types = { 'int32', 'int64', 'single', 'double' } ; am = 5 ; @@ -23,13 +33,18 @@ bm = 4 ; bn = 2 ; -Ax = sparse (100 * sprandn (am,an, 0.5)) ; -Bx = sparse (100 * sprandn (bm,bn, 0.5)) ; +Ax_temp = 100 * sprandn (am, an, 0.5); +Bx_temp = 100 * sprandn (bm, bn, 0.5); + +Ax = sparse(round(Ax_temp)); +Bx = sparse(round(Bx_temp)); + cm = am * bm ; cn = an * bn ; Cx = sparse (cm,cn) ; -AT = Ax' ; -BT = Bx' ; +Maskmat = sprandn (cm,cn,0.2) ~= 0 ; +ATmat = Ax' ; +BTmat = Bx' ; for k2 = [4 7 45:52 ] for k1 = 1:4 @@ -64,9 +79,19 @@ B.is_csc = B_is_csc ; clear C - C.matrix = Cx ; + C.matrix = sparse (cm,cn) ; C.is_csc = C_is_csc ; + clear AT + AT.matrix = ATmat ; + AT.is_hyper = A_is_hyper ; + AT.is_csc = A.is_csc ; + + clear BT + BT.matrix = BTmat ; + BT.is_hyper = B_is_hyper ; + BT.is_csc = B.is_csc ; + %--------------------------------------- % kron(A,B) %--------------------------------------- @@ -103,6 +128,44 @@ C1 = GB_mex_kron (C, [ ], [ ], op, AT, BT, dtt) ; GB_spec_compare (C0, C1) ; + % tests with Mask + for Mask_is_hyper = 0:1 + for Mask_is_csc = 0:1 + fprintf('*') + + A.is_csc = A_is_csc ; + B.is_csc = B_is_csc ; + A.is_hyper = A_is_hyper ; + B.is_hyper = B_is_hyper ; + + clear M + M.matrix = Maskmat ; + M.is_hyper = Mask_is_hyper ; + M.is_csc = Mask_is_csc; + C.is_csc = C_is_csc ; + + % kron(A, B) with Mask + C0 = GB_spec_kron (C, M, [ ], op, A, B, dnn) ; + fprintf('#') ; + C1 = GB_mex_kron (C, M, [ ], op, A, B, dnn) ; + GB_spec_compare(C0, C1) ; + + % kron(A', B) with Mask + C0 = GB_spec_kron (C, M, [ ], op, AT, B, dtn) ; + C1 = GB_mex_kron (C, M, [ ], op, AT, B, dtn) ; + GB_spec_compare (C0, C1) ; + + % kron(A, B') with Mask + C0 = GB_spec_kron (C, M, [ ], op, A, BT, dnt) ; + C1 = GB_mex_kron (C, M, [ ], op, A, BT, dnt) ; + GB_spec_compare (C0, C1) ; + + % kron(A', B') with Mask + C0 = GB_spec_kron (C, M, [ ], op, AT, BT, dtt) ; + C1 = GB_mex_kron (C, M, [ ], op, AT, BT, dtt) ; + GB_spec_compare (C0, C1) ; + end + end end end end