Commit Graph

57 Commits

Author SHA1 Message Date
AdilZouitine 7361a11a4d Refactor SAC configuration and policy to support discrete actions
- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig.
- Added num_discrete_actions parameter to SACConfig for better action handling.
- Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions.
- Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
2025-04-18 15:10:22 +02:00
pre-commit-ci[bot] 88d26ae976 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
s1lent4gnt 3a2308d86f Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
2025-04-18 15:10:22 +02:00
s1lent4gnt 66693965c0 Add grasp critic
- Implemented grasp critic to evaluate gripper actions
- Added corresponding config parameters for tuning
2025-04-18 15:10:22 +02:00
AdilZouitine 5b49601072 Fix convergence of sac, multiple torch compile on the same model caused divergence 2025-04-18 15:10:22 +02:00
AdilZouitine 0185a0b6fd Fix cuda graph break 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot] eb44a06a9b [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
AdilZouitine 4d5ecb082e Refactor SACPolicy for improved type annotations and readability
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
2025-04-18 15:10:22 +02:00
AdilZouitine 6e687e2910 Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
2025-04-18 15:10:22 +02:00
AdilZouitine 3beab33fac Refactor imports in modeling_sac.py for improved organization
- Rearranged import statements for better readability.
- Removed unused imports and streamlined the code structure.
2025-04-18 15:10:22 +02:00
Michel Aractingi 05a237ce10 Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
2025-04-18 15:10:22 +02:00
AdilZouitine db897a1619 [WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
2025-04-18 15:09:46 +02:00
AdilZouitine 80d566eb56 Handle new config with sac 2025-04-18 15:09:27 +02:00
pre-commit-ci[bot] 0ea27704f6 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:09:25 +02:00
pre-commit-ci[bot] 1c8daf11fd [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:07:46 +02:00
pre-commit-ci[bot] 42f95e827d [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine 618ed00d45 Initialize log_alpha with the logarithm of temperature_init in SACPolicy
- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot] 50d8db481e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine e4a5971ffd Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary.
- Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity.
- Simplified the assertion in the `Policy` class for better readability.
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot] fd74c194b6 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine 0959694bab Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
2025-04-18 15:06:52 +02:00
AdilZouitine 66816fd871 Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot] 599326508f [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine 2f04d0d2b9 Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
2025-04-18 15:06:52 +02:00
AdilZouitine e002c5ec56 Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot] 85fe8a3f4e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:51 +02:00
Michel Aractingi d3b84ecd6f Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:58 +02:00
AdilZouitine 85242cac67 Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi 0d88a5ee09 - Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the `update_policy_parameters` and remove `strict=False`
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi 5195f40fd3 Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi 98c6557869 fix log_alpha in modeling_sac: change to nn.parameter
added pretrained vision model in policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi 9784d8a47f Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi d2c41b35db - Refactor observation encoder in `modeling_sac.py`
- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Yoel bc7b6d3daf [Port HIL-SERL] Add HF vision encoder option in SAC (#651)
Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
2025-04-18 15:04:13 +02:00
Michel Aractingi aebea08a99 Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step
Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi 8cd44ae163 - Added additional logging information in wandb around the timings of the policy loop and optimization loop.
- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
AdilZouitine c8b1132846 Stable version of rlpd + drq 2025-04-18 15:04:10 +02:00
Adil Zouitine 760d60ad4b Change SAC policy implementation with configuration and modeling classes 2025-04-18 15:03:51 +02:00
Adil Zouitine 875c0271b7 SAC works 2025-04-18 15:03:51 +02:00
Adil Zouitine 57344bfde5 [WIP] correct sac implementation 2025-04-18 15:03:51 +02:00
Adil Zouitine 46827fb002 Add rlpd tricks 2025-04-18 15:03:51 +02:00
Adil Zouitine 2fd78879f6 SAC works 2025-04-18 15:03:51 +02:00
Adil Zouitine e8449e9630 remove breakpoint 2025-04-18 15:03:51 +02:00
Adil Zouitine a0e2be8b92 [WIP] correct sac implementation 2025-04-18 15:03:51 +02:00
Eugene Mironov d1d6ffd23c [Port HIL_SERL] Final fixes for the Reward Classifier (#598) 2025-04-18 15:03:01 +02:00
Michel Aractingi c6ca9523de split encoder for critic and actor 2025-04-18 15:03:01 +02:00
Michel Aractingi 642e3a3274 style fixes 2025-04-18 15:03:01 +02:00
KeWang1017 146148c48c Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2025-04-18 15:03:01 +02:00
KeWang1017 8f15835daa Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2025-04-18 15:03:01 +02:00
KeWang1017 022bd65125 Refactor SACPolicy for improved action sampling and standard deviation handling
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2025-04-18 15:03:01 +02:00