Actual source code: cupmblasinterface.hpp

  1: #ifndef PETSCCUPMBLASINTERFACE_HPP
  2: #define PETSCCUPMBLASINTERFACE_HPP

  4: #if defined(__cplusplus)
  5: #include <petsc/private/cupminterface.hpp>
  6: #include <petsc/private/petscadvancedmacros.h>

  8:   #include <limits> // std::numeric_limits

 10: namespace Petsc
 11: {

 13: namespace device
 14: {

 16: namespace cupm
 17: {

 19: namespace impl
 20: {

 22:   #define PetscCallCUPMBLAS(...) \
 23:     do { \
 24:       const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
 25:       if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 26:         if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 27:           SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
 28:                   "%s error %d (%s). Reports not initialized or alloc failed; " \
 29:                   "this indicates the GPU may have run out resources", \
 30:                   cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 31:         } \
 32:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 33:       } \
 34:     } while (0)

 36:   #define PetscCallCUPMBLASAbort(comm, ...) \
 37:     do { \
 38:       const cupmBlasError_t cberr_abort_p_ = __VA_ARGS__; \
 39:       if (PetscUnlikely(cberr_abort_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 40:         if (((cberr_abort_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_abort_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 41:           SETERRABORT(comm, PETSC_ERR_GPU_RESOURCE, \
 42:                       "%s error %d (%s). Reports not initialized or alloc failed; " \
 43:                       "this indicates the GPU may have run out resources", \
 44:                       cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
 45:         } \
 46:         SETERRABORT(comm, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
 47:       } \
 48:     } while (0)

 50:   // given cupmBlas<T>axpy() then
 51:   // T = PETSC_CUPBLAS_FP_TYPE
 52:   // given cupmBlas<T><u>nrm2() then
 53:   // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
 54:   // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
 55:   #if PetscDefined(USE_COMPLEX)
 56:     #if PetscDefined(USE_REAL_SINGLE)
 57:       #define PETSC_CUPMBLAS_FP_TYPE_U       C
 58:       #define PETSC_CUPMBLAS_FP_TYPE_L       c
 59:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
 60:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
 61:     #elif PetscDefined(USE_REAL_DOUBLE)
 62:       #define PETSC_CUPMBLAS_FP_TYPE_U       Z
 63:       #define PETSC_CUPMBLAS_FP_TYPE_L       z
 64:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
 65:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
 66:     #endif
 67:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 68:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 69:   #else
 70:     #if PetscDefined(USE_REAL_SINGLE)
 71:       #define PETSC_CUPMBLAS_FP_TYPE_U S
 72:       #define PETSC_CUPMBLAS_FP_TYPE_L s
 73:     #elif PetscDefined(USE_REAL_DOUBLE)
 74:       #define PETSC_CUPMBLAS_FP_TYPE_U D
 75:       #define PETSC_CUPMBLAS_FP_TYPE_L d
 76:     #endif
 77:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 78:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 79:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
 80:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
 81:   #endif // USE_COMPLEX

 83:   #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
 84:     #error "Unsupported floating-point type for CUDA/HIP BLAS"
 85:   #endif

 87:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
 88:   // blas function whose return type does not match the input type
 89:   //
 90:   // input param:
 91:   // func - base suffix of the blas function, e.g. nrm2
 92:   //
 93:   // notes:
 94:   // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
 95:   // letter ("S" for real/complex single, "D" for real/complex double).
 96:   //
 97:   // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
 98:   // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
 99:   // single/double).
100:   //
101:   // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
102:   // infuriatingly inconsistent...
103:   //
104:   // example usage:
105:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
106:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
107:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
108:   //
109:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
110:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
111:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
112:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)

114:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
115:   // because they are both extra special
116:   //
117:   // input param:
118:   // func - base suffix of the blas function, either amax or amin
119:   //
120:   // notes:
121:   // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
122:   // that's what it does.
123:   //
124:   // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
125:   // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
126:   // real double).
127:   //
128:   // example usage:
129:   // #define PETSC_CUPMBLAS_FP_TYPE_L s
130:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
131:   //
132:   // #define PETSC_CUPMBLAS_FP_TYPE_L z
133:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
134:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))

136:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
137:   // blas function name
138:   //
139:   // input param:
140:   // func - base suffix of the blas function, e.g. axpy, scal
141:   //
142:   // notes:
143:   // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
144:   // complex single, "Z" for complex double, "S" for real single, "D" for real double).
145:   //
146:   // example usage:
147:   // #define PETSC_CUPMBLAS_FP_TYPE S
148:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
149:   //
150:   // #define PETSC_CUPMBLAS_FP_TYPE Z
151:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
152:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)

154:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
155:   // one can provide both here
156:   //
157:   // input params:
158:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
159:   // IFPTYPE
160:   // our_suffix   - the suffix of the alias function
161:   // their_suffix - the suffix of the function being aliased
162:   //
163:   // notes:
164:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
165:   // prefix. requires any other specific definitions required by the specific builder macro to
166:   // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
167:   // function alias.
168:   //
169:   // example usage:
170:   // #define PETSC_CUPMBLAS_PREFIX  cublas
171:   // #define PETSC_CUPMBLAS_FP_TYPE C
172:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
173:   // template <typename... T>
174:   // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
175:   // {
176:   //   return cublasCdotc(std::forward<T>(args)...);
177:   // }
178:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
179:     PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))

181:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
182:   //
183:   // input params:
184:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
185:   // IFPTYPE
186:   // suffix       - the common suffix between CUDA and HIP of the alias function
187:   //
188:   // notes:
189:   // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
190:   // "our_prefix" and "their_prefix"
191:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)

193:   // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
194:   //
195:   // input params:
196:   // suffix - the common suffix between CUDA and HIP of the alias function
197:   //
198:   // notes:
199:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
200:   // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
201:   //
202:   // example usage:
203:   // #define PETSC_CUPMBLAS_PREFIX hipblas
204:   // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
205:   // template <typename... T>
206:   // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
207:   // {
208:   //   return hipblasCreate(std::forward<T>(args)...);
209:   // }
210:   #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))

212: template <DeviceType>
213: struct BlasInterfaceImpl;

215: // Exists because HIP (for whatever godforsaken reason) has elected to define both their
216: // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload
217: // resolution and hence need to wrap their types int this mess.
218: template <typename T, std::size_t I>
219: class cupmBlasHandleWrapper {
220: public:
221:   constexpr cupmBlasHandleWrapper() noexcept = default;
222:   constexpr cupmBlasHandleWrapper(T h) noexcept : handle_(std::move(h)) { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); }

224:   cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept
225:   {
226:     handle_ = nullptr;
227:     return *this;
228:   }

230:   operator T() const { return handle_; }

232:   const T *ptr_to() const { return &handle_; }
233:   T       *ptr_to() { return &handle_; }

235: private:
236:   T handle_{};
237: };

239:   #if PetscDefined(HAVE_CUDA)
240:     #define PETSC_CUPMBLAS_PREFIX         cublas
241:     #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
242:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
243:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
244:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
245: template <>
246: struct BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> {
247:   // typedefs
248:   using cupmBlasHandle_t      = cupmBlasHandleWrapper<cublasHandle_t, 0>;
249:   using cupmBlasError_t       = cublasStatus_t;
250:   using cupmBlasInt_t         = int;
251:   using cupmBlasPointerMode_t = cublasPointerMode_t;

253:   // values
254:   static const auto CUPMBLAS_STATUS_SUCCESS         = CUBLAS_STATUS_SUCCESS;
255:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
256:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = CUBLAS_STATUS_ALLOC_FAILED;
257:   static const auto CUPMBLAS_POINTER_MODE_HOST      = CUBLAS_POINTER_MODE_HOST;
258:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = CUBLAS_POINTER_MODE_DEVICE;
259:   static const auto CUPMBLAS_OP_T                   = CUBLAS_OP_T;
260:   static const auto CUPMBLAS_OP_N                   = CUBLAS_OP_N;
261:   static const auto CUPMBLAS_OP_C                   = CUBLAS_OP_C;
262:   static const auto CUPMBLAS_FILL_MODE_LOWER        = CUBLAS_FILL_MODE_LOWER;
263:   static const auto CUPMBLAS_FILL_MODE_UPPER        = CUBLAS_FILL_MODE_UPPER;
264:   static const auto CUPMBLAS_SIDE_LEFT              = CUBLAS_SIDE_LEFT;
265:   static const auto CUPMBLAS_DIAG_NON_UNIT          = CUBLAS_DIAG_NON_UNIT;

267:   // utility functions
268:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
269:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
270:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
271:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
272:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
273:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

275:   // level 1 BLAS
276:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
277:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
278:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
279:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
280:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
281:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
282:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
283:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

285:   // level 2 BLAS
286:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

288:   // level 3 BLAS
289:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
290:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)

292:   // BLAS extensions
293:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

295:   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); }
296: };
297:     #undef PETSC_CUPMBLAS_PREFIX
298:     #undef PETSC_CUPMBLAS_PREFIX_U
299:     #undef PETSC_CUPMBLAS_FP_TYPE
300:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
301:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
302:   #endif // PetscDefined(HAVE_CUDA)

304:   #if PetscDefined(HAVE_HIP)
305:     #define PETSC_CUPMBLAS_PREFIX         hipblas
306:     #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
307:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
308:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
309:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
310: template <>
311: struct BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> {
312:   // typedefs
313:   using cupmBlasHandle_t      = cupmBlasHandleWrapper<hipblasHandle_t, 0>;
314:   using cupmBlasError_t       = hipblasStatus_t;
315:   using cupmBlasInt_t         = int; // rocblas will have its own
316:   using cupmBlasPointerMode_t = hipblasPointerMode_t;

318:   // values
319:   static const auto CUPMBLAS_STATUS_SUCCESS         = HIPBLAS_STATUS_SUCCESS;
320:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
321:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = HIPBLAS_STATUS_ALLOC_FAILED;
322:   static const auto CUPMBLAS_POINTER_MODE_HOST      = HIPBLAS_POINTER_MODE_HOST;
323:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = HIPBLAS_POINTER_MODE_DEVICE;
324:   static const auto CUPMBLAS_OP_T                   = HIPBLAS_OP_T;
325:   static const auto CUPMBLAS_OP_N                   = HIPBLAS_OP_N;
326:   static const auto CUPMBLAS_OP_C                   = HIPBLAS_OP_C;
327:   static const auto CUPMBLAS_FILL_MODE_LOWER        = HIPBLAS_FILL_MODE_LOWER;
328:   static const auto CUPMBLAS_FILL_MODE_UPPER        = HIPBLAS_FILL_MODE_UPPER;
329:   static const auto CUPMBLAS_SIDE_LEFT              = HIPBLAS_SIDE_LEFT;
330:   static const auto CUPMBLAS_DIAG_NON_UNIT          = HIPBLAS_DIAG_NON_UNIT;

332:   // utility functions
333:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
334:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
335:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
336:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
337:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
338:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

340:   // level 1 BLAS
341:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
342:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
343:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
344:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
345:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
346:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
347:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
348:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

350:   // level 2 BLAS
351:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

353:   // level 3 BLAS
354:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
355:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)

357:   // BLAS extensions
358:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

360:   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); }
361: };
362:     #undef PETSC_CUPMBLAS_PREFIX
363:     #undef PETSC_CUPMBLAS_PREFIX_U
364:     #undef PETSC_CUPMBLAS_FP_TYPE
365:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
366:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
367:   #endif // PetscDefined(HAVE_HIP)

369:   #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \
370:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
371:     /* introspection */ \
372:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \
373:     /* types */ \
374:     using cupmBlasHandle_t      = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \
375:     using cupmBlasError_t       = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \
376:     using cupmBlasInt_t         = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \
377:     using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \
378:     /* values */ \
379:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \
380:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \
381:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \
382:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \
383:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \
384:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \
385:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \
386:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \
387:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \
388:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \
389:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \
390:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \
391:     /* utility functions */ \
392:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \
393:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \
394:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \
395:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \
396:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \
397:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \
398:     /* level 1 BLAS */ \
399:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \
400:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \
401:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \
402:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \
403:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \
404:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \
405:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \
406:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \
407:     /* level 2 BLAS */ \
408:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \
409:     /* level 3 BLAS */ \
410:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \
411:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \
412:     /* BLAS extensions */ \
413:     using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam

415: // The actual interface class
416: template <DeviceType T>
417: struct BlasInterface : BlasInterfaceImpl<T> {
418:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T);

420:   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }

422:   static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
423:   {
424:     auto mtype = PETSC_MEMTYPE_HOST;

426:     PetscFunctionBegin;
427:     PetscCall(PetscCUPMGetMemType(ptr, &mtype));
428:     PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
429:     PetscFunctionReturn(PETSC_SUCCESS);
430:   }

432:   static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
433:   {
434:     PetscFunctionBegin;
435:     PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName());
436:     PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName());
437:     PetscFunctionReturn(PETSC_SUCCESS);
438:   }

440:   static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept
441:   {
442:     PetscFunctionBegin;
443:     *y = static_cast<cupmBlasInt_t>(x);
444:     PetscCall(checkCupmBlasIntCast(x));
445:     PetscFunctionReturn(PETSC_SUCCESS);
446:   }
447: };

449:   #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \
450:     PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \
451:     using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \
452:     using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \
453:     using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \
454:     using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast

456:   #if PetscDefined(HAVE_CUDA)
457: extern template struct BlasInterface<DeviceType::CUDA>;
458:   #endif

460:   #if PetscDefined(HAVE_HIP)
461: extern template struct BlasInterface<DeviceType::HIP>;
462:   #endif

464: } // namespace impl

466: } // namespace cupm

468: } // namespace device

470: } // namespace Petsc

472: #endif // defined(__cplusplus)

474: #endif // PETSCCUPMBLASINTERFACE_HPP