OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::Attention Struct Reference

#include <operators.h>

Inheritance diagram for opennn::Attention:
[legend]

Public Member Functions

void set (Index heads_number, Index head_dimension, Index query_sequence_length, Index source_sequence_length, bool use_causal_mask, Type compute_dtype)
 
void set_dropout_rate (float rate)
 
vector< pair< Shape, Type > > forward_scratch_specs (Index batch_size) const
 
void apply (const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &source_input, TensorView &attention_weights, TensorView &attention_weights_dropped, TensorView &output, float *mask_scratch, bool is_training)
 
void apply_delta (const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &attention_output, const TensorView &attention_weights, const TensorView &attention_weights_dropped, const TensorView &output_gradient, TensorView &attention_weight_gradient, TensorView &query_gradient, TensorView &key_gradient, TensorView &value_gradient) const
 
void to_JSON (JsonWriter &w) const override
 
void from_JSON (const Json *parent) override
 
void destroy_cuda () override
 
 Attention ()
 
 ~Attention () override
 
 Attention (Attention &&) noexcept
 
Attentionoperator= (Attention &&) noexcept
 
 Attention (const Attention &)=delete
 
Attentionoperator= (const Attention &)=delete
 
- Public Member Functions inherited from opennn::Operator
virtual ~Operator ()=default
 
virtual vector< pair< Shape, Type > > parameter_specs () const
 
virtual vector< pair< Shape, Type > > state_specs () const
 
virtual void link_parameters (const vector< TensorView > &)
 
virtual void link_gradients (const vector< TensorView > &)
 
virtual void link_states (const vector< TensorView > &)
 
virtual void set_parameters_random ()
 
virtual void set_parameters_glorot ()
 
virtual void forward_propagate (ForwardPropagation &, size_t, bool) noexcept
 
virtual void load_state_from_JSON (const Json *)
 

Public Attributes

Index heads_number = 0
 
Index head_dimension = 0
 
Index query_sequence_length = 0
 
Index source_sequence_length = 0
 
bool use_causal_mask = false
 
Type compute_dtype = Type::FP32
 
MatrixR causal_mask
 
Dropout dropout
 
- Public Attributes inherited from opennn::Operator
vector< size_t > input_slots
 
vector< size_t > output_slots
 

Constructor & Destructor Documentation

◆ Attention() [1/3]

opennn::Attention::Attention ( )

◆ ~Attention()

opennn::Attention::~Attention ( )
override

◆ Attention() [2/3]

opennn::Attention::Attention ( Attention && )
noexcept

◆ Attention() [3/3]

opennn::Attention::Attention ( const Attention & )
delete

Member Function Documentation

◆ apply()

void opennn::Attention::apply ( const TensorView & query,
const TensorView & key,
const TensorView & value,
const TensorView & source_input,
TensorView & attention_weights,
TensorView & attention_weights_dropped,
TensorView & output,
float * mask_scratch,
bool is_training )

◆ apply_delta()

void opennn::Attention::apply_delta ( const TensorView & query,
const TensorView & key,
const TensorView & value,
const TensorView & attention_output,
const TensorView & attention_weights,
const TensorView & attention_weights_dropped,
const TensorView & output_gradient,
TensorView & attention_weight_gradient,
TensorView & query_gradient,
TensorView & key_gradient,
TensorView & value_gradient ) const

◆ destroy_cuda()

void opennn::Attention::destroy_cuda ( )
overridevirtual

Reimplemented from opennn::Operator.

◆ forward_scratch_specs()

vector< pair< Shape, Type > > opennn::Attention::forward_scratch_specs ( Index batch_size) const

◆ from_JSON()

void opennn::Attention::from_JSON ( const Json * parent)
overridevirtual

Reimplemented from opennn::Operator.

◆ operator=() [1/2]

Attention & opennn::Attention::operator= ( Attention && )
noexcept

◆ operator=() [2/2]

Attention & opennn::Attention::operator= ( const Attention & )
delete

◆ set()

void opennn::Attention::set ( Index heads_number,
Index head_dimension,
Index query_sequence_length,
Index source_sequence_length,
bool use_causal_mask,
Type compute_dtype )

◆ set_dropout_rate()

void opennn::Attention::set_dropout_rate ( float rate)
inline

◆ to_JSON()

void opennn::Attention::to_JSON ( JsonWriter & w) const
overridevirtual

Reimplemented from opennn::Operator.

Member Data Documentation

◆ causal_mask

MatrixR opennn::Attention::causal_mask

◆ compute_dtype

Type opennn::Attention::compute_dtype = Type::FP32

◆ dropout

Dropout opennn::Attention::dropout

◆ head_dimension

Index opennn::Attention::head_dimension = 0

◆ heads_number

Index opennn::Attention::heads_number = 0

◆ query_sequence_length

Index opennn::Attention::query_sequence_length = 0

◆ source_sequence_length

Index opennn::Attention::source_sequence_length = 0

◆ use_causal_mask

bool opennn::Attention::use_causal_mask = false