|
OpenNN
Open-source neural networks library
|
#include <operators.h>
Public Member Functions | |
| void | set (Index heads_number, Index head_dimension, Index query_sequence_length, Index source_sequence_length, bool use_causal_mask, Type compute_dtype) |
| void | set_dropout_rate (float rate) |
| vector< pair< Shape, Type > > | forward_scratch_specs (Index batch_size) const |
| void | apply (const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &source_input, TensorView &attention_weights, TensorView &attention_weights_dropped, TensorView &output, float *mask_scratch, bool is_training) |
| void | apply_delta (const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &attention_output, const TensorView &attention_weights, const TensorView &attention_weights_dropped, const TensorView &output_gradient, TensorView &attention_weight_gradient, TensorView &query_gradient, TensorView &key_gradient, TensorView &value_gradient) const |
| void | to_JSON (JsonWriter &w) const override |
| void | from_JSON (const Json *parent) override |
| void | destroy_cuda () override |
| Attention () | |
| ~Attention () override | |
| Attention (Attention &&) noexcept | |
| Attention & | operator= (Attention &&) noexcept |
| Attention (const Attention &)=delete | |
| Attention & | operator= (const Attention &)=delete |
Public Member Functions inherited from opennn::Operator | |
| virtual | ~Operator ()=default |
| virtual vector< pair< Shape, Type > > | parameter_specs () const |
| virtual vector< pair< Shape, Type > > | state_specs () const |
| virtual void | link_parameters (const vector< TensorView > &) |
| virtual void | link_gradients (const vector< TensorView > &) |
| virtual void | link_states (const vector< TensorView > &) |
| virtual void | set_parameters_random () |
| virtual void | set_parameters_glorot () |
| virtual void | forward_propagate (ForwardPropagation &, size_t, bool) noexcept |
| virtual void | load_state_from_JSON (const Json *) |
Public Attributes | |
| Index | heads_number = 0 |
| Index | head_dimension = 0 |
| Index | query_sequence_length = 0 |
| Index | source_sequence_length = 0 |
| bool | use_causal_mask = false |
| Type | compute_dtype = Type::FP32 |
| MatrixR | causal_mask |
| Dropout | dropout |
Public Attributes inherited from opennn::Operator | |
| vector< size_t > | input_slots |
| vector< size_t > | output_slots |
| opennn::Attention::Attention | ( | ) |
|
override |
|
noexcept |
|
delete |
| void opennn::Attention::apply | ( | const TensorView & | query, |
| const TensorView & | key, | ||
| const TensorView & | value, | ||
| const TensorView & | source_input, | ||
| TensorView & | attention_weights, | ||
| TensorView & | attention_weights_dropped, | ||
| TensorView & | output, | ||
| float * | mask_scratch, | ||
| bool | is_training ) |
| void opennn::Attention::apply_delta | ( | const TensorView & | query, |
| const TensorView & | key, | ||
| const TensorView & | value, | ||
| const TensorView & | attention_output, | ||
| const TensorView & | attention_weights, | ||
| const TensorView & | attention_weights_dropped, | ||
| const TensorView & | output_gradient, | ||
| TensorView & | attention_weight_gradient, | ||
| TensorView & | query_gradient, | ||
| TensorView & | key_gradient, | ||
| TensorView & | value_gradient ) const |
|
overridevirtual |
Reimplemented from opennn::Operator.
|
overridevirtual |
Reimplemented from opennn::Operator.
| void opennn::Attention::set | ( | Index | heads_number, |
| Index | head_dimension, | ||
| Index | query_sequence_length, | ||
| Index | source_sequence_length, | ||
| bool | use_causal_mask, | ||
| Type | compute_dtype ) |
|
inline |
|
overridevirtual |
Reimplemented from opennn::Operator.
| MatrixR opennn::Attention::causal_mask |
| Type opennn::Attention::compute_dtype = Type::FP32 |
| Dropout opennn::Attention::dropout |
| Index opennn::Attention::head_dimension = 0 |
| Index opennn::Attention::heads_number = 0 |
| Index opennn::Attention::query_sequence_length = 0 |
| Index opennn::Attention::source_sequence_length = 0 |
| bool opennn::Attention::use_causal_mask = false |