Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FullyShardedDataParallel (FSDP) #413

Merged
merged 51 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a7be0d2
Add fairscale.utils.testing.DeviceAndTypeCheckModule
myleott Jan 28, 2021
e2ad716
Add fairscale.utils.containers
myleott Jan 28, 2021
4d6a5c9
Add ShardParamsDataParallel
myleott Jan 26, 2021
dd57e30
[test]: skip autocast-needed tests on torch < 1.6 (#34)
min-xu-ai Jan 29, 2021
80678fc
[mypy]: fairscale/utils/containers.py (#33)
min-xu-ai Jan 29, 2021
5fb36d1
[mypy]: fixed fairscale/utils/testing.py (#32)
min-xu-ai Jan 29, 2021
774e130
More docs (#35)
myleott Jan 29, 2021
b35a28a
[mypy]: fixed all the mypy errors (#37)
min-xu-ai Feb 1, 2021
92c550b
Sharded DDP: test cpu_offload arg (#40)
sshleifer Feb 1, 2021
5bb212f
Misc comments from @anj-s (#43)
myleott Feb 1, 2021
cbd243e
Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wra…
myleott Feb 2, 2021
8db0cf6
Fix streams test (#45)
myleott Feb 2, 2021
bc7e337
move_grads_to_cpu defaults to same value as cpu_offload (#44)
sshleifer Feb 2, 2021
a1b3924
formatting change (#46)
min-xu-ai Feb 2, 2021
e10df73
Test that backward hooks are registered (#49)
sshleifer Feb 3, 2021
e139857
[test] add test for apply_to_tensors (#50)
min-xu-ai Feb 3, 2021
6f153b0
Test save/load state_dict V2 (#51)
sshleifer Feb 4, 2021
4114772
Replace x.view(-1) with torch.flatten(x) (#59)
myleott Feb 4, 2021
36b2d39
Add more comments + docstrings (#58)
myleott Feb 4, 2021
bc5190b
Rearrange dtype and device change in post-backward hook (#61)
myleott Feb 4, 2021
8a5f81c
Do reduce-scatter in a separate CUDA stream (#62)
myleott Feb 4, 2021
72c1f63
tests use spawn_for_all_world_sizes (#63)
sshleifer Feb 5, 2021
f481877
Fix state_dict bugs (#60)
sshleifer Feb 6, 2021
515411b
update comments to reflect where we are in stack (#69)
sshleifer Feb 7, 2021
ec4e75e
[CI] use parameterized.expand to make each test faster (#68)
sshleifer Feb 7, 2021
b1460d3
Fix delayed reduce_scatter test (#74)
myleott Feb 7, 2021
014ad05
add unit test pack/unpack kwargs (#65)
min-xu-ai Feb 7, 2021
0ca378b
Refactor param init and streams logic (#73)
myleott Feb 7, 2021
bebe7fd
Add test for NoGrad mode
myleott Feb 8, 2021
74b0223
Add all_gather stream and disable reshard_after_forward on root insta…
myleott Feb 8, 2021
6797964
Leave buffers on self.compute_device (#67)
sshleifer Feb 9, 2021
0937594
Pad parameters inside: right before gather, scatter (#76)
sshleifer Feb 9, 2021
93670cb
Add no_sync() context manager (#77)
myleott Feb 10, 2021
1d0bf73
rename
sshleifer Feb 10, 2021
5fc1f12
Slightly faster execution when world_size == 1 (#81)
myleott Feb 11, 2021
3681242
merge new base (which is public/master) (#82)
min-xu-ai Feb 11, 2021
dfada29
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
9fb974b
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
7bd82d1
two small changes (#83)
min-xu-ai Feb 12, 2021
d8f3349
fixing uneven shard support and add uneven shard unit test (#80)
min-xu-ai Feb 14, 2021
366c38e
rename test file (#86)
min-xu-ai Feb 15, 2021
87d28c9
adding TrainingState enum for asserting purpose (#87)
min-xu-ai Feb 19, 2021
72b5967
Misc (#88)
myleott Feb 19, 2021
a0fe832
Bugfix for backward stream (#91)
myleott Feb 20, 2021
2abf21f
Small fix to activation checkpointing + FSDP (#92)
myleott Feb 22, 2021
977f459
Bugfix (+ add test) for no_sync before first forward (#93)
myleott Feb 22, 2021
43d1f73
FSDP.clip_grad_norm (#89)
sshleifer Feb 22, 2021
77dc364
Add ReduceScatterBucketer (#94)
myleott Feb 22, 2021
bf30e59
Merge branch 'oss-master' into shard_params_ddp
myleott Feb 22, 2021
ec8bf9f
Fix merge conflict
myleott Feb 22, 2021
3f2b7f1
Update docstring slightly
myleott Feb 22, 2021
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test save/load state_dict V2 (#51)
  • Loading branch information
sshleifer committed Feb 4, 2021
commit 6f153b0666aab1c114924e5727f3148744d32efa
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0
parameterized >= 0.8.1
114 changes: 108 additions & 6 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from fairscale.nn.data_parallel import ShardParamsDataParallel
from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4


class DistributedTest(unittest.TestCase):
def setUp(self):
Expand All @@ -36,6 +38,7 @@ def setUp(self):
def _train_for_several_steps(model, num_steps, autocast):
model_device = next(model.parameters()).device
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
# If you set this higher implem differs from ddp in the 5th decimal place
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
Expand Down Expand Up @@ -157,9 +160,7 @@ def test_transformer(self):
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"]
for config in itertools.product([True, False], repeat=len(keys)):
config = dict(zip(keys, config))
spawn_and_init(
functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2,
)
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),)

def test_cpu_offload_and_cpu_grads(self):
for move_grads_choice in (True, None):
Expand Down Expand Up @@ -233,6 +234,109 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3
raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}")


class TestSaveLoadLocalStateDict(DistributedTest):
def test_load_local_state_dict(self):
test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False})
spawn_and_init(test_fn)

def test_local_state_dict_flatten_params_breaks(self):
test_fn_broken = functools.partial(self._load_local_and_train, {"flatten_parameters": True})
with self.assertRaises(Exception):
spawn_and_init(test_fn_broken)
# RuntimeError: Traceback [1]
# [1] https://gist.github.com/sshleifer/612d8eb02dbbf357d6133b2700e02f5e

def test_local_state_dict_odd_vocab_shape_breaks(self):
test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}, d_model=16, d_vocab=37)
with self.assertRaises(Exception):
spawn_and_init(test_fn)

@classmethod
def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = ShardParamsDataParallel(
TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config
).cuda()
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
model.load_local_state_dict(state_1)
state_1_weight = state_1["embed_tokens.weight"]

# This weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()["embed_tokens.weight"]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 4, False)

state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)

assert state_1.keys() == state_2.keys()

# Assert that parameters were updated since before training
unchanged = []
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all():
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")


class TestSaveLoadStateDict(DistributedTest):
def test_calling_state_dict_twice_breaks(self):
test_fn = functools.partial(self._test_calling_state_dict_twice_breaks, {"flatten_parameters": False})
spawn_and_init(test_fn)

@classmethod
def _test_calling_state_dict_twice_breaks(self, config, rank, group):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(ddp_model, 1, False)
ddp_model.state_dict() # Succeeds
try:
ddp_model.state_dict()
assert False, "Second state_dict call succeeded"
except Exception:
pass

def test_state_dict_after_forward(self):
test_fn = functools.partial(self._test_module_state_dict, {"flatten_parameters": False})
spawn_and_init(test_fn)

@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
try:
ddp_model.state_dict()
assert False, "Calling state_dict before forward succeeded"
except Exception:
pass
cls._train_for_several_steps(ddp_model, 2, False)
state_1 = ddp_model.state_dict()
# You must make a new ShardParamsDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams()
unwrapped_model.load_state_dict(state_1)
new_ddp_model = ShardParamsDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, False)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass


def get_sharded_model():
sharded_model = ShardParamsDataParallel(
nn.Sequential(
nn.Linear(8, 100),
ShardParamsDataParallel(nn.Linear(100, 100)),
ShardParamsDataParallel(nn.Linear(100, 100)),
nn.Linear(100, 8),
)
)
return sharded_model


class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
Expand Down Expand Up @@ -279,11 +383,9 @@ def _test_register_functions_called(self, rank, group, cuda_first=False):


class TransformerWithSharedParams(nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs):
super().__init__()
torch.manual_seed(0) # keep everything deterministic
d_model = 16
d_vocab = 32
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
Expand Down