🎉 initial commit
This commit is contained in:
commit
002f47dd15
17 changed files with 17995 additions and 0 deletions
BIN
data/bundle_adjustment_final.png
Normal file
BIN
data/bundle_adjustment_final.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 43 KiB |
BIN
data/bundle_adjustment_initialization.png
Normal file
BIN
data/bundle_adjustment_initialization.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
BIN
data/camera_graph.pth
Normal file
BIN
data/camera_graph.pth
Normal file
Binary file not shown.
6
data/cow_mesh/README.md
Normal file
6
data/cow_mesh/README.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
|
||||||
|
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
|
||||||
|
|
||||||
|
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
9
data/cow_mesh/cow.mtl
Executable file
9
data/cow_mesh/cow.mtl
Executable file
|
@ -0,0 +1,9 @@
|
||||||
|
newmtl material_1
|
||||||
|
map_Kd cow_texture.png
|
||||||
|
|
||||||
|
# Test colors
|
||||||
|
|
||||||
|
Ka 1.000 1.000 1.000 # white
|
||||||
|
Kd 1.000 1.000 1.000 # white
|
||||||
|
Ks 0.000 0.000 0.000 # black
|
||||||
|
Ns 10.0
|
12015
data/cow_mesh/cow.obj
Executable file
12015
data/cow_mesh/cow.obj
Executable file
File diff suppressed because it is too large
Load diff
BIN
data/cow_mesh/cow_texture.png
Executable file
BIN
data/cow_mesh/cow_texture.png
Executable file
Binary file not shown.
After Width: | Height: | Size: 77 KiB |
5046
data/teapot.obj
Normal file
5046
data/teapot.obj
Normal file
File diff suppressed because it is too large
Load diff
646
main.py
Normal file
646
main.py
Normal file
|
@ -0,0 +1,646 @@
|
||||||
|
# Inspired from Pytorch3D NeRF Tutorial
|
||||||
|
# https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/fit_simple_neural_radiance_field.ipynb
|
||||||
|
# Created on 2021-02-22
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Data structures and functions for rendering
|
||||||
|
from pytorch3d.structures import Volumes
|
||||||
|
from pytorch3d.transforms import so3_exponential_map
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
EmissionAbsorptionRaymarcher,
|
||||||
|
ImplicitRenderer,
|
||||||
|
RayBundle,
|
||||||
|
ray_bundle_to_ray_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# add path for demo utils functions
|
||||||
|
sys.path.append(os.path.abspath(''))
|
||||||
|
from utils.plot_image_grid import image_grid
|
||||||
|
from utils.generate_cow_renders import generate_cow_renders
|
||||||
|
|
||||||
|
|
||||||
|
# Intialize CUDA gpu
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Generate dataset
|
||||||
|
target_cameras, target_images, target_silhouettes = \
|
||||||
|
generate_cow_renders(num_views=40, azimuth_range=180)
|
||||||
|
print(f'Generated {len(target_images)} images/silhouettes/cameras.')
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Intitialize the implicit rendered
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
# render_size describes the size of both sides of the
|
||||||
|
# rendered images in pixels. Since an advantage of
|
||||||
|
# Neural Radiance Fields are high quality renders
|
||||||
|
# with a significant amount of details, we render
|
||||||
|
# the implicit function at double the size of
|
||||||
|
# target images.
|
||||||
|
render_size = target_images.shape[1] * 2
|
||||||
|
|
||||||
|
# Our rendered scene is centered around (0,0,0)
|
||||||
|
# and is enclosed inside a bounding box
|
||||||
|
# whose side is roughly equal to 3.0 (world units).
|
||||||
|
volume_extent_world = 3.0
|
||||||
|
|
||||||
|
# 1) Instantiate the raysamplers.
|
||||||
|
|
||||||
|
# Here, NDCGridRaysampler generates a rectangular image
|
||||||
|
# grid of rays whose coordinates follow the PyTorch3d
|
||||||
|
# coordinate conventions.
|
||||||
|
raysampler_grid = NDCGridRaysampler(
|
||||||
|
image_height=render_size,
|
||||||
|
image_width=render_size,
|
||||||
|
n_pts_per_ray=128,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=volume_extent_world,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MonteCarloRaysampler generates a random subset
|
||||||
|
# of `n_rays_per_image` rays emitted from the image plane.
|
||||||
|
raysampler_mc = MonteCarloRaysampler(
|
||||||
|
min_x = -1.0,
|
||||||
|
max_x = 1.0,
|
||||||
|
min_y = -1.0,
|
||||||
|
max_y = 1.0,
|
||||||
|
n_rays_per_image=750,
|
||||||
|
n_pts_per_ray=128,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=volume_extent_world,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Instantiate the raymarcher.
|
||||||
|
# Here, we use the standard EmissionAbsorptionRaymarcher
|
||||||
|
# which marches along each ray in order to render
|
||||||
|
# the ray into a single 3D color vector
|
||||||
|
# and an opacity scalar.
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# Finally, instantiate the implicit renders
|
||||||
|
# for both raysamplers.
|
||||||
|
renderer_grid = ImplicitRenderer(
|
||||||
|
raysampler=raysampler_grid, raymarcher=raymarcher,
|
||||||
|
)
|
||||||
|
renderer_mc = ImplicitRenderer(
|
||||||
|
raysampler=raysampler_mc, raymarcher=raymarcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Define the NeRF model
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
class HarmonicEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, n_harmonic_functions=60, omega0=0.1):
|
||||||
|
"""
|
||||||
|
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||||
|
the harmonic embedding layer converts each feature
|
||||||
|
in `x` into a series of harmonic features `embedding`
|
||||||
|
as follows:
|
||||||
|
embedding[..., i*dim:(i+1)*dim] = [
|
||||||
|
sin(x[..., i]),
|
||||||
|
sin(2*x[..., i]),
|
||||||
|
sin(4*x[..., i]),
|
||||||
|
...
|
||||||
|
sin(2**self.n_harmonic_functions * x[..., i]),
|
||||||
|
cos(x[..., i]),
|
||||||
|
cos(2*x[..., i]),
|
||||||
|
cos(4*x[..., i]),
|
||||||
|
...
|
||||||
|
cos(2**self.n_harmonic_functions * x[..., i])
|
||||||
|
]
|
||||||
|
|
||||||
|
Note that `x` is also premultiplied by `omega0` before
|
||||||
|
evaluting the harmonic functions.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer(
|
||||||
|
'frequencies',
|
||||||
|
omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: tensor of shape [..., dim]
|
||||||
|
Returns:
|
||||||
|
embedding: a harmonic embedding of `x`
|
||||||
|
of shape [..., n_harmonic_functions * dim * 2]
|
||||||
|
"""
|
||||||
|
embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
|
||||||
|
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralRadianceField(torch.nn.Module):
|
||||||
|
def __init__(self, n_harmonic_functions=60, n_hidden_neurons=256):
|
||||||
|
super().__init__()
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
n_harmonic_functions: The number of harmonic functions
|
||||||
|
used to form the harmonic embedding of each point.
|
||||||
|
n_hidden_neurons: The number of hidden units in the
|
||||||
|
fully connected layers of the MLPs of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The harmonic embedding layer converts input 3D coordinates
|
||||||
|
# to a representation that is more suitable for
|
||||||
|
# processing with a deep neural network.
|
||||||
|
self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
|
||||||
|
|
||||||
|
# The dimension of the harmonic embedding.
|
||||||
|
embedding_dim = n_harmonic_functions * 2 * 3
|
||||||
|
|
||||||
|
# self.mlp is a simple 2-layer multi-layer perceptron
|
||||||
|
# which converts the input per-point harmonic embeddings
|
||||||
|
# to a latent representation.
|
||||||
|
# Not that we use Softplus activations instead of ReLU.
|
||||||
|
self.mlp = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(embedding_dim, n_hidden_neurons),
|
||||||
|
torch.nn.Softplus(beta=10.0),
|
||||||
|
torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
|
||||||
|
torch.nn.Softplus(beta=10.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Given features predicted by self.mlp, self.color_layer
|
||||||
|
# is responsible for predicting a 3-D per-point vector
|
||||||
|
# that represents the RGB color of the point.
|
||||||
|
self.color_layer = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_hidden_neurons + embedding_dim, n_hidden_neurons),
|
||||||
|
torch.nn.Softplus(beta=10.0),
|
||||||
|
torch.nn.Linear(n_hidden_neurons, 3),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
# To ensure that the colors correctly range between [0-1],
|
||||||
|
# the layer is terminated with a sigmoid layer.
|
||||||
|
)
|
||||||
|
|
||||||
|
# The density layer converts the features of self.mlp
|
||||||
|
# to a 1D density value representing the raw opacity
|
||||||
|
# of each point.
|
||||||
|
self.density_layer = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_hidden_neurons, 1),
|
||||||
|
torch.nn.Softplus(beta=10.0),
|
||||||
|
# Sofplus activation ensures that the raw opacity
|
||||||
|
# is a non-negative number.
|
||||||
|
)
|
||||||
|
|
||||||
|
# We set the bias of the density layer to -1.5
|
||||||
|
# in order to initialize the opacities of the
|
||||||
|
# ray points to values close to 0.
|
||||||
|
# This is a crucial detail for ensuring convergence
|
||||||
|
# of the model.
|
||||||
|
self.density_layer[0].bias.data[0] = -1.5
|
||||||
|
|
||||||
|
def _get_densities(self, features):
|
||||||
|
"""
|
||||||
|
This function takes `features` predicted by `self.mlp`
|
||||||
|
and converts them to `raw_densities` with `self.density_layer`.
|
||||||
|
`raw_densities` are later mapped to [0-1] range with
|
||||||
|
1 - inverse exponential of `raw_densities`.
|
||||||
|
"""
|
||||||
|
raw_densities = self.density_layer(features)
|
||||||
|
return 1 - (-raw_densities).exp()
|
||||||
|
|
||||||
|
def _get_colors(self, features, rays_directions):
|
||||||
|
"""
|
||||||
|
This function takes per-point `features` predicted by `self.mlp`
|
||||||
|
and evaluates the color model in order to attach to each
|
||||||
|
point a 3D vector of its RGB color.
|
||||||
|
|
||||||
|
In order to represent viewpoint dependent effects,
|
||||||
|
before evaluating `self.color_layer`, `NeuralRadianceField`
|
||||||
|
concatenates to the `features` a harmonic embedding
|
||||||
|
of `ray_directions`, which are per-point directions
|
||||||
|
of point rays expressed as 3D l2-normalized vectors
|
||||||
|
in world coordinates.
|
||||||
|
"""
|
||||||
|
spatial_size = features.shape[:-1]
|
||||||
|
|
||||||
|
# Normalize the ray_directions to unit l2 norm.
|
||||||
|
rays_directions_normed = torch.nn.functional.normalize(
|
||||||
|
rays_directions, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
|
rays_embedding = self.harmonic_embedding(
|
||||||
|
rays_directions_normed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand the ray directions tensor so that its spatial size
|
||||||
|
# is equal to the size of features.
|
||||||
|
rays_embedding_expand = rays_embedding[..., None, :].expand(
|
||||||
|
*spatial_size, rays_embedding.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate ray direction embeddings with
|
||||||
|
# features and evaluate the color model.
|
||||||
|
color_layer_input = torch.cat(
|
||||||
|
(features, rays_embedding_expand),
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
return self.color_layer(color_layer_input)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ray_bundle: RayBundle,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The forward function accepts the parametrizations of
|
||||||
|
3D points sampled along projection rays. The forward
|
||||||
|
pass is responsible for attaching a 3D vector
|
||||||
|
and a 1D scalar representing the point's
|
||||||
|
RGB color and opacity respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ray_bundle: A RayBundle object containing the following variables:
|
||||||
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
|
origins of the sampling rays in world coords.
|
||||||
|
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||||
|
containing the direction vectors of sampling rays in world coords.
|
||||||
|
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||||
|
containing the lengths at which the rays are sampled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||||
|
denoting the opacitiy of each ray point.
|
||||||
|
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||||
|
denoting the color of each ray point.
|
||||||
|
"""
|
||||||
|
# We first convert the ray parametrizations to world
|
||||||
|
# coordinates with `ray_bundle_to_ray_points`.
|
||||||
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
|
# rays_points_world.shape = [minibatch x ... x 3]
|
||||||
|
|
||||||
|
# For each 3D world coordinate, we obtain its harmonic embedding.
|
||||||
|
embeds = self.harmonic_embedding(
|
||||||
|
rays_points_world
|
||||||
|
)
|
||||||
|
# embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]
|
||||||
|
|
||||||
|
# self.mlp maps each harmonic embedding to a latent feature space.
|
||||||
|
features = self.mlp(embeds)
|
||||||
|
# features.shape = [minibatch x ... x n_hidden_neurons]
|
||||||
|
|
||||||
|
# Finally, given the per-point features,
|
||||||
|
# execute the density and color branches.
|
||||||
|
|
||||||
|
rays_densities = self._get_densities(features)
|
||||||
|
# rays_densities.shape = [minibatch x ... x 1]
|
||||||
|
|
||||||
|
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||||
|
# rays_colors.shape = [minibatch x ... x 3]
|
||||||
|
|
||||||
|
return rays_densities, rays_colors
|
||||||
|
|
||||||
|
def batched_forward(
|
||||||
|
self,
|
||||||
|
ray_bundle: RayBundle,
|
||||||
|
n_batches: int = 16,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function is used to allow for memory efficient processing
|
||||||
|
of input rays. The input rays are first split to `n_batches`
|
||||||
|
chunks and passed through the `self.forward` function one at a time
|
||||||
|
in a for loop. Combined with disabling Pytorch gradient caching
|
||||||
|
(`torch.no_grad()`), this allows for rendering large batches
|
||||||
|
of rays that do not all fit into GPU memory in a single forward pass.
|
||||||
|
In our case, batched_forward is used to export a fully-sized render
|
||||||
|
of the radiance field for visualisation purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ray_bundle: A RayBundle object containing the following variables:
|
||||||
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
|
origins of the sampling rays in world coords.
|
||||||
|
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||||
|
containing the direction vectors of sampling rays in world coords.
|
||||||
|
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||||
|
containing the lengths at which the rays are sampled.
|
||||||
|
n_batches: Specifies the number of batches the input rays are split into.
|
||||||
|
The larger the number of batches, the smaller the memory footprint
|
||||||
|
and the lower the processing speed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||||
|
denoting the opacitiy of each ray point.
|
||||||
|
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||||
|
denoting the color of each ray point.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parse out shapes needed for tensor reshaping in this function.
|
||||||
|
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
||||||
|
spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]
|
||||||
|
|
||||||
|
# Split the rays to `n_batches` batches.
|
||||||
|
tot_samples = ray_bundle.origins.shape[:-1].numel()
|
||||||
|
batches = torch.chunk(torch.arange(tot_samples), n_batches)
|
||||||
|
|
||||||
|
# For each batch, execute the standard forward pass.
|
||||||
|
batch_outputs = [
|
||||||
|
self.forward(
|
||||||
|
RayBundle(
|
||||||
|
origins=ray_bundle.origins.view(-1, 3)[batch_idx],
|
||||||
|
directions=ray_bundle.directions.view(-1, 3)[batch_idx],
|
||||||
|
lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
|
||||||
|
xys=None,
|
||||||
|
)
|
||||||
|
) for batch_idx in batches
|
||||||
|
]
|
||||||
|
|
||||||
|
# Concatenate the per-batch rays_densities and rays_colors
|
||||||
|
# and reshape according to the sizes of the inputs.
|
||||||
|
rays_densities, rays_colors = [
|
||||||
|
torch.cat(
|
||||||
|
[batch_output[output_i] for batch_output in batch_outputs], dim=0
|
||||||
|
).view(*spatial_size, -1) for output_i in (0, 1)
|
||||||
|
]
|
||||||
|
return rays_densities, rays_colors
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Helper functions
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
def huber(x, y, scaling=0.1):
|
||||||
|
"""
|
||||||
|
A helper function for evaluating the smooth L1 (huber) loss
|
||||||
|
between the rendered silhouettes and colors.
|
||||||
|
"""
|
||||||
|
diff_sq = (x - y) ** 2
|
||||||
|
loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def sample_images_at_mc_locs(target_images, sampled_rays_xy):
|
||||||
|
"""
|
||||||
|
Given a set of Monte Carlo pixel locations `sampled_rays_xy`,
|
||||||
|
this method samples the tensor `target_images` at the
|
||||||
|
respective 2D locations.
|
||||||
|
|
||||||
|
This function is used in order to extract the colors from
|
||||||
|
ground truth images that correspond to the colors
|
||||||
|
rendered using `MonteCarloRaysampler`.
|
||||||
|
"""
|
||||||
|
ba = target_images.shape[0]
|
||||||
|
dim = target_images.shape[-1]
|
||||||
|
spatial_size = sampled_rays_xy.shape[1:-1]
|
||||||
|
# In order to sample target_images, we utilize
|
||||||
|
# the grid_sample function which implements a
|
||||||
|
# bilinear image sampler.
|
||||||
|
# Note that we have to invert the sign of the
|
||||||
|
# sampled ray positions to convert the NDC xy locations
|
||||||
|
# of the MonteCarloRaysampler to the coordinate
|
||||||
|
# convention of grid_sample.
|
||||||
|
images_sampled = torch.nn.functional.grid_sample(
|
||||||
|
target_images.permute(0, 3, 1, 2),
|
||||||
|
-sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion
|
||||||
|
align_corners=True
|
||||||
|
)
|
||||||
|
return images_sampled.permute(0, 2, 3, 1).view(
|
||||||
|
ba, *spatial_size, dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def show_full_render(
|
||||||
|
neural_radiance_field, camera,
|
||||||
|
target_image, target_silhouette,
|
||||||
|
loss_history_color, loss_history_sil,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This is a helper function for visualizing the
|
||||||
|
intermediate results of the learning.
|
||||||
|
|
||||||
|
Since the `NeuralRadianceField` suffers from
|
||||||
|
a large memory footprint, which does not allow to
|
||||||
|
render the full image grid in a single forward pass,
|
||||||
|
we utilize the `NeuralRadianceField.batched_forward`
|
||||||
|
function in combination with disabling the gradient caching.
|
||||||
|
This chunks the set of emitted rays to batches and
|
||||||
|
evaluates the implicit function on one-batch at a time
|
||||||
|
to prevent GPU memory overflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prevent gradient caching.
|
||||||
|
with torch.no_grad():
|
||||||
|
# Render using the grid renderer and the
|
||||||
|
# batched_forward function of neural_radiance_field.
|
||||||
|
rendered_image_silhouette, _ = renderer_grid(
|
||||||
|
cameras=camera,
|
||||||
|
volumetric_function=neural_radiance_field.batched_forward
|
||||||
|
)
|
||||||
|
# Split the rendering result to a silhouette render
|
||||||
|
# and the image render.
|
||||||
|
rendered_image, rendered_silhouette = (
|
||||||
|
rendered_image_silhouette[0].split([3, 1], dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate plots.
|
||||||
|
fig, ax = plt.subplots(2, 3, figsize=(15, 10))
|
||||||
|
ax = ax.ravel()
|
||||||
|
clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
|
||||||
|
ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1)
|
||||||
|
ax[1].imshow(clamp_and_detach(rendered_image))
|
||||||
|
ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))
|
||||||
|
ax[3].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1)
|
||||||
|
ax[4].imshow(clamp_and_detach(target_image))
|
||||||
|
ax[5].imshow(clamp_and_detach(target_silhouette))
|
||||||
|
for ax_, title_ in zip(
|
||||||
|
ax,
|
||||||
|
(
|
||||||
|
"loss color", "rendered image", "rendered silhouette",
|
||||||
|
"loss silhouette", "target image", "target silhouette",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if not title_.startswith('loss'):
|
||||||
|
ax_.grid("off")
|
||||||
|
ax_.axis("off")
|
||||||
|
ax_.set_title(title_)
|
||||||
|
fig.canvas.draw(); fig.show()
|
||||||
|
display.clear_output(wait=True)
|
||||||
|
display.display(fig)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Fit the radiance field
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
# First move all relevant variables to the correct device.
|
||||||
|
renderer_grid = renderer_grid.to(device)
|
||||||
|
renderer_mc = renderer_mc.to(device)
|
||||||
|
target_cameras = target_cameras.to(device)
|
||||||
|
target_images = target_images.to(device)
|
||||||
|
target_silhouettes = target_silhouettes.to(device)
|
||||||
|
|
||||||
|
# Set the seed for reproducibility
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
# Instantiate the radiance field model.
|
||||||
|
neural_radiance_field = NeuralRadianceField().to(device)
|
||||||
|
|
||||||
|
# Instantiate the Adam optimizer. We set its master learning rate to 1e-3.
|
||||||
|
lr = 1e-3
|
||||||
|
optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# We sample 6 random cameras in a minibatch. Each camera
|
||||||
|
# emits raysampler_mc.n_pts_per_image rays.
|
||||||
|
batch_size = 6
|
||||||
|
|
||||||
|
# 3000 iterations take ~20 min on a Tesla M40 and lead to
|
||||||
|
# reasonably sharp results. However, for the best possible
|
||||||
|
# results, we recommend setting n_iter=20000.
|
||||||
|
n_iter = 3000
|
||||||
|
|
||||||
|
# Init the loss history buffers.
|
||||||
|
loss_history_color, loss_history_sil = [], []
|
||||||
|
|
||||||
|
# The main optimization loop.
|
||||||
|
for iteration in range(n_iter):
|
||||||
|
# In case we reached the last 75% of iterations,
|
||||||
|
# decrease the learning rate of the optimizer 10-fold.
|
||||||
|
if iteration == round(n_iter * 0.75):
|
||||||
|
print('Decreasing LR 10-fold ...')
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
neural_radiance_field.parameters(), lr=lr * 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Zero the optimizer gradient.
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Sample random batch indices.
|
||||||
|
batch_idx = torch.randperm(len(target_cameras))[:batch_size]
|
||||||
|
|
||||||
|
# Sample the minibatch of cameras.
|
||||||
|
batch_cameras = FoVPerspectiveCameras(
|
||||||
|
R = target_cameras.R[batch_idx],
|
||||||
|
T = target_cameras.T[batch_idx],
|
||||||
|
znear = target_cameras.znear[batch_idx],
|
||||||
|
zfar = target_cameras.zfar[batch_idx],
|
||||||
|
aspect_ratio = target_cameras.aspect_ratio[batch_idx],
|
||||||
|
fov = target_cameras.fov[batch_idx],
|
||||||
|
device = device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate the nerf model.
|
||||||
|
rendered_images_silhouettes, sampled_rays = renderer_mc(
|
||||||
|
cameras=batch_cameras,
|
||||||
|
volumetric_function=neural_radiance_field
|
||||||
|
)
|
||||||
|
rendered_images, rendered_silhouettes = (
|
||||||
|
rendered_images_silhouettes.split([3, 1], dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the silhoutte error as the mean huber
|
||||||
|
# loss between the predicted masks and the
|
||||||
|
# sampled target silhouettes.
|
||||||
|
silhouettes_at_rays = sample_images_at_mc_locs(
|
||||||
|
target_silhouettes[batch_idx, ..., None],
|
||||||
|
sampled_rays.xys
|
||||||
|
)
|
||||||
|
sil_err = huber(
|
||||||
|
rendered_silhouettes,
|
||||||
|
silhouettes_at_rays,
|
||||||
|
).abs().mean()
|
||||||
|
|
||||||
|
# Compute the color error as the mean huber
|
||||||
|
# loss between the rendered colors and the
|
||||||
|
# sampled target images.
|
||||||
|
colors_at_rays = sample_images_at_mc_locs(
|
||||||
|
target_images[batch_idx],
|
||||||
|
sampled_rays.xys
|
||||||
|
)
|
||||||
|
color_err = huber(
|
||||||
|
rendered_images,
|
||||||
|
colors_at_rays,
|
||||||
|
).abs().mean()
|
||||||
|
|
||||||
|
# The optimization loss is a simple
|
||||||
|
# sum of the color and silhouette errors.
|
||||||
|
loss = color_err + sil_err
|
||||||
|
|
||||||
|
# Log the loss history.
|
||||||
|
loss_history_color.append(float(color_err))
|
||||||
|
loss_history_sil.append(float(sil_err))
|
||||||
|
|
||||||
|
# Every 10 iterations, print the current values of the losses.
|
||||||
|
if iteration % 10 == 0:
|
||||||
|
print(
|
||||||
|
f'Iteration {iteration:05d}:'
|
||||||
|
+ f' loss color = {float(color_err):1.2e}'
|
||||||
|
+ f' loss silhouette = {float(sil_err):1.2e}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Take the optimization step.
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Visualize the full renders every 100 iterations.
|
||||||
|
if iteration % 100 == 0:
|
||||||
|
show_idx = torch.randperm(len(target_cameras))[:1]
|
||||||
|
show_full_render(
|
||||||
|
neural_radiance_field,
|
||||||
|
FoVPerspectiveCameras(
|
||||||
|
R = target_cameras.R[show_idx],
|
||||||
|
T = target_cameras.T[show_idx],
|
||||||
|
znear = target_cameras.znear[show_idx],
|
||||||
|
zfar = target_cameras.zfar[show_idx],
|
||||||
|
aspect_ratio = target_cameras.aspect_ratio[show_idx],
|
||||||
|
fov = target_cameras.fov[show_idx],
|
||||||
|
device = device,
|
||||||
|
),
|
||||||
|
target_images[show_idx][0],
|
||||||
|
target_silhouettes[show_idx][0],
|
||||||
|
loss_history_color,
|
||||||
|
loss_history_sil,
|
||||||
|
)
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Visualizing the optimized neural radiance field
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
def generate_rotating_nerf(neural_radiance_field, n_frames = 50):
|
||||||
|
logRs = torch.zeros(n_frames, 3, device=device)
|
||||||
|
logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)
|
||||||
|
Rs = so3_exponential_map(logRs)
|
||||||
|
Ts = torch.zeros(n_frames, 3, device=device)
|
||||||
|
Ts[:, 2] = 2.7
|
||||||
|
frames = []
|
||||||
|
print('Rendering rotating NeRF ...')
|
||||||
|
for R, T in zip(tqdm(Rs), Ts):
|
||||||
|
camera = FoVPerspectiveCameras(
|
||||||
|
R=R[None],
|
||||||
|
T=T[None],
|
||||||
|
znear=target_cameras.znear[0],
|
||||||
|
zfar=target_cameras.zfar[0],
|
||||||
|
aspect_ratio=target_cameras.aspect_ratio[0],
|
||||||
|
fov=target_cameras.fov[0],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Note that we again render with `NDCGridSampler`
|
||||||
|
# and the batched_forward function of neural_radiance_field.
|
||||||
|
frames.append(
|
||||||
|
renderer_grid(
|
||||||
|
cameras=camera,
|
||||||
|
volumetric_function=neural_radiance_field.batched_forward,
|
||||||
|
)[0][..., :3]
|
||||||
|
)
|
||||||
|
return torch.cat(frames)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, n_frames=3*5)
|
||||||
|
|
||||||
|
image_grid(rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), rows=3, cols=5, rgb=True, fill=True)
|
||||||
|
plt.show()
|
4
utils/__init__.py
Normal file
4
utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from .camera_visualization import get_camera_wireframe, plot_camera_scene, plot_cameras
|
||||||
|
from .plot_image_grid import image_grid
|
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/camera_visualization.cpython-38.pyc
Normal file
BIN
utils/__pycache__/camera_visualization.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/generate_cow_renders.cpython-38.pyc
Normal file
BIN
utils/__pycache__/generate_cow_renders.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/plot_image_grid.cpython-38.pyc
Normal file
BIN
utils/__pycache__/plot_image_grid.cpython-38.pyc
Normal file
Binary file not shown.
55
utils/camera_visualization.py
Normal file
55
utils/camera_visualization.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
|
||||||
|
from pytorch3d.vis.plotly_vis import get_camera_wireframe
|
||||||
|
|
||||||
|
|
||||||
|
def plot_cameras(ax, cameras, color: str = "blue"):
|
||||||
|
"""
|
||||||
|
Plots a set of `cameras` objects into the maplotlib axis `ax` with
|
||||||
|
color `color`.
|
||||||
|
"""
|
||||||
|
cam_wires_canonical = get_camera_wireframe().cuda()[None]
|
||||||
|
cam_trans = cameras.get_world_to_view_transform().inverse()
|
||||||
|
cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
|
||||||
|
plot_handles = []
|
||||||
|
for wire in cam_wires_trans:
|
||||||
|
# the Z and Y axes are flipped intentionally here!
|
||||||
|
x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
|
||||||
|
(h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3)
|
||||||
|
plot_handles.append(h)
|
||||||
|
return plot_handles
|
||||||
|
|
||||||
|
|
||||||
|
def plot_camera_scene(cameras, cameras_gt, status: str):
|
||||||
|
"""
|
||||||
|
Plots a set of predicted cameras `cameras` and their corresponding
|
||||||
|
ground truth locations `cameras_gt`. The plot is named with
|
||||||
|
a string passed inside the `status` argument.
|
||||||
|
"""
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.gca(projection="3d")
|
||||||
|
ax.clear()
|
||||||
|
ax.set_title(status)
|
||||||
|
handle_cam = plot_cameras(ax, cameras, color="#FF7D1E")
|
||||||
|
handle_cam_gt = plot_cameras(ax, cameras_gt, color="#812CE5")
|
||||||
|
plot_radius = 3
|
||||||
|
ax.set_xlim3d([-plot_radius, plot_radius])
|
||||||
|
ax.set_ylim3d([3 - plot_radius, 3 + plot_radius])
|
||||||
|
ax.set_zlim3d([-plot_radius, plot_radius])
|
||||||
|
ax.set_xlabel("x")
|
||||||
|
ax.set_ylabel("z")
|
||||||
|
ax.set_zlabel("y")
|
||||||
|
labels_handles = {
|
||||||
|
"Estimated cameras": handle_cam[0],
|
||||||
|
"GT cameras": handle_cam_gt[0],
|
||||||
|
}
|
||||||
|
ax.legend(
|
||||||
|
labels_handles.values(),
|
||||||
|
labels_handles.keys(),
|
||||||
|
loc="upper center",
|
||||||
|
bbox_to_anchor=(0.5, 0),
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
return fig
|
165
utils/generate_cow_renders.py
Normal file
165
utils/generate_cow_renders.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Util function for loading meshes
|
||||||
|
from pytorch3d.io import load_objs_as_meshes
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
BlendParams,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
SoftPhongShader,
|
||||||
|
SoftSilhouetteShader,
|
||||||
|
look_at_view_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# create the default data directory
|
||||||
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
DATA_DIR = os.path.join(current_dir, "..", "data", "cow_mesh")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cow_renders(
|
||||||
|
num_views: int = 40, data_dir: str = DATA_DIR, azimuth_range: float = 180
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function generates `num_views` renders of a cow mesh.
|
||||||
|
The renders are generated from viewpoints sampled at uniformly distributed
|
||||||
|
azimuth intervals. The elevation is kept constant so that the camera's
|
||||||
|
vertical position coincides with the equator.
|
||||||
|
|
||||||
|
For a more detailed explanation of this code, please refer to the
|
||||||
|
docs/tutorials/fit_textured_mesh.ipynb notebook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_views: The number of generated renders.
|
||||||
|
data_dir: The folder that contains the cow mesh files. If the cow mesh
|
||||||
|
files do not exist in the folder, this function will automatically
|
||||||
|
download them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
||||||
|
images are rendered.
|
||||||
|
images: A tensor of shape `(num_views, height, width, 3)` containing
|
||||||
|
the rendered images.
|
||||||
|
silhouettes: A tensor of shape `(num_views, height, width)` containing
|
||||||
|
the rendered silhouettes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# set the paths
|
||||||
|
|
||||||
|
# download the cow mesh if not done before
|
||||||
|
cow_mesh_files = [
|
||||||
|
os.path.join(data_dir, fl) for fl in ("cow.obj", "cow.mtl", "cow_texture.png")
|
||||||
|
]
|
||||||
|
if any(not os.path.isfile(f) for f in cow_mesh_files):
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
os.system(
|
||||||
|
f"wget -P {data_dir} "
|
||||||
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj"
|
||||||
|
)
|
||||||
|
os.system(
|
||||||
|
f"wget -P {data_dir} "
|
||||||
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl"
|
||||||
|
)
|
||||||
|
os.system(
|
||||||
|
f"wget -P {data_dir} "
|
||||||
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Load obj file
|
||||||
|
obj_filename = os.path.join(data_dir, "cow.obj")
|
||||||
|
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||||
|
|
||||||
|
# We scale normalize and center the target mesh to fit in a sphere of radius 1
|
||||||
|
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
|
||||||
|
# to its original center and scale. Note that normalizing the target mesh,
|
||||||
|
# speeds up the optimization but is not necessary!
|
||||||
|
verts = mesh.verts_packed()
|
||||||
|
N = verts.shape[0]
|
||||||
|
center = verts.mean(0)
|
||||||
|
scale = max((verts - center).abs().max(0)[0])
|
||||||
|
mesh.offset_verts_(-(center.expand(N, 3)))
|
||||||
|
mesh.scale_verts_((1.0 / float(scale)))
|
||||||
|
|
||||||
|
# Get a batch of viewing angles.
|
||||||
|
elev = torch.linspace(0, 0, num_views) # keep constant
|
||||||
|
azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0
|
||||||
|
|
||||||
|
# Place a point light in front of the object. As mentioned above, the front of
|
||||||
|
# the cow is facing the -z direction.
|
||||||
|
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
|
||||||
|
|
||||||
|
# Initialize an OpenGL perspective camera that represents a batch of different
|
||||||
|
# viewing angles. All the cameras helper methods support mixed type inputs and
|
||||||
|
# broadcasting. So we can view the camera from the a distance of dist=2.7, and
|
||||||
|
# then specify elevation and azimuth angles for each viewpoint as tensors.
|
||||||
|
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
|
||||||
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
|
# Define the settings for rasterization and shading. Here we set the output
|
||||||
|
# image to be of size 128X128. As we are rendering images for visualization
|
||||||
|
# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
|
||||||
|
# rasterize_meshes.py for explanations of these parameters. We also leave
|
||||||
|
# bin_size and max_faces_per_bin to their default values of None, which sets
|
||||||
|
# their values using huristics and ensures that the faster coarse-to-fine
|
||||||
|
# rasterization method is used. Refer to docs/notes/renderer.md for an
|
||||||
|
# explanation of the difference between naive and coarse-to-fine rasterization.
|
||||||
|
raster_settings = RasterizationSettings(
|
||||||
|
image_size=128, blur_radius=0.0, faces_per_pixel=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a phong renderer by composing a rasterizer and a shader. The textured
|
||||||
|
# phong shader will interpolate the texture uv coordinates for each vertex,
|
||||||
|
# sample from a texture image and apply the Phong lighting model
|
||||||
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||||
|
shader=SoftPhongShader(
|
||||||
|
device=device, cameras=cameras, lights=lights, blend_params=blend_params
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a batch of meshes by repeating the cow mesh and associated textures.
|
||||||
|
# Meshes has a useful `extend` method which allows us do this very easily.
|
||||||
|
# This also extends the textures.
|
||||||
|
meshes = mesh.extend(num_views)
|
||||||
|
|
||||||
|
# Render the cow mesh from each viewing angle
|
||||||
|
target_images = renderer(meshes, cameras=cameras, lights=lights)
|
||||||
|
|
||||||
|
# Rasterization settings for silhouette rendering
|
||||||
|
sigma = 1e-4
|
||||||
|
raster_settings_silhouette = RasterizationSettings(
|
||||||
|
image_size=128, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Silhouette renderer
|
||||||
|
renderer_silhouette = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(
|
||||||
|
cameras=cameras, raster_settings=raster_settings_silhouette
|
||||||
|
),
|
||||||
|
shader=SoftSilhouetteShader(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render silhouette images. The 3rd channel of the rendering output is
|
||||||
|
# the alpha/silhouette channel
|
||||||
|
silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights)
|
||||||
|
|
||||||
|
# binary silhouettes
|
||||||
|
silhouette_binary = (silhouette_images[..., 3] > 1e-4).float()
|
||||||
|
|
||||||
|
return cameras, target_images[..., :3], silhouette_binary
|
49
utils/plot_image_grid.py
Normal file
49
utils/plot_image_grid.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def image_grid(
|
||||||
|
images,
|
||||||
|
rows=None,
|
||||||
|
cols=None,
|
||||||
|
fill: bool = True,
|
||||||
|
show_axes: bool = False,
|
||||||
|
rgb: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A util function for plotting a grid of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: (N, H, W, 4) array of RGBA images
|
||||||
|
rows: number of rows in the grid
|
||||||
|
cols: number of columns in the grid
|
||||||
|
fill: boolean indicating if the space between images should be filled
|
||||||
|
show_axes: boolean indicating if the axes of the plots should be visible
|
||||||
|
rgb: boolean, If True, only RGB channels are plotted.
|
||||||
|
If False, only the alpha channel is plotted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
if (rows is None) != (cols is None):
|
||||||
|
raise ValueError("Specify either both rows and cols or neither.")
|
||||||
|
|
||||||
|
if rows is None:
|
||||||
|
rows = len(images)
|
||||||
|
cols = 1
|
||||||
|
|
||||||
|
gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
|
||||||
|
fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
|
||||||
|
bleed = 0
|
||||||
|
fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
|
||||||
|
|
||||||
|
for ax, im in zip(axarr.ravel(), images):
|
||||||
|
if rgb:
|
||||||
|
# only render RGB channels
|
||||||
|
ax.imshow(im[..., :3])
|
||||||
|
else:
|
||||||
|
# only render Alpha channel
|
||||||
|
ax.imshow(im[..., 3])
|
||||||
|
if not show_axes:
|
||||||
|
ax.set_axis_off()
|
Loading…
Reference in a new issue