Replicate initialization from tensorflow

master
Kevin James Matzen 3 years ago
parent 4eee8829d3
commit a8aa8729c0

@ -13,6 +13,15 @@ img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
class DenseLayer(nn.Linear):
def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None:
self.activation = activation
super().__init__(in_dim, out_dim, *args, **kwargs)
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
# Positional encoding (section 5.1)
class Embedder:
@ -80,21 +89,21 @@ class NeRF(nn.Module):
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
[DenseLayer(input_ch, W, activation="relu")] + [DenseLayer(W, W, activation="relu") if i not in self.skips else DenseLayer(W + input_ch, W, activation="relu") for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation="relu")])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
self.feature_linear = DenseLayer(W, W, activation="linear")
self.alpha_linear = DenseLayer(W, 1, activation="linear")
self.rgb_linear = DenseLayer(W//2, 3, activation="linear")
else:
self.output_linear = nn.Linear(W, output_ch)
self.output_linear = DenseLayer(W, output_ch, activation="linear")
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)

Loading…
Cancel
Save