You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
# trying to import the CPU searchsorted
|
|
SEARCHSORTED_CPU_AVAILABLE = True
|
|
try:
|
|
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
|
|
except ImportError:
|
|
SEARCHSORTED_CPU_AVAILABLE = False
|
|
|
|
# trying to import the CUDA searchsorted
|
|
SEARCHSORTED_GPU_AVAILABLE = True
|
|
try:
|
|
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
|
|
except ImportError:
|
|
SEARCHSORTED_GPU_AVAILABLE = False
|
|
|
|
|
|
def searchsorted(a: torch.Tensor, v: torch.Tensor,
|
|
out: Optional[torch.LongTensor] = None,
|
|
side='left') -> torch.LongTensor:
|
|
assert len(a.shape) == 2, "input `a` must be 2-D."
|
|
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
|
|
assert (a.shape[0] == v.shape[0]
|
|
or a.shape[0] == 1
|
|
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
|
|
"rows or one of them must have only one ")
|
|
assert a.device == v.device, '`a` and `v` must be on the same device'
|
|
|
|
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
|
|
if out is not None:
|
|
assert out.device == a.device, "`out` must be on the same device as `a`"
|
|
assert out.dtype == torch.long, "out.dtype must be torch.long"
|
|
assert out.shape == result_shape, ("If the output tensor is provided, "
|
|
"its shape must be correct.")
|
|
else:
|
|
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
|
|
|
|
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
|
|
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
|
|
'that it is not available. Please install it')
|
|
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
|
|
raise Exception('torchsearchsorted on CPU is not available. '
|
|
'Please install it.')
|
|
|
|
left_side = 1 if side=='left' else 0
|
|
if a.is_cuda:
|
|
searchsorted_cuda_wrapper(a, v, out, left_side)
|
|
else:
|
|
searchsorted_cpu_wrapper(a, v, out, left_side)
|
|
|
|
return out
|