☕ reformat voxelisation
This commit is contained in:
parent
33adb79bf9
commit
67c023daba
1 changed files with 53 additions and 34 deletions
|
@ -7,8 +7,13 @@ import torch.distributed
|
|||
import trimesh
|
||||
from tqdm import tqdm
|
||||
|
||||
from ddp_train_nerf import (cleanup, config_parser, create_nerf, setup,
|
||||
setup_logger)
|
||||
from ddp_train_nerf import (
|
||||
cleanup,
|
||||
config_parser,
|
||||
create_nerf,
|
||||
setup,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
|
@ -21,7 +26,7 @@ args.rank = 0
|
|||
setup(args.rank, args.world_size)
|
||||
start, models = create_nerf(args.rank, args)
|
||||
|
||||
net_0 = models['net_0']
|
||||
net_0 = models["net_0"]
|
||||
|
||||
fg_far_depth = 1
|
||||
|
||||
|
@ -48,7 +53,10 @@ for idx, m in enumerate(net_0.modules()):
|
|||
# put everything on GPU
|
||||
device = "cuda"
|
||||
|
||||
def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, device="cuda"):
|
||||
|
||||
def query_occupancy(
|
||||
position, embedder_position, embedder_viewdir, mlp_net, device="cuda"
|
||||
):
|
||||
"""
|
||||
Given a position returns the occupancy probabily of the network.
|
||||
|
||||
|
@ -80,7 +88,9 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi
|
|||
ray_d = ray_d / ray_d_norm
|
||||
|
||||
# forge the input
|
||||
nn_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1)
|
||||
nn_input = torch.cat(
|
||||
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
|
||||
)
|
||||
|
||||
# forward the NN
|
||||
nn_raw = mlp_net(nn_input)
|
||||
|
@ -88,13 +98,20 @@ def query_occupancy(position, embedder_position, embedder_viewdir, mlp_net, devi
|
|||
|
||||
return sigma
|
||||
|
||||
|
||||
# annonymous function
|
||||
f = lambda x, y, z: query_occupancy(torch.tensor([x,y,z], dtype=torch.float32, device=device), fg_embedder_position, fg_embedder_viewdir, mlp_net)
|
||||
f = lambda x, y, z: query_occupancy(
|
||||
torch.tensor([x, y, z], dtype=torch.float32, device=device),
|
||||
fg_embedder_position,
|
||||
fg_embedder_viewdir,
|
||||
mlp_net,
|
||||
)
|
||||
|
||||
|
||||
def marching_cube_and_render(sigma_list, threshold):
|
||||
|
||||
vertices, triangles = mcubes.marching_cubes(sigma_list, threshold)
|
||||
mesh = trimesh.Trimesh(vertices / N - .5, triangles)
|
||||
mesh = trimesh.Trimesh(vertices / N - 0.5, triangles)
|
||||
mesh.show()
|
||||
|
||||
|
||||
|
@ -126,7 +143,9 @@ for x,y,z in tqdm(flat):
|
|||
# bg_position = torch.cat((position, torch.tensor([1], device=device)))
|
||||
|
||||
# concat the output of the embedding
|
||||
fg_input = torch.cat((fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1)
|
||||
fg_input = torch.cat(
|
||||
(fg_embedder_position(position), fg_embedder_viewdir(ray_d)), dim=-1
|
||||
)
|
||||
# bg_input = torch.cat((bg_embedder_position(bg_position), bg_embedder_viewdir(ray_d)), dim=-1)
|
||||
|
||||
# forward
|
||||
|
|
Loading…
Reference in a new issue