Actual source code: mpimattransposematmult.c
2: /*
3: Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4: C = A^T * B
5: The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6: */
7: #include <../src/mat/impls/aij/seq/aij.h>
8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
9: #include <../src/mat/impls/dense/mpi/mpidense.h>
11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
12: {
13: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
15: PetscFunctionBegin;
16: PetscCall(MatDestroy(&atb->mA));
17: PetscCall(VecDestroy(&atb->bt));
18: PetscCall(VecDestroy(&atb->ct));
19: PetscCall(PetscFree(atb));
20: PetscFunctionReturn(PETSC_SUCCESS);
21: }
23: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
25: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
26: {
27: Mat_MatTransMatMult *atb;
28: PetscBool cisdense;
30: PetscFunctionBegin;
31: MatCheckProduct(C, 4);
32: PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
34: /* create output dense matrix C = A^T*B */
35: PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
36: PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
37: if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
38: PetscCall(MatSetUp(C));
40: /* create additional data structure for the product */
41: PetscCall(PetscNew(&atb));
42: if (B->cmap->N) {
43: PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
44: if (!atb->mA->assembled) {
45: PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
46: PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
47: }
48: PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
49: }
50: C->product->data = atb;
51: C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
53: C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
54: PetscFunctionReturn(PETSC_SUCCESS);
55: }
57: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
58: {
59: const PetscScalar *Barray, *ctarray;
60: PetscScalar *Carray, *btarray;
61: PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
62: Mat_MatTransMatMult *atb;
63: Vec bt, ct;
65: PetscFunctionBegin;
66: MatCheckProduct(C, 3);
67: atb = (Mat_MatTransMatMult *)C->product->data;
68: PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
69: if (!BN) {
70: PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
71: PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
72: PetscFunctionReturn(PETSC_SUCCESS);
73: }
74: bt = atb->bt;
75: ct = atb->ct;
77: /* transpose local array of B, then copy it to vector bt */
78: PetscCall(MatDenseGetArrayRead(B, &Barray));
79: PetscCall(MatDenseGetLDA(B, &ldb));
80: PetscCall(VecGetArray(bt, &btarray));
81: for (j = 0; j < BN; j++)
82: for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
83: PetscCall(VecRestoreArray(bt, &btarray));
84: PetscCall(MatDenseRestoreArrayRead(B, &Barray));
86: /* compute ct = mA^T * cb */
87: PetscCall(MatMultTranspose(atb->mA, bt, ct));
89: /* transpose local array of ct to matrix C */
90: PetscCall(MatDenseGetArray(C, &Carray));
91: PetscCall(MatDenseGetLDA(C, &ldc));
92: PetscCall(VecGetArrayRead(ct, &ctarray));
93: for (j = 0; j < BN; j++)
94: for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
95: PetscCall(VecRestoreArrayRead(ct, &ctarray));
96: PetscCall(MatDenseRestoreArray(C, &Carray));
97: PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
98: PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
99: PetscFunctionReturn(PETSC_SUCCESS);
100: }