Actual source code: cupmobject.hpp

  1: #ifndef PETSC_PRIVATE_CUPMOBJECT_HPP
  2: #define PETSC_PRIVATE_CUPMOBJECT_HPP

  4: #ifdef __cplusplus
  5: #include <petsc/private/deviceimpl.h>
  6: #include <petsc/private/cupmsolverinterface.hpp>

  8:   #include <cstring> // std::memset

 10: namespace
 11: {

 13: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
 14: {
 15:   PetscFunctionBegin;
 17:   if (*dest) {
 19:     PetscCall(PetscFree(*dest));
 20:   }
 21:   PetscCall(PetscStrallocpy(target, dest));
 22:   PetscFunctionReturn(PETSC_SUCCESS);
 23: }

 25: } // namespace

 27: namespace Petsc
 28: {

 30: namespace device
 31: {

 33: namespace cupm
 34: {

 36: namespace impl
 37: {

 39: namespace
 40: {

 42: // ==========================================================================================
 43: // UseCUPMHostAllocGuard
 44: //
 45: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
 46: // regular versions would be an enormous pain to square with the templated types...
 47: // ==========================================================================================
 48: template <DeviceType T>
 49: class UseCUPMHostAllocGuard : Interface<T> {
 50: public:
 51:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 53:   UseCUPMHostAllocGuard(bool) noexcept;
 54:   ~UseCUPMHostAllocGuard() noexcept;

 56:   PETSC_NODISCARD bool value() const noexcept;

 58: private:
 59:     // would have loved to just do
 60:     //
 61:     // const auto oldmalloc = PetscTrMalloc;
 62:     //
 63:     // but in order to use auto the member needs to be static; in order to be static it must
 64:     // also be constexpr -- which in turn requires an initializer (also implicitly required by
 65:     // auto). But constexpr needs a constant expression initializer, so we can't initialize it
 66:     // with global (mutable) variables...
 67:   #define DECLTYPE_AUTO(left, right) decltype(right) left = right
 68:   const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
 69:   const DECLTYPE_AUTO(oldfree_, PetscTrFree);
 70:   const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
 71:   #undef DECLTYPE_AUTO
 72:   bool v_;
 73: };

 75: // ==========================================================================================
 76: // UseCUPMHostAllocGuard -- Public API
 77: // ==========================================================================================

 79: template <DeviceType T>
 80: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
 81: {
 82:   PetscFunctionBegin;
 83:   if (useit) {
 84:     // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
 85:     PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
 86:       PetscFunctionBegin;
 87:       PetscCallCUPM(cupmMallocHost(ptr, sz));
 88:       if (clear) std::memset(*ptr, 0, sz);
 89:       PetscFunctionReturn(PETSC_SUCCESS);
 90:     };
 91:     PetscTrFree = [](void *ptr, int, const char *, const char *) {
 92:       PetscFunctionBegin;
 93:       PetscCallCUPM(cupmFreeHost(ptr));
 94:       PetscFunctionReturn(PETSC_SUCCESS);
 95:     };
 96:     PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
 97:       // REVIEW ME: can be implemented by malloc->copy->free?
 98:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
 99:     };
100:   }
101:   PetscFunctionReturnVoid();
102: }

104: template <DeviceType T>
105: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
106: {
107:   PetscFunctionBegin;
108:   if (value()) {
109:     PetscTrMalloc  = oldmalloc_;
110:     PetscTrFree    = oldfree_;
111:     PetscTrRealloc = oldrealloc_;
112:   }
113:   PetscFunctionReturnVoid();
114: }

116: template <DeviceType T>
117: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
118: {
119:   return v_;
120: }

122: } // anonymous namespace

124: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
125: class RestoreableArray : Interface<T> {
126: public:
127:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

129:   static constexpr auto memory_type = MemoryType;
130:   static constexpr auto access_type = AccessMode;

132:   using value_type        = PetscScalar;
133:   using pointer_type      = value_type *;
134:   using cupm_pointer_type = cupmScalar_t *;

136:   PETSC_NODISCARD pointer_type      data() const noexcept;
137:   PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;

139:   operator pointer_type() const noexcept;
140:   // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
141:   // we make a dummy template parameter to allow SFINAE to nix it for us
142:   template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
143:   operator cupm_pointer_type() const noexcept;

145: protected:
146:   constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;

148:   value_type        *ptr_  = nullptr;
149:   PetscDeviceContext dctx_ = nullptr;
150: };

152: // ==========================================================================================
153: // RestoreableArray - Static Variables
154: // ==========================================================================================

156: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
157: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;

159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;

162: // ==========================================================================================
163: // RestoreableArray - Public API
164: // ==========================================================================================

166: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
167: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
168: {
169: }

171: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
172: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
173: {
174:   return ptr_;
175: }

177: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
178: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
179: {
180:   return cupmScalarPtrCast(data());
181: }

183: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
184: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
185: {
186:   return data();
187: }

189: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
190: // we make a dummy template parameter to allow SFINAE to nix it for us
191: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
192: template <typename U, typename>
193: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
194: {
195:   return cupmdata();
196: }

198: template <DeviceType T>
199: class CUPMObject : SolverInterface<T> {
200: protected:
201:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

203: private:
204:   // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
205:   // compute handles and ensure the given PetscDeviceContext is of the right type
206:   static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
207:   static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

209: protected:
210:   PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;

212:   // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
213:   // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
214:   // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
215:   // them check that the PetscDeviceContext is of the appropriate type
216:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;

218:   // triple
219:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
220:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

222:   // double
223:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
224:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;

226:   // single
227:   static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
228:   static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
229:   static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;

231:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
232:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
233:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;

235:   // disallow implicit conversion
236:   template <typename U>
237:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
238:   // utility for using cupmHostAlloc()
239:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
240:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(PetscBool) noexcept;

242:   // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
243:   // actually correct. Errors on mismatch
244:   static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
245: };

247: template <DeviceType T>
248: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
249: {
250:   // REVIEW ME: HIP default rng?
251:   return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
252: }

254: template <DeviceType T>
255: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
256: {
257:   PetscFunctionBegin;
262:   if (PetscDefined(USE_DEBUG)) {
263:     PetscDeviceType dtype;

265:     PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
266:     PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
267:   }
268:   if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
269:   if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
270:   if (stream) PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, stream));
271:   PetscFunctionReturn(PETSC_SUCCESS);
272: }

274: template <DeviceType T>
275: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
276: {
277:   PetscDeviceContext dctx_loc = nullptr;

279:   PetscFunctionBegin;
280:   // silence uninitialized variable warnings
281:   if (dctx) *dctx = nullptr;
282:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
283:   PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
284:   if (dctx) *dctx = dctx_loc;
285:   PetscFunctionReturn(PETSC_SUCCESS);
286: }

288: template <DeviceType T>
289: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
290: {
291:   return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
292: }

294: template <DeviceType T>
295: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
296: {
297:   return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
298: }

300: template <DeviceType T>
301: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
302: {
303:   return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
304: }

306: template <DeviceType T>
307: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
308: {
309:   return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
310: }

312: template <DeviceType T>
313: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
314: {
315:   return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
316: }

318: template <DeviceType T>
319: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
320: {
321:   return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
322: }

324: template <DeviceType T>
325: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
326: {
327:   return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
328: }

330: template <DeviceType T>
331: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
332: {
333:   return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
334: }

336: template <DeviceType T>
337: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338: {
339:   return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
340: }

342: template <DeviceType T>
343: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
344: {
345:   return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
346: }

348: template <DeviceType T>
349: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
350: {
351:   return {b};
352: }

354: template <DeviceType T>
355: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(PetscBool b) noexcept
356: {
357:   return UseCUPMHostAlloc(static_cast<bool>(b));
358: }

360: template <DeviceType T>
361: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362: {
363:   PetscFunctionBegin;
364:   if (PetscDefined(USE_DEBUG) && ptr) {
365:     PetscMemType ptr_mtype;

367:     PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368:     if (mtype == PETSC_MEMTYPE_HOST) {
369:       PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370:     } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371:       // generic "device" memory should only care if the actual memtype is also generically
372:       // "device"
373:       PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374:     } else {
375:       PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376:     }
377:   }
378:   PetscFunctionReturn(PETSC_SUCCESS);
379: }

381:   #define PETSC_CUPMOBJECT_HEADER(T) \
382:     PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383:     using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384:     using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385:     using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386:     using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387:     using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_

389: } // namespace impl

391: } // namespace cupm

393: } // namespace device

395: } // namespace Petsc

397: #endif // __cplusplus

399: #endif // PETSC_PRIVATE_CUPMOBJECT_HPP