lerobot/lerobot/common/envs/abstract.py

81 lines
2.2 KiB
Python

import abc
from collections import deque
from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()
@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()