@@ -134,6 +134,28 @@ template <DType T, Device D> struct TensorView {
134134 return std::span<const T>(data, data_size);
135135 }
136136
137+ #ifdef TENSOR_HAS_CUDA
138+ T operator [](int idx) const
139+ requires std::same_as<D, device::CUDA>
140+ {
141+ if (idx > data_size) {
142+ throw std::out_of_range (" cannot index past the tensor size" );
143+ }
144+ T value;
145+ cudaMemcpy (&value, data + idx, sizeof (T), cudaMemcpyDeviceToHost);
146+ return value;
147+ }
148+ #endif
149+
150+ T operator [](int idx) const
151+ requires std::same_as<D, device::CPU>
152+ {
153+ if (idx > data_size) {
154+ throw std::out_of_range (" cannot index past the tensor size" );
155+ }
156+ return *(data + idx);
157+ }
158+
137159 [[nodiscard]] size_t total_elements () const {
138160 size_t out = 1 ;
139161 for (auto dim : shape) {
@@ -241,46 +263,6 @@ template <DType T, Device D> struct TensorView {
241263 transpose (0 , 1 );
242264 }
243265
244- Tensor<std::remove_const_t <T>, D> repeat_interleave (size_t dim, size_t repeats) const {
245- assert (dim < shape.size ());
246-
247- Shape temp_shape;
248- Shape temp_stride;
249-
250- for (size_t dim_ = 0 ; dim_ <= dim; ++dim_) {
251- temp_shape.push_back (shape[dim_]);
252- temp_stride.push_back (stride[dim_]);
253- }
254-
255- temp_shape.push_back (repeats);
256- temp_stride.push_back (0 );
257-
258- for (size_t dim_ = dim + 1 ; dim_ < shape.size (); ++dim_) {
259- temp_shape.push_back (shape[dim_]);
260- temp_stride.push_back (stride[dim_]);
261- }
262-
263- size_t temp_size = 1 ;
264- for (auto dim_ : temp_shape) {
265- temp_size *= dim_;
266- }
267-
268- TensorView temp_view{data, temp_size, temp_shape, temp_stride};
269-
270- Tensor<T, D> materialized = temp_view.copy ();
271-
272- Shape final_shape;
273- for (size_t dim_ = 0 ; dim_ < shape.size (); ++dim_) {
274- if (dim_ == dim) {
275- final_shape.push_back (shape[dim_] * repeats); // Expanded dimension
276- } else {
277- final_shape.push_back (shape[dim_]);
278- }
279- }
280-
281- return materialized.view ().reshape (final_shape);
282- }
283-
284266 [[nodiscard]] bool is_contiguous () const {
285267 if (shape.empty ()) {
286268 return true ;
@@ -296,7 +278,10 @@ template <DType T, Device D> struct TensorView {
296278 return true ;
297279 }
298280
299- template <DType OutT, typename Func> Tensor<OutT, D> map (Func func) const {
281+ template <DType OutT, typename Func>
282+ Tensor<OutT, D> map (Func func) const
283+ requires std::same_as<D, device::CPU>
284+ {
300285 Tensor<OutT, D> result{shape};
301286
302287 auto result_span = result.span ();
@@ -317,7 +302,10 @@ template <DType T, Device D> struct TensorView {
317302 return result;
318303 }
319304
320- template <typename Func> void each (Func func) const {
305+ template <typename Func>
306+ void each (Func func) const
307+ requires std::same_as<D, device::CPU>
308+ {
321309 size_t total_elems = total_elements ();
322310
323311 for (size_t linear_idx = 0 ; linear_idx < total_elems; ++linear_idx) {
@@ -332,10 +320,6 @@ template <DType T, Device D> struct TensorView {
332320 }
333321 }
334322
335- template <DType OutT> Tensor<OutT, D> to () const {
336- return map<OutT>([](T val) { return static_cast <OutT>(val); });
337- }
338-
339323 void check_for_nans () const {
340324 for (size_t i = 0 ; i < span ().size (); ++i) {
341325 if (std::isnan (span ()[i])) {
@@ -349,10 +333,6 @@ template <DType T, Device D> struct TensorView {
349333 }
350334 }
351335
352- Tensor<std::remove_const_t <T>, D> copy () const {
353- return map<std::remove_const_t <T>>([](T val) { return val; });
354- }
355-
356336 Tensor<std::remove_const_t <T>, D> contiguous () const {
357337 Tensor<std::remove_const_t <T>, D> result{shape};
358338 auto dst_span = result.span ();
@@ -406,18 +386,6 @@ template <DType T, Device D> struct TensorView {
406386 return out;
407387 }
408388
409- Tensor<std::remove_const_t <T>, D> cos () const {
410- return map<std::remove_const_t <T>>([](T val) { return std::cos (val); });
411- }
412-
413- Tensor<std::remove_const_t <T>, D> sin () const {
414- return map<std::remove_const_t <T>>([](T val) { return std::sin (val); });
415- }
416-
417- Tensor<std::remove_const_t <T>, D> exp () const {
418- return map<std::remove_const_t <T>>([](T val) { return std::exp (val); });
419- }
420-
421389 T item () const {
422390 assert (data_size == 1 );
423391 return data[0 ];
@@ -491,11 +459,6 @@ template <DType T, Device D> class Tensor {
491459 return TensorView<const T, D>{data (), size (), shape (), get_all_strides (shape ())};
492460 }
493461
494- // Copy to a new mutable tensor
495- Tensor<std::remove_const_t <T>, D> copy () const {
496- return view ().copy ();
497- }
498-
499462 void fill_ (T value)
500463 requires(!std::is_const_v<T>)
501464 {
@@ -547,17 +510,17 @@ template <DType T, Device D> class Tensor {
547510 span ()[idx] = value;
548511 }
549512
550- T item () const {
551- assert (shape ().size () == 0 );
552- return storage_.data ()[0 ];
553- }
554-
555513 T at (int idx) const {
556514 if (idx > size ()) {
557515 throw std::out_of_range (" cannot index past the tensor size" );
558516 }
559517 return storage_[idx];
560518 }
519+
520+ T item () const {
521+ assert (shape ().size () == 0 );
522+ return at (0 );
523+ }
561524};
562525
563526} // namespace tensor
@@ -619,7 +582,7 @@ template <tensor::DType T, tensor::Device D> struct fmt::formatter<tensor::Tenso
619582 const auto & strides = tensor_view.stride ;
620583 if (dim == shape.size ()) {
621584 // Base case: actually print one scalar
622- return fmt::format_to (out, " {}" , tensor_view. span () [offset]);
585+ return fmt::format_to (out, " {}" , tensor_view[offset]);
623586 }
624587
625588 auto dim_size = shape[dim];
0 commit comments