OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
cuda_gemm.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// C U D A G E M M M O D U L E
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "tensor_utilities.h"
12
13namespace opennn
14{
15
16#ifdef OPENNN_HAS_CUDA
17
18// C and D share one layout descriptor — they always have the same shape and
19// dtype in this codebase, and cuBLASLt accepts the same layout for both.
20struct LtMatmulPlan
21{
22 cublasLtMatmulDesc_t op_desc = nullptr;
23 cublasLtMatrixLayout_t a_desc = nullptr;
24 cublasLtMatrixLayout_t b_desc = nullptr;
25 cublasLtMatrixLayout_t cd_desc = nullptr;
26 cublasLtMatmulAlgo_t algo{};
27 bool algo_valid = false;
28 size_t workspace_size = 0;
29
30 LtMatmulPlan() = default;
31 LtMatmulPlan(const LtMatmulPlan&) = delete;
32 LtMatmulPlan& operator=(const LtMatmulPlan&) = delete;
33 LtMatmulPlan(LtMatmulPlan&& other) noexcept { *this = move(other); }
34 LtMatmulPlan& operator=(LtMatmulPlan&& other) noexcept
35 {
36 std::swap(op_desc, other.op_desc);
37 std::swap(a_desc, other.a_desc);
38 std::swap(b_desc, other.b_desc);
39 std::swap(cd_desc, other.cd_desc);
40 std::swap(algo, other.algo);
41 std::swap(algo_valid, other.algo_valid);
42 std::swap(workspace_size, other.workspace_size);
43 return *this;
44 }
45 ~LtMatmulPlan()
46 {
47 cublasLtMatrixLayoutDestroy(cd_desc);
48 cublasLtMatrixLayoutDestroy(b_desc);
49 cublasLtMatrixLayoutDestroy(a_desc);
50 cublasLtMatmulDescDestroy(op_desc);
51 }
52};
53
54struct LtMatmulPlanKey
55{
56 int m;
57 int n;
58 int k;
59 int transA;
60 int transB;
61 int epilogue; // cublasLtEpilogue_t cast to int (e.g. BIAS, RELU_BIAS, BGRADA)
62 int io_dtype; // cudaDataType_t for A and B (inputs)
63 int out_dtype; // cudaDataType_t for C and D (outputs)
64
65 bool operator==(const LtMatmulPlanKey&) const noexcept = default;
66};
67
68struct LtMatmulPlanKeyHash
69{
70 size_t operator()(const LtMatmulPlanKey& key) const noexcept
71 {
72 return hash_combine(key.m, key.n, key.k,
73 key.transA, key.transB, key.epilogue,
74 key.io_dtype, key.out_dtype);
75 }
76};
77
78// Upper bound passed to cuBLASLt's heuristic search — limits which algorithms
79// are considered. The actual VRAM allocated only grows to the max workspace
80// the chosen algorithms reported they need (see ensure_cublas_lt_workspace).
81constexpr size_t cublas_lt_workspace_search_bytes() { return 32ull * 1024 * 1024; }
82
83namespace scratch
84{
85
86void* ensure_cublas_lt_workspace(size_t min_bytes = 0);
87
88bfloat16* ensure_bf16_input_scratch(Index n_elements);
89
90bfloat16* ensure_bf16_gradient_scratch(Index n_elements);
91
92float* ensure_fp32_upcast_scratch(Index n_elements);
93
94void* ensure_cudnn_conv_workspace(size_t min_bytes);
95
96}
97
98const void* data_for_gemm_dtype(const TensorView& input, Type target_type);
99
100const LtMatmulPlan& get_lt_gemm_plan(
101 int m, int n, int k,
102 cublasOperation_t transA,
103 cublasOperation_t transB,
104 cublasLtEpilogue_t epilogue,
105 cudaDataType_t io_dtype = CUDA_R_32F,
106 cudaDataType_t out_dtype = CUDA_R_32F);
107
108inline void run_lt_matmul(const LtMatmulPlan& plan,
109 const void* a_data, const void* b_data, void* c_data,
110 const void* bias_pointer)
111{
112 CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(plan.op_desc,
113 CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_pointer, sizeof(bias_pointer)));
114
115 CHECK_CUBLAS(cublasLtMatmul(Backend::get_cublas_lt_handle(),
116 plan.op_desc,
117 &one,
118 a_data, plan.a_desc,
119 b_data, plan.b_desc,
120 &zero,
121 c_data, plan.cd_desc,
122 c_data, plan.cd_desc,
123 plan.algo_valid ? &plan.algo : nullptr,
124 scratch::ensure_cublas_lt_workspace(plan.workspace_size), plan.workspace_size,
126}
127
128// CUBLAS_COMPUTE_DTYPE (= CUBLAS_COMPUTE_32F_FAST_TF32) is FP32-input only;
129// BF16 inputs require plain CUBLAS_COMPUTE_32F.
130inline cublasComputeType_t gemm_compute_type(cudaDataType_t a_type, cudaDataType_t b_type = CUDA_R_32F)
131{
132 return (a_type == CUDA_R_16BF || b_type == CUDA_R_16BF)
135}
136
137inline void gemm_cuda(cublasOperation_t transa, cublasOperation_t transb,
138 int m, int n, int k,
139 const void* A, cudaDataType_t Atype, int lda,
140 const void* B, cudaDataType_t Btype, int ldb,
141 void* C, cudaDataType_t Ctype, int ldc,
142 float alpha = 1.0f, float beta = 0.0f)
143{
144 const cublasComputeType_t compute = gemm_compute_type(Atype, Btype);
145 CHECK_CUBLAS(cublasGemmEx(Backend::get_cublas_handle(),
146 transa, transb,
147 m, n, k,
148 &alpha,
149 A, Atype, lda,
150 B, Btype, ldb,
151 &beta,
152 C, Ctype, ldc,
153 compute,
154 CUBLAS_GEMM_DEFAULT));
155}
156
157inline void gemm_strided_batched_cuda(cublasOperation_t transa, cublasOperation_t transb,
158 int m, int n, int k,
159 const void* A, int lda, long long stride_a,
160 const void* B, int ldb, long long stride_b,
161 void* C, int ldc, long long stride_c,
162 int batch_count,
163 cudaDataType_t io_dtype = CUDA_R_32F,
164 float alpha = 1.0f, float beta = 0.0f)
165{
166 const cublasComputeType_t compute = gemm_compute_type(io_dtype);
167 CHECK_CUBLAS(cublasGemmStridedBatchedEx(Backend::get_cublas_handle(),
168 transa, transb,
169 m, n, k,
170 &alpha,
171 A, io_dtype, lda, stride_a,
172 B, io_dtype, ldb, stride_b,
173 &beta,
174 C, io_dtype, ldc, stride_c,
175 batch_count,
176 compute,
177 CUBLAS_GEMM_DEFAULT));
178}
179
180#endif // OPENNN_HAS_CUDA
181
182}
183
184// OpenNN: Open Neural Networks Library.
185// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
186// Licensed under the GNU Lesser General Public License v2.1 or later.
static cudaStream_t get_compute_stream()
Default CUDA stream used by the compute backend.
Definition tensor_utilities.h:510
static cublasHandle_t get_cublas_handle()
Shared cuBLAS handle for legacy GEMM calls.
Definition tensor_utilities.h:504
static cublasLtHandle_t get_cublas_lt_handle()
Shared cuBLASLt handle for batched/tuned GEMMs.
Definition tensor_utilities.h:506
Definition adaptive_moment_estimation.h:14
size_t hash_combine(const Vs &... values)
Boost-style hash combine that mixes any number of hashable values into a single size_t.
Definition tensor_utilities.h:482
constexpr cublasComputeType_t CUBLAS_COMPUTE_DTYPE
Definition tensor_utilities.h:38
Type
Numeric precision used for training or inference tensors.
Definition configuration.h:20
__nv_bfloat16 bfloat16
Definition pch.h:145
cudaDataType_t
Definition pch.h:93
@ CUDA_R_32F
Definition pch.h:93
@ CUDA_R_16BF
Definition pch.h:93
cublasLtEpilogue_t
Definition pch.h:96
cublasOperation_t
Definition pch.h:95
cublasComputeType_t
Definition pch.h:94
@ CUBLAS_COMPUTE_32F
Definition pch.h:94
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293