diff --git a/model/warplayer.py b/model/warplayer.py index 21b0b90..f3fff44 100644 --- a/model/warplayer.py +++ b/model/warplayer.py @@ -1,19 +1,25 @@ import torch import torch.nn as nn -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") backwarp_tenGrid = {} def warp(tenInput, tenFlow): - k = (str(tenFlow.device), str(tenFlow.size())) + flow_device = tenFlow.device + k = (str(flow_device), str(tenFlow.size())) if k not in backwarp_tenGrid: - tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=flow_device).view( 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) - tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=flow_device).view( 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) backwarp_tenGrid[k] = torch.cat( - [tenHorizontal, tenVertical], 1).to(device) + [tenHorizontal, tenVertical], 1).to(flow_device) tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)