Skip to content

Crash when using save_state with deepspeed: model.state_dict functions incompatible with new deepspeed. #596

@JohannesAck

Description

@JohannesAck

🐛 Describe the bug

I've recently been using the code provided in https://github.com/tlc4418/llm_optimization, which in turn uses trlx.

In doing so I encountered a bug causing trlx to crash when trying to save, caused by a recent change in deepspeeed.

To reproduce, use this https://gist.github.com/JohannesAck/feb31ee5c491ca30771335296ec8b295 and start it with deepspeed by using accelerate launch with a config that enables deepspeed:

Traceback (most recent call last):
  File "/workspaces/llm_optimization/crash_example.py", line 111, in <module>
    main(hparams)
  File "/workspaces/llm_optimization/crash_example.py", line 101, in main
    trlx.train(
  File "/usr/local/lib/python3.10/dist-packages/trlx/trlx.py", line 142, in train
    trainer.learn()
  File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 598, in learn
    self.save(directory)
  File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 312, in save
    self.accelerator.save_state(dst_dir, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2944, in save_state
    model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3105, in save_checkpoint
    self._save_checkpoint(save_dir,
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3299, in _save_checkpoint
    module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2540, in module_state_dict
    sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
  File "/usr/local/lib/python3.10/dist-packages/trlx/models/modeling_ppo.py", line 460, in state_dict
    state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
TypeError: dict() got multiple values for keyword argument 'prefix'```

This is caused by this change in deepspeed deepspeedai/DeepSpeed#5408, that changes the call to state_dict to use a keyword instead of positional argument:

--- sd = self.module.state_dict(destination, prefix, keep_vars)
+++ sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

TRLX however assumes that the argument will be passed

def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))

In L359: dict(prefix="v_head.", **kwargs) becomes dict(prefix="v_head.", prefix="") and thus has two values for prefix and crashes.

Workaround:

Downgrade deepspeed to a version < 0.14.1:

pip install 'deepspeed<0.14.1'

I'm not sure what the proper solution here would be, just ignoring the prefix argument doesn't sound great either. One option might be to just ignore it if it's an empty string and raise an exception otherwise.

Hope this helps somebody!

Which trlX version are you using?

trlx=0.7.0

Additional system and package information

deepspeed=0.14.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions