-
Notifications
You must be signed in to change notification settings - Fork 480
Description
🐛 Describe the bug
I've recently been using the code provided in https://github.com/tlc4418/llm_optimization, which in turn uses trlx.
In doing so I encountered a bug causing trlx to crash when trying to save, caused by a recent change in deepspeeed.
To reproduce, use this https://gist.github.com/JohannesAck/feb31ee5c491ca30771335296ec8b295 and start it with deepspeed by using accelerate launch
with a config that enables deepspeed:
Traceback (most recent call last):
File "/workspaces/llm_optimization/crash_example.py", line 111, in <module>
main(hparams)
File "/workspaces/llm_optimization/crash_example.py", line 101, in main
trlx.train(
File "/usr/local/lib/python3.10/dist-packages/trlx/trlx.py", line 142, in train
trainer.learn()
File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 598, in learn
self.save(directory)
File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 312, in save
self.accelerator.save_state(dst_dir, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2944, in save_state
model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3105, in save_checkpoint
self._save_checkpoint(save_dir,
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3299, in _save_checkpoint
module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2540, in module_state_dict
sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
File "/usr/local/lib/python3.10/dist-packages/trlx/models/modeling_ppo.py", line 460, in state_dict
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
TypeError: dict() got multiple values for keyword argument 'prefix'```
This is caused by this change in deepspeed deepspeedai/DeepSpeed#5408, that changes the call to state_dict to use a keyword instead of positional argument:
--- sd = self.module.state_dict(destination, prefix, keep_vars)
+++ sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
TRLX however assumes that the argument will be passed
trlx/trlx/models/modeling_ppo.py
Lines 354 to 359 in 3340c2f
def state_dict(self, *args, heads_only=False, **kwargs): | |
""" | |
Returns the state dictionary of the model. We add the state dictionary of the value head | |
to the state dictionary of the wrapped model by prepending the key with `v_head.`. | |
""" | |
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs)) |
In L359: dict(prefix="v_head.", **kwargs)
becomes dict(prefix="v_head.", prefix="")
and thus has two values for prefix
and crashes.
Workaround:
Downgrade deepspeed to a version < 0.14.1:
pip install 'deepspeed<0.14.1'
I'm not sure what the proper solution here would be, just ignoring the prefix
argument doesn't sound great either. One option might be to just ignore it if it's an empty string and raise an exception otherwise.
Hope this helps somebody!
Which trlX version are you using?
trlx=0.7.0
Additional system and package information
deepspeed=0.14.4