nerf_plus_plus/voxelisation.py
2021-06-23 18:57:41 +02:00

173 lines
4.1 KiB
Python

import os
import mcubes
import numpy as np
import torch
import torch.distributed
import trimesh
from tqdm import tqdm
from ddp_train_nerf import (
cleanup,
config_parser,
create_nerf,
setup,
setup_logger,
)
parser = config_parser()
args = parser.parse_args()
# hardcode settings
args.world_size = 1
args.rank = 0
# setup
setup(args.rank, args.world_size)
start, models = create_nerf(args.rank, args)
net_0 = models["net_0"]
fg_far_depth = 1
# weird way to do it, should be change if something better exists
for idx, m in enumerate(net_0.modules()):
# print(idx, "->", m)
# foreground
if idx == 3:
fg_embedder_position = m
if idx == 4:
fg_embedder_viewdir = m
if idx == 5:
fg_mlp_net = m
# background
# if idx == 40:
# bg_embedder_position = m
# if idx == 41:
# bg_embedder_viewdir = m
# if idx == 42:
# bg_mlp_net = m
# put everything on GPU
device = "cuda"
def query_occupancy(
position, embedder_position, embedder_viewdir, mlp_net, device="cuda"
):
"""
Given a position returns the occupancy probabily of the network.
Given a poisition, appropriate embedders and the MLPNet, return the
corresponding occupancy.
Parameters
----------
position : torch.tensor
A (x,y,z) tensor of the position to query
embedder_position, embedder_viewder : nerf_network.Embedder
Positional and view directions embedders
mlp_net : nerf_network.MLPNet
A simple MLP implementation written for NeRF
device : str, optional
The torch device, can be either `cpu` or `cuda`
Returns
-------
sigma : float
The occupancy at the given position.
"""
# take a random ray direction as it does not matter for sigma
ray_d = torch.rand(3, device=device)
# normalize ray direction
ray_d_norm = torch.norm(ray_d)
ray_d = ray_d / ray_d_norm
# forge the input
nn_input = torch.cat(
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
)
# forward the NN
nn_raw = mlp_net(nn_input)
sigma = float(nn_raw["sigma"])
return sigma
# annonymous function
f = lambda x, y, z: query_occupancy(
torch.tensor([x, y, z], dtype=torch.float32, device=device),
fg_embedder_position,
fg_embedder_viewdir,
mlp_net,
)
def marching_cube_and_render(sigma_list, threshold):
vertices, triangles = mcubes.marching_cubes(sigma_list, threshold)
mesh = trimesh.Trimesh(vertices / N - 0.5, triangles)
mesh.show()
# position = torch.rand(3, device=device)
# position = torch.tensor([0.1, 0.1, 0.1], device=device)
ray_d = torch.rand(3, device=device)
# normalize ray direction
ray_d_norm = torch.norm(ray_d)
ray_d = ray_d / ray_d_norm
N = 100
t = np.linspace(-1, 1, N + 1)
query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
# print(query_pts.shape)
sh = query_pts.shape
flat = query_pts.reshape([-1, 3])
# raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W
fg_raw_voxel = torch.zeros(N + 1, N + 1, N + 1)
# bg_raw_voxel = torch.zeros(N+1, N+1, N+1)
i = 0
for x, y, z in tqdm(flat):
position = torch.tensor([x, y, z], device=device)
# bg_position = torch.cat((position, torch.tensor([1], device=device)))
# concat the output of the embedding
fg_input = torch.cat(
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
)
# bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1)
# forward
fg_raw = fg_mlp_net(fg_input)
# bg_raw = bg_mlp_net(bg_input)
# raw_voxel.append(position + float(nn_raw['sigma']))
fg_sigma = float(fg_raw["sigma"])
# bg_sigma = float(bg_raw["sigma"])
nx, ny, nz = np.unravel_index(i, (N + 1, N + 1, N + 1))
i += 1 # update index
# raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z])
fg_raw_voxel[nx, ny, nz] = fg_sigma
# bg_raw_voxel[nx, ny, nz] = bg_sigma
fg_sigma = np.array(fg_raw_voxel)
# bg_sigma = np.array(bg_raw_voxel)
threshold = 0.5
# vertices, triangles = mcubes.marching_cubes(sigma, threshold)
# mesh = trimesh.Trimesh(vertices / N - .5, triangles)
# mesh.show()