☕ black and isort pass + some extra reshaping
This commit is contained in:
parent
42678ac5bb
commit
1ab064da7e
1 changed files with 208 additions and 188 deletions
396
main.py
396
main.py
|
@ -7,49 +7,44 @@
|
||||||
#
|
#
|
||||||
# Copyright (c) 2021 Solal Nathan
|
# Copyright (c) 2021 Solal Nathan
|
||||||
# Author: Solal "Otthorn" Nathan <otthorn@crans.org>
|
# Author: Solal "Otthorn" Nathan <otthorn@crans.org>
|
||||||
# SPDX-License-Identifier: BSD-3-Clause
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
#
|
#
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
import glob
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from pytorch3d.renderer import (EmissionAbsorptionRaymarcher,
|
||||||
|
FoVPerspectiveCameras, ImplicitRenderer,
|
||||||
|
MonteCarloRaysampler, NDCGridRaysampler,
|
||||||
|
RayBundle, ray_bundle_to_ray_points)
|
||||||
# Data structures and functions for rendering
|
# Data structures and functions for rendering
|
||||||
from pytorch3d.structures import Volumes
|
from pytorch3d.structures import Volumes
|
||||||
from pytorch3d.transforms import so3_exponential_map
|
from pytorch3d.transforms import so3_exponential_map
|
||||||
from pytorch3d.renderer import (
|
from tqdm import tqdm
|
||||||
FoVPerspectiveCameras,
|
|
||||||
NDCGridRaysampler,
|
|
||||||
MonteCarloRaysampler,
|
|
||||||
EmissionAbsorptionRaymarcher,
|
|
||||||
ImplicitRenderer,
|
|
||||||
RayBundle,
|
|
||||||
ray_bundle_to_ray_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add path for demo utils functions
|
# add path for demo utils functions
|
||||||
sys.path.append(os.path.abspath(''))
|
sys.path.append(os.path.abspath(""))
|
||||||
from utils.plot_image_grid import image_grid
|
|
||||||
from utils.generate_cow_renders import generate_cow_renders
|
from utils.generate_cow_renders import generate_cow_renders
|
||||||
|
from utils.plot_image_grid import image_grid
|
||||||
|
|
||||||
# Intialize CUDA gpu
|
# Intialize CUDA gpu
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
# Generate dataset
|
# Generate dataset
|
||||||
target_cameras, target_images, target_silhouettes = \
|
target_cameras, target_images, target_silhouettes = generate_cow_renders(
|
||||||
generate_cow_renders(num_views=40, azimuth_range=180)
|
num_views=40, azimuth_range=180
|
||||||
print(f'Generated {len(target_images)} images/silhouettes/cameras.')
|
)
|
||||||
|
print(f"Generated {len(target_images)} images/silhouettes/cameras.")
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Intitialize the implicit rendered
|
# Intitialize the implicit rendered
|
||||||
|
@ -84,10 +79,10 @@ raysampler_grid = NDCGridRaysampler(
|
||||||
# MonteCarloRaysampler generates a random subset
|
# MonteCarloRaysampler generates a random subset
|
||||||
# of `n_rays_per_image` rays emitted from the image plane.
|
# of `n_rays_per_image` rays emitted from the image plane.
|
||||||
raysampler_mc = MonteCarloRaysampler(
|
raysampler_mc = MonteCarloRaysampler(
|
||||||
min_x = -1.0,
|
min_x=-1.0,
|
||||||
max_x = 1.0,
|
max_x=1.0,
|
||||||
min_y = -1.0,
|
min_y=-1.0,
|
||||||
max_y = 1.0,
|
max_y=1.0,
|
||||||
n_rays_per_image=750,
|
n_rays_per_image=750,
|
||||||
n_pts_per_ray=128,
|
n_pts_per_ray=128,
|
||||||
min_depth=0.1,
|
min_depth=0.1,
|
||||||
|
@ -104,16 +99,19 @@ raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
# Finally, instantiate the implicit renders
|
# Finally, instantiate the implicit renders
|
||||||
# for both raysamplers.
|
# for both raysamplers.
|
||||||
renderer_grid = ImplicitRenderer(
|
renderer_grid = ImplicitRenderer(
|
||||||
raysampler=raysampler_grid, raymarcher=raymarcher,
|
raysampler=raysampler_grid,
|
||||||
|
raymarcher=raymarcher,
|
||||||
)
|
)
|
||||||
renderer_mc = ImplicitRenderer(
|
renderer_mc = ImplicitRenderer(
|
||||||
raysampler=raysampler_mc, raymarcher=raymarcher,
|
raysampler=raysampler_mc,
|
||||||
|
raymarcher=raymarcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Define the NeRF model
|
# Define the NeRF model
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
class HarmonicEmbedding(torch.nn.Module):
|
class HarmonicEmbedding(torch.nn.Module):
|
||||||
def __init__(self, n_harmonic_functions=60, omega0=0.1):
|
def __init__(self, n_harmonic_functions=60, omega0=0.1):
|
||||||
"""
|
"""
|
||||||
|
@ -133,15 +131,16 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||||
...
|
...
|
||||||
cos(2**self.n_harmonic_functions * x[..., i])
|
cos(2**self.n_harmonic_functions * x[..., i])
|
||||||
]
|
]
|
||||||
|
|
||||||
Note that `x` is also premultiplied by `omega0` before
|
Note that `x` is also premultiplied by `omega0` before
|
||||||
evaluting the harmonic functions.
|
evaluting the harmonic functions.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'frequencies',
|
"frequencies",
|
||||||
omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
|
omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -164,15 +163,15 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
n_hidden_neurons: The number of hidden units in the
|
n_hidden_neurons: The number of hidden units in the
|
||||||
fully connected layers of the MLPs of the model.
|
fully connected layers of the MLPs of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The harmonic embedding layer converts input 3D coordinates
|
# The harmonic embedding layer converts input 3D coordinates
|
||||||
# to a representation that is more suitable for
|
# to a representation that is more suitable for
|
||||||
# processing with a deep neural network.
|
# processing with a deep neural network.
|
||||||
self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
|
self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
|
||||||
|
|
||||||
# The dimension of the harmonic embedding.
|
# The dimension of the harmonic embedding.
|
||||||
embedding_dim = n_harmonic_functions * 2 * 3
|
embedding_dim = n_harmonic_functions * 2 * 3
|
||||||
|
|
||||||
# self.mlp is a simple 2-layer multi-layer perceptron
|
# self.mlp is a simple 2-layer multi-layer perceptron
|
||||||
# which converts the input per-point harmonic embeddings
|
# which converts the input per-point harmonic embeddings
|
||||||
# to a latent representation.
|
# to a latent representation.
|
||||||
|
@ -182,8 +181,8 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
torch.nn.Softplus(beta=10.0),
|
torch.nn.Softplus(beta=10.0),
|
||||||
torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
|
torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
|
||||||
torch.nn.Softplus(beta=10.0),
|
torch.nn.Softplus(beta=10.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Given features predicted by self.mlp, self.color_layer
|
# Given features predicted by self.mlp, self.color_layer
|
||||||
# is responsible for predicting a 3-D per-point vector
|
# is responsible for predicting a 3-D per-point vector
|
||||||
# that represents the RGB color of the point.
|
# that represents the RGB color of the point.
|
||||||
|
@ -192,89 +191,80 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
torch.nn.Softplus(beta=10.0),
|
torch.nn.Softplus(beta=10.0),
|
||||||
torch.nn.Linear(n_hidden_neurons, 3),
|
torch.nn.Linear(n_hidden_neurons, 3),
|
||||||
torch.nn.Sigmoid(),
|
torch.nn.Sigmoid(),
|
||||||
# To ensure that the colors correctly range between [0-1],
|
# To ensure that the colors correctly range between [0-1], the
|
||||||
# the layer is terminated with a sigmoid layer.
|
# layer is terminated with a sigmoid layer.
|
||||||
)
|
)
|
||||||
|
|
||||||
# The density layer converts the features of self.mlp
|
# The density layer converts the features of self.mlp to a 1D density
|
||||||
# to a 1D density value representing the raw opacity
|
# value representing the raw opacity of each point.
|
||||||
# of each point.
|
|
||||||
self.density_layer = torch.nn.Sequential(
|
self.density_layer = torch.nn.Sequential(
|
||||||
torch.nn.Linear(n_hidden_neurons, 1),
|
torch.nn.Linear(n_hidden_neurons, 1),
|
||||||
torch.nn.Softplus(beta=10.0),
|
torch.nn.Softplus(beta=10.0),
|
||||||
# Sofplus activation ensures that the raw opacity
|
# Sofplus activation ensures that the raw opacity
|
||||||
# is a non-negative number.
|
# is a non-negative number.
|
||||||
)
|
)
|
||||||
|
|
||||||
# We set the bias of the density layer to -1.5
|
# We set the bias of the density layer to -1.5 in order to initialize
|
||||||
# in order to initialize the opacities of the
|
# the opacities of the ray points to values close to 0. This is a
|
||||||
# ray points to values close to 0.
|
# crucial detail for ensuring convergence of the model.
|
||||||
# This is a crucial detail for ensuring convergence
|
self.density_layer[0].bias.data[0] = -1.5
|
||||||
# of the model.
|
|
||||||
self.density_layer[0].bias.data[0] = -1.5
|
|
||||||
|
|
||||||
def _get_densities(self, features):
|
def _get_densities(self, features):
|
||||||
"""
|
"""
|
||||||
This function takes `features` predicted by `self.mlp`
|
This function takes `features` predicted by `self.mlp` and converts
|
||||||
and converts them to `raw_densities` with `self.density_layer`.
|
them to `raw_densities` with `self.density_layer`. `raw_densities` are
|
||||||
`raw_densities` are later mapped to [0-1] range with
|
later mapped to [0-1] range with 1 - inverse exponential of
|
||||||
1 - inverse exponential of `raw_densities`.
|
`raw_densities`.
|
||||||
"""
|
"""
|
||||||
raw_densities = self.density_layer(features)
|
raw_densities = self.density_layer(features)
|
||||||
return 1 - (-raw_densities).exp()
|
return 1 - (-raw_densities).exp()
|
||||||
|
|
||||||
def _get_colors(self, features, rays_directions):
|
def _get_colors(self, features, rays_directions):
|
||||||
"""
|
"""
|
||||||
This function takes per-point `features` predicted by `self.mlp`
|
This function takes per-point `features` predicted by `self.mlp` and
|
||||||
and evaluates the color model in order to attach to each
|
evaluates the color model in order to attach to each point a 3D vector
|
||||||
point a 3D vector of its RGB color.
|
of its RGB color.
|
||||||
|
|
||||||
In order to represent viewpoint dependent effects,
|
In order to represent viewpoint dependent effects, before evaluating
|
||||||
before evaluating `self.color_layer`, `NeuralRadianceField`
|
`self.color_layer`, `NeuralRadianceField` concatenates to the
|
||||||
concatenates to the `features` a harmonic embedding
|
`features` a harmonic embedding of `ray_directions`, which are
|
||||||
of `ray_directions`, which are per-point directions
|
per-point directions of point rays expressed as 3D l2-normalized
|
||||||
of point rays expressed as 3D l2-normalized vectors
|
vectors in world coordinates.
|
||||||
in world coordinates.
|
|
||||||
"""
|
"""
|
||||||
spatial_size = features.shape[:-1]
|
spatial_size = features.shape[:-1]
|
||||||
|
|
||||||
# Normalize the ray_directions to unit l2 norm.
|
# Normalize the ray_directions to unit l2 norm.
|
||||||
rays_directions_normed = torch.nn.functional.normalize(
|
rays_directions_normed = torch.nn.functional.normalize(
|
||||||
rays_directions, dim=-1
|
rays_directions, dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Obtain the harmonic embedding of the normalized ray directions.
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
rays_embedding = self.harmonic_embedding(
|
rays_embedding = self.harmonic_embedding(rays_directions_normed)
|
||||||
rays_directions_normed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expand the ray directions tensor so that its spatial size
|
# Expand the ray directions tensor so that its spatial size
|
||||||
# is equal to the size of features.
|
# is equal to the size of features.
|
||||||
rays_embedding_expand = rays_embedding[..., None, :].expand(
|
rays_embedding_expand = rays_embedding[..., None, :].expand(
|
||||||
*spatial_size, rays_embedding.shape[-1]
|
*spatial_size, rays_embedding.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Concatenate ray direction embeddings with
|
# Concatenate ray direction embeddings with
|
||||||
# features and evaluate the color model.
|
# features and evaluate the color model.
|
||||||
color_layer_input = torch.cat(
|
color_layer_input = torch.cat(
|
||||||
(features, rays_embedding_expand),
|
(features, rays_embedding_expand), dim=-1
|
||||||
dim=-1
|
|
||||||
)
|
)
|
||||||
return self.color_layer(color_layer_input)
|
return self.color_layer(color_layer_input)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The forward function accepts the parametrizations of
|
The forward function accepts the parametrizations of 3D points sampled
|
||||||
3D points sampled along projection rays. The forward
|
along projection rays. The forward pass is responsible for attaching a
|
||||||
pass is responsible for attaching a 3D vector
|
3D vector and a 1D scalar representing the point's RGB color and
|
||||||
and a 1D scalar representing the point's
|
opacity respectively.
|
||||||
RGB color and opacity respectively.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: A RayBundle object containing the following variables:
|
||||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
|
@ -294,44 +284,42 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
# coordinates with `ray_bundle_to_ray_points`.
|
# coordinates with `ray_bundle_to_ray_points`.
|
||||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
# rays_points_world.shape = [minibatch x ... x 3]
|
# rays_points_world.shape = [minibatch x ... x 3]
|
||||||
|
|
||||||
# For each 3D world coordinate, we obtain its harmonic embedding.
|
# For each 3D world coordinate, we obtain its harmonic embedding.
|
||||||
embeds = self.harmonic_embedding(
|
embeds = self.harmonic_embedding(rays_points_world)
|
||||||
rays_points_world
|
|
||||||
)
|
|
||||||
# embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]
|
# embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]
|
||||||
|
|
||||||
# self.mlp maps each harmonic embedding to a latent feature space.
|
# self.mlp maps each harmonic embedding to a latent feature space.
|
||||||
features = self.mlp(embeds)
|
features = self.mlp(embeds)
|
||||||
# features.shape = [minibatch x ... x n_hidden_neurons]
|
# features.shape = [minibatch x ... x n_hidden_neurons]
|
||||||
|
|
||||||
# Finally, given the per-point features,
|
# Finally, given the per-point features,
|
||||||
# execute the density and color branches.
|
# execute the density and color branches.
|
||||||
|
|
||||||
rays_densities = self._get_densities(features)
|
rays_densities = self._get_densities(features)
|
||||||
# rays_densities.shape = [minibatch x ... x 1]
|
# rays_densities.shape = [minibatch x ... x 1]
|
||||||
|
|
||||||
rays_colors = self._get_colors(features, ray_bundle.directions)
|
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||||
# rays_colors.shape = [minibatch x ... x 3]
|
# rays_colors.shape = [minibatch x ... x 3]
|
||||||
|
|
||||||
return rays_densities, rays_colors
|
return rays_densities, rays_colors
|
||||||
|
|
||||||
def batched_forward(
|
def batched_forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
n_batches: int = 16,
|
n_batches: int = 16,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This function is used to allow for memory efficient processing
|
This function is used to allow for memory efficient processing of input
|
||||||
of input rays. The input rays are first split to `n_batches`
|
rays. The input rays are first split to `n_batches` chunks and passed
|
||||||
chunks and passed through the `self.forward` function one at a time
|
through the `self.forward` function one at a time in a for loop.
|
||||||
in a for loop. Combined with disabling Pytorch gradient caching
|
Combined with disabling Pytorch gradient caching (`torch.no_grad()`),
|
||||||
(`torch.no_grad()`), this allows for rendering large batches
|
this allows for rendering large batches of rays that do not all fit
|
||||||
of rays that do not all fit into GPU memory in a single forward pass.
|
into GPU memory in a single forward pass. In our case, batched_forward
|
||||||
In our case, batched_forward is used to export a fully-sized render
|
is used to export a fully-sized render of the radiance field for
|
||||||
of the radiance field for visualisation purposes.
|
visualisation purposes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: A RayBundle object containing the following variables:
|
||||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
|
@ -353,7 +341,7 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parse out shapes needed for tensor reshaping in this function.
|
# Parse out shapes needed for tensor reshaping in this function.
|
||||||
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
||||||
spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]
|
spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]
|
||||||
|
|
||||||
# Split the rays to `n_batches` batches.
|
# Split the rays to `n_batches` batches.
|
||||||
|
@ -366,34 +354,44 @@ class NeuralRadianceField(torch.nn.Module):
|
||||||
RayBundle(
|
RayBundle(
|
||||||
origins=ray_bundle.origins.view(-1, 3)[batch_idx],
|
origins=ray_bundle.origins.view(-1, 3)[batch_idx],
|
||||||
directions=ray_bundle.directions.view(-1, 3)[batch_idx],
|
directions=ray_bundle.directions.view(-1, 3)[batch_idx],
|
||||||
lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
|
lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[
|
||||||
|
batch_idx
|
||||||
|
],
|
||||||
xys=None,
|
xys=None,
|
||||||
)
|
)
|
||||||
) for batch_idx in batches
|
)
|
||||||
|
for batch_idx in batches
|
||||||
]
|
]
|
||||||
|
|
||||||
# Concatenate the per-batch rays_densities and rays_colors
|
# Concatenate the per-batch rays_densities and rays_colors
|
||||||
# and reshape according to the sizes of the inputs.
|
# and reshape according to the sizes of the inputs.
|
||||||
rays_densities, rays_colors = [
|
rays_densities, rays_colors = [
|
||||||
torch.cat(
|
torch.cat(
|
||||||
[batch_output[output_i] for batch_output in batch_outputs], dim=0
|
[batch_output[output_i] for batch_output in batch_outputs],
|
||||||
).view(*spatial_size, -1) for output_i in (0, 1)
|
dim=0,
|
||||||
|
).view(*spatial_size, -1)
|
||||||
|
for output_i in (0, 1)
|
||||||
]
|
]
|
||||||
return rays_densities, rays_colors
|
return rays_densities, rays_colors
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Helper functions
|
# Helper functions
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def huber(x, y, scaling=0.1):
|
def huber(x, y, scaling=0.1):
|
||||||
"""
|
"""
|
||||||
A helper function for evaluating the smooth L1 (huber) loss
|
A helper function for evaluating the smooth L1 (huber) loss
|
||||||
between the rendered silhouettes and colors.
|
between the rendered silhouettes and colors.
|
||||||
"""
|
"""
|
||||||
diff_sq = (x - y) ** 2
|
diff_sq = (x - y) ** 2
|
||||||
loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
|
loss = ((1 + diff_sq / (scaling ** 2)).clamp(1e-4).sqrt() - 1) * float(
|
||||||
|
scaling
|
||||||
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def sample_images_at_mc_locs(target_images, sampled_rays_xy):
|
def sample_images_at_mc_locs(target_images, sampled_rays_xy):
|
||||||
"""
|
"""
|
||||||
Given a set of Monte Carlo pixel locations `sampled_rays_xy`,
|
Given a set of Monte Carlo pixel locations `sampled_rays_xy`,
|
||||||
|
@ -407,39 +405,36 @@ def sample_images_at_mc_locs(target_images, sampled_rays_xy):
|
||||||
ba = target_images.shape[0]
|
ba = target_images.shape[0]
|
||||||
dim = target_images.shape[-1]
|
dim = target_images.shape[-1]
|
||||||
spatial_size = sampled_rays_xy.shape[1:-1]
|
spatial_size = sampled_rays_xy.shape[1:-1]
|
||||||
# In order to sample target_images, we utilize
|
# In order to sample target_images, we utilize the grid_sample function
|
||||||
# the grid_sample function which implements a
|
# which implements a bilinear image sampler. Note that we have to invert
|
||||||
# bilinear image sampler.
|
# the sign of the sampled ray positions to convert the NDC xy locations of
|
||||||
# Note that we have to invert the sign of the
|
# the MonteCarloRaysampler to the coordinate convention of grid_sample.
|
||||||
# sampled ray positions to convert the NDC xy locations
|
|
||||||
# of the MonteCarloRaysampler to the coordinate
|
|
||||||
# convention of grid_sample.
|
|
||||||
images_sampled = torch.nn.functional.grid_sample(
|
images_sampled = torch.nn.functional.grid_sample(
|
||||||
target_images.permute(0, 3, 1, 2),
|
target_images.permute(0, 3, 1, 2),
|
||||||
-sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion
|
-sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion
|
||||||
align_corners=True
|
align_corners=True,
|
||||||
)
|
|
||||||
return images_sampled.permute(0, 2, 3, 1).view(
|
|
||||||
ba, *spatial_size, dim
|
|
||||||
)
|
)
|
||||||
|
return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)
|
||||||
|
|
||||||
|
|
||||||
def show_full_render(
|
def show_full_render(
|
||||||
neural_radiance_field, camera,
|
neural_radiance_field,
|
||||||
target_image, target_silhouette,
|
camera,
|
||||||
loss_history_color, loss_history_sil,
|
target_image,
|
||||||
|
target_silhouette,
|
||||||
|
loss_history_color,
|
||||||
|
loss_history_sil,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This is a helper function for visualizing the
|
This is a helper function for visualizing the intermediate results of the
|
||||||
intermediate results of the learning.
|
learning.
|
||||||
|
|
||||||
Since the `NeuralRadianceField` suffers from
|
Since the `NeuralRadianceField` suffers from a large memory footprint,
|
||||||
a large memory footprint, which does not allow to
|
which does not allow to render the full image grid in a single forward
|
||||||
render the full image grid in a single forward pass,
|
pass, we utilize the `NeuralRadianceField.batched_forward` function in
|
||||||
we utilize the `NeuralRadianceField.batched_forward`
|
combination with disabling the gradient caching. This chunks the set of
|
||||||
function in combination with disabling the gradient caching.
|
emitted rays to batches and evaluates the implicit function on one-batch at
|
||||||
This chunks the set of emitted rays to batches and
|
a time to prevent GPU memory overflow.
|
||||||
evaluates the implicit function on one-batch at a time
|
|
||||||
to prevent GPU memory overflow.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prevent gradient caching.
|
# Prevent gradient caching.
|
||||||
|
@ -448,43 +443,52 @@ def show_full_render(
|
||||||
# batched_forward function of neural_radiance_field.
|
# batched_forward function of neural_radiance_field.
|
||||||
rendered_image_silhouette, _ = renderer_grid(
|
rendered_image_silhouette, _ = renderer_grid(
|
||||||
cameras=camera,
|
cameras=camera,
|
||||||
volumetric_function=neural_radiance_field.batched_forward
|
volumetric_function=neural_radiance_field.batched_forward,
|
||||||
)
|
)
|
||||||
# Split the rendering result to a silhouette render
|
# Split the rendering result to a silhouette render
|
||||||
# and the image render.
|
# and the image render.
|
||||||
rendered_image, rendered_silhouette = (
|
rendered_image, rendered_silhouette = rendered_image_silhouette[
|
||||||
rendered_image_silhouette[0].split([3, 1], dim=-1)
|
0
|
||||||
)
|
].split([3, 1], dim=-1)
|
||||||
|
|
||||||
# Generate plots.
|
# Generate plots.
|
||||||
fig, ax = plt.subplots(2, 3, figsize=(15, 10))
|
fig, ax = plt.subplots(2, 3, figsize=(15, 10))
|
||||||
ax = ax.ravel()
|
ax = ax.ravel()
|
||||||
clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
|
clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
|
||||||
ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1)
|
ax[0].plot(
|
||||||
|
list(range(len(loss_history_color))), loss_history_color, linewidth=1
|
||||||
|
)
|
||||||
ax[1].imshow(clamp_and_detach(rendered_image))
|
ax[1].imshow(clamp_and_detach(rendered_image))
|
||||||
ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))
|
ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))
|
||||||
ax[3].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1)
|
ax[3].plot(
|
||||||
|
list(range(len(loss_history_sil))), loss_history_sil, linewidth=1
|
||||||
|
)
|
||||||
ax[4].imshow(clamp_and_detach(target_image))
|
ax[4].imshow(clamp_and_detach(target_image))
|
||||||
ax[5].imshow(clamp_and_detach(target_silhouette))
|
ax[5].imshow(clamp_and_detach(target_silhouette))
|
||||||
for ax_, title_ in zip(
|
for ax_, title_ in zip(
|
||||||
ax,
|
ax,
|
||||||
(
|
(
|
||||||
"loss color", "rendered image", "rendered silhouette",
|
"loss color",
|
||||||
"loss silhouette", "target image", "target silhouette",
|
"rendered image",
|
||||||
)
|
"rendered silhouette",
|
||||||
|
"loss silhouette",
|
||||||
|
"target image",
|
||||||
|
"target silhouette",
|
||||||
|
),
|
||||||
):
|
):
|
||||||
if not title_.startswith('loss'):
|
if not title_.startswith("loss"):
|
||||||
ax_.grid("off")
|
ax_.grid("off")
|
||||||
ax_.axis("off")
|
ax_.axis("off")
|
||||||
ax_.set_title(title_)
|
ax_.set_title(title_)
|
||||||
fig.canvas.draw(); fig.show()
|
fig.canvas.draw()
|
||||||
|
fig.show()
|
||||||
display.clear_output(wait=True)
|
display.clear_output(wait=True)
|
||||||
display.display(fig)
|
display.display(fig)
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Fit the radiance field
|
# Fit the radiance field
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
# First move all relevant variables to the correct device.
|
# First move all relevant variables to the correct device.
|
||||||
|
@ -521,7 +525,7 @@ for iteration in range(n_iter):
|
||||||
# In case we reached the last 75% of iterations,
|
# In case we reached the last 75% of iterations,
|
||||||
# decrease the learning rate of the optimizer 10-fold.
|
# decrease the learning rate of the optimizer 10-fold.
|
||||||
if iteration == round(n_iter * 0.75):
|
if iteration == round(n_iter * 0.75):
|
||||||
print('Decreasing LR 10-fold ...')
|
print("Decreasing LR 10-fold ...")
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
neural_radiance_field.parameters(), lr=lr * 0.1
|
neural_radiance_field.parameters(), lr=lr * 0.1
|
||||||
)
|
)
|
||||||
|
@ -534,47 +538,52 @@ for iteration in range(n_iter):
|
||||||
|
|
||||||
# Sample the minibatch of cameras.
|
# Sample the minibatch of cameras.
|
||||||
batch_cameras = FoVPerspectiveCameras(
|
batch_cameras = FoVPerspectiveCameras(
|
||||||
R = target_cameras.R[batch_idx],
|
R=target_cameras.R[batch_idx],
|
||||||
T = target_cameras.T[batch_idx],
|
T=target_cameras.T[batch_idx],
|
||||||
znear = target_cameras.znear[batch_idx],
|
znear=target_cameras.znear[batch_idx],
|
||||||
zfar = target_cameras.zfar[batch_idx],
|
zfar=target_cameras.zfar[batch_idx],
|
||||||
aspect_ratio = target_cameras.aspect_ratio[batch_idx],
|
aspect_ratio=target_cameras.aspect_ratio[batch_idx],
|
||||||
fov = target_cameras.fov[batch_idx],
|
fov=target_cameras.fov[batch_idx],
|
||||||
device = device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate the nerf model.
|
# Evaluate the nerf model.
|
||||||
rendered_images_silhouettes, sampled_rays = renderer_mc(
|
rendered_images_silhouettes, sampled_rays = renderer_mc(
|
||||||
cameras=batch_cameras,
|
cameras=batch_cameras, volumetric_function=neural_radiance_field
|
||||||
volumetric_function=neural_radiance_field
|
|
||||||
)
|
)
|
||||||
rendered_images, rendered_silhouettes = (
|
rendered_images, rendered_silhouettes = rendered_images_silhouettes.split(
|
||||||
rendered_images_silhouettes.split([3, 1], dim=-1)
|
[3, 1], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the silhoutte error as the mean huber
|
# Compute the silhoutte error as the mean huber
|
||||||
# loss between the predicted masks and the
|
# loss between the predicted masks and the
|
||||||
# sampled target silhouettes.
|
# sampled target silhouettes.
|
||||||
silhouettes_at_rays = sample_images_at_mc_locs(
|
silhouettes_at_rays = sample_images_at_mc_locs(
|
||||||
target_silhouettes[batch_idx, ..., None],
|
target_silhouettes[batch_idx, ..., None], sampled_rays.xys
|
||||||
sampled_rays.xys
|
)
|
||||||
|
sil_err = (
|
||||||
|
huber(
|
||||||
|
rendered_silhouettes,
|
||||||
|
silhouettes_at_rays,
|
||||||
|
)
|
||||||
|
.abs()
|
||||||
|
.mean()
|
||||||
)
|
)
|
||||||
sil_err = huber(
|
|
||||||
rendered_silhouettes,
|
|
||||||
silhouettes_at_rays,
|
|
||||||
).abs().mean()
|
|
||||||
|
|
||||||
# Compute the color error as the mean huber
|
# Compute the color error as the mean huber
|
||||||
# loss between the rendered colors and the
|
# loss between the rendered colors and the
|
||||||
# sampled target images.
|
# sampled target images.
|
||||||
colors_at_rays = sample_images_at_mc_locs(
|
colors_at_rays = sample_images_at_mc_locs(
|
||||||
target_images[batch_idx],
|
target_images[batch_idx], sampled_rays.xys
|
||||||
sampled_rays.xys
|
)
|
||||||
|
color_err = (
|
||||||
|
huber(
|
||||||
|
rendered_images,
|
||||||
|
colors_at_rays,
|
||||||
|
)
|
||||||
|
.abs()
|
||||||
|
.mean()
|
||||||
)
|
)
|
||||||
color_err = huber(
|
|
||||||
rendered_images,
|
|
||||||
colors_at_rays,
|
|
||||||
).abs().mean()
|
|
||||||
|
|
||||||
# The optimization loss is a simple
|
# The optimization loss is a simple
|
||||||
# sum of the color and silhouette errors.
|
# sum of the color and silhouette errors.
|
||||||
|
@ -587,9 +596,9 @@ for iteration in range(n_iter):
|
||||||
# Every 10 iterations, print the current values of the losses.
|
# Every 10 iterations, print the current values of the losses.
|
||||||
if iteration % 10 == 0:
|
if iteration % 10 == 0:
|
||||||
print(
|
print(
|
||||||
f'Iteration {iteration:05d}:'
|
f"Iteration {iteration:05d}:"
|
||||||
+ f' loss color = {float(color_err):1.2e}'
|
+ f" loss color = {float(color_err):1.2e}"
|
||||||
+ f' loss silhouette = {float(sil_err):1.2e}'
|
+ f" loss silhouette = {float(sil_err):1.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Take the optimization step.
|
# Take the optimization step.
|
||||||
|
@ -602,13 +611,13 @@ for iteration in range(n_iter):
|
||||||
show_full_render(
|
show_full_render(
|
||||||
neural_radiance_field,
|
neural_radiance_field,
|
||||||
FoVPerspectiveCameras(
|
FoVPerspectiveCameras(
|
||||||
R = target_cameras.R[show_idx],
|
R=target_cameras.R[show_idx],
|
||||||
T = target_cameras.T[show_idx],
|
T=target_cameras.T[show_idx],
|
||||||
znear = target_cameras.znear[show_idx],
|
znear=target_cameras.znear[show_idx],
|
||||||
zfar = target_cameras.zfar[show_idx],
|
zfar=target_cameras.zfar[show_idx],
|
||||||
aspect_ratio = target_cameras.aspect_ratio[show_idx],
|
aspect_ratio=target_cameras.aspect_ratio[show_idx],
|
||||||
fov = target_cameras.fov[show_idx],
|
fov=target_cameras.fov[show_idx],
|
||||||
device = device,
|
device=device,
|
||||||
),
|
),
|
||||||
target_images[show_idx][0],
|
target_images[show_idx][0],
|
||||||
target_silhouettes[show_idx][0],
|
target_silhouettes[show_idx][0],
|
||||||
|
@ -617,17 +626,18 @@ for iteration in range(n_iter):
|
||||||
)
|
)
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Visualizing the optimized neural radiance field
|
# Visualizing the optimized neural radiance field
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def generate_rotating_nerf(neural_radiance_field, n_frames = 50):
|
|
||||||
|
def generate_rotating_nerf(neural_radiance_field, n_frames=50):
|
||||||
logRs = torch.zeros(n_frames, 3, device=device)
|
logRs = torch.zeros(n_frames, 3, device=device)
|
||||||
logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)
|
logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)
|
||||||
Rs = so3_exponential_map(logRs)
|
Rs = so3_exponential_map(logRs)
|
||||||
Ts = torch.zeros(n_frames, 3, device=device)
|
Ts = torch.zeros(n_frames, 3, device=device)
|
||||||
Ts[:, 2] = 2.7
|
Ts[:, 2] = 2.7
|
||||||
frames = []
|
frames = []
|
||||||
print('Rendering rotating NeRF ...')
|
print("Rendering rotating NeRF ...")
|
||||||
for R, T in zip(tqdm(Rs), Ts):
|
for R, T in zip(tqdm(Rs), Ts):
|
||||||
camera = FoVPerspectiveCameras(
|
camera = FoVPerspectiveCameras(
|
||||||
R=R[None],
|
R=R[None],
|
||||||
|
@ -648,8 +658,18 @@ def generate_rotating_nerf(neural_radiance_field, n_frames = 50):
|
||||||
)
|
)
|
||||||
return torch.cat(frames)
|
return torch.cat(frames)
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, n_frames=3*5)
|
|
||||||
|
|
||||||
image_grid(rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), rows=3, cols=5, rgb=True, fill=True)
|
with torch.no_grad():
|
||||||
|
rotating_nerf_frames = generate_rotating_nerf(
|
||||||
|
neural_radiance_field, n_frames=3 * 5
|
||||||
|
)
|
||||||
|
|
||||||
|
image_grid(
|
||||||
|
rotating_nerf_frames.clamp(0.0, 1.0).cpu().numpy(),
|
||||||
|
rows=3,
|
||||||
|
cols=5,
|
||||||
|
rgb=True,
|
||||||
|
fill=True,
|
||||||
|
)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
Loading…
Reference in a new issue