[upgrade] unified interface and a easier example

* embed all function and new concepts into legged_robot for base implementation
* add namedarraytuple concept borrowed from astooke/rlpyt
* use namedarraytuple in rollout storage and summarize a minibatch object
* add `rollout_file` concept to store rollout data in files and demonstrations
* add state estimator module/algo implementation
* complete rewrite `play.py` example
* rename `actions_scaled_torque_clipped` to `actions_scaled_clipped`
* add example onboard codes for deploying on Go2.
This commit is contained in:
Ziwen Zhuang 2024-05-24 19:01:43 +08:00
parent 96317ac12f
commit 1ffd6d7c05
108 changed files with 47664 additions and 3080 deletions

View File

@ -16,7 +16,7 @@ Conference on Robot Learning (CoRL) 2023, **Oral**, **Best Systems Paper Award F
## Repository Structure ##
* `legged_gym`: contains the isaacgym environment and config files.
- `legged_gym/legged_gym/envs/a1/`: contains all the training config files.
- `legged_gym/legged_gym/envs/{robot}/`: contains all the training config files for a specific robot
- `legged_gym/legged_gym/envs/base/`: contains all the environment implementation.
- `legged_gym/legged_gym/utils/terrain/`: contains the terrain generation code.
* `rsl_rl`: contains the network module and algorithm implementation. You can copy this folder directly to your robot.
@ -24,19 +24,23 @@ Conference on Robot Learning (CoRL) 2023, **Oral**, **Best Systems Paper Award F
- `rsl_rl/rsl_rl/modules/`: contains the network module implementation.
## Training in Simulation ##
To install and run the code for training A1 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
To install and run the code for training A1/Go2 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
## Hardware Deployment ##
To deploy the trained model on your real robot, please follow the instructions in [Deploy.md](Deploy.md).
To deploy the trained model on your unitree Go1 robot, please follow the instructions in [Deploy-Go1.md](onboard_codes/Deploy-Go1.md) for deploying on the Unittree Go1 robot.
To deploy the trained model on your unitree Go2 robot, please follow the instructions in [Deploy-Go2.md](onboard_codes/Deploy-Go2.md) for deploying on the Unittree Go2 robot.
## Trouble Shooting ##
If you cannot run the distillation part or all graphics computing goes to GPU 0 dispite you have multiple GPUs and have set the CUDA_VISIBLE_DEVICES, please use docker to isolate each GPU.
## To Do (will be done before Nov 2023) ##
- [x] Go1 training configuration (not from scratch)
## To Do ##
- [x] Go1 training configuration (does not guarantee the same performance as the paper)
- [ ] A1 deployment code
- [x] Go1 deployment code
- [x] Go2 training configuration example (does not guarantee the same performance as the paper)
- [x] Go2 deployment code example
## Citation ##
If you find this project helpful to your research, please consider cite us! This is really important to us.

View File

@ -19,6 +19,53 @@ This is the tutorial for training the skill policy and distilling the parkour po
## Usage ##
***Always run your script in the root path of this legged_gym folder (which contains a `setup.py` file).***
### Go2 example (newer and simplier, does not guarantee performance) ###
1. Train a walking policy for planar locomotion
```python
python legged_gym/scripts/train.py --headless --task go2
```
Training logs will be saved in `logs/rough_go2`.
2. Train the parkour policy. In this example, we use **scandot** for terrain perception, so we remove the **crawl** skill.
- Update the `"{Your trained walking model directory}"` value in the config file `legged_gym/legged_gym/envs/go2/go2_field_config.py` with the trained walking policy folder name.
- Run ```python legged_gym/scripts/train.py --headless --task go2_field```
The training logs will be saved in `logs/field_go2`.
3. Distill the parkour policy
- Update the following literals in the config file `legged_gym/legged_gym/envs/go2/go2_distill_config.py`
- `"{Your trained oracle parkour model folder}"`: The oracle parkour policy folder name in the last step.
- `"{The latest model filename in the directory}"`: The model file name in the oracle parkour policy folder in the last step.
- `"{A temporary directory to store collected trajectory}"`: A temporary directory to store the collected trajectory data.
- Calibrate your depth camera extrinsic pose and update the `position` and `rotation` field in `sensor.forward_camera` class
- Run distillation process (choose one among the two options)
1. Run the distillation process in a single process if you believe you have a powerful GPU than Nvidia RTX 3090.
Set `multi_process_` to `False` in the config file.
Run ```python legged_gym/scripts/train.py --headless --task go2_distill```
2. Run the distillation process in multiple processes with multiple GPUs
Run ```python legged_gym/scripts/train.py --headless --task go2_distill```
Find the log directory generated by the training process when prompted waiting. (e.g. **Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08**)
Run ```python legged_gym/scripts/collect.py --headless --task go2_distill --log --load_run {the log directory name}``` in another terminal on another GPU. (Can run multiple collectors in parallel)
### A1 example ###
1. The specialized skill policy is trained using `a1_field_config.py` as task `a1_field`
Run command with `python legged_gym/scripts/train.py --headless --task a1_field`
@ -31,11 +78,11 @@ This is the tutorial for training the skill policy and distilling the parkour po
With `python legged_gym/scripts/collect.py --headless --task a1_distill --load_run {your training run}` you lauched the collector. The process will load the training policy and start collecting the data. The collected data will be saved in the directory prompted by the trainer. Remove it after you finish distillation.
### Train a walk policy ###
#### Train a walk policy ####
Launch the training by `python legged_gym/scripts/train.py --headless --task a1_field`. You will find the training log in `logs/a1_field`. The folder name is also the run name.
### Train each separate skill ###
#### Train each separate skill ####
- Launch the scirpt with task `a1_climb`, `a1_leap`, `a1_crawl`, `a1_tilt`. The training log will also be saved in `logs/a1_field`.
@ -47,7 +94,7 @@ Launch the training by `python legged_gym/scripts/train.py --headless --task a1_
- Do remember to update the `load_run` field in the corresponding log directory to load the policy from the previous stage.
### Distill the parkour policy ###
#### Distill the parkour policy ####
**You will need at least two GPUs that can render in IsaacGym and have at least 24GB of memory. (typically RTX 3090)**
@ -82,3 +129,5 @@ Launch the training by `python legged_gym/scripts/train.py --headless --task a1_
```bash
python legged_gym/scripts/play.py --task {task} --load_run {run_name}
```
Where `{run_name}` can be the absolute path of your log directory (which contains the `config.json` file).

View File

@ -32,7 +32,7 @@ from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
from legged_gym.envs.a1.a1_config import A1RoughCfg, A1RoughCfgPPO, A1PlaneCfg, A1RoughCfgTPPO
from .base.legged_robot import LeggedRobot
from .base.legged_robot_field import LeggedRobotField
from .base.legged_robot_noisy import LeggedRobotNoisy
from .base.robot_field_noisy import RobotFieldNoisy
from .anymal_c.anymal import Anymal
from .anymal_c.mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg, AnymalCRoughCfgPPO
from .anymal_c.flat.anymal_c_flat_config import AnymalCFlatCfg, AnymalCFlatCfgPPO
@ -44,6 +44,9 @@ from .a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
from .a1.a1_field_distill_config import A1FieldDistillCfg, A1FieldDistillCfgPPO
from .go1.go1_field_config import Go1FieldCfg, Go1FieldCfgPPO
from .go1.go1_field_distill_config import Go1FieldDistillCfg, Go1FieldDistillCfgPPO
from .go2.go2_config import Go2RoughCfg, Go2RoughCfgPPO
from .go2.go2_field_config import Go2FieldCfg, Go2FieldCfgPPO
from .go2.go2_distill_config import Go2DistillCfg, Go2DistillCfgPPO
import os
@ -54,35 +57,15 @@ task_registry.register( "anymal_c_rough", Anymal, AnymalCRoughCfg(), AnymalCRoug
task_registry.register( "anymal_c_flat", Anymal, AnymalCFlatCfg(), AnymalCFlatCfgPPO() )
task_registry.register( "anymal_b", Anymal, AnymalBRoughCfg(), AnymalBRoughCfgPPO() )
task_registry.register( "a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO() )
task_registry.register( "a1_teacher", LeggedRobot, A1PlaneCfg(), A1RoughCfgTPPO() )
task_registry.register( "a1_field", LeggedRobotNoisy, A1FieldCfg(), A1FieldCfgPPO() )
task_registry.register( "a1_distill", LeggedRobotNoisy, A1FieldDistillCfg(), A1FieldDistillCfgPPO() )
task_registry.register( "cassie", Cassie, CassieRoughCfg(), CassieRoughCfgPPO() )
task_registry.register( "go1_field", LeggedRobotNoisy, Go1FieldCfg(), Go1FieldCfgPPO())
task_registry.register( "go1_distill", LeggedRobotNoisy, Go1FieldDistillCfg(), Go1FieldDistillCfgPPO())
task_registry.register( "go1_field", LeggedRobot, Go1FieldCfg(), Go1FieldCfgPPO())
task_registry.register( "go1_distill", LeggedRobot, Go1FieldDistillCfg(), Go1FieldDistillCfgPPO())
task_registry.register( "go2", LeggedRobot, Go2RoughCfg(), Go2RoughCfgPPO() )
task_registry.register( "go2_field", RobotFieldNoisy, Go2FieldCfg(), Go2FieldCfgPPO() )
task_registry.register( "go2_distill", RobotFieldNoisy, Go2DistillCfg(), Go2DistillCfgPPO() )
## The following tasks are for the convinience of opensource
from .a1.a1_remote_config import A1RemoteCfg, A1RemoteCfgPPO
task_registry.register( "a1_remote", LeggedRobotNoisy, A1RemoteCfg(), A1RemoteCfgPPO() )
from .a1.a1_jump_config import A1JumpCfg, A1JumpCfgPPO
task_registry.register( "a1_jump", LeggedRobotNoisy, A1JumpCfg(), A1JumpCfgPPO() )
from .a1.a1_down_config import A1DownCfg, A1DownCfgPPO
task_registry.register( "a1_down", LeggedRobotNoisy, A1DownCfg(), A1DownCfgPPO() )
from .a1.a1_leap_config import A1LeapCfg, A1LeapCfgPPO
task_registry.register( "a1_leap", LeggedRobotNoisy, A1LeapCfg(), A1LeapCfgPPO() )
from .a1.a1_crawl_config import A1CrawlCfg, A1CrawlCfgPPO
task_registry.register( "a1_crawl", LeggedRobotNoisy, A1CrawlCfg(), A1CrawlCfgPPO() )
from .a1.a1_tilt_config import A1TiltCfg, A1TiltCfgPPO
task_registry.register( "a1_tilt", LeggedRobotNoisy, A1TiltCfg(), A1TiltCfgPPO() )
task_registry.register( "a1_remote", LeggedRobot, A1RemoteCfg(), A1RemoteCfgPPO() )
from .go1.go1_remote_config import Go1RemoteCfg, Go1RemoteCfgPPO
task_registry.register( "go1_remote", LeggedRobotNoisy, Go1RemoteCfg(), Go1RemoteCfgPPO() )
from .go1.go1_jump_config import Go1JumpCfg, Go1JumpCfgPPO
task_registry.register( "go1_jump", LeggedRobotNoisy, Go1JumpCfg(), Go1JumpCfgPPO() )
from .go1.go1_down_config import Go1DownCfg, Go1DownCfgPPO
task_registry.register( "go1_down", LeggedRobotNoisy, Go1DownCfg(), Go1DownCfgPPO() )
from .go1.go1_leap_config import Go1LeapCfg, Go1LeapCfgPPO
task_registry.register( "go1_leap", LeggedRobotNoisy, Go1LeapCfg(), Go1LeapCfgPPO() )
from .go1.go1_crawl_config import Go1CrawlCfg, Go1CrawlCfgPPO
task_registry.register( "go1_crawl", LeggedRobotNoisy, Go1CrawlCfg(), Go1CrawlCfgPPO() )
from .go1.go1_tilt_config import Go1TiltCfg, Go1TiltCfgPPO
task_registry.register( "go1_tilt", LeggedRobotNoisy, Go1TiltCfg(), Go1TiltCfgPPO() )
task_registry.register( "go1_remote", LeggedRobot, Go1RemoteCfg(), Go1RemoteCfgPPO() )

View File

@ -1,4 +1,5 @@
import numpy as np
import os.path as osp
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
from legged_gym.utils.helpers import merge_dict
@ -7,7 +8,6 @@ class A1CrawlCfg( A1FieldCfg ):
#### uncomment this to train non-virtual terrain
class sensor( A1FieldCfg.sensor ):
class proprioception( A1FieldCfg.sensor.proprioception ):
delay_action_obs = True
latency_range = [0.04-0.0025, 0.04+0.0075]
#### uncomment the above to train non-virtual terrain
@ -28,11 +28,11 @@ class A1CrawlCfg( A1FieldCfg ):
wall_height= 0.6,
no_perlin_at_obstacle= False,
),
virtual_terrain= True, # Change this to False for real terrain
virtual_terrain= False, # Change this to False for real terrain
))
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
zScale= 0.1,
zScale= 0.12,
))
class commands( A1FieldCfg.commands ):
@ -41,6 +41,9 @@ class A1CrawlCfg( A1FieldCfg ):
lin_vel_y = [0.0, 0.0]
ang_vel_yaw = [0., 0.]
class asset( A1FieldCfg.asset ):
terminate_after_contacts_on = ["base"]
class termination( A1FieldCfg.termination ):
# additional factors that determines whether to terminates the episode
termination_terms = [
@ -51,16 +54,29 @@ class A1CrawlCfg( A1FieldCfg ):
"out_of_track",
]
class domain_rand( A1FieldCfg.domain_rand ):
init_base_rot_range = dict(
roll= [-0.1, 0.1],
pitch= [-0.1, 0.1],
)
class rewards( A1FieldCfg.rewards ):
class scales:
tracking_ang_vel = 0.05
world_vel_l2norm = -1.
legs_energy_substeps = -2e-5
legs_energy_substeps = -1e-5
alive = 2.
penetrate_depth = -6e-2 # comment this out if trianing non-virtual terrain
penetrate_volume = -6e-2 # comment this out if trianing non-virtual terrain
exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1
# penetrate_depth = -6e-2 # comment this out if trianing non-virtual terrain
# penetrate_volume = -6e-2 # comment this out if trianing non-virtual terrain
exceed_dof_pos_limits = -8e-1
# exceed_torque_limits_i = -2e-1
exceed_torque_limits_l1norm = -4e-1
# collision = -0.05
# tilt_cond = 0.1
torques = -1e-5
yaw_abs = -0.1
lin_pos_y = -0.1
soft_dof_pos_limit = 0.9
class curriculum( A1FieldCfg.curriculum ):
penetrate_volume_threshold_harder = 1500
@ -69,27 +85,49 @@ class A1CrawlCfg( A1FieldCfg ):
penetrate_depth_threshold_easier = 400
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class A1CrawlCfgPPO( A1FieldCfgPPO ):
class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0
clip_min_std = 0.2
clip_min_std = 0.1
class runner( A1FieldCfgPPO.runner ):
policy_class_name = "ActorCriticRecurrent"
experiment_name = "field_a1"
run_name = "".join(["Skill",
resume = True
load_run = "{Your traind walking model directory}"
load_run = "{Your virtually trained crawling model directory}"
# load_run = "Aug21_06-12-58_Skillcrawl_propDelay0.00-0.05_virtual"
# load_run = osp.join(logs_root, "field_a1_oracle/May21_05-25-19_Skills_crawl_pEnergy2e-5_rAlive1_pPenV6e-2_pPenD6e-2_pPosY0.2_kp50_noContactTerminate_aScale0.5")
# load_run = osp.join(logs_root, "field_a1_oracle/Sep26_01-38-19_Skills_crawl_propDelay0.04-0.05_pEnergy-4e-5_pTorqueL13e-01_kp40_fromMay21_05-25-19")
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep26_14-30-24_Skills_crawl_propDelay0.04-0.05_pEnergy-2e-5_pDof8e-01_pTorqueL14e-01_rTilt5e-01_pCollision0.2_maxPushAng0.5_kp40_fromSep26_01-38-19")
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct09_09-58-26_Skills_crawl_propDelay0.04-0.05_pEnergy-1e-5_pDof8e-01_pTorqueL14e-01_maxPushAng0.0_kp40_fromSep26_14-30-24")
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct11_12-19-00_Skills_crawl_propDelay0.04-0.05_pEnergy-1e-5_pDof8e-01_pTorqueL14e-01_pPosY0.1_maxPushAng0.3_kp40_fromOct09_09-58-26")
run_name = "".join(["Skills_",
("Multi" if len(A1CrawlCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1CrawlCfg.terrain.BarrierTrack_kwargs["options"][0] if A1CrawlCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
("_comXRange{:.1f}-{:.1f}".format(A1CrawlCfg.domain_rand.com_range.x[0], A1CrawlCfg.domain_rand.com_range.x[1])),
("_noLinVel" if not A1CrawlCfg.env.use_lin_vel else ""),
("_propDelay{:.2f}-{:.2f}".format(
A1CrawlCfg.sensor.proprioception.latency_range[0],
A1CrawlCfg.sensor.proprioception.latency_range[1],
) if A1CrawlCfg.sensor.proprioception.delay_action_obs else ""
),
("_pEnergy" + np.format_float_scientific(A1CrawlCfg.rewards.scales.legs_energy_substeps, precision= 1, exp_digits= 1, trim= "-") if A1CrawlCfg.rewards.scales.legs_energy_substeps != 0. else ""),
# ("_pPenD{:.0e}".format(A1CrawlCfg.rewards.scales.penetrate_depth) if getattr(A1CrawlCfg.rewards.scales, "penetrate_depth", 0.) != 0. else ""),
("_pEnergySubsteps" + np.format_float_scientific(A1CrawlCfg.rewards.scales.legs_energy_substeps, precision= 1, exp_digits= 1, trim= "-") if getattr(A1CrawlCfg.rewards.scales, "legs_energy_substeps", 0.) != 0. else ""),
("_pDof{:.0e}".format(-A1CrawlCfg.rewards.scales.exceed_dof_pos_limits) if getattr(A1CrawlCfg.rewards.scales, "exceed_dof_pos_limits", 0.) != 0 else ""),
("_pTorque" + np.format_float_scientific(-A1CrawlCfg.rewards.scales.torques, precision= 1, exp_digits= 1, trim= "-") if getattr(A1CrawlCfg.rewards.scales, "torques", 0.) != 0 else ""),
("_pTorqueL1{:.0e}".format(-A1CrawlCfg.rewards.scales.exceed_torque_limits_l1norm) if getattr(A1CrawlCfg.rewards.scales, "exceed_torque_limits_l1norm", 0.) != 0 else ""),
# ("_rTilt{:.0e}".format(A1CrawlCfg.rewards.scales.tilt_cond) if getattr(A1CrawlCfg.rewards.scales, "tilt_cond", 0.) != 0 else ""),
# ("_pYaw{:.1f}".format(-A1CrawlCfg.rewards.scales.yaw_abs) if getattr(A1CrawlCfg.rewards.scales, "yaw_abs", 0.) != 0 else ""),
# ("_pPosY{:.1f}".format(-A1CrawlCfg.rewards.scales.lin_pos_y) if getattr(A1CrawlCfg.rewards.scales, "lin_pos_y", 0.) != 0 else ""),
# ("_pCollision{:.1f}".format(-A1CrawlCfg.rewards.scales.collision) if getattr(A1CrawlCfg.rewards.scales, "collision", 0.) != 0 else ""),
# ("_kp{:d}".format(int(A1CrawlCfg.control.stiffness["joint"])) if A1CrawlCfg.control.stiffness["joint"] != 50 else ""),
("_noDelayActObs" if not A1CrawlCfg.sensor.proprioception.delay_action_obs else ""),
("_noTanh"),
("_virtual" if A1CrawlCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
resume = True
load_run = "{Your traind walking model directory}"
load_run = "{Your virtually trained crawling model directory}"
max_iterations = 20000
save_interval = 500

View File

@ -200,6 +200,7 @@ class A1FieldCfg( A1RoughCfg ):
exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1
soft_dof_pos_limit = 0.01
only_positive_rewards = False
class normalization( A1RoughCfg.normalization ):
class obs_scales( A1RoughCfg.normalization.obs_scales ):
@ -286,7 +287,7 @@ class A1FieldCfgPPO( A1RoughCfgPPO ):
("_propDelay{:.2f}-{:.2f}".format(
A1FieldCfg.sensor.proprioception.latency_range[0],
A1FieldCfg.sensor.proprioception.latency_range[1],
) if A1FieldCfg.sensor.proprioception.delay_action_obs else ""
) if A1FieldCfg.sensor.proprioception.latency_range[1] > 0. else ""
),
("_aScale{:d}{:d}{:d}".format(
int(A1FieldCfg.control.action_scale[0] * 10),
@ -297,6 +298,6 @@ class A1FieldCfgPPO( A1RoughCfgPPO ):
),
])
resume = False
max_iterations = 10000
max_iterations = 5000
save_interval = 500

View File

@ -1,15 +1,15 @@
import numpy as np
from os import path as osp
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
from legged_gym.utils.helpers import merge_dict
class A1LeapCfg( A1FieldCfg ):
#### uncomment this to train non-virtual terrain
# class sensor( A1FieldCfg.sensor ):
# class proprioception( A1FieldCfg.sensor.proprioception ):
# delay_action_obs = True
# latency_range = [0.04-0.0025, 0.04+0.0075]
#### uncomment the above to train non-virtual terrain
### uncomment this to train non-virtual terrain
class sensor( A1FieldCfg.sensor ):
class proprioception( A1FieldCfg.sensor.proprioception ):
latency_range = [0.04-0.0025, 0.04+0.0075]
### uncomment the above to train non-virtual terrain
class terrain( A1FieldCfg.terrain ):
max_init_terrain_level = 2
@ -22,16 +22,16 @@ class A1LeapCfg( A1FieldCfg ):
"leap",
],
leap= dict(
length= (0.2, 1.0),
length= (0.2, 0.8),
depth= (0.4, 0.8),
height= 0.2,
height= 0.12,
),
virtual_terrain= False, # Change this to False for real terrain
no_perlin_threshold= 0.06,
virtual_terrain= True, # Change this to False for real terrain
no_perlin_threshold= 0.1,
))
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
zScale= [0.05, 0.1],
zScale= [0.05, 0.15],
))
class commands( A1FieldCfg.commands ):
@ -57,16 +57,28 @@ class A1LeapCfg( A1FieldCfg ):
threshold= 2.0,
))
class domain_rand( A1FieldCfg.domain_rand ):
init_base_rot_range = dict(
roll= [-0.1, 0.1],
pitch= [-0.1, 0.1],
)
class rewards( A1FieldCfg.rewards ):
class scales:
tracking_ang_vel = 0.05
world_vel_l2norm = -1.
legs_energy_substeps = -1e-6
# legs_energy_substeps = -8e-6
alive = 2.
penetrate_depth = -4e-3
penetrate_volume = -4e-3
exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1
penetrate_depth = -1e-2
penetrate_volume = -1e-2
exceed_dof_pos_limits = -4e-1
exceed_torque_limits_l1norm = -8e-1
# feet_contact_forces = -1e-2
torques = -2e-5
collision = -0.5
lin_pos_y = -0.1
yaw_abs = -0.1
soft_dof_pos_limit = 0.5
class curriculum( A1FieldCfg.curriculum ):
penetrate_volume_threshold_harder = 9000
@ -75,6 +87,7 @@ class A1LeapCfg( A1FieldCfg ):
penetrate_depth_threshold_easier = 5000
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class A1LeapCfgPPO( A1FieldCfgPPO ):
class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0
@ -83,19 +96,39 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
class runner( A1FieldCfgPPO.runner ):
policy_class_name = "ActorCriticRecurrent"
experiment_name = "field_a1"
run_name = "".join(["Skill",
resume = True
load_run = "{Your traind walking model directory}"
load_run = "{Your virtually trained leap model directory}"
# load_run = osp.join(logs_root, "field_a1_oracle/Jun04_01-03-59_Skills_leap_pEnergySubsteps2e-6_rAlive2_pPenV4e-3_pPenD4e-3_pPosY0.20_pYaw0.20_pTorqueExceedSquare1e-3_leapH0.2_propDelay0.04-0.05_noPerlinRate0.2_aScale0.5")
# load_run = "Sep27_02-44-48_Skills_leap_propDelay0.04-0.05_pDofLimit8e-01_pCollision0.1_kp40_kd0.5fromJun04_01-03-59"
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep27_14-56-25_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pDofLimit8e-01_pCollision0.1_kp40_kd0.5fromSep27_02-44-48")
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct05_02-16-22_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pPenD8.e-3_pDofLimit4e-01_pCollision0.5_kp40_kd0.5fromSep27_14-56-25")
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct09_09-51-58_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pPenD1.e-2_pDofLimit4e-01_pCollision0.5_kp40_kd0.5fromOct05_02-16-22")
run_name = "".join(["Skills_",
("Multi" if len(A1LeapCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1LeapCfg.terrain.BarrierTrack_kwargs["options"][0] if A1LeapCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
("_comXRange{:.1f}-{:.1f}".format(A1LeapCfg.domain_rand.com_range.x[0], A1LeapCfg.domain_rand.com_range.x[1])),
("_noLinVel" if not A1LeapCfg.env.use_lin_vel else ""),
("_propDelay{:.2f}-{:.2f}".format(
A1LeapCfg.sensor.proprioception.latency_range[0],
A1LeapCfg.sensor.proprioception.latency_range[1],
) if A1LeapCfg.sensor.proprioception.delay_action_obs else ""
),
("_pEnergySubsteps{:.0e}".format(A1LeapCfg.rewards.scales.legs_energy_substeps) if A1LeapCfg.rewards.scales.legs_energy_substeps != -2e-6 else ""),
("_pEnergySubsteps{:.0e}".format(A1LeapCfg.rewards.scales.legs_energy_substeps) if getattr(A1LeapCfg.rewards.scales, "legs_energy_substeps", -2e-6) != -2e-6 else ""),
("_pTorques" + np.format_float_scientific(-A1LeapCfg.rewards.scales.torques, precision=1, exp_digits=1) if getattr(A1LeapCfg.rewards.scales, "torques", 0.) != -0. else ""),
# ("_pPenD" + np.format_float_scientific(-A1LeapCfg.rewards.scales.penetrate_depth, precision=1, exp_digits=1) if A1LeapCfg.rewards.scales.penetrate_depth != -4e-3 else ""),
# ("_pYaw{:.1f}".format(-A1LeapCfg.rewards.scales.yaw_abs) if getattr(A1LeapCfg.rewards.scales, "yaw_abs", 0.) != 0. else ""),
# ("_pDofLimit{:.0e}".format(-A1LeapCfg.rewards.scales.exceed_dof_pos_limits) if getattr(A1LeapCfg.rewards.scales, "exceed_dof_pos_limits", 0.) != 0. else ""),
# ("_pCollision{:.1f}".format(-A1LeapCfg.rewards.scales.collision) if getattr(A1LeapCfg.rewards.scales, "collision", 0.) != 0. else ""),
("_pContactForces" + np.format_float_scientific(-A1LeapCfg.rewards.scales.feet_contact_forces, precision=1, exp_digits=1) if getattr(A1LeapCfg.rewards.scales, "feet_contact_forces", 0.) != 0. else ""),
("_leapHeight{:.1f}".format(A1LeapCfg.terrain.BarrierTrack_kwargs["leap"]["height"]) if A1LeapCfg.terrain.BarrierTrack_kwargs["leap"]["height"] != 0.2 else ""),
# ("_kp{:d}".format(int(A1LeapCfg.control.stiffness["joint"])) if A1LeapCfg.control.stiffness["joint"] != 50 else ""),
# ("_kd{:.1f}".format(A1LeapCfg.control.damping["joint"]) if A1LeapCfg.control.damping["joint"] != 0. else ""),
("_noDelayActObs" if not A1LeapCfg.sensor.proprioception.delay_action_obs else ""),
("_noTanh"),
("_virtual" if A1LeapCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
("_noResume" if not resume else "from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
resume = True
load_run = "{Your traind walking model directory}"
load_run = "{Your virtually trained leap model directory}"
max_iterations = 20000
save_interval = 500

View File

@ -1,452 +0,0 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import json
import os
import os.path as osp
from collections import OrderedDict
from typing import Tuple
import rospy
from unitree_legged_msgs.msg import LowState
from unitree_legged_msgs.msg import LegsCmd
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from std_msgs.msg import Float32MultiArray
from geometry_msgs.msg import Twist, Pose
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Image
import ros_numpy
@torch.no_grad()
def resize2d(img, size):
return (F.adaptive_avg_pool2d(Variable(img), size)).data
@torch.jit.script
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
class UnitreeA1Real:
""" This is the handler that works for ROS 1 on unitree. """
def __init__(self,
robot_namespace= "a112138",
low_state_topic= "/low_state",
legs_cmd_topic= "/legs_cmd",
forward_depth_topic = "/camera/depth/image_rect_raw",
forward_depth_embedding_dims = None,
odom_topic= "/odom/filtered",
lin_vel_deadband= 0.2,
ang_vel_deadband= 0.1,
move_by_wireless_remote= False, # if True, command will not listen to move_cmd_subscriber, but wireless remote.
cfg= dict(),
extra_cfg= dict(),
model_device= torch.device("cpu"),
):
"""
NOTE:
* Must call start_ros() before using this class's get_obs() and send_action()
* Joint order of simulation and of real A1 protocol are different, see dof_names
* We store all joints values in the order of simulation in this class
Args:
forward_depth_embedding_dims: If a real number, the obs will not be built as a normal env.
The segment of obs will be subsituted by the embedding of forward depth image from the
ROS topic.
cfg: same config from a1_config but a dict object.
extra_cfg: some other configs that is hard to load from file.
"""
self.model_device = model_device
self.num_envs = 1
self.robot_namespace = robot_namespace
self.low_state_topic = low_state_topic
self.legs_cmd_topic = legs_cmd_topic
self.forward_depth_topic = forward_depth_topic
self.forward_depth_embedding_dims = forward_depth_embedding_dims
self.odom_topic = odom_topic
self.lin_vel_deadband = lin_vel_deadband
self.ang_vel_deadband = ang_vel_deadband
self.move_by_wireless_remote = move_by_wireless_remote
self.cfg = cfg
self.extra_cfg = dict(
torque_limits= torch.tensor([33.5] * 12, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
# torque_limits= torch.tensor([1, 5, 5] * 4, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
dof_map= [ # from isaacgym simulation joint order to URDF order
3, 4, 5,
0, 1, 2,
9, 10,11,
6, 7, 8,
], # real_joint_idx = dof_map[sim_joint_idx]
dof_names= [ # NOTE: order matters
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
],
# motor strength is multiplied directly to the action.
motor_strength= torch.ones(12, dtype= torch.float32, device= self.model_device, requires_grad= False),
); self.extra_cfg.update(extra_cfg)
if "torque_limits" in self.cfg["control"]:
self.extra_cfg["torque_limits"][:] = self.cfg["control"]["torque_limits"]
self.command_buf = torch.zeros((self.num_envs, 3,), device= self.model_device, dtype= torch.float32) # zeros for initialization
self.actions = torch.zeros((1, 12), device= model_device, dtype= torch.float32)
self.process_configs()
def start_ros(self):
# initialze several buffers so that the system works even without message update.
# self.low_state_buffer = LowState() # not initialized, let input message update it.
self.base_position_buffer = torch.zeros((self.num_envs, 3), device= self.model_device, requires_grad= False)
self.legs_cmd_publisher = rospy.Publisher(
self.robot_namespace + self.legs_cmd_topic,
LegsCmd,
queue_size= 1,
)
# self.debug_publisher = rospy.Publisher(
# "/DNNmodel_debug",
# Float32MultiArray,
# queue_size= 1,
# )
# NOTE: this launches the subscriber callback function
self.low_state_subscriber = rospy.Subscriber(
self.robot_namespace + self.low_state_topic,
LowState,
self.update_low_state,
queue_size= 1,
)
self.odom_subscriber = rospy.Subscriber(
self.robot_namespace + self.odom_topic,
Odometry,
self.update_base_pose,
queue_size= 1,
)
if not self.move_by_wireless_remote:
self.move_cmd_subscriber = rospy.Subscriber(
"/cmd_vel",
Twist,
self.update_move_cmd,
queue_size= 1,
)
if "forward_depth" in self.all_obs_components:
if not self.forward_depth_embedding_dims:
self.forward_depth_subscriber = rospy.Subscriber(
self.robot_namespace + self.forward_depth_topic,
Image,
self.update_forward_depth,
queue_size= 1,
)
else:
self.forward_depth_subscriber = rospy.Subscriber(
self.robot_namespace + self.forward_depth_topic,
Float32MultiArrayStamped,
self.update_forward_depth_embedding,
queue_size= 1,
)
self.pose_cmd_subscriber = rospy.Subscriber(
"/body_pose",
Pose,
self.dummy_handler,
queue_size= 1,
)
def wait_untill_ros_working(self):
rate = rospy.Rate(100)
while not hasattr(self, "low_state_buffer"):
rate.sleep()
rospy.loginfo("UnitreeA1Real.low_state_buffer acquired, stop waiting.")
def process_configs(self):
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
self.gravity_vec = torch.zeros((self.num_envs, 3), dtype= torch.float32)
self.gravity_vec[:, self.up_axis_idx] = -1
self.obs_scales = self.cfg["normalization"]["obs_scales"]
self.obs_scales["dof_pos"] = torch.tensor(self.obs_scales["dof_pos"], device= self.model_device, dtype= torch.float32)
if not isinstance(self.cfg["control"]["damping"]["joint"], (list, tuple)):
self.cfg["control"]["damping"]["joint"] = [self.cfg["control"]["damping"]["joint"]] * 12
if not isinstance(self.cfg["control"]["stiffness"]["joint"], (list, tuple)):
self.cfg["control"]["stiffness"]["joint"] = [self.cfg["control"]["stiffness"]["joint"]] * 12
self.d_gains = torch.tensor(self.cfg["control"]["damping"]["joint"], device= self.model_device, dtype= torch.float32)
self.p_gains = torch.tensor(self.cfg["control"]["stiffness"]["joint"], device= self.model_device, dtype= torch.float32)
self.default_dof_pos = torch.zeros(12, device= self.model_device, dtype= torch.float32)
for i in range(12):
name = self.extra_cfg["dof_names"][i]
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
self.default_dof_pos[i] = default_joint_angle
self.torque_limits = self.extra_cfg["torque_limits"]
self.commands_scale = torch.tensor([
self.obs_scales["lin_vel"],
self.obs_scales["lin_vel"],
self.obs_scales["lin_vel"],
], device= self.model_device, requires_grad= False)
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
components = self.cfg["env"].get("privileged_obs_components", None)
self.privileged_obs_segments = None if components is None else self.get_num_obs_from_components(components)
self.num_privileged_obs = None if components is None else self.get_num_obs_from_components(components)
self.all_obs_components = self.cfg["env"]["obs_components"] + (self.cfg["env"].get("privileged_obs_components", []) if components is not None else [])
# store config values to attributes to improve speed
self.clip_obs = self.cfg["normalization"]["clip_observations"]
self.control_type = self.cfg["control"]["control_type"]
self.action_scale = self.cfg["control"]["action_scale"]
self.motor_strength = self.extra_cfg["motor_strength"]
self.clip_actions = self.cfg["normalization"]["clip_actions"]
self.dof_map = self.extra_cfg["dof_map"]
# get ROS params for hardware configs
self.joint_limits_high = torch.tensor([
rospy.get_param(self.robot_namespace + "/joint_limits/{}_max".format(s), 0.) \
for s in ["hip", "thigh", "calf"] * 4
])
self.joint_limits_low = torch.tensor([
rospy.get_param(self.robot_namespace + "/joint_limits/{}_min".format(s), 0.) \
for s in ["hip", "thigh", "calf"] * 4
])
if "forward_depth" in self.all_obs_components:
resolution = self.cfg["sensor"]["forward_camera"].get(
"output_resolution",
self.cfg["sensor"]["forward_camera"]["resolution"],
)
if not self.forward_depth_embedding_dims:
self.forward_depth_buf = torch.zeros(
(self.num_envs, *resolution),
device= self.model_device,
dtype= torch.float32,
)
else:
self.forward_depth_embedding_buf = torch.zeros(
(1, self.forward_depth_embedding_dims),
device= self.model_device,
dtype= torch.float32,
)
def _init_height_points(self):
""" Returns points at which the height measurments are sampled (in base frame)
Returns:
[torch.Tensor]: Tensor of shape (num_envs, self.num_height_points, 3)
"""
return None
def _get_heights(self):
""" TODO: get estimated terrain heights around the robot base """
# currently return a zero tensor with valid size
return torch.zeros(self.num_envs, 187, device= self.model_device, requires_grad= False)
def clip_by_torque_limit(self, actions_scaled):
""" Different from simulation, we reverse the process and clip the actions directly,
so that the PD controller runs in robot but not our script.
"""
control_type = self.cfg["control"]["control_type"]
if control_type == "P":
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos
else:
raise NotImplementedError
return torch.clip(actions_scaled, actions_low, actions_high)
""" Get obs components and cat to a single obs input """
def _get_proprioception_obs(self):
# base_ang_vel = quat_rotate_inverse(
# torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
# torch.tensor(self.low_state_buffer.imu.gyroscope).unsqueeze(0),
# ).to(self.model_device)
# NOTE: Different from the isaacgym.
# The anglar velocity is already in base frame, no need to rotate
base_ang_vel = torch.tensor(self.low_state_buffer.imu.gyroscope, device= self.model_device).unsqueeze(0)
projected_gravity = quat_rotate_inverse(
torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
self.gravity_vec,
).to(self.model_device)
self.dof_pos = dof_pos = torch.tensor([
self.low_state_buffer.motorState[self.dof_map[i]].q for i in range(12)
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
self.dof_vel = dof_vel = torch.tensor([
self.low_state_buffer.motorState[self.dof_map[i]].dq for i in range(12)
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
return torch.cat([
torch.zeros((1, 3), device= self.model_device), # no linear velocity
base_ang_vel * self.obs_scales["ang_vel"],
projected_gravity,
self.command_buf * self.commands_scale,
(dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"],
dof_vel * self.obs_scales["dof_vel"],
self.actions
], dim= -1)
def _get_forward_depth_obs(self):
if not self.forward_depth_embedding_dims:
return self.forward_depth_buf.flatten(start_dim= 1)
else:
return self.forward_depth_embedding_buf.flatten(start_dim= 1)
def compute_observation(self):
""" use the updated low_state_buffer to compute observation vector """
assert hasattr(self, "legs_cmd_publisher"), "start_ros() not called, ROS handlers are not initialized!"
obs_segments = self.obs_segments
obs = []
for k, v in obs_segments.items():
obs.append(
getattr(self, "_get_" + k + "_obs")() * \
self.obs_scales.get(k, 1.)
)
obs = torch.cat(obs, dim= 1)
self.obs_buf = obs
""" The methods combined with outer model forms the step function
NOTE: the outer user handles the loop frequency.
"""
def send_action(self, actions):
""" The function that send commands to the real robot.
"""
self.actions = torch.clip(actions, -self.clip_actions, self.clip_actions)
actions = self.actions * self.motor_strength
robot_coordinates_action = self.clip_by_torque_limit(actions * self.action_scale) + self.default_dof_pos.unsqueeze(0)
# robot_coordinates_action = self.actions * self.action_scale + self.default_dof_pos.unsqueeze(0)
# debugging and logging
# transfered_action = torch.zeros_like(self.actions[0])
# for i in range(12):
# transfered_action[self.dof_map[i]] = self.actions[0, i] + self.default_dof_pos[i]
# self.debug_publisher.publish(Float32MultiArray(data=
# transfered_action\
# .cpu().numpy().astype(np.float32).tolist()
# ))
# restrict the target action delta in order to avoid robot shutdown (maybe there is another solution)
# robot_coordinates_action = torch.clip(
# robot_coordinates_action,
# self.dof_pos - 0.3,
# self.dof_pos + 0.3,
# )
# wrap the message and publish
self.publish_legs_cmd(robot_coordinates_action)
def publish_legs_cmd(self, robot_coordinates_action, kp= None, kd= None):
""" publish the joint position directly to the robot. NOTE: The joint order from input should
be in simulation order. The value should be absolute value rather than related to dof_pos.
"""
robot_coordinates_action = torch.clip(
robot_coordinates_action.cpu(),
self.joint_limits_low,
self.joint_limits_high,
)
legs_cmd = LegsCmd()
for sim_joint_idx in range(12):
real_joint_idx = self.dof_map[sim_joint_idx]
legs_cmd.cmd[real_joint_idx].mode = 10
legs_cmd.cmd[real_joint_idx].q = robot_coordinates_action[0, sim_joint_idx] if self.control_type == "P" else rospy.get_param(self.robot_namespace + "/PosStopF", (2.146e+9))
legs_cmd.cmd[real_joint_idx].dq = 0.
legs_cmd.cmd[real_joint_idx].tau = 0.
legs_cmd.cmd[real_joint_idx].Kp = self.p_gains[sim_joint_idx] if kp is None else kp
legs_cmd.cmd[real_joint_idx].Kd = self.d_gains[sim_joint_idx] if kd is None else kd
self.legs_cmd_publisher.publish(legs_cmd)
def get_obs(self):
""" The function that refreshes the buffer and return the observation vector.
"""
self.compute_observation()
self.obs_buf = torch.clip(self.obs_buf, -self.clip_obs, self.clip_obs)
return self.obs_buf.to(self.model_device)
""" Copied from legged_robot_field. Please check whether these are consistent. """
def get_obs_segment_from_components(self, components):
segments = OrderedDict()
if "proprioception" in components:
segments["proprioception"] = (48,)
if "height_measurements" in components:
segments["height_measurements"] = (187,)
if "forward_depth" in components:
resolution = self.cfg["sensor"]["forward_camera"].get(
"output_resolution",
self.cfg["sensor"]["forward_camera"]["resolution"],
)
segments["forward_depth"] = (1, *resolution)
# The following components are only for rebuilding the non-actor module.
# DO NOT use these in actor network and check consistency with simulator implementation.
if "base_pose" in components:
segments["base_pose"] = (6,) # xyz + rpy
if "robot_config" in components:
segments["robot_config"] = (1 + 3 + 1 + 12,)
if "engaging_block" in components:
# This could be wrong, please check the implementation of BarrierTrack
segments["engaging_block"] = (1 + (4 + 1) + 2,)
if "sidewall_distance" in components:
segments["sidewall_distance"] = (2,)
return segments
def get_num_obs_from_components(self, components):
obs_segments = self.get_obs_segment_from_components(components)
num_obs = 0
for k, v in obs_segments.items():
num_obs += np.prod(v)
return num_obs
""" ROS callbacks and handlers that update the buffer """
def update_low_state(self, ros_msg):
self.low_state_buffer = ros_msg
if self.move_by_wireless_remote:
self.command_buf[0, 0] = self.low_state_buffer.wirelessRemote.ly
self.command_buf[0, 1] = -self.low_state_buffer.wirelessRemote.lx # right-moving stick is positive
self.command_buf[0, 2] = -self.low_state_buffer.wirelessRemote.rx # right-moving stick is positive
# set the command to zero if it is too small
if np.linalg.norm(self.command_buf[0, :2]) < self.lin_vel_deadband:
self.command_buf[0, :2] = 0.
if np.abs(self.command_buf[0, 2]) < self.ang_vel_deadband:
self.command_buf[0, 2] = 0.
def update_base_pose(self, ros_msg):
""" update robot odometry for position """
self.base_position_buffer[0, 0] = ros_msg.pose.pose.position.x
self.base_position_buffer[0, 1] = ros_msg.pose.pose.position.y
self.base_position_buffer[0, 2] = ros_msg.pose.pose.position.z
def update_move_cmd(self, ros_msg):
self.command_buf[0, 0] = ros_msg.linear.x
self.command_buf[0, 1] = ros_msg.linear.y
self.command_buf[0, 2] = ros_msg.angular.z
def update_forward_depth(self, ros_msg):
# TODO not checked.
self.forward_depth_header = ros_msg.header
buf = ros_numpy.numpify(ros_msg)
self.forward_depth_buf = resize2d(
torch.from_numpy(buf.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(self.model_device),
self.forward_depth_buf.shape[-2:],
)
def update_forward_depth_embedding(self, ros_msg):
self.forward_depth_embedding_stamp = ros_msg.header.stamp
self.forward_depth_embedding_buf[:] = torch.tensor(ros_msg.data).unsqueeze(0) # (1, d)
def dummy_handler(self, ros_msg):
""" To meet the need of teleop-legged-robots requirements """
pass

View File

@ -1,377 +0,0 @@
#!/home/unitree/agility_ziwenz_venv/bin/python
import os
import os.path as osp
import json
import numpy as np
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
import rospy
from std_msgs.msg import Float32MultiArray
from sensor_msgs.msg import Image
import ros_numpy
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
from rsl_rl.utils.utils import get_obs_slice
@torch.no_grad()
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
""" The callback function to handle the forward depth and send the embedding through ROS topic """
buf = ros_numpy.numpify(ros_msg).astype(np.float32)
forward_depth_buf = resize2d(
torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
output_resolution,
)
embedding = model(forward_depth_buf)
ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
publisher.publish(Float32MultiArray(data= ros_data.tolist()))
class StandOnlyModel(torch.nn.Module):
def __init__(self, action_scale, dof_pos_scale, tolerance= 0.1, delta= 0.1):
rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
super().__init__()
if isinstance(action_scale, (tuple, list)):
self.register_buffer("action_scale", torch.tensor(action_scale))
else:
self.action_scale = action_scale
if isinstance(dof_pos_scale, (tuple, list)):
self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
else:
self.dof_pos_scale = dof_pos_scale
self.tolerance = tolerance
self.delta = delta
def forward(self, obs):
joint_positions = obs[..., -36:-24] / self.dof_pos_scale
diff_large_mask = torch.abs(joint_positions) > self.tolerance
target_positions = torch.zeros_like(joint_positions)
target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
return torch.clip(
target_positions / self.action_scale,
-1.0, 1.0,
)
def reset(self, *args, **kwargs):
pass
def load_walk_policy(env, model_dir):
""" Load the walk policy from the model directory """
if model_dir == None:
model = StandOnlyModel(
action_scale= env.action_scale,
dof_pos_scale= env.obs_scales["dof_pos"],
)
policy = torch.jit.script(model)
else:
with open(osp.join(model_dir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
obs_components = config_dict["env"]["obs_components"]
privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= env.get_num_obs_from_components(obs_components),
num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
num_actions= 12,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
action_rescale_ratio = model_action_scale / env.action_scale
print("walk_policy action scaling:", action_rescale_ratio.tolist())
else:
action_rescale_ratio = 1.0
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy_run(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
policy = policy_run
else:
policy = lambda x: policy_run(x) * action_rescale_ratio
return policy, model
def standup_procedure(env, ros_rate, angle_tolerance= 0.05, kp= None, kd= None, device= "cpu"):
rospy.loginfo("Robot standing up, please wait ...")
target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
while not rospy.is_shutdown():
dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
direction = [1 if i > 0 else -1 for i in diff]
if all([abs(i) < angle_tolerance for i in diff]):
break
print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
for i in range(12):
target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
env.publish_legs_cmd(target_pos,
kp= kp,
kd= kd,
)
ros_rate.sleep()
rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
while not rospy.is_shutdown():
if env.low_state_buffer.wirelessRemote.btn.components.R1:
break
if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
exit(0)
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
ros_rate.sleep()
rospy.loginfo("Robot standing up procedure finished!")
class SkilledA1Real(UnitreeA1Real):
""" Some additional methods to help the execution of skill policy """
def __init__(self, *args,
skill_mode_threhold= 0.1,
skill_vel_range= [0.0, 1.0],
**kwargs,
):
self.skill_mode_threhold = skill_mode_threhold
self.skill_vel_range = skill_vel_range
super().__init__(*args, **kwargs)
def is_skill_mode(self):
if self.move_by_wireless_remote:
return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
else:
# Not implemented yet
return False
def update_low_state(self, ros_msg):
self.low_state_buffer = ros_msg
if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
skill_vel += self.skill_vel_range[0]
self.command_buf[0, 0] = skill_vel
self.command_buf[0, 1] = 0.
self.command_buf[0, 2] = 0.
return
return super().update_low_state(ros_msg)
def main(args):
log_level = rospy.DEBUG if args.debug else rospy.INFO
rospy.init_node("a1_legged_gym_" + args.mode, anonymous= True, log_level= log_level)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
# config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp
model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")
unitree_real_env = SkilledA1Real(
robot_namespace= args.namespace,
cfg= config_dict,
forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
move_by_wireless_remote= True,
skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
model_device= model_device,
# extra_cfg= dict(
# motor_strength= torch.tensor([
# 1., 1./0.9, 1./0.9,
# 1., 1./0.9, 1./0.9,
# 1., 1., 1.,
# 1., 1., 1.,
# ], dtype= torch.float32, device= model_device, requires_grad= False),
# ),
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
config_dict["terrain"]["measure_heights"] = False
# load the model with the latest checkpoint
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
duration,
config_dict["control"]["stiffness"]["joint"],
config_dict["control"]["damping"]["joint"],
))
rospy.loginfo("[Env] torque limit: {:.1f}".format(unitree_real_env.torque_limits.mean().item()))
rospy.loginfo("[Env] action scale: {:.1f}".format(unitree_real_env.action_scale))
rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))
if args.mode == "jetson":
embeding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArray,
queue_size= 1,
)
# extract and build the torch ScriptFunction
visual_encoder = model.visual_encoder
visual_encoder = torch.jit.script(visual_encoder)
forward_depth_subscriber = rospy.Subscriber(
args.namespace + "/camera/depth/image_rect_raw",
Image,
partial(handle_forward_depth,
model= visual_encoder,
publisher= embeding_publisher,
output_resolution= config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
),
device= model_device,
),
queue_size= 1,
)
rospy.spin()
elif args.mode == "upboard":
# extract and build the torch ScriptFunction
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)
using_walk_policy = True # switch between skill policy and walk policy
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
if not args.debug:
standup_procedure(unitree_real_env, rate,
angle_tolerance= 0.1,
kp= 50,
kd= 1.,
device= model_device,
)
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
# check remote controller and decide which policy to use
if unitree_real_env.is_skill_mode():
if using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to skill policy")
using_walk_policy = False
model.reset()
else:
if not using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to walk policy")
using_walk_policy = True
walk_model.reset()
if not using_walk_policy:
obs = unitree_real_env.get_obs()
actions = policy(obs)
else:
walk_obs = unitree_real_env._get_proprioception_obs()
actions = walk_policy(walk_obs)
unitree_real_env.send_action(actions)
# unitree_real_env.send_action(torch.zeros((1, 12)))
# inference_duration = rospy.get_time() - inference_start_time
# rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
# rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
# motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
# rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
rospy.loginfo_throttle(0.1, "model reset")
model.reset()
walk_model.reset()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
elif args.mode == "full":
# extract and build the torch ScriptFunction
visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
visual_encoder = model.visual_encoder
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
visual_latent = visual_encoder(
observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
).reshape(1, -1)
obs = torch.cat([
observations[..., :obs_start],
visual_latent,
observations[..., obs_stop:],
], dim= -1)
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
obs = unitree_real_env.get_obs()
actions = policy(obs,
obs_start= visual_obs_slice[0].start.item(),
obs_stop= visual_obs_slice[0].stop.item(),
obs_shape= visual_obs_slice[1],
)
unitree_real_env.send_action(actions)
# inference_duration = rospy.get_time() - inference_start_time
motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
else:
rospy.logfatal("Unknown mode, exiting")
if __name__ == "__main__":
""" The script to run the A1 script in ROS.
It's designed as a main function and not designed to be a scalable code.
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--walkdir",
type= str,
help= "The log directory of the walking model, not for the skills.",
default= None,
)
parser.add_argument("--mode",
type= str,
help= "The mode to determine which computer to run on.",
choices= ["jetson", "upboard", "full"],
)
parser.add_argument("--debug",
action= "store_true",
)
args = parser.parse_args()
main(args)

View File

@ -1,15 +1,15 @@
import numpy as np
from os import path as osp
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
from legged_gym.utils.helpers import merge_dict
class A1TiltCfg( A1FieldCfg ):
#### uncomment this to train non-virtual terrain
# class sensor( A1FieldCfg.sensor ):
# class proprioception( A1FieldCfg.sensor.proprioception ):
# delay_action_obs = True
# latency_range = [0.04-0.0025, 0.04+0.0075]
#### uncomment the above to train non-virtual terrain
### uncomment this to train non-virtual terrain
class sensor( A1FieldCfg.sensor ):
class proprioception( A1FieldCfg.sensor.proprioception ):
latency_range = [0.04-0.0025, 0.04+0.0075]
### uncomment the above to train non-virtual terrain
class terrain( A1FieldCfg.terrain ):
max_init_terrain_level = 2
@ -41,6 +41,9 @@ class A1TiltCfg( A1FieldCfg ):
lin_vel_y = [0.0, 0.0]
ang_vel_yaw = [0., 0.]
class asset( A1FieldCfg.asset ):
penalize_contacts_on = ["base"]
class termination( A1FieldCfg.termination ):
# additional factors that determines whether to terminates the episode
termination_terms = [
@ -52,8 +55,12 @@ class A1TiltCfg( A1FieldCfg ):
]
class domain_rand( A1FieldCfg.domain_rand ):
# push_robots = True # use for virtual training
push_robots = False # use for non-virtual training
push_robots = True # use for virtual training
# push_robots = False # use for non-virtual training
init_base_rot_range = dict(
roll= [-0.1, 0.1],
pitch= [-0.1, 0.1],
)
class rewards( A1FieldCfg.rewards ):
class scales:
@ -61,10 +68,12 @@ class A1TiltCfg( A1FieldCfg ):
world_vel_l2norm = -1.
legs_energy_substeps = -1e-5
alive = 2.
penetrate_depth = -3e-3
penetrate_volume = -3e-3
exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1
penetrate_depth = -2e-3
penetrate_volume = -2e-3
exceed_dof_pos_limits = -8e-1
exceed_torque_limits_l1norm = -8e-1
tilt_cond = 8e-3
collision = -0.1
class curriculum( A1FieldCfg.curriculum ):
penetrate_volume_threshold_harder = 4000
@ -73,6 +82,7 @@ class A1TiltCfg( A1FieldCfg ):
penetrate_depth_threshold_easier = 300
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class A1TiltCfgPPO( A1FieldCfgPPO ):
class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0
@ -81,24 +91,38 @@ class A1TiltCfgPPO( A1FieldCfgPPO ):
class runner( A1FieldCfgPPO.runner ):
policy_class_name = "ActorCriticRecurrent"
experiment_name = "field_a1"
run_name = "".join(["Skill",
("Multi" if len(A1TiltCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1TiltCfg.terrain.BarrierTrack_kwargs["options"][0] if A1TiltCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
("_propDelay{:.2f}-{:.2f}".format(
A1TiltCfg.sensor.proprioception.latency_range[0],
A1TiltCfg.sensor.proprioception.latency_range[1],
) if A1TiltCfg.sensor.proprioception.delay_action_obs else ""
),
("_pPenV" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_volume, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_volume", 0.) < 0. else ""),
("_pPenD" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_depth, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_depth", 0.) < 0. else ""),
("_noPush" if not A1TiltCfg.domain_rand.push_robots else ""),
("_tiltMax{:.2f}".format(A1TiltCfg.terrain.BarrierTrack_kwargs["tilt"]["width"][1])),
("_virtual" if A1TiltCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
])
resume = True
load_run = "{Your traind walking model directory}"
load_run = "{Your virtual terrain model directory}"
load_run = "Aug17_11-13-14_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
load_run = "Aug23_22-03-41_Skilltilt_pPenV3e-3_pPenD3e-3_tiltMax0.40_virtual"
# load_run = osp.join(logs_root, "field_a1_oracle/Aug08_05-22-52_Skills_tilt_pEnergySubsteps1e-5_rAlive1_pPenV5e-3_pPenD5e-3_pPosY0.50_pYaw0.50_rTilt7e-1_pTorqueExceedIndicate1e-1_virtualTerrain_propDelay0.04-0.05_push")
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep27_13-59-27_Skills_tilt_propDelay0.04-0.05_pEnergySubsteps1e-5_pPenD2e-3_pDofLimit8e-1_rTilt8e-03_pCollision0.1_noPush_kp40_kd0.5_tiltMax0.40fromAug08_05-22-52")
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct11_12-24-22_Skills_tilt_propDelay0.04-0.05_pEnergySubsteps1e-5_pPenD2e-3_pDofLimit8e-1_rTilt8e-03_pCollision0.1_PushRobot_kp40_kd0.5_tiltMax0.40fromSep27_13-59-27")
run_name = "".join(["Skills_",
("Multi" if len(A1TiltCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1TiltCfg.terrain.BarrierTrack_kwargs["options"][0] if A1TiltCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
("_comXRange{:.1f}-{:.1f}".format(A1TiltCfg.domain_rand.com_range.x[0], A1TiltCfg.domain_rand.com_range.x[1])),
("_noLinVel" if not A1TiltCfg.env.use_lin_vel else ""),
("_propDelay{:.2f}-{:.2f}".format(
A1TiltCfg.sensor.proprioception.latency_range[0],
A1TiltCfg.sensor.proprioception.latency_range[1],
) if A1TiltCfg.sensor.proprioception.delay_action_obs else ""
),
# ("_pEnergySubsteps" + np.format_float_scientific(-A1TiltCfg.rewards.scales.legs_energy_substeps, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "legs_energy_substeps", 0.) < 0. else ""),
# ("_pPenV" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_volume, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_volume", 0.) < 0. else ""),
("_pPenD" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_depth, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_depth", 0.) < 0. else ""),
("_pDofLimit" + np.format_float_scientific(-A1TiltCfg.rewards.scales.exceed_dof_pos_limits, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "exceed_dof_pos_limits", 0.) < 0. else ""),
# ("_rTilt{:.0e}".format(A1TiltCfg.rewards.scales.tilt_cond) if getattr(A1TiltCfg.rewards.scales, "tilt_cond", 0.) > 0. else ""),
# ("_pCollision{:.1f}".format(-A1TiltCfg.rewards.scales.collision) if getattr(A1TiltCfg.rewards.scales, "collision", 0.) != 0. else ""),
("_noPush" if not A1TiltCfg.domain_rand.push_robots else "_PushRobot"),
# ("_kp{:d}".format(int(A1TiltCfg.control.stiffness["joint"])) if A1TiltCfg.control.stiffness["joint"] != 50 else ""),
# ("_kd{:.1f}".format(A1TiltCfg.control.damping["joint"]) if A1TiltCfg.control.damping["joint"] != 0. else ""),
("_noTanh"),
("_tiltMax{:.2f}".format(A1TiltCfg.terrain.BarrierTrack_kwargs["tilt"]["width"][1])),
("_virtual" if A1TiltCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
("_noResume" if not resume else "from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
max_iterations = 20000
save_interval = 500

View File

@ -1,309 +0,0 @@
import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy
import pyrealsense2 as rs
def get_encoder_script(logdir):
with open(osp.join(logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
model_device = torch.device("cuda")
unitree_real_env = UnitreeA1Real(
robot_namespace= "DummyUnitreeA1Real",
cfg= config_dict,
forward_depth_topic= "", # this env only computes parameters to build the model
forward_depth_embedding_dims= None,
model_device= model_device,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
visual_encoder = model.visual_encoder
script = torch.jit.script(visual_encoder)
return script, model_device
def get_input_filter(args):
""" This is the filter different from the simulator, but try to close the gap. """
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
image_resolution = config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
)
depth_range = config_dict["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
crop_far = args.crop_far * 1000
def input_filter(depth_image: torch.Tensor,
crop_top: int,
crop_bottom: int,
crop_left: int,
crop_right: int,
crop_far: float,
depth_min: int,
depth_max: int,
output_height: int,
output_width: int,
):
""" depth_image must have shape [1, 1, H, W] """
depth_image = depth_image[:, :,
crop_top: -crop_bottom-1,
crop_left: -crop_right-1,
]
depth_image[depth_image > crop_far] = depth_max
depth_image = torch.clip(
depth_image,
depth_min,
depth_max,
) / (depth_max - depth_min)
depth_image = resize2d(depth_image, (output_height, output_width))
return depth_image
# input_filter = torch.jit.script(input_filter)
return partial(input_filter,
crop_top= crop_top,
crop_bottom= crop_bottom,
crop_left= crop_left,
crop_right= crop_right,
crop_far= crop_far,
depth_min= depth_range[0],
depth_max= depth_range[1],
output_height= image_resolution[0],
output_width= image_resolution[1],
), depth_range
def get_started_pipeline(
height= 480,
width= 640,
fps= 30,
enable_rgb= False,
):
# By default, rgb is not used.
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
if enable_rgb:
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
profile = pipeline.start(config)
# build the sensor filter
hole_filling_filter = rs.hole_filling_filter(2)
spatial_filter = rs.spatial_filter()
spatial_filter.set_option(rs.option.filter_magnitude, 5)
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
spatial_filter.set_option(rs.option.holes_fill, 4)
temporal_filter = rs.temporal_filter()
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# decimation_filter = rs.decimation_filter()
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
def filter_func(frame):
frame = hole_filling_filter.process(frame)
frame = spatial_filter.process(frame)
frame = temporal_filter.process(frame)
# frame = decimation_filter.process(frame)
return frame
return pipeline, filter_func
def main(args):
rospy.init_node("a1_legged_gym_jetson", anonymous= True)
input_filter, depth_range = get_input_filter(args)
model_script, model_device = get_encoder_script(args.logdir)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
ros_rate = rospy.Rate(1.0 / refresh_duration)
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
else:
ros_rate = rospy.Rate(args.fps)
rs_pipeline, rs_filters = get_started_pipeline(
height= args.height,
width= args.width,
fps= args.fps,
enable_rgb= args.enable_rgb,
)
embedding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArrayStamped,
queue_size= 1,
)
if args.enable_vis:
depth_image_publisher = rospy.Publisher(
args.namespace + "/camera/depth/image_rect_raw",
Image,
queue_size= 1,
)
network_input_publisher = rospy.Publisher(
args.namespace + "/camera/depth/network_input_raw",
Image,
queue_size= 1,
)
if args.enable_rgb:
rgb_image_publisher = rospy.Publisher(
args.namespace + "/camera/color/image_raw",
Image,
queue_size= 1,
)
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
rospy.loginfo("ROS, model, realsense have been initialized.")
if args.enable_vis:
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
try:
embedding_msg = Float32MultiArrayStamped()
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
frame_got = False
while not rospy.is_shutdown():
# Wait for the depth image
frames = rs_pipeline.wait_for_frames()
embedding_msg.header.stamp = rospy.Time.now()
depth_frame = frames.get_depth_frame()
if not depth_frame:
continue
if not frame_got:
frame_got = True
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
if args.enable_rgb:
color_frame = frames.get_color_frame()
# Use this branch to log the time when image is acquired
if args.enable_vis and not color_frame is None:
color_frame = np.asanyarray(color_frame.get_data())
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
rgb_image_msg.header.stamp = rospy.Time.now()
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
rgb_image_publisher.publish(rgb_image_msg)
# Process the depth image and publish
depth_frame = rs_filters(depth_frame)
depth_image_ = np.asanyarray(depth_frame.get_data())
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
depth_image = input_filter(depth_image)
with torch.no_grad():
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
embedding_msg.header.seq += 1
embedding_msg.data = depth_embedding.tolist()
embedding_publisher.publish(embedding_msg)
# Publish the acquired image if needed
if args.enable_vis:
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
depth_image_msg.header.stamp = rospy.Time.now()
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
depth_image_publisher.publish(depth_image_msg)
network_input_np = (\
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
+ depth_range[0]
).astype(np.uint16)
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
network_input_msg.header.stamp = rospy.Time.now()
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
network_input_publisher.publish(network_input_msg)
ros_rate.sleep()
finally:
rs_pipeline.stop()
if __name__ == "__main__":
""" This script is designed to load the model and process the realsense image directly
from realsense SDK without realsense ROS wrapper
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--height",
type= int,
default= 240,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 424,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps of the realsense image",
)
parser.add_argument("--crop_left",
type= int,
default= 60,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 46,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_far",
type= float,
default= 3.0,
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
)
parser.add_argument("--enable_rgb",
action= "store_true",
help= "Whether to enable rgb image",
)
parser.add_argument("--enable_vis",
action= "store_true",
help= "Whether to publish realsense image",
)
args = parser.parse_args()
main(args)

View File

@ -34,6 +34,8 @@ from isaacgym import gymutil
import numpy as np
import torch
from legged_gym.utils.webviewer import WebViewer
# Base class for RL tasks
class BaseTask():
@ -66,6 +68,12 @@ class BaseTask():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
self.extras = {}
# create envs, sim and viewer
self.create_sim()
self.gym.prepare_sim(self.sim)
# allocate buffers
self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
@ -78,12 +86,6 @@ class BaseTask():
self.privileged_obs_buf = None
# self.num_privileged_obs = self.num_obs
self.extras = {}
# create envs, sim and viewer
self.create_sim()
self.gym.prepare_sim(self.sim)
# todo: read from config
self.enable_viewer_sync = True
self.viewer = None
@ -98,6 +100,13 @@ class BaseTask():
self.gym.subscribe_viewer_keyboard_event(
self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
def start_webviewer(self, port= 5000):
""" This method must be called after the env is fully initialized """
print("Starting webviewer on port: ", port)
print("env is passed as a parameter to the webviewer")
self.webviewer = WebViewer(host= "127.0.0.1", port= port)
self.webviewer.setup(self)
def get_observations(self):
return self.obs_buf
@ -141,4 +150,9 @@ class BaseTask():
if sync_frame_time:
self.gym.sync_frame_time(self.sim)
else:
self.gym.poll_viewer_events(self.viewer)
self.gym.poll_viewer_events(self.viewer)
if hasattr(self, "webviewer"):
self.webviewer.render(fetch_results=True,
step_graphics=True,
render_all_camera_sensors=True,
wait_for_page_load=True)

File diff suppressed because it is too large Load Diff

View File

@ -129,6 +129,7 @@ class LeggedRobotCfg(BaseConfig):
max_push_vel_xy = 1.
max_push_vel_ang = 0.
init_dof_pos_ratio_range = [0.5, 1.5]
init_base_vel_range = [-1., 1.]
class rewards:
class scales:
@ -160,6 +161,7 @@ class LeggedRobotCfg(BaseConfig):
class obs_scales:
lin_vel = 2.0
ang_vel = 0.25
commands = [2., 2., 0.25] # matching lin_vel and ang_vel scales
dof_pos = 1.0
dof_vel = 0.05
height_measurements = 5.0
@ -182,6 +184,12 @@ class LeggedRobotCfg(BaseConfig):
ref_env = 0
pos = [10, 0, 6] # [m]
lookat = [11., 5, 3.] # [m]
stream_depth = False # for webviewer
draw_commands = True # for debugger
class commands:
color = [0.1, 0.8, 0.1] # rgb
size = 0.5
class sim:
dt = 0.005

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,13 @@
import random
from isaacgym.torch_utils import torch_rand_float, get_euler_xyz, quat_from_euler_xyz, tf_apply
from isaacgym import gymtorch, gymapi, gymutil
import numpy as np
from isaacgym.torch_utils import torch_rand_float, tf_apply
from isaacgym import gymutil, gymapi
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from legged_gym.envs.base.legged_robot_field import LeggedRobotField
class LeggedRobotNoisy(LeggedRobotField):
class LeggedRobotNoisyMixin:
""" This class should be independent from the terrain, but depend on the sensors of the parent
class.
"""
@ -16,22 +15,23 @@ class LeggedRobotNoisy(LeggedRobotField):
def clip_position_action_by_torque_limit(self, actions_scaled):
""" For position control, scaled actions should be in the coordinate of robot default dof pos
"""
if hasattr(self, "proprioception_output"):
dof_vel = self.proprioception_output[:, -24:-12] / self.obs_scales.dof_vel
dof_pos_ = self.proprioception_output[:, -36:-24] / self.obs_scales.dof_pos
if hasattr(self, "dof_vel_obs_output_buffer"):
dof_vel = self.dof_vel_obs_output_buffer / self.obs_scales.dof_vel
else:
dof_vel = self.dof_vel
if hasattr(self, "dof_pos_obs_output_buffer"):
dof_pos_ = self.dof_pos_obs_output_buffer / self.obs_scales.dof_pos
else:
dof_pos_ = self.dof_pos - self.default_dof_pos
p_limits_low = (-self.torque_limits) + self.d_gains*dof_vel
p_limits_high = (self.torque_limits) + self.d_gains*dof_vel
actions_low = (p_limits_low/self.p_gains) + dof_pos_
actions_high = (p_limits_high/self.p_gains) + dof_pos_
actions_scaled_torque_clipped = torch.clip(actions_scaled, actions_low, actions_high)
return actions_scaled_torque_clipped
actions_scaled_clipped = torch.clip(actions_scaled, actions_low, actions_high)
return actions_scaled_clipped
def pre_physics_step(self, actions):
self.forward_depth_refreshed = False # incase _get_forward_depth_obs is called multiple times
self.proprioception_refreshed = False
self.set_buffers_refreshed_to_false()
return_ = super().pre_physics_step(actions)
if isinstance(self.cfg.control.action_scale, (tuple, list)):
@ -40,24 +40,24 @@ class LeggedRobotNoisy(LeggedRobotField):
self.actions_scaled = self.actions * self.cfg.control.action_scale
control_type = self.cfg.control.control_type
if control_type == "P":
actions_scaled_torque_clipped = self.clip_position_action_by_torque_limit(self.actions_scaled)
actions_scaled_clipped = self.clip_position_action_by_torque_limit(self.actions_scaled)
else:
raise NotImplementedError
else:
actions_scaled_torque_clipped = self.actions * self.cfg.control.action_scale
actions_scaled_clipped = self.actions * self.cfg.control.action_scale
if getattr(self.cfg.control, "action_delay", False):
# always put the latest action at the end of the buffer
self.actions_history_buffer = torch.roll(self.actions_history_buffer, shifts= -1, dims= 0)
self.actions_history_buffer[-1] = actions_scaled_torque_clipped
self.actions_history_buffer[-1] = actions_scaled_clipped
# get the delayed action
self.action_delayed_frames = ((self.current_action_delay / self.dt) + 1).to(int)
self.actions_scaled_torque_clipped = self.actions_history_buffer[
-self.action_delayed_frames,
action_delayed_frames = ((self.action_delay_buffer / self.dt) + 1).to(int)
self.actions_scaled_clipped = self.actions_history_buffer[
-action_delayed_frames,
torch.arange(self.num_envs, device= self.device),
]
else:
self.actions_scaled_torque_clipped = actions_scaled_torque_clipped
self.actions_scaled_clipped = actions_scaled_clipped
return return_
@ -69,12 +69,12 @@ class LeggedRobotNoisy(LeggedRobotField):
return super()._compute_torques(actions)
else:
if hasattr(self, "motor_strength"):
actions_scaled_torque_clipped = self.motor_strength * self.actions_scaled_torque_clipped
actions_scaled_clipped = self.motor_strength * self.actions_scaled_clipped
else:
actions_scaled_torque_clipped = self.actions_scaled_torque_clipped
actions_scaled_clipped = self.actions_scaled_clipped
control_type = self.cfg.control.control_type
if control_type == "P":
torques = self.p_gains * (actions_scaled_torque_clipped + self.default_dof_pos - self.dof_pos) \
torques = self.p_gains * (actions_scaled_clipped + self.default_dof_pos - self.dof_pos) \
- self.d_gains * self.dof_vel
else:
raise NotImplementedError
@ -93,9 +93,7 @@ class LeggedRobotNoisy(LeggedRobotField):
self.max_torques,
)
### The set torque limit is usally smaller than the robot dataset
self.torque_exceed_count_substep[(torch.abs(self.torques) > self.torque_limits).any(dim= -1)] += 1
### Hack to check the torque limit exceeding by your own value.
# self.torque_exceed_count_envstep[(torch.abs(self.torques) > 38.).any(dim= -1)] += 1
self.torque_exceed_count_substep[(torch.abs(self.torques) > self.torque_limits * self.cfg.rewards.soft_torque_limit).any(dim= -1)] += 1
### count how many times in the episode the robot is out of dof pos limit (summing all dofs)
self.out_of_dof_pos_limit_count_substep += self._reward_dof_pos_limits().int()
@ -130,43 +128,11 @@ class LeggedRobotNoisy(LeggedRobotField):
if len(resample_env_ids) > 0:
self._resample_action_delay(resample_env_ids)
if hasattr(self, "proprioception_buffer"):
resampling_time = getattr(self.cfg.sensor.proprioception, "latency_resampling_time", self.dt)
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
if len(resample_env_ids) > 0:
self._resample_proprioception_latency(resample_env_ids)
if hasattr(self, "forward_depth_buffer"):
resampling_time = getattr(self.cfg.sensor.forward_camera, "latency_resampling_time", self.dt)
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
if len(resample_env_ids) > 0:
self._resample_forward_camera_latency(resample_env_ids)
for sensor_name in self.available_sensors:
if hasattr(self, sensor_name + "_latency_buffer"):
self._resample_sensor_latency_if_needed(sensor_name)
self.torque_exceed_count_envstep[(torch.abs(self.substep_torques) > self.torque_limits).any(dim= 1).any(dim= 1)] += 1
def _resample_action_delay(self, env_ids):
self.current_action_delay[env_ids] = torch_rand_float(
self.cfg.control.action_delay_range[0],
self.cfg.control.action_delay_range[1],
(len(env_ids), 1),
device= self.device,
).flatten()
def _resample_proprioception_latency(self, env_ids):
self.current_proprioception_latency[env_ids] = torch_rand_float(
self.cfg.sensor.proprioception.latency_range[0],
self.cfg.sensor.proprioception.latency_range[1],
(len(env_ids), 1),
device= self.device,
).flatten()
def _resample_forward_camera_latency(self, env_ids):
self.current_forward_camera_latency[env_ids] = torch_rand_float(
self.cfg.sensor.forward_camera.latency_range[0],
self.cfg.sensor.forward_camera.latency_range[1],
(len(env_ids), 1),
device= self.device,
).flatten()
self.torque_exceed_count_envstep[(torch.abs(self.substep_torques) > self.torque_limits * self.cfg.rewards.soft_torque_limit).any(dim= 1).any(dim= 1)] += 1
def _init_buffers(self):
return_ = super()._init_buffers()
@ -174,9 +140,29 @@ class LeggedRobotNoisy(LeggedRobotField):
if getattr(self.cfg.control, "action_delay", False):
assert hasattr(self.cfg.control, "action_delay_range") and hasattr(self.cfg.control, "action_delay_resample_time"), "Please specify action_delay_range and action_delay_resample_time in the config file."
""" Used in pre-physics step """
self.cfg.control.action_history_buffer_length = int((self.cfg.control.action_delay_range[1] + self.dt) / self.dt)
self.actions_history_buffer = torch.zeros(
self.build_action_delay_buffer()
self.component_governed_by_sensor = dict()
for sensor_name in self.available_sensors:
if hasattr(self.cfg.sensor, sensor_name):
self.set_latency_buffer_for_sensor(sensor_name)
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
assert not hasattr(self, component + "_obs_buffer"), "The obs component {} already has a buffer and corresponding sensor. Should not be governed also by {}".format(component, sensor_name)
self.set_obs_buffers_for_component(component, sensor_name)
if component == "forward_depth":
self.build_depth_image_processor_buffers(sensor_name)
self.max_torques = torch.zeros_like(self.torques[..., 0])
self.torque_exceed_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the torque exceeds the limit
self.torque_exceed_count_envstep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of envsteps that the torque exceeds the limit
self.out_of_dof_pos_limit_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the dof pos exceeds the limit
return return_
def build_action_delay_buffer(self):
""" Used in pre-physics step """
self.cfg.control.action_history_buffer_length = int((self.cfg.control.action_delay_range[1] + self.dt) / self.dt)
self.actions_history_buffer = torch.zeros(
(
self.cfg.control.action_history_buffer_length,
self.num_envs,
@ -185,66 +171,27 @@ class LeggedRobotNoisy(LeggedRobotField):
dtype= torch.float32,
device= self.device,
)
self.current_action_delay = torch_rand_float(
self.action_delay_buffer = torch_rand_float(
self.cfg.control.action_delay_range[0],
self.cfg.control.action_delay_range[1],
(self.num_envs, 1),
device= self.device,
).flatten()
self.action_delayed_frames = ((self.current_action_delay / self.dt) + 1).to(int)
self.action_delayed_frames = ((self.action_delay_buffer / self.dt) + 1).to(int)
if "proprioception" in all_obs_components and hasattr(self.cfg.sensor, "proprioception"):
""" Adding proprioception delay buffer """
self.cfg.sensor.proprioception.buffer_length = int((self.cfg.sensor.proprioception.latency_range[1] + self.dt) / self.dt)
self.proprioception_buffer = torch.zeros(
(
self.cfg.sensor.proprioception.buffer_length,
self.num_envs,
self.get_num_obs_from_components(["proprioception"]),
),
dtype= torch.float32,
device= self.device,
)
self.current_proprioception_latency = torch_rand_float(
self.cfg.sensor.proprioception.latency_range[0],
self.cfg.sensor.proprioception.latency_range[1],
(self.num_envs, 1),
device= self.device,
).flatten()
self.proprioception_delayed_frames = ((self.current_proprioception_latency / self.dt) + 1).to(int)
if "forward_depth" in all_obs_components and hasattr(self.cfg.sensor, "forward_camera"):
output_resolution = getattr(self.cfg.sensor.forward_camera, "output_resolution", self.cfg.sensor.forward_camera.resolution)
self.cfg.sensor.forward_camera.buffer_length = int((self.cfg.sensor.forward_camera.latency_range[1] + self.cfg.sensor.forward_camera.refresh_duration) / self.dt)
self.forward_depth_buffer = torch.zeros(
(
self.cfg.sensor.forward_camera.buffer_length,
self.num_envs,
1,
output_resolution[0],
output_resolution[1],
),
dtype= torch.float32,
device= self.device,
)
self.forward_depth_delayed_frames = torch.ones((self.num_envs,), device= self.device, dtype= int) * self.cfg.sensor.forward_camera.buffer_length
self.current_forward_camera_latency = torch_rand_float(
self.cfg.sensor.forward_camera.latency_range[0],
self.cfg.sensor.forward_camera.latency_range[1],
(self.num_envs, 1),
device= self.device,
).flatten()
if hasattr(self.cfg.sensor.forward_camera, "resized_resolution"):
self.forward_depth_resize_transform = T.Resize(
def build_depth_image_processor_buffers(self, sensor_name):
assert sensor_name == "forward_camera", "Only forward_camera is supported for now."
if hasattr(getattr(self.cfg.sensor, sensor_name), "resized_resolution"):
self.forward_depth_resize_transform = T.Resize(
self.cfg.sensor.forward_camera.resized_resolution,
interpolation= T.InterpolationMode.BICUBIC,
)
self.contour_detection_kernel = torch.zeros(
(8, 1, 3, 3),
dtype= torch.float32,
device= self.device,
)
# emperical values to be more sensitive to vertical edges
(8, 1, 3, 3),
dtype= torch.float32,
device= self.device,
)
# emperical values to be more sensitive to vertical edges
self.contour_detection_kernel[0, :, 1, 1] = 0.5
self.contour_detection_kernel[0, :, 0, 0] = -0.5
self.contour_detection_kernel[1, :, 1, 1] = 0.1
@ -261,40 +208,176 @@ class LeggedRobotNoisy(LeggedRobotField):
self.contour_detection_kernel[6, :, 2, 1] = -0.1
self.contour_detection_kernel[7, :, 1, 1] = 0.5
self.contour_detection_kernel[7, :, 2, 2] = -0.5
""" Considering sensors are not necessarily matching observation component,
we need to set the buffers for each obs component and latency buffer for each sensor.
"""
def set_obs_buffers_for_component(self, component, sensor_name):
buffer_length = int(getattr(self.cfg.sensor, sensor_name).latency_range[1] / self.dt) + 1
# use super().get_obs_segment_from_components() to get the obs shape to prevent post processing
# overrides the buffer shape
obs_buffer = torch.zeros(
(
buffer_length,
self.num_envs,
*(self.get_obs_segment_from_components([component])[component]), # tuple(obs_shape)
),
dtype= torch.float32,
device= self.device,
)
setattr(self, component + "_obs_buffer", obs_buffer)
setattr(self, component + "_obs_refreshed", False)
self.component_governed_by_sensor[component] = sensor_name
self.max_torques = torch.zeros_like(self.torques[..., 0])
self.torque_exceed_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the torque exceeds the limit
self.torque_exceed_count_envstep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of envsteps that the torque exceeds the limit
self.out_of_dof_pos_limit_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the dof pos exceeds the limit
def set_latency_buffer_for_sensor(self, sensor_name):
latency_buffer = torch_rand_float(
getattr(self.cfg.sensor, sensor_name).latency_range[0],
getattr(self.cfg.sensor, sensor_name).latency_range[1],
(self.num_envs, 1),
device= self.device,
).flatten()
# using setattr to set the buffer
setattr(self, sensor_name + "_latency_buffer", latency_buffer)
if "camera" in sensor_name:
setattr(
self,
sensor_name + "_delayed_frames",
torch.zeros_like(latency_buffer, dtype= torch.long, device= self.device),
)
def _resample_sensor_latency_if_needed(self, sensor_name):
resampling_time = getattr(getattr(self.cfg.sensor, sensor_name), "latency_resampling_time", self.dt)
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
if len(resample_env_ids) > 0:
getattr(self, sensor_name + "_latency_buffer")[resample_env_ids] = torch_rand_float(
getattr(getattr(self.cfg.sensor, sensor_name), "latency_range")[0],
getattr(getattr(self.cfg.sensor, sensor_name), "latency_range")[1],
(len(resample_env_ids), 1),
device= self.device,
).flatten()
return return_
def _resample_action_delay(self, env_ids):
self.action_delay_buffer[env_ids] = torch_rand_float(
self.cfg.control.action_delay_range[0],
self.cfg.control.action_delay_range[1],
(len(env_ids), 1),
device= self.device,
).flatten()
def _reset_buffers(self, env_ids):
return_ = super()._reset_buffers(env_ids)
if hasattr(self, "actions_history_buffer"):
self.actions_history_buffer[:, env_ids] = 0.
self.action_delayed_frames[env_ids] = self.cfg.control.action_history_buffer_length
if hasattr(self, "forward_depth_buffer"):
self.forward_depth_buffer[:, env_ids] = 0.
self.forward_depth_delayed_frames[env_ids] = self.cfg.sensor.forward_camera.buffer_length
if hasattr(self, "proprioception_buffer"):
self.proprioception_buffer[:, env_ids] = 0.
self.proprioception_delayed_frames[env_ids] = self.cfg.sensor.proprioception.buffer_length
for sensor_name in self.available_sensors:
if not hasattr(self.cfg.sensor, sensor_name):
continue
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
if hasattr(self, component + "_obs_buffer"):
getattr(self, component + "_obs_buffer")[:, env_ids] = 0.
setattr(self, component + "_obs_refreshed", False)
if "camera" in sensor_name:
getattr(self, sensor_name + "_delayed_frames")[env_ids] = 0
return return_
def _draw_debug_vis(self):
return_ = super()._draw_debug_vis()
def _build_forward_depth_intrinsic(self):
sim_raw_resolution = self.cfg.sensor.forward_camera.resolution
sim_cropping_h = self.cfg.sensor.forward_camera.crop_top_bottom
sim_cropping_w = self.cfg.sensor.forward_camera.crop_left_right
cropped_resolution = [ # (H, W)
sim_raw_resolution[0] - sum(sim_cropping_h),
sim_raw_resolution[1] - sum(sim_cropping_w),
]
network_input_resolution = self.cfg.sensor.forward_camera.output_resolution
x_fov = torch.mean(torch.tensor(self.cfg.sensor.forward_camera.horizontal_fov).float()) \
/ 180 * np.pi
fx = (sim_raw_resolution[1]) / (2 * torch.tan(x_fov / 2))
fy = fx # * (sim_raw_resolution[0] / sim_raw_resolution[1])
fx = fx * network_input_resolution[1] / cropped_resolution[1]
fy = fy * network_input_resolution[0] / cropped_resolution[0]
cx = (sim_raw_resolution[1] / 2) - sim_cropping_w[0]
cy = (sim_raw_resolution[0] / 2) - sim_cropping_h[0]
cx = cx * network_input_resolution[1] / cropped_resolution[1]
cy = cy * network_input_resolution[0] / cropped_resolution[0]
self.forward_depth_intrinsic = torch.tensor([
[fx, 0., cx],
[0., fy, cy],
[0., 0., 1.],
], device= self.device)
x_arr = torch.linspace(
0,
network_input_resolution[1] - 1,
network_input_resolution[1],
)
y_arr = torch.linspace(
0,
network_input_resolution[0] - 1,
network_input_resolution[0],
)
x_mesh, y_mesh = torch.meshgrid(x_arr, y_arr, indexing= "xy")
# (H, W, 2) -> (H * W, 3) row wise
self.forward_depth_pixel_mesh = torch.stack([
x_mesh,
y_mesh,
torch.ones_like(x_mesh),
], dim= -1).view(-1, 3).float().to(self.device)
def _draw_pointcloud_from_depth_image(self,
env_h, sensor_h,
camera_intrinsic,
depth_image, # torch tensor (H, W)
offset = [-2, 0, 1],
):
"""
Args:
offset: drawing points directly based on the sensor will interfere with the
depth image rendering. so we need to provide the offset to the pointcloud
"""
assert self.num_envs == 1, "LeggedRobotNoisy: Only implemented when num_envs == 1"
camera_transform = self.gym.get_camera_transform(self.sim, env_h, sensor_h)
camera_intrinsic_inv = torch.inverse(camera_intrinsic)
# (H * W, 3) -> (3, H * W) -> (H * W, 3)
depth_image = depth_image * (self.cfg.sensor.forward_camera.depth_range[1] - self.cfg.sensor.forward_camera.depth_range[0]) + self.cfg.sensor.forward_camera.depth_range[0]
points = camera_intrinsic_inv @ self.forward_depth_pixel_mesh.T * depth_image.view(-1)
points = points.T
sphere_geom = gymutil.WireframeSphereGeometry(0.008, 8, 8, None, color= (0.9, 0.1, 0.9))
for p in points:
sphere_pose = gymapi.Transform(
p= camera_transform.transform_point(gymapi.Vec3(
p[2] + offset[0],
-p[0] + offset[1],
-p[1] + offset[2],
)), # +z forward to +x forward
r= None,
)
gymutil.draw_lines(sphere_geom, self.gym, self.viewer, env_h, sphere_pose)
def _draw_sensor_reading_vis(self, env_h, sensor_hd):
return_ = super()._draw_sensor_reading_vis(env_h, sensor_hd)
if hasattr(self, "forward_depth_output"):
if self.num_envs == 1:
import matplotlib.pyplot as plt
forward_depth_np = self.forward_depth_output[0, 0].detach().cpu().numpy() # (H, W)
plt.imshow(forward_depth_np, cmap= "gray", vmin= 0, vmax= 1)
plt.pause(0.001)
if getattr(self.cfg.viewer, "forward_depth_as_pointcloud", False):
if not hasattr(self, "forward_depth_intrinsic"):
self._build_forward_depth_intrinsic()
for sensor_name, sensor_h in sensor_hd.items():
if "forward_camera" in sensor_name:
self._draw_pointcloud_from_depth_image(
env_h, sensor_h,
self.forward_depth_intrinsic,
self.forward_depth_output[0, 0].detach(),
)
else:
import matplotlib.pyplot as plt
forward_depth_np = self.forward_depth_output[0, 0].detach().cpu().numpy() # (H, W)
plt.imshow(forward_depth_np, cmap= "gray", vmin= 0, vmax= 1)
plt.pause(0.001)
else:
print("LeggedRobotNoisy: More than one robot, stop showing camera image")
return return_
""" Steps to simulate stereo camera depth image """
""" ########## Steps to simulate stereo camera depth image ########## """
def _add_depth_contour(self, depth_images):
mask = F.max_pool2d(
torch.abs(F.conv2d(depth_images, self.contour_detection_kernel, padding= 1)).max(dim= -3, keepdim= True)[0],
@ -545,32 +628,133 @@ class LeggedRobotNoisy(LeggedRobotField):
depth_images_ = self._crop_depth_images(depth_images_)
if hasattr(self, "forward_depth_resize_transform"):
depth_images_ = self.forward_depth_resize_transform(depth_images_)
depth_images_ = depth_images_.clip(0, 1)
return depth_images_.unsqueeze(0) # (1, N, 1, H, W)
""" ########## Override the observation functions to add latencies and artifacts ########## """
def set_buffers_refreshed_to_false(self):
for sensor_name in self.available_sensors:
if hasattr(self.cfg.sensor, sensor_name):
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
setattr(self, component + "_obs_refreshed", False)
@torch.no_grad()
def _get_ang_vel_obs(self, privileged= False):
if hasattr(self, "ang_vel_obs_buffer") and (not self.ang_vel_obs_refreshed) and (not privileged):
self.ang_vel_obs_buffer = torch.cat([
self.ang_vel_obs_buffer[1:],
super()._get_ang_vel_obs().unsqueeze(0),
], dim= 0)
component_governed_by = self.component_governed_by_sensor["ang_vel"]
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
self.ang_vel_obs_output_buffer = self.ang_vel_obs_buffer[
-buffer_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.ang_vel_obs_refreshed = True
if privileged or not hasattr(self, "ang_vel_obs_buffer"):
return super()._get_ang_vel_obs(privileged)
return self.ang_vel_obs_output_buffer
@torch.no_grad()
def _get_projected_gravity_obs(self, privileged= False):
if hasattr(self, "projected_gravity_obs_buffer") and (not self.projected_gravity_obs_refreshed) and (not privileged):
self.projected_gravity_obs_buffer = torch.cat([
self.projected_gravity_obs_buffer[1:],
super()._get_projected_gravity_obs().unsqueeze(0),
], dim= 0)
component_governed_by = self.component_governed_by_sensor["projected_gravity"]
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
self.projected_gravity_obs_output_buffer = self.projected_gravity_obs_buffer[
-buffer_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.projected_gravity_obs_refreshed = True
if privileged or not hasattr(self, "projected_gravity_obs_buffer"):
return super()._get_projected_gravity_obs(privileged)
return self.projected_gravity_obs_output_buffer
@torch.no_grad()
def _get_commands_obs(self, privileged= False):
if hasattr(self, "commands_obs_buffer") and (not self.commands_obs_refreshed) and (not privileged):
self.commands_obs_buffer = torch.cat([
self.commands_obs_buffer[1:],
super()._get_commands_obs().unsqueeze(0),
], dim= 0)
component_governed_by = self.component_governed_by_sensor["commands"]
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
self.commands_obs_output_buffer = self.commands_obs_buffer[
-buffer_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.commands_obs_refreshed = True
if privileged or not hasattr(self, "commands_obs_buffer"):
return super()._get_commands_obs(privileged)
return self.commands_obs_output_buffer
@torch.no_grad()
def _get_dof_pos_obs(self, privileged= False):
if hasattr(self, "dof_pos_obs_buffer") and (not self.dof_pos_obs_refreshed) and (not privileged):
self.dof_pos_obs_buffer = torch.cat([
self.dof_pos_obs_buffer[1:],
super()._get_dof_pos_obs().unsqueeze(0),
], dim= 0)
component_governed_by = self.component_governed_by_sensor["dof_pos"]
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
self.dof_pos_obs_output_buffer = self.dof_pos_obs_buffer[
-buffer_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.dof_pos_obs_refreshed = True
if privileged or not hasattr(self, "dof_pos_obs_buffer"):
return super()._get_dof_pos_obs(privileged)
return self.dof_pos_obs_output_buffer
@torch.no_grad()
def _get_dof_vel_obs(self, privileged= False):
if hasattr(self, "dof_vel_obs_buffer") and (not self.dof_vel_obs_refreshed) and (not privileged):
self.dof_vel_obs_buffer = torch.cat([
self.dof_vel_obs_buffer[1:],
super()._get_dof_vel_obs().unsqueeze(0),
], dim= 0)
component_governed_by = self.component_governed_by_sensor["dof_vel"]
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
self.dof_vel_obs_output_buffer = self.dof_vel_obs_buffer[
-buffer_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.dof_vel_obs_refreshed = True
if privileged or not hasattr(self, "dof_vel_obs_buffer"):
return super()._get_dof_vel_obs(privileged)
return self.dof_vel_obs_output_buffer
@torch.no_grad()
def _get_forward_depth_obs(self, privileged= False):
if not self.forward_depth_refreshed and hasattr(self.cfg.sensor, "forward_camera") and (not privileged):
self.forward_depth_buffer = torch.cat([
self.forward_depth_buffer[1:],
if hasattr(self, "forward_depth_obs_buffer") and (not self.forward_depth_obs_refreshed) and hasattr(self.cfg.sensor, "forward_camera") and (not privileged):
# TODO: any variables named with "forward_camera" here should be rearanged
# in a individual class method.
self.forward_depth_obs_buffer = torch.cat([
self.forward_depth_obs_buffer[1:],
self._process_depth_image(self.sensor_tensor_dict["forward_depth"]),
], dim= 0)
delay_refresh_mask = (self.episode_length_buf % int(self.cfg.sensor.forward_camera.refresh_duration / self.dt)) == 0
# NOTE: if the delayed frames is greater than the last frame, the last image should be used.
frame_select = (self.current_forward_camera_latency / self.dt).to(int)
self.forward_depth_delayed_frames = torch.where(
frame_select = (self.forward_camera_latency_buffer / self.dt).to(int)
self.forward_camera_delayed_frames = torch.where(
delay_refresh_mask,
torch.minimum(
frame_select,
self.forward_depth_delayed_frames + 1,
self.forward_camera_delayed_frames + 1,
),
self.forward_depth_delayed_frames + 1,
self.forward_camera_delayed_frames + 1,
)
self.forward_depth_delayed_frames = torch.clip(
self.forward_depth_delayed_frames,
self.forward_camera_delayed_frames = torch.clip(
self.forward_camera_delayed_frames,
0,
self.cfg.sensor.forward_camera.buffer_length,
self.forward_depth_obs_buffer.shape[0],
)
self.forward_depth_output = self.forward_depth_buffer[
-self.forward_depth_delayed_frames,
self.forward_depth_output = self.forward_depth_obs_buffer[
-self.forward_camera_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
self.forward_depth_refreshed = True
@ -579,29 +763,6 @@ class LeggedRobotNoisy(LeggedRobotField):
return self.forward_depth_output.flatten(start_dim= 1)
def _get_proprioception_obs(self, privileged= False):
if not self.proprioception_refreshed and hasattr(self.cfg.sensor, "proprioception") and (not privileged):
self.proprioception_buffer = torch.cat([
self.proprioception_buffer[1:],
super()._get_proprioception_obs().unsqueeze(0),
], dim= 0)
# NOTE: if the delayed frames is greater than the last frame, the last image should be used. [0.04-0.0075, 0.04+0.0025]
self.proprioception_delayed_frames = ((self.current_proprioception_latency / self.dt) + 1).to(int)
self.proprioception_output = self.proprioception_buffer[
-self.proprioception_delayed_frames,
torch.arange(self.num_envs, device= self.device),
].clone()
### NOTE: WARN: ERROR: remove this code in final version, no action delay should be used.
if getattr(self.cfg.sensor.proprioception, "delay_action_obs", False) or getattr(self.cfg.sensor.proprioception, "delay_privileged_action_obs", False):
raise ValueError("LeggedRobotNoisy: No action delay should be used. Please remove these settings")
# The last-action is not delayed.
self.proprioception_output[:, -12:] = self.proprioception_buffer[-1, :, -12:]
self.proprioception_refreshed = True
if not hasattr(self.cfg.sensor, "proprioception") or privileged:
return super()._get_proprioception_obs(privileged)
return self.proprioception_output.flatten(start_dim= 1)
def get_obs_segment_from_components(self, obs_components):
obs_segments = super().get_obs_segment_from_components(obs_components)
if "forward_depth" in obs_components:
@ -611,27 +772,3 @@ class LeggedRobotNoisy(LeggedRobotField):
self.cfg.sensor.forward_camera.resolution,
))
return obs_segments
def _reward_exceed_torque_limits_i(self):
""" Indicator function """
max_torques = torch.abs(self.substep_torques).max(dim= 1)[0]
exceed_torque_each_dof = max_torques > self.torque_limits
exceed_torque = exceed_torque_each_dof.any(dim= 1)
return exceed_torque.to(torch.float32)
def _reward_exceed_torque_limits_square(self):
""" square function for exceeding part """
exceeded_torques = torch.abs(self.substep_torques) - self.torque_limits
exceeded_torques[exceeded_torques < 0.] = 0.
# sum along decimation axis and dof axis
return torch.square(exceeded_torques).sum(dim= 1).sum(dim= 1)
def _reward_exceed_torque_limits_l1norm(self):
""" square function for exceeding part """
exceeded_torques = torch.abs(self.substep_torques) - self.torque_limits
exceeded_torques[exceeded_torques < 0.] = 0.
# sum along decimation axis and dof axis
return torch.norm(exceeded_torques, p= 1, dim= -1).sum(dim= 1)
def _reward_exceed_dof_pos_limits(self):
return self.substep_exceed_dof_pos_limits.to(torch.float32).sum(dim= -1).mean(dim= -1)

View File

@ -0,0 +1,9 @@
from legged_gym.envs.base.legged_robot_field import LeggedRobotField
from legged_gym.envs.base.legged_robot_noisy import LeggedRobotNoisyMixin
class RobotFieldNoisy(LeggedRobotNoisyMixin, LeggedRobotField):
""" Using inheritance to combine the two classes.
Then, LeggedRobotNoisyMixin and LeggedRobot can also be used elsewhere.
"""
pass

View File

@ -0,0 +1,280 @@
""" Basic model configs for Unitree Go2 """
import numpy as np
import os.path as osp
from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
go2_action_scale = 0.5
go2_const_dof_range = dict(
Hip_max= 1.0472,
Hip_min= -1.0472,
Front_Thigh_max= 3.4907,
Front_Thigh_min= -1.5708,
Rear_Thingh_max= 4.5379,
Rear_Thingh_min= -0.5236,
Calf_max= -0.83776,
Calf_min= -2.7227,
)
class Go2RoughCfg( LeggedRobotCfg ):
class env:
num_envs = 4096
num_observations = None # No use, use obs_components
num_privileged_obs = None # No use, use privileged_obs_components
use_lin_vel = False # to be decided
num_actions = 12
send_timeouts = True # send time out information to the algorithm
episode_length_s = 20 # episode length in seconds
obs_components = [
"lin_vel",
"ang_vel",
"projected_gravity",
"commands",
"dof_pos",
"dof_vel",
"last_actions",
"height_measurements",
]
class sensor:
class proprioception:
obs_components = ["ang_vel", "projected_gravity", "commands", "dof_pos", "dof_vel"]
latency_range = [0.005, 0.045] # [s]
latency_resample_time = 5.0 # [s]
class terrain:
selected = "TerrainPerlin"
mesh_type = None
measure_heights = True
# x: [-0.5, 1.5], y: [-0.5, 0.5] range for go2
measured_points_x = [i for i in np.arange(-0.5, 1.51, 0.1)]
measured_points_y = [i for i in np.arange(-0.5, 0.51, 0.1)]
horizontal_scale = 0.025 # [m]
vertical_scale = 0.005 # [m]
border_size = 5 # [m]
curriculum = False
static_friction = 1.0
dynamic_friction = 1.0
restitution = 0.
max_init_terrain_level = 5 # starting curriculum state
terrain_length = 4.
terrain_width = 4.
num_rows= 16 # number of terrain rows (levels)
num_cols = 16 # number of terrain cols (types)
slope_treshold = 1.
TerrainPerlin_kwargs = dict(
zScale= 0.07,
frequency= 10,
)
class commands( LeggedRobotCfg.commands ):
heading_command = False
resampling_time = 5 # [s]
lin_cmd_cutoff = 0.2
ang_cmd_cutoff = 0.2
class ranges( LeggedRobotCfg.commands.ranges ):
lin_vel_x = [-1.0, 1.5]
lin_vel_y = [-1., 1.]
ang_vel_yaw = [-2., 2.]
class init_state( LeggedRobotCfg.init_state ):
pos = [0., 0., 0.5] # [m]
default_joint_angles = { # 12 joints in the order of simulation
"FL_hip_joint": 0.1,
"FL_thigh_joint": 0.7,
"FL_calf_joint": -1.5,
"FR_hip_joint": -0.1,
"FR_thigh_joint": 0.7,
"FR_calf_joint": -1.5,
"RL_hip_joint": 0.1,
"RL_thigh_joint": 1.0,
"RL_calf_joint": -1.5,
"RR_hip_joint": -0.1,
"RR_thigh_joint": 1.0,
"RR_calf_joint": -1.5,
}
class control( LeggedRobotCfg.control ):
stiffness = {'joint': 40.}
damping = {'joint': 1.}
action_scale = go2_action_scale
computer_clip_torque = False
motor_clip_torque = True
class asset( LeggedRobotCfg.asset ):
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/go2/urdf/go2.urdf'
name = "go2"
foot_name = "foot"
front_hip_names = ["FL_hip_joint", "FR_hip_joint"]
rear_hip_names = ["RL_hip_joint", "RR_hip_joint"]
penalize_contacts_on = ["thigh", "calf"]
terminate_after_contacts_on = ["base"]
sdk_dof_range = go2_const_dof_range
dof_velocity_override = 35.
class termination:
termination_terms = [
"roll",
"pitch",
]
roll_kwargs = dict(
threshold= 3.0, # [rad]
)
pitch_kwargs = dict(
threshold= 3.0, # [rad] # for leap, jump
)
class domain_rand( LeggedRobotCfg.domain_rand ):
randomize_com = True
class com_range:
x = [-0.2, 0.2]
y = [-0.1, 0.1]
z = [-0.05, 0.05]
randomize_motor = True
leg_motor_strength_range = [0.8, 1.2]
randomize_base_mass = True
added_mass_range = [1.0, 3.0]
randomize_friction = True
friction_range = [0., 2.]
init_base_pos_range = dict(
x= [0.05, 0.6],
y= [-0.25, 0.25],
)
init_base_rot_range = dict(
roll= [-0.75, 0.75],
pitch= [-0.75, 0.75],
)
init_base_vel_range = dict(
x= [-0.2, 1.5],
y= [-0.2, 0.2],
z= [-0.2, 0.2],
roll= [-1., 1.],
pitch= [-1., 1.],
yaw= [-1., 1.],
)
init_dof_vel_range = [-5, 5]
push_robots = True
max_push_vel_xy = 0.5 # [m/s]
push_interval_s = 2
class rewards( LeggedRobotCfg.rewards ):
class scales:
tracking_lin_vel = 1.
tracking_ang_vel = 1.
energy_substeps = -2e-5
stand_still = -2.
dof_error_named = -1.
dof_error = -0.01
# penalty for hardware safety
exceed_dof_pos_limits = -0.4
exceed_torque_limits_l1norm = -0.4
dof_vel_limits = -0.4
dof_error_names = ["FL_hip_joint", "FR_hip_joint", "RL_hip_joint", "RR_hip_joint"]
only_positive_rewards = False
soft_dof_vel_limit = 0.9
soft_dof_pos_limit = 0.9
soft_torque_limit = 0.9
class normalization( LeggedRobotCfg.normalization ):
class obs_scales( LeggedRobotCfg.normalization.obs_scales ):
lin_vel = 1.
height_measurements_offset = -0.2
clip_actions_method = None # let the policy learns not to exceed the limits
class noise( LeggedRobotCfg.noise ):
add_noise = False
class viewer( LeggedRobotCfg.viewer ):
pos = [-1., -1., 0.4]
lookat = [0., 0., 0.3]
class sim( LeggedRobotCfg.sim ):
body_measure_points = { # transform are related to body frame
"base": dict(
x= [i for i in np.arange(-0.24, 0.41, 0.03)],
y= [-0.08, -0.04, 0.0, 0.04, 0.08],
z= [i for i in np.arange(-0.061, 0.071, 0.03)],
transform= [0., 0., 0.005, 0., 0., 0.],
),
"thigh": dict(
x= [
-0.16, -0.158, -0.156, -0.154, -0.152,
-0.15, -0.145, -0.14, -0.135, -0.13, -0.125, -0.12, -0.115, -0.11, -0.105, -0.1, -0.095, -0.09, -0.085, -0.08, -0.075, -0.07, -0.065, -0.05,
0.0, 0.05, 0.1,
],
y= [-0.015, -0.01, 0.0, -0.01, 0.015],
z= [-0.03, -0.015, 0.0, 0.015],
transform= [0., 0., -0.1, 0., 1.57079632679, 0.],
),
"calf": dict(
x= [i for i in np.arange(-0.13, 0.111, 0.03)],
y= [-0.015, 0.0, 0.015],
z= [-0.015, 0.0, 0.015],
transform= [0., 0., -0.11, 0., 1.57079632679, 0.],
),
}
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class Go2RoughCfgPPO( LeggedRobotCfgPPO ):
class algorithm( LeggedRobotCfgPPO.algorithm ):
entropy_coef = 0.01
clip_min_std = 0.2
learning_rate = 1e-4
optimizer_class_name = "AdamW"
class policy( LeggedRobotCfgPPO.policy ):
# configs for estimator module
estimator_obs_components = [
"ang_vel",
"projected_gravity",
"commands",
"dof_pos",
"dof_vel",
"last_actions",
]
estimator_target_components = ["lin_vel"]
replace_state_prob = 1.0
class estimator_kwargs:
hidden_sizes = [128, 64]
nonlinearity = "CELU"
# configs for (critic) encoder
encoder_component_names = ["height_measurements"]
encoder_class_name = "MlpModel"
class encoder_kwargs:
hidden_sizes = [128, 64]
nonlinearity = "CELU"
encoder_output_size = 32
critic_encoder_component_names = ["height_measurements"]
init_noise_std = 0.5
# configs for policy: using recurrent policy with GRU
rnn_type = 'gru'
mu_activation = None
class runner( LeggedRobotCfgPPO.runner ):
policy_class_name = "EncoderStateAcRecurrent"
algorithm_class_name = "EstimatorPPO"
experiment_name = "rough_go2"
resume = False
load_run = None
run_name = "".join(["Go2Rough",
("_pEnergy" + np.format_float_scientific(Go2RoughCfg.rewards.scales.energy_substeps, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.energy_substeps != 0 else ""),
("_pDofErr" + np.format_float_scientific(Go2RoughCfg.rewards.scales.dof_error, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.dof_error != 0 else ""),
("_pDofErrN" + np.format_float_scientific(Go2RoughCfg.rewards.scales.dof_error_named, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.dof_error_named != 0 else ""),
("_pStand" + np.format_float_scientific(Go2RoughCfg.rewards.scales.stand_still, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.stand_still != 0 else ""),
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
max_iterations = 2000
save_interval = 2000
log_interval = 100

View File

@ -0,0 +1,233 @@
""" Config to train the whole parkour oracle policy """
import numpy as np
from os import path as osp
from collections import OrderedDict
from datetime import datetime
from legged_gym.utils.helpers import merge_dict
from legged_gym.envs.go2.go2_field_config import Go2FieldCfg, Go2FieldCfgPPO, Go2RoughCfgPPO
multi_process_ = True
class Go2DistillCfg( Go2FieldCfg ):
class env( Go2FieldCfg.env ):
num_envs = 256
obs_components = [
"lin_vel",
"ang_vel",
"projected_gravity",
"commands",
"dof_pos",
"dof_vel",
"last_actions",
"forward_depth",
]
privileged_obs_components = [
"lin_vel",
"ang_vel",
"projected_gravity",
"commands",
"dof_pos",
"dof_vel",
"last_actions",
"height_measurements",
]
class terrain( Go2FieldCfg.terrain ):
if multi_process_:
num_rows = 4
num_cols = 1
curriculum = False
BarrierTrack_kwargs = merge_dict(Go2FieldCfg.terrain.BarrierTrack_kwargs, dict(
leap= dict(
length= [0.05, 0.8],
depth= [0.5, 0.8],
height= 0.15, # expected leap height over the gap
fake_offset= 0.1,
),
))
class sensor( Go2FieldCfg.sensor ):
class forward_camera:
obs_components = ["forward_depth"]
resolution = [int(480/4), int(640/4)]
position = dict(
mean= [0.24, -0.0175, 0.12],
std= [0.01, 0.0025, 0.03],
)
rotation = dict(
lower= [-0.1, 0.37, -0.1],
upper= [0.1, 0.43, 0.1],
)
resized_resolution = [48, 64]
output_resolution = [48, 64]
horizontal_fov = [86, 90]
crop_top_bottom = [int(48/4), 0]
crop_left_right = [int(28/4), int(36/4)]
near_plane = 0.05
depth_range = [0.0, 3.0]
latency_range = [0.08, 0.142]
latency_resample_time = 5.0
refresh_duration = 1/10 # [s]
class commands( Go2FieldCfg.commands ):
# a mixture of command sampling and goal_based command update allows only high speed range
# in x-axis but no limits on y-axis and yaw-axis
lin_cmd_cutoff = 0.2
class ranges( Go2FieldCfg.commands.ranges ):
# lin_vel_x = [0.6, 1.8]
lin_vel_x = [-0.6, 2.0]
is_goal_based = True
class goal_based:
# the ratios are related to the goal position in robot frame
x_ratio = None # sample from lin_vel_x range
y_ratio = 1.2
yaw_ratio = 0.8
follow_cmd_cutoff = True
x_stop_by_yaw_threshold = 1. # stop when yaw is over this threshold [rad]
class normalization( Go2FieldCfg.normalization ):
class obs_scales( Go2FieldCfg.normalization.obs_scales ):
forward_depth = 1.0
class noise( Go2FieldCfg.noise ):
add_noise = False
class noise_scales( Go2FieldCfg.noise.noise_scales ):
forward_depth = 0.0
### noise for simulating sensors
commands = 0.1
lin_vel = 0.1
ang_vel = 0.1
projected_gravity = 0.02
dof_pos = 0.02
dof_vel = 0.2
last_actions = 0.
### noise for simulating sensors
class forward_depth:
stereo_min_distance = 0.175 # when using (480, 640) resolution
stereo_far_distance = 1.2
stereo_far_noise_std = 0.08
stereo_near_noise_std = 0.02
stereo_full_block_artifacts_prob = 0.008
stereo_full_block_values = [0.0, 0.25, 0.5, 1., 3.]
stereo_full_block_height_mean_std = [62, 1.5]
stereo_full_block_width_mean_std = [3, 0.01]
stereo_half_block_spark_prob = 0.02
stereo_half_block_value = 3000
sky_artifacts_prob = 0.0001
sky_artifacts_far_distance = 2.
sky_artifacts_values = [0.6, 1., 1.2, 1.5, 1.8]
sky_artifacts_height_mean_std = [2, 3.2]
sky_artifacts_width_mean_std = [2, 3.2]
class curriculum:
no_moveup_when_fall = False
class sim( Go2FieldCfg.sim ):
no_camera = False
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class Go2DistillCfgPPO( Go2FieldCfgPPO ):
class algorithm( Go2FieldCfgPPO.algorithm ):
entropy_coef = 0.0
using_ppo = False
num_learning_epochs = 8
num_mini_batches = 2
distill_target = "l1"
learning_rate = 3e-4
optimizer_class_name = "AdamW"
teacher_act_prob = 0.
distillation_loss_coef = 1.0
# update_times_scale = 100
action_labels_from_sample = False
teacher_policy_class_name = "EncoderStateAcRecurrent"
teacher_ac_path = osp.join(logs_root, "field_go2",
"{Your trained oracle parkour model directory}",
"{The latest model filename in the directory}"
)
class teacher_policy( Go2FieldCfgPPO.policy ):
num_actor_obs = 48 + 21 * 11
num_critic_obs = 48 + 21 * 11
num_actions = 12
obs_segments = OrderedDict([
("lin_vel", (3,)),
("ang_vel", (3,)),
("projected_gravity", (3,)),
("commands", (3,)),
("dof_pos", (12,)),
("dof_vel", (12,)),
("last_actions", (12,)), # till here: 3+3+3+3+12+12+12 = 48
("height_measurements", (1, 21, 11)),
])
class policy( Go2RoughCfgPPO.policy ):
# configs for estimator module
estimator_obs_components = [
"ang_vel",
"projected_gravity",
"commands",
"dof_pos",
"dof_vel",
"last_actions",
]
estimator_target_components = ["lin_vel"]
replace_state_prob = 1.0
class estimator_kwargs:
hidden_sizes = [128, 64]
nonlinearity = "CELU"
# configs for visual encoder
encoder_component_names = ["forward_depth"]
encoder_class_name = "Conv2dHeadModel"
class encoder_kwargs:
channels = [16, 32, 32]
kernel_sizes = [5, 4, 3]
strides = [2, 2, 1]
hidden_sizes = [128]
use_maxpool = True
nonlinearity = "LeakyReLU"
# configs for critic encoder
critic_encoder_component_names = ["height_measurements"]
critic_encoder_class_name = "MlpModel"
class critic_encoder_kwargs:
hidden_sizes = [128, 64]
nonlinearity = "CELU"
encoder_output_size = 32
init_noise_std = 0.1
if multi_process_:
runner_class_name = "TwoStageRunner"
class runner( Go2FieldCfgPPO.runner ):
policy_class_name = "EncoderStateAcRecurrent"
algorithm_class_name = "EstimatorTPPO"
experiment_name = "distill_go2"
num_steps_per_env = 32
if multi_process_:
pretrain_iterations = -1
class pretrain_dataset:
data_dir = "{A temporary directory to store collected trajectory}"
dataset_loops = -1
random_shuffle_traj_order = True
keep_latest_n_trajs = 1500
starting_frame_range = [0, 50]
resume = True
load_run = osp.join(logs_root, "field_go2",
"{Your trained oracle parkour model directory}",
)
ckpt_manipulator = "replace_encoder0" if "field_go2" in load_run else None
run_name = "".join(["Go2_",
("{:d}skills".format(len(Go2DistillCfg.terrain.BarrierTrack_kwargs["options"]))),
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
max_iterations = 60000
log_interval = 100

View File

@ -0,0 +1,229 @@
""" Config to train the whole parkour oracle policy """
import numpy as np
from os import path as osp
from collections import OrderedDict
from legged_gym.envs.go2.go2_config import Go2RoughCfg, Go2RoughCfgPPO
class Go2FieldCfg( Go2RoughCfg ):
class init_state( Go2RoughCfg.init_state ):
pos = [0.0, 0.0, 0.7]
zero_actions = False
class sensor( Go2RoughCfg.sensor):
class proprioception( Go2RoughCfg.sensor.proprioception ):
# latency_range = [0.0, 0.0]
latency_range = [0.005, 0.045] # [s]
class terrain( Go2RoughCfg.terrain ):
num_rows = 10
num_cols = 40
selected = "BarrierTrack"
slope_treshold = 20.
max_init_terrain_level = 2
curriculum = True
pad_unavailable_info = True
BarrierTrack_kwargs = dict(
options= [
"jump",
"leap",
"hurdle",
"down",
"tilted_ramp",
"stairsup",
"stairsdown",
"discrete_rect",
"slope",
"wave",
], # each race track will permute all the options
jump= dict(
height= [0.05, 0.5],
depth= [0.1, 0.3],
# fake_offset= 0.1,
),
leap= dict(
length= [0.05, 0.8],
depth= [0.5, 0.8],
height= 0.2, # expected leap height over the gap
# fake_offset= 0.1,
),
hurdle= dict(
height= [0.05, 0.5],
depth= [0.2, 0.5],
# fake_offset= 0.1,
curved_top_rate= 0.1,
),
down= dict(
height= [0.1, 0.6],
depth= [0.3, 0.5],
),
tilted_ramp= dict(
tilt_angle= [0.2, 0.5],
switch_spacing= 0.,
spacing_curriculum= False,
overlap_size= 0.2,
depth= [-0.1, 0.1],
length= [0.6, 1.2],
),
slope= dict(
slope_angle= [0.2, 0.42],
length= [1.2, 2.2],
use_mean_height_offset= True,
face_angle= [-3.14, 0, 1.57, -1.57],
no_perlin_rate= 0.2,
length_curriculum= True,
),
slopeup= dict(
slope_angle= [0.2, 0.42],
length= [1.2, 2.2],
use_mean_height_offset= True,
face_angle= [-0.2, 0.2],
no_perlin_rate= 0.2,
length_curriculum= True,
),
slopedown= dict(
slope_angle= [0.2, 0.42],
length= [1.2, 2.2],
use_mean_height_offset= True,
face_angle= [-0.2, 0.2],
no_perlin_rate= 0.2,
length_curriculum= True,
),
stairsup= dict(
height= [0.1, 0.3],
length= [0.3, 0.5],
residual_distance= 0.05,
num_steps= [3, 19],
num_steps_curriculum= True,
),
stairsdown= dict(
height= [0.1, 0.3],
length= [0.3, 0.5],
num_steps= [3, 19],
num_steps_curriculum= True,
),
discrete_rect= dict(
max_height= [0.05, 0.2],
max_size= 0.6,
min_size= 0.2,
num_rects= 10,
),
wave= dict(
amplitude= [0.1, 0.15], # in meter
frequency= [0.6, 1.0], # in 1/meter
),
track_width= 3.2,
track_block_length= 2.4,
wall_thickness= (0.01, 0.6),
wall_height= [-0.5, 2.0],
add_perlin_noise= True,
border_perlin_noise= True,
border_height= 0.,
virtual_terrain= False,
draw_virtual_terrain= True,
engaging_next_threshold= 0.8,
engaging_finish_threshold= 0.,
curriculum_perlin= False,
no_perlin_threshold= 0.1,
randomize_obstacle_order= True,
n_obstacles_per_track= 1,
)
class commands( Go2RoughCfg.commands ):
# a mixture of command sampling and goal_based command update allows only high speed range
# in x-axis but no limits on y-axis and yaw-axis
lin_cmd_cutoff = 0.2
class ranges( Go2RoughCfg.commands.ranges ):
# lin_vel_x = [0.6, 1.8]
lin_vel_x = [-0.6, 2.0]
is_goal_based = True
class goal_based:
# the ratios are related to the goal position in robot frame
x_ratio = None # sample from lin_vel_x range
y_ratio = 1.2
yaw_ratio = 1.
follow_cmd_cutoff = True
x_stop_by_yaw_threshold = 1. # stop when yaw is over this threshold [rad]
class asset( Go2RoughCfg.asset ):
terminate_after_contacts_on = []
penalize_contacts_on = ["thigh", "calf", "base"]
class termination( Go2RoughCfg.termination ):
roll_kwargs = dict(
threshold= 1.4, # [rad]
)
pitch_kwargs = dict(
threshold= 1.6, # [rad]
)
timeout_at_border = True
timeout_at_finished = False
class rewards( Go2RoughCfg.rewards ):
class scales:
tracking_lin_vel = 1.
tracking_ang_vel = 1.
energy_substeps = -2e-7
torques = -1e-7
stand_still = -1.
dof_error_named = -1.
dof_error = -0.005
collision = -0.05
lazy_stop = -3.
# penalty for hardware safety
exceed_dof_pos_limits = -0.1
exceed_torque_limits_l1norm = -0.1
# penetration penalty
penetrate_depth = -0.05
class noise( Go2RoughCfg.noise ):
add_noise = False
class curriculum:
penetrate_depth_threshold_harder = 100
penetrate_depth_threshold_easier = 200
no_moveup_when_fall = True
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class Go2FieldCfgPPO( Go2RoughCfgPPO ):
class algorithm( Go2RoughCfgPPO.algorithm ):
entropy_coef = 0.0
class runner( Go2RoughCfgPPO.runner ):
experiment_name = "field_go2"
resume = True
load_run = osp.join(logs_root, "rough_go2",
"{Your trained walking model directory}",
)
run_name = "".join(["Go2_",
("{:d}skills".format(len(Go2FieldCfg.terrain.BarrierTrack_kwargs["options"]))),
("_pEnergy" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.energy_substeps, precision=2)),
# ("_pDofErr" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.dof_error, precision=2) if getattr(Go2FieldCfg.rewards.scales, "dof_error", 0.) != 0. else ""),
# ("_pHipDofErr" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.dof_error_named, precision=2) if getattr(Go2FieldCfg.rewards.scales, "dof_error_named", 0.) != 0. else ""),
# ("_pStand" + np.format_float_scientific(Go2FieldCfg.rewards.scales.stand_still, precision=2)),
# ("_pTerm" + np.format_float_scientific(Go2FieldCfg.rewards.scales.termination, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "termination") else ""),
("_pTorques" + np.format_float_scientific(Go2FieldCfg.rewards.scales.torques, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "torques") else ""),
# ("_pColl" + np.format_float_scientific(Go2FieldCfg.rewards.scales.collision, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "collision") else ""),
("_pLazyStop" + np.format_float_scientific(Go2FieldCfg.rewards.scales.lazy_stop, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "lazy_stop") else ""),
# ("_trackSigma" + np.format_float_scientific(Go2FieldCfg.rewards.tracking_sigma, precision=2) if Go2FieldCfg.rewards.tracking_sigma != 0.25 else ""),
# ("_pPenV" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.penetrate_volume, precision=2)),
("_pPenD" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.penetrate_depth, precision=2)),
# ("_pTorqueL1" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.exceed_torque_limits_l1norm, precision=2)),
("_penEasier{:d}".format(Go2FieldCfg.curriculum.penetrate_depth_threshold_easier)),
("_penHarder{:d}".format(Go2FieldCfg.curriculum.penetrate_depth_threshold_harder)),
# ("_leapMin" + np.format_float_scientific(Go2FieldCfg.terrain.BarrierTrack_kwargs["leap"]["length"][0], precision=2)),
("_leapHeight" + np.format_float_scientific(Go2FieldCfg.terrain.BarrierTrack_kwargs["leap"]["height"], precision=2)),
("_motorTorqueClip" if Go2FieldCfg.control.motor_clip_torque else ""),
# ("_noMoveupWhenFall" if Go2FieldCfg.curriculum.no_moveup_when_fall else ""),
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
])
max_iterations = 38000
save_interval = 10000
log_interval = 100

View File

@ -8,8 +8,10 @@ import os
import json
import os.path as osp
from legged_gym import LEGGED_GYM_ROOT_DIR
from legged_gym.envs import *
from legged_gym.utils import get_args, task_registry
from legged_gym.utils import get_args
from legged_gym.utils.task_registry import task_registry
from legged_gym.utils.helpers import update_cfg_from_args, class_to_dict, update_class_from_dict
from legged_gym.debugger import break_into_debugger
@ -17,55 +19,22 @@ from rsl_rl.modules import build_actor_critic
from rsl_rl.runners.dagger_saver import DemonstrationSaver, DaggerSaver
def main(args):
RunnerCls = DaggerSaver
# RunnerCls = DemonstrationSaver
RunnerCls = DaggerSaver if args.load_run else DemonstrationSaver
success_traj_only = False
teacher_act_prob = 0.1
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
if RunnerCls == DaggerSaver:
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
d = json.load(f, object_pairs_hook= OrderedDict)
update_class_from_dict(env_cfg, d, strict= True)
update_class_from_dict(train_cfg, d, strict= True)
# Some custom settings
# ####### customized option to increase data distribution #######
action_sample_std = 0.0
####### customized option to increase data distribution #######
# env_cfg.env.num_envs = 6
# env_cfg.terrain.curriculum = True
# env_cfg.terrain.max_init_terrain_level = 0
# env_cfg.terrain.border_size = 1.
############# some predefined options #############
if len(env_cfg.terrain.BarrierTrack_kwargs["options"]) == 1:
env_cfg.terrain.num_rows = 20; env_cfg.terrain.num_cols = 30
else: # for parkour env
# >>> option 1
env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 2.8
env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 2.4
env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 0.6)
env_cfg.domain_rand.init_base_pos_range["x"] = (0.4, 1.8)
env_cfg.terrain.num_rows = 12; env_cfg.terrain.num_cols = 10
# >>> option 2
# env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 3.
# env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 4.0
# env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.5, -0.2)
# env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 1.4)
# env_cfg.domain_rand.init_base_pos_range["x"] = (1.6, 2.0)
# env_cfg.terrain.num_rows = 16; env_cfg.terrain.num_cols = 5
# >>> option 3
# env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 1.6
# env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 2.2
# env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.5, 0.1)
# env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 0.5)
# env_cfg.domain_rand.init_base_pos_range["x"] = (0.2, 0.9)
# env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 1
# action_sample_std = 0.1
# env_cfg.terrain.num_rows = 22; env_cfg.terrain.num_cols = 16
pass
if (env_cfg.terrain.BarrierTrack_kwargs["options"][0] == "leap") and all(i == env_cfg.terrain.BarrierTrack_kwargs["options"][0] for i in env_cfg.terrain.BarrierTrack_kwargs["options"]):
######### For leap, because the platform is usually higher than the ground.
env_cfg.terrain.num_rows = 80
env_cfg.terrain.num_cols = 1
env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 1.6
env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.01, 0.5)
env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.4, 0.2) # randomize incase of terrain that have side walls
env_cfg.terrain.BarrierTrack_kwargs["border_height"] = -0.4
env_cfg.terrain.num_rows = 8; env_cfg.terrain.num_cols = 40
# Done custom settings
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
@ -73,82 +42,81 @@ def main(args):
config = class_to_dict(train_cfg)
config.update(class_to_dict(env_cfg))
teacher_act_prob = config["algorithm"]["teacher_act_prob"] if args.teacher_prob is None else args.teacher_prob
action_std = config["policy"]["init_noise_std"] if args.action_std is None else args.action_std
# create teacher policy
policy = build_actor_critic(
env,
train_cfg.algorithm.teacher_policy_class_name,
config["algorithm"]["teacher_policy_class_name"],
config["algorithm"]["teacher_policy"],
).to(env.device)
# load the policy is possible
if train_cfg.algorithm.teacher_ac_path is not None:
state_dict = torch.load(train_cfg.algorithm.teacher_ac_path, map_location= "cpu")
if config["algorithm"]["teacher_ac_path"] is not None:
if "{LEGGED_GYM_ROOT_DIR}" in config["algorithm"]["teacher_ac_path"]:
config["algorithm"]["teacher_ac_path"] = config["algorithm"]["teacher_ac_path"].format(LEGGED_GYM_ROOT_DIR= LEGGED_GYM_ROOT_DIR)
state_dict = torch.load(config["algorithm"]["teacher_ac_path"], map_location= "cpu")
teacher_actor_critic_state_dict = state_dict["model_state_dict"]
policy.load_state_dict(teacher_actor_critic_state_dict)
# build runner
track_header = "".join(env_cfg.terrain.BarrierTrack_kwargs["options"])
if env_cfg.commands.ranges.lin_vel_x[1] > 0.0:
cmd_vel = "_cmd{:.1f}-{:.1f}".format(env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1])
elif env_cfg.commands.ranges.lin_vel_x[1] == 0. and len(env_cfg.terrain.BarrierTrack_kwargs["options"]) == 1 \
or (env_cfg.terrain.BarrierTrack_kwargs["options"][0] == env_cfg.terrain.BarrierTrack_kwargs["options"][1]):
obstacle_id = env.terrain.track_options_id_dict[env_cfg.terrain.BarrierTrack_kwargs["options"][0]]
try:
overrided_vel = train_cfg.algorithm.teacher_policy.cmd_vel_mapping[obstacle_id]
except:
overrided_vel = train_cfg.algorithm.teacher_policy.cmd_vel_mapping[str(obstacle_id)]
cmd_vel = "_cmdOverride{:.1f}".format(overrided_vel)
else:
cmd_vel = "_cmdMutex"
datadir = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), "logs")
runner_kwargs = dict(
env= env,
policy= policy,
save_dir= osp.join(
train_cfg.runner.pretrain_dataset.scan_dir if RunnerCls == DaggerSaver else osp.join(datadir, "distill_{}_dagger".format(args.task.split("_")[0])),
config["runner"]["pretrain_dataset"]["data_dir"] if RunnerCls == DaggerSaver else osp.join("/localdata_ssd/zzw/athletic-isaac_tmp", "{}_dagger".format(config["runner"]["experiment_name"])),
datetime.now().strftime('%b%d_%H-%M-%S') + "_" + "".join([
track_header,
cmd_vel,
"_lowBorder" if env_cfg.terrain.BarrierTrack_kwargs["border_height"] < 0 else "",
"_trackWidth{:.1f}".format(env_cfg.terrain.BarrierTrack_kwargs["track_width"]) if env_cfg.terrain.BarrierTrack_kwargs["track_width"] < 1.8 else "",
"_blockLength{:.1f}".format(env_cfg.terrain.BarrierTrack_kwargs["track_block_length"]) if env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] > 1.6 else "",
"_addMassMin{:.1f}".format(env_cfg.domain_rand.added_mass_range[0]) if env_cfg.domain_rand.added_mass_range[0] > 1. else "",
"_comMean{:.2f}".format((env_cfg.domain_rand.com_range.x[0] + env_cfg.domain_rand.com_range.x[1])/2),
"_1cols" if env_cfg.terrain.num_cols == 1 else "",
"_teacherProb{:.1f}".format(teacher_act_prob),
"_randOrder" if env_cfg.terrain.BarrierTrack_kwargs.get("randomize_obstacle_order", False) else "",
("_noPerlinRate{:.1f}".format(
(env_cfg.terrain.BarrierTrack_kwargs["no_perlin_threshold"] - env_cfg.terrain.TerrainPerlin_kwargs["zScale"][0]) / \
(env_cfg.terrain.TerrainPerlin_kwargs["zScale"][1] - env_cfg.terrain.TerrainPerlin_kwargs["zScale"][0])
)),
) if isinstance(env_cfg.terrain.TerrainPerlin_kwargs["zScale"], (tuple, list)) else ""),
("_fric{:.1f}-{:.1f}".format(*env_cfg.domain_rand.friction_range)),
"_successOnly" if success_traj_only else "",
"_aStd{:.2f}".format(action_sample_std) if (action_sample_std > 0. and RunnerCls == DaggerSaver) else "",
"_aStd{:.2f}".format(action_std) if (action_std > 0. and RunnerCls == DaggerSaver) else "",
] + ([] if RunnerCls == DemonstrationSaver else ["_" + "_".join(args.load_run.split("_")[:2])])
),
),
rollout_storage_length= 256,
min_timesteps= 1e6, # 1e6,
min_episodes= 2e4 if RunnerCls == DaggerSaver else 2e-3,
min_timesteps= 1e9, # 1e6,
min_episodes= 1e6 if RunnerCls == DaggerSaver else 2e5,
use_critic_obs= True,
success_traj_only= success_traj_only,
obs_disassemble_mapping= dict(
forward_depth= "normalized_image",
),
demo_by_sample= config["algorithm"].get("action_labels_from_sample", False),
)
if RunnerCls == DaggerSaver:
# kwargs for dagger saver
runner_kwargs.update(dict(
training_policy_logdir= osp.join(
"logs",
train_cfg.runner.experiment_name,
config["runner"]["experiment_name"],
args.load_run,
),
teacher_act_prob= teacher_act_prob,
update_times_scale= config["algorithm"]["update_times_scale"],
action_sample_std= action_sample_std,
update_times_scale= config["algorithm"].get("update_times_scale", 1e5),
action_sample_std= action_std,
log_to_tensorboard= args.log,
))
runner = RunnerCls(**runner_kwargs)
runner.collect_and_save(config= config)
if __name__ == "__main__":
args = get_args()
args = get_args(
custom_args= [
{"name": "--teacher_prob", "type": float, "default": None, "help": "probability of using teacher's action"},
{"name": "--action_std", "type": float, "default": None, "help": "override the action sample std during rollout. None for using model's std"},
{"name": "--log", "action": "store_true", "help": "log the data to tensorboard"},
],
)
main(args)

View File

@ -39,7 +39,8 @@ import isaacgym
from isaacgym import gymtorch, gymapi
from isaacgym.torch_utils import *
from legged_gym.envs import *
from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
from legged_gym.utils import get_args, export_policy_as_jit, Logger
from legged_gym.utils.task_registry import task_registry
from legged_gym.utils.helpers import update_class_from_dict
from legged_gym.utils.observation import get_obs_slice
from legged_gym.debugger import break_into_debugger
@ -78,18 +79,30 @@ def create_recording_camera(gym, env_handle,
@torch.no_grad()
def play(args):
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
d = json.load(f, object_pairs_hook= OrderedDict)
update_class_from_dict(env_cfg, d, strict= True)
update_class_from_dict(train_cfg, d, strict= True)
if args.load_cfg:
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
d = json.load(f, object_pairs_hook= OrderedDict)
update_class_from_dict(env_cfg, d, strict= True)
update_class_from_dict(train_cfg, d, strict= True)
# override some parameters for testing
if env_cfg.terrain.selected == "BarrierTrack":
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
env_cfg.env.episode_length_s = 20
env_cfg.terrain.max_init_terrain_level = 0
env_cfg.terrain.num_rows = 1
env_cfg.terrain.num_cols = 1
env_cfg.terrain.num_rows = 4
env_cfg.terrain.num_cols = 8
env_cfg.terrain.BarrierTrack_kwargs["options"] = [
"jump",
"leap",
"down",
"hurdle",
"tilted_ramp",
"stairsup",
"discrete_rect",
"wave",
]
env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1
else:
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
env_cfg.env.episode_length_s = 60
@ -98,67 +111,43 @@ def play(args):
env_cfg.terrain.max_init_terrain_level = 0
env_cfg.terrain.num_rows = 1
env_cfg.terrain.num_cols = 1
env_cfg.terrain.curriculum = False
env_cfg.terrain.BarrierTrack_kwargs["options"] = [
# "crawl",
"jump",
# "leap",
# "tilt",
]
if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys():
env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track")
env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 2
# env_cfg.terrain.curriculum = False
# env_cfg.asset.fix_base_link = True
env_cfg.env.episode_length_s = 1000
env_cfg.commands.resampling_time = int(1e16)
env_cfg.commands.ranges.lin_vel_x = [1.2, 1.2]
if "distill" in args.task:
env_cfg.commands.ranges.lin_vel_x = [0.0, 0.0]
env_cfg.commands.ranges.lin_vel_y = [-0., 0.]
env_cfg.commands.ranges.ang_vel_yaw = [-0., 0.]
env_cfg.domain_rand.push_robots = False
env_cfg.domain_rand.init_base_pos_range = dict(
x= [0.6, 0.6],
y= [-0.05, 0.05],
)
env_cfg.termination.termination_terms = []
# env_cfg.termination.termination_terms = []
env_cfg.termination.timeout_at_border = False
env_cfg.termination.timeout_at_finished = False
env_cfg.viewer.debug_viz = False # in a1_distill, setting this to true will constantly showing the egocentric depth view.
env_cfg.viewer.debug_viz = True
env_cfg.viewer.draw_measure_heights = False
env_cfg.viewer.draw_height_measurements = False
env_cfg.viewer.draw_volume_sample_points = False
train_cfg.runner.resume = True
env_cfg.viewer.draw_sensors = False
if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"):
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
# train_cfg.runner.resume = (args.load_run is not None)
train_cfg.runner_class_name = "OnPolicyRunner"
if "distill" in args.task: # to save the memory
train_cfg.algorithm.teacher_policy.sub_policy_paths = []
train_cfg.algorithm.teacher_policy_class_name = "ActorCritic"
train_cfg.algorithm.teacher_policy = dict(
num_actor_obs= 48,
num_critic_obs= 48,
num_actions= 12,
)
if args.no_throw:
env_cfg.init_state.pos[2] = 0.4
env_cfg.domain_rand.init_base_pos_range["x"] = [0.4, 0.4]
env_cfg.domain_rand.init_base_vel_range = [0., 0.]
env_cfg.domain_rand.init_dof_vel_range = [0., 0.]
env_cfg.domain_rand.init_base_rot_range["roll"] = [0., 0.]
env_cfg.domain_rand.init_base_rot_range["pitch"] = [0., 0.]
env_cfg.domain_rand.init_base_rot_range["yaw"] = [0., 0.]
env_cfg.domain_rand.init_base_vel_range = [0., 0.]
env_cfg.domain_rand.init_dof_pos_ratio_range = [1., 1.]
######### Some hacks to run ActorCriticMutex policy ##########
if False: # for a1
train_cfg.runner.policy_class_name = "ActorCriticClimbMutex"
train_cfg.policy.sub_policy_class_name = "ActorCriticRecurrent"
logs_root = "logs"
train_cfg.policy.sub_policy_paths = [ # must in the order of obstacle ID
logs_root + "/field_a1_oracle/Jun03_00-01-38_SkillsPlaneWalking_pEnergySubsteps1e-5_rAlive2_pTorqueExceedIndicate1e+1_noCurriculum_propDelay0.04-0.05_noPerlinRate-2.0_nSubsteps4_envFreq50.0_aScale244",
logs_root + "/field_a1_oracle/Aug08_05-22-52_Skills_tilt_pEnergySubsteps1e-5_rAlive1_pPenV5e-3_pPenD5e-3_pPosY0.50_pYaw0.50_rTilt7e-1_pTorqueExceedIndicate1e-1_virtualTerrain_propDelay0.04-0.05_push/",
logs_root + "/field_a1_oracle/May21_05-25-19_Skills_crawl_pEnergy2e-5_rAlive1_pPenV6e-2_pPenD6e-2_pPosY0.2_kp50_noContactTerminate_aScale0.5/",
logs_root + "/field_a1_oracle/Jun03_00-33-08_Skills_climb_pEnergySubsteps2e-6_rAlive2_pTorqueExceedIndicate2e-1_propDelay0.04-0.05_noPerlinRate0.2_nSubsteps4_envFreq50.0_aScale0.5",
logs_root + "/field_a1_oracle/Jun04_01-03-59_Skills_leap_pEnergySubsteps2e-6_rAlive2_pPenV4e-3_pPenD4e-3_pPosY0.20_pYaw0.20_pTorqueExceedSquare1e-3_leapH0.2_propDelay0.04-0.05_noPerlinRate0.2_aScale0.5",
]
train_cfg.policy.jump_down_policy_path = logs_root + "/field_a1_oracle/Aug30_16-12-14_Skills_climb_climbDownProb0.5_propDelay0.04-0.05"
train_cfg.runner.resume = False
env_cfg.env.use_lin_vel = True
train_cfg.policy.cmd_vel_mapping = {
0: 1.0,
1: 0.5,
2: 0.8,
3: 1.2,
4: 1.5,
}
if args.task == "a1_distill":
env_cfg.env.obs_components = env_cfg.env.privileged_obs_components
env_cfg.env.privileged_obs_gets_privilege = False
# default camera position
# env_cfg.viewer.lookat = [0.6, 1.2, 0.5]
# env_cfg.viewer.pos = [0.6, 0., 0.5]
# prepare environment
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
@ -168,9 +157,8 @@ def play(args):
critic_obs = env.get_privileged_observations()
# register debugging options to manually trigger disruption
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_P, "push_robot")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_L, "press_robot")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_J, "action_jitter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_Q, "exit")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_ESCAPE, "exit")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_R, "agent_full_reset")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_U, "full_reset")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_C, "resample_commands")
@ -180,7 +168,24 @@ def play(args):
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_D, "rightward")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_F, "leftturn")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_G, "rightturn")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_B, "leftdrag")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_M, "rightdrag")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_X, "stop")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_K, "mark")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_I, "more_plots")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_T, "switch_teacher")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_O, "lean_fronter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_L, "lean_backer")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_COMMA, "lean_lefter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_PERIOD, "lean_righter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_DOWN, "slower")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_UP, "faster")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_LEFT, "lefter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_RIGHT, "righter")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_LEFT_BRACKET, "terrain_left")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_RIGHT_BRACKET, "terrain_right")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_MINUS, "terrain_back")
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_EQUAL, "terrain_forward")
# load policy
ppo_runner, train_cfg = task_registry.make_alg_runner(
env=env,
@ -200,6 +205,7 @@ def play(args):
export_policy_as_jit(ppo_runner.alg.actor_critic, path)
print('Exported policy as jit script to: ', path)
if RECORD_FRAMES:
os.mkdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images"), exist_ok= True)
transform = gymapi.Transform()
transform.p = gymapi.Vec3(*env_cfg.viewer.pos)
transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2)
@ -212,7 +218,7 @@ def play(args):
logger = Logger(env.dt)
robot_index = 0 # which robot is used for logging
joint_index = 4 # which joint is used for logging
stop_state_log = 512 # number of steps before plotting states
stop_state_log = args.plot_time # number of steps before plotting states
stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards
camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64)
camera_vel = np.array([0.6, 0., 0.])
@ -248,12 +254,9 @@ def play(args):
if ui_event.action == "push_robot" and ui_event.value > 0:
# manully trigger to push the robot
env._push_robots()
if ui_event.action == "press_robot" and ui_event.value > 0:
env.root_states[:, 9] = torch_rand_float(-env.cfg.domain_rand.max_push_vel_xy, 0, (env.num_envs, 1), device=env.device).squeeze(1)
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
if ui_event.action == "action_jitter" and ui_event.value > 0:
# assuming wrong action is taken
obs, critic_obs, rews, dones, infos = env.step(torch.tanh(torch.randn_like(actions)))
obs, critic_obs, rews, dones, infos = env.step(actions + torch.randn_like(actions) * 0.2)
if ui_event.action == "exit" and ui_event.value > 0:
print("exit")
exit(0)
@ -263,24 +266,169 @@ def play(args):
if ui_event.action == "full_reset" and ui_event.value > 0:
print("full_reset")
agent_model.reset()
if hasattr(ppo_runner.alg, "teacher_actor_critic"):
ppo_runner.alg.teacher_actor_critic.reset()
# print(env._get_terrain_curriculum_move([robot_index]))
obs, _ = env.reset()
if ui_event.action == "resample_commands" and ui_event.value > 0:
print("resample_commands")
env._resample_commands(torch.arange(env.num_envs, device= env.device))
if ui_event.action == "stop" and ui_event.value > 0:
if hasattr(env, "sampled_x_cmd_buffer"):
env.sampled_x_cmd_buffer[:] = 0
env.commands[:, :] = 0
if hasattr(env, "orientation_cmds"):
env.orientation_cmds[:] = env.gravity_vec
# env.stop_position.copy_(env.root_states[:, :3])
# env.command_ranges["lin_vel_x"] = [0, 0]
# env.command_ranges["lin_vel_y"] = [0, 0]
# env.command_ranges["ang_vel_yaw"] = [0, 0]
if ui_event.action == "forward" and ui_event.value > 0:
env.commands[:, 0] = env_cfg.commands.ranges.lin_vel_x[1]
# env.command_ranges["lin_vel_x"] = [env_cfg.commands.ranges.lin_vel_x[1], env_cfg.commands.ranges.lin_vel_x[1]]
if ui_event.action == "backward" and ui_event.value > 0:
env.commands[:, 0] = env_cfg.commands.ranges.lin_vel_x[0]
# env.command_ranges["lin_vel_x"] = [env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[0]]
if ui_event.action == "leftward" and ui_event.value > 0:
env.commands[:, 1] = env_cfg.commands.ranges.lin_vel_y[1]
# env.command_ranges["lin_vel_y"] = [env_cfg.commands.ranges.lin_vel_y[1], env_cfg.commands.ranges.lin_vel_y[1]]
if ui_event.action == "rightward" and ui_event.value > 0:
env.commands[:, 1] = env_cfg.commands.ranges.lin_vel_y[0]
# env.command_ranges["lin_vel_y"] = [env_cfg.commands.ranges.lin_vel_y[0], env_cfg.commands.ranges.lin_vel_y[0]]
if ui_event.action == "leftturn" and ui_event.value > 0:
env.commands[:, 2] = env_cfg.commands.ranges.ang_vel_yaw[1]
# env.command_ranges["ang_vel_yaw"] = [env_cfg.commands.ranges.ang_vel_yaw[1], env_cfg.commands.ranges.ang_vel_yaw[1]]
if ui_event.action == "rightturn" and ui_event.value > 0:
env.commands[:, 2] = env_cfg.commands.ranges.ang_vel_yaw[0]
# env.command_ranges["ang_vel_yaw"] = [env_cfg.commands.ranges.ang_vel_yaw[0], env_cfg.commands.ranges.ang_vel_yaw[0]]
if ui_event.action == "leftdrag" and ui_event.value > 0:
env.root_states[:, 7:10] += quat_rotate(env.base_quat, torch.tensor([[0., 0.5, 0.]], device= env.device))
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
if ui_event.action == "rightdrag" and ui_event.value > 0:
env.root_states[:, 7:10] += quat_rotate(env.base_quat, torch.tensor([[0., -0.5, 0.]], device= env.device))
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
if ui_event.action == "mark" and ui_event.value > 0:
logger.plot_states()
if ui_event.action == "more_plots" and ui_event.value > 0:
logger.plot_additional_states()
if ui_event.action == "switch_teacher" and ui_event.value > 0:
args.show_teacher = not args.show_teacher
print("show_teacher:", args.show_teacher)
if ui_event.action == "lean_fronter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
env.orientation_cmds[:, 0] += 0.1
print("orientation_cmds:", env.orientation_cmds[:, 0])
if ui_event.action == "lean_backer" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
env.orientation_cmds[:, 0] -= 0.1
print("orientation_cmds:", env.orientation_cmds[:, 0])
if ui_event.action == "lean_lefter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
env.orientation_cmds[:, 1] += 0.1
print("orientation_cmds:", env.orientation_cmds[:, 1])
if ui_event.action == "lean_righter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
env.orientation_cmds[:, 1] -= 0.1
print("orientation_cmds:", env.orientation_cmds[:, 1])
if ui_event.action == "slower" and ui_event.value > 0:
if hasattr(env, "sampled_x_cmd_buffer"):
env.sampled_x_cmd_buffer[:] -= 0.2
env.commands[:, 0] -= 0.2
print("command_x:", env.commands[:, 0])
if ui_event.action == "faster" and ui_event.value > 0:
if hasattr(env, "sampled_x_cmd_buffer"):
env.sampled_x_cmd_buffer[:] += 0.2
env.commands[:, 0] += 0.2
print("command_x:", env.commands[:, 0])
if ui_event.action == "lefter" and ui_event.value > 0:
if env.commands[:, 2] < 0:
env.commands[:, 2] = 0.
else:
env.commands[:, 2] += 0.4
print("command_yaw:", env.commands[:, 2])
if ui_event.action == "righter" and ui_event.value > 0:
if env.commands[:, 2] > 0:
env.commands[:, 2] = 0.
else:
env.commands[:, 2] -= 0.4
print("command_yaw:", env.commands[:, 2])
if ui_event.action == "terrain_forward" and ui_event.value > 0:
# env.cfg.terrain.curriculum = False
env.terrain_levels[:] += 1
env.terrain_levels = torch.clip(
env.terrain_levels,
min= 0,
max= env.cfg.terrain.num_rows - 1,
)
print("before", env.terrain_levels)
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
env.reset()
agent_model.reset()
print("after", env.terrain_levels)
if ui_event.action == "terrain_back" and ui_event.value > 0:
# env.cfg.terrain.curriculum = False
env.terrain_levels[:] -= 1
env.terrain_levels = torch.clip(
env.terrain_levels,
min= 0,
max= env.cfg.terrain.num_rows - 1,
)
print("before", env.terrain_levels)
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
env.reset()
agent_model.reset()
print("after", env.terrain_levels)
if ui_event.action == "terrain_right" and ui_event.value > 0:
# env.cfg.terrain.curriculum = False
env.terrain_types[:] -= 1
env.terrain_types = torch.clip(
env.terrain_types,
min= 0,
max= env.cfg.terrain.num_cols - 1,
)
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
env.reset()
agent_model.reset()
if ui_event.action == "terrain_left" and ui_event.value > 0:
# env.cfg.terrain.curriculum = False
env.terrain_types[:] += 1
env.terrain_types = torch.clip(
env.terrain_types,
min= 0,
max= env.cfg.terrain.num_cols - 1,
)
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
env.reset()
agent_model.reset()
# if (env.contact_forces[robot_index, env.feet_indices, 2] > 200).any():
# print("contact_forces:", env.contact_forces[robot_index, env.feet_indices, 2])
# if (abs(env.substep_torques[robot_index]) > 35.).any():
# exceed_idxs = torch.where(abs(env.substep_torques[robot_index]) > 35.)
# print("substep_torques:", exceed_idxs[1], env.substep_torques[robot_index][exceed_idxs[0], exceed_idxs[1]])
if env.torque_exceed_count_envstep[robot_index].any():
print("substep torque exceed limit ratio",
(torch.abs(env.substep_torques[robot_index]) / (env.torque_limits.unsqueeze(0))).max(),
"joint index",
torch.where((torch.abs(env.substep_torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any(dim= 0))[0],
"timestep", i,
)
env.torque_exceed_count_envstep[robot_index] = 0
# if (torch.abs(env.torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any():
# print("torque exceed limit ratio",
# (torch.abs(env.torques[robot_index]) / (env.torque_limits.unsqueeze(0))).max(),
# "joint index",
# torch.where((torch.abs(env.torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any(dim= 0))[0],
# "timestep", i,
# )
# dof_exceed_mask = ((env.dof_pos[robot_index] > env.dof_pos_limits[:, 1]) | (env.dof_pos[robot_index] < env.dof_pos_limits[:, 0]))
# if dof_exceed_mask.any():
# print("dof pos exceed limit: joint index",
# torch.where(dof_exceed_mask)[0],
# "amount",
# torch.maximum(
# env.dof_pos[robot_index][dof_exceed_mask] - env.dof_pos_limits[dof_exceed_mask][:, 1],
# env.dof_pos_limits[dof_exceed_mask][:, 0] - env.dof_pos[robot_index][dof_exceed_mask],
# ),
# "dof value:",
# env.dof_pos[robot_index][dof_exceed_mask],
# "timestep", i,
# )
if i < stop_state_log:
if torch.is_tensor(env.cfg.control.action_scale):
@ -308,6 +456,9 @@ def play(args):
"student_action": actions[robot_index, 2].item(),
"teacher_action": teacher_actions[robot_index, 2].item(),
"reward": rews[robot_index].item(),
'all_dof_vel': env.substep_dof_vel[robot_index].mean(-2).cpu().numpy(),
'all_dof_torque': env.substep_torques[robot_index].mean(-2).cpu().numpy(),
"power": torch.max(torch.sum(env.substep_torques * env.substep_dof_vel, dim= -1), dim= -1)[0][robot_index].item(),
}
)
elif i==stop_state_log:
@ -323,8 +474,11 @@ def play(args):
if dones.any():
agent_model.reset(dones)
print("env dones,{} because has timeout".format("" if env.time_out_buf[dones].any() else " not"))
print(infos)
if env.time_out_buf[dones].any():
print("env dones because of timeout")
else:
print("env dones because of failure")
# print(infos)
if i % 100 == 0:
print("frame_rate:" , 100/(time.time_ns() - start_time) * 1e9,
"command_x:", env.commands[robot_index, 0],
@ -333,8 +487,45 @@ def play(args):
if __name__ == '__main__':
EXPORT_POLICY = False
RECORD_FRAMES = False
MOVE_CAMERA = True
CAMERA_FOLLOW = True
args = get_args()
play(args)
args = get_args([
dict(name= "--slow", type= float, default= 0., help= "slow down the simulation by sleep secs (float) every frame"),
dict(name= "--show_teacher", action= "store_true", default= False, help= "show teacher actions"),
dict(name= "--no_teacher", action= "store_true", default= False, help= "whether to disable teacher policy when running the script"),
dict(name= "--zero_act_until", type= int, default= 0., help= "zero action until this step"),
dict(name= "--sample", action= "store_true", default= False, help= "sample actions from policy"),
dict(name= "--plot_time", type= int, default= -1, help= "plot states after this time"),
dict(name= "--no_throw", action= "store_true", default= False),
dict(name= "--load_cfg", action= "store_true", default= False, help= "use the config from the logdir"),
dict(name= "--record", action= "store_true", default= False, help= "record frames"),
dict(name= "--frames_dir", type= str, default= "images", help= "which folder to store intermediate recorded frames."),
])
MOVE_CAMERA = (args.num_envs is None)
CAMERA_FOLLOW = MOVE_CAMERA
RECORD_FRAMES = args.record
try:
play(args)
except KeyboardInterrupt:
print("KeyboardInterrupt")
finally:
if RECORD_FRAMES and args.load_run is not None:
import subprocess
print("converting frames to video")
log_dir = args.load_run if os.path.isabs(args.load_run) \
else os.path.join(
LEGGED_GYM_ROOT_DIR,
"logs",
task_registry.get_cfgs(name=args.task)[1].runner.experiment_name,
args.load_run,
)
subprocess.run(["ffmpeg",
"-framerate", "50",
"-r", "50",
"-i", "logs/images/%04d.png",
"-c:v", "libx264",
"-hide_banner", "-loglevel", "error",
os.path.join(log_dir, f"video_{args.checkpoint}.mp4")
])
print("done converting frames to video, deleting frame images")
for f in os.listdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images")):
os.remove(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir, f))
print("done deleting frame images")

View File

@ -29,12 +29,14 @@
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
np.float = np.float32
import os
from datetime import datetime
import isaacgym
from legged_gym.envs import *
from legged_gym.utils import get_args, task_registry
from legged_gym.utils import get_args
from legged_gym.utils.task_registry import task_registry
import torch
from legged_gym.debugger import break_into_debugger

View File

@ -29,7 +29,6 @@
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict
from .task_registry import task_registry
from .logger import Logger
from .math import *
from .terrain.terrain import Terrain

View File

@ -62,6 +62,12 @@ class Logger:
self.plot_process = Process(target=self._plot)
self.plot_process.start()
def plot_additional_states(self):
self.plot_vel_process = Process(target=self._plot_vel)
self.plot_vel_process.start()
self.plot_torque_process = Process(target=self._plot_torque)
self.plot_torque_process.start()
def _plot(self):
nb_rows = 4
nb_cols = 3
@ -98,7 +104,6 @@ class Logger:
a = axs[0, 2]
if log["base_pitch"]:
a.plot(time, log["base_pitch"], label='measured')
a.plot(time, [-0.75] * len(time), label= 'thresh')
# if log["command_yaw"]: a.plot(time, log["command_yaw"], label='commanded')
a.set(xlabel='time [s]', ylabel='base ang [rad]', title='Base pitch')
a.legend()
@ -157,6 +162,58 @@ class Logger:
a.legend(fontsize = 5)
plt.show()
def _plot_vel(self):
log= self.state_log
nb_rows = int(np.sqrt(log['all_dof_vel'][0].shape[0]))
nb_cols = int(np.ceil(log['all_dof_vel'][0].shape[0] / nb_rows))
nb_rows, nb_cols = nb_cols, nb_rows
fig, axs = plt.subplots(nb_rows, nb_cols)
for key, value in self.state_log.items():
time = np.linspace(0, len(value)*self.dt, len(value))
break
# plot joint velocities
for i in range(nb_rows):
for j in range(nb_cols):
if i*nb_cols+j < log['all_dof_vel'][0].shape[0]:
a = axs[i][j]
a.plot(
time,
[all_dof_vel[i*nb_cols+j] for all_dof_vel in log['all_dof_vel']],
label='measured',
)
a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title=f'Joint Velocity {i*nb_cols+j}')
a.legend()
else:
break
plt.show()
def _plot_torque(self):
log= self.state_log
nb_rows = int(np.sqrt(log['all_dof_torque'][0].shape[0]))
nb_cols = int(np.ceil(log['all_dof_torque'][0].shape[0] / nb_rows))
nb_rows, nb_cols = nb_cols, nb_rows
fig, axs = plt.subplots(nb_rows, nb_cols)
for key, value in self.state_log.items():
time = np.linspace(0, len(value)*self.dt, len(value))
break
# plot joint torques
for i in range(nb_rows):
for j in range(nb_cols):
if i*nb_cols+j < log['all_dof_torque'][0].shape[0]:
a = axs[i][j]
a.plot(
time,
[all_dof_torque[i*nb_cols+j] for all_dof_torque in log['all_dof_torque']],
label='measured',
)
a.set(xlabel='time [s]', ylabel='Torque [Nm]', title=f'Joint Torque {i*nb_cols+j}')
a.legend()
else:
break
plt.show()
def print_rewards(self):
print("Average rewards per second:")
for key, values in self.rew_log.items():

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
import numpy as np
import torch
from numpy.random import choice
from scipy import interpolate
@ -13,13 +14,14 @@ class TerrainPerlin:
self.env_width = cfg.terrain_width
self.xSize = cfg.terrain_length * cfg.num_rows # int(cfg.horizontal_scale * cfg.tot_cols)
self.ySize = cfg.terrain_width * cfg.num_cols # int(cfg.horizontal_scale * cfg.tot_rows)
self.tot_cols = int(self.xSize / cfg.horizontal_scale)
self.tot_rows = int(self.ySize / cfg.horizontal_scale)
self.tot_rows = int(self.xSize / cfg.horizontal_scale)
self.tot_cols = int(self.ySize / cfg.horizontal_scale)
assert(self.xSize == cfg.horizontal_scale * self.tot_rows and self.ySize == cfg.horizontal_scale * self.tot_cols)
self.heightsamples_float = self.generate_fractal_noise_2d(self.xSize, self.ySize, self.tot_rows, self.tot_cols, **cfg.TerrainPerlin_kwargs)
# self.heightsamples_float[self.tot_cols//2 - 100:, :] += 100000
# self.heightsamples_float[self.tot_cols//2 - 40: self.tot_cols//2 + 40, :] = np.mean(self.heightsamples_float)
self.heightsamples = (self.heightsamples_float * (1 / cfg.vertical_scale)).astype(np.int16)
self.heightfield_raw_pyt = torch.tensor(self.heightsamples, device= "cpu")
print("Terrain heightsamples shape: ", self.heightsamples.shape)
@ -111,3 +113,31 @@ class TerrainPerlin:
int(origin_y / self.cfg.horizontal_scale),
] * self.cfg.vertical_scale,
]
self.heightfield_raw_pyt = torch.from_numpy(self.heightsamples).to(device= self.device).float()
def in_terrain_range(self, pos):
""" Check if the given position still have terrain underneath. (same x/y, but z is different)
pos: (batch_size, 3) torch.Tensor
"""
return torch.logical_and(
pos[..., :2] >= 0,
pos[..., :2] < torch.tensor([self.xSize, self.ySize], device= self.device),
).all(dim= -1)
@torch.no_grad()
def get_terrain_heights(self, points):
""" Get the terrain heights below the given points """
points_shape = points.shape
points = points.view(-1, 3)
points_x_px = (points[:, 0] / self.cfg.horizontal_scale).to(int)
points_y_px = (points[:, 1] / self.cfg.horizontal_scale).to(int)
out_of_range_mask = torch.logical_or(
torch.logical_or(points_x_px < 0, points_x_px >= self.heightfield_raw_pyt.shape[0]),
torch.logical_or(points_y_px < 0, points_y_px >= self.heightfield_raw_pyt.shape[1]),
)
points_x_px = torch.clip(points_x_px, 0, self.heightfield_raw_pyt.shape[0] - 1)
points_y_px = torch.clip(points_y_px, 0, self.heightfield_raw_pyt.shape[1] - 1)
heights = self.heightfield_raw_pyt[points_x_px, points_y_px] * self.cfg.vertical_scale
heights[out_of_range_mask] = - torch.inf
heights = heights.view(points_shape[:-1])
return heights

View File

@ -29,6 +29,7 @@
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
import torch
from numpy.random import choice
from scipy import interpolate
@ -162,6 +163,65 @@ class Terrain:
y2 = int((self.env_width/2. + 1) / terrain.horizontal_scale)
env_origin_z = np.max(terrain.height_field_raw[x1:x2, y1:y2])*terrain.vertical_scale
self.env_origins[i, j] = [env_origin_x, env_origin_y, env_origin_z]
def _create_heightfield(self, gym, sim, device= "cpu"):
""" Adds a heightfield terrain to the simulation, sets parameters based on the cfg.
"""
hf_params = gym.HeightFieldParams()
hf_params.column_scale = self.cfg.horizontal_scale
hf_params.row_scale = self.cfg.horizontal_scale
hf_params.vertical_scale = self.cfg.vertical_scale
hf_params.nbRows = self.tot_cols
hf_params.nbColumns = self.tot_rows
hf_params.transform.p.x = -self.cfg.border_size
hf_params.transform.p.y = -self.cfg.border_size
hf_params.transform.p.z = 0.0
hf_params.static_friction = self.cfg.static_friction
hf_params.dynamic_friction = self.cfg.dynamic_friction
hf_params.restitution = self.cfg.restitution
self.gym.add_heightfield(sim, self.heightsamples, hf_params)
def _create_trimesh(self, gym, sim, device= "cpu"):
""" Adds a triangle mesh terrain to the simulation, sets parameters based on the cfg.
# """
tm_params = gym.TriangleMeshParams()
tm_params.nb_vertices = self.vertices.shape[0]
tm_params.nb_triangles = self.triangles.shape[0]
tm_params.transform.p.x = -self.cfg.border_size
tm_params.transform.p.y = -self.cfg.border_size
tm_params.transform.p.z = 0.0
tm_params.static_friction = self.cfg.static_friction
tm_params.dynamic_friction = self.cfg.dynamic_friction
tm_params.restitution = self.cfg.restitution
self.gym.add_triangle_mesh(self.sim, self.vertices.flatten(order='C'), self.triangles.flatten(order='C'), tm_params)
def add_terrain_to_sim(self, gym, sim, device= "cpu"):
if self.type == "heightfield":
self._create_heightfield(gym, sim, device)
elif self.type == "trimesh":
self._create_trimesh(gym, sim, device)
else:
raise NotImplementedError("Terrain type {} not implemented".format(self.type))
def get_terrain_heights(self, points):
""" Return the z coordinate of the terrain where just below the given points. """
num_robots = points.shape[0]
points += self.cfg.border_size
points = (points/self.cfg.horizontal_scale).long()
px = points[:, :, 0].view(-1)
py = points[:, :, 1].view(-1)
px = torch.clip(px, 0, self.heightsamples.shape[0]-2)
py = torch.clip(py, 0, self.heightsamples.shape[1]-2)
heights1 = self.heightsamples[px, py]
heights2 = self.heightsamples[px+1, py]
heights3 = self.heightsamples[px, py+1]
heights = torch.min(heights1, heights2)
heights = torch.min(heights, heights3)
return heights.view(num_robots, -1) * self.cfg.vertical_scale
def gap_terrain(terrain, gap_size, platform_size=1.):
gap_size = int(gap_size / terrain.horizontal_scale)

View File

@ -0,0 +1,2 @@
# then you can import it like `from legged_gym.utils.webviewer import WebViewer``
from .webviewer import WebViewer

View File

@ -0,0 +1,80 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<style>
html, body {
width: 100%; height: 100%;
margin: 0; overflow: hidden; display: block;
background-color: #000;
}
</style>
</head>
<body>
<div>
<canvas id="canvas" tabindex='1'></canvas>
</div>
<script>
var canvas, context, image;
function sendInputRequest(data){
let xmlRequest = new XMLHttpRequest();
xmlRequest.open("POST", "{{ url_for('_route_input_event') }}", true);
xmlRequest.setRequestHeader("Content-Type", "application/json");
xmlRequest.send(JSON.stringify(data));
}
window.onload = function(){
canvas = document.getElementById("canvas");
context = canvas.getContext('2d');
image = new Image();
image.src = "{{ url_for('_route_stream') }}";
image1 = new Image();
image1.src = "{{ url_for('_route_stream_depth') }}";
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
window.addEventListener('resize', function(){
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
}, false);
window.setInterval(function(){
let ratio = image.naturalWidth / image.naturalHeight;
context.drawImage(image, 0, 0, canvas.width, canvas.width / ratio);
let imageHeight = canvas.width / ratio;
context.drawImage(image1, 0, imageHeight, canvas.width, canvas.width / image1.naturalWidth * image1.naturalHeight);
}, 50);
canvas.addEventListener('keydown', function(event){
if(event.keyCode != 18)
sendInputRequest({key: event.keyCode});
}, false);
canvas.addEventListener('mousemove', function(event){
if(event.buttons){
let data = {dx: event.movementX, dy: event.movementY};
if(event.altKey && event.buttons == 1){
data.key = 18;
data.mouse = "left";
}
else if(event.buttons == 2)
data.mouse = "right";
else if(event.buttons == 4)
data.mouse = "middle";
else
return;
sendInputRequest(data);
}
}, false);
canvas.addEventListener('wheel', function(event){
sendInputRequest({mouse: "wheel", dz: Math.sign(event.deltaY)});
}, false);
}
</script>
</body>
</html>

View File

@ -0,0 +1,442 @@
from typing import List, Optional
import logging
import math
import threading
import numpy as np
import torch
from legged_gym import LEGGED_GYM_ROOT_DIR
import os
try:
import flask
except ImportError:
flask = None
try:
import imageio
import isaacgym
import isaacgym.torch_utils as torch_utils
from isaacgym import gymapi
except ImportError:
imageio = None
isaacgym = None
torch_utils = None
gymapi = None
def cartesian_to_spherical(x, y, z):
r = np.sqrt(x**2 + y**2 + z**2)
theta = np.arccos(z/r) if r != 0 else 0
phi = np.arctan2(y, x)
return r, theta, phi
def spherical_to_cartesian(r, theta, phi):
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return x, y, z
class WebViewer:
def __init__(self, host: str = "127.0.0.1", port: int = 5000) -> None:
"""
Web viewer for Isaac Gym
:param host: Host address (default: "127.0.0.1")
:type host: str
:param port: Port number (default: 5000)
:type port: int
"""
self._app = flask.Flask(__name__)
self._app.add_url_rule("/", view_func=self._route_index)
self._app.add_url_rule("/_route_stream", view_func=self._route_stream)
self._app.add_url_rule("/_route_stream_depth", view_func=self._route_stream_depth)
self._app.add_url_rule("/_route_input_event", view_func=self._route_input_event, methods=["POST"])
self._log = logging.getLogger('werkzeug')
self._log.disabled = True
self._app.logger.disabled = True
self._image = None
self._image_depth = None
self._camera_id = 0
self._camera_type = gymapi.IMAGE_COLOR
# get from self._env and stream to webviewer (expected shape: num_envs, 1, height, width)
self._depth_image_buffer_name = "forward_depth_output"
self._notified = False
self._wait_for_page = True
self._pause_stream = False
self._event_load = threading.Event()
self._event_stream = threading.Event()
self._event_stream_depth = threading.Event()
# start server
self._thread = threading.Thread(target=lambda: \
self._app.run(host=host, port=port, debug=False, use_reloader=False), daemon=True)
self._thread.start()
print(f"\nStarting web viewer on http://{host}:{port}/\n")
def _route_index(self) -> 'flask.Response':
"""Render the web page
:return: Flask response
:rtype: flask.Response
"""
with open(os.path.join(os.path.dirname(__file__), "webviewer.html"), 'r', encoding='utf-8') as file:
template = file.read()
self._event_load.set()
return flask.render_template_string(template)
def _route_stream(self) -> 'flask.Response':
"""Stream the image to the web page
:return: Flask response
:rtype: flask.Response
"""
return flask.Response(self._stream(), mimetype='multipart/x-mixed-replace; boundary=frame')
def _route_stream_depth(self) -> 'flask.Response':
return flask.Response(self._stream_depth(), mimetype='multipart/x-mixed-replace; boundary=frame')
def _route_input_event(self) -> 'flask.Response':
# get keyboard and mouse inputs
data = flask.request.get_json()
key, mouse = data.get("key", None), data.get("mouse", None)
dx, dy, dz = data.get("dx", None), data.get("dy", None), data.get("dz", None)
transform = self._gym.get_camera_transform(self._sim,
self._envs[self._camera_id],
self._cameras[self._camera_id])
# zoom in/out
if mouse == "wheel":
# compute zoom vector
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
r += 0.05 * dz
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
# orbit camera
elif mouse == "left":
# convert mouse movement to angle
dx *= 0.2 * math.pi / 180
dy *= 0.2 * math.pi / 180
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
theta -= dy
phi -= dx
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
# pan camera
elif mouse == "right":
# convert mouse movement to angle
dx *= -0.2 * math.pi / 180
dy *= -0.2 * math.pi / 180
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
theta += dy
phi += dx
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
elif key == 219: # prev
self._camera_id = (self._camera_id-1) % self._env.num_envs
return flask.Response(status=200)
elif key == 221: # next
self._camera_id = (self._camera_id+1) % self._env.num_envs
return flask.Response(status=200)
# pause stream (V: 86)
elif key == 86:
self._pause_stream = not self._pause_stream
return flask.Response(status=200)
# change image type (T: 84)
elif key == 84:
if self._camera_type == gymapi.IMAGE_COLOR:
self._camera_type = gymapi.IMAGE_DEPTH
elif self._camera_type == gymapi.IMAGE_DEPTH:
self._camera_type = gymapi.IMAGE_COLOR
return flask.Response(status=200)
else:
return flask.Response(status=200)
return flask.Response(status=200)
def _stream(self) -> bytes:
"""Format the image to be streamed
:return: Image encoded as Content-Type
:rtype: bytes
"""
while True:
self._event_stream.wait()
# prepare image
image = imageio.imwrite("<bytes>", self._image, format="JPEG")
# stream image
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + image + b'\r\n')
self._event_stream.clear()
self._notified = False
def _stream_depth(self) -> bytes:
while self._env.cfg.viewer.stream_depth:
self._event_stream_depth.wait()
# prepare image
image = imageio.imwrite("<bytes>", self._image_depth, format="JPEG")
# stream image
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + image + b'\r\n')
self._event_stream_depth.clear()
def attach_view_camera(self, i, env_handle, actor_handle, root_pos):
if True:
camera_props = gymapi.CameraProperties()
camera_props.width = 960
camera_props.height = 540
# camera_props.enable_tensors = True
# camera_props.horizontal_fov = camera_horizontal_fov
camera_handle = self._gym.create_camera_sensor(env_handle, camera_props)
self._cameras.append(camera_handle)
cam_pos = root_pos + np.array([0, 1, 0.5])
self._gym.set_camera_location(camera_handle, env_handle, gymapi.Vec3(*cam_pos), gymapi.Vec3(*root_pos))
def setup(self, env) -> None:
"""Setup the web viewer
:param gym: The gym
:type gym: isaacgym.gymapi.Gym
:param sim: Simulation handle
:type sim: isaacgym.gymapi.Sim
:param envs: Environment handles
:type envs: list of ints
:param cameras: Camera handles
:type cameras: list of ints
"""
self._gym = env.gym
self._sim = env.sim
self._envs = env.envs
self._cameras = []
self._env = env
self.cam_pos_rel = np.array([0, 2, 1])
for i in range(self._env.num_envs):
root_pos = self._env.root_states[i, :3].cpu().numpy()
self.attach_view_camera(i, self._envs[i], self._env.actor_handles[i], root_pos)
def render(self,
fetch_results: bool = True,
step_graphics: bool = True,
render_all_camera_sensors: bool = True,
wait_for_page_load: bool = True) -> None:
"""Render and get the image from the current camera
This function must be called after the simulation is stepped (post_physics_step).
The following Isaac Gym functions are called before get the image.
Their calling can be skipped by setting the corresponding argument to False
- fetch_results
- step_graphics
- render_all_camera_sensors
:param fetch_results: Call Gym.fetch_results method (default: True)
:type fetch_results: bool
:param step_graphics: Call Gym.step_graphics method (default: True)
:type step_graphics: bool
:param render_all_camera_sensors: Call Gym.render_all_camera_sensors method (default: True)
:type render_all_camera_sensors: bool
:param wait_for_page_load: Wait for the page to load (default: True)
:type wait_for_page_load: bool
"""
# wait for page to load
if self._wait_for_page:
if wait_for_page_load:
if not self._event_load.is_set():
print("Waiting for web page to begin loading...")
self._event_load.wait()
self._event_load.clear()
self._wait_for_page = False
# pause stream
if self._pause_stream:
return
if self._notified:
return
# isaac gym API
if fetch_results:
self._gym.fetch_results(self._sim, True)
if step_graphics:
self._gym.step_graphics(self._sim)
if render_all_camera_sensors:
self._gym.render_all_camera_sensors(self._sim)
# get image
image = self._gym.get_camera_image(self._sim,
self._envs[self._camera_id],
self._cameras[self._camera_id],
self._camera_type)
if self._camera_type == gymapi.IMAGE_COLOR:
self._image = image.reshape(image.shape[0], -1, 4)[..., :3]
elif self._camera_type == gymapi.IMAGE_DEPTH:
self._image = -image.reshape(image.shape[0], -1)
minimum = 0 if np.isinf(np.min(self._image)) else np.min(self._image)
maximum = 5 if np.isinf(np.max(self._image)) else np.max(self._image)
self._image = np.clip(1 - (self._image - minimum) / (maximum - minimum), 0, 1)
self._image = np.uint8(255 * self._image)
else:
raise ValueError("Unsupported camera type")
if self._env.cfg.viewer.stream_depth:
self._image_depth = getattr(self._env, self._depth_image_buffer_name)[self._camera_id, -1].cpu().numpy() # expected value range: [0, 1]
self._image_depth = np.uint8(255 * self._image_depth)
root_pos = self._env.root_states[self._camera_id, :3].cpu().numpy()
cam_pos = root_pos + self.cam_pos_rel
self._gym.set_camera_location(self._cameras[self._camera_id], self._envs[self._camera_id], gymapi.Vec3(*cam_pos), gymapi.Vec3(*root_pos))
# notify stream thread
self._event_stream.set()
if self._env.cfg.viewer.stream_depth:
self._event_stream_depth.set()
self._notified = True
def ik(jacobian_end_effector: torch.Tensor,
current_position: torch.Tensor,
current_orientation: torch.Tensor,
goal_position: torch.Tensor,
goal_orientation: Optional[torch.Tensor] = None,
damping_factor: float = 0.05,
squeeze_output: bool = True) -> torch.Tensor:
"""
Inverse kinematics using damped least squares method
:param jacobian_end_effector: End effector's jacobian
:type jacobian_end_effector: torch.Tensor
:param current_position: End effector's current position
:type current_position: torch.Tensor
:param current_orientation: End effector's current orientation
:type current_orientation: torch.Tensor
:param goal_position: End effector's goal position
:type goal_position: torch.Tensor
:param goal_orientation: End effector's goal orientation (default: None)
:type goal_orientation: torch.Tensor or None
:param damping_factor: Damping factor (default: 0.05)
:type damping_factor: float
:param squeeze_output: Squeeze output (default: True)
:type squeeze_output: bool
:return: Change in joint angles
:rtype: torch.Tensor
"""
if goal_orientation is None:
goal_orientation = current_orientation
# compute error
q = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation))
error = torch.cat([goal_position - current_position, # position error
q[:, 0:3] * torch.sign(q[:, 3]).unsqueeze(-1)], # orientation error
dim=-1).unsqueeze(-1)
# solve damped least squares (dO = J.T * V)
transpose = torch.transpose(jacobian_end_effector, 1, 2)
lmbda = torch.eye(6, device=jacobian_end_effector.device) * (damping_factor ** 2)
if squeeze_output:
return (transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error).squeeze(dim=2)
else:
return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error
def print_arguments(args):
print("")
print("Arguments")
for a in args.__dict__:
print(f" |-- {a}: {args.__getattribute__(a)}")
def print_asset_options(asset_options: 'isaacgym.gymapi.AssetOptions', asset_name: str = ""):
attrs = ["angular_damping", "armature", "collapse_fixed_joints", "convex_decomposition_from_submeshes",
"default_dof_drive_mode", "density", "disable_gravity", "fix_base_link", "flip_visual_attachments",
"linear_damping", "max_angular_velocity", "max_linear_velocity", "mesh_normal_mode", "min_particle_mass",
"override_com", "override_inertia", "replace_cylinder_with_capsule", "tendon_limit_stiffness", "thickness",
"use_mesh_materials", "use_physx_armature", "vhacd_enabled"] # vhacd_params
print("\nAsset options{}".format(f" ({asset_name})" if asset_name else ""))
for attr in attrs:
print(" |-- {}: {}".format(attr, getattr(asset_options, attr) if hasattr(asset_options, attr) else "--"))
# vhacd attributes
if attr == "vhacd_enabled" and hasattr(asset_options, attr) and getattr(asset_options, attr):
vhacd_attrs = ["alpha", "beta", "concavity", "convex_hull_approximation", "convex_hull_downsampling",
"max_convex_hulls", "max_num_vertices_per_ch", "min_volume_per_ch", "mode", "ocl_acceleration",
"pca", "plane_downsampling", "project_hull_vertices", "resolution"]
print(" |-- vhacd_params:")
for vhacd_attr in vhacd_attrs:
print(" | |-- {}: {}".format(vhacd_attr, getattr(asset_options.vhacd_params, vhacd_attr) \
if hasattr(asset_options.vhacd_params, vhacd_attr) else "--"))
def print_sim_components(gym, sim):
print("")
print("Sim components")
print(" |-- env count:", gym.get_env_count(sim))
print(" |-- actor count:", gym.get_sim_actor_count(sim))
print(" |-- rigid body count:", gym.get_sim_rigid_body_count(sim))
print(" |-- joint count:", gym.get_sim_joint_count(sim))
print(" |-- dof count:", gym.get_sim_dof_count(sim))
print(" |-- force sensor count:", gym.get_sim_force_sensor_count(sim))
def print_env_components(gym, env):
print("")
print("Env components")
print(" |-- actor count:", gym.get_actor_count(env))
print(" |-- rigid body count:", gym.get_env_rigid_body_count(env))
print(" |-- joint count:", gym.get_env_joint_count(env))
print(" |-- dof count:", gym.get_env_dof_count(env))
def print_actor_components(gym, env, actor):
print("")
print("Actor components")
print(" |-- rigid body count:", gym.get_actor_rigid_body_count(env, actor))
print(" |-- joint count:", gym.get_actor_joint_count(env, actor))
print(" |-- dof count:", gym.get_actor_dof_count(env, actor))
print(" |-- actuator count:", gym.get_actor_actuator_count(env, actor))
print(" |-- rigid shape count:", gym.get_actor_rigid_shape_count(env, actor))
print(" |-- soft body count:", gym.get_actor_soft_body_count(env, actor))
print(" |-- tendon count:", gym.get_actor_tendon_count(env, actor))
def print_dof_properties(gymapi, props):
print("")
print("DOF properties")
print(" |-- hasLimits:", props["hasLimits"])
print(" |-- lower:", props["lower"])
print(" |-- upper:", props["upper"])
print(" |-- driveMode:", props["driveMode"])
print(" | |-- {}: gymapi.DOF_MODE_NONE".format(int(gymapi.DOF_MODE_NONE)))
print(" | |-- {}: gymapi.DOF_MODE_POS".format(int(gymapi.DOF_MODE_POS)))
print(" | |-- {}: gymapi.DOF_MODE_VEL".format(int(gymapi.DOF_MODE_VEL)))
print(" | |-- {}: gymapi.DOF_MODE_EFFORT".format(int(gymapi.DOF_MODE_EFFORT)))
print(" |-- stiffness:", props["stiffness"])
print(" |-- damping:", props["damping"])
print(" |-- velocity (max):", props["velocity"])
print(" |-- effort (max):", props["effort"])
print(" |-- friction:", props["friction"])
print(" |-- armature:", props["armature"])
def print_links_and_dofs(gym, asset):
link_dict = gym.get_asset_rigid_body_dict(asset)
dof_dict = gym.get_asset_dof_dict(asset)
print("")
print("Links")
for k in link_dict:
print(f" |-- {k}: {link_dict[k]}")
print("DOFs")
for k in dof_dict:
print(f" |-- {k}: {dof_dict[k]}")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ setup(
install_requires=['isaacgym',
'rsl-rl',
'matplotlib',
'tensorboard',
'tensorboardX',
'debugpy']
)

View File

@ -1,9 +1,9 @@
# Deploy the model on your real Unitree robot
# Deploy the model on your real Unitree Go1 robot
This version shows an example of how to deploy the model on the Unittree Go1 robot (with Nvidia Jetson NX).
## Install dependencies on Go1
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script` to the `parkour` folder.
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script/go1` to the `parkour` folder.
1. Install ROS and the [unitree ros package for Go1](https://github.com/Tsinghua-MARS-Lab/unitree_ros_real.git) and follow the instructions to set up the robot on branch `go1`
@ -29,7 +29,7 @@ To deploy the trained model on Go1, please set up a folder on your robot, e.g. `
4. 3D print the camera mount for Go1 using the step file in `go1_ckpts/go1_camMount_30Down.step`. Use the two pairs of screw holes to mount the Intel Realsense D435i camera on the robot, as shown in the picture below.
<p align="center">
<img src="images/go1_camMount_30Down.png" width="50%"/>
<img src="../images/go1_camMount_30Down.png" width="50%"/>
</p>
## Run the model on Go1

View File

@ -0,0 +1,73 @@
# Deploy the model on your real Unitree Go2 robot
This file shows an example of how to deploy the model on the Unittree Go2 robot (with Nvidia Jetson NX).
The code is a quick start of the deployment and fit the simulation as much as possible. You can modify the code to fit your own project.
## Install dependencies on Go2
1. Take Nvidia Jetson Orin as an exmaple, make sure your JetPack and related software are up-to-date.
2. Install ROS and the [unitree ros package for Go2](https://support.unitree.com/home/en/developer/ROS2_service)
3. Set up a folder on your robot for this project, e.g. `parkour`. Then `cd` into it.
4. Create a python virtual env and install the dependencies.
- Install pytorch on a Python 3 environment.
```bash
sudo apt-get install python3-pip python3-dev python3-venv
python3 -m venv parkour_venv
source parkour_venv/bin/activate
```
- Download the pip wheel file from [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048) with v1.10.0. Then install it with
```bash
pip install torch-1.10.0-cp36-cp36m-linux_aarch64.whl
```
- Install `ros2-numpy` from [here](https://github.com/nitesh-subedi/ros2_numpy) in a new colcon_ws, where you prefer.
```bash
pip install transformations pybase64
mkdir -p ros2_numpy_ws/src
cd ros2_numpy_ws/src
git clone https://github.com/nitesh-subedi/ros2_numpy.git
cd ../
colcon build
```
4. Copy folders of this project.
- copy the `rsl_rl` folder to the `parkour` folder.
- copy the distilled parkour log folder (e.g. **Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08**) to the `parkour` folder.
- copy all the files in `onboard_script/go2` to the `parkour` folder.
3. Install rsl_rl and other dependencies.
```bash
pip install -e ./rsl_rl
```
## Run the model on Go2
***Disclaimer:*** *Always put a safety belt on the robot when the robot moves. The robot may fall down and cause damage to itself or the environment.*
1. Put the robot on the ground, power on the robot, and **turn off the builtin sport service**. Make sure your Intel Realsense D435i camera is connected to the robot and the camera is installed where you calibrated the camera.
> To turn off the builtin sport service, please refer to the [official guide](https://support.unitree.com/home/zh/developer/Basic_motion_control) and [official example](https://github.com/unitreerobotics/unitree_sdk2/blob/main/example/low_level/stand_example_go2.cpp#L184)
2. Launch 2 terminals onboard (whether 2 ssh connections from your computer or something else), named T_visual, T_run. Source the Unitree ROS environment, `ros2_numpy_ws` and the python virtual environment in both terminals.
3. In T_visual, run
```bash
cd parkour
python go2_visual.py --logdir Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08
```
where `Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08` is the logdir of the distilled model.
4. In T_run, run
```bash
cd parkour
python go2_run.py --logdir Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08
```
where `Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08` is the logdir of the distilled model.
Currently, the robot will not actually move its motors. You may see the ros topics. If you want to let the robot move, you can add argument `--nodryrun` in the command line, but **be careful**.

Binary file not shown.

View File

@ -0,0 +1,183 @@
import rclpy
from rclpy.node import Node
from unitree_ros2_real import UnitreeRos2Real, get_euler_xyz
import os
import os.path as osp
import json
import time
from collections import OrderedDict
from copy import deepcopy
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from rsl_rl import modules
class ZeroActModel(torch.nn.Module):
def __init__(self, angle_tolerance= 0.15, delta= 0.2):
super().__init__()
self.angle_tolerance = angle_tolerance
self.delta = delta
def forward(self, dof_pos):
target = torch.zeros_like(dof_pos)
diff = dof_pos - target
diff_large_mask = torch.abs(diff) > self.angle_tolerance
target[diff_large_mask] = dof_pos[diff_large_mask] \
- self.delta * torch.sign(diff[diff_large_mask])
return target
class Go2Node(UnitreeRos2Real):
def __init__(self, *args, **kwargs):
super().__init__(*args, robot_class_name= "Go2", **kwargs)
def register_models(self, stand_model, task_model, task_policy):
self.stand_model = stand_model
self.task_model = task_model
self.task_policy = task_policy
self.use_stand_policy = True # Start with standing model
def start_main_loop_timer(self, duration):
self.main_loop_timer = self.create_timer(
duration, # in sec
self.main_loop,
)
def main_loop(self):
if (self.joy_stick_buffer.keys & self.WirelessButtons.L1) and self.use_stand_policy:
self.get_logger().info("L1 pressed, stop using stand policy")
self.use_stand_policy = False
if self.use_stand_policy:
obs = self._get_dof_pos_obs() # do not multiply by obs_scales["dof_pos"]
action = self.stand_model(obs)
if (action == 0).all():
self.get_logger().info("All actions are zero, it's time to switch to the policy", throttle_duration_sec= 1)
# else:
# print("maximum dof error: {:.3f}".format(action.abs().max().item(), end= "\r"))
self.send_action(action / self.action_scale)
else:
# start_time = time.monotonic()
obs = self.get_obs()
# obs_time = time.monotonic()
action = self.task_policy(obs)
# policy_time = time.monotonic()
self.send_action(action)
# self.send_action(self._get_dof_pos_obs() / self.action_scale)
# publish_time = time.monotonic()
# print(
# "obs_time: {:.5f}".format(obs_time - start_time),
# "policy_time: {:.5f}".format(policy_time - obs_time),
# "publish_time: {:.5f}".format(publish_time - policy_time),
# )
if (self.joy_stick_buffer.keys & self.WirelessButtons.Y):
self.get_logger().info("Y pressed, reset the policy")
self.task_model.reset()
@torch.inference_mode()
def main(args):
rclpy.init()
assert args.logdir is not None, "Please provide a logdir"
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
# modify the config_dict if needed
config_dict["control"]["computer_clip_torque"] = True
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
device = "cuda"
env_node = Go2Node(
"go2",
# low_cmd_topic= "low_cmd_dryrun", # for the dryrun safety
cfg= config_dict,
replace_obs_with_embeddings= ["forward_depth"],
model_device= device,
dryrun= not args.nodryrun,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs = env_node.num_obs,
num_critic_obs = env_node.num_privileged_obs,
num_actions= env_node.num_actions,
obs_segments= env_node.obs_segments,
privileged_obs_segments= env_node.privileged_obs_segments,
**config_dict["policy"],
)
# load the model with the latest checkpoint
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.eval()
model.to(device)
env_node.get_logger().info("Model loaded from: {}".format(osp.join(args.logdir, model_names[-1])))
env_node.get_logger().info("Control Duration: {} sec".format(duration))
env_node.get_logger().info("Motor Stiffness (kp): {}".format(env_node.p_gains))
env_node.get_logger().info("Motor Damping (kd): {}".format(env_node.d_gains))
# zero_act_model to start the safe standing
zero_act_model = ZeroActModel()
zero_act_model = torch.jit.script(zero_act_model)
# magically modify the model to use the components other than the forward depth encoders
memory_a = model.memory_a
mlp = model.actor
@torch.jit.script
def policy(obs: torch.Tensor):
rnn_embedding = memory_a(obs)
action = mlp(rnn_embedding)
return action
if hasattr(model, "replace_state_prob"):
# the case where lin_vel is estimated by the state estimator
memory_s = model.memory_s
estimator = model.state_estimator
rnn_policy = policy
@torch.jit.script
def policy(obs: torch.Tensor):
estimator_input = obs[:, 3:48]
memory_s_embedding = memory_s(estimator_input)
estimated_state = estimator(memory_s_embedding)
obs[:, :3] = estimated_state
return rnn_policy(obs)
env_node.register_models(
zero_act_model,
model,
policy,
)
env_node.start_ros_handlers()
if args.loop_mode == "while":
rclpy.spin_once(env_node, timeout_sec= 0.)
env_node.get_logger().info("Model and Policy are ready")
while rclpy.ok():
main_loop_time = time.monotonic()
env_node.main_loop()
rclpy.spin_once(env_node, timeout_sec= 0.)
# env_node.get_logger().info("loop time: {:f}".format((time.monotonic() - main_loop_time)))
time.sleep(max(0, duration - (time.monotonic() - main_loop_time)))
elif args.loop_mode == "timer":
env_node.start_main_loop_timer(duration)
rclpy.spin(env_node)
rclpy.shutdown()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--logdir", type= str, default= None, help= "The directory which contains the config.json and model_*.pt files")
parser.add_argument("--nodryrun", action= "store_true", default= False, help= "Disable dryrun mode")
parser.add_argument("--loop_mode", type= str, default= "timer",
choices= ["while", "timer"],
help= "Select which mode to run the main policy control iteration",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,383 @@
import rclpy
from rclpy.node import Node
from unitree_ros2_real import UnitreeRos2Real
from std_msgs.msg import Float32MultiArray
from sensor_msgs.msg import Image, CameraInfo
import os
import os.path as osp
import json
import time
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from rsl_rl import modules
import pyrealsense2 as rs
import ros2_numpy as rnp
@torch.no_grad()
def resize2d(img, size):
return (F.adaptive_avg_pool2d(Variable(img), size)).data
class VisualHandlerNode(Node):
""" A wapper class for the realsense camera """
def __init__(self,
cfg: dict,
cropping: list = [0, 0, 0, 0], # top, bottom, left, right
rs_resolution: tuple = (480, 270), # width, height for the realsense camera)
rs_fps: int= 30,
depth_input_topic= "/camera/forward_depth",
rgb_topic= "/camera/forward_rgb",
camera_info_topic= "/camera/camera_info",
enable_rgb= False,
forward_depth_embedding_topic= "/forward_depth_embedding",
):
super().__init__("forward_depth_embedding")
self.cfg = cfg
self.cropping = cropping
self.rs_resolution = rs_resolution
self.rs_fps = rs_fps
self.depth_input_topic = depth_input_topic
self.rgb_topic= rgb_topic
self.camera_info_topic = camera_info_topic
self.enable_rgb= enable_rgb
self.forward_depth_embedding_topic = forward_depth_embedding_topic
self.parse_args()
self.start_pipeline()
self.start_ros_handlers()
def parse_args(self):
self.output_resolution = self.cfg["sensor"]["forward_camera"].get(
"output_resolution",
self.cfg["sensor"]["forward_camera"]["resolution"],
)
depth_range = self.cfg["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
self.depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
def start_pipeline(self):
self.rs_pipeline = rs.pipeline()
self.rs_config = rs.config()
self.rs_config.enable_stream(
rs.stream.depth,
self.rs_resolution[0],
self.rs_resolution[1],
rs.format.z16,
self.rs_fps,
)
if self.enable_rgb:
self.rs_config.enable_stream(
rs.stream.color,
self.rs_resolution[0],
self.rs_resolution[1],
rs.format.rgb8,
self.rs_fps,
)
self.rs_profile = self.rs_pipeline.start(self.rs_config)
self.rs_align = rs.align(rs.stream.depth)
# build rs builtin filters
# self.rs_decimation_filter = rs.decimation_filter()
# self.rs_decimation_filter.set_option(rs.option.filter_magnitude, 6)
self.rs_hole_filling_filter = rs.hole_filling_filter()
self.rs_spatial_filter = rs.spatial_filter()
self.rs_spatial_filter.set_option(rs.option.filter_magnitude, 5)
self.rs_spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
self.rs_spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
self.rs_spatial_filter.set_option(rs.option.holes_fill, 4)
self.rs_temporal_filter = rs.temporal_filter()
self.rs_temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
self.rs_temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# using a list of filters to define the filtering order
self.rs_filters = [
# self.rs_decimation_filter,
self.rs_hole_filling_filter,
self.rs_spatial_filter,
self.rs_temporal_filter,
]
if self.enable_rgb:
# get frame with longer waiting time to start the system
# I know what's going on, but when enabling rgb, this solves the problem.
rs_frame = self.rs_pipeline.wait_for_frames(int(
self.cfg["sensor"]["forward_camera"]["latency_range"][1] * 10000 # ms * 10
))
def start_ros_handlers(self):
self.depth_input_pub = self.create_publisher(
Image,
self.depth_input_topic,
1,
)
if self.enable_rgb:
self.rgb_pub = self.create_publisher(
Image,
self.rgb_topic,
1,
)
self.camera_info_pub = self.create_publisher(
CameraInfo,
self.camera_info_topic,
1,
)
# fill in critical info of processed camera info based on simulated data
# NOTE: simply because realsense's camera_info does not match our network input.
# It is easier to compute this way.
self.camera_info_msg = CameraInfo()
self.camera_info_msg.header.frame_id = "d435_sim_depth_link"
self.camera_info_msg.height = self.output_resolution[0]
self.camera_info_msg.width = self.output_resolution[1]
self.camera_info_msg.distortion_model = "plumb_bob"
self.camera_info_msg.d = [0., 0., 0., 0., 0.]
sim_raw_resolution = self.cfg["sensor"]["forward_camera"]["resolution"]
sim_cropping_h = self.cfg["sensor"]["forward_camera"]["crop_top_bottom"]
sim_cropping_w = self.cfg["sensor"]["forward_camera"]["crop_left_right"]
cropped_resolution = [ # (H, W)
sim_raw_resolution[0] - sum(sim_cropping_h),
sim_raw_resolution[1] - sum(sim_cropping_w),
]
network_input_resolution = self.cfg["sensor"]["forward_camera"]["output_resolution"]
x_fov = sum(self.cfg["sensor"]["forward_camera"]["horizontal_fov"]) / 2 / 180 * np.pi
fx = (sim_raw_resolution[1]) / (2 * np.tan(x_fov / 2))
fy = fx
fx = fx * network_input_resolution[1] / cropped_resolution[1]
fy = fy * network_input_resolution[0] / cropped_resolution[0]
cx = (sim_raw_resolution[1] / 2) - sim_cropping_w[0]
cy = (sim_raw_resolution[0] / 2) - sim_cropping_h[0]
cx = cx * network_input_resolution[1] / cropped_resolution[1]
cy = cy * network_input_resolution[0] / cropped_resolution[0]
self.camera_info_msg.k = [
fx, 0., cx,
0., fy, cy,
0., 0., 1.,
]
self.camera_info_msg.r = [1., 0., 0., 0., 1., 0., 0., 0., 1.]
self.camera_info_msg.p = [
fx, 0., cx, 0.,
0., fy, cy, 0.,
0., 0., 1., 0.,
]
self.camera_info_msg.binning_x = 0
self.camera_info_msg.binning_y = 0
self.camera_info_msg.roi.do_rectify = False
self.create_timer(
self.cfg["sensor"]["forward_camera"]["refresh_duration"],
self.publish_camera_info_callback,
)
self.forward_depth_embedding_pub = self.create_publisher(
Float32MultiArray,
self.forward_depth_embedding_topic,
1,
)
self.get_logger().info("ros handlers started")
def publish_camera_info_callback(self):
self.camera_info_msg.header.stamp = self.get_clock().now().to_msg()
self.get_logger().info("camera info published", once= True)
self.camera_info_pub.publish(self.camera_info_msg)
def get_depth_frame(self):
# read from pyrealsense2, preprocess and write the model embedding to the buffer
rs_frame = self.rs_pipeline.wait_for_frames(int(
self.cfg["sensor"]["forward_camera"]["latency_range"][1] * 1000 # ms
))
if self.enable_rgb:
rs_frame = self.rs_align.process(rs_frame)
depth_frame = rs_frame.get_depth_frame()
if not depth_frame:
self.get_logger().error("No depth frame", throttle_duration_sec= 1)
return
color_frame = rs_frame.get_color_frame()
if color_frame:
rgb_image_np = np.asanyarray(color_frame.get_data())
rgb_image_np = np.rot90(rgb_image_np, k= 2) # since the camera is inverted
rgb_image_np = rgb_image_np[
self.cropping[0]: -self.cropping[1]-1,
self.cropping[2]: -self.cropping[3]-1,
]
rgb_image_msg = rnp.msgify(Image, rgb_image_np, encoding= "rgb8")
rgb_image_msg.header.stamp = self.get_clock().now().to_msg()
rgb_image_msg.header.frame_id = "d435_sim_depth_link"
self.rgb_pub.publish(rgb_image_msg)
self.get_logger().info("rgb image published", once= True)
# apply relsense filters
for rs_filter in self.rs_filters:
depth_frame = rs_filter.process(depth_frame)
depth_image_np = np.asanyarray(depth_frame.get_data())
# rotate 180 degree because d435i on h1 head is mounted inverted
depth_image_np = np.rot90(depth_image_np, k= 2) # k = 2 for rotate 90 degree twice
depth_image_pyt = torch.from_numpy(depth_image_np.astype(np.float32)).unsqueeze(0).unsqueeze(0)
# apply torch filters
depth_image_pyt = depth_image_pyt[:, :,
self.cropping[0]: -self.cropping[1]-1,
self.cropping[2]: -self.cropping[3]-1,
]
depth_image_pyt = torch.clip(depth_image_pyt, self.depth_range[0], self.depth_range[1]) / (self.depth_range[1] - self.depth_range[0])
depth_image_pyt = resize2d(depth_image_pyt, self.output_resolution)
# publish the depth image input to ros topic
self.get_logger().info("depth range: {}-{}".format(*self.depth_range), once= True)
depth_input_data = (
depth_image_pyt.detach().cpu().numpy() * (self.depth_range[1] - self.depth_range[0]) + self.depth_range[0]
).astype(np.uint16)[0, 0] # (h, w) unit [mm]
# DEBUG: centering the depth image
# depth_input_data = depth_input_data.copy()
# depth_input_data[int(depth_input_data.shape[0] / 2), :] = 0
# depth_input_data[:, int(depth_input_data.shape[1] / 2)] = 0
depth_input_msg = rnp.msgify(Image, depth_input_data, encoding= "16UC1")
depth_input_msg.header.stamp = self.get_clock().now().to_msg()
depth_input_msg.header.frame_id = "d435_sim_depth_link"
self.depth_input_pub.publish(depth_input_msg)
self.get_logger().info("depth input published", once= True)
return depth_image_pyt
def publish_depth_embedding(self, embedding):
msg = Float32MultiArray()
msg.data = embedding.squeeze().detach().cpu().numpy().tolist()
self.forward_depth_embedding_pub.publish(msg)
self.get_logger().info("depth embedding published", once= True)
def register_models(self, visual_encoder):
self.visual_encoder = visual_encoder
def start_main_loop_timer(self, duration):
self.create_timer(
duration,
self.main_loop,
)
def main_loop(self):
depth_image_pyt = self.get_depth_frame()
if depth_image_pyt is not None:
embedding = self.visual_encoder(depth_image_pyt)
self.publish_depth_embedding(embedding)
else:
self.get_logger().warn("One frame of depth embedding if not acquired")
@torch.inference_mode()
def main(args):
rclpy.init()
assert args.logdir is not None, "Please provide a logdir"
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
device = "cpu"
duration = config_dict["sensor"]["forward_camera"]["refresh_duration"] # in sec
visual_node = VisualHandlerNode(
cfg= json.load(open(osp.join(args.logdir, "config.json"), "r")),
cropping= [args.crop_top, args.crop_bottom, args.crop_left, args.crop_right],
rs_resolution= (args.width, args.height),
rs_fps= args.fps,
enable_rgb= args.rgb,
)
env_node = UnitreeRos2Real(
"visual_h1",
low_cmd_topic= "low_cmd_dryrun", # This node should not publish any command at all
cfg= config_dict,
model_device= device,
robot_class_name= "Go2",
dryrun= True, # The robot node in this process should not run at all
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs = env_node.num_obs,
num_critic_obs = env_node.num_privileged_obs,
num_actions= env_node.num_actions,
obs_segments= env_node.obs_segments,
privileged_obs_segments= env_node.privileged_obs_segments,
**config_dict["policy"],
)
# load the model with the latest checkpoint
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(device)
model = model.encoders[0] # the first encoder is the visual encoder
env_node.destroy_node()
visual_node.get_logger().info("Embedding send duration: {:.2f} sec".format(duration))
visual_node.register_models(model)
if args.loop_mode == "while":
rclpy.spin_once(visual_node, timeout_sec= 0.)
while rclpy.ok():
main_loop_time = time.monotonic()
visual_node.main_loop()
rclpy.spin_once(visual_node, timeout_sec= 0.)
time.sleep(max(0, duration - (time.monotonic() - main_loop_time)))
elif args.loop_mode == "timer":
visual_node.start_main_loop_timer(duration)
rclpy.spin(visual_node)
visual_node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--logdir", type= str, default= None, help= "The directory which contains the config.json and model_*.pt files")
parser.add_argument("--height",
type= int,
default= 480,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 640,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps request to the rs pipeline",
)
parser.add_argument("--crop_left",
type= int,
default= 28,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 36,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 48,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--rgb",
action= "store_true",
default= False,
help= "Set to enable rgb visualization",
)
parser.add_argument("--loop_mode", type= str, default= "timer",
choices= ["while", "timer"],
help= "Select which mode to run the main policy control iteration",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,633 @@
import os, sys
import rclpy
from rclpy.node import Node
from unitree_go.msg import (
WirelessController,
LowState,
# MotorState,
# IMUState,
LowCmd,
# MotorCmd,
)
from std_msgs.msg import Float32MultiArray
if os.uname().machine in ["x86_64", "amd64"]:
sys.path.append(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"x86",
))
elif os.uname().machine == "aarch64":
sys.path.append(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"aarch64",
))
from crc_module import get_crc
from multiprocessing import Process
from collections import OrderedDict
import numpy as np
import torch
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
@torch.jit.script
def quat_rotate_inverse(q, v):
""" q must be in x, y, z, w order """
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
@torch.jit.script
def copysign(a, b):
# type: (float, Tensor) -> Tensor
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
@torch.jit.script
def get_euler_xyz(q):
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * \
q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz]
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(torch.abs(sinp) >= 1, copysign(
np.pi / 2.0, sinp), torch.asin(sinp))
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * \
q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz]
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2*np.pi), pitch % (2*np.pi), yaw % (2*np.pi)
class RobotCfgs:
class H1:
pass
class Go2:
NUM_DOF = 12
NUM_ACTIONS = 12
dof_map = [ # from isaacgym simulation joint order to real robot joint order
3, 4, 5,
0, 1, 2,
9, 10, 11,
6, 7, 8,
]
dof_names = [ # NOTE: order matters. This list is the order in simulation.
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
]
dof_signs = [1.] * 12
joint_limits_high = torch.tensor([
1.0472, 3.4907, -0.83776,
1.0472, 3.4907, -0.83776,
1.0472, 4.5379, -0.83776,
1.0472, 4.5379, -0.83776,
], device= "cpu", dtype= torch.float32)
joint_limits_low = torch.tensor([
-1.0472, -1.5708, -2.7227,
-1.0472, -1.5708, -2.7227,
-1.0472, -0.5236, -2.7227,
-1.0472, -0.5236, -2.7227,
], device= "cpu", dtype= torch.float32)
torque_limits = torch.tensor([ # from urdf and in simulation order
25, 40, 40,
25, 40, 40,
25, 40, 40,
25, 40, 40,
], device= "cpu", dtype= torch.float32)
turn_on_motor_mode = [0x01] * 12
class UnitreeRos2Real(Node):
""" A proxy implementation of the real H1 robot. """
class WirelessButtons:
R1 = 0b00000001 # 1
L1 = 0b00000010 # 2
start = 0b00000100 # 4
select = 0b00001000 # 8
R2 = 0b00010000 # 16
L2 = 0b00100000 # 32
F1 = 0b01000000 # 64
F2 = 0b10000000 # 128
A = 0b100000000 # 256
B = 0b1000000000 # 512
X = 0b10000000000 # 1024
Y = 0b100000000000 # 2048
up = 0b1000000000000 # 4096
right = 0b10000000000000 # 8192
down = 0b100000000000000 # 16384
left = 0b1000000000000000 # 32768
def __init__(self,
robot_namespace= None,
low_state_topic= "/lowstate",
low_cmd_topic= "/lowcmd",
joy_stick_topic= "/wirelesscontroller",
forward_depth_topic= None, # if None and still need access, set to str "pyrealsense"
forward_depth_embedding_topic= "/forward_depth_embedding",
cfg= dict(),
lin_vel_deadband= 0.1,
ang_vel_deadband= 0.1,
cmd_px_range= [0.4, 1.0], # check joy_stick_callback (p for positive, n for negative)
cmd_nx_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
cmd_py_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
cmd_ny_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
cmd_pyaw_range= [0.4, 1.6], # check joy_stick_callback (p for positive, n for negative)
cmd_nyaw_range= [0.4, 1.6], # check joy_stick_callback (p for positive, n for negative)
replace_obs_with_embeddings= [], # a list of strings, e.g. ["forward_depth"] then the corrseponding obs will be processed by _get_forward_depth_embedding_obs()
move_by_wireless_remote= True, # if True, the robot will be controlled by a wireless remote
model_device= "cpu",
dof_pos_protect_ratio= 1.1, # if the dof_pos is out of the range of this ratio, the process will shutdown.
robot_class_name= "H1",
dryrun= True, # if True, the robot will not send commands to the real robot
):
super().__init__("unitree_ros2_real")
self.NUM_DOF = getattr(RobotCfgs, robot_class_name).NUM_DOF
self.NUM_ACTIONS = getattr(RobotCfgs, robot_class_name).NUM_ACTIONS
self.robot_namespace = robot_namespace
self.low_state_topic = low_state_topic
# Generate a unique cmd topic so that the low_cmd will not send to the robot's motor.
self.low_cmd_topic = low_cmd_topic if not dryrun else low_cmd_topic + "_dryrun_" + str(np.random.randint(0, 65535))
self.joy_stick_topic = joy_stick_topic
self.forward_depth_topic = forward_depth_topic
self.forward_depth_embedding_topic = forward_depth_embedding_topic
self.cfg = cfg
self.lin_vel_deadband = lin_vel_deadband
self.ang_vel_deadband = ang_vel_deadband
self.cmd_px_range = cmd_px_range
self.cmd_nx_range = cmd_nx_range
self.cmd_py_range = cmd_py_range
self.cmd_ny_range = cmd_ny_range
self.cmd_pyaw_range = cmd_pyaw_range
self.cmd_nyaw_range = cmd_nyaw_range
self.replace_obs_with_embeddings = replace_obs_with_embeddings
self.move_by_wireless_remote = move_by_wireless_remote
self.model_device = model_device
self.dof_pos_protect_ratio = dof_pos_protect_ratio
self.robot_class_name = robot_class_name
self.dryrun = dryrun
self.dof_map = getattr(RobotCfgs, robot_class_name).dof_map
self.dof_names = getattr(RobotCfgs, robot_class_name).dof_names
self.dof_signs = getattr(RobotCfgs, robot_class_name).dof_signs
self.turn_on_motor_mode = getattr(RobotCfgs, robot_class_name).turn_on_motor_mode
self.parse_config()
def parse_config(self):
""" parse, set attributes from config dict, initialize buffers to speed up the computation """
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
self.gravity_vec = torch.zeros((1, 3), device= self.model_device, dtype= torch.float32)
self.gravity_vec[:, self.up_axis_idx] = -1
# observations
self.clip_obs = self.cfg["normalization"]["clip_observations"]
self.obs_scales = self.cfg["normalization"]["obs_scales"]
for k, v in self.obs_scales.items():
if isinstance(v, (list, tuple)):
self.obs_scales[k] = torch.tensor(v, device= self.model_device, dtype= torch.float32)
# check whether there are embeddings in obs_components and launch encoder process later
if len(self.replace_obs_with_embeddings) > 0:
for comp in self.replace_obs_with_embeddings:
self.get_logger().warn(f"{comp} will be replaced with its embedding when get_obs, don't forget to launch the corresponding process before running the policy.")
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
if "privileged_obs_components" in self.cfg["env"].keys():
self.privileged_obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["privileged_obs_components"])
self.num_privileged_obs = self.get_num_obs_from_components(self.cfg["env"]["privileged_obs_components"])
for obs_component in self.cfg["env"]["obs_components"]:
if "orientation_cmds" in obs_component:
self.roll_pitch_yaw_cmd = torch.zeros(1, 3, device= self.model_device, dtype= torch.float32)
# controls
self.control_type = self.cfg["control"]["control_type"]
if not (self.control_type == "P"):
raise NotImplementedError("Only position control is supported for now.")
self.p_gains = []
for i in range(self.NUM_DOF):
name = self.dof_names[i] # set p_gains in simulation order
for k, v in self.cfg["control"]["stiffness"].items():
if k in name:
self.p_gains.append(v)
break # only one match
self.p_gains = torch.tensor(self.p_gains, device= self.model_device, dtype= torch.float32)
self.d_gains = []
for i in range(self.NUM_DOF):
name = self.dof_names[i] # set d_gains in simulation order
for k, v in self.cfg["control"]["damping"].items():
if k in name:
self.d_gains.append(v)
break
self.d_gains = torch.tensor(self.d_gains, device= self.model_device, dtype= torch.float32)
self.default_dof_pos = torch.zeros(self.NUM_DOF, device= self.model_device, dtype= torch.float32)
self.dof_pos_ = torch.empty(1, self.NUM_DOF, device= self.model_device, dtype= torch.float32)
self.dof_vel_ = torch.empty(1, self.NUM_DOF, device= self.model_device, dtype= torch.float32)
for i in range(self.NUM_DOF):
name = self.dof_names[i]
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
# in simulation order.
self.default_dof_pos[i] = default_joint_angle
self.computer_clip_torque = self.cfg["control"].get("computer_clip_torque", True)
self.get_logger().info("Computer Clip Torque (onboard) is " + str(self.computer_clip_torque))
self.torque_limits = getattr(RobotCfgs, self.robot_class_name).torque_limits.to(self.model_device)
if self.computer_clip_torque:
assert hasattr(self, "torque_limits") and (len(self.torque_limits) == self.NUM_DOF), f"torque_limits must be set with the length of {self.NUM_DOF} if computer_clip_torque is True"
self.get_logger().info("[Env] torque limit: " + ",".join("{:.1f}".format(x) for x in self.torque_limits))
# actions
self.num_actions = self.NUM_ACTIONS
self.action_scale = self.cfg["control"]["action_scale"]
self.get_logger().info("[Env] action scale: {:.1f}".format(self.action_scale))
self.clip_actions = self.cfg["normalization"]["clip_actions"]
if self.cfg["normalization"].get("clip_actions_method", None) == "hard":
self.get_logger().info("clip_actions_method with hard mode")
self.get_logger().info("clip_actions_high: " + str(self.cfg["normalization"]["clip_actions_high"]))
self.get_logger().info("clip_actions_low: " + str(self.cfg["normalization"]["clip_actions_low"]))
self.clip_actions_method = "hard"
self.clip_actions_low = torch.tensor(self.cfg["normalization"]["clip_actions_low"], device= self.model_device, dtype= torch.float32)
self.clip_actions_high = torch.tensor(self.cfg["normalization"]["clip_actions_high"], device= self.model_device, dtype= torch.float32)
else:
self.get_logger().info("clip_actions_method is " + str(self.cfg["normalization"].get("clip_actions_method", None)))
self.actions = torch.zeros(self.NUM_ACTIONS, device= self.model_device, dtype= torch.float32)
# hardware related, in simulation order
self.joint_limits_high = getattr(RobotCfgs, self.robot_class_name).joint_limits_high.to(self.model_device)
self.joint_limits_low = getattr(RobotCfgs, self.robot_class_name).joint_limits_low.to(self.model_device)
joint_pos_mid = (self.joint_limits_high + self.joint_limits_low) / 2
joint_pos_range = (self.joint_limits_high - self.joint_limits_low) / 2
self.joint_pos_protect_high = joint_pos_mid + joint_pos_range * self.dof_pos_protect_ratio
self.joint_pos_protect_low = joint_pos_mid - joint_pos_range * self.dof_pos_protect_ratio
def start_ros_handlers(self):
""" after initializing the env and policy, register ros related callbacks and topics
"""
# ROS publishers
self.low_cmd_pub = self.create_publisher(
LowCmd,
self.low_cmd_topic,
1
)
self.low_cmd_buffer = LowCmd()
# ROS subscribers
self.low_state_sub = self.create_subscription(
LowState,
self.low_state_topic,
self._low_state_callback,
1
)
self.joy_stick_sub = self.create_subscription(
WirelessController,
self.joy_stick_topic,
self._joy_stick_callback,
1
)
if self.forward_depth_topic is not None:
self.forward_camera_sub = self.create_subscription(
Image,
self.forward_depth_topic,
self._forward_depth_callback,
1
)
if self.forward_depth_embedding_topic is not None and "forward_depth" in self.replace_obs_with_embeddings:
self.forward_depth_embedding_sub = self.create_subscription(
Float32MultiArray,
self.forward_depth_embedding_topic,
self._forward_depth_embedding_callback,
1,
)
self.get_logger().info("ROS handlers started, waiting to recieve critical low state and wireless controller messages.")
if not self.dryrun:
self.get_logger().warn(f"You are running the code in no-dryrun mode and publishing to '{self.low_cmd_topic}', Please keep safe.")
else:
self.get_logger().warn(f"You are publishing low cmd to '{self.low_cmd_topic}' because of dryrun mode, Please check and be safe.")
while rclpy.ok():
rclpy.spin_once(self)
if hasattr(self, "low_state_buffer") and hasattr(self, "joy_stick_buffer"):
break
self.get_logger().info("Low state message received, the robot is ready to go.")
""" ROS callbacks and handlers that update the buffer """
def _low_state_callback(self, msg):
""" store and handle proprioception data """
self.low_state_buffer = msg # keep the latest low state
# refresh dof_pos and dof_vel
for sim_idx in range(self.NUM_DOF):
real_idx = self.dof_map[sim_idx]
self.dof_pos_[0, sim_idx] = self.low_state_buffer.motor_state[real_idx].q * self.dof_signs[sim_idx]
for sim_idx in range(self.NUM_DOF):
real_idx = self.dof_map[sim_idx]
self.dof_vel_[0, sim_idx] = self.low_state_buffer.motor_state[real_idx].dq * self.dof_signs[sim_idx]
# automatic safety check
for sim_idx in range(self.NUM_DOF):
real_idx = self.dof_map[sim_idx]
if self.dof_pos_[0, sim_idx] > self.joint_pos_protect_high[sim_idx] or \
self.dof_pos_[0, sim_idx] < self.joint_pos_protect_low[sim_idx]:
self.get_logger().error(f"Joint {sim_idx}(sim), {real_idx}(real) position out of range at {self.low_state_buffer.motor_state[real_idx].q}")
self.get_logger().error("The motors and this process shuts down.")
self._turn_off_motors()
raise SystemExit()
def _joy_stick_callback(self, msg):
self.joy_stick_buffer = msg
if self.move_by_wireless_remote:
# left-y for forward/backward
ly = msg.ly
if ly > self.lin_vel_deadband:
vx = (ly - self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (0, 1)
vx = vx * (self.cmd_px_range[1] - self.cmd_px_range[0]) + self.cmd_px_range[0]
elif ly < -self.lin_vel_deadband:
vx = (ly + self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (-1, 0)
vx = vx * (self.cmd_nx_range[1] - self.cmd_nx_range[0]) - self.cmd_nx_range[0]
else:
vx = 0
# left-x for turning left/right
lx = -msg.lx
if lx > self.ang_vel_deadband:
yaw = (lx - self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
yaw = yaw * (self.cmd_pyaw_range[1] - self.cmd_pyaw_range[0]) + self.cmd_pyaw_range[0]
elif lx < -self.ang_vel_deadband:
yaw = (lx + self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
yaw = yaw * (self.cmd_nyaw_range[1] - self.cmd_nyaw_range[0]) - self.cmd_nyaw_range[0]
else:
yaw = 0
# right-x for side moving left/right
rx = -msg.rx
if rx > self.lin_vel_deadband:
vy = (rx - self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
vy = vy * (self.cmd_py_range[1] - self.cmd_py_range[0]) + self.cmd_py_range[0]
elif rx < -self.lin_vel_deadband:
vy = (rx + self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
vy = vy * (self.cmd_ny_range[1] - self.cmd_ny_range[0]) - self.cmd_ny_range[0]
else:
vy = 0
self.xyyaw_command = torch.tensor([vx, vy, yaw], device= self.model_device, dtype= torch.float32)
# refer to Unitree Remote Control data structure, msg.keys is a bit mask
# 00000000 00000001 means pressing the 0-th button (R1)
# 00000000 00000010 means pressing the 1-th button (L1)
# 10000000 00000000 means pressing the 15-th button (left)
if (msg.keys & self.WirelessButtons.R2) or (msg.keys & self.WirelessButtons.L2): # R2 or L2 is pressed
self.get_logger().warn("R2 or L2 is pressed, the motors and this process shuts down.")
self._turn_off_motors()
raise SystemExit()
# roll-pitch target
if hasattr(self, "roll_pitch_yaw_cmd"):
if (msg.keys & self.WirelessButtons.up):
self.roll_pitch_yaw_cmd[0, 1] += 0.1
self.get_logger().info("Pitch Command: " + str(self.roll_pitch_yaw_cmd))
if (msg.keys & self.WirelessButtons.down):
self.roll_pitch_yaw_cmd[0, 1] -= 0.1
self.get_logger().info("Pitch Command: " + str(self.roll_pitch_yaw_cmd))
if (msg.keys & self.WirelessButtons.left):
self.roll_pitch_yaw_cmd[0, 0] -= 0.1
self.get_logger().info("Roll Command: " + str(self.roll_pitch_yaw_cmd))
if (msg.keys & self.WirelessButtons.right):
self.roll_pitch_yaw_cmd[0, 0] += 0.1
self.get_logger().info("Roll Command: " + str(self.roll_pitch_yaw_cmd))
def _forward_depth_callback(self, msg):
""" store and handle depth camera data """
pass
def _forward_depth_embedding_callback(self, msg):
self.forward_depth_embedding_buffer = torch.tensor(msg.data, device= self.model_device, dtype= torch.float32).view(1, -1)
""" Done: ROS callbacks and handlers that update the buffer """
""" refresh observation buffer and corresponding sub-functions """
def _get_lin_vel_obs(self):
return torch.zeros(1, 3, device= self.model_device, dtype= torch.float32)
def _get_ang_vel_obs(self):
buffer = torch.from_numpy(self.low_state_buffer.imu_state.gyroscope).unsqueeze(0)
return buffer
def _get_projected_gravity_obs(self):
quat_xyzw = torch.tensor([
self.low_state_buffer.imu_state.quaternion[1],
self.low_state_buffer.imu_state.quaternion[2],
self.low_state_buffer.imu_state.quaternion[3],
self.low_state_buffer.imu_state.quaternion[0],
], device= self.model_device, dtype= torch.float32).unsqueeze(0)
return quat_rotate_inverse(
quat_xyzw,
self.gravity_vec,
)
def _get_commands_obs(self):
return self.xyyaw_command.unsqueeze(0) # (1, 3)
def _get_dof_pos_obs(self):
return self.dof_pos_ - self.default_dof_pos.unsqueeze(0)
def _get_dof_vel_obs(self):
return self.dof_vel_
def _get_last_actions_obs(self):
return self.actions
def _get_forward_depth_embedding_obs(self):
return self.forward_depth_embedding_buffer
def _get_forward_depth_obs(self):
raise NotImplementedError()
def _get_orientation_cmds_obs(self):
return quat_rotate_inverse(
quat_from_euler_xyz(self.roll_pitch_yaw_cmd[:, 0], self.roll_pitch_yaw_cmd[:, 1], self.roll_pitch_yaw_cmd[:, 2]),
self.gravity_vec,
)
def get_num_obs_from_components(self, components):
obs_segments = self.get_obs_segment_from_components(components)
num_obs = 0
for k, v in obs_segments.items():
num_obs += np.prod(v)
return num_obs
def get_obs_segment_from_components(self, components):
""" Observation segment is defined as a list of lists/ints defining the tensor shape with
corresponding order.
"""
segments = OrderedDict()
if "lin_vel" in components:
print("Warning: lin_vel is not typically available or accurate enough on the real robot. Will return zeros.")
segments["lin_vel"] = (3,)
if "ang_vel" in components:
segments["ang_vel"] = (3,)
if "projected_gravity" in components:
segments["projected_gravity"] = (3,)
if "commands" in components:
segments["commands"] = (3,)
if "dof_pos" in components:
segments["dof_pos"] = (self.NUM_DOF,)
if "dof_vel" in components:
segments["dof_vel"] = (self.NUM_DOF,)
if "last_actions" in components:
segments["last_actions"] = (self.NUM_ACTIONS,)
if "height_measurements" in components:
print("Warning: height_measurements is not typically available on the real robot.")
segments["height_measurements"] = (1, len(self.cfg["terrain"]["measured_points_x"]), len(self.cfg["terrain"]["measured_points_y"]))
if "forward_depth" in components:
if "output_resolution" in self.cfg["sensor"]["forward_camera"]:
segments["forward_depth"] = (1, *self.cfg["sensor"]["forward_camera"]["output_resolution"])
else:
segments["forward_depth"] = (1, *self.cfg["sensor"]["forward_camera"]["resolution"])
if "base_pose" in components:
segments["base_pose"] = (6,) # xyz + rpy
if "robot_config" in components:
""" Related to robot_config_buffer attribute, Be careful to change. """
# robot shape friction
# CoM (Center of Mass) x, y, z
# base mass (payload)
# motor strength for each joint
print("Warning: height_measurements is not typically available on the real robot.")
segments["robot_config"] = (1 + 3 + 1 + self.NUM_ACTIONS,)
""" NOTE: The following components are not directly set in legged_robot.py.
Please check the order or extend the class implementation if needed.
"""
if "joints_target" in components:
# o[0] for target value, 0[1] for wether the target should be tracked (1) or not (0)
segments["joints_target"] = (2, self.NUM_DOF)
if "projected_gravity_target" in components:
# projected_gravity for which the robot should track the target
# last value as a mask of whether to follow the target or not
segments["projected_gravity_target"] = (3+1,)
if "orientation_cmds" in components:
segments["orientation_cmds"] = (3,)
return segments
def get_obs(self, obs_segments= None):
""" Extract from the buffers and build the 1d observation tensor
Each get ... obs function does not do the obs_scale multiplication.
NOTE: obs_buffer has the batch dimension, whose size is 1.
"""
if obs_segments is None:
obs_segments = self.obs_segments
obs_buffer = []
for k, v in obs_segments.items():
if k in self.replace_obs_with_embeddings:
obs_component_value = getattr(self, "_get_" + k + "_embedding_obs")()
else:
obs_component_value = getattr(self, "_get_" + k + "_obs")() * self.obs_scales.get(k, 1.0)
obs_buffer.append(obs_component_value)
obs_buffer = torch.cat(obs_buffer, dim=1)
obs_buffer = torch.clamp(obs_buffer, -self.clip_obs, self.clip_obs)
return obs_buffer
""" Done: refresh observation buffer and corresponding sub-functions """
""" Control related functions """
def clip_action_before_scale(self, action):
action = torch.clip(action, -self.clip_actions, self.clip_actions)
if getattr(self, "clip_actions_method", None) == "hard":
action = torch.clip(action, self.clip_actions_low, self.clip_actions_high)
return action
def clip_by_torque_limit(self, actions_scaled):
""" Different from simulation, we reverse the process and clip the actions directly,
so that the PD controller runs in robot but not our script.
"""
control_type = self.cfg["control"]["control_type"]
if control_type == "P":
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel_
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel_
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos_
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos_
else:
raise NotImplementedError
return torch.clip(actions_scaled, actions_low, actions_high)
def send_action(self, actions):
""" Send the action to the robot motors, which does the preprocessing
just like env.step in simulation.
Thus, the actions has the batch dimension, whose size is 1.
"""
self.actions = self.clip_action_before_scale(actions)
if self.computer_clip_torque:
clipped_scaled_action = self.clip_by_torque_limit(actions * self.action_scale)
else:
self.get_logger().warn("Computer Clip Torque is False, the robot may be damaged.", throttle_duration_sec= 1)
clipped_scaled_action = actions * self.action_scale
robot_coordinates_action = clipped_scaled_action + self.default_dof_pos.unsqueeze(0)
self._publish_legs_cmd(robot_coordinates_action[0])
""" Done: Control related functions """
""" functions that actually publish the commands and take effect """
def _publish_legs_cmd(self, robot_coordinates_action: torch.Tensor):
""" Publish the joint commands to the robot legs in robot coordinates system.
robot_coordinates_action: shape (NUM_DOF,), in simulation order.
"""
for sim_idx in range(self.NUM_DOF):
real_idx = self.dof_map[sim_idx]
if not self.dryrun:
self.low_cmd_buffer.motor_cmd[real_idx].mode = self.turn_on_motor_mode[sim_idx]
self.low_cmd_buffer.motor_cmd[real_idx].q = robot_coordinates_action[sim_idx].item() * self.dof_signs[sim_idx]
self.low_cmd_buffer.motor_cmd[real_idx].dq = 0.
self.low_cmd_buffer.motor_cmd[real_idx].tau = 0.
self.low_cmd_buffer.motor_cmd[real_idx].kp = self.p_gains[sim_idx].item()
self.low_cmd_buffer.motor_cmd[real_idx].kd = self.d_gains[sim_idx].item()
self.low_cmd_buffer.crc = get_crc(self.low_cmd_buffer)
self.low_cmd_pub.publish(self.low_cmd_buffer)
def _turn_off_motors(self):
""" Turn off the motors """
for sim_idx in range(self.NUM_DOF):
real_idx = self.dof_map[sim_idx]
self.low_cmd_buffer.motor_cmd[real_idx].mode = 0x00
self.low_cmd_buffer.motor_cmd[real_idx].q = 0.
self.low_cmd_buffer.motor_cmd[real_idx].dq = 0.
self.low_cmd_buffer.motor_cmd[real_idx].tau = 0.
self.low_cmd_buffer.motor_cmd[real_idx].kp = 0.
self.low_cmd_buffer.motor_cmd[real_idx].kd = 0.
self.low_cmd_buffer.crc = get_crc(self.low_cmd_buffer)
self.low_cmd_pub.publish(self.low_cmd_buffer)
""" Done: functions that actually publish the commands and take effect """

Binary file not shown.

View File

@ -30,3 +30,4 @@
from .ppo import PPO
from .tppo import TPPO
from .estimator import EstimatorPPO, EstimatorTPPO

View File

@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from rsl_rl.algorithms.ppo import PPO
from rsl_rl.algorithms.tppo import TPPO
from rsl_rl.utils.utils import unpad_trajectories, get_subobs_by_components
from rsl_rl.storage.rollout_storage import SarsaRolloutStorage
class EstimatorAlgoMixin:
""" A supervised algorithm implementation that trains a state predictor in the policy model
"""
def __init__(self,
*args,
estimator_loss_func= "mse_loss",
estimator_loss_kwargs= dict(),
**kwargs,
):
super().__init__(*args, **kwargs)
self.estimator_obs_components = self.actor_critic.estimator_obs_components
self.estimator_target_obs_components = self.actor_critic.estimator_target_components
self.estimator_loss_func = estimator_loss_func
self.estimator_loss_kwargs = estimator_loss_kwargs
def compute_losses(self, minibatch):
losses, inter_vars, stats = super().compute_losses(minibatch)
# Use the critic_obs from the same timestep for estimation target
# Not considering predicting the next state for now.
estimation_target = get_subobs_by_components(
minibatch.critic_obs,
component_names= self.estimator_target_obs_components,
obs_segments= self.actor_critic.privileged_obs_segments,
)
if self.actor_critic.is_recurrent:
estimation_target = unpad_trajectories(estimation_target, minibatch.masks)
# actor_critic must compute the estimated_state_ during act() as a intermediate variable
estimation = unpad_trajectories(self.actor_critic.get_estimated_state(), minibatch.masks)
estimator_loss = getattr(F, self.estimator_loss_func)(
estimation,
estimation_target,
**self.estimator_loss_kwargs,
reduction= "none",
).sum(dim= -1)
losses["estimator_loss"] = estimator_loss.mean()
return losses, inter_vars, stats
class EstimatorPPO(EstimatorAlgoMixin, PPO):
pass
class EstimatorTPPO(EstimatorAlgoMixin, TPPO):
pass

View File

@ -68,7 +68,6 @@ class PPO:
self.actor_critic.to(self.device)
self.storage = None # initialized later
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
# PPO parameters
self.clip_param = clip_param
@ -86,6 +85,7 @@ class PPO:
self.current_learning_iteration = 0
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.transition = RolloutStorage.Transition()
self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device)
def test_mode(self):
@ -108,7 +108,7 @@ class PPO:
self.transition.critic_observations = critic_obs
return self.transition.actions
def process_env_step(self, rewards, dones, infos):
def process_env_step(self, rewards, dones, infos, next_obs, next_critic_obs):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
@ -162,9 +162,9 @@ class PPO:
return mean_losses, average_stats
def compute_losses(self, minibatch):
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.actor)
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.critic)
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
try:
@ -222,3 +222,22 @@ class PPO:
if self.use_clipped_value_loss:
inter_vars["value_clipped"] = value_clipped
return return_, inter_vars, dict()
def state_dict(self):
state_dict = {
"model_state_dict": self.actor_critic.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}
if hasattr(self, "lr_scheduler"):
state_dict["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
return state_dict
def load_state_dict(self, state_dict):
self.actor_critic.load_state_dict(state_dict["model_state_dict"])
if "optimizer_state_dict" in state_dict:
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
if hasattr(self, "lr_scheduler"):
self.lr_scheduler.load_state_dict(state_dict["lr_scheduler_state_dict"])
elif "lr_scheduler_state_dict" in state_dict:
print("Warning: lr scheduler state dict loaded but no lr scheduler is initialized. Ignored.")

View File

@ -6,17 +6,19 @@ import torch.nn.functional as F
import torch.optim as optim
import rsl_rl.modules as modules
from rsl_rl.utils import unpad_trajectories
from rsl_rl.storage.rollout_storage import ActionLabelRollout
from rsl_rl.algorithms.ppo import PPO
from legged_gym import LEGGED_GYM_ROOT_DIR
# assuming learning iteration is at an assumable iteration scale
def GET_TEACHER_ACT_PROB_FUNC(option, iteration_scale):
TEACHER_ACT_PROB_options = {
def GET_PROB_FUNC(option, iteration_scale):
PROB_options = {
"linear": (lambda x: max(0, 1 - 1 / iteration_scale * x)),
"exp": (lambda x: max(0, (1 - 1 / iteration_scale) ** x)),
"tanh": (lambda x: max(0, 0.5 * (1 - torch.tanh(1 / iteration_scale * (x - iteration_scale))))),
}
return TEACHER_ACT_PROB_options[option]
return PROB_options[option]
class TPPO(PPO):
def __init__(self,
@ -28,35 +30,53 @@ class TPPO(PPO):
teacher_act_prob= "exp", # a number or a callable to (0 ~ 1) to the selection of act using teacher policy
update_times_scale= 100, # a rough estimation of how many times the update will be called
using_ppo= True, # If False, compute_losses will skip ppo loss computation and returns to DAGGR
distillation_loss_coef= 1.,
distillation_loss_coef= 1., # can also be string to select a prob function to scale the distillation loss
distill_target= "real",
distill_latent_coef= 1.,
distill_latent_target= "real",
distill_latent_obs_component_mapping= None,
buffer_dilation_ratio= 1.,
lr_scheduler_class_name= None,
lr_scheduler= dict(),
hidden_state_resample_prob= 0.0, # if > 0, Some hidden state in the minibatch will be resampled
action_labels_from_sample= False, # if True, the action labels from teacher policy will be from policy.act instead of policy.act_inference
**kwargs,
):
"""
Args:
- distill_latent_obs_component_mapping: a dict of
{student_obs_component_name: teacher_obs_component_name}
only when both policy are the instance of EncoderActorCriticMixin
"""
super().__init__(*args, **kwargs)
self.label_action_with_critic_obs = label_action_with_critic_obs
self.teacher_act_prob = teacher_act_prob
self.update_times_scale = update_times_scale
if isinstance(self.teacher_act_prob, str):
self.teacher_act_prob = GET_TEACHER_ACT_PROB_FUNC(self.teacher_act_prob, update_times_scale)
self.teacher_act_prob = GET_PROB_FUNC(self.teacher_act_prob, update_times_scale)
else:
self.__teacher_act_prob = self.teacher_act_prob
self.teacher_act_prob = lambda x: self.__teacher_act_prob
self.using_ppo = using_ppo
self.distillation_loss_coef = distillation_loss_coef
self.__distillation_loss_coef = distillation_loss_coef
if isinstance(self.__distillation_loss_coef, str):
self.distillation_loss_coef_func = GET_PROB_FUNC(self.__distillation_loss_coef, update_times_scale)
self.distill_target = distill_target
self.distill_latent_coef = distill_latent_coef
self.distill_latent_target = distill_latent_target
self.distill_latent_obs_component_mapping = distill_latent_obs_component_mapping
self.buffer_dilation_ratio = buffer_dilation_ratio
self.lr_scheduler_class_name = lr_scheduler_class_name
self.lr_scheduler_kwargs = lr_scheduler
self.hidden_state_resample_prob = hidden_state_resample_prob
self.action_labels_from_sample = action_labels_from_sample
self.transition = ActionLabelRollout.Transition()
# build and load teacher network
teacher_actor_critic = getattr(modules, teacher_policy_class_name)(**teacher_policy)
if not teacher_ac_path is None:
if "{LEGGED_GYM_ROOT_DIR}" in teacher_ac_path:
teacher_ac_path = teacher_ac_path.format(LEGGED_GYM_ROOT_DIR= LEGGED_GYM_ROOT_DIR)
state_dict = torch.load(teacher_ac_path, map_location= "cpu")
teacher_actor_critic_state_dict = state_dict["model_state_dict"]
teacher_actor_critic.load_state_dict(teacher_actor_critic_state_dict)
@ -84,8 +104,12 @@ class TPPO(PPO):
def act(self, obs, critic_obs):
# get actions
return_ = super().act(obs, critic_obs)
if self.label_action_with_critic_obs:
if self.label_action_with_critic_obs and self.action_labels_from_sample:
self.transition.action_labels = self.teacher_actor_critic.act(critic_obs).detach()
elif self.label_action_with_critic_obs:
self.transition.action_labels = self.teacher_actor_critic.act_inference(critic_obs).detach()
elif self.action_labels_from_sample:
self.transition.action_labels = self.teacher_actor_critic.act(obs).detach()
else:
self.transition.action_labels = self.teacher_actor_critic.act_inference(obs).detach()
@ -96,8 +120,8 @@ class TPPO(PPO):
return return_
def process_env_step(self, rewards, dones, infos):
return_ = super().process_env_step(rewards, dones, infos)
def process_env_step(self, rewards, dones, infos, next_obs, next_critic_obs):
return_ = super().process_env_step(rewards, dones, infos, next_obs, next_critic_obs)
self.teacher_actor_critic.reset(dones)
# resample teacher action mask for those dones env
self.use_teacher_act_mask[dones] = torch.rand(dones.sum(), device= self.device) < self.teacher_act_prob(self.current_learning_iteration)
@ -107,7 +131,7 @@ class TPPO(PPO):
""" The interface to collect transition from dataset rather than env """
super().act(transition.observation, transition.privileged_observation)
self.transition.action_labels = transition.action
super().process_env_step(transition.reward, transition.done, infos)
super().process_env_step(transition.reward, transition.done, infos, transition.next_observation, transition.next_privileged_observation)
def compute_returns(self, last_critic_obs):
if not self.using_ppo:
@ -124,30 +148,30 @@ class TPPO(PPO):
def compute_losses(self, minibatch):
if self.hidden_state_resample_prob > 0.0:
# assuming the hidden states are from LSTM or GRU, which are always betwein -1 and 1
hidden_state_example = minibatch.hid_states[0][0] if isinstance(minibatch.hid_states[0], tuple) else minibatch.hid_states[0]
hidden_state_example = minibatch.hidden_states[0][0] if isinstance(minibatch.hidden_states[0], tuple) else minibatch.hidden_states[0]
resample_mask = torch.rand(hidden_state_example.shape[1], device= self.device) < self.hidden_state_resample_prob
# for each hidden state, resample from -1 to 1
if isinstance(minibatch.hid_states[0], tuple):
if isinstance(minibatch.hidden_states[0], tuple):
# for LSTM not tested
# iterate through actor and critic hidden state
minibatch = minibatch._replace(hid_states= tuple(
minibatch = minibatch._replace(hidden_states= tuple(
tuple(
torch.where(
resample_mask.unsqueeze(-1).unsqueeze(-1),
torch.rand_like(minibatch.hid_states[i][j], device= self.device) * 2 - 1,
minibatch.hid_states[i][j],
) for j in range(len(minibatch.hid_states[i]))
) for i in range(len(minibatch.hid_states))
torch.rand_like(minibatch.hidden_states[i][j], device= self.device) * 2 - 1,
minibatch.hidden_states[i][j],
) for j in range(len(minibatch.hidden_states[i]))
) for i in range(len(minibatch.hidden_states))
))
else:
# for GRU
# iterate through actor and critic hidden state
minibatch = minibatch._replace(hid_states= tuple(
minibatch = minibatch._replace(hidden_states= tuple(
torch.where(
resample_mask.unsqueeze(-1),
torch.rand_like(minibatch.hid_states[i], device= self.device) * 2 - 1,
minibatch.hid_states[i],
) for i in range(len(minibatch.hid_states))
torch.rand_like(minibatch.hidden_states[i], device= self.device) * 2 - 1,
minibatch.hidden_states[i],
) for i in range(len(minibatch.hidden_states))
))
if self.using_ppo:
@ -156,7 +180,7 @@ class TPPO(PPO):
losses = dict()
inter_vars = dict()
stats = dict()
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.actor)
# distillation loss (with teacher actor)
if self.distill_target == "real":
@ -164,6 +188,12 @@ class TPPO(PPO):
self.actor_critic.action_mean - minibatch.action_labels,
dim= -1
)
elif self.distill_target == "mse_sum":
dist_loss = F.mse_loss(
self.actor_critic.action_mean,
minibatch.action_labels,
reduction= "none",
).sum(-1)
elif self.distill_target == "l1":
dist_loss = torch.norm(
self.actor_critic.action_mean - minibatch.action_labels,
@ -187,6 +217,11 @@ class TPPO(PPO):
(minibatch.action_labels + 1) * 0.5, # (n, t, d)
reduction= "none",
).mean(-1) * 2 * l1 / self.actor_critic.action_mean.shape[-1] # (n, t)
elif self.distill_target == "max_log_prob":
action_labels_log_prob = self.actor_critic.get_actions_log_prob(minibatch.action_labels)
dist_loss = -action_labels_log_prob
elif self.distill_target == "kl":
raise NotImplementedError()
if "tanh" in self.distill_target:
stats["l1distance"] = torch.norm(
@ -199,7 +234,54 @@ class TPPO(PPO):
dim= -1,
p= 1
).mean().detach()
# update distillation loss coef if applicable
self.distillation_loss_coef = self.distillation_loss_coef_func(self.current_learning_iteration) if hasattr(self, "distillation_loss_coef_func") else self.__distillation_loss_coef
losses["distillation_loss"] = dist_loss.mean()
# distill latent embedding
if self.distill_latent_obs_component_mapping is not None:
for k, v in self.distill_latent_obs_component_mapping.items():
# get the latent embedding
latent = self.actor_critic.get_encoder_latent(
minibatch.obs,
k,
)
with torch.no_grad():
target_latent = self.teacher_actor_critic.get_encoder_latent(
minibatch.critic_obs,
v,
)
if self.actor_critic.is_recurrent:
latent = unpad_trajectories(latent, minibatch.masks)
target_latent = unpad_trajectories(target_latent, minibatch.masks)
if self.distill_latent_target == "real":
dist_loss = torch.norm(
latent - target_latent,
dim= -1,
)
elif self.distill_latent_target == "l1":
dist_loss = torch.norm(
latent - target_latent,
dim= -1,
p= 1,
)
elif self.distill_latent_target == "tanh":
dist_loss = F.binary_cross_entropy(
(latent + 1) * 0.5,
(target_latent + 1) * 0.5,
) * 2
elif self.distill_latent_target == "scaled_tanh":
l1 = torch.norm(
latent - target_latent,
dim= -1,
p= 1,
)
dist_loss = F.binary_cross_entropy(
(latent + 1) * 0.5,
(target_latent + 1) * 0.5, # (n, t, d)
reduction= "none",
).mean(-1) * 2 * l1 / latent.shape[-1] # (n, t)
setattr(self, f"distill_latent_{k}_coef", self.distill_latent_coef)
losses[f"distill_latent_{k}"] = dist_loss.mean()
return losses, inter_vars, stats

View File

@ -33,6 +33,9 @@ from .actor_critic_recurrent import ActorCriticRecurrent
from .visual_actor_critic import VisualDeterministicRecurrent, VisualDeterministicAC
from .actor_critic_mutex import ActorCriticMutex
from .actor_critic_field_mutex import ActorCriticFieldMutex, ActorCriticClimbMutex
from .encoder_actor_critic import EncoderActorCriticMixin, EncoderActorCritic, EncoderActorCriticRecurrent
from .state_estimator import EstimatorMixin, StateAc, StateAcRecurrent
from .all_mixer import EncoderStateAc, EncoderStateAcRecurrent
def build_actor_critic(env, policy_class_name, policy_cfg):
""" NOTE: This method allows to hack the policy kwargs by adding the env attributes to the policy_cfg. """

View File

@ -45,11 +45,20 @@ class ActorCritic(nn.Module):
activation='elu',
init_noise_std=1.0,
mu_activation= None, # If set, the last layer will be added with a special activation layer.
obs_segments= None,
privileged_obs_segments= None, # No need
**kwargs):
if kwargs:
print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
super(ActorCritic, self).__init__()
# obs_segmnets is a ordered dict that contains the observation components (string) and its shape.
# We use this to slice the observations into the correct components.
if privileged_obs_segments is None:
privileged_obs_segments = obs_segments
self.obs_segments = obs_segments
self.privileged_obs_segments = privileged_obs_segments
activation = get_activation(activation)
mlp_input_dim_a = num_actor_obs

View File

@ -36,6 +36,9 @@ from torch.distributions import Normal
from torch.nn.modules import rnn
from .actor_critic import ActorCritic, get_activation
from rsl_rl.utils import unpad_trajectories
from rsl_rl.utils.collections import namedarraytuple, is_namedarraytuple
ActorCriticHiddenState = namedarraytuple('ActorCriticHiddenState', ['actor', 'critic'])
class ActorCriticRecurrent(ActorCritic):
is_recurrent = True
@ -74,39 +77,47 @@ class ActorCriticRecurrent(ActorCritic):
def act(self, observations, masks=None, hidden_states=None):
input_a = self.memory_a(observations, masks, hidden_states)
return super().act(input_a.squeeze(0))
return super().act(input_a)
def act_inference(self, observations):
input_a = self.memory_a(observations)
return super().act_inference(input_a.squeeze(0))
return super().act_inference(input_a)
def evaluate(self, critic_observations, masks=None, hidden_states=None):
input_c = self.memory_c(critic_observations, masks, hidden_states)
return super().evaluate(input_c.squeeze(0))
return super().evaluate(input_c)
def get_hidden_states(self):
return self.memory_a.hidden_states, self.memory_c.hidden_states
return ActorCriticHiddenState(self.memory_a.hidden_states, self.memory_c.hidden_states)
LstmHiddenState = namedarraytuple('LstmHiddenState', ['hidden', 'cell'])
class Memory(torch.nn.Module):
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
super().__init__()
# RNN
# RNN currently support only GRU and LSTM
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
batch_mode = hidden_states is not None
if batch_mode:
# batch mode (policy update): need saved hidden states
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
if is_namedarraytuple(hidden_states):
hidden_states = tuple(hidden_states)
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
if not masks is None:
# in this case, user can choose whether to unpad the output or not
out = unpad_trajectories(out, masks)
else:
# inference mode (collection): use hidden states of last step
if is_namedarraytuple(self.hidden_states):
self.hidden_states = tuple(self.hidden_states)
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
if isinstance(self.hidden_states, tuple):
self.hidden_states = LstmHiddenState(*self.hidden_states)
out = out.squeeze(0) # remove the time dimension
return out
def reset(self, dones=None):

View File

@ -0,0 +1,13 @@
""" A file put all mixin class combinations """
from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent
from .encoder_actor_critic import EncoderActorCriticMixin
from .state_estimator import EstimatorMixin
class EncoderStateAc(EstimatorMixin, EncoderActorCriticMixin, ActorCritic):
pass
class EncoderStateAcRecurrent(EstimatorMixin, EncoderActorCriticMixin, ActorCriticRecurrent):
def load_misaligned_state_dict(self, module, obs_segments, privileged_obs_segments=None):
pass

View File

@ -0,0 +1,188 @@
import numpy as np
import torch
import torch.nn as nn
from rsl_rl.modules.mlp import MlpModel
from rsl_rl.modules.conv2d import Conv2dHeadModel
from rsl_rl.utils.utils import get_obs_slice
class EncoderActorCriticMixin:
""" A general implementation where a seperate encoder is used to embed the obs/privileged_obs """
def __init__(self,
num_actor_obs,
num_critic_obs,
num_actions,
obs_segments= None,
privileged_obs_segments= None,
encoder_component_names= [], # allow multiple encoders
encoder_class_name= "MlpModel", # accept list of names (in the same order as encoder_component_names)
encoder_kwargs= dict(), # accept list of kwargs (in the same order as encoder_component_names),
encoder_output_size= None,
critic_encoder_component_names= None, # None, "shared", or a list of names (in the same order as encoder_component_names)
critic_encoder_class_name= None, # accept list of names (in the same order as encoder_component_names)
critic_encoder_kwargs= None, # accept list of kwargs (in the same order as encoder_component_names),
**kwargs,
):
""" NOTE: recurrent encoder is not implemented and tested yet.
"""
self.num_actor_obs = num_actor_obs
self.num_critic_obs = num_critic_obs
self.num_actions = num_actions
self.obs_segments = obs_segments
self.privileged_obs_segments = privileged_obs_segments
self.encoder_component_names = encoder_component_names
self.encoder_class_name = encoder_class_name
self.encoder_kwargs = encoder_kwargs
self.encoder_output_size = encoder_output_size
self.critic_encoder_component_names = critic_encoder_component_names
self.critic_encoder_class_name = critic_encoder_class_name if not critic_encoder_class_name is None else encoder_class_name
self.critic_encoder_kwargs = critic_encoder_kwargs if not critic_encoder_kwargs is None else encoder_kwargs
self.obs_segments = obs_segments
self.privileged_obs_segments = privileged_obs_segments
self.prepare_obs_slices()
super().__init__(
num_actor_obs - sum([s[0].stop - s[0].start for s in self.encoder_obs_slices]) + len(self.encoder_obs_slices) * self.encoder_output_size,
num_critic_obs if self.critic_encoder_component_names is None else num_critic_obs - sum([s[0].stop - s[0].start for s in self.critic_encoder_obs_slices]) + len(self.critic_encoder_obs_slices) * self.encoder_output_size,
num_actions,
obs_segments= obs_segments,
privileged_obs_segments= privileged_obs_segments,
**kwargs,
)
self.encoders = self.build_encoders(
self.encoder_component_names,
self.encoder_class_name,
self.encoder_obs_slices,
self.encoder_kwargs,
self.encoder_output_size,
)
if not (self.critic_encoder_component_names is None or self.critic_encoder_component_names == "shared"):
self.critic_encoders = self.build_encoders(
self.critic_encoder_component_names,
self.critic_encoder_class_name,
self.critic_encoder_obs_slices,
self.critic_encoder_kwargs,
self.encoder_output_size,
)
def prepare_obs_slices(self):
# NOTE: encoders are stored in the order of obs_component_names respectively.
# latents_order stores the order of how each output latent should be concatenated with
# the rest of the obs vector.
self.encoder_obs_slices = [get_obs_slice(self.obs_segments, name) for name in self.encoder_component_names]
self.latents_order = [i for i in range(len(self.encoder_obs_slices))]
self.latents_order.sort(key= lambda i: self.encoder_obs_slices[i][0].start)
if self.critic_encoder_component_names is not None:
if self.critic_encoder_component_names == "shared":
self.critic_encoder_obs_slices = self.encoder_obs_slices
else:
critic_obs_segments = self.obs_segments if self.privileged_obs_segments is None else self.privileged_obs_segments
self.critic_encoder_obs_slices = [get_obs_slice(critic_obs_segments, name) for name in self.critic_encoder_component_names]
self.critic_latents_order = [i for i in range(len(self.critic_encoder_obs_slices))]
self.critic_latents_order.sort(key= lambda i: self.critic_encoder_obs_slices[i][0].start)
def build_encoders(self, component_names, class_name, obs_slices, kwargs, encoder_output_size):
encoders = nn.ModuleList()
for component_i, name in enumerate(component_names):
model_class_name = class_name[component_i] if isinstance(class_name, (tuple, list)) else class_name
obs_slice = obs_slices[component_i]
model_kwargs = kwargs[component_i] if isinstance(kwargs, (tuple, list)) else kwargs
model_kwargs = model_kwargs.copy() # 1-level shallow copy
# This code is not clean enough, need to sort out later
if model_class_name == "MlpModel":
hidden_sizes = model_kwargs.pop("hidden_sizes") + [encoder_output_size,]
encoders.append(MlpModel(
np.prod(obs_slice[1]),
hidden_sizes= hidden_sizes,
output_size= None,
**model_kwargs,
))
elif model_class_name == "Conv2dHeadModel":
hidden_sizes = model_kwargs.pop("hidden_sizes") + [encoder_output_size,]
encoders.append(Conv2dHeadModel(
obs_slice[1],
hidden_sizes= hidden_sizes,
output_size= None,
**model_kwargs,
))
else:
raise NotImplementedError(f"Encoder for {model_class_name} on {name} not implemented")
return encoders
def embed_encoders_latent(self, observations, obs_slices, encoders, latents_order):
leading_dims = observations.shape[:-1]
latents = []
for encoder_i, encoder in enumerate(encoders):
# This code is not clean enough, need to sort out later
if isinstance(encoder, MlpModel):
latents.append(encoder(
observations[..., obs_slices[encoder_i][0]].reshape(-1, np.prod(obs_slices[encoder_i][1]))
).reshape(*leading_dims, -1))
elif isinstance(encoder, Conv2dHeadModel):
latents.append(encoder(
observations[..., obs_slices[encoder_i][0]].reshape(-1, *obs_slices[encoder_i][1])
).reshape(*leading_dims, -1))
else:
raise NotImplementedError(f"Encoder for {type(encoder)} not implemented")
# replace the obs vector with the latent vector in eace obs_slice[0] (the slice of obs)
embedded_obs = []
embedded_obs.append(observations[..., :obs_slices[latents_order[0]][0].start])
for order_i in range(len(latents)- 1):
current_idx = latents_order[order_i]
next_idx = latents_order[order_i + 1]
embedded_obs.append(latents[current_idx])
embedded_obs.append(observations[..., obs_slices[current_idx][0].stop: obs_slices[next_idx][0].start])
current_idx = latents_order[-1]
next_idx = None
embedded_obs.append(latents[current_idx])
embedded_obs.append(observations[..., obs_slices[current_idx][0].stop:])
return torch.cat(embedded_obs, dim= -1)
def get_encoder_latent(self, observations, obs_component, critic= False):
""" Get the latent vector from the encoder of the specified obs_component
"""
leading_dims = observations.shape[:-1]
encoder_obs_components = self.critic_encoder_component_names if critic else self.encoder_component_names
encoder_obs_slices = self.critic_encoder_obs_slices if critic else self.encoder_obs_slices
encoders = self.critic_encoders if critic else self.encoders
for i, name in enumerate(encoder_obs_components):
if name == obs_component:
obs_component_var = observations[..., encoder_obs_slices[i][0]]
if isinstance(encoders[i], MlpModel):
obs_component_var = obs_component_var.reshape(-1, np.prod(encoder_obs_slices[i][1]))
elif isinstance(encoders[i], Conv2dHeadModel):
obs_component_var = obs_component_var.reshape(-1, *encoder_obs_slices[i][1])
latent = encoders[i](obs_component_var).reshape(*leading_dims, -1)
return latent
raise ValueError(f"obs_component {obs_component} not found in encoder_obs_components")
def act(self, observations, **kwargs):
obs = self.embed_encoders_latent(observations, self.encoder_obs_slices, self.encoders, self.latents_order)
return super().act(obs, **kwargs)
def act_inference(self, observations):
obs = self.embed_encoders_latent(observations, self.encoder_obs_slices, self.encoders, self.latents_order)
return super().act_inference(obs)
def evaluate(self, critic_observations, masks=None, hidden_states=None):
if self.critic_encoder_component_names == "shared":
obs = self.embed_encoders_latent(critic_observations, self.encoder_obs_slices, self.encoders, self.latents_order)
elif self.critic_encoder_component_names is None:
obs = critic_observations
else:
obs = self.embed_encoders_latent(critic_observations, self.critic_encoder_obs_slices, self.critic_encoders, self.critic_latents_order)
return super().evaluate(obs, masks, hidden_states)
from .actor_critic import ActorCritic
class EncoderActorCritic(EncoderActorCriticMixin, ActorCritic):
pass
from .actor_critic_recurrent import ActorCriticRecurrent
class EncoderActorCriticRecurrent(EncoderActorCriticMixin, ActorCriticRecurrent):
pass

View File

@ -24,6 +24,8 @@ class MlpModel(torch.nn.Module):
hidden_sizes = [hidden_sizes]
elif hidden_sizes is None:
hidden_sizes = []
if isinstance(nonlinearity, str):
nonlinearity = getattr(torch.nn, nonlinearity)
hidden_layers = [torch.nn.Linear(n_in, n_out) for n_in, n_out in
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
sequence = list()

View File

@ -0,0 +1,165 @@
import numpy as np
import torch
import torch.nn as nn
from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent, ActorCriticHiddenState
from rsl_rl.modules.mlp import MlpModel
from rsl_rl.modules.actor_critic_recurrent import Memory
from rsl_rl.utils import unpad_trajectories
from rsl_rl.utils.utils import get_subobs_size, get_subobs_by_components, substitute_estimated_state
from rsl_rl.utils.collections import namedarraytuple
EstimatorActorHiddenState = namedarraytuple('EstimatorActorHiddenState', [
'estimator',
'actor',
])
class EstimatorMixin:
def __init__(self,
*args,
estimator_obs_components= None, # a list of strings used to get obs slices
estimator_target_components= None, # a list of strings used to get obs slices
estimator_kwargs= dict(),
use_actor_rnn= False, # if set and self.is_recurrent, use actor-rnn output as the input of the state estimator directly
replace_state_prob= 0., # if 0~1, replace the actor observation with the estimated state with this probability
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.estimator_obs_components = estimator_obs_components
self.estimator_target_components = estimator_target_components
self.estimator_kwargs = estimator_kwargs
self.use_actor_rnn = use_actor_rnn
self.replace_state_prob = replace_state_prob
assert (self.replace_state_prob <= 0.) or (not self.use_actor_rnn), "You cannot replace the actor's observation (part) after the actor already used it's memory module. "
self.build_estimator(**kwargs)
def build_estimator(self, **kwargs):
""" This implementation is not flexible enough, but it is enough for now. """
estimator_input_size = get_subobs_size(self.obs_segments, self.estimator_obs_components)
estimator_output_size = get_subobs_size(self.obs_segments, self.estimator_target_components)
if self.is_recurrent:
# estimate required state using a recurrent network
if self.use_actor_rnn:
estimator_input_size = self.memory_a.rnn.hidden_size
self.state_estimator = MlpModel(
input_size= estimator_input_size,
output_size= estimator_output_size,
**self.estimator_kwargs,
)
else:
self.memory_s = Memory(
estimator_input_size,
type= kwargs.get("rnn_type", "lstm"),
num_layers= kwargs.get("rnn_num_layers", 1),
hidden_size= kwargs.get("rnn_hidden_size", 256),
)
self.state_estimator = MlpModel(
input_size= self.memory_s.rnn.hidden_size,
output_size= estimator_output_size,
**self.estimator_kwargs,
)
else:
# estimate required state using a feedforward network
self.state_estimator = MlpModel(
input_size= estimator_input_size,
output_size= estimator_output_size,
**self.estimator_kwargs,
)
def reset(self, dones=None):
super().reset(dones)
if self.is_recurrent and not self.use_actor_rnn:
self.memory_s.reset(dones)
def act(self, observations, masks=None, hidden_states=None, inference= False):
observations = observations.clone()
if inference:
assert masks is None and hidden_states is None, "Inference mode does not support masks and hidden_states. "
if self.is_recurrent and self.use_actor_rnn:
# NOTE:
# In this branch, it may requires a redesign for the entire actor_critic.act interface.
input_s = self.memory_a(observations, masks, hidden_states)
action = ActorCritic.act(self, input_s)
self.estimated_state_ = self.state_estimator(input_s)
elif self.is_recurrent and not self.use_actor_rnn:
# TODO: allows non-recurrent state estimator with recurrent actor
subobs = get_subobs_by_components(
observations,
component_names= self.estimator_obs_components,
obs_segments= self.obs_segments,
)
input_s = self.memory_s(
subobs,
None, # use None to prevent unpadding
hidden_states if hidden_states is None else hidden_states.estimator,
)
# QUESTION: after memory_s, the estimated_state is already unpadded. How to get
# the padded format and feed it into actor's observation?
# SOLUTION: modify the code of Memory module, use masks= None to stop the unpadding.
self.estimated_state_ = self.state_estimator(input_s)
use_estimated_state_mask = torch.rand_like(observations[..., 0]) < self.replace_state_prob
observations[use_estimated_state_mask] = substitute_estimated_state(
observations[use_estimated_state_mask],
self.estimator_target_components,
self.estimated_state_[use_estimated_state_mask].detach(),
self.obs_segments,
)
if inference:
action = super().act_inference(observations)
else:
action = super().act(
observations,
masks= masks,
hidden_states= hidden_states if hidden_states is None else hidden_states.actor,
)
else:
# both state estimator and actor are feedforward (non-recurrent)
subobs = get_subobs_by_components(
observations,
component_names= self.estimator_obs_components,
obs_segments= self.obs_segments,
)
self.estimated_state_ = self.state_estimator(subobs)
use_estimated_state_mask = torch.rand_like(observations[..., 0]) < self.replace_state_prob
observations[use_estimated_state_mask] = substitute_estimated_state(
observations[use_estimated_state_mask],
self.estimator_target_components,
self.estimated_state_[use_estimated_state_mask].detach(),
self.obs_segments,
)
if inference:
action = super().act_inference(observations)
else:
action = super().act(observations, masks= masks, hidden_states= hidden_states)
return action
def act_inference(self, observations):
return self.act(observations, inference= True)
""" No modification required for evaluate() """
def get_estimated_state(self):
""" In order to maintain the same interface of ActorCritic(Recurrent),
the user must call this function after calling act() to get the estimated state.
"""
return self.estimated_state_
def get_hidden_states(self):
return_ = super().get_hidden_states()
if self.is_recurrent and not self.use_actor_rnn:
return_ = return_._replace(actor= EstimatorActorHiddenState(
self.memory_s.hidden_states,
return_.actor,
))
return return_
class StateAc(EstimatorMixin, ActorCritic):
pass
class StateAcRecurrent(EstimatorMixin, ActorCriticRecurrent):
pass

View File

@ -6,10 +6,12 @@ import time
import numpy as np
import torch
from tensorboardX import SummaryWriter
from tabulate import tabulate
from rsl_rl.modules import build_actor_critic
from rsl_rl.runners.demonstration import DemonstrationSaver
from rsl_rl.algorithms.tppo import GET_TEACHER_ACT_PROB_FUNC
from rsl_rl.algorithms.tppo import GET_PROB_FUNC
class DaggerSaver(DemonstrationSaver):
""" This demonstration saver will rollout the trajectory by running the student policy
@ -21,6 +23,7 @@ class DaggerSaver(DemonstrationSaver):
teacher_act_prob= "exp",
update_times_scale= 5000,
action_sample_std= 0.0, # if > 0, add Gaussian noise to the action in effort.
log_to_tensorboard= False, # if True, log the rollout episode info to tensorboard
**kwargs,
):
"""
@ -32,9 +35,18 @@ class DaggerSaver(DemonstrationSaver):
self.teacher_act_prob = teacher_act_prob
self.update_times_scale = update_times_scale
self.action_sample_std = action_sample_std
self.log_to_tensorboard = log_to_tensorboard
if self.log_to_tensorboard:
self.tb_writer = SummaryWriter(
log_dir= osp.join(
self.training_policy_logdir,
"_".join(["collector", *(osp.basename(self.save_dir).split("_")[:2])]),
),
flush_secs= 10,
)
if isinstance(self.teacher_act_prob, str):
self.teacher_act_prob = GET_TEACHER_ACT_PROB_FUNC(self.teacher_act_prob, update_times_scale)
self.teacher_act_prob = GET_PROB_FUNC(self.teacher_act_prob, update_times_scale)
else:
self.__teacher_act_prob = self.teacher_act_prob
self.teacher_act_prob = lambda x: self.__teacher_act_prob
@ -47,6 +59,11 @@ class DaggerSaver(DemonstrationSaver):
self.build_training_policy()
return return_
def init_storage_buffer(self):
return_ = super().init_storage_buffer()
self.rollout_episode_infos = []
return return_
def build_training_policy(self):
""" Load the latest training policy model. """
with open(osp.join(self.training_policy_logdir, "config.json"), "r") as f:
@ -77,21 +94,26 @@ class DaggerSaver(DemonstrationSaver):
time.sleep(0.1)
self.training_policy.load_state_dict(loaded_dict["model_state_dict"])
self.training_policy_iteration = loaded_dict["iter"]
# override the action std in self.training_policy
with torch.no_grad():
if self.action_sample_std > 0:
self.training_policy.std[:] = self.action_sample_std
print("Training policy iteration: {}".format(self.training_policy_iteration))
self.use_teacher_act_mask = torch.rand(self.env.num_envs) < self.teacher_act_prob(self.training_policy_iteration)
def get_transition(self):
if self.use_critic_obs:
teacher_actions = self.policy.act_inference(self.critic_obs)
else:
teacher_actions = self.policy.act_inference(self.obs)
teacher_actions = self.get_policy_actions()
actions = self.training_policy.act(self.obs)
if self.action_sample_std > 0:
actions += torch.randn_like(actions) * self.action_sample_std
actions[self.use_teacher_act_mask] = teacher_actions[self.use_teacher_act_mask]
n_obs, n_critic_obs, rewards, dones, infos = self.env.step(actions)
# Use teacher actions to label the trajectory, no matter what the student policy does
return teacher_actions, rewards, dones, infos, n_obs, n_critic_obs
def add_transition(self, step_i, infos):
return_ = super().add_transition(step_i, infos)
if "episode" in infos:
self.rollout_episode_infos.append(infos["episode"])
return return_
def policy_reset(self, dones):
return_ = super().policy_reset(dones)
@ -103,3 +125,33 @@ class DaggerSaver(DemonstrationSaver):
""" Also check whether need to load the latest training policy model. """
self.load_latest_training_policy()
return super().check_stop()
def print_log(self):
# Copy from runner logging mechanism. TODO: optimize these implementation into one.
ep_table = []
for key in self.rollout_episode_infos[0].keys():
infotensor = torch.tensor([], device= self.env.device)
for ep_info in self.rollout_episode_infos:
if not isinstance(ep_info[key], torch.Tensor):
ep_info[key] = torch.Tensor([ep_info[key]])
if len(ep_info[key].shape) == 0:
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.env.device)))
if "_max" in key:
infotensor = infotensor[~infotensor.isnan()]
value = torch.max(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
elif "_min" in key:
infotensor = infotensor[~infotensor.isnan()]
value = torch.min(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
else:
value = torch.nanmean(infotensor)
if self.log_to_tensorboard:
self.tb_writer.add_scalar('Episode/' + key, value, self.training_policy_iteration)
ep_table.append(("Episode/" + key, value.detach().cpu().item()))
# NOTE: assuming dagger trainner's iteration is always faster than collector's iteration
# Otherwise, the training_policy will not be updated.
self.training_policy_iteration += 1
print("Sampling saved for training policy iteration: {}".format(self.training_policy_iteration))
print(tabulate(ep_table))
self.rollout_episode_infos = []
return super().print_log()

View File

@ -2,6 +2,7 @@ import os
import os.path as osp
import json
import pickle
import time
import numpy as np
import torch
@ -21,12 +22,14 @@ class DemonstrationSaver:
success_traj_only = False, # if true, the trajectory terminated no by timeout will be dumped.
use_critic_obs= False,
obs_disassemble_mapping= None,
demo_by_sample= False,
):
"""
Args:
obs_disassemble_mapping (dict):
If set, the obs segment will be compressed using given type.
example: {"forward_depth": "normalized_image", "forward_rgb": "normalized_image"}
obs_disassemble_mapping (dict): If set, the obs segment will be compressed using given
type. example: {"forward_depth": "normalized_image", "forward_rgb": "normalized_image"}
demo_by_sample (bool): # if True, the action will be sampled (policy.act) from the
policy instead of using the mean (policy.act_inference).
"""
self.env = env
self.policy = policy
@ -38,6 +41,7 @@ class DemonstrationSaver:
self.use_critic_obs = use_critic_obs
self.success_traj_only = success_traj_only
self.obs_disassemble_mapping = obs_disassemble_mapping
self.demo_by_sample = demo_by_sample
self.RolloutStorageCls = RolloutStorage
def init_traj_handlers(self):
@ -106,11 +110,19 @@ class DemonstrationSaver:
self.policy_reset(dones)
self.obs, self.critic_obs = n_obs, n_critic_obs
def get_transition(self):
if self.use_critic_obs:
def get_policy_actions(self):
if self.use_critic_obs and self.demo_by_sample:
actions = self.policy.act(self.critic_obs)
elif self.use_critic_obs:
actions = self.policy.act_inference(self.critic_obs)
elif self.demo_by_sample:
actions = self.policy.act(self.obs)
else:
actions = self.policy.act_inference(self.obs)
return actions
def get_transition(self):
actions = self.get_policy_actions()
n_obs, n_critic_obs, rewards, dones, infos = self.env.step(actions)
return actions, rewards, dones, infos, n_obs, n_critic_obs
@ -153,10 +165,16 @@ class DemonstrationSaver:
pickle.dump(trajectory, f)
self.dumped_traj_lengths[env_i] += step_slice.stop - step_slice.start
self.total_timesteps += step_slice.stop - step_slice.start
with open(osp.join(self.save_dir, "metadata.json"), "w") as f:
def dump_metadata(self):
self.metadata["total_timesteps"] = self.total_timesteps.item() if isinstance(self.total_timesteps, np.int64) else self.total_timesteps
self.metadata["total_trajectories"] = self.total_traj_completed
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
json.dump(self.metadata, f, indent= 4)
def wrap_up_trajectory(self, env_i, step_slice):
# wrap up from the rollout_storage based on `step_slice`. Thus, `step_slice` must include
# the `done` step if exist.
trajectory = dict(
privileged_observations= self.rollout_storage.privileged_observations[step_slice, env_i].cpu().numpy(),
actions= self.rollout_storage.actions[step_slice, env_i].cpu().numpy(),
@ -224,8 +242,7 @@ class DemonstrationSaver:
self.dump_to_file(rollout_env_i, slice(0, self.rollout_storage.num_transitions_per_env))
else:
start_idx = 0
di = 0
while di < done_idxs.shape[0]:
for di in range(done_idxs.shape[0]):
end_idx = done_idxs[di].item()
# dump and update the traj_idx for this env
@ -233,7 +250,9 @@ class DemonstrationSaver:
self.update_traj_handler(rollout_env_i, slice(start_idx, end_idx+1))
start_idx = end_idx + 1
di += 1
if start_idx < self.rollout_storage.num_transitions_per_env:
self.dump_to_file(rollout_env_i, slice(start_idx, self.rollout_storage.num_transitions_per_env))
self.dump_metadata()
def collect_and_save(self, config= None):
""" Run the rolllout to collect the demonstration data and save it to the file """
@ -276,6 +295,11 @@ class DemonstrationSaver:
def print_log(self):
""" print the log """
self.print_log_time = time.monotonic()
if hasattr(self, "last_print_log_time"):
print("time elapsed:", self.print_log_time - self.last_print_log_time)
print("throughput:", self.total_timesteps / (self.print_log_time - self.last_print_log_time))
self.last_print_log_time = self.print_log_time
print("total_timesteps:", self.total_timesteps)
print("total_trajectories", self.total_traj_completed)
@ -292,8 +316,5 @@ class DemonstrationSaver:
os.rmdir(traj_dir)
for timestep_count in self.dumped_traj_lengths:
self.total_timesteps += timestep_count
self.metadata["total_timesteps"] = self.total_timesteps.item() if isinstance(self.total_timesteps, np.int64) else self.total_timesteps
self.metadata["total_trajectories"] = self.total_traj_completed
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
json.dump(self.metadata, f, indent= 4)
self.dump_metadata()
print(f"Saved dataset in {self.save_dir}")

View File

@ -33,12 +33,13 @@ import os
from collections import deque
import statistics
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import torch
import rsl_rl.algorithms as algorithms
import rsl_rl.modules as modules
from rsl_rl.env import VecEnv
from rsl_rl.utils import ckpt_manipulator
class OnPolicyRunner:
@ -99,12 +100,15 @@ class OnPolicyRunner:
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
print("Initialization done, start learning.")
print("NOTE: you may see a bunch of `NaN or Inf found in input tensor` once and appears in the log. Just ignore it if it does not affect the performance.")
start_iter = self.current_learning_iteration
tot_iter = self.current_learning_iteration + num_learning_iterations
tot_start_time = time.time()
start = time.time()
while self.current_learning_iteration < tot_iter:
start = time.time()
# Rollout
with torch.inference_mode():
with torch.inference_mode(self.cfg.get("inference_mode_rollout", True)):
for i in range(self.num_steps_per_env):
obs, critic_obs, rewards, dones, infos = self.rollout_step(obs, critic_obs)
@ -137,6 +141,7 @@ class OnPolicyRunner:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
ep_infos.clear()
self.current_learning_iteration = self.current_learning_iteration + 1
start = time.time()
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
@ -145,12 +150,12 @@ class OnPolicyRunner:
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
self.alg.process_env_step(rewards, dones, infos, obs, critic_obs)
return obs, critic_obs, rewards, dones, infos
def log(self, locs, width=80, pad=35):
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
self.tot_time += locs['collection_time'] + locs['learn_time']
self.tot_time = time.time() - locs['tot_start_time']
iteration_time = locs['collection_time'] + locs['learn_time']
ep_string = f''
@ -164,7 +169,14 @@ class OnPolicyRunner:
if len(ep_info[key].shape) == 0:
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
value = torch.mean(infotensor)
if "_max" in key:
infotensor = infotensor[~infotensor.isnan()]
value = torch.max(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
elif "_min" in key:
infotensor = infotensor[~infotensor.isnan()]
value = torch.min(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
else:
value = torch.nanmean(infotensor)
self.writer.add_scalar('Episode/' + key, value, self.current_learning_iteration)
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
mean_std = self.alg.actor_critic.action_std.mean()
@ -181,10 +193,12 @@ class OnPolicyRunner:
self.writer.add_scalar('Perf/collection time', locs['collection_time'], self.current_learning_iteration)
self.writer.add_scalar('Perf/learning_time', locs['learn_time'], self.current_learning_iteration)
self.writer.add_scalar('Perf/gpu_allocated', torch.cuda.memory_allocated(self.device) / 1024 ** 3, self.current_learning_iteration)
self.writer.add_scalar('Perf/gpu_occupied', torch.cuda.mem_get_info(self.device)[1] / 1024 ** 3, self.current_learning_iteration)
self.writer.add_scalar('Perf/gpu_global_free_mem', torch.cuda.mem_get_info(self.device)[0] / 1024 ** 3, self.current_learning_iteration)
self.writer.add_scalar('Perf/gpu_total', torch.cuda.mem_get_info(self.device)[1] / 1024 ** 3, self.current_learning_iteration)
self.writer.add_scalar('Train/mean_reward_each_timestep', statistics.mean(locs['rframebuffer']), self.current_learning_iteration)
if len(locs['rewbuffer']) > 0:
self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), self.current_learning_iteration)
self.writer.add_scalar('Train/ratio_above_mean_reward', statistics.mean([(1. if rew > statistics.mean(locs['rewbuffer']) else 0) for rew in locs['rewbuffer']]), self.current_learning_iteration)
self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), self.current_learning_iteration)
self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
@ -192,29 +206,37 @@ class OnPolicyRunner:
str = f" \033[1m Learning iteration {self.current_learning_iteration}/{locs['tot_iter']} \033[0m "
if len(locs['rewbuffer']) > 0:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
)
log_string = (
f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
)
for k, v in locs["losses"].items():
log_string += f"""{k:>{pad}} {v.item():.4f}\n"""
log_string += (
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
)
else:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
)
log_string = (
f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
)
for k. v in locs["losses"].items():
log_string += f"""{k:>{pad}} {v.item():.4f}\n"""
log_string += (
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
)
log_string += ep_string
log_string += (f"""{'-' * width}\n"""
@ -226,29 +248,30 @@ class OnPolicyRunner:
print(log_string)
def save(self, path, infos=None):
run_state_dict = {
'model_state_dict': self.alg.actor_critic.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(),
run_state_dict = self.alg.state_dict()
run_state_dict.update({
'iter': self.current_learning_iteration,
'infos': infos,
}
if hasattr(self.alg, "lr_scheduler"):
run_state_dict["lr_scheduler_state_dict"] = self.alg.lr_scheduler.state_dict()
})
torch.save(run_state_dict, path)
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
if "lr_scheduler_state_dict" in loaded_dict:
if not hasattr(self.alg, "lr_scheduler"):
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")
else:
self.alg.lr_scheduler.load_state_dict(loaded_dict["lr_scheduler_state_dict"])
elif hasattr(self.alg, "lr_scheduler"):
print("Warning: lr_scheduler_state_dict not found in checkpoint but lr_scheduler in algorithm. Ignoring.")
if self.cfg.get("ckpt_manipulator", False):
# suppose to be a string specifying which function to use
print("\033[1;36m Warning: using a hacky way to load the model. \033[0m")
loaded_dict = getattr(ckpt_manipulator, self.cfg["ckpt_manipulator"])(
loaded_dict,
self.alg.state_dict(),
)
print("\033[1;36m Done: using a hacky way to load the model. \033[0m")
self.alg.load_state_dict(loaded_dict)
self.current_learning_iteration = loaded_dict['iter']
if self.cfg.get("ckpt_manipulator", False):
try:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
except:
print("\033[1;36m Save manipulated checkpoint failed, ignored... \033[0m")
return loaded_dict['infos']
def get_inference_policy(self, device=None):

View File

@ -3,7 +3,7 @@ import os
import torch
from rsl_rl.runners.on_policy_runner import OnPolicyRunner
from rsl_rl.storage.rollout_dataset import RolloutDataset
from rsl_rl.storage.rollout_files.rollout_dataset import RolloutDataset
class TwoStageRunner(OnPolicyRunner):
""" A runner that have a pretrain stage which is used to collect demonstration data """
@ -17,7 +17,7 @@ class TwoStageRunner(OnPolicyRunner):
self.rollout_dataset = RolloutDataset(
**self.cfg["pretrain_dataset"],
num_envs= self.env.num_envs,
rl_device= self.alg.device,
device= self.alg.device,
)
def rollout_step(self, obs, critic_obs):
@ -31,8 +31,8 @@ class TwoStageRunner(OnPolicyRunner):
if not transition is None:
self.alg.collect_transition_from_dataset(transition, infos)
return (
transition.observation,
transition.privileged_observation,
transition.next_observation,
transition.next_privileged_observation,
transition.reward,
transition.done,
infos,

View File

@ -1,359 +0,0 @@
import os
import os.path as osp
import pickle
from collections import namedtuple, OrderedDict
import json
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info
import rsl_rl.utils.data_compresser as compresser
class RolloutDataset(IterableDataset):
Transitions = namedtuple("Transitions", [
"observation", "privileged_observation", "action", "reward", "done",
])
def __init__(self,
data_dir= None,
scan_dir= None,
num_envs= 1,
dataset_loops: int= 1,
subset_traj= None, # (start_idx, end_idx) as a slice
random_shuffle_traj_order= False, # If True, the traj_data will be loaded directoy to rl_device in a random order
keep_latest_ratio= 1.0, # If < 1., only keeps a certain ratio of the latest trajectories
keep_latest_n_trajs= 0, # If > 0 and more than n_trajectories, ignores keep_latest_ratio and keeps the latest n trajectories.
starting_frame_range= [0, 1], # if set, the starting timestep will be uniformly chose from this, when each new trajectory is loaded.
# if sampled starting frame is bigger than the trajectory length, starting frame will be 0
load_data_to_device= True, # If True, the traj_data will be loaded directoy to rl_device rather than np array
rl_device= "cpu",
):
""" choose data_dir or scan_dir, but not both. If scan_dir is chosen, the dataset will scan the
directory and treat each direct subdirectory as a dataset everytime it is initialized.
"""
self.data_dir = data_dir
self.scan_dir = scan_dir
self.num_envs = num_envs
self.max_loops = dataset_loops
self.subset_traj = subset_traj
self.random_shuffle_traj_order = random_shuffle_traj_order
self.keep_latest_ratio = keep_latest_ratio
self.keep_latest_n_trajs = keep_latest_n_trajs
self.starting_frame_range = starting_frame_range
self.load_data_to_device = load_data_to_device
self.rl_device = rl_device
# check arguments
assert not (self.data_dir is None and self.scan_dir is None), "data_dir and scan_dir cannot be both None"
self.num_looped = 0
def initialize(self):
self.load_dataset_directory()
if self.subset_traj is not None:
self.unused_traj_dirs = self.unused_traj_dirs[self.subset_traj[0]: self.subset_traj[1]]
if self.keep_latest_ratio < 1. or self.keep_latest_n_trajs > 0:
self.unused_traj_dirs = sorted(
self.unused_traj_dirs,
key= lambda x: os.stat(x).st_ctime,
)
if self.keep_latest_n_trajs > 0:
self.unused_traj_dirs = self.unused_traj_dirs[-self.keep_latest_n_trajs:]
else:
self.unused_traj_dirs = self.unused_traj_dirs[int(len(self.unused_traj_dirs) * self.keep_latest_ratio):]
print("Using a subset of trajectories, total number of trajectories: ", len(self.unused_traj_dirs))
if self.random_shuffle_traj_order:
random.shuffle(self.unused_traj_dirs)
# attributes that handles trajectory files for each env
self.current_traj_dirs = [None for _ in range(self.num_envs)]
self.trajectory_files = [[] for _ in range(self.num_envs)]
self.traj_file_idxs = np.zeros(self.num_envs, dtype= np.int32)
self.traj_step_idxs = np.zeros(self.num_envs, dtype= np.int32)
self.traj_datas = [None for _ in range(self.num_envs)]
env_idx = 0
while env_idx < self.num_envs:
if len(self.unused_traj_dirs) == 0:
print("Not enough trajectories, waiting to re-initialize. Press Enter to continue....")
input()
self.initialize()
return
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
update_result = self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
if update_result:
env_idx += 1
self.dataset_drained = False
def update_traj_handle(self, env_idx, traj_dir, starting_step_idx= 0):
""" Load and update the trajectory handle for a given env_idx.
Also update traj_step_idxs.
Return whether the trajectory is successfully loaded
"""
self.current_traj_dirs[env_idx] = traj_dir
try:
self.trajectory_files[env_idx] = sorted(
os.listdir(self.current_traj_dirs[env_idx]),
key= lambda x: int(x.split("_")[1]),
)
self.traj_file_idxs[env_idx] = 0
except:
self.nullify_traj_handles(env_idx)
return False
self.traj_datas[env_idx] = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= True,
)
if self.traj_datas[env_idx] is None:
self.nullify_traj_handles(env_idx)
return False
# The number in the file name is the timestep slice
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
while current_file_max_timestep < starting_step_idx:
self.traj_file_idxs[env_idx] += 1
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
# trajectory length is shorter than starting_step_idx, set starting_step_idx to 0
starting_step_idx = 0
self.traj_file_idxs[env_idx] = 0
break
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
current_file_min_step = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[1])
self.traj_step_idxs[env_idx] = starting_step_idx - current_file_min_step
if self.traj_file_idxs[env_idx] > 0:
# reload the traj_data because traj_file_idxs is updated
self.traj_datas[env_idx] = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= True,
)
if self.traj_datas[env_idx] is None:
self.nullify_traj_handles(env_idx)
return False
return True
def nullify_traj_handles(self, env_idx):
self.current_traj_dirs[env_idx] = ""
self.trajectory_files[env_idx] = []
self.traj_file_idxs[env_idx] = 0
self.traj_step_idxs[env_idx] = 0
self.traj_datas[env_idx] = None
def load_dataset_directory(self):
if self.scan_dir is not None:
if not osp.isdir(self.scan_dir):
print("RolloutDataset: scan_dir {} does not exist, creating...".format(self.scan_dir))
os.makedirs(self.scan_dir)
self.data_dir = sorted([
osp.join(self.scan_dir, x) \
for x in os.listdir(self.scan_dir) \
if osp.isdir(osp.join(self.scan_dir, x)) and osp.isfile(osp.join(self.scan_dir, x, "metadata.json"))
])
if isinstance(self.data_dir, list):
total_timesteps = 0
self.unused_traj_dirs = []
for data_dir in self.data_dir:
try:
new_trajectories = sorted([
osp.join(data_dir, x) \
for x in os.listdir(data_dir) \
if x.startswith("trajectory_") and len(os.listdir(osp.join(data_dir, x))) > 0
], key= lambda x: int(x.split("_")[-1]))
except:
continue
self.unused_traj_dirs.extend(new_trajectories)
try:
with open(osp.join(data_dir, "metadata.json"), "r") as f:
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
total_timesteps += self.metadata["total_timesteps"]
except:
pass # skip
print("RolloutDataset: Loaded data from multiple directories. The metadata is from the last directory.")
print("RolloutDataset: Total number of timesteps: ", total_timesteps)
print("RolloutDataset: Total number of trajectories: ", len(self.unused_traj_dirs))
else:
self.unused_traj_dirs = sorted([
osp.join(self.data_dir, x) \
for x in os.listdir(self.data_dir) \
if x.startswith("trajectory_") and len(os.listdir(osp.join(self.data_dir, x))) > 0
], key= lambda x: int(x.split("_")[-1]))
with open(osp.join(self.data_dir, "metadata.json"), "r") as f:
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
# check if this dataset is initialized in worker process
worker_info = get_worker_info()
if worker_info is not None:
self.dataset_loops = 1 # Let the sampler handle the loops
worker_id = worker_info.id
num_workers = worker_info.num_workers
trajs_per_worker = len(self.unused_traj_dirs) // num_workers
self.unused_traj_dirs = self.unused_traj_dirs[worker_id * trajs_per_worker: (worker_id + 1) * trajs_per_worker]
if worker_id == num_workers - 1:
self.unused_traj_dirs.extend(self.unused_traj_dirs[:(len(self.unused_traj_dirs) % num_workers)])
print("RolloutDataset: Worker {} of {} initialized with {} trajectories".format(
worker_id, num_workers, len(self.unused_traj_dirs)
))
def assmeble_obs_components(self, traj_data):
assert "obs_segments" in self.metadata, "Corrupted metadata, obs_segments not found in metadata"
observations = []
for component_name in self.metadata["obs_segments"].keys():
obs_component = traj_data.pop("obs_" + component_name)
if component_name in self.metadata["obs_disassemble_mapping"]:
obs_component = getattr(
compresser,
"decompress_" + self.metadata["obs_disassemble_mapping"][component_name],
)(obs_component)
observations.append(obs_component)
traj_data["observations"] = np.concatenate(observations, axis= -1) # (n_steps, d_obs)
return traj_data
def load_traj_data(self, env_idx, traj_file_idx, new_episode= False):
""" If new_episode, set the 0-th frame to done, making sure the agent is reset.
"""
traj_dir = self.current_traj_dirs[env_idx]
try:
with open(osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx]), "rb") as f:
traj_data = pickle.load(f)
except:
try:
traj_file = osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx])
print("Failed to load", traj_file)
except:
print("Failed to load file")
# The caller will know that the file is abscent, then switch to a new trajectory
return None
# connect the observation components if they are disassambled in pickle files
if "obs_disassemble_mapping" in self.metadata:
traj_data = self.assmeble_obs_components(traj_data)
if self.load_data_to_device:
for data_key, data_val in traj_data.items():
traj_data[data_key] = torch.from_numpy(data_val).to(self.rl_device)
if new_episode:
# add done flag to the 0-th step of newly loaded trajectory
traj_data["dones"][0] = True
return traj_data
def get_transition_batch(self):
if not hasattr(self, "dataset_drained"):
# initialize the dataset if it is not used as a iterator
self.initialize()
observations = []
privileged_observations = []
actions = []
rewards = []
dones = []
time_outs = []
if self.dataset_drained:
return None, None
for env_idx in range(self.num_envs):
traj_data = self.traj_datas[env_idx]
traj_step_idx = self.traj_step_idxs[env_idx]
observations.append(traj_data["observations"][traj_step_idx])
privileged_observations.append(traj_data["privileged_observations"][traj_step_idx])
actions.append(traj_data["actions"][traj_step_idx])
rewards.append(traj_data["rewards"][traj_step_idx])
dones.append(traj_data["dones"][traj_step_idx])
if "timeouts" in traj_data: time_outs.append(traj_data["timeouts"][traj_step_idx])
self.traj_step_idxs[env_idx] += 1
traj_update_result = self.update_traj_data_if_needed(env_idx)
if traj_update_result == "drained":
self.dataset_drained = True
return None, None
elif traj_update_result == "new_traj":
dones[-1][:] = True
if torch.is_tensor(observations[0]):
observations = torch.stack(observations)
else:
observations = torch.from_numpy(np.stack(observations)).to(self.rl_device)
if torch.is_tensor(privileged_observations[0]):
privileged_observations = torch.stack(privileged_observations)
else:
privileged_observations = torch.from_numpy(np.stack(privileged_observations)).to(self.rl_device)
if torch.is_tensor(actions[0]):
actions = torch.stack(actions)
else:
actions = torch.from_numpy(np.stack(actions)).to(self.rl_device)
if torch.is_tensor(rewards[0]):
rewards = torch.stack(rewards).squeeze(-1) # to remove the last dimension as the simulator env
else:
rewards = torch.from_numpy(np.stack(rewards)).to(self.rl_device).squeeze(-1)
if torch.is_tensor(dones[0]):
dones = torch.stack(dones).to(bool).squeeze(-1)
else:
dones = torch.from_numpy(np.stack(dones)).to(self.rl_device).to(bool).squeeze(-1)
infos = dict()
if time_outs:
if torch.is_tensor(time_outs[0]):
infos["time_outs"] = torch.stack(time_outs)
else:
infos["time_outs"] = torch.from_numpy(np.stack(time_outs)).to(self.rl_device)
infos["num_looped"] = self.num_looped
return self.Transitions(
observation= observations,
privileged_observation= privileged_observations,
action= actions,
reward= rewards,
done= dones,
), infos
def update_traj_data_if_needed(self, env_idx):
""" Return 'new_file', 'new_traj', 'drained', or None
"""
traj_data = self.traj_datas[env_idx]
if self.traj_step_idxs[env_idx] >= len(traj_data["rewards"]):
# to next file
self.traj_file_idxs[env_idx] += 1
self.traj_step_idxs[env_idx] = 0
traj_data = None
new_episode = False
while traj_data is None:
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
# to next trajectory
if len(self.unused_traj_dirs) == 0 or not osp.isdir(self.unused_traj_dirs[0]):
if self.max_loops > 0 and self.num_looped >= self.max_loops:
return 'drained'
else:
self.num_looped += 1
self.initialize()
return 'new_traj'
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
traj_data = self.traj_datas[env_idx]
else:
traj_data = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= new_episode,
)
if traj_data is None:
self.nullify_traj_handles(env_idx)
else:
self.traj_datas[env_idx] = traj_data
return 'new_file'
return None
def set_traj_idx(self, traj_idx, env_idx= 0):
""" Allow users to select a specific trajectory to start from """
self.current_traj_dirs[env_idx] = self.unused_traj_dirs[traj_idx]
self.traj_file_idxs[env_idx] = 0
self.traj_step_idxs[env_idx] = 0
self.trajectory_files[env_idx] = sorted(
os.listdir(self.current_traj_dirs[env_idx]),
key= lambda x: int(x.split("_")[1]),
)
self.traj_datas[env_idx] = self.load_traj_data(env_idx, self.traj_file_idxs[env_idx])
self.dataset_drained = False
##### Interfaces for the IterableDataset #####
def __iter__(self):
self.initialize()
transition_batch, infos = self.get_transition_batch()
while transition_batch is not None:
yield transition_batch, infos
transition_batch, infos = self.get_transition_batch()

Some files were not shown because too many files have changed in this diff Show More