remove internal rendering hooks

This commit is contained in:
Alexander Soare 2024-03-20 09:23:23 +00:00
parent d16f6a93b3
commit b1ec3da035
5 changed files with 0 additions and 22 deletions

View File

@ -27,7 +27,6 @@ class AbstractEnv(EnvBase):
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
@ -45,16 +44,6 @@ class AbstractEnv(EnvBase):
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()

View File

@ -164,7 +164,6 @@ class AlohaEnv(AbstractEnv):
batch_size=[],
)
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -189,8 +188,6 @@ class AlohaEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),

View File

@ -116,7 +116,6 @@ class PushtEnv(AbstractEnv):
batch_size=[],
)
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -139,8 +138,6 @@ class PushtEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),

View File

@ -118,7 +118,6 @@ class SimxarmEnv(AbstractEnv):
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -152,8 +151,6 @@ class SimxarmEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),

View File

@ -101,8 +101,6 @@ def eval_policy(
if return_first_video and i == 0:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
for thread in threads:
thread.join()