Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 369 additions & 1 deletion Source/kronecker/GB_kron.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,71 @@
GB_Matrix_free (&T) ; \
}

#define GBI(Ai,p,avlen) ((Ai == NULL) ? ((p) % (avlen)) : Ai [p])

#define GBB(Ab,p) ((Ab == NULL) ? 1 : Ab [p])

#define GBP(Ap,k,avlen) ((Ap == NULL) ? ((k) * (avlen)) : Ap [k])

#define GBH(Ah,k) ((Ah == NULL) ? (k) : Ah [k])

#include "kronecker/GB_kron.h"
#include "mxm/GB_mxm.h"
#include "transpose/GB_transpose.h"
#include "mask/GB_accum_mask.h"

static bool GB_lookup_xoffset (
GrB_Index* p,
GrB_Matrix A,
GrB_Index row,
GrB_Index col
)
{
GrB_Index vector = A->is_csc ? col : row ;
GrB_Index coord = A->is_csc ? row : col ;

if (A->p == NULL)
{
GrB_Index offset = vector * A->vlen + coord ;
if (A->b == NULL || ((int8_t*)A->b)[offset])
{
*p = A->iso ? 0 : offset ;
return true ;
}
return false ;
}

int64_t start, end ;
bool res ;

if (A->h == NULL)
{
start = A->p_is_32 ? ((uint32_t*)A->p)[vector] : ((uint64_t*)A->p)[vector] ;
end = A->p_is_32 ? ((uint32_t*)A->p)[vector + 1] : ((uint64_t*)A->p)[vector + 1] ;
end-- ;
if (start > end) return false ;
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
if (res) { *p = A->iso ? 0 : start ; }
return res ;
}
else
{
start = 0 ; end = A->plen - 1 ;
res = GB_binary_search(vector, A->h, A->j_is_32, &start, &end) ;
if (!res) return false ;
int64_t k = start ;
start = A->p_is_32 ? ((uint32_t*)A->p)[k] : ((uint64_t*)A->p)[k] ;
end = A->p_is_32 ? ((uint32_t*)A->p)[k+1] : ((uint64_t*)A->p)[k+1] ;
end-- ;
if (start > end) return false ;
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
if (res) { *p = A->iso ? 0 : start ; }
return res ;
}
}

#include "emult/GB_emult.h"

GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
(
GrB_Matrix C, // input/output matrix for results
Expand Down Expand Up @@ -104,6 +164,314 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
// quick return if an empty mask is complemented
GB_RETURN_IF_QUICK_MASK (C, C_replace, M, Mask_comp, Mask_struct) ;

// check if it's possible to apply mask immediately in kron
// TODO: make MT of same CSR/CSC format as C

GrB_Matrix MT;
if (M != NULL && !Mask_comp)
{
// iterate over mask, count how many elements will be present in MT
// initialize MT->p

GB_MATRIX_WAIT(M);

size_t allocated = 0 ;
bool MT_hypersparse = (A->h != NULL) || (B->h != NULL);
int64_t centries ;
uint64_t nvecs ;
centries = 0 ;
nvecs = 0 ;

uint32_t* MTp32 = NULL ; uint64_t* MTp64 = NULL ;
MTp32 = M->p_is_32 ? GB_calloc_memory (M->vdim + 1, sizeof(uint32_t), &allocated) : NULL ;
MTp64 = M->p_is_32 ? NULL : GB_calloc_memory (M->vdim + 1, sizeof(uint64_t), &allocated) ;
if (MTp32 == NULL && MTp64 == NULL)
{
OUT_OF_MEM_p:
GB_FREE_WORKSPACE ;
return GrB_OUT_OF_MEMORY ;
}

GrB_Type MTtype = op->ztype ;
const size_t MTsize = MTtype->size ;
GB_void MTscalar [GB_VLA(MTsize)] ;
bool MTiso = GB_emult_iso (MTscalar, MTtype, A, B, op) ;

GB_Mp_DECLARE(Mp, ) ;
GB_Mp_PTR(Mp, M) ;

GB_Mh_DECLARE(Mh, ) ;
GB_Mh_PTR(Mh, M) ;

GB_Mi_DECLARE(Mi, ) ;
GB_Mi_PTR(Mi, M) ;

GB_cast_function cast_A = NULL ;
GB_cast_function cast_B = NULL ;

cast_A = GB_cast_factory (op->xtype->code, A->type->code) ;
cast_B = GB_cast_factory (op->ytype->code, B->type->code) ;

int64_t vlen = M->vlen ;
#pragma omp parallel
{
GrB_Index offset ;

#pragma omp for reduction(+:nvecs)
for (GrB_Index k = 0 ; k < M->nvec ; k++)
{
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;

int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
bool nonempty = false ;
for (GrB_Index p = pA_start ; p < pA_end ; p++)
{
if (!GBB (M->b, p)) continue ;

int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;

// extract elements from A and B, increment MTp

if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p]))
{
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);

GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);

bool code = GB_lookup_xoffset(&offset, A, arow, acol) ;
if (!code)
{
continue;
}

code = GB_lookup_xoffset(&offset, B, brow, bcol) ;
if (!code)
{
continue;
}

if (M->p_is_32)
{
(MTp32[j])++ ;
}
else
{
(MTp64[j])++ ;
}
nonempty = true ;
}
}
if (nonempty) nvecs++ ;
}
}

// GB_cumsum for MT->p

double work = M->vdim ;
int nthreads_max = GB_Context_nthreads_max ( ) ;
double chunk = GB_Context_chunk ( ) ;
int cumsum_threads = GB_nthreads (work, chunk, nthreads_max) ;
M->p_is_32 ? GB_cumsum(MTp32, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) :
GB_cumsum(MTp64, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) ;

centries = M->p_is_32 ? MTp32[M->vdim] : MTp64[M->vdim] ;

uint32_t* MTi32 = NULL ; uint64_t* MTi64 = NULL;
MTi32 = M->i_is_32 ? GB_malloc_memory (centries, sizeof(uint32_t), &allocated) : NULL ;
MTi64 = M->i_is_32 ? NULL : GB_malloc_memory (centries, sizeof(uint64_t), &allocated) ;

if (centries > 0 && MTi32 == NULL && MTi64 == NULL)
{
OUT_OF_MEM_i:
if (M->p_is_32) { GB_free_memory (&MTp32, (M->vdim + 1) * sizeof(uint32_t)) ; }
else { GB_free_memory (&MTp64, (M->vdim + 1) * sizeof(uint64_t)) ; }
goto OUT_OF_MEM_p ;
}

void* MTx = NULL ;
if (!MTiso)
{
MTx = GB_malloc_memory (centries, op->ztype->size, &allocated) ;
}
else
{
MTx = GB_malloc_memory (1, op->ztype->size, &allocated) ;
if (MTx == NULL) goto OUT_OF_MEM_x ;
memcpy (MTx, MTscalar, MTsize) ;
}

if (centries > 0 && MTx == NULL)
{
OUT_OF_MEM_x:
if (M->i_is_32) { GB_free_memory (&MTi32, centries * sizeof(uint32_t)) ; }
else { GB_free_memory (&MTi64, centries * sizeof (uint64_t)) ; }
goto OUT_OF_MEM_i ;
}

#pragma omp parallel
{
GrB_Index offset ;
GB_void a_elem[op->xtype->size] ;
GB_void b_elem[op->ytype->size] ;

#pragma omp for
for (GrB_Index k = 0 ; k < M->nvec ; k++)
{
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;

int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
GrB_Index pos = M->p_is_32 ? MTp32[j] : MTp64[j] ;
for (GrB_Index p = pA_start ; p < pA_end ; p++)
{
if (!GBB (M->b, p)) continue ;

int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;

// extract elements from A and B,
// initialize offset in MTi and MTx,
// get result of op, place it in MTx

if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p]))
{
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);

GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);

bool code = GB_lookup_xoffset (&offset, A, arow, acol) ;
if (!code)
{
continue;
}
if (!MTiso)
cast_A (a_elem, A->x + offset * A->type->size, A->type->size) ;

code = GB_lookup_xoffset (&offset, B, brow, bcol) ;
if (!code)
{
continue;
}
if (!MTiso)
cast_B (b_elem, B->x + offset * B->type->size, B->type->size) ;

if (!MTiso)
{
if (op->binop_function)
{
op->binop_function (MTx + op->ztype->size * pos, a_elem, b_elem) ;
}
else
{
GrB_Index ix, iy, jx, jy ;
ix = A_transpose ? acol : arow ;
iy = A_transpose ? arow : acol ;
jx = B_transpose ? bcol : brow ;
jy = B_transpose ? brow : bcol ;
op->idxbinop_function (MTx + op->ztype->size * pos, a_elem, ix, iy,
b_elem, jx, jy, op->theta) ;
}
}

if (M->i_is_32) { MTi32[pos] = i ; } else { MTi64[pos] = i ; }
pos++ ;
}
}
}
}

#undef GBI
#undef GBB
#undef GBP
#undef GBH

// initialize other fields of MT properly

MT = NULL ;
GrB_Info MTalloc = GB_new_bix (&MT, op->ztype, vlen, M->vdim, GB_ph_null, M->is_csc,
GxB_SPARSE, true, M->hyper_switch, M->vdim, centries, true, MTiso,
M->p_is_32, M->j_is_32, M->i_is_32) ;
if (MTalloc != GrB_SUCCESS)
{
if (MTiso) { GB_free_memory (&MTx, op->ztype->size) ; }
else { GB_free_memory (&MTx, centries * op->ztype->size) ; }
goto OUT_OF_MEM_x ;
}

GB_MATRIX_WAIT(MT) ;

GB_free_memory (&MT->i, MT->i_size) ;
GB_free_memory (&MT->x, MT->x_size) ;

MT->p = M->p_is_32 ? (void*)MTp32 : (void*)MTp64 ;
MT->i = M->i_is_32 ? (void*)MTi32 : (void*)MTi64 ;
MT->x = MTx ;

MT->p_size = (M->p_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * (M->vdim + 1) ;
MT->i_size = ((M->i_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * centries) ;
MT->x_size = MT->iso ? op->ztype->size : op->ztype->size * centries ;
MT->magic = GB_MAGIC ;
MT->nvals = centries ;
MT->nvec_nonempty = nvecs ;

// transpose and convert to hyper if needed

if (MT->is_csc != C->is_csc)
{
GrB_Info MTtranspose = GB_transpose_in_place (MT, true, Werk) ;
if (MTtranspose != GrB_SUCCESS)
{
GB_FREE_WORKSPACE ;
GB_Matrix_free (&MT) ;
return MTtranspose ;
}
}

if (MT_hypersparse)
{
uint32_t* MTh32 = NULL ; uint64_t* MTh64 = NULL ;
if (MT->j_is_32)
{
MTh32 = GB_malloc_memory (MT->vdim, sizeof(uint32_t), &allocated) ;
}
else
{
MTh64 = GB_malloc_memory (MT->vdim, sizeof(uint64_t), &allocated) ;
}

if (MTh32 == NULL && MTh64 == NULL)
{
GB_FREE_WORKSPACE ;
GB_Matrix_free (&MT) ;
return GrB_OUT_OF_MEMORY ;
}

#pragma omp parallel for
for (GrB_Index i = 0; i < MT->vdim; i++)
{
if (MT->j_is_32) { MTh32[i] = i ; } else { MTh64[i] = i ; }
}

MT->h = MTh32 ? (void*)MTh32 : (void*)MTh64 ;

GrB_Info MThyperprune = GB_hyper_prune (MT, Werk) ;
if (MThyperprune != GrB_SUCCESS)
{
GB_FREE_WORKSPACE ;
GB_Matrix_free (&MT) ;
return MThyperprune ;
}
}

return (GB_accum_mask (C, M, NULL, accum, &MT, C_replace, Mask_comp, Mask_struct, Werk)) ;
}

//--------------------------------------------------------------------------
// transpose A and B if requested
//--------------------------------------------------------------------------
Expand Down Expand Up @@ -153,7 +521,7 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
GB_CLEAR_MATRIX_HEADER (T, &T_header) ;
GB_OK (GB_kroner (T, T_is_csc, op, flipij,
A_transpose ? AT : A, A_is_pattern,
B_transpose ? BT : B, B_is_pattern, Werk)) ;
B_transpose ? BT : B, B_is_pattern, M, Mask_comp, Mask_struct, Werk)) ;

GB_FREE_WORKSPACE ;
ASSERT_MATRIX_OK (T, "T = kron(A,B)", GB0) ;
Expand Down
Loading