27 static constexpr Index
bytes = Index(
sizeof(
float));
28 static constexpr const char*
name =
"FP32";
37 static constexpr const char*
name =
"BF16";
45 static constexpr Index
bytes = Index(1);
46 static constexpr const char*
name =
"INT8";
49template<
Type... Supported,
typename F>
54 if (!matched)
throw runtime_error(
"visit_type: unsupported Type value");
57template<
Type... Supported,
typename F>
60 visit_type<Supported...>(t_in, [&](
auto in_info)
62 visit_type<Supported...>(t_out, [&](
auto out_info)
115 static Configuration configuration;
116 return configuration;
130 return cached_resolved;
132 return resolve_slow();
145 const Resolved& resolve_slow()
const;
151 mutable Resolved cached_resolved;
152 mutable std::atomic<bool> cache_valid{
false};
Definition configuration.h:103
const Resolved & resolve() const
Definition configuration.h:127
bool is_cpu() const
Definition configuration.h:136
Device get_device() const
Definition configuration.h:123
Type get_training_type() const
Definition configuration.h:124
bool is_bf16_inference() const
Definition configuration.h:139
bool is_bf16_training() const
Definition configuration.h:138
bool is_gpu() const
Definition configuration.h:135
Type get_inference_type() const
Definition configuration.h:125
static Configuration & instance()
Definition configuration.h:113
void set(Device new_device=Device::Auto, Type new_training_type=Type::Auto, Type new_inference_type=Type::Auto)
Definition adaptive_moment_estimation.h:19
cudnnDataType_t
Definition neural_network.h:88
@ CUDNN_DATA_INT8
Definition neural_network.h:88
@ CUDNN_DATA_FLOAT
Definition neural_network.h:88
@ CUDNN_DATA_BFLOAT16
Definition neural_network.h:88
Index type_bytes(Type type) noexcept
Definition configuration.h:91
Device
Definition configuration.h:16
@ Auto
Definition configuration.h:16
@ CPU
Definition configuration.h:16
@ CUDA
Definition configuration.h:16
cudnnDataType_t to_cudnn(Type type) noexcept
Definition configuration.h:69
Type
Definition configuration.h:18
@ Auto
Definition configuration.h:18
@ FP32
Definition configuration.h:18
@ INT8
Definition configuration.h:18
@ BF16
Definition configuration.h:18
cudaDataType_t
Definition neural_network.h:82
@ CUDA_R_8I
Definition neural_network.h:82
@ CUDA_R_16BF
Definition neural_network.h:82
@ CUDA_R_32F
Definition neural_network.h:82
void visit_type(Type t, F &&f)
Definition configuration.h:50
cudaDataType_t to_cuda(Type type) noexcept
Definition configuration.h:80
void visit_type_pair(Type t_in, Type t_out, F &&f)
Definition configuration.h:58
Definition configuration.h:107
Type training_type
Definition configuration.h:109
Type inference_type
Definition configuration.h:110
Device device
Definition configuration.h:108
__nv_bfloat16 type
Definition configuration.h:33
static constexpr Index bytes
Definition configuration.h:36
static constexpr cudaDataType_t cuda
Definition configuration.h:35
static constexpr cudnnDataType_t cudnn
Definition configuration.h:34
static constexpr const char * name
Definition configuration.h:37
float type
Definition configuration.h:24
static constexpr Index bytes
Definition configuration.h:27
static constexpr cudnnDataType_t cudnn
Definition configuration.h:25
static constexpr const char * name
Definition configuration.h:28
static constexpr cudaDataType_t cuda
Definition configuration.h:26
static constexpr cudaDataType_t cuda
Definition configuration.h:44
int8_t type
Definition configuration.h:42
static constexpr cudnnDataType_t cudnn
Definition configuration.h:43
static constexpr const char * name
Definition configuration.h:46
static constexpr Index bytes
Definition configuration.h:45
Definition configuration.h:20
Definition neural_network.h:78
Device device
Definition neural_network.h:108
Type training_type
Definition neural_network.h:109
Type inference_type
Definition neural_network.h:110