OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::TransformerDecoder Class Reference

Drives token-by-token inference of a Transformer model with configurable sampling strategies. More...

#include <transformer_decoder.h>

Classes

struct  SamplingConfig
 Sampling parameters that control how the next token is drawn from the model output distribution. More...
 

Public Types

using TokenCallback = function<void(const string& token)>
 Callback invoked for each token emitted during streaming decoding.
 

Public Member Functions

 TransformerDecoder (Transformer &, const LanguageDataset &)
 Builds the decoder bound to a Transformer model and the language dataset providing its vocabulary.
 
 TransformerDecoder (const TransformerDecoder &)=delete
 
TransformerDecoderoperator= (const TransformerDecoder &)=delete
 
 ~TransformerDecoder ()=default
 
string decode (const string &source)
 Generates a completion for the given source using the default sampling configuration.
 
string decode (const string &source, const SamplingConfig &config)
 Generates a completion for the given source using the supplied sampling configuration.
 
string decode (const string &source, const TokenCallback &on_token)
 Generates a completion for the given source and invokes the callback for each emitted token.
 
string decode (const string &source, const SamplingConfig &config, const TokenCallback &on_token)
 Generates a completion using the given sampling configuration and streams tokens via the callback.
 
string decode_to_stream (const string &source, ostream &out)
 Generates a completion for the given source and writes each emitted token to the output stream.
 
string decode_to_stream (const string &source, const SamplingConfig &config, ostream &out)
 Generates a completion using the given sampling configuration and writes each token to the stream.
 
void chat ()
 Runs an interactive REPL: reads prompts from cin, streams predictions to cout, exits on empty line / Ctrl+D.
 
void chat (const SamplingConfig &config)
 Runs the interactive REPL with the supplied sampling configuration.
 

Detailed Description

Drives token-by-token inference of a Transformer model with configurable sampling strategies.

Member Typedef Documentation

◆ TokenCallback

using opennn::TransformerDecoder::TokenCallback = function<void(const string& token)>

Callback invoked for each token emitted during streaming decoding.

Constructor & Destructor Documentation

◆ TransformerDecoder() [1/2]

opennn::TransformerDecoder::TransformerDecoder ( Transformer & ,
const LanguageDataset &  )

Builds the decoder bound to a Transformer model and the language dataset providing its vocabulary.

◆ TransformerDecoder() [2/2]

opennn::TransformerDecoder::TransformerDecoder ( const TransformerDecoder & )
delete

◆ ~TransformerDecoder()

opennn::TransformerDecoder::~TransformerDecoder ( )
default

Member Function Documentation

◆ chat() [1/2]

void opennn::TransformerDecoder::chat ( )

Runs an interactive REPL: reads prompts from cin, streams predictions to cout, exits on empty line / Ctrl+D.

◆ chat() [2/2]

void opennn::TransformerDecoder::chat ( const SamplingConfig & config)

Runs the interactive REPL with the supplied sampling configuration.

◆ decode() [1/4]

string opennn::TransformerDecoder::decode ( const string & source)

Generates a completion for the given source using the default sampling configuration.

◆ decode() [2/4]

string opennn::TransformerDecoder::decode ( const string & source,
const SamplingConfig & config )

Generates a completion for the given source using the supplied sampling configuration.

◆ decode() [3/4]

string opennn::TransformerDecoder::decode ( const string & source,
const SamplingConfig & config,
const TokenCallback & on_token )

Generates a completion using the given sampling configuration and streams tokens via the callback.

◆ decode() [4/4]

string opennn::TransformerDecoder::decode ( const string & source,
const TokenCallback & on_token )

Generates a completion for the given source and invokes the callback for each emitted token.

◆ decode_to_stream() [1/2]

string opennn::TransformerDecoder::decode_to_stream ( const string & source,
const SamplingConfig & config,
ostream & out )

Generates a completion using the given sampling configuration and writes each token to the stream.

◆ decode_to_stream() [2/2]

string opennn::TransformerDecoder::decode_to_stream ( const string & source,
ostream & out )

Generates a completion for the given source and writes each emitted token to the output stream.

◆ operator=()

TransformerDecoder & opennn::TransformerDecoder::operator= ( const TransformerDecoder & )
delete