temporary_fixes for gym-lowcostrobot

This commit is contained in:
Michel Aractingi 2024-10-23 00:24:07 +02:00
parent 04029f5e74
commit 9a5356d0ac
4 changed files with 31 additions and 29 deletions

View File

@ -18,6 +18,11 @@ import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
##############################################
### TODO this script is modified to hackathon purposes and should be reset after.
##############################################
PIXELS_KEY="image_front"
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
@ -28,27 +33,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
""" """
# map to expected inputs for the policy # map to expected inputs for the policy
return_observations = {} return_observations = {}
if "pixels" in observations: #if PIXELS_KEY in observations:
if isinstance(observations["pixels"], dict): # if isinstance(observations[PIXELS_KEY], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} # imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else: # else:
imgs = {"observation.image": observations["pixels"]} # imgs = {"observation.image": observations["pixels"]}
imgs = {"observation.images.image_front": observations["image_front"]}
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
img = torch.from_numpy(img) img = torch.from_numpy(img)
# sanity check that images are channel last # sanity check that images are channel last
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8 # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1] # convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32) img = img.type(torch.float32)
img /= 255 img /= 255
return_observations[imgkey] = img return_observations[imgkey] = img
if "environment_state" in observations: if "environment_state" in observations:
@ -58,5 +59,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos" # requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() return_observations["observation.state"] = torch.from_numpy(observations["arm_qpos"]).float()
return return_observations return return_observations

View File

@ -137,6 +137,8 @@ class TDMPCPolicy(
if self._use_image: if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key] batch["observation.image"] = batch[self.input_image_key]
#TODO michel_aractingi temp fix to remove before merge
del batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -343,7 +345,7 @@ class TDMPCPolicy(
batch[key] = batch[key].transpose(1, 0) batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b, action_dim) action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b) reward = batch["reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")} observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations. # Apply random image augmentations.
@ -420,7 +422,7 @@ class TDMPCPolicy(
( (
temporal_loss_coeffs temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none") * F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"] * ~batch["reward_is_pad"]
# `reward_preds` depends on the current observation and the actions. # `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
@ -441,7 +443,7 @@ class TDMPCPolicy(
* ~batch["observation.state_is_pad"][0] * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"] * ~batch["reward_is_pad"]
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
) )
.sum(0) .sum(0)

View File

@ -591,7 +591,6 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
from_idx = dataset.episode_data_index["from"][episode].item() from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item() to_idx = dataset.episode_data_index["to"][episode].item()
env.reset(seed=seeds[from_idx].item()) env.reset(seed=seeds[from_idx].item())
logging.info("Replaying episode") logging.info("Replaying episode")
say("Replaying episode", blocking=True) say("Replaying episode", blocking=True)
for idx in range(from_idx, to_idx): for idx in range(from_idx, to_idx):

View File

@ -158,14 +158,14 @@ def rollout(
action = action.to("cpu").numpy() action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action. # Apply the next action. TODO (michel_aractingi) temp fix
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
if render_callback is not None: if render_callback is not None:
render_callback(env) render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished. # available of none of the envs finished.
if "final_info" in info: if False and "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]] successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs