Skip to content

Commit bc784df

Browse files
committed
Tensor cuda is complete, and nn cuda is too
1 parent 0129cb1 commit bc784df

49 files changed

Lines changed: 1558 additions & 163 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Makefile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ tensor_cuda:
4646
@cmake --build build --target test_tensor_cuda
4747
@ctest --test-dir build -R "^TensorCUDA" --output-on-failure
4848

49-
.PHONY: nn
50-
nn:
51-
@cmake --build build --target test_nn
52-
@ctest --test-dir build -R "^NN" --output-on-failure
49+
.PHONY: nn_cpu
50+
nn_cpu:
51+
@cmake --build build --target test_nn_cpu
52+
@ctest --test-dir build -R "^NNCPU" --output-on-failure
53+
54+
.PHONY: nn_cuda
55+
nn_cuda:
56+
@cmake --build build --target test_nn_cuda
57+
@ctest --test-dir build -R "^NNCUDA" --output-on-failure
5358

5459
.PHONY: llama
5560
llama:

include/tensor/ops.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ template <typename T, typename D>
6666
Tensor<std::remove_const_t<T>, D> slice(const TensorView<T, D>& view, int dim, size_t start,
6767
size_t end);
6868

69+
template <typename T, typename D>
70+
Tensor<std::remove_const_t<T>, D> repeat_interleave(const TensorView<T, D>& view, int dim,
71+
size_t repeats);
72+
6973
template <typename T, typename D>
7074
Tensor<std::remove_const_t<T>, D> sum(const TensorView<T, D>& input, int dim, bool keepdim);
7175
template <typename T, typename D>
@@ -74,9 +78,17 @@ Tensor<std::remove_const_t<T>, D> max(const TensorView<T, D>& input, int dim, bo
7478
template <typename T, typename D>
7579
Tensor<int, D> argmax(const TensorView<T, D>& input, int dim, bool keepdim);
7680

81+
// copies
82+
83+
template <typename TIn, typename TOut, typename D>
84+
Tensor<TOut, D> to(const TensorView<TIn, D>& tensor);
85+
86+
template <typename T, typename D>
87+
Tensor<std::remove_const_t<T>, D> copy(const TensorView<T, D>& tensor);
88+
7789
// mutations
7890

7991
template <typename T, typename D>
80-
void replace_from_(Tensor<T, D>& out, const TensorView<T, D>& input);
92+
void replace_from_(Tensor<T, D>& destination, const TensorView<T, D>& source);
8193

8294
} // namespace tensor

include/tensor/storage.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ template <typename T> class TensorStorage<T, CUDA> {
132132
return data_;
133133
}
134134

135+
T operator[](size_t idx);
136+
const T operator[](size_t idx) const;
137+
135138
void resize(size_t size);
136139
void fill(T value);
137140
};
@@ -169,6 +172,9 @@ template <typename T> class TensorStorage<const T, CUDA> {
169172
return data_;
170173
}
171174

175+
T operator[](size_t idx);
176+
const T operator[](size_t idx) const;
177+
172178
void resize(size_t size);
173179
};
174180
#endif

include/tensor/tensor.hpp

Lines changed: 36 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,28 @@ template <DType T, Device D> struct TensorView {
134134
return std::span<const T>(data, data_size);
135135
}
136136

137+
#ifdef TENSOR_HAS_CUDA
138+
T operator[](int idx) const
139+
requires std::same_as<D, device::CUDA>
140+
{
141+
if (idx > data_size) {
142+
throw std::out_of_range("cannot index past the tensor size");
143+
}
144+
T value;
145+
cudaMemcpy(&value, data + idx, sizeof(T), cudaMemcpyDeviceToHost);
146+
return value;
147+
}
148+
#endif
149+
150+
T operator[](int idx) const
151+
requires std::same_as<D, device::CPU>
152+
{
153+
if (idx > data_size) {
154+
throw std::out_of_range("cannot index past the tensor size");
155+
}
156+
return *(data + idx);
157+
}
158+
137159
[[nodiscard]] size_t total_elements() const {
138160
size_t out = 1;
139161
for (auto dim : shape) {
@@ -241,46 +263,6 @@ template <DType T, Device D> struct TensorView {
241263
transpose(0, 1);
242264
}
243265

244-
Tensor<std::remove_const_t<T>, D> repeat_interleave(size_t dim, size_t repeats) const {
245-
assert(dim < shape.size());
246-
247-
Shape temp_shape;
248-
Shape temp_stride;
249-
250-
for (size_t dim_ = 0; dim_ <= dim; ++dim_) {
251-
temp_shape.push_back(shape[dim_]);
252-
temp_stride.push_back(stride[dim_]);
253-
}
254-
255-
temp_shape.push_back(repeats);
256-
temp_stride.push_back(0);
257-
258-
for (size_t dim_ = dim + 1; dim_ < shape.size(); ++dim_) {
259-
temp_shape.push_back(shape[dim_]);
260-
temp_stride.push_back(stride[dim_]);
261-
}
262-
263-
size_t temp_size = 1;
264-
for (auto dim_ : temp_shape) {
265-
temp_size *= dim_;
266-
}
267-
268-
TensorView temp_view{data, temp_size, temp_shape, temp_stride};
269-
270-
Tensor<T, D> materialized = temp_view.copy();
271-
272-
Shape final_shape;
273-
for (size_t dim_ = 0; dim_ < shape.size(); ++dim_) {
274-
if (dim_ == dim) {
275-
final_shape.push_back(shape[dim_] * repeats); // Expanded dimension
276-
} else {
277-
final_shape.push_back(shape[dim_]);
278-
}
279-
}
280-
281-
return materialized.view().reshape(final_shape);
282-
}
283-
284266
[[nodiscard]] bool is_contiguous() const {
285267
if (shape.empty()) {
286268
return true;
@@ -296,7 +278,10 @@ template <DType T, Device D> struct TensorView {
296278
return true;
297279
}
298280

299-
template <DType OutT, typename Func> Tensor<OutT, D> map(Func func) const {
281+
template <DType OutT, typename Func>
282+
Tensor<OutT, D> map(Func func) const
283+
requires std::same_as<D, device::CPU>
284+
{
300285
Tensor<OutT, D> result{shape};
301286

302287
auto result_span = result.span();
@@ -317,7 +302,10 @@ template <DType T, Device D> struct TensorView {
317302
return result;
318303
}
319304

320-
template <typename Func> void each(Func func) const {
305+
template <typename Func>
306+
void each(Func func) const
307+
requires std::same_as<D, device::CPU>
308+
{
321309
size_t total_elems = total_elements();
322310

323311
for (size_t linear_idx = 0; linear_idx < total_elems; ++linear_idx) {
@@ -332,10 +320,6 @@ template <DType T, Device D> struct TensorView {
332320
}
333321
}
334322

335-
template <DType OutT> Tensor<OutT, D> to() const {
336-
return map<OutT>([](T val) { return static_cast<OutT>(val); });
337-
}
338-
339323
void check_for_nans() const {
340324
for (size_t i = 0; i < span().size(); ++i) {
341325
if (std::isnan(span()[i])) {
@@ -349,10 +333,6 @@ template <DType T, Device D> struct TensorView {
349333
}
350334
}
351335

352-
Tensor<std::remove_const_t<T>, D> copy() const {
353-
return map<std::remove_const_t<T>>([](T val) { return val; });
354-
}
355-
356336
Tensor<std::remove_const_t<T>, D> contiguous() const {
357337
Tensor<std::remove_const_t<T>, D> result{shape};
358338
auto dst_span = result.span();
@@ -406,18 +386,6 @@ template <DType T, Device D> struct TensorView {
406386
return out;
407387
}
408388

409-
Tensor<std::remove_const_t<T>, D> cos() const {
410-
return map<std::remove_const_t<T>>([](T val) { return std::cos(val); });
411-
}
412-
413-
Tensor<std::remove_const_t<T>, D> sin() const {
414-
return map<std::remove_const_t<T>>([](T val) { return std::sin(val); });
415-
}
416-
417-
Tensor<std::remove_const_t<T>, D> exp() const {
418-
return map<std::remove_const_t<T>>([](T val) { return std::exp(val); });
419-
}
420-
421389
T item() const {
422390
assert(data_size == 1);
423391
return data[0];
@@ -491,11 +459,6 @@ template <DType T, Device D> class Tensor {
491459
return TensorView<const T, D>{data(), size(), shape(), get_all_strides(shape())};
492460
}
493461

494-
// Copy to a new mutable tensor
495-
Tensor<std::remove_const_t<T>, D> copy() const {
496-
return view().copy();
497-
}
498-
499462
void fill_(T value)
500463
requires(!std::is_const_v<T>)
501464
{
@@ -547,17 +510,17 @@ template <DType T, Device D> class Tensor {
547510
span()[idx] = value;
548511
}
549512

550-
T item() const {
551-
assert(shape().size() == 0);
552-
return storage_.data()[0];
553-
}
554-
555513
T at(int idx) const {
556514
if (idx > size()) {
557515
throw std::out_of_range("cannot index past the tensor size");
558516
}
559517
return storage_[idx];
560518
}
519+
520+
T item() const {
521+
assert(shape().size() == 0);
522+
return at(0);
523+
}
561524
};
562525

563526
} // namespace tensor
@@ -619,7 +582,7 @@ template <tensor::DType T, tensor::Device D> struct fmt::formatter<tensor::Tenso
619582
const auto& strides = tensor_view.stride;
620583
if (dim == shape.size()) {
621584
// Base case: actually print one scalar
622-
return fmt::format_to(out, "{}", tensor_view.span()[offset]);
585+
return fmt::format_to(out, "{}", tensor_view[offset]);
623586
}
624587

625588
auto dim_size = shape[dim];

src/llama/grouped_query_attention.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ GroupedQueryAttention<T, D>::forward(const TensorView<T, D>& inputs,
107107
keys_v = keys.view();
108108

109109
// repeat-expand to (batch, [num_kv_groups * group_size], seq_len, head_dim)
110-
keys = keys_v.repeat_interleave(1, group_size);
111-
values = values_v.repeat_interleave(1, group_size);
110+
keys = repeat_interleave(keys_v, 1, group_size);
111+
values = repeat_interleave(values_v, 1, group_size);
112112

113113
auto transposed_keys_ = keys.view();
114114
transposed_keys_.transpose(2, 3); // (batch, [num_kv_groups*group_size], head_dim, kvs_len)

src/llama/kv_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ KVCache<T, D>::forward(tensor::TensorView<T, D> new_keys, tensor::TensorView<T,
3030
all_keys = cat(already_cached_keys.view(), new_keys, 1);
3131
all_values = cat(already_cached_values.view(), new_values, 1);
3232
} else { // prefill
33-
all_keys = new_keys.copy();
34-
all_values = new_values.copy();
33+
all_keys = copy(new_keys);
34+
all_values = copy(new_values);
3535
}
3636

3737
replace_from_(k_cache, all_keys.view());

src/llama/rope.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ precompute_rope_values(size_t head_dim, float theta_base, size_t context_length)
1515

1616
// compute the inverse frequencies
1717
Tensor<int, D> range = arange<int, D>(0, head_dim, 2);
18-
auto range_float = range.view().template to<float>();
18+
auto range_float = to<int, float>(range.view());
1919

2020
auto scaled = div(range_float.view(), float(head_dim));
2121

@@ -47,7 +47,7 @@ precompute_rope_values(size_t head_dim, float theta_base, size_t context_length)
4747
// Medium frequency: smooth interpolation
4848
float smooth =
4949
(old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
50-
float scaled_inv_freq = ((1.0 - smooth) * (inv_f / factor)) + (smooth * inv_f);
50+
float scaled_inv_freq = ((1.0 - smooth) * (inv_f / factor)) + (smooth * inv_f); // NOLINT
5151
inv_freq_.span()[i] = scaled_inv_freq;
5252
}
5353
}
@@ -86,21 +86,18 @@ Tensor<std::remove_const_t<T>, D> RoPE<T, D>::forward(TensorView<T, D> inputs,
8686

8787
assert(head_dim % 2 == 0);
8888

89-
// Copy inputs to a tensor (stay in bfloat16)
90-
Tensor<T, D> inputs_t = inputs.copy();
91-
9289
// Slice and convert cos/sin to bfloat16
9390
auto adj_cos_ = slice(cos.view(), 0, position_offset, position_offset + seq_len);
94-
auto adj_cos_bf16 = adj_cos_.view().template to<T>();
91+
auto adj_cos_bf16 = to<float, T>(adj_cos_.view());
9592
auto adj_cos = adj_cos_bf16.view().reshape({1, 1, seq_len, head_dim});
9693

9794
auto adj_sin_ = slice(sin.view(), 0, position_offset, position_offset + seq_len);
98-
auto adj_sin_bf16 = adj_sin_.view().template to<T>();
95+
auto adj_sin_bf16 = to<float, T>(adj_sin_.view());
9996
auto adj_sin = adj_sin_bf16.view().reshape({1, 1, seq_len, head_dim});
10097

10198
// Split input into halves
102-
auto first_half = slice(inputs_t.view(), -1, 0, head_dim / 2);
103-
auto second_half = slice(inputs_t.view(), -1, head_dim / 2, head_dim);
99+
auto first_half = slice(inputs, -1, 0, head_dim / 2);
100+
auto second_half = slice(inputs, -1, head_dim / 2, head_dim);
104101

105102
// Negate second half
106103
auto second_half_neg = mul(second_half.view(), T(-1.0));
@@ -109,7 +106,7 @@ Tensor<std::remove_const_t<T>, D> RoPE<T, D>::forward(TensorView<T, D> inputs,
109106
auto rotated = cat(second_half_neg.view(), first_half.view(), -1);
110107

111108
// Apply rotation: inputs * cos + rotated * sin
112-
auto input_cos = mul(inputs_t.view(), adj_cos.view());
109+
auto input_cos = mul(inputs, adj_cos.view());
113110
auto rotated_sin = mul(rotated.view(), adj_sin.view());
114111

115112
auto out = add(input_cos.view(), rotated_sin.view());

src/nn/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ target_sources(nn_core
1919

2020
add_subdirectory(cpu)
2121

22+
if(APPLE)
23+
option(NN_BUILD_CUDA "Build nn_cuda library" OFF)
24+
else()
25+
option(NN_BUILD_CUDA "Build nn_cuda library" ON)
26+
endif()
27+
28+
if(NN_BUILD_CUDA)
29+
enable_language(CUDA)
30+
find_package(CUDAToolkit REQUIRED)
31+
add_subdirectory(cuda)
32+
target_compile_definitions(nn_core INTERFACE NN_HAS_CUDA)
33+
endif()
34+
2235
source_group(
2336
TREE "${PROJECT_SOURCE_DIR}/include"
2437
PREFIX "Header Files"

src/nn/cpu/softmax.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,29 @@ using namespace tensor;
77
template <typename T, typename D>
88
Tensor<std::remove_const_t<T>, D> Softmax::operator()(const TensorView<T, D>& input,
99
int dim) const {
10-
Tensor<float, D> f32 = input.template to<float>();
10+
Tensor<float, D> f32 = to<T, float>(input);
1111

12-
auto maxes = max(f32.view(), dim, true);
12+
auto maxes = tensor::max(f32.view(), dim, true);
13+
14+
fmt::println("MAXES: {}", maxes.view());
1315

1416
auto scaled = sub(f32.view(), maxes.view());
1517

16-
auto expd = scaled.view().exp();
18+
fmt::println("SCALED: {}", scaled.view());
19+
20+
auto expd = tensor::exp(scaled.view());
21+
22+
fmt::println("EXPD: {}", expd.view());
1723

1824
auto expd_sum = sum(expd.view(), dim, true);
1925

20-
auto out = div(expd.view(), expd_sum.view());
26+
fmt::println("EXPD SUM: {}", expd_sum.view());
27+
28+
auto out = tensor::div(expd.view(), expd_sum.view());
29+
30+
fmt::println("NORMALIZED: {}", out.view());
2131

22-
return out.view().template to<T>();
32+
return to<float, T>(out.view());
2333
}
2434

2535
template Tensor<bfloat16, CPU> Softmax::operator()(const TensorView<bfloat16, CPU>& input,

0 commit comments

Comments
 (0)