nerf_plus_plus/ddp_train_nerf.py

974 lines
32 KiB
Python
Raw Normal View History

2021-06-10 17:03:06 +02:00
import json
import logging
2020-10-12 03:33:31 +02:00
import os
import time
2021-06-10 17:03:06 +02:00
from collections import OrderedDict
2020-10-12 03:33:31 +02:00
import numpy as np
2021-06-10 17:03:06 +02:00
import torch
import torch.distributed
import torch.multiprocessing
import torch.nn as nn
import torch.optim
2020-10-12 03:33:31 +02:00
from tensorboardX import SummaryWriter
2021-06-10 17:03:06 +02:00
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
2020-10-12 17:05:53 +02:00
2021-06-10 17:03:06 +02:00
from data_loader_split import load_data_split
from ddp_model import NerfNetWithAutoExpo
from utils import TINY_NUMBER, colorize, img2mse, img_HWC2CHW, mse2psnr
2020-10-12 04:11:56 +02:00
2020-10-12 03:33:31 +02:00
logger = logging.getLogger(__package__)
def setup_logger():
# create logger
logger = logging.getLogger(__package__)
# logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
2021-06-10 17:03:06 +02:00
formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
2020-10-12 03:33:31 +02:00
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
def intersect_sphere(ray_o, ray_d):
2021-06-10 17:03:06 +02:00
"""
2020-10-12 03:33:31 +02:00
ray_o, ray_d: [..., 3]
compute the depth of the intersection point between this ray and unit sphere
2021-06-10 17:03:06 +02:00
"""
2020-10-12 03:33:31 +02:00
# note: d1 becomes negative if this mid point is behind camera
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
p = ray_o + d1.unsqueeze(-1) * ray_d
# consider the case where the ray does not intersect the sphere
2021-06-10 17:03:06 +02:00
ray_d_cos = 1.0 / torch.norm(ray_d, dim=-1)
p_norm_sq = torch.sum(p * p, dim=-1)
2021-06-10 17:03:06 +02:00
if (p_norm_sq >= 1.0).any():
raise Exception(
"Not all your cameras are bounded by the unit sphere; "
"please make sure the cameras are normalized properly!"
)
d2 = torch.sqrt(1.0 - p_norm_sq) * ray_d_cos
2020-10-12 03:33:31 +02:00
return d1 + d2
def perturb_samples(z_vals):
# get intervals between samples
2021-06-10 17:03:06 +02:00
mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
2020-10-12 03:33:31 +02:00
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
# uniform samples in those intervals
t_rand = torch.rand_like(z_vals)
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
return z_vals
def sample_pdf(bins, weights, N_samples, det=False):
2021-06-10 17:03:06 +02:00
"""
2020-10-12 03:33:31 +02:00
:param bins: tensor of shape [..., M+1], M is the number of bins
:param weights: tensor of shape [..., M]
:param N_samples: number of samples along each ray
:param det: if True, will perform deterministic sampling
:return: [..., N_samples]
2021-06-10 17:03:06 +02:00
"""
2020-10-12 03:33:31 +02:00
# Get pdf
2021-06-10 17:03:06 +02:00
weights = weights + TINY_NUMBER # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
cdf = torch.cat(
[torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1
) # [..., M+1]
2020-10-12 03:33:31 +02:00
# Take uniform samples
dots_sh = list(weights.shape[:-1])
M = weights.shape[-1]
min_cdf = 0.00
2021-06-10 17:03:06 +02:00
max_cdf = 1.00 # prevent outlier samples
2020-10-12 03:33:31 +02:00
if det:
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
2021-06-10 17:03:06 +02:00
u = u.view([1] * len(dots_sh) + [N_samples]).expand(
dots_sh
+ [
N_samples,
]
) # [..., N_samples]
2020-10-12 03:33:31 +02:00
else:
sh = dots_sh + [N_samples]
2021-06-10 17:03:06 +02:00
u = (
torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf
) # [..., N_samples]
2020-10-12 03:33:31 +02:00
# Invert CDF
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
2021-06-10 17:03:06 +02:00
above_inds = torch.sum(
u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1
).long()
2020-10-12 03:33:31 +02:00
# random sample inside each bin
2021-06-10 17:03:06 +02:00
below_inds = torch.clamp(above_inds - 1, min=0)
inds_g = torch.stack(
(below_inds, above_inds), dim=-1
) # [..., N_samples, 2]
cdf = cdf.unsqueeze(-2).expand(
dots_sh + [N_samples, M + 1]
) # [..., N_samples, M+1]
cdf_g = torch.gather(
input=cdf, dim=-1, index=inds_g
) # [..., N_samples, 2]
bins = bins.unsqueeze(-2).expand(
dots_sh + [N_samples, M + 1]
) # [..., N_samples, M+1]
bins_g = torch.gather(
input=bins, dim=-1, index=inds_g
) # [..., N_samples, 2]
2020-10-12 03:33:31 +02:00
# fix numeric issue
2021-06-10 17:03:06 +02:00
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
denom = torch.where(denom < TINY_NUMBER, torch.ones_like(denom), denom)
2020-10-12 03:33:31 +02:00
t = (u - cdf_g[..., 0]) / denom
2021-06-10 17:03:06 +02:00
samples = bins_g[..., 0] + t * (
bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER
)
2020-10-12 03:33:31 +02:00
return samples
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
##### parallel rendering of a single image
ray_batch = ray_sampler.get_all()
2021-06-10 17:03:06 +02:00
if (ray_batch["ray_d"].shape[0] // world_size) * world_size != ray_batch[
"ray_d"
].shape[0]:
raise Exception(
"Number of pixels in the image is not divisible by the number of GPUs!\n\t# pixels: {}\n\t# GPUs: {}".format(
ray_batch["ray_d"].shape[0], world_size
)
)
2020-10-12 03:33:31 +02:00
# split into ranks; make sure different processes don't overlap
2021-06-10 17:03:06 +02:00
rank_split_sizes = [
ray_batch["ray_d"].shape[0] // world_size,
] * world_size
rank_split_sizes[-1] = ray_batch["ray_d"].shape[0] - sum(
rank_split_sizes[:-1]
)
2020-10-12 03:33:31 +02:00
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
2021-06-10 17:03:06 +02:00
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[
rank
].to(rank)
2020-10-12 03:33:31 +02:00
# split into chunks and render inside each process
ray_batch_split = OrderedDict()
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
# forward and backward
2021-06-10 17:03:06 +02:00
ret_merge_chunk = [OrderedDict() for _ in range(models["cascade_level"])]
for s in range(len(ray_batch_split["ray_d"])):
ray_o = ray_batch_split["ray_o"][s]
ray_d = ray_batch_split["ray_d"][s]
min_depth = ray_batch_split["min_depth"][s]
2020-10-12 03:33:31 +02:00
dots_sh = list(ray_d.shape[:-1])
2021-06-10 17:03:06 +02:00
for m in range(models["cascade_level"]):
net = models["net_{}".format(m)]
2020-10-12 03:33:31 +02:00
# sample depths
2021-06-10 17:03:06 +02:00
N_samples = models["cascade_samples"][m]
2020-10-12 03:33:31 +02:00
if m == 0:
# foreground depth
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
2020-10-16 04:12:48 +02:00
fg_near_depth = min_depth # [..., ]
2020-10-12 03:33:31 +02:00
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
2021-06-10 17:03:06 +02:00
fg_depth = torch.stack(
[fg_near_depth + i * step for i in range(N_samples)],
dim=-1,
) # [..., N_samples]
2020-10-12 03:33:31 +02:00
# background depth
2021-06-10 17:03:06 +02:00
bg_depth = (
torch.linspace(0.0, 1.0, N_samples)
.view(
[
1,
]
* len(dots_sh)
+ [
N_samples,
]
)
.expand(
dots_sh
+ [
N_samples,
]
)
.to(rank)
)
2020-10-12 03:33:31 +02:00
# delete unused memory
del fg_near_depth
del step
torch.cuda.empty_cache()
else:
# sample pdf and concat with earlier samples
2021-06-10 17:03:06 +02:00
fg_weights = ret["fg_weights"].clone().detach()
fg_depth_mid = 0.5 * (
fg_depth[..., 1:] + fg_depth[..., :-1]
) # [..., N_samples-1]
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
fg_depth_samples = sample_pdf(
bins=fg_depth_mid,
weights=fg_weights,
N_samples=N_samples,
det=True,
) # [..., N_samples]
fg_depth, _ = torch.sort(
torch.cat((fg_depth, fg_depth_samples), dim=-1)
)
2020-10-12 03:33:31 +02:00
# sample pdf and concat with earlier samples
2021-06-10 17:03:06 +02:00
bg_weights = ret["bg_weights"].clone().detach()
bg_depth_mid = 0.5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
bg_depth_samples = sample_pdf(
bins=bg_depth_mid,
weights=bg_weights,
N_samples=N_samples,
det=True,
) # [..., N_samples]
bg_depth, _ = torch.sort(
torch.cat((bg_depth, bg_depth_samples), dim=-1)
)
2020-10-12 03:33:31 +02:00
# delete unused memory
del fg_weights
del fg_depth_mid
del fg_depth_samples
del bg_weights
del bg_depth_mid
del bg_depth_samples
torch.cuda.empty_cache()
with torch.no_grad():
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
for key in ret:
2021-06-10 17:03:06 +02:00
if key not in ["fg_weights", "bg_weights"]:
2020-10-12 03:33:31 +02:00
if torch.is_tensor(ret[key]):
if key not in ret_merge_chunk[m]:
2021-06-10 17:03:06 +02:00
ret_merge_chunk[m][key] = [
ret[key].cpu(),
]
2020-10-12 03:33:31 +02:00
else:
ret_merge_chunk[m][key].append(ret[key].cpu())
ret[key] = None
# clean unused memory
torch.cuda.empty_cache()
# merge results from different chunks
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
# merge results from different processes
if rank == 0:
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
# generate tensors to store results from other processes
sh = list(ret_merge_chunk[m][key].shape[1:])
2021-06-10 17:03:06 +02:00
ret_merge_rank[m][key] = [
torch.zeros(
*[
size,
]
+ sh,
dtype=torch.float32
)
for size in rank_split_sizes
]
torch.distributed.gather(
ret_merge_chunk[m][key], ret_merge_rank[m][key]
)
ret_merge_rank[m][key] = (
torch.cat(ret_merge_rank[m][key], dim=0)
.reshape((ray_sampler.H, ray_sampler.W, -1))
.squeeze()
)
2020-10-12 03:33:31 +02:00
# print(m, key, ret_merge_rank[m][key].shape)
else: # send results to main process
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
torch.distributed.gather(ret_merge_chunk[m][key])
# only rank 0 program returns
if rank == 0:
return ret_merge_rank
else:
return None
2021-06-10 17:03:06 +02:00
def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=""):
2020-10-12 03:33:31 +02:00
rgb_im = img_HWC2CHW(torch.from_numpy(gt_img))
2021-06-10 17:03:06 +02:00
writer.add_image(prefix + "rgb_gt", rgb_im, global_step)
2020-10-12 03:33:31 +02:00
for m in range(len(log_data)):
2021-06-10 17:03:06 +02:00
rgb_im = img_HWC2CHW(log_data[m]["rgb"])
rgb_im = torch.clamp(
rgb_im, min=0.0, max=1.0
) # just in case diffuse+specular>1
writer.add_image(
prefix + "level_{}/rgb".format(m), rgb_im, global_step
)
rgb_im = img_HWC2CHW(log_data[m]["fg_rgb"])
rgb_im = torch.clamp(
rgb_im, min=0.0, max=1.0
) # just in case diffuse+specular>1
writer.add_image(
prefix + "level_{}/fg_rgb".format(m), rgb_im, global_step
)
depth = log_data[m]["fg_depth"]
depth_im = img_HWC2CHW(
colorize(depth, cmap_name="jet", append_cbar=True, mask=mask)
)
writer.add_image(
prefix + "level_{}/fg_depth".format(m), depth_im, global_step
)
rgb_im = img_HWC2CHW(log_data[m]["bg_rgb"])
rgb_im = torch.clamp(
rgb_im, min=0.0, max=1.0
) # just in case diffuse+specular>1
writer.add_image(
prefix + "level_{}/bg_rgb".format(m), rgb_im, global_step
)
depth = log_data[m]["bg_depth"]
depth_im = img_HWC2CHW(
colorize(depth, cmap_name="jet", append_cbar=True, mask=mask)
)
writer.add_image(
prefix + "level_{}/bg_depth".format(m), depth_im, global_step
)
bg_lambda = log_data[m]["bg_lambda"]
bg_lambda_im = img_HWC2CHW(
colorize(bg_lambda, cmap_name="hot", append_cbar=True, mask=mask)
)
writer.add_image(
prefix + "level_{}/bg_lambda".format(m), bg_lambda_im, global_step
)
2020-10-12 03:33:31 +02:00
def setup(rank, world_size):
2021-06-10 17:03:06 +02:00
os.environ["MASTER_ADDR"] = "localhost"
2020-10-12 03:33:31 +02:00
# port = np.random.randint(12355, 12399)
# os.environ['MASTER_PORT'] = '{}'.format(port)
2021-06-10 17:03:06 +02:00
os.environ["MASTER_PORT"] = "12355"
2020-10-12 03:33:31 +02:00
# initialize the process group
2021-06-10 17:03:06 +02:00
torch.distributed.init_process_group(
"gloo", rank=rank, world_size=world_size
)
2020-10-12 03:33:31 +02:00
def cleanup():
torch.distributed.destroy_process_group()
2020-10-12 17:05:53 +02:00
def create_nerf(rank, args):
2020-10-12 03:33:31 +02:00
###### create network and wrap in ddp; each process should do this
# fix random seed just to make sure the network is initialized with same weights at different processes
torch.manual_seed(777)
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
torch.cuda.set_device(rank)
models = OrderedDict()
2021-06-10 17:03:06 +02:00
models["cascade_level"] = args.cascade_level
models["cascade_samples"] = [
int(x.strip()) for x in args.cascade_samples.split(",")
]
for m in range(models["cascade_level"]):
2020-10-12 17:05:53 +02:00
img_names = None
if args.optim_autoexpo:
# load training image names for autoexposure
2021-06-10 17:03:06 +02:00
f = os.path.join(args.basedir, args.expname, "train_images.json")
2020-10-12 17:05:53 +02:00
with open(f) as file:
img_names = json.load(file)
2021-06-10 17:03:06 +02:00
net = NerfNetWithAutoExpo(
args, optim_autoexpo=args.optim_autoexpo, img_names=img_names
).to(rank)
net = DDP(
net,
device_ids=[rank],
output_device=rank,
find_unused_parameters=True,
)
2020-10-12 17:05:53 +02:00
# net = DDP(net, device_ids=[rank], output_device=rank)
2020-10-12 03:33:31 +02:00
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
2021-06-10 17:03:06 +02:00
models["net_{}".format(m)] = net
models["optim_{}".format(m)] = optim
2020-10-12 03:33:31 +02:00
start = -1
###### load pretrained weights; each process should do this
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
ckpts = [args.ckpt_path]
else:
2021-06-10 17:03:06 +02:00
ckpts = [
os.path.join(args.basedir, args.expname, f)
for f in sorted(
os.listdir(os.path.join(args.basedir, args.expname))
)
if f.endswith(".pth")
]
2020-10-12 03:33:31 +02:00
def path2iter(path):
tmp = os.path.basename(path)[:-4]
2021-06-10 17:03:06 +02:00
idx = tmp.rfind("_")
return int(tmp[idx + 1 :])
2020-10-12 03:33:31 +02:00
ckpts = sorted(ckpts, key=path2iter)
2021-06-10 17:03:06 +02:00
logger.info("Found ckpts: {}".format(ckpts))
2020-10-12 03:33:31 +02:00
if len(ckpts) > 0 and not args.no_reload:
fpath = ckpts[-1]
2021-06-10 17:03:06 +02:00
logger.info("Reloading from: {}".format(fpath))
2020-10-12 03:33:31 +02:00
start = path2iter(fpath)
# configure map_location properly for different processes
2021-06-10 17:03:06 +02:00
map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
2020-10-12 03:33:31 +02:00
to_load = torch.load(fpath, map_location=map_location)
2021-06-10 17:03:06 +02:00
for m in range(models["cascade_level"]):
for name in ["net_{}".format(m), "optim_{}".format(m)]:
2020-10-12 03:33:31 +02:00
models[name].load_state_dict(to_load[name])
2020-10-12 17:05:53 +02:00
return start, models
def ddp_train_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
2021-06-10 17:03:06 +02:00
logger.info(
"gpu_mem: {}".format(
torch.cuda.get_device_properties(rank).total_memory
)
)
2020-10-12 17:05:53 +02:00
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
2021-06-10 17:03:06 +02:00
logger.info("setting batch size according to 24G gpu")
2020-10-12 17:05:53 +02:00
args.N_rand = 1024
args.chunk_size = 8192
else:
2021-06-10 17:03:06 +02:00
logger.info("setting batch size according to 12G gpu")
2020-10-12 17:05:53 +02:00
args.N_rand = 512
args.chunk_size = 4096
###### Create log dir and copy the config file
if rank == 0:
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
2021-06-10 17:03:06 +02:00
f = os.path.join(args.basedir, args.expname, "args.txt")
with open(f, "w") as file:
2020-10-12 17:05:53 +02:00
for arg in sorted(vars(args)):
attr = getattr(args, arg)
2021-06-10 17:03:06 +02:00
file.write("{} = {}\n".format(arg, attr))
2020-10-12 17:05:53 +02:00
if args.config is not None:
2021-06-10 17:03:06 +02:00
f = os.path.join(args.basedir, args.expname, "config.txt")
with open(f, "w") as file:
file.write(open(args.config, "r").read())
2020-10-12 17:05:53 +02:00
torch.distributed.barrier()
2021-06-10 17:03:06 +02:00
ray_samplers = load_data_split(
args.datadir,
args.scene,
split="train",
try_load_min_depth=args.load_min_depth,
)
val_ray_samplers = load_data_split(
args.datadir,
args.scene,
split="validation",
try_load_min_depth=args.load_min_depth,
skip=args.testskip,
)
2020-10-12 17:05:53 +02:00
# write training image names for autoexposure
if args.optim_autoexpo:
2021-06-10 17:03:06 +02:00
f = os.path.join(args.basedir, args.expname, "train_images.json")
with open(f, "w") as file:
img_names = [
ray_samplers[i].img_path for i in range(len(ray_samplers))
]
2020-10-12 17:05:53 +02:00
json.dump(img_names, file, indent=2)
###### create network and wrap in ddp; each process should do this
start, models = create_nerf(rank, args)
2020-10-12 03:33:31 +02:00
##### important!!!
# make sure different processes sample different rays
np.random.seed((rank + 1) * 777)
# make sure different processes have different perturbations in depth samples
torch.manual_seed((rank + 1) * 777)
##### only main process should do the logging
if rank == 0:
2021-06-10 17:03:06 +02:00
writer = SummaryWriter(
os.path.join(args.basedir, "summaries", args.expname)
)
2020-10-12 03:33:31 +02:00
# start training
2021-06-10 17:03:06 +02:00
what_val_to_log = 0 # helper variable for parallel rendering of a image
2020-10-12 03:33:31 +02:00
what_train_to_log = 0
2021-06-10 17:03:06 +02:00
for global_step in tqdm(range(start + 1, start + 1 + args.N_iters)):
2020-10-12 03:33:31 +02:00
time0 = time.time()
scalars_to_log = OrderedDict()
### Start of core optimization loop
2021-06-10 17:03:06 +02:00
scalars_to_log["resolution"] = ray_samplers[0].resolution_level
2020-10-12 03:33:31 +02:00
# randomly sample rays and move to device
i = np.random.randint(low=0, high=len(ray_samplers))
2021-06-10 17:03:06 +02:00
ray_batch = ray_samplers[i].random_sample(
args.N_rand, center_crop=False
)
2020-10-12 03:33:31 +02:00
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch[key] = ray_batch[key].to(rank)
# forward and backward
2021-06-10 17:03:06 +02:00
dots_sh = list(ray_batch["ray_d"].shape[:-1]) # number of rays
all_rets = [] # results on different cascade levels
for m in range(models["cascade_level"]):
optim = models["optim_{}".format(m)]
net = models["net_{}".format(m)]
2020-10-12 03:33:31 +02:00
# sample depths
2021-06-10 17:03:06 +02:00
N_samples = models["cascade_samples"][m]
2020-10-12 03:33:31 +02:00
if m == 0:
# foreground depth
2021-06-10 17:03:06 +02:00
fg_far_depth = intersect_sphere(
ray_batch["ray_o"], ray_batch["ray_d"]
) # [...,]
fg_near_depth = ray_batch["min_depth"] # [..., ]
2020-10-12 03:33:31 +02:00
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
2021-06-10 17:03:06 +02:00
fg_depth = torch.stack(
[fg_near_depth + i * step for i in range(N_samples)],
dim=-1,
) # [..., N_samples]
fg_depth = perturb_samples(
fg_depth
) # random perturbation during training
2020-10-12 03:33:31 +02:00
# background depth
2021-06-10 17:03:06 +02:00
bg_depth = (
torch.linspace(0.0, 1.0, N_samples)
.view(
[
1,
]
* len(dots_sh)
+ [
N_samples,
]
)
.expand(
dots_sh
+ [
N_samples,
]
)
.to(rank)
)
bg_depth = perturb_samples(
bg_depth
) # random perturbation during training
2020-10-12 03:33:31 +02:00
else:
# sample pdf and concat with earlier samples
2021-06-10 17:03:06 +02:00
fg_weights = ret["fg_weights"].clone().detach()
fg_depth_mid = 0.5 * (
fg_depth[..., 1:] + fg_depth[..., :-1]
) # [..., N_samples-1]
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
fg_depth_samples = sample_pdf(
bins=fg_depth_mid,
weights=fg_weights,
N_samples=N_samples,
det=False,
) # [..., N_samples]
fg_depth, _ = torch.sort(
torch.cat((fg_depth, fg_depth_samples), dim=-1)
)
2020-10-12 03:33:31 +02:00
# sample pdf and concat with earlier samples
2021-06-10 17:03:06 +02:00
bg_weights = ret["bg_weights"].clone().detach()
bg_depth_mid = 0.5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
bg_depth_samples = sample_pdf(
bins=bg_depth_mid,
weights=bg_weights,
N_samples=N_samples,
det=False,
) # [..., N_samples]
bg_depth, _ = torch.sort(
torch.cat((bg_depth, bg_depth_samples), dim=-1)
)
2020-10-12 03:33:31 +02:00
optim.zero_grad()
2021-06-10 17:03:06 +02:00
ret = net(
ray_batch["ray_o"],
ray_batch["ray_d"],
fg_far_depth,
fg_depth,
bg_depth,
img_name=ray_batch["img_name"],
)
2020-10-12 03:33:31 +02:00
all_rets.append(ret)
2021-06-10 17:03:06 +02:00
rgb_gt = ray_batch["rgb"].to(rank)
if "autoexpo" in ret:
scale, shift = ret["autoexpo"]
scalars_to_log[
"level_{}/autoexpo_scale".format(m)
] = scale.item()
scalars_to_log[
"level_{}/autoexpo_shift".format(m)
] = shift.item()
2020-10-12 17:05:53 +02:00
# rgb_gt = scale * rgb_gt + shift
2021-06-10 17:03:06 +02:00
rgb_pred = (ret["rgb"] - shift) / scale
2020-10-12 17:05:53 +02:00
rgb_loss = img2mse(rgb_pred, rgb_gt)
2021-06-10 17:03:06 +02:00
loss = rgb_loss + args.lambda_autoexpo * (
torch.abs(scale - 1.0) + torch.abs(shift)
)
2020-10-12 17:05:53 +02:00
else:
2021-06-10 17:03:06 +02:00
rgb_loss = img2mse(ret["rgb"], rgb_gt)
2020-10-12 17:05:53 +02:00
loss = rgb_loss
2021-06-10 17:03:06 +02:00
scalars_to_log["level_{}/loss".format(m)] = rgb_loss.item()
scalars_to_log["level_{}/pnsr".format(m)] = mse2psnr(
rgb_loss.item()
)
2020-10-12 03:33:31 +02:00
loss.backward()
optim.step()
# # clean unused memory
# torch.cuda.empty_cache()
### end of core optimization loop
dt = time.time() - time0
2021-06-10 17:03:06 +02:00
scalars_to_log["iter_time"] = dt
2020-10-12 03:33:31 +02:00
### only main process should do the logging
if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
2021-06-10 17:03:06 +02:00
logstr = "{} step: {} ".format(args.expname, global_step)
2020-10-12 03:33:31 +02:00
for k in scalars_to_log:
2021-06-10 17:03:06 +02:00
logstr += " {}: {:.6f}".format(k, scalars_to_log[k])
2020-10-12 03:33:31 +02:00
writer.add_scalar(k, scalars_to_log[k], global_step)
logger.info(logstr)
### each process should do this; but only main process merges the results
2021-06-10 17:03:06 +02:00
if global_step % args.i_img == 0 or global_step == start + 1:
2020-10-12 03:33:31 +02:00
#### critical: make sure each process is working on the same random image
time0 = time.time()
idx = what_val_to_log % len(val_ray_samplers)
2021-06-10 17:03:06 +02:00
log_data = render_single_image(
rank,
args.world_size,
models,
val_ray_samplers[idx],
args.chunk_size,
)
2020-10-12 03:33:31 +02:00
what_val_to_log += 1
dt = time.time() - time0
2021-06-10 17:03:06 +02:00
if rank == 0: # only main process should do this
logger.info(
"Logged a random validation view in {} seconds".format(dt)
)
log_view_to_tb(
writer,
global_step,
log_data,
gt_img=val_ray_samplers[idx].get_img(),
mask=None,
prefix="val/",
)
2020-10-12 03:33:31 +02:00
time0 = time.time()
idx = what_train_to_log % len(ray_samplers)
2021-06-10 17:03:06 +02:00
log_data = render_single_image(
rank,
args.world_size,
models,
ray_samplers[idx],
args.chunk_size,
)
2020-10-12 03:33:31 +02:00
what_train_to_log += 1
dt = time.time() - time0
2021-06-10 17:03:06 +02:00
if rank == 0: # only main process should do this
logger.info(
"Logged a random training view in {} seconds".format(dt)
)
log_view_to_tb(
writer,
global_step,
log_data,
gt_img=ray_samplers[idx].get_img(),
mask=None,
prefix="train/",
)
2020-10-12 03:33:31 +02:00
2020-10-12 17:05:53 +02:00
del log_data
2020-10-12 03:33:31 +02:00
torch.cuda.empty_cache()
2021-06-10 17:03:06 +02:00
if rank == 0 and (
global_step % args.i_weights == 0 and global_step > 0
):
2020-10-12 03:33:31 +02:00
# saving checkpoints and logging
2021-06-10 17:03:06 +02:00
fpath = os.path.join(
args.basedir,
args.expname,
"model_{:06d}.pth".format(global_step),
)
2020-10-12 03:33:31 +02:00
to_save = OrderedDict()
2021-06-10 17:03:06 +02:00
for m in range(models["cascade_level"]):
name = "net_{}".format(m)
2020-10-12 03:33:31 +02:00
to_save[name] = models[name].state_dict()
2021-06-10 17:03:06 +02:00
name = "optim_{}".format(m)
2020-10-12 03:33:31 +02:00
to_save[name] = models[name].state_dict()
torch.save(to_save, fpath)
# clean up for multi-processing
cleanup()
def config_parser():
import configargparse
2021-06-10 17:03:06 +02:00
2020-10-12 03:33:31 +02:00
parser = configargparse.ArgumentParser()
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--config",
is_config_file=True,
help="config file path",
)
parser.add_argument(
"--expname",
type=str,
help="experiment name",
)
parser.add_argument(
"--basedir",
type=str,
default="./logs/",
help="where to store ckpts and logs",
)
2020-10-12 03:33:31 +02:00
# dataset options
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--datadir",
type=str,
default=None,
help="input data directory",
)
parser.add_argument(
"--scene",
type=str,
default=None,
help="scene name",
)
parser.add_argument(
"--testskip",
type=int,
default=8,
help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels",
)
2020-10-12 03:33:31 +02:00
# model size
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--netdepth",
type=int,
default=8,
help="layers in coarse network",
)
parser.add_argument(
"--netwidth",
type=int,
default=256,
help="channels per layer in coarse network",
)
parser.add_argument(
"--use_viewdirs",
action="store_true",
help="use full 5D input instead of 3D",
)
2020-10-12 03:33:31 +02:00
# checkpoints
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--no_reload",
action="store_true",
help="do not reload weights from saved ckpt",
)
parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="specific weights npy file to reload for coarse network",
)
2020-10-12 03:33:31 +02:00
# batch size
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--N_rand",
type=int,
default=32 * 32 * 2,
help="batch size (number of random rays per gradient step)",
)
parser.add_argument(
"--chunk_size",
type=int,
default=1024 * 8,
help="number of rays processed in parallel, decrease if running out of memory",
)
2020-10-12 03:33:31 +02:00
# iterations
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--N_iters",
type=int,
default=250001,
help="number of iterations",
)
2020-10-12 04:25:41 +02:00
# render only
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--render_splits",
type=str,
default="test",
help="splits to render",
)
2020-10-12 03:33:31 +02:00
# cascade training
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--cascade_level",
type=int,
default=2,
help="number of cascade levels",
)
parser.add_argument(
"--cascade_samples",
type=str,
default="64,64",
help="samples at each level",
)
2020-10-12 04:25:41 +02:00
# multiprocess learning
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--world_size",
type=int,
default="-1",
help="number of processes (GPU). defaults to -1 for every GPU.",
)
2020-10-12 17:05:53 +02:00
# optimize autoexposure
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--optim_autoexpo",
action="store_true",
help="optimize autoexposure parameters",
)
parser.add_argument(
"--lambda_autoexpo",
type=float,
default=1.0,
help="regularization weight for autoexposure",
)
2020-10-12 17:05:53 +02:00
2020-10-12 03:33:31 +02:00
# learning rate options
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--lrate",
type=float,
default=5e-4,
help="learning rate",
)
parser.add_argument(
"--lrate_decay_factor",
type=float,
default=0.1,
help="decay learning rate by a factor every specified number of steps",
)
parser.add_argument(
"--lrate_decay_steps",
type=int,
default=5000,
help="decay learning rate by a factor every specified number of steps",
)
2020-10-12 03:33:31 +02:00
# rendering options
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--det",
action="store_true",
help="deterministic sampling for coarse and fine samples",
)
parser.add_argument(
"--max_freq_log2",
type=int,
default=10,
help="log2 of max freq for positional encoding (3D location)",
)
parser.add_argument(
"--max_freq_log2_viewdirs",
type=int,
default=4,
help="log2 of max freq for positional encoding (2D direction)",
)
parser.add_argument(
"--load_min_depth",
action="store_true",
help="whether to load min depth",
)
2020-10-12 03:33:31 +02:00
# logging/saving options
2021-06-10 17:03:06 +02:00
parser.add_argument(
"--i_print",
type=int,
default=100,
help="frequency of console printout and metric loggin",
)
parser.add_argument(
"--i_img",
type=int,
default=500,
help="frequency of tensorboard image logging",
)
parser.add_argument(
"--i_weights",
type=int,
default=10000,
help="frequency of weight ckpt saving",
)
2020-10-12 03:33:31 +02:00
return parser
def train():
parser = config_parser()
args = parser.parse_args()
logger.info(parser.format_values())
if args.world_size == -1:
args.world_size = torch.cuda.device_count()
2021-06-10 17:03:06 +02:00
logger.info("Using # gpus: {}".format(args.world_size))
torch.multiprocessing.spawn(
ddp_train_nerf, args=(args,), nprocs=args.world_size, join=True
)
2020-10-12 03:33:31 +02:00
2021-06-10 17:03:06 +02:00
if __name__ == "__main__":
2020-10-12 03:33:31 +02:00
setup_logger()
train()