You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
4 years ago
|
import pytest
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from torchsearchsorted import searchsorted, numpy_searchsorted
|
||
|
from itertools import product, repeat
|
||
|
|
||
|
|
||
|
def test_searchsorted_output_dtype(device):
|
||
|
B = 100
|
||
|
A = 50
|
||
|
V = 12
|
||
|
|
||
|
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
|
||
|
v = torch.rand(B, A, device=device)
|
||
|
|
||
|
out = searchsorted(a, v)
|
||
|
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
|
||
|
assert out.dtype == torch.long
|
||
|
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||
|
|
||
|
out = torch.empty(v.shape, dtype=torch.long, device=device)
|
||
|
searchsorted(a, v, out)
|
||
|
assert out.dtype == torch.long
|
||
|
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
|
||
|
|
||
|
Ba_val = [1, 100, 200]
|
||
|
Bv_val = [1, 100, 200]
|
||
|
A_val = [1, 50, 500]
|
||
|
V_val = [1, 12, 120]
|
||
|
side_val = ['left', 'right']
|
||
|
nrepeat = 100
|
||
|
|
||
|
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
|
||
|
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
|
||
|
if Ba > 1 and Bv > 1 and Ba != Bv:
|
||
|
return
|
||
|
for test in range(nrepeat):
|
||
|
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
|
||
|
v = torch.rand(Bv, V, device=device)
|
||
|
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
|
||
|
side=side)
|
||
|
out = searchsorted(a, v, side=side).cpu().numpy()
|
||
|
np.testing.assert_array_equal(out, out_np)
|