remove internal rendering hooks

This commit is contained in:
Alexander Soare 2024-03-20 09:23:23 +00:00
parent d16f6a93b3
commit b1ec3da035
5 changed files with 0 additions and 22 deletions

View File

@ -27,7 +27,6 @@ class AbstractEnv(EnvBase):
self.image_size = image_size self.image_size = image_size
self.num_prev_obs = num_prev_obs self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only: if pixels_only:
assert from_pixels assert from_pixels
@ -45,16 +44,6 @@ class AbstractEnv(EnvBase):
raise NotImplementedError() raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action) # self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod @abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480): def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError() raise NotImplementedError()

View File

@ -164,7 +164,6 @@ class AlohaEnv(AbstractEnv):
batch_size=[], batch_size=[],
) )
self.call_rendering_hooks()
return td return td
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
@ -189,8 +188,6 @@ class AlohaEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict( td = TensorDict(
{ {
"observation": TensorDict(obs, batch_size=[]), "observation": TensorDict(obs, batch_size=[]),

View File

@ -116,7 +116,6 @@ class PushtEnv(AbstractEnv):
batch_size=[], batch_size=[],
) )
self.call_rendering_hooks()
return td return td
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
@ -139,8 +138,6 @@ class PushtEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict( td = TensorDict(
{ {
"observation": TensorDict(obs, batch_size=[]), "observation": TensorDict(obs, batch_size=[]),

View File

@ -118,7 +118,6 @@ class SimxarmEnv(AbstractEnv):
else: else:
raise NotImplementedError() raise NotImplementedError()
self.call_rendering_hooks()
return td return td
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
@ -152,8 +151,6 @@ class SimxarmEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict( td = TensorDict(
{ {
"observation": self._format_raw_obs(raw_obs), "observation": self._format_raw_obs(raw_obs),

View File

@ -101,8 +101,6 @@ def eval_policy(
if return_first_video and i == 0: if return_first_video and i == 0:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
for thread in threads: for thread in threads:
thread.join() thread.join()