first commit
This commit is contained in:
commit
5fbe15ff24
58 changed files with 4470 additions and 0 deletions
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
15
.idea/deployment.xml
Normal file
15
.idea/deployment.xml
Normal file
|
@ -0,0 +1,15 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="nerf_bg_latest_ddp">
|
||||
<serverData>
|
||||
<paths name="nerf_bg_latest_ddp">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/zhangka2/gernot_experi/nerf_bg_latest_ddp" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ALWAYS" />
|
||||
</component>
|
||||
</project>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
7
.idea/misc.xml
Normal file
7
.idea/misc.xml
Normal file
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (nerf)" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/nerf_bg_latest_ddp.iml" filepath="$PROJECT_DIR$/.idea/nerf_bg_latest_ddp.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
13
.idea/nerf_bg_latest_ddp.iml
Normal file
13
.idea/nerf_bg_latest_ddp.iml
Normal file
|
@ -0,0 +1,13 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (nerf)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
<option name="renderExternalDocumentation" value="true" />
|
||||
</component>
|
||||
</module>
|
7
.idea/other.xml
Normal file
7
.idea/other.xml
Normal file
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_SCI_VIEW" value="true" />
|
||||
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
||||
</component>
|
||||
</project>
|
14
.idea/webServers.xml
Normal file
14
.idea/webServers.xml
Normal file
|
@ -0,0 +1,14 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebServers">
|
||||
<option name="servers">
|
||||
<webServer id="1d5a6596-7d8b-45af-9960-9d4d014e6bbe" name="nerf_bg_latest_ddp">
|
||||
<fileTransfer accessType="SFTP" host="isl-iam2.rr.intel.com" port="22" sshConfigId="d5ddaa0b-8e6c-4721-ad8e-298dc2859ce7" sshConfig="intel_cluster" keyPair="true">
|
||||
<advancedOptions>
|
||||
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
||||
</advancedOptions>
|
||||
</fileTransfer>
|
||||
</webServer>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
0
README.md
Normal file
0
README.md
Normal file
BIN
__pycache__/utils.cpython-37.pyc
Normal file
BIN
__pycache__/utils.cpython-37.pyc
Normal file
Binary file not shown.
162
conda_env_nerf.yml
Normal file
162
conda_env_nerf.yml
Normal file
|
@ -0,0 +1,162 @@
|
|||
name: nerf
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=0_gnu
|
||||
- absl-py=0.9.0=py37_0
|
||||
- astor=0.7.1=py_0
|
||||
- boost=1.72.0=py37h9de70de_0
|
||||
- boost-cpp=1.72.0=h8e57a91_0
|
||||
- bzip2=1.0.8=h516909a_2
|
||||
- c-ares=1.15.0=h516909a_1001
|
||||
- ca-certificates=2020.4.5.1=hecc5488_0
|
||||
- cairo=1.16.0=hfb77d84_1002
|
||||
- certifi=2020.4.5.1=py37hc8dfbb8_0
|
||||
- cffi=1.13.2=py37h8022711_0
|
||||
- cloudpickle=1.3.0=py_0
|
||||
- cudatoolkit=10.0.130=0
|
||||
- cycler=0.10.0=py_2
|
||||
- cytoolz=0.10.1=py37h516909a_0
|
||||
- dask-core=2.10.1=py_0
|
||||
- dbus=1.13.6=he372182_0
|
||||
- decorator=4.4.1=py_0
|
||||
- expat=2.2.9=he1b5a44_2
|
||||
- ffmpeg=4.1.3=h167e202_0
|
||||
- fontconfig=2.13.1=h86ecdb6_1001
|
||||
- freetype=2.10.0=he983fc9_1
|
||||
- gast=0.3.3=py_0
|
||||
- gettext=0.19.8.1=hc5be6a0_1002
|
||||
- giflib=5.2.1=h516909a_1
|
||||
- glib=2.58.3=py37h6f030ca_1002
|
||||
- gmp=6.2.0=he1b5a44_1
|
||||
- gnutls=3.6.5=hd3a4fd2_1002
|
||||
- graphite2=1.3.13=hf484d3e_1000
|
||||
- grpcio=1.23.0=py37hb0870dc_1
|
||||
- gst-plugins-base=1.14.5=h0935bb2_2
|
||||
- gstreamer=1.14.5=h36ae1b5_2
|
||||
- h5py=2.10.0=nompi_py37h513d04c_102
|
||||
- harfbuzz=2.4.0=h9f30f68_3
|
||||
- hdf5=1.10.5=nompi_h3c11f04_1104
|
||||
- icu=64.2=he1b5a44_1
|
||||
- ilmbase=2.4.1=h8b12597_0
|
||||
- imageio=2.6.1=py37_0
|
||||
- intel-openmp=2020.0=166
|
||||
- jasper=1.900.1=h07fcdf6_1006
|
||||
- joblib=0.14.1=py_0
|
||||
- jpeg=9c=h14c3975_1001
|
||||
- keras-applications=1.0.8=py_1
|
||||
- keras-preprocessing=1.1.0=py_0
|
||||
- kiwisolver=1.1.0=py37hc9558a2_0
|
||||
- krb5=1.16.4=h2fd8d38_0
|
||||
- lame=3.100=h14c3975_1001
|
||||
- ld_impl_linux-64=2.33.1=h53a641e_8
|
||||
- libblas=3.8.0=14_openblas
|
||||
- libcblas=3.8.0=14_openblas
|
||||
- libclang=9.0.1=default_hde54327_0
|
||||
- libcurl=7.68.0=hda55be3_0
|
||||
- libedit=3.1.20170329=hf8c457e_1001
|
||||
- libffi=3.2.1=he1b5a44_1006
|
||||
- libgcc-ng=9.2.0=h24d8f2e_2
|
||||
- libgfortran-ng=7.3.0=hdf63c60_5
|
||||
- libgomp=9.2.0=h24d8f2e_2
|
||||
- libiconv=1.15=h516909a_1005
|
||||
- liblapack=3.8.0=14_openblas
|
||||
- liblapacke=3.8.0=14_openblas
|
||||
- libllvm9=9.0.1=hc9558a2_0
|
||||
- libopenblas=0.3.7=h5ec1e0e_6
|
||||
- libopencv=4.2.0=py37_2
|
||||
- libpng=1.6.37=hed695b0_0
|
||||
- libprotobuf=3.8.0=h8b12597_0
|
||||
- libssh2=1.8.2=h22169c7_2
|
||||
- libstdcxx-ng=9.2.0=hdf63c60_2
|
||||
- libtiff=4.1.0=hc3755c2_3
|
||||
- libuuid=2.32.1=h14c3975_1000
|
||||
- libwebp=1.0.2=h56121f0_5
|
||||
- libxcb=1.13=h14c3975_1002
|
||||
- libxkbcommon=0.10.0=he1b5a44_0
|
||||
- libxml2=2.9.10=hee79883_0
|
||||
- lz4-c=1.8.3=he1b5a44_1001
|
||||
- markdown=3.2.1=py_0
|
||||
- matplotlib-base=3.1.3=py37h250f245_0
|
||||
- mkl=2020.0=166
|
||||
- mock=3.0.5=py37_0
|
||||
- ncurses=6.1=hf484d3e_1002
|
||||
- nettle=3.4.1=h1bed415_1002
|
||||
- networkx=2.4=py_0
|
||||
- ninja=1.10.0=hc9558a2_0
|
||||
- nspr=4.25=he1b5a44_0
|
||||
- nss=3.47=he751ad9_0
|
||||
- numpy=1.18.1=py37h95a1406_0
|
||||
- olefile=0.46=py_0
|
||||
- opencv=4.2.0=py37_2
|
||||
- openh264=1.8.0=hdbcaa40_1000
|
||||
- openimageio=2.1.13=hf311ebb_0
|
||||
- openssl=1.1.1g=h516909a_0
|
||||
- pandas=1.0.1=py37hb3f55d8_0
|
||||
- patsy=0.5.1=py_0
|
||||
- pcre=8.44=he1b5a44_0
|
||||
- pillow=6.2.1=py37hd70f55b_1
|
||||
- pip=20.0.2=py_2
|
||||
- pixman=0.38.0=h516909a_1003
|
||||
- protobuf=3.8.0=py37he1b5a44_2
|
||||
- pthread-stubs=0.4=h14c3975_1001
|
||||
- py-opencv=4.2.0=py37h5ca1d4c_2
|
||||
- py-openimageio=2.1.13=py37hf311ebb_0
|
||||
- pycparser=2.19=py37_1
|
||||
- pyparsing=2.4.6=py_0
|
||||
- python=3.7.6=h357f687_2
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- python_abi=3.7=1_cp37m
|
||||
- pytorch=1.0.1=py3.7_cuda10.0.130_cudnn7.4.2_2
|
||||
- pytz=2019.3=py_0
|
||||
- pywavelets=1.1.1=py37hc1659b7_0
|
||||
- qt=5.12.5=hd8c4c69_1
|
||||
- readline=8.0=hf8c457e_0
|
||||
- scikit-image=0.16.2=py37hb3f55d8_0
|
||||
- scikit-learn=0.22.1=py37hcdab131_1
|
||||
- scipy=1.4.1=py37h921218d_0
|
||||
- seaborn=0.10.0=py_1
|
||||
- setuptools=45.2.0=py37_0
|
||||
- six=1.14.0=py37_0
|
||||
- sqlite=3.30.1=hcee41ef_0
|
||||
- statsmodels=0.11.0=py37h516909a_0
|
||||
- tensorboard=1.13.1=py37_0
|
||||
- tensorboardx=2.0=py_0
|
||||
- tensorflow=1.13.1=h5ece82f_5
|
||||
- tensorflow-base=1.13.1=py37h5ece82f_5
|
||||
- tensorflow-estimator=1.13.0=py_0
|
||||
- termcolor=1.1.0=py_2
|
||||
- tk=8.6.10=hed695b0_0
|
||||
- toolz=0.10.0=py_0
|
||||
- torchvision=0.2.2=py_3
|
||||
- tornado=6.0.3=py37h516909a_4
|
||||
- tqdm=4.42.1=py_0
|
||||
- werkzeug=1.0.0=py_0
|
||||
- wheel=0.34.2=py_1
|
||||
- x264=1!152.20180806=h14c3975_0
|
||||
- xorg-kbproto=1.0.7=h14c3975_1002
|
||||
- xorg-libice=1.0.10=h516909a_0
|
||||
- xorg-libsm=1.2.3=h84519dc_1000
|
||||
- xorg-libx11=1.6.9=h516909a_0
|
||||
- xorg-libxau=1.0.9=h14c3975_0
|
||||
- xorg-libxdmcp=1.1.3=h516909a_0
|
||||
- xorg-libxext=1.3.4=h516909a_0
|
||||
- xorg-libxrender=0.9.10=h516909a_1002
|
||||
- xorg-renderproto=0.11.1=h14c3975_1002
|
||||
- xorg-xextproto=7.3.0=h14c3975_1002
|
||||
- xorg-xproto=7.0.31=h14c3975_1007
|
||||
- xz=5.2.4=h14c3975_1001
|
||||
- zlib=1.2.11=h516909a_1006
|
||||
- zstd=1.4.4=h3b9ef0a_1
|
||||
- pip:
|
||||
- configargparse==1.2.3
|
||||
- future==0.18.2
|
||||
- imageio-ffmpeg==0.4.1
|
||||
- openexr==1.3.2
|
||||
- pyexr==0.3.7
|
||||
- pymcubes==0.1.0
|
||||
- pyquaternion==0.9.5
|
||||
prefix: /home/kz298/anaconda3/envs/nerf
|
48
configs/lf_data/lf_africa.txt
Normal file
48
configs/lf_data/lf_africa.txt
Normal file
|
@ -0,0 +1,48 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/lf_data/lf_nerf
|
||||
scene = africa
|
||||
expname = africa_ddp
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
N_rand = 512
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
chunk_size = 4096
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
48
configs/lf_data/lf_basket.txt
Normal file
48
configs/lf_data/lf_basket.txt
Normal file
|
@ -0,0 +1,48 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/lf_data/lf_nerf
|
||||
scene = basket
|
||||
expname = basket_ddp
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
N_rand = 512
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
chunk_size = 4096
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
50
configs/lf_data/lf_ship.txt
Normal file
50
configs/lf_data/lf_ship.txt
Normal file
|
@ -0,0 +1,50 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/lf_data/lf_nerf
|
||||
scene = ship
|
||||
expname = ship_ddp
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 512
|
||||
N_rand = 1024
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
# chunk_size = 4096
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
48
configs/lf_data/lf_torch.txt
Normal file
48
configs/lf_data/lf_torch.txt
Normal file
|
@ -0,0 +1,48 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/lf_data/lf_nerf
|
||||
scene = torch
|
||||
expname = torch_ddp
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
N_rand = 512
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
chunk_size = 4096
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
49
configs/tanks_and_temples/tat_intermediate_m60.txt
Normal file
49
configs/tanks_and_temples/tat_intermediate_m60.txt
Normal file
|
@ -0,0 +1,49 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_intermediate_M60
|
||||
expname = tat_intermediate_M60_bg_carve_latest
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
49
configs/tanks_and_temples/tat_intermediate_playground.txt
Normal file
49
configs/tanks_and_temples/tat_intermediate_playground.txt
Normal file
|
@ -0,0 +1,49 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_intermediate_Playground
|
||||
expname = tat_intermediate_Playground_bg_carve_latest
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,48 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_intermediate_Playground
|
||||
expname = tat_intermediate_Playground_ddp_bignet
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
N_rand = 256
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
chunk_size = 4096
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 512
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
49
configs/tanks_and_temples/tat_intermediate_train.txt
Normal file
49
configs/tanks_and_temples/tat_intermediate_train.txt
Normal file
|
@ -0,0 +1,49 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_intermediate_Train
|
||||
expname = tat_intermediate_Train_bg_carve_latest
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
53
configs/tanks_and_temples/tat_training_truck.txt
Normal file
53
configs/tanks_and_temples/tat_training_truck.txt
Normal file
|
@ -0,0 +1,53 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_training_Truck
|
||||
expname = tat_training_Truck_ddp_implicit
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 512
|
||||
N_rand = 1024
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
# chunk_size = 4096
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
48
configs/tanks_and_temples/tat_training_truck_bignet.txt
Normal file
48
configs/tanks_and_temples/tat_training_truck_bignet.txt
Normal file
|
@ -0,0 +1,48 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_training_Truck
|
||||
expname = tat_training_Truck_ddp_bignet
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
N_rand = 256
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
chunk_size = 4096
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 512
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
47
configs/tanks_and_temples/tat_training_truck_subset.txt
Normal file
47
configs/tanks_and_temples/tat_training_truck_subset.txt
Normal file
|
@ -0,0 +1,47 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
|
||||
scene = tat_training_Truck_subset
|
||||
expname = tat_training_Truck_subset_bg_carvenew
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 250001
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,64
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = False
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,54 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_intermediate_Playground
|
||||
expname = tat_intermediate_Playground_ddp_sparse_addcarve
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = True
|
||||
regularize_weight = 0.1
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,54 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_intermediate_Playground
|
||||
expname = tat_intermediate_Playground_ddp_sparse_addparam
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = False
|
||||
regularize_weight = 0.
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,54 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_intermediate_Playground
|
||||
expname = tat_intermediate_Playground_ddp_sparse_addregularize_pretrain
|
||||
basedir = ./logs
|
||||
config = /home/zhangka2/gernot_experi/nerf_bg_latest_ddp/logs/tat_intermediate_Playground_ddp_sparse_addparam/model_210000.pth
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 4096
|
||||
N_rand = 2048
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 500000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = False
|
||||
regularize_weight = 0.1
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
chunk_size = 16384
|
||||
# chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,55 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_training_Truck
|
||||
expname = tat_training_Truck_ddp_sparse_addcarve
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 512
|
||||
N_rand = 1024
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = True
|
||||
regularize_weight = 0.1
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
# chunk_size = 4096
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,55 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_training_Truck
|
||||
expname = tat_training_Truck_ddp_sparse_addparam
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = None
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 512
|
||||
N_rand = 1024
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = False
|
||||
regularize_weight = 0.
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
# chunk_size = 4096
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
|
@ -0,0 +1,55 @@
|
|||
### INPUT
|
||||
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
|
||||
scene = tat_training_Truck
|
||||
expname = tat_training_Truck_ddp_sparse_addregularize_pretrain
|
||||
basedir = ./logs
|
||||
config = None
|
||||
ckpt_path = /home/zhangka2/gernot_experi/nerf_bg_latest_ddp/logs/tat_training_Truck_ddp_sparse_addparam/model_245000.pth
|
||||
no_reload = False
|
||||
testskip = 1
|
||||
|
||||
### TRAINING
|
||||
N_iters = 1250001
|
||||
# N_rand = 512
|
||||
N_rand = 1024
|
||||
lrate = 0.0005
|
||||
lrate_decay_factor = 0.1
|
||||
lrate_decay_steps = 50000000
|
||||
|
||||
### implicit
|
||||
use_implicit = True
|
||||
load_min_depth = False
|
||||
regularize_weight = 0.1
|
||||
|
||||
### CASCADE
|
||||
cascade_level = 2
|
||||
cascade_samples = 64,128
|
||||
near_depth = 0.
|
||||
far_depth = 1.
|
||||
|
||||
### TESTING
|
||||
render_only = False
|
||||
render_test = False
|
||||
render_train = False
|
||||
# chunk_size = 16384
|
||||
# chunk_size = 4096
|
||||
chunk_size = 8192
|
||||
|
||||
### RENDERING
|
||||
det = False
|
||||
max_freq_log2 = 10
|
||||
max_freq_log2_viewdirs = 4
|
||||
netdepth = 8
|
||||
netwidth = 256
|
||||
raw_noise_std = 1.0
|
||||
N_iters_perturb = 1000
|
||||
inv_uniform = False
|
||||
use_viewdirs = True
|
||||
white_bkgd = False
|
||||
|
||||
### CONSOLE AND TENSORBOARD
|
||||
i_img = 2000
|
||||
i_print = 100
|
||||
i_testset = 5000000
|
||||
i_video = 5000000
|
||||
i_weights = 5000
|
94
data_loader_split.py
Normal file
94
data_loader_split.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
import logging
|
||||
from nerf_sample_ray_split import RaySamplerSingleImage
|
||||
import glob
|
||||
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
########################################################################################################################
|
||||
# camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention)
|
||||
# poses is camera-to-world
|
||||
########################################################################################################################
|
||||
def find_files(dir, exts):
|
||||
if os.path.isdir(dir):
|
||||
# types should be ['*.png', '*.jpg']
|
||||
files_grabbed = []
|
||||
for ext in exts:
|
||||
files_grabbed.extend(glob.glob(os.path.join(dir, ext)))
|
||||
if len(files_grabbed) > 0:
|
||||
files_grabbed = sorted(files_grabbed)
|
||||
return files_grabbed
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True):
|
||||
def parse_txt(filename):
|
||||
assert os.path.isfile(filename)
|
||||
nums = open(filename).read().split()
|
||||
return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
|
||||
|
||||
split_dir = '{}/{}/{}'.format(basedir, scene, split)
|
||||
intrinsics_files = find_files('{}/intrinsics'.format(split_dir), exts=['*.txt'])
|
||||
pose_files = find_files('{}/pose'.format(split_dir), exts=['*.txt'])
|
||||
|
||||
logger.info('raw intrinsics_files: {}'.format(len(intrinsics_files)))
|
||||
logger.info('raw pose_files: {}'.format(len(pose_files)))
|
||||
|
||||
intrinsics_files = intrinsics_files[::skip]
|
||||
pose_files = pose_files[::skip]
|
||||
cam_cnt = len(pose_files)
|
||||
|
||||
# img files
|
||||
img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg'])
|
||||
if len(img_files) > 0:
|
||||
logger.info('raw img_files: {}'.format(len(img_files)))
|
||||
img_files = img_files[::skip]
|
||||
assert(len(img_files) == cam_cnt)
|
||||
else:
|
||||
img_files = [None, ] * cam_cnt
|
||||
# mask files
|
||||
mask_files = find_files('{}/mask'.format(split_dir), exts=['*.png', '*.jpg'])
|
||||
if len(mask_files) > 0:
|
||||
logger.info('raw mask_files: {}'.format(len(mask_files)))
|
||||
mask_files = mask_files[::skip]
|
||||
assert(len(mask_files) == cam_cnt)
|
||||
else:
|
||||
mask_files = [None, ] * cam_cnt
|
||||
|
||||
# min depth files
|
||||
mindepth_files = find_files('{}/min_depth'.format(split_dir), exts=['*.png', '*.jpg'])
|
||||
if try_load_min_depth and len(mindepth_files) > 0:
|
||||
logger.info('raw mindepth_files: {}'.format(len(mindepth_files)))
|
||||
mindepth_files = mindepth_files[::skip]
|
||||
assert(len(mindepth_files) == cam_cnt)
|
||||
else:
|
||||
mindepth_files = [None, ] * cam_cnt
|
||||
|
||||
# assume all images have the same size
|
||||
train_imgfile = find_files('{}/{}/train/rgb'.format(basedir, scene), exts=['*.png', '*.jpg'])[0]
|
||||
train_im = imageio.imread(train_imgfile)
|
||||
H, W = train_im.shape[:2]
|
||||
|
||||
ray_samplers = []
|
||||
for i in range(cam_cnt):
|
||||
intrinsics = parse_txt(intrinsics_files[i])
|
||||
pose = parse_txt(pose_files[i])
|
||||
|
||||
# read max depth
|
||||
try:
|
||||
max_depth = float(open('{}/max_depth.txt'.format(split_dir)).readline().strip())
|
||||
except:
|
||||
max_depth = None
|
||||
|
||||
ray_samplers.append(RaySamplerSingleImage(H=H, W=W, intrinsics=intrinsics, c2w=pose,
|
||||
img_path=img_files[i],
|
||||
mask_path=mask_files[i],
|
||||
min_depth_path=mindepth_files[i],
|
||||
max_depth=max_depth))
|
||||
|
||||
logger.info('Split {}, # views: {}'.format(split, cam_cnt))
|
||||
|
||||
return ray_samplers
|
131
data_verifier.py
Normal file
131
data_verifier.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
## pip install opencv-python=3.4.2.17 opencv-contrib-python==3.4.2.17
|
||||
|
||||
|
||||
def skew(x):
|
||||
return np.array([[0, -x[2], x[1]],
|
||||
[x[2], 0, -x[0]],
|
||||
[-x[1], x[0], 0]])
|
||||
|
||||
|
||||
def two_view_geometry(intrinsics1, extrinsics1, intrinsics2, extrinsics2):
|
||||
'''
|
||||
:param intrinsics1: 4 by 4 matrix
|
||||
:param extrinsics1: 4 by 4 W2C matrix
|
||||
:param intrinsics2: 4 by 4 matrix
|
||||
:param extrinsics2: 4 by 4 W2C matrix
|
||||
:return:
|
||||
'''
|
||||
relative_pose = extrinsics2.dot(np.linalg.inv(extrinsics1))
|
||||
R = relative_pose[:3, :3]
|
||||
T = relative_pose[:3, 3]
|
||||
tx = skew(T)
|
||||
E = np.dot(tx, R)
|
||||
F = np.linalg.inv(intrinsics2[:3, :3]).T.dot(E).dot(np.linalg.inv(intrinsics1[:3, :3]))
|
||||
|
||||
return E, F, relative_pose
|
||||
|
||||
|
||||
def drawpointslines(img1, img2, lines1, pts2, color):
|
||||
'''
|
||||
draw corresponding epilines on img1 for the points in img2
|
||||
'''
|
||||
|
||||
r, c = img1.shape
|
||||
img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
|
||||
for r, pt2, cl in zip(lines1, pts2, color):
|
||||
x0, y0 = map(int, [0, -r[2]/r[1]])
|
||||
x1, y1 = map(int, [c, -(r[2]+r[0]*c)/r[1]])
|
||||
cl = tuple(cl.tolist())
|
||||
img1 = cv2.line(img1, (x0,y0), (x1,y1), cl, 1)
|
||||
img2 = cv2.circle(img2, tuple(pt2), 5, cl, -1)
|
||||
return img1, img2
|
||||
|
||||
|
||||
def epipolar(coord1, F, img1, img2):
|
||||
# compute epipole
|
||||
pts1 = coord1.astype(int).T
|
||||
color = np.random.randint(0, high=255, size=(len(pts1), 3))
|
||||
# Find epilines corresponding to points in left image (first image) and
|
||||
# drawing its lines on right image
|
||||
lines2 = cv2.computeCorrespondEpilines(pts1.reshape(-1,1,2), 1,F)
|
||||
lines2 = lines2.reshape(-1,3)
|
||||
img3, img4 = drawpointslines(img2,img1,lines2,pts1,color)
|
||||
## print(img3.shape)
|
||||
## print(np.concatenate((img4, img3)).shape)
|
||||
## cv2.imwrite('vis.png', np.concatenate((img4, img3), axis=1))
|
||||
|
||||
return np.concatenate((img4, img3), axis=1)
|
||||
|
||||
|
||||
def verify_data(img1, img2, intrinsics1, extrinsics1, intrinsics2, extrinsics2):
|
||||
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
E, F, relative_pose = two_view_geometry(intrinsics1, extrinsics1,
|
||||
intrinsics2, extrinsics2)
|
||||
|
||||
# sift = cv2.xfeatures2d.SIFT_create(nfeatures=20)
|
||||
# kp1 = sift.detect(img1, mask=None)
|
||||
# coord1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1]).T
|
||||
|
||||
# Initiate ORB detector
|
||||
orb = cv2.ORB_create()
|
||||
# find the keypoints with ORB
|
||||
kp1 = orb.detect(img1, None)
|
||||
coord1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1[:20]]).T
|
||||
return epipolar(coord1, F, img1, img2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from data_loader import load_data
|
||||
from run_nerf import config_parser
|
||||
from nerf_sample_ray import parse_camera
|
||||
import os
|
||||
|
||||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
data = load_data(args.datadir, args.scene, testskip=1)
|
||||
|
||||
all_imgs = data['images']
|
||||
all_cameras = data['cameras']
|
||||
all_intrinsics = []
|
||||
all_extrinsics = [] # W2C
|
||||
for i in range(all_cameras.shape[0]):
|
||||
W, H, intrinsics, extrinsics = parse_camera(all_cameras[i])
|
||||
all_intrinsics.append(intrinsics)
|
||||
all_extrinsics.append(np.linalg.inv(extrinsics))
|
||||
|
||||
#### arbitrarily select 10 pairs of images to verify pose
|
||||
out_dir = os.path.join(args.basedir, args.expname, 'data_verify')
|
||||
print(out_dir)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
def calc_angles(c2w_1, c2w_2):
|
||||
c1 = c2w_1[:3, 3:4]
|
||||
c2 = c2w_2[:3, 3:4]
|
||||
|
||||
c1 = c1 / np.linalg.norm(c1)
|
||||
c2 = c2 / np.linalg.norm(c2)
|
||||
return np.rad2deg(np.arccos(np.dot(c1.T, c2)))
|
||||
|
||||
images_verify = []
|
||||
for i in range(10):
|
||||
while True:
|
||||
idx1, idx2 = np.random.choice(len(all_imgs), (2,), replace=False)
|
||||
|
||||
angle = calc_angles(np.linalg.inv(all_extrinsics[idx1]),
|
||||
np.linalg.inv(all_extrinsics[idx2]))
|
||||
if angle > 5. and angle < 10.:
|
||||
break
|
||||
|
||||
im = verify_data(np.uint8(all_imgs[idx1]*255.), np.uint8(all_imgs[idx2]*255.),
|
||||
all_intrinsics[idx1], all_extrinsics[idx1],
|
||||
all_intrinsics[idx2], all_extrinsics[idx2])
|
||||
cv2.imwrite(os.path.join(out_dir, '{:03d}.png'.format(i)), im)
|
155
ddp_model.py
Normal file
155
ddp_model.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F
|
||||
# import numpy as np
|
||||
from utils import TINY_NUMBER, HUGE_NUMBER
|
||||
from collections import OrderedDict
|
||||
from nerf_network import Embedder, MLPNet
|
||||
|
||||
|
||||
######################################################################################
|
||||
# wrapper to simplify the use of nerfnet
|
||||
######################################################################################
|
||||
def depth2pts_outside(ray_o, ray_d, depth):
|
||||
'''
|
||||
ray_o, ray_d: [..., 3]
|
||||
depth: [...]; inverse of distance to sphere origin
|
||||
'''
|
||||
# note: d1 becomes negative if this mid point is behind camera
|
||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
||||
p_mid = ray_o + d1.unsqueeze(-1) * ray_d
|
||||
p_mid_norm = torch.norm(p_mid, dim=-1)
|
||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
||||
d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
|
||||
p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
|
||||
|
||||
rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
|
||||
rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
|
||||
phi = torch.asin(p_mid_norm)
|
||||
theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
|
||||
rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
|
||||
|
||||
# now rotate p_sphere
|
||||
# Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
||||
p_sphere_new = p_sphere * torch.cos(rot_angle) + \
|
||||
torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
|
||||
rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
|
||||
p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
|
||||
pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# now calculate conventional depth
|
||||
depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
|
||||
return pts, depth_real
|
||||
|
||||
|
||||
class NerfNet(nn.Module):
|
||||
def __init__(self, args):
|
||||
'''
|
||||
:param D: network depth
|
||||
:param W: network width
|
||||
:param input_ch: input channels for encodings of (x, y, z)
|
||||
:param input_ch_viewdirs: input channels for encodings of view directions
|
||||
:param skips: skip connection in network
|
||||
:param use_viewdirs: if True, will use the view directions as input
|
||||
'''
|
||||
super().__init__()
|
||||
# foreground
|
||||
self.fg_embedder_position = Embedder(input_dim=3,
|
||||
max_freq_log2=args.max_freq_log2 - 1,
|
||||
N_freqs=args.max_freq_log2)
|
||||
self.fg_embedder_viewdir = Embedder(input_dim=3,
|
||||
max_freq_log2=args.max_freq_log2_viewdirs - 1,
|
||||
N_freqs=args.max_freq_log2_viewdirs)
|
||||
self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
||||
input_ch=self.fg_embedder_position.out_dim,
|
||||
input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
|
||||
use_viewdirs=args.use_viewdirs,
|
||||
use_implicit=args.use_implicit)
|
||||
# background; bg_pt is (x, y, z, 1/r)
|
||||
self.bg_embedder_position = Embedder(input_dim=4,
|
||||
max_freq_log2=args.max_freq_log2 - 1,
|
||||
N_freqs=args.max_freq_log2)
|
||||
self.bg_embedder_viewdir = Embedder(input_dim=3,
|
||||
max_freq_log2=args.max_freq_log2_viewdirs - 1,
|
||||
N_freqs=args.max_freq_log2_viewdirs)
|
||||
self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
||||
input_ch=self.bg_embedder_position.out_dim,
|
||||
input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
|
||||
use_viewdirs=args.use_viewdirs,
|
||||
use_implicit=args.use_implicit)
|
||||
|
||||
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
|
||||
'''
|
||||
:param ray_o, ray_d: [..., 3]
|
||||
:param fg_z_max: [...,]
|
||||
:param fg_z_vals, bg_z_vals: [..., N_samples]
|
||||
:return
|
||||
'''
|
||||
# print(ray_o.shape, ray_d.shape, fg_z_max.shape, fg_z_vals.shape, bg_z_vals.shape)
|
||||
ray_d_norm = torch.norm(ray_d, dim=-1, keepdim=True) # [..., 1]
|
||||
viewdirs = ray_d / ray_d_norm # [..., 3]
|
||||
dots_sh = list(ray_d.shape[:-1])
|
||||
|
||||
######### render foreground
|
||||
N_samples = fg_z_vals.shape[-1]
|
||||
fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
fg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
fg_pts = fg_ray_o + fg_z_vals.unsqueeze(-1) * fg_ray_d
|
||||
input = torch.cat((self.fg_embedder_position(fg_pts),
|
||||
self.fg_embedder_viewdir(fg_viewdirs)), dim=-1)
|
||||
fg_raw = self.fg_net(input)
|
||||
# alpha blending
|
||||
fg_dists = fg_z_vals[..., 1:] - fg_z_vals[..., :-1]
|
||||
# account for view directions
|
||||
fg_dists = ray_d_norm * torch.cat((fg_dists, fg_z_max.unsqueeze(-1) - fg_z_vals[..., -1:]), dim=-1) # [..., N_samples]
|
||||
fg_alpha = 1. - torch.exp(-fg_raw['sigma'] * fg_dists) # [..., N_samples]
|
||||
T = torch.cumprod(1. - fg_alpha + TINY_NUMBER, dim=-1) # [..., N_samples]
|
||||
bg_lambda = T[..., -1]
|
||||
T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
|
||||
fg_weights = fg_alpha * T # [..., N_samples]
|
||||
fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
|
||||
fg_diffuse_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['diffuse_rgb'], dim=-2) # [..., 3]
|
||||
fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
|
||||
|
||||
# render background
|
||||
N_samples = bg_z_vals.shape[-1]
|
||||
bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
bg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
|
||||
bg_pts, _ = depth2pts_outside(bg_ray_o, bg_ray_d, bg_z_vals) # [..., N_samples, 4]
|
||||
input = torch.cat((self.bg_embedder_position(bg_pts),
|
||||
self.bg_embedder_viewdir(bg_viewdirs)), dim=-1)
|
||||
# near_depth: physical far; far_depth: physical near
|
||||
input = torch.flip(input, dims=[-2,])
|
||||
bg_z_vals = torch.flip(bg_z_vals, dims=[-1,]) # 1--->0
|
||||
bg_dists = bg_z_vals[..., :-1] - bg_z_vals[..., 1:]
|
||||
bg_dists = torch.cat((bg_dists, HUGE_NUMBER * torch.ones_like(bg_dists[..., 0:1])), dim=-1) # [..., N_samples]
|
||||
bg_raw = self.bg_net(input)
|
||||
bg_alpha = 1. - torch.exp(-bg_raw['sigma'] * bg_dists) # [..., N_samples]
|
||||
# Eq. (3): T
|
||||
# maths show weights, and summation of weights along a ray, are always inside [0, 1]
|
||||
T = torch.cumprod(1. - bg_alpha + TINY_NUMBER, dim=-1)[..., :-1] # [..., N_samples-1]
|
||||
T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
|
||||
bg_weights = bg_alpha * T # [..., N_samples]
|
||||
bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
|
||||
bg_diffuse_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['diffuse_rgb'], dim=-2) # [..., 3]
|
||||
bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
|
||||
|
||||
# composite foreground and background
|
||||
bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
|
||||
bg_diffuse_rgb_map = bg_lambda.unsqueeze(-1) * bg_diffuse_rgb_map
|
||||
bg_depth_map = bg_lambda * bg_depth_map
|
||||
rgb_map = fg_rgb_map + bg_rgb_map
|
||||
diffuse_rgb_map = fg_diffuse_rgb_map + bg_diffuse_rgb_map
|
||||
|
||||
ret = OrderedDict([('rgb', rgb_map), # loss
|
||||
('diffuse_rgb', diffuse_rgb_map), # regularize
|
||||
('fg_weights', fg_weights), # importance sampling
|
||||
('bg_weights', bg_weights), # importance sampling
|
||||
('fg_rgb', fg_rgb_map), # below are for logging
|
||||
('fg_depth', fg_depth_map),
|
||||
('bg_rgb', bg_rgb_map),
|
||||
('bg_depth', bg_depth_map),
|
||||
('bg_lambda', bg_lambda)])
|
||||
return ret
|
621
ddp_run_nerf.py
Normal file
621
ddp_run_nerf.py
Normal file
|
@ -0,0 +1,621 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.distributed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from ddp_model import NerfNet
|
||||
import time
|
||||
|
||||
from data_loader_split import load_data_split
|
||||
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def setup_logger():
|
||||
# create logger
|
||||
logger = logging.getLogger(__package__)
|
||||
# logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# create console handler and set level to debug
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
|
||||
# create formatter
|
||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
||||
|
||||
# add formatter to ch
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
# add ch to logger
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
def intersect_sphere(ray_o, ray_d):
|
||||
'''
|
||||
ray_o, ray_d: [..., 3]
|
||||
compute the depth of the intersection point between this ray and unit sphere
|
||||
'''
|
||||
# note: d1 becomes negative if this mid point is behind camera
|
||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
||||
p = ray_o + d1.unsqueeze(-1) * ray_d
|
||||
# consider the case where the ray does not intersect the sphere
|
||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
||||
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
|
||||
|
||||
return d1 + d2
|
||||
|
||||
|
||||
def perturb_samples(z_vals):
|
||||
# get intervals between samples
|
||||
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
|
||||
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
|
||||
# uniform samples in those intervals
|
||||
t_rand = torch.rand_like(z_vals)
|
||||
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
|
||||
|
||||
return z_vals
|
||||
|
||||
|
||||
def sample_pdf(bins, weights, N_samples, det=False):
|
||||
'''
|
||||
:param bins: tensor of shape [..., M+1], M is the number of bins
|
||||
:param weights: tensor of shape [..., M]
|
||||
:param N_samples: number of samples along each ray
|
||||
:param det: if True, will perform deterministic sampling
|
||||
:return: [..., N_samples]
|
||||
'''
|
||||
# Get pdf
|
||||
weights = weights + TINY_NUMBER # prevent nans
|
||||
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
|
||||
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
|
||||
|
||||
# Take uniform samples
|
||||
dots_sh = list(weights.shape[:-1])
|
||||
M = weights.shape[-1]
|
||||
|
||||
min_cdf = 0.00
|
||||
max_cdf = 1.00 # prevent outlier samples
|
||||
|
||||
if det:
|
||||
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
|
||||
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
|
||||
else:
|
||||
sh = dots_sh + [N_samples]
|
||||
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
|
||||
|
||||
# Invert CDF
|
||||
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
|
||||
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
|
||||
|
||||
# random sample inside each bin
|
||||
below_inds = torch.clamp(above_inds-1, min=0)
|
||||
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
|
||||
|
||||
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
# fix numeric issue
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
|
||||
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
|
||||
##### parallel rendering of a single image
|
||||
ray_batch = ray_sampler.get_all()
|
||||
# split into ranks; make sure different processes don't overlap
|
||||
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
|
||||
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
|
||||
|
||||
# split into chunks and render inside each process
|
||||
ray_batch_split = OrderedDict()
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
|
||||
|
||||
# forward and backward
|
||||
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
|
||||
for s in range(len(ray_batch_split['ray_d'])):
|
||||
ray_o = ray_batch_split['ray_o'][s]
|
||||
ray_d = ray_batch_split['ray_d'][s]
|
||||
min_depth = ray_batch_split['min_depth'][s]
|
||||
|
||||
dots_sh = list(ray_d.shape[:-1])
|
||||
for m in range(models['cascade_level']):
|
||||
net = models['net_{}'.format(m)]
|
||||
# sample depths
|
||||
N_samples = models['cascade_samples'][m]
|
||||
if m == 0:
|
||||
# foreground depth
|
||||
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
|
||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
||||
fg_near_depth = min_depth # [..., 3]
|
||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
||||
|
||||
# background depth
|
||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
||||
|
||||
# delete unused memory
|
||||
del fg_near_depth
|
||||
del step
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# sample pdf and concat with earlier samples
|
||||
fg_weights = ret['fg_weights'].clone().detach()
|
||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
||||
|
||||
# sample pdf and concat with earlier samples
|
||||
bg_weights = ret['bg_weights'].clone().detach()
|
||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
||||
|
||||
# delete unused memory
|
||||
del fg_weights
|
||||
del fg_depth_mid
|
||||
del fg_depth_samples
|
||||
del bg_weights
|
||||
del bg_depth_mid
|
||||
del bg_depth_samples
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.no_grad():
|
||||
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
|
||||
|
||||
for key in ret:
|
||||
if key not in ['fg_weights', 'bg_weights']:
|
||||
if torch.is_tensor(ret[key]):
|
||||
if key not in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
|
||||
else:
|
||||
ret_merge_chunk[m][key].append(ret[key].cpu())
|
||||
|
||||
ret[key] = None
|
||||
|
||||
# clean unused memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# merge results from different chunks
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
|
||||
|
||||
# merge results from different processes
|
||||
if rank == 0:
|
||||
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
# generate tensors to store results from other processes
|
||||
sh = list(ret_merge_chunk[m][key].shape[1:])
|
||||
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
|
||||
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
|
||||
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
|
||||
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
|
||||
# print(m, key, ret_merge_rank[m][key].shape)
|
||||
else: # send results to main process
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
torch.distributed.gather(ret_merge_chunk[m][key])
|
||||
|
||||
|
||||
# only rank 0 program returns
|
||||
if rank == 0:
|
||||
return ret_merge_rank
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=''):
|
||||
rgb_im = img_HWC2CHW(torch.from_numpy(gt_img))
|
||||
writer.add_image(prefix + 'rgb_gt', rgb_im, global_step)
|
||||
|
||||
for m in range(len(log_data)):
|
||||
rgb_im = img_HWC2CHW(log_data[m]['rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step)
|
||||
|
||||
rgb_im = img_HWC2CHW(log_data[m]['fg_rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step)
|
||||
depth = log_data[m]['fg_depth']
|
||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step)
|
||||
|
||||
rgb_im = img_HWC2CHW(log_data[m]['bg_rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step)
|
||||
depth = log_data[m]['bg_depth']
|
||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step)
|
||||
bg_lambda = log_data[m]['bg_lambda']
|
||||
bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step)
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
# port = np.random.randint(12355, 12399)
|
||||
# os.environ['MASTER_PORT'] = '{}'.format(port)
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def ddp_train_nerf(rank, args):
|
||||
###### set up multi-processing
|
||||
setup(rank, args.world_size)
|
||||
###### set up logger
|
||||
logger = logging.getLogger(__package__)
|
||||
setup_logger()
|
||||
|
||||
###### decide chunk size according to gpu memory
|
||||
logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory))
|
||||
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
|
||||
logger.info('setting batch size according to 24G gpu')
|
||||
args.N_rand = 1024
|
||||
args.chunk_size = 8192
|
||||
else:
|
||||
logger.info('setting batch size according to 12G gpu')
|
||||
args.N_rand = 512
|
||||
args.chunk_size = 4096
|
||||
|
||||
###### Create log dir and copy the config file
|
||||
if rank == 0:
|
||||
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
|
||||
f = os.path.join(args.basedir, args.expname, 'args.txt')
|
||||
with open(f, 'w') as file:
|
||||
for arg in sorted(vars(args)):
|
||||
attr = getattr(args, arg)
|
||||
file.write('{} = {}\n'.format(arg, attr))
|
||||
if args.config is not None:
|
||||
f = os.path.join(args.basedir, args.expname, 'config.txt')
|
||||
with open(f, 'w') as file:
|
||||
file.write(open(args.config, 'r').read())
|
||||
torch.distributed.barrier()
|
||||
|
||||
ray_samplers = load_data_split(args.datadir, args.scene, split='train', try_load_min_depth=args.load_min_depth)
|
||||
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation', try_load_min_depth=args.load_min_depth)
|
||||
|
||||
###### create network and wrap in ddp; each process should do this
|
||||
# fix random seed just to make sure the network is initialized with same weights at different processes
|
||||
torch.manual_seed(777)
|
||||
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
models = OrderedDict()
|
||||
models['cascade_level'] = args.cascade_level
|
||||
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
|
||||
for m in range(models['cascade_level']):
|
||||
net = NerfNet(args).to(rank)
|
||||
net = DDP(net, device_ids=[rank], output_device=rank)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
|
||||
models['net_{}'.format(m)] = net
|
||||
models['optim_{}'.format(m)] = optim
|
||||
|
||||
start = -1
|
||||
|
||||
###### load pretrained weights; each process should do this
|
||||
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
|
||||
ckpts = [args.ckpt_path]
|
||||
else:
|
||||
ckpts = [os.path.join(args.basedir, args.expname, f)
|
||||
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
|
||||
def path2iter(path):
|
||||
tmp = os.path.basename(path)[:-4]
|
||||
idx = tmp.rfind('_')
|
||||
return int(tmp[idx + 1:])
|
||||
ckpts = sorted(ckpts, key=path2iter)
|
||||
logger.info('Found ckpts: {}'.format(ckpts))
|
||||
if len(ckpts) > 0 and not args.no_reload:
|
||||
fpath = ckpts[-1]
|
||||
logger.info('Reloading from: {}'.format(fpath))
|
||||
start = path2iter(fpath)
|
||||
# configure map_location properly for different processes
|
||||
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
||||
to_load = torch.load(fpath, map_location=map_location)
|
||||
for m in range(models['cascade_level']):
|
||||
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
|
||||
models[name].load_state_dict(to_load[name])
|
||||
models[name].load_state_dict(to_load[name])
|
||||
|
||||
##### important!!!
|
||||
# make sure different processes sample different rays
|
||||
np.random.seed((rank + 1) * 777)
|
||||
# make sure different processes have different perturbations in depth samples
|
||||
torch.manual_seed((rank + 1) * 777)
|
||||
|
||||
##### only main process should do the logging
|
||||
if rank == 0:
|
||||
writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname))
|
||||
|
||||
# start training
|
||||
what_val_to_log = 0 # helper variable for parallel rendering of a image
|
||||
what_train_to_log = 0
|
||||
for global_step in range(start+1, start+1+args.N_iters):
|
||||
time0 = time.time()
|
||||
scalars_to_log = OrderedDict()
|
||||
### Start of core optimization loop
|
||||
scalars_to_log['resolution'] = ray_samplers[0].resolution_level
|
||||
# randomly sample rays and move to device
|
||||
i = np.random.randint(low=0, high=len(ray_samplers))
|
||||
ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False)
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch[key] = ray_batch[key].to(rank)
|
||||
|
||||
# forward and backward
|
||||
dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays
|
||||
all_rets = [] # results on different cascade levels
|
||||
for m in range(models['cascade_level']):
|
||||
optim = models['optim_{}'.format(m)]
|
||||
net = models['net_{}'.format(m)]
|
||||
|
||||
# sample depths
|
||||
N_samples = models['cascade_samples'][m]
|
||||
if m == 0:
|
||||
# foreground depth
|
||||
fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,]
|
||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
||||
fg_near_depth = ray_batch['min_depth'] # [..., 3]
|
||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
||||
fg_depth = perturb_samples(fg_depth) # random perturbation during training
|
||||
|
||||
# background depth
|
||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
||||
bg_depth = perturb_samples(bg_depth) # random perturbation during training
|
||||
else:
|
||||
# sample pdf and concat with earlier samples
|
||||
fg_weights = ret['fg_weights'].clone().detach()
|
||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
||||
N_samples=N_samples, det=False) # [..., N_samples]
|
||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
||||
|
||||
# sample pdf and concat with earlier samples
|
||||
bg_weights = ret['bg_weights'].clone().detach()
|
||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
||||
N_samples=N_samples, det=False) # [..., N_samples]
|
||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
||||
|
||||
optim.zero_grad()
|
||||
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth)
|
||||
all_rets.append(ret)
|
||||
|
||||
rgb_gt = ray_batch['rgb'].to(rank)
|
||||
loss = img2mse(ret['rgb'], rgb_gt)
|
||||
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
|
||||
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
|
||||
# regularize sigma with photo-consistency
|
||||
diffuse_loss = img2mse(ret['diffuse_rgb'], rgb_gt)
|
||||
scalars_to_log['level_{}/diffuse_loss'.format(m)] = diffuse_loss.item()
|
||||
scalars_to_log['level_{}/diffuse_psnr'.format(m)] = mse2psnr(diffuse_loss.item())
|
||||
loss = (1. - args.regularize_weight) * loss + args.regularize_weight * diffuse_loss
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
# # clean unused memory
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
### end of core optimization loop
|
||||
dt = time.time() - time0
|
||||
scalars_to_log['iter_time'] = dt
|
||||
|
||||
### only main process should do the logging
|
||||
if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
|
||||
logstr = '{} step: {} '.format(args.expname, global_step)
|
||||
for k in scalars_to_log:
|
||||
logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
|
||||
writer.add_scalar(k, scalars_to_log[k], global_step)
|
||||
logger.info(logstr)
|
||||
|
||||
### each process should do this; but only main process merges the results
|
||||
if global_step % args.i_img == 0 or global_step == start+1:
|
||||
#### critical: make sure each process is working on the same random image
|
||||
time0 = time.time()
|
||||
idx = what_val_to_log % len(val_ray_samplers)
|
||||
log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size)
|
||||
what_val_to_log += 1
|
||||
dt = time.time() - time0
|
||||
if rank == 0: # only main process should do this
|
||||
logger.info('Logged a random validation view in {} seconds'.format(dt))
|
||||
log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/')
|
||||
|
||||
time0 = time.time()
|
||||
idx = what_train_to_log % len(ray_samplers)
|
||||
log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
||||
what_train_to_log += 1
|
||||
dt = time.time() - time0
|
||||
if rank == 0: # only main process should do this
|
||||
logger.info('Logged a random training view in {} seconds'.format(dt))
|
||||
log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/')
|
||||
|
||||
log_data = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
|
||||
# saving checkpoints and logging
|
||||
fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step))
|
||||
to_save = OrderedDict()
|
||||
for m in range(models['cascade_level']):
|
||||
name = 'net_{}'.format(m)
|
||||
to_save[name] = models[name].state_dict()
|
||||
|
||||
name = 'optim_{}'.format(m)
|
||||
to_save[name] = models[name].state_dict()
|
||||
torch.save(to_save, fpath)
|
||||
|
||||
# clean up for multi-processing
|
||||
cleanup()
|
||||
|
||||
|
||||
def config_parser():
|
||||
import configargparse
|
||||
parser = configargparse.ArgumentParser()
|
||||
parser.add_argument('--config', is_config_file=True, help='config file path')
|
||||
parser.add_argument("--expname", type=str, help='experiment name')
|
||||
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
|
||||
|
||||
# dataset options
|
||||
parser.add_argument("--datadir", type=str, default=None, help='input data directory')
|
||||
parser.add_argument("--scene", type=str, default=None, help='scene name')
|
||||
parser.add_argument("--testskip", type=int, default=8,
|
||||
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
||||
|
||||
# model size
|
||||
parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
|
||||
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
|
||||
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
|
||||
|
||||
# checkpoints
|
||||
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
|
||||
parser.add_argument("--ckpt_path", type=str, default=None,
|
||||
help='specific weights npy file to reload for coarse network')
|
||||
|
||||
# batch size
|
||||
parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
|
||||
help='batch size (number of random rays per gradient step)')
|
||||
parser.add_argument("--chunk_size", type=int, default=1024 * 8,
|
||||
help='number of rays processed in parallel, decrease if running out of memory')
|
||||
|
||||
# iterations
|
||||
parser.add_argument("--N_iters", type=int, default=250001,
|
||||
help='number of iterations')
|
||||
|
||||
parser.add_argument("--render_splits", type=str, default='test',
|
||||
help='splits to render')
|
||||
|
||||
# cascade training
|
||||
parser.add_argument("--cascade_level", type=int, default=2,
|
||||
help='number of cascade levels')
|
||||
parser.add_argument("--cascade_samples", type=str, default='64,64',
|
||||
help='samples at each level')
|
||||
parser.add_argument("--devices", type=str, default='0,1',
|
||||
help='cuda device for each level')
|
||||
parser.add_argument("--bg_devices", type=str, default='0,2',
|
||||
help='cuda device for the background of each level')
|
||||
|
||||
parser.add_argument("--world_size", type=int, default='-1',
|
||||
help='number of processes')
|
||||
|
||||
# mixed precison training
|
||||
parser.add_argument("--opt_level", type=str, default='O1',
|
||||
help='mixed precison training')
|
||||
|
||||
parser.add_argument("--near_depth", type=float, default=0.1,
|
||||
help='near depth plane')
|
||||
parser.add_argument("--far_depth", type=float, default=50.,
|
||||
help='far depth plane')
|
||||
|
||||
# learning rate options
|
||||
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
|
||||
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
|
||||
help='decay learning rate by a factor every specified number of steps')
|
||||
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
|
||||
help='decay learning rate by a factor every specified number of steps')
|
||||
|
||||
# rendering options
|
||||
parser.add_argument("--inv_uniform", action='store_true',
|
||||
help='if True, will uniformly sample inverse depths')
|
||||
parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
|
||||
parser.add_argument("--max_freq_log2", type=int, default=10,
|
||||
help='log2 of max freq for positional encoding (3D location)')
|
||||
parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
|
||||
help='log2 of max freq for positional encoding (2D direction)')
|
||||
parser.add_argument("--N_iters_perturb", type=int, default=1000,
|
||||
help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck')
|
||||
parser.add_argument("--raw_noise_std", type=float, default=1.,
|
||||
help='std dev of noise added to regularize sigma output, 1e0 recommended')
|
||||
parser.add_argument("--white_bkgd", action='store_true',
|
||||
help='apply the trick to avoid fitting to white background')
|
||||
|
||||
# use implicit
|
||||
parser.add_argument("--use_implicit", action='store_true', help='whether to use implicit regularization')
|
||||
parser.add_argument("--regularize_weight", type=float, default=0.5,
|
||||
help='regularizing weight of auxiliary loss')
|
||||
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
|
||||
|
||||
# no training; render only
|
||||
parser.add_argument("--render_only", action='store_true',
|
||||
help='do not optimize, reload weights and render out render_poses path')
|
||||
parser.add_argument("--render_train", action='store_true', help='render the training set')
|
||||
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
|
||||
|
||||
# no training; extract mesh only
|
||||
parser.add_argument("--mesh_only", action='store_true',
|
||||
help='do not optimize, extract mesh from pretrained model')
|
||||
parser.add_argument("--N_pts", type=int, default=256,
|
||||
help='voxel resolution; N_pts * N_pts * N_pts')
|
||||
parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50',
|
||||
help='threshold(s) for mesh extraction; can use multiple thresholds')
|
||||
|
||||
# logging/saving options
|
||||
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
|
||||
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
|
||||
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
|
||||
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
|
||||
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def train():
|
||||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
logger.info(parser.format_values())
|
||||
|
||||
if args.world_size == -1:
|
||||
args.world_size = torch.cuda.device_count()
|
||||
logger.info('Using # gpus: {}'.format(args.world_size))
|
||||
torch.multiprocessing.spawn(ddp_train_nerf,
|
||||
args=(args,),
|
||||
nprocs=args.world_size,
|
||||
join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_logger()
|
||||
train()
|
||||
|
||||
|
391
ddp_test_nerf.py
Normal file
391
ddp_test_nerf.py
Normal file
|
@ -0,0 +1,391 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.distributed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from ddp_model import NerfNet
|
||||
import time
|
||||
from data_loader_split import load_data_split
|
||||
from utils import mse2psnr, img_HWC2CHW, colorize, colorize_np, TINY_NUMBER, to8b
|
||||
import imageio
|
||||
from ddp_run_nerf import config_parser
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def setup_logger():
|
||||
# create logger
|
||||
logger = logging.getLogger(__package__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# create console handler and set level to debug
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
|
||||
# create formatter
|
||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
||||
|
||||
# add formatter to ch
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
# add ch to logger
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
def intersect_sphere(ray_o, ray_d):
|
||||
'''
|
||||
ray_o, ray_d: [..., 3]
|
||||
compute the depth of the intersection point between this ray and unit sphere
|
||||
'''
|
||||
# note: d1 becomes negative if this mid point is behind camera
|
||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
||||
p = ray_o + d1.unsqueeze(-1) * ray_d
|
||||
# consider the case where the ray does not intersect the sphere
|
||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
||||
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
|
||||
|
||||
return d1 + d2
|
||||
|
||||
|
||||
def perturb_samples(z_vals):
|
||||
# get intervals between samples
|
||||
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
|
||||
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
|
||||
# uniform samples in those intervals
|
||||
t_rand = torch.rand_like(z_vals)
|
||||
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
|
||||
|
||||
return z_vals
|
||||
|
||||
|
||||
def sample_pdf(bins, weights, N_samples, det=False):
|
||||
'''
|
||||
:param bins: tensor of shape [..., M+1], M is the number of bins
|
||||
:param weights: tensor of shape [..., M]
|
||||
:param N_samples: number of samples along each ray
|
||||
:param det: if True, will perform deterministic sampling
|
||||
:return: [..., N_samples]
|
||||
'''
|
||||
# Get pdf
|
||||
weights = weights + TINY_NUMBER # prevent nans
|
||||
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
|
||||
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
|
||||
|
||||
# Take uniform samples
|
||||
dots_sh = list(weights.shape[:-1])
|
||||
M = weights.shape[-1]
|
||||
|
||||
min_cdf = 0.00
|
||||
max_cdf = 1.00 # prevent outlier samples
|
||||
|
||||
if det:
|
||||
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
|
||||
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
|
||||
else:
|
||||
sh = dots_sh + [N_samples]
|
||||
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
|
||||
|
||||
# Invert CDF
|
||||
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
|
||||
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
|
||||
|
||||
# random sample inside each bin
|
||||
below_inds = torch.clamp(above_inds-1, min=0)
|
||||
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
|
||||
|
||||
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
# fix numeric issue
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
|
||||
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
|
||||
##### parallel rendering of a single image
|
||||
ray_batch = ray_sampler.get_all()
|
||||
# split into ranks; make sure different processes don't overlap
|
||||
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
|
||||
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
|
||||
|
||||
# split into chunks and render inside each process
|
||||
ray_batch_split = OrderedDict()
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
|
||||
|
||||
# forward and backward
|
||||
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
|
||||
for s in range(len(ray_batch_split['ray_d'])):
|
||||
ray_o = ray_batch_split['ray_o'][s]
|
||||
ray_d = ray_batch_split['ray_d'][s]
|
||||
min_depth = ray_batch_split['min_depth'][s]
|
||||
|
||||
dots_sh = list(ray_d.shape[:-1])
|
||||
for m in range(models['cascade_level']):
|
||||
net = models['net_{}'.format(m)]
|
||||
# sample depths
|
||||
N_samples = models['cascade_samples'][m]
|
||||
if m == 0:
|
||||
# foreground depth
|
||||
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
|
||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
||||
fg_near_depth = min_depth # [..., 3]
|
||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
||||
|
||||
# background depth
|
||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
||||
|
||||
# delete unused memory
|
||||
del fg_near_depth
|
||||
del step
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# sample pdf and concat with earlier samples
|
||||
fg_weights = ret['fg_weights'].clone().detach()
|
||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
||||
|
||||
# sample pdf and concat with earlier samples
|
||||
bg_weights = ret['bg_weights'].clone().detach()
|
||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
||||
|
||||
# delete unused memory
|
||||
del fg_weights
|
||||
del fg_depth_mid
|
||||
del fg_depth_samples
|
||||
del bg_weights
|
||||
del bg_depth_mid
|
||||
del bg_depth_samples
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.no_grad():
|
||||
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
|
||||
|
||||
for key in ret:
|
||||
if key not in ['fg_weights', 'bg_weights']:
|
||||
if torch.is_tensor(ret[key]):
|
||||
if key not in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
|
||||
else:
|
||||
ret_merge_chunk[m][key].append(ret[key].cpu())
|
||||
|
||||
ret[key] = None
|
||||
|
||||
# clean unused memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# merge results from different chunks
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
|
||||
|
||||
# merge results from different processes
|
||||
if rank == 0:
|
||||
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
# generate tensors to store results from other processes
|
||||
sh = list(ret_merge_chunk[m][key].shape[1:])
|
||||
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
|
||||
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
|
||||
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
|
||||
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
|
||||
# print(m, key, ret_merge_rank[m][key].shape)
|
||||
else: # send results to main process
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
torch.distributed.gather(ret_merge_chunk[m][key])
|
||||
|
||||
# only rank 0 program returns
|
||||
if rank == 0:
|
||||
return ret_merge_rank
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def ddp_test_nerf(rank, args):
|
||||
###### set up multi-processing
|
||||
setup(rank, args.world_size)
|
||||
###### set up logger
|
||||
logger = logging.getLogger(__package__)
|
||||
setup_logger()
|
||||
|
||||
###### decide chunk size according to gpu memory
|
||||
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
|
||||
logger.info('setting batch size according to 24G gpu')
|
||||
args.N_rand = 1024
|
||||
args.chunk_size = 8192
|
||||
else:
|
||||
logger.info('setting batch size according to 12G gpu')
|
||||
args.N_rand = 512
|
||||
args.chunk_size = 4096
|
||||
|
||||
###### create network and wrap in ddp; each process should do this
|
||||
# fix random seed just to make sure the network is initialized with same weights at different processes
|
||||
torch.manual_seed(777)
|
||||
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
models = OrderedDict()
|
||||
models['cascade_level'] = args.cascade_level
|
||||
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
|
||||
for m in range(models['cascade_level']):
|
||||
net = NerfNet(args).to(rank)
|
||||
net = DDP(net, device_ids=[rank], output_device=rank)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
|
||||
models['net_{}'.format(m)] = net
|
||||
models['optim_{}'.format(m)] = optim
|
||||
|
||||
start = -1
|
||||
|
||||
###### load pretrained weights; each process should do this
|
||||
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
|
||||
ckpts = [args.ckpt_path]
|
||||
else:
|
||||
ckpts = [os.path.join(args.basedir, args.expname, f)
|
||||
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
|
||||
def path2iter(path):
|
||||
tmp = os.path.basename(path)[:-4]
|
||||
idx = tmp.rfind('_')
|
||||
return int(tmp[idx + 1:])
|
||||
ckpts = sorted(ckpts, key=path2iter)
|
||||
logger.info('Found ckpts: {}'.format(ckpts))
|
||||
if len(ckpts) > 0 and not args.no_reload:
|
||||
fpath = ckpts[-1]
|
||||
logger.info('Reloading from: {}'.format(fpath))
|
||||
start = path2iter(fpath)
|
||||
# configure map_location properly for different processes
|
||||
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
||||
to_load = torch.load(fpath, map_location=map_location)
|
||||
for m in range(models['cascade_level']):
|
||||
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
|
||||
models[name].load_state_dict(to_load[name])
|
||||
models[name].load_state_dict(to_load[name])
|
||||
|
||||
render_splits = [x.strip() for x in args.render_splits.strip().split(',')]
|
||||
# start testing
|
||||
for split in render_splits:
|
||||
out_dir = os.path.join(args.basedir, args.expname,
|
||||
'render_{}_{:06d}'.format(split, start))
|
||||
if rank == 0:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
###### load data and create ray samplers; each process should do this
|
||||
ray_samplers = load_data_split(args.datadir, args.scene, split, try_load_min_depth=args.load_min_depth)
|
||||
for idx in range(len(ray_samplers)):
|
||||
### each process should do this; but only main process merges the results
|
||||
fname = '{:06d}.png'.format(idx)
|
||||
if ray_samplers[idx].img_path is not None:
|
||||
fname = os.path.basename(ray_samplers[idx].img_path)
|
||||
|
||||
if os.path.isfile(os.path.join(out_dir, fname)):
|
||||
logger.info('Skipping {}'.format(fname))
|
||||
continue
|
||||
|
||||
time0 = time.time()
|
||||
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
||||
dt = time.time() - time0
|
||||
if rank == 0: # only main process should do this
|
||||
|
||||
logger.info('Rendered {} in {} seconds'.format(fname, dt))
|
||||
|
||||
# only save last level
|
||||
im = ret[-1]['rgb'].numpy()
|
||||
# compute psnr if ground-truth is available
|
||||
if ray_samplers[idx].img_path is not None:
|
||||
gt_im = ray_samplers[idx].get_img()
|
||||
psnr = mse2psnr(np.mean((gt_im - im) * (gt_im - im)))
|
||||
logger.info('{}: psnr={}'.format(fname, psnr))
|
||||
|
||||
im = to8b(im)
|
||||
imageio.imwrite(os.path.join(out_dir, fname), im)
|
||||
|
||||
# im = ret[-1]['diffuse_rgb'].numpy()
|
||||
# im = to8b(im)
|
||||
# imageio.imwrite(os.path.join(out_dir, 'diffuse_' + fname), im)
|
||||
|
||||
im = ret[-1]['fg_rgb'].numpy()
|
||||
im = to8b(im)
|
||||
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
|
||||
|
||||
im = ret[-1]['bg_rgb'].numpy()
|
||||
im = to8b(im)
|
||||
imageio.imwrite(os.path.join(out_dir, 'bg_' + fname), im)
|
||||
|
||||
im = ret[-1]['fg_depth'].numpy()
|
||||
im = colorize_np(im, cmap_name='jet', append_cbar=True)
|
||||
im = to8b(im)
|
||||
imageio.imwrite(os.path.join(out_dir, 'fg_depth_' + fname), im)
|
||||
|
||||
im = ret[-1]['bg_depth'].numpy()
|
||||
im = colorize_np(im, cmap_name='jet', append_cbar=True)
|
||||
im = to8b(im)
|
||||
imageio.imwrite(os.path.join(out_dir, 'bg_depth_' + fname), im)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# clean up for multi-processing
|
||||
cleanup()
|
||||
|
||||
|
||||
def test():
|
||||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
logger.info(parser.format_values())
|
||||
|
||||
if args.world_size == -1:
|
||||
args.world_size = torch.cuda.device_count()
|
||||
logger.info('Using # gpus: {}'.format(args.world_size))
|
||||
torch.multiprocessing.spawn(ddp_test_nerf,
|
||||
args=(args,),
|
||||
nprocs=args.world_size,
|
||||
join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_logger()
|
||||
test()
|
||||
|
||||
|
167
nerf_network.py
Normal file
167
nerf_network.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F
|
||||
# import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
class Embedder(nn.Module):
|
||||
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
||||
log_sampling=True, include_input=True,
|
||||
periodic_fns=(torch.sin, torch.cos)):
|
||||
'''
|
||||
:param input_dim: dimension of input to be embedded
|
||||
:param max_freq_log2: log2 of max freq; min freq is 1 by default
|
||||
:param N_freqs: number of frequency bands
|
||||
:param log_sampling: if True, frequency bands are linerly sampled in log-space
|
||||
:param include_input: if True, raw input is included in the embedding
|
||||
:param periodic_fns: periodic functions used to embed input
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.include_input = include_input
|
||||
self.periodic_fns = periodic_fns
|
||||
|
||||
self.out_dim = 0
|
||||
if self.include_input:
|
||||
self.out_dim += self.input_dim
|
||||
|
||||
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
||||
|
||||
if log_sampling:
|
||||
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
||||
else:
|
||||
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
|
||||
|
||||
self.freq_bands = self.freq_bands.numpy().tolist()
|
||||
|
||||
def forward(self, input):
|
||||
'''
|
||||
:param input: tensor of shape [..., self.input_dim]
|
||||
:return: tensor of shape [..., self.out_dim]
|
||||
'''
|
||||
assert (input.shape[-1] == self.input_dim)
|
||||
|
||||
out = []
|
||||
if self.include_input:
|
||||
out.append(input)
|
||||
|
||||
for i in range(len(self.freq_bands)):
|
||||
freq = self.freq_bands[i]
|
||||
for p_fn in self.periodic_fns:
|
||||
out.append(p_fn(input * freq))
|
||||
out = torch.cat(out, dim=-1)
|
||||
|
||||
assert (out.shape[-1] == self.out_dim)
|
||||
return out
|
||||
|
||||
# default tensorflow initialization of linear layers
|
||||
def weights_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias.data)
|
||||
|
||||
|
||||
class MLPNet(nn.Module):
|
||||
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3, skips=[4], use_viewdirs=False
|
||||
, use_implicit=False):
|
||||
'''
|
||||
:param D: network depth
|
||||
:param W: network width
|
||||
:param input_ch: input channels for encodings of (x, y, z)
|
||||
:param input_ch_viewdirs: input channels for encodings of view directions
|
||||
:param skips: skip connection in network
|
||||
:param use_viewdirs: if True, will use the view directions as input
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.use_implicit = use_implicit
|
||||
if self.use_implicit:
|
||||
logger.info('Using implicit regularization as well!')
|
||||
|
||||
self.input_ch = input_ch
|
||||
self.input_ch_viewdirs = input_ch_viewdirs
|
||||
self.use_viewdirs = use_viewdirs
|
||||
self.skips = skips
|
||||
|
||||
self.base_layers = []
|
||||
dim = self.input_ch
|
||||
for i in range(D):
|
||||
self.base_layers.append(
|
||||
nn.Sequential(nn.Linear(dim, W), nn.ReLU())
|
||||
)
|
||||
dim = W
|
||||
if i in self.skips and i != (D-1): # skip connection after i^th layer
|
||||
dim += input_ch
|
||||
self.base_layers = nn.ModuleList(self.base_layers)
|
||||
# self.base_layers.apply(weights_init) # xavier init
|
||||
|
||||
sigma_layers = [nn.Linear(dim, 1), ] # sigma must be positive
|
||||
self.sigma_layers = nn.Sequential(*sigma_layers)
|
||||
# self.sigma_layers.apply(weights_init) # xavier init
|
||||
|
||||
base_dim = dim
|
||||
# diffuse color
|
||||
diffuse_rgb_layers = []
|
||||
dim = base_dim
|
||||
for i in range(1):
|
||||
diffuse_rgb_layers.append(nn.Linear(dim, W))
|
||||
diffuse_rgb_layers.append(nn.ReLU())
|
||||
dim = W
|
||||
diffuse_rgb_layers.append(nn.Linear(dim, 3))
|
||||
diffuse_rgb_layers.append(nn.Sigmoid())
|
||||
self.diffuse_rgb_layers = nn.Sequential(*diffuse_rgb_layers)
|
||||
# self.diffuse_rgb_layers.apply(weights_init)
|
||||
|
||||
# specular color
|
||||
specular_rgb_layers = []
|
||||
dim = base_dim
|
||||
base_remap_layers = [nn.Linear(dim, 256), ]
|
||||
self.base_remap_layers = nn.Sequential(*base_remap_layers)
|
||||
# self.base_remap_layers.apply(weights_init)
|
||||
|
||||
dim = 256 + self.input_ch_viewdirs
|
||||
for i in range(1):
|
||||
specular_rgb_layers.append(nn.Linear(dim, W))
|
||||
specular_rgb_layers.append(nn.ReLU())
|
||||
dim = W
|
||||
specular_rgb_layers.append(nn.Linear(dim, 3))
|
||||
specular_rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
|
||||
self.specular_rgb_layers = nn.Sequential(*specular_rgb_layers)
|
||||
# self.specular_rgb_layers.apply(weights_init)
|
||||
|
||||
def forward(self, input):
|
||||
'''
|
||||
:param input: [..., input_ch+input_ch_viewdirs]
|
||||
:return [..., 4]
|
||||
'''
|
||||
input_pts = input[..., :self.input_ch]
|
||||
|
||||
base = self.base_layers[0](input_pts)
|
||||
for i in range(len(self.base_layers)-1):
|
||||
if i in self.skips:
|
||||
base = torch.cat((input_pts, base), dim=-1)
|
||||
base = self.base_layers[i+1](base)
|
||||
|
||||
sigma = self.sigma_layers(base)
|
||||
sigma = torch.abs(sigma)
|
||||
|
||||
diffuse_rgb = self.diffuse_rgb_layers(base)
|
||||
|
||||
base_remap = self.base_remap_layers(base)
|
||||
input_viewdirs = input[..., -self.input_ch_viewdirs:]
|
||||
specular_rgb = self.specular_rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
|
||||
|
||||
if self.use_implicit:
|
||||
rgb = specular_rgb
|
||||
else:
|
||||
rgb = diffuse_rgb + specular_rgb
|
||||
|
||||
ret = OrderedDict([('rgb', rgb),
|
||||
('diffuse_rgb', diffuse_rgb),
|
||||
('sigma', sigma.squeeze(-1))])
|
||||
return ret
|
249
nerf_sample_ray_split.py
Normal file
249
nerf_sample_ray_split.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import cv2
|
||||
import imageio
|
||||
|
||||
########################################################################################################################
|
||||
# ray batch sampling
|
||||
########################################################################################################################
|
||||
|
||||
def parse_camera(params):
|
||||
H, W = params[:2]
|
||||
intrinsics = params[2:18].reshape((4, 4))
|
||||
c2w = params[18:34].reshape((4, 4))
|
||||
|
||||
return int(W), int(H), intrinsics.astype(np.float32), c2w.astype(np.float32)
|
||||
|
||||
|
||||
def get_rays_single_image(H, W, intrinsics, c2w):
|
||||
'''
|
||||
:param H: image height
|
||||
:param W: image width
|
||||
:param intrinsics: 4 by 4 intrinsic matrix
|
||||
:param c2w: 4 by 4 camera to world extrinsic matrix
|
||||
:return:
|
||||
'''
|
||||
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
||||
|
||||
u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel
|
||||
v = v.reshape(-1).astype(dtype=np.float32) + 0.5
|
||||
pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W)
|
||||
|
||||
rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels)
|
||||
rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W)
|
||||
rays_d = rays_d.transpose((1, 0)) # (H*W, 3)
|
||||
|
||||
rays_o = c2w[:3, 3].reshape((1, 3))
|
||||
rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3)
|
||||
|
||||
depth = np.linalg.inv(c2w)[2, 3]
|
||||
depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,)
|
||||
|
||||
return rays_o, rays_d, depth
|
||||
|
||||
|
||||
class RaySamplerSingleImage(object):
|
||||
def __init__(self, H, W, intrinsics, c2w,
|
||||
img_path=None,
|
||||
resolution_level=1,
|
||||
mask_path=None,
|
||||
min_depth_path=None,
|
||||
max_depth=None):
|
||||
super().__init__()
|
||||
self.W_orig = W
|
||||
self.H_orig = H
|
||||
self.intrinsics_orig = intrinsics
|
||||
self.c2w_mat = c2w
|
||||
|
||||
self.img_path = img_path
|
||||
self.mask_path = mask_path
|
||||
self.min_depth_path = min_depth_path
|
||||
self.max_depth = max_depth
|
||||
|
||||
self.resolution_level = -1
|
||||
self.set_resolution_level(resolution_level)
|
||||
|
||||
def set_resolution_level(self, resolution_level):
|
||||
if resolution_level != self.resolution_level:
|
||||
self.resolution_level = resolution_level
|
||||
self.W = self.W_orig // resolution_level
|
||||
self.H = self.H_orig // resolution_level
|
||||
self.intrinsics = np.copy(self.intrinsics_orig)
|
||||
self.intrinsics[:2, :3] /= resolution_level
|
||||
# only load image at this time
|
||||
if self.img_path is not None:
|
||||
self.img = imageio.imread(self.img_path).astype(np.float32) / 255.
|
||||
self.img = cv2.resize(self.img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
||||
self.img = self.img.reshape((-1, 3))
|
||||
else:
|
||||
self.img = None
|
||||
|
||||
if self.mask_path is not None:
|
||||
self.mask = imageio.imread(self.mask_path).astype(np.float32) / 255.
|
||||
self.mask = cv2.resize(self.mask, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
|
||||
self.mask = self.mask.reshape((-1))
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
if self.min_depth_path is not None:
|
||||
self.min_depth = imageio.imread(self.min_depth_path).astype(np.float32) / 255. * self.max_depth + 1e-4
|
||||
self.min_depth = cv2.resize(self.min_depth, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
||||
self.min_depth = self.min_depth.reshape((-1))
|
||||
else:
|
||||
self.min_depth = None
|
||||
|
||||
self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W,
|
||||
self.intrinsics, self.c2w_mat)
|
||||
|
||||
def get_img(self):
|
||||
if self.img is not None:
|
||||
return self.img.reshape((self.H, self.W, 3))
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_all(self):
|
||||
if self.min_depth is not None:
|
||||
min_depth = self.min_depth
|
||||
else:
|
||||
min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0])
|
||||
|
||||
ret = OrderedDict([
|
||||
('ray_o', self.rays_o),
|
||||
('ray_d', self.rays_d),
|
||||
('depth', self.depth),
|
||||
('rgb', self.img),
|
||||
('mask', self.mask),
|
||||
('min_depth', min_depth)
|
||||
])
|
||||
# return torch tensors
|
||||
for k in ret:
|
||||
if ret[k] is not None:
|
||||
ret[k] = torch.from_numpy(ret[k])
|
||||
return ret
|
||||
|
||||
def random_sample(self, N_rand, center_crop=False):
|
||||
'''
|
||||
:param N_rand: number of rays to be casted
|
||||
:return:
|
||||
'''
|
||||
if center_crop:
|
||||
half_H = self.H // 2
|
||||
half_W = self.W // 2
|
||||
quad_H = half_H // 2
|
||||
quad_W = half_W // 2
|
||||
|
||||
# pixel coordinates
|
||||
u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W),
|
||||
np.arange(half_H-quad_H, half_H+quad_H))
|
||||
u = u.reshape(-1)
|
||||
v = v.reshape(-1)
|
||||
|
||||
select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False)
|
||||
|
||||
# Convert back to original image
|
||||
select_inds = v[select_inds] * self.W + u[select_inds]
|
||||
else:
|
||||
# Random from one image
|
||||
select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False)
|
||||
|
||||
rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
||||
rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
||||
depth = self.depth[select_inds] # [N_rand, ]
|
||||
|
||||
if self.img is not None:
|
||||
rgb = self.img[select_inds, :] # [N_rand, 3]
|
||||
else:
|
||||
rgb = None
|
||||
|
||||
if self.mask is not None:
|
||||
mask = self.mask[select_inds]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if self.min_depth is not None:
|
||||
min_depth = self.min_depth[select_inds]
|
||||
else:
|
||||
min_depth = 1e-4 * np.ones_like(rays_d[..., 0])
|
||||
|
||||
ret = OrderedDict([
|
||||
('ray_o', rays_o),
|
||||
('ray_d', rays_d),
|
||||
('depth', depth),
|
||||
('rgb', rgb),
|
||||
('mask', mask),
|
||||
('min_depth', min_depth)
|
||||
])
|
||||
# return torch tensors
|
||||
for k in ret:
|
||||
if ret[k] is not None:
|
||||
ret[k] = torch.from_numpy(ret[k])
|
||||
|
||||
return ret
|
||||
|
||||
# def random_sample_patches(self, N_patch, r_patch=16, center_crop=False):
|
||||
# '''
|
||||
# :param N_patch: number of patches to be sampled
|
||||
# :param r_patch: patch size will be (2*r_patch+1)*(2*r_patch+1)
|
||||
# :return:
|
||||
# '''
|
||||
# # even size patch
|
||||
# # offsets to center pixels
|
||||
# u, v = np.meshgrid(np.arange(-r_patch, r_patch),
|
||||
# np.arange(-r_patch, r_patch))
|
||||
# u = u.reshape(-1)
|
||||
# v = v.reshape(-1)
|
||||
# offsets = v * self.W + u
|
||||
|
||||
# # center pixel coordinates
|
||||
# u_min = r_patch
|
||||
# u_max = self.W - r_patch
|
||||
# v_min = r_patch
|
||||
# v_max = self.H - r_patch
|
||||
# if center_crop:
|
||||
# u_min = self.W // 4 + r_patch
|
||||
# u_max = self.W - self.W // 4 - r_patch
|
||||
# v_min = self.H // 4 + r_patch
|
||||
# v_max = self.H - self.H // 4 - r_patch
|
||||
|
||||
# u, v = np.meshgrid(np.arange(u_min, u_max, r_patch),
|
||||
# np.arange(v_min, v_max, r_patch))
|
||||
# u = u.reshape(-1)
|
||||
# v = v.reshape(-1)
|
||||
|
||||
# select_inds = np.random.choice(u.shape[0], size=(N_patch,), replace=False)
|
||||
# # Convert back to original image
|
||||
# select_inds = v[select_inds] * self.W + u[select_inds]
|
||||
|
||||
# # pick patches
|
||||
# select_inds = np.stack([select_inds + shift for shift in offsets], axis=1)
|
||||
# select_inds = select_inds.reshape(-1)
|
||||
|
||||
# rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
||||
# rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
||||
# depth = self.depth[select_inds] # [N_rand, ]
|
||||
|
||||
# if self.img is not None:
|
||||
# rgb = self.img[select_inds, :] # [N_rand, 3]
|
||||
|
||||
# # ### debug
|
||||
# # import imageio
|
||||
# # imgs = rgb.reshape((N_patch, r_patch*2, r_patch*2, -1))
|
||||
# # for kk in range(imgs.shape[0]):
|
||||
# # imageio.imwrite('./debug_{}.png'.format(kk), imgs[kk])
|
||||
# # ###
|
||||
# else:
|
||||
# rgb = None
|
||||
|
||||
# ret = OrderedDict([
|
||||
# ('ray_o', rays_o),
|
||||
# ('ray_d', rays_d),
|
||||
# ('depth', depth),
|
||||
# ('rgb', rgb)
|
||||
# ])
|
||||
|
||||
# # return torch tensors
|
||||
# for k in ret:
|
||||
# ret[k] = torch.from_numpy(ret[k])
|
||||
|
||||
# return ret
|
181
old_scripts/data_loader.py
Normal file
181
old_scripts/data_loader.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
########################################################################################################################
|
||||
# camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention)
|
||||
# poses is camera-to-world
|
||||
########################################################################################################################
|
||||
|
||||
def load_data(basedir, scene, testskip=8):
|
||||
def parse_txt(filename):
|
||||
assert os.path.isfile(filename)
|
||||
nums = open(filename).read().split()
|
||||
return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
|
||||
|
||||
def dir2poses(posedir):
|
||||
poses = np.stack(
|
||||
[parse_txt(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
|
||||
poses = poses.astype(np.float32)
|
||||
return poses
|
||||
|
||||
def dir2intrinsics(intrinsicdir):
|
||||
intrinsics = np.stack(
|
||||
[parse_txt(os.path.join(intrinsicdir, f)) for f in sorted(os.listdir(intrinsicdir)) if f.endswith('txt')], 0)
|
||||
intrinsics = intrinsics.astype(np.float32)
|
||||
return intrinsics
|
||||
|
||||
intrinsics = dir2intrinsics('{}/{}/train/intrinsics'.format(basedir, scene))
|
||||
testintrinsics = dir2poses('{}/{}/test/intrinsics'.format(basedir, scene))
|
||||
testintrinsics = testintrinsics[::testskip]
|
||||
valintrinsics = dir2poses('{}/{}/validation/intrinsics'.format(basedir, scene))
|
||||
valintrinsics = valintrinsics[::testskip]
|
||||
|
||||
print(intrinsics.shape, testintrinsics.shape, valintrinsics.shape)
|
||||
|
||||
poses = dir2poses('{}/{}/train/pose'.format(basedir, scene))
|
||||
testposes = dir2poses('{}/{}/test/pose'.format(basedir, scene))
|
||||
testposes = testposes[::testskip]
|
||||
valposes = dir2poses('{}/{}/validation/pose'.format(basedir, scene))
|
||||
valposes = valposes[::testskip]
|
||||
|
||||
print(poses.shape, testposes.shape, valposes.shape)
|
||||
|
||||
imgd = '{}/{}/train/rgb'.format(basedir, scene)
|
||||
imgfiles = ['{}/{}'.format(imgd, f)
|
||||
for f in sorted(os.listdir(imgd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
imgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in imgfiles]
|
||||
|
||||
maskd = '{}/{}/train/mask'.format(basedir, scene)
|
||||
if os.path.isdir(maskd):
|
||||
logger.info('Loading mask from: {}'.format(maskd))
|
||||
maskfiles = ['{}/{}'.format(maskd, f)
|
||||
for f in sorted(os.listdir(maskd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
masks = [imageio.imread(f).astype(np.float32) / 255. for f in maskfiles]
|
||||
else:
|
||||
masks = [None for im in imgs]
|
||||
|
||||
# load min_depth map
|
||||
min_depthd = '{}/{}/train/min_depth'.format(basedir, scene)
|
||||
if os.path.isdir(min_depthd):
|
||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
||||
max_depth = float(open('{}/{}/train/max_depth.txt'.format(basedir, scene)).readline().strip())
|
||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
||||
else:
|
||||
min_depths = [None for im in imgs]
|
||||
|
||||
testimgd = '{}/{}/test/rgb'.format(basedir, scene)
|
||||
testimgfiles = ['{}/{}'.format(testimgd, f)
|
||||
for f in sorted(os.listdir(testimgd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
testimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in testimgfiles]
|
||||
testimgfiles = testimgfiles[::testskip]
|
||||
testimgs = testimgs[::testskip]
|
||||
|
||||
testmaskd = '{}/{}/test/mask'.format(basedir, scene)
|
||||
if os.path.isdir(testmaskd):
|
||||
logger.info('Loading mask from: {}'.format(testmaskd))
|
||||
testmaskfiles = ['{}/{}'.format(testmaskd, f)
|
||||
for f in sorted(os.listdir(testmaskd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
testmasks = [imageio.imread(f).astype(np.float32) / 255. for f in testmaskfiles]
|
||||
else:
|
||||
testmasks = [None for im in testimgs]
|
||||
|
||||
# load min_depth map
|
||||
min_depthd = '{}/{}/test/min_depth'.format(basedir, scene)
|
||||
if os.path.isdir(min_depthd):
|
||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
||||
max_depth = float(open('{}/{}/test/max_depth.txt'.format(basedir, scene)).readline().strip())
|
||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
test_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
||||
else:
|
||||
test_min_depths = [None for im in testimgs]
|
||||
|
||||
valimgd = '{}/{}/validation/rgb'.format(basedir, scene)
|
||||
valimgfiles = ['{}/{}'.format(valimgd, f)
|
||||
for f in sorted(os.listdir(valimgd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
valimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in valimgfiles]
|
||||
valimgfiles = valimgfiles[::testskip]
|
||||
valimgs = valimgs[::testskip]
|
||||
|
||||
valmaskd = '{}/{}/validation/mask'.format(basedir, scene)
|
||||
if os.path.isdir(valmaskd):
|
||||
logger.info('Loading mask from: {}'.format(valmaskd))
|
||||
valmaskfiles = ['{}/{}'.format(valmaskd, f)
|
||||
for f in sorted(os.listdir(valmaskd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
valmasks = [imageio.imread(f).astype(np.float32) / 255. for f in valmaskfiles]
|
||||
else:
|
||||
valmasks = [None for im in valimgs]
|
||||
|
||||
# load min_depth map
|
||||
min_depthd = '{}/{}/validation/min_depth'.format(basedir, scene)
|
||||
if os.path.isdir(min_depthd):
|
||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
||||
max_depth = float(open('{}/{}/validation/max_depth.txt'.format(basedir, scene)).readline().strip())
|
||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
||||
val_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
||||
else:
|
||||
val_min_depths = [None for im in valimgs]
|
||||
|
||||
# data format for training/testing
|
||||
print(len(imgs), len(testimgs), len(valimgs))
|
||||
all_imgs = imgs + valimgs + testimgs
|
||||
all_masks = masks + valmasks + testmasks
|
||||
all_min_depths = min_depths + val_min_depths + test_min_depths
|
||||
all_paths = imgfiles + valimgfiles + testimgfiles
|
||||
|
||||
counts = [0] + [len(x) for x in [imgs, valimgs, testimgs]]
|
||||
counts = np.cumsum(counts)
|
||||
i_split = [list(np.arange(counts[i], counts[i+1])) for i in range(3)]
|
||||
|
||||
intrinsics = np.concatenate([intrinsics, valintrinsics, testintrinsics], 0)
|
||||
poses = np.concatenate([poses, valposes, testposes], 0)
|
||||
img_sizes = np.stack([np.array(x.shape[:2]) for x in all_imgs], axis=0) # [H, W]
|
||||
cnt = len(all_imgs)
|
||||
all_cams = np.concatenate((img_sizes.astype(dtype=np.float32), intrinsics.reshape((cnt, -1)), poses.reshape((cnt, -1))), axis=1)
|
||||
|
||||
if os.path.isdir('{}/{}/camera_path/intrinsics'.format(basedir, scene)):
|
||||
camera_path_intrinsics = dir2poses('{}/{}/camera_path/intrinsics'.format(basedir, scene))
|
||||
camera_path_poses = dir2poses('{}/{}/camera_path/pose'.format(basedir, scene))
|
||||
# assume centered principal points
|
||||
# img_sizes = np.stack((camera_path_intrinsics[:, 1, 2]*2, camera_path_intrinsics[:, 0, 2]*2), axis=1) # [H, W]
|
||||
# img_sizes = np.int32(img_sizes)
|
||||
|
||||
H = all_cams[0, 0]
|
||||
W = all_cams[0, 1]
|
||||
img_sizes = np.stack((np.ones_like(camera_path_intrinsics[:, 1, 2])*H, np.ones_like(camera_path_intrinsics[:, 0, 2])*W), axis=1) # [H, W]
|
||||
|
||||
cnt = len(camera_path_intrinsics)
|
||||
render_cams = np.concatenate(
|
||||
(img_sizes.astype(dtype=np.float32), camera_path_intrinsics.reshape((cnt, -1)), camera_path_poses.reshape((cnt, -1))),
|
||||
axis=1)
|
||||
else:
|
||||
render_cams = None
|
||||
|
||||
print(all_cams.shape)
|
||||
|
||||
data = OrderedDict([('images', all_imgs),
|
||||
('masks', all_masks),
|
||||
('paths', all_paths),
|
||||
('min_depths', all_min_depths),
|
||||
('cameras', all_cams),
|
||||
('i_train', i_split[0]),
|
||||
('i_val', i_split[1]),
|
||||
('i_test', i_split[2]),
|
||||
('render_cams', render_cams)])
|
||||
|
||||
logger.info('Data statistics:')
|
||||
logger.info('\t # of training views: {}'.format(len(data['i_train'])))
|
||||
logger.info('\t # of validation views: {}'.format(len(data['i_val'])))
|
||||
logger.info('\t # of test views: {}'.format(len(data['i_test'])))
|
||||
if data['render_cams'] is not None:
|
||||
logger.info('\t # of render cameras: {}'.format(len(data['render_cams'])))
|
||||
|
||||
return data
|
617
old_scripts/ddp_run_nerf.py
Normal file
617
old_scripts/ddp_run_nerf.py
Normal file
|
@ -0,0 +1,617 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.distributed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from ddp_model import NerfNet
|
||||
import time
|
||||
# from data_loader import load_data
|
||||
# from nerf_sample_ray import RaySamplerSingleImage
|
||||
|
||||
from data_loader_split import load_data_split
|
||||
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def setup_logger():
|
||||
# create logger
|
||||
logger = logging.getLogger(__package__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# create console handler and set level to debug
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
|
||||
# create formatter
|
||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
||||
|
||||
# add formatter to ch
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
# add ch to logger
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
def intersect_sphere(ray_o, ray_d):
|
||||
'''
|
||||
ray_o, ray_d: [..., 3]
|
||||
compute the depth of the intersection point between this ray and unit sphere
|
||||
'''
|
||||
# note: d1 becomes negative if this mid point is behind camera
|
||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
||||
p = ray_o + d1.unsqueeze(-1) * ray_d
|
||||
# consider the case where the ray does not intersect the sphere
|
||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
||||
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
|
||||
|
||||
return d1 + d2
|
||||
|
||||
|
||||
def perturb_samples(z_vals):
|
||||
# get intervals between samples
|
||||
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
|
||||
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
|
||||
# uniform samples in those intervals
|
||||
t_rand = torch.rand_like(z_vals)
|
||||
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
|
||||
|
||||
return z_vals
|
||||
|
||||
|
||||
def sample_pdf(bins, weights, N_samples, det=False):
|
||||
'''
|
||||
:param bins: tensor of shape [..., M+1], M is the number of bins
|
||||
:param weights: tensor of shape [..., M]
|
||||
:param N_samples: number of samples along each ray
|
||||
:param det: if True, will perform deterministic sampling
|
||||
:return: [..., N_samples]
|
||||
'''
|
||||
# Get pdf
|
||||
weights = weights + TINY_NUMBER # prevent nans
|
||||
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
|
||||
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
|
||||
|
||||
# Take uniform samples
|
||||
dots_sh = list(weights.shape[:-1])
|
||||
M = weights.shape[-1]
|
||||
|
||||
min_cdf = 0.00
|
||||
max_cdf = 1.00 # prevent outlier samples
|
||||
|
||||
if det:
|
||||
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
|
||||
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
|
||||
else:
|
||||
sh = dots_sh + [N_samples]
|
||||
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
|
||||
|
||||
# Invert CDF
|
||||
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
|
||||
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
|
||||
|
||||
# random sample inside each bin
|
||||
below_inds = torch.clamp(above_inds-1, min=0)
|
||||
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
|
||||
|
||||
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
||||
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
|
||||
|
||||
# fix numeric issue
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
|
||||
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
|
||||
##### parallel rendering of a single image
|
||||
ray_batch = ray_sampler.get_all()
|
||||
# split into ranks; make sure different processes don't overlap
|
||||
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
|
||||
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
|
||||
|
||||
# split into chunks and render inside each process
|
||||
ray_batch_split = OrderedDict()
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
|
||||
|
||||
# forward and backward
|
||||
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
|
||||
for s in range(len(ray_batch_split['ray_d'])):
|
||||
ray_o = ray_batch_split['ray_o'][s]
|
||||
ray_d = ray_batch_split['ray_d'][s]
|
||||
min_depth = ray_batch_split['min_depth'][s]
|
||||
|
||||
dots_sh = list(ray_d.shape[:-1])
|
||||
for m in range(models['cascade_level']):
|
||||
net = models['net_{}'.format(m)]
|
||||
# sample depths
|
||||
N_samples = models['cascade_samples'][m]
|
||||
if m == 0:
|
||||
# foreground depth
|
||||
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
|
||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
||||
fg_near_depth = min_depth # [..., 3]
|
||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
||||
|
||||
# background depth
|
||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
||||
|
||||
# delete unused memory
|
||||
del fg_near_depth
|
||||
del step
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# sample pdf and concat with earlier samples
|
||||
fg_weights = ret['fg_weights'].clone().detach()
|
||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
||||
|
||||
# sample pdf and concat with earlier samples
|
||||
bg_weights = ret['bg_weights'].clone().detach()
|
||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
||||
N_samples=N_samples, det=True) # [..., N_samples]
|
||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
||||
|
||||
# delete unused memory
|
||||
del fg_weights
|
||||
del fg_depth_mid
|
||||
del fg_depth_samples
|
||||
del bg_weights
|
||||
del bg_depth_mid
|
||||
del bg_depth_samples
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.no_grad():
|
||||
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
|
||||
|
||||
for key in ret:
|
||||
if key not in ['fg_weights', 'bg_weights']:
|
||||
if torch.is_tensor(ret[key]):
|
||||
if key not in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
|
||||
else:
|
||||
ret_merge_chunk[m][key].append(ret[key].cpu())
|
||||
|
||||
ret[key] = None
|
||||
|
||||
# clean unused memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# merge results from different chunks
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
|
||||
|
||||
# merge results from different processes
|
||||
if rank == 0:
|
||||
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
# generate tensors to store results from other processes
|
||||
sh = list(ret_merge_chunk[m][key].shape[1:])
|
||||
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
|
||||
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
|
||||
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
|
||||
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
|
||||
# print(m, key, ret_merge_rank[m][key].shape)
|
||||
else: # send results to main process
|
||||
for m in range(len(ret_merge_chunk)):
|
||||
for key in ret_merge_chunk[m]:
|
||||
torch.distributed.gather(ret_merge_chunk[m][key])
|
||||
|
||||
|
||||
# only rank 0 program returns
|
||||
if rank == 0:
|
||||
return ret_merge_rank
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=''):
|
||||
rgb_im = img_HWC2CHW(torch.from_numpy(gt_img))
|
||||
writer.add_image(prefix + 'rgb_gt', rgb_im, global_step)
|
||||
|
||||
for m in range(len(log_data)):
|
||||
rgb_im = img_HWC2CHW(log_data[m]['rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step)
|
||||
|
||||
rgb_im = img_HWC2CHW(log_data[m]['fg_rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step)
|
||||
depth = log_data[m]['fg_depth']
|
||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step)
|
||||
|
||||
rgb_im = img_HWC2CHW(log_data[m]['bg_rgb'])
|
||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
||||
writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step)
|
||||
depth = log_data[m]['bg_depth']
|
||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step)
|
||||
bg_lambda = log_data[m]['bg_lambda']
|
||||
bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True,
|
||||
mask=mask))
|
||||
writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step)
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def ddp_train_nerf(rank, args):
|
||||
###### set up multi-processing
|
||||
setup(rank, args.world_size)
|
||||
###### set up logger
|
||||
logger = logging.getLogger(__package__)
|
||||
setup_logger()
|
||||
|
||||
###### Create log dir and copy the config file
|
||||
if rank == 0:
|
||||
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
|
||||
f = os.path.join(args.basedir, args.expname, 'args.txt')
|
||||
with open(f, 'w') as file:
|
||||
for arg in sorted(vars(args)):
|
||||
attr = getattr(args, arg)
|
||||
file.write('{} = {}\n'.format(arg, attr))
|
||||
if args.config is not None:
|
||||
f = os.path.join(args.basedir, args.expname, 'config.txt')
|
||||
with open(f, 'w') as file:
|
||||
file.write(open(args.config, 'r').read())
|
||||
torch.distributed.barrier()
|
||||
|
||||
###### load data and create ray samplers; each process should do this
|
||||
# data = load_data(args.datadir, args.scene, args.testskip)
|
||||
# ray_samplers = []
|
||||
# for i in data['i_train']:
|
||||
# ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i],
|
||||
# img=data['images'][i],
|
||||
# img_path=data['paths'][i],
|
||||
# mask=data['masks'][i],
|
||||
# min_depth=data['min_depths'][i]))
|
||||
#
|
||||
# val_ray_samplers = []
|
||||
# for i in data['i_val']:
|
||||
# val_ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i],
|
||||
# img=data['images'][i],
|
||||
# img_path=data['paths'][i],
|
||||
# mask=data['masks'][i],
|
||||
# min_depth=data['min_depths'][i]))
|
||||
# # free memory
|
||||
# del data
|
||||
|
||||
ray_samplers = load_data_split(args.datadir, args.scene, split='train')
|
||||
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation')
|
||||
|
||||
###### create network and wrap in ddp; each process should do this
|
||||
# fix random seed just to make sure the network is initialized with same weights at different processes
|
||||
torch.manual_seed(777)
|
||||
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
models = OrderedDict()
|
||||
models['cascade_level'] = args.cascade_level
|
||||
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
|
||||
for m in range(models['cascade_level']):
|
||||
net = NerfNet(args).to(rank)
|
||||
net = DDP(net, device_ids=[rank], output_device=rank)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
|
||||
models['net_{}'.format(m)] = net
|
||||
models['optim_{}'.format(m)] = optim
|
||||
|
||||
start = -1
|
||||
|
||||
###### load pretrained weights; each process should do this
|
||||
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
|
||||
ckpts = [args.ckpt_path]
|
||||
else:
|
||||
ckpts = [os.path.join(args.basedir, args.expname, f)
|
||||
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
|
||||
def path2iter(path):
|
||||
tmp = os.path.basename(path)[:-4]
|
||||
idx = tmp.rfind('_')
|
||||
return int(tmp[idx + 1:])
|
||||
ckpts = sorted(ckpts, key=path2iter)
|
||||
logger.info('Found ckpts: {}'.format(ckpts))
|
||||
if len(ckpts) > 0 and not args.no_reload:
|
||||
fpath = ckpts[-1]
|
||||
logger.info('Reloading from: {}'.format(fpath))
|
||||
start = path2iter(fpath)
|
||||
# configure map_location properly for different processes
|
||||
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
||||
to_load = torch.load(fpath, map_location=map_location)
|
||||
for m in range(models['cascade_level']):
|
||||
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
|
||||
models[name].load_state_dict(to_load[name])
|
||||
models[name].load_state_dict(to_load[name])
|
||||
|
||||
##### important!!!
|
||||
# make sure different processes sample different rays
|
||||
np.random.seed((rank + 1) * 777)
|
||||
# make sure different processes have different perturbations in depth samples
|
||||
torch.manual_seed((rank + 1) * 777)
|
||||
|
||||
##### only main process should do the logging
|
||||
if rank == 0:
|
||||
writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname))
|
||||
|
||||
# start training
|
||||
what_val_to_log = 0 # helper variable for parallel rendering of a image
|
||||
what_train_to_log = 0
|
||||
for global_step in range(start+1, start+1+args.N_iters):
|
||||
time0 = time.time()
|
||||
scalars_to_log = OrderedDict()
|
||||
### Start of core optimization loop
|
||||
scalars_to_log['resolution'] = ray_samplers[0].resolution_level
|
||||
# randomly sample rays and move to device
|
||||
i = np.random.randint(low=0, high=len(ray_samplers))
|
||||
ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False)
|
||||
for key in ray_batch:
|
||||
if torch.is_tensor(ray_batch[key]):
|
||||
ray_batch[key] = ray_batch[key].to(rank)
|
||||
|
||||
# forward and backward
|
||||
dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays
|
||||
all_rets = [] # results on different cascade levels
|
||||
for m in range(models['cascade_level']):
|
||||
optim = models['optim_{}'.format(m)]
|
||||
net = models['net_{}'.format(m)]
|
||||
|
||||
# sample depths
|
||||
N_samples = models['cascade_samples'][m]
|
||||
if m == 0:
|
||||
# foreground depth
|
||||
fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,]
|
||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
||||
fg_near_depth = ray_batch['min_depth'] # [..., 3]
|
||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
||||
fg_depth = perturb_samples(fg_depth) # random perturbation during training
|
||||
|
||||
# background depth
|
||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
||||
bg_depth = perturb_samples(bg_depth) # random perturbation during training
|
||||
else:
|
||||
# sample pdf and concat with earlier samples
|
||||
fg_weights = ret['fg_weights'].clone().detach()
|
||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
||||
N_samples=N_samples, det=False) # [..., N_samples]
|
||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
||||
|
||||
# sample pdf and concat with earlier samples
|
||||
bg_weights = ret['bg_weights'].clone().detach()
|
||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
||||
N_samples=N_samples, det=False) # [..., N_samples]
|
||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
||||
|
||||
optim.zero_grad()
|
||||
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth)
|
||||
all_rets.append(ret)
|
||||
|
||||
rgb_gt = ray_batch['rgb'].to(rank)
|
||||
loss = img2mse(ret['rgb'], rgb_gt)
|
||||
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
|
||||
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
|
||||
# regularize sigma with photo-consistency
|
||||
loss = loss + img2mse(ret['diffuse_rgb'], rgb_gt)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
# # clean unused memory
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
### end of core optimization loop
|
||||
dt = time.time() - time0
|
||||
scalars_to_log['iter_time'] = dt
|
||||
|
||||
### only main process should do the logging
|
||||
if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
|
||||
logstr = '{} step: {} '.format(args.expname, global_step)
|
||||
for k in scalars_to_log:
|
||||
logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
|
||||
writer.add_scalar(k, scalars_to_log[k], global_step)
|
||||
logger.info(logstr)
|
||||
|
||||
### each process should do this; but only main process merges the results
|
||||
if global_step % args.i_img == 0 or global_step == start+1:
|
||||
#### critical: make sure each process is working on the same random image
|
||||
time0 = time.time()
|
||||
idx = what_val_to_log % len(val_ray_samplers)
|
||||
log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size)
|
||||
what_val_to_log += 1
|
||||
dt = time.time() - time0
|
||||
if rank == 0: # only main process should do this
|
||||
logger.info('Logged a random validation view in {} seconds'.format(dt))
|
||||
log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].img_orig, mask=None, prefix='val/')
|
||||
|
||||
time0 = time.time()
|
||||
idx = what_train_to_log % len(ray_samplers)
|
||||
log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
||||
what_train_to_log += 1
|
||||
dt = time.time() - time0
|
||||
if rank == 0: # only main process should do this
|
||||
logger.info('Logged a random training view in {} seconds'.format(dt))
|
||||
log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].img_orig, mask=None, prefix='train/')
|
||||
|
||||
log_data = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
|
||||
# saving checkpoints and logging
|
||||
fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step))
|
||||
to_save = OrderedDict()
|
||||
for m in range(models['cascade_level']):
|
||||
name = 'net_{}'.format(m)
|
||||
to_save[name] = models[name].state_dict()
|
||||
|
||||
name = 'optim_{}'.format(m)
|
||||
to_save[name] = models[name].state_dict()
|
||||
torch.save(to_save, fpath)
|
||||
|
||||
# clean up for multi-processing
|
||||
cleanup()
|
||||
|
||||
|
||||
def config_parser():
|
||||
import configargparse
|
||||
parser = configargparse.ArgumentParser()
|
||||
parser.add_argument('--config', is_config_file=True, help='config file path')
|
||||
parser.add_argument("--expname", type=str, help='experiment name')
|
||||
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
|
||||
|
||||
# dataset options
|
||||
parser.add_argument("--datadir", type=str, default=None, help='input data directory')
|
||||
parser.add_argument("--scene", type=str, default=None, help='scene name')
|
||||
parser.add_argument("--testskip", type=int, default=8,
|
||||
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
||||
|
||||
# model size
|
||||
parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
|
||||
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
|
||||
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
|
||||
|
||||
# checkpoints
|
||||
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
|
||||
parser.add_argument("--ckpt_path", type=str, default=None,
|
||||
help='specific weights npy file to reload for coarse network')
|
||||
|
||||
# batch size
|
||||
parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
|
||||
help='batch size (number of random rays per gradient step)')
|
||||
parser.add_argument("--chunk_size", type=int, default=1024 * 8,
|
||||
help='number of rays processed in parallel, decrease if running out of memory')
|
||||
|
||||
# iterations
|
||||
parser.add_argument("--N_iters", type=int, default=250001,
|
||||
help='number of iterations')
|
||||
|
||||
# cascade training
|
||||
parser.add_argument("--cascade_level", type=int, default=2,
|
||||
help='number of cascade levels')
|
||||
parser.add_argument("--cascade_samples", type=str, default='64,64',
|
||||
help='samples at each level')
|
||||
parser.add_argument("--devices", type=str, default='0,1',
|
||||
help='cuda device for each level')
|
||||
parser.add_argument("--bg_devices", type=str, default='0,2',
|
||||
help='cuda device for the background of each level')
|
||||
|
||||
parser.add_argument("--world_size", type=int, default='-1',
|
||||
help='number of processes')
|
||||
|
||||
# mixed precison training
|
||||
parser.add_argument("--opt_level", type=str, default='O1',
|
||||
help='mixed precison training')
|
||||
|
||||
parser.add_argument("--near_depth", type=float, default=0.1,
|
||||
help='near depth plane')
|
||||
parser.add_argument("--far_depth", type=float, default=50.,
|
||||
help='far depth plane')
|
||||
|
||||
# learning rate options
|
||||
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
|
||||
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
|
||||
help='decay learning rate by a factor every specified number of steps')
|
||||
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
|
||||
help='decay learning rate by a factor every specified number of steps')
|
||||
|
||||
# rendering options
|
||||
parser.add_argument("--inv_uniform", action='store_true',
|
||||
help='if True, will uniformly sample inverse depths')
|
||||
parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
|
||||
parser.add_argument("--max_freq_log2", type=int, default=10,
|
||||
help='log2 of max freq for positional encoding (3D location)')
|
||||
parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
|
||||
help='log2 of max freq for positional encoding (2D direction)')
|
||||
parser.add_argument("--N_iters_perturb", type=int, default=1000,
|
||||
help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck')
|
||||
parser.add_argument("--raw_noise_std", type=float, default=1.,
|
||||
help='std dev of noise added to regularize sigma output, 1e0 recommended')
|
||||
parser.add_argument("--white_bkgd", action='store_true',
|
||||
help='apply the trick to avoid fitting to white background')
|
||||
|
||||
# no training; render only
|
||||
parser.add_argument("--render_only", action='store_true',
|
||||
help='do not optimize, reload weights and render out render_poses path')
|
||||
parser.add_argument("--render_train", action='store_true', help='render the training set')
|
||||
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
|
||||
|
||||
# no training; extract mesh only
|
||||
parser.add_argument("--mesh_only", action='store_true',
|
||||
help='do not optimize, extract mesh from pretrained model')
|
||||
parser.add_argument("--N_pts", type=int, default=256,
|
||||
help='voxel resolution; N_pts * N_pts * N_pts')
|
||||
parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50',
|
||||
help='threshold(s) for mesh extraction; can use multiple thresholds')
|
||||
|
||||
# logging/saving options
|
||||
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
|
||||
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
|
||||
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
|
||||
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
|
||||
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def train():
|
||||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
logger.info(parser.format_values())
|
||||
|
||||
if args.world_size == -1:
|
||||
args.world_size = torch.cuda.device_count()
|
||||
logger.info('Using # gpus: {}'.format(args.world_size))
|
||||
torch.multiprocessing.spawn(ddp_train_nerf,
|
||||
args=(args,),
|
||||
nprocs=args.world_size,
|
||||
join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_logger()
|
||||
train()
|
||||
|
||||
|
231
old_scripts/nerf_sample_ray.py
Normal file
231
old_scripts/nerf_sample_ray.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
|
||||
########################################################################################################################
|
||||
# ray batch sampling
|
||||
########################################################################################################################
|
||||
|
||||
def parse_camera(params):
|
||||
H, W = params[:2]
|
||||
intrinsics = params[2:18].reshape((4, 4))
|
||||
c2w = params[18:34].reshape((4, 4))
|
||||
|
||||
return int(W), int(H), intrinsics.astype(np.float32), c2w.astype(np.float32)
|
||||
|
||||
|
||||
def get_rays_single_image(H, W, intrinsics, c2w):
|
||||
'''
|
||||
:param H: image height
|
||||
:param W: image width
|
||||
:param intrinsics: 4 by 4 intrinsic matrix
|
||||
:param c2w: 4 by 4 camera to world extrinsic matrix
|
||||
:return:
|
||||
'''
|
||||
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
||||
|
||||
u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel
|
||||
v = v.reshape(-1).astype(dtype=np.float32) + 0.5
|
||||
pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W)
|
||||
|
||||
rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels)
|
||||
rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W)
|
||||
rays_d = rays_d.transpose((1, 0)) # (H*W, 3)
|
||||
|
||||
rays_o = c2w[:3, 3].reshape((1, 3))
|
||||
rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3)
|
||||
|
||||
depth = np.linalg.inv(c2w)[2, 3]
|
||||
depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,)
|
||||
|
||||
return rays_o, rays_d, depth
|
||||
|
||||
|
||||
class RaySamplerSingleImage(object):
|
||||
def __init__(self, cam_params, img_path=None, img=None, resolution_level=1, mask=None, min_depth=None):
|
||||
super().__init__()
|
||||
self.W_orig, self.H_orig, self.intrinsics_orig, self.c2w_mat = parse_camera(cam_params)
|
||||
|
||||
self.img_path = img_path
|
||||
self.img_orig = img
|
||||
self.mask_orig = mask
|
||||
self.min_depth_orig = min_depth
|
||||
|
||||
self.resolution_level = -1
|
||||
self.set_resolution_level(resolution_level)
|
||||
|
||||
def set_resolution_level(self, resolution_level):
|
||||
if resolution_level != self.resolution_level:
|
||||
self.resolution_level = resolution_level
|
||||
self.W = self.W_orig // resolution_level
|
||||
self.H = self.H_orig // resolution_level
|
||||
self.intrinsics = np.copy(self.intrinsics_orig)
|
||||
self.intrinsics[:2, :3] /= resolution_level
|
||||
if self.img_orig is not None:
|
||||
self.img = cv2.resize(self.img_orig, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
||||
self.img = self.img.reshape((-1, 3))
|
||||
else:
|
||||
self.img = None
|
||||
|
||||
if self.mask_orig is not None:
|
||||
self.mask = cv2.resize(self.mask_orig, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
|
||||
self.mask = self.mask.reshape((-1))
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
if self.min_depth_orig is not None:
|
||||
self.min_depth = cv2.resize(self.min_depth_orig, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
||||
self.min_depth = self.min_depth.reshape((-1))
|
||||
else:
|
||||
self.min_depth = None
|
||||
|
||||
self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W,
|
||||
self.intrinsics, self.c2w_mat)
|
||||
|
||||
def get_all(self):
|
||||
if self.min_depth is not None:
|
||||
min_depth = self.min_depth
|
||||
else:
|
||||
min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0])
|
||||
|
||||
ret = OrderedDict([
|
||||
('ray_o', self.rays_o),
|
||||
('ray_d', self.rays_d),
|
||||
('depth', self.depth),
|
||||
('rgb', self.img),
|
||||
('mask', self.mask),
|
||||
('min_depth', min_depth)
|
||||
])
|
||||
# return torch tensors
|
||||
for k in ret:
|
||||
if ret[k] is not None:
|
||||
ret[k] = torch.from_numpy(ret[k])
|
||||
return ret
|
||||
|
||||
def random_sample(self, N_rand, center_crop=False):
|
||||
'''
|
||||
:param N_rand: number of rays to be casted
|
||||
:return:
|
||||
'''
|
||||
if center_crop:
|
||||
half_H = self.H // 2
|
||||
half_W = self.W // 2
|
||||
quad_H = half_H // 2
|
||||
quad_W = half_W // 2
|
||||
|
||||
# pixel coordinates
|
||||
u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W),
|
||||
np.arange(half_H-quad_H, half_H+quad_H))
|
||||
u = u.reshape(-1)
|
||||
v = v.reshape(-1)
|
||||
|
||||
select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False)
|
||||
|
||||
# Convert back to original image
|
||||
select_inds = v[select_inds] * self.W + u[select_inds]
|
||||
else:
|
||||
# Random from one image
|
||||
select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False)
|
||||
|
||||
rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
||||
rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
||||
depth = self.depth[select_inds] # [N_rand, ]
|
||||
|
||||
if self.img is not None:
|
||||
rgb = self.img[select_inds, :] # [N_rand, 3]
|
||||
else:
|
||||
rgb = None
|
||||
|
||||
if self.mask is not None:
|
||||
mask = self.mask[select_inds]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if self.min_depth is not None:
|
||||
min_depth = self.min_depth[select_inds]
|
||||
else:
|
||||
min_depth = 1e-4 * np.ones_like(rays_d[..., 0])
|
||||
|
||||
ret = OrderedDict([
|
||||
('ray_o', rays_o),
|
||||
('ray_d', rays_d),
|
||||
('depth', depth),
|
||||
('rgb', rgb),
|
||||
('mask', mask),
|
||||
('min_depth', min_depth)
|
||||
])
|
||||
# return torch tensors
|
||||
for k in ret:
|
||||
if ret[k] is not None:
|
||||
ret[k] = torch.from_numpy(ret[k])
|
||||
|
||||
return ret
|
||||
|
||||
# def random_sample_patches(self, N_patch, r_patch=16, center_crop=False):
|
||||
# '''
|
||||
# :param N_patch: number of patches to be sampled
|
||||
# :param r_patch: patch size will be (2*r_patch+1)*(2*r_patch+1)
|
||||
# :return:
|
||||
# '''
|
||||
# # even size patch
|
||||
# # offsets to center pixels
|
||||
# u, v = np.meshgrid(np.arange(-r_patch, r_patch),
|
||||
# np.arange(-r_patch, r_patch))
|
||||
# u = u.reshape(-1)
|
||||
# v = v.reshape(-1)
|
||||
# offsets = v * self.W + u
|
||||
|
||||
# # center pixel coordinates
|
||||
# u_min = r_patch
|
||||
# u_max = self.W - r_patch
|
||||
# v_min = r_patch
|
||||
# v_max = self.H - r_patch
|
||||
# if center_crop:
|
||||
# u_min = self.W // 4 + r_patch
|
||||
# u_max = self.W - self.W // 4 - r_patch
|
||||
# v_min = self.H // 4 + r_patch
|
||||
# v_max = self.H - self.H // 4 - r_patch
|
||||
|
||||
# u, v = np.meshgrid(np.arange(u_min, u_max, r_patch),
|
||||
# np.arange(v_min, v_max, r_patch))
|
||||
# u = u.reshape(-1)
|
||||
# v = v.reshape(-1)
|
||||
|
||||
# select_inds = np.random.choice(u.shape[0], size=(N_patch,), replace=False)
|
||||
# # Convert back to original image
|
||||
# select_inds = v[select_inds] * self.W + u[select_inds]
|
||||
|
||||
# # pick patches
|
||||
# select_inds = np.stack([select_inds + shift for shift in offsets], axis=1)
|
||||
# select_inds = select_inds.reshape(-1)
|
||||
|
||||
# rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
||||
# rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
||||
# depth = self.depth[select_inds] # [N_rand, ]
|
||||
|
||||
# if self.img is not None:
|
||||
# rgb = self.img[select_inds, :] # [N_rand, 3]
|
||||
|
||||
# # ### debug
|
||||
# # import imageio
|
||||
# # imgs = rgb.reshape((N_patch, r_patch*2, r_patch*2, -1))
|
||||
# # for kk in range(imgs.shape[0]):
|
||||
# # imageio.imwrite('./debug_{}.png'.format(kk), imgs[kk])
|
||||
# # ###
|
||||
# else:
|
||||
# rgb = None
|
||||
|
||||
# ret = OrderedDict([
|
||||
# ('ray_o', rays_o),
|
||||
# ('ray_d', rays_d),
|
||||
# ('depth', depth),
|
||||
# ('rgb', rgb)
|
||||
# ])
|
||||
|
||||
# # return torch tensors
|
||||
# for k in ret:
|
||||
# ret[k] = torch.from_numpy(ret[k])
|
||||
|
||||
# return ret
|
19
render_all.sh
Executable file
19
render_all.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C pascal
|
||||
#SBATCH --mem=40G
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
#SBATCH --qos=high
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
#$PYTHON -u $CODE_DIR/ddp_test_nerf.py --config $CODE_DIR/configs/lf_data/lf_africa.txt
|
||||
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_test_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck.txt
|
24
render_tat_training_truck.sh
Executable file
24
render_tat_training_truck.sh
Executable file
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:3
|
||||
#SBATCH -c 8
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
|
||||
echo $CODE_DIR
|
||||
|
||||
#$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
|
||||
#$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
|
||||
#$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
|
||||
|
||||
$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
|
||||
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
|
||||
|
||||
#$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
|
||||
#$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
|
16
sparse_playground.sh
Executable file
16
sparse_playground.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addparam.txt
|
16
sparse_playground_addcarve.sh
Executable file
16
sparse_playground_addcarve.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addcarve.txt
|
16
sparse_playground_addregularize.sh
Executable file
16
sparse_playground_addregularize.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addregularize.txt
|
16
sparse_truck.sh
Executable file
16
sparse_truck.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addparam.txt
|
16
sparse_truck_addcarve.sh
Executable file
16
sparse_truck_addcarve.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
####SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addcarve.txt
|
16
sparse_truck_addregularize.sh
Executable file
16
sparse_truck_addregularize.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=60G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addregularize.txt
|
16
train_lf_africa.sh
Executable file
16
train_lf_africa.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=80G
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
#SBATCH --qos=high
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_africa.txt
|
17
train_lf_basket.sh
Executable file
17
train_lf_basket.sh
Executable file
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=100G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
######## #SBATCH --qos=high
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_basket.txt
|
16
train_lf_ship.sh
Executable file
16
train_lf_ship.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=80G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
#SBATCH --qos=normal
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_ship.txt
|
16
train_lf_torch.sh
Executable file
16
train_lf_torch.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=80G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
#SBATCH --qos=normal
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_torch.txt
|
18
train_tat_intermediate_m60.sh
Executable file
18
train_tat_intermediate_m60.sh
Executable file
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:3
|
||||
#SBATCH -c 8
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_m60.txt
|
||||
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_m60.txt
|
||||
|
18
train_tat_intermediate_playground.sh
Executable file
18
train_tat_intermediate_playground.sh
Executable file
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:3
|
||||
#SBATCH -c 8
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
|
||||
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
|
||||
|
15
train_tat_intermediate_playground_bignet.sh
Executable file
15
train_tat_intermediate_playground_bignet.sh
Executable file
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH -c 25
|
||||
#SBATCH -C turing
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground_bignet.txt
|
18
train_tat_intermediate_train.sh
Executable file
18
train_tat_intermediate_train.sh
Executable file
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:3
|
||||
#SBATCH -c 8
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_train.txt
|
||||
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_train.txt
|
||||
|
16
train_tat_training_truck.sh
Executable file
16
train_tat_training_truck.sh
Executable file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p q6
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH -c 10
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=50G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck.txt
|
15
train_tat_training_truck_bignet.sh
Executable file
15
train_tat_training_truck_bignet.sh
Executable file
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH -c 25
|
||||
#SBATCH -C turing
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_bignet.txt
|
19
train_tat_truck_sphere_subset.sh
Executable file
19
train_tat_truck_sphere_subset.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --gres=gpu:3
|
||||
#SBATCH -c 8
|
||||
#SBATCH -C turing
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --output=slurm_%A.out
|
||||
#SBATCH --exclude=isl-gpu17
|
||||
|
||||
|
||||
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
|
||||
|
||||
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg
|
||||
echo $CODE_DIR
|
||||
|
||||
$PYTHON $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_subset.txt
|
||||
$PYTHON $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_subset.txt
|
||||
|
206
utils.py
Normal file
206
utils.py
Normal file
|
@ -0,0 +1,206 @@
|
|||
import torch
|
||||
# import torch.nn as nn
|
||||
# import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
HUGE_NUMBER = 1e10
|
||||
TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
|
||||
|
||||
|
||||
# Misc utils
|
||||
# work on tensors
|
||||
# img2mse = lambda x, y: torch.mean((x - y) * (x - y))
|
||||
def img2mse(x, y, mask=None):
|
||||
if mask is None:
|
||||
return torch.mean((x - y) * (x - y))
|
||||
else:
|
||||
return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER)
|
||||
|
||||
img_HWC2CHW = lambda x: x.permute(2, 0, 1)
|
||||
gray2rgb = lambda x: x.unsqueeze(2).repeat(1, 1, 3)
|
||||
|
||||
|
||||
def normalize(x):
|
||||
min = x.min()
|
||||
max = x.max()
|
||||
|
||||
return (x - min) / ((max - min) + TINY_NUMBER)
|
||||
|
||||
|
||||
to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
|
||||
# gray2rgb = lambda x: np.tile(x[:,:,np.newaxis], (1, 1, 3))
|
||||
mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.)
|
||||
|
||||
#
|
||||
# def normalize(x):
|
||||
# x_min, x_max = np.percentile(x, (0.5, 99.5))
|
||||
# x = np.clip(x, x_min, x_max)
|
||||
# x = (x - x_min) / (x_max - x_min)
|
||||
# return x
|
||||
|
||||
|
||||
########################################################################################################################
|
||||
#
|
||||
########################################################################################################################
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from matplotlib.figure import Figure
|
||||
import matplotlib as mpl
|
||||
from matplotlib import cm
|
||||
import cv2
|
||||
|
||||
|
||||
def get_vertical_colorbar(h, vmin, vmax, cmap_name='jet', label=None):
|
||||
'''
|
||||
:param w: pixels
|
||||
:param h: pixels
|
||||
:param vmin: min value
|
||||
:param vmax: max value
|
||||
:param cmap_name:
|
||||
:param label
|
||||
:return:
|
||||
'''
|
||||
fig = Figure(figsize=(1.2, 8), dpi=100)
|
||||
fig.subplots_adjust(right=1.5)
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
|
||||
# Do some plotting.
|
||||
ax = fig.add_subplot(111)
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
||||
|
||||
tick_cnt = 6
|
||||
tick_loc = np.linspace(vmin, vmax, tick_cnt)
|
||||
cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap,
|
||||
norm=norm,
|
||||
ticks=tick_loc,
|
||||
orientation='vertical')
|
||||
|
||||
tick_label = ['{:3.2f}'.format(x) for x in tick_loc]
|
||||
cb1.set_ticklabels(tick_label)
|
||||
|
||||
cb1.ax.tick_params(labelsize=18, rotation=0)
|
||||
|
||||
if label is not None:
|
||||
cb1.set_label(label)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
# # debug
|
||||
# fig.savefig("debug3.png")
|
||||
|
||||
canvas.draw()
|
||||
s, (width, height) = canvas.print_to_buffer()
|
||||
|
||||
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
|
||||
|
||||
im = im[:, :, :3].astype(np.float32) / 255.
|
||||
if h != im.shape[0]:
|
||||
w = int(im.shape[1] / im.shape[0] * h)
|
||||
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
# def colorize_np(x, cmap_name='jet', append_cbar=False):
|
||||
# vmin = x.min()
|
||||
# vmax = x.max() + TINY_NUMBER
|
||||
# x = (x - vmin) / (vmax - vmin)
|
||||
# # x = np.clip(x, 0., 1.)
|
||||
|
||||
# cmap = cm.get_cmap(cmap_name)
|
||||
# x_new = cmap(x)[:, :, :3]
|
||||
|
||||
# cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name)
|
||||
|
||||
# if append_cbar:
|
||||
# x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1)
|
||||
# return x_new
|
||||
# else:
|
||||
# return x_new, cbar
|
||||
|
||||
|
||||
# # tensor
|
||||
# def colorize(x, cmap_name='jet', append_cbar=False):
|
||||
# x = x.numpy()
|
||||
# x, cbar = colorize_np(x, cmap_name)
|
||||
|
||||
# if append_cbar:
|
||||
# x = np.concatenate((x, np.zeros_like(x[:, :5, :]), cbar), axis=1)
|
||||
|
||||
# x = torch.from_numpy(x)
|
||||
# return x
|
||||
|
||||
def colorize_np(x, cmap_name='jet', mask=None, append_cbar=False):
|
||||
if mask is not None:
|
||||
# vmin, vmax = np.percentile(x[mask], (1, 99))
|
||||
vmin = np.min(x[mask])
|
||||
vmax = np.max(x[mask])
|
||||
vmin = vmin - np.abs(vmin) * 0.01
|
||||
x[np.logical_not(mask)] = vmin
|
||||
x = np.clip(x, vmin, vmax)
|
||||
# print(vmin, vmax)
|
||||
else:
|
||||
vmin = x.min()
|
||||
vmax = x.max() + TINY_NUMBER
|
||||
|
||||
x = (x - vmin) / (vmax - vmin)
|
||||
# x = np.clip(x, 0., 1.)
|
||||
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
x_new = cmap(x)[:, :, :3]
|
||||
|
||||
if mask is not None:
|
||||
mask = np.float32(mask[:, :, np.newaxis])
|
||||
x_new = x_new * mask + np.zeros_like(x_new) * (1. - mask)
|
||||
|
||||
cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name)
|
||||
|
||||
if append_cbar:
|
||||
x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1)
|
||||
return x_new
|
||||
else:
|
||||
return x_new, cbar
|
||||
|
||||
|
||||
# tensor
|
||||
def colorize(x, cmap_name='jet', append_cbar=False, mask=None):
|
||||
x = x.numpy()
|
||||
if mask is not None:
|
||||
mask = mask.numpy().astype(dtype=np.bool)
|
||||
x, cbar = colorize_np(x, cmap_name, mask)
|
||||
|
||||
if append_cbar:
|
||||
x = np.concatenate((x, np.zeros_like(x[:, :5, :]), cbar), axis=1)
|
||||
|
||||
x = torch.from_numpy(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# # cbar = get_vertical_colorbar(h=512, vmin=0.1, vmax=5, cmap_name='jet')
|
||||
# # cbar = cbar[:, :, :3]
|
||||
# import imageio
|
||||
#
|
||||
# # imageio.imwrite('./debug.png', cbar)
|
||||
#
|
||||
# x = torch.rand(512, 512)
|
||||
# x = colorize(x, append_cbar=True)
|
||||
#
|
||||
# x = np.uint8(x.numpy() * 255.)
|
||||
#
|
||||
# import imageio
|
||||
# imageio.imwrite('./debug.png', x)
|
||||
|
||||
import os
|
||||
import imageio
|
||||
|
||||
img_dir = '/phoenix/S7/kz298/latest_work/nerf/logs/dtu_scan9_3_nearfar/renderonly_train_200001'
|
||||
|
||||
all_imgs = []
|
||||
for item in sorted(os.listdir(img_dir)):
|
||||
if item[-4:] == '.png':
|
||||
fpath = os.path.join(img_dir, item)
|
||||
all_imgs.append(imageio.imread(fpath))
|
||||
|
||||
imageio.mimwrite(os.path.join(img_dir, 'video.mp4'), all_imgs, fps=3, quality=8)
|
Loading…
Reference in a new issue