37 const Shape& new_source_dimensions,
55 return (heads_number == 0) ? 0 : Index(embedding_dimension / heads_number);
82 const string& =
"multihead_attention_layer");
91 void set_dropout_rate(
float new_dropout_rate) { attention.set_dropout_rate(new_dropout_rate); }
101 Index embedding_dimension = 0;
102 Index heads_number = 0;
103 Index query_sequence_length = 0;
104 Index source_sequence_length = 0;
113 enum Forward {Input, Query, Key, AttentionWeights, AttentionWeightsDropped,
114 ConcatenatedAttentionOutputs, Value, TransposeScratch, Output};
119 AttentionWeightDelta,
121 ConcatenatedOutputDelta,
void set(Index=0, Index=0, Index=0, Index=0, bool=false, const string &="multihead_attention_layer")
Reconfigures the layer with new sequence, embedding, head sizes and causal flag.
void on_compute_dtype_changed() override
Rebuilds projection operators when the compute dtype changes.
Index get_query_sequence_length() const
Definition multihead_attention_layer.h:47
Shape get_heads_shape(Index batch_size) const
Returns the per-head tensor shape used internally during attention.
Definition multihead_attention_layer.h:59
void set_input_shape(const Shape &) override
Updates the layer for a new input shape.
Shape get_concat_shape(Index batch_size) const
Returns the shape used when concatenating heads back to the embedding dimension.
Definition multihead_attention_layer.h:65
Shape get_output_shape() const override
Returns the output tensor shape.
Index get_heads_number() const
Definition multihead_attention_layer.h:50
Index get_embedding_dimension() const
Definition multihead_attention_layer.h:49
vector< TensorSpec > get_forward_specs(Index batch_size) const override
Returns the tensor specifications used during forward propagation.
MultiHeadAttention(const Shape &new_query_dimensions, const Shape &new_source_dimensions, Index=0, const string &={})
Constructs a cross-attention layer with separate query and source (key/value) sequences.
Index get_source_sequence_length() const
Definition multihead_attention_layer.h:48
void read_JSON_body(const Json *) override
Reads the layer configuration from a JSON node.
Shape get_input_shape() const override
Returns the input tensor shape.
void set_dropout_rate(float new_dropout_rate)
Sets the dropout rate applied to the attention weights.
Definition multihead_attention_layer.h:91
MultiHeadAttention(const Shape &=Shape({0, 0}), Index=0, const string &={})
Constructs a self-attention layer where queries and keys share the same sequence.
Index get_head_dimension() const
Returns the per-head dimension (embedding_dimension / heads_number).
Definition multihead_attention_layer.h:53
vector< TensorSpec > get_backward_specs(Index batch_size) const override
Returns the tensor specifications used during back propagation.
void write_JSON_body(JsonWriter &) const override
Writes the layer configuration to a JSON writer.
Definition adaptive_moment_estimation.h:14
Scaled dot-product attention with optional causal mask and dropout.
Definition operators.h:642
Affine combination output = input * weights + bias (the dense matmul building block).
Definition operators.h:232
Reshapes (batch, heads, seq, head_dim) tensors back into (batch, seq, embed); no parameters.
Definition operators.h:839
Projects (input_features) into (heads * head_dim) and reshapes for multi-head attention.
Definition operators.h:579
Fixed-capacity small-vector describing tensor dimensions (rank up to MaxRank).
Definition tensor_utilities.h:42