Actual source code: math2opussampler.hpp

  1: #include <petscmat.h>
  2: #include <h2opus.h>

  4: #ifndef __MATH2OPUS_HPP

  7: class PetscMatrixSampler : public HMatrixSampler {
  8: protected:
  9:   Mat                                                                    A;
 10:   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
 11:   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type         HIntVector;
 12:   HIntVector                                                             hindexmap;
 13:   HRealVector                                                            hbuffer_in, hbuffer_out;
 14:   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
 15:   H2OpusDeviceVector<int>         dindexmap;
 16:   H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out;
 17:   #endif
 18:   bool                  gpusampling;
 19:   h2opusComputeStream_t stream;

 21: private:
 22:   void Init();
 23:   void VerifyBuffers(int);
 24:   void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **);
 25:   void PermuteBuffersOut(int, H2Opus_Real *);

 27: public:
 28:   PetscMatrixSampler();
 29:   PetscMatrixSampler(Mat);
 30:   ~PetscMatrixSampler();
 31:   void         SetSamplingMat(Mat);
 32:   void         SetIndexMap(int, int *);
 33:   void         SetGPUSampling(bool);
 34:   void         SetStream(h2opusComputeStream_t);
 35:   virtual void sample(H2Opus_Real *, H2Opus_Real *, int);
 36:   Mat          GetSamplingMat() { return A; }
 37: };

 39: void PetscMatrixSampler::Init()
 40: {
 41:   this->A           = NULL;
 42:   this->gpusampling = false;
 43:   this->stream      = NULL;
 44: }

 46: PetscMatrixSampler::PetscMatrixSampler()
 47: {
 48:   Init();
 49: }

 51: PetscMatrixSampler::PetscMatrixSampler(Mat A)
 52: {
 53:   Init();
 54:   SetSamplingMat(A);
 55: }

 57: void PetscMatrixSampler::SetSamplingMat(Mat A)
 58: {
 59:   PetscMPIInt size = 1;

 61:   if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size)));
 62:   if (size > 1) PetscCallVoid(PETSC_ERR_SUP);
 63:   PetscCallVoid(PetscObjectReference((PetscObject)A));
 64:   PetscCallVoid(MatDestroy(&this->A));
 65:   this->A = A;
 66: }

 68: void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream)
 69: {
 70:   this->stream = stream;
 71: }

 73: void PetscMatrixSampler::SetIndexMap(int n, int *indexmap)
 74: {
 75:   copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
 76:   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
 77:   copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
 78:   #endif
 79: }

 81: void PetscMatrixSampler::VerifyBuffers(int nv)
 82: {
 83:   if (this->hindexmap.size()) {
 84:     size_t n = this->hindexmap.size();
 85:     if (!this->gpusampling) {
 86:       if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv);
 87:       if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv);
 88:     } else {
 89:   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
 90:       if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv);
 91:       if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv);
 92:   #endif
 93:     }
 94:   }
 95: }

 97: void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow)
 98: {
 99:   *w  = v;
100:   *ow = ov;
101:   VerifyBuffers(nv);
102:   if (this->hindexmap.size()) {
103:     size_t n = this->hindexmap.size();
104:     if (!this->gpusampling) {
105:       permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream);
106:       *w  = this->hbuffer_in.data();
107:       *ow = this->hbuffer_out.data();
108:     } else {
109:   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
110:       permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream);
111:       *w  = this->dbuffer_in.data();
112:       *ow = this->dbuffer_out.data();
113:   #endif
114:     }
115:   }
116: }

118: void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v)
119: {
120:   VerifyBuffers(nv);
121:   if (this->hindexmap.size()) {
122:     size_t n = this->hindexmap.size();
123:     if (!this->gpusampling) {
124:       permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream);
125:     } else {
126:   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
127:       permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream);
128:   #endif
129:     }
130:   }
131: }

133: void PetscMatrixSampler::SetGPUSampling(bool gpusampling)
134: {
135:   this->gpusampling = gpusampling;
136: }

138: PetscMatrixSampler::~PetscMatrixSampler()
139: {
140:   PetscCallVoid(MatDestroy(&A));
141: }

143: void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples)
144: {
145:   MPI_Comm     comm = PetscObjectComm((PetscObject)this->A);
146:   Mat          X = NULL, Y = NULL;
147:   PetscInt     M, N, m, n;
148:   H2Opus_Real *px, *py;

150:   if (!this->A) PetscCallVoid(PETSC_ERR_PLIB);
151:   PetscCallVoid(MatGetSize(this->A, &M, &N));
152:   PetscCallVoid(MatGetLocalSize(this->A, &m, &n));
153:   PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm));
154:   PermuteBuffersIn(samples, x, &px, y, &py);
155:   if (!this->gpusampling) {
156:     PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X));
157:     PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y));
158:   } else {
159:   #if defined(PETSC_HAVE_CUDA)
160:     PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X));
161:     PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y));
162:   #endif
163:   }
164:   PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y));
165:   #if defined(PETSC_HAVE_CUDA)
166:   if (this->gpusampling) {
167:     const PetscScalar *dummy;
168:     PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy));
169:     PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy));
170:   }
171:   #endif
172:   PermuteBuffersOut(samples, y);
173:   PetscCallVoid(MatDestroy(&X));
174:   PetscCallVoid(MatDestroy(&Y));
175: }

177: #endif