20 cublasLtMatmulDesc_t op_desc =
nullptr;
21 cublasLtMatrixLayout_t a_desc =
nullptr;
22 cublasLtMatrixLayout_t b_desc =
nullptr;
23 cublasLtMatrixLayout_t c_desc =
nullptr;
24 cublasLtMatrixLayout_t d_desc =
nullptr;
25 cublasLtMatmulAlgo_t algo{};
26 bool algo_valid =
false;
27 size_t workspace_size = 0;
29 LtMatmulPlan() =
default;
30 LtMatmulPlan(
const LtMatmulPlan&) =
delete;
31 LtMatmulPlan& operator=(
const LtMatmulPlan&) =
delete;
32 LtMatmulPlan(LtMatmulPlan&& other)
noexcept { *
this = move(other); }
33 LtMatmulPlan& operator=(LtMatmulPlan&& other)
noexcept
35 swap(op_desc, other.op_desc);
36 swap(a_desc, other.a_desc);
37 swap(b_desc, other.b_desc);
38 swap(c_desc, other.c_desc);
39 swap(d_desc, other.d_desc);
40 swap(algo, other.algo);
41 swap(algo_valid, other.algo_valid);
42 swap(workspace_size, other.workspace_size);
47 cublasLtMatrixLayoutDestroy(d_desc);
48 cublasLtMatrixLayoutDestroy(c_desc);
49 cublasLtMatrixLayoutDestroy(b_desc);
50 cublasLtMatrixLayoutDestroy(a_desc);
51 cublasLtMatmulDescDestroy(op_desc);
66 bool operator==(
const LtMatmulPlanKey& other)
const noexcept
68 return m == other.m && n == other.n && k == other.k
69 && transA == other.transA && transB == other.transB
70 && epilogue == other.epilogue
71 && io_dtype == other.io_dtype && out_dtype == other.out_dtype;
75struct LtMatmulPlanKeyHash
77 size_t operator()(
const LtMatmulPlanKey& key)
const noexcept
80 key.transA, key.transB, key.epilogue,
81 key.io_dtype, key.out_dtype);
88constexpr size_t cublas_lt_workspace_search_bytes() {
return 32ull * 1024 * 1024; }
93void* ensure_cublas_lt_workspace(
size_t min_bytes = 0);
99float* ensure_fp32_upcast_scratch(Index n_elements);
101float* get_loss_scratch(Index n_elements);
106const LtMatmulPlan& get_lt_gemm_plan(
119 float alpha = 1.0f,
float beta = 0.0f)
134 CUBLAS_GEMM_DEFAULT));
139 const void* A,
int lda,
long long stride_a,
140 const void* B,
int ldb,
long long stride_b,
141 void* C,
int ldc,
long long stride_c,
144 float alpha = 1.0f,
float beta = 0.0f)
153 A, io_dtype, lda, stride_a,
154 B, io_dtype, ldb, stride_b,
156 C, io_dtype, ldc, stride_c,
159 CUBLAS_GEMM_DEFAULT));
static cublasHandle_t get_cublas_handle()
Definition tensor_utilities.h:403
size_t hash_combine(const Vs &... values)
Definition neural_network.h:388
Definition adaptive_moment_estimation.h:19
cublasLtEpilogue_t
Definition neural_network.h:85
constexpr cublasComputeType_t CUBLAS_COMPUTE_DTYPE
Definition tensor_utilities.h:43
Type
Definition configuration.h:18
cudaDataType_t
Definition neural_network.h:82
@ CUDA_R_16BF
Definition neural_network.h:82
@ CUDA_R_32F
Definition neural_network.h:82
cublasComputeType_t
Definition neural_network.h:83
@ CUBLAS_COMPUTE_32F
Definition neural_network.h:83
cublasOperation_t
Definition neural_network.h:84
Definition tensor_utilities.h:236
Definition neural_network.h:78