diff --git a/main.py b/main.py index 0845d58..6bddd0a 100644 --- a/main.py +++ b/main.py @@ -7,49 +7,44 @@ # # Copyright (c) 2021 Solal Nathan # Author: Solal "Otthorn" Nathan -# SPDX-License-Identifier: BSD-3-Clause +# SPDX-License-Identifier: BSD-3-Clause # ############################################################################### +import glob +import json +import math import os import sys import time -import json -import glob -import torch -import math + import matplotlib.pyplot as plt import numpy as np +import torch from PIL import Image -from tqdm import tqdm - +from pytorch3d.renderer import (EmissionAbsorptionRaymarcher, + FoVPerspectiveCameras, ImplicitRenderer, + MonteCarloRaysampler, NDCGridRaysampler, + RayBundle, ray_bundle_to_ray_points) # Data structures and functions for rendering from pytorch3d.structures import Volumes from pytorch3d.transforms import so3_exponential_map -from pytorch3d.renderer import ( - FoVPerspectiveCameras, - NDCGridRaysampler, - MonteCarloRaysampler, - EmissionAbsorptionRaymarcher, - ImplicitRenderer, - RayBundle, - ray_bundle_to_ray_points, -) +from tqdm import tqdm # add path for demo utils functions -sys.path.append(os.path.abspath('')) -from utils.plot_image_grid import image_grid +sys.path.append(os.path.abspath("")) from utils.generate_cow_renders import generate_cow_renders - +from utils.plot_image_grid import image_grid # Intialize CUDA gpu device = torch.device("cuda:0") torch.cuda.set_device(device) # Generate dataset -target_cameras, target_images, target_silhouettes = \ - generate_cow_renders(num_views=40, azimuth_range=180) -print(f'Generated {len(target_images)} images/silhouettes/cameras.') +target_cameras, target_images, target_silhouettes = generate_cow_renders( + num_views=40, azimuth_range=180 +) +print(f"Generated {len(target_images)} images/silhouettes/cameras.") ############################################################################### # Intitialize the implicit rendered @@ -84,10 +79,10 @@ raysampler_grid = NDCGridRaysampler( # MonteCarloRaysampler generates a random subset # of `n_rays_per_image` rays emitted from the image plane. raysampler_mc = MonteCarloRaysampler( - min_x = -1.0, - max_x = 1.0, - min_y = -1.0, - max_y = 1.0, + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, n_rays_per_image=750, n_pts_per_ray=128, min_depth=0.1, @@ -104,16 +99,19 @@ raymarcher = EmissionAbsorptionRaymarcher() # Finally, instantiate the implicit renders # for both raysamplers. renderer_grid = ImplicitRenderer( - raysampler=raysampler_grid, raymarcher=raymarcher, + raysampler=raysampler_grid, + raymarcher=raymarcher, ) renderer_mc = ImplicitRenderer( - raysampler=raysampler_mc, raymarcher=raymarcher, + raysampler=raysampler_mc, + raymarcher=raymarcher, ) ############################################################################### # Define the NeRF model ############################################################################### + class HarmonicEmbedding(torch.nn.Module): def __init__(self, n_harmonic_functions=60, omega0=0.1): """ @@ -133,15 +131,16 @@ class HarmonicEmbedding(torch.nn.Module): ... cos(2**self.n_harmonic_functions * x[..., i]) ] - + Note that `x` is also premultiplied by `omega0` before evaluting the harmonic functions. """ super().__init__() self.register_buffer( - 'frequencies', + "frequencies", omega0 * (2.0 ** torch.arange(n_harmonic_functions)), ) + def forward(self, x): """ Args: @@ -164,15 +163,15 @@ class NeuralRadianceField(torch.nn.Module): n_hidden_neurons: The number of hidden units in the fully connected layers of the MLPs of the model. """ - + # The harmonic embedding layer converts input 3D coordinates # to a representation that is more suitable for # processing with a deep neural network. self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions) - + # The dimension of the harmonic embedding. embedding_dim = n_harmonic_functions * 2 * 3 - + # self.mlp is a simple 2-layer multi-layer perceptron # which converts the input per-point harmonic embeddings # to a latent representation. @@ -182,8 +181,8 @@ class NeuralRadianceField(torch.nn.Module): torch.nn.Softplus(beta=10.0), torch.nn.Linear(n_hidden_neurons, n_hidden_neurons), torch.nn.Softplus(beta=10.0), - ) - + ) + # Given features predicted by self.mlp, self.color_layer # is responsible for predicting a 3-D per-point vector # that represents the RGB color of the point. @@ -192,89 +191,80 @@ class NeuralRadianceField(torch.nn.Module): torch.nn.Softplus(beta=10.0), torch.nn.Linear(n_hidden_neurons, 3), torch.nn.Sigmoid(), - # To ensure that the colors correctly range between [0-1], - # the layer is terminated with a sigmoid layer. - ) - - # The density layer converts the features of self.mlp - # to a 1D density value representing the raw opacity - # of each point. + # To ensure that the colors correctly range between [0-1], the + # layer is terminated with a sigmoid layer. + ) + + # The density layer converts the features of self.mlp to a 1D density + # value representing the raw opacity of each point. self.density_layer = torch.nn.Sequential( torch.nn.Linear(n_hidden_neurons, 1), torch.nn.Softplus(beta=10.0), # Sofplus activation ensures that the raw opacity # is a non-negative number. ) - - # We set the bias of the density layer to -1.5 - # in order to initialize the opacities of the - # ray points to values close to 0. - # This is a crucial detail for ensuring convergence - # of the model. - self.density_layer[0].bias.data[0] = -1.5 - + + # We set the bias of the density layer to -1.5 in order to initialize + # the opacities of the ray points to values close to 0. This is a + # crucial detail for ensuring convergence of the model. + self.density_layer[0].bias.data[0] = -1.5 + def _get_densities(self, features): """ - This function takes `features` predicted by `self.mlp` - and converts them to `raw_densities` with `self.density_layer`. - `raw_densities` are later mapped to [0-1] range with - 1 - inverse exponential of `raw_densities`. + This function takes `features` predicted by `self.mlp` and converts + them to `raw_densities` with `self.density_layer`. `raw_densities` are + later mapped to [0-1] range with 1 - inverse exponential of + `raw_densities`. """ raw_densities = self.density_layer(features) return 1 - (-raw_densities).exp() - + def _get_colors(self, features, rays_directions): """ - This function takes per-point `features` predicted by `self.mlp` - and evaluates the color model in order to attach to each - point a 3D vector of its RGB color. - - In order to represent viewpoint dependent effects, - before evaluating `self.color_layer`, `NeuralRadianceField` - concatenates to the `features` a harmonic embedding - of `ray_directions`, which are per-point directions - of point rays expressed as 3D l2-normalized vectors - in world coordinates. + This function takes per-point `features` predicted by `self.mlp` and + evaluates the color model in order to attach to each point a 3D vector + of its RGB color. + + In order to represent viewpoint dependent effects, before evaluating + `self.color_layer`, `NeuralRadianceField` concatenates to the + `features` a harmonic embedding of `ray_directions`, which are + per-point directions of point rays expressed as 3D l2-normalized + vectors in world coordinates. """ spatial_size = features.shape[:-1] - + # Normalize the ray_directions to unit l2 norm. rays_directions_normed = torch.nn.functional.normalize( rays_directions, dim=-1 ) - + # Obtain the harmonic embedding of the normalized ray directions. - rays_embedding = self.harmonic_embedding( - rays_directions_normed - ) - + rays_embedding = self.harmonic_embedding(rays_directions_normed) + # Expand the ray directions tensor so that its spatial size # is equal to the size of features. rays_embedding_expand = rays_embedding[..., None, :].expand( *spatial_size, rays_embedding.shape[-1] ) - - # Concatenate ray direction embeddings with + + # Concatenate ray direction embeddings with # features and evaluate the color model. color_layer_input = torch.cat( - (features, rays_embedding_expand), - dim=-1 + (features, rays_embedding_expand), dim=-1 ) return self.color_layer(color_layer_input) - - + def forward( - self, + self, ray_bundle: RayBundle, **kwargs, ): """ - The forward function accepts the parametrizations of - 3D points sampled along projection rays. The forward - pass is responsible for attaching a 3D vector - and a 1D scalar representing the point's - RGB color and opacity respectively. - + The forward function accepts the parametrizations of 3D points sampled + along projection rays. The forward pass is responsible for attaching a + 3D vector and a 1D scalar representing the point's RGB color and + opacity respectively. + Args: ray_bundle: A RayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the @@ -294,44 +284,42 @@ class NeuralRadianceField(torch.nn.Module): # coordinates with `ray_bundle_to_ray_points`. rays_points_world = ray_bundle_to_ray_points(ray_bundle) # rays_points_world.shape = [minibatch x ... x 3] - + # For each 3D world coordinate, we obtain its harmonic embedding. - embeds = self.harmonic_embedding( - rays_points_world - ) + embeds = self.harmonic_embedding(rays_points_world) # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6] - + # self.mlp maps each harmonic embedding to a latent feature space. features = self.mlp(embeds) # features.shape = [minibatch x ... x n_hidden_neurons] - - # Finally, given the per-point features, + + # Finally, given the per-point features, # execute the density and color branches. - + rays_densities = self._get_densities(features) # rays_densities.shape = [minibatch x ... x 1] rays_colors = self._get_colors(features, ray_bundle.directions) # rays_colors.shape = [minibatch x ... x 3] - + return rays_densities, rays_colors - + def batched_forward( - self, + self, ray_bundle: RayBundle, n_batches: int = 16, - **kwargs, + **kwargs, ): """ - This function is used to allow for memory efficient processing - of input rays. The input rays are first split to `n_batches` - chunks and passed through the `self.forward` function one at a time - in a for loop. Combined with disabling Pytorch gradient caching - (`torch.no_grad()`), this allows for rendering large batches - of rays that do not all fit into GPU memory in a single forward pass. - In our case, batched_forward is used to export a fully-sized render - of the radiance field for visualisation purposes. - + This function is used to allow for memory efficient processing of input + rays. The input rays are first split to `n_batches` chunks and passed + through the `self.forward` function one at a time in a for loop. + Combined with disabling Pytorch gradient caching (`torch.no_grad()`), + this allows for rendering large batches of rays that do not all fit + into GPU memory in a single forward pass. In our case, batched_forward + is used to export a fully-sized render of the radiance field for + visualisation purposes. + Args: ray_bundle: A RayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the @@ -353,7 +341,7 @@ class NeuralRadianceField(torch.nn.Module): """ # Parse out shapes needed for tensor reshaping in this function. - n_pts_per_ray = ray_bundle.lengths.shape[-1] + n_pts_per_ray = ray_bundle.lengths.shape[-1] spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray] # Split the rays to `n_batches` batches. @@ -366,34 +354,44 @@ class NeuralRadianceField(torch.nn.Module): RayBundle( origins=ray_bundle.origins.view(-1, 3)[batch_idx], directions=ray_bundle.directions.view(-1, 3)[batch_idx], - lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx], + lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[ + batch_idx + ], xys=None, ) - ) for batch_idx in batches + ) + for batch_idx in batches ] - + # Concatenate the per-batch rays_densities and rays_colors # and reshape according to the sizes of the inputs. rays_densities, rays_colors = [ torch.cat( - [batch_output[output_i] for batch_output in batch_outputs], dim=0 - ).view(*spatial_size, -1) for output_i in (0, 1) + [batch_output[output_i] for batch_output in batch_outputs], + dim=0, + ).view(*spatial_size, -1) + for output_i in (0, 1) ] return rays_densities, rays_colors + ############################################################################### # Helper functions ############################################################################### + def huber(x, y, scaling=0.1): """ A helper function for evaluating the smooth L1 (huber) loss between the rendered silhouettes and colors. """ diff_sq = (x - y) ** 2 - loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling) + loss = ((1 + diff_sq / (scaling ** 2)).clamp(1e-4).sqrt() - 1) * float( + scaling + ) return loss + def sample_images_at_mc_locs(target_images, sampled_rays_xy): """ Given a set of Monte Carlo pixel locations `sampled_rays_xy`, @@ -407,39 +405,36 @@ def sample_images_at_mc_locs(target_images, sampled_rays_xy): ba = target_images.shape[0] dim = target_images.shape[-1] spatial_size = sampled_rays_xy.shape[1:-1] - # In order to sample target_images, we utilize - # the grid_sample function which implements a - # bilinear image sampler. - # Note that we have to invert the sign of the - # sampled ray positions to convert the NDC xy locations - # of the MonteCarloRaysampler to the coordinate - # convention of grid_sample. + # In order to sample target_images, we utilize the grid_sample function + # which implements a bilinear image sampler. Note that we have to invert + # the sign of the sampled ray positions to convert the NDC xy locations of + # the MonteCarloRaysampler to the coordinate convention of grid_sample. images_sampled = torch.nn.functional.grid_sample( target_images.permute(0, 3, 1, 2), -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion - align_corners=True - ) - return images_sampled.permute(0, 2, 3, 1).view( - ba, *spatial_size, dim + align_corners=True, ) + return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim) + def show_full_render( - neural_radiance_field, camera, - target_image, target_silhouette, - loss_history_color, loss_history_sil, + neural_radiance_field, + camera, + target_image, + target_silhouette, + loss_history_color, + loss_history_sil, ): """ - This is a helper function for visualizing the - intermediate results of the learning. - - Since the `NeuralRadianceField` suffers from - a large memory footprint, which does not allow to - render the full image grid in a single forward pass, - we utilize the `NeuralRadianceField.batched_forward` - function in combination with disabling the gradient caching. - This chunks the set of emitted rays to batches and - evaluates the implicit function on one-batch at a time - to prevent GPU memory overflow. + This is a helper function for visualizing the intermediate results of the + learning. + + Since the `NeuralRadianceField` suffers from a large memory footprint, + which does not allow to render the full image grid in a single forward + pass, we utilize the `NeuralRadianceField.batched_forward` function in + combination with disabling the gradient caching. This chunks the set of + emitted rays to batches and evaluates the implicit function on one-batch at + a time to prevent GPU memory overflow. """ # Prevent gradient caching. @@ -448,43 +443,52 @@ def show_full_render( # batched_forward function of neural_radiance_field. rendered_image_silhouette, _ = renderer_grid( cameras=camera, - volumetric_function=neural_radiance_field.batched_forward + volumetric_function=neural_radiance_field.batched_forward, ) # Split the rendering result to a silhouette render # and the image render. - rendered_image, rendered_silhouette = ( - rendered_image_silhouette[0].split([3, 1], dim=-1) - ) + rendered_image, rendered_silhouette = rendered_image_silhouette[ + 0 + ].split([3, 1], dim=-1) # Generate plots. fig, ax = plt.subplots(2, 3, figsize=(15, 10)) ax = ax.ravel() clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy() - ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1) + ax[0].plot( + list(range(len(loss_history_color))), loss_history_color, linewidth=1 + ) ax[1].imshow(clamp_and_detach(rendered_image)) ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0])) - ax[3].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1) + ax[3].plot( + list(range(len(loss_history_sil))), loss_history_sil, linewidth=1 + ) ax[4].imshow(clamp_and_detach(target_image)) ax[5].imshow(clamp_and_detach(target_silhouette)) for ax_, title_ in zip( ax, ( - "loss color", "rendered image", "rendered silhouette", - "loss silhouette", "target image", "target silhouette", - ) + "loss color", + "rendered image", + "rendered silhouette", + "loss silhouette", + "target image", + "target silhouette", + ), ): - if not title_.startswith('loss'): + if not title_.startswith("loss"): ax_.grid("off") ax_.axis("off") ax_.set_title(title_) - fig.canvas.draw(); fig.show() + fig.canvas.draw() + fig.show() display.clear_output(wait=True) display.display(fig) return fig ############################################################################### -# Fit the radiance field +# Fit the radiance field ############################################################################### # First move all relevant variables to the correct device. @@ -521,7 +525,7 @@ for iteration in range(n_iter): # In case we reached the last 75% of iterations, # decrease the learning rate of the optimizer 10-fold. if iteration == round(n_iter * 0.75): - print('Decreasing LR 10-fold ...') + print("Decreasing LR 10-fold ...") optimizer = torch.optim.Adam( neural_radiance_field.parameters(), lr=lr * 0.1 ) @@ -534,47 +538,52 @@ for iteration in range(n_iter): # Sample the minibatch of cameras. batch_cameras = FoVPerspectiveCameras( - R = target_cameras.R[batch_idx], - T = target_cameras.T[batch_idx], - znear = target_cameras.znear[batch_idx], - zfar = target_cameras.zfar[batch_idx], - aspect_ratio = target_cameras.aspect_ratio[batch_idx], - fov = target_cameras.fov[batch_idx], - device = device, + R=target_cameras.R[batch_idx], + T=target_cameras.T[batch_idx], + znear=target_cameras.znear[batch_idx], + zfar=target_cameras.zfar[batch_idx], + aspect_ratio=target_cameras.aspect_ratio[batch_idx], + fov=target_cameras.fov[batch_idx], + device=device, ) # Evaluate the nerf model. rendered_images_silhouettes, sampled_rays = renderer_mc( - cameras=batch_cameras, - volumetric_function=neural_radiance_field + cameras=batch_cameras, volumetric_function=neural_radiance_field ) - rendered_images, rendered_silhouettes = ( - rendered_images_silhouettes.split([3, 1], dim=-1) + rendered_images, rendered_silhouettes = rendered_images_silhouettes.split( + [3, 1], dim=-1 ) # Compute the silhoutte error as the mean huber # loss between the predicted masks and the # sampled target silhouettes. silhouettes_at_rays = sample_images_at_mc_locs( - target_silhouettes[batch_idx, ..., None], - sampled_rays.xys + target_silhouettes[batch_idx, ..., None], sampled_rays.xys + ) + sil_err = ( + huber( + rendered_silhouettes, + silhouettes_at_rays, + ) + .abs() + .mean() ) - sil_err = huber( - rendered_silhouettes, - silhouettes_at_rays, - ).abs().mean() # Compute the color error as the mean huber # loss between the rendered colors and the # sampled target images. colors_at_rays = sample_images_at_mc_locs( - target_images[batch_idx], - sampled_rays.xys + target_images[batch_idx], sampled_rays.xys + ) + color_err = ( + huber( + rendered_images, + colors_at_rays, + ) + .abs() + .mean() ) - color_err = huber( - rendered_images, - colors_at_rays, - ).abs().mean() # The optimization loss is a simple # sum of the color and silhouette errors. @@ -587,9 +596,9 @@ for iteration in range(n_iter): # Every 10 iterations, print the current values of the losses. if iteration % 10 == 0: print( - f'Iteration {iteration:05d}:' - + f' loss color = {float(color_err):1.2e}' - + f' loss silhouette = {float(sil_err):1.2e}' + f"Iteration {iteration:05d}:" + + f" loss color = {float(color_err):1.2e}" + + f" loss silhouette = {float(sil_err):1.2e}" ) # Take the optimization step. @@ -602,13 +611,13 @@ for iteration in range(n_iter): show_full_render( neural_radiance_field, FoVPerspectiveCameras( - R = target_cameras.R[show_idx], - T = target_cameras.T[show_idx], - znear = target_cameras.znear[show_idx], - zfar = target_cameras.zfar[show_idx], - aspect_ratio = target_cameras.aspect_ratio[show_idx], - fov = target_cameras.fov[show_idx], - device = device, + R=target_cameras.R[show_idx], + T=target_cameras.T[show_idx], + znear=target_cameras.znear[show_idx], + zfar=target_cameras.zfar[show_idx], + aspect_ratio=target_cameras.aspect_ratio[show_idx], + fov=target_cameras.fov[show_idx], + device=device, ), target_images[show_idx][0], target_silhouettes[show_idx][0], @@ -617,17 +626,18 @@ for iteration in range(n_iter): ) ############################################################################### -# Visualizing the optimized neural radiance field +# Visualizing the optimized neural radiance field ############################################################################### -def generate_rotating_nerf(neural_radiance_field, n_frames = 50): + +def generate_rotating_nerf(neural_radiance_field, n_frames=50): logRs = torch.zeros(n_frames, 3, device=device) logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device) Rs = so3_exponential_map(logRs) Ts = torch.zeros(n_frames, 3, device=device) Ts[:, 2] = 2.7 frames = [] - print('Rendering rotating NeRF ...') + print("Rendering rotating NeRF ...") for R, T in zip(tqdm(Rs), Ts): camera = FoVPerspectiveCameras( R=R[None], @@ -648,8 +658,18 @@ def generate_rotating_nerf(neural_radiance_field, n_frames = 50): ) return torch.cat(frames) + with torch.no_grad(): - rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, n_frames=3*5) + rotating_nerf_frames = generate_rotating_nerf( + neural_radiance_field, n_frames=3 * 5 + ) + +image_grid( + rotating_nerf_frames.clamp(0.0, 1.0).cpu().numpy(), + rows=3, + cols=5, + rgb=True, + fill=True, +) -image_grid(rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), rows=3, cols=5, rgb=True, fill=True) plt.show()