backup wip

This commit is contained in:
Alexander Soare 2024-04-10 11:28:44 +01:00
parent 91d6f5a6c3
commit 6e14b85747
3 changed files with 13 additions and 2 deletions

View File

@ -16,7 +16,7 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
obs[imgkey] = img
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
@ -33,7 +33,9 @@ def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
if transform is not None:
action = apply_inverse_transform({"action": action}, transform)["action"]
action = action.numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@ -183,7 +183,13 @@ class TDMPCPolicy(nn.Module):
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
# TODO(now): This is for compatibility with official weights. Remove.
# obs['rgb'] = obs['rgb'] * 255
# obs_ = torch.load('/tmp/obs.pth')
# out_ = torch.load('/tmp/out.pth')
# breakpoint()
z = self.model.encode(obs)
# breakpoint()
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:

View File

@ -125,6 +125,7 @@ def eval_policy(
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
# apply the next
observation, reward, terminated, truncated, info = env.step(action)
@ -249,6 +250,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
# TODO(now)
transform = None
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)