Small fix, Refactor diffusion, Diffusion runs (TODO: remove normalization in diffusion)
This commit is contained in:
parent
45b4ecb727
commit
80785f8d0e
|
@ -0,0 +1,246 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from einops import reduce
|
||||||
|
|
||||||
|
from diffusion_policy.common.pytorch_util import dict_apply
|
||||||
|
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
|
||||||
|
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||||
|
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler: DDPMScheduler,
|
||||||
|
obs_encoder: MultiImageObsEncoder,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
obs_as_global_cond=True,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=(256, 512, 1024),
|
||||||
|
kernel_size=5,
|
||||||
|
n_groups=8,
|
||||||
|
cond_predict_scale=True,
|
||||||
|
# parameters passed to step
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# parse shapes
|
||||||
|
action_shape = shape_meta["action"]["shape"]
|
||||||
|
assert len(action_shape) == 1
|
||||||
|
action_dim = action_shape[0]
|
||||||
|
# get feature dim
|
||||||
|
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||||
|
|
||||||
|
# create diffusion model
|
||||||
|
input_dim = action_dim + obs_feature_dim
|
||||||
|
global_cond_dim = None
|
||||||
|
if obs_as_global_cond:
|
||||||
|
input_dim = action_dim
|
||||||
|
global_cond_dim = obs_feature_dim * n_obs_steps
|
||||||
|
|
||||||
|
model = ConditionalUnet1D(
|
||||||
|
input_dim=input_dim,
|
||||||
|
local_cond_dim=None,
|
||||||
|
global_cond_dim=global_cond_dim,
|
||||||
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||||
|
down_dims=down_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.obs_encoder = obs_encoder
|
||||||
|
self.model = model
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
self.mask_generator = LowdimMaskGenerator(
|
||||||
|
action_dim=action_dim,
|
||||||
|
obs_dim=0 if obs_as_global_cond else obs_feature_dim,
|
||||||
|
max_n_obs_steps=n_obs_steps,
|
||||||
|
fix_obs_steps=True,
|
||||||
|
action_visible=False,
|
||||||
|
)
|
||||||
|
self.horizon = horizon
|
||||||
|
self.obs_feature_dim = obs_feature_dim
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.obs_as_global_cond = obs_as_global_cond
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
if num_inference_steps is None:
|
||||||
|
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
condition_data,
|
||||||
|
condition_mask,
|
||||||
|
local_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
generator=None,
|
||||||
|
# keyword arguments to scheduler.step
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model = self.model
|
||||||
|
scheduler = self.noise_scheduler
|
||||||
|
|
||||||
|
trajectory = torch.randn(
|
||||||
|
size=condition_data.shape,
|
||||||
|
dtype=condition_data.dtype,
|
||||||
|
device=condition_data.device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# set step values
|
||||||
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
|
for t in scheduler.timesteps:
|
||||||
|
# 1. apply conditioning
|
||||||
|
trajectory[condition_mask] = condition_data[condition_mask]
|
||||||
|
|
||||||
|
# 2. predict model output
|
||||||
|
model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
|
||||||
|
|
||||||
|
# 3. compute previous image: x_t -> x_t-1
|
||||||
|
trajectory = scheduler.step(
|
||||||
|
model_output,
|
||||||
|
t,
|
||||||
|
trajectory,
|
||||||
|
generator=generator,
|
||||||
|
# **kwargs # TODO(rcadene): in diffusion_policy, expected to be {}
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# finally make sure conditioning is enforced
|
||||||
|
trajectory[condition_mask] = condition_data[condition_mask]
|
||||||
|
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
obs_dict: must include "obs" key
|
||||||
|
result: must include "action" key
|
||||||
|
"""
|
||||||
|
assert "past_action" not in obs_dict # not implemented yet
|
||||||
|
nobs = obs_dict
|
||||||
|
value = next(iter(nobs.values()))
|
||||||
|
bsize, n_obs_steps = value.shape[:2]
|
||||||
|
horizon = self.horizon
|
||||||
|
action_dim = self.action_dim
|
||||||
|
obs_dim = self.obs_feature_dim
|
||||||
|
assert self.n_obs_steps == n_obs_steps
|
||||||
|
|
||||||
|
# build input
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
|
||||||
|
# handle different ways of passing observation
|
||||||
|
local_cond = None
|
||||||
|
global_cond = None
|
||||||
|
if self.obs_as_global_cond:
|
||||||
|
# condition through global feature
|
||||||
|
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
# reshape back to B, Do
|
||||||
|
global_cond = nobs_features.reshape(bsize, -1)
|
||||||
|
# empty data for action
|
||||||
|
cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# condition through impainting
|
||||||
|
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
# reshape back to B, T, Do
|
||||||
|
nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
|
||||||
|
cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
cond_data[:, :n_obs_steps, action_dim:] = nobs_features
|
||||||
|
cond_mask[:, :n_obs_steps, action_dim:] = True
|
||||||
|
|
||||||
|
# run sampling
|
||||||
|
nsample = self.conditional_sample(
|
||||||
|
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
action_pred = nsample[..., :action_dim]
|
||||||
|
|
||||||
|
# get action
|
||||||
|
start = n_obs_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
action = action_pred[:, start:end]
|
||||||
|
|
||||||
|
result = {"action": action, "action_pred": action_pred}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
assert "valid_mask" not in batch
|
||||||
|
nobs = batch["obs"]
|
||||||
|
nactions = batch["action"]
|
||||||
|
batch_size = nactions.shape[0]
|
||||||
|
horizon = nactions.shape[1]
|
||||||
|
|
||||||
|
# handle different ways of passing observation
|
||||||
|
local_cond = None
|
||||||
|
global_cond = None
|
||||||
|
trajectory = nactions
|
||||||
|
cond_data = trajectory
|
||||||
|
if self.obs_as_global_cond:
|
||||||
|
# reshape B, T, ... to B*T
|
||||||
|
this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
# reshape back to B, Do
|
||||||
|
global_cond = nobs_features.reshape(batch_size, -1)
|
||||||
|
else:
|
||||||
|
# reshape B, T, ... to B*T
|
||||||
|
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
|
# reshape back to B, T, Do
|
||||||
|
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||||
|
cond_data = torch.cat([nactions, nobs_features], dim=-1)
|
||||||
|
trajectory = cond_data.detach()
|
||||||
|
|
||||||
|
# generate impainting mask
|
||||||
|
condition_mask = self.mask_generator(trajectory.shape)
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the images
|
||||||
|
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
||||||
|
bsz = trajectory.shape[0]
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(
|
||||||
|
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
|
||||||
|
).long()
|
||||||
|
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
||||||
|
|
||||||
|
# compute loss mask
|
||||||
|
loss_mask = ~condition_mask
|
||||||
|
|
||||||
|
# apply conditioning
|
||||||
|
noisy_trajectory[condition_mask] = cond_data[condition_mask]
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
||||||
|
|
||||||
|
pred_type = self.noise_scheduler.config.prediction_type
|
||||||
|
if pred_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif pred_type == "sample":
|
||||||
|
target = trajectory
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||||
|
|
||||||
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
loss = loss * loss_mask.type(loss.dtype)
|
||||||
|
loss = reduce(loss, "b ... -> b (...)", "mean")
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
|
@ -0,0 +1,189 @@
|
||||||
|
import copy
|
||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from diffusion_policy.common.pytorch_util import replace_submodules
|
||||||
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||||
|
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
|
||||||
|
|
||||||
|
|
||||||
|
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
|
||||||
|
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||||
|
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||||
|
random_crop: bool = True,
|
||||||
|
# replace BatchNorm with GroupNorm
|
||||||
|
use_group_norm: bool = False,
|
||||||
|
# use single rgb model for all rgb inputs
|
||||||
|
share_rgb_model: bool = False,
|
||||||
|
# renormalize rgb input with imagenet normalization
|
||||||
|
# assuming input in [0,1]
|
||||||
|
imagenet_norm: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Assumes rgb input: B,C,H,W
|
||||||
|
Assumes low_dim input: B,D
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
rgb_keys = []
|
||||||
|
low_dim_keys = []
|
||||||
|
key_model_map = nn.ModuleDict()
|
||||||
|
key_transform_map = nn.ModuleDict()
|
||||||
|
key_shape_map = {}
|
||||||
|
|
||||||
|
# handle sharing vision backbone
|
||||||
|
if share_rgb_model:
|
||||||
|
assert isinstance(rgb_model, nn.Module)
|
||||||
|
key_model_map["rgb"] = rgb_model
|
||||||
|
|
||||||
|
obs_shape_meta = shape_meta["obs"]
|
||||||
|
for key, attr in obs_shape_meta.items():
|
||||||
|
shape = tuple(attr["shape"])
|
||||||
|
type = attr.get("type", "low_dim")
|
||||||
|
key_shape_map[key] = shape
|
||||||
|
if type == "rgb":
|
||||||
|
rgb_keys.append(key)
|
||||||
|
# configure model for this key
|
||||||
|
this_model = None
|
||||||
|
if not share_rgb_model:
|
||||||
|
if isinstance(rgb_model, dict):
|
||||||
|
# have provided model for each key
|
||||||
|
this_model = rgb_model[key]
|
||||||
|
else:
|
||||||
|
assert isinstance(rgb_model, nn.Module)
|
||||||
|
# have a copy of the rgb model
|
||||||
|
this_model = copy.deepcopy(rgb_model)
|
||||||
|
|
||||||
|
if this_model is not None:
|
||||||
|
if use_group_norm:
|
||||||
|
this_model = replace_submodules(
|
||||||
|
root_module=this_model,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(
|
||||||
|
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||||
|
),
|
||||||
|
)
|
||||||
|
key_model_map[key] = this_model
|
||||||
|
|
||||||
|
# configure resize
|
||||||
|
input_shape = shape
|
||||||
|
this_resizer = nn.Identity()
|
||||||
|
if resize_shape is not None:
|
||||||
|
if isinstance(resize_shape, dict):
|
||||||
|
h, w = resize_shape[key]
|
||||||
|
else:
|
||||||
|
h, w = resize_shape
|
||||||
|
this_resizer = torchvision.transforms.Resize(size=(h, w))
|
||||||
|
input_shape = (shape[0], h, w)
|
||||||
|
|
||||||
|
# configure randomizer
|
||||||
|
this_randomizer = nn.Identity()
|
||||||
|
if crop_shape is not None:
|
||||||
|
if isinstance(crop_shape, dict):
|
||||||
|
h, w = crop_shape[key]
|
||||||
|
else:
|
||||||
|
h, w = crop_shape
|
||||||
|
if random_crop:
|
||||||
|
this_randomizer = CropRandomizer(
|
||||||
|
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||||
|
# configure normalizer
|
||||||
|
this_normalizer = nn.Identity()
|
||||||
|
if imagenet_norm:
|
||||||
|
# TODO(rcadene): move normalizer to dataset and env
|
||||||
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
|
||||||
|
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||||
|
key_transform_map[key] = this_transform
|
||||||
|
elif type == "low_dim":
|
||||||
|
low_dim_keys.append(key)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||||
|
rgb_keys = sorted(rgb_keys)
|
||||||
|
low_dim_keys = sorted(low_dim_keys)
|
||||||
|
|
||||||
|
self.shape_meta = shape_meta
|
||||||
|
self.key_model_map = key_model_map
|
||||||
|
self.key_transform_map = key_transform_map
|
||||||
|
self.share_rgb_model = share_rgb_model
|
||||||
|
self.rgb_keys = rgb_keys
|
||||||
|
self.low_dim_keys = low_dim_keys
|
||||||
|
self.key_shape_map = key_shape_map
|
||||||
|
|
||||||
|
def forward(self, obs_dict):
|
||||||
|
batch_size = None
|
||||||
|
features = []
|
||||||
|
# process rgb input
|
||||||
|
if self.share_rgb_model:
|
||||||
|
# pass all rgb obs to rgb model
|
||||||
|
imgs = []
|
||||||
|
for key in self.rgb_keys:
|
||||||
|
img = obs_dict[key]
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
else:
|
||||||
|
assert batch_size == img.shape[0]
|
||||||
|
assert img.shape[1:] == self.key_shape_map[key]
|
||||||
|
img = self.key_transform_map[key](img)
|
||||||
|
imgs.append(img)
|
||||||
|
# (N*B,C,H,W)
|
||||||
|
imgs = torch.cat(imgs, dim=0)
|
||||||
|
# (N*B,D)
|
||||||
|
feature = self.key_model_map["rgb"](imgs)
|
||||||
|
# (N,B,D)
|
||||||
|
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||||
|
# (B,N,D)
|
||||||
|
feature = torch.moveaxis(feature, 0, 1)
|
||||||
|
# (B,N*D)
|
||||||
|
feature = feature.reshape(batch_size, -1)
|
||||||
|
features.append(feature)
|
||||||
|
else:
|
||||||
|
# run each rgb obs to independent models
|
||||||
|
for key in self.rgb_keys:
|
||||||
|
img = obs_dict[key]
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
else:
|
||||||
|
assert batch_size == img.shape[0]
|
||||||
|
assert img.shape[1:] == self.key_shape_map[key]
|
||||||
|
img = self.key_transform_map[key](img)
|
||||||
|
feature = self.key_model_map[key](img)
|
||||||
|
features.append(feature)
|
||||||
|
|
||||||
|
# process lowdim input
|
||||||
|
for key in self.low_dim_keys:
|
||||||
|
data = obs_dict[key]
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = data.shape[0]
|
||||||
|
else:
|
||||||
|
assert batch_size == data.shape[0]
|
||||||
|
assert data.shape[1:] == self.key_shape_map[key]
|
||||||
|
features.append(data)
|
||||||
|
|
||||||
|
# concatenate all features
|
||||||
|
result = torch.cat(features, dim=-1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def output_shape(self):
|
||||||
|
example_obs_dict = {}
|
||||||
|
obs_shape_meta = self.shape_meta["obs"]
|
||||||
|
batch_size = 1
|
||||||
|
for key, attr in obs_shape_meta.items():
|
||||||
|
shape = tuple(attr["shape"])
|
||||||
|
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
|
||||||
|
example_obs_dict[key] = this_obs
|
||||||
|
example_output = self.forward(example_obs_dict)
|
||||||
|
output_shape = example_output.shape[1:]
|
||||||
|
return output_shape
|
|
@ -1,6 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import einops
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -8,8 +9,9 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
|
||||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
|
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
FIRST_ACTION = 0
|
FIRST_ACTION = 0
|
||||||
|
|
||||||
|
@ -99,10 +101,15 @@ class DiffusionPolicy(nn.Module):
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step_count
|
||||||
|
|
||||||
|
# TODO(rcadene): remove unsqueeze hack...
|
||||||
|
if observation["image"].ndim == 3:
|
||||||
|
observation["image"] = observation["image"].unsqueeze(0)
|
||||||
|
observation["state"] = observation["state"].unsqueeze(0)
|
||||||
|
|
||||||
obs_dict = {
|
obs_dict = {
|
||||||
# c h w -> b t c h w (b=1, t=1)
|
# TODO(rcadene): hack to add temporal dim
|
||||||
"image": observation["image"][None, None, ...],
|
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
|
||||||
"agent_pos": observation["state"][None, None, ...],
|
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
|
||||||
}
|
}
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
|
|
|
@ -4,7 +4,7 @@ def make_policy(cfg):
|
||||||
|
|
||||||
policy = TDMPC(cfg.policy)
|
policy = TDMPC(cfg.policy)
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
|
|
||||||
policy = DiffusionPolicy(
|
policy = DiffusionPolicy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
|
|
|
@ -138,9 +138,6 @@ class TDMPC(nn.Module):
|
||||||
"state": observation["state"].contiguous(),
|
"state": observation["state"].contiguous(),
|
||||||
}
|
}
|
||||||
action = self.act(obs, t0=t0, step=self.step.item())
|
action = self.act(obs, t0=t0, step=self.step.item())
|
||||||
|
|
||||||
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
|
|
||||||
# action = action * self.action_std + self.action_mean
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -147,7 +147,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
env = make_env(cfg, transform=offline_buffer._transform)
|
env = make_env(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg, transform=offline_buffer._transform)
|
policy = make_policy(cfg)
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
|
|
Loading…
Reference in New Issue