Actual source code: vecseqcupm.hpp

  1: #ifndef PETSCVECSEQCUPM_HPP
  2: #define PETSCVECSEQCUPM_HPP

  4: #include <petsc/private/veccupmimpl.h>

  6: #if defined(__cplusplus)
  7: #include <petsc/private/randomimpl.h>

  9: #include <petsc/private/cpp/utility.hpp>

 11:   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
 12:   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"

 14:   #if PetscDefined(USE_COMPLEX)
 15:     #include <thrust/transform_reduce.h>
 16:   #endif
 17:   #include <thrust/transform.h>
 18:   #include <thrust/reduce.h>
 19:   #include <thrust/functional.h>
 20:   #include <thrust/tuple.h>
 21:   #include <thrust/device_ptr.h>
 22:   #include <thrust/iterator/zip_iterator.h>
 23:   #include <thrust/iterator/counting_iterator.h>
 24:   #include <thrust/inner_product.h>

 26: namespace Petsc
 27: {

 29: namespace vec
 30: {

 32: namespace cupm
 33: {

 35: namespace impl
 36: {

 38: // ==========================================================================================
 39: // VecSeq_CUPM
 40: // ==========================================================================================

 42: template <device::cupm::DeviceType T>
 43: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
 44: public:
 45:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);

 47: private:
 48:   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
 49:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;

 51:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 52:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 53:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 54:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 56:   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;

 58:   // common core for min and max
 59:   template <typename TupleFuncT, typename UnaryFuncT>
 60:   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
 61:   // common core for pointwise binary and pointwise unary thrust functions
 62:   template <typename BinaryFuncT>
 63:   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
 64:   template <typename UnaryFuncT>
 65:   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
 66:   // mdot dispatchers
 67:   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 68:   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 69:   template <std::size_t... Idx>
 70:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
 71:   template <int>
 72:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
 73:   template <std::size_t... Idx>
 74:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
 75:   template <int>
 76:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
 77:   // common core for the various create routines
 78:   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;

 80: public:
 81:   // callable directly via a bespoke function
 82:   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 83:   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 85:   // callable indirectly via function pointers
 86:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 87:   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
 88:   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
 89:   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
 90:   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
 91:   static PetscErrorCode Reciprocal(Vec) noexcept;
 92:   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
 93:   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
 94:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
 95:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
 96:   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
 97:   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
 98:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
 99:   static PetscErrorCode Copy(Vec, Vec) noexcept;
100:   static PetscErrorCode Swap(Vec, Vec) noexcept;
101:   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
102:   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
103:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
104:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
105:   static PetscErrorCode Conjugate(Vec) noexcept;
106:   template <PetscMemoryAccessMode>
107:   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
108:   template <PetscMemoryAccessMode>
109:   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
110:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
111:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
112:   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
113:   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
114:   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
115:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
116:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
117:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
118: };

120: // ==========================================================================================
121: // VecSeq_CUPM - Private API
122: // ==========================================================================================

124: template <device::cupm::DeviceType T>
125: inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
126: {
127:   return static_cast<Vec_Seq *>(v->data);
128: }

130: template <device::cupm::DeviceType T>
131: inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
132: {
133:   return VECSEQCUPM();
134: }

136: template <device::cupm::DeviceType T>
137: inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
138: {
139:   return VecDestroy_Seq(v);
140: }

142: template <device::cupm::DeviceType T>
143: inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
144: {
145:   return VecResetArray_Seq(v);
146: }

148: template <device::cupm::DeviceType T>
149: inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
150: {
151:   return VecPlaceArray_Seq(v, a);
152: }

154: template <device::cupm::DeviceType T>
155: inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
156: {
157:   PetscMPIInt size;

159:   PetscFunctionBegin;
160:   if (alloc_missing) *alloc_missing = PETSC_FALSE;
161:   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
162:   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
163:   PetscCall(VecCreate_Seq_Private(v, host_array));
164:   PetscFunctionReturn(PETSC_SUCCESS);
165: }

167: // for functions with an early return based one vec size we still need to artificially bump the
168: // object state. This is to prevent the following:
169: //
170: // 0. Suppose you have a Vec {
171: //   rank 0: [0],
172: //   rank 1: [<empty>]
173: // }
174: // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
175: // 2. Vec enters e.g. VecSet(10)
176: // 3. rank 1 has local size 0 and bails immediately
177: // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
178: // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
179: // 6. Vec enters VecNorm(), and calls VecNormAvailable()
180: // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
181: // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
182: // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
183: template <device::cupm::DeviceType T>
184: inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
185: {
186:   PetscFunctionBegin;
187:   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
188:   PetscFunctionReturn(PETSC_SUCCESS);
189: }

191: template <device::cupm::DeviceType T>
192: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
193: {
194:   PetscFunctionBegin;
195:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
196:   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
197:   PetscFunctionReturn(PETSC_SUCCESS);
198: }

200: template <device::cupm::DeviceType T>
201: template <typename BinaryFuncT>
202: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
203: {
204:   PetscFunctionBegin;
205:   if (const auto n = zout->map->n) {
206:     PetscDeviceContext dctx;
207:     cupmStream_t       stream;

209:     PetscCall(GetHandles_(&dctx, &stream));
210:     // clang-format off
211:     PetscCallThrust(
212:       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());

214:       THRUST_CALL(
215:         thrust::transform,
216:         stream,
217:         dxptr, dxptr + n,
218:         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
219:         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
220:         std::forward<BinaryFuncT>(binary)
221:       )
222:     );
223:     // clang-format on
224:     PetscCall(PetscLogFlops(n));
225:     PetscCall(PetscDeviceContextSynchronize(dctx));
226:   } else {
227:     PetscCall(MaybeIncrementEmptyLocalVec(zout));
228:   }
229:   PetscFunctionReturn(PETSC_SUCCESS);
230: }

232: template <device::cupm::DeviceType T>
233: template <typename UnaryFuncT>
234: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
235: {
236:   const auto inplace = !yin || (xinout == yin);

238:   PetscFunctionBegin;
239:   if (const auto n = xinout->map->n) {
240:     PetscDeviceContext dctx;
241:     cupmStream_t       stream;
242:     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
243:       PetscFunctionBegin;
244:       // clang-format off
245:       PetscCallThrust(
246:         const auto xptr = thrust::device_pointer_cast(xinout);

248:         THRUST_CALL(
249:           thrust::transform,
250:           stream,
251:           xptr, xptr + n,
252:           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
253:           std::forward<UnaryFuncT>(unary)
254:         )
255:       );
256:       PetscFunctionReturn(PETSC_SUCCESS);
257:     };

259:     PetscCall(GetHandles_(&dctx, &stream));
260:     if (inplace) {
261:       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
262:     } else {
263:       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
264:     }
265:     PetscCall(PetscLogFlops(n));
266:     PetscCall(PetscDeviceContextSynchronize(dctx));
267:   } else {
268:     if (inplace) {
269:       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
270:     } else {
271:       PetscCall(MaybeIncrementEmptyLocalVec(yin));
272:     }
273:   }
274:   PetscFunctionReturn(PETSC_SUCCESS);
275: }

277: // ==========================================================================================
278: // VecSeq_CUPM - Public API - Constructors
279: // ==========================================================================================

281: // VecCreateSeqCUPM()
282: template <device::cupm::DeviceType T>
283: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
284: {
285:   PetscFunctionBegin;
286:   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
287:   PetscFunctionReturn(PETSC_SUCCESS);
288: }

290: // VecCreateSeqCUPMWithArrays()
291: template <device::cupm::DeviceType T>
292: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
293: {
294:   PetscDeviceContext dctx;

296:   PetscFunctionBegin;
297:   PetscCall(GetHandles_(&dctx));
298:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
299:   // CreateSeqCUPM_() is called!
300:   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
301:   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
302:   PetscFunctionReturn(PETSC_SUCCESS);
303: }

305: // v->ops->duplicate
306: template <device::cupm::DeviceType T>
307: inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
308: {
309:   PetscDeviceContext dctx;

311:   PetscFunctionBegin;
312:   PetscCall(GetHandles_(&dctx));
313:   PetscCall(Duplicate_CUPMBase(v, y, dctx));
314:   PetscFunctionReturn(PETSC_SUCCESS);
315: }

317: // ==========================================================================================
318: // VecSeq_CUPM - Public API - Utility
319: // ==========================================================================================

321: // v->ops->bindtocpu
322: template <device::cupm::DeviceType T>
323: inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
324: {
325:   PetscDeviceContext dctx;

327:   PetscFunctionBegin;
328:   PetscCall(GetHandles_(&dctx));
329:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

331:   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
332:   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
333:   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
334:   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
335:   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
336:   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
337:   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
338:   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
339:   VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate);
340:   VecSetOp_CUPM(max, VecMax_Seq, Max);
341:   VecSetOp_CUPM(min, VecMin_Seq, Min);
342:   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
343:   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
344:   PetscFunctionReturn(PETSC_SUCCESS);
345: }

347: // ==========================================================================================
348: // VecSeq_CUPM - Public API - Mutators
349: // ==========================================================================================

351: // v->ops->getlocalvector or v->ops->getlocalvectorread
352: template <device::cupm::DeviceType T>
353: template <PetscMemoryAccessMode access>
354: inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
355: {
356:   PetscBool wisseqcupm;

358:   PetscFunctionBegin;
359:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
360:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
361:   if (wisseqcupm) {
362:     if (const auto wseq = VecIMPLCast(w)) {
363:       if (auto &alloced = wseq->array_allocated) {
364:         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));

366:         PetscCall(PetscFree(alloced));
367:       }
368:       wseq->array         = nullptr;
369:       wseq->unplacedarray = nullptr;
370:     }
371:     if (const auto wcu = VecCUPMCast(w)) {
372:       if (auto &device_array = wcu->array_d) {
373:         cupmStream_t stream;

375:         PetscCall(GetHandles_(&stream));
376:         PetscCallCUPM(cupmFreeAsync(device_array, stream));
377:       }
378:       PetscCall(PetscFree(w->spptr /* wcu */));
379:     }
380:   }
381:   if (v->petscnative && wisseqcupm) {
382:     PetscCall(PetscFree(w->data));
383:     w->data          = v->data;
384:     w->offloadmask   = v->offloadmask;
385:     w->pinned_memory = v->pinned_memory;
386:     w->spptr         = v->spptr;
387:     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
388:   } else {
389:     const auto array = &VecIMPLCast(w)->array;

391:     if (access == PETSC_MEMORY_ACCESS_READ) {
392:       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
393:     } else {
394:       PetscCall(VecGetArray(v, array));
395:     }
396:     w->offloadmask = PETSC_OFFLOAD_CPU;
397:     if (wisseqcupm) {
398:       PetscDeviceContext dctx;

400:       PetscCall(GetHandles_(&dctx));
401:       PetscCall(DeviceAllocateCheck_(dctx, w));
402:     }
403:   }
404:   PetscFunctionReturn(PETSC_SUCCESS);
405: }

407: // v->ops->restorelocalvector or v->ops->restorelocalvectorread
408: template <device::cupm::DeviceType T>
409: template <PetscMemoryAccessMode access>
410: inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
411: {
412:   PetscBool wisseqcupm;

414:   PetscFunctionBegin;
415:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
416:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
417:   if (v->petscnative && wisseqcupm) {
418:     // the assignments to nullptr are __critical__, as w may persist after this call returns
419:     // and shouldn't share data with v!
420:     v->pinned_memory = w->pinned_memory;
421:     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
422:     v->data          = util::exchange(w->data, nullptr);
423:     v->spptr         = util::exchange(w->spptr, nullptr);
424:   } else {
425:     const auto array = &VecIMPLCast(w)->array;

427:     if (access == PETSC_MEMORY_ACCESS_READ) {
428:       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
429:     } else {
430:       PetscCall(VecRestoreArray(v, array));
431:     }
432:     if (w->spptr && wisseqcupm) {
433:       cupmStream_t stream;

435:       PetscCall(GetHandles_(&stream));
436:       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
437:       PetscCall(PetscFree(w->spptr));
438:     }
439:   }
440:   PetscFunctionReturn(PETSC_SUCCESS);
441: }

443: // ==========================================================================================
444: // VecSeq_CUPM - Public API - Compute Methods
445: // ==========================================================================================

447: // v->ops->aypx
448: template <device::cupm::DeviceType T>
449: inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
450: {
451:   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
452:   const auto         sync = n != 0;
453:   PetscDeviceContext dctx;

455:   PetscFunctionBegin;
456:   PetscCall(GetHandles_(&dctx));
457:   if (alpha == PetscScalar(0.0)) {
458:     cupmStream_t stream;

460:     PetscCall(GetHandlesFrom_(dctx, &stream));
461:     PetscCall(PetscLogGpuTimeBegin());
462:     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
463:     PetscCall(PetscLogGpuTimeEnd());
464:   } else if (n) {
465:     const auto       alphaIsOne = alpha == PetscScalar(1.0);
466:     const auto       calpha     = cupmScalarPtrCast(&alpha);
467:     cupmBlasHandle_t cupmBlasHandle;

469:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
470:     {
471:       const auto yptr = DeviceArrayReadWrite(dctx, yin);
472:       const auto xptr = DeviceArrayRead(dctx, xin);

474:       PetscCall(PetscLogGpuTimeBegin());
475:       if (alphaIsOne) {
476:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
477:       } else {
478:         const auto one = cupmScalarCast(1.0);

480:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
481:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
482:       }
483:       PetscCall(PetscLogGpuTimeEnd());
484:     }
485:     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
486:   }
487:   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
488:   PetscFunctionReturn(PETSC_SUCCESS);
489: }

491: // v->ops->axpy
492: template <device::cupm::DeviceType T>
493: inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
494: {
495:   PetscBool xiscupm;

497:   PetscFunctionBegin;
498:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
499:   if (xiscupm) {
500:     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
501:     PetscDeviceContext dctx;
502:     cupmBlasHandle_t   cupmBlasHandle;

504:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
505:     PetscCall(PetscLogGpuTimeBegin());
506:     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
507:     PetscCall(PetscLogGpuTimeEnd());
508:     PetscCall(PetscLogGpuFlops(2 * n));
509:     PetscCall(PetscDeviceContextSynchronize(dctx));
510:   } else {
511:     PetscCall(VecAXPY_Seq(yin, alpha, xin));
512:   }
513:   PetscFunctionReturn(PETSC_SUCCESS);
514: }

516: // v->ops->pointwisedivide
517: template <device::cupm::DeviceType T>
518: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept
519: {
520:   PetscFunctionBegin;
521:   if (xin->boundtocpu || yin->boundtocpu) {
522:     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
523:   } else {
524:     // note order of arguments! xin and yin are read, win is written!
525:     PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
526:   }
527:   PetscFunctionReturn(PETSC_SUCCESS);
528: }

530: // v->ops->pointwisemult
531: template <device::cupm::DeviceType T>
532: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept
533: {
534:   PetscFunctionBegin;
535:   if (xin->boundtocpu || yin->boundtocpu) {
536:     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
537:   } else {
538:     // note order of arguments! xin and yin are read, win is written!
539:     PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
540:   }
541:   PetscFunctionReturn(PETSC_SUCCESS);
542: }

544: namespace detail
545: {

547: struct reciprocal {
548:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
549:   {
550:     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
551:     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
552:     // everything in PetscScalar...
553:     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
554:   }
555: };

557: } // namespace detail

559: // v->ops->reciprocal
560: template <device::cupm::DeviceType T>
561: inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
562: {
563:   PetscFunctionBegin;
564:   PetscCall(PointwiseUnary_(detail::reciprocal{}, xin));
565:   PetscFunctionReturn(PETSC_SUCCESS);
566: }

568: // v->ops->waxpy
569: template <device::cupm::DeviceType T>
570: inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
571: {
572:   PetscFunctionBegin;
573:   if (alpha == PetscScalar(0.0)) {
574:     PetscCall(Copy(yin, win));
575:   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
576:     PetscDeviceContext dctx;
577:     cupmBlasHandle_t   cupmBlasHandle;
578:     cupmStream_t       stream;

580:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
581:     {
582:       const auto wptr = DeviceArrayWrite(dctx, win);

584:       PetscCall(PetscLogGpuTimeBegin());
585:       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
586:       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
587:       PetscCall(PetscLogGpuTimeEnd());
588:     }
589:     PetscCall(PetscLogGpuFlops(2 * n));
590:     PetscCall(PetscDeviceContextSynchronize(dctx));
591:   }
592:   PetscFunctionReturn(PETSC_SUCCESS);
593: }

595: namespace kernels
596: {

598: template <typename... Args>
599: PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
600: {
601:   constexpr int      N        = sizeof...(Args);
602:   const auto         tx       = threadIdx.x;
603:   const PetscScalar *yptr_p[] = {yptr...};

605:   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];

607:   // load a to shared memory
608:   if (tx < N) aptr_shmem[tx] = aptr[tx];
609:   __syncthreads();

611:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
612:     // these may look the same but give different results!
613:   #if 0
614:     PetscScalar sum = 0.0;

616:     #pragma unroll
617:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
618:     xptr[i] += sum;
619:   #else
620:     auto sum = xptr[i];

622:     #pragma unroll
623:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
624:     xptr[i] = sum;
625:   #endif
626:   });
627:   return;
628: }

630: } // namespace kernels

632: namespace detail
633: {

635: // a helper-struct to gobble the size_t input, it is used with template parameter pack
636: // expansion such that
637: // typename repeat_type<MyType, IdxParamPack>...
638: // expands to
639: // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
640: template <typename T, std::size_t>
641: struct repeat_type {
642:   using type = T;
643: };

645: } // namespace detail

647: template <device::cupm::DeviceType T>
648: template <std::size_t... Idx>
649: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
650: {
651:   PetscFunctionBegin;
652:   // clang-format off
653:   PetscCall(
654:     PetscCUPMLaunchKernel1D(
655:       size, 0, stream,
656:       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
657:       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
658:     )
659:   );
660:   // clang-format on
661:   PetscFunctionReturn(PETSC_SUCCESS);
662: }

664: template <device::cupm::DeviceType T>
665: template <int N>
666: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
667: {
668:   PetscFunctionBegin;
669:   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
670:   yidx += N;
671:   PetscFunctionReturn(PETSC_SUCCESS);
672: }

674: // v->ops->maxpy
675: template <device::cupm::DeviceType T>
676: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
677: {
678:   const auto         n = xin->map->n;
679:   PetscDeviceContext dctx;
680:   cupmStream_t       stream;

682:   PetscFunctionBegin;
683:   PetscCall(GetHandles_(&dctx, &stream));
684:   {
685:     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
686:     PetscScalar *d_alpha = nullptr;
687:     PetscInt     yidx    = 0;

689:     // placement of early-return is deliberate, we would like to capture the
690:     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
691:     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
692:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
693:     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
694:     PetscCall(PetscLogGpuTimeBegin());
695:     do {
696:       switch (nv - yidx) {
697:       case 7:
698:         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
699:         break;
700:       case 6:
701:         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
702:         break;
703:       case 5:
704:         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
705:         break;
706:       case 4:
707:         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
708:         break;
709:       case 3:
710:         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
711:         break;
712:       case 2:
713:         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
714:         break;
715:       case 1:
716:         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
717:         break;
718:       default: // 8 or more
719:         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
720:         break;
721:       }
722:     } while (yidx < nv);
723:     PetscCall(PetscLogGpuTimeEnd());
724:     PetscCall(PetscDeviceFree(dctx, d_alpha));
725:   }
726:   PetscCall(PetscLogGpuFlops(nv * 2 * n));
727:   PetscCall(PetscDeviceContextSynchronize(dctx));
728:   PetscFunctionReturn(PETSC_SUCCESS);
729: }

731: template <device::cupm::DeviceType T>
732: inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
733: {
734:   PetscFunctionBegin;
735:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
736:     PetscDeviceContext dctx;
737:     cupmBlasHandle_t   cupmBlasHandle;

739:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
740:     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
741:     // second
742:     PetscCall(PetscLogGpuTimeBegin());
743:     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
744:     PetscCall(PetscLogGpuTimeEnd());
745:     PetscCall(PetscLogGpuFlops(2 * n - 1));
746:   } else {
747:     *z = 0.0;
748:   }
749:   PetscFunctionReturn(PETSC_SUCCESS);
750: }

752:   #define MDOT_WORKGROUP_NUM  128
753:   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM

755: namespace kernels
756: {

758: PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
759: {
760:   const auto group_entries = (size - 1) / gridDim.x + 1;
761:   // for very small vectors, a group should still do some work
762:   return group_entries ? group_entries : 1;
763: }

765: template <typename... ConstPetscScalarPointer>
766: PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
767: {
768:   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
769:   const PetscScalar *ylocal[] = {y...};
770:   PetscScalar        sumlocal[N];

772:   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];

774:   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
775:   // types, so each of these go on separate lines...
776:   const auto tx       = threadIdx.x;
777:   const auto bx       = blockIdx.x;
778:   const auto bdx      = blockDim.x;
779:   const auto gdx      = gridDim.x;
780:   const auto worksize = EntriesPerGroup(size);
781:   const auto begin    = tx + bx * worksize;
782:   const auto end      = min((bx + 1) * worksize, size);

784:   #pragma unroll
785:   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;

787:   for (auto i = begin; i < end; i += bdx) {
788:     const auto xi = x[i]; // load only once from global memory!

790:   #pragma unroll
791:     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
792:   }

794:   #pragma unroll
795:   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];

797:   // parallel reduction
798:   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
799:     __syncthreads();
800:     if (tx < stride) {
801:   #pragma unroll
802:       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
803:     }
804:   }
805:   // bottom N threads per block write to global memory
806:   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
807:   // writes to the same sections in the above loop that it is about to read from below, but
808:   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
809:   __syncthreads();
810:   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
811:   return;
812: }

814: namespace
815: {

817: PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
818: {
819:   int         local_i = 0;
820:   PetscScalar local_results[8];

822:   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
823:   //
824:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
825:   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
826:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
827:   //  |  ______________________________________________________/
828:   //  | /            <- MDOT_WORKGROUP_NUM ->
829:   //  |/
830:   //  +
831:   //  v
832:   // *-*-*
833:   // | | | ...
834:   // *-*-*
835:   //
836:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
837:     PetscScalar z_sum = 0;

839:     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
840:     local_results[local_i++] = z_sum;
841:   });
842:   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
843:   // may currently be reading from results
844:   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
845:   // Local buffer is now written to global memory
846:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
847:     const auto j = --local_i;

849:     if (j >= 0) results[i] = local_results[j];
850:   });
851:   return;
852: }

854: } // namespace

856: } // namespace kernels

858: template <device::cupm::DeviceType T>
859: template <std::size_t... Idx>
860: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
861: {
862:   PetscFunctionBegin;
863:   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
864:   // 128 blocks of 128 threads every time which may be wasteful
865:   // clang-format off
866:   PetscCallCUPM(
867:     cupmLaunchKernel(
868:       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
869:       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
870:       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
871:     )
872:   );
873:   // clang-format on
874:   PetscFunctionReturn(PETSC_SUCCESS);
875: }

877: template <device::cupm::DeviceType T>
878: template <int N>
879: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
880: {
881:   PetscFunctionBegin;
882:   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
883:   yidx += N;
884:   PetscFunctionReturn(PETSC_SUCCESS);
885: }

887: template <device::cupm::DeviceType T>
888: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
889: {
890:   // the largest possible size of a batch
891:   constexpr PetscInt batchsize = 8;
892:   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
893:   // do not create substreams. Note we don't create more than 8 streams, in practice we could
894:   // not get more parallelism with higher numbers.
895:   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
896:   const auto n               = xin->map->n;
897:   // number of vectors that we handle via the batches. note any singletons are handled by
898:   // cublas, hence the nv-1.
899:   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
900:   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
901:   PetscScalar *d_results;
902:   cupmStream_t stream;

904:   PetscFunctionBegin;
905:   PetscCall(GetHandlesFrom_(dctx, &stream));
906:   // allocate scratchpad memory for the results of individual work groups
907:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
908:   {
909:     const auto          xptr       = DeviceArrayRead(dctx, xin);
910:     PetscInt            yidx       = 0;
911:     auto                subidx     = 0;
912:     auto                cur_stream = stream;
913:     auto                cur_ctx    = dctx;
914:     PetscDeviceContext *sub        = nullptr;
915:     PetscStreamType     stype;

917:     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
918:     // sub. Ideally the parent context should also join in on the fork, but it is extremely
919:     // fiddly to do so presently
920:     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
921:     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
922:     // If we have a globally blocking stream create nonblocking streams instead (as we can
923:     // locally exploit the parallelism). Otherwise use the prescribed stream type.
924:     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
925:     PetscCall(PetscLogGpuTimeBegin());
926:     do {
927:       if (num_sub_streams) {
928:         cur_ctx = sub[subidx++ % num_sub_streams];
929:         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
930:       }
931:       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
932:       // it is very likely better to do 4+5 rather than 8+1
933:       switch (nv - yidx) {
934:       case 7:
935:         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
936:         break;
937:       case 6:
938:         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
939:         break;
940:       case 5:
941:         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
942:         break;
943:       case 4:
944:         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
945:         break;
946:       case 3:
947:         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
948:         break;
949:       case 2:
950:         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
951:         break;
952:       case 1: {
953:         cupmBlasHandle_t cupmBlasHandle;

955:         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
956:         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
957:         ++yidx;
958:       } break;
959:       default: // 8 or more
960:         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
961:         break;
962:       }
963:     } while (yidx < nv);
964:     PetscCall(PetscLogGpuTimeEnd());
965:     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
966:   }

968:   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
969:   // copy result of device reduction to host
970:   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
971:   // do these now while final reduction is in flight
972:   PetscCall(PetscLogFlops(nwork));
973:   PetscCall(PetscDeviceFree(dctx, d_results));
974:   PetscFunctionReturn(PETSC_SUCCESS);
975: }

977:   #undef MDOT_WORKGROUP_NUM
978:   #undef MDOT_WORKGROUP_SIZE

980: template <device::cupm::DeviceType T>
981: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
982: {
983:   // probably not worth it to run more than 8 of these at a time?
984:   const auto          n_sub = PetscMin(nv, 8);
985:   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
986:   const auto          xptr  = DeviceArrayRead(dctx, xin);
987:   PetscScalar        *d_z;
988:   PetscDeviceContext *subctx;
989:   cupmStream_t        stream;

991:   PetscFunctionBegin;
992:   PetscCall(GetHandlesFrom_(dctx, &stream));
993:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
994:   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
995:   PetscCall(PetscLogGpuTimeBegin());
996:   for (PetscInt i = 0; i < nv; ++i) {
997:     const auto            sub = subctx[i % n_sub];
998:     cupmBlasHandle_t      handle;
999:     cupmBlasPointerMode_t old_mode;

1001:     PetscCall(GetHandlesFrom_(sub, &handle));
1002:     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1003:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1004:     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1005:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1006:   }
1007:   PetscCall(PetscLogGpuTimeEnd());
1008:   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1009:   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1010:   PetscCall(PetscDeviceFree(dctx, d_z));
1011:   // REVIEW ME: flops?????
1012:   PetscFunctionReturn(PETSC_SUCCESS);
1013: }

1015: // v->ops->mdot
1016: template <device::cupm::DeviceType T>
1017: inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1018: {
1019:   PetscFunctionBegin;
1020:   if (PetscUnlikely(nv == 1)) {
1021:     // dot handles nv = 0 correctly
1022:     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1023:   } else if (const auto n = xin->map->n) {
1024:     PetscDeviceContext dctx;

1026:     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1027:     PetscCall(GetHandles_(&dctx));
1028:     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1029:     // REVIEW ME: double count of flops??
1030:     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1031:     PetscCall(PetscDeviceContextSynchronize(dctx));
1032:   } else {
1033:     PetscCall(PetscArrayzero(z, nv));
1034:   }
1035:   PetscFunctionReturn(PETSC_SUCCESS);
1036: }

1038: // v->ops->set
1039: template <device::cupm::DeviceType T>
1040: inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1041: {
1042:   const auto         n = xin->map->n;
1043:   PetscDeviceContext dctx;
1044:   cupmStream_t       stream;

1046:   PetscFunctionBegin;
1047:   PetscCall(GetHandles_(&dctx, &stream));
1048:   {
1049:     const auto xptr = DeviceArrayWrite(dctx, xin);

1051:     if (alpha == PetscScalar(0.0)) {
1052:       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1053:     } else {
1054:       const auto dptr = thrust::device_pointer_cast(xptr.data());

1056:       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1057:     }
1058:     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1059:   }
1060:   PetscFunctionReturn(PETSC_SUCCESS);
1061: }

1063: // v->ops->scale
1064: template <device::cupm::DeviceType T>
1065: inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1066: {
1067:   PetscFunctionBegin;
1068:   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1069:   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1070:     PetscCall(Set(xin, alpha));
1071:   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1072:     PetscDeviceContext dctx;
1073:     cupmBlasHandle_t   cupmBlasHandle;

1075:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1076:     PetscCall(PetscLogGpuTimeBegin());
1077:     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1078:     PetscCall(PetscLogGpuTimeEnd());
1079:     PetscCall(PetscLogGpuFlops(n));
1080:     PetscCall(PetscDeviceContextSynchronize(dctx));
1081:   } else {
1082:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1083:   }
1084:   PetscFunctionReturn(PETSC_SUCCESS);
1085: }

1087: // v->ops->tdot
1088: template <device::cupm::DeviceType T>
1089: inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1090: {
1091:   PetscFunctionBegin;
1092:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1093:     PetscDeviceContext dctx;
1094:     cupmBlasHandle_t   cupmBlasHandle;

1096:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1097:     PetscCall(PetscLogGpuTimeBegin());
1098:     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1099:     PetscCall(PetscLogGpuTimeEnd());
1100:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1101:   } else {
1102:     *z = 0.0;
1103:   }
1104:   PetscFunctionReturn(PETSC_SUCCESS);
1105: }

1107: // v->ops->copy
1108: template <device::cupm::DeviceType T>
1109: inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1110: {
1111:   PetscFunctionBegin;
1112:   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1113:   if (const auto n = xin->map->n) {
1114:     const auto xmask = xin->offloadmask;
1115:     // silence buggy gcc warning: mode may be used uninitialized in this function
1116:     auto               mode = cupmMemcpyDeviceToDevice;
1117:     PetscDeviceContext dctx;
1118:     cupmStream_t       stream;

1120:     // translate from PetscOffloadMask to cupmMemcpyKind
1121:     switch (const auto ymask = yout->offloadmask) {
1122:     case PETSC_OFFLOAD_UNALLOCATED: {
1123:       PetscBool yiscupm;

1125:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1126:       if (yiscupm) {
1127:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1128:         break;
1129:       }
1130:     } // fall-through if unallocated and not cupm
1131:   #if PETSC_CPP_VERSION >= 17
1132:       [[fallthrough]];
1133:   #endif
1134:     case PETSC_OFFLOAD_CPU:
1135:       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1136:       break;
1137:     case PETSC_OFFLOAD_BOTH:
1138:     case PETSC_OFFLOAD_GPU:
1139:       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1140:       break;
1141:     default:
1142:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1143:     }

1145:     PetscCall(GetHandles_(&dctx, &stream));
1146:     switch (mode) {
1147:     case cupmMemcpyDeviceToDevice: // the best case
1148:     case cupmMemcpyHostToDevice: { // not terrible
1149:       const auto yptr = DeviceArrayWrite(dctx, yout);
1150:       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();

1152:       PetscCall(PetscLogGpuTimeBegin());
1153:       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1154:       PetscCall(PetscLogGpuTimeEnd());
1155:     } break;
1156:     case cupmMemcpyDeviceToHost: // not great
1157:     case cupmMemcpyHostToHost: { // worst case
1158:       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1159:       PetscScalar *yptr;

1161:       PetscCall(VecGetArrayWrite(yout, &yptr));
1162:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1163:       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1164:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1165:       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1166:     } break;
1167:     default:
1168:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1169:     }
1170:     PetscCall(PetscDeviceContextSynchronize(dctx));
1171:   } else {
1172:     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1173:   }
1174:   PetscFunctionReturn(PETSC_SUCCESS);
1175: }

1177: // v->ops->swap
1178: template <device::cupm::DeviceType T>
1179: inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1180: {
1181:   PetscFunctionBegin;
1182:   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1183:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1184:     PetscDeviceContext dctx;
1185:     cupmBlasHandle_t   cupmBlasHandle;

1187:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1188:     PetscCall(PetscLogGpuTimeBegin());
1189:     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1190:     PetscCall(PetscLogGpuTimeEnd());
1191:     PetscCall(PetscDeviceContextSynchronize(dctx));
1192:   } else {
1193:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1194:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1195:   }
1196:   PetscFunctionReturn(PETSC_SUCCESS);
1197: }

1199: // v->ops->axpby
1200: template <device::cupm::DeviceType T>
1201: inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1202: {
1203:   PetscFunctionBegin;
1204:   if (alpha == PetscScalar(0.0)) {
1205:     PetscCall(Scale(yin, beta));
1206:   } else if (beta == PetscScalar(1.0)) {
1207:     PetscCall(AXPY(yin, alpha, xin));
1208:   } else if (alpha == PetscScalar(1.0)) {
1209:     PetscCall(AYPX(yin, beta, xin));
1210:   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1211:     const auto         betaIsZero = beta == PetscScalar(0.0);
1212:     const auto         aptr       = cupmScalarPtrCast(&alpha);
1213:     PetscDeviceContext dctx;
1214:     cupmBlasHandle_t   cupmBlasHandle;

1216:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1217:     {
1218:       const auto xptr = DeviceArrayRead(dctx, xin);

1220:       if (betaIsZero /* beta = 0 */) {
1221:         // here we can get away with purely write-only as we memcpy into it first
1222:         const auto   yptr = DeviceArrayWrite(dctx, yin);
1223:         cupmStream_t stream;

1225:         PetscCall(GetHandlesFrom_(dctx, &stream));
1226:         PetscCall(PetscLogGpuTimeBegin());
1227:         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1228:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1229:       } else {
1230:         const auto yptr = DeviceArrayReadWrite(dctx, yin);

1232:         PetscCall(PetscLogGpuTimeBegin());
1233:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1234:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1235:       }
1236:     }
1237:     PetscCall(PetscLogGpuTimeEnd());
1238:     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1239:     PetscCall(PetscDeviceContextSynchronize(dctx));
1240:   } else {
1241:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1242:   }
1243:   PetscFunctionReturn(PETSC_SUCCESS);
1244: }

1246: // v->ops->axpbypcz
1247: template <device::cupm::DeviceType T>
1248: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1249: {
1250:   PetscFunctionBegin;
1251:   if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma));
1252:   PetscCall(AXPY(zin, alpha, xin));
1253:   PetscCall(AXPY(zin, beta, yin));
1254:   PetscFunctionReturn(PETSC_SUCCESS);
1255: }

1257: // v->ops->norm
1258: template <device::cupm::DeviceType T>
1259: inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1260: {
1261:   PetscDeviceContext dctx;
1262:   cupmBlasHandle_t   cupmBlasHandle;

1264:   PetscFunctionBegin;
1265:   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1266:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1267:     const auto xptr      = DeviceArrayRead(dctx, xin);
1268:     PetscInt   flopCount = 0;

1270:     PetscCall(PetscLogGpuTimeBegin());
1271:     switch (type) {
1272:     case NORM_1_AND_2:
1273:     case NORM_1:
1274:       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1275:       flopCount = std::max(n - 1, 0);
1276:       if (type == NORM_1) break;
1277:       ++z; // fall-through
1278:   #if PETSC_CPP_VERSION >= 17
1279:       [[fallthrough]];
1280:   #endif
1281:     case NORM_2:
1282:     case NORM_FROBENIUS:
1283:       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1284:       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1285:       break;
1286:     case NORM_INFINITY: {
1287:       cupmBlasInt_t max_loc = 0;
1288:       PetscScalar   xv      = 0.;
1289:       cupmStream_t  stream;

1291:       PetscCall(GetHandlesFrom_(dctx, &stream));
1292:       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1293:       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1294:       *z = PetscAbsScalar(xv);
1295:       // REVIEW ME: flopCount = ???
1296:     } break;
1297:     }
1298:     PetscCall(PetscLogGpuTimeEnd());
1299:     PetscCall(PetscLogGpuFlops(flopCount));
1300:   } else {
1301:     z[0]                    = 0.0;
1302:     z[type == NORM_1_AND_2] = 0.0;
1303:   }
1304:   PetscFunctionReturn(PETSC_SUCCESS);
1305: }

1307: namespace detail
1308: {

1310: struct dotnorm2_mult {
1311:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1312:   {
1313:     const auto conjt = PetscConj(t);

1315:     return {s * conjt, t * conjt};
1316:   }
1317: };

1319: // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1320: // would do it myself but now I am worried that they do so on purpose...
1321: struct dotnorm2_tuple_plus {
1322:   using value_type = thrust::tuple<PetscScalar, PetscScalar>;

1324:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1325: };

1327: } // namespace detail

1329: // v->ops->dotnorm2
1330: template <device::cupm::DeviceType T>
1331: inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1332: {
1333:   PetscDeviceContext dctx;
1334:   cupmStream_t       stream;

1336:   PetscFunctionBegin;
1337:   PetscCall(GetHandles_(&dctx, &stream));
1338:   {
1339:     PetscScalar dpt = 0.0, nmt = 0.0;
1340:     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());

1342:     // clang-format off
1343:     PetscCallThrust(
1344:       thrust::tie(*dp, *nm) = THRUST_CALL(
1345:         thrust::inner_product,
1346:         stream,
1347:         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1348:         thrust::make_tuple(dpt, nmt),
1349:         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1350:       );
1351:     );
1352:     // clang-format on
1353:   }
1354:   PetscFunctionReturn(PETSC_SUCCESS);
1355: }

1357: namespace detail
1358: {

1360: struct conjugate {
1361:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1362: };

1364: } // namespace detail

1366: // v->ops->conjugate
1367: template <device::cupm::DeviceType T>
1368: inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
1369: {
1370:   PetscFunctionBegin;
1371:   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin));
1372:   PetscFunctionReturn(PETSC_SUCCESS);
1373: }

1375: namespace detail
1376: {

1378: struct real_part {
1379:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }

1381:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1382: };

1384: // deriving from Operator allows us to "store" an instance of the operator in the class but
1385: // also take advantage of empty base class optimization if the operator is stateless
1386: template <typename Operator>
1387: class tuple_compare : Operator {
1388: public:
1389:   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1390:   using operator_type = Operator;

1392:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1393:   {
1394:     if (op_()(y.get<0>(), x.get<0>())) {
1395:       // if y is strictly greater/less than x, return y
1396:       return y;
1397:     } else if (y.get<0>() == x.get<0>()) {
1398:       // if equal, prefer lower index
1399:       return y.get<1>() < x.get<1>() ? y : x;
1400:     }
1401:     // otherwise return x
1402:     return x;
1403:   }

1405: private:
1406:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1407: };

1409: } // namespace detail

1411: template <device::cupm::DeviceType T>
1412: template <typename TupleFuncT, typename UnaryFuncT>
1413: inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1414: {
1415:   PetscFunctionBegin;
1416:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1417:   if (p) *p = -1;
1418:   if (const auto n = v->map->n) {
1419:     PetscDeviceContext dctx;
1420:     cupmStream_t       stream;

1422:     PetscCall(GetHandles_(&dctx, &stream));
1423:       // needed to:
1424:       // 1. switch between transform_reduce and reduce
1425:       // 2. strip the real_part functor from the arguments
1426:   #if PetscDefined(USE_COMPLEX)
1427:     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1428:   #else
1429:     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1430:   #endif
1431:     {
1432:       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());

1434:       if (p) {
1435:         // clang-format off
1436:         const auto zip = thrust::make_zip_iterator(
1437:           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1438:         );
1439:         // clang-format on
1440:         // need to use preprocessor conditionals since otherwise thrust complains about not being
1441:         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1442:         // builds...
1443:         // clang-format off
1444:         PetscCallThrust(
1445:           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1446:             stream, zip, zip + n, detail::real_part{},
1447:             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1448:           );
1449:         );
1450:         // clang-format on
1451:       } else {
1452:         // clang-format off
1453:         PetscCallThrust(
1454:           *m = THRUST_MINMAX_REDUCE(
1455:             stream, vptr, vptr + n, detail::real_part{},
1456:             *m, std::forward<UnaryFuncT>(unary_ftr)
1457:           );
1458:         );
1459:         // clang-format on
1460:       }
1461:     }
1462:   #undef THRUST_MINMAX_REDUCE
1463:   }
1464:   // REVIEW ME: flops?
1465:   PetscFunctionReturn(PETSC_SUCCESS);
1466: }

1468: // v->ops->max
1469: template <device::cupm::DeviceType T>
1470: inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
1471: {
1472:   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1473:   using unary_functor = thrust::maximum<PetscReal>;

1475:   PetscFunctionBegin;
1476:   *m = PETSC_MIN_REAL;
1477:   // use {} constructor syntax otherwise most vexing parse
1478:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1479:   PetscFunctionReturn(PETSC_SUCCESS);
1480: }

1482: // v->ops->min
1483: template <device::cupm::DeviceType T>
1484: inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
1485: {
1486:   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1487:   using unary_functor = thrust::minimum<PetscReal>;

1489:   PetscFunctionBegin;
1490:   *m = PETSC_MAX_REAL;
1491:   // use {} constructor syntax otherwise most vexing parse
1492:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1493:   PetscFunctionReturn(PETSC_SUCCESS);
1494: }

1496: // v->ops->sum
1497: template <device::cupm::DeviceType T>
1498: inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
1499: {
1500:   PetscFunctionBegin;
1501:   if (const auto n = v->map->n) {
1502:     PetscDeviceContext dctx;
1503:     cupmStream_t       stream;

1505:     PetscCall(GetHandles_(&dctx, &stream));
1506:     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1507:     // REVIEW ME: why not cupmBlasXasum()?
1508:     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1509:     // REVIEW ME: must be at least n additions
1510:     PetscCall(PetscLogGpuFlops(n));
1511:   } else {
1512:     *sum = 0.0;
1513:   }
1514:   PetscFunctionReturn(PETSC_SUCCESS);
1515: }

1517: template <device::cupm::DeviceType T>
1518: inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
1519: {
1520:   PetscFunctionBegin;
1521:   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v));
1522:   PetscFunctionReturn(PETSC_SUCCESS);
1523: }

1525: template <device::cupm::DeviceType T>
1526: inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
1527: {
1528:   PetscFunctionBegin;
1529:   if (const auto n = v->map->n) {
1530:     PetscBool          iscurand;
1531:     PetscDeviceContext dctx;

1533:     PetscCall(GetHandles_(&dctx));
1534:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1535:     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1536:     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1537:   } else {
1538:     PetscCall(MaybeIncrementEmptyLocalVec(v));
1539:   }
1540:   // REVIEW ME: flops????
1541:   // REVIEW ME: Timing???
1542:   PetscFunctionReturn(PETSC_SUCCESS);
1543: }

1545: // v->ops->setpreallocation
1546: template <device::cupm::DeviceType T>
1547: inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1548: {
1549:   PetscDeviceContext dctx;

1551:   PetscFunctionBegin;
1552:   PetscCall(GetHandles_(&dctx));
1553:   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1554:   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1555:   PetscFunctionReturn(PETSC_SUCCESS);
1556: }

1558: namespace kernels
1559: {

1561: template <typename F>
1562: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1563: {
1564:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1565:     const auto  end = jmap[i + 1];
1566:     const auto  idx = xvindex(i);
1567:     PetscScalar sum = 0.0;

1569:     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];

1571:     if (imode == INSERT_VALUES) {
1572:       xv[idx] = sum;
1573:     } else {
1574:       xv[idx] += sum;
1575:     }
1576:   });
1577:   return;
1578: }

1580: namespace
1581: {

1583: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1584: {
1585:   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1586:   return;
1587: }

1589: } // namespace

1591:   #if PetscDefined(USING_HCC)
1592: namespace do_not_use
1593: {

1595: // Needed to silence clang warning:
1596: //
1597: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1598: //
1599: // The warning is silly, since the function *is* used, however the host compiler does not
1600: // appear see this. Likely because the function using it is in a template.
1601: //
1602: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1603: inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1604: {
1605:   (void)sum_kernel;
1606: }

1608: inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1609: {
1610:   (void)add_coo_values;
1611: }

1613: } // namespace do_not_use
1614:   #endif

1616: } // namespace kernels

1618: // v->ops->setvaluescoo
1619: template <device::cupm::DeviceType T>
1620: inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1621: {
1622:   auto               vv = const_cast<PetscScalar *>(v);
1623:   PetscMemType       memtype;
1624:   PetscDeviceContext dctx;
1625:   cupmStream_t       stream;

1627:   PetscFunctionBegin;
1628:   PetscCall(GetHandles_(&dctx, &stream));
1629:   PetscCall(PetscGetMemType(v, &memtype));
1630:   if (PetscMemTypeHost(memtype)) {
1631:     const auto size = VecIMPLCast(x)->coo_n;

1633:     // If user gave v[] in host, we might need to copy it to device if any
1634:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1635:     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1636:   }

1638:   if (const auto n = x->map->n) {
1639:     const auto vcu = VecCUPMCast(x);

1641:     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1642:   } else {
1643:     PetscCall(MaybeIncrementEmptyLocalVec(x));
1644:   }

1646:   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1647:   PetscCall(PetscDeviceContextSynchronize(dctx));
1648:   PetscFunctionReturn(PETSC_SUCCESS);
1649: }

1651: } // namespace impl

1653: // ==========================================================================================
1654: // VecSeq_CUPM - Implementations
1655: // ==========================================================================================

1657: namespace
1658: {

1660: template <device::cupm::DeviceType T>
1661: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
1662: {
1663:   PetscFunctionBegin;
1665:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
1666:   PetscFunctionReturn(PETSC_SUCCESS);
1667: }

1669: template <device::cupm::DeviceType T>
1670: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1671: {
1672:   PetscFunctionBegin;
1675:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
1676:   PetscFunctionReturn(PETSC_SUCCESS);
1677: }

1679: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1680: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1681: {
1682:   PetscFunctionBegin;
1685:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1686:   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1687:   PetscFunctionReturn(PETSC_SUCCESS);
1688: }

1690: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1691: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1692: {
1693:   PetscFunctionBegin;
1695:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1696:   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1697:   PetscFunctionReturn(PETSC_SUCCESS);
1698: }

1700: template <device::cupm::DeviceType T>
1701: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1702: {
1703:   PetscFunctionBegin;
1704:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1705:   PetscFunctionReturn(PETSC_SUCCESS);
1706: }

1708: template <device::cupm::DeviceType T>
1709: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1710: {
1711:   PetscFunctionBegin;
1712:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1713:   PetscFunctionReturn(PETSC_SUCCESS);
1714: }

1716: template <device::cupm::DeviceType T>
1717: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1718: {
1719:   PetscFunctionBegin;
1720:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1721:   PetscFunctionReturn(PETSC_SUCCESS);
1722: }

1724: template <device::cupm::DeviceType T>
1725: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1726: {
1727:   PetscFunctionBegin;
1728:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1729:   PetscFunctionReturn(PETSC_SUCCESS);
1730: }

1732: template <device::cupm::DeviceType T>
1733: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1734: {
1735:   PetscFunctionBegin;
1736:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1737:   PetscFunctionReturn(PETSC_SUCCESS);
1738: }

1740: template <device::cupm::DeviceType T>
1741: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1742: {
1743:   PetscFunctionBegin;
1744:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1745:   PetscFunctionReturn(PETSC_SUCCESS);
1746: }

1748: template <device::cupm::DeviceType T>
1749: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1750: {
1751:   PetscFunctionBegin;
1753:   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1754:   PetscFunctionReturn(PETSC_SUCCESS);
1755: }

1757: template <device::cupm::DeviceType T>
1758: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1759: {
1760:   PetscFunctionBegin;
1762:   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1763:   PetscFunctionReturn(PETSC_SUCCESS);
1764: }

1766: template <device::cupm::DeviceType T>
1767: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
1768: {
1769:   PetscFunctionBegin;
1771:   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
1772:   PetscFunctionReturn(PETSC_SUCCESS);
1773: }

1775: } // anonymous namespace

1777: } // namespace cupm

1779: } // namespace vec

1781: } // namespace Petsc

1783: #endif // __cplusplus

1785: #endif // PETSCVECSEQCUPM_HPP