You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
4.3 KiB
Python

import os
import pickle
import mcubes
import numpy as np
import torch
import torch.distributed
import trimesh
from tqdm import tqdm
from ddp_train_nerf import (
cleanup,
config_parser,
create_nerf,
setup,
setup_logger,
)
parser = config_parser()
args = parser.parse_args()
# hardcode settings
args.world_size = 1
args.rank = 0
# setup
setup(args.rank, args.world_size)
start, models = create_nerf(args.rank, args)
net_0 = models["net_0"]
fg_far_depth = 1
# weird way to do it, should be change if something better exists
for idx, m in enumerate(net_0.modules()):
# print(idx, "->", m)
# foreground
if idx == 3:
fg_embedder_position = m
if idx == 4:
fg_embedder_viewdir = m
if idx == 5:
fg_mlp_net = m
# background
# if idx == 40:
# bg_embedder_position = m
# if idx == 41:
# bg_embedder_viewdir = m
# if idx == 42:
# bg_mlp_net = m
# put everything on GPU
device = "cuda"
def query_occupancy(
position, embedder_position, embedder_viewdir, mlp_net, device="cuda"
):
"""
Given a position returns the occupancy probabily of the network.
Given a poisition, appropriate embedders and the MLPNet, return the
corresponding occupancy.
Parameters
----------
position : torch.tensor
A (x,y,z) tensor of the position to query
embedder_position, embedder_viewder : nerf_network.Embedder
Positional and view directions embedders
mlp_net : nerf_network.MLPNet
A simple MLP implementation written for NeRF
device : str, optional
The torch device, can be either `cpu` or `cuda`
Returns
-------
sigma : float
The occupancy at the given position.
"""
# take a random ray direction as it does not matter for sigma
ray_d = torch.rand(3, device=device)
# normalize ray direction
ray_d_norm = torch.norm(ray_d)
ray_d = ray_d / ray_d_norm
# forge the input
nn_input = torch.cat(
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
)
# forward the NN
nn_raw = mlp_net(nn_input)
sigma = float(nn_raw["sigma"])
return sigma
# annonymous function
f = lambda x, y, z: query_occupancy(
torch.tensor([x, y, z], dtype=torch.float32, device=device),
fg_embedder_position,
fg_embedder_viewdir,
mlp_net,
)
def marching_cube_and_render(sigma_list, threshold):
vertices, triangles = mcubes.marching_cubes(sigma_list, threshold)
mesh = trimesh.Trimesh(vertices / N - 0.5, triangles)
mesh.show()
# position = torch.rand(3, device=device)
# position = torch.tensor([0.1, 0.1, 0.1], device=device)
ray_d = torch.rand(3, device=device)
# normalize ray direction
ray_d_norm = torch.norm(ray_d)
ray_d = ray_d / ray_d_norm
N = 256
t = np.linspace(-2, 2, N + 1)
# A cube of size 2x2x2 is necessary to contain a sphere of radius 1.0
query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
# print(query_pts.shape)
sh = query_pts.shape
flat = query_pts.reshape([-1, 3])
# raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W
fg_raw_voxel = torch.zeros(N + 1, N + 1, N + 1)
# bg_raw_voxel = torch.zeros(N+1, N+1, N+1)
i = 0
for x, y, z in tqdm(flat):
position = torch.tensor([x, y, z], device=device)
# bg_position = torch.cat((position, torch.tensor([1], device=device)))
# concat the output of the embedding
fg_input = torch.cat(
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
)
# bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1)
# forward
fg_raw = fg_mlp_net(fg_input)
# bg_raw = bg_mlp_net(bg_input)
# raw_voxel.append(position + float(nn_raw['sigma']))
fg_sigma = float(fg_raw["sigma"])
# bg_sigma = float(bg_raw["sigma"])
nx, ny, nz = np.unravel_index(i, (N + 1, N + 1, N + 1))
i += 1 # update index
# raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z])
fg_raw_voxel[nx, ny, nz] = fg_sigma
# bg_raw_voxel[nx, ny, nz] = bg_sigma
fg_sigma = np.array(fg_raw_voxel)
# bg_sigma = np.array(bg_raw_voxel)
threshold = 0.5
# save the raw_voxel in pickle format
fd = open("raw_voxel_256.pkl", "wb")
3 years ago
pickle.dump(fg_sigma, fd)
fd.close()
# vertices, triangles = mcubes.marching_cubes(sigma, threshold)
# mesh = trimesh.Trimesh(vertices / N - .5, triangles)
# mesh.show()