Skip to content

Online filter #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@
add_right_padding,
compute_advantages,
dist_compute_masked_mean_and_var,
filter_resolved_outputs,
flatten,
get_decoded_sequence,
get_entropies,
get_log_probs,
mask_eos,
masked_mean,
masked_sum,
partition_batch,
stack_resolved_outputs,
switch_left_to_right_padding,
)

Expand Down Expand Up @@ -428,10 +431,10 @@ def __init__(
f'Per iteration using: {self.num_unique_prompts_per_iter} prompts.',
)

if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update:
raise ValueError(
f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}',
)
# if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update:
# raise ValueError(
# f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}',
# )

self.epochs_per_iteration = ensure_time(
var_config.get('epoch_per_iteration', 1),
Expand All @@ -441,7 +444,9 @@ def __init__(

# Programmatically setting the max buffer size instead of the yaml
var_config['buffer']['max_buffer_size'] = self.num_batches_per_update

self.buffer = MinibatchRolloutBuffer(var_config['buffer'])
self.global_sample_list = []

# Build the KL controller through registries
kl_ctl_name = var_config['kl_controller'].pop('kl_ctl_type')
Expand Down Expand Up @@ -470,6 +475,12 @@ def __init__(
train_config['python_log_level'].upper(),
)

self.same_reward_filter_threshold = var_config.get(
'same_reward_filter_threshold',
None,
)

self.use_kl_penalty = var_config.get('use_kl_penalty', True)
self.vllm_engines = None
self.num_vllm_engines = 0
self.vllm_tensor_parallel_size = var_config.get(
Expand Down Expand Up @@ -535,6 +546,8 @@ def init(self, state: State, logger: Logger):
if self.device_train_microbatch_size == 'auto': # type: ignore
raise ValueError('auto microbatching is not supported for PPO')

self.global_iter_batch_size = self.num_batches_per_update * self.global_train_batch_size

# The KL penalty in the reward should only exist if we aren't minimizing
# the KL directly in the loss.
kl_penalty_in_reward = True
Expand All @@ -550,6 +563,7 @@ def init(self, state: State, logger: Logger):
fsdp_config=self.non_train_fsdp_config,
precision=state.precision,
kl_penalty_in_reward=kl_penalty_in_reward,
use_kl_penalty=self.use_kl_penalty,
)

# This is needed to ensure PyTorch 2.4 checkpointing doesn't break
Expand Down Expand Up @@ -602,13 +616,69 @@ def after_load(self, state: State, logger: Logger):
def iteration_start(self, state: State, logger: Logger):
del logger # unused

batch = self._get_next_iter_prompts()
batch = state.device.batch_to_device(batch)

batch = self._get_next_iter_prompts(state)
if self.vllm_engines is not None:
self._update_inference_model(batch)

self._interact_with_env(batch)
num_env_interactions = 0
while len(self.buffer) < self.num_batches_per_update:
if num_env_interactions > 0:
batch = self._get_next_iter_prompts(state)

num_env_interactions += 1

# TODO: the case where we are not filtering
# We do not do an all gather, so this logic is slightly wrong right now
self._interact_with_env(batch)

cur_global_samples = stack_resolved_outputs(
self.global_sample_list,
self.pad_token_idx,
)

bs = cur_global_samples['prompt_id'].shape[0]

log.info(f"Current global batch size is {bs}.")
log.info(
f"Current global iter batch size is {self.global_iter_batch_size}.",
)

if bs >= self.global_iter_batch_size:
log.info(
'We have enough samples, adding samples to the buffer.',
)
rank = dist.get_global_rank()
world_size = dist.get_world_size()

local_samples = {}
for key, value in cur_global_samples.items():
local_samples[key] = partition_batch(
value,
world_size=world_size,
rank=rank,
device_train_batch_size=self.device_train_batch_size,
)

local_bs = local_samples['prompt_id'].shape[0]
# Add the local samples to the buffer
for idx in range(local_bs // self.device_train_batch_size):
minibatch = self._extract_minibatch(
batch=local_samples,
idx=idx,
minibatch_size=self.device_train_batch_size,
)
self.buffer.add(minibatch)

log.info(
f"For iteration {self.iter_num}, we have {len(self.buffer)} samples in the buffer. Starting training.",
)
log.info(
f"It took {num_env_interactions} environment interactions to fill the buffer.",
)

# Making sure we correctly parsed the minibatches
assert len(self.buffer) >= self.num_batches_per_update

# Reset and initialize state train dataloader
log.warning(
'trainer._train_data_spec should be updated whenever the dataloader is updated',
Expand All @@ -635,16 +705,25 @@ def iteration_end(self, state: State, logger: Logger):
del logger # unused
self._log_generations_to_logger(state)
self._increment_rl_iter()

self.buffer.reset()

# A list of all samples across ranks
# These can be filtered or unfiltered
self.global_sample_list = []

self.buffer.set_state_dict(
self.train_prompt_loader.state_dict(), # pyright: ignore
0,
)

def _get_next_iter_prompts(self):
def _get_next_iter_prompts(self, state: State):
"""Gets the next iteration's batch of prompts."""
# Sample fewer batches for the Online RL interation depending on the number of generations per prompt
n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size
log.info(
f"Getting {n_unique_batches} unique batches of prompts for the current iteration.",
)

batches = [
self._get_single_batch_prompts() for _ in range(n_unique_batches)
]
Expand Down Expand Up @@ -696,7 +775,7 @@ def _get_next_iter_prompts(self):
# this is an edge case that we will not hit currently, but just handling it as needed
ret_batch[key] = curr_values

return ret_batch
return state.device.batch_to_device(ret_batch)

def _get_single_batch_prompts(self):
"""Gets a single batch of prompts from the dataloader."""
Expand Down Expand Up @@ -770,6 +849,7 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
)
padded_sequences.append(padded_sequence)
sequences = torch.cat(padded_sequences, dim=0)

# Add the prepared sequences to the batch again
batch['sequences'] = sequences

Expand All @@ -795,8 +875,6 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
f'Finished reward computation for the rollout in {total_reward_time:.4f} seconds.',
)

self.prompts_and_gens.extend(prompts_and_gens)

gen_batch_partial_outputs = (env_outputs, ref_outputs, all_rewards_dict)
# For every partial output we want to resolve them together
# And compute the global per iteration batch advantage's mean and variance
Expand All @@ -805,17 +883,32 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
gen_batch_partial_outputs,
)

# We need to split the resolved outputs into minibatches
for idx in range(bs // self.device_train_batch_size):
minibatch = self._extract_minibatch(
resolved_outputs,
idx,
self.device_train_batch_size,
if self.same_reward_filter_threshold is not None:
log.info(
f"in reward thresholding, trying to filter with: {self.same_reward_filter_threshold}",
)
self.buffer.add(minibatch)
start_time = time.time()
all_gathered_outputs = dist.all_gather_object(resolved_outputs)

# Making sure we correctly parsed the minibatches
assert len(self.buffer) == self.num_batches_per_update
log.info(
f"It took {time.time() - start_time} seconds to gather all resolved outputs.",
)

all_resolved_outputs = stack_resolved_outputs(
all_gathered_outputs,
self.pad_token_idx,
)

# Filter the resolved outputs based on the generation filtering values
resolved_outputs = filter_resolved_outputs(
all_resolved_outputs,
self.same_reward_filter_threshold,
)

self.global_sample_list.append(resolved_outputs)

# TODO: bcui fix
self.prompts_and_gens.extend(prompts_and_gens)

self.actor_critic.train()

Expand Down
5 changes: 5 additions & 0 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ def __init__(
self.kl_clip_range = kl_clip_range
self.entropy_loss_weight = entropy_loss_weight

print(
'length normalize policy loss: ',
self.length_normalize_policy_loss,
)

def forward(self, batch: MutableMapping):
ret_val = composer_online_rl_forward(
batch,
Expand Down
14 changes: 11 additions & 3 deletions compose_rl/algorithms/online/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ def policy_loss(
policy_kl = utils.masked_sum(
policy_kl_dict[kl_estimator], # pyright: ignore
batch['action_mask'],
)
) / batch['max_gen_len']
online_ift_kl = utils.masked_sum(
online_ift_kl_dict[kl_estimator], # pyright: ignore
batch['action_mask'],
)
) / batch['max_gen_len']

ratio = torch.exp(online_log_probs - old_log_probs)
policy_loss_1 = -advantages * ratio
Expand All @@ -319,7 +319,15 @@ def policy_loss(
batch['action_mask'],
)
else:
policy_loss = utils.masked_sum(policy_loss, batch['action_mask'])
print(
'loss was: ',
utils.masked_sum(policy_loss, batch['action_mask']),
)
policy_loss = utils.masked_sum(
policy_loss,
batch['action_mask'],
) / batch['max_gen_len']
print('loss after now is: ', policy_loss)

policy_token_kl_logging_dict = {
f'token_kl/policy_token_kl_{k}_estimate':
Expand Down
Loading
Loading