remove internal rendering hooks
This commit is contained in:
parent
d16f6a93b3
commit
b1ec3da035
|
@ -27,7 +27,6 @@ class AbstractEnv(EnvBase):
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.num_prev_obs = num_prev_obs
|
self.num_prev_obs = num_prev_obs
|
||||||
self.num_prev_action = num_prev_action
|
self.num_prev_action = num_prev_action
|
||||||
self._rendering_hooks = []
|
|
||||||
|
|
||||||
if pixels_only:
|
if pixels_only:
|
||||||
assert from_pixels
|
assert from_pixels
|
||||||
|
@ -45,16 +44,6 @@ class AbstractEnv(EnvBase):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
def register_rendering_hook(self, func):
|
|
||||||
self._rendering_hooks.append(func)
|
|
||||||
|
|
||||||
def call_rendering_hooks(self):
|
|
||||||
for func in self._rendering_hooks:
|
|
||||||
func(self)
|
|
||||||
|
|
||||||
def reset_rendering_hooks(self):
|
|
||||||
self._rendering_hooks = []
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def render(self, mode="rgb_array", width=640, height=480):
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -164,7 +164,6 @@ class AlohaEnv(AbstractEnv):
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
|
@ -189,8 +188,6 @@ class AlohaEnv(AbstractEnv):
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
obs = stacked_obs
|
obs = stacked_obs
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": TensorDict(obs, batch_size=[]),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
|
|
@ -116,7 +116,6 @@ class PushtEnv(AbstractEnv):
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
|
@ -139,8 +138,6 @@ class PushtEnv(AbstractEnv):
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
obs = stacked_obs
|
obs = stacked_obs
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": TensorDict(obs, batch_size=[]),
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
|
|
@ -118,7 +118,6 @@ class SimxarmEnv(AbstractEnv):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
|
@ -152,8 +151,6 @@ class SimxarmEnv(AbstractEnv):
|
||||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
obs = stacked_obs
|
obs = stacked_obs
|
||||||
|
|
||||||
self.call_rendering_hooks()
|
|
||||||
|
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
{
|
{
|
||||||
"observation": self._format_raw_obs(raw_obs),
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
|
|
@ -101,8 +101,6 @@ def eval_policy(
|
||||||
if return_first_video and i == 0:
|
if return_first_video and i == 0:
|
||||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
env.reset_rendering_hooks()
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue