temporary_fixes for gym-lowcostrobot
This commit is contained in:
parent
04029f5e74
commit
9a5356d0ac
|
@ -18,6 +18,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
##############################################
|
||||
### TODO this script is modified to hackathon purposes and should be reset after.
|
||||
##############################################
|
||||
|
||||
PIXELS_KEY="image_front"
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
|
@ -28,27 +33,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
#if PIXELS_KEY in observations:
|
||||
# if isinstance(observations[PIXELS_KEY], dict):
|
||||
# imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
# else:
|
||||
# imgs = {"observation.image": observations["pixels"]}
|
||||
imgs = {"observation.images.image_front": observations["image_front"]}
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
|
@ -58,5 +59,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["arm_qpos"]).float()
|
||||
return return_observations
|
||||
|
|
|
@ -137,6 +137,8 @@ class TDMPCPolicy(
|
|||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
#TODO michel_aractingi temp fix to remove before merge
|
||||
del batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -343,7 +345,7 @@ class TDMPCPolicy(
|
|||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
reward = batch["reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
|
@ -420,7 +422,7 @@ class TDMPCPolicy(
|
|||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
|
@ -441,7 +443,7 @@ class TDMPCPolicy(
|
|||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
|
|
|
@ -591,7 +591,6 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
|
|||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
env.reset(seed=seeds[from_idx].item())
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
|
|
|
@ -158,14 +158,14 @@ def rollout(
|
|||
action = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
# Apply the next action. TODO (michel_aractingi) temp fix
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
if False and "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
|
Loading…
Reference in New Issue