diff --git a/ddp_train_nerf.py b/ddp_train_nerf.py index a217fdc..8988318 100644 --- a/ddp_train_nerf.py +++ b/ddp_train_nerf.py @@ -49,7 +49,10 @@ def intersect_sphere(ray_o, ray_d): p = ray_o + d1.unsqueeze(-1) * ray_d # consider the case where the ray does not intersect the sphere ray_d_cos = 1. / torch.norm(ray_d, dim=-1) - d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos + p_norm_sq = torch.sum(p * p, dim=-1) + if (p_norm_sq >= 1.).any(): + raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!') + d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos return d1 + d2