Actual source code: cupmcontext.hpp

  1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
  2: #define PETSCDEVICECONTEXTCUPM_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupmsolverinterface.hpp>
  6: #include <petsc/private/logimpl.h>

  8: #include <petsc/private/cpp/array.hpp>

 10: #include "../segmentedmempool.hpp"
 11: #include "cupmallocator.hpp"
 12: #include "cupmstream.hpp"
 13: #include "cupmevent.hpp"

 15: #if defined(__cplusplus)

 17: namespace Petsc
 18: {

 20: namespace device
 21: {

 23: namespace cupm
 24: {

 26: namespace impl
 27: {

 29: template <DeviceType T>
 30: class DeviceContext : SolverInterface<T> {
 31: public:
 32:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 34: private:
 35:   template <typename H, std::size_t>
 36:   struct HandleTag {
 37:     using type = H;
 38:   };

 40:   using stream_tag = HandleTag<cupmStream_t, 0>;
 41:   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
 42:   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;

 44:   using stream_type = CUPMStream<T>;
 45:   using event_type  = CUPMEvent<T>;

 47: public:
 48:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 49:   // header, but since we are using the power of templates it must be declared part of
 50:   // this class to have easy access the same typedefs. Technically one can make a
 51:   // templated struct outside the class but it's more code for the same result.
 52:   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
 53:     stream_type stream{};
 54:     cupmEvent_t event{};
 55:     cupmEvent_t begin{}; // timer-only
 56:     cupmEvent_t end{};   // timer-only
 57:   #if PetscDefined(USE_DEBUG)
 58:     PetscBool timerInUse{};
 59:   #endif
 60:     cupmBlasHandle_t   blas{};
 61:     cupmSolverHandle_t solver{};

 63:     constexpr PetscDeviceContext_IMPLS() noexcept = default;

 65:     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }

 67:     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }

 69:     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
 70:   };

 72: private:
 73:   static bool initialized_;

 75:   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 76:   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 78:   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }

 80:   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }

 82:   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }

 84:   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }

 86:   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
 87:   // handles
 88:   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }

 90:   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
 91:   {
 92:     const auto dci    = impls_cast_(dctx);
 93:     auto      &handle = blashandles_[dctx->device->deviceId];

 95:     PetscFunctionBegin;
 96:     if (!handle) {
 97:       PetscLogEvent event;

 99:       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
100:       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
101:       for (auto i = 0; i < 3; ++i) {
102:         const auto cberr = cupmBlasCreate(handle.ptr_to());
103:         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
104:         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
105:         if (i != 2) {
106:           PetscCall(PetscSleep(3));
107:           continue;
108:         }
109:         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
110:       }
111:       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
112:       PetscCall(PetscLogEventResume_Internal(event));
113:     }
114:     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
115:     dci->blas = handle;
116:     PetscFunctionReturn(PETSC_SUCCESS);
117:   }

119:   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
120:   {
121:     const auto dci    = impls_cast_(dctx);
122:     auto      &handle = solverhandles_[dctx->device->deviceId];

124:     PetscFunctionBegin;
125:     if (!handle) {
126:       PetscLogEvent event;

128:       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
129:       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
130:       for (auto i = 0; i < 3; ++i) {
131:         const auto cerr = cupmSolverCreate(&handle);
132:         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
133:         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
134:         if (i < 2) {
135:           PetscCall(PetscSleep(3));
136:           continue;
137:         }
138:         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
139:       }
140:       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
141:       PetscCall(PetscLogEventResume_Internal(event));
142:     }
143:     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
144:     dci->solver = handle;
145:     PetscFunctionReturn(PETSC_SUCCESS);
146:   }

148:   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
149:   {
150:     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;

152:     PetscFunctionBegin;
153:     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
154:                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
155:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
156:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
157:     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
158:     PetscFunctionReturn(PETSC_SUCCESS);
159:   }

161:   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }

163:   static PetscErrorCode finalize_() noexcept
164:   {
165:     PetscFunctionBegin;
166:     for (auto &&handle : blashandles_) {
167:       if (handle) {
168:         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
169:         handle = nullptr;
170:       }
171:     }

173:     for (auto &&handle : solverhandles_) {
174:       if (handle) {
175:         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
176:         handle = nullptr;
177:       }
178:     }
179:     initialized_ = false;
180:     PetscFunctionReturn(PETSC_SUCCESS);
181:   }

183:   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
184:   PETSC_NODISCARD static PoolType &default_pool_() noexcept
185:   {
186:     static PoolType pool;
187:     return pool;
188:   }

190:   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
191:   {
192:     PetscFunctionBegin;
193:     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
194:     PetscFunctionReturn(PETSC_SUCCESS);
195:   }

197: public:
198:   // All of these functions MUST be static in order to be callable from C, otherwise they
199:   // get the implicit 'this' pointer tacked on
200:   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
201:   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
202:   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
203:   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
204:   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
205:   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
206:   template <typename Handle_t>
207:   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
208:   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209:   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210:   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
211:   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
212:   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
213:   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
214:   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
215:   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
216:   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;

218:   // not a PetscDeviceContext method, this registers the class
219:   static PetscErrorCode initialize(PetscDevice) noexcept;

221:   // clang-format off
222:   static constexpr _DeviceContextOps ops = {
223:     PetscDesignatedInitializer(destroy, destroy),
224:     PetscDesignatedInitializer(changestreamtype, changeStreamType),
225:     PetscDesignatedInitializer(setup, setUp),
226:     PetscDesignatedInitializer(query, query),
227:     PetscDesignatedInitializer(waitforcontext, waitForContext),
228:     PetscDesignatedInitializer(synchronize, synchronize),
229:     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
230:     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
231:     PetscDesignatedInitializer(getstreamhandle, getHandle<stream_tag>),
232:     PetscDesignatedInitializer(begintimer, beginTimer),
233:     PetscDesignatedInitializer(endtimer, endTimer),
234:     PetscDesignatedInitializer(memalloc, memAlloc),
235:     PetscDesignatedInitializer(memfree, memFree),
236:     PetscDesignatedInitializer(memcopy, memCopy),
237:     PetscDesignatedInitializer(memset, memSet),
238:     PetscDesignatedInitializer(createevent, createEvent),
239:     PetscDesignatedInitializer(recordevent, recordEvent),
240:     PetscDesignatedInitializer(waitforevent, waitForEvent)
241:   };
242:   // clang-format on
243: };

245: // not a PetscDeviceContext method, this initializes the CLASS
246: template <DeviceType T>
247: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
248: {
249:   PetscFunctionBegin;
250:   if (PetscUnlikely(!initialized_)) {
251:     uint64_t      threshold = UINT64_MAX;
252:     cupmMemPool_t mempool;

254:     initialized_ = true;
255:     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
256:     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
257:     blashandles_.fill(nullptr);
258:     solverhandles_.fill(nullptr);
259:     PetscCall(PetscRegisterFinalize(finalize_));
260:   }
261:   PetscFunctionReturn(PETSC_SUCCESS);
262: }

264: template <DeviceType T>
265: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
266: {
267:   PetscFunctionBegin;
268:   if (const auto dci = impls_cast_(dctx)) {
269:     PetscCall(dci->stream.destroy());
270:     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
271:     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
272:     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
273:     delete dci;
274:     dctx->data = nullptr;
275:   }
276:   PetscFunctionReturn(PETSC_SUCCESS);
277: }

279: template <DeviceType T>
280: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
281: {
282:   const auto dci = impls_cast_(dctx);

284:   PetscFunctionBegin;
285:   PetscCall(dci->stream.destroy());
286:   // set these to null so they aren't usable until setup is called again
287:   dci->blas   = nullptr;
288:   dci->solver = nullptr;
289:   PetscFunctionReturn(PETSC_SUCCESS);
290: }

292: template <DeviceType T>
293: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
294: {
295:   const auto dci   = impls_cast_(dctx);
296:   auto      &event = dci->event;

298:   PetscFunctionBegin;
299:   PetscCall(check_current_device_(dctx));
300:   PetscCall(dci->stream.change_type(dctx->streamType));
301:   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
302:   #if PetscDefined(USE_DEBUG)
303:   dci->timerInUse = PETSC_FALSE;
304:   #endif
305:   PetscFunctionReturn(PETSC_SUCCESS);
306: }

308: template <DeviceType T>
309: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
310: {
311:   PetscFunctionBegin;
312:   PetscCall(check_current_device_(dctx));
313:   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
314:   case cupmSuccess:
315:     *idle = PETSC_TRUE;
316:     break;
317:   case cupmErrorNotReady:
318:     *idle = PETSC_FALSE;
319:     // reset the error
320:     cerr = cupmGetLastError();
321:     static_cast<void>(cerr);
322:     break;
323:   default:
324:     PetscCallCUPM(cerr);
325:     PetscUnreachable();
326:   }
327:   PetscFunctionReturn(PETSC_SUCCESS);
328: }

330: template <DeviceType T>
331: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
332: {
333:   const auto dcib  = impls_cast_(dctxb);
334:   const auto event = dcib->event;

336:   PetscFunctionBegin;
337:   PetscCall(check_current_device_(dctxa, dctxb));
338:   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
339:   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
340:   PetscFunctionReturn(PETSC_SUCCESS);
341: }

343: template <DeviceType T>
344: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
345: {
346:   auto idle = PETSC_TRUE;

348:   PetscFunctionBegin;
349:   PetscCall(query(dctx, &idle));
350:   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
351:   PetscFunctionReturn(PETSC_SUCCESS);
352: }

354: template <DeviceType T>
355: template <typename handle_t>
356: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
357: {
358:   PetscFunctionBegin;
359:   PetscCall(initialize_handle_(handle_t{}, dctx));
360:   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
361:   PetscFunctionReturn(PETSC_SUCCESS);
362: }

364: template <DeviceType T>
365: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
366: {
367:   const auto dci = impls_cast_(dctx);

369:   PetscFunctionBegin;
370:   PetscCall(check_current_device_(dctx));
371:   #if PetscDefined(USE_DEBUG)
372:   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
373:   dci->timerInUse = PETSC_TRUE;
374:   #endif
375:   if (!dci->begin) {
376:     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
377:     PetscCallCUPM(cupmEventCreate(&dci->begin));
378:     PetscCallCUPM(cupmEventCreate(&dci->end));
379:   }
380:   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
381:   PetscFunctionReturn(PETSC_SUCCESS);
382: }

384: template <DeviceType T>
385: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
386: {
387:   float      gtime;
388:   const auto dci = impls_cast_(dctx);
389:   const auto end = dci->end;

391:   PetscFunctionBegin;
392:   PetscCall(check_current_device_(dctx));
393:   #if PetscDefined(USE_DEBUG)
394:   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
395:   dci->timerInUse = PETSC_FALSE;
396:   #endif
397:   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
398:   PetscCallCUPM(cupmEventSynchronize(end));
399:   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
400:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
401:   PetscFunctionReturn(PETSC_SUCCESS);
402: }

404: template <DeviceType T>
405: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
406: {
407:   const auto &stream = impls_cast_(dctx)->stream;

409:   PetscFunctionBegin;
410:   PetscCall(check_current_device_(dctx));
411:   PetscCall(check_memtype_(mtype, "allocating"));
412:   if (PetscMemTypeHost(mtype)) {
413:     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
414:   } else {
415:     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
416:   }
417:   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
418:   PetscFunctionReturn(PETSC_SUCCESS);
419: }

421: template <DeviceType T>
422: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
423: {
424:   const auto &stream = impls_cast_(dctx)->stream;

426:   PetscFunctionBegin;
427:   PetscCall(check_current_device_(dctx));
428:   PetscCall(check_memtype_(mtype, "freeing"));
429:   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
430:   if (PetscMemTypeHost(mtype)) {
431:     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
432:     // if ptr exists still exists the pool didn't own it
433:     if (*ptr) {
434:       auto registered = PETSC_FALSE, managed = PETSC_FALSE;

436:       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
437:       if (registered) {
438:         PetscCallCUPM(cupmFreeHost(*ptr));
439:       } else if (managed) {
440:         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
441:       }
442:     }
443:   } else {
444:     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
445:     // if ptr still exists the pool didn't own it
446:     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447:   }
448:   PetscFunctionReturn(PETSC_SUCCESS);
449: }

451: template <DeviceType T>
452: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
453: {
454:   const auto stream = impls_cast_(dctx)->stream.get_stream();

456:   PetscFunctionBegin;
457:   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
458:   if (mode == PETSC_DEVICE_COPY_HTOH) {
459:     const auto cerr = cupmStreamQuery(stream);

461:     // yes this is faster
462:     if (cerr == cupmSuccess) {
463:       PetscCall(PetscMemcpy(dest, src, n));
464:       PetscFunctionReturn(PETSC_SUCCESS);
465:     } else if (cerr == cupmErrorNotReady) {
466:       auto PETSC_UNUSED unused = cupmGetLastError();

468:       static_cast<void>(unused);
469:     } else {
470:       PetscCallCUPM(cerr);
471:     }
472:   }
473:   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
474:   PetscFunctionReturn(PETSC_SUCCESS);
475: }

477: template <DeviceType T>
478: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
479: {
480:   PetscFunctionBegin;
481:   PetscCall(check_current_device_(dctx));
482:   PetscCall(check_memtype_(mtype, "zeroing"));
483:   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
484:   PetscFunctionReturn(PETSC_SUCCESS);
485: }

487: template <DeviceType T>
488: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
489: {
490:   PetscFunctionBegin;
491:   PetscCallCXX(event->data = new event_type());
492:   event->destroy = [](PetscEvent event) {
493:     PetscFunctionBegin;
494:     delete event_cast_(event);
495:     event->data = nullptr;
496:     PetscFunctionReturn(PETSC_SUCCESS);
497:   };
498:   PetscFunctionReturn(PETSC_SUCCESS);
499: }

501: template <DeviceType T>
502: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
503: {
504:   PetscFunctionBegin;
505:   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
506:   PetscFunctionReturn(PETSC_SUCCESS);
507: }

509: template <DeviceType T>
510: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
511: {
512:   PetscFunctionBegin;
513:   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
514:   PetscFunctionReturn(PETSC_SUCCESS);
515: }

517: // initialize the static member variables
518: template <DeviceType T>
519: bool DeviceContext<T>::initialized_ = false;

521: template <DeviceType T>
522: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};

524: template <DeviceType T>
525: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

527: template <DeviceType T>
528: constexpr _DeviceContextOps DeviceContext<T>::ops;

530: } // namespace impl

532: // shorten this one up a bit (and instantiate the templates)
533: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
534: using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;

536:   // shorthand for what is an EXTREMELY long name
537:   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

539: } // namespace cupm

541: } // namespace device

543: } // namespace Petsc

545: #endif // __cplusplus

547: #endif // PETSCDEVICECONTEXTCUDA_HPP