Actual source code: cupmstream.hpp
1: #ifndef PETSC_CUPMSTREAM_HPP
2: #define PETSC_CUPMSTREAM_HPP
4: #include <petsc/private/cupminterface.hpp>
6: #include "../segmentedmempool.hpp"
7: #include "cupmevent.hpp"
9: #if defined(__cplusplus)
10: namespace Petsc
11: {
13: namespace device
14: {
16: namespace cupm
17: {
19: // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
20: // identify separate cupm streams. This is so that the memory pool can accelerate allocation
21: // calls as it can just pass back a pointer to memory that was used on the same
22: // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
23: // Address of the objects does not suffice since cupmStreams are very likely internally reused.
25: template <DeviceType T>
26: class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
27: using crtp_base_type = StreamBase<CUPMStream<T>>;
28: friend crtp_base_type;
30: public:
31: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
33: using stream_type = cupmStream_t;
34: using id_type = typename crtp_base_type::id_type;
35: using event_type = CUPMEvent<T>;
36: using flag_type = unsigned int;
38: CUPMStream() noexcept = default;
40: PetscErrorCode destroy() noexcept;
41: PetscErrorCode create(flag_type) noexcept;
42: PetscErrorCode change_type(PetscStreamType) noexcept;
44: private:
45: stream_type stream_{};
46: id_type id_ = new_id_();
48: PETSC_NODISCARD static id_type new_id_() noexcept;
50: // CRTP implementations
51: PETSC_NODISCARD stream_type get_stream_() const noexcept;
52: PETSC_NODISCARD id_type get_id_() const noexcept;
53: PetscErrorCode record_event_(event_type &) const noexcept;
54: PetscErrorCode wait_for_(event_type &) const noexcept;
55: };
57: template <DeviceType T>
58: inline PetscErrorCode CUPMStream<T>::destroy() noexcept
59: {
60: PetscFunctionBegin;
61: if (stream_) {
62: PetscCallCUPM(cupmStreamDestroy(stream_));
63: stream_ = cupmStream_t{};
64: id_ = 0;
65: }
66: PetscFunctionReturn(PETSC_SUCCESS);
67: }
69: template <DeviceType T>
70: inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
71: {
72: PetscFunctionBegin;
73: if (stream_) {
74: if (PetscDefined(USE_DEBUG)) {
75: flag_type current_flags;
77: PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags));
78: PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
79: }
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
82: PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
83: id_ = new_id_();
84: PetscFunctionReturn(PETSC_SUCCESS);
85: }
87: template <DeviceType T>
88: inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
89: {
90: PetscFunctionBegin;
91: if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
92: PetscCall(destroy());
93: } else {
94: const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
96: if (stream_) {
97: flag_type flag;
99: PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
100: if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
101: PetscCall(destroy());
102: }
103: PetscCall(create(preferred));
104: }
105: PetscFunctionReturn(PETSC_SUCCESS);
106: }
108: template <DeviceType T>
109: inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
110: {
111: static id_type id = 0;
112: return id++;
113: }
115: // CRTP implementations
116: template <DeviceType T>
117: inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept
118: {
119: return stream_;
120: }
122: template <DeviceType T>
123: inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
124: {
125: return id_;
126: }
128: template <DeviceType T>
129: inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
130: {
131: PetscFunctionBegin;
132: PetscCall(event.record(stream_));
133: PetscFunctionReturn(PETSC_SUCCESS);
134: }
136: template <DeviceType T>
137: inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
138: {
139: PetscFunctionBegin;
140: PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
141: PetscFunctionReturn(PETSC_SUCCESS);
142: }
144: } // namespace cupm
146: } // namespace device
148: } // namespace Petsc
149: #endif // __cplusplus
151: #endif // PETSC_CUPMSTREAM_HPP