Source code for lagom.utils.dtype
import numpy as np
import torch
[docs]def tensorify(x, device):
if torch.is_tensor(x):
if str(x.device) != str(device):
x = x.to(device)
return x
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).float().to(device)
else:
return torch.from_numpy(np.asarray(x)).float().to(device)
[docs]def numpify(x, dtype):
if torch.is_tensor(x):
return x.detach().cpu().numpy().astype(dtype)
elif isinstance(x, np.ndarray):
return x.astype(dtype)
else:
return np.asarray(x, dtype=dtype)