- Refactor observation encoder in `modeling_sac.py`
- added `torch.compile` to the actor and learner servers. - organized imports in `train_sac.py` - optimized the parameters push by not sending the frozen pre-trained encoder. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
f1c8bfe01e
commit
506821c7df
|
@ -55,9 +55,10 @@ class SACConfig:
|
||||||
)
|
)
|
||||||
camera_number: int = 1
|
camera_number: int = 1
|
||||||
# Add type annotations for these fields:
|
# Add type annotations for these fields:
|
||||||
vision_encoder_name: str = field(default="microsoft/resnet-18")
|
vision_encoder_name: str | None = field(default="microsoft/resnet-18")
|
||||||
|
freeze_vision_encoder: bool = True
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = False
|
shared_encoder: bool = True
|
||||||
discount: float = 0.99
|
discount: float = 0.99
|
||||||
temperature_init: float = 1.0
|
temperature_init: float = 1.0
|
||||||
num_critics: int = 2
|
num_critics: int = 2
|
||||||
|
|
|
@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network_list: nn.Module,
|
network_list: nn.ModuleList,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module):
|
||||||
self.network_list = network_list
|
self.network_list = network_list
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
|
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
# Handle the case where a part of the encoder if frozen
|
||||||
|
if self.encoder is not None:
|
||||||
|
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||||
# Find the last Linear layer's output dimension
|
# Find the last Linear layer's output dimension
|
||||||
for layer in reversed(network_list[0].net):
|
for layer in reversed(network_list[0].net):
|
||||||
if isinstance(layer, nn.Linear):
|
if isinstance(layer, nn.Linear):
|
||||||
|
@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module):
|
||||||
self.output_layers.append(output_layer)
|
self.output_layers.append(output_layer)
|
||||||
|
|
||||||
self.output_layers = nn.ModuleList(self.output_layers)
|
self.output_layers = nn.ModuleList(self.output_layers)
|
||||||
|
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -474,61 +481,25 @@ class SACObservationEncoder(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.has_pretrained_vision_encoder = False
|
self.has_pretrained_vision_encoder = False
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
|
self.aggregation_size: int = 0
|
||||||
if "observation.image" in config.input_shapes:
|
if "observation.image" in config.input_shapes:
|
||||||
self.camera_number = config.camera_number
|
self.camera_number = config.camera_number
|
||||||
self.aggregation_size: int = 0
|
|
||||||
if self.config.vision_encoder_name is not None:
|
if self.config.vision_encoder_name is not None:
|
||||||
|
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||||
self.has_pretrained_vision_encoder = True
|
self.has_pretrained_vision_encoder = True
|
||||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
|
|
||||||
self.freeze_encoder()
|
|
||||||
self.image_enc_proj = nn.Sequential(
|
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = DefaultImageEncoder(config)
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=config.input_shapes["observation.image"][0],
|
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=config.image_encoder_hidden_dim,
|
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
|
||||||
kernel_size=5,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=config.image_encoder_hidden_dim,
|
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=config.image_encoder_hidden_dim,
|
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
),
|
|
||||||
nn.ReLU(),
|
|
||||||
)
|
|
||||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
|
||||||
self.image_enc_layers.extend(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.aggregation_size += config.latent_dim * self.camera_number
|
self.aggregation_size += config.latent_dim * self.camera_number
|
||||||
|
|
||||||
|
if config.freeze_vision_encoder:
|
||||||
|
freeze_image_encoder(self.image_enc_layers)
|
||||||
|
else:
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
|
|
||||||
if "observation.state" in config.input_shapes:
|
if "observation.state" in config.input_shapes:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
|
@ -539,6 +510,8 @@ class SACObservationEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self.aggregation_size += config.latent_dim
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_shapes:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
|
@ -548,26 +521,11 @@ class SACObservationEncoder(nn.Module):
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self.aggregation_size += config.latent_dim
|
||||||
|
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||||
|
|
||||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||||
|
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||||
def _load_pretrained_vision_encoder(self):
|
|
||||||
"""Set up CNN encoder"""
|
|
||||||
from transformers import AutoModel
|
|
||||||
|
|
||||||
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
|
|
||||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
|
||||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
|
||||||
elif hasattr(self.image_enc_layers, "fc"):
|
|
||||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
|
||||||
return self.image_enc_layers, self.image_enc_out_shape
|
|
||||||
|
|
||||||
def freeze_encoder(self):
|
|
||||||
"""Freeze all parameters in the encoder"""
|
|
||||||
for param in self.image_enc_layers.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
@ -579,12 +537,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
# Concatenate all images along the channel dimension.
|
# Concatenate all images along the channel dimension.
|
||||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||||
for image_key in image_keys:
|
for image_key in image_keys:
|
||||||
if self.has_pretrained_vision_encoder:
|
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||||
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
|
|
||||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
|
||||||
else:
|
|
||||||
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
|
||||||
|
|
||||||
|
# if not self.has_pretrained_vision_encoder:
|
||||||
|
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||||
feat.append(enc_feat)
|
feat.append(enc_feat)
|
||||||
if "observation.environment_state" in self.config.input_shapes:
|
if "observation.environment_state" in self.config.input_shapes:
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
|
@ -602,10 +558,107 @@ class SACObservationEncoder(nn.Module):
|
||||||
return self.config.latent_dim
|
return self.config.latent_dim
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultImageEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.image_enc_layers = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=config.input_shapes["observation.image"][0],
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=5,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=config.image_encoder_hidden_dim,
|
||||||
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
|
self.image_enc_layers.extend(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.image_enc_layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainedImageEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||||
|
self.image_enc_proj = nn.Sequential(
|
||||||
|
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_pretrained_vision_encoder(self, config):
|
||||||
|
"""Set up CNN encoder"""
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
|
||||||
|
# self.image_enc_layers.pooler = Identity()
|
||||||
|
|
||||||
|
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||||
|
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||||
|
elif hasattr(self.image_enc_layers, "fc"):
|
||||||
|
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||||
|
return self.image_enc_layers, self.image_enc_out_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||||
|
# doesn't reach the classifier layer because we don't need it
|
||||||
|
enc_feat = self.image_enc_layers(x).pooler_output
|
||||||
|
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||||
|
return enc_feat
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_image_encoder(image_encoder: nn.Module):
|
||||||
|
"""Freeze all parameters in the encoder"""
|
||||||
|
for param in image_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_init():
|
def orthogonal_init():
|
||||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||||
# after some investigation
|
# after some investigation
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
|
@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||||
flat_out = fn(inp)
|
flat_out = fn(inp)
|
||||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test the SACObservationEncoder
|
||||||
|
import time
|
||||||
|
|
||||||
|
config = SACConfig()
|
||||||
|
config.num_critics = 10
|
||||||
|
encoder = SACObservationEncoder(config)
|
||||||
|
actor_encoder = SACObservationEncoder(config)
|
||||||
|
encoder = torch.compile(encoder)
|
||||||
|
critic_ensemble = CriticEnsemble(
|
||||||
|
encoder=encoder,
|
||||||
|
network_list=nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLP(
|
||||||
|
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||||
|
**config.critic_network_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_critics)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
actor = Policy(
|
||||||
|
encoder=actor_encoder,
|
||||||
|
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||||
|
action_dim=config.output_shapes["action"][0],
|
||||||
|
encoder_is_shared=config.shared_encoder,
|
||||||
|
**config.policy_kwargs,
|
||||||
|
)
|
||||||
|
encoder = encoder.to("cuda:0")
|
||||||
|
critic_ensemble = torch.compile(critic_ensemble)
|
||||||
|
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||||
|
actor = torch.compile(actor)
|
||||||
|
actor = actor.to("cuda:0")
|
||||||
|
obs_dict = {
|
||||||
|
"observation.image": torch.randn(1, 3, 84, 84),
|
||||||
|
"observation.state": torch.randn(1, 4),
|
||||||
|
}
|
||||||
|
actions = torch.randn(1, 2).to("cuda:0")
|
||||||
|
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||||
|
print("compiling...")
|
||||||
|
# q_value = critic_ensemble(obs_dict, actions)
|
||||||
|
action = actor(obs_dict)
|
||||||
|
print("compiled")
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(1000):
|
||||||
|
# features = encoder(obs_dict)
|
||||||
|
action = actor(obs_dict)
|
||||||
|
# q_value = critic_ensemble(obs_dict, actions)
|
||||||
|
print("Time taken:", time.perf_counter() - start)
|
||||||
|
|
|
@ -52,6 +52,8 @@ policy:
|
||||||
n_action_steps: 1
|
n_action_steps: 1
|
||||||
|
|
||||||
shared_encoder: true
|
shared_encoder: true
|
||||||
|
# vision_encoder_name: null
|
||||||
|
freeze_vision_encoder: false
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
|
|
|
@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
|
||||||
# pretrained_policy_name_or_path=None,
|
# pretrained_policy_name_or_path=None,
|
||||||
# device=device,
|
# device=device,
|
||||||
# )
|
# )
|
||||||
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# HACK for maniskill
|
# HACK for maniskill
|
||||||
|
@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
|
||||||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||||
state_dict = parameters_queue.get()
|
state_dict = parameters_queue.get()
|
||||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||||
policy.actor.load_state_dict(state_dict)
|
# strict=False for the case when the image encoder is frozen and not sent through
|
||||||
|
# the network. Becareful might cause issues if the wrong keys are passed
|
||||||
|
policy.actor.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
if len(list_transition_to_send_to_learner) > 0:
|
if len(list_transition_to_send_to_learner) > 0:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
|
|
|
@ -259,6 +259,9 @@ def learner_push_parameters(
|
||||||
while True:
|
while True:
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
params_dict = policy.actor.state_dict()
|
params_dict = policy.actor.state_dict()
|
||||||
|
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||||
|
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
||||||
|
|
||||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||||
# Serialize
|
# Serialize
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
|
@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
|
# compile policy
|
||||||
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||||
|
|
|
@ -13,26 +13,25 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import functools
|
import functools
|
||||||
from pprint import pformat
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Sequence, TypedDict, Callable
|
from pprint import pformat
|
||||||
|
from typing import Callable, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
|
||||||
from tqdm import tqdm
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from torch import nn
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from tqdm import tqdm
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
Loading…
Reference in New Issue