import os import mcubes import numpy as np import torch import torch.distributed import trimesh from tqdm import tqdm from ddp_train_nerf import ( cleanup, config_parser, create_nerf, setup, setup_logger, ) parser = config_parser() args = parser.parse_args() # hardcode settings args.world_size = 1 args.rank = 0 # setup setup(args.rank, args.world_size) start, models = create_nerf(args.rank, args) net_0 = models["net_0"] fg_far_depth = 1 # weird way to do it, should be change if something better exists for idx, m in enumerate(net_0.modules()): # print(idx, "->", m) # foreground if idx == 3: fg_embedder_position = m if idx == 4: fg_embedder_viewdir = m if idx == 5: fg_mlp_net = m # background # if idx == 40: # bg_embedder_position = m # if idx == 41: # bg_embedder_viewdir = m # if idx == 42: # bg_mlp_net = m # put everything on GPU device = "cuda" def query_occupancy( position, embedder_position, embedder_viewdir, mlp_net, device="cuda" ): """ Given a position returns the occupancy probabily of the network. Given a poisition, appropriate embedders and the MLPNet, return the corresponding occupancy. Parameters ---------- position : torch.tensor A (x,y,z) tensor of the position to query embedder_position, embedder_viewder : nerf_network.Embedder Positional and view directions embedders mlp_net : nerf_network.MLPNet A simple MLP implementation written for NeRF device : str, optional The torch device, can be either `cpu` or `cuda` Returns ------- sigma : float The occupancy at the given position. """ # take a random ray direction as it does not matter for sigma ray_d = torch.rand(3, device=device) # normalize ray direction ray_d_norm = torch.norm(ray_d) ray_d = ray_d / ray_d_norm # forge the input nn_input = torch.cat( (fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1 ) # forward the NN nn_raw = mlp_net(nn_input) sigma = float(nn_raw["sigma"]) return sigma # annonymous function f = lambda x, y, z: query_occupancy( torch.tensor([x, y, z], dtype=torch.float32, device=device), fg_embedder_position, fg_embedder_viewdir, mlp_net, ) def marching_cube_and_render(sigma_list, threshold): vertices, triangles = mcubes.marching_cubes(sigma_list, threshold) mesh = trimesh.Trimesh(vertices / N - 0.5, triangles) mesh.show() # position = torch.rand(3, device=device) # position = torch.tensor([0.1, 0.1, 0.1], device=device) ray_d = torch.rand(3, device=device) # normalize ray direction ray_d_norm = torch.norm(ray_d) ray_d = ray_d / ray_d_norm N = 100 t = np.linspace(-1, 1, N + 1) query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32) # print(query_pts.shape) sh = query_pts.shape flat = query_pts.reshape([-1, 3]) # raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W fg_raw_voxel = torch.zeros(N + 1, N + 1, N + 1) # bg_raw_voxel = torch.zeros(N+1, N+1, N+1) i = 0 for x, y, z in tqdm(flat): position = torch.tensor([x, y, z], device=device) # bg_position = torch.cat((position, torch.tensor([1], device=device))) # concat the output of the embedding fg_input = torch.cat( (fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1 ) # bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1) # forward fg_raw = fg_mlp_net(fg_input) # bg_raw = bg_mlp_net(bg_input) # raw_voxel.append(position + float(nn_raw['sigma'])) fg_sigma = float(fg_raw["sigma"]) # bg_sigma = float(bg_raw["sigma"]) nx, ny, nz = np.unravel_index(i, (N + 1, N + 1, N + 1)) i += 1 # update index # raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z]) fg_raw_voxel[nx, ny, nz] = fg_sigma # bg_raw_voxel[nx, ny, nz] = bg_sigma fg_sigma = np.array(fg_raw_voxel) # bg_sigma = np.array(bg_raw_voxel) threshold = 0.5 # vertices, triangles = mcubes.marching_cubes(sigma, threshold) # mesh = trimesh.Trimesh(vertices / N - .5, triangles) # mesh.show()