Actual source code: matmpidensecupm.hpp

  1: #ifndef PETSCMATMPIDENSECUPM_HPP
  2: #define PETSCMATMPIDENSECUPM_HPP

  4: #include <petsc/private/matdensecupmimpl.h>
  5: #include <../src/mat/impls/dense/mpi/mpidense.h>

  7: #ifdef __cplusplus
  8: #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
  9: #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>

 11: namespace Petsc
 12: {

 14: namespace mat
 15: {

 17: namespace cupm
 18: {

 20: namespace impl
 21: {

 23: template <device::cupm::DeviceType T>
 24: class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
 25: public:
 26:   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);

 28: private:
 29:   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
 30:   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;

 32:   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;

 34:   template <bool to_host>
 35:   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;

 37: public:
 38:   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;

 40:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
 41:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;

 43:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
 44:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;

 46:   static PetscErrorCode Create(Mat) noexcept;

 48:   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
 49:   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
 50:   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;

 52:   template <PetscMemType, PetscMemoryAccessMode>
 53:   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
 54:   template <PetscMemType, PetscMemoryAccessMode>
 55:   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;

 57: private:
 58:   template <PetscMemType mtype, PetscMemoryAccessMode mode>
 59:   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
 60:   {
 61:     return GetArray<mtype, mode>(m, p);
 62:   }

 64:   template <PetscMemType mtype, PetscMemoryAccessMode mode>
 65:   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
 66:   {
 67:     return RestoreArray<mtype, mode>(m, p);
 68:   }

 70: public:
 71:   template <PetscMemoryAccessMode>
 72:   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
 73:   template <PetscMemoryAccessMode>
 74:   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;

 76:   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
 77:   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
 78:   static PetscErrorCode ResetArray(Mat) noexcept;

 80:   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
 81: };

 83: } // namespace impl

 85: namespace
 86: {

 88: // Declare this here so that the functions below can make use of it
 89: template <device::cupm::DeviceType T>
 90: inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
 91: {
 92:   PetscFunctionBegin;
 93:   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
 94:   PetscFunctionReturn(PETSC_SUCCESS);
 95: }

 97: } // anonymous namespace

 99: namespace impl
100: {

102: // ==========================================================================================
103: // MatDense_MPI_CUPM -- Private API
104: // ==========================================================================================

106: template <device::cupm::DeviceType T>
107: inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
108: {
109:   return static_cast<Mat_MPIDense *>(m->data);
110: }

112: template <device::cupm::DeviceType T>
113: inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
114: {
115:   return MATMPIDENSECUPM();
116: }

118: // ==========================================================================================

120: template <device::cupm::DeviceType T>
121: inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
122: {
123:   PetscFunctionBegin;
124:   if (auto &mimplA = MatIMPLCast(A)->A) {
125:     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
126:     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
127:   } else {
128:     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
129:   }
130:   PetscFunctionReturn(PETSC_SUCCESS);
131: }

133: template <device::cupm::DeviceType T>
134: template <bool to_host>
135: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
136: {
137:   PetscFunctionBegin;
138:   if (reuse == MAT_INITIAL_MATRIX) {
139:     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
140:   } else if (reuse == MAT_REUSE_MATRIX) {
141:     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
142:   }
143:   {
144:     const auto B    = *newmat;
145:     const auto pobj = PetscObjectCast(B);

147:     if (to_host) {
148:       PetscCall(BindToCPU(B, PETSC_TRUE));
149:     } else {
150:       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
151:     }

153:     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
154:     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));

156:     // ============================================================
157:     // Composed Ops
158:     // ============================================================
159:     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
160:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
161:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
162:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
163:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
164:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
165:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
166:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
167:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
168:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
169:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
170:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
171:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
172:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);

174:     if (to_host) {
175:       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
176:       B->offloadmask = PETSC_OFFLOAD_CPU;
177:     } else {
178:       if (auto &m_A = MatIMPLCast(B)->A) {
179:         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
180:         B->offloadmask = PETSC_OFFLOAD_BOTH;
181:       } else {
182:         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
183:       }
184:       PetscCall(BindToCPU(B, PETSC_FALSE));
185:     }

187:     // ============================================================
188:     // Function Pointer Ops
189:     // ============================================================
190:     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
191:   }
192:   PetscFunctionReturn(PETSC_SUCCESS);
193: }

195: // ==========================================================================================
196: // MatDense_MPI_CUPM -- Public API
197: // ==========================================================================================

199: template <device::cupm::DeviceType T>
200: inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
201: {
202:   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
203: }

205: template <device::cupm::DeviceType T>
206: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
207: {
208:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
209: }

211: template <device::cupm::DeviceType T>
212: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
213: {
214:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
215: }

217: template <device::cupm::DeviceType T>
218: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
219: {
220:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
221: }

223: template <device::cupm::DeviceType T>
224: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
225: {
226:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
227: }

229: // ==========================================================================================

231: template <device::cupm::DeviceType T>
232: inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
233: {
234:   PetscFunctionBegin;
235:   PetscCall(MatCreate_MPIDense(A));
236:   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
237:   PetscFunctionReturn(PETSC_SUCCESS);
238: }

240: // ==========================================================================================

242: template <device::cupm::DeviceType T>
243: inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
244: {
245:   const auto mimpl = MatIMPLCast(A);
246:   const auto pobj  = PetscObjectCast(A);

248:   PetscFunctionBegin;
249:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
250:   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
251:   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
252:   A->boundtocpu = usehost;
253:   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
254:   if (!usehost) {
255:     PetscBool iscupm;

257:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
258:     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
259:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
260:     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
261:   }

263:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
264:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
265:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
266:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
267:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
268:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);

270:   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);

272:   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
273:   PetscFunctionReturn(PETSC_SUCCESS);
274: }

276: template <device::cupm::DeviceType T>
277: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
278: {
279:   PetscFunctionBegin;
280:   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
281:   PetscFunctionReturn(PETSC_SUCCESS);
282: }

284: template <device::cupm::DeviceType T>
285: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
286: {
287:   PetscFunctionBegin;
288:   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
289:   PetscFunctionReturn(PETSC_SUCCESS);
290: }

292: // ==========================================================================================

294: template <device::cupm::DeviceType T>
295: template <PetscMemType, PetscMemoryAccessMode access>
296: inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
297: {
298:   PetscFunctionBegin;
299:   PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array));
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: template <device::cupm::DeviceType T>
304: template <PetscMemType, PetscMemoryAccessMode access>
305: inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
306: {
307:   PetscFunctionBegin;
308:   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
309:   PetscFunctionReturn(PETSC_SUCCESS);
310: }

312: // ==========================================================================================

314: template <device::cupm::DeviceType T>
315: template <PetscMemoryAccessMode access>
316: inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
317: {
318:   using namespace vec::cupm;

320:   const auto mimpl   = MatIMPLCast(A);
321:   const auto mimpl_A = mimpl->A;
322:   const auto pobj    = PetscObjectCast(A);
323:   auto      &cvec    = mimpl->cvec;
324:   PetscInt   lda;

326:   PetscFunctionBegin;
327:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
328:   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
329:   mimpl->vecinuse = col + 1;

331:   if (!cvec) PetscCall(VecCreateMPICUPMWithArray<T>(PetscObjectComm(pobj), A->rmap->bs, A->rmap->n, A->rmap->N, nullptr, &cvec));

333:   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
334:   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
335:   PetscCall(VecCUPMPlaceArrayAsync<T>(cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));

337:   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(cvec));
338:   *v = cvec;
339:   PetscFunctionReturn(PETSC_SUCCESS);
340: }

342: template <device::cupm::DeviceType T>
343: template <PetscMemoryAccessMode access>
344: inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
345: {
346:   using namespace vec::cupm;

348:   const auto mimpl = MatIMPLCast(A);
349:   const auto cvec  = mimpl->cvec;

351:   PetscFunctionBegin;
352:   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
353:   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
354:   mimpl->vecinuse = 0;

356:   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
357:   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
358:   PetscCall(VecCUPMResetArrayAsync<T>(cvec));

360:   if (v) *v = nullptr;
361:   PetscFunctionReturn(PETSC_SUCCESS);
362: }

364: // ==========================================================================================

366: template <device::cupm::DeviceType T>
367: inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
368: {
369:   const auto mimpl = MatIMPLCast(A);

371:   PetscFunctionBegin;
372:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
373:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
374:   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
375:   PetscFunctionReturn(PETSC_SUCCESS);
376: }

378: template <device::cupm::DeviceType T>
379: inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
380: {
381:   const auto mimpl = MatIMPLCast(A);

383:   PetscFunctionBegin;
384:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
385:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
386:   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
387:   PetscFunctionReturn(PETSC_SUCCESS);
388: }

390: template <device::cupm::DeviceType T>
391: inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
392: {
393:   const auto mimpl = MatIMPLCast(A);

395:   PetscFunctionBegin;
396:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
397:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
398:   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
399:   PetscFunctionReturn(PETSC_SUCCESS);
400: }

402: // ==========================================================================================

404: template <device::cupm::DeviceType T>
405: inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
406: {
407:   PetscDeviceContext dctx;

409:   PetscFunctionBegin;
410:   PetscCall(GetHandles_(&dctx));
411:   PetscCall(PetscInfo(A, "Performing Shift on backend\n"));
412:   PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha)));
413:   PetscFunctionReturn(PETSC_SUCCESS);
414: }

416: } // namespace impl

418: namespace
419: {

421: template <device::cupm::DeviceType T>
422: inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
423: {
424:   PetscMPIInt size;

426:   PetscFunctionBegin;
428:   PetscCallMPI(MPI_Comm_size(comm, &size));
429:   if (size > 1) {
430:     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
431:   } else {
432:     if (n == PETSC_DECIDE) n = N;
433:     if (m == PETSC_DECIDE) m = M;
434:     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
435:     // the line
436:     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
437:   }
438:   PetscFunctionReturn(PETSC_SUCCESS);
439: }

441: } // anonymous namespace

443: } // namespace cupm

445: } // namespace mat

447: } // namespace Petsc

449: #endif // __cplusplus

451: #endif // PETSCMATMPIDENSECUPM_HPP