From c5c921cd7cf6dc554d84a77968b493e725410e8f Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 19 Mar 2025 18:53:01 +0000 Subject: [PATCH] Remove unused functions and imports from modeling_sac.py - Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary. - Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity. - Simplified the assertion in the `Policy` class for better readability. --- lerobot/common/policies/sac/modeling_sac.py | 105 +------------------- pyproject.toml | 3 +- 2 files changed, 4 insertions(+), 104 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e634af3f..43221d5c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -23,7 +23,6 @@ from pathlib import Path import einops import numpy as np -from tensordict import from_modules import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 @@ -446,50 +445,6 @@ class MLP(nn.Module): return self.net(x) -def find_and_copy_params( - original_state_dict: dict[str, torch.Tensor], - loaded_state_dict: dict[str, torch.Tensor], - pattern: str, - match_type: str = "contains", -) -> list[str]: - """Find and copy parameters from original state dict to loaded state dict based on a pattern. - - This function can search for keys in different ways based on the match_type: - - "exact": The key must exactly match the pattern - - "contains": The key must contain the pattern anywhere - - "startswith": The key must start with the pattern - - "endswith": The key must end with the pattern - - Args: - original_state_dict: The source state dictionary - loaded_state_dict: The target state dictionary - pattern: The pattern to search for in keys - match_type: How to match the pattern (exact, contains, startswith, endswith) - - Returns: - list[str]: List of keys that were copied - """ - copied_keys = [] - - for key in original_state_dict: - should_copy = False - - if match_type == "exact": - should_copy = key == pattern - elif match_type == "contains": - should_copy = pattern in key - elif match_type == "startswith": - should_copy = key.startswith(pattern) - elif match_type == "endswith": - should_copy = key.endswith(pattern) - - if should_copy: - loaded_state_dict[key] = original_state_dict[key] - copied_keys.append(key) - - return copied_keys - - class CriticHead(nn.Module): def __init__( self, @@ -682,9 +637,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std) @@ -932,60 +887,6 @@ class Identity(nn.Module): return x -class Ensemble(nn.Module): - """ - Vectorized ensemble of modules. - """ - - def __init__(self, modules, **kwargs): - super().__init__() - # combine_state_for_ensemble causes graph breaks - self.params = from_modules(*modules, as_module=True) - with self.params[0].data.to("meta").to_module(modules[0]): - self.module = deepcopy(modules[0]) - self._repr = str(modules[0]) - self._n = len(modules) - - def __len__(self): - return self._n - - def _call(self, params, *args, **kwargs): - with params.to_module(self.module): - return self.module(*args, **kwargs) - - def forward(self, *args, **kwargs): - return torch.vmap(self._call, (0, None), randomness="different")( - self.params, *args, **kwargs - ) - - def __repr__(self): - return f"Vectorized {len(self)}x " + self._repr - - -# TODO (azouitine): I think in our case this function is not usefull we should remove it -# after some investigation -# borrowed from tdmpc -def flatten_forward_unflatten( - fn: Callable[[Tensor], Tensor], image_tensor: Tensor -) -> Tensor: - """Helper to temporarily flatten extra dims at the start of the image tensor. - - Args: - fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return - (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and - can be more than 1 dimensions, generally different from *. - Returns: - A return value from the callable reshaped to (**, *). - """ - if image_tensor.ndim == 4: - return fn(image_tensor) - start_dims = image_tensor.shape[:-3] - inp = torch.flatten(image_tensor, end_dim=-4) - flat_out = fn(inp) - return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) - - def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: converted_params = {} for outer_key, inner_dict in normalization_params.items(): diff --git a/pyproject.toml b/pyproject.toml index 6f884d44..89577a4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,6 @@ dependencies = [ "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", "termcolor>=2.4.0", - "tensordict>=0.0.1", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))", "torchmetrics>=1.6.0", @@ -89,7 +88,7 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "tensordict>=0.0.1"] +hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] mani_skill = ["mani-skill"] pi0 = ["transformers>=4.48.0"]