Fixes for PR #3

This commit is contained in:
Simon Alibert 2024-02-29 21:46:41 +01:00
parent 7e024fdce6
commit 2c05b75f45
2 changed files with 11 additions and 6 deletions

View File

@ -44,6 +44,9 @@ class PushtEnv(EnvBase):
if not _has_gym:
raise ImportError("Cannot import gym.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv

View File

@ -1,3 +1,5 @@
# ruff: noqa: N806
from copy import deepcopy
import einops
@ -78,7 +80,7 @@ class TOLD(nn.Module):
return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
@ -146,21 +148,21 @@ class TDMPC(nn.Module):
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 # noqa: N806
G, discount = 0, 1
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= ( # noqa: N806
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward # noqa: N806
G += discount * reward
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0:
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
return G
@torch.no_grad()