temporary_fixes for gym-lowcostrobot
This commit is contained in:
parent
04029f5e74
commit
9a5356d0ac
|
@ -18,6 +18,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
### TODO this script is modified to hackathon purposes and should be reset after.
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
PIXELS_KEY="image_front"
|
||||||
|
|
||||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||||
"""Convert environment observation to LeRobot format observation.
|
"""Convert environment observation to LeRobot format observation.
|
||||||
|
@ -28,27 +33,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
"""
|
"""
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
if "pixels" in observations:
|
#if PIXELS_KEY in observations:
|
||||||
if isinstance(observations["pixels"], dict):
|
# if isinstance(observations[PIXELS_KEY], dict):
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
# imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
else:
|
# else:
|
||||||
imgs = {"observation.image": observations["pixels"]}
|
# imgs = {"observation.image": observations["pixels"]}
|
||||||
|
imgs = {"observation.images.image_front": observations["image_front"]}
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
img = img.type(torch.float32)
|
img = img.type(torch.float32)
|
||||||
img /= 255
|
img /= 255
|
||||||
|
|
||||||
return_observations[imgkey] = img
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
|
@ -58,5 +59,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
# requirement for "agent_pos"
|
# requirement for "agent_pos"
|
||||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
return_observations["observation.state"] = torch.from_numpy(observations["arm_qpos"]).float()
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
|
@ -137,6 +137,8 @@ class TDMPCPolicy(
|
||||||
if self._use_image:
|
if self._use_image:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
#TODO michel_aractingi temp fix to remove before merge
|
||||||
|
del batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
@ -343,7 +345,7 @@ class TDMPCPolicy(
|
||||||
batch[key] = batch[key].transpose(1, 0)
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
action = batch["action"] # (t, b, action_dim)
|
action = batch["action"] # (t, b, action_dim)
|
||||||
reward = batch["next.reward"] # (t, b)
|
reward = batch["reward"] # (t, b)
|
||||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
# Apply random image augmentations.
|
# Apply random image augmentations.
|
||||||
|
@ -420,7 +422,7 @@ class TDMPCPolicy(
|
||||||
(
|
(
|
||||||
temporal_loss_coeffs
|
temporal_loss_coeffs
|
||||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||||
* ~batch["next.reward_is_pad"]
|
* ~batch["reward_is_pad"]
|
||||||
# `reward_preds` depends on the current observation and the actions.
|
# `reward_preds` depends on the current observation and the actions.
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
|
@ -441,7 +443,7 @@ class TDMPCPolicy(
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
# q_targets depends on the reward and the next observations.
|
# q_targets depends on the reward and the next observations.
|
||||||
* ~batch["next.reward_is_pad"]
|
* ~batch["reward_is_pad"]
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
|
|
|
@ -591,7 +591,6 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
|
||||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||||
env.reset(seed=seeds[from_idx].item())
|
env.reset(seed=seeds[from_idx].item())
|
||||||
|
|
||||||
logging.info("Replaying episode")
|
logging.info("Replaying episode")
|
||||||
say("Replaying episode", blocking=True)
|
say("Replaying episode", blocking=True)
|
||||||
for idx in range(from_idx, to_idx):
|
for idx in range(from_idx, to_idx):
|
||||||
|
|
|
@ -158,14 +158,14 @@ def rollout(
|
||||||
action = action.to("cpu").numpy()
|
action = action.to("cpu").numpy()
|
||||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
|
|
||||||
# Apply the next action.
|
# Apply the next action. TODO (michel_aractingi) temp fix
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
if render_callback is not None:
|
if render_callback is not None:
|
||||||
render_callback(env)
|
render_callback(env)
|
||||||
|
|
||||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
# available of none of the envs finished.
|
# available of none of the envs finished.
|
||||||
if "final_info" in info:
|
if False and "final_info" in info:
|
||||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
Loading…
Reference in New Issue