Fixes for PR #3
This commit is contained in:
parent
7e024fdce6
commit
2c05b75f45
|
@ -44,6 +44,9 @@ class PushtEnv(EnvBase):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gym.")
|
||||||
|
|
||||||
|
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||||
|
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
if not from_pixels:
|
if not from_pixels:
|
||||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# ruff: noqa: N806
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -78,7 +80,7 @@ class TOLD(nn.Module):
|
||||||
return torch.stack([q(x) for q in self._Qs], dim=0)
|
return torch.stack([q(x) for q in self._Qs], dim=0)
|
||||||
|
|
||||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
|
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,21 +148,21 @@ class TDMPC(nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_value(self, z, actions, horizon):
|
def estimate_value(self, z, actions, horizon):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1 # noqa: N806
|
G, discount = 0, 1
|
||||||
for t in range(horizon):
|
for t in range(horizon):
|
||||||
if self.cfg.uncertainty_cost > 0:
|
if self.cfg.uncertainty_cost > 0:
|
||||||
G -= ( # noqa: N806
|
G -= (
|
||||||
discount
|
discount
|
||||||
* self.cfg.uncertainty_cost
|
* self.cfg.uncertainty_cost
|
||||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||||
)
|
)
|
||||||
z, reward = self.model.next(z, actions[t])
|
z, reward = self.model.next(z, actions[t])
|
||||||
G += discount * reward # noqa: N806
|
G += discount * reward
|
||||||
discount *= self.cfg.discount
|
discount *= self.cfg.discount
|
||||||
pi = self.model.pi(z, self.cfg.min_std)
|
pi = self.model.pi(z, self.cfg.min_std)
|
||||||
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
|
G += discount * self.model.Q(z, pi, return_type="min")
|
||||||
if self.cfg.uncertainty_cost > 0:
|
if self.cfg.uncertainty_cost > 0:
|
||||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
|
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue