Remove unused functions and imports from modeling_sac.py

- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary.
- Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity.
- Simplified the assertion in the `Policy` class for better readability.
This commit is contained in:
AdilZouitine 2025-03-19 18:53:01 +00:00
parent 80e766c05c
commit c5c921cd7c
2 changed files with 4 additions and 104 deletions

View File

@ -23,7 +23,6 @@ from pathlib import Path
import einops
import numpy as np
from tensordict import from_modules
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
@ -446,50 +445,6 @@ class MLP(nn.Module):
return self.net(x)
def find_and_copy_params(
original_state_dict: dict[str, torch.Tensor],
loaded_state_dict: dict[str, torch.Tensor],
pattern: str,
match_type: str = "contains",
) -> list[str]:
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
This function can search for keys in different ways based on the match_type:
- "exact": The key must exactly match the pattern
- "contains": The key must contain the pattern anywhere
- "startswith": The key must start with the pattern
- "endswith": The key must end with the pattern
Args:
original_state_dict: The source state dictionary
loaded_state_dict: The target state dictionary
pattern: The pattern to search for in keys
match_type: How to match the pattern (exact, contains, startswith, endswith)
Returns:
list[str]: List of keys that were copied
"""
copied_keys = []
for key in original_state_dict:
should_copy = False
if match_type == "exact":
should_copy = key == pattern
elif match_type == "contains":
should_copy = pattern in key
elif match_type == "startswith":
should_copy = key.startswith(pattern)
elif match_type == "endswith":
should_copy = key.endswith(pattern)
if should_copy:
loaded_state_dict[key] = original_state_dict[key]
copied_keys.append(key)
return copied_keys
class CriticHead(nn.Module):
def __init__(
self,
@ -682,9 +637,9 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
@ -932,60 +887,6 @@ class Identity(nn.Module):
return x
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
# combine_state_for_ensemble causes graph breaks
self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0])
self._repr = str(modules[0])
self._n = len(modules)
def __len__(self):
return self._n
def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(
self.params, *args, **kwargs
)
def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():

View File

@ -71,7 +71,6 @@ dependencies = [
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"tensordict>=0.0.1",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
"torchmetrics>=1.6.0",
@ -89,7 +88,7 @@ dora = [
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
mani_skill = ["mani-skill"]
pi0 = ["transformers>=4.48.0"]