Actual source code: lrc.c


  2: #include <petsc/private/matimpl.h>

  4: PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *);

  6: typedef struct {
  7:   Mat A;            /* sparse matrix */
  8:   Mat U, V;         /* dense tall-skinny matrices */
  9:   Vec c;            /* sequential vector containing the diagonal of C */
 10:   Vec work1, work2; /* sequential vectors that hold partial products */
 11:   Vec xl, yl;       /* auxiliary sequential vectors for matmult operation */
 12: } Mat_LRC;

 14: static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose)
 15: {
 16:   Mat_LRC    *Na = (Mat_LRC *)N->data;
 17:   PetscMPIInt size;
 18:   Mat         U, V;

 20:   PetscFunctionBegin;
 21:   U = transpose ? Na->V : Na->U;
 22:   V = transpose ? Na->U : Na->V;
 23:   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size));
 24:   if (size == 1) {
 25:     PetscCall(MatMultHermitianTranspose(V, x, Na->work1));
 26:     if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1));
 27:     if (Na->A) {
 28:       if (transpose) {
 29:         PetscCall(MatMultTranspose(Na->A, x, y));
 30:       } else {
 31:         PetscCall(MatMult(Na->A, x, y));
 32:       }
 33:       PetscCall(MatMultAdd(U, Na->work1, y, y));
 34:     } else {
 35:       PetscCall(MatMult(U, Na->work1, y));
 36:     }
 37:   } else {
 38:     Mat                Uloc, Vloc;
 39:     Vec                yl, xl;
 40:     const PetscScalar *w1;
 41:     PetscScalar       *w2;
 42:     PetscInt           nwork;
 43:     PetscMPIInt        mpinwork;

 45:     xl = transpose ? Na->yl : Na->xl;
 46:     yl = transpose ? Na->xl : Na->yl;
 47:     PetscCall(VecGetLocalVector(y, yl));
 48:     PetscCall(MatDenseGetLocalMatrix(U, &Uloc));
 49:     PetscCall(MatDenseGetLocalMatrix(V, &Vloc));

 51:     /* multiply the local part of V with the local part of x */
 52:     PetscCall(VecGetLocalVectorRead(x, xl));
 53:     PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1));
 54:     PetscCall(VecRestoreLocalVectorRead(x, xl));

 56:     /* form the sum of all the local multiplies: this is work2 = V'*x =
 57:        sum_{all processors} work1 */
 58:     PetscCall(VecGetArrayRead(Na->work1, &w1));
 59:     PetscCall(VecGetArrayWrite(Na->work2, &w2));
 60:     PetscCall(VecGetLocalSize(Na->work1, &nwork));
 61:     PetscCall(PetscMPIIntCast(nwork, &mpinwork));
 62:     PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N)));
 63:     PetscCall(VecRestoreArrayRead(Na->work1, &w1));
 64:     PetscCall(VecRestoreArrayWrite(Na->work2, &w2));

 66:     if (Na->c) { /* work2 = C*work2 */
 67:       PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2));
 68:     }

 70:     if (Na->A) {
 71:       /* form y = A*x or A^t*x */
 72:       if (transpose) {
 73:         PetscCall(MatMultTranspose(Na->A, x, y));
 74:       } else {
 75:         PetscCall(MatMult(Na->A, x, y));
 76:       }
 77:       /* multiply-add y = y + U*work2 */
 78:       PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl));
 79:     } else {
 80:       /* multiply y = U*work2 */
 81:       PetscCall(MatMult(Uloc, Na->work2, yl));
 82:     }

 84:     PetscCall(VecRestoreLocalVector(y, yl));
 85:   }
 86:   PetscFunctionReturn(PETSC_SUCCESS);
 87: }

 89: static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y)
 90: {
 91:   PetscFunctionBegin;
 92:   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE));
 93:   PetscFunctionReturn(PETSC_SUCCESS);
 94: }

 96: static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y)
 97: {
 98:   PetscFunctionBegin;
 99:   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE));
100:   PetscFunctionReturn(PETSC_SUCCESS);
101: }

103: static PetscErrorCode MatDestroy_LRC(Mat N)
104: {
105:   Mat_LRC *Na = (Mat_LRC *)N->data;

107:   PetscFunctionBegin;
108:   PetscCall(MatDestroy(&Na->A));
109:   PetscCall(MatDestroy(&Na->U));
110:   PetscCall(MatDestroy(&Na->V));
111:   PetscCall(VecDestroy(&Na->c));
112:   PetscCall(VecDestroy(&Na->work1));
113:   PetscCall(VecDestroy(&Na->work2));
114:   PetscCall(VecDestroy(&Na->xl));
115:   PetscCall(VecDestroy(&Na->yl));
116:   PetscCall(PetscFree(N->data));
117:   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL));
118:   PetscFunctionReturn(PETSC_SUCCESS);
119: }

121: static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
122: {
123:   Mat_LRC *Na = (Mat_LRC *)N->data;

125:   PetscFunctionBegin;
126:   if (A) *A = Na->A;
127:   if (U) *U = Na->U;
128:   if (c) *c = Na->c;
129:   if (V) *V = Na->V;
130:   PetscFunctionReturn(PETSC_SUCCESS);
131: }

133: /*@
134:    MatLRCGetMats - Returns the constituents of an LRC matrix

136:    Collective

138:    Input Parameter:
139: .  N - matrix of type `MATLRC`

141:    Output Parameters:
142: +  A - the (sparse) matrix
143: .  U - first dense rectangular (tall and skinny) matrix
144: .  c - a sequential vector containing the diagonal of C
145: -  V - second dense rectangular (tall and skinny) matrix

147:    Level: intermediate

149:    Notes:
150:    The returned matrices need not be destroyed by the caller.

152:    `U`, `c`, `V` may be `NULL` if not needed

154: .seealso: [](ch_matrices), `Mat`, `MATLRC`, `MatCreateLRC()`
155: @*/
156: PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
157: {
158:   PetscFunctionBegin;
159:   PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V));
160:   PetscFunctionReturn(PETSC_SUCCESS);
161: }

163: /*MC
164:   MATLRC -  "lrc" - a matrix object that behaves like A + U*C*V'

166:   Note:
167:    The matrix A + U*C*V' is not formed! Rather the matrix  object performs the matrix-vector product `MatMult()`, by first multiplying by
168:    A and then adding the other term.

170:   Level: advanced

172: .seealso: [](ch_matrices), `Mat`, `MatCreateLRC()`, `MatMult()`, `MatLRCGetMats()`
173: M*/

175: /*@
176:    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' of type `MATLRC`

178:    Collective

180:    Input Parameters:
181: +  A    - the (sparse) matrix (can be `NULL`)
182: .  U    - dense rectangular (tall and skinny) matrix
183: .  V    - dense rectangular (tall and skinny) matrix
184: -  c    - a vector containing the diagonal of C (can be `NULL`)

186:    Output Parameter:
187: .  N    - the matrix that represents A + U*C*V'

189:    Level: intermediate

191:    Notes:
192:    The matrix A + U*C*V' is not formed! Rather the new matrix
193:    object performs the matrix-vector product `MatMult()`, by first multiplying by
194:    A and then adding the other term.

196:    `C` is a diagonal matrix (represented as a vector) of order k,
197:    where k is the number of columns of both `U` and `V`.

199:    If `A` is `NULL` then the new object behaves like a low-rank matrix U*C*V'.

201:    Use `V`=`U` (or `V`=`NULL`) for a symmetric low-rank correction, A + U*C*U'.

203:    If `c` is `NULL` then the low-rank correction is just U*V'.
204:    If a sequential `c` vector is used for a parallel matrix,
205:    PETSc assumes that the values of the vector are consistently set across processors.

207: .seealso: [](ch_matrices), `Mat`, `MATLRC`, `MatLRCGetMats()`
208: @*/
209: PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N)
210: {
211:   PetscBool   match;
212:   PetscInt    m, n, k, m1, n1, k1;
213:   Mat_LRC    *Na;
214:   Mat         Uloc;
215:   PetscMPIInt size, csize = 0;

217:   PetscFunctionBegin;
221:   if (V) {
223:     PetscCheckSameComm(U, 2, V, 4);
224:   }
225:   if (A) PetscCheckSameComm(A, 1, U, 2);

227:   if (!V) V = U;
228:   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, ""));
229:   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name);
230:   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, ""));
231:   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name);
232:   PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match));
233:   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype);
234:   if (A) {
235:     PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match));
236:     PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype);
237:   }

239:   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
240:   PetscCall(MatGetSize(U, NULL, &k));
241:   PetscCall(MatGetSize(V, NULL, &k1));
242:   PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1);
243:   PetscCall(MatGetLocalSize(U, &m, NULL));
244:   PetscCall(MatGetLocalSize(V, &n, NULL));
245:   if (A) {
246:     PetscCall(MatGetLocalSize(A, &m1, &n1));
247:     PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1);
248:     PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1);
249:   }
250:   if (c) {
251:     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
252:     PetscCall(VecGetSize(c, &k1));
253:     PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k);
254:     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize);
255:   }

257:   PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N));
258:   PetscCall(MatSetSizes(*N, m, n, PETSC_DECIDE, PETSC_DECIDE));
259:   PetscCall(MatSetVecType(*N, U->defaultvectype));
260:   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATLRC));
261:   /* Flag matrix as symmetric if A is symmetric and U == V */
262:   PetscCall(MatSetOption(*N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));

264:   PetscCall(PetscNew(&Na));
265:   (*N)->data = (void *)Na;
266:   Na->A      = A;
267:   Na->U      = U;
268:   Na->c      = c;
269:   Na->V      = V;

271:   PetscCall(PetscObjectReference((PetscObject)A));
272:   PetscCall(PetscObjectReference((PetscObject)Na->U));
273:   PetscCall(PetscObjectReference((PetscObject)Na->V));
274:   PetscCall(PetscObjectReference((PetscObject)c));

276:   PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc));
277:   PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL));
278:   if (size != 1) {
279:     Mat Vloc;

281:     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
282:       VecScatter sct;

284:       PetscCall(VecScatterCreateToAll(Na->c, &sct, &c));
285:       PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
286:       PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
287:       PetscCall(VecScatterDestroy(&sct));
288:       PetscCall(VecDestroy(&Na->c));
289:       Na->c = c;
290:     }
291:     PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc));
292:     PetscCall(VecDuplicate(Na->work1, &Na->work2));
293:     PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl));
294:     PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl));
295:   }

297:   /* Internally create a scaling vector if roottypes do not match */
298:   if (Na->c) {
299:     VecType rt1, rt2;

301:     PetscCall(VecGetRootType_Private(Na->work1, &rt1));
302:     PetscCall(VecGetRootType_Private(Na->c, &rt2));
303:     PetscCall(PetscStrcmp(rt1, rt2, &match));
304:     if (!match) {
305:       PetscCall(VecDuplicate(Na->c, &c));
306:       PetscCall(VecCopy(Na->c, c));
307:       PetscCall(VecDestroy(&Na->c));
308:       Na->c = c;
309:     }
310:   }

312:   (*N)->ops->destroy       = MatDestroy_LRC;
313:   (*N)->ops->mult          = MatMult_LRC;
314:   (*N)->ops->multtranspose = MatMultTranspose_LRC;

316:   (*N)->assembled    = PETSC_TRUE;
317:   (*N)->preallocated = PETSC_TRUE;

319:   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatLRCGetMats_C", MatLRCGetMats_LRC));
320:   PetscCall(MatSetUp(*N));
321:   PetscFunctionReturn(PETSC_SUCCESS);
322: }