@@ -128,9 +128,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
128
128
if opts .n <= 0 : return []
129
129
if opts .weighted :
130
130
probs_np = probs_cpu .detach ().numpy ()
131
-
132
- if np .count_nonzero (probs_np ) <= opts .n :
133
- results = np .random .choice (a = tags_np , size = opts .n , replace = False )
131
+ num_nonzero = np .count_nonzero (probs_np )
132
+ if num_nonzero <= opts .n :
133
+ if num_nonzero > 0 :
134
+ results = np .random .choice (tags_np , num_nonzero , replace = False , p = probs_np )
135
+ else :
136
+ results = np .random .choice (tags_np , opts .n , replace = False )
134
137
else :
135
138
results = np .random .choice (a = tags_np , size = opts .n , replace = False , p = probs_np )
136
139
else :
@@ -148,9 +151,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
148
151
probs_np = probs .detach ().numpy ()
149
152
probs_np /= np .sum (probs_np )
150
153
probs_np = np .nan_to_num (probs_np )
151
-
152
- if np .count_nonzero (probs_np ) <= opts .n :
153
- results = np .random .choice (tags_np , opts .n , replace = False )
154
+ num_nonzero = np .count_nonzero (probs_np )
155
+ if num_nonzero <= opts .n :
156
+ if num_nonzero > 0 :
157
+ results = np .random .choice (tags_np , num_nonzero , replace = False , p = probs_np )
158
+ else :
159
+ results = np .random .choice (tags_np , opts .n , replace = False )
154
160
else :
155
161
results = np .random .choice (tags_np , opts .n , replace = False , p = probs_np )
156
162
else :
@@ -174,9 +180,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
174
180
probs_np = np .array ([sorted_probs [i ] for i in indices ])
175
181
probs_np /= np .sum (probs_np )
176
182
probs_np = np .nan_to_num (probs_np )
177
-
178
- if np .count_nonzero (probs_np ) <= opts .n :
179
- results = np .random .choice (tags_np , opts .n , replace = False )
183
+ num_nonzero = np .count_nonzero (probs_np )
184
+ if num_nonzero <= opts .n :
185
+ if num_nonzero > 0 :
186
+ results = np .random .choice (tags_np , num_nonzero , replace = False , p = probs_np )
187
+ else :
188
+ results = np .random .choice (tags_np , opts .n , replace = False )
180
189
else :
181
190
results = np .random .choice (tags_np , opts .n , replace = False , p = probs_np )
182
191
else :
0 commit comments