black and isort pass + some extra reshaping

This commit is contained in:
otthorn 2021-04-13 11:30:31 +02:00
parent 42678ac5bb
commit 1ab064da7e

396
main.py
View file

@ -7,49 +7,44 @@
# #
# Copyright (c) 2021 Solal Nathan # Copyright (c) 2021 Solal Nathan
# Author: Solal "Otthorn" Nathan <otthorn@crans.org> # Author: Solal "Otthorn" Nathan <otthorn@crans.org>
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
############################################################################### ###############################################################################
import glob
import json
import math
import os import os
import sys import sys
import time import time
import json
import glob
import torch
import math
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from tqdm import tqdm from pytorch3d.renderer import (EmissionAbsorptionRaymarcher,
FoVPerspectiveCameras, ImplicitRenderer,
MonteCarloRaysampler, NDCGridRaysampler,
RayBundle, ray_bundle_to_ray_points)
# Data structures and functions for rendering # Data structures and functions for rendering
from pytorch3d.structures import Volumes from pytorch3d.structures import Volumes
from pytorch3d.transforms import so3_exponential_map from pytorch3d.transforms import so3_exponential_map
from pytorch3d.renderer import ( from tqdm import tqdm
FoVPerspectiveCameras,
NDCGridRaysampler,
MonteCarloRaysampler,
EmissionAbsorptionRaymarcher,
ImplicitRenderer,
RayBundle,
ray_bundle_to_ray_points,
)
# add path for demo utils functions # add path for demo utils functions
sys.path.append(os.path.abspath('')) sys.path.append(os.path.abspath(""))
from utils.plot_image_grid import image_grid
from utils.generate_cow_renders import generate_cow_renders from utils.generate_cow_renders import generate_cow_renders
from utils.plot_image_grid import image_grid
# Intialize CUDA gpu # Intialize CUDA gpu
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Generate dataset # Generate dataset
target_cameras, target_images, target_silhouettes = \ target_cameras, target_images, target_silhouettes = generate_cow_renders(
generate_cow_renders(num_views=40, azimuth_range=180) num_views=40, azimuth_range=180
print(f'Generated {len(target_images)} images/silhouettes/cameras.') )
print(f"Generated {len(target_images)} images/silhouettes/cameras.")
############################################################################### ###############################################################################
# Intitialize the implicit rendered # Intitialize the implicit rendered
@ -84,10 +79,10 @@ raysampler_grid = NDCGridRaysampler(
# MonteCarloRaysampler generates a random subset # MonteCarloRaysampler generates a random subset
# of `n_rays_per_image` rays emitted from the image plane. # of `n_rays_per_image` rays emitted from the image plane.
raysampler_mc = MonteCarloRaysampler( raysampler_mc = MonteCarloRaysampler(
min_x = -1.0, min_x=-1.0,
max_x = 1.0, max_x=1.0,
min_y = -1.0, min_y=-1.0,
max_y = 1.0, max_y=1.0,
n_rays_per_image=750, n_rays_per_image=750,
n_pts_per_ray=128, n_pts_per_ray=128,
min_depth=0.1, min_depth=0.1,
@ -104,16 +99,19 @@ raymarcher = EmissionAbsorptionRaymarcher()
# Finally, instantiate the implicit renders # Finally, instantiate the implicit renders
# for both raysamplers. # for both raysamplers.
renderer_grid = ImplicitRenderer( renderer_grid = ImplicitRenderer(
raysampler=raysampler_grid, raymarcher=raymarcher, raysampler=raysampler_grid,
raymarcher=raymarcher,
) )
renderer_mc = ImplicitRenderer( renderer_mc = ImplicitRenderer(
raysampler=raysampler_mc, raymarcher=raymarcher, raysampler=raysampler_mc,
raymarcher=raymarcher,
) )
############################################################################### ###############################################################################
# Define the NeRF model # Define the NeRF model
############################################################################### ###############################################################################
class HarmonicEmbedding(torch.nn.Module): class HarmonicEmbedding(torch.nn.Module):
def __init__(self, n_harmonic_functions=60, omega0=0.1): def __init__(self, n_harmonic_functions=60, omega0=0.1):
""" """
@ -133,15 +131,16 @@ class HarmonicEmbedding(torch.nn.Module):
... ...
cos(2**self.n_harmonic_functions * x[..., i]) cos(2**self.n_harmonic_functions * x[..., i])
] ]
Note that `x` is also premultiplied by `omega0` before Note that `x` is also premultiplied by `omega0` before
evaluting the harmonic functions. evaluting the harmonic functions.
""" """
super().__init__() super().__init__()
self.register_buffer( self.register_buffer(
'frequencies', "frequencies",
omega0 * (2.0 ** torch.arange(n_harmonic_functions)), omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
) )
def forward(self, x): def forward(self, x):
""" """
Args: Args:
@ -164,15 +163,15 @@ class NeuralRadianceField(torch.nn.Module):
n_hidden_neurons: The number of hidden units in the n_hidden_neurons: The number of hidden units in the
fully connected layers of the MLPs of the model. fully connected layers of the MLPs of the model.
""" """
# The harmonic embedding layer converts input 3D coordinates # The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for # to a representation that is more suitable for
# processing with a deep neural network. # processing with a deep neural network.
self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions) self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
# The dimension of the harmonic embedding. # The dimension of the harmonic embedding.
embedding_dim = n_harmonic_functions * 2 * 3 embedding_dim = n_harmonic_functions * 2 * 3
# self.mlp is a simple 2-layer multi-layer perceptron # self.mlp is a simple 2-layer multi-layer perceptron
# which converts the input per-point harmonic embeddings # which converts the input per-point harmonic embeddings
# to a latent representation. # to a latent representation.
@ -182,8 +181,8 @@ class NeuralRadianceField(torch.nn.Module):
torch.nn.Softplus(beta=10.0), torch.nn.Softplus(beta=10.0),
torch.nn.Linear(n_hidden_neurons, n_hidden_neurons), torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
torch.nn.Softplus(beta=10.0), torch.nn.Softplus(beta=10.0),
) )
# Given features predicted by self.mlp, self.color_layer # Given features predicted by self.mlp, self.color_layer
# is responsible for predicting a 3-D per-point vector # is responsible for predicting a 3-D per-point vector
# that represents the RGB color of the point. # that represents the RGB color of the point.
@ -192,89 +191,80 @@ class NeuralRadianceField(torch.nn.Module):
torch.nn.Softplus(beta=10.0), torch.nn.Softplus(beta=10.0),
torch.nn.Linear(n_hidden_neurons, 3), torch.nn.Linear(n_hidden_neurons, 3),
torch.nn.Sigmoid(), torch.nn.Sigmoid(),
# To ensure that the colors correctly range between [0-1], # To ensure that the colors correctly range between [0-1], the
# the layer is terminated with a sigmoid layer. # layer is terminated with a sigmoid layer.
) )
# The density layer converts the features of self.mlp # The density layer converts the features of self.mlp to a 1D density
# to a 1D density value representing the raw opacity # value representing the raw opacity of each point.
# of each point.
self.density_layer = torch.nn.Sequential( self.density_layer = torch.nn.Sequential(
torch.nn.Linear(n_hidden_neurons, 1), torch.nn.Linear(n_hidden_neurons, 1),
torch.nn.Softplus(beta=10.0), torch.nn.Softplus(beta=10.0),
# Sofplus activation ensures that the raw opacity # Sofplus activation ensures that the raw opacity
# is a non-negative number. # is a non-negative number.
) )
# We set the bias of the density layer to -1.5 # We set the bias of the density layer to -1.5 in order to initialize
# in order to initialize the opacities of the # the opacities of the ray points to values close to 0. This is a
# ray points to values close to 0. # crucial detail for ensuring convergence of the model.
# This is a crucial detail for ensuring convergence self.density_layer[0].bias.data[0] = -1.5
# of the model.
self.density_layer[0].bias.data[0] = -1.5
def _get_densities(self, features): def _get_densities(self, features):
""" """
This function takes `features` predicted by `self.mlp` This function takes `features` predicted by `self.mlp` and converts
and converts them to `raw_densities` with `self.density_layer`. them to `raw_densities` with `self.density_layer`. `raw_densities` are
`raw_densities` are later mapped to [0-1] range with later mapped to [0-1] range with 1 - inverse exponential of
1 - inverse exponential of `raw_densities`. `raw_densities`.
""" """
raw_densities = self.density_layer(features) raw_densities = self.density_layer(features)
return 1 - (-raw_densities).exp() return 1 - (-raw_densities).exp()
def _get_colors(self, features, rays_directions): def _get_colors(self, features, rays_directions):
""" """
This function takes per-point `features` predicted by `self.mlp` This function takes per-point `features` predicted by `self.mlp` and
and evaluates the color model in order to attach to each evaluates the color model in order to attach to each point a 3D vector
point a 3D vector of its RGB color. of its RGB color.
In order to represent viewpoint dependent effects, In order to represent viewpoint dependent effects, before evaluating
before evaluating `self.color_layer`, `NeuralRadianceField` `self.color_layer`, `NeuralRadianceField` concatenates to the
concatenates to the `features` a harmonic embedding `features` a harmonic embedding of `ray_directions`, which are
of `ray_directions`, which are per-point directions per-point directions of point rays expressed as 3D l2-normalized
of point rays expressed as 3D l2-normalized vectors vectors in world coordinates.
in world coordinates.
""" """
spatial_size = features.shape[:-1] spatial_size = features.shape[:-1]
# Normalize the ray_directions to unit l2 norm. # Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize( rays_directions_normed = torch.nn.functional.normalize(
rays_directions, dim=-1 rays_directions, dim=-1
) )
# Obtain the harmonic embedding of the normalized ray directions. # Obtain the harmonic embedding of the normalized ray directions.
rays_embedding = self.harmonic_embedding( rays_embedding = self.harmonic_embedding(rays_directions_normed)
rays_directions_normed
)
# Expand the ray directions tensor so that its spatial size # Expand the ray directions tensor so that its spatial size
# is equal to the size of features. # is equal to the size of features.
rays_embedding_expand = rays_embedding[..., None, :].expand( rays_embedding_expand = rays_embedding[..., None, :].expand(
*spatial_size, rays_embedding.shape[-1] *spatial_size, rays_embedding.shape[-1]
) )
# Concatenate ray direction embeddings with # Concatenate ray direction embeddings with
# features and evaluate the color model. # features and evaluate the color model.
color_layer_input = torch.cat( color_layer_input = torch.cat(
(features, rays_embedding_expand), (features, rays_embedding_expand), dim=-1
dim=-1
) )
return self.color_layer(color_layer_input) return self.color_layer(color_layer_input)
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: RayBundle,
**kwargs, **kwargs,
): ):
""" """
The forward function accepts the parametrizations of The forward function accepts the parametrizations of 3D points sampled
3D points sampled along projection rays. The forward along projection rays. The forward pass is responsible for attaching a
pass is responsible for attaching a 3D vector 3D vector and a 1D scalar representing the point's RGB color and
and a 1D scalar representing the point's opacity respectively.
RGB color and opacity respectively.
Args: Args:
ray_bundle: A RayBundle object containing the following variables: ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins: A tensor of shape `(minibatch, ..., 3)` denoting the
@ -294,44 +284,42 @@ class NeuralRadianceField(torch.nn.Module):
# coordinates with `ray_bundle_to_ray_points`. # coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle) rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x 3] # rays_points_world.shape = [minibatch x ... x 3]
# For each 3D world coordinate, we obtain its harmonic embedding. # For each 3D world coordinate, we obtain its harmonic embedding.
embeds = self.harmonic_embedding( embeds = self.harmonic_embedding(rays_points_world)
rays_points_world
)
# embeds.shape = [minibatch x ... x self.n_harmonic_functions*6] # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]
# self.mlp maps each harmonic embedding to a latent feature space. # self.mlp maps each harmonic embedding to a latent feature space.
features = self.mlp(embeds) features = self.mlp(embeds)
# features.shape = [minibatch x ... x n_hidden_neurons] # features.shape = [minibatch x ... x n_hidden_neurons]
# Finally, given the per-point features, # Finally, given the per-point features,
# execute the density and color branches. # execute the density and color branches.
rays_densities = self._get_densities(features) rays_densities = self._get_densities(features)
# rays_densities.shape = [minibatch x ... x 1] # rays_densities.shape = [minibatch x ... x 1]
rays_colors = self._get_colors(features, ray_bundle.directions) rays_colors = self._get_colors(features, ray_bundle.directions)
# rays_colors.shape = [minibatch x ... x 3] # rays_colors.shape = [minibatch x ... x 3]
return rays_densities, rays_colors return rays_densities, rays_colors
def batched_forward( def batched_forward(
self, self,
ray_bundle: RayBundle, ray_bundle: RayBundle,
n_batches: int = 16, n_batches: int = 16,
**kwargs, **kwargs,
): ):
""" """
This function is used to allow for memory efficient processing This function is used to allow for memory efficient processing of input
of input rays. The input rays are first split to `n_batches` rays. The input rays are first split to `n_batches` chunks and passed
chunks and passed through the `self.forward` function one at a time through the `self.forward` function one at a time in a for loop.
in a for loop. Combined with disabling Pytorch gradient caching Combined with disabling Pytorch gradient caching (`torch.no_grad()`),
(`torch.no_grad()`), this allows for rendering large batches this allows for rendering large batches of rays that do not all fit
of rays that do not all fit into GPU memory in a single forward pass. into GPU memory in a single forward pass. In our case, batched_forward
In our case, batched_forward is used to export a fully-sized render is used to export a fully-sized render of the radiance field for
of the radiance field for visualisation purposes. visualisation purposes.
Args: Args:
ray_bundle: A RayBundle object containing the following variables: ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins: A tensor of shape `(minibatch, ..., 3)` denoting the
@ -353,7 +341,7 @@ class NeuralRadianceField(torch.nn.Module):
""" """
# Parse out shapes needed for tensor reshaping in this function. # Parse out shapes needed for tensor reshaping in this function.
n_pts_per_ray = ray_bundle.lengths.shape[-1] n_pts_per_ray = ray_bundle.lengths.shape[-1]
spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray] spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]
# Split the rays to `n_batches` batches. # Split the rays to `n_batches` batches.
@ -366,34 +354,44 @@ class NeuralRadianceField(torch.nn.Module):
RayBundle( RayBundle(
origins=ray_bundle.origins.view(-1, 3)[batch_idx], origins=ray_bundle.origins.view(-1, 3)[batch_idx],
directions=ray_bundle.directions.view(-1, 3)[batch_idx], directions=ray_bundle.directions.view(-1, 3)[batch_idx],
lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx], lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[
batch_idx
],
xys=None, xys=None,
) )
) for batch_idx in batches )
for batch_idx in batches
] ]
# Concatenate the per-batch rays_densities and rays_colors # Concatenate the per-batch rays_densities and rays_colors
# and reshape according to the sizes of the inputs. # and reshape according to the sizes of the inputs.
rays_densities, rays_colors = [ rays_densities, rays_colors = [
torch.cat( torch.cat(
[batch_output[output_i] for batch_output in batch_outputs], dim=0 [batch_output[output_i] for batch_output in batch_outputs],
).view(*spatial_size, -1) for output_i in (0, 1) dim=0,
).view(*spatial_size, -1)
for output_i in (0, 1)
] ]
return rays_densities, rays_colors return rays_densities, rays_colors
############################################################################### ###############################################################################
# Helper functions # Helper functions
############################################################################### ###############################################################################
def huber(x, y, scaling=0.1): def huber(x, y, scaling=0.1):
""" """
A helper function for evaluating the smooth L1 (huber) loss A helper function for evaluating the smooth L1 (huber) loss
between the rendered silhouettes and colors. between the rendered silhouettes and colors.
""" """
diff_sq = (x - y) ** 2 diff_sq = (x - y) ** 2
loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling) loss = ((1 + diff_sq / (scaling ** 2)).clamp(1e-4).sqrt() - 1) * float(
scaling
)
return loss return loss
def sample_images_at_mc_locs(target_images, sampled_rays_xy): def sample_images_at_mc_locs(target_images, sampled_rays_xy):
""" """
Given a set of Monte Carlo pixel locations `sampled_rays_xy`, Given a set of Monte Carlo pixel locations `sampled_rays_xy`,
@ -407,39 +405,36 @@ def sample_images_at_mc_locs(target_images, sampled_rays_xy):
ba = target_images.shape[0] ba = target_images.shape[0]
dim = target_images.shape[-1] dim = target_images.shape[-1]
spatial_size = sampled_rays_xy.shape[1:-1] spatial_size = sampled_rays_xy.shape[1:-1]
# In order to sample target_images, we utilize # In order to sample target_images, we utilize the grid_sample function
# the grid_sample function which implements a # which implements a bilinear image sampler. Note that we have to invert
# bilinear image sampler. # the sign of the sampled ray positions to convert the NDC xy locations of
# Note that we have to invert the sign of the # the MonteCarloRaysampler to the coordinate convention of grid_sample.
# sampled ray positions to convert the NDC xy locations
# of the MonteCarloRaysampler to the coordinate
# convention of grid_sample.
images_sampled = torch.nn.functional.grid_sample( images_sampled = torch.nn.functional.grid_sample(
target_images.permute(0, 3, 1, 2), target_images.permute(0, 3, 1, 2),
-sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion
align_corners=True align_corners=True,
)
return images_sampled.permute(0, 2, 3, 1).view(
ba, *spatial_size, dim
) )
return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)
def show_full_render( def show_full_render(
neural_radiance_field, camera, neural_radiance_field,
target_image, target_silhouette, camera,
loss_history_color, loss_history_sil, target_image,
target_silhouette,
loss_history_color,
loss_history_sil,
): ):
""" """
This is a helper function for visualizing the This is a helper function for visualizing the intermediate results of the
intermediate results of the learning. learning.
Since the `NeuralRadianceField` suffers from Since the `NeuralRadianceField` suffers from a large memory footprint,
a large memory footprint, which does not allow to which does not allow to render the full image grid in a single forward
render the full image grid in a single forward pass, pass, we utilize the `NeuralRadianceField.batched_forward` function in
we utilize the `NeuralRadianceField.batched_forward` combination with disabling the gradient caching. This chunks the set of
function in combination with disabling the gradient caching. emitted rays to batches and evaluates the implicit function on one-batch at
This chunks the set of emitted rays to batches and a time to prevent GPU memory overflow.
evaluates the implicit function on one-batch at a time
to prevent GPU memory overflow.
""" """
# Prevent gradient caching. # Prevent gradient caching.
@ -448,43 +443,52 @@ def show_full_render(
# batched_forward function of neural_radiance_field. # batched_forward function of neural_radiance_field.
rendered_image_silhouette, _ = renderer_grid( rendered_image_silhouette, _ = renderer_grid(
cameras=camera, cameras=camera,
volumetric_function=neural_radiance_field.batched_forward volumetric_function=neural_radiance_field.batched_forward,
) )
# Split the rendering result to a silhouette render # Split the rendering result to a silhouette render
# and the image render. # and the image render.
rendered_image, rendered_silhouette = ( rendered_image, rendered_silhouette = rendered_image_silhouette[
rendered_image_silhouette[0].split([3, 1], dim=-1) 0
) ].split([3, 1], dim=-1)
# Generate plots. # Generate plots.
fig, ax = plt.subplots(2, 3, figsize=(15, 10)) fig, ax = plt.subplots(2, 3, figsize=(15, 10))
ax = ax.ravel() ax = ax.ravel()
clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy() clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1) ax[0].plot(
list(range(len(loss_history_color))), loss_history_color, linewidth=1
)
ax[1].imshow(clamp_and_detach(rendered_image)) ax[1].imshow(clamp_and_detach(rendered_image))
ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0])) ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))
ax[3].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1) ax[3].plot(
list(range(len(loss_history_sil))), loss_history_sil, linewidth=1
)
ax[4].imshow(clamp_and_detach(target_image)) ax[4].imshow(clamp_and_detach(target_image))
ax[5].imshow(clamp_and_detach(target_silhouette)) ax[5].imshow(clamp_and_detach(target_silhouette))
for ax_, title_ in zip( for ax_, title_ in zip(
ax, ax,
( (
"loss color", "rendered image", "rendered silhouette", "loss color",
"loss silhouette", "target image", "target silhouette", "rendered image",
) "rendered silhouette",
"loss silhouette",
"target image",
"target silhouette",
),
): ):
if not title_.startswith('loss'): if not title_.startswith("loss"):
ax_.grid("off") ax_.grid("off")
ax_.axis("off") ax_.axis("off")
ax_.set_title(title_) ax_.set_title(title_)
fig.canvas.draw(); fig.show() fig.canvas.draw()
fig.show()
display.clear_output(wait=True) display.clear_output(wait=True)
display.display(fig) display.display(fig)
return fig return fig
############################################################################### ###############################################################################
# Fit the radiance field # Fit the radiance field
############################################################################### ###############################################################################
# First move all relevant variables to the correct device. # First move all relevant variables to the correct device.
@ -521,7 +525,7 @@ for iteration in range(n_iter):
# In case we reached the last 75% of iterations, # In case we reached the last 75% of iterations,
# decrease the learning rate of the optimizer 10-fold. # decrease the learning rate of the optimizer 10-fold.
if iteration == round(n_iter * 0.75): if iteration == round(n_iter * 0.75):
print('Decreasing LR 10-fold ...') print("Decreasing LR 10-fold ...")
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
neural_radiance_field.parameters(), lr=lr * 0.1 neural_radiance_field.parameters(), lr=lr * 0.1
) )
@ -534,47 +538,52 @@ for iteration in range(n_iter):
# Sample the minibatch of cameras. # Sample the minibatch of cameras.
batch_cameras = FoVPerspectiveCameras( batch_cameras = FoVPerspectiveCameras(
R = target_cameras.R[batch_idx], R=target_cameras.R[batch_idx],
T = target_cameras.T[batch_idx], T=target_cameras.T[batch_idx],
znear = target_cameras.znear[batch_idx], znear=target_cameras.znear[batch_idx],
zfar = target_cameras.zfar[batch_idx], zfar=target_cameras.zfar[batch_idx],
aspect_ratio = target_cameras.aspect_ratio[batch_idx], aspect_ratio=target_cameras.aspect_ratio[batch_idx],
fov = target_cameras.fov[batch_idx], fov=target_cameras.fov[batch_idx],
device = device, device=device,
) )
# Evaluate the nerf model. # Evaluate the nerf model.
rendered_images_silhouettes, sampled_rays = renderer_mc( rendered_images_silhouettes, sampled_rays = renderer_mc(
cameras=batch_cameras, cameras=batch_cameras, volumetric_function=neural_radiance_field
volumetric_function=neural_radiance_field
) )
rendered_images, rendered_silhouettes = ( rendered_images, rendered_silhouettes = rendered_images_silhouettes.split(
rendered_images_silhouettes.split([3, 1], dim=-1) [3, 1], dim=-1
) )
# Compute the silhoutte error as the mean huber # Compute the silhoutte error as the mean huber
# loss between the predicted masks and the # loss between the predicted masks and the
# sampled target silhouettes. # sampled target silhouettes.
silhouettes_at_rays = sample_images_at_mc_locs( silhouettes_at_rays = sample_images_at_mc_locs(
target_silhouettes[batch_idx, ..., None], target_silhouettes[batch_idx, ..., None], sampled_rays.xys
sampled_rays.xys )
sil_err = (
huber(
rendered_silhouettes,
silhouettes_at_rays,
)
.abs()
.mean()
) )
sil_err = huber(
rendered_silhouettes,
silhouettes_at_rays,
).abs().mean()
# Compute the color error as the mean huber # Compute the color error as the mean huber
# loss between the rendered colors and the # loss between the rendered colors and the
# sampled target images. # sampled target images.
colors_at_rays = sample_images_at_mc_locs( colors_at_rays = sample_images_at_mc_locs(
target_images[batch_idx], target_images[batch_idx], sampled_rays.xys
sampled_rays.xys )
color_err = (
huber(
rendered_images,
colors_at_rays,
)
.abs()
.mean()
) )
color_err = huber(
rendered_images,
colors_at_rays,
).abs().mean()
# The optimization loss is a simple # The optimization loss is a simple
# sum of the color and silhouette errors. # sum of the color and silhouette errors.
@ -587,9 +596,9 @@ for iteration in range(n_iter):
# Every 10 iterations, print the current values of the losses. # Every 10 iterations, print the current values of the losses.
if iteration % 10 == 0: if iteration % 10 == 0:
print( print(
f'Iteration {iteration:05d}:' f"Iteration {iteration:05d}:"
+ f' loss color = {float(color_err):1.2e}' + f" loss color = {float(color_err):1.2e}"
+ f' loss silhouette = {float(sil_err):1.2e}' + f" loss silhouette = {float(sil_err):1.2e}"
) )
# Take the optimization step. # Take the optimization step.
@ -602,13 +611,13 @@ for iteration in range(n_iter):
show_full_render( show_full_render(
neural_radiance_field, neural_radiance_field,
FoVPerspectiveCameras( FoVPerspectiveCameras(
R = target_cameras.R[show_idx], R=target_cameras.R[show_idx],
T = target_cameras.T[show_idx], T=target_cameras.T[show_idx],
znear = target_cameras.znear[show_idx], znear=target_cameras.znear[show_idx],
zfar = target_cameras.zfar[show_idx], zfar=target_cameras.zfar[show_idx],
aspect_ratio = target_cameras.aspect_ratio[show_idx], aspect_ratio=target_cameras.aspect_ratio[show_idx],
fov = target_cameras.fov[show_idx], fov=target_cameras.fov[show_idx],
device = device, device=device,
), ),
target_images[show_idx][0], target_images[show_idx][0],
target_silhouettes[show_idx][0], target_silhouettes[show_idx][0],
@ -617,17 +626,18 @@ for iteration in range(n_iter):
) )
############################################################################### ###############################################################################
# Visualizing the optimized neural radiance field # Visualizing the optimized neural radiance field
############################################################################### ###############################################################################
def generate_rotating_nerf(neural_radiance_field, n_frames = 50):
def generate_rotating_nerf(neural_radiance_field, n_frames=50):
logRs = torch.zeros(n_frames, 3, device=device) logRs = torch.zeros(n_frames, 3, device=device)
logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device) logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)
Rs = so3_exponential_map(logRs) Rs = so3_exponential_map(logRs)
Ts = torch.zeros(n_frames, 3, device=device) Ts = torch.zeros(n_frames, 3, device=device)
Ts[:, 2] = 2.7 Ts[:, 2] = 2.7
frames = [] frames = []
print('Rendering rotating NeRF ...') print("Rendering rotating NeRF ...")
for R, T in zip(tqdm(Rs), Ts): for R, T in zip(tqdm(Rs), Ts):
camera = FoVPerspectiveCameras( camera = FoVPerspectiveCameras(
R=R[None], R=R[None],
@ -648,8 +658,18 @@ def generate_rotating_nerf(neural_radiance_field, n_frames = 50):
) )
return torch.cat(frames) return torch.cat(frames)
with torch.no_grad():
rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, n_frames=3*5)
image_grid(rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), rows=3, cols=5, rgb=True, fill=True) with torch.no_grad():
rotating_nerf_frames = generate_rotating_nerf(
neural_radiance_field, n_frames=3 * 5
)
image_grid(
rotating_nerf_frames.clamp(0.0, 1.0).cpu().numpy(),
rows=3,
cols=5,
rgb=True,
fill=True,
)
plt.show() plt.show()