Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary. - Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity. - Simplified the assertion in the `Policy` class for better readability.
This commit is contained in:
parent
80e766c05c
commit
c5c921cd7c
|
@ -23,7 +23,6 @@ from pathlib import Path
|
|||
|
||||
import einops
|
||||
import numpy as np
|
||||
from tensordict import from_modules
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
@ -446,50 +445,6 @@ class MLP(nn.Module):
|
|||
return self.net(x)
|
||||
|
||||
|
||||
def find_and_copy_params(
|
||||
original_state_dict: dict[str, torch.Tensor],
|
||||
loaded_state_dict: dict[str, torch.Tensor],
|
||||
pattern: str,
|
||||
match_type: str = "contains",
|
||||
) -> list[str]:
|
||||
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
|
||||
|
||||
This function can search for keys in different ways based on the match_type:
|
||||
- "exact": The key must exactly match the pattern
|
||||
- "contains": The key must contain the pattern anywhere
|
||||
- "startswith": The key must start with the pattern
|
||||
- "endswith": The key must end with the pattern
|
||||
|
||||
Args:
|
||||
original_state_dict: The source state dictionary
|
||||
loaded_state_dict: The target state dictionary
|
||||
pattern: The pattern to search for in keys
|
||||
match_type: How to match the pattern (exact, contains, startswith, endswith)
|
||||
|
||||
Returns:
|
||||
list[str]: List of keys that were copied
|
||||
"""
|
||||
copied_keys = []
|
||||
|
||||
for key in original_state_dict:
|
||||
should_copy = False
|
||||
|
||||
if match_type == "exact":
|
||||
should_copy = key == pattern
|
||||
elif match_type == "contains":
|
||||
should_copy = pattern in key
|
||||
elif match_type == "startswith":
|
||||
should_copy = key.startswith(pattern)
|
||||
elif match_type == "endswith":
|
||||
should_copy = key.endswith(pattern)
|
||||
|
||||
if should_copy:
|
||||
loaded_state_dict[key] = original_state_dict[key]
|
||||
copied_keys.append(key)
|
||||
|
||||
return copied_keys
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -682,9 +637,9 @@ class Policy(nn.Module):
|
|||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
@ -932,60 +887,6 @@ class Identity(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
# combine_state_for_ensemble causes graph breaks
|
||||
self.params = from_modules(*modules, as_module=True)
|
||||
with self.params[0].data.to("meta").to_module(modules[0]):
|
||||
self.module = deepcopy(modules[0])
|
||||
self._repr = str(modules[0])
|
||||
self._n = len(modules)
|
||||
|
||||
def __len__(self):
|
||||
return self._n
|
||||
|
||||
def _call(self, params, *args, **kwargs):
|
||||
with params.to_module(self.module):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(
|
||||
self.params, *args, **kwargs
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Vectorized {len(self)}x " + self._repr
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
|
|
|
@ -71,7 +71,6 @@ dependencies = [
|
|||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"tensordict>=0.0.1",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||
"torchmetrics>=1.6.0",
|
||||
|
@ -89,7 +88,7 @@ dora = [
|
|||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
|
||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
mani_skill = ["mani-skill"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
|
|
Loading…
Reference in New Issue