|
|
|
@ -7,8 +7,13 @@ import torch.distributed
|
|
|
|
|
import trimesh
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from ddp_train_nerf import (cleanup, config_parser, create_nerf, setup,
|
|
|
|
|
setup_logger)
|
|
|
|
|
from ddp_train_nerf import (
|
|
|
|
|
cleanup,
|
|
|
|
|
config_parser,
|
|
|
|
|
create_nerf,
|
|
|
|
|
setup,
|
|
|
|
|
setup_logger,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser = config_parser()
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
@ -21,13 +26,13 @@ args.rank = 0
|
|
|
|
|
setup(args.rank, args.world_size)
|
|
|
|
|
start, models = create_nerf(args.rank, args)
|
|
|
|
|
|
|
|
|
|
net_0 = models['net_0']
|
|
|
|
|
net_0 = models["net_0"]
|
|
|
|
|
|
|
|
|
|
fg_far_depth = 1
|
|
|
|
|
|
|
|
|
|
# weird way to do it, should be change if something better exists
|
|
|
|
|
for idx, m in enumerate(net_0.modules()):
|
|
|
|
|
#print(idx, "->", m)
|
|
|
|
|
# print(idx, "->", m)
|
|
|
|
|
|
|
|
|
|
# foreground
|
|
|
|
|
if idx == 3:
|
|
|
|
@ -48,7 +53,10 @@ for idx, m in enumerate(net_0.modules()):
|
|
|
|
|
# put everything on GPU
|
|
|
|
|
device = "cuda"
|
|
|
|
|
|
|
|
|
|
def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, device="cuda"):
|
|
|
|
|
|
|
|
|
|
def query_occupancy(
|
|
|
|
|
position, embedder_position, embedder_viewdir, mlp_net, device="cuda"
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Given a position returns the occupancy probabily of the network.
|
|
|
|
|
|
|
|
|
@ -74,13 +82,15 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi
|
|
|
|
|
|
|
|
|
|
# take a random ray direction as it does not matter for sigma
|
|
|
|
|
ray_d = torch.rand(3, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# normalize ray direction
|
|
|
|
|
ray_d_norm = torch.norm(ray_d)
|
|
|
|
|
ray_d = ray_d / ray_d_norm
|
|
|
|
|
|
|
|
|
|
# forge the input
|
|
|
|
|
nn_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1)
|
|
|
|
|
nn_input = torch.cat(
|
|
|
|
|
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# forward the NN
|
|
|
|
|
nn_raw = mlp_net(nn_input)
|
|
|
|
@ -88,18 +98,25 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi
|
|
|
|
|
|
|
|
|
|
return sigma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# annonymous function
|
|
|
|
|
f = lambda x, y, z: query_occupancy(torch.tensor([x,y,z], dtype=torch.float32, device=device), fg_embedder_position, fg_embedder_viewdir, mlp_net)
|
|
|
|
|
f = lambda x, y, z: query_occupancy(
|
|
|
|
|
torch.tensor([x, y, z], dtype=torch.float32, device=device),
|
|
|
|
|
fg_embedder_position,
|
|
|
|
|
fg_embedder_viewdir,
|
|
|
|
|
mlp_net,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def marching_cube_and_render(sigma_list, threshold):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vertices, triangles = mcubes.marching_cubes(sigma_list, threshold)
|
|
|
|
|
mesh = trimesh.Trimesh(vertices / N - .5, triangles)
|
|
|
|
|
mesh = trimesh.Trimesh(vertices / N - 0.5, triangles)
|
|
|
|
|
mesh.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#position = torch.rand(3, device=device)
|
|
|
|
|
#position = torch.tensor([0.1, 0.1, 0.1], device=device)
|
|
|
|
|
# position = torch.rand(3, device=device)
|
|
|
|
|
# position = torch.tensor([0.1, 0.1, 0.1], device=device)
|
|
|
|
|
|
|
|
|
|
ray_d = torch.rand(3, device=device)
|
|
|
|
|
# normalize ray direction
|
|
|
|
@ -108,47 +125,49 @@ ray_d = ray_d / ray_d_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N = 100
|
|
|
|
|
t = np.linspace(-1, 1, N+1)
|
|
|
|
|
t = np.linspace(-1, 1, N + 1)
|
|
|
|
|
|
|
|
|
|
query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
|
|
|
|
|
#print(query_pts.shape)
|
|
|
|
|
# print(query_pts.shape)
|
|
|
|
|
sh = query_pts.shape
|
|
|
|
|
flat = query_pts.reshape([-1,3])
|
|
|
|
|
flat = query_pts.reshape([-1, 3])
|
|
|
|
|
|
|
|
|
|
#raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W
|
|
|
|
|
fg_raw_voxel = torch.zeros(N+1, N+1, N+1)
|
|
|
|
|
#bg_raw_voxel = torch.zeros(N+1, N+1, N+1)
|
|
|
|
|
# raw_voxel = torch.zeros(N+1, N+1, N+1, 4) # N, D, H, W
|
|
|
|
|
fg_raw_voxel = torch.zeros(N + 1, N + 1, N + 1)
|
|
|
|
|
# bg_raw_voxel = torch.zeros(N+1, N+1, N+1)
|
|
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
|
for x,y,z in tqdm(flat):
|
|
|
|
|
for x, y, z in tqdm(flat):
|
|
|
|
|
|
|
|
|
|
position = torch.tensor([x, y, z], device=device)
|
|
|
|
|
# bg_position = torch.cat((position, torch.tensor([1], device=device)))
|
|
|
|
|
|
|
|
|
|
position = torch.tensor([x,y,z], device=device)
|
|
|
|
|
#bg_position = torch.cat((position, torch.tensor([1], device=device)))
|
|
|
|
|
|
|
|
|
|
# concat the output of the embedding
|
|
|
|
|
fg_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1)
|
|
|
|
|
#bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1)
|
|
|
|
|
fg_input = torch.cat(
|
|
|
|
|
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
|
|
|
|
|
)
|
|
|
|
|
# bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1)
|
|
|
|
|
|
|
|
|
|
# forward
|
|
|
|
|
fg_raw = fg_mlp_net(fg_input)
|
|
|
|
|
#bg_raw = bg_mlp_net(bg_input)
|
|
|
|
|
# bg_raw = bg_mlp_net(bg_input)
|
|
|
|
|
|
|
|
|
|
#raw_voxel.append(position + float(nn_raw['sigma']))
|
|
|
|
|
# raw_voxel.append(position + float(nn_raw['sigma']))
|
|
|
|
|
fg_sigma = float(fg_raw["sigma"])
|
|
|
|
|
#bg_sigma = float(bg_raw["sigma"])
|
|
|
|
|
# bg_sigma = float(bg_raw["sigma"])
|
|
|
|
|
|
|
|
|
|
nx, ny, nz = np.unravel_index(i, (N+1, N+1, N+1))
|
|
|
|
|
nx, ny, nz = np.unravel_index(i, (N + 1, N + 1, N + 1))
|
|
|
|
|
i += 1 # update index
|
|
|
|
|
#raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z])
|
|
|
|
|
# raw_voxel[unraveled_index] = torch.tensor([sigma, x, y, z])
|
|
|
|
|
fg_raw_voxel[nx, ny, nz] = fg_sigma
|
|
|
|
|
#bg_raw_voxel[nx, ny, nz] = bg_sigma
|
|
|
|
|
# bg_raw_voxel[nx, ny, nz] = bg_sigma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fg_sigma = np.array(fg_raw_voxel)
|
|
|
|
|
#bg_sigma = np.array(bg_raw_voxel)
|
|
|
|
|
# bg_sigma = np.array(bg_raw_voxel)
|
|
|
|
|
threshold = 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#vertices, triangles = mcubes.marching_cubes(sigma, threshold)
|
|
|
|
|
#mesh = trimesh.Trimesh(vertices / N - .5, triangles)
|
|
|
|
|
#mesh.show()
|
|
|
|
|
# vertices, triangles = mcubes.marching_cubes(sigma, threshold)
|
|
|
|
|
# mesh = trimesh.Trimesh(vertices / N - .5, triangles)
|
|
|
|
|
# mesh.show()
|
|
|
|
|