-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtensor.h
More file actions
163 lines (127 loc) · 4.69 KB
/
tensor.h
File metadata and controls
163 lines (127 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright (c) 2025, IST Austria, developed by Erik Schultheis
// SPDX-License-Identifier: Apache-2.0
//
#ifndef LLMQ_SRC_UTILS_TENSOR_H
#define LLMQ_SRC_UTILS_TENSOR_H
#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include "dtype.h"
#include "utils.h"
constexpr int MAX_TENSOR_DIM = 5;
void throw_dtype_mismatch(ETensorDType expected, ETensorDType actual);
//! \brief The Tensor class represents a contiguous view on memory that is associated
//! with a specific data type and shape.
struct Tensor {
ETensorDType DType;
std::array<long, MAX_TENSOR_DIM> Sizes;
std::byte* Data = nullptr;
float* Stats = nullptr;
int Rank = 0;
int Device = -1;
[[nodiscard]] constexpr bool empty() const {
return Data == nullptr;
}
[[nodiscard]] constexpr std::size_t bytes() const {
return nelem() * get_dtype_size(DType);
}
[[nodiscard]] constexpr std::size_t nelem() const {
std::size_t sz = 1;
for(int i = 0; i < Rank; ++i) {
sz *= Sizes[i];
}
return sz;
}
constexpr explicit operator bool() const {
return !empty();
}
//! this is a debugging function, copying the requested element from the GPU to the CPU
//! for easier printing. Do **not** use it for anything else!
template<class TargetType>
TargetType at(long index) const;
//! this is a debugging function, printing a few consecutive elements of the tensor.
//! Do **not** use it for anything else!
void print_sample(long offset, long count=10) const;
template<class TargetType>
[[nodiscard]] const TargetType* get() const {
if(dtype_from_type<TargetType> != DType) {
throw_dtype_mismatch(dtype_from_type<TargetType>, DType);
}
if(Data == nullptr) {
throw std::logic_error("Tensor is null");
}
return reinterpret_cast<const TargetType*>(Data);
}
template<class TargetType>
[[nodiscard]] TargetType* get() {
if(dtype_from_type<TargetType> != DType) {
throw_dtype_mismatch(dtype_from_type<TargetType>, DType);
}
if (Data == nullptr) {
throw std::logic_error("Tensor is null");
}
return reinterpret_cast<TargetType*>(Data);
}
// like `get`, but may return nullptr. In case of nullptr, no type check will be performed.
template<class TargetType>
[[nodiscard]] TargetType* get_optional() {
if(Data == nullptr) { return nullptr; }
if(dtype_from_type<TargetType> != DType) {
throw_dtype_mismatch(dtype_from_type<TargetType>, DType);
}
return reinterpret_cast<TargetType*>(Data);
}
template<class TargetType>
[[nodiscard]] const TargetType* get_optional() const {
if(Data == nullptr) { return nullptr; }
if(dtype_from_type<TargetType> != DType) {
throw_dtype_mismatch(dtype_from_type<TargetType>, DType);
}
return reinterpret_cast<const TargetType*>(Data);
}
template<typename Container>
static Tensor from_pointer(std::byte* ptr, int device, ETensorDType dtype, const Container& shape)
{
if(shape.size() > MAX_TENSOR_DIM) {
throw std::runtime_error("Tensor rank too large");
}
int rank = narrow<int>(shape.size());
std::array<long, MAX_TENSOR_DIM> sizes{};
std::copy(shape.begin(), shape.end(), sizes.begin());
std::fill(sizes.begin() + shape.size(), sizes.end(), 1);
return Tensor{dtype, sizes, ptr, nullptr, rank, device};
}
float* abs_max() {
return Stats;
}
float* scale() {
if(Stats == nullptr)
return nullptr;
return Stats + 1;
}
};
void fill_zero(Tensor& dst, cudaStream_t stream);
Tensor slice(const Tensor& src, int dim, long start, long end);
class TensorShard : public Tensor {
public:
TensorShard() = default;
TensorShard(const Tensor& src); // implicit
template<typename Container>
TensorShard(const Tensor& src, int idx, int num, const Container& global_shape)
: Tensor(src), GlobalShape{}, ShardIndex(idx), NumShards(num) {
std::copy(global_shape.begin(), global_shape.end(), GlobalShape.begin());
std::fill(GlobalShape.begin() + global_shape.size(), GlobalShape.end(), 1);
if(global_nelem() != src.nelem() * NumShards) {
throw std::logic_error("Invalid global shape");
}
}
std::size_t global_nelem() const;
std::ptrdiff_t shard_offset() const;
std::array<long, MAX_TENSOR_DIM> GlobalShape;
int ShardIndex;
int NumShards;
};
TensorShard shard_view(const Tensor& src, int idx, int num);
#endif //LLMQ_SRC_UTILS_TENSOR_H