From 33adb79bf939f3be6d53d740a45aacd3bd293d5c Mon Sep 17 00:00:00 2001 From: Solal Nathan Date: Wed, 23 Jun 2021 18:55:31 +0200 Subject: [PATCH] voxelisation and marching cubes algorithms --- voxelisation.py | 154 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 voxelisation.py diff --git a/voxelisation.py b/voxelisation.py new file mode 100644 index 0000000..086c03d --- /dev/null +++ b/voxelisation.py @@ -0,0 +1,154 @@ +import os + +import mcubes +import numpy as np +import torch +import torch.distributed +import trimesh +from tqdm import tqdm + +from ddp_train_nerf import (cleanup, config_parser, create_nerf, setup, + setup_logger) + +parser = config_parser() +args = parser.parse_args() + +# hardcode settings +args.world_size = 1 +args.rank = 0 + +# setup +setup(args.rank, args.world_size) +start, models = create_nerf(args.rank, args) + +net_0 = models['net_0'] + +fg_far_depth = 1 + +# weird way to do it, should be change if something better exists +for idx, m in enumerate(net_0.modules()): + #print(idx, "->", m) + + # foreground + if idx == 3: + fg_embedder_position = m + if idx == 4: + fg_embedder_viewdir = m + if idx == 5: + fg_mlp_net = m + + # background +# if idx == 40: +# bg_embedder_position = m +# if idx == 41: +# bg_embedder_viewdir = m +# if idx == 42: +# bg_mlp_net = m + +# put everything on GPU +device = "cuda" + +def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, device="cuda"): + """ + Given a position returns the occupancy probabily of the network. + + Given a poisition, appropriate embedders and the MLPNet, return the + corresponding occupancy. + + Parameters + ---------- + position : torch.tensor + A (x,y,z) tensor of the position to query + embedder_position, embedder_viewder : nerf_network.Embedder + Positional and view directions embedders + mlp_net : nerf_network.MLPNet + A simple MLP implementation written for NeRF + device : str, optional + The torch device, can be either `cpu` or `cuda` + + Returns + ------- + sigma : float + The occupancy at the given position. + """ + + # take a random ray direction as it does not matter for sigma + ray_d = torch.rand(3, device=device) + + # normalize ray direction + ray_d_norm = torch.norm(ray_d) + ray_d = ray_d / ray_d_norm + + # forge the input + nn_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1) + + # forward the NN + nn_raw = mlp_net(nn_input) + sigma = float(nn_raw["sigma"]) + + return sigma + +# annonymous function +f = lambda x, y, z: query_occupancy(torch.tensor([x,y,z], dtype=torch.float32, device=device), fg_embedder_position, fg_embedder_viewdir, mlp_net) + +def marching_cube_and_render(sigma_list, threshold): + + vertices, triangles = mcubes.marching_cubes(sigma_list, threshold) + mesh = trimesh.Trimesh(vertices / N - .5, triangles) + mesh.show() + + +#position = torch.rand(3, device=device) +#position = torch.tensor([0.1, 0.1, 0.1], device=device) + +ray_d = torch.rand(3, device=device) +# normalize ray direction +ray_d_norm = torch.norm(ray_d) +ray_d = ray_d / ray_d_norm + + +N = 100 +t = np.linspace(-1, 1, N+1) + +query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32) +#print(query_pts.shape) +sh = query_pts.shape +flat = query_pts.reshape([-1,3]) + +#raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W +fg_raw_voxel = torch.zeros(N+1, N+1, N+1) +#bg_raw_voxel = torch.zeros(N+1, N+1, N+1) + +i = 0 +for x,y,z in tqdm(flat): + + position = torch.tensor([x,y,z], device=device) + #bg_position = torch.cat((position, torch.tensor([1], device=device))) + + # concat the output of the embedding + fg_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1) + #bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1) + + # forward + fg_raw = fg_mlp_net(fg_input) + #bg_raw = bg_mlp_net(bg_input) + + #raw_voxel.append(position + float(nn_raw['sigma'])) + fg_sigma = float(fg_raw["sigma"]) + #bg_sigma = float(bg_raw["sigma"]) + + nx, ny, nz = np.unravel_index(i, (N+1, N+1, N+1)) + i += 1 # update index + #raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z]) + fg_raw_voxel[nx, ny, nz] = fg_sigma + #bg_raw_voxel[nx, ny, nz] = bg_sigma + + +fg_sigma = np.array(fg_raw_voxel) +#bg_sigma = np.array(bg_raw_voxel) +threshold = 0.5 + + +#vertices, triangles = mcubes.marching_cubes(sigma, threshold) +#mesh = trimesh.Trimesh(vertices / N - .5, triangles) +#mesh.show()