Actual source code: cupmobject.hpp
1: #ifndef PETSC_PRIVATE_CUPMOBJECT_HPP
2: #define PETSC_PRIVATE_CUPMOBJECT_HPP
4: #ifdef __cplusplus
5: #include <petsc/private/deviceimpl.h>
6: #include <petsc/private/cupmsolverinterface.hpp>
8: #include <cstring> // std::memset
10: namespace
11: {
13: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
14: {
15: PetscFunctionBegin;
17: if (*dest) {
19: PetscCall(PetscFree(*dest));
20: }
21: PetscCall(PetscStrallocpy(target, dest));
22: PetscFunctionReturn(PETSC_SUCCESS);
23: }
25: } // namespace
27: namespace Petsc
28: {
30: namespace device
31: {
33: namespace cupm
34: {
36: namespace impl
37: {
39: namespace
40: {
42: // ==========================================================================================
43: // UseCUPMHostAllocGuard
44: //
45: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
46: // regular versions would be an enormous pain to square with the templated types...
47: // ==========================================================================================
48: template <DeviceType T>
49: class UseCUPMHostAllocGuard : Interface<T> {
50: public:
51: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
53: UseCUPMHostAllocGuard(bool) noexcept;
54: ~UseCUPMHostAllocGuard() noexcept;
56: PETSC_NODISCARD bool value() const noexcept;
58: private:
59: // would have loved to just do
60: //
61: // const auto oldmalloc = PetscTrMalloc;
62: //
63: // but in order to use auto the member needs to be static; in order to be static it must
64: // also be constexpr -- which in turn requires an initializer (also implicitly required by
65: // auto). But constexpr needs a constant expression initializer, so we can't initialize it
66: // with global (mutable) variables...
67: #define DECLTYPE_AUTO(left, right) decltype(right) left = right
68: const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
69: const DECLTYPE_AUTO(oldfree_, PetscTrFree);
70: const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
71: #undef DECLTYPE_AUTO
72: bool v_;
73: };
75: // ==========================================================================================
76: // UseCUPMHostAllocGuard -- Public API
77: // ==========================================================================================
79: template <DeviceType T>
80: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
81: {
82: PetscFunctionBegin;
83: if (useit) {
84: // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
85: PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
86: PetscFunctionBegin;
87: PetscCallCUPM(cupmMallocHost(ptr, sz));
88: if (clear) std::memset(*ptr, 0, sz);
89: PetscFunctionReturn(PETSC_SUCCESS);
90: };
91: PetscTrFree = [](void *ptr, int, const char *, const char *) {
92: PetscFunctionBegin;
93: PetscCallCUPM(cupmFreeHost(ptr));
94: PetscFunctionReturn(PETSC_SUCCESS);
95: };
96: PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
97: // REVIEW ME: can be implemented by malloc->copy->free?
98: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
99: };
100: }
101: PetscFunctionReturnVoid();
102: }
104: template <DeviceType T>
105: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
106: {
107: PetscFunctionBegin;
108: if (value()) {
109: PetscTrMalloc = oldmalloc_;
110: PetscTrFree = oldfree_;
111: PetscTrRealloc = oldrealloc_;
112: }
113: PetscFunctionReturnVoid();
114: }
116: template <DeviceType T>
117: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
118: {
119: return v_;
120: }
122: } // anonymous namespace
124: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
125: class RestoreableArray : Interface<T> {
126: public:
127: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
129: static constexpr auto memory_type = MemoryType;
130: static constexpr auto access_type = AccessMode;
132: using value_type = PetscScalar;
133: using pointer_type = value_type *;
134: using cupm_pointer_type = cupmScalar_t *;
136: PETSC_NODISCARD pointer_type data() const noexcept;
137: PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;
139: operator pointer_type() const noexcept;
140: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
141: // we make a dummy template parameter to allow SFINAE to nix it for us
142: template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
143: operator cupm_pointer_type() const noexcept;
145: protected:
146: constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;
148: value_type *ptr_ = nullptr;
149: PetscDeviceContext dctx_ = nullptr;
150: };
152: // ==========================================================================================
153: // RestoreableArray - Static Variables
154: // ==========================================================================================
156: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
157: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;
159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;
162: // ==========================================================================================
163: // RestoreableArray - Public API
164: // ==========================================================================================
166: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
167: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
168: {
169: }
171: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
172: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
173: {
174: return ptr_;
175: }
177: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
178: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
179: {
180: return cupmScalarPtrCast(data());
181: }
183: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
184: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
185: {
186: return data();
187: }
189: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
190: // we make a dummy template parameter to allow SFINAE to nix it for us
191: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
192: template <typename U, typename>
193: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
194: {
195: return cupmdata();
196: }
198: template <DeviceType T>
199: class CUPMObject : SolverInterface<T> {
200: protected:
201: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
203: private:
204: // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
205: // compute handles and ensure the given PetscDeviceContext is of the right type
206: static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
207: static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
209: protected:
210: PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;
212: // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
213: // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
214: // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
215: // them check that the PetscDeviceContext is of the appropriate type
216: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
218: // triple
219: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
220: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
222: // double
223: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
224: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;
226: // single
227: static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
228: static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
229: static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;
231: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
232: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
233: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;
235: // disallow implicit conversion
236: template <typename U>
237: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
238: // utility for using cupmHostAlloc()
239: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
240: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(PetscBool) noexcept;
242: // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
243: // actually correct. Errors on mismatch
244: static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
245: };
247: template <DeviceType T>
248: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
249: {
250: // REVIEW ME: HIP default rng?
251: return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
252: }
254: template <DeviceType T>
255: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
256: {
257: PetscFunctionBegin;
262: if (PetscDefined(USE_DEBUG)) {
263: PetscDeviceType dtype;
265: PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
266: PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
267: }
268: if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
269: if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
270: if (stream) PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, stream));
271: PetscFunctionReturn(PETSC_SUCCESS);
272: }
274: template <DeviceType T>
275: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
276: {
277: PetscDeviceContext dctx_loc = nullptr;
279: PetscFunctionBegin;
280: // silence uninitialized variable warnings
281: if (dctx) *dctx = nullptr;
282: PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
283: PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
284: if (dctx) *dctx = dctx_loc;
285: PetscFunctionReturn(PETSC_SUCCESS);
286: }
288: template <DeviceType T>
289: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
290: {
291: return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
292: }
294: template <DeviceType T>
295: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
296: {
297: return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
298: }
300: template <DeviceType T>
301: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
302: {
303: return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
304: }
306: template <DeviceType T>
307: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
308: {
309: return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
310: }
312: template <DeviceType T>
313: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
314: {
315: return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
316: }
318: template <DeviceType T>
319: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
320: {
321: return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
322: }
324: template <DeviceType T>
325: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
326: {
327: return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
328: }
330: template <DeviceType T>
331: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
332: {
333: return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
334: }
336: template <DeviceType T>
337: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338: {
339: return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
340: }
342: template <DeviceType T>
343: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
344: {
345: return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
346: }
348: template <DeviceType T>
349: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
350: {
351: return {b};
352: }
354: template <DeviceType T>
355: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(PetscBool b) noexcept
356: {
357: return UseCUPMHostAlloc(static_cast<bool>(b));
358: }
360: template <DeviceType T>
361: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362: {
363: PetscFunctionBegin;
364: if (PetscDefined(USE_DEBUG) && ptr) {
365: PetscMemType ptr_mtype;
367: PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368: if (mtype == PETSC_MEMTYPE_HOST) {
369: PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370: } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371: // generic "device" memory should only care if the actual memtype is also generically
372: // "device"
373: PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374: } else {
375: PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376: }
377: }
378: PetscFunctionReturn(PETSC_SUCCESS);
379: }
381: #define PETSC_CUPMOBJECT_HEADER(T) \
382: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383: using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386: using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387: using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_
389: } // namespace impl
391: } // namespace cupm
393: } // namespace device
395: } // namespace Petsc
397: #endif // __cplusplus
399: #endif // PETSC_PRIVATE_CUPMOBJECT_HPP