Actual source code: submat.c


  2: #include <petsc/private/matimpl.h>

  4: typedef struct {
  5:   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
  6:   Vec        lwork, rwork;   /* work vectors inside the scatters */
  7:   Vec        lwork2, rwork2; /* work vectors inside the scatters */
  8:   VecScatter lrestrict, rprolong;
  9:   Mat        A;
 10: } Mat_SubVirtual;

 12: static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
 13: {
 14:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 16:   PetscFunctionBegin;
 17:   PetscCall(MatScale(Na->A, a));
 18:   PetscFunctionReturn(PETSC_SUCCESS);
 19: }

 21: static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
 22: {
 23:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 25:   PetscFunctionBegin;
 26:   PetscCall(MatShift(Na->A, a));
 27:   PetscFunctionReturn(PETSC_SUCCESS);
 28: }

 30: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
 31: {
 32:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 34:   PetscFunctionBegin;
 35:   if (right) {
 36:     PetscCall(VecZeroEntries(Na->rwork));
 37:     PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 38:     PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 39:   }
 40:   if (left) {
 41:     PetscCall(VecZeroEntries(Na->lwork));
 42:     PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 43:     PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 44:   }
 45:   PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL));
 46:   PetscFunctionReturn(PETSC_SUCCESS);
 47: }

 49: static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
 50: {
 51:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 53:   PetscFunctionBegin;
 54:   PetscCall(MatGetDiagonal(Na->A, Na->rwork));
 55:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
 56:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
 57:   PetscFunctionReturn(PETSC_SUCCESS);
 58: }

 60: static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
 61: {
 62:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 64:   PetscFunctionBegin;
 65:   PetscCall(VecZeroEntries(Na->rwork));
 66:   PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 67:   PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 68:   PetscCall(MatMult(Na->A, Na->rwork, Na->lwork));
 69:   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
 70:   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
 71:   PetscFunctionReturn(PETSC_SUCCESS);
 72: }

 74: static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
 75: {
 76:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 78:   PetscFunctionBegin;
 79:   PetscCall(VecZeroEntries(Na->rwork));
 80:   PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 81:   PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 82:   if (v1 == v2) {
 83:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork));
 84:   } else if (v2 == v3) {
 85:     PetscCall(VecZeroEntries(Na->lwork));
 86:     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 87:     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 88:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork));
 89:   } else {
 90:     if (!Na->lwork2) {
 91:       PetscCall(VecDuplicate(Na->lwork, &Na->lwork2));
 92:     } else {
 93:       PetscCall(VecZeroEntries(Na->lwork2));
 94:     }
 95:     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
 96:     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
 97:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork));
 98:   }
 99:   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
100:   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
101:   PetscFunctionReturn(PETSC_SUCCESS);
102: }

104: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
105: {
106:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

108:   PetscFunctionBegin;
109:   PetscCall(VecZeroEntries(Na->lwork));
110:   PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
111:   PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
112:   PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork));
113:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
114:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
115:   PetscFunctionReturn(PETSC_SUCCESS);
116: }

118: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
119: {
120:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

122:   PetscFunctionBegin;
123:   PetscCall(VecZeroEntries(Na->lwork));
124:   PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
125:   PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
126:   if (v1 == v2) {
127:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork));
128:   } else if (v2 == v3) {
129:     PetscCall(VecZeroEntries(Na->rwork));
130:     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
131:     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
132:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork));
133:   } else {
134:     if (!Na->rwork2) {
135:       PetscCall(VecDuplicate(Na->rwork, &Na->rwork2));
136:     } else {
137:       PetscCall(VecZeroEntries(Na->rwork2));
138:     }
139:     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
140:     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
141:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork));
142:   }
143:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
144:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
145:   PetscFunctionReturn(PETSC_SUCCESS);
146: }

148: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
149: {
150:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

152:   PetscFunctionBegin;
153:   PetscCall(ISDestroy(&Na->isrow));
154:   PetscCall(ISDestroy(&Na->iscol));
155:   PetscCall(VecDestroy(&Na->lwork));
156:   PetscCall(VecDestroy(&Na->rwork));
157:   PetscCall(VecDestroy(&Na->lwork2));
158:   PetscCall(VecDestroy(&Na->rwork2));
159:   PetscCall(VecScatterDestroy(&Na->lrestrict));
160:   PetscCall(VecScatterDestroy(&Na->rprolong));
161:   PetscCall(MatDestroy(&Na->A));
162:   PetscCall(PetscFree(N->data));
163:   PetscFunctionReturn(PETSC_SUCCESS);
164: }

166: /*@
167:    MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix

169:    Collective

171:    Input Parameters:
172: +  A - matrix that we will extract a submatrix of
173: .  isrow - rows to be present in the submatrix
174: -  iscol - columns to be present in the submatrix

176:    Output Parameter:
177: .  newmat - new matrix

179:    Level: developer

181:    Note:
182:    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

184: .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
185: @*/
186: PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
187: {
188:   Vec             left, right;
189:   PetscInt        m, n;
190:   Mat             N;
191:   Mat_SubVirtual *Na;

193:   PetscFunctionBegin;
198:   *newmat = NULL;

200:   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N));
201:   PetscCall(ISGetLocalSize(isrow, &m));
202:   PetscCall(ISGetLocalSize(iscol, &n));
203:   PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE));
204:   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX));

206:   PetscCall(PetscNew(&Na));
207:   N->data = (void *)Na;

209:   PetscCall(PetscObjectReference((PetscObject)isrow));
210:   PetscCall(PetscObjectReference((PetscObject)iscol));
211:   Na->isrow = isrow;
212:   Na->iscol = iscol;

214:   PetscCall(PetscFree(N->defaultvectype));
215:   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
216:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
217:      the reference count of the context. This is a problem if A is already of type MATSHELL */
218:   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));

220:   N->ops->destroy          = MatDestroy_SubMatrix;
221:   N->ops->mult             = MatMult_SubMatrix;
222:   N->ops->multadd          = MatMultAdd_SubMatrix;
223:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
224:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
225:   N->ops->scale            = MatScale_SubMatrix;
226:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
227:   N->ops->shift            = MatShift_SubMatrix;
228:   N->ops->convert          = MatConvert_Shell;
229:   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;

231:   PetscCall(MatSetBlockSizesFromMats(N, A, A));
232:   PetscCall(PetscLayoutSetUp(N->rmap));
233:   PetscCall(PetscLayoutSetUp(N->cmap));

235:   PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork));
236:   PetscCall(MatCreateVecs(N, &right, &left));
237:   PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict));
238:   PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong));
239:   PetscCall(VecDestroy(&left));
240:   PetscCall(VecDestroy(&right));
241:   PetscCall(MatSetUp(N));

243:   N->assembled = PETSC_TRUE;
244:   *newmat      = N;
245:   PetscFunctionReturn(PETSC_SUCCESS);
246: }

248: /*MC
249:    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix

251:   Level: advanced

253:    Developer Note:
254:    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` name should likely be changed to
255:    `MATSUBMATRIXVIRTUAL`

257: .seealso: [](ch_matrices), `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
258: M*/

260: /*@
261:    MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix

263:    Collective

265:    Input Parameters:
266: +  N - submatrix to update
267: .  A - full matrix in the submatrix
268: .  isrow - rows in the update (same as the first time the submatrix was created)
269: -  iscol - columns in the update (same as the first time the submatrix was created)

271:    Level: developer

273:    Note:
274:    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

276: .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
277: @*/
278: PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
279: {
280:   PetscBool       flg;
281:   Mat_SubVirtual *Na;

283:   PetscFunctionBegin;
288:   PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg));
289:   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type");

291:   Na = (Mat_SubVirtual *)N->data;
292:   PetscCall(ISEqual(isrow, Na->isrow, &flg));
293:   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices");
294:   PetscCall(ISEqual(iscol, Na->iscol, &flg));
295:   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices");

297:   PetscCall(PetscFree(N->defaultvectype));
298:   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
299:   PetscCall(MatDestroy(&Na->A));
300:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
301:      the reference count of the context. This is a problem if A is already of type MATSHELL */
302:   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }