Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FullyShardedDataParallel (FSDP) #413

Merged
merged 51 commits into from
Feb 23, 2021
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a7be0d2
Add fairscale.utils.testing.DeviceAndTypeCheckModule
myleott Jan 28, 2021
e2ad716
Add fairscale.utils.containers
myleott Jan 28, 2021
4d6a5c9
Add ShardParamsDataParallel
myleott Jan 26, 2021
dd57e30
[test]: skip autocast-needed tests on torch < 1.6 (#34)
min-xu-ai Jan 29, 2021
80678fc
[mypy]: fairscale/utils/containers.py (#33)
min-xu-ai Jan 29, 2021
5fb36d1
[mypy]: fixed fairscale/utils/testing.py (#32)
min-xu-ai Jan 29, 2021
774e130
More docs (#35)
myleott Jan 29, 2021
b35a28a
[mypy]: fixed all the mypy errors (#37)
min-xu-ai Feb 1, 2021
92c550b
Sharded DDP: test cpu_offload arg (#40)
sshleifer Feb 1, 2021
5bb212f
Misc comments from @anj-s (#43)
myleott Feb 1, 2021
cbd243e
Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wra…
myleott Feb 2, 2021
8db0cf6
Fix streams test (#45)
myleott Feb 2, 2021
bc7e337
move_grads_to_cpu defaults to same value as cpu_offload (#44)
sshleifer Feb 2, 2021
a1b3924
formatting change (#46)
min-xu-ai Feb 2, 2021
e10df73
Test that backward hooks are registered (#49)
sshleifer Feb 3, 2021
e139857
[test] add test for apply_to_tensors (#50)
min-xu-ai Feb 3, 2021
6f153b0
Test save/load state_dict V2 (#51)
sshleifer Feb 4, 2021
4114772
Replace x.view(-1) with torch.flatten(x) (#59)
myleott Feb 4, 2021
36b2d39
Add more comments + docstrings (#58)
myleott Feb 4, 2021
bc5190b
Rearrange dtype and device change in post-backward hook (#61)
myleott Feb 4, 2021
8a5f81c
Do reduce-scatter in a separate CUDA stream (#62)
myleott Feb 4, 2021
72c1f63
tests use spawn_for_all_world_sizes (#63)
sshleifer Feb 5, 2021
f481877
Fix state_dict bugs (#60)
sshleifer Feb 6, 2021
515411b
update comments to reflect where we are in stack (#69)
sshleifer Feb 7, 2021
ec4e75e
[CI] use parameterized.expand to make each test faster (#68)
sshleifer Feb 7, 2021
b1460d3
Fix delayed reduce_scatter test (#74)
myleott Feb 7, 2021
014ad05
add unit test pack/unpack kwargs (#65)
min-xu-ai Feb 7, 2021
0ca378b
Refactor param init and streams logic (#73)
myleott Feb 7, 2021
bebe7fd
Add test for NoGrad mode
myleott Feb 8, 2021
74b0223
Add all_gather stream and disable reshard_after_forward on root insta…
myleott Feb 8, 2021
6797964
Leave buffers on self.compute_device (#67)
sshleifer Feb 9, 2021
0937594
Pad parameters inside: right before gather, scatter (#76)
sshleifer Feb 9, 2021
93670cb
Add no_sync() context manager (#77)
myleott Feb 10, 2021
1d0bf73
rename
sshleifer Feb 10, 2021
5fc1f12
Slightly faster execution when world_size == 1 (#81)
myleott Feb 11, 2021
3681242
merge new base (which is public/master) (#82)
min-xu-ai Feb 11, 2021
dfada29
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
9fb974b
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
7bd82d1
two small changes (#83)
min-xu-ai Feb 12, 2021
d8f3349
fixing uneven shard support and add uneven shard unit test (#80)
min-xu-ai Feb 14, 2021
366c38e
rename test file (#86)
min-xu-ai Feb 15, 2021
87d28c9
adding TrainingState enum for asserting purpose (#87)
min-xu-ai Feb 19, 2021
72b5967
Misc (#88)
myleott Feb 19, 2021
a0fe832
Bugfix for backward stream (#91)
myleott Feb 20, 2021
2abf21f
Small fix to activation checkpointing + FSDP (#92)
myleott Feb 22, 2021
977f459
Bugfix (+ add test) for no_sync before first forward (#93)
myleott Feb 22, 2021
43d1f73
FSDP.clip_grad_norm (#89)
sshleifer Feb 22, 2021
77dc364
Add ReduceScatterBucketer (#94)
myleott Feb 22, 2021
bf30e59
Merge branch 'oss-master' into shard_params_ddp
myleott Feb 22, 2021
ec8bf9f
Fix merge conflict
myleott Feb 22, 2021
3f2b7f1
Update docstring slightly
myleott Feb 22, 2021
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add all_gather stream and disable reshard_after_forward on root insta…
…nce (#75)
  • Loading branch information
myleott committed Feb 8, 2021
commit 74b0223b605c21e45ae9c626997b64f92cee025c
73 changes: 47 additions & 26 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class ShardParamsDataParallel(nn.Module):
module (nn.Module): module to checkpoint
process_group (Optional): process group for sharding
reshard_after_forward (bool, Optional): if ``True``, reshard parameters
after the forward pass. This saves memory but slows training.
after the forward pass. This saves memory but slows training. This
is only relevant when resharding individual layers.
mixed_precision (bool, Optional): if ``True``, inputs, activations and
gradients will be kept in FP16; computation and communication will
occur in FP16; and a (sharded) master copy of the model weights will
Expand Down Expand Up @@ -278,6 +279,11 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()

# Don't free the full params for the outer-most (root) instance, since
# those params will be needed immediately after for the backward pass.
if self._is_root:
self.reshard_after_forward = False

@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
"""
Expand Down Expand Up @@ -361,6 +367,8 @@ def _setup_streams(self) -> None:
return
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams["fp32_to_fp16"] = torch.cuda.Stream()
# Stream for all-gathering parameters.
self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream()
# We share streams with all children instances, which allows them to
Expand All @@ -376,7 +384,10 @@ def _wait_for_previous_optim_step(self) -> None:
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
if self.mixed_precision:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._lazy_init()
Expand Down Expand Up @@ -511,13 +522,16 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)

# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad

# Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback.
if not self._post_backward_callback_queued:
# post-backward work has finished. We only need one callback and it only
# needs to be called from the outer-most (root) instance.
if self._is_root and not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)

Expand All @@ -533,21 +547,23 @@ def _wait_for_post_backward(self) -> None:
@torch.no_grad()
def _rebuild_full_params(self) -> None:
"""Gather all shards of params."""
if self.mixed_precision and len(self._streams) > 0:
self._cast_fp32_param_shards_to_fp16()
with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision:
self._cast_fp32_param_shards_to_fp16()

for p in self.params:
if p._full_param.storage().size() != p._orig_size.numel():
# All-gather parameters
alloc_storage_(p._full_param, size=p._orig_size)
output_list = list(torch.flatten(p._full_param).chunk(self.world_size))
dist.all_gather(output_list, p.data, group=self.process_group)
for p in self.params:
if p._full_param.storage().size() != p._orig_size.numel():
# All-gather parameters
alloc_storage_(p._full_param, size=p._orig_size)
output_list = list(torch.flatten(p._full_param).chunk(self.world_size))
dist.all_gather(output_list, p.data, group=self.process_group)

p.data = p._full_param
p.grad = None
p.data = p._full_param
p.grad = None

if self.mixed_precision and len(self._streams) > 0:
self._free_fp16_param_shard([p])
if self.mixed_precision:
self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])

@torch.no_grad()
def _use_full_params(self) -> None:
Expand All @@ -561,15 +577,16 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we unshard
# parameters, we should reuse the original Tensor Storage object
# and unshard it in-place. For now, just resize the Storage to 0 to
# save memory.
p._full_param.record_stream(current_stream)
free_storage_(p._full_param)
with torch.cuda.stream(self._streams["all_gather"]):
for p in params:
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
p._full_param.record_stream(current_stream)
free_storage_(p._full_param)

@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
Expand All @@ -588,7 +605,11 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(p._fp32_shard, non_blocking=True)
p._fp16_shard.copy_(
# If cpu_offload is True, this will be non-blocking because
# _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

Expand Down