diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..ca9f62b 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -19,6 +19,21 @@ def inverse_sigmoid(x): return torch.log(x/(1-x)) def PILtoTorch(pil_image, resolution): + # When resizing RGBA, PIL pre-multiplies the resulting RGB with the resized alpha channel. This gives + # different training behaviors depending on whether the image is actually resized (via -r flag) or not. + # Moreover, the resized alpha is no longer a perfect binary image due to interpolation, which produces + # a significant amount of floaters along the edges. To fix this, we manually mask the RGB if the input + # is an RGBA, then we forget the alpha channel entirely. The multiplication of the rendered image with + # the alpha_mask during training thus becomes a no-op for RGBA. + if pil_image.mode == 'RGBA': + from PIL import Image + image_np = np.array(pil_image) + rgb_np = image_np[..., :3] + alpha_np = image_np[..., 3:] + masked_rgb_np = (rgb_np / 255.0) * (alpha_np / 255.0) + masked_rgb_np = np.clip(masked_rgb_np, 0.0, 1.0) + pil_image = Image.fromarray((masked_rgb_np * 255).astype(np.uint8)) + resized_image_PIL = pil_image.resize(resolution) resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 if len(resized_image.shape) == 3: