Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary. - Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity. - Simplified the assertion in the `Policy` class for better readability.
This commit is contained in:
parent
36f9ccd851
commit
e4a5971ffd
|
@ -23,7 +23,6 @@ from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensordict import from_modules
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
@ -446,50 +445,6 @@ class MLP(nn.Module):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def find_and_copy_params(
|
|
||||||
original_state_dict: dict[str, torch.Tensor],
|
|
||||||
loaded_state_dict: dict[str, torch.Tensor],
|
|
||||||
pattern: str,
|
|
||||||
match_type: str = "contains",
|
|
||||||
) -> list[str]:
|
|
||||||
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
|
|
||||||
|
|
||||||
This function can search for keys in different ways based on the match_type:
|
|
||||||
- "exact": The key must exactly match the pattern
|
|
||||||
- "contains": The key must contain the pattern anywhere
|
|
||||||
- "startswith": The key must start with the pattern
|
|
||||||
- "endswith": The key must end with the pattern
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_state_dict: The source state dictionary
|
|
||||||
loaded_state_dict: The target state dictionary
|
|
||||||
pattern: The pattern to search for in keys
|
|
||||||
match_type: How to match the pattern (exact, contains, startswith, endswith)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: List of keys that were copied
|
|
||||||
"""
|
|
||||||
copied_keys = []
|
|
||||||
|
|
||||||
for key in original_state_dict:
|
|
||||||
should_copy = False
|
|
||||||
|
|
||||||
if match_type == "exact":
|
|
||||||
should_copy = key == pattern
|
|
||||||
elif match_type == "contains":
|
|
||||||
should_copy = pattern in key
|
|
||||||
elif match_type == "startswith":
|
|
||||||
should_copy = key.startswith(pattern)
|
|
||||||
elif match_type == "endswith":
|
|
||||||
should_copy = key.endswith(pattern)
|
|
||||||
|
|
||||||
if should_copy:
|
|
||||||
loaded_state_dict[key] = original_state_dict[key]
|
|
||||||
copied_keys.append(key)
|
|
||||||
|
|
||||||
return copied_keys
|
|
||||||
|
|
||||||
|
|
||||||
class CriticHead(nn.Module):
|
class CriticHead(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -682,9 +637,9 @@ class Policy(nn.Module):
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
assert not torch.isnan(
|
assert not torch.isnan(log_std).any(), (
|
||||||
log_std
|
"[ERROR] log_std became NaN after std_layer!"
|
||||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
)
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
|
@ -932,60 +887,6 @@ class Identity(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.Module):
|
|
||||||
"""
|
|
||||||
Vectorized ensemble of modules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, modules, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
# combine_state_for_ensemble causes graph breaks
|
|
||||||
self.params = from_modules(*modules, as_module=True)
|
|
||||||
with self.params[0].data.to("meta").to_module(modules[0]):
|
|
||||||
self.module = deepcopy(modules[0])
|
|
||||||
self._repr = str(modules[0])
|
|
||||||
self._n = len(modules)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._n
|
|
||||||
|
|
||||||
def _call(self, params, *args, **kwargs):
|
|
||||||
with params.to_module(self.module):
|
|
||||||
return self.module(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return torch.vmap(self._call, (0, None), randomness="different")(
|
|
||||||
self.params, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Vectorized {len(self)}x " + self._repr
|
|
||||||
|
|
||||||
|
|
||||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
|
||||||
# after some investigation
|
|
||||||
# borrowed from tdmpc
|
|
||||||
def flatten_forward_unflatten(
|
|
||||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
|
||||||
) -> Tensor:
|
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
|
||||||
(B, *), where * is any number of dimensions.
|
|
||||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
|
||||||
can be more than 1 dimensions, generally different from *.
|
|
||||||
Returns:
|
|
||||||
A return value from the callable reshaped to (**, *).
|
|
||||||
"""
|
|
||||||
if image_tensor.ndim == 4:
|
|
||||||
return fn(image_tensor)
|
|
||||||
start_dims = image_tensor.shape[:-3]
|
|
||||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
|
||||||
flat_out = fn(inp)
|
|
||||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
converted_params = {}
|
converted_params = {}
|
||||||
for outer_key, inner_dict in normalization_params.items():
|
for outer_key, inner_dict in normalization_params.items():
|
||||||
|
|
|
@ -71,7 +71,6 @@ dependencies = [
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"tensordict>=0.0.1",
|
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
"torchmetrics>=1.6.0",
|
"torchmetrics>=1.6.0",
|
||||||
|
@ -89,7 +88,7 @@ dora = [
|
||||||
]
|
]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"]
|
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
mani_skill = ["mani-skill"]
|
mani_skill = ["mani-skill"]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
|
|
Loading…
Reference in New Issue