OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
tensor_utilities.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// T E N S O R U T I L I T I E S C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "pch.h"
12#include "configuration.h"
13
14namespace opennn
15{
16
17static constexpr Index ALIGN_BYTES = EIGEN_MAX_ALIGN_BYTES;
18static constexpr Index ALIGN_ELEMENTS = ALIGN_BYTES / sizeof(float);
19
20[[nodiscard]] inline int to_int(Index value) { return static_cast<int>(value); }
21[[nodiscard]] inline float to_type(Index value) { return static_cast<float>(value); }
22
23[[nodiscard]] inline Index align_up(Index value, Index alignment)
24{
25 return value == 0 ? 0 : (value + alignment - 1) & ~(alignment - 1);
26}
27
28[[nodiscard]] inline Index get_aligned_size(Index size) { return align_up(size, ALIGN_ELEMENTS); }
29[[nodiscard]] inline Index get_aligned_bytes(Index n_bytes) { return align_up(n_bytes, ALIGN_BYTES); }
30[[nodiscard]] inline Index get_aligned_bytes(Index count, Type dtype) { return get_aligned_bytes(count * type_bytes(dtype)); }
31
32[[nodiscard]] inline bool is_aligned(const void* ptr)
33{
34 return reinterpret_cast<uintptr_t>(ptr) % ALIGN_BYTES == 0;
35}
36
39
41struct Shape
42{
43 static constexpr size_t MaxRank = 4;
44
45 Index dims[MaxRank] = {0};
46 size_t rank = 0;
47
48 Shape() noexcept = default;
49
51 Shape(size_t new_rank, Index value) : rank(new_rank)
52 {
53 if (new_rank > MaxRank)
54 throw runtime_error(format("Shape: rank {} exceeds MaxRank={}.",
55 new_rank, MaxRank));
56 fill_n(dims, rank, value);
57 }
58
60 Shape(initializer_list<Index> list) : rank(list.size())
61 {
62 if (list.size() > MaxRank)
63 throw runtime_error(format("Shape: initializer rank {} exceeds MaxRank={}.",
64 list.size(), MaxRank));
65 copy_n(list.begin(), rank, dims);
66 }
67
68 [[nodiscard]] const Index* begin() const noexcept { return dims; }
69 [[nodiscard]] const Index* end() const noexcept { return dims + rank; }
70 [[nodiscard]] const Index& operator[](size_t i) const noexcept { return dims[i]; }
71 Index& operator[](size_t i) noexcept { return dims[i]; }
72
73 Index& back() { if (rank == 0) throw runtime_error("Shape::back() on empty"); return dims[rank - 1]; }
74 [[nodiscard]] const Index& back() const { if (rank == 0) throw runtime_error("Shape::back() on empty"); return dims[rank - 1]; }
75
76 [[nodiscard]] bool empty() const noexcept { return rank == 0; }
77
79 [[nodiscard]] Index dim_or_zero(size_t i) const noexcept { return i < rank ? dims[i] : Index(0); }
80
82 [[nodiscard]] Index size() const noexcept
83 {
84 return rank == 0 ? 0 : accumulate(begin(), end(), Index(1), multiplies<>{});
85 }
86
88 void clear() noexcept { rank = 0; }
90 void push_back(Index value) noexcept { if (rank < MaxRank) dims[rank++] = value; }
91
92 friend ostream& operator<<(ostream& os, const Shape& shape)
93 {
94 os << "[";
95 for (size_t i = 0; i < shape.rank; ++i) os << (i ? ", " : " ") << shape.dims[i];
96 os << " ]";
97 return os;
98 }
99
100 [[nodiscard]] bool operator==(const Shape& other) const noexcept
101 {
102 return rank == other.rank && equal(begin(), end(), other.begin());
103 }
104
106 Shape& append(const Shape& other)
107 {
108 const size_t copy_count = min(other.rank, MaxRank - rank);
109 copy_n(other.dims, copy_count, dims + rank);
110 rank += copy_count;
111 return *this;
112 }
113};
114
121
122[[nodiscard]] inline Index get_aligned_size(const vector<TensorSpec>& specs)
123{
124 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
125 [](const auto& spec) { return get_aligned_size(spec.shape.size()); });
126}
127
128[[nodiscard]] inline Index get_aligned_size(const vector<vector<TensorSpec>>& specs)
129{
130 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
131 [](const auto& s) { return get_aligned_size(s); });
132}
133
134[[nodiscard]] inline Index get_aligned_bytes(const vector<TensorSpec>& specs)
135{
136 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
137 [](const auto& spec) { return get_aligned_bytes(spec.shape.size(), spec.dtype); });
138}
139
140[[nodiscard]] inline Index get_aligned_bytes(const vector<vector<TensorSpec>>& specs)
141{
142 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
143 [](const auto& s) { return get_aligned_bytes(s); });
144}
145
146[[nodiscard]] inline Index get_aligned_bytes(const vector<Shape>& shapes, Type dtype)
147{
148 return transform_reduce(shapes.begin(), shapes.end(), Index(0), plus<>{},
149 [dtype](const Shape& s) { return get_aligned_bytes(s.size(), dtype); });
150}
151
152[[nodiscard]] inline Index get_aligned_bytes(const vector<TensorSpec>& specs, Type dtype)
153{
154 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
155 [dtype](const auto& spec) { return get_aligned_bytes(spec.shape.size(), dtype); });
156}
157
158[[nodiscard]] inline Index get_aligned_bytes(const vector<vector<TensorSpec>>& specs, Type dtype)
159{
160 return transform_reduce(specs.begin(), specs.end(), Index(0), plus<>{},
161 [dtype](const auto& s) { return get_aligned_bytes(s, dtype); });
162}
163
165struct Buffer
166{
167 void* data = nullptr;
168 Index bytes = 0;
170
172 template<typename T> [[nodiscard]] T* as() { return static_cast<T*>(data); }
174 template<typename T> [[nodiscard]] const T* as() const { return static_cast<const T*>(data); }
175
177 [[nodiscard]] Index size_in_floats() const { return bytes / Index(sizeof(float)); }
179 [[nodiscard]] bool empty() const { return bytes == 0; }
180
182 void resize_bytes(Index new_bytes, Device new_device_type)
183 {
184 if (new_bytes == bytes && device_type == new_device_type) return;
185 free_buffer();
186 if (new_bytes == 0) return;
187
188 data = alloc(new_device_type, new_bytes);
189 device_type = new_device_type;
190 bytes = new_bytes;
191 }
192
194 void grow_to(Index new_bytes)
195 {
196 if (new_bytes > bytes)
197 resize_bytes(new_bytes, device_type);
198 }
199
201 template<typename T>
202 T* ensure(Index n_elements)
203 {
204 grow_to(n_elements * Index(sizeof(T)));
205 return as<T>();
206 }
207
209 void setZero()
210 {
211 if (!data) return;
212#ifdef OPENNN_HAS_CUDA
214 {
215 CHECK_CUDA(cudaMemset(data, 0, bytes));
216 return;
217 }
218#endif
219 memset(data, 0, static_cast<size_t>(bytes));
220 }
221
222#ifdef OPENNN_HAS_CUDA
224 void migrate_to(Device target, cudaStream_t stream = nullptr)
225 {
226 if (device_type == target || !data) return;
227
228 void* fresh = alloc(target, bytes);
229 const cudaMemcpyKind kind = (target == Device::CUDA)
230 ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost;
231
232 if (stream)
233 {
234 CHECK_CUDA(cudaMemcpyAsync(fresh, data, bytes, kind, stream));
235 CHECK_CUDA(cudaStreamSynchronize(stream));
236 }
237 else
238 {
239 CHECK_CUDA(cudaMemcpy(fresh, data, bytes, kind));
240 }
241
242 dealloc(device_type, data, bytes);
243 data = fresh;
244 device_type = target;
245 }
246#endif
247
249 explicit Buffer(Device new_device_type = Device::CPU) noexcept : device_type(new_device_type) {}
250 Buffer(const Buffer&) = delete;
251 Buffer& operator=(const Buffer&) = delete;
252
253 Buffer(Buffer&& other) noexcept : Buffer() { swap(other); }
254 Buffer& operator=(Buffer&& other) noexcept { swap(other); return *this; }
255
256 ~Buffer() { free_buffer(); }
257
259 void swap(Buffer& other) noexcept
260 {
261 std::swap(data, other.data);
262 std::swap(bytes, other.bytes);
263 std::swap(device_type, other.device_type);
264 }
265
266private:
267 static void* alloc([[maybe_unused]] Device device_type, Index byte_count)
268 {
269#ifdef OPENNN_HAS_CUDA
270 if (device_type == Device::CUDA) { void* device_pointer = nullptr; CHECK_CUDA(cudaMalloc(&device_pointer, byte_count)); return device_pointer; }
271#endif
272 return Eigen::aligned_allocator<uint8_t>{}.allocate(static_cast<size_t>(byte_count));
273 }
274
275 static void dealloc([[maybe_unused]] Device device_type, void* pointer, Index byte_count)
276 {
277#ifdef OPENNN_HAS_CUDA
278 if (device_type == Device::CUDA) { cudaFree(pointer); return; }
279#endif
280 Eigen::aligned_allocator<uint8_t>{}.deallocate(static_cast<uint8_t*>(pointer), static_cast<size_t>(byte_count));
281 }
282
283 void free_buffer()
284 {
285 if (data) dealloc(device_type, data, bytes);
286 data = nullptr;
287 bytes = 0;
288 }
289};
290
293{
294 void* data = nullptr;
295
297
299
301 TensorView(void* new_data = nullptr, const Shape& new_shape = {},
302 Type new_dtype = Type::FP32) noexcept
303 : data(new_data), shape(new_shape), type(new_dtype) {}
304
306 [[nodiscard]] Index get_rank() const noexcept { return shape.rank; }
307
309 [[nodiscard]] Index size() const noexcept { return shape.size(); }
310
312 [[nodiscard]] Index byte_size() const noexcept { return size() * type_bytes(type); }
313
315 [[nodiscard]] bool empty() const noexcept { return shape.empty(); }
316
318 template<typename T>
319 [[nodiscard]] T* as() const noexcept
320 {
321 assert(data);
322 return reinterpret_cast<T*>(data);
323 }
324
326 [[nodiscard]] float* as_float() const noexcept
327 {
328 return reinterpret_cast<float*>(data);
329 }
330
332 [[nodiscard]] cudaDataType_t cuda_dtype() const noexcept { return to_cuda(type); }
333
335 template<typename F>
336 void dispatch(F&& fn) const
337 {
339 {
340 fn(typename decltype(info)::type{});
341 });
342 }
343
345 [[nodiscard]] TensorView reshape(const Shape& new_shape) const
346 { return TensorView(data, new_shape, type); }
347
349 [[nodiscard]] MatrixMap as_matrix() const
350 {
351 assert(shape.rank >= 2);
352 return MatrixMap(as<float>(), shape[0], shape.size() / shape[0]);
353 }
354
356 [[nodiscard]] MatrixMap as_matrix(Index batch_index) const
357 {
358 assert(shape.rank >= 2);
359 const Index rows = shape[shape.rank - 2];
360 const Index cols = shape[shape.rank - 1];
361 return MatrixMap(as<float>() + batch_index * rows * cols, rows, cols);
362 }
363
365 [[nodiscard]] MatrixMap as_flat_matrix() const
366 {
367 assert(shape.rank >= 1);
368 const Index cols = shape[shape.rank - 1];
369 return MatrixMap(as<float>(), shape.size() / cols, cols);
370 }
371
373 [[nodiscard]] MatrixMap as_flat_matrix(Index batch_index) const
374 {
375 assert(shape.rank >= 2);
376 const Index cols = shape[shape.rank - 1];
377 const Index rows = shape.size() / (shape[0] * cols);
378 return MatrixMap(as<float>() + batch_index * rows * cols, rows, cols);
379 }
380
382 [[nodiscard]] VectorMap as_vector() const
383 {
384 return VectorMap(as<float>(), shape.size());
385 }
386
388 template<int Rank>
389 [[nodiscard]] TensorMapR<Rank> as_tensor() const
390 {
391 assert(shape.rank == Rank);
392 Eigen::array<Index, Rank> dims;
393 copy_n(shape.dims, Rank, dims.begin());
394 return TensorMapR<Rank>(as<float>(), dims);
395 }
396
398 template<int Rank>
399 [[nodiscard]] TensorMapR<Rank> as_tensor(Index batch_index) const
400 {
401 assert(shape.rank == Rank + 1);
402 Eigen::array<Index, Rank> dims;
403 for (int i = 0; i < Rank; ++i) dims[i] = shape[i + 1];
404 const Index slice_size = shape.size() / shape[0];
405 return TensorMapR<Rank>(as<float>() + batch_index * slice_size, dims);
406 }
407
409 void fill(float value);
411 void setZero() { fill(0.0f); }
412
413#ifdef OPENNN_HAS_CUDA
414 void set_zero_async() const;
415
416 mutable shared_ptr<cudnnTensorStruct> descriptor_handle = nullptr;
417
418 cudnnTensorDescriptor_t get_descriptor() const
419 {
420 if (!descriptor_handle && !shape.empty())
421 set_descriptor(shape);
422 return descriptor_handle.get();
423 }
424
425private:
426 void set_descriptor(const Shape& shape) const
427 {
428 // NHWC layout: rank < 4 leading dims default to 1.
429 int batch_count = 1, channels = 1, height = 1, width = 1;
430 const size_t rank = shape.rank;
431 if (rank >= 1) channels = static_cast<int>(shape[rank - 1]);
432 if (rank >= 2) batch_count = static_cast<int>(shape[0]);
433 if (rank >= 3) width = static_cast<int>(shape[rank - 2]);
434 if (rank >= 4) height = static_cast<int>(shape[rank - 3]);
435
436 if (batch_count <= 0 || channels <= 0 || height <= 0 || width <= 0)
437 return;
438
439 if (!descriptor_handle)
440 {
442 CHECK_CUDNN(cudnnCreateTensorDescriptor(&raw_desc));
443
444 descriptor_handle = shared_ptr<cudnnTensorStruct>(raw_desc, [](cudnnTensorDescriptor_t descriptor) {
445 cudnnDestroyTensorDescriptor(descriptor);
446 });
447 }
448
449 CHECK_CUDNN(cudnnSetTensor4dDescriptor(descriptor_handle.get(), CUDNN_TENSOR_NHWC, to_cudnn(type), batch_count, channels, height, width));
450 }
451
452#endif
453
454};
455
456inline TensorView& view_at_slot_or(vector<TensorView>& views,
457 const vector<size_t>& slots, size_t i,
458 TensorView& fallback)
459{
460 return i < slots.size() ? views[slots[i]] : fallback;
461}
462
463inline TensorView& view_at_slot_or(vector<vector<TensorView>>& views,
464 const vector<size_t>& slots, size_t i,
465 TensorView& fallback)
466{
467 return i < slots.size() ? views[slots[i]][0] : fallback;
468}
469
470template<typename T, size_t N>
471using array = Eigen::array<T, N>;
472
474[[nodiscard]] string shape_to_string(const Shape&, const string& = " ");
476[[nodiscard]] Shape string_to_shape(const string&, const string& = " ");
477
479// Boost-style hash combine. Mixes one or more values into a single size_t. Used
480// for plan/graph cache keys (cuBLASLt, cuDNN SDPA).
481template<typename... Vs>
482[[nodiscard]] size_t hash_combine(const Vs&... values)
483{
484 size_t h = 0;
485 ((h ^= hash<Vs>{}(values) + 0x9e3779b9 + (h << 6) + (h >> 2)), ...);
486 return h;
487}
488
490class Backend
491{
492public:
493
495 static Backend& instance();
496
498 ThreadPoolDevice* get_thread_pool_device();
499
501 void set_threads_number(int num_threads);
502
504 static cublasHandle_t get_cublas_handle() { return instance().cublas_handle; }
506 static cublasLtHandle_t get_cublas_lt_handle() { return instance().cublas_lt_handle; }
508 static cudnnHandle_t get_cudnn_handle() { return instance().cudnn_handle; }
510 static cudaStream_t get_compute_stream() { return instance().compute_stream; }
512 static cudnnOpTensorDescriptor_t get_operator_sum_descriptor() { return instance().operator_sum_descriptor; }
513
514private:
515 Backend();
516 ~Backend();
517
518 unique_ptr<ThreadPool> thread_pool;
519 unique_ptr<ThreadPoolDevice> thread_pool_device;
520
521 cublasHandle_t cublas_handle = nullptr;
522 cublasLtHandle_t cublas_lt_handle = nullptr;
523 cudnnHandle_t cudnn_handle = nullptr;
524 cudaStream_t compute_stream = nullptr;
525 cudnnOpTensorDescriptor_t operator_sum_descriptor = nullptr;
526};
527
529inline ThreadPoolDevice& get_device()
530{
532}
533
534inline void TensorView::fill(float value)
535{
536 if (!data) return;
537
538#ifdef OPENNN_HAS_CUDA
539 // Probe the pointer: set_parameters_random() may call fill() on host-resident
540 // TensorViews even when Device is already GPU.
541 cudaPointerAttributes attr{};
542 const cudaError_t err = cudaPointerGetAttributes(&attr, data);
543 const bool gpu_data = (err == cudaSuccess) && (attr.type == cudaMemoryTypeDevice);
544 if (err != cudaSuccess) cudaGetLastError(); // clear sticky error from CPU pointer probe
545
546 if (gpu_data)
547 {
548 if (value == 0.0f)
549 {
550 CHECK_CUDA(cudaMemset(data, 0, byte_size()));
551 return;
552 }
553
554 CHECK_CUDNN(cudnnSetTensor(Backend::get_cudnn_handle(),
555 get_descriptor(), data, &value));
556 return;
557 }
558#endif
559
560 assert(type == Type::FP32);
561 float* data_pointer = static_cast<float*>(data);
562 std::fill(data_pointer, data_pointer + size(), value);
563}
564
565#ifdef OPENNN_HAS_CUDA
566
567inline void TensorView::set_zero_async() const
568{
569 if (!data || byte_size() == 0) return;
570 CHECK_CUDA(cudaMemsetAsync(data, 0, byte_size(), Backend::get_compute_stream()));
571}
572
573inline const float one = 1.0f;
574inline const float zero = 0.0f;
575
576// Sync D2H copy of a contiguous device buffer into an FP32 host buffer,
577// upcasting BF16 -> FP32 inline. Blocks on `stream`. Throws on unsupported
578// dtype. Allocates a uint16_t staging buffer for BF16 sources, so callers on
579// hot paths may prefer an inline version with a pre-allocated staging buffer.
580void copy_device_to_host_float(const void* device_src, Type src_dtype,
581 Index element_count, float* host_dst,
582 cudaStream_t stream);
583
584#endif
585
586}
587
588// OpenNN: Open Neural Networks Library.
589// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
590// Licensed under the GNU Lesser General Public License v2.1 or later.
Process-wide singleton that owns the thread pool and the cuBLAS/cuDNN handles.
Definition tensor_utilities.h:491
static cudaStream_t get_compute_stream()
Default CUDA stream used by the compute backend.
Definition tensor_utilities.h:510
static Backend & instance()
Returns the global Backend instance.
static cudnnHandle_t get_cudnn_handle()
Shared cuDNN handle.
Definition tensor_utilities.h:508
static cublasHandle_t get_cublas_handle()
Shared cuBLAS handle for legacy GEMM calls.
Definition tensor_utilities.h:504
static cublasLtHandle_t get_cublas_lt_handle()
Shared cuBLASLt handle for batched/tuned GEMMs.
Definition tensor_utilities.h:506
static cudnnOpTensorDescriptor_t get_operator_sum_descriptor()
cuDNN op-tensor descriptor configured for elementwise sum.
Definition tensor_utilities.h:512
ThreadPoolDevice * get_thread_pool_device()
Returns the Eigen ThreadPoolDevice used for CPU tensor evaluations.
void set_threads_number(int num_threads)
Reconfigures the underlying thread pool to use num_threads workers.
Definition adaptive_moment_estimation.h:14
bool is_aligned(const void *ptr)
Definition tensor_utilities.h:32
size_t hash_combine(const Vs &... values)
Boost-style hash combine that mixes any number of hashable values into a single size_t.
Definition tensor_utilities.h:482
constexpr cudaDataType_t CUDA_REDUCTION_DTYPE
Definition tensor_utilities.h:37
float to_type(Index value)
Definition tensor_utilities.h:21
Index type_bytes(Type type) noexcept
Returns the byte size of one element of the given OpenNN Type.
Definition configuration.h:111
constexpr cublasComputeType_t CUBLAS_COMPUTE_DTYPE
Definition tensor_utilities.h:38
Device
Execution device selection for OpenNN runtime (auto-detected, CPU or CUDA GPU).
Definition configuration.h:17
@ CPU
Definition configuration.h:17
@ CUDA
Definition configuration.h:17
Index get_aligned_bytes(Index n_bytes)
Definition tensor_utilities.h:29
Eigen::array< T, N > array
Definition tensor_utilities.h:471
Shape string_to_shape(const string &, const string &=" ")
Parses a separator-joined string of dimensions into a Shape.
ThreadPoolDevice & get_device()
Convenience accessor for the global Eigen ThreadPoolDevice.
Definition tensor_utilities.h:529
cudnnDataType_t to_cudnn(Type type) noexcept
Returns the cuDNN data type matching the given OpenNN Type (Auto resolves to FP32).
Definition configuration.h:83
Index get_aligned_size(Index size)
Definition tensor_utilities.h:28
Type
Numeric precision used for training or inference tensors.
Definition configuration.h:20
@ FP32
Definition configuration.h:20
int to_int(Index value)
Definition tensor_utilities.h:20
void visit_type(Type t, F &&f)
Dispatches f with the TypeInfo of the runtime Type t (must be in Supported).
Definition configuration.h:59
string shape_to_string(const Shape &, const string &=" ")
Serializes a shape as a separator-joined string of dimensions.
cudaDataType_t to_cuda(Type type) noexcept
Returns the CUDA data type matching the given OpenNN Type (Auto resolves to FP32).
Definition configuration.h:97
static constexpr Index ALIGN_BYTES
Definition tensor_utilities.h:17
Index align_up(Index value, Index alignment)
Definition tensor_utilities.h:23
static constexpr Index ALIGN_ELEMENTS
Definition tensor_utilities.h:18
TensorView & view_at_slot_or(vector< TensorView > &views, const vector< size_t > &slots, size_t i, TensorView &fallback)
Definition tensor_utilities.h:456
cudnnTensorStruct * cudnnTensorDescriptor_t
Definition pch.h:109
Map< VectorR, AlignedMax > VectorMap
Definition pch.h:185
TensorMap< Tensor< float, Rank, Layout|AlignedMax >, AlignedMax > TensorMapR
Definition pch.h:201
void * cudnnHandle_t
Definition pch.h:86
void * cudaStream_t
Definition pch.h:82
cudaDataType_t
Definition pch.h:93
@ CUDA_R_32F
Definition pch.h:93
void * cublasLtHandle_t
Definition pch.h:85
void * cudnnOpTensorDescriptor_t
Definition pch.h:115
void * cublasHandle_t
Definition pch.h:84
#define EIGEN_MAX_ALIGN_BYTES
Definition pch.h:12
Map< MatrixR, Layout|AlignedMax > MatrixMap
Definition pch.h:186
cublasComputeType_t
Definition pch.h:94
@ CUBLAS_COMPUTE_32F_FAST_TF32
Definition pch.h:94
Buffer & operator=(Buffer &&other) noexcept
Definition tensor_utilities.h:254
bool empty() const
Returns true if no storage is allocated.
Definition tensor_utilities.h:179
Index size_in_floats() const
Capacity expressed in float elements.
Definition tensor_utilities.h:177
void swap(Buffer &other) noexcept
Swaps storage with another buffer.
Definition tensor_utilities.h:259
T * ensure(Index n_elements)
Ensures the buffer holds at least n_elements of T and returns a typed pointer.
Definition tensor_utilities.h:202
Device device_type
Definition tensor_utilities.h:169
void grow_to(Index new_bytes)
Grows the buffer to at least new_bytes; no-op if already large enough.
Definition tensor_utilities.h:194
Buffer(Buffer &&other) noexcept
Definition tensor_utilities.h:253
void setZero()
Zeros all bytes in the buffer (cudaMemset on device, memset on host).
Definition tensor_utilities.h:209
Index bytes
Definition tensor_utilities.h:168
void * data
Definition tensor_utilities.h:167
Buffer(Device new_device_type=Device::CPU) noexcept
Constructs an empty buffer targeting the given device type.
Definition tensor_utilities.h:249
Buffer & operator=(const Buffer &)=delete
~Buffer()
Definition tensor_utilities.h:256
Buffer(const Buffer &)=delete
void resize_bytes(Index new_bytes, Device new_device_type)
Resizes the buffer to new_bytes on new_device_type, freeing prior storage.
Definition tensor_utilities.h:182
T * as()
Reinterprets the buffer as a typed pointer (no bounds checking).
Definition tensor_utilities.h:172
const T * as() const
Reinterprets the buffer as a typed const pointer.
Definition tensor_utilities.h:174
Fixed-capacity small-vector describing tensor dimensions (rank up to MaxRank).
Definition tensor_utilities.h:42
Index & operator[](size_t i) noexcept
Definition tensor_utilities.h:71
const Index * begin() const noexcept
Definition tensor_utilities.h:68
Shape() noexcept=default
friend ostream & operator<<(ostream &os, const Shape &shape)
Definition tensor_utilities.h:92
Index dim_or_zero(size_t i) const noexcept
Returns dims[i] when i is in range, otherwise 0.
Definition tensor_utilities.h:79
static constexpr size_t MaxRank
Definition tensor_utilities.h:43
const Index & back() const
Definition tensor_utilities.h:74
void push_back(Index value) noexcept
Appends a dimension to the shape (silently no-op if already at MaxRank).
Definition tensor_utilities.h:90
bool empty() const noexcept
Definition tensor_utilities.h:76
size_t rank
Definition tensor_utilities.h:46
Shape(initializer_list< Index > list)
Builds a shape from a brace-enclosed list of dimensions.
Definition tensor_utilities.h:60
const Index & operator[](size_t i) const noexcept
Definition tensor_utilities.h:70
Index dims[MaxRank]
Definition tensor_utilities.h:45
const Index * end() const noexcept
Definition tensor_utilities.h:69
bool operator==(const Shape &other) const noexcept
Definition tensor_utilities.h:100
void clear() noexcept
Resets the shape to rank 0 without freeing storage.
Definition tensor_utilities.h:88
Shape & append(const Shape &other)
Appends another shape's dimensions to this one, stopping at MaxRank.
Definition tensor_utilities.h:106
Index size() const noexcept
Returns the number of elements (product of all dimensions).
Definition tensor_utilities.h:82
Index & back()
Definition tensor_utilities.h:73
Lightweight description of a tensor's shape and data type (no storage attached).
Definition tensor_utilities.h:117
Shape shape
Definition tensor_utilities.h:118
Type dtype
Definition tensor_utilities.h:119
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293
MatrixMap as_matrix() const
Maps the view to an Eigen matrix: rows = first dim, cols = product of the rest.
Definition tensor_utilities.h:349
bool empty() const noexcept
Returns true if the shape is empty.
Definition tensor_utilities.h:315
void setZero()
Zeros every element of the view.
Definition tensor_utilities.h:411
Index size() const noexcept
Total element count.
Definition tensor_utilities.h:309
Index byte_size() const noexcept
Total byte count (size() * sizeof(dtype)).
Definition tensor_utilities.h:312
MatrixMap as_flat_matrix() const
Maps the view to an Eigen matrix flattened across all leading dimensions.
Definition tensor_utilities.h:365
void fill(float value)
Sets every element of the view to the given value, dispatching CPU/GPU as needed.
Definition tensor_utilities.h:534
TensorMapR< Rank > as_tensor(Index batch_index) const
Maps a single batch slice of the view to an Eigen Tensor of rank Rank.
Definition tensor_utilities.h:399
Index get_rank() const noexcept
Number of dimensions in the view.
Definition tensor_utilities.h:306
VectorMap as_vector() const
Maps the view to a flat Eigen vector.
Definition tensor_utilities.h:382
cudaDataType_t cuda_dtype() const noexcept
Returns the CUDA data type tag corresponding to this view's dtype.
Definition tensor_utilities.h:332
float * as_float() const noexcept
Reinterprets the view's data as a float pointer.
Definition tensor_utilities.h:326
TensorView reshape(const Shape &new_shape) const
Returns a new view over the same memory with a different shape.
Definition tensor_utilities.h:345
Shape shape
Definition tensor_utilities.h:296
T * as() const noexcept
Reinterprets the view's data as a pointer to T (no type checking).
Definition tensor_utilities.h:319
TensorView(void *new_data=nullptr, const Shape &new_shape={}, Type new_dtype=Type::FP32) noexcept
Constructs a view from an external buffer, shape, and dtype.
Definition tensor_utilities.h:301
void * data
Definition tensor_utilities.h:294
TensorMapR< Rank > as_tensor() const
Maps the view to an Eigen Tensor of the given rank.
Definition tensor_utilities.h:389
void dispatch(F &&fn) const
Dispatches a callable on the concrete element type (FP32 or BF16).
Definition tensor_utilities.h:336
MatrixMap as_matrix(Index batch_index) const
Maps a single batch slice of the view to an Eigen matrix.
Definition tensor_utilities.h:356
MatrixMap as_flat_matrix(Index batch_index) const
Flat-matrix view of a single batch slice.
Definition tensor_utilities.h:373
Type type
Definition tensor_utilities.h:298