OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
flatten_layer.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// F L A T T E N L A Y E R C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
14
15#pragma once
16
17#include "layer.h"
18#include "operators.h"
19
20namespace opennn
21{
22
35class Flatten final : public Layer
36{
37public:
38
44 Flatten(const Shape& input_shape = {});
45
47 Shape get_input_shape() const override { return input_shape; }
52 Shape get_output_shape() const override { return { input_shape.size() }; }
53
55 vector<Operator*> get_operators() override { return {&flat}; }
56
62 void set(const Shape&);
63
68 void set_input_shape(const Shape& new_input_shape) override { set(new_input_shape); }
69
76 void back_propagate(ForwardPropagation&, BackPropagation&, size_t) const noexcept override;
77
78
79private:
80
82 Shape input_shape;
83
85 Flat flat;
86
88 enum Forward {Input, Output};
90 enum Backward {OutputDelta, InputDelta};
91};
92
93}
94
95// OpenNN: Open Neural Networks Library.
96// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
97// Licensed under the GNU Lesser General Public License v2.1 or later.
Shape get_input_shape() const override
Returns the per-sample input shape.
Definition flatten_layer.h:47
void set(const Shape &)
Re-initializes the layer.
Flatten(const Shape &input_shape={})
Constructs a Flatten layer.
void back_propagate(ForwardPropagation &, BackPropagation &, size_t) const noexcept override
Backward pass: reshapes the output gradient back to the input shape.
void set_input_shape(const Shape &new_input_shape) override
Updates the input shape; equivalent to calling set().
Definition flatten_layer.h:68
Shape get_output_shape() const override
Returns the per-sample output shape.
Definition flatten_layer.h:52
vector< Operator * > get_operators() override
Returns the single Flat operator that implements this layer.
Definition flatten_layer.h:55
Layer()=default
Default constructor; only invoked by subclasses.
Declares the Layer abstract base class and the LayerType enumeration.
Definition adaptive_moment_estimation.h:19
Definition back_propagation.h:26
Definition operators.h:680
Definition forward_propagation.h:19
Definition tensor_utilities.h:46