OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
multihead_attention_layer.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// M U L T I H E A D A T T E N T I O N L A Y E R C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "layer.h"
12#include "operators.h"
13#include "math_utilities.h"
14
15namespace opennn
16{
17
19class MultiHeadAttention final : public Layer
20{
21public:
22
27 MultiHeadAttention(const Shape& = Shape({0, 0}),
28 Index = 0,
29 const string& = {});
30
36 MultiHeadAttention(const Shape& new_query_dimensions,
37 const Shape& new_source_dimensions,
38 Index = 0,
39 const string& = {});
40
42 Shape get_input_shape() const override;
43
45 Shape get_output_shape() const override;
46
47 Index get_query_sequence_length() const { return query_sequence_length; }
48 Index get_source_sequence_length() const { return source_sequence_length; }
49 Index get_embedding_dimension() const { return embedding_dimension; }
50 Index get_heads_number() const { return heads_number; }
51
53 Index get_head_dimension() const
54 {
55 return (heads_number == 0) ? 0 : Index(embedding_dimension / heads_number);
56 }
57
59 Shape get_heads_shape(Index batch_size) const
60 {
61 return {batch_size, heads_number, query_sequence_length, get_head_dimension()};
62 }
63
65 Shape get_concat_shape(Index batch_size) const
66 {
67 return {batch_size, query_sequence_length, heads_number, get_head_dimension()};
68 }
69
71 vector<TensorSpec> get_forward_specs(Index batch_size) const override;
72
74 vector<TensorSpec> get_backward_specs(Index batch_size) const override;
75
77 void set(Index = 0,
78 Index = 0,
79 Index = 0,
80 Index = 0,
81 bool = false,
82 const string& = "multihead_attention_layer");
83
85 void set_input_shape(const Shape&) override;
86
88 void on_compute_dtype_changed() override;
89
91 void set_dropout_rate(float new_dropout_rate) { attention.set_dropout_rate(new_dropout_rate); }
92
94 void read_JSON_body(const Json*) override;
95
97 void write_JSON_body(JsonWriter&) const override;
98
99private:
100
101 Index embedding_dimension = 0;
102 Index heads_number = 0;
103 Index query_sequence_length = 0;
104 Index source_sequence_length = 0;
105
106 MultiHeadProjectionOp query_projection;
107 MultiHeadProjectionOp key_projection;
108 MultiHeadProjectionOp value_projection;
109 CombinationOp output_projection;
110 AttentionOp attention;
111 MergeOp merge;
112
113 enum Forward {Input, Query, Key, AttentionWeights, AttentionWeightsDropped,
114 ConcatenatedAttentionOutputs, Value, TransposeScratch, Output};
115 enum Backward {
116 OutputDelta,
117 InputQueryDelta, // final dInput query, embed shape
118 InputSourceDelta, // final dInput source, embed shape
119 AttentionWeightDelta, // unfused attention scratch
120 ValueHeadDelta, // dV, head shape
121 ConcatenatedOutputDelta, // dConcat, embed shape
122 QueryHeadDelta, // dQ, head shape
123 KeyHeadDelta // dK, head shape
124 };
125};
126
127}
128
129// OpenNN: Open Neural Networks Library.
130// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
131// Licensed under the GNU Lesser General Public License v2.1 or later.
Definition json.h:85
Definition json.h:23
Layer()=default
void set(Index=0, Index=0, Index=0, Index=0, bool=false, const string &="multihead_attention_layer")
Reconfigures the layer with new sequence, embedding, head sizes and causal flag.
void on_compute_dtype_changed() override
Rebuilds projection operators when the compute dtype changes.
Index get_query_sequence_length() const
Definition multihead_attention_layer.h:47
Shape get_heads_shape(Index batch_size) const
Returns the per-head tensor shape used internally during attention.
Definition multihead_attention_layer.h:59
void set_input_shape(const Shape &) override
Updates the layer for a new input shape.
Shape get_concat_shape(Index batch_size) const
Returns the shape used when concatenating heads back to the embedding dimension.
Definition multihead_attention_layer.h:65
Shape get_output_shape() const override
Returns the output tensor shape.
Index get_heads_number() const
Definition multihead_attention_layer.h:50
Index get_embedding_dimension() const
Definition multihead_attention_layer.h:49
vector< TensorSpec > get_forward_specs(Index batch_size) const override
Returns the tensor specifications used during forward propagation.
MultiHeadAttention(const Shape &new_query_dimensions, const Shape &new_source_dimensions, Index=0, const string &={})
Constructs a cross-attention layer with separate query and source (key/value) sequences.
Index get_source_sequence_length() const
Definition multihead_attention_layer.h:48
void read_JSON_body(const Json *) override
Reads the layer configuration from a JSON node.
Shape get_input_shape() const override
Returns the input tensor shape.
void set_dropout_rate(float new_dropout_rate)
Sets the dropout rate applied to the attention weights.
Definition multihead_attention_layer.h:91
MultiHeadAttention(const Shape &=Shape({0, 0}), Index=0, const string &={})
Constructs a self-attention layer where queries and keys share the same sequence.
Index get_head_dimension() const
Returns the per-head dimension (embedding_dimension / heads_number).
Definition multihead_attention_layer.h:53
vector< TensorSpec > get_backward_specs(Index batch_size) const override
Returns the tensor specifications used during back propagation.
void write_JSON_body(JsonWriter &) const override
Writes the layer configuration to a JSON writer.
Definition adaptive_moment_estimation.h:14
Scaled dot-product attention with optional causal mask and dropout.
Definition operators.h:642
Affine combination output = input * weights + bias (the dense matmul building block).
Definition operators.h:232
Reshapes (batch, heads, seq, head_dim) tensors back into (batch, seq, embed); no parameters.
Definition operators.h:839
Projects (input_features) into (heads * head_dim) and reshapes for multi-head attention.
Definition operators.h:579
Fixed-capacity small-vector describing tensor dimensions (rank up to MaxRank).
Definition tensor_utilities.h:42