Skip to content

Commit ef27a18

Browse files
committed
Hires fix rework
1 parent fd4461d commit ef27a18

File tree

7 files changed

+96
-60
lines changed

7 files changed

+96
-60
lines changed

modules/generation_parameters_copypaste.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import io
3+
import math
34
import os
45
import re
56
from pathlib import Path
@@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
164165
return None
165166

166167

168+
def restore_old_hires_fix_params(res):
169+
"""for infotexts that specify old First pass size parameter, convert it into
170+
width, height, and hr scale"""
171+
172+
firstpass_width = res.get('First pass size-1', None)
173+
firstpass_height = res.get('First pass size-2', None)
174+
175+
if firstpass_width is None or firstpass_height is None:
176+
return
177+
178+
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
179+
width = int(res.get("Size-1", 512))
180+
height = int(res.get("Size-2", 512))
181+
182+
if firstpass_width == 0 or firstpass_height == 0:
183+
# old algorithm for auto-calculating first pass size
184+
desired_pixel_count = 512 * 512
185+
actual_pixel_count = width * height
186+
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
187+
firstpass_width = math.ceil(scale * width / 64) * 64
188+
firstpass_height = math.ceil(scale * height / 64) * 64
189+
190+
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
191+
192+
res['Size-1'] = firstpass_width
193+
res['Size-2'] = firstpass_height
194+
res['Hires upscale'] = hr_scale
195+
196+
167197
def parse_generation_parameters(x: str):
168198
"""parses generation parameters string, the one you see in text field under the picture in UI:
169199
```
@@ -221,6 +251,8 @@ def parse_generation_parameters(x: str):
221251
hypernet_hash = res.get("Hypernet hash", None)
222252
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
223253

254+
restore_old_hires_fix_params(res)
255+
224256
return res
225257

226258

modules/images.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
230230
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
231231

232232

233-
def resize_image(resize_mode, im, width, height):
233+
def resize_image(resize_mode, im, width, height, upscaler_name=None):
234+
"""
235+
Resizes an image with the specified resize_mode, width, and height.
236+
237+
Args:
238+
resize_mode: The mode to use when resizing the image.
239+
0: Resize the image to the specified width and height.
240+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
241+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
242+
im: The image to resize.
243+
width: The width to resize the image to.
244+
height: The height to resize the image to.
245+
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
246+
"""
247+
248+
upscaler_name = upscaler_name or opts.upscaler_for_img2img
249+
234250
def resize(im, w, h):
235-
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
251+
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
236252
return im.resize((w, h), resample=LANCZOS)
237253

238254
scale = max(w / im.width, h / im.height)
239255

240256
if scale > 1.0:
241-
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
242-
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
257+
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
258+
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
243259

244260
upscaler = upscalers[0]
245261
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)

modules/processing.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,18 @@ def infotext(iteration=0, position_in_batch=0):
658658
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
659659
sampler = None
660660

661-
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
661+
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
662662
super().__init__(**kwargs)
663663
self.enable_hr = enable_hr
664664
self.denoising_strength = denoising_strength
665-
self.firstphase_width = firstphase_width
666-
self.firstphase_height = firstphase_height
667-
self.truncate_x = 0
668-
self.truncate_y = 0
665+
self.hr_scale = hr_scale
666+
self.hr_upscaler = hr_upscaler
667+
668+
if firstphase_width != 0 or firstphase_height != 0:
669+
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
670+
self.hr_scale = self.width / firstphase_width
671+
self.width = firstphase_width
672+
self.height = firstphase_height
669673

670674
def init(self, all_prompts, all_seeds, all_subseeds):
671675
if self.enable_hr:
@@ -674,47 +678,29 @@ def init(self, all_prompts, all_seeds, all_subseeds):
674678
else:
675679
state.job_count = state.job_count * 2
676680

677-
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
678-
679-
if self.firstphase_width == 0 or self.firstphase_height == 0:
680-
desired_pixel_count = 512 * 512
681-
actual_pixel_count = self.width * self.height
682-
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
683-
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
684-
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
685-
firstphase_width_truncated = int(scale * self.width)
686-
firstphase_height_truncated = int(scale * self.height)
687-
688-
else:
689-
690-
width_ratio = self.width / self.firstphase_width
691-
height_ratio = self.height / self.firstphase_height
692-
693-
if width_ratio > height_ratio:
694-
firstphase_width_truncated = self.firstphase_width
695-
firstphase_height_truncated = self.firstphase_width * self.height / self.width
696-
else:
697-
firstphase_width_truncated = self.firstphase_height * self.width / self.height
698-
firstphase_height_truncated = self.firstphase_height
699-
700-
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
701-
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
681+
self.extra_generation_params["Hires upscale"] = self.hr_scale
682+
if self.hr_upscaler is not None:
683+
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
702684

703685
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
704686
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
705687

688+
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
689+
if self.enable_hr and latent_scale_mode is None:
690+
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
691+
692+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
693+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
694+
706695
if not self.enable_hr:
707-
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
708-
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
709696
return samples
710697

711-
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
712-
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
713-
714-
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
698+
target_width = int(self.width * self.hr_scale)
699+
target_height = int(self.height * self.hr_scale)
715700

716-
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
717701
def save_intermediate(image, index):
702+
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
703+
718704
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
719705
return
720706

@@ -723,11 +709,11 @@ def save_intermediate(image, index):
723709

724710
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
725711

726-
if opts.use_scale_latent_for_hires_fix:
712+
if latent_scale_mode is not None:
727713
for i in range(samples.shape[0]):
728714
save_intermediate(samples, i)
729715

730-
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
716+
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
731717

732718
# Avoid making the inpainting conditioning unless necessary as
733719
# this does need some extra compute to decode / encode the image again.
@@ -747,7 +733,7 @@ def save_intermediate(image, index):
747733

748734
save_intermediate(image, i)
749735

750-
image = images.resize_image(0, image, self.width, self.height)
736+
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
751737
image = np.array(image).astype(np.float32) / 255.0
752738
image = np.moveaxis(image, 2, 0)
753739
batch_images.append(image)
@@ -764,7 +750,7 @@ def save_intermediate(image, index):
764750

765751
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
766752

767-
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
753+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
768754

769755
# GC now before running the next img2img to prevent running out of memory
770756
x = None

modules/shared.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ def list_samplers():
327327
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
328328
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
329329
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
330-
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
331330
}))
332331

333332
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -545,6 +544,12 @@ def reorder(self):
545544
if os.path.exists(config_filename):
546545
opts.load(config_filename)
547546

547+
latent_upscale_default_mode = "Latent"
548+
latent_upscale_modes = {
549+
"Latent": "bilinear",
550+
"Latent (nearest)": "nearest",
551+
}
552+
548553
sd_upscalers = []
549554

550555
sd_model = None

modules/txt2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from modules.ui import plaintext_to_html
99

1010

11-
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
11+
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
1212
p = StableDiffusionProcessingTxt2Img(
1313
sd_model=shared.sd_model,
1414
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
3333
tiling=tiling,
3434
enable_hr=enable_hr,
3535
denoising_strength=denoising_strength if enable_hr else None,
36-
firstphase_width=firstphase_width if enable_hr else None,
37-
firstphase_height=firstphase_height if enable_hr else None,
36+
hr_scale=hr_scale,
37+
hr_upscaler=hr_upscaler,
3838
)
3939

4040
p.scripts = modules.scripts.scripts_txt2img

modules/ui.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,11 +684,11 @@ def create_ui():
684684
with gr.Row():
685685
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
686686
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
687-
enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr")
687+
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
688688

689689
with gr.Row(visible=False) as hr_options:
690-
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width")
691-
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height")
690+
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
691+
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
692692
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
693693

694694
with gr.Row(equal_height=True):
@@ -729,8 +729,8 @@ def create_ui():
729729
width,
730730
enable_hr,
731731
denoising_strength,
732-
firstphase_width,
733-
firstphase_height,
732+
hr_scale,
733+
hr_upscaler,
734734
] + custom_inputs,
735735

736736
outputs=[
@@ -762,7 +762,6 @@ def create_ui():
762762
outputs=[hr_options],
763763
)
764764

765-
766765
txt2img_paste_fields = [
767766
(txt2img_prompt, "Prompt"),
768767
(txt2img_negative_prompt, "Negative prompt"),
@@ -781,8 +780,8 @@ def create_ui():
781780
(denoising_strength, "Denoising strength"),
782781
(enable_hr, lambda d: "Denoising strength" in d),
783782
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
784-
(firstphase_width, "First pass size-1"),
785-
(firstphase_height, "First pass size-2"),
783+
(hr_scale, "Hires upscale"),
784+
(hr_upscaler, "Hires upscaler"),
786785
*modules.scripts.scripts_txt2img.infotext_fields
787786
]
788787
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)

scripts/xy_grid.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def str_permutations(x):
202202
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
203203
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
204204
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
205-
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None),
205+
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
206206
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
207207
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
208208
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
@@ -267,7 +267,6 @@ def __enter__(self):
267267
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
268268
self.hypernetwork = opts.sd_hypernetwork
269269
self.model = shared.sd_model
270-
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
271270
self.vae = opts.sd_vae
272271

273272
def __exit__(self, exc_type, exc_value, tb):
@@ -278,7 +277,6 @@ def __exit__(self, exc_type, exc_value, tb):
278277
hypernetwork.apply_strength()
279278

280279
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
281-
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
282280

283281

284282
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")

0 commit comments

Comments
 (0)