22 cublasLtMatmulDesc_t op_desc =
nullptr;
23 cublasLtMatrixLayout_t a_desc =
nullptr;
24 cublasLtMatrixLayout_t b_desc =
nullptr;
25 cublasLtMatrixLayout_t cd_desc =
nullptr;
26 cublasLtMatmulAlgo_t algo{};
27 bool algo_valid =
false;
28 size_t workspace_size = 0;
30 LtMatmulPlan() =
default;
31 LtMatmulPlan(
const LtMatmulPlan&) =
delete;
32 LtMatmulPlan& operator=(
const LtMatmulPlan&) =
delete;
33 LtMatmulPlan(LtMatmulPlan&& other)
noexcept { *
this = move(other); }
34 LtMatmulPlan& operator=(LtMatmulPlan&& other)
noexcept
36 std::swap(op_desc, other.op_desc);
37 std::swap(a_desc, other.a_desc);
38 std::swap(b_desc, other.b_desc);
39 std::swap(cd_desc, other.cd_desc);
40 std::swap(algo, other.algo);
41 std::swap(algo_valid, other.algo_valid);
42 std::swap(workspace_size, other.workspace_size);
47 cublasLtMatrixLayoutDestroy(cd_desc);
48 cublasLtMatrixLayoutDestroy(b_desc);
49 cublasLtMatrixLayoutDestroy(a_desc);
50 cublasLtMatmulDescDestroy(op_desc);
65 bool operator==(
const LtMatmulPlanKey&)
const noexcept =
default;
68struct LtMatmulPlanKeyHash
70 size_t operator()(
const LtMatmulPlanKey& key)
const noexcept
73 key.transA, key.transB, key.epilogue,
74 key.io_dtype, key.out_dtype);
81constexpr size_t cublas_lt_workspace_search_bytes() {
return 32ull * 1024 * 1024; }
86void* ensure_cublas_lt_workspace(
size_t min_bytes = 0);
88bfloat16* ensure_bf16_input_scratch(Index n_elements);
90bfloat16* ensure_bf16_gradient_scratch(Index n_elements);
92float* ensure_fp32_upcast_scratch(Index n_elements);
94void* ensure_cudnn_conv_workspace(
size_t min_bytes);
98const void* data_for_gemm_dtype(
const TensorView& input,
Type target_type);
100const LtMatmulPlan& get_lt_gemm_plan(
108inline void run_lt_matmul(
const LtMatmulPlan& plan,
109 const void* a_data,
const void* b_data,
void* c_data,
110 const void* bias_pointer)
112 CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(plan.op_desc,
113 CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_pointer,
sizeof(bias_pointer)));
121 c_data, plan.cd_desc,
122 c_data, plan.cd_desc,
123 plan.algo_valid ? &plan.algo :
nullptr,
124 scratch::ensure_cublas_lt_workspace(plan.workspace_size), plan.workspace_size,
142 float alpha = 1.0f,
float beta = 0.0f)
154 CUBLAS_GEMM_DEFAULT));
159 const void* A,
int lda,
long long stride_a,
160 const void* B,
int ldb,
long long stride_b,
161 void* C,
int ldc,
long long stride_c,
164 float alpha = 1.0f,
float beta = 0.0f)
171 A, io_dtype, lda, stride_a,
172 B, io_dtype, ldb, stride_b,
174 C, io_dtype, ldc, stride_c,
177 CUBLAS_GEMM_DEFAULT));
static cudaStream_t get_compute_stream()
Default CUDA stream used by the compute backend.
Definition tensor_utilities.h:510
static cublasHandle_t get_cublas_handle()
Shared cuBLAS handle for legacy GEMM calls.
Definition tensor_utilities.h:504
static cublasLtHandle_t get_cublas_lt_handle()
Shared cuBLASLt handle for batched/tuned GEMMs.
Definition tensor_utilities.h:506
Definition adaptive_moment_estimation.h:14
size_t hash_combine(const Vs &... values)
Boost-style hash combine that mixes any number of hashable values into a single size_t.
Definition tensor_utilities.h:482
constexpr cublasComputeType_t CUBLAS_COMPUTE_DTYPE
Definition tensor_utilities.h:38
Type
Numeric precision used for training or inference tensors.
Definition configuration.h:20
__nv_bfloat16 bfloat16
Definition pch.h:145
cudaDataType_t
Definition pch.h:93
@ CUDA_R_32F
Definition pch.h:93
@ CUDA_R_16BF
Definition pch.h:93
cublasLtEpilogue_t
Definition pch.h:96
cublasOperation_t
Definition pch.h:95
cublasComputeType_t
Definition pch.h:94
@ CUBLAS_COMPUTE_32F
Definition pch.h:94
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293