Add possibility for the policy to provide a sequence of actions to the env
This commit is contained in:
parent
4c400b41a5
commit
fddd9f0311
|
@ -157,13 +157,20 @@ class PushtEnv(EnvBase):
|
||||||
|
|
||||||
def _step(self, tensordict: TensorDict):
|
def _step(self, tensordict: TensorDict):
|
||||||
td = tensordict
|
td = tensordict
|
||||||
# remove batch dim
|
action = td["action"].numpy()
|
||||||
action = td["action"].squeeze(0).numpy()
|
|
||||||
# step expects shape=(4,) so we pad if necessary
|
# step expects shape=(4,) so we pad if necessary
|
||||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
for _ in range(self.frame_skip):
|
|
||||||
raw_obs, reward, done, info = self._env.step(action)
|
if action.ndim == 1:
|
||||||
|
action = action.repeat(self.frame_skip, 1)
|
||||||
|
else:
|
||||||
|
if self.frame_skip > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
num_action_steps = action.shape[0]
|
||||||
|
for i in range(num_action_steps):
|
||||||
|
raw_obs, reward, done, info = self._env.step(action[i])
|
||||||
sum_reward += reward
|
sum_reward += reward
|
||||||
|
|
||||||
obs = self._format_raw_obs(raw_obs)
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
|
@ -12,8 +12,6 @@ from diffusion_policy.model.vision.model_getter import get_resnet
|
||||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
FIRST_ACTION = 0
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
class DiffusionPolicy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -110,8 +108,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
}
|
}
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
|
|
||||||
# TODO(rcadene): add possibility to return >1 timestemps
|
action = out["action"].squeeze(0)
|
||||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
|
|
Loading…
Reference in New Issue