diff --git a/voxelisation.py b/voxelisation.py index 086c03d..2026263 100644 --- a/voxelisation.py +++ b/voxelisation.py @@ -7,8 +7,13 @@ import torch.distributed import trimesh from tqdm import tqdm -from ddp_train_nerf import (cleanup, config_parser, create_nerf, setup, - setup_logger) +from ddp_train_nerf import ( + cleanup, + config_parser, + create_nerf, + setup, + setup_logger, +) parser = config_parser() args = parser.parse_args() @@ -21,13 +26,13 @@ args.rank = 0 setup(args.rank, args.world_size) start, models = create_nerf(args.rank, args) -net_0 = models['net_0'] +net_0 = models["net_0"] fg_far_depth = 1 # weird way to do it, should be change if something better exists for idx, m in enumerate(net_0.modules()): - #print(idx, "->", m) + # print(idx, "->", m) # foreground if idx == 3: @@ -48,7 +53,10 @@ for idx, m in enumerate(net_0.modules()): # put everything on GPU device = "cuda" -def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, device="cuda"): + +def query_occupancy( + position, embedder_position, embedder_viewdir, mlp_net, device="cuda" +): """ Given a position returns the occupancy probabily of the network. @@ -74,13 +82,15 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi # take a random ray direction as it does not matter for sigma ray_d = torch.rand(3, device=device) - + # normalize ray direction ray_d_norm = torch.norm(ray_d) ray_d = ray_d / ray_d_norm # forge the input - nn_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1) + nn_input = torch.cat( + (fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1 + ) # forward the NN nn_raw = mlp_net(nn_input) @@ -88,18 +98,25 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi return sigma + # annonymous function -f = lambda x, y, z: query_occupancy(torch.tensor([x,y,z], dtype=torch.float32, device=device), fg_embedder_position, fg_embedder_viewdir, mlp_net) +f = lambda x, y, z: query_occupancy( + torch.tensor([x, y, z], dtype=torch.float32, device=device), + fg_embedder_position, + fg_embedder_viewdir, + mlp_net, +) + def marching_cube_and_render(sigma_list, threshold): - + vertices, triangles = mcubes.marching_cubes(sigma_list, threshold) - mesh = trimesh.Trimesh(vertices / N - .5, triangles) + mesh = trimesh.Trimesh(vertices / N - 0.5, triangles) mesh.show() -#position = torch.rand(3, device=device) -#position = torch.tensor([0.1, 0.1, 0.1], device=device) +# position = torch.rand(3, device=device) +# position = torch.tensor([0.1, 0.1, 0.1], device=device) ray_d = torch.rand(3, device=device) # normalize ray direction @@ -108,47 +125,49 @@ ray_d = ray_d / ray_d_norm N = 100 -t = np.linspace(-1, 1, N+1) +t = np.linspace(-1, 1, N + 1) query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32) -#print(query_pts.shape) +# print(query_pts.shape) sh = query_pts.shape -flat = query_pts.reshape([-1,3]) +flat = query_pts.reshape([-1, 3]) -#raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W -fg_raw_voxel = torch.zeros(N+1, N+1, N+1) -#bg_raw_voxel = torch.zeros(N+1, N+1, N+1) +# raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W +fg_raw_voxel = torch.zeros(N + 1, N + 1, N + 1) +# bg_raw_voxel = torch.zeros(N+1, N+1, N+1) i = 0 -for x,y,z in tqdm(flat): +for x, y, z in tqdm(flat): + + position = torch.tensor([x, y, z], device=device) + # bg_position = torch.cat((position, torch.tensor([1], device=device))) - position = torch.tensor([x,y,z], device=device) - #bg_position = torch.cat((position, torch.tensor([1], device=device))) - # concat the output of the embedding - fg_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1) - #bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1) + fg_input = torch.cat( + (fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1 + ) + # bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1) # forward fg_raw = fg_mlp_net(fg_input) - #bg_raw = bg_mlp_net(bg_input) + # bg_raw = bg_mlp_net(bg_input) - #raw_voxel.append(position + float(nn_raw['sigma'])) + # raw_voxel.append(position + float(nn_raw['sigma'])) fg_sigma = float(fg_raw["sigma"]) - #bg_sigma = float(bg_raw["sigma"]) + # bg_sigma = float(bg_raw["sigma"]) - nx, ny, nz = np.unravel_index(i, (N+1, N+1, N+1)) + nx, ny, nz = np.unravel_index(i, (N + 1, N + 1, N + 1)) i += 1 # update index - #raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z]) + # raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z]) fg_raw_voxel[nx, ny, nz] = fg_sigma - #bg_raw_voxel[nx, ny, nz] = bg_sigma + # bg_raw_voxel[nx, ny, nz] = bg_sigma fg_sigma = np.array(fg_raw_voxel) -#bg_sigma = np.array(bg_raw_voxel) +# bg_sigma = np.array(bg_raw_voxel) threshold = 0.5 -#vertices, triangles = mcubes.marching_cubes(sigma, threshold) -#mesh = trimesh.Trimesh(vertices / N - .5, triangles) -#mesh.show() +# vertices, triangles = mcubes.marching_cubes(sigma, threshold) +# mesh = trimesh.Trimesh(vertices / N - .5, triangles) +# mesh.show()