Actual source code: sfhip.hip.cpp

  1: #include <../src/vec/is/sf/impls/basic/sfpack.h>
  2: #include <petscpkg_version.h>

  4: /* compilation issues on SPOCK */
  5: #undef PETSC_HAVE_COMPLEX

  7: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
  8: __device__ static inline PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
  9: {
 10:   PetscInt        i, j, k, m, n, r;
 11:   const PetscInt *offset, *start, *dx, *dy, *X, *Y;

 13:   n      = opt[0];
 14:   offset = opt + 1;
 15:   start  = opt + n + 2;
 16:   dx     = opt + 2 * n + 2;
 17:   dy     = opt + 3 * n + 2;
 18:   X      = opt + 5 * n + 2;
 19:   Y      = opt + 6 * n + 2;
 20:   for (r = 0; r < n; r++) {
 21:     if (tid < offset[r + 1]) break;
 22:   }
 23:   m = (tid - offset[r]);
 24:   k = m / (dx[r] * dy[r]);
 25:   j = (m - k * dx[r] * dy[r]) / dx[r];
 26:   i = m - k * dx[r] * dy[r] - j * dx[r];

 28:   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
 29: }

 31: /*====================================================================================*/
 32: /*  Templated HIP kernels for pack/unpack. The Op can be regular or atomic           */
 33: /*====================================================================================*/

 35: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
 36:    <Type> is PetscReal, which is the primitive type we operate on.
 37:    <bs>   is 16, which says <unit> contains 16 primitive types.
 38:    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
 39:    <EQ>   is 0, which is (bs == BS ? 1 : 0)

 41:   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
 42:   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
 43: */
 44: template <class Type, PetscInt BS, PetscInt EQ>
 45: __global__ static void d_Pack(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, const Type *data, Type *buf)
 46: {
 47:   PetscInt       i, s, t, tid = blockIdx.x * blockDim.x + threadIdx.x;
 48:   const PetscInt grid_size = gridDim.x * blockDim.x;
 49:   const PetscInt M         = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */
 50:   const PetscInt MBS       = M * BS;             /* MBS=bs. We turn MBS into a compile-time const when EQ=1. */

 52:   for (; tid < count; tid += grid_size) {
 53:     /* opt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
 54:        opt == NULL && idx == NULL ==> the indices are contiguous;
 55:      */
 56:     t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
 57:     s = tid * MBS;
 58:     for (i = 0; i < MBS; i++) buf[s + i] = data[t + i];
 59:   }
 60: }

 62: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 63: __global__ static void d_UnpackAndOp(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, Type *data, const Type *buf)
 64: {
 65:   PetscInt       i, s, t, tid = blockIdx.x * blockDim.x + threadIdx.x;
 66:   const PetscInt grid_size = gridDim.x * blockDim.x;
 67:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 68:   Op             op;

 70:   for (; tid < count; tid += grid_size) {
 71:     t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
 72:     s = tid * MBS;
 73:     for (i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
 74:   }
 75: }

 77: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 78: __global__ static void d_FetchAndOp(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, Type *leafbuf)
 79: {
 80:   PetscInt       i, l, r, tid = blockIdx.x * blockDim.x + threadIdx.x;
 81:   const PetscInt grid_size = gridDim.x * blockDim.x;
 82:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 83:   Op             op;

 85:   for (; tid < count; tid += grid_size) {
 86:     r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
 87:     l = tid * MBS;
 88:     for (i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
 89:   }
 90: }

 92: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 93: __global__ static void d_ScatterAndOp(PetscInt bs, PetscInt count, PetscInt srcx, PetscInt srcy, PetscInt srcX, PetscInt srcY, PetscInt srcStart, const PetscInt *srcIdx, const Type *src, PetscInt dstx, PetscInt dsty, PetscInt dstX, PetscInt dstY, PetscInt dstStart, const PetscInt *dstIdx, Type *dst)
 94: {
 95:   PetscInt       i, j, k, s, t, tid = blockIdx.x * blockDim.x + threadIdx.x;
 96:   const PetscInt grid_size = gridDim.x * blockDim.x;
 97:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 98:   Op             op;

100:   for (; tid < count; tid += grid_size) {
101:     if (!srcIdx) { /* src is either contiguous or 3D */
102:       k = tid / (srcx * srcy);
103:       j = (tid - k * srcx * srcy) / srcx;
104:       i = tid - k * srcx * srcy - j * srcx;
105:       s = srcStart + k * srcX * srcY + j * srcX + i;
106:     } else {
107:       s = srcIdx[tid];
108:     }

110:     if (!dstIdx) { /* dst is either contiguous or 3D */
111:       k = tid / (dstx * dsty);
112:       j = (tid - k * dstx * dsty) / dstx;
113:       i = tid - k * dstx * dsty - j * dstx;
114:       t = dstStart + k * dstX * dstY + j * dstX + i;
115:     } else {
116:       t = dstIdx[tid];
117:     }

119:     s *= MBS;
120:     t *= MBS;
121:     for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
122:   }
123: }

125: template <class Type, class Op, PetscInt BS, PetscInt EQ>
126: __global__ static void d_FetchAndOpLocal(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, PetscInt leafstart, const PetscInt *leafopt, const PetscInt *leafidx, const Type *leafdata, Type *leafupdate)
127: {
128:   PetscInt       i, l, r, tid = blockIdx.x * blockDim.x + threadIdx.x;
129:   const PetscInt grid_size = gridDim.x * blockDim.x;
130:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
131:   Op             op;

133:   for (; tid < count; tid += grid_size) {
134:     r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
135:     l = (leafopt ? MapTidToIndex(leafopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
136:     for (i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
137:   }
138: }

140: /*====================================================================================*/
141: /*                             Regular operations on device                           */
142: /*====================================================================================*/
143: template <typename Type>
144: struct Insert {
145:   __device__ Type operator()(Type &x, Type y) const
146:   {
147:     Type old = x;
148:     x        = y;
149:     return old;
150:   }
151: };
152: template <typename Type>
153: struct Add {
154:   __device__ Type operator()(Type &x, Type y) const
155:   {
156:     Type old = x;
157:     x += y;
158:     return old;
159:   }
160: };
161: template <typename Type>
162: struct Mult {
163:   __device__ Type operator()(Type &x, Type y) const
164:   {
165:     Type old = x;
166:     x *= y;
167:     return old;
168:   }
169: };
170: template <typename Type>
171: struct Min {
172:   __device__ Type operator()(Type &x, Type y) const
173:   {
174:     Type old = x;
175:     x        = PetscMin(x, y);
176:     return old;
177:   }
178: };
179: template <typename Type>
180: struct Max {
181:   __device__ Type operator()(Type &x, Type y) const
182:   {
183:     Type old = x;
184:     x        = PetscMax(x, y);
185:     return old;
186:   }
187: };
188: template <typename Type>
189: struct LAND {
190:   __device__ Type operator()(Type &x, Type y) const
191:   {
192:     Type old = x;
193:     x        = x && y;
194:     return old;
195:   }
196: };
197: template <typename Type>
198: struct LOR {
199:   __device__ Type operator()(Type &x, Type y) const
200:   {
201:     Type old = x;
202:     x        = x || y;
203:     return old;
204:   }
205: };
206: template <typename Type>
207: struct LXOR {
208:   __device__ Type operator()(Type &x, Type y) const
209:   {
210:     Type old = x;
211:     x        = !x != !y;
212:     return old;
213:   }
214: };
215: template <typename Type>
216: struct BAND {
217:   __device__ Type operator()(Type &x, Type y) const
218:   {
219:     Type old = x;
220:     x        = x & y;
221:     return old;
222:   }
223: };
224: template <typename Type>
225: struct BOR {
226:   __device__ Type operator()(Type &x, Type y) const
227:   {
228:     Type old = x;
229:     x        = x | y;
230:     return old;
231:   }
232: };
233: template <typename Type>
234: struct BXOR {
235:   __device__ Type operator()(Type &x, Type y) const
236:   {
237:     Type old = x;
238:     x        = x ^ y;
239:     return old;
240:   }
241: };
242: template <typename Type>
243: struct Minloc {
244:   __device__ Type operator()(Type &x, Type y) const
245:   {
246:     Type old = x;
247:     if (y.a < x.a) x = y;
248:     else if (y.a == x.a) x.b = min(x.b, y.b);
249:     return old;
250:   }
251: };
252: template <typename Type>
253: struct Maxloc {
254:   __device__ Type operator()(Type &x, Type y) const
255:   {
256:     Type old = x;
257:     if (y.a > x.a) x = y;
258:     else if (y.a == x.a) x.b = min(x.b, y.b); /* See MPI MAXLOC */
259:     return old;
260:   }
261: };

263: /*====================================================================================*/
264: /*                             Atomic operations on device                            */
265: /*====================================================================================*/

267: /*
268:   Atomic Insert (exchange) operations

270:   See Cuda version
271: */
272: #if PETSC_PKG_HIP_VERSION_LT(4, 4, 0)
273: __device__ static double atomicExch(double *address, double val)
274: {
275:   return __longlong_as_double(atomicExch((ullint *)address, __double_as_longlong(val)));
276: }
277: #endif

279: __device__ static llint atomicExch(llint *address, llint val)
280: {
281:   return (llint)(atomicExch((ullint *)address, (ullint)val));
282: }

284: template <typename Type>
285: struct AtomicInsert {
286:   __device__ Type operator()(Type &x, Type y) const { return atomicExch(&x, y); }
287: };

289: #if defined(PETSC_HAVE_COMPLEX)
290:   #if defined(PETSC_USE_REAL_DOUBLE)
291: template <>
292: struct AtomicInsert<PetscComplex> {
293:   __device__ PetscComplex operator()(PetscComplex &x, PetscComplex y) const
294:   {
295:     PetscComplex         old, *z = &old;
296:     double              *xp = (double *)&x, *yp = (double *)&y;
297:     AtomicInsert<double> op;
298:     z[0] = op(xp[0], yp[0]);
299:     z[1] = op(xp[1], yp[1]);
300:     return old; /* The returned value may not be atomic. It can be mix of two ops. Caller should discard it. */
301:   }
302: };
303:   #elif defined(PETSC_USE_REAL_SINGLE)
304: template <>
305: struct AtomicInsert<PetscComplex> {
306:   __device__ PetscComplex operator()(PetscComplex &x, PetscComplex y) const
307:   {
308:     double              *xp = (double *)&x, *yp = (double *)&y;
309:     AtomicInsert<double> op;
310:     return op(xp[0], yp[0]);
311:   }
312: };
313:   #endif
314: #endif

316: /*
317:   Atomic add operations

319: */
320: __device__ static llint atomicAdd(llint *address, llint val)
321: {
322:   return (llint)atomicAdd((ullint *)address, (ullint)val);
323: }

325: template <typename Type>
326: struct AtomicAdd {
327:   __device__ Type operator()(Type &x, Type y) const { return atomicAdd(&x, y); }
328: };

330: template <>
331: struct AtomicAdd<double> {
332:   __device__ double operator()(double &x, double y) const
333:   {
334:     /* Cuda version does more checks that may be needed */
335:     return atomicAdd(&x, y);
336:   }
337: };

339: template <>
340: struct AtomicAdd<float> {
341:   __device__ float operator()(float &x, float y) const
342:   {
343:     /* Cuda version does more checks that may be needed */
344:     return atomicAdd(&x, y);
345:   }
346: };

348: #if defined(PETSC_HAVE_COMPLEX)
349: template <>
350: struct AtomicAdd<PetscComplex> {
351:   __device__ PetscComplex operator()(PetscComplex &x, PetscComplex y) const
352:   {
353:     PetscComplex         old, *z = &old;
354:     PetscReal           *xp = (PetscReal *)&x, *yp = (PetscReal *)&y;
355:     AtomicAdd<PetscReal> op;
356:     z[0] = op(xp[0], yp[0]);
357:     z[1] = op(xp[1], yp[1]);
358:     return old; /* The returned value may not be atomic. It can be mix of two ops. Caller should discard it. */
359:   }
360: };
361: #endif

363: /*
364:   Atomic Mult operations:

366:   HIP has no atomicMult at all, so we build our own with atomicCAS
367:  */
368: #if defined(PETSC_USE_REAL_DOUBLE)
369: __device__ static double atomicMult(double *address, double val)
370: {
371:   ullint *address_as_ull = (ullint *)(address);
372:   ullint  old            = *address_as_ull, assumed;
373:   do {
374:     assumed = old;
375:     /* Other threads can access and modify value of *address_as_ull after the read above and before the write below */
376:     old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val * __longlong_as_double(assumed)));
377:   } while (assumed != old);
378:   return __longlong_as_double(old);
379: }
380: #elif defined(PETSC_USE_REAL_SINGLE)
381: __device__ static float atomicMult(float *address, float val)
382: {
383:   int *address_as_int = (int *)(address);
384:   int  old            = *address_as_int, assumed;
385:   do {
386:     assumed = old;
387:     old     = atomicCAS(address_as_int, assumed, __float_as_int(val * __int_as_float(assumed)));
388:   } while (assumed != old);
389:   return __int_as_float(old);
390: }
391: #endif

393: __device__ static int atomicMult(int *address, int val)
394: {
395:   int *address_as_int = (int *)(address);
396:   int  old            = *address_as_int, assumed;
397:   do {
398:     assumed = old;
399:     old     = atomicCAS(address_as_int, assumed, val * assumed);
400:   } while (assumed != old);
401:   return (int)old;
402: }

404: __device__ static llint atomicMult(llint *address, llint val)
405: {
406:   ullint *address_as_ull = (ullint *)(address);
407:   ullint  old            = *address_as_ull, assumed;
408:   do {
409:     assumed = old;
410:     old     = atomicCAS(address_as_ull, assumed, (ullint)(val * (llint)assumed));
411:   } while (assumed != old);
412:   return (llint)old;
413: }

415: template <typename Type>
416: struct AtomicMult {
417:   __device__ Type operator()(Type &x, Type y) const { return atomicMult(&x, y); }
418: };

420: /*
421:   Atomic Min/Max operations

423:   See CUDA version for comments.
424:  */
425: #if PETSC_PKG_HIP_VERSION_LT(4, 4, 0)
426:   #if defined(PETSC_USE_REAL_DOUBLE)
427: __device__ static double atomicMin(double *address, double val)
428: {
429:   ullint *address_as_ull = (ullint *)(address);
430:   ullint  old            = *address_as_ull, assumed;
431:   do {
432:     assumed = old;
433:     old     = atomicCAS(address_as_ull, assumed, __double_as_longlong(PetscMin(val, __longlong_as_double(assumed))));
434:   } while (assumed != old);
435:   return __longlong_as_double(old);
436: }

438: __device__ static double atomicMax(double *address, double val)
439: {
440:   ullint *address_as_ull = (ullint *)(address);
441:   ullint  old            = *address_as_ull, assumed;
442:   do {
443:     assumed = old;
444:     old     = atomicCAS(address_as_ull, assumed, __double_as_longlong(PetscMax(val, __longlong_as_double(assumed))));
445:   } while (assumed != old);
446:   return __longlong_as_double(old);
447: }
448:   #elif defined(PETSC_USE_REAL_SINGLE)
449: __device__ static float atomicMin(float *address, float val)
450: {
451:   int *address_as_int = (int *)(address);
452:   int  old            = *address_as_int, assumed;
453:   do {
454:     assumed = old;
455:     old     = atomicCAS(address_as_int, assumed, __float_as_int(PetscMin(val, __int_as_float(assumed))));
456:   } while (assumed != old);
457:   return __int_as_float(old);
458: }

460: __device__ static float atomicMax(float *address, float val)
461: {
462:   int *address_as_int = (int *)(address);
463:   int  old            = *address_as_int, assumed;
464:   do {
465:     assumed = old;
466:     old     = atomicCAS(address_as_int, assumed, __float_as_int(PetscMax(val, __int_as_float(assumed))));
467:   } while (assumed != old);
468:   return __int_as_float(old);
469: }
470:   #endif
471: #endif

473: /* As of ROCm 3.10 llint atomicMin/Max(llint*, llint) is not supported */
474: __device__ static llint atomicMin(llint *address, llint val)
475: {
476:   ullint *address_as_ull = (ullint *)(address);
477:   ullint  old            = *address_as_ull, assumed;
478:   do {
479:     assumed = old;
480:     old     = atomicCAS(address_as_ull, assumed, (ullint)(PetscMin(val, (llint)assumed)));
481:   } while (assumed != old);
482:   return (llint)old;
483: }

485: __device__ static llint atomicMax(llint *address, llint val)
486: {
487:   ullint *address_as_ull = (ullint *)(address);
488:   ullint  old            = *address_as_ull, assumed;
489:   do {
490:     assumed = old;
491:     old     = atomicCAS(address_as_ull, assumed, (ullint)(PetscMax(val, (llint)assumed)));
492:   } while (assumed != old);
493:   return (llint)old;
494: }

496: template <typename Type>
497: struct AtomicMin {
498:   __device__ Type operator()(Type &x, Type y) const { return atomicMin(&x, y); }
499: };
500: template <typename Type>
501: struct AtomicMax {
502:   __device__ Type operator()(Type &x, Type y) const { return atomicMax(&x, y); }
503: };

505: /*
506:   Atomic bitwise operations
507:   As of ROCm 3.10, the llint atomicAnd/Or/Xor(llint*, llint) is not supported
508: */

510: __device__ static llint atomicAnd(llint *address, llint val)
511: {
512:   ullint *address_as_ull = (ullint *)(address);
513:   ullint  old            = *address_as_ull, assumed;
514:   do {
515:     assumed = old;
516:     old     = atomicCAS(address_as_ull, assumed, (ullint)(val & (llint)assumed));
517:   } while (assumed != old);
518:   return (llint)old;
519: }
520: __device__ static llint atomicOr(llint *address, llint val)
521: {
522:   ullint *address_as_ull = (ullint *)(address);
523:   ullint  old            = *address_as_ull, assumed;
524:   do {
525:     assumed = old;
526:     old     = atomicCAS(address_as_ull, assumed, (ullint)(val | (llint)assumed));
527:   } while (assumed != old);
528:   return (llint)old;
529: }

531: __device__ static llint atomicXor(llint *address, llint val)
532: {
533:   ullint *address_as_ull = (ullint *)(address);
534:   ullint  old            = *address_as_ull, assumed;
535:   do {
536:     assumed = old;
537:     old     = atomicCAS(address_as_ull, assumed, (ullint)(val ^ (llint)assumed));
538:   } while (assumed != old);
539:   return (llint)old;
540: }

542: template <typename Type>
543: struct AtomicBAND {
544:   __device__ Type operator()(Type &x, Type y) const { return atomicAnd(&x, y); }
545: };
546: template <typename Type>
547: struct AtomicBOR {
548:   __device__ Type operator()(Type &x, Type y) const { return atomicOr(&x, y); }
549: };
550: template <typename Type>
551: struct AtomicBXOR {
552:   __device__ Type operator()(Type &x, Type y) const { return atomicXor(&x, y); }
553: };

555: /*
556:   Atomic logical operations:

558:   CUDA has no atomic logical operations at all. We support them on integer types.
559: */

561: /* A template without definition makes any instantiation not using given specializations erroneous at compile time,
562:    which is what we want since we only support 32-bit and 64-bit integers.
563:  */
564: template <typename Type, class Op, int size /* sizeof(Type) */>
565: struct AtomicLogical;

567: template <typename Type, class Op>
568: struct AtomicLogical<Type, Op, 4> {
569:   __device__ Type operator()(Type &x, Type y) const
570:   {
571:     int *address_as_int = (int *)(&x);
572:     int  old            = *address_as_int, assumed;
573:     Op   op;
574:     do {
575:       assumed = old;
576:       old     = atomicCAS(address_as_int, assumed, (int)(op((Type)assumed, y)));
577:     } while (assumed != old);
578:     return (Type)old;
579:   }
580: };

582: template <typename Type, class Op>
583: struct AtomicLogical<Type, Op, 8> {
584:   __device__ Type operator()(Type &x, Type y) const
585:   {
586:     ullint *address_as_ull = (ullint *)(&x);
587:     ullint  old            = *address_as_ull, assumed;
588:     Op      op;
589:     do {
590:       assumed = old;
591:       old     = atomicCAS(address_as_ull, assumed, (ullint)(op((Type)assumed, y)));
592:     } while (assumed != old);
593:     return (Type)old;
594:   }
595: };

597: /* Note land/lor/lxor below are different from LAND etc above. Here we pass arguments by value and return result of ops (not old value) */
598: template <typename Type>
599: struct land {
600:   __device__ Type operator()(Type x, Type y) { return x && y; }
601: };
602: template <typename Type>
603: struct lor {
604:   __device__ Type operator()(Type x, Type y) { return x || y; }
605: };
606: template <typename Type>
607: struct lxor {
608:   __device__ Type operator()(Type x, Type y) { return (!x != !y); }
609: };

611: template <typename Type>
612: struct AtomicLAND {
613:   __device__ Type operator()(Type &x, Type y) const
614:   {
615:     AtomicLogical<Type, land<Type>, sizeof(Type)> op;
616:     return op(x, y);
617:   }
618: };
619: template <typename Type>
620: struct AtomicLOR {
621:   __device__ Type operator()(Type &x, Type y) const
622:   {
623:     AtomicLogical<Type, lor<Type>, sizeof(Type)> op;
624:     return op(x, y);
625:   }
626: };
627: template <typename Type>
628: struct AtomicLXOR {
629:   __device__ Type operator()(Type &x, Type y) const
630:   {
631:     AtomicLogical<Type, lxor<Type>, sizeof(Type)> op;
632:     return op(x, y);
633:   }
634: };

636: /*====================================================================================*/
637: /*  Wrapper functions of hip kernels. Function pointers are stored in 'link'         */
638: /*====================================================================================*/
639: template <typename Type, PetscInt BS, PetscInt EQ>
640: static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data, void *buf)
641: {
642:   PetscInt        nthreads = 256;
643:   PetscInt        nblocks  = (count + nthreads - 1) / nthreads;
644:   const PetscInt *iarray   = opt ? opt->array : NULL;

646:   PetscFunctionBegin;
647:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
648:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
649:   hipLaunchKernelGGL(HIP_KERNEL_NAME(d_Pack<Type, BS, EQ>), dim3(nblocks), dim3(nthreads), 0, link->stream, link->bs, count, start, iarray, idx, (const Type *)data, (Type *)buf);
650:   PetscCallHIP(hipGetLastError());
651:   PetscFunctionReturn(PETSC_SUCCESS);
652: }

654: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
655: static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, const void *buf)
656: {
657:   PetscInt        nthreads = 256;
658:   PetscInt        nblocks  = (count + nthreads - 1) / nthreads;
659:   const PetscInt *iarray   = opt ? opt->array : NULL;

661:   PetscFunctionBegin;
662:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
663:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
664:   hipLaunchKernelGGL(HIP_KERNEL_NAME(d_UnpackAndOp<Type, Op, BS, EQ>), dim3(nblocks), dim3(nthreads), 0, link->stream, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf);
665:   PetscCallHIP(hipGetLastError());
666:   PetscFunctionReturn(PETSC_SUCCESS);
667: }

669: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
670: static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
671: {
672:   PetscInt        nthreads = 256;
673:   PetscInt        nblocks  = (count + nthreads - 1) / nthreads;
674:   const PetscInt *iarray   = opt ? opt->array : NULL;

676:   PetscFunctionBegin;
677:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
678:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
679:   hipLaunchKernelGGL(HIP_KERNEL_NAME(d_FetchAndOp<Type, Op, BS, EQ>), dim3(nblocks), dim3(nthreads), 0, link->stream, link->bs, count, start, iarray, idx, (Type *)data, (Type *)buf);
680:   PetscCallHIP(hipGetLastError());
681:   PetscFunctionReturn(PETSC_SUCCESS);
682: }

684: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
685: static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst)
686: {
687:   PetscInt nthreads = 256;
688:   PetscInt nblocks  = (count + nthreads - 1) / nthreads;
689:   PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;

691:   PetscFunctionBegin;
692:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
693:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);

695:   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
696:   if (srcOpt) {
697:     srcx     = srcOpt->dx[0];
698:     srcy     = srcOpt->dy[0];
699:     srcX     = srcOpt->X[0];
700:     srcY     = srcOpt->Y[0];
701:     srcStart = srcOpt->start[0];
702:     srcIdx   = NULL;
703:   } else if (!srcIdx) {
704:     srcx = srcX = count;
705:     srcy = srcY = 1;
706:   }

708:   if (dstOpt) {
709:     dstx     = dstOpt->dx[0];
710:     dsty     = dstOpt->dy[0];
711:     dstX     = dstOpt->X[0];
712:     dstY     = dstOpt->Y[0];
713:     dstStart = dstOpt->start[0];
714:     dstIdx   = NULL;
715:   } else if (!dstIdx) {
716:     dstx = dstX = count;
717:     dsty = dstY = 1;
718:   }

720:   hipLaunchKernelGGL(HIP_KERNEL_NAME(d_ScatterAndOp<Type, Op, BS, EQ>), dim3(nblocks), dim3(nthreads), 0, link->stream, link->bs, count, srcx, srcy, srcX, srcY, srcStart, srcIdx, (const Type *)src, dstx, dsty, dstX, dstY, dstStart, dstIdx, (Type *)dst);
721:   PetscCallHIP(hipGetLastError());
722:   PetscFunctionReturn(PETSC_SUCCESS);
723: }

725: /* Specialization for Insert since we may use hipMemcpyAsync */
726: template <typename Type, PetscInt BS, PetscInt EQ>
727: static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst)
728: {
729:   PetscFunctionBegin;
730:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
731:   /*src and dst are contiguous */
732:   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
733:     PetscCallHIP(hipMemcpyAsync((Type *)dst + dstStart * link->bs, (const Type *)src + srcStart * link->bs, count * link->unitbytes, hipMemcpyDeviceToDevice, link->stream));
734:   } else {
735:     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
736:   }
737:   PetscFunctionReturn(PETSC_SUCCESS);
738: }

740: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
741: static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate)
742: {
743:   PetscInt        nthreads = 256;
744:   PetscInt        nblocks  = (count + nthreads - 1) / nthreads;
745:   const PetscInt *rarray   = rootopt ? rootopt->array : NULL;
746:   const PetscInt *larray   = leafopt ? leafopt->array : NULL;

748:   PetscFunctionBegin;
749:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
750:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
751:   hipLaunchKernelGGL(HIP_KERNEL_NAME(d_FetchAndOpLocal<Type, Op, BS, EQ>), dim3(nblocks), dim3(nthreads), 0, link->stream, link->bs, count, rootstart, rarray, rootidx, (Type *)rootdata, leafstart, larray, leafidx, (const Type *)leafdata, (Type *)leafupdate);
752:   PetscCallHIP(hipGetLastError());
753:   PetscFunctionReturn(PETSC_SUCCESS);
754: }

756: /*====================================================================================*/
757: /*  Init various types and instantiate pack/unpack function pointers                  */
758: /*====================================================================================*/
759: template <typename Type, PetscInt BS, PetscInt EQ>
760: static void PackInit_RealType(PetscSFLink link)
761: {
762:   /* Pack/unpack for remote communication */
763:   link->d_Pack            = Pack<Type, BS, EQ>;
764:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
765:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
766:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
767:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
768:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
769:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

771:   /* Scatter for local communication */
772:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
773:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
774:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
775:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
776:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
777:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

779:   /* Atomic versions when there are data-race possibilities */
780:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
781:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
782:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
783:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
784:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
785:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;

787:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
788:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
789:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
790:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
791:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
792:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
793: }

795: /* Have this templated class to specialize for char integers */
796: template <typename Type, PetscInt BS, PetscInt EQ, PetscInt size /*sizeof(Type)*/>
797: struct PackInit_IntegerType_Atomic {
798:   static void Init(PetscSFLink link)
799:   {
800:     link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
801:     link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
802:     link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
803:     link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
804:     link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
805:     link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
806:     link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
807:     link->da_UnpackAndLXOR   = UnpackAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
808:     link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
809:     link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
810:     link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
811:     link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;

813:     link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
814:     link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
815:     link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
816:     link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
817:     link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
818:     link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
819:     link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
820:     link->da_ScatterAndLXOR   = ScatterAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
821:     link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
822:     link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
823:     link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
824:     link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
825:   }
826: };

828: /*  See cuda version */
829: template <typename Type, PetscInt BS, PetscInt EQ>
830: struct PackInit_IntegerType_Atomic<Type, BS, EQ, 1> {
831:   static void Init(PetscSFLink link)
832:   { /* Nothing to leave function pointers NULL */
833:   }
834: };

836: template <typename Type, PetscInt BS, PetscInt EQ>
837: static void PackInit_IntegerType(PetscSFLink link)
838: {
839:   link->d_Pack            = Pack<Type, BS, EQ>;
840:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
841:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
842:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
843:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
844:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
845:   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
846:   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
847:   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
848:   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
849:   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
850:   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
851:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

853:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
854:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
855:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
856:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
857:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
858:   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
859:   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
860:   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
861:   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
862:   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
863:   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
864:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
865:   PackInit_IntegerType_Atomic<Type, BS, EQ, sizeof(Type)>::Init(link);
866: }

868: #if defined(PETSC_HAVE_COMPLEX)
869: template <typename Type, PetscInt BS, PetscInt EQ>
870: static void PackInit_ComplexType(PetscSFLink link)
871: {
872:   link->d_Pack            = Pack<Type, BS, EQ>;
873:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
874:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
875:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
876:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

878:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
879:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
880:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
881:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

883:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
884:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
885:   link->da_UnpackAndMult   = NULL; /* Not implemented yet */
886:   link->da_FetchAndAdd     = NULL; /* Return value of atomicAdd on complex is not atomic */

888:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
889:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
890: }
891: #endif

893: typedef signed char   SignedChar;
894: typedef unsigned char UnsignedChar;
895: typedef struct {
896:   int a;
897:   int b;
898: } PairInt;
899: typedef struct {
900:   PetscInt a;
901:   PetscInt b;
902: } PairPetscInt;

904: template <typename Type>
905: static void PackInit_PairType(PetscSFLink link)
906: {
907:   link->d_Pack            = Pack<Type, 1, 1>;
908:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
909:   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
910:   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;

912:   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
913:   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
914:   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
915:   /* Atomics for pair types are not implemented yet */
916: }

918: template <typename Type, PetscInt BS, PetscInt EQ>
919: static void PackInit_DumbType(PetscSFLink link)
920: {
921:   link->d_Pack             = Pack<Type, BS, EQ>;
922:   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
923:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
924:   /* Atomics for dumb types are not implemented yet */
925: }

927: /* Some device-specific utilities */
928: static PetscErrorCode PetscSFLinkSyncDevice_HIP(PetscSFLink link)
929: {
930:   PetscFunctionBegin;
931:   PetscCallHIP(hipDeviceSynchronize());
932:   PetscFunctionReturn(PETSC_SUCCESS);
933: }

935: static PetscErrorCode PetscSFLinkSyncStream_HIP(PetscSFLink link)
936: {
937:   PetscFunctionBegin;
938:   PetscCallHIP(hipStreamSynchronize(link->stream));
939:   PetscFunctionReturn(PETSC_SUCCESS);
940: }

942: static PetscErrorCode PetscSFLinkMemcpy_HIP(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
943: {
944:   PetscFunctionBegin;
945:   enum hipMemcpyKind kinds[2][2] = {
946:     {hipMemcpyHostToHost,   hipMemcpyHostToDevice  },
947:     {hipMemcpyDeviceToHost, hipMemcpyDeviceToDevice}
948:   };

950:   if (n) {
951:     if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { /* Separate HostToHost so that pure-cpu code won't call hip runtime */
952:       PetscCall(PetscMemcpy(dst, src, n));
953:     } else {
954:       int stype = PetscMemTypeDevice(srcmtype) ? 1 : 0;
955:       int dtype = PetscMemTypeDevice(dstmtype) ? 1 : 0;
956:       PetscCallHIP(hipMemcpyAsync(dst, src, n, kinds[stype][dtype], link->stream));
957:     }
958:   }
959:   PetscFunctionReturn(PETSC_SUCCESS);
960: }

962: PetscErrorCode PetscSFMalloc_HIP(PetscMemType mtype, size_t size, void **ptr)
963: {
964:   PetscFunctionBegin;
965:   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
966:   else if (PetscMemTypeDevice(mtype)) {
967:     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP));
968:     PetscCallHIP(hipMalloc(ptr, size));
969:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
970:   PetscFunctionReturn(PETSC_SUCCESS);
971: }

973: PetscErrorCode PetscSFFree_HIP(PetscMemType mtype, void *ptr)
974: {
975:   PetscFunctionBegin;
976:   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
977:   else if (PetscMemTypeDevice(mtype)) PetscCallHIP(hipFree(ptr));
978:   else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
979:   PetscFunctionReturn(PETSC_SUCCESS);
980: }

982: /* Destructor when the link uses MPI for communication on HIP device */
983: static PetscErrorCode PetscSFLinkDestroy_MPI_HIP(PetscSF sf, PetscSFLink link)
984: {
985:   PetscFunctionBegin;
986:   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
987:     PetscCallHIP(hipFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
988:     PetscCallHIP(hipFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
989:   }
990:   PetscFunctionReturn(PETSC_SUCCESS);
991: }

993: /*====================================================================================*/
994: /*                Main driver to init MPI datatype on device                          */
995: /*====================================================================================*/

997: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
998: PetscErrorCode PetscSFLinkSetUp_HIP(PetscSF sf, PetscSFLink link, MPI_Datatype unit)
999: {
1000:   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
1001:   PetscBool is2Int, is2PetscInt;
1002: #if defined(PETSC_HAVE_COMPLEX)
1003:   PetscInt nPetscComplex = 0;
1004: #endif

1006:   PetscFunctionBegin;
1007:   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
1008:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
1009:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
1010:   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
1011:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
1012:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
1013:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
1014: #if defined(PETSC_HAVE_COMPLEX)
1015:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
1016: #endif
1017:   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
1018:   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));

1020:   if (is2Int) {
1021:     PackInit_PairType<PairInt>(link);
1022:   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
1023:     PackInit_PairType<PairPetscInt>(link);
1024:   } else if (nPetscReal) {
1025: #if !defined(PETSC_HAVE_DEVICE)
1026:     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
1027:     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
1028:     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
1029:     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
1030:     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
1031:     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
1032:     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
1033:     else if (nPetscReal % 1 == 0)
1034: #endif
1035:       PackInit_RealType<PetscReal, 1, 0>(link);
1036:   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
1037: #if !defined(PETSC_HAVE_DEVICE)
1038:     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
1039:     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
1040:     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
1041:     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
1042:     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
1043:     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
1044:     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
1045:     else if (nPetscInt % 1 == 0)
1046: #endif
1047:       PackInit_IntegerType<llint, 1, 0>(link);
1048:   } else if (nInt) {
1049: #if !defined(PETSC_HAVE_DEVICE)
1050:     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
1051:     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
1052:     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
1053:     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
1054:     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
1055:     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
1056:     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
1057:     else if (nInt % 1 == 0)
1058: #endif
1059:       PackInit_IntegerType<int, 1, 0>(link);
1060:   } else if (nSignedChar) {
1061: #if !defined(PETSC_HAVE_DEVICE)
1062:     if (nSignedChar == 8) PackInit_IntegerType<SignedChar, 8, 1>(link);
1063:     else if (nSignedChar % 8 == 0) PackInit_IntegerType<SignedChar, 8, 0>(link);
1064:     else if (nSignedChar == 4) PackInit_IntegerType<SignedChar, 4, 1>(link);
1065:     else if (nSignedChar % 4 == 0) PackInit_IntegerType<SignedChar, 4, 0>(link);
1066:     else if (nSignedChar == 2) PackInit_IntegerType<SignedChar, 2, 1>(link);
1067:     else if (nSignedChar % 2 == 0) PackInit_IntegerType<SignedChar, 2, 0>(link);
1068:     else if (nSignedChar == 1) PackInit_IntegerType<SignedChar, 1, 1>(link);
1069:     else if (nSignedChar % 1 == 0)
1070: #endif
1071:       PackInit_IntegerType<SignedChar, 1, 0>(link);
1072:   } else if (nUnsignedChar) {
1073: #if !defined(PETSC_HAVE_DEVICE)
1074:     if (nUnsignedChar == 8) PackInit_IntegerType<UnsignedChar, 8, 1>(link);
1075:     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<UnsignedChar, 8, 0>(link);
1076:     else if (nUnsignedChar == 4) PackInit_IntegerType<UnsignedChar, 4, 1>(link);
1077:     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<UnsignedChar, 4, 0>(link);
1078:     else if (nUnsignedChar == 2) PackInit_IntegerType<UnsignedChar, 2, 1>(link);
1079:     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<UnsignedChar, 2, 0>(link);
1080:     else if (nUnsignedChar == 1) PackInit_IntegerType<UnsignedChar, 1, 1>(link);
1081:     else if (nUnsignedChar % 1 == 0)
1082: #endif
1083:       PackInit_IntegerType<UnsignedChar, 1, 0>(link);
1084: #if defined(PETSC_HAVE_COMPLEX)
1085:   } else if (nPetscComplex) {
1086:   #if !defined(PETSC_HAVE_DEVICE)
1087:     if (nPetscComplex == 8) PackInit_ComplexType<PetscComplex, 8, 1>(link);
1088:     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<PetscComplex, 8, 0>(link);
1089:     else if (nPetscComplex == 4) PackInit_ComplexType<PetscComplex, 4, 1>(link);
1090:     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<PetscComplex, 4, 0>(link);
1091:     else if (nPetscComplex == 2) PackInit_ComplexType<PetscComplex, 2, 1>(link);
1092:     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<PetscComplex, 2, 0>(link);
1093:     else if (nPetscComplex == 1) PackInit_ComplexType<PetscComplex, 1, 1>(link);
1094:     else if (nPetscComplex % 1 == 0)
1095:   #endif
1096:       PackInit_ComplexType<PetscComplex, 1, 0>(link);
1097: #endif
1098:   } else {
1099:     MPI_Aint lb, nbyte;
1100:     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
1101:     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
1102:     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
1103: #if !defined(PETSC_HAVE_DEVICE)
1104:       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
1105:       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
1106:       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
1107:       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
1108:       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
1109:       else if (nbyte % 1 == 0)
1110: #endif
1111:         PackInit_DumbType<char, 1, 0>(link);
1112:     } else {
1113:       nInt = nbyte / sizeof(int);
1114: #if !defined(PETSC_HAVE_DEVICE)
1115:       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
1116:       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
1117:       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
1118:       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
1119:       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
1120:       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
1121:       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
1122:       else if (nInt % 1 == 0)
1123: #endif
1124:         PackInit_DumbType<int, 1, 0>(link);
1125:     }
1126:   }

1128:   if (!sf->maxResidentThreadsPerGPU) { /* Not initialized */
1129:     int                    device;
1130:     struct hipDeviceProp_t props;
1131:     PetscCallHIP(hipGetDevice(&device));
1132:     PetscCallHIP(hipGetDeviceProperties(&props, device));
1133:     sf->maxResidentThreadsPerGPU = props.maxThreadsPerMultiProcessor * props.multiProcessorCount;
1134:   }
1135:   link->maxResidentThreadsPerGPU = sf->maxResidentThreadsPerGPU;

1137:   link->stream       = PetscDefaultHipStream;
1138:   link->Destroy      = PetscSFLinkDestroy_MPI_HIP;
1139:   link->SyncDevice   = PetscSFLinkSyncDevice_HIP;
1140:   link->SyncStream   = PetscSFLinkSyncStream_HIP;
1141:   link->Memcpy       = PetscSFLinkMemcpy_HIP;
1142:   link->deviceinited = PETSC_TRUE;
1143:   PetscFunctionReturn(PETSC_SUCCESS);
1144: }