Actual source code: curand.c

  1: #include <petsc/private/deviceimpl.h>
  2: #include <petsc/private/randomimpl.h>
  3: #include <petscdevice_cuda.h>
  4: #include <curand.h>

  6: typedef struct {
  7:   curandGenerator_t gen;
  8: } PetscRandom_CURAND;

 10: PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
 11: {
 12:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;

 14:   PetscFunctionBegin;
 15:   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
 16:   PetscFunctionReturn(PETSC_SUCCESS);
 17: }

 19: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);

 21: PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
 22: {
 23:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
 24:   size_t              nn     = n < 0 ? (size_t)(-2 * n) : n; /* handle complex case */

 26:   PetscFunctionBegin;
 27: #if defined(PETSC_USE_REAL_SINGLE)
 28:   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
 29: #else
 30:   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
 31: #endif
 32:   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
 33:   PetscFunctionReturn(PETSC_SUCCESS);
 34: }

 36: PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
 37: {
 38:   PetscFunctionBegin;
 39: #if defined(PETSC_USE_COMPLEX)
 40:   /* pass negative size to flag complex scaling (if needed) */
 41:   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
 42: #else
 43:   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
 44: #endif
 45:   PetscFunctionReturn(PETSC_SUCCESS);
 46: }

 48: PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
 49: {
 50:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;

 52:   PetscFunctionBegin;
 53:   PetscCallCURAND(curandDestroyGenerator(curand->gen));
 54:   PetscCall(PetscFree(r->data));
 55:   PetscFunctionReturn(PETSC_SUCCESS);
 56: }

 58: static struct _PetscRandomOps PetscRandomOps_Values = {
 59:   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
 60:   PetscDesignatedInitializer(getvalue, NULL),
 61:   PetscDesignatedInitializer(getvaluereal, NULL),
 62:   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
 63:   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
 64:   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
 65: };

 67: /*MC
 68:    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object

 70:   PETSc must be ./configure with the option --with-cuda to use this random number generator.

 72:   Level: beginner

 74: .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
 75: M*/

 77: PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
 78: {
 79:   PetscRandom_CURAND *curand;

 81:   PetscFunctionBegin;
 82:   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
 83:   PetscCall(PetscNew(&curand));
 84:   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
 85:   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
 86:   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
 87:   PetscCall(PetscMemcpy(r->ops, &PetscRandomOps_Values, sizeof(PetscRandomOps_Values)));
 88:   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
 89:   r->data = curand;
 90:   PetscCall(PetscRandomSeed_CURAND(r));
 91:   PetscFunctionReturn(PETSC_SUCCESS);
 92: }