v0.1 release to public

This commit is contained in:
bmild 2020-03-19 23:43:48 +00:00 committed by Yen-Chen Lin
commit bcb670095e
35 changed files with 2698 additions and 0 deletions

9
.gitignore vendored Normal file
View file

@ -0,0 +1,9 @@
**/.ipynb_checkpoints
**/__pycache__
*.png
*.mp4
*.npy
*.npz
*.dae
data/*
logs/*

0
.gitmodules vendored Normal file
View file

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 bmild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

164
README.md Normal file
View file

@ -0,0 +1,164 @@
# NeRF-pytorch
[NeRF](http://www.matthewtancik.com/nerf) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif)
![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif)
This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is tested to match authors' Tensorflow implementation [here](https://github.com/bmild/nerf) numerically.
## Installation
```
git clone https://github.com/yenchenlin/nerf-pytorch.git
cd nerf-pytorch
pip install -r requirements.txt
cd torchsearchsorted
pip install .
cd ../
```
<details>
<summary> Dependencies (click to expand) </summary>
## Dependencies
- PyTorch 1.4
- matplotlib
- numpy
- imageio
- imageio-ffmpeg
- configargparse
The LLFF data loader requires ImageMagick.
You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
</details>
## How To Run?
### Quick Start
Download data for two example datasets: `lego` and `fern`
```
bash download_example_data.sh
```
To train a low-res `lego` NeRF:
```
python run_nerf_torch.py --config configs/config_lego.txt
```
After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif)
---
To train a low-res `fern` NeRF:
```
python run_nerf_torch.py --config configs/config_fern.txt
```
After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif)
---
### More Datasets
To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
```
├── configs
│   ├── ...
│  
├── data
│   ├── nerf_llff_data
│   │   └── fern
│   │  └── flower # downloaded llff dataset
│   │  └── horns # downloaded llff dataset
| | └── ...
| ├── nerf_synthetic
| | └── lego
| | └── ship # downloaded synthetic dataset
| | └── ...
```
---
To train NeRF on different datasets:
```
python run_nerf_torch.py --config configs/config_{DATASET}.txt
```
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
---
To test NeRF trained on different datasets:
```
python run_nerf_torch.py --config configs/config_{DATASET}.txt --render_only
```
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
### Pre-trained Models
You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
```
├── logs
│   ├── fern_test
│   ├── flower_test # downloaded logs
│ ├── trex_test # downloaded logs
```
### Reproducibility
Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
```
git checkout reproduce
py.test
```
## Method
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
\*denotes equal contribution
<img src='imgs/pipeline.jpg'/>
> A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
## Citation
Kudos to the authors for their amazing results:
```
@misc{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
eprint={2003.08934},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
However, if you find this implementation or pre-trained models helpful, please consider to cite:
```
@misc{lin2020nerfpytorch,
title={NeRF-pytorch},
author={Yen-Chen, Lin},
howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
year={2020}
}
```

15
configs/config_fern.txt Normal file
View file

@ -0,0 +1,15 @@
expname = fern_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fern
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0

15
configs/config_flower.txt Normal file
View file

@ -0,0 +1,15 @@
expname = flower_test
basedir = ./logs
datadir = ./data/nerf_llff_data/flower
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0

View file

@ -0,0 +1,15 @@
expname = fortress_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fortress
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0

15
configs/config_horns.txt Normal file
View file

@ -0,0 +1,15 @@
expname = horns_test
basedir = ./logs
datadir = ./data/nerf_llff_data/horns
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0

15
configs/config_lego.txt Normal file
View file

@ -0,0 +1,15 @@
expname = lego_test
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender
half_res = True
N_samples = 64
N_importance = 64
use_viewdirs = True
white_bkgd = True
N_rand = 1024

15
configs/config_trex.txt Normal file
View file

@ -0,0 +1,15 @@
expname = trex_test
basedir = ./logs
datadir = ./data/nerf_llff_data/trex
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0

6
download_example_data.sh Normal file
View file

@ -0,0 +1,6 @@
wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
mkdir -p data
cd data
wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
unzip nerf_example_data.zip
cd ..

BIN
imgs/pipeline.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 342 KiB

91
load_blender.py Normal file
View file

@ -0,0 +1,91 @@
import os
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
def load_blender_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip
for frame in meta['frames'][::skip]:
fname = os.path.join(basedir, frame['file_path'] + '.png')
imgs.append(imageio.imread(fname))
poses.append(np.array(frame['transform_matrix']))
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
if half_res:
H = H//2
W = W//2
focal = focal/2.
imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
return imgs, poses, render_poses, [H, W, focal], i_split

110
load_deepvoxels.py Normal file
View file

@ -0,0 +1,110 @@
import os
import numpy as np
import imageio
def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
# Get camera intrinsics
with open(filepath, 'r') as file:
f, cx, cy = list(map(float, file.readline().split()))[:3]
grid_barycenter = np.array(list(map(float, file.readline().split())))
near_plane = float(file.readline())
scale = float(file.readline())
height, width = map(float, file.readline().split())
try:
world2cam_poses = int(file.readline())
except ValueError:
world2cam_poses = None
if world2cam_poses is None:
world2cam_poses = False
world2cam_poses = bool(world2cam_poses)
print(cx,cy,f,height,width)
cx = cx / width * trgt_sidelength
cy = cy / height * trgt_sidelength
f = trgt_sidelength / height * f
fx = f
if invert_y:
fy = -f
else:
fy = f
# Build the intrinsic matrices
full_intrinsic = np.array([[fx, 0., cx, 0.],
[0., fy, cy, 0],
[0., 0, 1, 0],
[0, 0, 0, 1]])
return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
def load_pose(filename):
assert os.path.isfile(filename)
nums = open(filename).read().split()
return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
H = 512
W = 512
deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
focal = full_intrinsic[0,0]
print(H, W, focal)
def dir2poses(posedir):
poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
transf = np.array([
[1,0,0,0],
[0,-1,0,0],
[0,0,-1,0],
[0,0,0,1.],
])
poses = poses @ transf
poses = poses[:,:3,:4].astype(np.float32)
return poses
posedir = os.path.join(deepvoxels_base, 'pose')
poses = dir2poses(posedir)
testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
testposes = testposes[::testskip]
valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
valposes = valposes[::testskip]
imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
testimgd = '{}/test/{}/rgb'.format(basedir, scene)
imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
all_imgs = [imgs, valimgs, testimgs]
counts = [0] + [x.shape[0] for x in all_imgs]
counts = np.cumsum(counts)
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate([poses, valposes, testposes], 0)
render_poses = testposes
print(poses.shape, imgs.shape)
return imgs, poses, render_poses, [H,W,focal], i_split

319
load_llff.py Normal file
View file

@ -0,0 +1,319 @@
import numpy as np
import os, imageio
########## Slightly modified version of LLFF data loading code
########## see https://github.com/Fyusion/LLFF for original
def _minify(basedir, factors=[], resolutions=[]):
needtoload = False
for r in factors:
imgdir = os.path.join(basedir, 'images_{}'.format(r))
if not os.path.exists(imgdir):
needtoload = True
for r in resolutions:
imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
if not os.path.exists(imgdir):
needtoload = True
if not needtoload:
return
from shutil import copy
from subprocess import check_output
imgdir = os.path.join(basedir, 'images')
imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
imgdir_orig = imgdir
wd = os.getcwd()
for r in factors + resolutions:
if isinstance(r, int):
name = 'images_{}'.format(r)
resizearg = '{}%'.format(100./r)
else:
name = 'images_{}x{}'.format(r[1], r[0])
resizearg = '{}x{}'.format(r[1], r[0])
imgdir = os.path.join(basedir, name)
if os.path.exists(imgdir):
continue
print('Minifying', r, basedir)
os.makedirs(imgdir)
check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
ext = imgs[0].split('.')[-1]
args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
print(args)
os.chdir(imgdir)
check_output(args, shell=True)
os.chdir(wd)
if ext != 'png':
check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
print('Removed duplicates')
print('Done')
def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
bds = poses_arr[:, -2:].transpose([1,0])
img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
sh = imageio.imread(img0).shape
sfx = ''
if factor is not None:
sfx = '_{}'.format(factor)
_minify(basedir, factors=[factor])
factor = factor
elif height is not None:
factor = sh[0] / float(height)
width = int(sh[1] / factor)
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
elif width is not None:
factor = sh[1] / float(width)
height = int(sh[0] / factor)
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
else:
factor = 1
imgdir = os.path.join(basedir, 'images' + sfx)
if not os.path.exists(imgdir):
print( imgdir, 'does not exist, returning' )
return
imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
if poses.shape[-1] != len(imgfiles):
print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
return
sh = imageio.imread(imgfiles[0]).shape
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1./factor
if not load_imgs:
return poses, bds
def imread(f):
if f.endswith('png'):
return imageio.imread(f, ignoregamma=True)
else:
return imageio.imread(f)
imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles]
imgs = np.stack(imgs, -1)
print('Loaded image data', imgs.shape, poses[:,-1,0])
return poses, bds, imgs
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def ptstocam(pts, c2w):
tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
return tt
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.])
hwf = c2w[:,4:5]
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def recenter_poses(poses):
poses_ = poses+0
bottom = np.reshape([0,0,0,1.], [1,4])
c2w = poses_avg(poses)
c2w = np.concatenate([c2w[:3,:4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
poses = np.concatenate([poses[:,:3,:4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:,:3,:4] = poses[:,:3,:4]
poses = poses_
return poses
#####################
def spherify_poses(poses, bds):
p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
rays_d = poses[:,:3,2:3]
rays_o = poses[:,:3,3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:,:3,3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([.1,.2,.3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
sc = 1./rad
poses_reset[:,:3,3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:,:3,3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad**2-zh**2)
new_poses = []
for th in np.linspace(0.,2.*np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0,0,-1.])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
return poses_reset, new_poses, bds
def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False):
poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
print('Loaded', basedir, bds.min(), bds.max())
# Correct rotation matrix ordering and move variable dim to axis 0
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
images = imgs
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
# Rescale if bd_factor is provided
sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
poses[:,:3,3] *= sc
bds *= sc
if recenter:
poses = recenter_poses(poses)
if spherify:
poses, render_poses, bds = spherify_poses(poses, bds)
else:
c2w = poses_avg(poses)
print('recentered', c2w.shape)
print(c2w[:3,:4])
## Get spiral
# Get average pose
up = normalize(poses[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
close_depth, inf_depth = bds.min()*.9, bds.max()*5.
dt = .75
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
focal = mean_dz
# Get radii for spiral path
shrink_factor = .8
zdelta = close_depth * .2
tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
rads = np.percentile(np.abs(tt), 90, 0)
c2w_path = c2w
N_views = 120
N_rots = 2
if path_zflat:
# zloc = np.percentile(tt, 10, 0)[2]
zloc = -close_depth * .1
c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2]
rads[2] = 0.
N_rots = 1
N_views/=2
# Generate poses for spiral path
render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
render_poses = np.array(render_poses).astype(np.float32)
c2w = poses_avg(poses)
print('Data:')
print(poses.shape, images.shape, bds.shape)
dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1)
i_test = np.argmin(dists)
print('HOLDOUT view is', i_test)
images = images.astype(np.float32)
poses = poses.astype(np.float32)
return images, poses, bds, render_poses, i_test

6
requirements.txt Normal file
View file

@ -0,0 +1,6 @@
torch>=1.4.0
torchvision>=0.2.1
imageio
imageio-ffmpeg
matplotlib
configargparse

736
run_nerf.py Normal file
View file

@ -0,0 +1,736 @@
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt
from run_nerf_helpers import *
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False
def batchify(fn, chunk):
if chunk is None:
return fn
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, focal, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
# Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# Render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
H, W, focal = hwf
if render_factor!=0:
# Render downsampled for speed
H = H//render_factor
W = W//render_factor
focal = focal/render_factor
rgbs = []
disps = []
t = time.time()
for i, c2w in enumerate(tqdm(render_poses)):
print(i, time.time() - t)
t = time.time()
rgb, disp, acc, _ = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgbs.append(rgb.cpu().numpy())
disps.append(disp.cpu().numpy())
if i==0:
print(rgb.shape, disp.shape)
"""
if gt_imgs is not None and render_factor==0:
p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
print(p)
"""
if savedir is not None:
rgb8 = to8b(rgbs[-1])
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
rgbs = np.stack(rgbs, 0)
disps = np.stack(disps, 0)
return rgbs, disps
def create_nerf(args):
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
grad_vars = list(model.parameters())
model_fine = None
if args.N_importance > 0:
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
grad_vars += list(model_fine.parameters())
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)
# Create optimizer
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
start = 0
basedir = args.basedir
expname = args.expname
##########################
# Load checkpoints
if args.ft_path is not None and args.ft_path!='None':
ckpts = [args.ft_path]
else:
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
print('Found ckpts', ckpts)
if len(ckpts) > 0 and not args.no_reload:
ckpt_path = ckpts[-1]
print('Reloading from', ckpt_path)
ckpt = torch.load(ckpt_path)
start = ckpt['global_step'] + 1
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
# Load model
model.load_state_dict(ckpt['network_fn_state_dict'])
if model_fine is not None:
model_fine.load_state_dict(ckpt['network_fine_state_dict'])
##########################
render_kwargs_train = {
'network_query_fn' : network_query_fn,
'perturb' : args.perturb,
'N_importance' : args.N_importance,
'network_fine' : model_fine,
'N_samples' : args.N_samples,
'network_fn' : model,
'use_viewdirs' : args.use_viewdirs,
'white_bkgd' : args.white_bkgd,
'raw_noise_std' : args.raw_noise_std,
}
# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
""" A helper function for `render_rays`.
"""
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
dists = z_vals[...,1:] - z_vals[...,:-1]
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[...,3].shape) * raw_noise_std
# Overwrite randomly sampled data if pytest
if pytest:
np.random.seed(0)
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
noise = torch.Tensor(noise)
alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
# weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
depth_map = torch.sum(weights * z_vals, -1)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
acc_map = torch.sum(weights, -1)
if white_bkgd:
rgb_map = rgb_map + (1.-acc_map[...,None])
return rgb_map, disp_map, acc_map, weights, depth_map
def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
retraw=False,
lindisp=False,
perturb=0.,
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False,
pytest=False):
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], bounds[...,1] # [-1,1]
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
z_vals = z_vals.expand([N_rays, N_samples])
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
# raw = run_network(pts)
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
if N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
run_fn = network_fn if network_fine is None else network_fine
# raw = run_network(pts, fn=run_fn)
raw = network_query_fn(pts, viewdirs, run_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
for k in ret:
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, help='config file path')
parser.add_argument("--expname", type=str, help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
# rendering options
parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
# dataset options
parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
## deepvoxels flags
parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase')
## blender flags
parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
## llff flags
parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
return parser
def train():
parser = config_parser()
args = parser.parse_args()
# Load data
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
spherify=args.spherify)
hwf = poses[0,:3,-1]
poses = poses[:,:3,:4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
if not isinstance(i_test, list):
i_test = [i_test]
if args.llffhold > 0:
print('Auto LLFF holdout,', args.llffhold)
i_test = np.arange(images.shape[0])[::args.llffhold]
i_val = i_test
i_train = np.array([i for i in np.arange(int(images.shape[0])) if
(i not in i_test and i not in i_val)])
print('DEFINING BOUNDS')
if args.no_ndc:
near = torch.min(bds) * .9
far = torch.max(bds) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
elif args.dataset_type == 'blender':
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
near = 2.
far = 6.
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]
elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
basedir=args.datadir,
testskip=args.testskip)
print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
near = hemi_R-1.
far = hemi_R+1.
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if args.render_test:
render_poses = np.array(poses[i_test])
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
global_step = start
bds_dict = {
'near' : near,
'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)
# Short circuit if only rendering out from trained model
if args.render_only:
print('RENDER ONLY')
with torch.no_grad():
if args.render_test:
# render_test switches to test poses
images = images[i_test]
else:
# Default is smoother render_poses path
images = None
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
return
# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# For random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
print('done')
i_batch = 0
# Move training data to GPU
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 1000000
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
# Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
for i in range(start, N_iters):
time0 = time.time()
# Sample random ray batch
if use_batching:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
else:
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
pose = poses[img_i, :3,:4]
if N_rand is not None:
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)
optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
loss.backward()
# NOTE: same as tf till here - 04/03/2020
optimizer.step()
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################
dt = time.time()-time0
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end #####
# Rest is logging
if i%args.i_weights==0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
torch.save({
'global_step': global_step,
'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
print('Saved checkpoints at', path)
if i%args.i_video==0 and i > 0:
# Turn on testing mode
with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
# if args.use_viewdirs:
# render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
# with torch.no_grad():
# rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
# render_kwargs_test['c2w_staticcam'] = None
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
if i%args.i_testset==0 and i > 0:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
with torch.no_grad():
render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set')
"""
if i%args.i_print==0 or i < 10:
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)
if i%args.i_img==0:
# Log a rendered validation view to Tensorboard
img_i=np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
"""
global_step += 1
if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
train()

242
run_nerf_helpers.py Normal file
View file

@ -0,0 +1,242 @@
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# TODO: remove this dependency
from torchsearchsorted import searchsorted
# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
# Positional encoding (section 5.1)
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
# Model
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return outputs
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
# Ray helpers
def get_rays(H, W, focal, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
def get_rays_np(H, W, focal, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[...,2]) / rays_d[...,2]
rays_o = rays_o + t[...,None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = searchsorted(cdf, u, side='right')
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
return samples

158
torchsearchsorted/.gitignore vendored Normal file
View file

@ -0,0 +1,158 @@
# Prerequisites
*.d
# Object files
*.o
*.ko
*.obj
*.elf
# Linker output
*.ilk
*.map
*.exp
# Precompiled Headers
*.gch
*.pch
# Libraries
*.lib
*.a
*.la
*.lo
# Shared objects (inc. Windows DLLs)
*.dll
*.so
*.so.*
*.dylib
# Executables
*.exe
*.out
*.app
*.i*86
*.x86_64
*.hex
# Debug files
*.dSYM/
*.su
*.idb
*.pdb
# Kernel Module Compile Results
*.mod*
*.cmd
.tmp_versions/
modules.order
Module.symvers
Mkfile.old
dkms.conf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/

29
torchsearchsorted/LICENSE Normal file
View file

@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2019, Inria (Antoine Liutkus)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,89 @@
# Pytorch Custom CUDA kernel for searchsorted
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
> Warnings:
> * only works with pytorch > v1.3 and CUDA >= v10.1
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
## Description
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
**Disclaimers**
* This function has not been heavily tested. Use at your own risks
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
## Installation
Just `pip install .`, in the root folder of this repo. This will compile
and install the torchsearchsorted module.
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
So I had to do:
> sudo apt-get install g++-8 gcc-8
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
## Usage
Just import the torchsearchsorted package after installation. I typically do:
```
from torchsearchsorted import searchsorted
```
## Testing
Under the `examples` subfolder, you may:
1. try `python test.py` with `torch` available.
```
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4851.592ms
CPU: searchsorted in 4805.432ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 1.055ms
difference between GPU and NUMPY: 0.000
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4333.964ms
CPU: searchsorted in 4753.958ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 0.391ms
difference between GPU and NUMPY: 0.000
```
The first run comprises the time of allocation, while the second one does not.
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
```
Benchmark searchsorted:
- a [5000 x 300]
- v [5000 x 100]
- reporting fastest time of 20 runs
- each run executes searchsorted 100 times
Numpy: 4.6302046799100935
CPU: 5.041533078998327
CUDA: 0.0007955809123814106
```

View file

@ -0,0 +1,71 @@
import timeit
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
B = 5_000
A = 300
V = 100
repeats = 20
number = 100
print(
f'Benchmark searchsorted:',
f'- a [{B} x {A}]',
f'- v [{B} x {V}]',
f'- reporting fastest time of {repeats} runs',
f'- each run executes searchsorted {number} times',
sep='\n',
end='\n\n'
)
def get_arrays():
a = np.sort(np.random.randn(B, A), axis=1)
v = np.random.randn(B, V)
out = np.empty_like(v, dtype=np.long)
return a, v, out
def get_tensors(device):
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
v = torch.randn(B, V, device=device)
out = torch.empty(B, V, device=device, dtype=torch.long)
if torch.cuda.is_available():
torch.cuda.synchronize()
return a, v, out
def searchsorted_synchronized(a,v,out=None,side='left'):
out = searchsorted(a,v,out,side)
torch.cuda.synchronize()
return out
numpy = timeit.repeat(
stmt="numpy_searchsorted(a, v, side='left')",
setup="a, v, out = get_arrays()",
globals=globals(),
repeat=repeats,
number=number
)
print('Numpy: ', min(numpy), sep='\t')
cpu = timeit.repeat(
stmt="searchsorted(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cpu')",
globals=globals(),
repeat=repeats,
number=number
)
print('CPU: ', min(cpu), sep='\t')
if torch.cuda.is_available():
gpu = timeit.repeat(
stmt="searchsorted_synchronized(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cuda')",
globals=globals(),
repeat=repeats,
number=number
)
print('CUDA: ', min(gpu), sep='\t')

View file

@ -0,0 +1,66 @@
import torch
from torchsearchsorted import searchsorted, numpy_searchsorted
import time
if __name__ == '__main__':
# defining the number of tests
ntests = 2
# defining the problem dimensions
nrows_a = 50000
nrows_v = 50000
nsorted_values = 300
nvalues = 1000
# defines the variables. The first run will comprise allocation, the
# further ones will not
test_GPU = None
test_CPU = None
for ntest in range(ntests):
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
nrows_a,
nsorted_values))
side = 'right'
# generate a matrix with sorted rows
a = torch.randn(nrows_a, nsorted_values, device='cpu')
a = torch.sort(a, dim=1)[0]
# generate a matrix of values to searchsort
v = torch.randn(nrows_v, nvalues, device='cpu')
# a = torch.tensor([[0., 1.]])
# v = torch.tensor([[1.]])
t0 = time.time()
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
t0 = time.time()
test_CPU = searchsorted(a, v, test_CPU, side)
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CPU = torch.norm(test_NP.double()
- test_CPU.double()).numpy()
if error_CPU:
import ipdb; ipdb.set_trace()
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
if not torch.cuda.is_available():
print('CUDA is not available on this machine, cannot go further.')
continue
else:
# now do the CPU
a = a.to('cuda')
v = v.to('cuda')
torch.cuda.synchronize()
# launch searchsorted on those
t0 = time.time()
test_GPU = searchsorted(a, v, test_GPU, side)
torch.cuda.synchronize()
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CUDA = torch.norm(test_NP.to('cuda').double()
- test_GPU.double()).cpu().numpy()
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)

View file

@ -0,0 +1,41 @@
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
from torch.utils.cpp_extension import CppExtension, CUDAExtension
# In any case, include the CPU version
modules = [
CppExtension('torchsearchsorted.cpu',
['src/cpu/searchsorted_cpu_wrapper.cpp']),
]
# If nvcc is available, add the CUDA extension
if CUDA_HOME:
modules.append(
CUDAExtension('torchsearchsorted.cuda',
['src/cuda/searchsorted_cuda_wrapper.cpp',
'src/cuda/searchsorted_cuda_kernel.cu'])
)
tests_require = [
'pytest',
]
# Now proceed to setup
setup(
name='torchsearchsorted',
version='1.1',
description='A searchsorted implementation for pytorch',
keywords='searchsorted',
author='Antoine Liutkus',
author_email='antoine.liutkus@inria.fr',
packages=find_packages(where='src'),
package_dir={"": "src"},
ext_modules=modules,
tests_require=tests_require,
extras_require={
'test': tests_require,
},
cmdclass={
'build_ext': BuildExtension
}
)

View file

@ -0,0 +1,126 @@
#include "searchsorted_cpu_wrapper.h"
#include <stdio.h>
template<typename scalar_t>
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template<typename scalar_t>
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
returns:
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left)
{
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(nrow_a, nrow_v);
//auto acc_v = v.accessor<float, 2>();
//auto acc_res = res.accessor<float, 2>();
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
scalar_t* a_data = a.data_ptr<scalar_t>();
scalar_t* v_data = v.data_ptr<scalar_t>();
int64_t* res_data = res.data<int64_t>();
for (int64_t row = 0; row < nrow_res; row++)
{
for (int64_t col = 0; col < ncol_v; col++)
{
// get the value to look for
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
int64_t idx_in_v = row_in_v * ncol_v + col;
int64_t idx_in_res = row * ncol_v + col;
// apply binary search
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
}
}
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
}

View file

@ -0,0 +1,12 @@
#ifndef _SEARCHSORTED_CPU
#define _SEARCHSORTED_CPU
#include <torch/extension.h>
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif

View file

@ -0,0 +1,142 @@
#include "searchsorted_cuda_kernel.h"
template <typename scalar_t>
__device__
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template <typename scalar_t>
__device__
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
Returns
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
template <typename scalar_t>
__global__
void searchsorted_kernel(
int64_t *res,
scalar_t *a,
scalar_t *v,
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
{
// get current row and column
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
// check whether we are outside the bounds of what needs be computed.
if ((row >= nrow_res) || (col >= ncol_v)) {
return;}
// get the value to look for
int64_t row_in_v = (nrow_v==1) ? 0: row;
int64_t row_in_a = (nrow_a==1) ? 0: row;
int64_t idx_in_v = row_in_v*ncol_v+col;
int64_t idx_in_res = row*ncol_v+col;
// apply binary search
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
}
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left){
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
// prepare the kernel configuration
dim3 threads(ncol_v, nrow_res);
dim3 blocks(1, 1);
if (nrow_res*ncol_v > 1024){
threads.x = int(fmin(double(1024), double(ncol_v)));
threads.y = floor(1024/threads.x);
blocks.x = ceil(double(ncol_v)/double(threads.x));
blocks.y = ceil(double(nrow_res)/double(threads.y));
}
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
res.data<int64_t>(),
a.data<scalar_t>(),
v.data<scalar_t>(),
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
}));
}

View file

@ -0,0 +1,12 @@
#ifndef _SEARCHSORTED_CUDA_KERNEL
#define _SEARCHSORTED_CUDA_KERNEL
#include <torch/extension.h>
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif

View file

@ -0,0 +1,20 @@
#include "searchsorted_cuda_wrapper.h"
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
{
CHECK_INPUT(a);
CHECK_INPUT(v);
CHECK_INPUT(res);
searchsorted_cuda(a, v, res, side_left);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
}

View file

@ -0,0 +1,13 @@
#ifndef _SEARCHSORTED_CUDA_WRAPPER
#define _SEARCHSORTED_CUDA_WRAPPER
#include <torch/extension.h>
#include "searchsorted_cuda_kernel.h"
void searchsorted_cuda_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif

View file

@ -0,0 +1,2 @@
from .searchsorted import searchsorted
from .utils import numpy_searchsorted

View file

@ -0,0 +1,53 @@
from typing import Optional
import torch
# trying to import the CPU searchsorted
SEARCHSORTED_CPU_AVAILABLE = True
try:
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
except ImportError:
SEARCHSORTED_CPU_AVAILABLE = False
# trying to import the CUDA searchsorted
SEARCHSORTED_GPU_AVAILABLE = True
try:
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
except ImportError:
SEARCHSORTED_GPU_AVAILABLE = False
def searchsorted(a: torch.Tensor, v: torch.Tensor,
out: Optional[torch.LongTensor] = None,
side='left') -> torch.LongTensor:
assert len(a.shape) == 2, "input `a` must be 2-D."
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
assert (a.shape[0] == v.shape[0]
or a.shape[0] == 1
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
"rows or one of them must have only one ")
assert a.device == v.device, '`a` and `v` must be on the same device'
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
if out is not None:
assert out.device == a.device, "`out` must be on the same device as `a`"
assert out.dtype == torch.long, "out.dtype must be torch.long"
assert out.shape == result_shape, ("If the output tensor is provided, "
"its shape must be correct.")
else:
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
'that it is not available. Please install it')
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
raise Exception('torchsearchsorted on CPU is not available. '
'Please install it.')
left_side = 1 if side=='left' else 0
if a.is_cuda:
searchsorted_cuda_wrapper(a, v, out, left_side)
else:
searchsorted_cpu_wrapper(a, v, out, left_side)
return out

View file

@ -0,0 +1,15 @@
import numpy as np
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
"""
nrows_a = a.shape[0]
(nrows_v, ncols_v) = v.shape
nrows_out = max(nrows_a, nrows_v)
out = np.empty((nrows_out, ncols_v), dtype=np.long)
def sel(data, row):
return data[0] if data.shape[0] == 1 else data[row]
for row in range(nrows_out):
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
return out

View file

@ -0,0 +1,11 @@
import pytest
import torch
devices = {'cpu': torch.device('cpu')}
if torch.cuda.is_available():
devices['cuda'] = torch.device('cuda:0')
@pytest.fixture(params=devices.values(), ids=devices.keys())
def device(request):
return request.param

View file

@ -0,0 +1,44 @@
import pytest
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
from itertools import product, repeat
def test_searchsorted_output_dtype(device):
B = 100
A = 50
V = 12
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
v = torch.rand(B, A, device=device)
out = searchsorted(a, v)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
out = torch.empty(v.shape, dtype=torch.long, device=device)
searchsorted(a, v, out)
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
Ba_val = [1, 100, 200]
Bv_val = [1, 100, 200]
A_val = [1, 50, 500]
V_val = [1, 12, 120]
side_val = ['left', 'right']
nrepeat = 100
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
if Ba > 1 and Bv > 1 and Ba != Bv:
return
for test in range(nrepeat):
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
v = torch.rand(Bv, V, device=device)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
side=side)
out = searchsorted(a, v, side=side).cpu().numpy()
np.testing.assert_array_equal(out, out_np)