OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
transformer_decoder.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// T R A N S F O R M E R D E C O D E R C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include <functional>
12
13#include "forward_propagation.h"
14#include "language_dataset.h"
15#include "standard_networks.h"
16
17namespace opennn
18{
19
22{
23public:
24
27 {
28 float temperature = 1.0f;
29 Index top_k = 0;
30 float top_p = 1.0f;
31 float repetition_penalty = 1.0f;
32 Index maximum_tokens = 0;
33 };
34
36 using TokenCallback = function<void(const string& token)>;
37
43
45 string decode(const string& source);
46
48 string decode(const string& source, const SamplingConfig& config);
49
51 string decode(const string& source, const TokenCallback& on_token);
52
54 string decode(const string& source, const SamplingConfig& config, const TokenCallback& on_token);
55
57 string decode_to_stream(const string& source, ostream& out);
58
60 string decode_to_stream(const string& source, const SamplingConfig& config, ostream& out);
61
63 void chat();
64
66 void chat(const SamplingConfig& config);
67
68private:
69
70 Transformer& transformer;
71 const LanguageDataset& language_dataset;
72
73 Buffer arena{Device::CUDA};
74 TensorView source_ids_device;
75 TensorView target_ids_device;
76
77 unique_ptr<ForwardPropagation> forward_propagation;
78
79 vector<TensorView> inputs;
80
81 Tensor2 source_ids;
82 Tensor2 target_ids;
83 vector<Index> history;
84 VectorR distribution;
85 vector<uint16_t> bf16_staging;
86
87 Index decoder_embedding_layer_index = -1;
88 Index encoder_embedding_layer_index = -1;
89 Index encoder_last_layer_index = -1;
90 Index decoder_stack_first_layer_index = -1;
91 Index output_projection_layer_index = -1;
92
93 void identify_layer_ranges();
94 void encode_source(const string& source);
95 Index decode_step(Index step_index, const SamplingConfig& config);
96 void reset_per_prompt_state();
97 string assemble_output_string() const;
98};
99
100}
101
102// OpenNN: Open Neural Networks Library.
103// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
104// Licensed under the GNU Lesser General Public License v2.1 or later.
Token-based language dataset with input/target vocabularies and binary token cache.
Definition language_dataset.h:19
string decode_to_stream(const string &source, const SamplingConfig &config, ostream &out)
Generates a completion using the given sampling configuration and writes each token to the stream.
TransformerDecoder(const TransformerDecoder &)=delete
string decode_to_stream(const string &source, ostream &out)
Generates a completion for the given source and writes each emitted token to the output stream.
string decode(const string &source, const SamplingConfig &config, const TokenCallback &on_token)
Generates a completion using the given sampling configuration and streams tokens via the callback.
void chat()
Runs an interactive REPL: reads prompts from cin, streams predictions to cout, exits on empty line / ...
function< void(const string &token)> TokenCallback
Callback invoked for each token emitted during streaming decoding.
Definition transformer_decoder.h:36
TransformerDecoder(Transformer &, const LanguageDataset &)
Builds the decoder bound to a Transformer model and the language dataset providing its vocabulary.
string decode(const string &source)
Generates a completion for the given source using the default sampling configuration.
string decode(const string &source, const SamplingConfig &config)
Generates a completion for the given source using the supplied sampling configuration.
void chat(const SamplingConfig &config)
Runs the interactive REPL with the supplied sampling configuration.
TransformerDecoder & operator=(const TransformerDecoder &)=delete
string decode(const string &source, const TokenCallback &on_token)
Generates a completion for the given source and invokes the callback for each emitted token.
Factory encoder-decoder Transformer neural network for sequence-to-sequence tasks.
Definition standard_networks.h:123
Definition adaptive_moment_estimation.h:14
@ CUDA
Definition configuration.h:17
Matrix< float, Dynamic, 1 > VectorR
Definition pch.h:181
Tensor< float, 2, Layout|AlignedMax > Tensor2
Definition pch.h:189
Owning raw byte buffer that lives on CPU or CUDA memory, with aligned (re)allocation.
Definition tensor_utilities.h:166
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293
Sampling parameters that control how the next token is drawn from the model output distribution.
Definition transformer_decoder.h:27
float temperature
Definition transformer_decoder.h:28
float top_p
Definition transformer_decoder.h:30
float repetition_penalty
Definition transformer_decoder.h:31
Index maximum_tokens
Definition transformer_decoder.h:32
Index top_k
Definition transformer_decoder.h:29