Replicate initialization from tensorflow
This commit is contained in:
parent
4eee8829d3
commit
a8aa8729c0
1 changed files with 15 additions and 6 deletions
|
@ -13,6 +13,15 @@ img2mse = lambda x, y : torch.mean((x - y) ** 2)
|
|||
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
|
||||
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
|
||||
|
||||
class DenseLayer(nn.Linear):
|
||||
def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None:
|
||||
self.activation = activation
|
||||
super().__init__(in_dim, out_dim, *args, **kwargs)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
|
||||
if self.bias is not None:
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
# Positional encoding (section 5.1)
|
||||
class Embedder:
|
||||
|
@ -80,21 +89,21 @@ class NeRF(nn.Module):
|
|||
self.use_viewdirs = use_viewdirs
|
||||
|
||||
self.pts_linears = nn.ModuleList(
|
||||
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
|
||||
[DenseLayer(input_ch, W, activation="relu")] + [DenseLayer(W, W, activation="relu") if i not in self.skips else DenseLayer(W + input_ch, W, activation="relu") for i in range(D-1)])
|
||||
|
||||
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
|
||||
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
|
||||
self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation="relu")])
|
||||
|
||||
### Implementation according to the paper
|
||||
# self.views_linears = nn.ModuleList(
|
||||
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
|
||||
|
||||
if use_viewdirs:
|
||||
self.feature_linear = nn.Linear(W, W)
|
||||
self.alpha_linear = nn.Linear(W, 1)
|
||||
self.rgb_linear = nn.Linear(W//2, 3)
|
||||
self.feature_linear = DenseLayer(W, W, activation="linear")
|
||||
self.alpha_linear = DenseLayer(W, 1, activation="linear")
|
||||
self.rgb_linear = DenseLayer(W//2, 3, activation="linear")
|
||||
else:
|
||||
self.output_linear = nn.Linear(W, output_ch)
|
||||
self.output_linear = DenseLayer(W, output_ch, activation="linear")
|
||||
|
||||
def forward(self, x):
|
||||
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
|
||||
|
|
Loading…
Reference in a new issue