Better training interface
This commit is contained in:
parent
c3ccc0bdb3
commit
7158181ee4
1 changed files with 8 additions and 6 deletions
14
run_nerf.py
14
run_nerf.py
|
@ -8,7 +8,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -673,7 +673,7 @@ def train():
|
||||||
rays_rgb = torch.Tensor(rays_rgb).to(device)
|
rays_rgb = torch.Tensor(rays_rgb).to(device)
|
||||||
|
|
||||||
|
|
||||||
N_iters = 1000000
|
N_iters = 200000 + 1
|
||||||
print('Begin')
|
print('Begin')
|
||||||
print('TRAIN views are', i_train)
|
print('TRAIN views are', i_train)
|
||||||
print('TEST views are', i_test)
|
print('TEST views are', i_test)
|
||||||
|
@ -682,7 +682,7 @@ def train():
|
||||||
# Summary writers
|
# Summary writers
|
||||||
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
|
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
|
||||||
|
|
||||||
for i in range(start, N_iters):
|
for i in trange(start, N_iters):
|
||||||
time0 = time.time()
|
time0 = time.time()
|
||||||
|
|
||||||
# Sample random ray batch
|
# Sample random ray batch
|
||||||
|
@ -745,7 +745,7 @@ def train():
|
||||||
################################
|
################################
|
||||||
|
|
||||||
dt = time.time()-time0
|
dt = time.time()-time0
|
||||||
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
|
# print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
|
||||||
##### end #####
|
##### end #####
|
||||||
|
|
||||||
# Rest is logging
|
# Rest is logging
|
||||||
|
@ -784,11 +784,13 @@ def train():
|
||||||
print('Saved test set')
|
print('Saved test set')
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
if i%args.i_print==0 or i < 10:
|
if i%args.i_print==0 or i < 10:
|
||||||
|
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
|
||||||
|
"""
|
||||||
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
|
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
|
||||||
print('iter time {:.05f}'.format(dt))
|
print('iter time {:.05f}'.format(dt))
|
||||||
|
|
||||||
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
|
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
|
||||||
tf.contrib.summary.scalar('loss', loss)
|
tf.contrib.summary.scalar('loss', loss)
|
||||||
tf.contrib.summary.scalar('psnr', psnr)
|
tf.contrib.summary.scalar('psnr', psnr)
|
||||||
|
|
Loading…
Reference in a new issue