Actual source code: cupmcontext.hpp
1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
2: #define PETSCDEVICECONTEXTCUPM_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupmsolverinterface.hpp>
6: #include <petsc/private/logimpl.h>
8: #include <petsc/private/cpp/array.hpp>
10: #include "../segmentedmempool.hpp"
11: #include "cupmallocator.hpp"
12: #include "cupmstream.hpp"
13: #include "cupmevent.hpp"
15: #if defined(__cplusplus)
17: namespace Petsc
18: {
20: namespace device
21: {
23: namespace cupm
24: {
26: namespace impl
27: {
29: template <DeviceType T>
30: class DeviceContext : SolverInterface<T> {
31: public:
32: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
34: private:
35: template <typename H, std::size_t>
36: struct HandleTag {
37: using type = H;
38: };
40: using stream_tag = HandleTag<cupmStream_t, 0>;
41: using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
42: using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
44: using stream_type = CUPMStream<T>;
45: using event_type = CUPMEvent<T>;
47: public:
48: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
49: // header, but since we are using the power of templates it must be declared part of
50: // this class to have easy access the same typedefs. Technically one can make a
51: // templated struct outside the class but it's more code for the same result.
52: struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
53: stream_type stream{};
54: cupmEvent_t event{};
55: cupmEvent_t begin{}; // timer-only
56: cupmEvent_t end{}; // timer-only
57: #if PetscDefined(USE_DEBUG)
58: PetscBool timerInUse{};
59: #endif
60: cupmBlasHandle_t blas{};
61: cupmSolverHandle_t solver{};
63: constexpr PetscDeviceContext_IMPLS() noexcept = default;
65: PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }
67: PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }
69: PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
70: };
72: private:
73: static bool initialized_;
75: static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
76: static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
78: PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
80: PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
82: PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
84: PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
86: // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
87: // handles
88: static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
90: static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
91: {
92: const auto dci = impls_cast_(dctx);
93: auto &handle = blashandles_[dctx->device->deviceId];
95: PetscFunctionBegin;
96: if (!handle) {
97: PetscLogEvent event;
99: PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
100: PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
101: for (auto i = 0; i < 3; ++i) {
102: const auto cberr = cupmBlasCreate(handle.ptr_to());
103: if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
104: if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
105: if (i != 2) {
106: PetscCall(PetscSleep(3));
107: continue;
108: }
109: PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
110: }
111: PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
112: PetscCall(PetscLogEventResume_Internal(event));
113: }
114: PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
115: dci->blas = handle;
116: PetscFunctionReturn(PETSC_SUCCESS);
117: }
119: static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
120: {
121: const auto dci = impls_cast_(dctx);
122: auto &handle = solverhandles_[dctx->device->deviceId];
124: PetscFunctionBegin;
125: if (!handle) {
126: PetscLogEvent event;
128: PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
129: PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
130: for (auto i = 0; i < 3; ++i) {
131: const auto cerr = cupmSolverCreate(&handle);
132: if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
133: if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
134: if (i < 2) {
135: PetscCall(PetscSleep(3));
136: continue;
137: }
138: PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
139: }
140: PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
141: PetscCall(PetscLogEventResume_Internal(event));
142: }
143: PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
144: dci->solver = handle;
145: PetscFunctionReturn(PETSC_SUCCESS);
146: }
148: static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
149: {
150: const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
152: PetscFunctionBegin;
153: PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
154: PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
155: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
156: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
157: PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
158: PetscFunctionReturn(PETSC_SUCCESS);
159: }
161: static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
163: static PetscErrorCode finalize_() noexcept
164: {
165: PetscFunctionBegin;
166: for (auto &&handle : blashandles_) {
167: if (handle) {
168: PetscCallCUPMBLAS(cupmBlasDestroy(handle));
169: handle = nullptr;
170: }
171: }
173: for (auto &&handle : solverhandles_) {
174: if (handle) {
175: PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
176: handle = nullptr;
177: }
178: }
179: initialized_ = false;
180: PetscFunctionReturn(PETSC_SUCCESS);
181: }
183: template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
184: PETSC_NODISCARD static PoolType &default_pool_() noexcept
185: {
186: static PoolType pool;
187: return pool;
188: }
190: static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
191: {
192: PetscFunctionBegin;
193: PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
194: PetscFunctionReturn(PETSC_SUCCESS);
195: }
197: public:
198: // All of these functions MUST be static in order to be callable from C, otherwise they
199: // get the implicit 'this' pointer tacked on
200: static PetscErrorCode destroy(PetscDeviceContext) noexcept;
201: static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
202: static PetscErrorCode setUp(PetscDeviceContext) noexcept;
203: static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
204: static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
205: static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
206: template <typename Handle_t>
207: static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
208: static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209: static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210: static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
211: static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
212: static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
213: static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
214: static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
215: static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
216: static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
218: // not a PetscDeviceContext method, this registers the class
219: static PetscErrorCode initialize(PetscDevice) noexcept;
221: // clang-format off
222: static constexpr _DeviceContextOps ops = {
223: PetscDesignatedInitializer(destroy, destroy),
224: PetscDesignatedInitializer(changestreamtype, changeStreamType),
225: PetscDesignatedInitializer(setup, setUp),
226: PetscDesignatedInitializer(query, query),
227: PetscDesignatedInitializer(waitforcontext, waitForContext),
228: PetscDesignatedInitializer(synchronize, synchronize),
229: PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
230: PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
231: PetscDesignatedInitializer(getstreamhandle, getHandle<stream_tag>),
232: PetscDesignatedInitializer(begintimer, beginTimer),
233: PetscDesignatedInitializer(endtimer, endTimer),
234: PetscDesignatedInitializer(memalloc, memAlloc),
235: PetscDesignatedInitializer(memfree, memFree),
236: PetscDesignatedInitializer(memcopy, memCopy),
237: PetscDesignatedInitializer(memset, memSet),
238: PetscDesignatedInitializer(createevent, createEvent),
239: PetscDesignatedInitializer(recordevent, recordEvent),
240: PetscDesignatedInitializer(waitforevent, waitForEvent)
241: };
242: // clang-format on
243: };
245: // not a PetscDeviceContext method, this initializes the CLASS
246: template <DeviceType T>
247: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
248: {
249: PetscFunctionBegin;
250: if (PetscUnlikely(!initialized_)) {
251: uint64_t threshold = UINT64_MAX;
252: cupmMemPool_t mempool;
254: initialized_ = true;
255: PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
256: PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
257: blashandles_.fill(nullptr);
258: solverhandles_.fill(nullptr);
259: PetscCall(PetscRegisterFinalize(finalize_));
260: }
261: PetscFunctionReturn(PETSC_SUCCESS);
262: }
264: template <DeviceType T>
265: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
266: {
267: PetscFunctionBegin;
268: if (const auto dci = impls_cast_(dctx)) {
269: PetscCall(dci->stream.destroy());
270: if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
271: if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
272: if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
273: delete dci;
274: dctx->data = nullptr;
275: }
276: PetscFunctionReturn(PETSC_SUCCESS);
277: }
279: template <DeviceType T>
280: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
281: {
282: const auto dci = impls_cast_(dctx);
284: PetscFunctionBegin;
285: PetscCall(dci->stream.destroy());
286: // set these to null so they aren't usable until setup is called again
287: dci->blas = nullptr;
288: dci->solver = nullptr;
289: PetscFunctionReturn(PETSC_SUCCESS);
290: }
292: template <DeviceType T>
293: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
294: {
295: const auto dci = impls_cast_(dctx);
296: auto &event = dci->event;
298: PetscFunctionBegin;
299: PetscCall(check_current_device_(dctx));
300: PetscCall(dci->stream.change_type(dctx->streamType));
301: if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
302: #if PetscDefined(USE_DEBUG)
303: dci->timerInUse = PETSC_FALSE;
304: #endif
305: PetscFunctionReturn(PETSC_SUCCESS);
306: }
308: template <DeviceType T>
309: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
310: {
311: PetscFunctionBegin;
312: PetscCall(check_current_device_(dctx));
313: switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
314: case cupmSuccess:
315: *idle = PETSC_TRUE;
316: break;
317: case cupmErrorNotReady:
318: *idle = PETSC_FALSE;
319: // reset the error
320: cerr = cupmGetLastError();
321: static_cast<void>(cerr);
322: break;
323: default:
324: PetscCallCUPM(cerr);
325: PetscUnreachable();
326: }
327: PetscFunctionReturn(PETSC_SUCCESS);
328: }
330: template <DeviceType T>
331: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
332: {
333: const auto dcib = impls_cast_(dctxb);
334: const auto event = dcib->event;
336: PetscFunctionBegin;
337: PetscCall(check_current_device_(dctxa, dctxb));
338: PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
339: PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
340: PetscFunctionReturn(PETSC_SUCCESS);
341: }
343: template <DeviceType T>
344: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
345: {
346: auto idle = PETSC_TRUE;
348: PetscFunctionBegin;
349: PetscCall(query(dctx, &idle));
350: if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
351: PetscFunctionReturn(PETSC_SUCCESS);
352: }
354: template <DeviceType T>
355: template <typename handle_t>
356: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
357: {
358: PetscFunctionBegin;
359: PetscCall(initialize_handle_(handle_t{}, dctx));
360: *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
361: PetscFunctionReturn(PETSC_SUCCESS);
362: }
364: template <DeviceType T>
365: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
366: {
367: const auto dci = impls_cast_(dctx);
369: PetscFunctionBegin;
370: PetscCall(check_current_device_(dctx));
371: #if PetscDefined(USE_DEBUG)
372: PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
373: dci->timerInUse = PETSC_TRUE;
374: #endif
375: if (!dci->begin) {
376: PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
377: PetscCallCUPM(cupmEventCreate(&dci->begin));
378: PetscCallCUPM(cupmEventCreate(&dci->end));
379: }
380: PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
381: PetscFunctionReturn(PETSC_SUCCESS);
382: }
384: template <DeviceType T>
385: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
386: {
387: float gtime;
388: const auto dci = impls_cast_(dctx);
389: const auto end = dci->end;
391: PetscFunctionBegin;
392: PetscCall(check_current_device_(dctx));
393: #if PetscDefined(USE_DEBUG)
394: PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
395: dci->timerInUse = PETSC_FALSE;
396: #endif
397: PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
398: PetscCallCUPM(cupmEventSynchronize(end));
399: PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end));
400: *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
401: PetscFunctionReturn(PETSC_SUCCESS);
402: }
404: template <DeviceType T>
405: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
406: {
407: const auto &stream = impls_cast_(dctx)->stream;
409: PetscFunctionBegin;
410: PetscCall(check_current_device_(dctx));
411: PetscCall(check_memtype_(mtype, "allocating"));
412: if (PetscMemTypeHost(mtype)) {
413: PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
414: } else {
415: PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
416: }
417: if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
418: PetscFunctionReturn(PETSC_SUCCESS);
419: }
421: template <DeviceType T>
422: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
423: {
424: const auto &stream = impls_cast_(dctx)->stream;
426: PetscFunctionBegin;
427: PetscCall(check_current_device_(dctx));
428: PetscCall(check_memtype_(mtype, "freeing"));
429: if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
430: if (PetscMemTypeHost(mtype)) {
431: PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
432: // if ptr exists still exists the pool didn't own it
433: if (*ptr) {
434: auto registered = PETSC_FALSE, managed = PETSC_FALSE;
436: PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed));
437: if (registered) {
438: PetscCallCUPM(cupmFreeHost(*ptr));
439: } else if (managed) {
440: PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
441: }
442: }
443: } else {
444: PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
445: // if ptr still exists the pool didn't own it
446: if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447: }
448: PetscFunctionReturn(PETSC_SUCCESS);
449: }
451: template <DeviceType T>
452: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
453: {
454: const auto stream = impls_cast_(dctx)->stream.get_stream();
456: PetscFunctionBegin;
457: // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
458: if (mode == PETSC_DEVICE_COPY_HTOH) {
459: const auto cerr = cupmStreamQuery(stream);
461: // yes this is faster
462: if (cerr == cupmSuccess) {
463: PetscCall(PetscMemcpy(dest, src, n));
464: PetscFunctionReturn(PETSC_SUCCESS);
465: } else if (cerr == cupmErrorNotReady) {
466: auto PETSC_UNUSED unused = cupmGetLastError();
468: static_cast<void>(unused);
469: } else {
470: PetscCallCUPM(cerr);
471: }
472: }
473: PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
474: PetscFunctionReturn(PETSC_SUCCESS);
475: }
477: template <DeviceType T>
478: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
479: {
480: PetscFunctionBegin;
481: PetscCall(check_current_device_(dctx));
482: PetscCall(check_memtype_(mtype, "zeroing"));
483: PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
484: PetscFunctionReturn(PETSC_SUCCESS);
485: }
487: template <DeviceType T>
488: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
489: {
490: PetscFunctionBegin;
491: PetscCallCXX(event->data = new event_type());
492: event->destroy = [](PetscEvent event) {
493: PetscFunctionBegin;
494: delete event_cast_(event);
495: event->data = nullptr;
496: PetscFunctionReturn(PETSC_SUCCESS);
497: };
498: PetscFunctionReturn(PETSC_SUCCESS);
499: }
501: template <DeviceType T>
502: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
503: {
504: PetscFunctionBegin;
505: PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
506: PetscFunctionReturn(PETSC_SUCCESS);
507: }
509: template <DeviceType T>
510: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
511: {
512: PetscFunctionBegin;
513: PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
514: PetscFunctionReturn(PETSC_SUCCESS);
515: }
517: // initialize the static member variables
518: template <DeviceType T>
519: bool DeviceContext<T>::initialized_ = false;
521: template <DeviceType T>
522: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
524: template <DeviceType T>
525: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
527: template <DeviceType T>
528: constexpr _DeviceContextOps DeviceContext<T>::ops;
530: } // namespace impl
532: // shorten this one up a bit (and instantiate the templates)
533: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
534: using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
536: // shorthand for what is an EXTREMELY long name
537: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
539: } // namespace cupm
541: } // namespace device
543: } // namespace Petsc
545: #endif // __cplusplus
547: #endif // PETSCDEVICECONTEXTCUDA_HPP