diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index c8f75646fe..4db326a13b 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -263,7 +263,7 @@ def update_policy(self, data: DataProto): kl_penalty=self.config.kl_loss_type) kl_loss = masked_mean(kld, response_mask) - policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics['actor/kl_loss'] = kl_loss.detach().item() metrics['actor/kl_coef'] = self.config.kl_loss_coef