12 lines
251 B
Python
12 lines
251 B
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
devices = {'cpu': torch.device('cpu')}
|
||
|
if torch.cuda.is_available():
|
||
|
devices['cuda'] = torch.device('cuda:0')
|
||
|
|
||
|
|
||
|
@pytest.fixture(params=devices.values(), ids=devices.keys())
|
||
|
def device(request):
|
||
|
return request.param
|