rm uneeded dir for torch search
This commit is contained in:
parent
56033b32f5
commit
71e3a13134
17 changed files with 0 additions and 904 deletions
158
torchsearchsorted/.gitignore
vendored
158
torchsearchsorted/.gitignore
vendored
|
@ -1,158 +0,0 @@
|
||||||
# Prerequisites
|
|
||||||
*.d
|
|
||||||
|
|
||||||
# Object files
|
|
||||||
*.o
|
|
||||||
*.ko
|
|
||||||
*.obj
|
|
||||||
*.elf
|
|
||||||
|
|
||||||
# Linker output
|
|
||||||
*.ilk
|
|
||||||
*.map
|
|
||||||
*.exp
|
|
||||||
|
|
||||||
# Precompiled Headers
|
|
||||||
*.gch
|
|
||||||
*.pch
|
|
||||||
|
|
||||||
# Libraries
|
|
||||||
*.lib
|
|
||||||
*.a
|
|
||||||
*.la
|
|
||||||
*.lo
|
|
||||||
|
|
||||||
# Shared objects (inc. Windows DLLs)
|
|
||||||
*.dll
|
|
||||||
*.so
|
|
||||||
*.so.*
|
|
||||||
*.dylib
|
|
||||||
|
|
||||||
# Executables
|
|
||||||
*.exe
|
|
||||||
*.out
|
|
||||||
*.app
|
|
||||||
*.i*86
|
|
||||||
*.x86_64
|
|
||||||
*.hex
|
|
||||||
|
|
||||||
# Debug files
|
|
||||||
*.dSYM/
|
|
||||||
*.su
|
|
||||||
*.idb
|
|
||||||
*.pdb
|
|
||||||
|
|
||||||
# Kernel Module Compile Results
|
|
||||||
*.mod*
|
|
||||||
*.cmd
|
|
||||||
.tmp_versions/
|
|
||||||
modules.order
|
|
||||||
Module.symvers
|
|
||||||
Mkfile.old
|
|
||||||
dkms.conf
|
|
||||||
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# celery beat schedule file
|
|
||||||
celerybeat-schedule
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
|
@ -1,29 +0,0 @@
|
||||||
BSD 3-Clause License
|
|
||||||
|
|
||||||
Copyright (c) 2019, Inria (Antoine Liutkus)
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
this list of conditions and the following disclaimer in the documentation
|
|
||||||
and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
@ -1,89 +0,0 @@
|
||||||
# Pytorch Custom CUDA kernel for searchsorted
|
|
||||||
|
|
||||||
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
|
|
||||||
|
|
||||||
|
|
||||||
> Warnings:
|
|
||||||
> * only works with pytorch > v1.3 and CUDA >= v10.1
|
|
||||||
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
|
|
||||||
|
|
||||||
## Description
|
|
||||||
|
|
||||||
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
|
|
||||||
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
|
|
||||||
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
|
|
||||||
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
|
|
||||||
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
|
|
||||||
|
|
||||||
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
|
|
||||||
|
|
||||||
|
|
||||||
**Disclaimers**
|
|
||||||
|
|
||||||
* This function has not been heavily tested. Use at your own risks
|
|
||||||
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
|
|
||||||
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
|
|
||||||
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
|
|
||||||
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Just `pip install .`, in the root folder of this repo. This will compile
|
|
||||||
and install the torchsearchsorted module.
|
|
||||||
|
|
||||||
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
|
|
||||||
|
|
||||||
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
|
|
||||||
So I had to do:
|
|
||||||
|
|
||||||
> sudo apt-get install g++-8 gcc-8
|
|
||||||
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
|
|
||||||
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
|
|
||||||
|
|
||||||
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Just import the torchsearchsorted package after installation. I typically do:
|
|
||||||
|
|
||||||
```
|
|
||||||
from torchsearchsorted import searchsorted
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Under the `examples` subfolder, you may:
|
|
||||||
|
|
||||||
1. try `python test.py` with `torch` available.
|
|
||||||
|
|
||||||
```
|
|
||||||
Looking for 50000x1000 values in 50000x300 entries
|
|
||||||
NUMPY: searchsorted in 4851.592ms
|
|
||||||
CPU: searchsorted in 4805.432ms
|
|
||||||
difference between CPU and NUMPY: 0.000
|
|
||||||
GPU: searchsorted in 1.055ms
|
|
||||||
difference between GPU and NUMPY: 0.000
|
|
||||||
|
|
||||||
Looking for 50000x1000 values in 50000x300 entries
|
|
||||||
NUMPY: searchsorted in 4333.964ms
|
|
||||||
CPU: searchsorted in 4753.958ms
|
|
||||||
difference between CPU and NUMPY: 0.000
|
|
||||||
GPU: searchsorted in 0.391ms
|
|
||||||
difference between GPU and NUMPY: 0.000
|
|
||||||
```
|
|
||||||
The first run comprises the time of allocation, while the second one does not.
|
|
||||||
|
|
||||||
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
|
|
||||||
|
|
||||||
```
|
|
||||||
Benchmark searchsorted:
|
|
||||||
- a [5000 x 300]
|
|
||||||
- v [5000 x 100]
|
|
||||||
- reporting fastest time of 20 runs
|
|
||||||
- each run executes searchsorted 100 times
|
|
||||||
|
|
||||||
Numpy: 4.6302046799100935
|
|
||||||
CPU: 5.041533078998327
|
|
||||||
CUDA: 0.0007955809123814106
|
|
||||||
```
|
|
|
@ -1,71 +0,0 @@
|
||||||
import timeit
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
|
||||||
|
|
||||||
B = 5_000
|
|
||||||
A = 300
|
|
||||||
V = 100
|
|
||||||
|
|
||||||
repeats = 20
|
|
||||||
number = 100
|
|
||||||
|
|
||||||
print(
|
|
||||||
f'Benchmark searchsorted:',
|
|
||||||
f'- a [{B} x {A}]',
|
|
||||||
f'- v [{B} x {V}]',
|
|
||||||
f'- reporting fastest time of {repeats} runs',
|
|
||||||
f'- each run executes searchsorted {number} times',
|
|
||||||
sep='\n',
|
|
||||||
end='\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_arrays():
|
|
||||||
a = np.sort(np.random.randn(B, A), axis=1)
|
|
||||||
v = np.random.randn(B, V)
|
|
||||||
out = np.empty_like(v, dtype=np.long)
|
|
||||||
return a, v, out
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensors(device):
|
|
||||||
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
|
|
||||||
v = torch.randn(B, V, device=device)
|
|
||||||
out = torch.empty(B, V, device=device, dtype=torch.long)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return a, v, out
|
|
||||||
|
|
||||||
def searchsorted_synchronized(a,v,out=None,side='left'):
|
|
||||||
out = searchsorted(a,v,out,side)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return out
|
|
||||||
|
|
||||||
numpy = timeit.repeat(
|
|
||||||
stmt="numpy_searchsorted(a, v, side='left')",
|
|
||||||
setup="a, v, out = get_arrays()",
|
|
||||||
globals=globals(),
|
|
||||||
repeat=repeats,
|
|
||||||
number=number
|
|
||||||
)
|
|
||||||
print('Numpy: ', min(numpy), sep='\t')
|
|
||||||
|
|
||||||
cpu = timeit.repeat(
|
|
||||||
stmt="searchsorted(a, v, out, side='left')",
|
|
||||||
setup="a, v, out = get_tensors(device='cpu')",
|
|
||||||
globals=globals(),
|
|
||||||
repeat=repeats,
|
|
||||||
number=number
|
|
||||||
)
|
|
||||||
print('CPU: ', min(cpu), sep='\t')
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
gpu = timeit.repeat(
|
|
||||||
stmt="searchsorted_synchronized(a, v, out, side='left')",
|
|
||||||
setup="a, v, out = get_tensors(device='cuda')",
|
|
||||||
globals=globals(),
|
|
||||||
repeat=repeats,
|
|
||||||
number=number
|
|
||||||
)
|
|
||||||
print('CUDA: ', min(gpu), sep='\t')
|
|
|
@ -1,66 +0,0 @@
|
||||||
import torch
|
|
||||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
|
||||||
import time
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# defining the number of tests
|
|
||||||
ntests = 2
|
|
||||||
|
|
||||||
# defining the problem dimensions
|
|
||||||
nrows_a = 50000
|
|
||||||
nrows_v = 50000
|
|
||||||
nsorted_values = 300
|
|
||||||
nvalues = 1000
|
|
||||||
|
|
||||||
# defines the variables. The first run will comprise allocation, the
|
|
||||||
# further ones will not
|
|
||||||
test_GPU = None
|
|
||||||
test_CPU = None
|
|
||||||
|
|
||||||
for ntest in range(ntests):
|
|
||||||
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
|
|
||||||
nrows_a,
|
|
||||||
nsorted_values))
|
|
||||||
|
|
||||||
side = 'right'
|
|
||||||
# generate a matrix with sorted rows
|
|
||||||
a = torch.randn(nrows_a, nsorted_values, device='cpu')
|
|
||||||
a = torch.sort(a, dim=1)[0]
|
|
||||||
# generate a matrix of values to searchsort
|
|
||||||
v = torch.randn(nrows_v, nvalues, device='cpu')
|
|
||||||
|
|
||||||
# a = torch.tensor([[0., 1.]])
|
|
||||||
# v = torch.tensor([[1.]])
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
|
|
||||||
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
|
||||||
t0 = time.time()
|
|
||||||
test_CPU = searchsorted(a, v, test_CPU, side)
|
|
||||||
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
|
||||||
# compute the difference between both
|
|
||||||
error_CPU = torch.norm(test_NP.double()
|
|
||||||
- test_CPU.double()).numpy()
|
|
||||||
if error_CPU:
|
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print('CUDA is not available on this machine, cannot go further.')
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# now do the CPU
|
|
||||||
a = a.to('cuda')
|
|
||||||
v = v.to('cuda')
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# launch searchsorted on those
|
|
||||||
t0 = time.time()
|
|
||||||
test_GPU = searchsorted(a, v, test_GPU, side)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
|
||||||
|
|
||||||
# compute the difference between both
|
|
||||||
error_CUDA = torch.norm(test_NP.to('cuda').double()
|
|
||||||
- test_GPU.double()).cpu().numpy()
|
|
||||||
|
|
||||||
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
|
|
|
@ -1,41 +0,0 @@
|
||||||
from setuptools import setup, find_packages
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
|
|
||||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
|
||||||
|
|
||||||
# In any case, include the CPU version
|
|
||||||
modules = [
|
|
||||||
CppExtension('torchsearchsorted.cpu',
|
|
||||||
['src/cpu/searchsorted_cpu_wrapper.cpp']),
|
|
||||||
]
|
|
||||||
|
|
||||||
# If nvcc is available, add the CUDA extension
|
|
||||||
if CUDA_HOME:
|
|
||||||
modules.append(
|
|
||||||
CUDAExtension('torchsearchsorted.cuda',
|
|
||||||
['src/cuda/searchsorted_cuda_wrapper.cpp',
|
|
||||||
'src/cuda/searchsorted_cuda_kernel.cu'])
|
|
||||||
)
|
|
||||||
|
|
||||||
tests_require = [
|
|
||||||
'pytest',
|
|
||||||
]
|
|
||||||
|
|
||||||
# Now proceed to setup
|
|
||||||
setup(
|
|
||||||
name='torchsearchsorted',
|
|
||||||
version='1.1',
|
|
||||||
description='A searchsorted implementation for pytorch',
|
|
||||||
keywords='searchsorted',
|
|
||||||
author='Antoine Liutkus',
|
|
||||||
author_email='antoine.liutkus@inria.fr',
|
|
||||||
packages=find_packages(where='src'),
|
|
||||||
package_dir={"": "src"},
|
|
||||||
ext_modules=modules,
|
|
||||||
tests_require=tests_require,
|
|
||||||
extras_require={
|
|
||||||
'test': tests_require,
|
|
||||||
},
|
|
||||||
cmdclass={
|
|
||||||
'build_ext': BuildExtension
|
|
||||||
}
|
|
||||||
)
|
|
|
@ -1,126 +0,0 @@
|
||||||
#include "searchsorted_cpu_wrapper.h"
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
template<typename scalar_t>
|
|
||||||
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
|
||||||
{
|
|
||||||
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
|
||||||
|
|
||||||
if (col == ncol - 1)
|
|
||||||
{
|
|
||||||
// special case: we are on the right border
|
|
||||||
if (a[row * ncol + col] <= val){
|
|
||||||
return 1;}
|
|
||||||
else {
|
|
||||||
return -1;}
|
|
||||||
}
|
|
||||||
bool is_lower;
|
|
||||||
bool is_next_higher;
|
|
||||||
|
|
||||||
if (side_left) {
|
|
||||||
// a[row, col] < v <= a[row, col+1]
|
|
||||||
is_lower = (a[row * ncol + col] < val);
|
|
||||||
is_next_higher = (a[row*ncol + col + 1] >= val);
|
|
||||||
} else {
|
|
||||||
// a[row, col] <= v < a[row, col+1]
|
|
||||||
is_lower = (a[row * ncol + col] <= val);
|
|
||||||
is_next_higher = (a[row * ncol + col + 1] > val);
|
|
||||||
}
|
|
||||||
if (is_lower && is_next_higher) {
|
|
||||||
// we found the right spot
|
|
||||||
return 0;
|
|
||||||
} else if (is_lower) {
|
|
||||||
// answer is on the right side
|
|
||||||
return 1;
|
|
||||||
} else {
|
|
||||||
// answer is on the left side
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename scalar_t>
|
|
||||||
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
|
||||||
{
|
|
||||||
/* Look for the value `val` within row `row` of matrix `a`, which
|
|
||||||
has `ncol` columns.
|
|
||||||
|
|
||||||
the `a` matrix is assumed sorted in increasing order, row-wise
|
|
||||||
|
|
||||||
returns:
|
|
||||||
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
|
||||||
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
|
||||||
* Otherwise, return the column index `res` such that:
|
|
||||||
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
|
||||||
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
|
||||||
*/
|
|
||||||
|
|
||||||
//start with left at 0 and right at number of columns of a
|
|
||||||
int64_t right = ncol;
|
|
||||||
int64_t left = 0;
|
|
||||||
|
|
||||||
while (right >= left) {
|
|
||||||
// take the midpoint of current left and right cursors
|
|
||||||
int64_t mid = left + (right-left)/2;
|
|
||||||
|
|
||||||
// check the relative position of val: are we good here ?
|
|
||||||
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
|
||||||
// we found the point
|
|
||||||
if(rel_pos == 0) {
|
|
||||||
return mid;
|
|
||||||
} else if (rel_pos > 0) {
|
|
||||||
if (mid==ncol-1){return ncol-1;}
|
|
||||||
// the answer is on the right side
|
|
||||||
left = mid;
|
|
||||||
} else {
|
|
||||||
if (mid==0){return -1;}
|
|
||||||
right = mid;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void searchsorted_cpu_wrapper(
|
|
||||||
at::Tensor a,
|
|
||||||
at::Tensor v,
|
|
||||||
at::Tensor res,
|
|
||||||
bool side_left)
|
|
||||||
{
|
|
||||||
|
|
||||||
// Get the dimensions
|
|
||||||
auto nrow_a = a.size(/*dim=*/0);
|
|
||||||
auto ncol_a = a.size(/*dim=*/1);
|
|
||||||
auto nrow_v = v.size(/*dim=*/0);
|
|
||||||
auto ncol_v = v.size(/*dim=*/1);
|
|
||||||
|
|
||||||
auto nrow_res = fmax(nrow_a, nrow_v);
|
|
||||||
|
|
||||||
//auto acc_v = v.accessor<float, 2>();
|
|
||||||
//auto acc_res = res.accessor<float, 2>();
|
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
|
|
||||||
|
|
||||||
scalar_t* a_data = a.data_ptr<scalar_t>();
|
|
||||||
scalar_t* v_data = v.data_ptr<scalar_t>();
|
|
||||||
int64_t* res_data = res.data<int64_t>();
|
|
||||||
|
|
||||||
for (int64_t row = 0; row < nrow_res; row++)
|
|
||||||
{
|
|
||||||
for (int64_t col = 0; col < ncol_v; col++)
|
|
||||||
{
|
|
||||||
// get the value to look for
|
|
||||||
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
|
|
||||||
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
|
|
||||||
|
|
||||||
int64_t idx_in_v = row_in_v * ncol_v + col;
|
|
||||||
int64_t idx_in_res = row * ncol_v + col;
|
|
||||||
|
|
||||||
// apply binary search
|
|
||||||
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
|
|
||||||
}
|
|
|
@ -1,12 +0,0 @@
|
||||||
#ifndef _SEARCHSORTED_CPU
|
|
||||||
#define _SEARCHSORTED_CPU
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void searchsorted_cpu_wrapper(
|
|
||||||
at::Tensor a,
|
|
||||||
at::Tensor v,
|
|
||||||
at::Tensor res,
|
|
||||||
bool side_left);
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,142 +0,0 @@
|
||||||
#include "searchsorted_cuda_kernel.h"
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__
|
|
||||||
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
|
||||||
{
|
|
||||||
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
|
||||||
|
|
||||||
if (col == ncol - 1)
|
|
||||||
{
|
|
||||||
// special case: we are on the right border
|
|
||||||
if (a[row * ncol + col] <= val){
|
|
||||||
return 1;}
|
|
||||||
else {
|
|
||||||
return -1;}
|
|
||||||
}
|
|
||||||
bool is_lower;
|
|
||||||
bool is_next_higher;
|
|
||||||
|
|
||||||
if (side_left) {
|
|
||||||
// a[row, col] < v <= a[row, col+1]
|
|
||||||
is_lower = (a[row * ncol + col] < val);
|
|
||||||
is_next_higher = (a[row*ncol + col + 1] >= val);
|
|
||||||
} else {
|
|
||||||
// a[row, col] <= v < a[row, col+1]
|
|
||||||
is_lower = (a[row * ncol + col] <= val);
|
|
||||||
is_next_higher = (a[row * ncol + col + 1] > val);
|
|
||||||
}
|
|
||||||
if (is_lower && is_next_higher) {
|
|
||||||
// we found the right spot
|
|
||||||
return 0;
|
|
||||||
} else if (is_lower) {
|
|
||||||
// answer is on the right side
|
|
||||||
return 1;
|
|
||||||
} else {
|
|
||||||
// answer is on the left side
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__
|
|
||||||
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
|
||||||
{
|
|
||||||
/* Look for the value `val` within row `row` of matrix `a`, which
|
|
||||||
has `ncol` columns.
|
|
||||||
|
|
||||||
the `a` matrix is assumed sorted in increasing order, row-wise
|
|
||||||
|
|
||||||
Returns
|
|
||||||
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
|
||||||
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
|
||||||
* Otherwise, return the column index `res` such that:
|
|
||||||
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
|
||||||
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
|
||||||
*/
|
|
||||||
|
|
||||||
//start with left at 0 and right at number of columns of a
|
|
||||||
int64_t right = ncol;
|
|
||||||
int64_t left = 0;
|
|
||||||
|
|
||||||
while (right >= left) {
|
|
||||||
// take the midpoint of current left and right cursors
|
|
||||||
int64_t mid = left + (right-left)/2;
|
|
||||||
|
|
||||||
// check the relative position of val: are we good here ?
|
|
||||||
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
|
||||||
// we found the point
|
|
||||||
if(rel_pos == 0) {
|
|
||||||
return mid;
|
|
||||||
} else if (rel_pos > 0) {
|
|
||||||
if (mid==ncol-1){return ncol-1;}
|
|
||||||
// the answer is on the right side
|
|
||||||
left = mid;
|
|
||||||
} else {
|
|
||||||
if (mid==0){return -1;}
|
|
||||||
right = mid;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__
|
|
||||||
void searchsorted_kernel(
|
|
||||||
int64_t *res,
|
|
||||||
scalar_t *a,
|
|
||||||
scalar_t *v,
|
|
||||||
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
|
|
||||||
{
|
|
||||||
// get current row and column
|
|
||||||
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
|
|
||||||
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
|
|
||||||
|
|
||||||
// check whether we are outside the bounds of what needs be computed.
|
|
||||||
if ((row >= nrow_res) || (col >= ncol_v)) {
|
|
||||||
return;}
|
|
||||||
|
|
||||||
// get the value to look for
|
|
||||||
int64_t row_in_v = (nrow_v==1) ? 0: row;
|
|
||||||
int64_t row_in_a = (nrow_a==1) ? 0: row;
|
|
||||||
int64_t idx_in_v = row_in_v*ncol_v+col;
|
|
||||||
int64_t idx_in_res = row*ncol_v+col;
|
|
||||||
|
|
||||||
// apply binary search
|
|
||||||
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void searchsorted_cuda(
|
|
||||||
at::Tensor a,
|
|
||||||
at::Tensor v,
|
|
||||||
at::Tensor res,
|
|
||||||
bool side_left){
|
|
||||||
|
|
||||||
// Get the dimensions
|
|
||||||
auto nrow_a = a.size(/*dim=*/0);
|
|
||||||
auto nrow_v = v.size(/*dim=*/0);
|
|
||||||
auto ncol_a = a.size(/*dim=*/1);
|
|
||||||
auto ncol_v = v.size(/*dim=*/1);
|
|
||||||
|
|
||||||
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
|
|
||||||
|
|
||||||
// prepare the kernel configuration
|
|
||||||
dim3 threads(ncol_v, nrow_res);
|
|
||||||
dim3 blocks(1, 1);
|
|
||||||
if (nrow_res*ncol_v > 1024){
|
|
||||||
threads.x = int(fmin(double(1024), double(ncol_v)));
|
|
||||||
threads.y = floor(1024/threads.x);
|
|
||||||
blocks.x = ceil(double(ncol_v)/double(threads.x));
|
|
||||||
blocks.y = ceil(double(nrow_res)/double(threads.y));
|
|
||||||
}
|
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
|
|
||||||
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
|
|
||||||
res.data<int64_t>(),
|
|
||||||
a.data<scalar_t>(),
|
|
||||||
v.data<scalar_t>(),
|
|
||||||
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
|
|
||||||
}));
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,12 +0,0 @@
|
||||||
#ifndef _SEARCHSORTED_CUDA_KERNEL
|
|
||||||
#define _SEARCHSORTED_CUDA_KERNEL
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void searchsorted_cuda(
|
|
||||||
at::Tensor a,
|
|
||||||
at::Tensor v,
|
|
||||||
at::Tensor res,
|
|
||||||
bool side_left);
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,20 +0,0 @@
|
||||||
#include "searchsorted_cuda_wrapper.h"
|
|
||||||
|
|
||||||
// C++ interface
|
|
||||||
|
|
||||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
|
||||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
|
||||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
|
||||||
|
|
||||||
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
|
|
||||||
{
|
|
||||||
CHECK_INPUT(a);
|
|
||||||
CHECK_INPUT(v);
|
|
||||||
CHECK_INPUT(res);
|
|
||||||
|
|
||||||
searchsorted_cuda(a, v, res, side_left);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
#ifndef _SEARCHSORTED_CUDA_WRAPPER
|
|
||||||
#define _SEARCHSORTED_CUDA_WRAPPER
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
#include "searchsorted_cuda_kernel.h"
|
|
||||||
|
|
||||||
void searchsorted_cuda_wrapper(
|
|
||||||
at::Tensor a,
|
|
||||||
at::Tensor v,
|
|
||||||
at::Tensor res,
|
|
||||||
bool side_left);
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,2 +0,0 @@
|
||||||
from .searchsorted import searchsorted
|
|
||||||
from .utils import numpy_searchsorted
|
|
|
@ -1,53 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# trying to import the CPU searchsorted
|
|
||||||
SEARCHSORTED_CPU_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
|
|
||||||
except ImportError:
|
|
||||||
SEARCHSORTED_CPU_AVAILABLE = False
|
|
||||||
|
|
||||||
# trying to import the CUDA searchsorted
|
|
||||||
SEARCHSORTED_GPU_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
|
|
||||||
except ImportError:
|
|
||||||
SEARCHSORTED_GPU_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def searchsorted(a: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: Optional[torch.LongTensor] = None,
|
|
||||||
side='left') -> torch.LongTensor:
|
|
||||||
assert len(a.shape) == 2, "input `a` must be 2-D."
|
|
||||||
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
|
|
||||||
assert (a.shape[0] == v.shape[0]
|
|
||||||
or a.shape[0] == 1
|
|
||||||
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
|
|
||||||
"rows or one of them must have only one ")
|
|
||||||
assert a.device == v.device, '`a` and `v` must be on the same device'
|
|
||||||
|
|
||||||
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
|
|
||||||
if out is not None:
|
|
||||||
assert out.device == a.device, "`out` must be on the same device as `a`"
|
|
||||||
assert out.dtype == torch.long, "out.dtype must be torch.long"
|
|
||||||
assert out.shape == result_shape, ("If the output tensor is provided, "
|
|
||||||
"its shape must be correct.")
|
|
||||||
else:
|
|
||||||
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
|
|
||||||
|
|
||||||
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
|
|
||||||
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
|
|
||||||
'that it is not available. Please install it')
|
|
||||||
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
|
|
||||||
raise Exception('torchsearchsorted on CPU is not available. '
|
|
||||||
'Please install it.')
|
|
||||||
|
|
||||||
left_side = 1 if side=='left' else 0
|
|
||||||
if a.is_cuda:
|
|
||||||
searchsorted_cuda_wrapper(a, v, out, left_side)
|
|
||||||
else:
|
|
||||||
searchsorted_cpu_wrapper(a, v, out, left_side)
|
|
||||||
|
|
||||||
return out
|
|
|
@ -1,15 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
|
|
||||||
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
|
|
||||||
"""
|
|
||||||
nrows_a = a.shape[0]
|
|
||||||
(nrows_v, ncols_v) = v.shape
|
|
||||||
nrows_out = max(nrows_a, nrows_v)
|
|
||||||
out = np.empty((nrows_out, ncols_v), dtype=np.long)
|
|
||||||
def sel(data, row):
|
|
||||||
return data[0] if data.shape[0] == 1 else data[row]
|
|
||||||
for row in range(nrows_out):
|
|
||||||
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
|
|
||||||
return out
|
|
|
@ -1,11 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
devices = {'cpu': torch.device('cpu')}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
devices['cuda'] = torch.device('cuda:0')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=devices.values(), ids=devices.keys())
|
|
||||||
def device(request):
|
|
||||||
return request.param
|
|
|
@ -1,44 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
|
||||||
from itertools import product, repeat
|
|
||||||
|
|
||||||
|
|
||||||
def test_searchsorted_output_dtype(device):
|
|
||||||
B = 100
|
|
||||||
A = 50
|
|
||||||
V = 12
|
|
||||||
|
|
||||||
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
|
|
||||||
v = torch.rand(B, A, device=device)
|
|
||||||
|
|
||||||
out = searchsorted(a, v)
|
|
||||||
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
|
|
||||||
assert out.dtype == torch.long
|
|
||||||
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
|
||||||
|
|
||||||
out = torch.empty(v.shape, dtype=torch.long, device=device)
|
|
||||||
searchsorted(a, v, out)
|
|
||||||
assert out.dtype == torch.long
|
|
||||||
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
|
||||||
|
|
||||||
Ba_val = [1, 100, 200]
|
|
||||||
Bv_val = [1, 100, 200]
|
|
||||||
A_val = [1, 50, 500]
|
|
||||||
V_val = [1, 12, 120]
|
|
||||||
side_val = ['left', 'right']
|
|
||||||
nrepeat = 100
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
|
|
||||||
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
|
|
||||||
if Ba > 1 and Bv > 1 and Ba != Bv:
|
|
||||||
return
|
|
||||||
for test in range(nrepeat):
|
|
||||||
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
|
|
||||||
v = torch.rand(Bv, V, device=device)
|
|
||||||
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
|
|
||||||
side=side)
|
|
||||||
out = searchsorted(a, v, side=side).cpu().numpy()
|
|
||||||
np.testing.assert_array_equal(out, out_np)
|
|
Loading…
Reference in a new issue