[HIL-SERL PORT] Fix linter issues (#588)

This commit is contained in:
Eugene Mironov 2024-12-23 16:44:29 +07:00 committed by AdilZouitine
parent 6340d9d17c
commit d96edbf5ac
1 changed files with 8 additions and 12 deletions

View File

@ -19,21 +19,18 @@
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Sequence, Tuple
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
import numpy as np
from typing import Callable, Optional, Tuple, Sequence
class SACPolicy(
@ -290,10 +287,7 @@ class Critic(nn.Module):
observations = observations.to(self.device)
actions = actions.to(self.device)
if self.encoder is not None:
obs_enc = self.encoder(observations)
else:
obs_enc = observations
obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
@ -563,6 +557,8 @@ class LagrangeMultiplier(nn.Module):
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
DEFAULT_SAMPLE_SHAPE = torch.Size()
def __init__(
self,
loc: torch.Tensor,
@ -611,7 +607,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
return mode
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
@ -643,7 +639,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
return log_prob
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""