45 const string&
label =
"normalization_layer_3d");
73 void set(Index sequence_length = 0,
74 Index embedding_dimension = 0,
75 const string&
label =
"normalization_layer_3d");
83 if (new_input_shape.
rank >= 2)
85 sequence_length = new_input_shape[0];
86 embedding_dimension = new_input_shape[1];
107 Index sequence_length = 0;
109 Index embedding_dimension = 0;
115 enum Forward {Input, Means, StandardDeviations, NormalizedInput, Output};
117 enum Backward {OutputDelta, InputDelta};
Layer()=default
Default constructor; only invoked by subclasses.
string label
User-visible label for this layer instance (default "my_layer").
Definition layer.h:469
Normalization3d(const Shape &input_shape=Shape({0, 0}), const string &label="normalization_layer_3d")
Constructs a Normalization3d layer.
Shape get_output_shape() const override
Returns the per-sample output shape (same as input).
Index get_embedding_dimension() const
Embedding dimension along which normalization is applied.
Definition normalization_layer_3d.h:55
Shape get_input_shape() const override
Returns the per-sample input shape (sequence_length, embedding_dimension).
vector< pair< Shape, Type > > get_forward_specs(Index batch_size) const override
Specifications of the forward intermediate buffers.
vector< Operator * > get_operators() override
Returns the single LayerNorm operator.
Index get_sequence_length() const
Sequence length of the input.
Definition normalization_layer_3d.h:53
void set_input_shape(const Shape &new_input_shape) override
Updates the input shape (sequence_length, embedding_dimension).
Definition normalization_layer_3d.h:81
void back_propagate(ForwardPropagation &, BackPropagation &, size_t) const noexcept override
Backward pass through layer normalization.
void read_JSON_body(const Json *) override
Reads the layer-specific JSON body (sequence length and embedding dimension).
void set(Index sequence_length=0, Index embedding_dimension=0, const string &label="normalization_layer_3d")
Re-initializes the layer.
Declares the Layer abstract base class and the LayerType enumeration.
Definition adaptive_moment_estimation.h:19
Definition back_propagation.h:26
Definition forward_propagation.h:19
Definition operators.h:362
Definition tensor_utilities.h:46
size_t rank
Definition tensor_utilities.h:50