Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 4 additions & 36 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import gc
import collections
import itertools
from typing import Deque, Dict, Set, List, Tuple, Container, Optional
from typing import Deque, Dict, Set, List, Container, Optional
from contextlib import contextmanager
from dataclasses import dataclass, field

Expand All @@ -21,13 +21,13 @@
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer, defragment
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
Expand Down Expand Up @@ -655,38 +655,6 @@ def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]["lr"]

# TODO. factor out to a utility outside of stage3
@staticmethod
def defragment(tensors: List[Tensor]) -> Tensor:
"""move provided tensors into a contiguous flat buffer, with some additional
measures taken to reduce memory fragmentation"""
assert len(set(t.dtype for t in tensors)) == 1
assert len(set(t.device for t in tensors)) == 1

cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
dtype=get_only_unique_item(t.dtype for t in tensors),
device="cpu")
tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors)
orig_device = get_only_unique_item(t.device for t in tensors)

offset = 0
for tensor, offset, tensor_numel in tensor_infos:
# move the tensor from device memory to host memory
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)

gc.collect()
get_accelerator().empty_cache()

# copy tensors (now flattened and contiguous) back to GPU
device_buffer = cpu_buffer.to(orig_device)

# restore device tensors
for tensor, offset, tensor_numel in tensor_infos:
tensor.data = device_buffer.narrow(0, offset, tensor_numel)

return device_buffer

def _get_param_coordinator(self):
return self.parameter_offload.get_param_coordinator()

Expand Down Expand Up @@ -834,7 +802,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
parameter_partitions = self._get_parameter_partitions()

# We need to keep the reference to this buffer to make sure you can free it in `offload_states`
self.lp_param_buffer = __class__.defragment(parameter_partitions)
self.lp_param_buffer = defragment(parameter_partitions)
self._set_fp16_partitioned_groups_flat()

else: # partitioned params offloaded to CPU when not in use
Expand Down
33 changes: 33 additions & 0 deletions deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import os
import gc
from typing import List, Tuple

import torch
Expand All @@ -15,6 +16,7 @@
from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion
from deepspeed.utils.nvtx import instrument_w_nvtx
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.utils import get_only_unique_item

# ensure we only warn once, otherwise every iteration will trigger a warning
warned = False
Expand Down Expand Up @@ -200,3 +202,34 @@ def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.
offset += tensor_numel

return tensor_infos


def defragment(tensors: List[torch.Tensor]) -> torch.Tensor:
"""move provided tensors into a contiguous flat buffer, with some additional
measures taken to reduce memory fragmentation"""
assert len(set(t.dtype for t in tensors)) == 1
assert len(set(t.device for t in tensors)) == 1

cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
dtype=get_only_unique_item(t.dtype for t in tensors),
device="cpu")
tensor_infos: List[Tuple[torch.Tensor, int, int]] = get_mapping_to_flat_buffer(tensors)
orig_device = get_only_unique_item(t.device for t in tensors)

offset = 0
for tensor, offset, tensor_numel in tensor_infos:
# move the tensor from device memory to host memory
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)

gc.collect()
get_accelerator().empty_cache()

# copy tensors (now flattened and contiguous) back to GPU
device_buffer = cpu_buffer.to(orig_device)

# restore device tensors
for tensor, offset, tensor_numel in tensor_infos:
tensor.data = device_buffer.narrow(0, offset, tensor_numel)

return device_buffer