OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::Attention Member List

This is the complete list of members for opennn::Attention, including all inherited members.

apply(const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &source_input, TensorView &attention_weights, TensorView &attention_weights_dropped, TensorView &output, float *mask_scratch, bool is_training)opennn::Attention
apply_delta(const TensorView &query, const TensorView &key, const TensorView &value, const TensorView &attention_output, const TensorView &attention_weights, const TensorView &attention_weights_dropped, const TensorView &output_gradient, TensorView &attention_weight_gradient, TensorView &query_gradient, TensorView &key_gradient, TensorView &value_gradient) constopennn::Attention
Attention()opennn::Attention
Attention(Attention &&) noexceptopennn::Attention
Attention(const Attention &)=deleteopennn::Attention
causal_maskopennn::Attention
compute_dtypeopennn::Attention
destroy_cuda() overrideopennn::Attentionvirtual
dropoutopennn::Attention
forward_propagate(ForwardPropagation &, size_t, bool) noexceptopennn::Operatorinlinevirtual
forward_scratch_specs(Index batch_size) constopennn::Attention
from_JSON(const Json *parent) overrideopennn::Attentionvirtual
head_dimensionopennn::Attention
heads_numberopennn::Attention
input_slotsopennn::Operator
link_gradients(const vector< TensorView > &)opennn::Operatorinlinevirtual
link_parameters(const vector< TensorView > &)opennn::Operatorinlinevirtual
link_states(const vector< TensorView > &)opennn::Operatorinlinevirtual
load_state_from_JSON(const Json *)opennn::Operatorinlinevirtual
operator=(Attention &&) noexceptopennn::Attention
operator=(const Attention &)=deleteopennn::Attention
output_slotsopennn::Operator
parameter_specs() constopennn::Operatorinlinevirtual
query_sequence_lengthopennn::Attention
set(Index heads_number, Index head_dimension, Index query_sequence_length, Index source_sequence_length, bool use_causal_mask, Type compute_dtype)opennn::Attention
set_dropout_rate(float rate)opennn::Attentioninline
set_parameters_glorot()opennn::Operatorinlinevirtual
set_parameters_random()opennn::Operatorinlinevirtual
source_sequence_lengthopennn::Attention
state_specs() constopennn::Operatorinlinevirtual
to_JSON(JsonWriter &w) const overrideopennn::Attentionvirtual
use_causal_maskopennn::Attention
~Attention() overrideopennn::Attention
~Operator()=defaultopennn::Operatorvirtual