From dc89e53d8dc0ae5404921ba3b14e314303329e53 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 23 Dec 2024 16:44:29 +0700 Subject: [PATCH] [HIL-SERL PORT] Fix linter issues (#588) --- lerobot/common/policies/sac/modeling_sac.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index de8283de..c5e3f209 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,21 +19,18 @@ from collections import deque from copy import deepcopy -from functools import partial +from typing import Callable, Optional, Sequence, Tuple import einops - +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 +from huggingface_hub import PyTorchModelHubMixin from torch import Tensor -from huggingface_hub import PyTorchModelHubMixin from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig -import numpy as np -from typing import Callable, Optional, Tuple, Sequence - class SACPolicy( @@ -290,10 +287,7 @@ class Critic(nn.Module): observations = observations.to(self.device) actions = actions.to(self.device) - if self.encoder is not None: - obs_enc = self.encoder(observations) - else: - obs_enc = observations + obs_enc = observations if self.encoder is None else self.encoder(observations) inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) @@ -563,6 +557,8 @@ class LagrangeMultiplier(nn.Module): # 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation # This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): + DEFAULT_SAMPLE_SHAPE = torch.Size() + def __init__( self, loc: torch.Tensor, @@ -611,7 +607,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return mode - def rsample(self, sample_shape=torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: """ Reparameterized sample from the distribution """ @@ -643,7 +639,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return log_prob - def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample from the distribution and compute log probability """