From a8aa8729c0b5c32b6980e79c07d3fc274c146b66 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Tue, 19 Jan 2021 20:54:50 +0000 Subject: [PATCH] Replicate initialization from tensorflow --- run_nerf_helpers.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/run_nerf_helpers.py b/run_nerf_helpers.py index 3c68271..ca270e6 100644 --- a/run_nerf_helpers.py +++ b/run_nerf_helpers.py @@ -13,6 +13,15 @@ img2mse = lambda x, y : torch.mean((x - y) ** 2) mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) +class DenseLayer(nn.Linear): + def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None: + self.activation = activation + super().__init__(in_dim, out_dim, *args, **kwargs) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) # Positional encoding (section 5.1) class Embedder: @@ -80,21 +89,21 @@ class NeRF(nn.Module): self.use_viewdirs = use_viewdirs self.pts_linears = nn.ModuleList( - [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) + [DenseLayer(input_ch, W, activation="relu")] + [DenseLayer(W, W, activation="relu") if i not in self.skips else DenseLayer(W + input_ch, W, activation="relu") for i in range(D-1)]) ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) - self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) + self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation="relu")]) ### Implementation according to the paper # self.views_linears = nn.ModuleList( # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) if use_viewdirs: - self.feature_linear = nn.Linear(W, W) - self.alpha_linear = nn.Linear(W, 1) - self.rgb_linear = nn.Linear(W//2, 3) + self.feature_linear = DenseLayer(W, W, activation="linear") + self.alpha_linear = DenseLayer(W, 1, activation="linear") + self.rgb_linear = DenseLayer(W//2, 3, activation="linear") else: - self.output_linear = nn.Linear(W, output_ch) + self.output_linear = DenseLayer(W, output_ch, activation="linear") def forward(self, x): input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)