Merge pull request #16 from Cadene/user/aliberts/2024_03_09_integrate_diffusion_policy
Integrate diffusion policy
This commit is contained in:
commit
d4ea4f0ad1
|
@ -69,10 +69,7 @@ jobs:
|
||||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: poetry install --no-interaction --no-root
|
||||||
poetry install --no-interaction --no-root
|
|
||||||
git clone https://github.com/real-stanford/diffusion_policy
|
|
||||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
|
||||||
- name: Save cached venv
|
- name: Save cached venv
|
||||||
if: |
|
if: |
|
||||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
# Custom
|
|
||||||
diffusion_policy
|
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
exclude: ^(data/|tests/|diffusion_policy/)
|
exclude: ^(data/|tests/)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
|
|
|
@ -24,12 +24,6 @@ mkdir ~/tmp
|
||||||
export TMPDIR='~/tmp'
|
export TMPDIR='~/tmp'
|
||||||
```
|
```
|
||||||
|
|
||||||
Install `diffusion_policy` #HACK
|
|
||||||
```
|
|
||||||
# from this directory
|
|
||||||
git clone https://github.com/real-stanford/diffusion_policy
|
|
||||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@ import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
|
@ -17,11 +15,12 @@ from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||||
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
||||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
|
||||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
@ -49,8 +48,10 @@ def add_tee(
|
||||||
angle,
|
angle,
|
||||||
scale=30,
|
scale=30,
|
||||||
color="LightSlateGray",
|
color="LightSlateGray",
|
||||||
mask=DEFAULT_TEE_MASK,
|
mask=None,
|
||||||
):
|
):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
mass = 1
|
mass = 1
|
||||||
length = 4
|
length = 4
|
||||||
vertices1 = [
|
vertices1 = [
|
||||||
|
|
|
@ -18,7 +18,7 @@ def make_env(cfg, transform=None):
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = SimxarmEnv
|
clsfunc = SimxarmEnv
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
|
|
||||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
|
|
||||||
_has_gym = importlib.util.find_spec("gym") is not None
|
_has_gym = importlib.util.find_spec("gym") is not None
|
||||||
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
|
||||||
|
|
||||||
|
|
||||||
class PushtEnv(EnvBase):
|
class PushtEnv(EnvBase):
|
||||||
|
@ -45,17 +44,15 @@ class PushtEnv(EnvBase):
|
||||||
if from_pixels:
|
if from_pixels:
|
||||||
assert image_size
|
assert image_size
|
||||||
|
|
||||||
if not _has_diffpolicy:
|
|
||||||
raise ImportError("Cannot import diffusion_policy.")
|
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
if not from_pixels:
|
if not from_pixels:
|
||||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||||
|
|
||||||
self._env = PushTImageEnv(render_size=self.image_size)
|
self._env = PushTImageEnv(render_size=self.image_size)
|
||||||
|
|
|
@ -0,0 +1,378 @@
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
import pymunk.pygame_util
|
||||||
|
import shapely.geometry as sg
|
||||||
|
import skimage.transform as st
|
||||||
|
from gym import spaces
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
|
||||||
|
|
||||||
|
|
||||||
|
def pymunk_to_shapely(body, shapes):
|
||||||
|
geoms = []
|
||||||
|
for shape in shapes:
|
||||||
|
if isinstance(shape, pymunk.shapes.Poly):
|
||||||
|
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||||
|
verts += [verts[0]]
|
||||||
|
geoms.append(sg.Polygon(verts))
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||||
|
geom = sg.MultiPolygon(geoms)
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
class PushTEnv(gym.Env):
|
||||||
|
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
|
||||||
|
reward_range = (0.0, 1.0)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
legacy=False,
|
||||||
|
block_cog=None,
|
||||||
|
damping=None,
|
||||||
|
render_action=True,
|
||||||
|
render_size=96,
|
||||||
|
reset_to_state=None,
|
||||||
|
):
|
||||||
|
self._seed = None
|
||||||
|
self.seed()
|
||||||
|
self.window_size = ws = 512 # The size of the PyGame window
|
||||||
|
self.render_size = render_size
|
||||||
|
self.sim_hz = 100
|
||||||
|
# Local controller params.
|
||||||
|
self.k_p, self.k_v = 100, 20 # PD control.z
|
||||||
|
self.control_hz = self.metadata["video.frames_per_second"]
|
||||||
|
# legcay set_state for data compatibility
|
||||||
|
self.legacy = legacy
|
||||||
|
|
||||||
|
# agent_pos, block_pos, block_angle
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
|
||||||
|
shape=(5,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# positional goal for agent
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=np.array([0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws], dtype=np.float64),
|
||||||
|
shape=(2,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_cog = block_cog
|
||||||
|
self.damping = damping
|
||||||
|
self.render_action = render_action
|
||||||
|
|
||||||
|
"""
|
||||||
|
If human-rendering is used, `self.window` will be a reference
|
||||||
|
to the window that we draw to. `self.clock` will be a clock that is used
|
||||||
|
to ensure that the environment is rendered at the correct framerate in
|
||||||
|
human-mode. They will remain `None` until human-mode is used for the
|
||||||
|
first time.
|
||||||
|
"""
|
||||||
|
self.window = None
|
||||||
|
self.clock = None
|
||||||
|
self.screen = None
|
||||||
|
|
||||||
|
self.space = None
|
||||||
|
self.teleop = None
|
||||||
|
self.render_buffer = None
|
||||||
|
self.latest_action = None
|
||||||
|
self.reset_to_state = reset_to_state
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
seed = self._seed
|
||||||
|
self._setup()
|
||||||
|
if self.block_cog is not None:
|
||||||
|
self.block.center_of_gravity = self.block_cog
|
||||||
|
if self.damping is not None:
|
||||||
|
self.space.damping = self.damping
|
||||||
|
|
||||||
|
# use legacy RandomState for compatibility
|
||||||
|
state = self.reset_to_state
|
||||||
|
if state is None:
|
||||||
|
rs = np.random.RandomState(seed=seed)
|
||||||
|
state = np.array(
|
||||||
|
[
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randn() * 2 * np.pi - np.pi,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._set_state(state)
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
dt = 1.0 / self.sim_hz
|
||||||
|
self.n_contact_points = 0
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
if action is not None:
|
||||||
|
self.latest_action = action
|
||||||
|
for _ in range(n_steps):
|
||||||
|
# Step PD control.
|
||||||
|
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
|
||||||
|
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
|
||||||
|
Vec2d(0, 0) - self.agent.velocity
|
||||||
|
)
|
||||||
|
self.agent.velocity += acceleration * dt
|
||||||
|
|
||||||
|
# Step physics.
|
||||||
|
self.space.step(dt)
|
||||||
|
|
||||||
|
# compute reward
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
|
||||||
|
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward = np.clip(coverage / self.success_threshold, 0, 1)
|
||||||
|
done = coverage > self.success_threshold
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
info = self._get_info()
|
||||||
|
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
return self._render_frame(mode)
|
||||||
|
|
||||||
|
def teleop_agent(self):
|
||||||
|
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
|
||||||
|
|
||||||
|
def act(obs):
|
||||||
|
act = None
|
||||||
|
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
|
||||||
|
if self.teleop or (mouse_position - self.agent.position).length < 30:
|
||||||
|
self.teleop = True
|
||||||
|
act = mouse_position
|
||||||
|
return act
|
||||||
|
|
||||||
|
return TeleopAgent(act)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
obs = np.array(
|
||||||
|
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
|
||||||
|
)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_goal_pose_body(self, pose):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
# preserving the legacy assignment order for compatibility
|
||||||
|
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||||
|
body.position = pose[:2].tolist()
|
||||||
|
body.angle = pose[2]
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _get_info(self):
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
|
||||||
|
info = {
|
||||||
|
"pos_agent": np.array(self.agent.position),
|
||||||
|
"vel_agent": np.array(self.agent.velocity),
|
||||||
|
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
|
||||||
|
"goal_pose": self.goal_pose,
|
||||||
|
"n_contacts": n_contact_points_per_step,
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
|
||||||
|
def _render_frame(self, mode):
|
||||||
|
if self.window is None and mode == "human":
|
||||||
|
pygame.init()
|
||||||
|
pygame.display.init()
|
||||||
|
self.window = pygame.display.set_mode((self.window_size, self.window_size))
|
||||||
|
if self.clock is None and mode == "human":
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
canvas = pygame.Surface((self.window_size, self.window_size))
|
||||||
|
canvas.fill((255, 255, 255))
|
||||||
|
self.screen = canvas
|
||||||
|
|
||||||
|
draw_options = DrawOptions(canvas)
|
||||||
|
|
||||||
|
# Draw goal pose.
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
for shape in self.block.shapes:
|
||||||
|
goal_points = [
|
||||||
|
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
|
||||||
|
for v in shape.get_vertices()
|
||||||
|
]
|
||||||
|
goal_points += [goal_points[0]]
|
||||||
|
pygame.draw.polygon(canvas, self.goal_color, goal_points)
|
||||||
|
|
||||||
|
# Draw agent and block.
|
||||||
|
self.space.debug_draw(draw_options)
|
||||||
|
|
||||||
|
if mode == "human":
|
||||||
|
# The following line copies our drawings from `canvas` to the visible window
|
||||||
|
self.window.blit(canvas, canvas.get_rect())
|
||||||
|
pygame.event.pump()
|
||||||
|
pygame.display.update()
|
||||||
|
|
||||||
|
# the clock is already ticked during in step for "human"
|
||||||
|
|
||||||
|
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
|
||||||
|
img = cv2.resize(img, (self.render_size, self.render_size))
|
||||||
|
if self.render_action and self.latest_action is not None:
|
||||||
|
action = np.array(self.latest_action)
|
||||||
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
|
thickness = int(1 / 96 * self.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
img,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.window is not None:
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
if seed is None:
|
||||||
|
seed = np.random.randint(0, 25536)
|
||||||
|
self._seed = seed
|
||||||
|
self.np_random = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
def _handle_collision(self, arbiter, space, data):
|
||||||
|
self.n_contact_points += len(arbiter.contact_point_set.points)
|
||||||
|
|
||||||
|
def _set_state(self, state):
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
state = state.tolist()
|
||||||
|
pos_agent = state[:2]
|
||||||
|
pos_block = state[2:4]
|
||||||
|
rot_block = state[4]
|
||||||
|
self.agent.position = pos_agent
|
||||||
|
# setting angle rotates with respect to center of mass
|
||||||
|
# therefore will modify the geometric position
|
||||||
|
# if not the same as CoM
|
||||||
|
# therefore should be modified first.
|
||||||
|
if self.legacy:
|
||||||
|
# for compatibility with legacy data
|
||||||
|
self.block.position = pos_block
|
||||||
|
self.block.angle = rot_block
|
||||||
|
else:
|
||||||
|
self.block.angle = rot_block
|
||||||
|
self.block.position = pos_block
|
||||||
|
|
||||||
|
# Run physics to take effect
|
||||||
|
self.space.step(1.0 / self.sim_hz)
|
||||||
|
|
||||||
|
def _set_state_local(self, state_local):
|
||||||
|
agent_pos_local = state_local[:2]
|
||||||
|
block_pose_local = state_local[2:]
|
||||||
|
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
|
||||||
|
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
|
||||||
|
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
|
||||||
|
agent_pos_new = tf_img_new(agent_pos_local)
|
||||||
|
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
|
||||||
|
self._set_state(new_state)
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
self.space = pymunk.Space()
|
||||||
|
self.space.gravity = 0, 0
|
||||||
|
self.space.damping = 0
|
||||||
|
self.teleop = False
|
||||||
|
self.render_buffer = []
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
self._add_segment((5, 506), (5, 5), 2),
|
||||||
|
self._add_segment((5, 5), (506, 5), 2),
|
||||||
|
self._add_segment((506, 5), (506, 506), 2),
|
||||||
|
self._add_segment((5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
self.space.add(*walls)
|
||||||
|
|
||||||
|
# Add agent, block, and goal zone.
|
||||||
|
self.agent = self.add_circle((256, 400), 15)
|
||||||
|
self.block = self.add_tee((256, 300), 0)
|
||||||
|
self.goal_color = pygame.Color("LightGreen")
|
||||||
|
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
|
||||||
|
# Add collision handling
|
||||||
|
self.collision_handeler = self.space.add_collision_handler(0, 0)
|
||||||
|
self.collision_handeler.post_solve = self._handle_collision
|
||||||
|
self.n_contact_points = 0
|
||||||
|
|
||||||
|
self.max_score = 50 * 100
|
||||||
|
self.success_threshold = 0.95 # 95% coverage.
|
||||||
|
|
||||||
|
def _add_segment(self, a, b, radius):
|
||||||
|
shape = pymunk.Segment(self.space.static_body, a, b, radius)
|
||||||
|
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def add_circle(self, position, radius):
|
||||||
|
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
|
||||||
|
body.position = position
|
||||||
|
body.friction = 1
|
||||||
|
shape = pymunk.Circle(body, radius)
|
||||||
|
shape.color = pygame.Color("RoyalBlue")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_box(self, position, height, width):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (height, width))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
body.position = position
|
||||||
|
shape = pymunk.Poly.create_box(body, (height, width))
|
||||||
|
shape.color = pygame.Color("LightSlateGray")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
|
mass = 1
|
||||||
|
length = 4
|
||||||
|
vertices1 = [
|
||||||
|
(-length * scale / 2, scale),
|
||||||
|
(length * scale / 2, scale),
|
||||||
|
(length * scale / 2, 0),
|
||||||
|
(-length * scale / 2, 0),
|
||||||
|
]
|
||||||
|
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
vertices2 = [
|
||||||
|
(-scale / 2, scale),
|
||||||
|
(-scale / 2, length * scale),
|
||||||
|
(scale / 2, length * scale),
|
||||||
|
(scale / 2, scale),
|
||||||
|
]
|
||||||
|
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||||
|
shape1 = pymunk.Poly(body, vertices1)
|
||||||
|
shape2 = pymunk.Poly(body, vertices2)
|
||||||
|
shape1.color = pygame.Color(color)
|
||||||
|
shape2.color = pygame.Color(color)
|
||||||
|
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||||
|
body.position = position
|
||||||
|
body.angle = angle
|
||||||
|
body.friction = 1
|
||||||
|
self.space.add(body, shape1, shape2)
|
||||||
|
return body
|
|
@ -0,0 +1,55 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
|
|
||||||
|
class PushTImageEnv(PushTEnv):
|
||||||
|
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||||
|
|
||||||
|
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
|
||||||
|
super().__init__(
|
||||||
|
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||||
|
)
|
||||||
|
ws = self.window_size
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
|
||||||
|
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.render_cache = None
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
img = super()._render_frame(mode="rgb_array")
|
||||||
|
|
||||||
|
agent_pos = np.array(self.agent.position)
|
||||||
|
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
|
||||||
|
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||||
|
|
||||||
|
# draw action
|
||||||
|
if self.latest_action is not None:
|
||||||
|
action = np.array(self.latest_action)
|
||||||
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
|
thickness = int(1 / 96 * self.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
img,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
self.render_cache = img
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
assert mode == "rgb_array"
|
||||||
|
|
||||||
|
if self.render_cache is None:
|
||||||
|
self._get_obs()
|
||||||
|
|
||||||
|
return self.render_cache
|
|
@ -0,0 +1,244 @@
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# pymunk
|
||||||
|
# Copyright (c) 2007-2016 Victor Blomqvist
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""This submodule contains helper functions to help with quick prototyping
|
||||||
|
using pymunk together with pygame.
|
||||||
|
|
||||||
|
Intended to help with debugging and prototyping, not for actual production use
|
||||||
|
in a full application. The methods contained in this module is opinionated
|
||||||
|
about your coordinate system and not in any way optimized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__docformat__ = "reStructuredText"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DrawOptions",
|
||||||
|
"get_mouse_pos",
|
||||||
|
"to_pygame",
|
||||||
|
"from_pygame",
|
||||||
|
# "lighten",
|
||||||
|
"positive_y_is_up",
|
||||||
|
]
|
||||||
|
|
||||||
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
from pymunk.space_debug_draw_options import SpaceDebugColor
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
positive_y_is_up: bool = False
|
||||||
|
"""Make increasing values of y point upwards.
|
||||||
|
|
||||||
|
When True::
|
||||||
|
|
||||||
|
y
|
||||||
|
^
|
||||||
|
| . (3, 3)
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
+------ > x
|
||||||
|
|
||||||
|
When False::
|
||||||
|
|
||||||
|
+------ > x
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
| . (3, 3)
|
||||||
|
v
|
||||||
|
y
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
||||||
|
def __init__(self, surface: pygame.Surface) -> None:
|
||||||
|
"""Draw a pymunk.Space on a pygame.Surface object.
|
||||||
|
|
||||||
|
Typical usage::
|
||||||
|
|
||||||
|
>>> import pymunk
|
||||||
|
>>> surface = pygame.Surface((10,10))
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
You can control the color of a shape by setting shape.color to the color
|
||||||
|
you want it drawn in::
|
||||||
|
|
||||||
|
>>> c = pymunk.Circle(None, 10)
|
||||||
|
>>> c.color = pygame.Color("pink")
|
||||||
|
|
||||||
|
See pygame_util.demo.py for a full example
|
||||||
|
|
||||||
|
Since pygame uses a coordinate system where y points down (in contrast
|
||||||
|
to many other cases), you either have to make the physics simulation
|
||||||
|
with Pymunk also behave in that way, or flip everything when you draw.
|
||||||
|
|
||||||
|
The easiest is probably to just make the simulation behave the same
|
||||||
|
way as Pygame does. In that way all coordinates used are in the same
|
||||||
|
orientation and easy to reason about::
|
||||||
|
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> space.gravity = (0, -1000)
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0) # will be positioned in the top left corner
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
To flip the drawing its possible to set the module property
|
||||||
|
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
||||||
|
the simulation upside down before drawing::
|
||||||
|
|
||||||
|
>>> positive_y_is_up = True
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0)
|
||||||
|
>>> # Body will be position in bottom left corner
|
||||||
|
|
||||||
|
:Parameters:
|
||||||
|
surface : pygame.Surface
|
||||||
|
Surface that the objects will be drawn on
|
||||||
|
"""
|
||||||
|
self.surface = surface
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def draw_circle(
|
||||||
|
self,
|
||||||
|
pos: Vec2d,
|
||||||
|
angle: float,
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
||||||
|
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
|
||||||
|
|
||||||
|
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
||||||
|
# p2 = to_pygame(circle_edge, self.surface)
|
||||||
|
# line_r = 2 if radius > 20 else 1
|
||||||
|
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
||||||
|
|
||||||
|
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
||||||
|
|
||||||
|
def draw_fat_segment(
|
||||||
|
self,
|
||||||
|
a: Tuple[float, float],
|
||||||
|
b: Tuple[float, float],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
r = round(max(1, radius * 2))
|
||||||
|
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
||||||
|
if r > 2:
|
||||||
|
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
||||||
|
if orthog[0] == 0 and orthog[1] == 0:
|
||||||
|
return
|
||||||
|
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
||||||
|
orthog[0] = round(orthog[0] * scale)
|
||||||
|
orthog[1] = round(orthog[1] * scale)
|
||||||
|
points = [
|
||||||
|
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
||||||
|
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
||||||
|
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
||||||
|
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
||||||
|
]
|
||||||
|
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p1[0]), round(p1[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p2[0]), round(p2[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
|
||||||
|
def draw_polygon(
|
||||||
|
self,
|
||||||
|
verts: Sequence[Tuple[float, float]],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
ps = [to_pygame(v, self.surface) for v in verts]
|
||||||
|
ps += [ps[0]]
|
||||||
|
|
||||||
|
radius = 2
|
||||||
|
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
||||||
|
|
||||||
|
if radius > 0:
|
||||||
|
for i in range(len(verts)):
|
||||||
|
a = verts[i]
|
||||||
|
b = verts[(i + 1) % len(verts)]
|
||||||
|
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
||||||
|
|
||||||
|
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Get position of the mouse pointer in pymunk coordinates."""
|
||||||
|
p = pygame.mouse.get_pos()
|
||||||
|
return from_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pymunk coordinates to pygame surface
|
||||||
|
local coordinates.
|
||||||
|
|
||||||
|
Note that in case positive_y_is_up is False, this function won't actually do
|
||||||
|
anything except converting the point to integers.
|
||||||
|
"""
|
||||||
|
if positive_y_is_up:
|
||||||
|
return round(p[0]), surface.get_height() - round(p[1])
|
||||||
|
else:
|
||||||
|
return round(p[0]), round(p[1])
|
||||||
|
|
||||||
|
|
||||||
|
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pygame surface local coordinates to
|
||||||
|
pymunk coordinates
|
||||||
|
"""
|
||||||
|
return to_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def light_color(color: SpaceDebugColor):
|
||||||
|
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
||||||
|
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
||||||
|
return color
|
|
@ -5,11 +5,33 @@ import torch.nn.functional as F # noqa: N812
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from einops import reduce
|
from einops import reduce
|
||||||
|
|
||||||
from diffusion_policy.common.pytorch_util import dict_apply
|
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
||||||
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
|
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
|
||||||
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
|
||||||
|
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImagePolicy(ModuleAttrMixin):
|
||||||
|
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
|
||||||
|
|
||||||
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
obs_dict:
|
||||||
|
str: B,To,*
|
||||||
|
return: B,Ta,Da
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# reset state for stateful policies
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ========== training ===========
|
||||||
|
# no standard training interface except setting normalizer
|
||||||
|
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
|
|
|
@ -0,0 +1,286 @@
|
||||||
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
|
||||||
|
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalResidualBlock1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
|
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||||
|
# predicts per-channel scale and bias
|
||||||
|
cond_channels = out_channels
|
||||||
|
if cond_predict_scale:
|
||||||
|
cond_channels = out_channels * 2
|
||||||
|
self.cond_predict_scale = cond_predict_scale
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.cond_encoder = nn.Sequential(
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(cond_dim, cond_channels),
|
||||||
|
Rearrange("batch t -> batch t 1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure dimensions compatible
|
||||||
|
self.residual_conv = (
|
||||||
|
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, cond):
|
||||||
|
"""
|
||||||
|
x : [ batch_size x in_channels x horizon ]
|
||||||
|
cond : [ batch_size x cond_dim]
|
||||||
|
|
||||||
|
returns:
|
||||||
|
out : [ batch_size x out_channels x horizon ]
|
||||||
|
"""
|
||||||
|
out = self.blocks[0](x)
|
||||||
|
embed = self.cond_encoder(cond)
|
||||||
|
if self.cond_predict_scale:
|
||||||
|
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||||
|
scale = embed[:, 0, ...]
|
||||||
|
bias = embed[:, 1, ...]
|
||||||
|
out = scale * out + bias
|
||||||
|
else:
|
||||||
|
out = out + embed
|
||||||
|
out = self.blocks[1](out)
|
||||||
|
out = out + self.residual_conv(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalUnet1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
local_cond_dim=None,
|
||||||
|
global_cond_dim=None,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=None,
|
||||||
|
kernel_size=3,
|
||||||
|
n_groups=8,
|
||||||
|
cond_predict_scale=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if down_dims is None:
|
||||||
|
down_dims = [256, 512, 1024]
|
||||||
|
|
||||||
|
all_dims = [input_dim] + list(down_dims)
|
||||||
|
start_dim = down_dims[0]
|
||||||
|
|
||||||
|
dsed = diffusion_step_embed_dim
|
||||||
|
diffusion_step_encoder = nn.Sequential(
|
||||||
|
SinusoidalPosEmb(dsed),
|
||||||
|
nn.Linear(dsed, dsed * 4),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(dsed * 4, dsed),
|
||||||
|
)
|
||||||
|
cond_dim = dsed
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
cond_dim += global_cond_dim
|
||||||
|
|
||||||
|
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||||
|
|
||||||
|
local_cond_encoder = None
|
||||||
|
if local_cond_dim is not None:
|
||||||
|
_, dim_out = in_out[0]
|
||||||
|
dim_in = local_cond_dim
|
||||||
|
local_cond_encoder = nn.ModuleList(
|
||||||
|
[
|
||||||
|
# down encoder
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
# up encoder
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mid_dim = all_dims[-1]
|
||||||
|
self.mid_modules = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
mid_dim,
|
||||||
|
mid_dim,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
mid_dim,
|
||||||
|
mid_dim,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
down_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
down_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_out,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
up_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
up_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_out * 2,
|
||||||
|
dim_in,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in,
|
||||||
|
dim_in,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
),
|
||||||
|
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_conv = nn.Sequential(
|
||||||
|
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||||
|
nn.Conv1d(start_dim, input_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.diffusion_step_encoder = diffusion_step_encoder
|
||||||
|
self.local_cond_encoder = local_cond_encoder
|
||||||
|
self.up_modules = up_modules
|
||||||
|
self.down_modules = down_modules
|
||||||
|
self.final_conv = final_conv
|
||||||
|
|
||||||
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
local_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
x: (B,T,input_dim)
|
||||||
|
timestep: (B,) or int, diffusion step
|
||||||
|
local_cond: (B,T,local_cond_dim)
|
||||||
|
global_cond: (B,global_cond_dim)
|
||||||
|
output: (B,T,input_dim)
|
||||||
|
"""
|
||||||
|
sample = einops.rearrange(sample, "b h t -> b t h")
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
|
|
||||||
|
if global_cond is not None:
|
||||||
|
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||||
|
|
||||||
|
# encode local features
|
||||||
|
h_local = []
|
||||||
|
if local_cond is not None:
|
||||||
|
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
|
||||||
|
resnet, resnet2 = self.local_cond_encoder
|
||||||
|
x = resnet(local_cond, global_feature)
|
||||||
|
h_local.append(x)
|
||||||
|
x = resnet2(local_cond, global_feature)
|
||||||
|
h_local.append(x)
|
||||||
|
|
||||||
|
x = sample
|
||||||
|
h = []
|
||||||
|
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
if idx == 0 and len(h_local) > 0:
|
||||||
|
x = x + h_local[0]
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
h.append(x)
|
||||||
|
x = downsample(x)
|
||||||
|
|
||||||
|
for mid_module in self.mid_modules:
|
||||||
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
# The correct condition should be:
|
||||||
|
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||||
|
# However this change will break compatibility with published checkpoints.
|
||||||
|
# Therefore it is left as a comment.
|
||||||
|
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||||
|
x = x + h_local[1]
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
x = einops.rearrange(x, "b t h -> b h t")
|
||||||
|
return x
|
|
@ -0,0 +1,47 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample1d(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1d(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Conv1d --> GroupNorm --> Mish
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||||
|
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||||
|
nn.GroupNorm(n_groups, out_channels),
|
||||||
|
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
# def test():
|
||||||
|
# cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||||
|
# x = torch.zeros((1,256,16))
|
||||||
|
# o = cb(x)
|
|
@ -0,0 +1,294 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms.functional as ttf
|
||||||
|
|
||||||
|
import lerobot.common.policies.diffusion.model.tensor_utils as tu
|
||||||
|
|
||||||
|
|
||||||
|
class CropRandomizer(nn.Module):
|
||||||
|
"""
|
||||||
|
Randomly sample crops at input, and then average across crop features at output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_shape,
|
||||||
|
crop_height,
|
||||||
|
crop_width,
|
||||||
|
num_crops=1,
|
||||||
|
pos_enc=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||||
|
crop_height (int): crop height
|
||||||
|
crop_width (int): crop width
|
||||||
|
num_crops (int): number of random crops to take
|
||||||
|
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||||
|
location of the cropped pixels in the source image
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3 # (C, H, W)
|
||||||
|
assert crop_height < input_shape[1]
|
||||||
|
assert crop_width < input_shape[2]
|
||||||
|
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.crop_height = crop_height
|
||||||
|
self.crop_width = crop_width
|
||||||
|
self.num_crops = num_crops
|
||||||
|
self.pos_enc = pos_enc
|
||||||
|
|
||||||
|
def output_shape_in(self, input_shape=None):
|
||||||
|
"""
|
||||||
|
Function to compute output shape from inputs to this module. Corresponds to
|
||||||
|
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||||
|
are passed in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||||
|
Some modules may not need this argument, if their output does not depend
|
||||||
|
on the size of the input, or if they assume fixed size input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_shape ([int]): list of integers corresponding to output shape
|
||||||
|
"""
|
||||||
|
|
||||||
|
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||||
|
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||||
|
# size from B to B * N
|
||||||
|
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||||
|
return [out_c, self.crop_height, self.crop_width]
|
||||||
|
|
||||||
|
def output_shape_out(self, input_shape=None):
|
||||||
|
"""
|
||||||
|
Function to compute output shape from inputs to this module. Corresponds to
|
||||||
|
the @forward_out operation, where processed inputs (usually encoded observation
|
||||||
|
modalities) are passed in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||||
|
Some modules may not need this argument, if their output does not depend
|
||||||
|
on the size of the input, or if they assume fixed size input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_shape ([int]): list of integers corresponding to output shape
|
||||||
|
"""
|
||||||
|
|
||||||
|
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||||
|
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||||
|
# and so the other dimensions retain their shape.
|
||||||
|
return list(input_shape)
|
||||||
|
|
||||||
|
def forward_in(self, inputs):
|
||||||
|
"""
|
||||||
|
Samples N random crops for each input in the batch, and then reshapes
|
||||||
|
inputs to [B * N, ...].
|
||||||
|
"""
|
||||||
|
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||||
|
if self.training:
|
||||||
|
# generate random crops
|
||||||
|
out, _ = sample_random_image_crops(
|
||||||
|
images=inputs,
|
||||||
|
crop_height=self.crop_height,
|
||||||
|
crop_width=self.crop_width,
|
||||||
|
num_crops=self.num_crops,
|
||||||
|
pos_enc=self.pos_enc,
|
||||||
|
)
|
||||||
|
# [B, N, ...] -> [B * N, ...]
|
||||||
|
return tu.join_dimensions(out, 0, 1)
|
||||||
|
else:
|
||||||
|
# take center crop during eval
|
||||||
|
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
|
||||||
|
if self.num_crops > 1:
|
||||||
|
B, C, H, W = out.shape # noqa: N806
|
||||||
|
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
|
||||||
|
# [B * N, ...]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_out(self, inputs):
|
||||||
|
"""
|
||||||
|
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||||
|
to result in shape [B, ...] to make sure the network output is consistent with
|
||||||
|
what would have happened if there were no randomization.
|
||||||
|
"""
|
||||||
|
if self.num_crops <= 1:
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
batch_size = inputs.shape[0] // self.num_crops
|
||||||
|
out = tu.reshape_dimensions(
|
||||||
|
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
|
||||||
|
)
|
||||||
|
return out.mean(dim=1)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return self.forward_in(inputs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Pretty print network."""
|
||||||
|
header = "{}".format(str(self.__class__.__name__))
|
||||||
|
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||||
|
self.input_shape, self.crop_height, self.crop_width, self.num_crops
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||||
|
"""
|
||||||
|
Crops images at the locations specified by @crop_indices. Crops will be
|
||||||
|
taken across all channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||||
|
|
||||||
|
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||||
|
N is the number of crops to take per image and each entry corresponds
|
||||||
|
to the pixel height and width of where to take the crop. Note that
|
||||||
|
the indices can also be of shape [..., 2] if only 1 crop should
|
||||||
|
be taken per image. Leading dimensions must be consistent with
|
||||||
|
@images argument. Each index specifies the top left of the crop.
|
||||||
|
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||||
|
H and W are the height and width of @images and CH and CW are
|
||||||
|
@crop_height and @crop_width.
|
||||||
|
|
||||||
|
crop_height (int): height of crop to take
|
||||||
|
|
||||||
|
crop_width (int): width of crop to take
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# make sure length of input shapes is consistent
|
||||||
|
assert crop_indices.shape[-1] == 2
|
||||||
|
ndim_im_shape = len(images.shape)
|
||||||
|
ndim_indices_shape = len(crop_indices.shape)
|
||||||
|
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||||
|
|
||||||
|
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||||
|
is_padded = False
|
||||||
|
if ndim_im_shape == ndim_indices_shape + 2:
|
||||||
|
crop_indices = crop_indices.unsqueeze(-2)
|
||||||
|
is_padded = True
|
||||||
|
|
||||||
|
# make sure leading dimensions between images and indices are consistent
|
||||||
|
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||||
|
|
||||||
|
device = images.device
|
||||||
|
image_c, image_h, image_w = images.shape[-3:]
|
||||||
|
num_crops = crop_indices.shape[-2]
|
||||||
|
|
||||||
|
# make sure @crop_indices are in valid range
|
||||||
|
assert (crop_indices[..., 0] >= 0).all().item()
|
||||||
|
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||||
|
assert (crop_indices[..., 1] >= 0).all().item()
|
||||||
|
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||||
|
|
||||||
|
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||||
|
|
||||||
|
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||||
|
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||||
|
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
|
||||||
|
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||||
|
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||||
|
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
|
||||||
|
# combine into shape [CH, CW, 2]
|
||||||
|
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||||
|
|
||||||
|
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||||
|
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||||
|
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||||
|
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
|
||||||
|
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
|
||||||
|
|
||||||
|
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||||
|
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||||
|
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||||
|
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
|
||||||
|
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
|
||||||
|
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||||
|
|
||||||
|
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||||
|
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||||
|
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||||
|
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||||
|
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||||
|
reshape_axis = len(crops.shape) - 1
|
||||||
|
crops = tu.reshape_dimensions(
|
||||||
|
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_padded:
|
||||||
|
# undo padding -> [..., C, CH, CW]
|
||||||
|
crops = crops.squeeze(-4)
|
||||||
|
return crops
|
||||||
|
|
||||||
|
|
||||||
|
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
|
||||||
|
"""
|
||||||
|
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||||
|
@images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||||
|
|
||||||
|
crop_height (int): height of crop to take
|
||||||
|
|
||||||
|
crop_width (int): width of crop to take
|
||||||
|
|
||||||
|
num_crops (n): number of crops to sample
|
||||||
|
|
||||||
|
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||||
|
encoding of the original source pixel locations. This means that the
|
||||||
|
output crops will contain information about where in the source image
|
||||||
|
it was sampled from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||||
|
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||||
|
|
||||||
|
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||||
|
"""
|
||||||
|
device = images.device
|
||||||
|
|
||||||
|
# maybe add 2 channels of spatial encoding to the source image
|
||||||
|
source_im = images
|
||||||
|
if pos_enc:
|
||||||
|
# spatial encoding [y, x] in [0, 1]
|
||||||
|
h, w = source_im.shape[-2:]
|
||||||
|
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||||
|
pos_y = pos_y.float().to(device) / float(h)
|
||||||
|
pos_x = pos_x.float().to(device) / float(w)
|
||||||
|
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||||
|
|
||||||
|
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||||
|
leading_shape = source_im.shape[:-3]
|
||||||
|
position_enc = position_enc[(None,) * len(leading_shape)]
|
||||||
|
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||||
|
|
||||||
|
# concat across channel dimension with input
|
||||||
|
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||||
|
|
||||||
|
# make sure sample boundaries ensure crops are fully within the images
|
||||||
|
image_c, image_h, image_w = source_im.shape[-3:]
|
||||||
|
max_sample_h = image_h - crop_height
|
||||||
|
max_sample_w = image_w - crop_width
|
||||||
|
|
||||||
|
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||||
|
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||||
|
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||||
|
# or possibly no leading dimension.
|
||||||
|
#
|
||||||
|
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||||
|
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||||
|
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||||
|
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
|
||||||
|
|
||||||
|
crops = crop_image_from_indices(
|
||||||
|
images=source_im,
|
||||||
|
crop_indices=crop_inds,
|
||||||
|
crop_height=crop_height,
|
||||||
|
crop_width=crop_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
return crops, crop_inds
|
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DictOfTensorMixin(nn.Module):
|
||||||
|
def __init__(self, params_dict=None):
|
||||||
|
super().__init__()
|
||||||
|
if params_dict is None:
|
||||||
|
params_dict = nn.ParameterDict()
|
||||||
|
self.params_dict = params_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(iter(self.parameters())).device
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
|
def dfs_add(dest, keys, value: torch.Tensor):
|
||||||
|
if len(keys) == 1:
|
||||||
|
dest[keys[0]] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
if keys[0] not in dest:
|
||||||
|
dest[keys[0]] = nn.ParameterDict()
|
||||||
|
dfs_add(dest[keys[0]], keys[1:], value)
|
||||||
|
|
||||||
|
def load_dict(state_dict, prefix):
|
||||||
|
out_dict = nn.ParameterDict()
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
value: torch.Tensor
|
||||||
|
if key.startswith(prefix):
|
||||||
|
param_keys = key[len(prefix) :].split(".")[1:]
|
||||||
|
# if len(param_keys) == 0:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
dfs_add(out_dict, param_keys, value.clone())
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
self.params_dict = load_dict(state_dict, prefix + "params_dict")
|
||||||
|
self.params_dict.requires_grad_(False)
|
||||||
|
return
|
|
@ -0,0 +1,84 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModel:
|
||||||
|
"""
|
||||||
|
Exponential Moving Average of models weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||||
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||||
|
at 215.4k steps).
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.averaged_model = model
|
||||||
|
self.averaged_model.eval()
|
||||||
|
self.averaged_model.requires_grad_(False)
|
||||||
|
|
||||||
|
self.update_after_step = update_after_step
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
|
self.decay = 0.0
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
|
def get_decay(self, optimization_step):
|
||||||
|
"""
|
||||||
|
Compute the decay factor for the exponential moving average.
|
||||||
|
"""
|
||||||
|
step = max(0, optimization_step - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||||
|
|
||||||
|
if step <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return max(self.min_value, min(value, self.max_value))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, new_model):
|
||||||
|
self.decay = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
|
# old_all_dataptrs = set()
|
||||||
|
# for param in new_model.parameters():
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# old_all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
# all_dataptrs = set()
|
||||||
|
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
||||||
|
for param, ema_param in zip(
|
||||||
|
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
||||||
|
):
|
||||||
|
# iterative over immediate parameters only.
|
||||||
|
if isinstance(param, dict):
|
||||||
|
raise RuntimeError("Dict parameter not supported")
|
||||||
|
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
if isinstance(module, _BatchNorm):
|
||||||
|
# skip batchnorms
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
elif not param.requires_grad:
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
else:
|
||||||
|
ema_param.mul_(self.decay)
|
||||||
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||||
|
|
||||||
|
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||||
|
# assert old_all_dataptrs == all_dataptrs
|
||||||
|
self.optimization_step += 1
|
|
@ -0,0 +1,46 @@
|
||||||
|
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(
|
||||||
|
name: Union[str, SchedulerType],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: Optional[int] = None,
|
||||||
|
num_training_steps: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Added kwargs vs diffuser's original implementation
|
||||||
|
|
||||||
|
Unified API to get any scheduler from its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (`str` or `SchedulerType`):
|
||||||
|
The name of the scheduler to use.
|
||||||
|
optimizer (`torch.optim.Optimizer`):
|
||||||
|
The optimizer that will be used during training.
|
||||||
|
num_warmup_steps (`int`, *optional*):
|
||||||
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
num_training_steps (`int``, *optional*):
|
||||||
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
"""
|
||||||
|
name = SchedulerType(name)
|
||||||
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
if name == SchedulerType.CONSTANT:
|
||||||
|
return schedule_func(optimizer, **kwargs)
|
||||||
|
|
||||||
|
# All other schedulers require `num_warmup_steps`
|
||||||
|
if num_warmup_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
|
||||||
|
|
||||||
|
# All other schedulers require `num_training_steps`
|
||||||
|
if num_training_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
return schedule_func(
|
||||||
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
|
||||||
|
)
|
|
@ -0,0 +1,65 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||||
|
|
||||||
|
|
||||||
|
class LowdimMaskGenerator(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_dim,
|
||||||
|
obs_dim,
|
||||||
|
# obs mask setup
|
||||||
|
max_n_obs_steps=2,
|
||||||
|
fix_obs_steps=True,
|
||||||
|
# action mask
|
||||||
|
action_visible=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.obs_dim = obs_dim
|
||||||
|
self.max_n_obs_steps = max_n_obs_steps
|
||||||
|
self.fix_obs_steps = fix_obs_steps
|
||||||
|
self.action_visible = action_visible
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, shape, seed=None):
|
||||||
|
device = self.device
|
||||||
|
B, T, D = shape # noqa: N806
|
||||||
|
assert (self.action_dim + self.obs_dim) == D
|
||||||
|
|
||||||
|
# create all tensors on this device
|
||||||
|
rng = torch.Generator(device=device)
|
||||||
|
if seed is not None:
|
||||||
|
rng = rng.manual_seed(seed)
|
||||||
|
|
||||||
|
# generate dim mask
|
||||||
|
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
||||||
|
is_action_dim = dim_mask.clone()
|
||||||
|
is_action_dim[..., : self.action_dim] = True
|
||||||
|
is_obs_dim = ~is_action_dim
|
||||||
|
|
||||||
|
# generate obs mask
|
||||||
|
if self.fix_obs_steps:
|
||||||
|
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
||||||
|
else:
|
||||||
|
obs_steps = torch.randint(
|
||||||
|
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
||||||
|
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||||
|
obs_mask = obs_mask & is_obs_dim
|
||||||
|
|
||||||
|
# generate action mask
|
||||||
|
if self.action_visible:
|
||||||
|
action_steps = torch.maximum(
|
||||||
|
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
||||||
|
)
|
||||||
|
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||||
|
action_mask = action_mask & is_action_dim
|
||||||
|
|
||||||
|
mask = obs_mask
|
||||||
|
if self.action_visible:
|
||||||
|
mask = mask | action_mask
|
||||||
|
|
||||||
|
return mask
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleAttrMixin(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._dummy_variable = nn.Parameter()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(iter(self.parameters())).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(iter(self.parameters())).dtype
|
|
@ -5,9 +5,9 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
from diffusion_policy.common.pytorch_util import replace_submodules
|
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||||
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
|
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||||
|
|
||||||
|
|
||||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
class MultiImageObsEncoder(ModuleAttrMixin):
|
|
@ -0,0 +1,358 @@
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
|
||||||
|
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||||
|
|
||||||
|
|
||||||
|
class LinearNormalizer(DictOfTensorMixin):
|
||||||
|
avaliable_modes = ["limits", "gaussian"]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
|
||||||
|
last_n_dims=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
mode="limits",
|
||||||
|
output_max=1.0,
|
||||||
|
output_min=-1.0,
|
||||||
|
range_eps=1e-4,
|
||||||
|
fit_offset=True,
|
||||||
|
):
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key, value in data.items():
|
||||||
|
self.params_dict[key] = _fit(
|
||||||
|
value,
|
||||||
|
last_n_dims=last_n_dims,
|
||||||
|
dtype=dtype,
|
||||||
|
mode=mode,
|
||||||
|
output_max=output_max,
|
||||||
|
output_min=output_min,
|
||||||
|
range_eps=range_eps,
|
||||||
|
fit_offset=fit_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.params_dict["_default"] = _fit(
|
||||||
|
data,
|
||||||
|
last_n_dims=last_n_dims,
|
||||||
|
dtype=dtype,
|
||||||
|
mode=mode,
|
||||||
|
output_max=output_max,
|
||||||
|
output_min=output_min,
|
||||||
|
range_eps=range_eps,
|
||||||
|
fit_offset=fit_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return self.normalize(x)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return SingleFieldLinearNormalizer(self.params_dict[key])
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
|
||||||
|
self.params_dict[key] = value.params_dict
|
||||||
|
|
||||||
|
def _normalize_impl(self, x, forward=True):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
result = {}
|
||||||
|
for key, value in x.items():
|
||||||
|
params = self.params_dict[key]
|
||||||
|
result[key] = _normalize(value, params, forward=forward)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
if "_default" not in self.params_dict:
|
||||||
|
raise RuntimeError("Not initialized")
|
||||||
|
params = self.params_dict["_default"]
|
||||||
|
return _normalize(x, params, forward=forward)
|
||||||
|
|
||||||
|
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return self._normalize_impl(x, forward=True)
|
||||||
|
|
||||||
|
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return self._normalize_impl(x, forward=False)
|
||||||
|
|
||||||
|
def get_input_stats(self) -> Dict:
|
||||||
|
if len(self.params_dict) == 0:
|
||||||
|
raise RuntimeError("Not initialized")
|
||||||
|
if len(self.params_dict) == 1 and "_default" in self.params_dict:
|
||||||
|
return self.params_dict["_default"]["input_stats"]
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in self.params_dict.items():
|
||||||
|
if key != "_default":
|
||||||
|
result[key] = value["input_stats"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_output_stats(self, key="_default"):
|
||||||
|
input_stats = self.get_input_stats()
|
||||||
|
if "min" in input_stats:
|
||||||
|
# no dict
|
||||||
|
return dict_apply(input_stats, self.normalize)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, group in input_stats.items():
|
||||||
|
this_dict = {}
|
||||||
|
for name, value in group.items():
|
||||||
|
this_dict[name] = self.normalize({key: value})[key]
|
||||||
|
result[key] = this_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SingleFieldLinearNormalizer(DictOfTensorMixin):
|
||||||
|
avaliable_modes = ["limits", "gaussian"]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||||
|
last_n_dims=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
mode="limits",
|
||||||
|
output_max=1.0,
|
||||||
|
output_min=-1.0,
|
||||||
|
range_eps=1e-4,
|
||||||
|
fit_offset=True,
|
||||||
|
):
|
||||||
|
self.params_dict = _fit(
|
||||||
|
data,
|
||||||
|
last_n_dims=last_n_dims,
|
||||||
|
dtype=dtype,
|
||||||
|
mode=mode,
|
||||||
|
output_max=output_max,
|
||||||
|
output_min=output_min,
|
||||||
|
range_eps=range_eps,
|
||||||
|
fit_offset=fit_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
|
||||||
|
obj = cls()
|
||||||
|
obj.fit(data, **kwargs)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_manual(
|
||||||
|
cls,
|
||||||
|
scale: Union[torch.Tensor, np.ndarray],
|
||||||
|
offset: Union[torch.Tensor, np.ndarray],
|
||||||
|
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
|
||||||
|
):
|
||||||
|
def to_tensor(x):
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
x = x.flatten()
|
||||||
|
return x
|
||||||
|
|
||||||
|
# check
|
||||||
|
for x in [offset] + list(input_stats_dict.values()):
|
||||||
|
assert x.shape == scale.shape
|
||||||
|
assert x.dtype == scale.dtype
|
||||||
|
|
||||||
|
params_dict = nn.ParameterDict(
|
||||||
|
{
|
||||||
|
"scale": to_tensor(scale),
|
||||||
|
"offset": to_tensor(offset),
|
||||||
|
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return cls(params_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_identity(cls, dtype=torch.float32):
|
||||||
|
scale = torch.tensor([1], dtype=dtype)
|
||||||
|
offset = torch.tensor([0], dtype=dtype)
|
||||||
|
input_stats_dict = {
|
||||||
|
"min": torch.tensor([-1], dtype=dtype),
|
||||||
|
"max": torch.tensor([1], dtype=dtype),
|
||||||
|
"mean": torch.tensor([0], dtype=dtype),
|
||||||
|
"std": torch.tensor([1], dtype=dtype),
|
||||||
|
}
|
||||||
|
return cls.create_manual(scale, offset, input_stats_dict)
|
||||||
|
|
||||||
|
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return _normalize(x, self.params_dict, forward=True)
|
||||||
|
|
||||||
|
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return _normalize(x, self.params_dict, forward=False)
|
||||||
|
|
||||||
|
def get_input_stats(self):
|
||||||
|
return self.params_dict["input_stats"]
|
||||||
|
|
||||||
|
def get_output_stats(self):
|
||||||
|
return dict_apply(self.params_dict["input_stats"], self.normalize)
|
||||||
|
|
||||||
|
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||||
|
return self.normalize(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _fit(
|
||||||
|
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||||
|
last_n_dims=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
mode="limits",
|
||||||
|
output_max=1.0,
|
||||||
|
output_min=-1.0,
|
||||||
|
range_eps=1e-4,
|
||||||
|
fit_offset=True,
|
||||||
|
):
|
||||||
|
assert mode in ["limits", "gaussian"]
|
||||||
|
assert last_n_dims >= 0
|
||||||
|
assert output_max > output_min
|
||||||
|
|
||||||
|
# convert data to torch and type
|
||||||
|
if isinstance(data, zarr.Array):
|
||||||
|
data = data[:]
|
||||||
|
if isinstance(data, np.ndarray):
|
||||||
|
data = torch.from_numpy(data)
|
||||||
|
if dtype is not None:
|
||||||
|
data = data.type(dtype)
|
||||||
|
|
||||||
|
# convert shape
|
||||||
|
dim = 1
|
||||||
|
if last_n_dims > 0:
|
||||||
|
dim = np.prod(data.shape[-last_n_dims:])
|
||||||
|
data = data.reshape(-1, dim)
|
||||||
|
|
||||||
|
# compute input stats min max mean std
|
||||||
|
input_min, _ = data.min(axis=0)
|
||||||
|
input_max, _ = data.max(axis=0)
|
||||||
|
input_mean = data.mean(axis=0)
|
||||||
|
input_std = data.std(axis=0)
|
||||||
|
|
||||||
|
# compute scale and offset
|
||||||
|
if mode == "limits":
|
||||||
|
if fit_offset:
|
||||||
|
# unit scale
|
||||||
|
input_range = input_max - input_min
|
||||||
|
ignore_dim = input_range < range_eps
|
||||||
|
input_range[ignore_dim] = output_max - output_min
|
||||||
|
scale = (output_max - output_min) / input_range
|
||||||
|
offset = output_min - scale * input_min
|
||||||
|
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
||||||
|
# ignore dims scaled to mean of output max and min
|
||||||
|
else:
|
||||||
|
# use this when data is pre-zero-centered.
|
||||||
|
assert output_max > 0
|
||||||
|
assert output_min < 0
|
||||||
|
# unit abs
|
||||||
|
output_abs = min(abs(output_min), abs(output_max))
|
||||||
|
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
|
||||||
|
ignore_dim = input_abs < range_eps
|
||||||
|
input_abs[ignore_dim] = output_abs
|
||||||
|
# don't scale constant channels
|
||||||
|
scale = output_abs / input_abs
|
||||||
|
offset = torch.zeros_like(input_mean)
|
||||||
|
elif mode == "gaussian":
|
||||||
|
ignore_dim = input_std < range_eps
|
||||||
|
scale = input_std.clone()
|
||||||
|
scale[ignore_dim] = 1
|
||||||
|
scale = 1 / scale
|
||||||
|
|
||||||
|
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
|
||||||
|
|
||||||
|
# save
|
||||||
|
this_params = nn.ParameterDict(
|
||||||
|
{
|
||||||
|
"scale": scale,
|
||||||
|
"offset": offset,
|
||||||
|
"input_stats": nn.ParameterDict(
|
||||||
|
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for p in this_params.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
return this_params
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(x, params, forward=True):
|
||||||
|
assert "scale" in params
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
scale = params["scale"]
|
||||||
|
offset = params["offset"]
|
||||||
|
x = x.to(device=scale.device, dtype=scale.dtype)
|
||||||
|
src_shape = x.shape
|
||||||
|
x = x.reshape(-1, scale.shape[0])
|
||||||
|
x = x * scale + offset if forward else (x - offset) / scale
|
||||||
|
x = x.reshape(src_shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||||
|
data[..., 0, 0] = 0
|
||||||
|
|
||||||
|
normalizer = SingleFieldLinearNormalizer()
|
||||||
|
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||||
|
datan = normalizer.normalize(data)
|
||||||
|
assert datan.shape == data.shape
|
||||||
|
assert np.allclose(datan.max(), 1.0)
|
||||||
|
assert np.allclose(datan.min(), -1.0)
|
||||||
|
dataun = normalizer.unnormalize(datan)
|
||||||
|
assert torch.allclose(data, dataun, atol=1e-7)
|
||||||
|
|
||||||
|
_ = normalizer.get_input_stats()
|
||||||
|
_ = normalizer.get_output_stats()
|
||||||
|
|
||||||
|
normalizer = SingleFieldLinearNormalizer()
|
||||||
|
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
|
||||||
|
datan = normalizer.normalize(data)
|
||||||
|
assert datan.shape == data.shape
|
||||||
|
assert np.allclose(datan.max(), 1.0, atol=1e-3)
|
||||||
|
assert np.allclose(datan.min(), 0.0, atol=1e-3)
|
||||||
|
dataun = normalizer.unnormalize(datan)
|
||||||
|
assert torch.allclose(data, dataun, atol=1e-7)
|
||||||
|
|
||||||
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||||
|
normalizer = SingleFieldLinearNormalizer()
|
||||||
|
normalizer.fit(data, mode="gaussian", last_n_dims=0)
|
||||||
|
datan = normalizer.normalize(data)
|
||||||
|
assert datan.shape == data.shape
|
||||||
|
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
|
||||||
|
assert np.allclose(datan.std(), 1.0, atol=1e-3)
|
||||||
|
dataun = normalizer.unnormalize(datan)
|
||||||
|
assert torch.allclose(data, dataun, atol=1e-7)
|
||||||
|
|
||||||
|
# dict
|
||||||
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||||
|
data[..., 0, 0] = 0
|
||||||
|
|
||||||
|
normalizer = LinearNormalizer()
|
||||||
|
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||||
|
datan = normalizer.normalize(data)
|
||||||
|
assert datan.shape == data.shape
|
||||||
|
assert np.allclose(datan.max(), 1.0)
|
||||||
|
assert np.allclose(datan.min(), -1.0)
|
||||||
|
dataun = normalizer.unnormalize(datan)
|
||||||
|
assert torch.allclose(data, dataun, atol=1e-7)
|
||||||
|
|
||||||
|
_ = normalizer.get_input_stats()
|
||||||
|
_ = normalizer.get_output_stats()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
|
||||||
|
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
|
||||||
|
}
|
||||||
|
normalizer = LinearNormalizer()
|
||||||
|
normalizer.fit(data)
|
||||||
|
datan = normalizer.normalize(data)
|
||||||
|
dataun = normalizer.unnormalize(datan)
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||||
|
|
||||||
|
_ = normalizer.get_input_stats()
|
||||||
|
_ = normalizer.get_output_stats()
|
||||||
|
|
||||||
|
state_dict = normalizer.state_dict()
|
||||||
|
n = LinearNormalizer()
|
||||||
|
n.load_state_dict(state_dict)
|
||||||
|
datan = n.normalize(data)
|
||||||
|
dataun = n.unnormalize(datan)
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
|
@ -0,0 +1,19 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
|
@ -0,0 +1,971 @@
|
||||||
|
"""
|
||||||
|
A collection of utilities for working with nested tensor structures consisting
|
||||||
|
of numpy arrays and torch tensors.
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||||
|
"""
|
||||||
|
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||||
|
{data_type: function_to_apply}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
type_func_dict (dict): a mapping from data types to the functions to be
|
||||||
|
applied for each data type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
assert list not in type_func_dict
|
||||||
|
assert tuple not in type_func_dict
|
||||||
|
assert dict not in type_func_dict
|
||||||
|
|
||||||
|
if isinstance(x, (dict, collections.OrderedDict)):
|
||||||
|
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
|
||||||
|
for k, v in x.items():
|
||||||
|
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||||
|
return new_x
|
||||||
|
elif isinstance(x, (list, tuple)):
|
||||||
|
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
ret = tuple(ret)
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
for t, f in type_func_dict.items():
|
||||||
|
if isinstance(x, t):
|
||||||
|
return f(x)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
|
||||||
|
|
||||||
|
|
||||||
|
def map_tensor(x, func):
|
||||||
|
"""
|
||||||
|
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||||
|
list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
func (function): function to apply to each tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: func,
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_ndarray(x, func):
|
||||||
|
"""
|
||||||
|
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||||
|
list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
func (function): function to apply to each array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
np.ndarray: func,
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||||
|
"""
|
||||||
|
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||||
|
np.ndarray objects in a nested dictionary or list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
tensor_func (function): function to apply to each tensor
|
||||||
|
ndarray_Func (function): function to apply to each array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: tensor_func,
|
||||||
|
np.ndarray: ndarray_func,
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clone(x):
|
||||||
|
"""
|
||||||
|
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||||
|
or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.clone(),
|
||||||
|
np.ndarray: lambda x: x.copy(),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detach(x):
|
||||||
|
"""
|
||||||
|
Detaches all torch tensors in nested dictionary or list
|
||||||
|
or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.detach(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_batch(x):
|
||||||
|
"""
|
||||||
|
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||||
|
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x[None, ...],
|
||||||
|
np.ndarray: lambda x: x[None, ...],
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_sequence(x):
|
||||||
|
"""
|
||||||
|
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||||
|
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x[:, None, ...],
|
||||||
|
np.ndarray: lambda x: x[:, None, ...],
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def index_at_time(x, ind):
|
||||||
|
"""
|
||||||
|
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||||
|
nested dictionary or list or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
ind (int): index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x[:, ind, ...],
|
||||||
|
np.ndarray: lambda x: x[:, ind, ...],
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unsqueeze(x, dim):
|
||||||
|
"""
|
||||||
|
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||||
|
in nested dictionary or list or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
dim (int): dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||||
|
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def contiguous(x):
|
||||||
|
"""
|
||||||
|
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||||
|
list or tuple and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.contiguous(),
|
||||||
|
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_device(x, device):
|
||||||
|
"""
|
||||||
|
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||||
|
@device, and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
device (torch.Device): device to send tensors to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x, d=device: x.to(d),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(x):
|
||||||
|
"""
|
||||||
|
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||||
|
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||||
|
a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x,
|
||||||
|
np.ndarray: lambda x: torch.from_numpy(x),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(x):
|
||||||
|
"""
|
||||||
|
Converts all torch tensors in nested dictionary or list or tuple to
|
||||||
|
numpy (and leaves existing numpy arrays as-is), and returns
|
||||||
|
a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(tensor):
|
||||||
|
if tensor.is_cuda:
|
||||||
|
return tensor.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
return tensor.detach().numpy()
|
||||||
|
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: f,
|
||||||
|
np.ndarray: lambda x: x,
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_list(x):
|
||||||
|
"""
|
||||||
|
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||||
|
or tuple to a list, and returns a new nested structure. Useful for
|
||||||
|
json encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(tensor):
|
||||||
|
if tensor.is_cuda:
|
||||||
|
return tensor.detach().cpu().numpy().tolist()
|
||||||
|
else:
|
||||||
|
return tensor.detach().numpy().tolist()
|
||||||
|
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: f,
|
||||||
|
np.ndarray: lambda x: x.tolist(),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_float(x):
|
||||||
|
"""
|
||||||
|
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||||
|
or tuple to float type entries, and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.float(),
|
||||||
|
np.ndarray: lambda x: x.astype(np.float32),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_uint8(x):
|
||||||
|
"""
|
||||||
|
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||||
|
or tuple to uint8 type entries, and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.byte(),
|
||||||
|
np.ndarray: lambda x: x.astype(np.uint8),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_torch(x, device):
|
||||||
|
"""
|
||||||
|
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||||
|
torch tensors on device @device and returns a new nested structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
device (torch.Device): device to send tensors to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return to_device(to_float(to_tensor(x)), device)
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot_single(tensor, num_class):
|
||||||
|
"""
|
||||||
|
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): tensor containing integer labels
|
||||||
|
num_class (int): number of classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||||
|
"""
|
||||||
|
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
|
||||||
|
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot(tensor, num_class):
|
||||||
|
"""
|
||||||
|
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||||
|
assuming a certain number of total class labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
num_class (int): number of classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_single(x, begin_axis=1):
|
||||||
|
"""
|
||||||
|
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): tensor to flatten
|
||||||
|
begin_axis (int): which axis to flatten from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (torch.Tensor): flattened tensor
|
||||||
|
"""
|
||||||
|
fixed_size = x.size()[:begin_axis]
|
||||||
|
_s = list(fixed_size) + [-1]
|
||||||
|
return x.reshape(*_s)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(x, begin_axis=1):
|
||||||
|
"""
|
||||||
|
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
begin_axis (int): which axis to flatten from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||||
|
"""
|
||||||
|
Reshape selected dimensions in a tensor to a target dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): tensor to reshape
|
||||||
|
begin_axis (int): begin dimension
|
||||||
|
end_axis (int): end dimension
|
||||||
|
target_dims (tuple or list): target shape for the range of dimensions
|
||||||
|
(@begin_axis, @end_axis)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (torch.Tensor): reshaped tensor
|
||||||
|
"""
|
||||||
|
assert begin_axis <= end_axis
|
||||||
|
assert begin_axis >= 0
|
||||||
|
assert end_axis < len(x.shape)
|
||||||
|
assert isinstance(target_dims, (tuple, list))
|
||||||
|
s = x.shape
|
||||||
|
final_s = []
|
||||||
|
for i in range(len(s)):
|
||||||
|
if i == begin_axis:
|
||||||
|
final_s.extend(target_dims)
|
||||||
|
elif i < begin_axis or i > end_axis:
|
||||||
|
final_s.append(s[i])
|
||||||
|
return x.reshape(*final_s)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||||
|
"""
|
||||||
|
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||||
|
to a target dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
begin_axis (int): begin dimension
|
||||||
|
end_axis (int): end dimension
|
||||||
|
target_dims (tuple or list): target shape for the range of dimensions
|
||||||
|
(@begin_axis, @end_axis)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||||
|
x, begin_axis=b, end_axis=e, target_dims=t
|
||||||
|
),
|
||||||
|
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||||
|
x, begin_axis=b, end_axis=e, target_dims=t
|
||||||
|
),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def join_dimensions(x, begin_axis, end_axis):
|
||||||
|
"""
|
||||||
|
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||||
|
all tensors in nested dictionary or list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
begin_axis (int): begin dimension
|
||||||
|
end_axis (int): end dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||||
|
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||||
|
),
|
||||||
|
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||||
|
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||||
|
),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_at_single(x, size, dim):
|
||||||
|
"""
|
||||||
|
Expand a tensor at a single dimension @dim by @size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input tensor
|
||||||
|
size (int): size to expand
|
||||||
|
dim (int): dimension to expand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (torch.Tensor): expanded tensor
|
||||||
|
"""
|
||||||
|
assert dim < x.ndimension()
|
||||||
|
assert x.shape[dim] == 1
|
||||||
|
expand_dims = [-1] * x.ndimension()
|
||||||
|
expand_dims[dim] = size
|
||||||
|
return x.expand(*expand_dims)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_at(x, size, dim):
|
||||||
|
"""
|
||||||
|
Expand all tensors in nested dictionary or list or tuple at a single
|
||||||
|
dimension @dim by @size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
size (int): size to expand
|
||||||
|
dim (int): dimension to expand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||||
|
|
||||||
|
|
||||||
|
def unsqueeze_expand_at(x, size, dim):
|
||||||
|
"""
|
||||||
|
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
size (int): size to expand
|
||||||
|
dim (int): dimension to unsqueeze and expand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
x = unsqueeze(x, dim)
|
||||||
|
return expand_at(x, size, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_by_expand_at(x, repeats, dim):
|
||||||
|
"""
|
||||||
|
Repeat a dimension by combining expand and reshape operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
repeats (int): number of times to repeat the target dimension
|
||||||
|
dim (int): dimension to repeat on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||||
|
return join_dimensions(x, dim, dim + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def named_reduce_single(x, reduction, dim):
|
||||||
|
"""
|
||||||
|
Reduce tensor at a dimension by named reduction functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): tensor to be reduced
|
||||||
|
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||||
|
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (torch.Tensor): reduced tensor
|
||||||
|
"""
|
||||||
|
assert x.ndimension() > dim
|
||||||
|
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||||
|
if reduction == "flatten":
|
||||||
|
x = flatten(x, begin_axis=dim)
|
||||||
|
elif reduction == "max":
|
||||||
|
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||||
|
elif reduction == "sum":
|
||||||
|
x = torch.sum(x, dim=dim)
|
||||||
|
else:
|
||||||
|
x = torch.mean(x, dim=dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def named_reduce(x, reduction, dim):
|
||||||
|
"""
|
||||||
|
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||||
|
using a named reduction function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||||
|
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||||
|
|
||||||
|
|
||||||
|
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||||
|
"""
|
||||||
|
This function indexes out a target dimension of a tensor in a structured way,
|
||||||
|
by allowing a different value to be selected for each member of a flat index
|
||||||
|
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||||
|
as moving along the source dimension, using the corresponding index value
|
||||||
|
in @indices to select values for all other dimensions outside of the
|
||||||
|
source and target dimensions. A common use case is to gather values
|
||||||
|
in target dimension 1 for each batch member (target dimension 0).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): tensor to gather values for
|
||||||
|
target_dim (int): dimension to gather values along
|
||||||
|
source_dim (int): dimension to hold constant and use for gathering values
|
||||||
|
from the other dimensions
|
||||||
|
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||||
|
@source_dim
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||||
|
"""
|
||||||
|
assert len(indices.shape) == 1
|
||||||
|
assert x.shape[source_dim] == indices.shape[0]
|
||||||
|
|
||||||
|
# unsqueeze in all dimensions except the source dimension
|
||||||
|
new_shape = [1] * x.ndimension()
|
||||||
|
new_shape[source_dim] = -1
|
||||||
|
indices = indices.reshape(*new_shape)
|
||||||
|
|
||||||
|
# repeat in all dimensions - but preserve shape of source dimension,
|
||||||
|
# and make sure target_dimension has singleton dimension
|
||||||
|
expand_shape = list(x.shape)
|
||||||
|
expand_shape[source_dim] = -1
|
||||||
|
expand_shape[target_dim] = 1
|
||||||
|
indices = indices.expand(*expand_shape)
|
||||||
|
|
||||||
|
out = x.gather(dim=target_dim, index=indices)
|
||||||
|
return out.squeeze(target_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||||
|
"""
|
||||||
|
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||||
|
dictionary or list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
target_dim (int): dimension to gather values along
|
||||||
|
source_dim (int): dimension to hold constant and use for gathering values
|
||||||
|
from the other dimensions
|
||||||
|
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||||
|
@source_dim
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple
|
||||||
|
"""
|
||||||
|
return map_tensor(
|
||||||
|
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sequence_single(seq, indices):
|
||||||
|
"""
|
||||||
|
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||||
|
the batch given an index for each sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||||
|
indices (torch.Tensor): tensor indices of shape [B]
|
||||||
|
|
||||||
|
Return:
|
||||||
|
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||||
|
"""
|
||||||
|
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sequence(seq, indices):
|
||||||
|
"""
|
||||||
|
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||||
|
for tensors with leading dimensions [B, T, ...].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||||
|
of leading dimensions [B, T, ...]
|
||||||
|
indices (torch.Tensor): tensor indices of shape [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||||
|
"""
|
||||||
|
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||||
|
"""
|
||||||
|
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||||
|
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||||
|
batched (bool): if sequence has the batch dimension
|
||||||
|
pad_same (bool): if pad by duplicating
|
||||||
|
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
padded sequence (np.ndarray or torch.Tensor)
|
||||||
|
"""
|
||||||
|
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||||
|
assert pad_same or pad_values is not None
|
||||||
|
if pad_values is not None:
|
||||||
|
assert isinstance(pad_values, float)
|
||||||
|
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
|
||||||
|
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||||
|
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
|
||||||
|
seq_dim = 1 if batched else 0
|
||||||
|
|
||||||
|
begin_pad = []
|
||||||
|
end_pad = []
|
||||||
|
|
||||||
|
if padding[0] > 0:
|
||||||
|
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||||
|
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||||
|
if padding[1] > 0:
|
||||||
|
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||||
|
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||||
|
|
||||||
|
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||||
|
"""
|
||||||
|
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||||
|
of leading dimensions [B, T, ...]
|
||||||
|
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||||
|
batched (bool): if sequence has the batch dimension
|
||||||
|
pad_same (bool): if pad by duplicating
|
||||||
|
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
padded sequence (dict or list or tuple)
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
seq,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||||
|
x, p, b, ps, pv
|
||||||
|
),
|
||||||
|
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||||
|
x, p, b, ps, pv
|
||||||
|
),
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_size_at_dim_single(x, size, dim, msg):
|
||||||
|
"""
|
||||||
|
Ensure that array or tensor @x has size @size in dim @dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray or torch.Tensor): input array or tensor
|
||||||
|
size (int): size that tensors should have at @dim
|
||||||
|
dim (int): dimension to check
|
||||||
|
msg (str): text to display if assertion fails
|
||||||
|
"""
|
||||||
|
assert x.shape[dim] == size, msg
|
||||||
|
|
||||||
|
|
||||||
|
def assert_size_at_dim(x, size, dim, msg):
|
||||||
|
"""
|
||||||
|
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||||
|
size @size in dim @dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
size (int): size that tensors should have at @dim
|
||||||
|
dim (int): dimension to check
|
||||||
|
"""
|
||||||
|
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||||
|
|
||||||
|
|
||||||
|
def get_shape(x):
|
||||||
|
"""
|
||||||
|
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||||
|
tensor's shape
|
||||||
|
"""
|
||||||
|
return recursive_dict_list_tuple_apply(
|
||||||
|
x,
|
||||||
|
{
|
||||||
|
torch.Tensor: lambda x: x.shape,
|
||||||
|
np.ndarray: lambda x: x.shape,
|
||||||
|
type(None): lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||||
|
"""
|
||||||
|
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||||
|
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||||
|
floats, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list_of_dict (list): list of flat dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict_of_list (dict): dictionary of lists
|
||||||
|
"""
|
||||||
|
assert isinstance(list_of_dict, list)
|
||||||
|
dic = collections.OrderedDict()
|
||||||
|
for i in range(len(list_of_dict)):
|
||||||
|
for k in list_of_dict[i]:
|
||||||
|
if k not in dic:
|
||||||
|
dic[k] = []
|
||||||
|
dic[k].append(list_of_dict[i][k])
|
||||||
|
return dic
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
|
||||||
|
"""
|
||||||
|
Flatten a nested dict or list to a list.
|
||||||
|
|
||||||
|
For example, given a dict
|
||||||
|
{
|
||||||
|
a: 1
|
||||||
|
b: {
|
||||||
|
c: 2
|
||||||
|
}
|
||||||
|
c: 3
|
||||||
|
}
|
||||||
|
|
||||||
|
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d (dict, list): a nested dict or list to be flattened
|
||||||
|
parent_key (str): recursion helper
|
||||||
|
sep (str): separator for nesting keys
|
||||||
|
item_key (str): recursion helper
|
||||||
|
Returns:
|
||||||
|
list: a list of (key, value) tuples
|
||||||
|
"""
|
||||||
|
items = []
|
||||||
|
if isinstance(d, (tuple, list)):
|
||||||
|
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||||
|
for i, v in enumerate(d):
|
||||||
|
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||||
|
return items
|
||||||
|
elif isinstance(d, dict):
|
||||||
|
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||||
|
for k, v in d.items():
|
||||||
|
assert isinstance(k, str)
|
||||||
|
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||||
|
return items
|
||||||
|
else:
|
||||||
|
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||||
|
return [(new_key, d)]
|
||||||
|
|
||||||
|
|
||||||
|
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||||
|
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||||
|
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||||
|
outputs to [B, T, ...].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||||
|
of leading dimensions [B, T, ...]
|
||||||
|
op: a layer op that accepts inputs
|
||||||
|
activation: activation to apply at the output
|
||||||
|
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||||
|
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||||
|
kwargs (dict): other kwargs to supply to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||||
|
inputs = join_dimensions(inputs, 0, 1)
|
||||||
|
if inputs_as_kwargs:
|
||||||
|
outputs = op(**inputs, **kwargs)
|
||||||
|
elif inputs_as_args:
|
||||||
|
outputs = op(*inputs, **kwargs)
|
||||||
|
else:
|
||||||
|
outputs = op(inputs, **kwargs)
|
||||||
|
|
||||||
|
if activation is not None:
|
||||||
|
outputs = map_tensor(outputs, activation)
|
||||||
|
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
|
||||||
|
return outputs
|
|
@ -4,10 +4,10 @@ import time
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
|
||||||
|
|
||||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
class DiffusionPolicy(nn.Module):
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
from typing import Callable, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet(name, weights=None, **kwargs):
|
||||||
|
"""
|
||||||
|
name: resnet18, resnet34, resnet50
|
||||||
|
weights: "IMAGENET1K_V1", "r3m"
|
||||||
|
"""
|
||||||
|
# load r3m weights
|
||||||
|
if (weights == "r3m") or (weights == "R3M"):
|
||||||
|
return get_r3m(name=name, **kwargs)
|
||||||
|
|
||||||
|
func = getattr(torchvision.models, name)
|
||||||
|
resnet = func(weights=weights, **kwargs)
|
||||||
|
resnet.fc = torch.nn.Identity()
|
||||||
|
return resnet
|
||||||
|
|
||||||
|
|
||||||
|
def get_r3m(name, **kwargs):
|
||||||
|
"""
|
||||||
|
name: resnet18, resnet34, resnet50
|
||||||
|
"""
|
||||||
|
import r3m
|
||||||
|
|
||||||
|
r3m.device = "cpu"
|
||||||
|
model = r3m.load_r3m(name)
|
||||||
|
r3m_model = model.module
|
||||||
|
resnet_model = r3m_model.convnet
|
||||||
|
resnet_model = resnet_model.to("cpu")
|
||||||
|
return resnet_model
|
||||||
|
|
||||||
|
|
||||||
|
def dict_apply(
|
||||||
|
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
result = {}
|
||||||
|
for key, value in x.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result[key] = dict_apply(value, func)
|
||||||
|
else:
|
||||||
|
result[key] = func(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_submodules(
|
||||||
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
predicate: Return true if the module is to be replaced.
|
||||||
|
func: Return new module to use.
|
||||||
|
"""
|
||||||
|
if predicate(root_module):
|
||||||
|
return func(root_module)
|
||||||
|
|
||||||
|
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||||
|
for *parent, k in bn_list:
|
||||||
|
parent_module = root_module
|
||||||
|
if len(parent) > 0:
|
||||||
|
parent_module = root_module.get_submodule(".".join(parent))
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
src_module = parent_module[int(k)]
|
||||||
|
else:
|
||||||
|
src_module = getattr(parent_module, k)
|
||||||
|
tgt_module = func(src_module)
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
parent_module[int(k)] = tgt_module
|
||||||
|
else:
|
||||||
|
setattr(parent_module, k, tgt_module)
|
||||||
|
# verify that all BN are replaced
|
||||||
|
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||||
|
assert len(bn_list) == 0
|
||||||
|
return root_module
|
|
@ -0,0 +1,614 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numbers
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import numcodecs
|
||||||
|
import numpy as np
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
|
||||||
|
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||||
|
assert len(shape) == len(chunks)
|
||||||
|
for c in chunks:
|
||||||
|
assert isinstance(c, numbers.Integral)
|
||||||
|
assert c > 0
|
||||||
|
|
||||||
|
|
||||||
|
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||||
|
old_arr = group[name]
|
||||||
|
if chunks is None:
|
||||||
|
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||||
|
check_chunks_compatible(chunks, old_arr.shape)
|
||||||
|
|
||||||
|
if compressor is None:
|
||||||
|
compressor = old_arr.compressor
|
||||||
|
|
||||||
|
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
||||||
|
# no change
|
||||||
|
return old_arr
|
||||||
|
|
||||||
|
# rechunk recompress
|
||||||
|
group.move(name, tmp_key)
|
||||||
|
old_arr = group[tmp_key]
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||||
|
source=old_arr,
|
||||||
|
dest=group,
|
||||||
|
name=name,
|
||||||
|
chunks=chunks,
|
||||||
|
compressor=compressor,
|
||||||
|
)
|
||||||
|
del group[tmp_key]
|
||||||
|
arr = group[name]
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
||||||
|
"""
|
||||||
|
Common shapes
|
||||||
|
T,D
|
||||||
|
T,N,D
|
||||||
|
T,H,W,C
|
||||||
|
T,N,H,W,C
|
||||||
|
"""
|
||||||
|
itemsize = np.dtype(dtype).itemsize
|
||||||
|
# reversed
|
||||||
|
rshape = list(shape[::-1])
|
||||||
|
if max_chunk_length is not None:
|
||||||
|
rshape[-1] = int(max_chunk_length)
|
||||||
|
split_idx = len(shape) - 1
|
||||||
|
for i in range(len(shape) - 1):
|
||||||
|
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||||
|
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||||
|
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||||
|
split_idx = i
|
||||||
|
|
||||||
|
rchunks = rshape[:split_idx]
|
||||||
|
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||||
|
this_max_chunk_length = rshape[split_idx]
|
||||||
|
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||||
|
rchunks.append(next_chunk_length)
|
||||||
|
len_diff = len(shape) - len(rchunks)
|
||||||
|
rchunks.extend([1] * len_diff)
|
||||||
|
chunks = tuple(rchunks[::-1])
|
||||||
|
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayBuffer:
|
||||||
|
"""
|
||||||
|
Zarr-based temporal datastructure.
|
||||||
|
Assumes first dimension to be time. Only chunk in time dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root: zarr.Group | dict[str, dict]):
|
||||||
|
"""
|
||||||
|
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
||||||
|
"""
|
||||||
|
assert "data" in root
|
||||||
|
assert "meta" in root
|
||||||
|
assert "episode_ends" in root["meta"]
|
||||||
|
for value in root["data"].values():
|
||||||
|
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
||||||
|
self.root = root
|
||||||
|
|
||||||
|
# ============= create constructors ===============
|
||||||
|
@classmethod
|
||||||
|
def create_empty_zarr(cls, storage=None, root=None):
|
||||||
|
if root is None:
|
||||||
|
if storage is None:
|
||||||
|
storage = zarr.MemoryStore()
|
||||||
|
root = zarr.group(store=storage)
|
||||||
|
root.require_group("data", overwrite=False)
|
||||||
|
meta = root.require_group("meta", overwrite=False)
|
||||||
|
if "episode_ends" not in meta:
|
||||||
|
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||||
|
return cls(root=root)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_empty_numpy(cls):
|
||||||
|
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
||||||
|
return cls(root=root)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_group(cls, group, **kwargs):
|
||||||
|
if "data" not in group:
|
||||||
|
# create from stratch
|
||||||
|
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
||||||
|
else:
|
||||||
|
# already exist
|
||||||
|
buffer = cls(root=group, **kwargs)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
||||||
|
"""
|
||||||
|
Open a on-disk zarr directly (for dataset larger than memory).
|
||||||
|
Slower.
|
||||||
|
"""
|
||||||
|
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
||||||
|
return cls.create_from_group(group, **kwargs)
|
||||||
|
|
||||||
|
# ============= copy constructors ===============
|
||||||
|
@classmethod
|
||||||
|
def copy_from_store(
|
||||||
|
cls,
|
||||||
|
src_store,
|
||||||
|
store=None,
|
||||||
|
keys=None,
|
||||||
|
chunks: dict[str, tuple] | None = None,
|
||||||
|
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||||
|
if_exists="replace",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load to memory.
|
||||||
|
"""
|
||||||
|
src_root = zarr.group(src_store)
|
||||||
|
if chunks is None:
|
||||||
|
chunks = {}
|
||||||
|
if compressors is None:
|
||||||
|
compressors = {}
|
||||||
|
root = None
|
||||||
|
if store is None:
|
||||||
|
# numpy backend
|
||||||
|
meta = {}
|
||||||
|
for key, value in src_root["meta"].items():
|
||||||
|
if len(value.shape) == 0:
|
||||||
|
meta[key] = np.array(value)
|
||||||
|
else:
|
||||||
|
meta[key] = value[:]
|
||||||
|
|
||||||
|
if keys is None:
|
||||||
|
keys = src_root["data"].keys()
|
||||||
|
data = {}
|
||||||
|
for key in keys:
|
||||||
|
arr = src_root["data"][key]
|
||||||
|
data[key] = arr[:]
|
||||||
|
|
||||||
|
root = {"meta": meta, "data": data}
|
||||||
|
else:
|
||||||
|
root = zarr.group(store=store)
|
||||||
|
# copy without recompression
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||||
|
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||||
|
)
|
||||||
|
data_group = root.create_group("data", overwrite=True)
|
||||||
|
if keys is None:
|
||||||
|
keys = src_root["data"].keys()
|
||||||
|
for key in keys:
|
||||||
|
value = src_root["data"][key]
|
||||||
|
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||||
|
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||||
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
|
# copy without recompression
|
||||||
|
this_path = "/data/" + key
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||||
|
source=src_store,
|
||||||
|
dest=store,
|
||||||
|
source_path=this_path,
|
||||||
|
dest_path=this_path,
|
||||||
|
if_exists=if_exists,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# copy with recompression
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||||
|
source=value,
|
||||||
|
dest=data_group,
|
||||||
|
name=key,
|
||||||
|
chunks=cks,
|
||||||
|
compressor=cpr,
|
||||||
|
if_exists=if_exists,
|
||||||
|
)
|
||||||
|
buffer = cls(root=root)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def copy_from_path(
|
||||||
|
cls,
|
||||||
|
zarr_path,
|
||||||
|
backend=None,
|
||||||
|
store=None,
|
||||||
|
keys=None,
|
||||||
|
chunks: dict[str, tuple] | None = None,
|
||||||
|
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||||
|
if_exists="replace",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy a on-disk zarr to in-memory compressed.
|
||||||
|
Recommended
|
||||||
|
"""
|
||||||
|
if chunks is None:
|
||||||
|
chunks = {}
|
||||||
|
if compressors is None:
|
||||||
|
compressors = {}
|
||||||
|
if backend == "numpy":
|
||||||
|
print("backend argument is deprecated!")
|
||||||
|
store = None
|
||||||
|
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
||||||
|
return cls.copy_from_store(
|
||||||
|
src_store=group.store,
|
||||||
|
store=store,
|
||||||
|
keys=keys,
|
||||||
|
chunks=chunks,
|
||||||
|
compressors=compressors,
|
||||||
|
if_exists=if_exists,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============= save methods ===============
|
||||||
|
def save_to_store(
|
||||||
|
self,
|
||||||
|
store,
|
||||||
|
chunks: dict[str, tuple] | None = None,
|
||||||
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||||
|
if_exists="replace",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
root = zarr.group(store)
|
||||||
|
if chunks is None:
|
||||||
|
chunks = {}
|
||||||
|
if compressors is None:
|
||||||
|
compressors = {}
|
||||||
|
if self.backend == "zarr":
|
||||||
|
# recompression free copy
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||||
|
source=self.root.store,
|
||||||
|
dest=store,
|
||||||
|
source_path="/meta",
|
||||||
|
dest_path="/meta",
|
||||||
|
if_exists=if_exists,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
meta_group = root.create_group("meta", overwrite=True)
|
||||||
|
# save meta, no chunking
|
||||||
|
for key, value in self.root["meta"].items():
|
||||||
|
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||||
|
|
||||||
|
# save data, chunk
|
||||||
|
data_group = root.create_group("data", overwrite=True)
|
||||||
|
for key, value in self.root["data"].items():
|
||||||
|
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||||
|
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||||
|
if isinstance(value, zarr.Array):
|
||||||
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
|
# copy without recompression
|
||||||
|
this_path = "/data/" + key
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||||
|
source=self.root.store,
|
||||||
|
dest=store,
|
||||||
|
source_path=this_path,
|
||||||
|
dest_path=this_path,
|
||||||
|
if_exists=if_exists,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# copy with recompression
|
||||||
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||||
|
source=value,
|
||||||
|
dest=data_group,
|
||||||
|
name=key,
|
||||||
|
chunks=cks,
|
||||||
|
compressor=cpr,
|
||||||
|
if_exists=if_exists,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# numpy
|
||||||
|
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
||||||
|
return store
|
||||||
|
|
||||||
|
def save_to_path(
|
||||||
|
self,
|
||||||
|
zarr_path,
|
||||||
|
chunks: dict[str, tuple] | None = None,
|
||||||
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||||
|
if_exists="replace",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if chunks is None:
|
||||||
|
chunks = {}
|
||||||
|
if compressors is None:
|
||||||
|
compressors = {}
|
||||||
|
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
||||||
|
return self.save_to_store(
|
||||||
|
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_compressor(compressor="default"):
|
||||||
|
if compressor == "default":
|
||||||
|
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||||
|
elif compressor == "disk":
|
||||||
|
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||||
|
return compressor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||||
|
# allows compressor to be explicitly set to None
|
||||||
|
cpr = "nil"
|
||||||
|
if isinstance(compressors, dict):
|
||||||
|
if key in compressors:
|
||||||
|
cpr = cls.resolve_compressor(compressors[key])
|
||||||
|
elif isinstance(array, zarr.Array):
|
||||||
|
cpr = array.compressor
|
||||||
|
else:
|
||||||
|
cpr = cls.resolve_compressor(compressors)
|
||||||
|
# backup default
|
||||||
|
if cpr == "nil":
|
||||||
|
cpr = cls.resolve_compressor("default")
|
||||||
|
return cpr
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
||||||
|
cks = None
|
||||||
|
if isinstance(chunks, dict):
|
||||||
|
if key in chunks:
|
||||||
|
cks = chunks[key]
|
||||||
|
elif isinstance(array, zarr.Array):
|
||||||
|
cks = array.chunks
|
||||||
|
elif isinstance(chunks, tuple):
|
||||||
|
cks = chunks
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
||||||
|
# backup default
|
||||||
|
if cks is None:
|
||||||
|
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
||||||
|
# check
|
||||||
|
check_chunks_compatible(chunks=cks, shape=array.shape)
|
||||||
|
return cks
|
||||||
|
|
||||||
|
# ============= properties =================
|
||||||
|
@cached_property
|
||||||
|
def data(self):
|
||||||
|
return self.root["data"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def meta(self):
|
||||||
|
return self.root["meta"]
|
||||||
|
|
||||||
|
def update_meta(self, data):
|
||||||
|
# sanitize data
|
||||||
|
np_data = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
np_data[key] = value
|
||||||
|
else:
|
||||||
|
arr = np.array(value)
|
||||||
|
if arr.dtype == object:
|
||||||
|
raise TypeError(f"Invalid value type {type(value)}")
|
||||||
|
np_data[key] = arr
|
||||||
|
|
||||||
|
meta_group = self.meta
|
||||||
|
if self.backend == "zarr":
|
||||||
|
for key, value in np_data.items():
|
||||||
|
_ = meta_group.array(
|
||||||
|
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
meta_group.update(np_data)
|
||||||
|
|
||||||
|
return meta_group
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_ends(self):
|
||||||
|
return self.meta["episode_ends"]
|
||||||
|
|
||||||
|
def get_episode_idxs(self):
|
||||||
|
import numba
|
||||||
|
|
||||||
|
numba.jit(nopython=True)
|
||||||
|
|
||||||
|
def _get_episode_idxs(episode_ends):
|
||||||
|
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||||
|
for i in range(len(episode_ends)):
|
||||||
|
start = 0
|
||||||
|
if i > 0:
|
||||||
|
start = episode_ends[i - 1]
|
||||||
|
end = episode_ends[i]
|
||||||
|
for idx in range(start, end):
|
||||||
|
result[idx] = i
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _get_episode_idxs(self.episode_ends)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backend(self):
|
||||||
|
backend = "numpy"
|
||||||
|
if isinstance(self.root, zarr.Group):
|
||||||
|
backend = "zarr"
|
||||||
|
return backend
|
||||||
|
|
||||||
|
# =========== dict-like API ==============
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.backend == "zarr":
|
||||||
|
return str(self.root.tree())
|
||||||
|
else:
|
||||||
|
return super().__repr__()
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.data.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.data.items()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.data[key]
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.data
|
||||||
|
|
||||||
|
# =========== our API ==============
|
||||||
|
@property
|
||||||
|
def n_steps(self):
|
||||||
|
if len(self.episode_ends) == 0:
|
||||||
|
return 0
|
||||||
|
return self.episode_ends[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_episodes(self):
|
||||||
|
return len(self.episode_ends)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_size(self):
|
||||||
|
if self.backend == "zarr":
|
||||||
|
return next(iter(self.data.arrays()))[-1].chunks[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_lengths(self):
|
||||||
|
ends = self.episode_ends[:]
|
||||||
|
ends = np.insert(ends, 0, 0)
|
||||||
|
lengths = np.diff(ends)
|
||||||
|
return lengths
|
||||||
|
|
||||||
|
def add_episode(
|
||||||
|
self,
|
||||||
|
data: dict[str, np.ndarray],
|
||||||
|
chunks: dict[str, tuple] | None = None,
|
||||||
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||||
|
):
|
||||||
|
if chunks is None:
|
||||||
|
chunks = {}
|
||||||
|
if compressors is None:
|
||||||
|
compressors = {}
|
||||||
|
assert len(data) > 0
|
||||||
|
is_zarr = self.backend == "zarr"
|
||||||
|
|
||||||
|
curr_len = self.n_steps
|
||||||
|
episode_length = None
|
||||||
|
for value in data.values():
|
||||||
|
assert len(value.shape) >= 1
|
||||||
|
if episode_length is None:
|
||||||
|
episode_length = len(value)
|
||||||
|
else:
|
||||||
|
assert episode_length == len(value)
|
||||||
|
new_len = curr_len + episode_length
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
new_shape = (new_len,) + value.shape[1:]
|
||||||
|
# create array
|
||||||
|
if key not in self.data:
|
||||||
|
if is_zarr:
|
||||||
|
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||||
|
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||||
|
arr = self.data.zeros(
|
||||||
|
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# copy data to prevent modify
|
||||||
|
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
||||||
|
self.data[key] = arr
|
||||||
|
else:
|
||||||
|
arr = self.data[key]
|
||||||
|
assert value.shape[1:] == arr.shape[1:]
|
||||||
|
# same method for both zarr and numpy
|
||||||
|
if is_zarr:
|
||||||
|
arr.resize(new_shape)
|
||||||
|
else:
|
||||||
|
arr.resize(new_shape, refcheck=False)
|
||||||
|
# copy data
|
||||||
|
arr[-value.shape[0] :] = value
|
||||||
|
|
||||||
|
# append to episode ends
|
||||||
|
episode_ends = self.episode_ends
|
||||||
|
if is_zarr:
|
||||||
|
episode_ends.resize(episode_ends.shape[0] + 1)
|
||||||
|
else:
|
||||||
|
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
||||||
|
episode_ends[-1] = new_len
|
||||||
|
|
||||||
|
# rechunk
|
||||||
|
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||||
|
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||||
|
|
||||||
|
def drop_episode(self):
|
||||||
|
is_zarr = self.backend == "zarr"
|
||||||
|
episode_ends = self.episode_ends[:].copy()
|
||||||
|
assert len(episode_ends) > 0
|
||||||
|
start_idx = 0
|
||||||
|
if len(episode_ends) > 1:
|
||||||
|
start_idx = episode_ends[-2]
|
||||||
|
for value in self.data.values():
|
||||||
|
new_shape = (start_idx,) + value.shape[1:]
|
||||||
|
if is_zarr:
|
||||||
|
value.resize(new_shape)
|
||||||
|
else:
|
||||||
|
value.resize(new_shape, refcheck=False)
|
||||||
|
if is_zarr:
|
||||||
|
self.episode_ends.resize(len(episode_ends) - 1)
|
||||||
|
else:
|
||||||
|
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
||||||
|
|
||||||
|
def pop_episode(self):
|
||||||
|
assert self.n_episodes > 0
|
||||||
|
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
||||||
|
self.drop_episode()
|
||||||
|
return episode
|
||||||
|
|
||||||
|
def extend(self, data):
|
||||||
|
self.add_episode(data)
|
||||||
|
|
||||||
|
def get_episode(self, idx, copy=False):
|
||||||
|
idx = list(range(len(self.episode_ends)))[idx]
|
||||||
|
start_idx = 0
|
||||||
|
if idx > 0:
|
||||||
|
start_idx = self.episode_ends[idx - 1]
|
||||||
|
end_idx = self.episode_ends[idx]
|
||||||
|
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_episode_slice(self, idx):
|
||||||
|
start_idx = 0
|
||||||
|
if idx > 0:
|
||||||
|
start_idx = self.episode_ends[idx - 1]
|
||||||
|
end_idx = self.episode_ends[idx]
|
||||||
|
return slice(start_idx, end_idx)
|
||||||
|
|
||||||
|
def get_steps_slice(self, start, stop, step=None, copy=False):
|
||||||
|
_slice = slice(start, stop, step)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in self.data.items():
|
||||||
|
x = value[_slice]
|
||||||
|
if copy and isinstance(value, np.ndarray):
|
||||||
|
x = x.copy()
|
||||||
|
result[key] = x
|
||||||
|
return result
|
||||||
|
|
||||||
|
# =========== chunking =============
|
||||||
|
def get_chunks(self) -> dict:
|
||||||
|
assert self.backend == "zarr"
|
||||||
|
chunks = {}
|
||||||
|
for key, value in self.data.items():
|
||||||
|
chunks[key] = value.chunks
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def set_chunks(self, chunks: dict):
|
||||||
|
assert self.backend == "zarr"
|
||||||
|
for key, value in chunks.items():
|
||||||
|
if key in self.data:
|
||||||
|
arr = self.data[key]
|
||||||
|
if value != arr.chunks:
|
||||||
|
check_chunks_compatible(chunks=value, shape=arr.shape)
|
||||||
|
rechunk_recompress_array(self.data, key, chunks=value)
|
||||||
|
|
||||||
|
def get_compressors(self) -> dict:
|
||||||
|
assert self.backend == "zarr"
|
||||||
|
compressors = {}
|
||||||
|
for key, value in self.data.items():
|
||||||
|
compressors[key] = value.compressor
|
||||||
|
return compressors
|
||||||
|
|
||||||
|
def set_compressors(self, compressors: dict):
|
||||||
|
assert self.backend == "zarr"
|
||||||
|
for key, value in compressors.items():
|
||||||
|
if key in self.data:
|
||||||
|
arr = self.data[key]
|
||||||
|
compressor = self.resolve_compressor(value)
|
||||||
|
if compressor != arr.compressor:
|
||||||
|
rechunk_recompress_array(self.data, key, compressor=compressor)
|
|
@ -1,6 +1,6 @@
|
||||||
def make_policy(cfg):
|
def make_policy(cfg):
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc import TDMPC
|
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||||
|
|
||||||
policy = TDMPC(cfg.policy, cfg.device)
|
policy = TDMPC(cfg.policy, cfg.device)
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc_helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
|
@ -74,7 +74,6 @@ noise_scheduler:
|
||||||
prediction_type: epsilon # or sample
|
prediction_type: epsilon # or sample
|
||||||
|
|
||||||
obs_encoder:
|
obs_encoder:
|
||||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
|
||||||
shape_meta: ${shape_meta}
|
shape_meta: ${shape_meta}
|
||||||
# resize_shape: null
|
# resize_shape: null
|
||||||
# crop_shape: [76, 76]
|
# crop_shape: [76, 76]
|
||||||
|
@ -85,12 +84,12 @@ obs_encoder:
|
||||||
imagenet_norm: True
|
imagenet_norm: True
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
|
||||||
name: resnet18
|
name: resnet18
|
||||||
weights: null
|
weights: null
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||||
update_after_step: 0
|
update_after_step: 0
|
||||||
inv_gamma: 1.0
|
inv_gamma: 1.0
|
||||||
power: 0.75
|
power: 0.75
|
||||||
|
|
|
@ -477,21 +477,6 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi
|
||||||
torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"]
|
torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"]
|
||||||
training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
|
training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "diffusion_policy"
|
|
||||||
version = "0.0.0"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = []
|
|
||||||
develop = false
|
|
||||||
|
|
||||||
[package.source]
|
|
||||||
type = "git"
|
|
||||||
url = "https://github.com/real-stanford/diffusion_policy"
|
|
||||||
reference = "HEAD"
|
|
||||||
resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "distlib"
|
name = "distlib"
|
||||||
version = "0.3.8"
|
version = "0.3.8"
|
||||||
|
@ -3140,4 +3125,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "9c3e86956dd11bc8d7823e5e6c5e74a073051b495f71f96179113d99791f7ca0"
|
content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6"
|
||||||
|
|
|
@ -45,7 +45,6 @@ mujoco = "^3.1.2"
|
||||||
mujoco-py = "^2.1.2.14"
|
mujoco-py = "^2.1.2.14"
|
||||||
gym = "^0.26.2"
|
gym = "^0.26.2"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = "^4.9.0.80"
|
||||||
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
|
|
||||||
diffusers = "^0.26.3"
|
diffusers = "^0.26.3"
|
||||||
torchvision = "^0.17.1"
|
torchvision = "^0.17.1"
|
||||||
h5py = "^3.10.0"
|
h5py = "^3.10.0"
|
||||||
|
|
|
@ -3,7 +3,7 @@ from tensordict import TensorDict
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.pusht import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
from .utils import init_config
|
from .utils import init_config
|
||||||
|
|
Loading…
Reference in New Issue