OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::MultiHeadAttention Class Referencefinal

Scaled dot-product attention with multiple heads and learned linear projections. More...

#include <multihead_attention_layer.h>

Inheritance diagram for opennn::MultiHeadAttention:
[legend]

Public Member Functions

 MultiHeadAttention (const Shape &input_shape=Shape({0, 0}), Index heads_number=0, const string &label=string())
 Constructs a self-attention layer.
 
 MultiHeadAttention (const Shape &new_query_dimensions, const Shape &new_source_dimensions, Index heads_number=0, const string &label=string())
 Constructs a cross-attention layer.
 
Shape get_input_shape () const override
 Returns the per-sample input shape.
 
Shape get_output_shape () const override
 Returns the per-sample output shape.
 
Index get_query_sequence_length () const
 Length of the query side sequence.
 
Index get_source_sequence_length () const
 Length of the source side sequence (equal to query length for self-attention).
 
Index get_embedding_dimension () const
 Width of the embedding (model) dimension.
 
Index get_heads_number () const
 Number of attention heads.
 
Index get_head_dimension () const
 Per-head feature dimension.
 
Shape get_heads_shape (Index batch_size) const
 Shape of the per-head attention scratch buffer.
 
Shape get_concat_shape (Index batch_size) const
 Shape of the concatenated attention output before projection.
 
vector< Operator * > get_operators () override
 Returns the active operators in pipeline order.
 
vector< pair< Shape, Type > > get_forward_specs (Index batch_size) const override
 Specifications of the forward intermediate buffers.
 
vector< pair< Shape, Type > > get_backward_specs (Index batch_size) const override
 Specifications of the backward intermediate buffers.
 
void set (Index query_sequence_length=0, Index source_sequence_length=0, Index embedding_dimension=0, Index heads_number=0, bool use_causal_mask=false, const string &label="multihead_attention_layer")
 Re-initializes the layer.
 
void set_input_shape (const Shape &new_input_shape) override
 Updates the input shape; rejects shapes whose rank is not 2.
 
void on_compute_dtype_changed () override
 Propagates a compute dtype change to all sub-operators.
 
void set_dropout_rate (float new_dropout_rate)
 Sets the dropout rate applied to attention weights.
 
void forward_propagate (ForwardPropagation &, size_t, bool) noexcept override
 Forward pass: Q/K/V projections, scaled dot-product attention, head concatenation, output projection.
 
void back_propagate (ForwardPropagation &, BackPropagation &, size_t) const noexcept override
 Backward pass through every operator in reverse order.
 
void read_JSON_body (const Json *) override
 Reads the layer-specific JSON body (heads, sequences, dimension, causal flag, dropout).
 
void write_JSON_body (JsonWriter &) const override
 Writes the layer-specific JSON body (heads, sequences, dimension, causal flag, dropout).
 
- Public Member Functions inherited from opennn::Layer
virtual ~Layer ()=default
 Virtual destructor; subclasses are owned via unique_ptr<Layer>.
 
const string & get_label () const
 Returns the user-assigned label of this layer.
 
const string & get_name () const
 Returns the canonical type name of this layer.
 
LayerType get_type () const
 Returns the LayerType enumerator for this layer.
 
virtual void set_output_shape (const Shape &)
 Sets the per-sample output shape of this layer.
 
void set_label (string new_label)
 Sets the human-readable label of this layer.
 
Index get_parameters_number () const
 Total number of trainable parameters in this layer.
 
virtual vector< pair< Shape, Type > > get_parameter_specs () const
 Specifications of the trainable parameter tensors owned by this layer.
 
virtual vector< pair< Shape, Type > > get_state_specs () const
 Specifications of the persistent state tensors of this layer.
 
vector< Shapeget_parameter_shapes () const
 Shape-only view of get_parameter_specs().
 
vector< Shapeget_state_shapes () const
 Shape-only view of get_state_specs().
 
vector< Shapeget_forward_shapes (Index b) const
 Shape-only view of get_forward_specs() for batch size b.
 
vector< Shapeget_backward_shapes (Index b) const
 Shape-only view of get_backward_specs() for batch size b.
 
vector< Typeget_parameter_dtypes () const
 Dtype-only view of get_parameter_specs().
 
vector< Typeget_forward_dtypes (Index b) const
 Dtype-only view of get_forward_specs() for batch size b.
 
vector< Typeget_backward_dtypes (Index b) const
 Dtype-only view of get_backward_specs() for batch size b.
 
virtual Activation::Function get_output_activation () const
 Activation function fused at the end of this layer, if any.
 
Index get_inputs_number () const
 Total number of scalar inputs per sample (product of input dims).
 
Index get_outputs_number () const
 Total number of scalar outputs per sample (product of output dims).
 
virtual void from_JSON (const JsonDocument &document)
 Loads the layer configuration (hyperparameters) from JSON.
 
virtual void load_state_from_JSON (const JsonDocument &document)
 Loads parameter and state tensors from a JSON document.
 
virtual void to_JSON (JsonWriter &writer) const
 Writes the layer configuration to JSON.
 
virtual void print () const
 Prints a human-readable summary of the layer to stdout.
 
bool get_is_trainable () const
 Whether this layer has trainable parameters.
 
Type get_compute_dtype () const
 Numerical type used for forward/backward computation.
 
void set_compute_dtype (Type new_compute_dtype)
 Sets the compute dtype and triggers on_compute_dtype_changed().
 
virtual float * link_parameters (float *pointer)
 Wires this layer's parameter TensorViews onto an external buffer.
 
virtual float * link_states (float *pointer)
 Wires this layer's state TensorViews onto an external buffer.
 
vector< TensorView > & get_parameter_views ()
 Mutable access to this layer's parameter TensorViews.
 
const vector< TensorView > & get_parameter_views () const
 Read-only access to this layer's parameter TensorViews.
 
vector< TensorView > & get_state_views ()
 Mutable access to this layer's state TensorViews.
 
const vector< TensorView > & get_state_views () const
 Read-only access to this layer's state TensorViews.
 
void redistribute_parameters_to_operators ()
 Forwards the current parameter views down to each composing Operator.
 
void redistribute_parameter_gradients_to_operators (vector< TensorView > &gradient_views)
 Forwards externally provided gradient views down to each Operator.
 
void redistribute_states_to_operators ()
 Forwards the current state views down to each composing Operator.
 

Static Public Member Functions

static bool is_self_attention (const vector< vector< TensorView > > &forward_views)
 Convenience predicate: true when the layer was wired in self-attention mode (single input view).
 
static const TensorViewget_query_input (const vector< vector< TensorView > > &forward_views)
 Returns the query side input view.
 
static const TensorViewget_source_input (const vector< vector< TensorView > > &forward_views)
 Returns the source side input view (key/value source).
 

Additional Inherited Members

- Protected Member Functions inherited from opennn::Layer
 Layer ()=default
 Default constructor; only invoked by subclasses.
 
float * link_views (float *pointer, const vector< Shape > &shapes, vector< TensorView > &views, const char *tag) const
 Builds views over a contiguous float buffer using shapes.
 
void distribute_to_operators (vector< TensorView > &views, void(Operator::*link)(const vector< TensorView > &), vector< pair< Shape, Type > >(Operator::*specs)() const)
 Generic helper used by the redistribute_*_to_operators() routines.
 
- Protected Attributes inherited from opennn::Layer
string label = "my_layer"
 User-visible label for this layer instance (default "my_layer").
 
string name = "layer"
 Canonical type name set by the subclass (e.g. "dense").
 
LayerType layer_type = LayerType::Dense
 Layer type tag set by the subclass.
 
bool is_trainable = true
 True if the layer has parameters that participate in training.
 
bool is_first_layer = false
 True if this layer is the network's input layer.
 
Type compute_dtype = Type::FP32
 Numerical type used for forward and backward computation.
 
vector< TensorViewparameters
 Parameter TensorViews bound to the network's parameter arena.
 
vector< TensorViewstates
 State TensorViews bound to the network's state arena.
 
vector< unique_ptr< Layer > > layers
 Sub-layers, when this layer is itself a composite.
 

Detailed Description

Scaled dot-product attention with multiple heads and learned linear projections.

Wraps four Combination/MultiHeadProjection operators (query, key, value, output projection) and one Attention operator that performs the scaled dot-product attention with optional dropout.

Two input modes are supported:

  • Self-attention: a single rank-2 input is used as query, key and value.
  • Cross-attention: two rank-2 inputs (query side and source side); used in the decoder's encoder-decoder attention.

Constructor & Destructor Documentation

◆ MultiHeadAttention() [1/2]

opennn::MultiHeadAttention::MultiHeadAttention ( const Shape & input_shape = Shape({0, 0}),
Index heads_number = 0,
const string & label = string() )

Constructs a self-attention layer.

Parameters
input_shapePer-sample input shape (sequence_length, embedding_dimension).
heads_numberNumber of attention heads (must divide embedding_dimension).
labelHuman-readable label assigned to this layer.

◆ MultiHeadAttention() [2/2]

opennn::MultiHeadAttention::MultiHeadAttention ( const Shape & new_query_dimensions,
const Shape & new_source_dimensions,
Index heads_number = 0,
const string & label = string() )

Constructs a cross-attention layer.

Parameters
new_query_dimensionsQuery side input shape (query_sequence_length, embedding_dimension).
new_source_dimensionsSource side input shape (source_sequence_length, embedding_dimension).
heads_numberNumber of attention heads.
labelHuman-readable label assigned to this layer.

Member Function Documentation

◆ back_propagate()

void opennn::MultiHeadAttention::back_propagate ( ForwardPropagation & ,
BackPropagation & ,
size_t  ) const
overridevirtualnoexcept

Backward pass through every operator in reverse order.

Receives the forward intermediates, the BackPropagation buffer and this layer's index inside the network.

Reimplemented from opennn::Layer.

◆ forward_propagate()

void opennn::MultiHeadAttention::forward_propagate ( ForwardPropagation & ,
size_t ,
bool  )
overridevirtualnoexcept

Forward pass: Q/K/V projections, scaled dot-product attention, head concatenation, output projection.

Receives the ForwardPropagation buffer slice, this layer's index and the training flag.

Reimplemented from opennn::Layer.

◆ get_backward_specs()

vector< pair< Shape, Type > > opennn::MultiHeadAttention::get_backward_specs ( Index batch_size) const
overridevirtual

Specifications of the backward intermediate buffers.

Parameters
batch_sizeBatch size used for sizing.
Returns
One spec per slot in the Backward enum.

Reimplemented from opennn::Layer.

◆ get_concat_shape()

Shape opennn::MultiHeadAttention::get_concat_shape ( Index batch_size) const
inline

Shape of the concatenated attention output before projection.

Parameters
batch_sizeBatch size used for sizing.
Returns
(batch_size, query_sequence_length, heads_number, head_dimension).

◆ get_embedding_dimension()

Index opennn::MultiHeadAttention::get_embedding_dimension ( ) const
inline

Width of the embedding (model) dimension.

◆ get_forward_specs()

vector< pair< Shape, Type > > opennn::MultiHeadAttention::get_forward_specs ( Index batch_size) const
overridevirtual

Specifications of the forward intermediate buffers.

Parameters
batch_sizeBatch size used for sizing.
Returns
One spec per slot in the Forward enum.

Reimplemented from opennn::Layer.

◆ get_head_dimension()

Index opennn::MultiHeadAttention::get_head_dimension ( ) const
inline

Per-head feature dimension.

Returns
embedding_dimension / heads_number, or 0 if heads_number is 0.

◆ get_heads_number()

Index opennn::MultiHeadAttention::get_heads_number ( ) const
inline

Number of attention heads.

◆ get_heads_shape()

Shape opennn::MultiHeadAttention::get_heads_shape ( Index batch_size) const
inline

Shape of the per-head attention scratch buffer.

Parameters
batch_sizeBatch size used for sizing.
Returns
(batch_size, heads_number, query_sequence_length, head_dimension).

◆ get_input_shape()

Shape opennn::MultiHeadAttention::get_input_shape ( ) const
overridevirtual

Returns the per-sample input shape.

Returns
(query_sequence_length, embedding_dimension); subclasses of the network may provide an additional source input separately.

Implements opennn::Layer.

◆ get_operators()

vector< Operator * > opennn::MultiHeadAttention::get_operators ( )
overridevirtual

Returns the active operators in pipeline order.

Returns
Q/K/V projections, Attention, then output projection.

Reimplemented from opennn::Layer.

◆ get_output_shape()

Shape opennn::MultiHeadAttention::get_output_shape ( ) const
overridevirtual

Returns the per-sample output shape.

Returns
(query_sequence_length, embedding_dimension).

Implements opennn::Layer.

◆ get_query_input()

static const TensorView & opennn::MultiHeadAttention::get_query_input ( const vector< vector< TensorView > > & forward_views)
inlinestatic

Returns the query side input view.

Parameters
forward_viewsForwardPropagation views[layer] for this layer.
Returns
Reference to the query input TensorView.

◆ get_query_sequence_length()

Index opennn::MultiHeadAttention::get_query_sequence_length ( ) const
inline

Length of the query side sequence.

◆ get_source_input()

static const TensorView & opennn::MultiHeadAttention::get_source_input ( const vector< vector< TensorView > > & forward_views)
inlinestatic

Returns the source side input view (key/value source).

Parameters
forward_viewsForwardPropagation views[layer] for this layer.
Returns
The query input itself for self-attention; otherwise the second wired input.

◆ get_source_sequence_length()

Index opennn::MultiHeadAttention::get_source_sequence_length ( ) const
inline

Length of the source side sequence (equal to query length for self-attention).

◆ is_self_attention()

static bool opennn::MultiHeadAttention::is_self_attention ( const vector< vector< TensorView > > & forward_views)
inlinestatic

Convenience predicate: true when the layer was wired in self-attention mode (single input view).

Parameters
forward_viewsForwardPropagation views[layer] for this layer.
Returns
True if self-attention, false if cross-attention.

◆ on_compute_dtype_changed()

void opennn::MultiHeadAttention::on_compute_dtype_changed ( )
inlineoverridevirtual

Propagates a compute dtype change to all sub-operators.

Reimplemented from opennn::Layer.

◆ read_JSON_body()

void opennn::MultiHeadAttention::read_JSON_body ( const Json * )
overridevirtual

Reads the layer-specific JSON body (heads, sequences, dimension, causal flag, dropout).

Reimplemented from opennn::Layer.

◆ set()

void opennn::MultiHeadAttention::set ( Index query_sequence_length = 0,
Index source_sequence_length = 0,
Index embedding_dimension = 0,
Index heads_number = 0,
bool use_causal_mask = false,
const string & label = "multihead_attention_layer" )

Re-initializes the layer.

Parameters
query_sequence_lengthLength of the query sequence.
source_sequence_lengthLength of the source sequence.
embedding_dimensionEmbedding (model) dimension.
heads_numberNumber of attention heads.
use_causal_maskTrue to apply a causal mask in self-attention.
labelHuman-readable label.

◆ set_dropout_rate()

void opennn::MultiHeadAttention::set_dropout_rate ( float new_dropout_rate)
inline

Sets the dropout rate applied to attention weights.

Parameters
new_dropout_rateProbability of dropping each attention weight (0 disables dropout).

◆ set_input_shape()

void opennn::MultiHeadAttention::set_input_shape ( const Shape & new_input_shape)
inlineoverridevirtual

Updates the input shape; rejects shapes whose rank is not 2.

Parameters
new_input_shapeNew per-sample input shape (sequence_length, embedding_dimension).

Reimplemented from opennn::Layer.

◆ write_JSON_body()

void opennn::MultiHeadAttention::write_JSON_body ( JsonWriter & ) const
overridevirtual

Writes the layer-specific JSON body (heads, sequences, dimension, causal flag, dropout).

Reimplemented from opennn::Layer.