Fixes for PR #3
This commit is contained in:
parent
7e024fdce6
commit
2c05b75f45
|
@ -44,6 +44,9 @@ class PushtEnv(EnvBase):
|
|||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# ruff: noqa: N806
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
|
@ -78,7 +80,7 @@ class TOLD(nn.Module):
|
|||
return torch.stack([q(x) for q in self._Qs], dim=0)
|
||||
|
||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
|
@ -146,21 +148,21 @@ class TDMPC(nn.Module):
|
|||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1 # noqa: N806
|
||||
G, discount = 0, 1
|
||||
for t in range(horizon):
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= ( # noqa: N806
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||
)
|
||||
z, reward = self.model.next(z, actions[t])
|
||||
G += discount * reward # noqa: N806
|
||||
G += discount * reward
|
||||
discount *= self.cfg.discount
|
||||
pi = self.model.pi(z, self.cfg.min_std)
|
||||
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
|
||||
G += discount * self.model.Q(z, pi, return_type="min")
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
|
||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue