Minor bugfix in sample_pdf
This commit is contained in:
parent
f913df5946
commit
581ea38d78
1 changed files with 2 additions and 2 deletions
|
@ -225,7 +225,7 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
|
||||||
u = u.contiguous()
|
u = u.contiguous()
|
||||||
inds = searchsorted(cdf, u, side='right')
|
inds = searchsorted(cdf, u, side='right')
|
||||||
below = torch.max(torch.zeros_like(inds-1), inds-1)
|
below = torch.max(torch.zeros_like(inds-1), inds-1)
|
||||||
above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds), inds)
|
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
|
||||||
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
||||||
|
|
||||||
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
|
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
|
||||||
|
@ -239,4 +239,4 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
|
||||||
t = (u-cdf_g[...,0])/denom
|
t = (u-cdf_g[...,0])/denom
|
||||||
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
|
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
Loading…
Reference in a new issue