[HIL-SERL PORT] Fix linter issues (#588)
This commit is contained in:
parent
70b652f791
commit
b53d6e0ff2
|
@ -19,21 +19,18 @@
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from typing import Callable, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
import numpy as np
|
|
||||||
from typing import Callable, Optional, Tuple, Sequence
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
|
@ -290,10 +287,7 @@ class Critic(nn.Module):
|
||||||
observations = observations.to(self.device)
|
observations = observations.to(self.device)
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
if self.encoder is not None:
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
obs_enc = self.encoder(observations)
|
|
||||||
else:
|
|
||||||
obs_enc = observations
|
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
x = self.network(inputs)
|
x = self.network(inputs)
|
||||||
|
@ -563,6 +557,8 @@ class LagrangeMultiplier(nn.Module):
|
||||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
|
DEFAULT_SAMPLE_SHAPE = torch.Size()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
|
@ -611,7 +607,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
|
|
||||||
return mode
|
return mode
|
||||||
|
|
||||||
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Reparameterized sample from the distribution
|
Reparameterized sample from the distribution
|
||||||
"""
|
"""
|
||||||
|
@ -643,7 +639,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
|
|
||||||
return log_prob
|
return log_prob
|
||||||
|
|
||||||
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Sample from the distribution and compute log probability
|
Sample from the distribution and compute log probability
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue