@@ -658,14 +658,18 @@ def infotext(iteration=0, position_in_batch=0):
658
658
class StableDiffusionProcessingTxt2Img (StableDiffusionProcessing ):
659
659
sampler = None
660
660
661
- def __init__ (self , enable_hr : bool = False , denoising_strength : float = 0.75 , firstphase_width : int = 0 , firstphase_height : int = 0 , ** kwargs ):
661
+ def __init__ (self , enable_hr : bool = False , denoising_strength : float = 0.75 , firstphase_width : int = 0 , firstphase_height : int = 0 , hr_scale : float = 2.0 , hr_upscaler : str = None , ** kwargs ):
662
662
super ().__init__ (** kwargs )
663
663
self .enable_hr = enable_hr
664
664
self .denoising_strength = denoising_strength
665
- self .firstphase_width = firstphase_width
666
- self .firstphase_height = firstphase_height
667
- self .truncate_x = 0
668
- self .truncate_y = 0
665
+ self .hr_scale = hr_scale
666
+ self .hr_upscaler = hr_upscaler
667
+
668
+ if firstphase_width != 0 or firstphase_height != 0 :
669
+ print ("firstphase_width/firstphase_height no longer supported; use hr_scale" , file = sys .stderr )
670
+ self .hr_scale = self .width / firstphase_width
671
+ self .width = firstphase_width
672
+ self .height = firstphase_height
669
673
670
674
def init (self , all_prompts , all_seeds , all_subseeds ):
671
675
if self .enable_hr :
@@ -674,47 +678,29 @@ def init(self, all_prompts, all_seeds, all_subseeds):
674
678
else :
675
679
state .job_count = state .job_count * 2
676
680
677
- self .extra_generation_params ["First pass size" ] = f"{ self .firstphase_width } x{ self .firstphase_height } "
678
-
679
- if self .firstphase_width == 0 or self .firstphase_height == 0 :
680
- desired_pixel_count = 512 * 512
681
- actual_pixel_count = self .width * self .height
682
- scale = math .sqrt (desired_pixel_count / actual_pixel_count )
683
- self .firstphase_width = math .ceil (scale * self .width / 64 ) * 64
684
- self .firstphase_height = math .ceil (scale * self .height / 64 ) * 64
685
- firstphase_width_truncated = int (scale * self .width )
686
- firstphase_height_truncated = int (scale * self .height )
687
-
688
- else :
689
-
690
- width_ratio = self .width / self .firstphase_width
691
- height_ratio = self .height / self .firstphase_height
692
-
693
- if width_ratio > height_ratio :
694
- firstphase_width_truncated = self .firstphase_width
695
- firstphase_height_truncated = self .firstphase_width * self .height / self .width
696
- else :
697
- firstphase_width_truncated = self .firstphase_height * self .width / self .height
698
- firstphase_height_truncated = self .firstphase_height
699
-
700
- self .truncate_x = int (self .firstphase_width - firstphase_width_truncated ) // opt_f
701
- self .truncate_y = int (self .firstphase_height - firstphase_height_truncated ) // opt_f
681
+ self .extra_generation_params ["Hires upscale" ] = self .hr_scale
682
+ if self .hr_upscaler is not None :
683
+ self .extra_generation_params ["Hires upscaler" ] = self .hr_upscaler
702
684
703
685
def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ):
704
686
self .sampler = sd_samplers .create_sampler (self .sampler_name , self .sd_model )
705
687
688
+ latent_scale_mode = shared .latent_upscale_modes .get (self .hr_upscaler , None ) if self .hr_upscaler is not None else shared .latent_upscale_default_mode
689
+ if self .enable_hr and latent_scale_mode is None :
690
+ assert len ([x for x in shared .sd_upscalers if x .name == self .hr_upscaler ]) > 0 , f"could not find upscaler named { self .hr_upscaler } "
691
+
692
+ x = create_random_tensors ([opt_C , self .height // opt_f , self .width // opt_f ], seeds = seeds , subseeds = subseeds , subseed_strength = self .subseed_strength , seed_resize_from_h = self .seed_resize_from_h , seed_resize_from_w = self .seed_resize_from_w , p = self )
693
+ samples = self .sampler .sample (self , x , conditioning , unconditional_conditioning , image_conditioning = self .txt2img_image_conditioning (x ))
694
+
706
695
if not self .enable_hr :
707
- x = create_random_tensors ([opt_C , self .height // opt_f , self .width // opt_f ], seeds = seeds , subseeds = subseeds , subseed_strength = self .subseed_strength , seed_resize_from_h = self .seed_resize_from_h , seed_resize_from_w = self .seed_resize_from_w , p = self )
708
- samples = self .sampler .sample (self , x , conditioning , unconditional_conditioning , image_conditioning = self .txt2img_image_conditioning (x ))
709
696
return samples
710
697
711
- x = create_random_tensors ([opt_C , self .firstphase_height // opt_f , self .firstphase_width // opt_f ], seeds = seeds , subseeds = subseeds , subseed_strength = self .subseed_strength , seed_resize_from_h = self .seed_resize_from_h , seed_resize_from_w = self .seed_resize_from_w , p = self )
712
- samples = self .sampler .sample (self , x , conditioning , unconditional_conditioning , image_conditioning = self .txt2img_image_conditioning (x , self .firstphase_width , self .firstphase_height ))
713
-
714
- samples = samples [:, :, self .truncate_y // 2 :samples .shape [2 ]- self .truncate_y // 2 , self .truncate_x // 2 :samples .shape [3 ]- self .truncate_x // 2 ]
698
+ target_width = int (self .width * self .hr_scale )
699
+ target_height = int (self .height * self .hr_scale )
715
700
716
- """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
717
701
def save_intermediate (image , index ):
702
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
703
+
718
704
if not opts .save or self .do_not_save_samples or not opts .save_images_before_highres_fix :
719
705
return
720
706
@@ -723,11 +709,11 @@ def save_intermediate(image, index):
723
709
724
710
images .save_image (image , self .outpath_samples , "" , seeds [index ], prompts [index ], opts .samples_format , suffix = "-before-highres-fix" )
725
711
726
- if opts . use_scale_latent_for_hires_fix :
712
+ if latent_scale_mode is not None :
727
713
for i in range (samples .shape [0 ]):
728
714
save_intermediate (samples , i )
729
715
730
- samples = torch .nn .functional .interpolate (samples , size = (self . height // opt_f , self . width // opt_f ), mode = "bilinear" )
716
+ samples = torch .nn .functional .interpolate (samples , size = (target_height // opt_f , target_width // opt_f ), mode = latent_scale_mode )
731
717
732
718
# Avoid making the inpainting conditioning unless necessary as
733
719
# this does need some extra compute to decode / encode the image again.
@@ -747,7 +733,7 @@ def save_intermediate(image, index):
747
733
748
734
save_intermediate (image , i )
749
735
750
- image = images .resize_image (0 , image , self . width , self .height )
736
+ image = images .resize_image (0 , image , target_width , target_height , upscaler_name = self .hr_upscaler )
751
737
image = np .array (image ).astype (np .float32 ) / 255.0
752
738
image = np .moveaxis (image , 2 , 0 )
753
739
batch_images .append (image )
@@ -764,7 +750,7 @@ def save_intermediate(image, index):
764
750
765
751
self .sampler = sd_samplers .create_sampler (self .sampler_name , self .sd_model )
766
752
767
- noise = create_random_tensors (samples .shape [1 :], seeds = seeds , subseeds = subseeds , subseed_strength = subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
753
+ noise = create_random_tensors (samples .shape [1 :], seeds = seeds , subseeds = subseeds , subseed_strength = subseed_strength , p = self )
768
754
769
755
# GC now before running the next img2img to prevent running out of memory
770
756
x = None
0 commit comments