Merge pull request #16 from Cadene/user/aliberts/2024_03_09_integrate_diffusion_policy

Integrate diffusion policy
This commit is contained in:
Remi 2024-03-10 17:02:16 +01:00 committed by GitHub
commit d4ea4f0ad1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 3642 additions and 58 deletions

View File

@ -69,10 +69,7 @@ jobs:
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
run: |
poetry install --no-interaction --no-root
git clone https://github.com/real-stanford/diffusion_policy
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
run: poetry install --no-interaction --no-root
- name: Save cached venv
if: |
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&

3
.gitignore vendored
View File

@ -1,6 +1,3 @@
# Custom
diffusion_policy
# Logging
logs
tmp

View File

@ -1,4 +1,4 @@
exclude: ^(data/|tests/|diffusion_policy/)
exclude: ^(data/|tests/)
default_language_version:
python: python3.10
repos:

View File

@ -24,12 +24,6 @@ mkdir ~/tmp
export TMPDIR='~/tmp'
```
Install `diffusion_policy` #HACK
```
# from this directory
git clone https://github.com/real-stanford/diffusion_policy
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
```
## Usage

View File

@ -8,8 +8,6 @@ import pymunk
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
@ -17,11 +15,12 @@ from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
@ -49,8 +48,10 @@ def add_tee(
angle,
scale=30,
color="LightSlateGray",
mask=DEFAULT_TEE_MASK,
mask=None,
):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [

View File

@ -18,7 +18,7 @@ def make_env(cfg, transform=None):
kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv
elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.pusht.env import PushtEnv
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."

View File

@ -17,7 +17,6 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
class PushtEnv(EnvBase):
@ -45,17 +44,15 @@ class PushtEnv(EnvBase):
if from_pixels:
assert image_size
if not _has_diffpolicy:
raise ImportError("Cannot import diffusion_policy.")
if not _has_gym:
raise ImportError("Cannot import gym.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
self._env = PushTImageEnv(render_size=self.image_size)

View File

@ -0,0 +1,378 @@
import collections
import cv2
import gym
import numpy as np
import pygame
import pymunk
import pymunk.pygame_util
import shapely.geometry as sg
import skimage.transform as st
from gym import spaces
from pymunk.vec2d import Vec2d
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
def pymunk_to_shapely(body, shapes):
geoms = []
for shape in shapes:
if isinstance(shape, pymunk.shapes.Poly):
verts = [body.local_to_world(v) for v in shape.get_vertices()]
verts += [verts[0]]
geoms.append(sg.Polygon(verts))
else:
raise RuntimeError(f"Unsupported shape type {type(shape)}")
geom = sg.MultiPolygon(geoms)
return geom
class PushTEnv(gym.Env):
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
reward_range = (0.0, 1.0)
def __init__(
self,
legacy=False,
block_cog=None,
damping=None,
render_action=True,
render_size=96,
reset_to_state=None,
):
self._seed = None
self.seed()
self.window_size = ws = 512 # The size of the PyGame window
self.render_size = render_size
self.sim_hz = 100
# Local controller params.
self.k_p, self.k_v = 100, 20 # PD control.z
self.control_hz = self.metadata["video.frames_per_second"]
# legcay set_state for data compatibility
self.legacy = legacy
# agent_pos, block_pos, block_angle
self.observation_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
shape=(5,),
dtype=np.float64,
)
# positional goal for agent
self.action_space = spaces.Box(
low=np.array([0, 0], dtype=np.float64),
high=np.array([ws, ws], dtype=np.float64),
shape=(2,),
dtype=np.float64,
)
self.block_cog = block_cog
self.damping = damping
self.render_action = render_action
"""
If human-rendering is used, `self.window` will be a reference
to the window that we draw to. `self.clock` will be a clock that is used
to ensure that the environment is rendered at the correct framerate in
human-mode. They will remain `None` until human-mode is used for the
first time.
"""
self.window = None
self.clock = None
self.screen = None
self.space = None
self.teleop = None
self.render_buffer = None
self.latest_action = None
self.reset_to_state = reset_to_state
def reset(self):
seed = self._seed
self._setup()
if self.block_cog is not None:
self.block.center_of_gravity = self.block_cog
if self.damping is not None:
self.space.damping = self.damping
# use legacy RandomState for compatibility
state = self.reset_to_state
if state is None:
rs = np.random.RandomState(seed=seed)
state = np.array(
[
rs.randint(50, 450),
rs.randint(50, 450),
rs.randint(100, 400),
rs.randint(100, 400),
rs.randn() * 2 * np.pi - np.pi,
]
)
self._set_state(state)
observation = self._get_obs()
return observation
def step(self, action):
dt = 1.0 / self.sim_hz
self.n_contact_points = 0
n_steps = self.sim_hz // self.control_hz
if action is not None:
self.latest_action = action
for _ in range(n_steps):
# Step PD control.
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
Vec2d(0, 0) - self.agent.velocity
)
self.agent.velocity += acceleration * dt
# Step physics.
self.space.step(dt)
# compute reward
goal_body = self._get_goal_pose_body(self.goal_pose)
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward = np.clip(coverage / self.success_threshold, 0, 1)
done = coverage > self.success_threshold
observation = self._get_obs()
info = self._get_info()
return observation, reward, done, info
def render(self, mode):
return self._render_frame(mode)
def teleop_agent(self):
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
def act(obs):
act = None
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
if self.teleop or (mouse_position - self.agent.position).length < 30:
self.teleop = True
act = mouse_position
return act
return TeleopAgent(act)
def _get_obs(self):
obs = np.array(
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
)
return obs
def _get_goal_pose_body(self, pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def _get_info(self):
n_steps = self.sim_hz // self.control_hz
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
info = {
"pos_agent": np.array(self.agent.position),
"vel_agent": np.array(self.agent.velocity),
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
"goal_pose": self.goal_pose,
"n_contacts": n_contact_points_per_step,
}
return info
def _render_frame(self, mode):
if self.window is None and mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.window_size, self.window_size))
if self.clock is None and mode == "human":
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
self.screen = canvas
draw_options = DrawOptions(canvas)
# Draw goal pose.
goal_body = self._get_goal_pose_body(self.goal_pose)
for shape in self.block.shapes:
goal_points = [
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
for v in shape.get_vertices()
]
goal_points += [goal_points[0]]
pygame.draw.polygon(canvas, self.goal_color, goal_points)
# Draw agent and block.
self.space.debug_draw(draw_options)
if mode == "human":
# The following line copies our drawings from `canvas` to the visible window
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# the clock is already ticked during in step for "human"
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
img = cv2.resize(img, (self.render_size, self.render_size))
if self.render_action and self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
return img
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
def seed(self, seed=None):
if seed is None:
seed = np.random.randint(0, 25536)
self._seed = seed
self.np_random = np.random.default_rng(seed)
def _handle_collision(self, arbiter, space, data):
self.n_contact_points += len(arbiter.contact_point_set.points)
def _set_state(self, state):
if isinstance(state, np.ndarray):
state = state.tolist()
pos_agent = state[:2]
pos_block = state[2:4]
rot_block = state[4]
self.agent.position = pos_agent
# setting angle rotates with respect to center of mass
# therefore will modify the geometric position
# if not the same as CoM
# therefore should be modified first.
if self.legacy:
# for compatibility with legacy data
self.block.position = pos_block
self.block.angle = rot_block
else:
self.block.angle = rot_block
self.block.position = pos_block
# Run physics to take effect
self.space.step(1.0 / self.sim_hz)
def _set_state_local(self, state_local):
agent_pos_local = state_local[:2]
block_pose_local = state_local[2:]
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
agent_pos_new = tf_img_new(agent_pos_local)
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
self._set_state(new_state)
return new_state
def _setup(self):
self.space = pymunk.Space()
self.space.gravity = 0, 0
self.space.damping = 0
self.teleop = False
self.render_buffer = []
# Add walls.
walls = [
self._add_segment((5, 506), (5, 5), 2),
self._add_segment((5, 5), (506, 5), 2),
self._add_segment((506, 5), (506, 506), 2),
self._add_segment((5, 506), (506, 506), 2),
]
self.space.add(*walls)
# Add agent, block, and goal zone.
self.agent = self.add_circle((256, 400), 15)
self.block = self.add_tee((256, 300), 0)
self.goal_color = pygame.Color("LightGreen")
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
# Add collision handling
self.collision_handeler = self.space.add_collision_handler(0, 0)
self.collision_handeler.post_solve = self._handle_collision
self.n_contact_points = 0
self.max_score = 50 * 100
self.success_threshold = 0.95 # 95% coverage.
def _add_segment(self, a, b, radius):
shape = pymunk.Segment(self.space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_circle(self, position, radius):
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
body.position = position
body.friction = 1
shape = pymunk.Circle(body, radius)
shape.color = pygame.Color("RoyalBlue")
self.space.add(body, shape)
return body
def add_box(self, position, height, width):
mass = 1
inertia = pymunk.moment_for_box(mass, (height, width))
body = pymunk.Body(mass, inertia)
body.position = position
shape = pymunk.Poly.create_box(body, (height, width))
shape.color = pygame.Color("LightSlateGray")
self.space.add(body, shape)
return body
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
self.space.add(body, shape1, shape2)
return body

View File

@ -0,0 +1,55 @@
import cv2
import numpy as np
from gym import spaces
from lerobot.common.envs.pusht.pusht_env import PushTEnv
class PushTImageEnv(PushTEnv):
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
super().__init__(
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
)
ws = self.window_size
self.observation_space = spaces.Dict(
{
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
}
)
self.render_cache = None
def _get_obs(self):
img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self.render_cache = img
return obs
def render(self, mode):
assert mode == "rgb_array"
if self.render_cache is None:
self._get_obs()
return self.render_cache

View File

@ -0,0 +1,244 @@
# ----------------------------------------------------------------------------
# pymunk
# Copyright (c) 2007-2016 Victor Blomqvist
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ----------------------------------------------------------------------------
"""This submodule contains helper functions to help with quick prototyping
using pymunk together with pygame.
Intended to help with debugging and prototyping, not for actual production use
in a full application. The methods contained in this module is opinionated
about your coordinate system and not in any way optimized.
"""
__docformat__ = "reStructuredText"
__all__ = [
"DrawOptions",
"get_mouse_pos",
"to_pygame",
"from_pygame",
# "lighten",
"positive_y_is_up",
]
from typing import Sequence, Tuple
import numpy as np
import pygame
import pymunk
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
positive_y_is_up: bool = False
"""Make increasing values of y point upwards.
When True::
y
^
| . (3, 3)
|
| . (2, 2)
|
+------ > x
When False::
+------ > x
|
| . (2, 2)
|
| . (3, 3)
v
y
"""
class DrawOptions(pymunk.SpaceDebugDrawOptions):
def __init__(self, surface: pygame.Surface) -> None:
"""Draw a pymunk.Space on a pygame.Surface object.
Typical usage::
>>> import pymunk
>>> surface = pygame.Surface((10,10))
>>> space = pymunk.Space()
>>> options = pymunk.pygame_util.DrawOptions(surface)
>>> space.debug_draw(options)
You can control the color of a shape by setting shape.color to the color
you want it drawn in::
>>> c = pymunk.Circle(None, 10)
>>> c.color = pygame.Color("pink")
See pygame_util.demo.py for a full example
Since pygame uses a coordinate system where y points down (in contrast
to many other cases), you either have to make the physics simulation
with Pymunk also behave in that way, or flip everything when you draw.
The easiest is probably to just make the simulation behave the same
way as Pygame does. In that way all coordinates used are in the same
orientation and easy to reason about::
>>> space = pymunk.Space()
>>> space.gravity = (0, -1000)
>>> body = pymunk.Body()
>>> body.position = (0, 0) # will be positioned in the top left corner
>>> space.debug_draw(options)
To flip the drawing its possible to set the module property
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
the simulation upside down before drawing::
>>> positive_y_is_up = True
>>> body = pymunk.Body()
>>> body.position = (0, 0)
>>> # Body will be position in bottom left corner
:Parameters:
surface : pygame.Surface
Surface that the objects will be drawn on
"""
self.surface = surface
super().__init__()
def draw_circle(
self,
pos: Vec2d,
angle: float,
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
p = to_pygame(pos, self.surface)
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
# p2 = to_pygame(circle_edge, self.surface)
# line_r = 2 if radius > 20 else 1
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
p1 = to_pygame(a, self.surface)
p2 = to_pygame(b, self.surface)
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
def draw_fat_segment(
self,
a: Tuple[float, float],
b: Tuple[float, float],
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
p1 = to_pygame(a, self.surface)
p2 = to_pygame(b, self.surface)
r = round(max(1, radius * 2))
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
if r > 2:
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
if orthog[0] == 0 and orthog[1] == 0:
return
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
orthog[0] = round(orthog[0] * scale)
orthog[1] = round(orthog[1] * scale)
points = [
(p1[0] - orthog[0], p1[1] - orthog[1]),
(p1[0] + orthog[0], p1[1] + orthog[1]),
(p2[0] + orthog[0], p2[1] + orthog[1]),
(p2[0] - orthog[0], p2[1] - orthog[1]),
]
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
pygame.draw.circle(
self.surface,
fill_color.as_int(),
(round(p1[0]), round(p1[1])),
round(radius),
)
pygame.draw.circle(
self.surface,
fill_color.as_int(),
(round(p2[0]), round(p2[1])),
round(radius),
)
def draw_polygon(
self,
verts: Sequence[Tuple[float, float]],
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
ps = [to_pygame(v, self.surface) for v in verts]
ps += [ps[0]]
radius = 2
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
if radius > 0:
for i in range(len(verts)):
a = verts[i]
b = verts[(i + 1) % len(verts)]
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
p = to_pygame(pos, self.surface)
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
"""Get position of the mouse pointer in pymunk coordinates."""
p = pygame.mouse.get_pos()
return from_pygame(p, surface)
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
"""Convenience method to convert pymunk coordinates to pygame surface
local coordinates.
Note that in case positive_y_is_up is False, this function won't actually do
anything except converting the point to integers.
"""
if positive_y_is_up:
return round(p[0]), surface.get_height() - round(p[1])
else:
return round(p[0]), round(p[1])
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
"""Convenience method to convert pygame surface local coordinates to
pymunk coordinates
"""
return to_pygame(p, surface)
def light_color(color: SpaceDebugColor):
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
return color

View File

@ -5,11 +5,33 @@ import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from einops import reduce
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class BaseImagePolicy(ModuleAttrMixin):
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
obs_dict:
str: B,To,*
return: B,Ta,Da
"""
raise NotImplementedError()
# reset state for stateful policies
def reset(self):
pass
# ========== training ===========
# no standard training interface except setting normalizer
def set_normalizer(self, normalizer: LinearNormalizer):
raise NotImplementedError()
class DiffusionUnetImagePolicy(BaseImagePolicy):

View File

@ -0,0 +1,286 @@
import logging
from typing import Union
import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class ConditionalResidualBlock1D(nn.Module):
def __init__(
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange("batch t -> batch t 1"),
)
# make sure dimensions compatible
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
else:
out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(
self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=None,
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList(
[
# down encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
# up encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
)
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None,
global_cond=None,
**kwargs,
):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
sample = einops.rearrange(sample, "b h t -> b t h")
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
# encode local features
h_local = []
if local_cond is not None:
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
return x

View File

@ -0,0 +1,47 @@
import torch.nn as nn
# from einops.layers.torch import Rearrange
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
# def test():
# cb = Conv1dBlock(256, 128, kernel_size=3)
# x = torch.zeros((1,256,16))
# o = cb(x)

View File

@ -0,0 +1,294 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import lerobot.common.policies.diffusion.model.tensor_utils as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape # noqa: N806
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = inputs.shape[0] // self.num_crops
out = tu.reshape_dimensions(
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
)
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = "{}".format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width, self.num_crops
)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
)
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None,) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
class DictOfTensorMixin(nn.Module):
def __init__(self, params_dict=None):
super().__init__()
if params_dict is None:
params_dict = nn.ParameterDict()
self.params_dict = params_dict
@property
def device(self):
return next(iter(self.parameters())).device
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
def dfs_add(dest, keys, value: torch.Tensor):
if len(keys) == 1:
dest[keys[0]] = value
return
if keys[0] not in dest:
dest[keys[0]] = nn.ParameterDict()
dfs_add(dest[keys[0]], keys[1:], value)
def load_dict(state_dict, prefix):
out_dict = nn.ParameterDict()
for key, value in state_dict.items():
value: torch.Tensor
if key.startswith(prefix):
param_keys = key[len(prefix) :].split(".")[1:]
# if len(param_keys) == 0:
# import pdb; pdb.set_trace()
dfs_add(out_dict, param_keys, value.clone())
return out_dict
self.params_dict = load_dict(state_dict, prefix + "params_dict")
self.params_dict.requires_grad_(False)
return

View File

@ -0,0 +1,84 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
# old_all_dataptrs = set()
# for param in new_model.parameters():
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# old_all_dataptrs.add(data_ptr)
# all_dataptrs = set()
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# all_dataptrs.add(data_ptr)
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@ -0,0 +1,46 @@
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs,
):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
)

View File

@ -0,0 +1,65 @@
import torch
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
class LowdimMaskGenerator(ModuleAttrMixin):
def __init__(
self,
action_dim,
obs_dim,
# obs mask setup
max_n_obs_steps=2,
fix_obs_steps=True,
# action mask
action_visible=False,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.max_n_obs_steps = max_n_obs_steps
self.fix_obs_steps = fix_obs_steps
self.action_visible = action_visible
@torch.no_grad()
def forward(self, shape, seed=None):
device = self.device
B, T, D = shape # noqa: N806
assert (self.action_dim + self.obs_dim) == D
# create all tensors on this device
rng = torch.Generator(device=device)
if seed is not None:
rng = rng.manual_seed(seed)
# generate dim mask
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
is_action_dim = dim_mask.clone()
is_action_dim[..., : self.action_dim] = True
is_obs_dim = ~is_action_dim
# generate obs mask
if self.fix_obs_steps:
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
else:
obs_steps = torch.randint(
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
)
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
obs_mask = obs_mask & is_obs_dim
# generate action mask
if self.action_visible:
action_steps = torch.maximum(
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
)
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
action_mask = action_mask & is_action_dim
mask = obs_mask
if self.action_visible:
mask = mask | action_mask
return mask

View File

@ -0,0 +1,15 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@ -5,9 +5,9 @@ import torch
import torch.nn as nn
import torchvision
from diffusion_policy.common.pytorch_util import replace_submodules
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class MultiImageObsEncoder(ModuleAttrMixin):

View File

@ -0,0 +1,358 @@
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
import zarr
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class LinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
if isinstance(data, dict):
for key, value in data.items():
self.params_dict[key] = _fit(
value,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
else:
self.params_dict["_default"] = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def __getitem__(self, key: str):
return SingleFieldLinearNormalizer(self.params_dict[key])
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
self.params_dict[key] = value.params_dict
def _normalize_impl(self, x, forward=True):
if isinstance(x, dict):
result = {}
for key, value in x.items():
params = self.params_dict[key]
result[key] = _normalize(value, params, forward=forward)
return result
else:
if "_default" not in self.params_dict:
raise RuntimeError("Not initialized")
params = self.params_dict["_default"]
return _normalize(x, params, forward=forward)
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=True)
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=False)
def get_input_stats(self) -> Dict:
if len(self.params_dict) == 0:
raise RuntimeError("Not initialized")
if len(self.params_dict) == 1 and "_default" in self.params_dict:
return self.params_dict["_default"]["input_stats"]
result = {}
for key, value in self.params_dict.items():
if key != "_default":
result[key] = value["input_stats"]
return result
def get_output_stats(self, key="_default"):
input_stats = self.get_input_stats()
if "min" in input_stats:
# no dict
return dict_apply(input_stats, self.normalize)
result = {}
for key, group in input_stats.items():
this_dict = {}
for name, value in group.items():
this_dict[name] = self.normalize({key: value})[key]
result[key] = this_dict
return result
class SingleFieldLinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
self.params_dict = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
@classmethod
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
obj = cls()
obj.fit(data, **kwargs)
return obj
@classmethod
def create_manual(
cls,
scale: Union[torch.Tensor, np.ndarray],
offset: Union[torch.Tensor, np.ndarray],
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
):
def to_tensor(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = x.flatten()
return x
# check
for x in [offset] + list(input_stats_dict.values()):
assert x.shape == scale.shape
assert x.dtype == scale.dtype
params_dict = nn.ParameterDict(
{
"scale": to_tensor(scale),
"offset": to_tensor(offset),
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
}
)
return cls(params_dict)
@classmethod
def create_identity(cls, dtype=torch.float32):
scale = torch.tensor([1], dtype=dtype)
offset = torch.tensor([0], dtype=dtype)
input_stats_dict = {
"min": torch.tensor([-1], dtype=dtype),
"max": torch.tensor([1], dtype=dtype),
"mean": torch.tensor([0], dtype=dtype),
"std": torch.tensor([1], dtype=dtype),
}
return cls.create_manual(scale, offset, input_stats_dict)
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=True)
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=False)
def get_input_stats(self):
return self.params_dict["input_stats"]
def get_output_stats(self):
return dict_apply(self.params_dict["input_stats"], self.normalize)
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def _fit(
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
assert mode in ["limits", "gaussian"]
assert last_n_dims >= 0
assert output_max > output_min
# convert data to torch and type
if isinstance(data, zarr.Array):
data = data[:]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if dtype is not None:
data = data.type(dtype)
# convert shape
dim = 1
if last_n_dims > 0:
dim = np.prod(data.shape[-last_n_dims:])
data = data.reshape(-1, dim)
# compute input stats min max mean std
input_min, _ = data.min(axis=0)
input_max, _ = data.max(axis=0)
input_mean = data.mean(axis=0)
input_std = data.std(axis=0)
# compute scale and offset
if mode == "limits":
if fit_offset:
# unit scale
input_range = input_max - input_min
ignore_dim = input_range < range_eps
input_range[ignore_dim] = output_max - output_min
scale = (output_max - output_min) / input_range
offset = output_min - scale * input_min
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
# ignore dims scaled to mean of output max and min
else:
# use this when data is pre-zero-centered.
assert output_max > 0
assert output_min < 0
# unit abs
output_abs = min(abs(output_min), abs(output_max))
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
ignore_dim = input_abs < range_eps
input_abs[ignore_dim] = output_abs
# don't scale constant channels
scale = output_abs / input_abs
offset = torch.zeros_like(input_mean)
elif mode == "gaussian":
ignore_dim = input_std < range_eps
scale = input_std.clone()
scale[ignore_dim] = 1
scale = 1 / scale
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
# save
this_params = nn.ParameterDict(
{
"scale": scale,
"offset": offset,
"input_stats": nn.ParameterDict(
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
),
}
)
for p in this_params.parameters():
p.requires_grad_(False)
return this_params
def _normalize(x, params, forward=True):
assert "scale" in params
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
scale = params["scale"]
offset = params["offset"]
x = x.to(device=scale.device, dtype=scale.dtype)
src_shape = x.shape
x = x.reshape(-1, scale.shape[0])
x = x * scale + offset if forward else (x - offset) / scale
x = x.reshape(src_shape)
return x
def test():
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0, atol=1e-3)
assert np.allclose(datan.min(), 0.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
data = torch.zeros((100, 10, 9, 2)).uniform_()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="gaussian", last_n_dims=0)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
assert np.allclose(datan.std(), 1.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
# dict
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = LinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
data = {
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
}
normalizer = LinearNormalizer()
normalizer.fit(data)
datan = normalizer.normalize(data)
dataun = normalizer.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
state_dict = normalizer.state_dict()
n = LinearNormalizer()
n.load_state_dict(state_dict)
datan = n.normalize(data)
dataun = n.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)

View File

@ -0,0 +1,19 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@ -0,0 +1,971 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert list not in type_func_dict
assert tuple not in type_func_dict
assert dict not in type_func_dict
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: func,
type(None): lambda x: x,
},
)
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
np.ndarray: func,
type(None): lambda x: x,
},
)
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
},
)
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
},
)
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.detach(),
},
)
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
},
)
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
},
)
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
},
)
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
},
)
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
},
)
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
},
)
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
},
)
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
},
)
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
},
)
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
},
)
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
},
)
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
},
)
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert begin_axis <= end_axis
assert begin_axis >= 0
assert end_axis < len(x.shape)
assert isinstance(target_dims, (tuple, list))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
type(None): lambda x: x,
},
)
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
type(None): lambda x: x,
},
)
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
)
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq,
{
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
type(None): lambda x: x,
},
)
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
},
)
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
return outputs

View File

@ -4,10 +4,10 @@ import time
import hydra
import torch
import torch.nn as nn
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(nn.Module):

View File

@ -0,0 +1,76 @@
from typing import Callable, Dict
import torch
import torch.nn as nn
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = "cpu"
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to("cpu")
return resnet_model
def dict_apply(
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = {}
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule(".".join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
assert len(bn_list) == 0
return root_module

View File

@ -0,0 +1,614 @@
from __future__ import annotations
import math
import numbers
import os
from functools import cached_property
import numcodecs
import numpy as np
import zarr
def check_chunks_compatible(chunks: tuple, shape: tuple):
assert len(shape) == len(chunks)
for c in chunks:
assert isinstance(c, numbers.Integral)
assert c > 0
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
old_arr = group[name]
if chunks is None:
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
check_chunks_compatible(chunks, old_arr.shape)
if compressor is None:
compressor = old_arr.compressor
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
# no change
return old_arr
# rechunk recompress
group.move(name, tmp_key)
old_arr = group[tmp_key]
n_copied, n_skipped, n_bytes_copied = zarr.copy(
source=old_arr,
dest=group,
name=name,
chunks=chunks,
compressor=compressor,
)
del group[tmp_key]
arr = group[name]
return arr
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
"""
Common shapes
T,D
T,N,D
T,H,W,C
T,N,H,W,C
"""
itemsize = np.dtype(dtype).itemsize
# reversed
rshape = list(shape[::-1])
if max_chunk_length is not None:
rshape[-1] = int(max_chunk_length)
split_idx = len(shape) - 1
for i in range(len(shape) - 1):
this_chunk_bytes = itemsize * np.prod(rshape[:i])
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
split_idx = i
rchunks = rshape[:split_idx]
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
this_max_chunk_length = rshape[split_idx]
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
rchunks.append(next_chunk_length)
len_diff = len(shape) - len(rchunks)
rchunks.extend([1] * len_diff)
chunks = tuple(rchunks[::-1])
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
return chunks
class ReplayBuffer:
"""
Zarr-based temporal datastructure.
Assumes first dimension to be time. Only chunk in time dimension.
"""
def __init__(self, root: zarr.Group | dict[str, dict]):
"""
Dummy constructor. Use copy_from* and create_from* class methods instead.
"""
assert "data" in root
assert "meta" in root
assert "episode_ends" in root["meta"]
for value in root["data"].values():
assert value.shape[0] == root["meta"]["episode_ends"][-1]
self.root = root
# ============= create constructors ===============
@classmethod
def create_empty_zarr(cls, storage=None, root=None):
if root is None:
if storage is None:
storage = zarr.MemoryStore()
root = zarr.group(store=storage)
root.require_group("data", overwrite=False)
meta = root.require_group("meta", overwrite=False)
if "episode_ends" not in meta:
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
return cls(root=root)
@classmethod
def create_empty_numpy(cls):
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
return cls(root=root)
@classmethod
def create_from_group(cls, group, **kwargs):
if "data" not in group:
# create from stratch
buffer = cls.create_empty_zarr(root=group, **kwargs)
else:
# already exist
buffer = cls(root=group, **kwargs)
return buffer
@classmethod
def create_from_path(cls, zarr_path, mode="r", **kwargs):
"""
Open a on-disk zarr directly (for dataset larger than memory).
Slower.
"""
group = zarr.open(os.path.expanduser(zarr_path), mode)
return cls.create_from_group(group, **kwargs)
# ============= copy constructors ===============
@classmethod
def copy_from_store(
cls,
src_store,
store=None,
keys=None,
chunks: dict[str, tuple] | None = None,
compressors: dict | str | numcodecs.abc.Codec | None = None,
if_exists="replace",
**kwargs,
):
"""
Load to memory.
"""
src_root = zarr.group(src_store)
if chunks is None:
chunks = {}
if compressors is None:
compressors = {}
root = None
if store is None:
# numpy backend
meta = {}
for key, value in src_root["meta"].items():
if len(value.shape) == 0:
meta[key] = np.array(value)
else:
meta[key] = value[:]
if keys is None:
keys = src_root["data"].keys()
data = {}
for key in keys:
arr = src_root["data"][key]
data[key] = arr[:]
root = {"meta": meta, "data": data}
else:
root = zarr.group(store=store)
# copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
)
data_group = root.create_group("data", overwrite=True)
if keys is None:
keys = src_root["data"].keys()
for key in keys:
value = src_root["data"][key]
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=src_store,
dest=store,
source_path=this_path,
dest_path=this_path,
if_exists=if_exists,
)
else:
# copy with recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy(
source=value,
dest=data_group,
name=key,
chunks=cks,
compressor=cpr,
if_exists=if_exists,
)
buffer = cls(root=root)
return buffer
@classmethod
def copy_from_path(
cls,
zarr_path,
backend=None,
store=None,
keys=None,
chunks: dict[str, tuple] | None = None,
compressors: dict | str | numcodecs.abc.Codec | None = None,
if_exists="replace",
**kwargs,
):
"""
Copy a on-disk zarr to in-memory compressed.
Recommended
"""
if chunks is None:
chunks = {}
if compressors is None:
compressors = {}
if backend == "numpy":
print("backend argument is deprecated!")
store = None
group = zarr.open(os.path.expanduser(zarr_path), "r")
return cls.copy_from_store(
src_store=group.store,
store=store,
keys=keys,
chunks=chunks,
compressors=compressors,
if_exists=if_exists,
**kwargs,
)
# ============= save methods ===============
def save_to_store(
self,
store,
chunks: dict[str, tuple] | None = None,
compressors: str | numcodecs.abc.Codec | dict | None = None,
if_exists="replace",
**kwargs,
):
root = zarr.group(store)
if chunks is None:
chunks = {}
if compressors is None:
compressors = {}
if self.backend == "zarr":
# recompression free copy
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=self.root.store,
dest=store,
source_path="/meta",
dest_path="/meta",
if_exists=if_exists,
)
else:
meta_group = root.create_group("meta", overwrite=True)
# save meta, no chunking
for key, value in self.root["meta"].items():
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
# save data, chunk
data_group = root.create_group("data", overwrite=True)
for key, value in self.root["data"].items():
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
if isinstance(value, zarr.Array):
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=self.root.store,
dest=store,
source_path=this_path,
dest_path=this_path,
if_exists=if_exists,
)
else:
# copy with recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy(
source=value,
dest=data_group,
name=key,
chunks=cks,
compressor=cpr,
if_exists=if_exists,
)
else:
# numpy
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
return store
def save_to_path(
self,
zarr_path,
chunks: dict[str, tuple] | None = None,
compressors: str | numcodecs.abc.Codec | dict | None = None,
if_exists="replace",
**kwargs,
):
if chunks is None:
chunks = {}
if compressors is None:
compressors = {}
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
return self.save_to_store(
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
)
@staticmethod
def resolve_compressor(compressor="default"):
if compressor == "default":
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
elif compressor == "disk":
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
return compressor
@classmethod
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
# allows compressor to be explicitly set to None
cpr = "nil"
if isinstance(compressors, dict):
if key in compressors:
cpr = cls.resolve_compressor(compressors[key])
elif isinstance(array, zarr.Array):
cpr = array.compressor
else:
cpr = cls.resolve_compressor(compressors)
# backup default
if cpr == "nil":
cpr = cls.resolve_compressor("default")
return cpr
@classmethod
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
cks = None
if isinstance(chunks, dict):
if key in chunks:
cks = chunks[key]
elif isinstance(array, zarr.Array):
cks = array.chunks
elif isinstance(chunks, tuple):
cks = chunks
else:
raise TypeError(f"Unsupported chunks type {type(chunks)}")
# backup default
if cks is None:
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
# check
check_chunks_compatible(chunks=cks, shape=array.shape)
return cks
# ============= properties =================
@cached_property
def data(self):
return self.root["data"]
@cached_property
def meta(self):
return self.root["meta"]
def update_meta(self, data):
# sanitize data
np_data = {}
for key, value in data.items():
if isinstance(value, np.ndarray):
np_data[key] = value
else:
arr = np.array(value)
if arr.dtype == object:
raise TypeError(f"Invalid value type {type(value)}")
np_data[key] = arr
meta_group = self.meta
if self.backend == "zarr":
for key, value in np_data.items():
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
)
else:
meta_group.update(np_data)
return meta_group
@property
def episode_ends(self):
return self.meta["episode_ends"]
def get_episode_idxs(self):
import numba
numba.jit(nopython=True)
def _get_episode_idxs(episode_ends):
result = np.zeros((episode_ends[-1],), dtype=np.int64)
for i in range(len(episode_ends)):
start = 0
if i > 0:
start = episode_ends[i - 1]
end = episode_ends[i]
for idx in range(start, end):
result[idx] = i
return result
return _get_episode_idxs(self.episode_ends)
@property
def backend(self):
backend = "numpy"
if isinstance(self.root, zarr.Group):
backend = "zarr"
return backend
# =========== dict-like API ==============
def __repr__(self) -> str:
if self.backend == "zarr":
return str(self.root.tree())
else:
return super().__repr__()
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def __getitem__(self, key):
return self.data[key]
def __contains__(self, key):
return key in self.data
# =========== our API ==============
@property
def n_steps(self):
if len(self.episode_ends) == 0:
return 0
return self.episode_ends[-1]
@property
def n_episodes(self):
return len(self.episode_ends)
@property
def chunk_size(self):
if self.backend == "zarr":
return next(iter(self.data.arrays()))[-1].chunks[0]
return None
@property
def episode_lengths(self):
ends = self.episode_ends[:]
ends = np.insert(ends, 0, 0)
lengths = np.diff(ends)
return lengths
def add_episode(
self,
data: dict[str, np.ndarray],
chunks: dict[str, tuple] | None = None,
compressors: str | numcodecs.abc.Codec | dict | None = None,
):
if chunks is None:
chunks = {}
if compressors is None:
compressors = {}
assert len(data) > 0
is_zarr = self.backend == "zarr"
curr_len = self.n_steps
episode_length = None
for value in data.values():
assert len(value.shape) >= 1
if episode_length is None:
episode_length = len(value)
else:
assert episode_length == len(value)
new_len = curr_len + episode_length
for key, value in data.items():
new_shape = (new_len,) + value.shape[1:]
# create array
if key not in self.data:
if is_zarr:
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
arr = self.data.zeros(
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
)
else:
# copy data to prevent modify
arr = np.zeros(shape=new_shape, dtype=value.dtype)
self.data[key] = arr
else:
arr = self.data[key]
assert value.shape[1:] == arr.shape[1:]
# same method for both zarr and numpy
if is_zarr:
arr.resize(new_shape)
else:
arr.resize(new_shape, refcheck=False)
# copy data
arr[-value.shape[0] :] = value
# append to episode ends
episode_ends = self.episode_ends
if is_zarr:
episode_ends.resize(episode_ends.shape[0] + 1)
else:
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
episode_ends[-1] = new_len
# rechunk
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
def drop_episode(self):
is_zarr = self.backend == "zarr"
episode_ends = self.episode_ends[:].copy()
assert len(episode_ends) > 0
start_idx = 0
if len(episode_ends) > 1:
start_idx = episode_ends[-2]
for value in self.data.values():
new_shape = (start_idx,) + value.shape[1:]
if is_zarr:
value.resize(new_shape)
else:
value.resize(new_shape, refcheck=False)
if is_zarr:
self.episode_ends.resize(len(episode_ends) - 1)
else:
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
def pop_episode(self):
assert self.n_episodes > 0
episode = self.get_episode(self.n_episodes - 1, copy=True)
self.drop_episode()
return episode
def extend(self, data):
self.add_episode(data)
def get_episode(self, idx, copy=False):
idx = list(range(len(self.episode_ends)))[idx]
start_idx = 0
if idx > 0:
start_idx = self.episode_ends[idx - 1]
end_idx = self.episode_ends[idx]
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
return result
def get_episode_slice(self, idx):
start_idx = 0
if idx > 0:
start_idx = self.episode_ends[idx - 1]
end_idx = self.episode_ends[idx]
return slice(start_idx, end_idx)
def get_steps_slice(self, start, stop, step=None, copy=False):
_slice = slice(start, stop, step)
result = {}
for key, value in self.data.items():
x = value[_slice]
if copy and isinstance(value, np.ndarray):
x = x.copy()
result[key] = x
return result
# =========== chunking =============
def get_chunks(self) -> dict:
assert self.backend == "zarr"
chunks = {}
for key, value in self.data.items():
chunks[key] = value.chunks
return chunks
def set_chunks(self, chunks: dict):
assert self.backend == "zarr"
for key, value in chunks.items():
if key in self.data:
arr = self.data[key]
if value != arr.chunks:
check_chunks_compatible(chunks=value, shape=arr.shape)
rechunk_recompress_array(self.data, key, chunks=value)
def get_compressors(self) -> dict:
assert self.backend == "zarr"
compressors = {}
for key, value in self.data.items():
compressors[key] = value.compressor
return compressors
def set_compressors(self, compressors: dict):
assert self.backend == "zarr"
for key, value in compressors.items():
if key in self.data:
arr = self.data[key]
compressor = self.resolve_compressor(value)
if compressor != arr.compressor:
rechunk_recompress_array(self.data, key, compressor=compressor)

View File

@ -1,6 +1,6 @@
def make_policy(cfg):
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc import TDMPC
from lerobot.common.policies.tdmpc.policy import TDMPC
policy = TDMPC(cfg.policy, cfg.device)
elif cfg.policy.name == "diffusion":

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc_helper as h
import lerobot.common.policies.tdmpc.helper as h
FIRST_FRAME = 0

View File

@ -74,7 +74,6 @@ noise_scheduler:
prediction_type: epsilon # or sample
obs_encoder:
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
shape_meta: ${shape_meta}
# resize_shape: null
# crop_shape: [76, 76]
@ -85,12 +84,12 @@ obs_encoder:
imagenet_norm: True
rgb_model:
_target_: diffusion_policy.model.vision.model_getter.get_resnet
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
name: resnet18
weights: null
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
update_after_step: 0
inv_gamma: 1.0
power: 0.75

17
poetry.lock generated
View File

@ -477,21 +477,6 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi
torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"]
training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
[[package]]
name = "diffusion_policy"
version = "0.0.0"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/real-stanford/diffusion_policy"
reference = "HEAD"
resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a"
[[package]]
name = "distlib"
version = "0.3.8"
@ -3140,4 +3125,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "9c3e86956dd11bc8d7823e5e6c5e74a073051b495f71f96179113d99791f7ca0"
content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6"

View File

@ -45,7 +45,6 @@ mujoco = "^3.1.2"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
opencv-python = "^4.9.0.80"
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
diffusers = "^0.26.3"
torchvision = "^0.17.1"
h5py = "^3.10.0"

View File

@ -3,7 +3,7 @@ from tensordict import TensorDict
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
from .utils import init_config