Actual source code: vecmpicupm.hpp

  1: #ifndef PETSCVECMPICUPM_HPP
  2: #define PETSCVECMPICUPM_HPP

  4: #if defined(__cplusplus)
  5: #include <petsc/private/veccupmimpl.h>
  6: #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp>
  7: #include <../src/vec/vec/impls/mpi/pvecimpl.h>
  8: #include <petsc/private/sfimpl.h>

 10: namespace Petsc
 11: {

 13: namespace vec
 14: {

 16: namespace cupm
 17: {

 19: namespace impl
 20: {

 22: template <device::cupm::DeviceType T>
 23: class VecMPI_CUPM : public Vec_CUPMBase<T, VecMPI_CUPM<T>> {
 24: public:
 25:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecMPI_CUPM<T>);
 26:   using VecSeq_T = VecSeq_CUPM<T>;

 28: private:
 29:   PETSC_NODISCARD static Vec_MPI          *VecIMPLCast_(Vec) noexcept;
 30:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;

 32:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 33:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 34:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 35:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 37:   static PetscErrorCode CreateMPICUPM_(Vec, PetscDeviceContext, PetscBool /*allocate_missing*/ = PETSC_TRUE, PetscInt /*nghost*/ = 0, PetscScalar * /*host_array*/ = nullptr, PetscScalar * /*device_array*/ = nullptr) noexcept;

 39: public:
 40:   // callable directly via a bespoke function
 41:   static PetscErrorCode CreateMPICUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 42:   static PetscErrorCode CreateMPICUPMWithArrays(MPI_Comm, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 44:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 45:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
 46:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
 47:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
 48:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
 49:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
 50:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
 51:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
 52:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
 53:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
 54:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
 55: };

 57: template <device::cupm::DeviceType T>
 58: inline Vec_MPI *VecMPI_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 59: {
 60:   return static_cast<Vec_MPI *>(v->data);
 61: }

 63: template <device::cupm::DeviceType T>
 64: inline constexpr VecType VecMPI_CUPM<T>::VECIMPLCUPM_() noexcept
 65: {
 66:   return VECMPICUPM();
 67: }

 69: template <device::cupm::DeviceType T>
 70: inline PetscErrorCode VecMPI_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
 71: {
 72:   return VecDestroy_MPI(v);
 73: }

 75: template <device::cupm::DeviceType T>
 76: inline PetscErrorCode VecMPI_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
 77: {
 78:   return VecResetArray_MPI(v);
 79: }

 81: template <device::cupm::DeviceType T>
 82: inline PetscErrorCode VecMPI_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
 83: {
 84:   return VecPlaceArray_MPI(v, a);
 85: }

 87: template <device::cupm::DeviceType T>
 88: inline PetscErrorCode VecMPI_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *) noexcept
 89: {
 90:   PetscFunctionBegin;
 91:   if (alloc_missing) *alloc_missing = PETSC_TRUE;
 92:   // note host_array is always ignored, we never create it as part of the construction sequence
 93:   // for VecMPI since we always want to either allocate it ourselves with pinned memory or set
 94:   // it in Initialize_CUPMBase()
 95:   PetscCall(VecCreate_MPI_Private(v, PETSC_FALSE, nghost, nullptr));
 96:   PetscFunctionReturn(PETSC_SUCCESS);
 97: }

 99: template <device::cupm::DeviceType T>
100: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM_(Vec v, PetscDeviceContext dctx, PetscBool allocate_missing, PetscInt nghost, PetscScalar *host_array, PetscScalar *device_array) noexcept
101: {
102:   PetscFunctionBegin;
103:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, nghost));
104:   PetscCall(Initialize_CUPMBase(v, allocate_missing, host_array, device_array, dctx));
105:   PetscFunctionReturn(PETSC_SUCCESS);
106: }

108: // ================================================================================== //
109: //                                                                                    //
110: //                                  public methods                                    //
111: //                                                                                    //
112: // ================================================================================== //

114: // ================================================================================== //
115: //                             constructors/destructors                               //

117: // VecCreateMPICUPM()
118: template <device::cupm::DeviceType T>
119: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type) noexcept
120: {
121:   PetscFunctionBegin;
122:   PetscCall(Create_CUPMBase(comm, bs, n, N, v, call_set_type));
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: // VecCreateMPICUPMWithArray[s]()
127: template <device::cupm::DeviceType T>
128: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
129: {
130:   PetscDeviceContext dctx;

132:   PetscFunctionBegin;
133:   PetscCall(GetHandles_(&dctx));
134:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
135:   // CreateMPICUPM_() is called!
136:   PetscCall(CreateMPICUPM(comm, bs, n, N, v, PETSC_FALSE));
137:   PetscCall(CreateMPICUPM_(*v, dctx, PETSC_FALSE, 0, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
138:   PetscFunctionReturn(PETSC_SUCCESS);
139: }

141: // v->ops->duplicate
142: template <device::cupm::DeviceType T>
143: inline PetscErrorCode VecMPI_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
144: {
145:   const auto         vimpl  = VecIMPLCast(v);
146:   const auto         nghost = vimpl->nghost;
147:   PetscDeviceContext dctx;

149:   PetscFunctionBegin;
150:   PetscCall(GetHandles_(&dctx));
151:   // does not call VecSetType(), we set up the data structures ourselves
152:   PetscCall(Duplicate_CUPMBase(v, y, dctx, [=](Vec z) { return CreateMPICUPM_(z, dctx, PETSC_FALSE, nghost); }));

154:   /* save local representation of the parallel vector (and scatter) if it exists */
155:   if (const auto locrep = vimpl->localrep) {
156:     const auto   yimpl   = VecIMPLCast(*y);
157:     auto        &ylocrep = yimpl->localrep;
158:     PetscScalar *array;

160:     PetscCall(VecGetArray(*y, &array));
161:     PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, std::abs(v->map->bs), v->map->n + nghost, array, &ylocrep));
162:     PetscCall(VecRestoreArray(*y, &array));
163:     PetscCall(PetscArraycpy(ylocrep->ops, locrep->ops, 1));
164:     if (auto &scatter = (yimpl->localupdate = vimpl->localupdate)) PetscCall(PetscObjectReference(PetscObjectCast(scatter)));
165:   }
166:   PetscFunctionReturn(PETSC_SUCCESS);
167: }

169: // v->ops->bintocpu
170: template <device::cupm::DeviceType T>
171: inline PetscErrorCode VecMPI_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
172: {
173:   PetscDeviceContext dctx;

175:   PetscFunctionBegin;
176:   PetscCall(GetHandles_(&dctx));
177:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

179:   VecSetOp_CUPM(dot, VecDot_MPI, Dot);
180:   VecSetOp_CUPM(mdot, VecMDot_MPI, MDot);
181:   VecSetOp_CUPM(norm, VecNorm_MPI, Norm);
182:   VecSetOp_CUPM(tdot, VecTDot_MPI, TDot);
183:   VecSetOp_CUPM(resetarray, VecResetArray_MPI, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
184:   VecSetOp_CUPM(placearray, VecPlaceArray_MPI, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
185:   VecSetOp_CUPM(max, VecMax_MPI, Max);
186:   VecSetOp_CUPM(min, VecMin_MPI, Min);
187:   PetscFunctionReturn(PETSC_SUCCESS);
188: }

190: // ================================================================================== //
191: //                                   compute methods                                  //

193: template <device::cupm::DeviceType T>
194: inline PetscErrorCode VecMPI_CUPM<T>::Norm(Vec v, NormType type, PetscReal *z) noexcept
195: {
196:   PetscFunctionBegin;
197:   PetscCall(VecNorm_MPI_Default(v, type, z, VecSeq_T::Norm));
198:   PetscFunctionReturn(PETSC_SUCCESS);
199: }

201: template <device::cupm::DeviceType T>
202: inline PetscErrorCode VecMPI_CUPM<T>::Dot(Vec x, Vec y, PetscScalar *z) noexcept
203: {
204:   PetscFunctionBegin;
205:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::Dot));
206:   PetscFunctionReturn(PETSC_SUCCESS);
207: }

209: template <device::cupm::DeviceType T>
210: inline PetscErrorCode VecMPI_CUPM<T>::TDot(Vec x, Vec y, PetscScalar *z) noexcept
211: {
212:   PetscFunctionBegin;
213:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::TDot));
214:   PetscFunctionReturn(PETSC_SUCCESS);
215: }

217: template <device::cupm::DeviceType T>
218: inline PetscErrorCode VecMPI_CUPM<T>::MDot(Vec x, PetscInt nv, const Vec y[], PetscScalar *z) noexcept
219: {
220:   PetscFunctionBegin;
221:   PetscCall(VecMXDot_MPI_Default(x, nv, y, z, VecSeq_T::MDot));
222:   PetscFunctionReturn(PETSC_SUCCESS);
223: }

225: template <device::cupm::DeviceType T>
226: inline PetscErrorCode VecMPI_CUPM<T>::DotNorm2(Vec x, Vec y, PetscScalar *dp, PetscScalar *nm) noexcept
227: {
228:   PetscFunctionBegin;
229:   PetscCall(VecDotNorm2_MPI_Default(x, y, dp, nm, VecSeq_T::DotNorm2));
230:   PetscFunctionReturn(PETSC_SUCCESS);
231: }

233: template <device::cupm::DeviceType T>
234: inline PetscErrorCode VecMPI_CUPM<T>::Max(Vec x, PetscInt *idx, PetscReal *z) noexcept
235: {
236:   const MPI_Op ops[] = {MPIU_MAXLOC, MPIU_MAX};

238:   PetscFunctionBegin;
239:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Max, ops));
240:   PetscFunctionReturn(PETSC_SUCCESS);
241: }

243: template <device::cupm::DeviceType T>
244: inline PetscErrorCode VecMPI_CUPM<T>::Min(Vec x, PetscInt *idx, PetscReal *z) noexcept
245: {
246:   const MPI_Op ops[] = {MPIU_MINLOC, MPIU_MIN};

248:   PetscFunctionBegin;
249:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Min, ops));
250:   PetscFunctionReturn(PETSC_SUCCESS);
251: }

253: template <device::cupm::DeviceType T>
254: inline PetscErrorCode VecMPI_CUPM<T>::SetPreallocationCOO(Vec x, PetscCount ncoo, const PetscInt coo_i[]) noexcept
255: {
256:   PetscFunctionBegin;
257:   PetscCall(VecSetPreallocationCOO_MPI(x, ncoo, coo_i));
258:   // both of these must exist for this to work
259:   PetscCall(VecCUPMAllocateCheck_(x));
260:   {
261:     const auto vcu  = VecCUPMCast(x);
262:     const auto vmpi = VecIMPLCast(x);

264:     // clang-format off
265:     PetscCall(
266:       SetPreallocationCOO_CUPMBase(
267:         x, ncoo, coo_i,
268:         util::make_array(
269:           make_coo_pair(vcu->imap2_d, vmpi->imap2, vmpi->nnz2),
270:           make_coo_pair(vcu->jmap2_d, vmpi->jmap2, vmpi->nnz2 + 1),
271:           make_coo_pair(vcu->perm2_d, vmpi->perm2, vmpi->recvlen),
272:           make_coo_pair(vcu->Cperm_d, vmpi->Cperm, vmpi->sendlen)
273:         ),
274:         util::make_array(
275:           make_coo_pair(vcu->sendbuf_d, vmpi->sendbuf, vmpi->sendlen),
276:           make_coo_pair(vcu->recvbuf_d, vmpi->recvbuf, vmpi->recvlen)
277:         )
278:       )
279:     );
280:     // clang-format on
281:   }
282:   PetscFunctionReturn(PETSC_SUCCESS);
283: }

285: namespace kernels
286: {

288: namespace
289: {

291: PETSC_KERNEL_DECL void pack_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz, const PetscCount *PETSC_RESTRICT perm, PetscScalar *PETSC_RESTRICT buf)
292: {
293:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(nnz, [=](PetscCount i) { buf[i] = vv[perm[i]]; });
294:   return;
295: }

297: PETSC_KERNEL_DECL void add_remote_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz2, const PetscCount *PETSC_RESTRICT imap2, const PetscCount *PETSC_RESTRICT jmap2, const PetscCount *PETSC_RESTRICT perm2, PetscScalar *PETSC_RESTRICT xv)
298: {
299:   add_coo_values_impl(vv, nnz2, jmap2, perm2, ADD_VALUES, xv, [=](PetscCount i) { return imap2[i]; });
300:   return;
301: }

303: } // namespace

305:   #if PetscDefined(USING_HCC)
306: namespace do_not_use
307: {

309: // Needed to silence clang warning:
310: //
311: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
312: //
313: // The warning is silly, since the function *is* used, however the host compiler does not
314: // appear see this. Likely because the function using it is in a template.
315: //
316: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
317: inline void silence_warning_function_pack_coo_values_is_not_needed_and_will_not_be_emitted()
318: {
319:   (void)pack_coo_values;
320: }

322: inline void silence_warning_function_add_remote_coo_values_is_not_needed_and_will_not_be_emitted()
323: {
324:   (void)add_remote_coo_values;
325: }

327: } // namespace do_not_use
328:   #endif

330: } // namespace kernels

332: template <device::cupm::DeviceType T>
333: inline PetscErrorCode VecMPI_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
334: {
335:   PetscDeviceContext dctx;
336:   PetscMemType       v_memtype;
337:   cupmStream_t       stream;

339:   PetscFunctionBegin;
340:   PetscCall(GetHandles_(&dctx, &stream));
341:   PetscCall(PetscGetMemType(v, &v_memtype));
342:   {
343:     const auto vmpi      = VecIMPLCast(x);
344:     const auto vcu       = VecCUPMCast(x);
345:     const auto sf        = vmpi->coo_sf;
346:     const auto sendbuf_d = vcu->sendbuf_d;
347:     const auto recvbuf_d = vcu->recvbuf_d;
348:     const auto xv        = imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data();
349:     auto       vv        = const_cast<PetscScalar *>(v);

351:     if (PetscMemTypeHost(v_memtype)) {
352:       const auto size = vmpi->coo_n;

354:       /* If user gave v[] in host, we might need to copy it to device if any */
355:       PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
356:       PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
357:     }

359:     /* Pack entries to be sent to remote */
360:     if (const auto sendlen = vmpi->sendlen) {
361:       PetscCall(PetscCUPMLaunchKernel1D(sendlen, 0, stream, kernels::pack_coo_values, vv, sendlen, vcu->Cperm_d, sendbuf_d));
362:       // need to sync up here since we are about to send this to petscsf
363:       // REVIEW ME: no we dont, sf just needs to learn to use PetscDeviceContext
364:       PetscCallCUPM(cupmStreamSynchronize(stream));
365:     }

367:     PetscCall(PetscSFReduceWithMemTypeBegin(sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), sendbuf_d, PETSC_MEMTYPE_CUPM(), recvbuf_d, MPI_REPLACE));

369:     if (const auto n = x->map->n) PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, xv));

371:     PetscCall(PetscSFReduceEnd(sf, MPIU_SCALAR, sendbuf_d, recvbuf_d, MPI_REPLACE));

373:     /* Add received remote entries */
374:     if (const auto nnz2 = vmpi->nnz2) PetscCall(PetscCUPMLaunchKernel1D(nnz2, 0, stream, kernels::add_remote_coo_values, recvbuf_d, nnz2, vcu->imap2_d, vcu->jmap2_d, vcu->perm2_d, xv));

376:     if (PetscMemTypeHost(v_memtype)) PetscCall(PetscDeviceFree(dctx, vv));
377:     PetscCall(PetscDeviceContextSynchronize(dctx));
378:   }
379:   PetscFunctionReturn(PETSC_SUCCESS);
380: }

382: namespace
383: {

385: template <device::cupm::DeviceType T>
386: inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
387: {
388:   PetscFunctionBegin;
390:   PetscCall(VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
391:   PetscFunctionReturn(PETSC_SUCCESS);
392: }

394: template <device::cupm::DeviceType T>
395: inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
396: {
397:   PetscFunctionBegin;
400:   PetscCall(VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
401:   PetscFunctionReturn(PETSC_SUCCESS);
402: }

404: template <device::cupm::DeviceType T>
405: inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
406: {
407:   PetscFunctionBegin;
408:   PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
409:   PetscFunctionReturn(PETSC_SUCCESS);
410: }

412: } // anonymous namespace

414: } // namespace impl

416: namespace
417: {

419: template <device::cupm::DeviceType T>
420: inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
421: {
422:   PetscFunctionBegin;
424:   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
425:   PetscFunctionReturn(PETSC_SUCCESS);
426: }

428: template <device::cupm::DeviceType T>
429: inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
430: {
431:   PetscFunctionBegin;
434:   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
435:   PetscFunctionReturn(PETSC_SUCCESS);
436: }

438: template <device::cupm::DeviceType T>
439: inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
440: {
441:   PetscFunctionBegin;
442:   PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
443:   PetscFunctionReturn(PETSC_SUCCESS);
444: }

446: } // anonymous namespace

448: } // namespace cupm

450: } // namespace vec

452: } // namespace Petsc

454: #endif // __cplusplus

456: #endif // PETSCVECMPICUPM_HPP