Actual source code: mpimattransposematmult.c


  2: /*
  3:   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
  4:           C = A^T * B
  5:   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
  6: */
  7: #include <../src/mat/impls/aij/seq/aij.h>
  8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  9: #include <../src/mat/impls/dense/mpi/mpidense.h>

 11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
 12: {
 13:   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;

 15:   PetscFunctionBegin;
 16:   PetscCall(MatDestroy(&atb->mA));
 17:   PetscCall(VecDestroy(&atb->bt));
 18:   PetscCall(VecDestroy(&atb->ct));
 19:   PetscCall(PetscFree(atb));
 20:   PetscFunctionReturn(PETSC_SUCCESS);
 21: }

 23: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);

 25: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
 26: {
 27:   Mat_MatTransMatMult *atb;
 28:   PetscBool            cisdense;

 30:   PetscFunctionBegin;
 31:   MatCheckProduct(C, 4);
 32:   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");

 34:   /* create output dense matrix C = A^T*B */
 35:   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
 36:   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
 37:   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
 38:   PetscCall(MatSetUp(C));

 40:   /* create additional data structure for the product */
 41:   PetscCall(PetscNew(&atb));
 42:   if (B->cmap->N) {
 43:     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
 44:     if (!atb->mA->assembled) {
 45:       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
 46:       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
 47:     }
 48:     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
 49:   }
 50:   C->product->data    = atb;
 51:   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;

 53:   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
 54:   PetscFunctionReturn(PETSC_SUCCESS);
 55: }

 57: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
 58: {
 59:   const PetscScalar   *Barray, *ctarray;
 60:   PetscScalar         *Carray, *btarray;
 61:   PetscInt             i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
 62:   Mat_MatTransMatMult *atb;
 63:   Vec                  bt, ct;

 65:   PetscFunctionBegin;
 66:   MatCheckProduct(C, 3);
 67:   atb = (Mat_MatTransMatMult *)C->product->data;
 68:   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
 69:   if (!BN) {
 70:     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
 71:     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
 72:     PetscFunctionReturn(PETSC_SUCCESS);
 73:   }
 74:   bt = atb->bt;
 75:   ct = atb->ct;

 77:   /* transpose local array of B, then copy it to vector bt */
 78:   PetscCall(MatDenseGetArrayRead(B, &Barray));
 79:   PetscCall(MatDenseGetLDA(B, &ldb));
 80:   PetscCall(VecGetArray(bt, &btarray));
 81:   for (j = 0; j < BN; j++)
 82:     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
 83:   PetscCall(VecRestoreArray(bt, &btarray));
 84:   PetscCall(MatDenseRestoreArrayRead(B, &Barray));

 86:   /* compute ct = mA^T * cb */
 87:   PetscCall(MatMultTranspose(atb->mA, bt, ct));

 89:   /* transpose local array of ct to matrix C */
 90:   PetscCall(MatDenseGetArray(C, &Carray));
 91:   PetscCall(MatDenseGetLDA(C, &ldc));
 92:   PetscCall(VecGetArrayRead(ct, &ctarray));
 93:   for (j = 0; j < BN; j++)
 94:     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
 95:   PetscCall(VecRestoreArrayRead(ct, &ctarray));
 96:   PetscCall(MatDenseRestoreArray(C, &Carray));
 97:   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
 98:   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
 99:   PetscFunctionReturn(PETSC_SUCCESS);
100: }