OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
cuda_gemm.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// C U D A G E M M M O D U L E
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "tensor_utilities.h"
12
13namespace opennn
14{
15
16#ifdef OPENNN_HAS_CUDA
17
18struct LtMatmulPlan
19{
20 cublasLtMatmulDesc_t op_desc = nullptr;
21 cublasLtMatrixLayout_t a_desc = nullptr;
22 cublasLtMatrixLayout_t b_desc = nullptr;
23 cublasLtMatrixLayout_t c_desc = nullptr;
24 cublasLtMatrixLayout_t d_desc = nullptr;
25 cublasLtMatmulAlgo_t algo{};
26 bool algo_valid = false;
27 size_t workspace_size = 0; // bytes the chosen algo actually needs
28
29 LtMatmulPlan() = default;
30 LtMatmulPlan(const LtMatmulPlan&) = delete;
31 LtMatmulPlan& operator=(const LtMatmulPlan&) = delete;
32 LtMatmulPlan(LtMatmulPlan&& other) noexcept { *this = move(other); }
33 LtMatmulPlan& operator=(LtMatmulPlan&& other) noexcept
34 {
35 swap(op_desc, other.op_desc);
36 swap(a_desc, other.a_desc);
37 swap(b_desc, other.b_desc);
38 swap(c_desc, other.c_desc);
39 swap(d_desc, other.d_desc);
40 swap(algo, other.algo);
41 swap(algo_valid, other.algo_valid);
42 swap(workspace_size, other.workspace_size);
43 return *this;
44 }
45 ~LtMatmulPlan()
46 {
47 cublasLtMatrixLayoutDestroy(d_desc);
48 cublasLtMatrixLayoutDestroy(c_desc);
49 cublasLtMatrixLayoutDestroy(b_desc);
50 cublasLtMatrixLayoutDestroy(a_desc);
51 cublasLtMatmulDescDestroy(op_desc);
52 }
53};
54
55struct LtMatmulPlanKey
56{
57 int m;
58 int n;
59 int k;
60 int transA;
61 int transB;
62 int epilogue; // cublasLtEpilogue_t cast to int (e.g. BIAS, RELU_BIAS, BGRADA)
63 int io_dtype; // cudaDataType_t for A and B (inputs)
64 int out_dtype; // cudaDataType_t for C and D (outputs)
65
66 bool operator==(const LtMatmulPlanKey& other) const noexcept
67 {
68 return m == other.m && n == other.n && k == other.k
69 && transA == other.transA && transB == other.transB
70 && epilogue == other.epilogue
71 && io_dtype == other.io_dtype && out_dtype == other.out_dtype;
72 }
73};
74
75struct LtMatmulPlanKeyHash
76{
77 size_t operator()(const LtMatmulPlanKey& key) const noexcept
78 {
79 return hash_combine(key.m, key.n, key.k,
80 key.transA, key.transB, key.epilogue,
81 key.io_dtype, key.out_dtype);
82 }
83};
84
85// Upper bound passed to cuBLASLt's heuristic search — limits which algorithms
86// are considered. The actual VRAM allocated only grows to the max workspace
87// the chosen algorithms reported they need (see ensure_cublas_lt_workspace).
88constexpr size_t cublas_lt_workspace_search_bytes() { return 32ull * 1024 * 1024; }
89
90// Grows the global cublasLt scratch buffer to at least `min_bytes`. Returns a
91// pointer to it. Initial size is 0 — the buffer only grows when a plan whose
92// chosen algorithm needs more workspace gets created.
93void* ensure_cublas_lt_workspace(size_t min_bytes = 0);
94
95__nv_bfloat16* ensure_bf16_input_scratch(Index n_elements);
96
97__nv_bfloat16* ensure_bf16_gradient_scratch(Index n_elements);
98
99float* ensure_fp32_upcast_scratch(Index n_elements);
100
101float* get_loss_scratch(Index n_elements);
102
103const void* maybe_cast(const TensorView& input, Type target_type);
104
105
106const LtMatmulPlan& get_lt_gemm_plan(
107 int m, int n, int k,
108 cublasOperation_t transA,
109 cublasOperation_t transB,
110 cublasLtEpilogue_t epilogue,
111 cudaDataType_t io_dtype = CUDA_R_32F,
112 cudaDataType_t out_dtype = CUDA_R_32F);
113
114inline void gemm_cuda(cublasOperation_t transa, cublasOperation_t transb,
115 int m, int n, int k,
116 const void* A, cudaDataType_t Atype, int lda,
117 const void* B, cudaDataType_t Btype, int ldb,
118 void* C, cudaDataType_t Ctype, int ldc,
119 float alpha = 1.0f, float beta = 0.0f)
120{
121 // CUBLAS_COMPUTE_32F_FAST_TF32 is FP32-input only; for BF16 use plain CUBLAS_COMPUTE_32F.
122 const cublasComputeType_t compute = (Atype == CUDA_R_16BF || Btype == CUDA_R_16BF)
125 CHECK_CUBLAS(cublasGemmEx(Backend::get_cublas_handle(),
126 transa, transb,
127 m, n, k,
128 &alpha,
129 A, Atype, lda,
130 B, Btype, ldb,
131 &beta,
132 C, Ctype, ldc,
133 compute,
134 CUBLAS_GEMM_DEFAULT));
135}
136
137inline void gemm_strided_batched_cuda(cublasOperation_t transa, cublasOperation_t transb,
138 int m, int n, int k,
139 const void* A, int lda, long long stride_a,
140 const void* B, int ldb, long long stride_b,
141 void* C, int ldc, long long stride_c,
142 int batch_count,
143 cudaDataType_t io_dtype = CUDA_R_32F,
144 float alpha = 1.0f, float beta = 0.0f)
145{
146 const cublasComputeType_t compute = (io_dtype == CUDA_R_16BF)
149 CHECK_CUBLAS(cublasGemmStridedBatchedEx(Backend::get_cublas_handle(),
150 transa, transb,
151 m, n, k,
152 &alpha,
153 A, io_dtype, lda, stride_a,
154 B, io_dtype, ldb, stride_b,
155 &beta,
156 C, io_dtype, ldc, stride_c,
157 batch_count,
158 compute,
159 CUBLAS_GEMM_DEFAULT));
160}
161
162#endif // OPENNN_HAS_CUDA
163
164}
165
166// OpenNN: Open Neural Networks Library.
167// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
168// Licensed under the GNU Lesser General Public License v2.1 or later.
static cublasHandle_t get_cublas_handle()
Definition tensor_utilities.h:403
size_t hash_combine(const Vs &... values)
Definition neural_network.h:388
Definition adaptive_moment_estimation.h:19
cublasLtEpilogue_t
Definition neural_network.h:85
constexpr cublasComputeType_t CUBLAS_COMPUTE_DTYPE
Definition tensor_utilities.h:43
Type
Definition configuration.h:18
cudaDataType_t
Definition neural_network.h:82
@ CUDA_R_16BF
Definition neural_network.h:82
@ CUDA_R_32F
Definition neural_network.h:82
cublasComputeType_t
Definition neural_network.h:83
@ CUBLAS_COMPUTE_32F
Definition neural_network.h:83
cublasOperation_t
Definition neural_network.h:84
Definition tensor_utilities.h:236
Definition neural_network.h:78