backup wip
This commit is contained in:
parent
91d6f5a6c3
commit
6e14b85747
|
@ -16,7 +16,7 @@ def preprocess_observation(observation, transform=None):
|
|||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img).float()
|
||||
# convert to (b c h w) torch format
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
obs[imgkey] = img
|
||||
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
|
@ -33,7 +33,9 @@ def postprocess_action(action, transform=None):
|
|||
action = action.to("cpu")
|
||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||
# we assume applying inverse transform on a batch works the same
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||
if transform is not None:
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"]
|
||||
action = action.numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
|
|
|
@ -183,7 +183,13 @@ class TDMPCPolicy(nn.Module):
|
|||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||
# TODO(now): This is for compatibility with official weights. Remove.
|
||||
# obs['rgb'] = obs['rgb'] * 255
|
||||
# obs_ = torch.load('/tmp/obs.pth')
|
||||
# out_ = torch.load('/tmp/out.pth')
|
||||
# breakpoint()
|
||||
z = self.model.encode(obs)
|
||||
# breakpoint()
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
|
|
|
@ -125,6 +125,7 @@ def eval_policy(
|
|||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
|
||||
|
||||
# apply the next
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
@ -249,6 +250,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
# TODO(now)
|
||||
transform = None
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
|
Loading…
Reference in New Issue