Actual source code: cupmstream.hpp

  1: #ifndef PETSC_CUPMSTREAM_HPP
  2: #define PETSC_CUPMSTREAM_HPP

  4: #include <petsc/private/cupminterface.hpp>

  6: #include "../segmentedmempool.hpp"
  7: #include "cupmevent.hpp"

  9: #if defined(__cplusplus)
 10: namespace Petsc
 11: {

 13: namespace device
 14: {

 16: namespace cupm
 17: {

 19: // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
 20: // identify separate cupm streams. This is so that the memory pool can accelerate allocation
 21: // calls as it can just pass back a pointer to memory that was used on the same
 22: // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
 23: // Address of the objects does not suffice since cupmStreams are very likely internally reused.

 25: template <DeviceType T>
 26: class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
 27:   using crtp_base_type = StreamBase<CUPMStream<T>>;
 28:   friend crtp_base_type;

 30: public:
 31:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 33:   using stream_type = cupmStream_t;
 34:   using id_type     = typename crtp_base_type::id_type;
 35:   using event_type  = CUPMEvent<T>;
 36:   using flag_type   = unsigned int;

 38:   CUPMStream() noexcept = default;

 40:   PetscErrorCode destroy() noexcept;
 41:   PetscErrorCode create(flag_type) noexcept;
 42:   PetscErrorCode change_type(PetscStreamType) noexcept;

 44: private:
 45:   stream_type stream_{};
 46:   id_type     id_ = new_id_();

 48:   PETSC_NODISCARD static id_type new_id_() noexcept;

 50:   // CRTP implementations
 51:   PETSC_NODISCARD stream_type get_stream_() const noexcept;
 52:   PETSC_NODISCARD id_type     get_id_() const noexcept;
 53:   PetscErrorCode              record_event_(event_type &) const noexcept;
 54:   PetscErrorCode              wait_for_(event_type &) const noexcept;
 55: };

 57: template <DeviceType T>
 58: inline PetscErrorCode CUPMStream<T>::destroy() noexcept
 59: {
 60:   PetscFunctionBegin;
 61:   if (stream_) {
 62:     PetscCallCUPM(cupmStreamDestroy(stream_));
 63:     stream_ = cupmStream_t{};
 64:     id_     = 0;
 65:   }
 66:   PetscFunctionReturn(PETSC_SUCCESS);
 67: }

 69: template <DeviceType T>
 70: inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
 71: {
 72:   PetscFunctionBegin;
 73:   if (stream_) {
 74:     if (PetscDefined(USE_DEBUG)) {
 75:       flag_type current_flags;

 77:       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
 78:       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
 79:     }
 80:     PetscFunctionReturn(PETSC_SUCCESS);
 81:   }
 82:   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
 83:   id_ = new_id_();
 84:   PetscFunctionReturn(PETSC_SUCCESS);
 85: }

 87: template <DeviceType T>
 88: inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
 89: {
 90:   PetscFunctionBegin;
 91:   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
 92:     PetscCall(destroy());
 93:   } else {
 94:     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;

 96:     if (stream_) {
 97:       flag_type flag;

 99:       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
100:       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
101:       PetscCall(destroy());
102:     }
103:     PetscCall(create(preferred));
104:   }
105:   PetscFunctionReturn(PETSC_SUCCESS);
106: }

108: template <DeviceType T>
109: inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
110: {
111:   static id_type id = 0;
112:   return id++;
113: }

115: // CRTP implementations
116: template <DeviceType T>
117: inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept
118: {
119:   return stream_;
120: }

122: template <DeviceType T>
123: inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
124: {
125:   return id_;
126: }

128: template <DeviceType T>
129: inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
130: {
131:   PetscFunctionBegin;
132:   PetscCall(event.record(stream_));
133:   PetscFunctionReturn(PETSC_SUCCESS);
134: }

136: template <DeviceType T>
137: inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
138: {
139:   PetscFunctionBegin;
140:   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
141:   PetscFunctionReturn(PETSC_SUCCESS);
142: }

144: } // namespace cupm

146: } // namespace device

148: } // namespace Petsc
149: #endif // __cplusplus

151: #endif // PETSC_CUPMSTREAM_HPP