Skip to content

Commit f45016a

Browse files
committed
Update wd_like.py
1 parent 238318c commit f45016a

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

scripts/t2p/prompt_generator/wd_like.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
128128
if opts.n <= 0: return []
129129
if opts.weighted:
130130
probs_np = probs_cpu.detach().numpy()
131-
132-
if np.count_nonzero(probs_np) <= opts.n:
133-
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
131+
num_nonzero = np.count_nonzero(probs_np)
132+
if num_nonzero <= opts.n:
133+
if num_nonzero > 0:
134+
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
135+
else:
136+
results = np.random.choice(tags_np, opts.n, replace=False)
134137
else:
135138
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
136139
else:
@@ -148,9 +151,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
148151
probs_np = probs.detach().numpy()
149152
probs_np /= np.sum(probs_np)
150153
probs_np = np.nan_to_num(probs_np)
151-
152-
if np.count_nonzero(probs_np) <= opts.n:
153-
results = np.random.choice(tags_np, opts.n, replace=False)
154+
num_nonzero = np.count_nonzero(probs_np)
155+
if num_nonzero <= opts.n:
156+
if num_nonzero > 0:
157+
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
158+
else:
159+
results = np.random.choice(tags_np, opts.n, replace=False)
154160
else:
155161
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
156162
else:
@@ -174,9 +180,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
174180
probs_np = np.array([sorted_probs[i] for i in indices])
175181
probs_np /= np.sum(probs_np)
176182
probs_np = np.nan_to_num(probs_np)
177-
178-
if np.count_nonzero(probs_np) <= opts.n:
179-
results = np.random.choice(tags_np, opts.n, replace=False)
183+
num_nonzero = np.count_nonzero(probs_np)
184+
if num_nonzero <= opts.n:
185+
if num_nonzero > 0:
186+
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
187+
else:
188+
results = np.random.choice(tags_np, opts.n, replace=False)
180189
else:
181190
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
182191
else:

0 commit comments

Comments
 (0)