31 virtual vector<pair<Shape, Type>>
state_specs()
const {
return {}; }
57 void apply_cpu(
const vector<TensorView>& inputs,
TensorView& output);
59 void apply_gpu(
const vector<TensorView>& inputs,
TensorView& output);
61 void check(
const vector<TensorView>& inputs,
const TensorView& output)
const;
100 void apply_delta_cpu(
TensorView& delta) const;
101 void apply_delta_gpu(
TensorView& delta) const;
103 void ensure_mask(Index n);
160 void set(Index new_input_features, Index new_output_features,
Type new_weight_type =
Type::FP32);
176 bool accumulate_input_delta =
false)
const;
183 TensorView& input_delta,
bool accumulate_input_delta)
const;
185 TensorView& input_delta,
bool accumulate_input_delta)
const;
187#ifdef OPENNN_HAS_CUDA
188 mutable const LtMatmulPlan* fwd_plan_ =
nullptr;
189 mutable int fwd_total_rows_ = -1;
192 mutable const LtMatmulPlan* bwd_plan_ =
nullptr;
193 mutable int bwd_total_rows_ = -1;
194 mutable int bwd_io_dtype_ = -1;
213 void set(Index new_features,
float new_momentum = 0.1f);
250 bool inference_cache_dirty =
true;
252 mutable VectorR delta_scale_scratch;
257 void apply_training_cpu (
const TensorView& input,
260 void apply_training_gpu (
const TensorView& input,
298#ifdef OPENNN_HAS_CUDA
313 Index planned_batch_size = 0;
316 void set(Index input_h, Index input_w,
317 Index kernels_n, Index kernel_h, Index kernel_w, Index kernel_c,
318 Index row_stride, Index column_stride,
319 Index padding_h, Index padding_w,
356#ifdef OPENNN_HAS_CUDA
404 void apply_delta_cpu(
const TensorView& output_delta,
440 float* scratch)
const;
501 float scaling_factor() const;
548#ifdef OPENNN_HAS_CUDA
551 void apply_delta_gpu_unfused(
const TensorView& query,
563 mutable std::unique_ptr<SDPACache> sdpa_cache;
581#ifdef OPENNN_HAS_CUDA
585 void set(Index input_h, Index input_w, Index input_c,
586 Index pool_h, Index pool_w,
588 Index padding_h, Index padding_w,
616 void apply_delta_cpu(
const TensorView& output_delta,
652 void set(Index new_vocabulary_size, Index new_sequence_length, Index new_embedding_dimension);
716 void set(Index new_features);
738 void set(Index new_features);
Definition adaptive_moment_estimation.h:19
cublasLtEpilogue_t
Definition neural_network.h:85
@ CUBLASLT_EPILOGUE_DEFAULT
Definition neural_network.h:85
@ CUBLASLT_EPILOGUE_BIAS
Definition neural_network.h:85
float mean(const VectorR &)
Matrix< float, Dynamic, Dynamic, Layout > MatrixR
Definition neural_network.h:152
Matrix< float, Dynamic, 1 > VectorR
Definition neural_network.h:156
@ CUDA
Definition configuration.h:16
VectorI maximal_indices(const VectorR &, Index)
cudnnActivationMode_t
Definition neural_network.h:89
Type
Definition configuration.h:18
@ FP32
Definition configuration.h:18
cudnnConvolutionBwdFilterAlgo_t
Definition neural_network.h:94
@ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
Definition neural_network.h:94
cudnnConvolutionBwdDataAlgo_t
Definition neural_network.h:93
@ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0
Definition neural_network.h:93
cudnnConvolutionFwdAlgo_t
Definition neural_network.h:92
@ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
Definition neural_network.h:92
void * cudnnActivationDescriptor_t
Definition neural_network.h:102
void * cudnnPoolingDescriptor_t
Definition neural_network.h:101
void * cudnnConvolutionDescriptor_t
Definition neural_network.h:100
void * cudnnFilterDescriptor_t
Definition neural_network.h:99
static Function from_string(const string &name)
void apply_delta(const TensorView &outputs, TensorView &delta) const
Function
Definition operators.h:108
@ Sigmoid
Definition operators.h:108
@ Softmax
Definition operators.h:108
@ Identity
Definition operators.h:108
@ Tanh
Definition operators.h:108
@ ReLU
Definition operators.h:108
cudnnActivationDescriptor_t descriptor
Definition operators.h:117
~Activation() override
Definition operators.h:132
void destroy_cuda() override
static const string & to_string(Function function)
Function function
Definition operators.h:115
static const EnumMap< Function > & map()
void from_JSON(const Json *parent) override
Activation & operator=(const Activation &)=delete
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
void to_JSON(JsonWriter &w) const override
void set_function(const string &name)
void set_function(Function new_function)
void apply(TensorView &output)
static cudnnActivationMode_t to_cudnn_mode(Function function)
Activation(const Activation &)=delete
Definition operators.h:53
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Index heads_number
Definition operators.h:445
void apply_delta(const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &attention_output, const TensorView &attention_weights, const TensorView &attention_weights_dropped, const TensorView &output_gradient, TensorView &attention_weight_gradient, TensorView &query_gradient, TensorView &key_gradient, TensorView &value_gradient) const
Attention(Attention &&) noexcept
Type compute_dtype
Definition operators.h:450
vector< pair< Shape, Type > > forward_scratch_specs(Index batch_size) const
void destroy_cuda() override
void set_dropout_rate(float rate)
Definition operators.h:460
MatrixR causal_mask
Definition operators.h:452
void apply(const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &source_input, TensorView &attention_weights, TensorView &attention_weights_dropped, TensorView &output, float *mask_scratch, bool is_training)
void to_JSON(JsonWriter &w) const override
void set(Index heads_number, Index head_dimension, Index query_sequence_length, Index source_sequence_length, bool use_causal_mask, Type compute_dtype)
bool use_causal_mask
Definition operators.h:449
void from_JSON(const Json *parent) override
Index query_sequence_length
Definition operators.h:447
Index head_dimension
Definition operators.h:446
Dropout dropout
Definition operators.h:454
Index source_sequence_length
Definition operators.h:448
Definition operators.h:199
bool active() const
Definition operators.h:211
TensorView gamma_gradient
Definition operators.h:208
void apply_training(const TensorView &input, TensorView &mean, TensorView &inverse_variance, TensorView &output)
void invalidate_inference_cache()
Definition operators.h:241
void apply_delta(const TensorView &input, const TensorView &mean, const TensorView &inverse_variance, TensorView &delta) const
TensorView gamma
Definition operators.h:203
void link_gradients(const vector< TensorView > &views) override
TensorView beta
Definition operators.h:204
vector< pair< Shape, Type > > parameter_specs() const override
TensorView running_variance
Definition operators.h:206
void apply_inference(const TensorView &input, TensorView &output)
void link_states(const vector< TensorView > &views) override
void link_parameters(const vector< TensorView > &views) override
vector< pair< Shape, Type > > state_specs() const override
void set(Index new_features, float new_momentum=0.1f)
TensorView running_mean
Definition operators.h:205
Index features
Definition operators.h:200
void load_state_from_JSON(const Json *parent) override
float momentum
Definition operators.h:201
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
void set_parameters_random() override
Definition operators.h:221
void set_parameters_glorot() override
Definition operators.h:222
void update_inference_cache()
void from_JSON(const Json *parent) override
void to_JSON(JsonWriter &w) const override
TensorView beta_gradient
Definition operators.h:209
Definition operators.h:685
Index features
Definition operators.h:689
vector< pair< Shape, Type > > state_specs() const override
Method method
Definition operators.h:688
Method
Definition operators.h:686
@ Bounding
Definition operators.h:686
@ NoBounding
Definition operators.h:686
void load_state_from_JSON(const Json *parent) override
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
TensorView upper
Definition operators.h:692
TensorView lower
Definition operators.h:691
void link_states(const vector< TensorView > &views) override
void set(Method new_method, Index new_features)
Definition tensor_utilities.h:144
Definition operators.h:147
vector< pair< Shape, Type > > parameter_specs() const override
TensorView weight_gradient
Definition operators.h:157
cublasLtEpilogue_t epilogue
Definition operators.h:152
void set(Index new_input_features, Index new_output_features, Type new_weight_type=Type::FP32)
void link_parameters(const vector< TensorView > &views) override
TensorView weights
Definition operators.h:154
TensorView bias
Definition operators.h:155
Type weight_type
Definition operators.h:150
void apply_delta(const TensorView &output_delta, const TensorView &input, TensorView &input_delta, bool accumulate_input_delta=false) const
void set_parameters_random() override
Index output_features
Definition operators.h:149
void apply(const TensorView &input, TensorView &output, cublasLtEpilogue_t epilogue=CUBLASLT_EPILOGUE_BIAS)
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
TensorView bias_gradient
Definition operators.h:158
void link_gradients(const vector< TensorView > &views) override
Index input_features
Definition operators.h:148
void set_parameters_glorot() override
Convolution(const Convolution &)=delete
Index kernel_channels
Definition operators.h:282
void link_gradients(const vector< TensorView > &views) override
Index padding_width
Definition operators.h:285
void apply_delta(const TensorView &input, const TensorView &output_delta, TensorView &input_delta) const
void set_parameters_random() override
Index input_height
Definition operators.h:276
TensorView weight_gradient
Definition operators.h:295
void destroy_cuda() override
Index padding_height
Definition operators.h:284
vector< pair< Shape, Type > > parameter_specs() const override
void set(Index input_h, Index input_w, Index kernels_n, Index kernel_h, Index kernel_w, Index kernel_c, Index row_stride, Index column_stride, Index padding_h, Index padding_w, Type compute_dtype)
Index kernel_width
Definition operators.h:281
TensorView bias
Definition operators.h:293
Type compute_dtype
Definition operators.h:287
TensorView weights
Definition operators.h:292
void apply(const TensorView &input, TensorView &output, cudnnActivationDescriptor_t fused_activation=nullptr)
Index kernel_height
Definition operators.h:280
~Convolution() override
Definition operators.h:331
void link_parameters(const vector< TensorView > &views) override
cudnnActivationDescriptor_t fused_activation
Definition operators.h:290
Convolution & operator=(const Convolution &)=delete
Index kernels_number
Definition operators.h:279
Index input_width
Definition operators.h:277
void set_parameters_glorot() override
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
TensorView bias_gradient
Definition operators.h:296
Definition operators.h:65
vector< size_t > save_slots
Definition operators.h:74
void destroy_cuda() override
void from_JSON(const Json *parent) override
float rate
Definition operators.h:66
void apply(TensorView &output)
Buffer mask
Definition operators.h:70
~Dropout() override
Definition operators.h:90
void to_JSON(JsonWriter &w) const override
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Dropout(Dropout &&) noexcept=default
VectorR mask_cpu
Definition operators.h:68
void apply_delta(TensorView &delta) const
bool active() const
Definition operators.h:76
void set_rate(float new_rate)
Definition operators.h:637
void link_parameters(const vector< TensorView > &views) override
void set(Index new_vocabulary_size, Index new_sequence_length, Index new_embedding_dimension)
bool scale_embedding
Definition operators.h:642
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Index sequence_length
Definition operators.h:639
vector< pair< Shape, Type > > state_specs() const override
void apply_delta(const TensorView &indices, const TensorView &output_delta) const
Index vocabulary_size
Definition operators.h:638
void init_positional_encoding()
void apply(const TensorView &indices, TensorView &output)
TensorView weights
Definition operators.h:647
void set_parameters_glorot() override
void link_states(const vector< TensorView > &views) override
float embedding_scale
Definition operators.h:645
Index embedding_dimension
Definition operators.h:640
vector< pair< Shape, Type > > parameter_specs() const override
void link_gradients(const vector< TensorView > &views) override
TensorView weight_gradient
Definition operators.h:650
bool add_positional_encoding
Definition operators.h:643
TensorView positional_encoding
Definition operators.h:648
void set_parameters_random() override
Definition operators.h:680
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Definition forward_propagation.h:19
Definition operators.h:362
TensorView beta_gradient
Definition operators.h:370
void apply(const TensorView &input, TensorView &means, TensorView &standard_deviations, TensorView &normalized, TensorView &output)
Index sequence_length
Definition operators.h:363
TensorView gamma
Definition operators.h:366
vector< pair< Shape, Type > > parameter_specs() const override
Index embedding_dimension
Definition operators.h:364
TensorView gamma_gradient
Definition operators.h:369
TensorView beta
Definition operators.h:367
void set(Index sequence_length, Index embedding_dimension)
void set_parameters_glorot() override
Definition operators.h:379
void apply_delta(const TensorView &input, const TensorView &output_delta, const TensorView &means, const TensorView &standard_deviations, const TensorView &normalized, TensorView &input_delta) const
void link_gradients(const vector< TensorView > &views) override
void set_parameters_random() override
Definition operators.h:378
void link_parameters(const vector< TensorView > &views) override
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Definition operators.h:415
void apply_delta(const TensorView &head_gradient, const TensorView &input, TensorView &input_gradient, bool accumulate, float *scratch) const
void link_parameters(const vector< TensorView > &views) override
Definition operators.h:425
void link_gradients(const vector< TensorView > &views) override
Definition operators.h:426
void set_parameters_glorot() override
Definition operators.h:429
Combination combination
Definition operators.h:416
vector< pair< Shape, Type > > parameter_specs() const override
Definition operators.h:424
void set(Index input_features, Index heads_number, Index head_dimension, Type compute_dtype)
Index head_dimension
Definition operators.h:419
Type compute_dtype
Definition operators.h:420
Index heads_number
Definition operators.h:418
void apply(const TensorView &input, TensorView &head_output, float *scratch)
Index input_features
Definition operators.h:417
void set_parameters_random() override
Definition operators.h:428
Definition operators.h:27
virtual void destroy_cuda()
Definition operators.h:46
virtual void link_states(const vector< TensorView > &)
Definition operators.h:35
virtual void link_parameters(const vector< TensorView > &)
Definition operators.h:33
virtual void from_JSON(const Json *)
Definition operators.h:43
virtual void to_JSON(JsonWriter &) const
Definition operators.h:42
virtual void load_state_from_JSON(const Json *)
Definition operators.h:44
vector< size_t > output_slots
Definition operators.h:49
virtual vector< pair< Shape, Type > > state_specs() const
Definition operators.h:31
virtual vector< pair< Shape, Type > > parameter_specs() const
Definition operators.h:30
virtual void forward_propagate(ForwardPropagation &, size_t, bool) noexcept
Definition operators.h:40
virtual void set_parameters_glorot()
Definition operators.h:38
virtual ~Operator()=default
vector< size_t > input_slots
Definition operators.h:48
virtual void link_gradients(const vector< TensorView > &)
Definition operators.h:34
virtual void set_parameters_random()
Definition operators.h:37
Definition operators.h:626
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
int method
Definition operators.h:627
~Pool() override
Definition operators.h:593
int method
Definition operators.h:579
Index pool_width
Definition operators.h:573
Pool(const Pool &)=delete
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
Pool & operator=(const Pool &)=delete
Index input_width
Definition operators.h:569
void apply_delta(const TensorView &input, const TensorView &output, const TensorView &output_delta, const TensorView &maximal_indices, TensorView &input_delta) const
Index input_height
Definition operators.h:568
Index row_stride
Definition operators.h:574
void set(Index input_h, Index input_w, Index input_c, Index pool_h, Index pool_w, Index row_stride, Index column_stride, Index padding_h, Index padding_w, int method)
void apply(const TensorView &input, TensorView &output, TensorView &maximal_indices, bool is_training)
Index pool_height
Definition operators.h:572
void destroy_cuda() override
Index column_stride
Definition operators.h:575
Index input_channels
Definition operators.h:570
Index padding_width
Definition operators.h:577
Index padding_height
Definition operators.h:576
Definition operators.h:705
TensorView scalers
Definition operators.h:714
Index features
Definition operators.h:706
TensorView minimums
Definition operators.h:710
void link_states(const vector< TensorView > &views) override
float min_range
Definition operators.h:707
void load_state_from_JSON(const Json *parent) override
void set(Index new_features)
TensorView standard_deviations
Definition operators.h:713
TensorView means
Definition operators.h:712
vector< pair< Shape, Type > > state_specs() const override
float max_range
Definition operators.h:708
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
TensorView maximums
Definition operators.h:711
Definition tensor_utilities.h:236
Definition operators.h:727
float min_range
Definition operators.h:729
Index features
Definition operators.h:728
TensorView maximums
Definition operators.h:733
TensorView standard_deviations
Definition operators.h:735
float max_range
Definition operators.h:730
void link_states(const vector< TensorView > &views) override
void load_state_from_JSON(const Json *parent) override
TensorView scalers
Definition operators.h:736
vector< pair< Shape, Type > > state_specs() const override
void forward_propagate(ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
TensorView minimums
Definition operators.h:732
void set(Index new_features)
TensorView means
Definition operators.h:734