Integrate pusht env from diffusion
This commit is contained in:
parent
302b78962c
commit
6c867d78ef
|
@ -5,7 +5,6 @@ import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
import pymunk
|
import pymunk
|
||||||
import shapely.geometry as sg
|
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -16,29 +15,16 @@ from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
||||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
|
||||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
|
||||||
def pymunk_to_shapely(body, shapes):
|
|
||||||
geoms = []
|
|
||||||
for shape in shapes:
|
|
||||||
if isinstance(shape, pymunk.shapes.Poly):
|
|
||||||
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
|
||||||
verts += [verts[0]]
|
|
||||||
geoms.append(sg.Polygon(verts))
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
|
||||||
geom = sg.MultiPolygon(geoms)
|
|
||||||
return geom
|
|
||||||
|
|
||||||
|
|
||||||
def get_goal_pose_body(pose):
|
def get_goal_pose_body(pose):
|
||||||
mass = 1
|
mass = 1
|
||||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
@ -62,8 +48,10 @@ def add_tee(
|
||||||
angle,
|
angle,
|
||||||
scale=30,
|
scale=30,
|
||||||
color="LightSlateGray",
|
color="LightSlateGray",
|
||||||
mask=DEFAULT_TEE_MASK,
|
mask=None,
|
||||||
):
|
):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
mass = 1
|
mass = 1
|
||||||
length = 4
|
length = 4
|
||||||
vertices1 = [
|
vertices1 = [
|
||||||
|
|
|
@ -18,7 +18,7 @@ def make_env(cfg, transform=None):
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = SimxarmEnv
|
clsfunc = SimxarmEnv
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
from lerobot.common.envs.pusht.pusht import PushtEnv
|
||||||
|
|
||||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
_has_gym = importlib.util.find_spec("gym") is not None
|
_has_gym = importlib.util.find_spec("gym") is not None
|
||||||
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
|
||||||
|
|
||||||
|
|
||||||
class PushtEnv(EnvBase):
|
class PushtEnv(EnvBase):
|
||||||
|
@ -45,17 +44,15 @@ class PushtEnv(EnvBase):
|
||||||
if from_pixels:
|
if from_pixels:
|
||||||
assert image_size
|
assert image_size
|
||||||
|
|
||||||
if not _has_diffpolicy:
|
|
||||||
raise ImportError("Cannot import diffusion_policy.")
|
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
if not from_pixels:
|
if not from_pixels:
|
||||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||||
|
|
||||||
self._env = PushTImageEnv(render_size=self.image_size)
|
self._env = PushTImageEnv(render_size=self.image_size)
|
||||||
|
|
|
@ -0,0 +1,378 @@
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
import pymunk.pygame_util
|
||||||
|
import shapely.geometry as sg
|
||||||
|
import skimage.transform as st
|
||||||
|
from gym import spaces
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
|
||||||
|
|
||||||
|
|
||||||
|
def pymunk_to_shapely(body, shapes):
|
||||||
|
geoms = []
|
||||||
|
for shape in shapes:
|
||||||
|
if isinstance(shape, pymunk.shapes.Poly):
|
||||||
|
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||||
|
verts += [verts[0]]
|
||||||
|
geoms.append(sg.Polygon(verts))
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||||
|
geom = sg.MultiPolygon(geoms)
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
class PushTEnv(gym.Env):
|
||||||
|
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
|
||||||
|
reward_range = (0.0, 1.0)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
legacy=False,
|
||||||
|
block_cog=None,
|
||||||
|
damping=None,
|
||||||
|
render_action=True,
|
||||||
|
render_size=96,
|
||||||
|
reset_to_state=None,
|
||||||
|
):
|
||||||
|
self._seed = None
|
||||||
|
self.seed()
|
||||||
|
self.window_size = ws = 512 # The size of the PyGame window
|
||||||
|
self.render_size = render_size
|
||||||
|
self.sim_hz = 100
|
||||||
|
# Local controller params.
|
||||||
|
self.k_p, self.k_v = 100, 20 # PD control.z
|
||||||
|
self.control_hz = self.metadata["video.frames_per_second"]
|
||||||
|
# legcay set_state for data compatibility
|
||||||
|
self.legacy = legacy
|
||||||
|
|
||||||
|
# agent_pos, block_pos, block_angle
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
|
||||||
|
shape=(5,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# positional goal for agent
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=np.array([0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws], dtype=np.float64),
|
||||||
|
shape=(2,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_cog = block_cog
|
||||||
|
self.damping = damping
|
||||||
|
self.render_action = render_action
|
||||||
|
|
||||||
|
"""
|
||||||
|
If human-rendering is used, `self.window` will be a reference
|
||||||
|
to the window that we draw to. `self.clock` will be a clock that is used
|
||||||
|
to ensure that the environment is rendered at the correct framerate in
|
||||||
|
human-mode. They will remain `None` until human-mode is used for the
|
||||||
|
first time.
|
||||||
|
"""
|
||||||
|
self.window = None
|
||||||
|
self.clock = None
|
||||||
|
self.screen = None
|
||||||
|
|
||||||
|
self.space = None
|
||||||
|
self.teleop = None
|
||||||
|
self.render_buffer = None
|
||||||
|
self.latest_action = None
|
||||||
|
self.reset_to_state = reset_to_state
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
seed = self._seed
|
||||||
|
self._setup()
|
||||||
|
if self.block_cog is not None:
|
||||||
|
self.block.center_of_gravity = self.block_cog
|
||||||
|
if self.damping is not None:
|
||||||
|
self.space.damping = self.damping
|
||||||
|
|
||||||
|
# use legacy RandomState for compatibility
|
||||||
|
state = self.reset_to_state
|
||||||
|
if state is None:
|
||||||
|
rs = np.random.RandomState(seed=seed)
|
||||||
|
state = np.array(
|
||||||
|
[
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randn() * 2 * np.pi - np.pi,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._set_state(state)
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
dt = 1.0 / self.sim_hz
|
||||||
|
self.n_contact_points = 0
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
if action is not None:
|
||||||
|
self.latest_action = action
|
||||||
|
for _ in range(n_steps):
|
||||||
|
# Step PD control.
|
||||||
|
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
|
||||||
|
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
|
||||||
|
Vec2d(0, 0) - self.agent.velocity
|
||||||
|
)
|
||||||
|
self.agent.velocity += acceleration * dt
|
||||||
|
|
||||||
|
# Step physics.
|
||||||
|
self.space.step(dt)
|
||||||
|
|
||||||
|
# compute reward
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
|
||||||
|
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward = np.clip(coverage / self.success_threshold, 0, 1)
|
||||||
|
done = coverage > self.success_threshold
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
info = self._get_info()
|
||||||
|
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
return self._render_frame(mode)
|
||||||
|
|
||||||
|
def teleop_agent(self):
|
||||||
|
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
|
||||||
|
|
||||||
|
def act(obs):
|
||||||
|
act = None
|
||||||
|
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
|
||||||
|
if self.teleop or (mouse_position - self.agent.position).length < 30:
|
||||||
|
self.teleop = True
|
||||||
|
act = mouse_position
|
||||||
|
return act
|
||||||
|
|
||||||
|
return TeleopAgent(act)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
obs = np.array(
|
||||||
|
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
|
||||||
|
)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_goal_pose_body(self, pose):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
# preserving the legacy assignment order for compatibility
|
||||||
|
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||||
|
body.position = pose[:2].tolist()
|
||||||
|
body.angle = pose[2]
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _get_info(self):
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
|
||||||
|
info = {
|
||||||
|
"pos_agent": np.array(self.agent.position),
|
||||||
|
"vel_agent": np.array(self.agent.velocity),
|
||||||
|
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
|
||||||
|
"goal_pose": self.goal_pose,
|
||||||
|
"n_contacts": n_contact_points_per_step,
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
|
||||||
|
def _render_frame(self, mode):
|
||||||
|
if self.window is None and mode == "human":
|
||||||
|
pygame.init()
|
||||||
|
pygame.display.init()
|
||||||
|
self.window = pygame.display.set_mode((self.window_size, self.window_size))
|
||||||
|
if self.clock is None and mode == "human":
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
canvas = pygame.Surface((self.window_size, self.window_size))
|
||||||
|
canvas.fill((255, 255, 255))
|
||||||
|
self.screen = canvas
|
||||||
|
|
||||||
|
draw_options = DrawOptions(canvas)
|
||||||
|
|
||||||
|
# Draw goal pose.
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
for shape in self.block.shapes:
|
||||||
|
goal_points = [
|
||||||
|
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
|
||||||
|
for v in shape.get_vertices()
|
||||||
|
]
|
||||||
|
goal_points += [goal_points[0]]
|
||||||
|
pygame.draw.polygon(canvas, self.goal_color, goal_points)
|
||||||
|
|
||||||
|
# Draw agent and block.
|
||||||
|
self.space.debug_draw(draw_options)
|
||||||
|
|
||||||
|
if mode == "human":
|
||||||
|
# The following line copies our drawings from `canvas` to the visible window
|
||||||
|
self.window.blit(canvas, canvas.get_rect())
|
||||||
|
pygame.event.pump()
|
||||||
|
pygame.display.update()
|
||||||
|
|
||||||
|
# the clock is already ticked during in step for "human"
|
||||||
|
|
||||||
|
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
|
||||||
|
img = cv2.resize(img, (self.render_size, self.render_size))
|
||||||
|
if self.render_action and self.latest_action is not None:
|
||||||
|
action = np.array(self.latest_action)
|
||||||
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
|
thickness = int(1 / 96 * self.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
img,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.window is not None:
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
if seed is None:
|
||||||
|
seed = np.random.randint(0, 25536)
|
||||||
|
self._seed = seed
|
||||||
|
self.np_random = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
def _handle_collision(self, arbiter, space, data):
|
||||||
|
self.n_contact_points += len(arbiter.contact_point_set.points)
|
||||||
|
|
||||||
|
def _set_state(self, state):
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
state = state.tolist()
|
||||||
|
pos_agent = state[:2]
|
||||||
|
pos_block = state[2:4]
|
||||||
|
rot_block = state[4]
|
||||||
|
self.agent.position = pos_agent
|
||||||
|
# setting angle rotates with respect to center of mass
|
||||||
|
# therefore will modify the geometric position
|
||||||
|
# if not the same as CoM
|
||||||
|
# therefore should be modified first.
|
||||||
|
if self.legacy:
|
||||||
|
# for compatibility with legacy data
|
||||||
|
self.block.position = pos_block
|
||||||
|
self.block.angle = rot_block
|
||||||
|
else:
|
||||||
|
self.block.angle = rot_block
|
||||||
|
self.block.position = pos_block
|
||||||
|
|
||||||
|
# Run physics to take effect
|
||||||
|
self.space.step(1.0 / self.sim_hz)
|
||||||
|
|
||||||
|
def _set_state_local(self, state_local):
|
||||||
|
agent_pos_local = state_local[:2]
|
||||||
|
block_pose_local = state_local[2:]
|
||||||
|
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
|
||||||
|
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
|
||||||
|
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
|
||||||
|
agent_pos_new = tf_img_new(agent_pos_local)
|
||||||
|
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
|
||||||
|
self._set_state(new_state)
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
self.space = pymunk.Space()
|
||||||
|
self.space.gravity = 0, 0
|
||||||
|
self.space.damping = 0
|
||||||
|
self.teleop = False
|
||||||
|
self.render_buffer = []
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
self._add_segment((5, 506), (5, 5), 2),
|
||||||
|
self._add_segment((5, 5), (506, 5), 2),
|
||||||
|
self._add_segment((506, 5), (506, 506), 2),
|
||||||
|
self._add_segment((5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
self.space.add(*walls)
|
||||||
|
|
||||||
|
# Add agent, block, and goal zone.
|
||||||
|
self.agent = self.add_circle((256, 400), 15)
|
||||||
|
self.block = self.add_tee((256, 300), 0)
|
||||||
|
self.goal_color = pygame.Color("LightGreen")
|
||||||
|
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
|
||||||
|
# Add collision handling
|
||||||
|
self.collision_handeler = self.space.add_collision_handler(0, 0)
|
||||||
|
self.collision_handeler.post_solve = self._handle_collision
|
||||||
|
self.n_contact_points = 0
|
||||||
|
|
||||||
|
self.max_score = 50 * 100
|
||||||
|
self.success_threshold = 0.95 # 95% coverage.
|
||||||
|
|
||||||
|
def _add_segment(self, a, b, radius):
|
||||||
|
shape = pymunk.Segment(self.space.static_body, a, b, radius)
|
||||||
|
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def add_circle(self, position, radius):
|
||||||
|
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
|
||||||
|
body.position = position
|
||||||
|
body.friction = 1
|
||||||
|
shape = pymunk.Circle(body, radius)
|
||||||
|
shape.color = pygame.Color("RoyalBlue")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_box(self, position, height, width):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (height, width))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
body.position = position
|
||||||
|
shape = pymunk.Poly.create_box(body, (height, width))
|
||||||
|
shape.color = pygame.Color("LightSlateGray")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
|
mass = 1
|
||||||
|
length = 4
|
||||||
|
vertices1 = [
|
||||||
|
(-length * scale / 2, scale),
|
||||||
|
(length * scale / 2, scale),
|
||||||
|
(length * scale / 2, 0),
|
||||||
|
(-length * scale / 2, 0),
|
||||||
|
]
|
||||||
|
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
vertices2 = [
|
||||||
|
(-scale / 2, scale),
|
||||||
|
(-scale / 2, length * scale),
|
||||||
|
(scale / 2, length * scale),
|
||||||
|
(scale / 2, scale),
|
||||||
|
]
|
||||||
|
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||||
|
shape1 = pymunk.Poly(body, vertices1)
|
||||||
|
shape2 = pymunk.Poly(body, vertices2)
|
||||||
|
shape1.color = pygame.Color(color)
|
||||||
|
shape2.color = pygame.Color(color)
|
||||||
|
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||||
|
body.position = position
|
||||||
|
body.angle = angle
|
||||||
|
body.friction = 1
|
||||||
|
self.space.add(body, shape1, shape2)
|
||||||
|
return body
|
|
@ -0,0 +1,55 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
|
|
||||||
|
class PushTImageEnv(PushTEnv):
|
||||||
|
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||||
|
|
||||||
|
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
|
||||||
|
super().__init__(
|
||||||
|
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||||
|
)
|
||||||
|
ws = self.window_size
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
|
||||||
|
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.render_cache = None
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
img = super()._render_frame(mode="rgb_array")
|
||||||
|
|
||||||
|
agent_pos = np.array(self.agent.position)
|
||||||
|
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
|
||||||
|
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||||
|
|
||||||
|
# draw action
|
||||||
|
if self.latest_action is not None:
|
||||||
|
action = np.array(self.latest_action)
|
||||||
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
|
thickness = int(1 / 96 * self.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
img,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
self.render_cache = img
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
assert mode == "rgb_array"
|
||||||
|
|
||||||
|
if self.render_cache is None:
|
||||||
|
self._get_obs()
|
||||||
|
|
||||||
|
return self.render_cache
|
|
@ -0,0 +1,244 @@
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# pymunk
|
||||||
|
# Copyright (c) 2007-2016 Victor Blomqvist
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""This submodule contains helper functions to help with quick prototyping
|
||||||
|
using pymunk together with pygame.
|
||||||
|
|
||||||
|
Intended to help with debugging and prototyping, not for actual production use
|
||||||
|
in a full application. The methods contained in this module is opinionated
|
||||||
|
about your coordinate system and not in any way optimized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__docformat__ = "reStructuredText"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DrawOptions",
|
||||||
|
"get_mouse_pos",
|
||||||
|
"to_pygame",
|
||||||
|
"from_pygame",
|
||||||
|
# "lighten",
|
||||||
|
"positive_y_is_up",
|
||||||
|
]
|
||||||
|
|
||||||
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
from pymunk.space_debug_draw_options import SpaceDebugColor
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
positive_y_is_up: bool = False
|
||||||
|
"""Make increasing values of y point upwards.
|
||||||
|
|
||||||
|
When True::
|
||||||
|
|
||||||
|
y
|
||||||
|
^
|
||||||
|
| . (3, 3)
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
+------ > x
|
||||||
|
|
||||||
|
When False::
|
||||||
|
|
||||||
|
+------ > x
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
| . (3, 3)
|
||||||
|
v
|
||||||
|
y
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
||||||
|
def __init__(self, surface: pygame.Surface) -> None:
|
||||||
|
"""Draw a pymunk.Space on a pygame.Surface object.
|
||||||
|
|
||||||
|
Typical usage::
|
||||||
|
|
||||||
|
>>> import pymunk
|
||||||
|
>>> surface = pygame.Surface((10,10))
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
You can control the color of a shape by setting shape.color to the color
|
||||||
|
you want it drawn in::
|
||||||
|
|
||||||
|
>>> c = pymunk.Circle(None, 10)
|
||||||
|
>>> c.color = pygame.Color("pink")
|
||||||
|
|
||||||
|
See pygame_util.demo.py for a full example
|
||||||
|
|
||||||
|
Since pygame uses a coordinate system where y points down (in contrast
|
||||||
|
to many other cases), you either have to make the physics simulation
|
||||||
|
with Pymunk also behave in that way, or flip everything when you draw.
|
||||||
|
|
||||||
|
The easiest is probably to just make the simulation behave the same
|
||||||
|
way as Pygame does. In that way all coordinates used are in the same
|
||||||
|
orientation and easy to reason about::
|
||||||
|
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> space.gravity = (0, -1000)
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0) # will be positioned in the top left corner
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
To flip the drawing its possible to set the module property
|
||||||
|
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
||||||
|
the simulation upside down before drawing::
|
||||||
|
|
||||||
|
>>> positive_y_is_up = True
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0)
|
||||||
|
>>> # Body will be position in bottom left corner
|
||||||
|
|
||||||
|
:Parameters:
|
||||||
|
surface : pygame.Surface
|
||||||
|
Surface that the objects will be drawn on
|
||||||
|
"""
|
||||||
|
self.surface = surface
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def draw_circle(
|
||||||
|
self,
|
||||||
|
pos: Vec2d,
|
||||||
|
angle: float,
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
||||||
|
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
|
||||||
|
|
||||||
|
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
||||||
|
# p2 = to_pygame(circle_edge, self.surface)
|
||||||
|
# line_r = 2 if radius > 20 else 1
|
||||||
|
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
||||||
|
|
||||||
|
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
||||||
|
|
||||||
|
def draw_fat_segment(
|
||||||
|
self,
|
||||||
|
a: Tuple[float, float],
|
||||||
|
b: Tuple[float, float],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
r = round(max(1, radius * 2))
|
||||||
|
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
||||||
|
if r > 2:
|
||||||
|
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
||||||
|
if orthog[0] == 0 and orthog[1] == 0:
|
||||||
|
return
|
||||||
|
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
||||||
|
orthog[0] = round(orthog[0] * scale)
|
||||||
|
orthog[1] = round(orthog[1] * scale)
|
||||||
|
points = [
|
||||||
|
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
||||||
|
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
||||||
|
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
||||||
|
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
||||||
|
]
|
||||||
|
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p1[0]), round(p1[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p2[0]), round(p2[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
|
||||||
|
def draw_polygon(
|
||||||
|
self,
|
||||||
|
verts: Sequence[Tuple[float, float]],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
ps = [to_pygame(v, self.surface) for v in verts]
|
||||||
|
ps += [ps[0]]
|
||||||
|
|
||||||
|
radius = 2
|
||||||
|
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
||||||
|
|
||||||
|
if radius > 0:
|
||||||
|
for i in range(len(verts)):
|
||||||
|
a = verts[i]
|
||||||
|
b = verts[(i + 1) % len(verts)]
|
||||||
|
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
||||||
|
|
||||||
|
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Get position of the mouse pointer in pymunk coordinates."""
|
||||||
|
p = pygame.mouse.get_pos()
|
||||||
|
return from_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pymunk coordinates to pygame surface
|
||||||
|
local coordinates.
|
||||||
|
|
||||||
|
Note that in case positive_y_is_up is False, this function won't actually do
|
||||||
|
anything except converting the point to integers.
|
||||||
|
"""
|
||||||
|
if positive_y_is_up:
|
||||||
|
return round(p[0]), surface.get_height() - round(p[1])
|
||||||
|
else:
|
||||||
|
return round(p[0]), round(p[1])
|
||||||
|
|
||||||
|
|
||||||
|
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pygame surface local coordinates to
|
||||||
|
pymunk coordinates
|
||||||
|
"""
|
||||||
|
return to_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def light_color(color: SpaceDebugColor):
|
||||||
|
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
||||||
|
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
||||||
|
return color
|
|
@ -0,0 +1,84 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModel:
|
||||||
|
"""
|
||||||
|
Exponential Moving Average of models weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||||
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||||
|
at 215.4k steps).
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.averaged_model = model
|
||||||
|
self.averaged_model.eval()
|
||||||
|
self.averaged_model.requires_grad_(False)
|
||||||
|
|
||||||
|
self.update_after_step = update_after_step
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
|
self.decay = 0.0
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
|
def get_decay(self, optimization_step):
|
||||||
|
"""
|
||||||
|
Compute the decay factor for the exponential moving average.
|
||||||
|
"""
|
||||||
|
step = max(0, optimization_step - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||||
|
|
||||||
|
if step <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return max(self.min_value, min(value, self.max_value))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, new_model):
|
||||||
|
self.decay = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
|
# old_all_dataptrs = set()
|
||||||
|
# for param in new_model.parameters():
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# old_all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
# all_dataptrs = set()
|
||||||
|
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
||||||
|
for param, ema_param in zip(
|
||||||
|
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
||||||
|
):
|
||||||
|
# iterative over immediate parameters only.
|
||||||
|
if isinstance(param, dict):
|
||||||
|
raise RuntimeError("Dict parameter not supported")
|
||||||
|
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
if isinstance(module, _BatchNorm):
|
||||||
|
# skip batchnorms
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
elif not param.requires_grad:
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
else:
|
||||||
|
ema_param.mul_(self.decay)
|
||||||
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||||
|
|
||||||
|
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||||
|
# assert old_all_dataptrs == all_dataptrs
|
||||||
|
self.optimization_step += 1
|
|
@ -2,6 +2,36 @@ from typing import Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet(name, weights=None, **kwargs):
|
||||||
|
"""
|
||||||
|
name: resnet18, resnet34, resnet50
|
||||||
|
weights: "IMAGENET1K_V1", "r3m"
|
||||||
|
"""
|
||||||
|
# load r3m weights
|
||||||
|
if (weights == "r3m") or (weights == "R3M"):
|
||||||
|
return get_r3m(name=name, **kwargs)
|
||||||
|
|
||||||
|
func = getattr(torchvision.models, name)
|
||||||
|
resnet = func(weights=weights, **kwargs)
|
||||||
|
resnet.fc = torch.nn.Identity()
|
||||||
|
return resnet
|
||||||
|
|
||||||
|
|
||||||
|
def get_r3m(name, **kwargs):
|
||||||
|
"""
|
||||||
|
name: resnet18, resnet34, resnet50
|
||||||
|
"""
|
||||||
|
import r3m
|
||||||
|
|
||||||
|
r3m.device = "cpu"
|
||||||
|
model = r3m.load_r3m(name)
|
||||||
|
r3m_model = model.module
|
||||||
|
resnet_model = r3m_model.convnet
|
||||||
|
resnet_model = resnet_model.to("cpu")
|
||||||
|
return resnet_model
|
||||||
|
|
||||||
|
|
||||||
def dict_apply(
|
def dict_apply(
|
||||||
|
|
|
@ -74,7 +74,6 @@ noise_scheduler:
|
||||||
prediction_type: epsilon # or sample
|
prediction_type: epsilon # or sample
|
||||||
|
|
||||||
obs_encoder:
|
obs_encoder:
|
||||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
|
||||||
shape_meta: ${shape_meta}
|
shape_meta: ${shape_meta}
|
||||||
# resize_shape: null
|
# resize_shape: null
|
||||||
# crop_shape: [76, 76]
|
# crop_shape: [76, 76]
|
||||||
|
@ -85,12 +84,12 @@ obs_encoder:
|
||||||
imagenet_norm: True
|
imagenet_norm: True
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
|
||||||
name: resnet18
|
name: resnet18
|
||||||
weights: null
|
weights: null
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||||
update_after_step: 0
|
update_after_step: 0
|
||||||
inv_gamma: 1.0
|
inv_gamma: 1.0
|
||||||
power: 0.75
|
power: 0.75
|
||||||
|
|
|
@ -3,7 +3,7 @@ from tensordict import TensorDict
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
from lerobot.common.envs.pusht.pusht import PushtEnv
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
from .utils import init_config
|
from .utils import init_config
|
||||||
|
|
Loading…
Reference in New Issue