OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::EmbeddingLookup Struct Reference

#include <operators.h>

Inheritance diagram for opennn::EmbeddingLookup:
[legend]

Public Member Functions

void set (Index new_vocabulary_size, Index new_sequence_length, Index new_embedding_dimension)
 
vector< pair< Shape, Type > > parameter_specs () const override
 
vector< pair< Shape, Type > > state_specs () const override
 
void link_parameters (const vector< TensorView > &views) override
 
void link_gradients (const vector< TensorView > &views) override
 
void link_states (const vector< TensorView > &views) override
 
void set_parameters_random () override
 
void set_parameters_glorot () override
 
void init_positional_encoding ()
 
void forward_propagate (ForwardPropagation &fp, size_t layer, bool is_training) noexcept override
 
void apply (const TensorView &indices, TensorView &output)
 
void apply_delta (const TensorView &indices, const TensorView &output_delta) const
 
- Public Member Functions inherited from opennn::Operator
virtual ~Operator ()=default
 
virtual void to_JSON (JsonWriter &) const
 
virtual void from_JSON (const Json *)
 
virtual void load_state_from_JSON (const Json *)
 
virtual void destroy_cuda ()
 

Public Attributes

Index vocabulary_size = 0
 
Index sequence_length = 0
 
Index embedding_dimension = 0
 
bool scale_embedding = false
 
bool add_positional_encoding = false
 
float embedding_scale = 1.0f
 
TensorView weights
 
TensorView positional_encoding
 
TensorView weight_gradient
 
- Public Attributes inherited from opennn::Operator
vector< size_t > input_slots
 
vector< size_t > output_slots
 

Member Function Documentation

◆ apply()

void opennn::EmbeddingLookup::apply ( const TensorView & indices,
TensorView & output )

◆ apply_delta()

void opennn::EmbeddingLookup::apply_delta ( const TensorView & indices,
const TensorView & output_delta ) const

◆ forward_propagate()

void opennn::EmbeddingLookup::forward_propagate ( ForwardPropagation & fp,
size_t layer,
bool is_training )
overridevirtualnoexcept

Reimplemented from opennn::Operator.

◆ init_positional_encoding()

void opennn::EmbeddingLookup::init_positional_encoding ( )

◆ link_gradients()

void opennn::EmbeddingLookup::link_gradients ( const vector< TensorView > & views)
overridevirtual

Reimplemented from opennn::Operator.

◆ link_parameters()

void opennn::EmbeddingLookup::link_parameters ( const vector< TensorView > & views)
overridevirtual

Reimplemented from opennn::Operator.

◆ link_states()

void opennn::EmbeddingLookup::link_states ( const vector< TensorView > & views)
overridevirtual

Reimplemented from opennn::Operator.

◆ parameter_specs()

vector< pair< Shape, Type > > opennn::EmbeddingLookup::parameter_specs ( ) const
overridevirtual

Reimplemented from opennn::Operator.

◆ set()

void opennn::EmbeddingLookup::set ( Index new_vocabulary_size,
Index new_sequence_length,
Index new_embedding_dimension )

◆ set_parameters_glorot()

void opennn::EmbeddingLookup::set_parameters_glorot ( )
overridevirtual

Reimplemented from opennn::Operator.

◆ set_parameters_random()

void opennn::EmbeddingLookup::set_parameters_random ( )
overridevirtual

Reimplemented from opennn::Operator.

◆ state_specs()

vector< pair< Shape, Type > > opennn::EmbeddingLookup::state_specs ( ) const
overridevirtual

Reimplemented from opennn::Operator.

Member Data Documentation

◆ add_positional_encoding

bool opennn::EmbeddingLookup::add_positional_encoding = false

◆ embedding_dimension

Index opennn::EmbeddingLookup::embedding_dimension = 0

◆ embedding_scale

float opennn::EmbeddingLookup::embedding_scale = 1.0f

◆ positional_encoding

TensorView opennn::EmbeddingLookup::positional_encoding

◆ scale_embedding

bool opennn::EmbeddingLookup::scale_embedding = false

◆ sequence_length

Index opennn::EmbeddingLookup::sequence_length = 0

◆ vocabulary_size

Index opennn::EmbeddingLookup::vocabulary_size = 0

◆ weight_gradient

TensorView opennn::EmbeddingLookup::weight_gradient

◆ weights

TensorView opennn::EmbeddingLookup::weights