Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FullyShardedDataParallel (FSDP) #413

Merged
merged 51 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a7be0d2
Add fairscale.utils.testing.DeviceAndTypeCheckModule
myleott Jan 28, 2021
e2ad716
Add fairscale.utils.containers
myleott Jan 28, 2021
4d6a5c9
Add ShardParamsDataParallel
myleott Jan 26, 2021
dd57e30
[test]: skip autocast-needed tests on torch < 1.6 (#34)
min-xu-ai Jan 29, 2021
80678fc
[mypy]: fairscale/utils/containers.py (#33)
min-xu-ai Jan 29, 2021
5fb36d1
[mypy]: fixed fairscale/utils/testing.py (#32)
min-xu-ai Jan 29, 2021
774e130
More docs (#35)
myleott Jan 29, 2021
b35a28a
[mypy]: fixed all the mypy errors (#37)
min-xu-ai Feb 1, 2021
92c550b
Sharded DDP: test cpu_offload arg (#40)
sshleifer Feb 1, 2021
5bb212f
Misc comments from @anj-s (#43)
myleott Feb 1, 2021
cbd243e
Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wra…
myleott Feb 2, 2021
8db0cf6
Fix streams test (#45)
myleott Feb 2, 2021
bc7e337
move_grads_to_cpu defaults to same value as cpu_offload (#44)
sshleifer Feb 2, 2021
a1b3924
formatting change (#46)
min-xu-ai Feb 2, 2021
e10df73
Test that backward hooks are registered (#49)
sshleifer Feb 3, 2021
e139857
[test] add test for apply_to_tensors (#50)
min-xu-ai Feb 3, 2021
6f153b0
Test save/load state_dict V2 (#51)
sshleifer Feb 4, 2021
4114772
Replace x.view(-1) with torch.flatten(x) (#59)
myleott Feb 4, 2021
36b2d39
Add more comments + docstrings (#58)
myleott Feb 4, 2021
bc5190b
Rearrange dtype and device change in post-backward hook (#61)
myleott Feb 4, 2021
8a5f81c
Do reduce-scatter in a separate CUDA stream (#62)
myleott Feb 4, 2021
72c1f63
tests use spawn_for_all_world_sizes (#63)
sshleifer Feb 5, 2021
f481877
Fix state_dict bugs (#60)
sshleifer Feb 6, 2021
515411b
update comments to reflect where we are in stack (#69)
sshleifer Feb 7, 2021
ec4e75e
[CI] use parameterized.expand to make each test faster (#68)
sshleifer Feb 7, 2021
b1460d3
Fix delayed reduce_scatter test (#74)
myleott Feb 7, 2021
014ad05
add unit test pack/unpack kwargs (#65)
min-xu-ai Feb 7, 2021
0ca378b
Refactor param init and streams logic (#73)
myleott Feb 7, 2021
bebe7fd
Add test for NoGrad mode
myleott Feb 8, 2021
74b0223
Add all_gather stream and disable reshard_after_forward on root insta…
myleott Feb 8, 2021
6797964
Leave buffers on self.compute_device (#67)
sshleifer Feb 9, 2021
0937594
Pad parameters inside: right before gather, scatter (#76)
sshleifer Feb 9, 2021
93670cb
Add no_sync() context manager (#77)
myleott Feb 10, 2021
1d0bf73
rename
sshleifer Feb 10, 2021
5fc1f12
Slightly faster execution when world_size == 1 (#81)
myleott Feb 11, 2021
3681242
merge new base (which is public/master) (#82)
min-xu-ai Feb 11, 2021
dfada29
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
9fb974b
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
7bd82d1
two small changes (#83)
min-xu-ai Feb 12, 2021
d8f3349
fixing uneven shard support and add uneven shard unit test (#80)
min-xu-ai Feb 14, 2021
366c38e
rename test file (#86)
min-xu-ai Feb 15, 2021
87d28c9
adding TrainingState enum for asserting purpose (#87)
min-xu-ai Feb 19, 2021
72b5967
Misc (#88)
myleott Feb 19, 2021
a0fe832
Bugfix for backward stream (#91)
myleott Feb 20, 2021
2abf21f
Small fix to activation checkpointing + FSDP (#92)
myleott Feb 22, 2021
977f459
Bugfix (+ add test) for no_sync before first forward (#93)
myleott Feb 22, 2021
43d1f73
FSDP.clip_grad_norm (#89)
sshleifer Feb 22, 2021
77dc364
Add ReduceScatterBucketer (#94)
myleott Feb 22, 2021
bf30e59
Merge branch 'oss-master' into shard_params_ddp
myleott Feb 22, 2021
ec8bf9f
Fix merge conflict
myleott Feb 22, 2021
3f2b7f1
Update docstring slightly
myleott Feb 22, 2021
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add test for NoGrad mode
  • Loading branch information
myleott committed Feb 8, 2021
commit bebe7fd6b3c3143d31f09171646d0f1d53ac1ac1
8 changes: 6 additions & 2 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,18 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# state is typically initialized lazily in ``optim.step()``.
self._use_fp32_param_shard()

if torch.is_grad_enabled():
outputs = self._register_pre_backward_hooks(outputs)
# Register pre-backward hooks to all-gather the params for the backward
# pass (if needed).
outputs = self._register_pre_backward_hooks(outputs)

return outputs

def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward."""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled

pre_backward_hook_has_run = [False]

def _pre_backward_hook(*unused: Any) -> None:
Expand Down
27 changes: 27 additions & 0 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,33 @@ def _test_register_functions_called(self, rank, group, cuda_first=False):
assert model._register_pre_backward_hooks.called


class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
test_fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(test_fn)

@classmethod
def _test_transformer(self, rank, group, config):
autocast = config["mixed_precision"]

# Train model for a step
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast)

model.eval() # no dropout for this test

# Eval in standard mode (i.e., without no_grad)
input = model.module.get_input(torch.device("cuda"))
ref_output = model(*input)

# Eval with no_grad and compare
with torch.no_grad():
no_grad_output = model(*input)

assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output"


class TransformerWithSharedParams(nn.Module):
def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs):
super().__init__()
Expand Down