lerobot/lerobot/common/policies/diffusion/model/module_attr_mixin.py

16 lines
331 B
Python

import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype