Actual source code: cupmblasinterface.hpp
1: #ifndef PETSCCUPMBLASINTERFACE_HPP
2: #define PETSCCUPMBLASINTERFACE_HPP
4: #if defined(__cplusplus)
5: #include <petsc/private/cupminterface.hpp>
6: #include <petsc/private/petscadvancedmacros.h>
8: #include <limits> // std::numeric_limits
10: namespace Petsc
11: {
13: namespace device
14: {
16: namespace cupm
17: {
19: namespace impl
20: {
22: #define PetscCallCUPMBLAS(...) \
23: do { \
24: const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
25: if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
26: if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
27: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
28: "%s error %d (%s). Reports not initialized or alloc failed; " \
29: "this indicates the GPU may have run out resources", \
30: cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
31: } \
32: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
33: } \
34: } while (0)
36: #define PetscCallCUPMBLASAbort(comm, ...) \
37: do { \
38: const cupmBlasError_t cberr_abort_p_ = __VA_ARGS__; \
39: if (PetscUnlikely(cberr_abort_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
40: if (((cberr_abort_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_abort_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
41: SETERRABORT(comm, PETSC_ERR_GPU_RESOURCE, \
42: "%s error %d (%s). Reports not initialized or alloc failed; " \
43: "this indicates the GPU may have run out resources", \
44: cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
45: } \
46: SETERRABORT(comm, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
47: } \
48: } while (0)
50: // given cupmBlas<T>axpy() then
51: // T = PETSC_CUPBLAS_FP_TYPE
52: // given cupmBlas<T><u>nrm2() then
53: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
54: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
55: #if PetscDefined(USE_COMPLEX)
56: #if PetscDefined(USE_REAL_SINGLE)
57: #define PETSC_CUPMBLAS_FP_TYPE_U C
58: #define PETSC_CUPMBLAS_FP_TYPE_L c
59: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
60: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
61: #elif PetscDefined(USE_REAL_DOUBLE)
62: #define PETSC_CUPMBLAS_FP_TYPE_U Z
63: #define PETSC_CUPMBLAS_FP_TYPE_L z
64: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
65: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
66: #endif
67: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
68: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
69: #else
70: #if PetscDefined(USE_REAL_SINGLE)
71: #define PETSC_CUPMBLAS_FP_TYPE_U S
72: #define PETSC_CUPMBLAS_FP_TYPE_L s
73: #elif PetscDefined(USE_REAL_DOUBLE)
74: #define PETSC_CUPMBLAS_FP_TYPE_U D
75: #define PETSC_CUPMBLAS_FP_TYPE_L d
76: #endif
77: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
78: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
79: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
80: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
81: #endif // USE_COMPLEX
83: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
84: #error "Unsupported floating-point type for CUDA/HIP BLAS"
85: #endif
87: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
88: // blas function whose return type does not match the input type
89: //
90: // input param:
91: // func - base suffix of the blas function, e.g. nrm2
92: //
93: // notes:
94: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
95: // letter ("S" for real/complex single, "D" for real/complex double).
96: //
97: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
98: // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
99: // single/double).
100: //
101: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
102: // infuriatingly inconsistent...
103: //
104: // example usage:
105: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S
106: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
107: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
108: //
109: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D
110: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
111: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
112: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
114: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
115: // because they are both extra special
116: //
117: // input param:
118: // func - base suffix of the blas function, either amax or amin
119: //
120: // notes:
121: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
122: // that's what it does.
123: //
124: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
125: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
126: // real double).
127: //
128: // example usage:
129: // #define PETSC_CUPMBLAS_FP_TYPE_L s
130: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
131: //
132: // #define PETSC_CUPMBLAS_FP_TYPE_L z
133: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
134: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
136: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
137: // blas function name
138: //
139: // input param:
140: // func - base suffix of the blas function, e.g. axpy, scal
141: //
142: // notes:
143: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
144: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
145: //
146: // example usage:
147: // #define PETSC_CUPMBLAS_FP_TYPE S
148: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
149: //
150: // #define PETSC_CUPMBLAS_FP_TYPE Z
151: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
152: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
154: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
155: // one can provide both here
156: //
157: // input params:
158: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
159: // IFPTYPE
160: // our_suffix - the suffix of the alias function
161: // their_suffix - the suffix of the function being aliased
162: //
163: // notes:
164: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
165: // prefix. requires any other specific definitions required by the specific builder macro to
166: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
167: // function alias.
168: //
169: // example usage:
170: // #define PETSC_CUPMBLAS_PREFIX cublas
171: // #define PETSC_CUPMBLAS_FP_TYPE C
172: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
173: // template <typename... T>
174: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
175: // {
176: // return cublasCdotc(std::forward<T>(args)...);
177: // }
178: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
179: PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))
181: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
182: //
183: // input params:
184: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
185: // IFPTYPE
186: // suffix - the common suffix between CUDA and HIP of the alias function
187: //
188: // notes:
189: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
190: // "our_prefix" and "their_prefix"
191: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
193: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
194: //
195: // input params:
196: // suffix - the common suffix between CUDA and HIP of the alias function
197: //
198: // notes:
199: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
200: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
201: //
202: // example usage:
203: // #define PETSC_CUPMBLAS_PREFIX hipblas
204: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
205: // template <typename... T>
206: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
207: // {
208: // return hipblasCreate(std::forward<T>(args)...);
209: // }
210: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))
212: template <DeviceType>
213: struct BlasInterfaceImpl;
215: // Exists because HIP (for whatever godforsaken reason) has elected to define both their
216: // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload
217: // resolution and hence need to wrap their types int this mess.
218: template <typename T, std::size_t I>
219: class cupmBlasHandleWrapper {
220: public:
221: constexpr cupmBlasHandleWrapper() noexcept = default;
222: constexpr cupmBlasHandleWrapper(T h) noexcept : handle_(std::move(h)) { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); }
224: cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept
225: {
226: handle_ = nullptr;
227: return *this;
228: }
230: operator T() const { return handle_; }
232: const T *ptr_to() const { return &handle_; }
233: T *ptr_to() { return &handle_; }
235: private:
236: T handle_{};
237: };
239: #if PetscDefined(HAVE_CUDA)
240: #define PETSC_CUPMBLAS_PREFIX cublas
241: #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
242: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
243: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
244: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
245: template <>
246: struct BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> {
247: // typedefs
248: using cupmBlasHandle_t = cupmBlasHandleWrapper<cublasHandle_t, 0>;
249: using cupmBlasError_t = cublasStatus_t;
250: using cupmBlasInt_t = int;
251: using cupmBlasPointerMode_t = cublasPointerMode_t;
253: // values
254: static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS;
255: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
256: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = CUBLAS_STATUS_ALLOC_FAILED;
257: static const auto CUPMBLAS_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST;
258: static const auto CUPMBLAS_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE;
259: static const auto CUPMBLAS_OP_T = CUBLAS_OP_T;
260: static const auto CUPMBLAS_OP_N = CUBLAS_OP_N;
261: static const auto CUPMBLAS_OP_C = CUBLAS_OP_C;
262: static const auto CUPMBLAS_FILL_MODE_LOWER = CUBLAS_FILL_MODE_LOWER;
263: static const auto CUPMBLAS_FILL_MODE_UPPER = CUBLAS_FILL_MODE_UPPER;
264: static const auto CUPMBLAS_SIDE_LEFT = CUBLAS_SIDE_LEFT;
265: static const auto CUPMBLAS_DIAG_NON_UNIT = CUBLAS_DIAG_NON_UNIT;
267: // utility functions
268: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
269: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
270: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
271: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
272: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
273: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
275: // level 1 BLAS
276: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
277: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
278: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
279: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
280: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
281: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
282: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
283: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
285: // level 2 BLAS
286: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
288: // level 3 BLAS
289: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
290: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
292: // BLAS extensions
293: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
295: PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); }
296: };
297: #undef PETSC_CUPMBLAS_PREFIX
298: #undef PETSC_CUPMBLAS_PREFIX_U
299: #undef PETSC_CUPMBLAS_FP_TYPE
300: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
301: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
302: #endif // PetscDefined(HAVE_CUDA)
304: #if PetscDefined(HAVE_HIP)
305: #define PETSC_CUPMBLAS_PREFIX hipblas
306: #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
307: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
308: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
309: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
310: template <>
311: struct BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> {
312: // typedefs
313: using cupmBlasHandle_t = cupmBlasHandleWrapper<hipblasHandle_t, 0>;
314: using cupmBlasError_t = hipblasStatus_t;
315: using cupmBlasInt_t = int; // rocblas will have its own
316: using cupmBlasPointerMode_t = hipblasPointerMode_t;
318: // values
319: static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS;
320: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
321: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = HIPBLAS_STATUS_ALLOC_FAILED;
322: static const auto CUPMBLAS_POINTER_MODE_HOST = HIPBLAS_POINTER_MODE_HOST;
323: static const auto CUPMBLAS_POINTER_MODE_DEVICE = HIPBLAS_POINTER_MODE_DEVICE;
324: static const auto CUPMBLAS_OP_T = HIPBLAS_OP_T;
325: static const auto CUPMBLAS_OP_N = HIPBLAS_OP_N;
326: static const auto CUPMBLAS_OP_C = HIPBLAS_OP_C;
327: static const auto CUPMBLAS_FILL_MODE_LOWER = HIPBLAS_FILL_MODE_LOWER;
328: static const auto CUPMBLAS_FILL_MODE_UPPER = HIPBLAS_FILL_MODE_UPPER;
329: static const auto CUPMBLAS_SIDE_LEFT = HIPBLAS_SIDE_LEFT;
330: static const auto CUPMBLAS_DIAG_NON_UNIT = HIPBLAS_DIAG_NON_UNIT;
332: // utility functions
333: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
334: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
335: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
336: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
337: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
338: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
340: // level 1 BLAS
341: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
342: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
343: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
344: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
345: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
346: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
347: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
348: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
350: // level 2 BLAS
351: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
353: // level 3 BLAS
354: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
355: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
357: // BLAS extensions
358: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
360: PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); }
361: };
362: #undef PETSC_CUPMBLAS_PREFIX
363: #undef PETSC_CUPMBLAS_PREFIX_U
364: #undef PETSC_CUPMBLAS_FP_TYPE
365: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
366: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
367: #endif // PetscDefined(HAVE_HIP)
369: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \
370: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
371: /* introspection */ \
372: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \
373: /* types */ \
374: using cupmBlasHandle_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \
375: using cupmBlasError_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \
376: using cupmBlasInt_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \
377: using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \
378: /* values */ \
379: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \
380: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \
381: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \
382: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \
383: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \
384: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \
385: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \
386: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \
387: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \
388: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \
389: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \
390: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \
391: /* utility functions */ \
392: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \
393: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \
394: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \
395: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \
396: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \
397: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \
398: /* level 1 BLAS */ \
399: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \
400: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \
401: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \
402: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \
403: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \
404: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \
405: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \
406: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \
407: /* level 2 BLAS */ \
408: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \
409: /* level 3 BLAS */ \
410: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \
411: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \
412: /* BLAS extensions */ \
413: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam
415: // The actual interface class
416: template <DeviceType T>
417: struct BlasInterface : BlasInterfaceImpl<T> {
418: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T);
420: PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
422: static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
423: {
424: auto mtype = PETSC_MEMTYPE_HOST;
426: PetscFunctionBegin;
427: PetscCall(PetscCUPMGetMemType(ptr, &mtype));
428: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
429: PetscFunctionReturn(PETSC_SUCCESS);
430: }
432: static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
433: {
434: PetscFunctionBegin;
435: PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName());
436: PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName());
437: PetscFunctionReturn(PETSC_SUCCESS);
438: }
440: static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept
441: {
442: PetscFunctionBegin;
443: *y = static_cast<cupmBlasInt_t>(x);
444: PetscCall(checkCupmBlasIntCast(x));
445: PetscFunctionReturn(PETSC_SUCCESS);
446: }
447: };
449: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \
450: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \
451: using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \
452: using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \
453: using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \
454: using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast
456: #if PetscDefined(HAVE_CUDA)
457: extern template struct BlasInterface<DeviceType::CUDA>;
458: #endif
460: #if PetscDefined(HAVE_HIP)
461: extern template struct BlasInterface<DeviceType::HIP>;
462: #endif
464: } // namespace impl
466: } // namespace cupm
468: } // namespace device
470: } // namespace Petsc
472: #endif // defined(__cplusplus)
474: #endif // PETSCCUPMBLASINTERFACE_HPP