OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
embedding_layer.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// E M B E D D I N G L A Y E R C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
14
15#pragma once
16
17#include "layer.h"
18#include "operators.h"
19#include "forward_propagation.h"
20#include "back_propagation.h"
21
22namespace opennn
23{
24
37class Embedding final : public Layer
38{
39public:
40
48 Embedding(const Shape& input_shape = {0, 0},
49 Index embedding_dimension = 0,
50 const string& label = "embedding_layer");
51
53 Shape get_input_shape() const override { return {sequence_length}; }
58 Shape get_output_shape() const override;
59
61 Index get_vocabulary_size() const { return vocabulary_size; }
63 Index get_sequence_length() const { return sequence_length; }
65 Index get_embedding_dimension() const { return embedding_dimension; }
66
71 vector<Operator*> get_operators() override;
72
78 vector<pair<Shape, Type>> get_forward_specs(Index batch_size) const override;
79
87 void set(Index vocabulary_size = 0,
88 Index sequence_length = 0,
89 Index embedding_dimension = 0,
90 const string& label = "embedding_layer");
91
97 void set_scale_embedding(bool enabled) { embedding_lookup.scale_embedding = enabled; }
102 void set_add_positional_encoding(bool enabled) { embedding_lookup.add_positional_encoding = enabled; }
107 void set_dropout_rate(float rate) { dropout.set_rate(rate); }
108
116 void back_propagate(ForwardPropagation&, BackPropagation&, size_t) const noexcept override;
117
122 void read_JSON_body(const Json*) override;
127 void write_JSON_body(JsonWriter&) const override;
128
129private:
130
132 Index vocabulary_size = 0;
134 Index sequence_length = 0;
136 Index embedding_dimension = 0;
137
139 EmbeddingLookup embedding_lookup;
141 Dropout dropout;
142
144 enum Forward {Input, Output};
146 enum Backward {OutputDelta};
147};
148
149}
150
151// OpenNN: Open Neural Networks Library.
152// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
153// Licensed under the GNU Lesser General Public License v2.1 or later.
void set_scale_embedding(bool enabled)
Enables Transformer-style sqrt(d_model) scaling on the embedding table output.
Definition embedding_layer.h:97
void back_propagate(ForwardPropagation &, BackPropagation &, size_t) const noexcept override
Backward pass: scatters output gradients into the embedding table rows referenced by the input ids.
void set_dropout_rate(float rate)
Sets the dropout rate applied at the layer output.
Definition embedding_layer.h:107
void read_JSON_body(const Json *) override
Reads the layer-specific JSON body (vocabulary size, sequence length, embedding dimension,...
Shape get_input_shape() const override
Returns the per-sample input shape (sequence_length,).
Definition embedding_layer.h:53
Index get_sequence_length() const
Sequence length expected at input.
Definition embedding_layer.h:63
void set_add_positional_encoding(bool enabled)
Enables addition of a sinusoidal positional encoding after lookup.
Definition embedding_layer.h:102
vector< pair< Shape, Type > > get_forward_specs(Index batch_size) const override
Specifications of the forward intermediate buffers.
vector< Operator * > get_operators() override
Returns the active operators in pipeline order.
Embedding(const Shape &input_shape={0, 0}, Index embedding_dimension=0, const string &label="embedding_layer")
Constructs an Embedding layer.
void set(Index vocabulary_size=0, Index sequence_length=0, Index embedding_dimension=0, const string &label="embedding_layer")
Re-initializes the layer.
Index get_vocabulary_size() const
Number of distinct tokens in the vocabulary.
Definition embedding_layer.h:61
Index get_embedding_dimension() const
Width of each embedding vector.
Definition embedding_layer.h:65
Shape get_output_shape() const override
Returns the per-sample output shape.
void write_JSON_body(JsonWriter &) const override
Writes the layer-specific JSON body (vocabulary size, sequence length, embedding dimension,...
Definition json.h:84
Definition json.h:22
Layer()=default
Default constructor; only invoked by subclasses.
string label
User-visible label for this layer instance (default "my_layer").
Definition layer.h:469
Declares the Layer abstract base class and the LayerType enumeration.
Definition adaptive_moment_estimation.h:19
Definition back_propagation.h:26
Definition operators.h:65
Definition operators.h:637
Definition forward_propagation.h:19
Definition tensor_utilities.h:46