v0.1 release to public
This commit is contained in:
commit
bcb670095e
35 changed files with 2698 additions and 0 deletions
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
**/.ipynb_checkpoints
|
||||||
|
**/__pycache__
|
||||||
|
*.png
|
||||||
|
*.mp4
|
||||||
|
*.npy
|
||||||
|
*.npz
|
||||||
|
*.dae
|
||||||
|
data/*
|
||||||
|
logs/*
|
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 bmild
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
164
README.md
Normal file
164
README.md
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# NeRF-pytorch
|
||||||
|
|
||||||
|
|
||||||
|
[NeRF](http://www.matthewtancik.com/nerf) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
|
||||||
|
|
||||||
|
![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif)
|
||||||
|
![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif)
|
||||||
|
|
||||||
|
This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is tested to match authors' Tensorflow implementation [here](https://github.com/bmild/nerf) numerically.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/yenchenlin/nerf-pytorch.git
|
||||||
|
cd nerf-pytorch
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cd torchsearchsorted
|
||||||
|
pip install .
|
||||||
|
cd ../
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Dependencies (click to expand) </summary>
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- PyTorch 1.4
|
||||||
|
- matplotlib
|
||||||
|
- numpy
|
||||||
|
- imageio
|
||||||
|
- imageio-ffmpeg
|
||||||
|
- configargparse
|
||||||
|
|
||||||
|
The LLFF data loader requires ImageMagick.
|
||||||
|
|
||||||
|
You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How To Run?
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
Download data for two example datasets: `lego` and `fern`
|
||||||
|
```
|
||||||
|
bash download_example_data.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
To train a low-res `lego` NeRF:
|
||||||
|
```
|
||||||
|
python run_nerf_torch.py --config configs/config_lego.txt
|
||||||
|
```
|
||||||
|
After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
|
||||||
|
|
||||||
|
![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
To train a low-res `fern` NeRF:
|
||||||
|
```
|
||||||
|
python run_nerf_torch.py --config configs/config_fern.txt
|
||||||
|
```
|
||||||
|
After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
|
||||||
|
|
||||||
|
![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### More Datasets
|
||||||
|
To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
|
||||||
|
```
|
||||||
|
├── configs
|
||||||
|
│ ├── ...
|
||||||
|
│
|
||||||
|
├── data
|
||||||
|
│ ├── nerf_llff_data
|
||||||
|
│ │ └── fern
|
||||||
|
│ │ └── flower # downloaded llff dataset
|
||||||
|
│ │ └── horns # downloaded llff dataset
|
||||||
|
| | └── ...
|
||||||
|
| ├── nerf_synthetic
|
||||||
|
| | └── lego
|
||||||
|
| | └── ship # downloaded synthetic dataset
|
||||||
|
| | └── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
To train NeRF on different datasets:
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_nerf_torch.py --config configs/config_{DATASET}.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
To test NeRF trained on different datasets:
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_nerf_torch.py --config configs/config_{DATASET}.txt --render_only
|
||||||
|
```
|
||||||
|
|
||||||
|
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
|
||||||
|
|
||||||
|
|
||||||
|
### Pre-trained Models
|
||||||
|
|
||||||
|
You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
|
||||||
|
|
||||||
|
```
|
||||||
|
├── logs
|
||||||
|
│ ├── fern_test
|
||||||
|
│ ├── flower_test # downloaded logs
|
||||||
|
│ ├── trex_test # downloaded logs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reproducibility
|
||||||
|
|
||||||
|
Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
|
||||||
|
```
|
||||||
|
git checkout reproduce
|
||||||
|
py.test
|
||||||
|
```
|
||||||
|
|
||||||
|
## Method
|
||||||
|
|
||||||
|
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
|
||||||
|
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
|
||||||
|
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
|
||||||
|
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
|
||||||
|
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
|
||||||
|
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
|
||||||
|
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
|
||||||
|
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
|
||||||
|
\*denotes equal contribution
|
||||||
|
|
||||||
|
<img src='imgs/pipeline.jpg'/>
|
||||||
|
|
||||||
|
> A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
Kudos to the authors for their amazing results:
|
||||||
|
```
|
||||||
|
@misc{mildenhall2020nerf,
|
||||||
|
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
|
||||||
|
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
|
||||||
|
year={2020},
|
||||||
|
eprint={2003.08934},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
However, if you find this implementation or pre-trained models helpful, please consider to cite:
|
||||||
|
```
|
||||||
|
@misc{lin2020nerfpytorch,
|
||||||
|
title={NeRF-pytorch},
|
||||||
|
author={Yen-Chen, Lin},
|
||||||
|
howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
|
||||||
|
year={2020}
|
||||||
|
}
|
||||||
|
```
|
15
configs/config_fern.txt
Normal file
15
configs/config_fern.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = fern_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_llff_data/fern
|
||||||
|
dataset_type = llff
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
llffhold = 8
|
||||||
|
|
||||||
|
N_rand = 1024
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
raw_noise_std = 1e0
|
||||||
|
|
15
configs/config_flower.txt
Normal file
15
configs/config_flower.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = flower_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_llff_data/flower
|
||||||
|
dataset_type = llff
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
llffhold = 8
|
||||||
|
|
||||||
|
N_rand = 1024
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
raw_noise_std = 1e0
|
||||||
|
|
15
configs/config_fortress.txt
Normal file
15
configs/config_fortress.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = fortress_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_llff_data/fortress
|
||||||
|
dataset_type = llff
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
llffhold = 8
|
||||||
|
|
||||||
|
N_rand = 1024
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
raw_noise_std = 1e0
|
||||||
|
|
15
configs/config_horns.txt
Normal file
15
configs/config_horns.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = horns_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_llff_data/horns
|
||||||
|
dataset_type = llff
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
llffhold = 8
|
||||||
|
|
||||||
|
N_rand = 1024
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
raw_noise_std = 1e0
|
||||||
|
|
15
configs/config_lego.txt
Normal file
15
configs/config_lego.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = lego_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_synthetic/lego
|
||||||
|
dataset_type = blender
|
||||||
|
|
||||||
|
half_res = True
|
||||||
|
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
|
||||||
|
white_bkgd = True
|
||||||
|
|
||||||
|
N_rand = 1024
|
15
configs/config_trex.txt
Normal file
15
configs/config_trex.txt
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
expname = trex_test
|
||||||
|
basedir = ./logs
|
||||||
|
datadir = ./data/nerf_llff_data/trex
|
||||||
|
dataset_type = llff
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
llffhold = 8
|
||||||
|
|
||||||
|
N_rand = 1024
|
||||||
|
N_samples = 64
|
||||||
|
N_importance = 64
|
||||||
|
|
||||||
|
use_viewdirs = True
|
||||||
|
raw_noise_std = 1e0
|
||||||
|
|
6
download_example_data.sh
Normal file
6
download_example_data.sh
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
|
||||||
|
mkdir -p data
|
||||||
|
cd data
|
||||||
|
wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
|
||||||
|
unzip nerf_example_data.zip
|
||||||
|
cd ..
|
BIN
imgs/pipeline.jpg
Normal file
BIN
imgs/pipeline.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 342 KiB |
91
load_blender.py
Normal file
91
load_blender.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
import json
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
trans_t = lambda t : torch.Tensor([
|
||||||
|
[1,0,0,0],
|
||||||
|
[0,1,0,0],
|
||||||
|
[0,0,1,t],
|
||||||
|
[0,0,0,1]]).float()
|
||||||
|
|
||||||
|
rot_phi = lambda phi : torch.Tensor([
|
||||||
|
[1,0,0,0],
|
||||||
|
[0,np.cos(phi),-np.sin(phi),0],
|
||||||
|
[0,np.sin(phi), np.cos(phi),0],
|
||||||
|
[0,0,0,1]]).float()
|
||||||
|
|
||||||
|
rot_theta = lambda th : torch.Tensor([
|
||||||
|
[np.cos(th),0,-np.sin(th),0],
|
||||||
|
[0,1,0,0],
|
||||||
|
[np.sin(th),0, np.cos(th),0],
|
||||||
|
[0,0,0,1]]).float()
|
||||||
|
|
||||||
|
|
||||||
|
def pose_spherical(theta, phi, radius):
|
||||||
|
c2w = trans_t(radius)
|
||||||
|
c2w = rot_phi(phi/180.*np.pi) @ c2w
|
||||||
|
c2w = rot_theta(theta/180.*np.pi) @ c2w
|
||||||
|
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
|
||||||
|
return c2w
|
||||||
|
|
||||||
|
|
||||||
|
def load_blender_data(basedir, half_res=False, testskip=1):
|
||||||
|
splits = ['train', 'val', 'test']
|
||||||
|
metas = {}
|
||||||
|
for s in splits:
|
||||||
|
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
|
||||||
|
metas[s] = json.load(fp)
|
||||||
|
|
||||||
|
all_imgs = []
|
||||||
|
all_poses = []
|
||||||
|
counts = [0]
|
||||||
|
for s in splits:
|
||||||
|
meta = metas[s]
|
||||||
|
imgs = []
|
||||||
|
poses = []
|
||||||
|
if s=='train' or testskip==0:
|
||||||
|
skip = 1
|
||||||
|
else:
|
||||||
|
skip = testskip
|
||||||
|
|
||||||
|
for frame in meta['frames'][::skip]:
|
||||||
|
fname = os.path.join(basedir, frame['file_path'] + '.png')
|
||||||
|
imgs.append(imageio.imread(fname))
|
||||||
|
poses.append(np.array(frame['transform_matrix']))
|
||||||
|
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
|
||||||
|
poses = np.array(poses).astype(np.float32)
|
||||||
|
counts.append(counts[-1] + imgs.shape[0])
|
||||||
|
all_imgs.append(imgs)
|
||||||
|
all_poses.append(poses)
|
||||||
|
|
||||||
|
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
|
||||||
|
|
||||||
|
imgs = np.concatenate(all_imgs, 0)
|
||||||
|
poses = np.concatenate(all_poses, 0)
|
||||||
|
|
||||||
|
H, W = imgs[0].shape[:2]
|
||||||
|
camera_angle_x = float(meta['camera_angle_x'])
|
||||||
|
focal = .5 * W / np.tan(.5 * camera_angle_x)
|
||||||
|
|
||||||
|
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
|
||||||
|
|
||||||
|
if half_res:
|
||||||
|
H = H//2
|
||||||
|
W = W//2
|
||||||
|
focal = focal/2.
|
||||||
|
|
||||||
|
imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
|
||||||
|
for i, img in enumerate(imgs):
|
||||||
|
imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA)
|
||||||
|
imgs = imgs_half_res
|
||||||
|
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
|
||||||
|
|
||||||
|
|
||||||
|
return imgs, poses, render_poses, [H, W, focal], i_split
|
||||||
|
|
||||||
|
|
110
load_deepvoxels.py
Normal file
110
load_deepvoxels.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
|
||||||
|
def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
|
||||||
|
|
||||||
|
|
||||||
|
def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
|
||||||
|
# Get camera intrinsics
|
||||||
|
with open(filepath, 'r') as file:
|
||||||
|
f, cx, cy = list(map(float, file.readline().split()))[:3]
|
||||||
|
grid_barycenter = np.array(list(map(float, file.readline().split())))
|
||||||
|
near_plane = float(file.readline())
|
||||||
|
scale = float(file.readline())
|
||||||
|
height, width = map(float, file.readline().split())
|
||||||
|
|
||||||
|
try:
|
||||||
|
world2cam_poses = int(file.readline())
|
||||||
|
except ValueError:
|
||||||
|
world2cam_poses = None
|
||||||
|
|
||||||
|
if world2cam_poses is None:
|
||||||
|
world2cam_poses = False
|
||||||
|
|
||||||
|
world2cam_poses = bool(world2cam_poses)
|
||||||
|
|
||||||
|
print(cx,cy,f,height,width)
|
||||||
|
|
||||||
|
cx = cx / width * trgt_sidelength
|
||||||
|
cy = cy / height * trgt_sidelength
|
||||||
|
f = trgt_sidelength / height * f
|
||||||
|
|
||||||
|
fx = f
|
||||||
|
if invert_y:
|
||||||
|
fy = -f
|
||||||
|
else:
|
||||||
|
fy = f
|
||||||
|
|
||||||
|
# Build the intrinsic matrices
|
||||||
|
full_intrinsic = np.array([[fx, 0., cx, 0.],
|
||||||
|
[0., fy, cy, 0],
|
||||||
|
[0., 0, 1, 0],
|
||||||
|
[0, 0, 0, 1]])
|
||||||
|
|
||||||
|
return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
|
||||||
|
|
||||||
|
|
||||||
|
def load_pose(filename):
|
||||||
|
assert os.path.isfile(filename)
|
||||||
|
nums = open(filename).read().split()
|
||||||
|
return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
H = 512
|
||||||
|
W = 512
|
||||||
|
deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
|
||||||
|
|
||||||
|
full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
|
||||||
|
print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
|
||||||
|
focal = full_intrinsic[0,0]
|
||||||
|
print(H, W, focal)
|
||||||
|
|
||||||
|
|
||||||
|
def dir2poses(posedir):
|
||||||
|
poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
|
||||||
|
transf = np.array([
|
||||||
|
[1,0,0,0],
|
||||||
|
[0,-1,0,0],
|
||||||
|
[0,0,-1,0],
|
||||||
|
[0,0,0,1.],
|
||||||
|
])
|
||||||
|
poses = poses @ transf
|
||||||
|
poses = poses[:,:3,:4].astype(np.float32)
|
||||||
|
return poses
|
||||||
|
|
||||||
|
posedir = os.path.join(deepvoxels_base, 'pose')
|
||||||
|
poses = dir2poses(posedir)
|
||||||
|
testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
|
||||||
|
testposes = testposes[::testskip]
|
||||||
|
valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
|
||||||
|
valposes = valposes[::testskip]
|
||||||
|
|
||||||
|
imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
|
||||||
|
imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
testimgd = '{}/test/{}/rgb'.format(basedir, scene)
|
||||||
|
imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
|
||||||
|
testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
|
||||||
|
|
||||||
|
valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
|
||||||
|
imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
|
||||||
|
valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
|
||||||
|
|
||||||
|
all_imgs = [imgs, valimgs, testimgs]
|
||||||
|
counts = [0] + [x.shape[0] for x in all_imgs]
|
||||||
|
counts = np.cumsum(counts)
|
||||||
|
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
|
||||||
|
|
||||||
|
imgs = np.concatenate(all_imgs, 0)
|
||||||
|
poses = np.concatenate([poses, valposes, testposes], 0)
|
||||||
|
|
||||||
|
render_poses = testposes
|
||||||
|
|
||||||
|
print(poses.shape, imgs.shape)
|
||||||
|
|
||||||
|
return imgs, poses, render_poses, [H,W,focal], i_split
|
||||||
|
|
||||||
|
|
319
load_llff.py
Normal file
319
load_llff.py
Normal file
|
@ -0,0 +1,319 @@
|
||||||
|
import numpy as np
|
||||||
|
import os, imageio
|
||||||
|
|
||||||
|
|
||||||
|
########## Slightly modified version of LLFF data loading code
|
||||||
|
########## see https://github.com/Fyusion/LLFF for original
|
||||||
|
|
||||||
|
def _minify(basedir, factors=[], resolutions=[]):
|
||||||
|
needtoload = False
|
||||||
|
for r in factors:
|
||||||
|
imgdir = os.path.join(basedir, 'images_{}'.format(r))
|
||||||
|
if not os.path.exists(imgdir):
|
||||||
|
needtoload = True
|
||||||
|
for r in resolutions:
|
||||||
|
imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
|
||||||
|
if not os.path.exists(imgdir):
|
||||||
|
needtoload = True
|
||||||
|
if not needtoload:
|
||||||
|
return
|
||||||
|
|
||||||
|
from shutil import copy
|
||||||
|
from subprocess import check_output
|
||||||
|
|
||||||
|
imgdir = os.path.join(basedir, 'images')
|
||||||
|
imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
|
||||||
|
imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
|
||||||
|
imgdir_orig = imgdir
|
||||||
|
|
||||||
|
wd = os.getcwd()
|
||||||
|
|
||||||
|
for r in factors + resolutions:
|
||||||
|
if isinstance(r, int):
|
||||||
|
name = 'images_{}'.format(r)
|
||||||
|
resizearg = '{}%'.format(100./r)
|
||||||
|
else:
|
||||||
|
name = 'images_{}x{}'.format(r[1], r[0])
|
||||||
|
resizearg = '{}x{}'.format(r[1], r[0])
|
||||||
|
imgdir = os.path.join(basedir, name)
|
||||||
|
if os.path.exists(imgdir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
print('Minifying', r, basedir)
|
||||||
|
|
||||||
|
os.makedirs(imgdir)
|
||||||
|
check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
|
||||||
|
|
||||||
|
ext = imgs[0].split('.')[-1]
|
||||||
|
args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
|
||||||
|
print(args)
|
||||||
|
os.chdir(imgdir)
|
||||||
|
check_output(args, shell=True)
|
||||||
|
os.chdir(wd)
|
||||||
|
|
||||||
|
if ext != 'png':
|
||||||
|
check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
|
||||||
|
print('Removed duplicates')
|
||||||
|
print('Done')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
|
||||||
|
|
||||||
|
poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
|
||||||
|
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
|
||||||
|
bds = poses_arr[:, -2:].transpose([1,0])
|
||||||
|
|
||||||
|
img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
|
||||||
|
if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
|
||||||
|
sh = imageio.imread(img0).shape
|
||||||
|
|
||||||
|
sfx = ''
|
||||||
|
|
||||||
|
if factor is not None:
|
||||||
|
sfx = '_{}'.format(factor)
|
||||||
|
_minify(basedir, factors=[factor])
|
||||||
|
factor = factor
|
||||||
|
elif height is not None:
|
||||||
|
factor = sh[0] / float(height)
|
||||||
|
width = int(sh[1] / factor)
|
||||||
|
_minify(basedir, resolutions=[[height, width]])
|
||||||
|
sfx = '_{}x{}'.format(width, height)
|
||||||
|
elif width is not None:
|
||||||
|
factor = sh[1] / float(width)
|
||||||
|
height = int(sh[0] / factor)
|
||||||
|
_minify(basedir, resolutions=[[height, width]])
|
||||||
|
sfx = '_{}x{}'.format(width, height)
|
||||||
|
else:
|
||||||
|
factor = 1
|
||||||
|
|
||||||
|
imgdir = os.path.join(basedir, 'images' + sfx)
|
||||||
|
if not os.path.exists(imgdir):
|
||||||
|
print( imgdir, 'does not exist, returning' )
|
||||||
|
return
|
||||||
|
|
||||||
|
imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
|
||||||
|
if poses.shape[-1] != len(imgfiles):
|
||||||
|
print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
|
||||||
|
return
|
||||||
|
|
||||||
|
sh = imageio.imread(imgfiles[0]).shape
|
||||||
|
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
|
||||||
|
poses[2, 4, :] = poses[2, 4, :] * 1./factor
|
||||||
|
|
||||||
|
if not load_imgs:
|
||||||
|
return poses, bds
|
||||||
|
|
||||||
|
def imread(f):
|
||||||
|
if f.endswith('png'):
|
||||||
|
return imageio.imread(f, ignoregamma=True)
|
||||||
|
else:
|
||||||
|
return imageio.imread(f)
|
||||||
|
|
||||||
|
imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles]
|
||||||
|
imgs = np.stack(imgs, -1)
|
||||||
|
|
||||||
|
print('Loaded image data', imgs.shape, poses[:,-1,0])
|
||||||
|
return poses, bds, imgs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
return x / np.linalg.norm(x)
|
||||||
|
|
||||||
|
def viewmatrix(z, up, pos):
|
||||||
|
vec2 = normalize(z)
|
||||||
|
vec1_avg = up
|
||||||
|
vec0 = normalize(np.cross(vec1_avg, vec2))
|
||||||
|
vec1 = normalize(np.cross(vec2, vec0))
|
||||||
|
m = np.stack([vec0, vec1, vec2, pos], 1)
|
||||||
|
return m
|
||||||
|
|
||||||
|
def ptstocam(pts, c2w):
|
||||||
|
tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
|
||||||
|
return tt
|
||||||
|
|
||||||
|
def poses_avg(poses):
|
||||||
|
|
||||||
|
hwf = poses[0, :3, -1:]
|
||||||
|
|
||||||
|
center = poses[:, :3, 3].mean(0)
|
||||||
|
vec2 = normalize(poses[:, :3, 2].sum(0))
|
||||||
|
up = poses[:, :3, 1].sum(0)
|
||||||
|
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
|
||||||
|
|
||||||
|
return c2w
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||||
|
render_poses = []
|
||||||
|
rads = np.array(list(rads) + [1.])
|
||||||
|
hwf = c2w[:,4:5]
|
||||||
|
|
||||||
|
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
|
||||||
|
c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
|
||||||
|
z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
|
||||||
|
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
|
||||||
|
return render_poses
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def recenter_poses(poses):
|
||||||
|
|
||||||
|
poses_ = poses+0
|
||||||
|
bottom = np.reshape([0,0,0,1.], [1,4])
|
||||||
|
c2w = poses_avg(poses)
|
||||||
|
c2w = np.concatenate([c2w[:3,:4], bottom], -2)
|
||||||
|
bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
|
||||||
|
poses = np.concatenate([poses[:,:3,:4], bottom], -2)
|
||||||
|
|
||||||
|
poses = np.linalg.inv(c2w) @ poses
|
||||||
|
poses_[:,:3,:4] = poses[:,:3,:4]
|
||||||
|
poses = poses_
|
||||||
|
return poses
|
||||||
|
|
||||||
|
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
|
def spherify_poses(poses, bds):
|
||||||
|
|
||||||
|
p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
|
||||||
|
|
||||||
|
rays_d = poses[:,:3,2:3]
|
||||||
|
rays_o = poses[:,:3,3:4]
|
||||||
|
|
||||||
|
def min_line_dist(rays_o, rays_d):
|
||||||
|
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
|
||||||
|
b_i = -A_i @ rays_o
|
||||||
|
pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
|
||||||
|
return pt_mindist
|
||||||
|
|
||||||
|
pt_mindist = min_line_dist(rays_o, rays_d)
|
||||||
|
|
||||||
|
center = pt_mindist
|
||||||
|
up = (poses[:,:3,3] - center).mean(0)
|
||||||
|
|
||||||
|
vec0 = normalize(up)
|
||||||
|
vec1 = normalize(np.cross([.1,.2,.3], vec0))
|
||||||
|
vec2 = normalize(np.cross(vec0, vec1))
|
||||||
|
pos = center
|
||||||
|
c2w = np.stack([vec1, vec2, vec0, pos], 1)
|
||||||
|
|
||||||
|
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
|
||||||
|
|
||||||
|
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
|
||||||
|
|
||||||
|
sc = 1./rad
|
||||||
|
poses_reset[:,:3,3] *= sc
|
||||||
|
bds *= sc
|
||||||
|
rad *= sc
|
||||||
|
|
||||||
|
centroid = np.mean(poses_reset[:,:3,3], 0)
|
||||||
|
zh = centroid[2]
|
||||||
|
radcircle = np.sqrt(rad**2-zh**2)
|
||||||
|
new_poses = []
|
||||||
|
|
||||||
|
for th in np.linspace(0.,2.*np.pi, 120):
|
||||||
|
|
||||||
|
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||||
|
up = np.array([0,0,-1.])
|
||||||
|
|
||||||
|
vec2 = normalize(camorigin)
|
||||||
|
vec0 = normalize(np.cross(vec2, up))
|
||||||
|
vec1 = normalize(np.cross(vec2, vec0))
|
||||||
|
pos = camorigin
|
||||||
|
p = np.stack([vec0, vec1, vec2, pos], 1)
|
||||||
|
|
||||||
|
new_poses.append(p)
|
||||||
|
|
||||||
|
new_poses = np.stack(new_poses, 0)
|
||||||
|
|
||||||
|
new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
|
||||||
|
poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
|
||||||
|
|
||||||
|
return poses_reset, new_poses, bds
|
||||||
|
|
||||||
|
|
||||||
|
def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False):
|
||||||
|
|
||||||
|
|
||||||
|
poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
|
||||||
|
print('Loaded', basedir, bds.min(), bds.max())
|
||||||
|
|
||||||
|
# Correct rotation matrix ordering and move variable dim to axis 0
|
||||||
|
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
|
||||||
|
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
|
||||||
|
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
|
||||||
|
images = imgs
|
||||||
|
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
|
||||||
|
|
||||||
|
# Rescale if bd_factor is provided
|
||||||
|
sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
|
||||||
|
poses[:,:3,3] *= sc
|
||||||
|
bds *= sc
|
||||||
|
|
||||||
|
if recenter:
|
||||||
|
poses = recenter_poses(poses)
|
||||||
|
|
||||||
|
if spherify:
|
||||||
|
poses, render_poses, bds = spherify_poses(poses, bds)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
c2w = poses_avg(poses)
|
||||||
|
print('recentered', c2w.shape)
|
||||||
|
print(c2w[:3,:4])
|
||||||
|
|
||||||
|
## Get spiral
|
||||||
|
# Get average pose
|
||||||
|
up = normalize(poses[:, :3, 1].sum(0))
|
||||||
|
|
||||||
|
# Find a reasonable "focus depth" for this dataset
|
||||||
|
close_depth, inf_depth = bds.min()*.9, bds.max()*5.
|
||||||
|
dt = .75
|
||||||
|
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
|
||||||
|
focal = mean_dz
|
||||||
|
|
||||||
|
# Get radii for spiral path
|
||||||
|
shrink_factor = .8
|
||||||
|
zdelta = close_depth * .2
|
||||||
|
tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
|
||||||
|
rads = np.percentile(np.abs(tt), 90, 0)
|
||||||
|
c2w_path = c2w
|
||||||
|
N_views = 120
|
||||||
|
N_rots = 2
|
||||||
|
if path_zflat:
|
||||||
|
# zloc = np.percentile(tt, 10, 0)[2]
|
||||||
|
zloc = -close_depth * .1
|
||||||
|
c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2]
|
||||||
|
rads[2] = 0.
|
||||||
|
N_rots = 1
|
||||||
|
N_views/=2
|
||||||
|
|
||||||
|
# Generate poses for spiral path
|
||||||
|
render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
|
||||||
|
|
||||||
|
|
||||||
|
render_poses = np.array(render_poses).astype(np.float32)
|
||||||
|
|
||||||
|
c2w = poses_avg(poses)
|
||||||
|
print('Data:')
|
||||||
|
print(poses.shape, images.shape, bds.shape)
|
||||||
|
|
||||||
|
dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1)
|
||||||
|
i_test = np.argmin(dists)
|
||||||
|
print('HOLDOUT view is', i_test)
|
||||||
|
|
||||||
|
images = images.astype(np.float32)
|
||||||
|
poses = poses.astype(np.float32)
|
||||||
|
|
||||||
|
return images, poses, bds, render_poses, i_test
|
||||||
|
|
||||||
|
|
||||||
|
|
6
requirements.txt
Normal file
6
requirements.txt
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
torch>=1.4.0
|
||||||
|
torchvision>=0.2.1
|
||||||
|
imageio
|
||||||
|
imageio-ffmpeg
|
||||||
|
matplotlib
|
||||||
|
configargparse
|
736
run_nerf.py
Normal file
736
run_nerf.py
Normal file
|
@ -0,0 +1,736 @@
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from run_nerf_helpers import *
|
||||||
|
|
||||||
|
from load_llff import load_llff_data
|
||||||
|
from load_deepvoxels import load_dv_data
|
||||||
|
from load_blender import load_blender_data
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
np.random.seed(0)
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
|
def batchify(fn, chunk):
|
||||||
|
if chunk is None:
|
||||||
|
return fn
|
||||||
|
def ret(inputs):
|
||||||
|
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
|
||||||
|
|
||||||
|
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
|
||||||
|
embedded = embed_fn(inputs_flat)
|
||||||
|
|
||||||
|
if viewdirs is not None:
|
||||||
|
input_dirs = viewdirs[:,None].expand(inputs.shape)
|
||||||
|
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
|
||||||
|
embedded_dirs = embeddirs_fn(input_dirs_flat)
|
||||||
|
embedded = torch.cat([embedded, embedded_dirs], -1)
|
||||||
|
|
||||||
|
outputs_flat = batchify(fn, netchunk)(embedded)
|
||||||
|
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
|
||||||
|
|
||||||
|
all_ret = {}
|
||||||
|
for i in range(0, rays_flat.shape[0], chunk):
|
||||||
|
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
|
||||||
|
for k in ret:
|
||||||
|
if k not in all_ret:
|
||||||
|
all_ret[k] = []
|
||||||
|
all_ret[k].append(ret[k])
|
||||||
|
|
||||||
|
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
|
||||||
|
return all_ret
|
||||||
|
|
||||||
|
|
||||||
|
def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
|
||||||
|
near=0., far=1.,
|
||||||
|
use_viewdirs=False, c2w_staticcam=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if c2w is not None:
|
||||||
|
# special case to render full image
|
||||||
|
rays_o, rays_d = get_rays(H, W, focal, c2w)
|
||||||
|
else:
|
||||||
|
# use provided ray batch
|
||||||
|
rays_o, rays_d = rays
|
||||||
|
|
||||||
|
if use_viewdirs:
|
||||||
|
# provide ray directions as input
|
||||||
|
viewdirs = rays_d
|
||||||
|
if c2w_staticcam is not None:
|
||||||
|
# special case to visualize effect of viewdirs
|
||||||
|
rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
|
||||||
|
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
|
||||||
|
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
|
||||||
|
|
||||||
|
sh = rays_d.shape # [..., 3]
|
||||||
|
if ndc:
|
||||||
|
# for forward facing scenes
|
||||||
|
rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
|
||||||
|
|
||||||
|
# Create ray batch
|
||||||
|
rays_o = torch.reshape(rays_o, [-1,3]).float()
|
||||||
|
rays_d = torch.reshape(rays_d, [-1,3]).float()
|
||||||
|
|
||||||
|
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
|
||||||
|
rays = torch.cat([rays_o, rays_d, near, far], -1)
|
||||||
|
if use_viewdirs:
|
||||||
|
rays = torch.cat([rays, viewdirs], -1)
|
||||||
|
|
||||||
|
# Render and reshape
|
||||||
|
all_ret = batchify_rays(rays, chunk, **kwargs)
|
||||||
|
for k in all_ret:
|
||||||
|
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
|
||||||
|
all_ret[k] = torch.reshape(all_ret[k], k_sh)
|
||||||
|
|
||||||
|
k_extract = ['rgb_map', 'disp_map', 'acc_map']
|
||||||
|
ret_list = [all_ret[k] for k in k_extract]
|
||||||
|
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
|
||||||
|
return ret_list + [ret_dict]
|
||||||
|
|
||||||
|
|
||||||
|
def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
|
||||||
|
|
||||||
|
H, W, focal = hwf
|
||||||
|
|
||||||
|
if render_factor!=0:
|
||||||
|
# Render downsampled for speed
|
||||||
|
H = H//render_factor
|
||||||
|
W = W//render_factor
|
||||||
|
focal = focal/render_factor
|
||||||
|
|
||||||
|
rgbs = []
|
||||||
|
disps = []
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
for i, c2w in enumerate(tqdm(render_poses)):
|
||||||
|
print(i, time.time() - t)
|
||||||
|
t = time.time()
|
||||||
|
rgb, disp, acc, _ = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
|
||||||
|
rgbs.append(rgb.cpu().numpy())
|
||||||
|
disps.append(disp.cpu().numpy())
|
||||||
|
if i==0:
|
||||||
|
print(rgb.shape, disp.shape)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if gt_imgs is not None and render_factor==0:
|
||||||
|
p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
|
||||||
|
print(p)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if savedir is not None:
|
||||||
|
rgb8 = to8b(rgbs[-1])
|
||||||
|
filename = os.path.join(savedir, '{:03d}.png'.format(i))
|
||||||
|
imageio.imwrite(filename, rgb8)
|
||||||
|
|
||||||
|
|
||||||
|
rgbs = np.stack(rgbs, 0)
|
||||||
|
disps = np.stack(disps, 0)
|
||||||
|
|
||||||
|
return rgbs, disps
|
||||||
|
|
||||||
|
|
||||||
|
def create_nerf(args):
|
||||||
|
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
|
||||||
|
|
||||||
|
input_ch_views = 0
|
||||||
|
embeddirs_fn = None
|
||||||
|
if args.use_viewdirs:
|
||||||
|
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
|
||||||
|
output_ch = 5 if args.N_importance > 0 else 4
|
||||||
|
skips = [4]
|
||||||
|
model = NeRF(D=args.netdepth, W=args.netwidth,
|
||||||
|
input_ch=input_ch, output_ch=output_ch, skips=skips,
|
||||||
|
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
|
||||||
|
grad_vars = list(model.parameters())
|
||||||
|
|
||||||
|
model_fine = None
|
||||||
|
if args.N_importance > 0:
|
||||||
|
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
|
||||||
|
input_ch=input_ch, output_ch=output_ch, skips=skips,
|
||||||
|
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
|
||||||
|
grad_vars += list(model_fine.parameters())
|
||||||
|
|
||||||
|
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
|
||||||
|
embed_fn=embed_fn,
|
||||||
|
embeddirs_fn=embeddirs_fn,
|
||||||
|
netchunk=args.netchunk)
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
basedir = args.basedir
|
||||||
|
expname = args.expname
|
||||||
|
|
||||||
|
##########################
|
||||||
|
|
||||||
|
# Load checkpoints
|
||||||
|
if args.ft_path is not None and args.ft_path!='None':
|
||||||
|
ckpts = [args.ft_path]
|
||||||
|
else:
|
||||||
|
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
|
||||||
|
|
||||||
|
print('Found ckpts', ckpts)
|
||||||
|
if len(ckpts) > 0 and not args.no_reload:
|
||||||
|
ckpt_path = ckpts[-1]
|
||||||
|
print('Reloading from', ckpt_path)
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
|
||||||
|
start = ckpt['global_step'] + 1
|
||||||
|
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model.load_state_dict(ckpt['network_fn_state_dict'])
|
||||||
|
if model_fine is not None:
|
||||||
|
model_fine.load_state_dict(ckpt['network_fine_state_dict'])
|
||||||
|
|
||||||
|
##########################
|
||||||
|
|
||||||
|
render_kwargs_train = {
|
||||||
|
'network_query_fn' : network_query_fn,
|
||||||
|
'perturb' : args.perturb,
|
||||||
|
'N_importance' : args.N_importance,
|
||||||
|
'network_fine' : model_fine,
|
||||||
|
'N_samples' : args.N_samples,
|
||||||
|
'network_fn' : model,
|
||||||
|
'use_viewdirs' : args.use_viewdirs,
|
||||||
|
'white_bkgd' : args.white_bkgd,
|
||||||
|
'raw_noise_std' : args.raw_noise_std,
|
||||||
|
}
|
||||||
|
|
||||||
|
# NDC only good for LLFF-style forward facing data
|
||||||
|
if args.dataset_type != 'llff' or args.no_ndc:
|
||||||
|
print('Not ndc!')
|
||||||
|
render_kwargs_train['ndc'] = False
|
||||||
|
render_kwargs_train['lindisp'] = args.lindisp
|
||||||
|
|
||||||
|
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
|
||||||
|
render_kwargs_test['perturb'] = False
|
||||||
|
render_kwargs_test['raw_noise_std'] = 0.
|
||||||
|
|
||||||
|
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
|
||||||
|
""" A helper function for `render_rays`.
|
||||||
|
"""
|
||||||
|
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
|
||||||
|
|
||||||
|
dists = z_vals[...,1:] - z_vals[...,:-1]
|
||||||
|
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
|
||||||
|
|
||||||
|
dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
|
||||||
|
|
||||||
|
rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
|
||||||
|
noise = 0.
|
||||||
|
if raw_noise_std > 0.:
|
||||||
|
noise = torch.randn(raw[...,3].shape) * raw_noise_std
|
||||||
|
|
||||||
|
# Overwrite randomly sampled data if pytest
|
||||||
|
if pytest:
|
||||||
|
np.random.seed(0)
|
||||||
|
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
|
||||||
|
noise = torch.Tensor(noise)
|
||||||
|
|
||||||
|
alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
|
||||||
|
# weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
|
||||||
|
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
|
||||||
|
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
|
||||||
|
|
||||||
|
depth_map = torch.sum(weights * z_vals, -1)
|
||||||
|
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
|
||||||
|
acc_map = torch.sum(weights, -1)
|
||||||
|
|
||||||
|
if white_bkgd:
|
||||||
|
rgb_map = rgb_map + (1.-acc_map[...,None])
|
||||||
|
|
||||||
|
return rgb_map, disp_map, acc_map, weights, depth_map
|
||||||
|
|
||||||
|
|
||||||
|
def render_rays(ray_batch,
|
||||||
|
network_fn,
|
||||||
|
network_query_fn,
|
||||||
|
N_samples,
|
||||||
|
retraw=False,
|
||||||
|
lindisp=False,
|
||||||
|
perturb=0.,
|
||||||
|
N_importance=0,
|
||||||
|
network_fine=None,
|
||||||
|
white_bkgd=False,
|
||||||
|
raw_noise_std=0.,
|
||||||
|
verbose=False,
|
||||||
|
pytest=False):
|
||||||
|
N_rays = ray_batch.shape[0]
|
||||||
|
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
|
||||||
|
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
|
||||||
|
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
|
||||||
|
near, far = bounds[...,0], bounds[...,1] # [-1,1]
|
||||||
|
|
||||||
|
t_vals = torch.linspace(0., 1., steps=N_samples)
|
||||||
|
if not lindisp:
|
||||||
|
z_vals = near * (1.-t_vals) + far * (t_vals)
|
||||||
|
else:
|
||||||
|
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
|
||||||
|
|
||||||
|
z_vals = z_vals.expand([N_rays, N_samples])
|
||||||
|
|
||||||
|
if perturb > 0.:
|
||||||
|
# get intervals between samples
|
||||||
|
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
|
||||||
|
upper = torch.cat([mids, z_vals[...,-1:]], -1)
|
||||||
|
lower = torch.cat([z_vals[...,:1], mids], -1)
|
||||||
|
# stratified samples in those intervals
|
||||||
|
t_rand = torch.rand(z_vals.shape)
|
||||||
|
|
||||||
|
# Pytest, overwrite u with numpy's fixed random numbers
|
||||||
|
if pytest:
|
||||||
|
np.random.seed(0)
|
||||||
|
t_rand = np.random.rand(*list(z_vals.shape))
|
||||||
|
t_rand = torch.Tensor(t_rand)
|
||||||
|
|
||||||
|
z_vals = lower + (upper - lower) * t_rand
|
||||||
|
|
||||||
|
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
|
||||||
|
|
||||||
|
|
||||||
|
# raw = run_network(pts)
|
||||||
|
raw = network_query_fn(pts, viewdirs, network_fn)
|
||||||
|
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
|
||||||
|
|
||||||
|
if N_importance > 0:
|
||||||
|
|
||||||
|
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
|
||||||
|
|
||||||
|
z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
|
||||||
|
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
|
||||||
|
z_samples = z_samples.detach()
|
||||||
|
|
||||||
|
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
|
||||||
|
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
|
||||||
|
|
||||||
|
run_fn = network_fn if network_fine is None else network_fine
|
||||||
|
# raw = run_network(pts, fn=run_fn)
|
||||||
|
raw = network_query_fn(pts, viewdirs, run_fn)
|
||||||
|
|
||||||
|
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
|
||||||
|
|
||||||
|
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
|
||||||
|
if retraw:
|
||||||
|
ret['raw'] = raw
|
||||||
|
if N_importance > 0:
|
||||||
|
ret['rgb0'] = rgb_map_0
|
||||||
|
ret['disp0'] = disp_map_0
|
||||||
|
ret['acc0'] = acc_map_0
|
||||||
|
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
|
||||||
|
|
||||||
|
for k in ret:
|
||||||
|
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
|
||||||
|
print(f"! [Numerical Error] {k} contains nan or inf.")
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def config_parser():
|
||||||
|
|
||||||
|
import configargparse
|
||||||
|
parser = configargparse.ArgumentParser()
|
||||||
|
parser.add_argument('--config', is_config_file=True, help='config file path')
|
||||||
|
parser.add_argument("--expname", type=str, help='experiment name')
|
||||||
|
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
|
||||||
|
parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
|
||||||
|
|
||||||
|
# training options
|
||||||
|
parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
|
||||||
|
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')
|
||||||
|
parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
|
||||||
|
parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')
|
||||||
|
parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
|
||||||
|
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
|
||||||
|
parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
|
||||||
|
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
|
||||||
|
parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
|
||||||
|
parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
|
||||||
|
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
|
||||||
|
parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
|
||||||
|
|
||||||
|
# rendering options
|
||||||
|
parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
|
||||||
|
parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray')
|
||||||
|
parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
|
||||||
|
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
|
||||||
|
parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
|
||||||
|
parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
|
||||||
|
parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
|
||||||
|
parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
|
||||||
|
|
||||||
|
parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
|
||||||
|
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
|
||||||
|
parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
|
||||||
|
|
||||||
|
# dataset options
|
||||||
|
parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')
|
||||||
|
parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
||||||
|
|
||||||
|
## deepvoxels flags
|
||||||
|
parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase')
|
||||||
|
|
||||||
|
## blender flags
|
||||||
|
parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
|
||||||
|
parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
|
||||||
|
|
||||||
|
## llff flags
|
||||||
|
parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
|
||||||
|
parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
|
||||||
|
parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
|
||||||
|
parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
|
||||||
|
parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
|
||||||
|
|
||||||
|
# logging/saving options
|
||||||
|
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
|
||||||
|
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
|
||||||
|
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
|
||||||
|
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
|
||||||
|
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
|
||||||
|
parser = config_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
|
||||||
|
if args.dataset_type == 'llff':
|
||||||
|
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
|
||||||
|
recenter=True, bd_factor=.75,
|
||||||
|
spherify=args.spherify)
|
||||||
|
hwf = poses[0,:3,-1]
|
||||||
|
poses = poses[:,:3,:4]
|
||||||
|
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
|
||||||
|
if not isinstance(i_test, list):
|
||||||
|
i_test = [i_test]
|
||||||
|
|
||||||
|
if args.llffhold > 0:
|
||||||
|
print('Auto LLFF holdout,', args.llffhold)
|
||||||
|
i_test = np.arange(images.shape[0])[::args.llffhold]
|
||||||
|
|
||||||
|
i_val = i_test
|
||||||
|
i_train = np.array([i for i in np.arange(int(images.shape[0])) if
|
||||||
|
(i not in i_test and i not in i_val)])
|
||||||
|
|
||||||
|
print('DEFINING BOUNDS')
|
||||||
|
if args.no_ndc:
|
||||||
|
near = torch.min(bds) * .9
|
||||||
|
far = torch.max(bds) * 1.
|
||||||
|
else:
|
||||||
|
near = 0.
|
||||||
|
far = 1.
|
||||||
|
print('NEAR FAR', near, far)
|
||||||
|
|
||||||
|
|
||||||
|
elif args.dataset_type == 'blender':
|
||||||
|
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
|
||||||
|
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
|
||||||
|
i_train, i_val, i_test = i_split
|
||||||
|
|
||||||
|
near = 2.
|
||||||
|
far = 6.
|
||||||
|
|
||||||
|
if args.white_bkgd:
|
||||||
|
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
|
||||||
|
else:
|
||||||
|
images = images[...,:3]
|
||||||
|
|
||||||
|
|
||||||
|
elif args.dataset_type == 'deepvoxels':
|
||||||
|
|
||||||
|
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
|
||||||
|
basedir=args.datadir,
|
||||||
|
testskip=args.testskip)
|
||||||
|
|
||||||
|
print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
|
||||||
|
i_train, i_val, i_test = i_split
|
||||||
|
|
||||||
|
hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
|
||||||
|
near = hemi_R-1.
|
||||||
|
far = hemi_R+1.
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('Unknown dataset type', args.dataset_type, 'exiting')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cast intrinsics to right types
|
||||||
|
H, W, focal = hwf
|
||||||
|
H, W = int(H), int(W)
|
||||||
|
hwf = [H, W, focal]
|
||||||
|
|
||||||
|
if args.render_test:
|
||||||
|
render_poses = np.array(poses[i_test])
|
||||||
|
|
||||||
|
|
||||||
|
# Create log dir and copy the config file
|
||||||
|
basedir = args.basedir
|
||||||
|
expname = args.expname
|
||||||
|
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
|
||||||
|
f = os.path.join(basedir, expname, 'args.txt')
|
||||||
|
with open(f, 'w') as file:
|
||||||
|
for arg in sorted(vars(args)):
|
||||||
|
attr = getattr(args, arg)
|
||||||
|
file.write('{} = {}\n'.format(arg, attr))
|
||||||
|
if args.config is not None:
|
||||||
|
f = os.path.join(basedir, expname, 'config.txt')
|
||||||
|
with open(f, 'w') as file:
|
||||||
|
file.write(open(args.config, 'r').read())
|
||||||
|
|
||||||
|
|
||||||
|
# Create nerf model
|
||||||
|
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
|
||||||
|
global_step = start
|
||||||
|
|
||||||
|
bds_dict = {
|
||||||
|
'near' : near,
|
||||||
|
'far' : far,
|
||||||
|
}
|
||||||
|
render_kwargs_train.update(bds_dict)
|
||||||
|
render_kwargs_test.update(bds_dict)
|
||||||
|
|
||||||
|
# Move testing data to GPU
|
||||||
|
render_poses = torch.Tensor(render_poses).to(device)
|
||||||
|
|
||||||
|
# Short circuit if only rendering out from trained model
|
||||||
|
if args.render_only:
|
||||||
|
print('RENDER ONLY')
|
||||||
|
with torch.no_grad():
|
||||||
|
if args.render_test:
|
||||||
|
# render_test switches to test poses
|
||||||
|
images = images[i_test]
|
||||||
|
else:
|
||||||
|
# Default is smoother render_poses path
|
||||||
|
images = None
|
||||||
|
|
||||||
|
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
|
||||||
|
os.makedirs(testsavedir, exist_ok=True)
|
||||||
|
print('test poses shape', render_poses.shape)
|
||||||
|
|
||||||
|
rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
|
||||||
|
print('Done rendering', testsavedir)
|
||||||
|
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare raybatch tensor if batching random rays
|
||||||
|
N_rand = args.N_rand
|
||||||
|
use_batching = not args.no_batching
|
||||||
|
if use_batching:
|
||||||
|
# For random ray batching
|
||||||
|
print('get rays')
|
||||||
|
rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
|
||||||
|
print('done, concats')
|
||||||
|
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
|
||||||
|
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
|
||||||
|
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
|
||||||
|
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
|
||||||
|
rays_rgb = rays_rgb.astype(np.float32)
|
||||||
|
print('shuffle rays')
|
||||||
|
np.random.shuffle(rays_rgb)
|
||||||
|
|
||||||
|
print('done')
|
||||||
|
i_batch = 0
|
||||||
|
|
||||||
|
# Move training data to GPU
|
||||||
|
images = torch.Tensor(images).to(device)
|
||||||
|
poses = torch.Tensor(poses).to(device)
|
||||||
|
if use_batching:
|
||||||
|
rays_rgb = torch.Tensor(rays_rgb).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
N_iters = 1000000
|
||||||
|
print('Begin')
|
||||||
|
print('TRAIN views are', i_train)
|
||||||
|
print('TEST views are', i_test)
|
||||||
|
print('VAL views are', i_val)
|
||||||
|
|
||||||
|
# Summary writers
|
||||||
|
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
|
||||||
|
|
||||||
|
for i in range(start, N_iters):
|
||||||
|
time0 = time.time()
|
||||||
|
|
||||||
|
# Sample random ray batch
|
||||||
|
if use_batching:
|
||||||
|
# Random over all images
|
||||||
|
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
|
||||||
|
batch = torch.transpose(batch, 0, 1)
|
||||||
|
batch_rays, target_s = batch[:2], batch[2]
|
||||||
|
|
||||||
|
i_batch += N_rand
|
||||||
|
if i_batch >= rays_rgb.shape[0]:
|
||||||
|
print("Shuffle data after an epoch!")
|
||||||
|
rand_idx = torch.randperm(rays_rgb.shape[0])
|
||||||
|
rays_rgb = rays_rgb[rand_idx]
|
||||||
|
i_batch = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Random from one image
|
||||||
|
img_i = np.random.choice(i_train)
|
||||||
|
target = images[img_i]
|
||||||
|
pose = poses[img_i, :3,:4]
|
||||||
|
|
||||||
|
if N_rand is not None:
|
||||||
|
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
|
||||||
|
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
|
||||||
|
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
|
||||||
|
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
|
||||||
|
select_coords = coords[select_inds].long() # (N_rand, 2)
|
||||||
|
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||||
|
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||||
|
batch_rays = torch.stack([rays_o, rays_d], 0)
|
||||||
|
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||||
|
|
||||||
|
##### Core optimization loop #####
|
||||||
|
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays,
|
||||||
|
verbose=i < 10, retraw=True,
|
||||||
|
**render_kwargs_train)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
img_loss = img2mse(rgb, target_s)
|
||||||
|
trans = extras['raw'][...,-1]
|
||||||
|
loss = img_loss
|
||||||
|
psnr = mse2psnr(img_loss)
|
||||||
|
|
||||||
|
if 'rgb0' in extras:
|
||||||
|
img_loss0 = img2mse(extras['rgb0'], target_s)
|
||||||
|
loss = loss + img_loss0
|
||||||
|
psnr0 = mse2psnr(img_loss0)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# NOTE: same as tf till here - 04/03/2020
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# NOTE: IMPORTANT!
|
||||||
|
### update learning rate ###
|
||||||
|
decay_rate = 0.1
|
||||||
|
decay_steps = args.lrate_decay * 1000
|
||||||
|
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = new_lrate
|
||||||
|
################################
|
||||||
|
|
||||||
|
dt = time.time()-time0
|
||||||
|
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
|
||||||
|
##### end #####
|
||||||
|
|
||||||
|
# Rest is logging
|
||||||
|
if i%args.i_weights==0:
|
||||||
|
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
|
||||||
|
torch.save({
|
||||||
|
'global_step': global_step,
|
||||||
|
'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
|
||||||
|
'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
}, path)
|
||||||
|
print('Saved checkpoints at', path)
|
||||||
|
|
||||||
|
if i%args.i_video==0 and i > 0:
|
||||||
|
# Turn on testing mode
|
||||||
|
with torch.no_grad():
|
||||||
|
rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
|
||||||
|
print('Done, saving', rgbs.shape, disps.shape)
|
||||||
|
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
|
||||||
|
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
|
||||||
|
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
|
||||||
|
|
||||||
|
# if args.use_viewdirs:
|
||||||
|
# render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
|
||||||
|
# with torch.no_grad():
|
||||||
|
# rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
|
||||||
|
# render_kwargs_test['c2w_staticcam'] = None
|
||||||
|
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
|
||||||
|
|
||||||
|
if i%args.i_testset==0 and i > 0:
|
||||||
|
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
|
||||||
|
os.makedirs(testsavedir, exist_ok=True)
|
||||||
|
print('test poses shape', poses[i_test].shape)
|
||||||
|
with torch.no_grad():
|
||||||
|
render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
|
||||||
|
print('Saved test set')
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if i%args.i_print==0 or i < 10:
|
||||||
|
|
||||||
|
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
|
||||||
|
print('iter time {:.05f}'.format(dt))
|
||||||
|
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
|
||||||
|
tf.contrib.summary.scalar('loss', loss)
|
||||||
|
tf.contrib.summary.scalar('psnr', psnr)
|
||||||
|
tf.contrib.summary.histogram('tran', trans)
|
||||||
|
if args.N_importance > 0:
|
||||||
|
tf.contrib.summary.scalar('psnr0', psnr0)
|
||||||
|
|
||||||
|
|
||||||
|
if i%args.i_img==0:
|
||||||
|
|
||||||
|
# Log a rendered validation view to Tensorboard
|
||||||
|
img_i=np.random.choice(i_val)
|
||||||
|
target = images[img_i]
|
||||||
|
pose = poses[img_i, :3,:4]
|
||||||
|
with torch.no_grad():
|
||||||
|
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
|
||||||
|
**render_kwargs_test)
|
||||||
|
|
||||||
|
psnr = mse2psnr(img2mse(rgb, target))
|
||||||
|
|
||||||
|
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
|
||||||
|
|
||||||
|
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
|
||||||
|
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
|
||||||
|
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
|
||||||
|
|
||||||
|
tf.contrib.summary.scalar('psnr_holdout', psnr)
|
||||||
|
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
|
||||||
|
|
||||||
|
|
||||||
|
if args.N_importance > 0:
|
||||||
|
|
||||||
|
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
|
||||||
|
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
|
||||||
|
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
|
||||||
|
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
|
||||||
|
"""
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||||
|
|
||||||
|
train()
|
242
run_nerf_helpers.py
Normal file
242
run_nerf_helpers.py
Normal file
|
@ -0,0 +1,242 @@
|
||||||
|
import torch
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# TODO: remove this dependency
|
||||||
|
from torchsearchsorted import searchsorted
|
||||||
|
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
img2mse = lambda x, y : torch.mean((x - y) ** 2)
|
||||||
|
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
|
||||||
|
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
# Positional encoding (section 5.1)
|
||||||
|
class Embedder:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.create_embedding_fn()
|
||||||
|
|
||||||
|
def create_embedding_fn(self):
|
||||||
|
embed_fns = []
|
||||||
|
d = self.kwargs['input_dims']
|
||||||
|
out_dim = 0
|
||||||
|
if self.kwargs['include_input']:
|
||||||
|
embed_fns.append(lambda x : x)
|
||||||
|
out_dim += d
|
||||||
|
|
||||||
|
max_freq = self.kwargs['max_freq_log2']
|
||||||
|
N_freqs = self.kwargs['num_freqs']
|
||||||
|
|
||||||
|
if self.kwargs['log_sampling']:
|
||||||
|
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
|
||||||
|
else:
|
||||||
|
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
|
||||||
|
|
||||||
|
for freq in freq_bands:
|
||||||
|
for p_fn in self.kwargs['periodic_fns']:
|
||||||
|
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
|
||||||
|
out_dim += d
|
||||||
|
|
||||||
|
self.embed_fns = embed_fns
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
def embed(self, inputs):
|
||||||
|
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedder(multires, i=0):
|
||||||
|
if i == -1:
|
||||||
|
return nn.Identity(), 3
|
||||||
|
|
||||||
|
embed_kwargs = {
|
||||||
|
'include_input' : True,
|
||||||
|
'input_dims' : 3,
|
||||||
|
'max_freq_log2' : multires-1,
|
||||||
|
'num_freqs' : multires,
|
||||||
|
'log_sampling' : True,
|
||||||
|
'periodic_fns' : [torch.sin, torch.cos],
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder_obj = Embedder(**embed_kwargs)
|
||||||
|
embed = lambda x, eo=embedder_obj : eo.embed(x)
|
||||||
|
return embed, embedder_obj.out_dim
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
class NeRF(nn.Module):
|
||||||
|
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
super(NeRF, self).__init__()
|
||||||
|
self.D = D
|
||||||
|
self.W = W
|
||||||
|
self.input_ch = input_ch
|
||||||
|
self.input_ch_views = input_ch_views
|
||||||
|
self.skips = skips
|
||||||
|
self.use_viewdirs = use_viewdirs
|
||||||
|
|
||||||
|
self.pts_linears = nn.ModuleList(
|
||||||
|
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
|
||||||
|
|
||||||
|
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
|
||||||
|
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
|
||||||
|
|
||||||
|
### Implementation according to the paper
|
||||||
|
# self.views_linears = nn.ModuleList(
|
||||||
|
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
|
||||||
|
|
||||||
|
if use_viewdirs:
|
||||||
|
self.feature_linear = nn.Linear(W, W)
|
||||||
|
self.alpha_linear = nn.Linear(W, 1)
|
||||||
|
self.rgb_linear = nn.Linear(W//2, 3)
|
||||||
|
else:
|
||||||
|
self.output_linear = nn.Linear(W, output_ch)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
|
||||||
|
h = input_pts
|
||||||
|
for i, l in enumerate(self.pts_linears):
|
||||||
|
h = self.pts_linears[i](h)
|
||||||
|
h = F.relu(h)
|
||||||
|
if i in self.skips:
|
||||||
|
h = torch.cat([input_pts, h], -1)
|
||||||
|
|
||||||
|
if self.use_viewdirs:
|
||||||
|
alpha = self.alpha_linear(h)
|
||||||
|
feature = self.feature_linear(h)
|
||||||
|
h = torch.cat([feature, input_views], -1)
|
||||||
|
|
||||||
|
for i, l in enumerate(self.views_linears):
|
||||||
|
h = self.views_linears[i](h)
|
||||||
|
h = F.relu(h)
|
||||||
|
|
||||||
|
rgb = self.rgb_linear(h)
|
||||||
|
outputs = torch.cat([rgb, alpha], -1)
|
||||||
|
else:
|
||||||
|
outputs = self.output_linear(h)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def load_weights_from_keras(self, weights):
|
||||||
|
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
|
||||||
|
|
||||||
|
# Load pts_linears
|
||||||
|
for i in range(self.D):
|
||||||
|
idx_pts_linears = 2 * i
|
||||||
|
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
|
||||||
|
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
|
||||||
|
|
||||||
|
# Load feature_linear
|
||||||
|
idx_feature_linear = 2 * self.D
|
||||||
|
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
|
||||||
|
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
|
||||||
|
|
||||||
|
# Load views_linears
|
||||||
|
idx_views_linears = 2 * self.D + 2
|
||||||
|
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
|
||||||
|
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
|
||||||
|
|
||||||
|
# Load rgb_linear
|
||||||
|
idx_rbg_linear = 2 * self.D + 4
|
||||||
|
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
|
||||||
|
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
|
||||||
|
|
||||||
|
# Load alpha_linear
|
||||||
|
idx_alpha_linear = 2 * self.D + 6
|
||||||
|
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
|
||||||
|
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Ray helpers
|
||||||
|
def get_rays(H, W, focal, c2w):
|
||||||
|
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
|
||||||
|
i = i.t()
|
||||||
|
j = j.t()
|
||||||
|
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
|
||||||
|
# Rotate ray directions from camera frame to the world frame
|
||||||
|
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
|
||||||
|
# Translate camera frame's origin to the world frame. It is the origin of all rays.
|
||||||
|
rays_o = c2w[:3,-1].expand(rays_d.shape)
|
||||||
|
return rays_o, rays_d
|
||||||
|
|
||||||
|
|
||||||
|
def get_rays_np(H, W, focal, c2w):
|
||||||
|
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
|
||||||
|
dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
|
||||||
|
# Rotate ray directions from camera frame to the world frame
|
||||||
|
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
|
||||||
|
# Translate camera frame's origin to the world frame. It is the origin of all rays.
|
||||||
|
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
|
||||||
|
return rays_o, rays_d
|
||||||
|
|
||||||
|
|
||||||
|
def ndc_rays(H, W, focal, near, rays_o, rays_d):
|
||||||
|
# Shift ray origins to near plane
|
||||||
|
t = -(near + rays_o[...,2]) / rays_d[...,2]
|
||||||
|
rays_o = rays_o + t[...,None] * rays_d
|
||||||
|
|
||||||
|
# Projection
|
||||||
|
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
|
||||||
|
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
|
||||||
|
o2 = 1. + 2. * near / rays_o[...,2]
|
||||||
|
|
||||||
|
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
|
||||||
|
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
|
||||||
|
d2 = -2. * near / rays_o[...,2]
|
||||||
|
|
||||||
|
rays_o = torch.stack([o0,o1,o2], -1)
|
||||||
|
rays_d = torch.stack([d0,d1,d2], -1)
|
||||||
|
|
||||||
|
return rays_o, rays_d
|
||||||
|
|
||||||
|
|
||||||
|
# Hierarchical sampling (section 5.2)
|
||||||
|
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
|
||||||
|
# Get pdf
|
||||||
|
weights = weights + 1e-5 # prevent nans
|
||||||
|
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
||||||
|
cdf = torch.cumsum(pdf, -1)
|
||||||
|
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
|
||||||
|
|
||||||
|
# Take uniform samples
|
||||||
|
if det:
|
||||||
|
u = torch.linspace(0., 1., steps=N_samples)
|
||||||
|
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
|
||||||
|
else:
|
||||||
|
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
|
||||||
|
|
||||||
|
# Pytest, overwrite u with numpy's fixed random numbers
|
||||||
|
if pytest:
|
||||||
|
np.random.seed(0)
|
||||||
|
new_shape = list(cdf.shape[:-1]) + [N_samples]
|
||||||
|
if det:
|
||||||
|
u = np.linspace(0., 1., N_samples)
|
||||||
|
u = np.broadcast_to(u, new_shape)
|
||||||
|
else:
|
||||||
|
u = np.random.rand(*new_shape)
|
||||||
|
u = torch.Tensor(u)
|
||||||
|
|
||||||
|
# Invert CDF
|
||||||
|
u = u.contiguous()
|
||||||
|
inds = searchsorted(cdf, u, side='right')
|
||||||
|
below = torch.max(torch.zeros_like(inds-1), inds-1)
|
||||||
|
above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds), inds)
|
||||||
|
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
||||||
|
|
||||||
|
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
|
||||||
|
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
|
||||||
|
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
||||||
|
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||||
|
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||||
|
|
||||||
|
denom = (cdf_g[...,1]-cdf_g[...,0])
|
||||||
|
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
|
||||||
|
t = (u-cdf_g[...,0])/denom
|
||||||
|
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
|
||||||
|
|
||||||
|
return samples
|
158
torchsearchsorted/.gitignore
vendored
Normal file
158
torchsearchsorted/.gitignore
vendored
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
# Prerequisites
|
||||||
|
*.d
|
||||||
|
|
||||||
|
# Object files
|
||||||
|
*.o
|
||||||
|
*.ko
|
||||||
|
*.obj
|
||||||
|
*.elf
|
||||||
|
|
||||||
|
# Linker output
|
||||||
|
*.ilk
|
||||||
|
*.map
|
||||||
|
*.exp
|
||||||
|
|
||||||
|
# Precompiled Headers
|
||||||
|
*.gch
|
||||||
|
*.pch
|
||||||
|
|
||||||
|
# Libraries
|
||||||
|
*.lib
|
||||||
|
*.a
|
||||||
|
*.la
|
||||||
|
*.lo
|
||||||
|
|
||||||
|
# Shared objects (inc. Windows DLLs)
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.so.*
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# Executables
|
||||||
|
*.exe
|
||||||
|
*.out
|
||||||
|
*.app
|
||||||
|
*.i*86
|
||||||
|
*.x86_64
|
||||||
|
*.hex
|
||||||
|
|
||||||
|
# Debug files
|
||||||
|
*.dSYM/
|
||||||
|
*.su
|
||||||
|
*.idb
|
||||||
|
*.pdb
|
||||||
|
|
||||||
|
# Kernel Module Compile Results
|
||||||
|
*.mod*
|
||||||
|
*.cmd
|
||||||
|
.tmp_versions/
|
||||||
|
modules.order
|
||||||
|
Module.symvers
|
||||||
|
Mkfile.old
|
||||||
|
dkms.conf
|
||||||
|
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
29
torchsearchsorted/LICENSE
Normal file
29
torchsearchsorted/LICENSE
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2019, Inria (Antoine Liutkus)
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
89
torchsearchsorted/README.md
Normal file
89
torchsearchsorted/README.md
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# Pytorch Custom CUDA kernel for searchsorted
|
||||||
|
|
||||||
|
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
|
||||||
|
|
||||||
|
|
||||||
|
> Warnings:
|
||||||
|
> * only works with pytorch > v1.3 and CUDA >= v10.1
|
||||||
|
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
|
||||||
|
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
|
||||||
|
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
|
||||||
|
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
|
||||||
|
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
|
||||||
|
|
||||||
|
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
|
||||||
|
|
||||||
|
|
||||||
|
**Disclaimers**
|
||||||
|
|
||||||
|
* This function has not been heavily tested. Use at your own risks
|
||||||
|
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
|
||||||
|
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
|
||||||
|
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Just `pip install .`, in the root folder of this repo. This will compile
|
||||||
|
and install the torchsearchsorted module.
|
||||||
|
|
||||||
|
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
|
||||||
|
|
||||||
|
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
|
||||||
|
So I had to do:
|
||||||
|
|
||||||
|
> sudo apt-get install g++-8 gcc-8
|
||||||
|
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
|
||||||
|
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
|
||||||
|
|
||||||
|
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Just import the torchsearchsorted package after installation. I typically do:
|
||||||
|
|
||||||
|
```
|
||||||
|
from torchsearchsorted import searchsorted
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Under the `examples` subfolder, you may:
|
||||||
|
|
||||||
|
1. try `python test.py` with `torch` available.
|
||||||
|
|
||||||
|
```
|
||||||
|
Looking for 50000x1000 values in 50000x300 entries
|
||||||
|
NUMPY: searchsorted in 4851.592ms
|
||||||
|
CPU: searchsorted in 4805.432ms
|
||||||
|
difference between CPU and NUMPY: 0.000
|
||||||
|
GPU: searchsorted in 1.055ms
|
||||||
|
difference between GPU and NUMPY: 0.000
|
||||||
|
|
||||||
|
Looking for 50000x1000 values in 50000x300 entries
|
||||||
|
NUMPY: searchsorted in 4333.964ms
|
||||||
|
CPU: searchsorted in 4753.958ms
|
||||||
|
difference between CPU and NUMPY: 0.000
|
||||||
|
GPU: searchsorted in 0.391ms
|
||||||
|
difference between GPU and NUMPY: 0.000
|
||||||
|
```
|
||||||
|
The first run comprises the time of allocation, while the second one does not.
|
||||||
|
|
||||||
|
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
|
||||||
|
|
||||||
|
```
|
||||||
|
Benchmark searchsorted:
|
||||||
|
- a [5000 x 300]
|
||||||
|
- v [5000 x 100]
|
||||||
|
- reporting fastest time of 20 runs
|
||||||
|
- each run executes searchsorted 100 times
|
||||||
|
|
||||||
|
Numpy: 4.6302046799100935
|
||||||
|
CPU: 5.041533078998327
|
||||||
|
CUDA: 0.0007955809123814106
|
||||||
|
```
|
71
torchsearchsorted/examples/benchmark.py
Normal file
71
torchsearchsorted/examples/benchmark.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
import timeit
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||||
|
|
||||||
|
B = 5_000
|
||||||
|
A = 300
|
||||||
|
V = 100
|
||||||
|
|
||||||
|
repeats = 20
|
||||||
|
number = 100
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Benchmark searchsorted:',
|
||||||
|
f'- a [{B} x {A}]',
|
||||||
|
f'- v [{B} x {V}]',
|
||||||
|
f'- reporting fastest time of {repeats} runs',
|
||||||
|
f'- each run executes searchsorted {number} times',
|
||||||
|
sep='\n',
|
||||||
|
end='\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_arrays():
|
||||||
|
a = np.sort(np.random.randn(B, A), axis=1)
|
||||||
|
v = np.random.randn(B, V)
|
||||||
|
out = np.empty_like(v, dtype=np.long)
|
||||||
|
return a, v, out
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensors(device):
|
||||||
|
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
|
||||||
|
v = torch.randn(B, V, device=device)
|
||||||
|
out = torch.empty(B, V, device=device, dtype=torch.long)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return a, v, out
|
||||||
|
|
||||||
|
def searchsorted_synchronized(a,v,out=None,side='left'):
|
||||||
|
out = searchsorted(a,v,out,side)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return out
|
||||||
|
|
||||||
|
numpy = timeit.repeat(
|
||||||
|
stmt="numpy_searchsorted(a, v, side='left')",
|
||||||
|
setup="a, v, out = get_arrays()",
|
||||||
|
globals=globals(),
|
||||||
|
repeat=repeats,
|
||||||
|
number=number
|
||||||
|
)
|
||||||
|
print('Numpy: ', min(numpy), sep='\t')
|
||||||
|
|
||||||
|
cpu = timeit.repeat(
|
||||||
|
stmt="searchsorted(a, v, out, side='left')",
|
||||||
|
setup="a, v, out = get_tensors(device='cpu')",
|
||||||
|
globals=globals(),
|
||||||
|
repeat=repeats,
|
||||||
|
number=number
|
||||||
|
)
|
||||||
|
print('CPU: ', min(cpu), sep='\t')
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu = timeit.repeat(
|
||||||
|
stmt="searchsorted_synchronized(a, v, out, side='left')",
|
||||||
|
setup="a, v, out = get_tensors(device='cuda')",
|
||||||
|
globals=globals(),
|
||||||
|
repeat=repeats,
|
||||||
|
number=number
|
||||||
|
)
|
||||||
|
print('CUDA: ', min(gpu), sep='\t')
|
66
torchsearchsorted/examples/test.py
Normal file
66
torchsearchsorted/examples/test.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import torch
|
||||||
|
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||||
|
import time
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# defining the number of tests
|
||||||
|
ntests = 2
|
||||||
|
|
||||||
|
# defining the problem dimensions
|
||||||
|
nrows_a = 50000
|
||||||
|
nrows_v = 50000
|
||||||
|
nsorted_values = 300
|
||||||
|
nvalues = 1000
|
||||||
|
|
||||||
|
# defines the variables. The first run will comprise allocation, the
|
||||||
|
# further ones will not
|
||||||
|
test_GPU = None
|
||||||
|
test_CPU = None
|
||||||
|
|
||||||
|
for ntest in range(ntests):
|
||||||
|
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
|
||||||
|
nrows_a,
|
||||||
|
nsorted_values))
|
||||||
|
|
||||||
|
side = 'right'
|
||||||
|
# generate a matrix with sorted rows
|
||||||
|
a = torch.randn(nrows_a, nsorted_values, device='cpu')
|
||||||
|
a = torch.sort(a, dim=1)[0]
|
||||||
|
# generate a matrix of values to searchsort
|
||||||
|
v = torch.randn(nrows_v, nvalues, device='cpu')
|
||||||
|
|
||||||
|
# a = torch.tensor([[0., 1.]])
|
||||||
|
# v = torch.tensor([[1.]])
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
|
||||||
|
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||||
|
t0 = time.time()
|
||||||
|
test_CPU = searchsorted(a, v, test_CPU, side)
|
||||||
|
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||||
|
# compute the difference between both
|
||||||
|
error_CPU = torch.norm(test_NP.double()
|
||||||
|
- test_CPU.double()).numpy()
|
||||||
|
if error_CPU:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print('CUDA is not available on this machine, cannot go further.')
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# now do the CPU
|
||||||
|
a = a.to('cuda')
|
||||||
|
v = v.to('cuda')
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# launch searchsorted on those
|
||||||
|
t0 = time.time()
|
||||||
|
test_GPU = searchsorted(a, v, test_GPU, side)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||||
|
|
||||||
|
# compute the difference between both
|
||||||
|
error_CUDA = torch.norm(test_NP.to('cuda').double()
|
||||||
|
- test_GPU.double()).cpu().numpy()
|
||||||
|
|
||||||
|
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
|
41
torchsearchsorted/setup.py
Normal file
41
torchsearchsorted/setup.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
|
||||||
|
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||||
|
|
||||||
|
# In any case, include the CPU version
|
||||||
|
modules = [
|
||||||
|
CppExtension('torchsearchsorted.cpu',
|
||||||
|
['src/cpu/searchsorted_cpu_wrapper.cpp']),
|
||||||
|
]
|
||||||
|
|
||||||
|
# If nvcc is available, add the CUDA extension
|
||||||
|
if CUDA_HOME:
|
||||||
|
modules.append(
|
||||||
|
CUDAExtension('torchsearchsorted.cuda',
|
||||||
|
['src/cuda/searchsorted_cuda_wrapper.cpp',
|
||||||
|
'src/cuda/searchsorted_cuda_kernel.cu'])
|
||||||
|
)
|
||||||
|
|
||||||
|
tests_require = [
|
||||||
|
'pytest',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Now proceed to setup
|
||||||
|
setup(
|
||||||
|
name='torchsearchsorted',
|
||||||
|
version='1.1',
|
||||||
|
description='A searchsorted implementation for pytorch',
|
||||||
|
keywords='searchsorted',
|
||||||
|
author='Antoine Liutkus',
|
||||||
|
author_email='antoine.liutkus@inria.fr',
|
||||||
|
packages=find_packages(where='src'),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
ext_modules=modules,
|
||||||
|
tests_require=tests_require,
|
||||||
|
extras_require={
|
||||||
|
'test': tests_require,
|
||||||
|
},
|
||||||
|
cmdclass={
|
||||||
|
'build_ext': BuildExtension
|
||||||
|
}
|
||||||
|
)
|
126
torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.cpp
Normal file
126
torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.cpp
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
#include "searchsorted_cpu_wrapper.h"
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
||||||
|
{
|
||||||
|
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
||||||
|
|
||||||
|
if (col == ncol - 1)
|
||||||
|
{
|
||||||
|
// special case: we are on the right border
|
||||||
|
if (a[row * ncol + col] <= val){
|
||||||
|
return 1;}
|
||||||
|
else {
|
||||||
|
return -1;}
|
||||||
|
}
|
||||||
|
bool is_lower;
|
||||||
|
bool is_next_higher;
|
||||||
|
|
||||||
|
if (side_left) {
|
||||||
|
// a[row, col] < v <= a[row, col+1]
|
||||||
|
is_lower = (a[row * ncol + col] < val);
|
||||||
|
is_next_higher = (a[row*ncol + col + 1] >= val);
|
||||||
|
} else {
|
||||||
|
// a[row, col] <= v < a[row, col+1]
|
||||||
|
is_lower = (a[row * ncol + col] <= val);
|
||||||
|
is_next_higher = (a[row * ncol + col + 1] > val);
|
||||||
|
}
|
||||||
|
if (is_lower && is_next_higher) {
|
||||||
|
// we found the right spot
|
||||||
|
return 0;
|
||||||
|
} else if (is_lower) {
|
||||||
|
// answer is on the right side
|
||||||
|
return 1;
|
||||||
|
} else {
|
||||||
|
// answer is on the left side
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
||||||
|
{
|
||||||
|
/* Look for the value `val` within row `row` of matrix `a`, which
|
||||||
|
has `ncol` columns.
|
||||||
|
|
||||||
|
the `a` matrix is assumed sorted in increasing order, row-wise
|
||||||
|
|
||||||
|
returns:
|
||||||
|
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
||||||
|
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
||||||
|
* Otherwise, return the column index `res` such that:
|
||||||
|
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
||||||
|
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
||||||
|
*/
|
||||||
|
|
||||||
|
//start with left at 0 and right at number of columns of a
|
||||||
|
int64_t right = ncol;
|
||||||
|
int64_t left = 0;
|
||||||
|
|
||||||
|
while (right >= left) {
|
||||||
|
// take the midpoint of current left and right cursors
|
||||||
|
int64_t mid = left + (right-left)/2;
|
||||||
|
|
||||||
|
// check the relative position of val: are we good here ?
|
||||||
|
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
||||||
|
// we found the point
|
||||||
|
if(rel_pos == 0) {
|
||||||
|
return mid;
|
||||||
|
} else if (rel_pos > 0) {
|
||||||
|
if (mid==ncol-1){return ncol-1;}
|
||||||
|
// the answer is on the right side
|
||||||
|
left = mid;
|
||||||
|
} else {
|
||||||
|
if (mid==0){return -1;}
|
||||||
|
right = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void searchsorted_cpu_wrapper(
|
||||||
|
at::Tensor a,
|
||||||
|
at::Tensor v,
|
||||||
|
at::Tensor res,
|
||||||
|
bool side_left)
|
||||||
|
{
|
||||||
|
|
||||||
|
// Get the dimensions
|
||||||
|
auto nrow_a = a.size(/*dim=*/0);
|
||||||
|
auto ncol_a = a.size(/*dim=*/1);
|
||||||
|
auto nrow_v = v.size(/*dim=*/0);
|
||||||
|
auto ncol_v = v.size(/*dim=*/1);
|
||||||
|
|
||||||
|
auto nrow_res = fmax(nrow_a, nrow_v);
|
||||||
|
|
||||||
|
//auto acc_v = v.accessor<float, 2>();
|
||||||
|
//auto acc_res = res.accessor<float, 2>();
|
||||||
|
|
||||||
|
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
|
||||||
|
|
||||||
|
scalar_t* a_data = a.data_ptr<scalar_t>();
|
||||||
|
scalar_t* v_data = v.data_ptr<scalar_t>();
|
||||||
|
int64_t* res_data = res.data<int64_t>();
|
||||||
|
|
||||||
|
for (int64_t row = 0; row < nrow_res; row++)
|
||||||
|
{
|
||||||
|
for (int64_t col = 0; col < ncol_v; col++)
|
||||||
|
{
|
||||||
|
// get the value to look for
|
||||||
|
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
|
||||||
|
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
|
||||||
|
|
||||||
|
int64_t idx_in_v = row_in_v * ncol_v + col;
|
||||||
|
int64_t idx_in_res = row * ncol_v + col;
|
||||||
|
|
||||||
|
// apply binary search
|
||||||
|
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
|
||||||
|
}
|
12
torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.h
Normal file
12
torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.h
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
#ifndef _SEARCHSORTED_CPU
|
||||||
|
#define _SEARCHSORTED_CPU
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void searchsorted_cpu_wrapper(
|
||||||
|
at::Tensor a,
|
||||||
|
at::Tensor v,
|
||||||
|
at::Tensor res,
|
||||||
|
bool side_left);
|
||||||
|
|
||||||
|
#endif
|
142
torchsearchsorted/src/cuda/searchsorted_cuda_kernel.cu
Normal file
142
torchsearchsorted/src/cuda/searchsorted_cuda_kernel.cu
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
#include "searchsorted_cuda_kernel.h"
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__
|
||||||
|
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
||||||
|
{
|
||||||
|
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
||||||
|
|
||||||
|
if (col == ncol - 1)
|
||||||
|
{
|
||||||
|
// special case: we are on the right border
|
||||||
|
if (a[row * ncol + col] <= val){
|
||||||
|
return 1;}
|
||||||
|
else {
|
||||||
|
return -1;}
|
||||||
|
}
|
||||||
|
bool is_lower;
|
||||||
|
bool is_next_higher;
|
||||||
|
|
||||||
|
if (side_left) {
|
||||||
|
// a[row, col] < v <= a[row, col+1]
|
||||||
|
is_lower = (a[row * ncol + col] < val);
|
||||||
|
is_next_higher = (a[row*ncol + col + 1] >= val);
|
||||||
|
} else {
|
||||||
|
// a[row, col] <= v < a[row, col+1]
|
||||||
|
is_lower = (a[row * ncol + col] <= val);
|
||||||
|
is_next_higher = (a[row * ncol + col + 1] > val);
|
||||||
|
}
|
||||||
|
if (is_lower && is_next_higher) {
|
||||||
|
// we found the right spot
|
||||||
|
return 0;
|
||||||
|
} else if (is_lower) {
|
||||||
|
// answer is on the right side
|
||||||
|
return 1;
|
||||||
|
} else {
|
||||||
|
// answer is on the left side
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__
|
||||||
|
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
||||||
|
{
|
||||||
|
/* Look for the value `val` within row `row` of matrix `a`, which
|
||||||
|
has `ncol` columns.
|
||||||
|
|
||||||
|
the `a` matrix is assumed sorted in increasing order, row-wise
|
||||||
|
|
||||||
|
Returns
|
||||||
|
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
||||||
|
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
||||||
|
* Otherwise, return the column index `res` such that:
|
||||||
|
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
||||||
|
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
||||||
|
*/
|
||||||
|
|
||||||
|
//start with left at 0 and right at number of columns of a
|
||||||
|
int64_t right = ncol;
|
||||||
|
int64_t left = 0;
|
||||||
|
|
||||||
|
while (right >= left) {
|
||||||
|
// take the midpoint of current left and right cursors
|
||||||
|
int64_t mid = left + (right-left)/2;
|
||||||
|
|
||||||
|
// check the relative position of val: are we good here ?
|
||||||
|
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
||||||
|
// we found the point
|
||||||
|
if(rel_pos == 0) {
|
||||||
|
return mid;
|
||||||
|
} else if (rel_pos > 0) {
|
||||||
|
if (mid==ncol-1){return ncol-1;}
|
||||||
|
// the answer is on the right side
|
||||||
|
left = mid;
|
||||||
|
} else {
|
||||||
|
if (mid==0){return -1;}
|
||||||
|
right = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__
|
||||||
|
void searchsorted_kernel(
|
||||||
|
int64_t *res,
|
||||||
|
scalar_t *a,
|
||||||
|
scalar_t *v,
|
||||||
|
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
|
||||||
|
{
|
||||||
|
// get current row and column
|
||||||
|
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
|
||||||
|
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
|
||||||
|
|
||||||
|
// check whether we are outside the bounds of what needs be computed.
|
||||||
|
if ((row >= nrow_res) || (col >= ncol_v)) {
|
||||||
|
return;}
|
||||||
|
|
||||||
|
// get the value to look for
|
||||||
|
int64_t row_in_v = (nrow_v==1) ? 0: row;
|
||||||
|
int64_t row_in_a = (nrow_a==1) ? 0: row;
|
||||||
|
int64_t idx_in_v = row_in_v*ncol_v+col;
|
||||||
|
int64_t idx_in_res = row*ncol_v+col;
|
||||||
|
|
||||||
|
// apply binary search
|
||||||
|
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void searchsorted_cuda(
|
||||||
|
at::Tensor a,
|
||||||
|
at::Tensor v,
|
||||||
|
at::Tensor res,
|
||||||
|
bool side_left){
|
||||||
|
|
||||||
|
// Get the dimensions
|
||||||
|
auto nrow_a = a.size(/*dim=*/0);
|
||||||
|
auto nrow_v = v.size(/*dim=*/0);
|
||||||
|
auto ncol_a = a.size(/*dim=*/1);
|
||||||
|
auto ncol_v = v.size(/*dim=*/1);
|
||||||
|
|
||||||
|
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
|
||||||
|
|
||||||
|
// prepare the kernel configuration
|
||||||
|
dim3 threads(ncol_v, nrow_res);
|
||||||
|
dim3 blocks(1, 1);
|
||||||
|
if (nrow_res*ncol_v > 1024){
|
||||||
|
threads.x = int(fmin(double(1024), double(ncol_v)));
|
||||||
|
threads.y = floor(1024/threads.x);
|
||||||
|
blocks.x = ceil(double(ncol_v)/double(threads.x));
|
||||||
|
blocks.y = ceil(double(nrow_res)/double(threads.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
|
||||||
|
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
|
||||||
|
res.data<int64_t>(),
|
||||||
|
a.data<scalar_t>(),
|
||||||
|
v.data<scalar_t>(),
|
||||||
|
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
|
||||||
|
}));
|
||||||
|
|
||||||
|
}
|
12
torchsearchsorted/src/cuda/searchsorted_cuda_kernel.h
Normal file
12
torchsearchsorted/src/cuda/searchsorted_cuda_kernel.h
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
#ifndef _SEARCHSORTED_CUDA_KERNEL
|
||||||
|
#define _SEARCHSORTED_CUDA_KERNEL
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void searchsorted_cuda(
|
||||||
|
at::Tensor a,
|
||||||
|
at::Tensor v,
|
||||||
|
at::Tensor res,
|
||||||
|
bool side_left);
|
||||||
|
|
||||||
|
#endif
|
20
torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.cpp
Normal file
20
torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.cpp
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
#include "searchsorted_cuda_wrapper.h"
|
||||||
|
|
||||||
|
// C++ interface
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
|
||||||
|
{
|
||||||
|
CHECK_INPUT(a);
|
||||||
|
CHECK_INPUT(v);
|
||||||
|
CHECK_INPUT(res);
|
||||||
|
|
||||||
|
searchsorted_cuda(a, v, res, side_left);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
|
||||||
|
}
|
13
torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.h
Normal file
13
torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.h
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
#ifndef _SEARCHSORTED_CUDA_WRAPPER
|
||||||
|
#define _SEARCHSORTED_CUDA_WRAPPER
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include "searchsorted_cuda_kernel.h"
|
||||||
|
|
||||||
|
void searchsorted_cuda_wrapper(
|
||||||
|
at::Tensor a,
|
||||||
|
at::Tensor v,
|
||||||
|
at::Tensor res,
|
||||||
|
bool side_left);
|
||||||
|
|
||||||
|
#endif
|
2
torchsearchsorted/src/torchsearchsorted/__init__.py
Normal file
2
torchsearchsorted/src/torchsearchsorted/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .searchsorted import searchsorted
|
||||||
|
from .utils import numpy_searchsorted
|
53
torchsearchsorted/src/torchsearchsorted/searchsorted.py
Normal file
53
torchsearchsorted/src/torchsearchsorted/searchsorted.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# trying to import the CPU searchsorted
|
||||||
|
SEARCHSORTED_CPU_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
|
||||||
|
except ImportError:
|
||||||
|
SEARCHSORTED_CPU_AVAILABLE = False
|
||||||
|
|
||||||
|
# trying to import the CUDA searchsorted
|
||||||
|
SEARCHSORTED_GPU_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
|
||||||
|
except ImportError:
|
||||||
|
SEARCHSORTED_GPU_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def searchsorted(a: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: Optional[torch.LongTensor] = None,
|
||||||
|
side='left') -> torch.LongTensor:
|
||||||
|
assert len(a.shape) == 2, "input `a` must be 2-D."
|
||||||
|
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
|
||||||
|
assert (a.shape[0] == v.shape[0]
|
||||||
|
or a.shape[0] == 1
|
||||||
|
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
|
||||||
|
"rows or one of them must have only one ")
|
||||||
|
assert a.device == v.device, '`a` and `v` must be on the same device'
|
||||||
|
|
||||||
|
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
|
||||||
|
if out is not None:
|
||||||
|
assert out.device == a.device, "`out` must be on the same device as `a`"
|
||||||
|
assert out.dtype == torch.long, "out.dtype must be torch.long"
|
||||||
|
assert out.shape == result_shape, ("If the output tensor is provided, "
|
||||||
|
"its shape must be correct.")
|
||||||
|
else:
|
||||||
|
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
|
||||||
|
|
||||||
|
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
|
||||||
|
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
|
||||||
|
'that it is not available. Please install it')
|
||||||
|
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
|
||||||
|
raise Exception('torchsearchsorted on CPU is not available. '
|
||||||
|
'Please install it.')
|
||||||
|
|
||||||
|
left_side = 1 if side=='left' else 0
|
||||||
|
if a.is_cuda:
|
||||||
|
searchsorted_cuda_wrapper(a, v, out, left_side)
|
||||||
|
else:
|
||||||
|
searchsorted_cpu_wrapper(a, v, out, left_side)
|
||||||
|
|
||||||
|
return out
|
15
torchsearchsorted/src/torchsearchsorted/utils.py
Normal file
15
torchsearchsorted/src/torchsearchsorted/utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
|
||||||
|
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
|
||||||
|
"""
|
||||||
|
nrows_a = a.shape[0]
|
||||||
|
(nrows_v, ncols_v) = v.shape
|
||||||
|
nrows_out = max(nrows_a, nrows_v)
|
||||||
|
out = np.empty((nrows_out, ncols_v), dtype=np.long)
|
||||||
|
def sel(data, row):
|
||||||
|
return data[0] if data.shape[0] == 1 else data[row]
|
||||||
|
for row in range(nrows_out):
|
||||||
|
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
|
||||||
|
return out
|
11
torchsearchsorted/test/conftest.py
Normal file
11
torchsearchsorted/test/conftest.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
devices = {'cpu': torch.device('cpu')}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
devices['cuda'] = torch.device('cuda:0')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=devices.values(), ids=devices.keys())
|
||||||
|
def device(request):
|
||||||
|
return request.param
|
44
torchsearchsorted/test/test_searchsorted.py
Normal file
44
torchsearchsorted/test/test_searchsorted.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||||
|
from itertools import product, repeat
|
||||||
|
|
||||||
|
|
||||||
|
def test_searchsorted_output_dtype(device):
|
||||||
|
B = 100
|
||||||
|
A = 50
|
||||||
|
V = 12
|
||||||
|
|
||||||
|
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
|
||||||
|
v = torch.rand(B, A, device=device)
|
||||||
|
|
||||||
|
out = searchsorted(a, v)
|
||||||
|
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
|
||||||
|
assert out.dtype == torch.long
|
||||||
|
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||||||
|
|
||||||
|
out = torch.empty(v.shape, dtype=torch.long, device=device)
|
||||||
|
searchsorted(a, v, out)
|
||||||
|
assert out.dtype == torch.long
|
||||||
|
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||||||
|
|
||||||
|
Ba_val = [1, 100, 200]
|
||||||
|
Bv_val = [1, 100, 200]
|
||||||
|
A_val = [1, 50, 500]
|
||||||
|
V_val = [1, 12, 120]
|
||||||
|
side_val = ['left', 'right']
|
||||||
|
nrepeat = 100
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
|
||||||
|
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
|
||||||
|
if Ba > 1 and Bv > 1 and Ba != Bv:
|
||||||
|
return
|
||||||
|
for test in range(nrepeat):
|
||||||
|
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
|
||||||
|
v = torch.rand(Bv, V, device=device)
|
||||||
|
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
|
||||||
|
side=side)
|
||||||
|
out = searchsorted(a, v, side=side).cpu().numpy()
|
||||||
|
np.testing.assert_array_equal(out, out_np)
|
Loading…
Reference in a new issue