backup wip

This commit is contained in:
Alexander Soare 2024-04-10 11:28:44 +01:00
parent 91d6f5a6c3
commit 6e14b85747
3 changed files with 13 additions and 2 deletions

View File

@ -16,7 +16,7 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
# convert to (b c h w) torch format # convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w") img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
obs[imgkey] = img obs[imgkey] = img
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
@ -33,7 +33,9 @@ def postprocess_action(action, transform=None):
action = action.to("cpu") action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim), # action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same # we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy() if transform is not None:
action = apply_inverse_transform({"action": action}, transform)["action"]
action = action.numpy()
assert ( assert (
action.ndim == 2 action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions" ), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@ -183,7 +183,13 @@ class TDMPCPolicy(nn.Module):
def act(self, obs, t0=False, step=None): def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
# TODO(now): This is for compatibility with official weights. Remove.
# obs['rgb'] = obs['rgb'] * 255
# obs_ = torch.load('/tmp/obs.pth')
# out_ = torch.load('/tmp/out.pth')
# breakpoint()
z = self.model.encode(obs) z = self.model.encode(obs)
# breakpoint()
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
else: else:

View File

@ -125,6 +125,7 @@ def eval_policy(
# apply inverse transform to unnormalize the action # apply inverse transform to unnormalize the action
action = postprocess_action(action, transform) action = postprocess_action(action, transform)
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
# apply the next # apply the next
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
@ -249,6 +250,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.") logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation. # TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform transform = make_dataset(cfg, stats_path=stats_path).transform
# TODO(now)
transform = None
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)