You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

16 lines
530 B
Python

import numpy as np
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
"""
nrows_a = a.shape[0]
(nrows_v, ncols_v) = v.shape
nrows_out = max(nrows_a, nrows_v)
out = np.empty((nrows_out, ncols_v), dtype=np.long)
def sel(data, row):
return data[0] if data.shape[0] == 1 else data[row]
for row in range(nrows_out):
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
return out