rm uneeded dir for torch search
This commit is contained in:
parent
56033b32f5
commit
71e3a13134
17 changed files with 0 additions and 904 deletions
158
torchsearchsorted/.gitignore
vendored
158
torchsearchsorted/.gitignore
vendored
|
@ -1,158 +0,0 @@
|
|||
# Prerequisites
|
||||
*.d
|
||||
|
||||
# Object files
|
||||
*.o
|
||||
*.ko
|
||||
*.obj
|
||||
*.elf
|
||||
|
||||
# Linker output
|
||||
*.ilk
|
||||
*.map
|
||||
*.exp
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
|
||||
# Libraries
|
||||
*.lib
|
||||
*.a
|
||||
*.la
|
||||
*.lo
|
||||
|
||||
# Shared objects (inc. Windows DLLs)
|
||||
*.dll
|
||||
*.so
|
||||
*.so.*
|
||||
*.dylib
|
||||
|
||||
# Executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
*.i*86
|
||||
*.x86_64
|
||||
*.hex
|
||||
|
||||
# Debug files
|
||||
*.dSYM/
|
||||
*.su
|
||||
*.idb
|
||||
*.pdb
|
||||
|
||||
# Kernel Module Compile Results
|
||||
*.mod*
|
||||
*.cmd
|
||||
.tmp_versions/
|
||||
modules.order
|
||||
Module.symvers
|
||||
Mkfile.old
|
||||
dkms.conf
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
|
@ -1,29 +0,0 @@
|
|||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2019, Inria (Antoine Liutkus)
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -1,89 +0,0 @@
|
|||
# Pytorch Custom CUDA kernel for searchsorted
|
||||
|
||||
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
|
||||
|
||||
|
||||
> Warnings:
|
||||
> * only works with pytorch > v1.3 and CUDA >= v10.1
|
||||
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
|
||||
|
||||
## Description
|
||||
|
||||
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
|
||||
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
|
||||
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
|
||||
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
|
||||
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
|
||||
|
||||
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
|
||||
|
||||
|
||||
**Disclaimers**
|
||||
|
||||
* This function has not been heavily tested. Use at your own risks
|
||||
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
|
||||
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
|
||||
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Just `pip install .`, in the root folder of this repo. This will compile
|
||||
and install the torchsearchsorted module.
|
||||
|
||||
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
|
||||
|
||||
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
|
||||
So I had to do:
|
||||
|
||||
> sudo apt-get install g++-8 gcc-8
|
||||
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
|
||||
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
|
||||
|
||||
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
|
||||
|
||||
## Usage
|
||||
|
||||
Just import the torchsearchsorted package after installation. I typically do:
|
||||
|
||||
```
|
||||
from torchsearchsorted import searchsorted
|
||||
```
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
Under the `examples` subfolder, you may:
|
||||
|
||||
1. try `python test.py` with `torch` available.
|
||||
|
||||
```
|
||||
Looking for 50000x1000 values in 50000x300 entries
|
||||
NUMPY: searchsorted in 4851.592ms
|
||||
CPU: searchsorted in 4805.432ms
|
||||
difference between CPU and NUMPY: 0.000
|
||||
GPU: searchsorted in 1.055ms
|
||||
difference between GPU and NUMPY: 0.000
|
||||
|
||||
Looking for 50000x1000 values in 50000x300 entries
|
||||
NUMPY: searchsorted in 4333.964ms
|
||||
CPU: searchsorted in 4753.958ms
|
||||
difference between CPU and NUMPY: 0.000
|
||||
GPU: searchsorted in 0.391ms
|
||||
difference between GPU and NUMPY: 0.000
|
||||
```
|
||||
The first run comprises the time of allocation, while the second one does not.
|
||||
|
||||
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
|
||||
|
||||
```
|
||||
Benchmark searchsorted:
|
||||
- a [5000 x 300]
|
||||
- v [5000 x 100]
|
||||
- reporting fastest time of 20 runs
|
||||
- each run executes searchsorted 100 times
|
||||
|
||||
Numpy: 4.6302046799100935
|
||||
CPU: 5.041533078998327
|
||||
CUDA: 0.0007955809123814106
|
||||
```
|
|
@ -1,71 +0,0 @@
|
|||
import timeit
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||
|
||||
B = 5_000
|
||||
A = 300
|
||||
V = 100
|
||||
|
||||
repeats = 20
|
||||
number = 100
|
||||
|
||||
print(
|
||||
f'Benchmark searchsorted:',
|
||||
f'- a [{B} x {A}]',
|
||||
f'- v [{B} x {V}]',
|
||||
f'- reporting fastest time of {repeats} runs',
|
||||
f'- each run executes searchsorted {number} times',
|
||||
sep='\n',
|
||||
end='\n\n'
|
||||
)
|
||||
|
||||
|
||||
def get_arrays():
|
||||
a = np.sort(np.random.randn(B, A), axis=1)
|
||||
v = np.random.randn(B, V)
|
||||
out = np.empty_like(v, dtype=np.long)
|
||||
return a, v, out
|
||||
|
||||
|
||||
def get_tensors(device):
|
||||
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
|
||||
v = torch.randn(B, V, device=device)
|
||||
out = torch.empty(B, V, device=device, dtype=torch.long)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
return a, v, out
|
||||
|
||||
def searchsorted_synchronized(a,v,out=None,side='left'):
|
||||
out = searchsorted(a,v,out,side)
|
||||
torch.cuda.synchronize()
|
||||
return out
|
||||
|
||||
numpy = timeit.repeat(
|
||||
stmt="numpy_searchsorted(a, v, side='left')",
|
||||
setup="a, v, out = get_arrays()",
|
||||
globals=globals(),
|
||||
repeat=repeats,
|
||||
number=number
|
||||
)
|
||||
print('Numpy: ', min(numpy), sep='\t')
|
||||
|
||||
cpu = timeit.repeat(
|
||||
stmt="searchsorted(a, v, out, side='left')",
|
||||
setup="a, v, out = get_tensors(device='cpu')",
|
||||
globals=globals(),
|
||||
repeat=repeats,
|
||||
number=number
|
||||
)
|
||||
print('CPU: ', min(cpu), sep='\t')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu = timeit.repeat(
|
||||
stmt="searchsorted_synchronized(a, v, out, side='left')",
|
||||
setup="a, v, out = get_tensors(device='cuda')",
|
||||
globals=globals(),
|
||||
repeat=repeats,
|
||||
number=number
|
||||
)
|
||||
print('CUDA: ', min(gpu), sep='\t')
|
|
@ -1,66 +0,0 @@
|
|||
import torch
|
||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
# defining the number of tests
|
||||
ntests = 2
|
||||
|
||||
# defining the problem dimensions
|
||||
nrows_a = 50000
|
||||
nrows_v = 50000
|
||||
nsorted_values = 300
|
||||
nvalues = 1000
|
||||
|
||||
# defines the variables. The first run will comprise allocation, the
|
||||
# further ones will not
|
||||
test_GPU = None
|
||||
test_CPU = None
|
||||
|
||||
for ntest in range(ntests):
|
||||
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
|
||||
nrows_a,
|
||||
nsorted_values))
|
||||
|
||||
side = 'right'
|
||||
# generate a matrix with sorted rows
|
||||
a = torch.randn(nrows_a, nsorted_values, device='cpu')
|
||||
a = torch.sort(a, dim=1)[0]
|
||||
# generate a matrix of values to searchsort
|
||||
v = torch.randn(nrows_v, nvalues, device='cpu')
|
||||
|
||||
# a = torch.tensor([[0., 1.]])
|
||||
# v = torch.tensor([[1.]])
|
||||
|
||||
t0 = time.time()
|
||||
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
|
||||
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||
t0 = time.time()
|
||||
test_CPU = searchsorted(a, v, test_CPU, side)
|
||||
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||
# compute the difference between both
|
||||
error_CPU = torch.norm(test_NP.double()
|
||||
- test_CPU.double()).numpy()
|
||||
if error_CPU:
|
||||
import ipdb; ipdb.set_trace()
|
||||
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print('CUDA is not available on this machine, cannot go further.')
|
||||
continue
|
||||
else:
|
||||
# now do the CPU
|
||||
a = a.to('cuda')
|
||||
v = v.to('cuda')
|
||||
torch.cuda.synchronize()
|
||||
# launch searchsorted on those
|
||||
t0 = time.time()
|
||||
test_GPU = searchsorted(a, v, test_GPU, side)
|
||||
torch.cuda.synchronize()
|
||||
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
|
||||
|
||||
# compute the difference between both
|
||||
error_CUDA = torch.norm(test_NP.to('cuda').double()
|
||||
- test_GPU.double()).cpu().numpy()
|
||||
|
||||
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
|
|
@ -1,41 +0,0 @@
|
|||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||
|
||||
# In any case, include the CPU version
|
||||
modules = [
|
||||
CppExtension('torchsearchsorted.cpu',
|
||||
['src/cpu/searchsorted_cpu_wrapper.cpp']),
|
||||
]
|
||||
|
||||
# If nvcc is available, add the CUDA extension
|
||||
if CUDA_HOME:
|
||||
modules.append(
|
||||
CUDAExtension('torchsearchsorted.cuda',
|
||||
['src/cuda/searchsorted_cuda_wrapper.cpp',
|
||||
'src/cuda/searchsorted_cuda_kernel.cu'])
|
||||
)
|
||||
|
||||
tests_require = [
|
||||
'pytest',
|
||||
]
|
||||
|
||||
# Now proceed to setup
|
||||
setup(
|
||||
name='torchsearchsorted',
|
||||
version='1.1',
|
||||
description='A searchsorted implementation for pytorch',
|
||||
keywords='searchsorted',
|
||||
author='Antoine Liutkus',
|
||||
author_email='antoine.liutkus@inria.fr',
|
||||
packages=find_packages(where='src'),
|
||||
package_dir={"": "src"},
|
||||
ext_modules=modules,
|
||||
tests_require=tests_require,
|
||||
extras_require={
|
||||
'test': tests_require,
|
||||
},
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
}
|
||||
)
|
|
@ -1,126 +0,0 @@
|
|||
#include "searchsorted_cpu_wrapper.h"
|
||||
#include <stdio.h>
|
||||
|
||||
template<typename scalar_t>
|
||||
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
||||
{
|
||||
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
||||
|
||||
if (col == ncol - 1)
|
||||
{
|
||||
// special case: we are on the right border
|
||||
if (a[row * ncol + col] <= val){
|
||||
return 1;}
|
||||
else {
|
||||
return -1;}
|
||||
}
|
||||
bool is_lower;
|
||||
bool is_next_higher;
|
||||
|
||||
if (side_left) {
|
||||
// a[row, col] < v <= a[row, col+1]
|
||||
is_lower = (a[row * ncol + col] < val);
|
||||
is_next_higher = (a[row*ncol + col + 1] >= val);
|
||||
} else {
|
||||
// a[row, col] <= v < a[row, col+1]
|
||||
is_lower = (a[row * ncol + col] <= val);
|
||||
is_next_higher = (a[row * ncol + col + 1] > val);
|
||||
}
|
||||
if (is_lower && is_next_higher) {
|
||||
// we found the right spot
|
||||
return 0;
|
||||
} else if (is_lower) {
|
||||
// answer is on the right side
|
||||
return 1;
|
||||
} else {
|
||||
// answer is on the left side
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
||||
{
|
||||
/* Look for the value `val` within row `row` of matrix `a`, which
|
||||
has `ncol` columns.
|
||||
|
||||
the `a` matrix is assumed sorted in increasing order, row-wise
|
||||
|
||||
returns:
|
||||
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
||||
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
||||
* Otherwise, return the column index `res` such that:
|
||||
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
||||
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
||||
*/
|
||||
|
||||
//start with left at 0 and right at number of columns of a
|
||||
int64_t right = ncol;
|
||||
int64_t left = 0;
|
||||
|
||||
while (right >= left) {
|
||||
// take the midpoint of current left and right cursors
|
||||
int64_t mid = left + (right-left)/2;
|
||||
|
||||
// check the relative position of val: are we good here ?
|
||||
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
||||
// we found the point
|
||||
if(rel_pos == 0) {
|
||||
return mid;
|
||||
} else if (rel_pos > 0) {
|
||||
if (mid==ncol-1){return ncol-1;}
|
||||
// the answer is on the right side
|
||||
left = mid;
|
||||
} else {
|
||||
if (mid==0){return -1;}
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void searchsorted_cpu_wrapper(
|
||||
at::Tensor a,
|
||||
at::Tensor v,
|
||||
at::Tensor res,
|
||||
bool side_left)
|
||||
{
|
||||
|
||||
// Get the dimensions
|
||||
auto nrow_a = a.size(/*dim=*/0);
|
||||
auto ncol_a = a.size(/*dim=*/1);
|
||||
auto nrow_v = v.size(/*dim=*/0);
|
||||
auto ncol_v = v.size(/*dim=*/1);
|
||||
|
||||
auto nrow_res = fmax(nrow_a, nrow_v);
|
||||
|
||||
//auto acc_v = v.accessor<float, 2>();
|
||||
//auto acc_res = res.accessor<float, 2>();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
|
||||
|
||||
scalar_t* a_data = a.data_ptr<scalar_t>();
|
||||
scalar_t* v_data = v.data_ptr<scalar_t>();
|
||||
int64_t* res_data = res.data<int64_t>();
|
||||
|
||||
for (int64_t row = 0; row < nrow_res; row++)
|
||||
{
|
||||
for (int64_t col = 0; col < ncol_v; col++)
|
||||
{
|
||||
// get the value to look for
|
||||
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
|
||||
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
|
||||
|
||||
int64_t idx_in_v = row_in_v * ncol_v + col;
|
||||
int64_t idx_in_res = row * ncol_v + col;
|
||||
|
||||
// apply binary search
|
||||
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
#ifndef _SEARCHSORTED_CPU
|
||||
#define _SEARCHSORTED_CPU
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void searchsorted_cpu_wrapper(
|
||||
at::Tensor a,
|
||||
at::Tensor v,
|
||||
at::Tensor res,
|
||||
bool side_left);
|
||||
|
||||
#endif
|
|
@ -1,142 +0,0 @@
|
|||
#include "searchsorted_cuda_kernel.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__
|
||||
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
|
||||
{
|
||||
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
|
||||
|
||||
if (col == ncol - 1)
|
||||
{
|
||||
// special case: we are on the right border
|
||||
if (a[row * ncol + col] <= val){
|
||||
return 1;}
|
||||
else {
|
||||
return -1;}
|
||||
}
|
||||
bool is_lower;
|
||||
bool is_next_higher;
|
||||
|
||||
if (side_left) {
|
||||
// a[row, col] < v <= a[row, col+1]
|
||||
is_lower = (a[row * ncol + col] < val);
|
||||
is_next_higher = (a[row*ncol + col + 1] >= val);
|
||||
} else {
|
||||
// a[row, col] <= v < a[row, col+1]
|
||||
is_lower = (a[row * ncol + col] <= val);
|
||||
is_next_higher = (a[row * ncol + col + 1] > val);
|
||||
}
|
||||
if (is_lower && is_next_higher) {
|
||||
// we found the right spot
|
||||
return 0;
|
||||
} else if (is_lower) {
|
||||
// answer is on the right side
|
||||
return 1;
|
||||
} else {
|
||||
// answer is on the left side
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__
|
||||
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
|
||||
{
|
||||
/* Look for the value `val` within row `row` of matrix `a`, which
|
||||
has `ncol` columns.
|
||||
|
||||
the `a` matrix is assumed sorted in increasing order, row-wise
|
||||
|
||||
Returns
|
||||
* -1 if `val` is smaller than the smallest value found within that row of `a`
|
||||
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
|
||||
* Otherwise, return the column index `res` such that:
|
||||
- a[row, col] < val <= a[row, col+1]. (if side_left), or
|
||||
- a[row, col] < val <= a[row, col+1] (if not side_left).
|
||||
*/
|
||||
|
||||
//start with left at 0 and right at number of columns of a
|
||||
int64_t right = ncol;
|
||||
int64_t left = 0;
|
||||
|
||||
while (right >= left) {
|
||||
// take the midpoint of current left and right cursors
|
||||
int64_t mid = left + (right-left)/2;
|
||||
|
||||
// check the relative position of val: are we good here ?
|
||||
int rel_pos = eval(val, a, row, mid, ncol, side_left);
|
||||
// we found the point
|
||||
if(rel_pos == 0) {
|
||||
return mid;
|
||||
} else if (rel_pos > 0) {
|
||||
if (mid==ncol-1){return ncol-1;}
|
||||
// the answer is on the right side
|
||||
left = mid;
|
||||
} else {
|
||||
if (mid==0){return -1;}
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__
|
||||
void searchsorted_kernel(
|
||||
int64_t *res,
|
||||
scalar_t *a,
|
||||
scalar_t *v,
|
||||
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
|
||||
{
|
||||
// get current row and column
|
||||
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
|
||||
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
|
||||
|
||||
// check whether we are outside the bounds of what needs be computed.
|
||||
if ((row >= nrow_res) || (col >= ncol_v)) {
|
||||
return;}
|
||||
|
||||
// get the value to look for
|
||||
int64_t row_in_v = (nrow_v==1) ? 0: row;
|
||||
int64_t row_in_a = (nrow_a==1) ? 0: row;
|
||||
int64_t idx_in_v = row_in_v*ncol_v+col;
|
||||
int64_t idx_in_res = row*ncol_v+col;
|
||||
|
||||
// apply binary search
|
||||
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
|
||||
}
|
||||
|
||||
|
||||
void searchsorted_cuda(
|
||||
at::Tensor a,
|
||||
at::Tensor v,
|
||||
at::Tensor res,
|
||||
bool side_left){
|
||||
|
||||
// Get the dimensions
|
||||
auto nrow_a = a.size(/*dim=*/0);
|
||||
auto nrow_v = v.size(/*dim=*/0);
|
||||
auto ncol_a = a.size(/*dim=*/1);
|
||||
auto ncol_v = v.size(/*dim=*/1);
|
||||
|
||||
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
|
||||
|
||||
// prepare the kernel configuration
|
||||
dim3 threads(ncol_v, nrow_res);
|
||||
dim3 blocks(1, 1);
|
||||
if (nrow_res*ncol_v > 1024){
|
||||
threads.x = int(fmin(double(1024), double(ncol_v)));
|
||||
threads.y = floor(1024/threads.x);
|
||||
blocks.x = ceil(double(ncol_v)/double(threads.x));
|
||||
blocks.y = ceil(double(nrow_res)/double(threads.y));
|
||||
}
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
|
||||
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
|
||||
res.data<int64_t>(),
|
||||
a.data<scalar_t>(),
|
||||
v.data<scalar_t>(),
|
||||
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
|
||||
}));
|
||||
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
#ifndef _SEARCHSORTED_CUDA_KERNEL
|
||||
#define _SEARCHSORTED_CUDA_KERNEL
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void searchsorted_cuda(
|
||||
at::Tensor a,
|
||||
at::Tensor v,
|
||||
at::Tensor res,
|
||||
bool side_left);
|
||||
|
||||
#endif
|
|
@ -1,20 +0,0 @@
|
|||
#include "searchsorted_cuda_wrapper.h"
|
||||
|
||||
// C++ interface
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
|
||||
{
|
||||
CHECK_INPUT(a);
|
||||
CHECK_INPUT(v);
|
||||
CHECK_INPUT(res);
|
||||
|
||||
searchsorted_cuda(a, v, res, side_left);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
#ifndef _SEARCHSORTED_CUDA_WRAPPER
|
||||
#define _SEARCHSORTED_CUDA_WRAPPER
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "searchsorted_cuda_kernel.h"
|
||||
|
||||
void searchsorted_cuda_wrapper(
|
||||
at::Tensor a,
|
||||
at::Tensor v,
|
||||
at::Tensor res,
|
||||
bool side_left);
|
||||
|
||||
#endif
|
|
@ -1,2 +0,0 @@
|
|||
from .searchsorted import searchsorted
|
||||
from .utils import numpy_searchsorted
|
|
@ -1,53 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# trying to import the CPU searchsorted
|
||||
SEARCHSORTED_CPU_AVAILABLE = True
|
||||
try:
|
||||
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
|
||||
except ImportError:
|
||||
SEARCHSORTED_CPU_AVAILABLE = False
|
||||
|
||||
# trying to import the CUDA searchsorted
|
||||
SEARCHSORTED_GPU_AVAILABLE = True
|
||||
try:
|
||||
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
|
||||
except ImportError:
|
||||
SEARCHSORTED_GPU_AVAILABLE = False
|
||||
|
||||
|
||||
def searchsorted(a: torch.Tensor, v: torch.Tensor,
|
||||
out: Optional[torch.LongTensor] = None,
|
||||
side='left') -> torch.LongTensor:
|
||||
assert len(a.shape) == 2, "input `a` must be 2-D."
|
||||
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
|
||||
assert (a.shape[0] == v.shape[0]
|
||||
or a.shape[0] == 1
|
||||
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
|
||||
"rows or one of them must have only one ")
|
||||
assert a.device == v.device, '`a` and `v` must be on the same device'
|
||||
|
||||
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
|
||||
if out is not None:
|
||||
assert out.device == a.device, "`out` must be on the same device as `a`"
|
||||
assert out.dtype == torch.long, "out.dtype must be torch.long"
|
||||
assert out.shape == result_shape, ("If the output tensor is provided, "
|
||||
"its shape must be correct.")
|
||||
else:
|
||||
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
|
||||
|
||||
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
|
||||
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
|
||||
'that it is not available. Please install it')
|
||||
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
|
||||
raise Exception('torchsearchsorted on CPU is not available. '
|
||||
'Please install it.')
|
||||
|
||||
left_side = 1 if side=='left' else 0
|
||||
if a.is_cuda:
|
||||
searchsorted_cuda_wrapper(a, v, out, left_side)
|
||||
else:
|
||||
searchsorted_cpu_wrapper(a, v, out, left_side)
|
||||
|
||||
return out
|
|
@ -1,15 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
|
||||
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
|
||||
"""
|
||||
nrows_a = a.shape[0]
|
||||
(nrows_v, ncols_v) = v.shape
|
||||
nrows_out = max(nrows_a, nrows_v)
|
||||
out = np.empty((nrows_out, ncols_v), dtype=np.long)
|
||||
def sel(data, row):
|
||||
return data[0] if data.shape[0] == 1 else data[row]
|
||||
for row in range(nrows_out):
|
||||
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
|
||||
return out
|
|
@ -1,11 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
devices = {'cpu': torch.device('cpu')}
|
||||
if torch.cuda.is_available():
|
||||
devices['cuda'] = torch.device('cuda:0')
|
||||
|
||||
|
||||
@pytest.fixture(params=devices.values(), ids=devices.keys())
|
||||
def device(request):
|
||||
return request.param
|
|
@ -1,44 +0,0 @@
|
|||
import pytest
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||||
from itertools import product, repeat
|
||||
|
||||
|
||||
def test_searchsorted_output_dtype(device):
|
||||
B = 100
|
||||
A = 50
|
||||
V = 12
|
||||
|
||||
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
|
||||
v = torch.rand(B, A, device=device)
|
||||
|
||||
out = searchsorted(a, v)
|
||||
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
|
||||
assert out.dtype == torch.long
|
||||
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||||
|
||||
out = torch.empty(v.shape, dtype=torch.long, device=device)
|
||||
searchsorted(a, v, out)
|
||||
assert out.dtype == torch.long
|
||||
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||||
|
||||
Ba_val = [1, 100, 200]
|
||||
Bv_val = [1, 100, 200]
|
||||
A_val = [1, 50, 500]
|
||||
V_val = [1, 12, 120]
|
||||
side_val = ['left', 'right']
|
||||
nrepeat = 100
|
||||
|
||||
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
|
||||
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
|
||||
if Ba > 1 and Bv > 1 and Ba != Bv:
|
||||
return
|
||||
for test in range(nrepeat):
|
||||
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
|
||||
v = torch.rand(Bv, V, device=device)
|
||||
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
|
||||
side=side)
|
||||
out = searchsorted(a, v, side=side).cpu().numpy()
|
||||
np.testing.assert_array_equal(out, out_np)
|
Loading…
Reference in a new issue