OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
language_dataset.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// L A N G U A G E D A T A S E T C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
14
15#pragma once
16
17#include "dataset.h"
18
19namespace opennn
20{
21
34class LanguageDataset final : public Dataset
35{
36
37public:
38
43 LanguageDataset(const filesystem::path& path = "");
50 LanguageDataset(const Index samples_number,
51 Index maximum_input_sequence_length,
52 Index maximum_target_sequence_length);
53
55 const vector<string>& get_input_vocabulary() const { return input_vocabulary; }
57 const vector<string>& get_target_vocabulary() const { return target_vocabulary; }
58
60 Index get_input_vocabulary_size() const { return input_vocabulary.size(); }
62 Index get_target_vocabulary_size() const { return target_vocabulary.size(); }
63
65 Index get_maximum_input_sequence_length() const { return maximum_input_sequence_length; }
67 Index get_maximum_target_sequence_length() const { return maximum_target_sequence_length; }
68
73 void set_input_vocabulary(const vector<string>& new_vocabulary) { input_vocabulary = new_vocabulary; }
78 void set_target_vocabulary(const vector<string>& new_vocabulary) { target_vocabulary = new_vocabulary; }
79
81 void read_csv() override;
82
90 void create_vocabulary(const vector<vector<string>>&, vector<string>&) const;
91
98 void encode_input(const vector<vector<string>>&);
105 void encode_decoder_target_sequence_to_sequence(const vector<vector<string>>&);
111 void encode_target_classification(const vector<vector<string>>&);
112
116 void from_JSON(const JsonDocument&) override;
120 void to_JSON(JsonWriter&) const override;
121
123 inline static const string PAD_TOKEN = "[PAD]";
125 inline static const string UNK_TOKEN = "[UNK]";
127 inline static const string START_TOKEN = "[START]";
129 inline static const string END_TOKEN = "[END]";
130
132 inline static const float UNK_INDEX = 1.0f;
134 inline static const float START_INDEX = 2.0f;
136 inline static const float END_INDEX = 3.0f;
137
139 inline static const vector<string> reserved_tokens = {PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN};
140
141private:
142
148 unordered_map<string, Index> create_vocabulary_map(const vector<string>& vocabulary);
149
151 vector<string> input_vocabulary;
153 vector<string> target_vocabulary;
154
156 Index maximum_input_sequence_length = 0;
158 Index maximum_target_sequence_length = 0;
159
161 Index minimum_token_frequency = 1;
163 Index maximum_vocabulary_size = 20000;
164};
165
166}
167
168// OpenNN: Open Neural Networks Library.
169// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
170// Licensed under the GNU Lesser General Public License v2.1 or later.
Dataset(const Index samples_number=0, const Shape &input_shape={0}, const Shape &target_shape={0})
Constructs an empty dataset of given dimensions.
Definition json.h:71
Definition json.h:84
static const float UNK_INDEX
Numeric id used for unknown tokens.
Definition language_dataset.h:132
LanguageDataset(const filesystem::path &path="")
Constructs a LanguageDataset, optionally loading from CSV.
const vector< string > & get_input_vocabulary() const
Read-only access to the input-side vocabulary.
Definition language_dataset.h:55
const vector< string > & get_target_vocabulary() const
Read-only access to the target-side vocabulary.
Definition language_dataset.h:57
LanguageDataset(const Index samples_number, Index maximum_input_sequence_length, Index maximum_target_sequence_length)
Constructs an empty LanguageDataset of given dimensions.
void set_input_vocabulary(const vector< string > &new_vocabulary)
Replaces the input vocabulary.
Definition language_dataset.h:73
void encode_decoder_target_sequence_to_sequence(const vector< vector< string > > &)
Encodes decoder target tokens for sequence-to-sequence training.
static const float START_INDEX
Numeric id used for the start-of-sequence token.
Definition language_dataset.h:134
void to_JSON(JsonWriter &) const override
Writes dataset metadata and vocabularies to a streaming JSON writer.
void encode_target_classification(const vector< vector< string > > &)
Encodes target tokens for classification training.
void from_JSON(const JsonDocument &) override
Loads dataset metadata and vocabularies from a parsed JSON document.
static const string UNK_TOKEN
Unknown token text (id 1).
Definition language_dataset.h:125
static const vector< string > reserved_tokens
Special tokens reserved at the start of every vocabulary.
Definition language_dataset.h:139
static const float END_INDEX
Numeric id used for the end-of-sequence token.
Definition language_dataset.h:136
void set_target_vocabulary(const vector< string > &new_vocabulary)
Replaces the target vocabulary.
Definition language_dataset.h:78
void create_vocabulary(const vector< vector< string > > &, vector< string > &) const
Builds a vocabulary from a list of tokenized samples.
void encode_input(const vector< vector< string > > &)
Encodes input tokens into ids and stores them in the dataset.
static const string PAD_TOKEN
Padding token text (id 0).
Definition language_dataset.h:123
Index get_target_vocabulary_size() const
Number of distinct tokens in the target vocabulary.
Definition language_dataset.h:62
static const string START_TOKEN
Start-of-sequence token text (id 2).
Definition language_dataset.h:127
void read_csv() override
Reads tokens from the configured CSV file into the dataset.
Index get_maximum_input_sequence_length() const
Maximum input sequence length supported.
Definition language_dataset.h:65
Index get_input_vocabulary_size() const
Number of distinct tokens in the input vocabulary.
Definition language_dataset.h:60
Index get_maximum_target_sequence_length() const
Maximum target sequence length supported.
Definition language_dataset.h:67
static const string END_TOKEN
End-of-sequence token text (id 3).
Definition language_dataset.h:129
Declares the Dataset class and the SampleRole enum.
Definition adaptive_moment_estimation.h:19