Skip to content

Commit 756223c

Browse files
committed
fix an issue with modality being given as prompt, with subsequent modalities decoded
1 parent 23e43de commit 756223c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.10.4"
3+
version = "0.10.5"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

transfusion_pytorch/transfusion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,10 +1580,12 @@ def sample(
15801580
if is_tensor(prompt) and prompt.dtype == torch.float: # is modality with type 0 implicit
15811581
prompt = (0, prompt)
15821582

1583+
prompt_is_modality = isinstance(prompt, tuple)
1584+
15831585
if is_tensor(prompt) and prompt.dtype in (torch.int, torch.long): # is text only prompt
15841586
prompt = [prompt]
15851587

1586-
elif isinstance(prompt, tuple):
1588+
elif prompt_is_modality:
15871589
modality_type, modality = prompt
15881590

15891591
mod = self.get_modality_info(modality_type)
@@ -1625,7 +1627,8 @@ def sample(
16251627
curr_length = 0
16261628
curr_modality_id = None
16271629
modality_shape = None
1628-
num_past_modalities = 0 # starts off with no modalities in output
1630+
1631+
num_past_modalities = int(prompt_is_modality) # either 0 or 1 (if the prompt given is a modality)
16291632

16301633
text_is_greedy = text_temperature == 0.
16311634
is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected

0 commit comments

Comments
 (0)