This commit is contained in:
Minh Nguyen 2025-03-19 14:56:32 +01:00 committed by GitHub
commit 31bd7a5bbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,6 +19,21 @@ def inverse_sigmoid(x):
return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution):
# When resizing RGBA, PIL pre-multiplies the resulting RGB with the resized alpha channel. This gives
# different training behaviors depending on whether the image is actually resized (via -r flag) or not.
# Moreover, the resized alpha is no longer a perfect binary image due to interpolation, which produces
# a significant amount of floaters along the edges. To fix this, we manually mask the RGB if the input
# is an RGBA, then we forget the alpha channel entirely. The multiplication of the rendered image with
# the alpha_mask during training thus becomes a no-op for RGBA.
if pil_image.mode == 'RGBA':
from PIL import Image
image_np = np.array(pil_image)
rgb_np = image_np[..., :3]
alpha_np = image_np[..., 3:]
masked_rgb_np = (rgb_np / 255.0) * (alpha_np / 255.0)
masked_rgb_np = np.clip(masked_rgb_np, 0.0, 1.0)
pil_image = Image.fromarray((masked_rgb_np * 255).astype(np.uint8))
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
if len(resized_image.shape) == 3: