[HIL-SERL PORT] Fix linter issues (#588)
This commit is contained in:
parent
6340d9d17c
commit
d96edbf5ac
|
@ -19,21 +19,18 @@
|
|||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import einops
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Sequence
|
||||
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
|
@ -290,10 +287,7 @@ class Critic(nn.Module):
|
|||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if self.encoder is not None:
|
||||
obs_enc = self.encoder(observations)
|
||||
else:
|
||||
obs_enc = observations
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
|
@ -563,6 +557,8 @@ class LagrangeMultiplier(nn.Module):
|
|||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||
DEFAULT_SAMPLE_SHAPE = torch.Size()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
|
@ -611,7 +607,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
|
||||
return mode
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
||||
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
||||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
|
@ -643,7 +639,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample from the distribution and compute log probability
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue