[upgrade] unified interface and a easier example
* embed all function and new concepts into legged_robot for base implementation * add namedarraytuple concept borrowed from astooke/rlpyt * use namedarraytuple in rollout storage and summarize a minibatch object * add `rollout_file` concept to store rollout data in files and demonstrations * add state estimator module/algo implementation * complete rewrite `play.py` example * rename `actions_scaled_torque_clipped` to `actions_scaled_clipped` * add example onboard codes for deploying on Go2.
This commit is contained in:
parent
96317ac12f
commit
1ffd6d7c05
14
README.md
14
README.md
|
@ -16,7 +16,7 @@ Conference on Robot Learning (CoRL) 2023, **Oral**, **Best Systems Paper Award F
|
|||
|
||||
## Repository Structure ##
|
||||
* `legged_gym`: contains the isaacgym environment and config files.
|
||||
- `legged_gym/legged_gym/envs/a1/`: contains all the training config files.
|
||||
- `legged_gym/legged_gym/envs/{robot}/`: contains all the training config files for a specific robot
|
||||
- `legged_gym/legged_gym/envs/base/`: contains all the environment implementation.
|
||||
- `legged_gym/legged_gym/utils/terrain/`: contains the terrain generation code.
|
||||
* `rsl_rl`: contains the network module and algorithm implementation. You can copy this folder directly to your robot.
|
||||
|
@ -24,19 +24,23 @@ Conference on Robot Learning (CoRL) 2023, **Oral**, **Best Systems Paper Award F
|
|||
- `rsl_rl/rsl_rl/modules/`: contains the network module implementation.
|
||||
|
||||
## Training in Simulation ##
|
||||
To install and run the code for training A1 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
|
||||
To install and run the code for training A1/Go2 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
|
||||
|
||||
## Hardware Deployment ##
|
||||
To deploy the trained model on your real robot, please follow the instructions in [Deploy.md](Deploy.md).
|
||||
To deploy the trained model on your unitree Go1 robot, please follow the instructions in [Deploy-Go1.md](onboard_codes/Deploy-Go1.md) for deploying on the Unittree Go1 robot.
|
||||
|
||||
To deploy the trained model on your unitree Go2 robot, please follow the instructions in [Deploy-Go2.md](onboard_codes/Deploy-Go2.md) for deploying on the Unittree Go2 robot.
|
||||
|
||||
|
||||
## Trouble Shooting ##
|
||||
If you cannot run the distillation part or all graphics computing goes to GPU 0 dispite you have multiple GPUs and have set the CUDA_VISIBLE_DEVICES, please use docker to isolate each GPU.
|
||||
|
||||
## To Do (will be done before Nov 2023) ##
|
||||
- [x] Go1 training configuration (not from scratch)
|
||||
## To Do ##
|
||||
- [x] Go1 training configuration (does not guarantee the same performance as the paper)
|
||||
- [ ] A1 deployment code
|
||||
- [x] Go1 deployment code
|
||||
- [x] Go2 training configuration example (does not guarantee the same performance as the paper)
|
||||
- [x] Go2 deployment code example
|
||||
|
||||
## Citation ##
|
||||
If you find this project helpful to your research, please consider cite us! This is really important to us.
|
||||
|
|
|
@ -19,6 +19,53 @@ This is the tutorial for training the skill policy and distilling the parkour po
|
|||
## Usage ##
|
||||
***Always run your script in the root path of this legged_gym folder (which contains a `setup.py` file).***
|
||||
|
||||
### Go2 example (newer and simplier, does not guarantee performance) ###
|
||||
|
||||
1. Train a walking policy for planar locomotion
|
||||
```python
|
||||
python legged_gym/scripts/train.py --headless --task go2
|
||||
```
|
||||
Training logs will be saved in `logs/rough_go2`.
|
||||
|
||||
2. Train the parkour policy. In this example, we use **scandot** for terrain perception, so we remove the **crawl** skill.
|
||||
|
||||
- Update the `"{Your trained walking model directory}"` value in the config file `legged_gym/legged_gym/envs/go2/go2_field_config.py` with the trained walking policy folder name.
|
||||
|
||||
- Run ```python legged_gym/scripts/train.py --headless --task go2_field```
|
||||
|
||||
The training logs will be saved in `logs/field_go2`.
|
||||
|
||||
3. Distill the parkour policy
|
||||
|
||||
- Update the following literals in the config file `legged_gym/legged_gym/envs/go2/go2_distill_config.py`
|
||||
|
||||
- `"{Your trained oracle parkour model folder}"`: The oracle parkour policy folder name in the last step.
|
||||
|
||||
- `"{The latest model filename in the directory}"`: The model file name in the oracle parkour policy folder in the last step.
|
||||
|
||||
- `"{A temporary directory to store collected trajectory}"`: A temporary directory to store the collected trajectory data.
|
||||
|
||||
- Calibrate your depth camera extrinsic pose and update the `position` and `rotation` field in `sensor.forward_camera` class
|
||||
|
||||
- Run distillation process (choose one among the two options)
|
||||
|
||||
1. Run the distillation process in a single process if you believe you have a powerful GPU than Nvidia RTX 3090.
|
||||
|
||||
Set `multi_process_` to `False` in the config file.
|
||||
|
||||
Run ```python legged_gym/scripts/train.py --headless --task go2_distill```
|
||||
|
||||
2. Run the distillation process in multiple processes with multiple GPUs
|
||||
|
||||
Run ```python legged_gym/scripts/train.py --headless --task go2_distill```
|
||||
|
||||
Find the log directory generated by the training process when prompted waiting. (e.g. **Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08**)
|
||||
|
||||
Run ```python legged_gym/scripts/collect.py --headless --task go2_distill --log --load_run {the log directory name}``` in another terminal on another GPU. (Can run multiple collectors in parallel)
|
||||
|
||||
|
||||
### A1 example ###
|
||||
|
||||
1. The specialized skill policy is trained using `a1_field_config.py` as task `a1_field`
|
||||
|
||||
Run command with `python legged_gym/scripts/train.py --headless --task a1_field`
|
||||
|
@ -31,11 +78,11 @@ This is the tutorial for training the skill policy and distilling the parkour po
|
|||
|
||||
With `python legged_gym/scripts/collect.py --headless --task a1_distill --load_run {your training run}` you lauched the collector. The process will load the training policy and start collecting the data. The collected data will be saved in the directory prompted by the trainer. Remove it after you finish distillation.
|
||||
|
||||
### Train a walk policy ###
|
||||
#### Train a walk policy ####
|
||||
|
||||
Launch the training by `python legged_gym/scripts/train.py --headless --task a1_field`. You will find the training log in `logs/a1_field`. The folder name is also the run name.
|
||||
|
||||
### Train each separate skill ###
|
||||
#### Train each separate skill ####
|
||||
|
||||
- Launch the scirpt with task `a1_climb`, `a1_leap`, `a1_crawl`, `a1_tilt`. The training log will also be saved in `logs/a1_field`.
|
||||
|
||||
|
@ -47,7 +94,7 @@ Launch the training by `python legged_gym/scripts/train.py --headless --task a1_
|
|||
|
||||
- Do remember to update the `load_run` field in the corresponding log directory to load the policy from the previous stage.
|
||||
|
||||
### Distill the parkour policy ###
|
||||
#### Distill the parkour policy ####
|
||||
|
||||
**You will need at least two GPUs that can render in IsaacGym and have at least 24GB of memory. (typically RTX 3090)**
|
||||
|
||||
|
@ -82,3 +129,5 @@ Launch the training by `python legged_gym/scripts/train.py --headless --task a1_
|
|||
```bash
|
||||
python legged_gym/scripts/play.py --task {task} --load_run {run_name}
|
||||
```
|
||||
|
||||
Where `{run_name}` can be the absolute path of your log directory (which contains the `config.json` file).
|
||||
|
|
|
@ -32,7 +32,7 @@ from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
|
|||
from legged_gym.envs.a1.a1_config import A1RoughCfg, A1RoughCfgPPO, A1PlaneCfg, A1RoughCfgTPPO
|
||||
from .base.legged_robot import LeggedRobot
|
||||
from .base.legged_robot_field import LeggedRobotField
|
||||
from .base.legged_robot_noisy import LeggedRobotNoisy
|
||||
from .base.robot_field_noisy import RobotFieldNoisy
|
||||
from .anymal_c.anymal import Anymal
|
||||
from .anymal_c.mixed_terrains.anymal_c_rough_config import AnymalCRoughCfg, AnymalCRoughCfgPPO
|
||||
from .anymal_c.flat.anymal_c_flat_config import AnymalCFlatCfg, AnymalCFlatCfgPPO
|
||||
|
@ -44,6 +44,9 @@ from .a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
|
|||
from .a1.a1_field_distill_config import A1FieldDistillCfg, A1FieldDistillCfgPPO
|
||||
from .go1.go1_field_config import Go1FieldCfg, Go1FieldCfgPPO
|
||||
from .go1.go1_field_distill_config import Go1FieldDistillCfg, Go1FieldDistillCfgPPO
|
||||
from .go2.go2_config import Go2RoughCfg, Go2RoughCfgPPO
|
||||
from .go2.go2_field_config import Go2FieldCfg, Go2FieldCfgPPO
|
||||
from .go2.go2_distill_config import Go2DistillCfg, Go2DistillCfgPPO
|
||||
|
||||
|
||||
import os
|
||||
|
@ -54,35 +57,15 @@ task_registry.register( "anymal_c_rough", Anymal, AnymalCRoughCfg(), AnymalCRoug
|
|||
task_registry.register( "anymal_c_flat", Anymal, AnymalCFlatCfg(), AnymalCFlatCfgPPO() )
|
||||
task_registry.register( "anymal_b", Anymal, AnymalBRoughCfg(), AnymalBRoughCfgPPO() )
|
||||
task_registry.register( "a1", LeggedRobot, A1RoughCfg(), A1RoughCfgPPO() )
|
||||
task_registry.register( "a1_teacher", LeggedRobot, A1PlaneCfg(), A1RoughCfgTPPO() )
|
||||
task_registry.register( "a1_field", LeggedRobotNoisy, A1FieldCfg(), A1FieldCfgPPO() )
|
||||
task_registry.register( "a1_distill", LeggedRobotNoisy, A1FieldDistillCfg(), A1FieldDistillCfgPPO() )
|
||||
task_registry.register( "cassie", Cassie, CassieRoughCfg(), CassieRoughCfgPPO() )
|
||||
task_registry.register( "go1_field", LeggedRobotNoisy, Go1FieldCfg(), Go1FieldCfgPPO())
|
||||
task_registry.register( "go1_distill", LeggedRobotNoisy, Go1FieldDistillCfg(), Go1FieldDistillCfgPPO())
|
||||
task_registry.register( "go1_field", LeggedRobot, Go1FieldCfg(), Go1FieldCfgPPO())
|
||||
task_registry.register( "go1_distill", LeggedRobot, Go1FieldDistillCfg(), Go1FieldDistillCfgPPO())
|
||||
task_registry.register( "go2", LeggedRobot, Go2RoughCfg(), Go2RoughCfgPPO() )
|
||||
task_registry.register( "go2_field", RobotFieldNoisy, Go2FieldCfg(), Go2FieldCfgPPO() )
|
||||
task_registry.register( "go2_distill", RobotFieldNoisy, Go2DistillCfg(), Go2DistillCfgPPO() )
|
||||
|
||||
## The following tasks are for the convinience of opensource
|
||||
from .a1.a1_remote_config import A1RemoteCfg, A1RemoteCfgPPO
|
||||
task_registry.register( "a1_remote", LeggedRobotNoisy, A1RemoteCfg(), A1RemoteCfgPPO() )
|
||||
from .a1.a1_jump_config import A1JumpCfg, A1JumpCfgPPO
|
||||
task_registry.register( "a1_jump", LeggedRobotNoisy, A1JumpCfg(), A1JumpCfgPPO() )
|
||||
from .a1.a1_down_config import A1DownCfg, A1DownCfgPPO
|
||||
task_registry.register( "a1_down", LeggedRobotNoisy, A1DownCfg(), A1DownCfgPPO() )
|
||||
from .a1.a1_leap_config import A1LeapCfg, A1LeapCfgPPO
|
||||
task_registry.register( "a1_leap", LeggedRobotNoisy, A1LeapCfg(), A1LeapCfgPPO() )
|
||||
from .a1.a1_crawl_config import A1CrawlCfg, A1CrawlCfgPPO
|
||||
task_registry.register( "a1_crawl", LeggedRobotNoisy, A1CrawlCfg(), A1CrawlCfgPPO() )
|
||||
from .a1.a1_tilt_config import A1TiltCfg, A1TiltCfgPPO
|
||||
task_registry.register( "a1_tilt", LeggedRobotNoisy, A1TiltCfg(), A1TiltCfgPPO() )
|
||||
task_registry.register( "a1_remote", LeggedRobot, A1RemoteCfg(), A1RemoteCfgPPO() )
|
||||
from .go1.go1_remote_config import Go1RemoteCfg, Go1RemoteCfgPPO
|
||||
task_registry.register( "go1_remote", LeggedRobotNoisy, Go1RemoteCfg(), Go1RemoteCfgPPO() )
|
||||
from .go1.go1_jump_config import Go1JumpCfg, Go1JumpCfgPPO
|
||||
task_registry.register( "go1_jump", LeggedRobotNoisy, Go1JumpCfg(), Go1JumpCfgPPO() )
|
||||
from .go1.go1_down_config import Go1DownCfg, Go1DownCfgPPO
|
||||
task_registry.register( "go1_down", LeggedRobotNoisy, Go1DownCfg(), Go1DownCfgPPO() )
|
||||
from .go1.go1_leap_config import Go1LeapCfg, Go1LeapCfgPPO
|
||||
task_registry.register( "go1_leap", LeggedRobotNoisy, Go1LeapCfg(), Go1LeapCfgPPO() )
|
||||
from .go1.go1_crawl_config import Go1CrawlCfg, Go1CrawlCfgPPO
|
||||
task_registry.register( "go1_crawl", LeggedRobotNoisy, Go1CrawlCfg(), Go1CrawlCfgPPO() )
|
||||
from .go1.go1_tilt_config import Go1TiltCfg, Go1TiltCfgPPO
|
||||
task_registry.register( "go1_tilt", LeggedRobotNoisy, Go1TiltCfg(), Go1TiltCfgPPO() )
|
||||
task_registry.register( "go1_remote", LeggedRobot, Go1RemoteCfg(), Go1RemoteCfgPPO() )
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import os.path as osp
|
||||
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
|
||||
from legged_gym.utils.helpers import merge_dict
|
||||
|
||||
|
@ -7,7 +8,6 @@ class A1CrawlCfg( A1FieldCfg ):
|
|||
#### uncomment this to train non-virtual terrain
|
||||
class sensor( A1FieldCfg.sensor ):
|
||||
class proprioception( A1FieldCfg.sensor.proprioception ):
|
||||
delay_action_obs = True
|
||||
latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
#### uncomment the above to train non-virtual terrain
|
||||
|
||||
|
@ -28,11 +28,11 @@ class A1CrawlCfg( A1FieldCfg ):
|
|||
wall_height= 0.6,
|
||||
no_perlin_at_obstacle= False,
|
||||
),
|
||||
virtual_terrain= True, # Change this to False for real terrain
|
||||
virtual_terrain= False, # Change this to False for real terrain
|
||||
))
|
||||
|
||||
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
|
||||
zScale= 0.1,
|
||||
zScale= 0.12,
|
||||
))
|
||||
|
||||
class commands( A1FieldCfg.commands ):
|
||||
|
@ -41,6 +41,9 @@ class A1CrawlCfg( A1FieldCfg ):
|
|||
lin_vel_y = [0.0, 0.0]
|
||||
ang_vel_yaw = [0., 0.]
|
||||
|
||||
class asset( A1FieldCfg.asset ):
|
||||
terminate_after_contacts_on = ["base"]
|
||||
|
||||
class termination( A1FieldCfg.termination ):
|
||||
# additional factors that determines whether to terminates the episode
|
||||
termination_terms = [
|
||||
|
@ -51,16 +54,29 @@ class A1CrawlCfg( A1FieldCfg ):
|
|||
"out_of_track",
|
||||
]
|
||||
|
||||
class domain_rand( A1FieldCfg.domain_rand ):
|
||||
init_base_rot_range = dict(
|
||||
roll= [-0.1, 0.1],
|
||||
pitch= [-0.1, 0.1],
|
||||
)
|
||||
|
||||
class rewards( A1FieldCfg.rewards ):
|
||||
class scales:
|
||||
tracking_ang_vel = 0.05
|
||||
world_vel_l2norm = -1.
|
||||
legs_energy_substeps = -2e-5
|
||||
legs_energy_substeps = -1e-5
|
||||
alive = 2.
|
||||
penetrate_depth = -6e-2 # comment this out if trianing non-virtual terrain
|
||||
penetrate_volume = -6e-2 # comment this out if trianing non-virtual terrain
|
||||
exceed_dof_pos_limits = -1e-1
|
||||
exceed_torque_limits_i = -2e-1
|
||||
# penetrate_depth = -6e-2 # comment this out if trianing non-virtual terrain
|
||||
# penetrate_volume = -6e-2 # comment this out if trianing non-virtual terrain
|
||||
exceed_dof_pos_limits = -8e-1
|
||||
# exceed_torque_limits_i = -2e-1
|
||||
exceed_torque_limits_l1norm = -4e-1
|
||||
# collision = -0.05
|
||||
# tilt_cond = 0.1
|
||||
torques = -1e-5
|
||||
yaw_abs = -0.1
|
||||
lin_pos_y = -0.1
|
||||
soft_dof_pos_limit = 0.9
|
||||
|
||||
class curriculum( A1FieldCfg.curriculum ):
|
||||
penetrate_volume_threshold_harder = 1500
|
||||
|
@ -69,27 +85,49 @@ class A1CrawlCfg( A1FieldCfg ):
|
|||
penetrate_depth_threshold_easier = 400
|
||||
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class A1CrawlCfgPPO( A1FieldCfgPPO ):
|
||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
clip_min_std = 0.2
|
||||
clip_min_std = 0.1
|
||||
|
||||
class runner( A1FieldCfgPPO.runner ):
|
||||
policy_class_name = "ActorCriticRecurrent"
|
||||
experiment_name = "field_a1"
|
||||
run_name = "".join(["Skill",
|
||||
resume = True
|
||||
load_run = "{Your traind walking model directory}"
|
||||
load_run = "{Your virtually trained crawling model directory}"
|
||||
# load_run = "Aug21_06-12-58_Skillcrawl_propDelay0.00-0.05_virtual"
|
||||
# load_run = osp.join(logs_root, "field_a1_oracle/May21_05-25-19_Skills_crawl_pEnergy2e-5_rAlive1_pPenV6e-2_pPenD6e-2_pPosY0.2_kp50_noContactTerminate_aScale0.5")
|
||||
# load_run = osp.join(logs_root, "field_a1_oracle/Sep26_01-38-19_Skills_crawl_propDelay0.04-0.05_pEnergy-4e-5_pTorqueL13e-01_kp40_fromMay21_05-25-19")
|
||||
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep26_14-30-24_Skills_crawl_propDelay0.04-0.05_pEnergy-2e-5_pDof8e-01_pTorqueL14e-01_rTilt5e-01_pCollision0.2_maxPushAng0.5_kp40_fromSep26_01-38-19")
|
||||
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct09_09-58-26_Skills_crawl_propDelay0.04-0.05_pEnergy-1e-5_pDof8e-01_pTorqueL14e-01_maxPushAng0.0_kp40_fromSep26_14-30-24")
|
||||
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct11_12-19-00_Skills_crawl_propDelay0.04-0.05_pEnergy-1e-5_pDof8e-01_pTorqueL14e-01_pPosY0.1_maxPushAng0.3_kp40_fromOct09_09-58-26")
|
||||
|
||||
run_name = "".join(["Skills_",
|
||||
("Multi" if len(A1CrawlCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1CrawlCfg.terrain.BarrierTrack_kwargs["options"][0] if A1CrawlCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
|
||||
("_comXRange{:.1f}-{:.1f}".format(A1CrawlCfg.domain_rand.com_range.x[0], A1CrawlCfg.domain_rand.com_range.x[1])),
|
||||
("_noLinVel" if not A1CrawlCfg.env.use_lin_vel else ""),
|
||||
("_propDelay{:.2f}-{:.2f}".format(
|
||||
A1CrawlCfg.sensor.proprioception.latency_range[0],
|
||||
A1CrawlCfg.sensor.proprioception.latency_range[1],
|
||||
) if A1CrawlCfg.sensor.proprioception.delay_action_obs else ""
|
||||
),
|
||||
("_pEnergy" + np.format_float_scientific(A1CrawlCfg.rewards.scales.legs_energy_substeps, precision= 1, exp_digits= 1, trim= "-") if A1CrawlCfg.rewards.scales.legs_energy_substeps != 0. else ""),
|
||||
# ("_pPenD{:.0e}".format(A1CrawlCfg.rewards.scales.penetrate_depth) if getattr(A1CrawlCfg.rewards.scales, "penetrate_depth", 0.) != 0. else ""),
|
||||
("_pEnergySubsteps" + np.format_float_scientific(A1CrawlCfg.rewards.scales.legs_energy_substeps, precision= 1, exp_digits= 1, trim= "-") if getattr(A1CrawlCfg.rewards.scales, "legs_energy_substeps", 0.) != 0. else ""),
|
||||
("_pDof{:.0e}".format(-A1CrawlCfg.rewards.scales.exceed_dof_pos_limits) if getattr(A1CrawlCfg.rewards.scales, "exceed_dof_pos_limits", 0.) != 0 else ""),
|
||||
("_pTorque" + np.format_float_scientific(-A1CrawlCfg.rewards.scales.torques, precision= 1, exp_digits= 1, trim= "-") if getattr(A1CrawlCfg.rewards.scales, "torques", 0.) != 0 else ""),
|
||||
("_pTorqueL1{:.0e}".format(-A1CrawlCfg.rewards.scales.exceed_torque_limits_l1norm) if getattr(A1CrawlCfg.rewards.scales, "exceed_torque_limits_l1norm", 0.) != 0 else ""),
|
||||
# ("_rTilt{:.0e}".format(A1CrawlCfg.rewards.scales.tilt_cond) if getattr(A1CrawlCfg.rewards.scales, "tilt_cond", 0.) != 0 else ""),
|
||||
# ("_pYaw{:.1f}".format(-A1CrawlCfg.rewards.scales.yaw_abs) if getattr(A1CrawlCfg.rewards.scales, "yaw_abs", 0.) != 0 else ""),
|
||||
# ("_pPosY{:.1f}".format(-A1CrawlCfg.rewards.scales.lin_pos_y) if getattr(A1CrawlCfg.rewards.scales, "lin_pos_y", 0.) != 0 else ""),
|
||||
# ("_pCollision{:.1f}".format(-A1CrawlCfg.rewards.scales.collision) if getattr(A1CrawlCfg.rewards.scales, "collision", 0.) != 0 else ""),
|
||||
# ("_kp{:d}".format(int(A1CrawlCfg.control.stiffness["joint"])) if A1CrawlCfg.control.stiffness["joint"] != 50 else ""),
|
||||
("_noDelayActObs" if not A1CrawlCfg.sensor.proprioception.delay_action_obs else ""),
|
||||
("_noTanh"),
|
||||
("_virtual" if A1CrawlCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
|
||||
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
resume = True
|
||||
load_run = "{Your traind walking model directory}"
|
||||
load_run = "{Your virtually trained crawling model directory}"
|
||||
max_iterations = 20000
|
||||
save_interval = 500
|
||||
|
|
@ -200,6 +200,7 @@ class A1FieldCfg( A1RoughCfg ):
|
|||
exceed_dof_pos_limits = -1e-1
|
||||
exceed_torque_limits_i = -2e-1
|
||||
soft_dof_pos_limit = 0.01
|
||||
only_positive_rewards = False
|
||||
|
||||
class normalization( A1RoughCfg.normalization ):
|
||||
class obs_scales( A1RoughCfg.normalization.obs_scales ):
|
||||
|
@ -286,7 +287,7 @@ class A1FieldCfgPPO( A1RoughCfgPPO ):
|
|||
("_propDelay{:.2f}-{:.2f}".format(
|
||||
A1FieldCfg.sensor.proprioception.latency_range[0],
|
||||
A1FieldCfg.sensor.proprioception.latency_range[1],
|
||||
) if A1FieldCfg.sensor.proprioception.delay_action_obs else ""
|
||||
) if A1FieldCfg.sensor.proprioception.latency_range[1] > 0. else ""
|
||||
),
|
||||
("_aScale{:d}{:d}{:d}".format(
|
||||
int(A1FieldCfg.control.action_scale[0] * 10),
|
||||
|
@ -297,6 +298,6 @@ class A1FieldCfgPPO( A1RoughCfgPPO ):
|
|||
),
|
||||
])
|
||||
resume = False
|
||||
max_iterations = 10000
|
||||
max_iterations = 5000
|
||||
save_interval = 500
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
import numpy as np
|
||||
from os import path as osp
|
||||
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
|
||||
from legged_gym.utils.helpers import merge_dict
|
||||
|
||||
class A1LeapCfg( A1FieldCfg ):
|
||||
|
||||
#### uncomment this to train non-virtual terrain
|
||||
# class sensor( A1FieldCfg.sensor ):
|
||||
# class proprioception( A1FieldCfg.sensor.proprioception ):
|
||||
# delay_action_obs = True
|
||||
# latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
#### uncomment the above to train non-virtual terrain
|
||||
### uncomment this to train non-virtual terrain
|
||||
class sensor( A1FieldCfg.sensor ):
|
||||
class proprioception( A1FieldCfg.sensor.proprioception ):
|
||||
latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
### uncomment the above to train non-virtual terrain
|
||||
|
||||
class terrain( A1FieldCfg.terrain ):
|
||||
max_init_terrain_level = 2
|
||||
|
@ -22,16 +22,16 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
"leap",
|
||||
],
|
||||
leap= dict(
|
||||
length= (0.2, 1.0),
|
||||
length= (0.2, 0.8),
|
||||
depth= (0.4, 0.8),
|
||||
height= 0.2,
|
||||
height= 0.12,
|
||||
),
|
||||
virtual_terrain= False, # Change this to False for real terrain
|
||||
no_perlin_threshold= 0.06,
|
||||
virtual_terrain= True, # Change this to False for real terrain
|
||||
no_perlin_threshold= 0.1,
|
||||
))
|
||||
|
||||
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
|
||||
zScale= [0.05, 0.1],
|
||||
zScale= [0.05, 0.15],
|
||||
))
|
||||
|
||||
class commands( A1FieldCfg.commands ):
|
||||
|
@ -57,16 +57,28 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
threshold= 2.0,
|
||||
))
|
||||
|
||||
class domain_rand( A1FieldCfg.domain_rand ):
|
||||
init_base_rot_range = dict(
|
||||
roll= [-0.1, 0.1],
|
||||
pitch= [-0.1, 0.1],
|
||||
)
|
||||
|
||||
class rewards( A1FieldCfg.rewards ):
|
||||
class scales:
|
||||
tracking_ang_vel = 0.05
|
||||
world_vel_l2norm = -1.
|
||||
legs_energy_substeps = -1e-6
|
||||
# legs_energy_substeps = -8e-6
|
||||
alive = 2.
|
||||
penetrate_depth = -4e-3
|
||||
penetrate_volume = -4e-3
|
||||
exceed_dof_pos_limits = -1e-1
|
||||
exceed_torque_limits_i = -2e-1
|
||||
penetrate_depth = -1e-2
|
||||
penetrate_volume = -1e-2
|
||||
exceed_dof_pos_limits = -4e-1
|
||||
exceed_torque_limits_l1norm = -8e-1
|
||||
# feet_contact_forces = -1e-2
|
||||
torques = -2e-5
|
||||
collision = -0.5
|
||||
lin_pos_y = -0.1
|
||||
yaw_abs = -0.1
|
||||
soft_dof_pos_limit = 0.5
|
||||
|
||||
class curriculum( A1FieldCfg.curriculum ):
|
||||
penetrate_volume_threshold_harder = 9000
|
||||
|
@ -75,6 +87,7 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
penetrate_depth_threshold_easier = 5000
|
||||
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
|
@ -83,19 +96,39 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
|||
class runner( A1FieldCfgPPO.runner ):
|
||||
policy_class_name = "ActorCriticRecurrent"
|
||||
experiment_name = "field_a1"
|
||||
run_name = "".join(["Skill",
|
||||
resume = True
|
||||
load_run = "{Your traind walking model directory}"
|
||||
load_run = "{Your virtually trained leap model directory}"
|
||||
# load_run = osp.join(logs_root, "field_a1_oracle/Jun04_01-03-59_Skills_leap_pEnergySubsteps2e-6_rAlive2_pPenV4e-3_pPenD4e-3_pPosY0.20_pYaw0.20_pTorqueExceedSquare1e-3_leapH0.2_propDelay0.04-0.05_noPerlinRate0.2_aScale0.5")
|
||||
# load_run = "Sep27_02-44-48_Skills_leap_propDelay0.04-0.05_pDofLimit8e-01_pCollision0.1_kp40_kd0.5fromJun04_01-03-59"
|
||||
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep27_14-56-25_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pDofLimit8e-01_pCollision0.1_kp40_kd0.5fromSep27_02-44-48")
|
||||
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct05_02-16-22_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pPenD8.e-3_pDofLimit4e-01_pCollision0.5_kp40_kd0.5fromSep27_14-56-25")
|
||||
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct09_09-51-58_Skills_leap_propDelay0.04-0.05_pEnergySubsteps-8e-06_pPenD1.e-2_pDofLimit4e-01_pCollision0.5_kp40_kd0.5fromOct05_02-16-22")
|
||||
|
||||
run_name = "".join(["Skills_",
|
||||
("Multi" if len(A1LeapCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1LeapCfg.terrain.BarrierTrack_kwargs["options"][0] if A1LeapCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
|
||||
("_comXRange{:.1f}-{:.1f}".format(A1LeapCfg.domain_rand.com_range.x[0], A1LeapCfg.domain_rand.com_range.x[1])),
|
||||
("_noLinVel" if not A1LeapCfg.env.use_lin_vel else ""),
|
||||
("_propDelay{:.2f}-{:.2f}".format(
|
||||
A1LeapCfg.sensor.proprioception.latency_range[0],
|
||||
A1LeapCfg.sensor.proprioception.latency_range[1],
|
||||
) if A1LeapCfg.sensor.proprioception.delay_action_obs else ""
|
||||
),
|
||||
("_pEnergySubsteps{:.0e}".format(A1LeapCfg.rewards.scales.legs_energy_substeps) if A1LeapCfg.rewards.scales.legs_energy_substeps != -2e-6 else ""),
|
||||
("_pEnergySubsteps{:.0e}".format(A1LeapCfg.rewards.scales.legs_energy_substeps) if getattr(A1LeapCfg.rewards.scales, "legs_energy_substeps", -2e-6) != -2e-6 else ""),
|
||||
("_pTorques" + np.format_float_scientific(-A1LeapCfg.rewards.scales.torques, precision=1, exp_digits=1) if getattr(A1LeapCfg.rewards.scales, "torques", 0.) != -0. else ""),
|
||||
# ("_pPenD" + np.format_float_scientific(-A1LeapCfg.rewards.scales.penetrate_depth, precision=1, exp_digits=1) if A1LeapCfg.rewards.scales.penetrate_depth != -4e-3 else ""),
|
||||
# ("_pYaw{:.1f}".format(-A1LeapCfg.rewards.scales.yaw_abs) if getattr(A1LeapCfg.rewards.scales, "yaw_abs", 0.) != 0. else ""),
|
||||
# ("_pDofLimit{:.0e}".format(-A1LeapCfg.rewards.scales.exceed_dof_pos_limits) if getattr(A1LeapCfg.rewards.scales, "exceed_dof_pos_limits", 0.) != 0. else ""),
|
||||
# ("_pCollision{:.1f}".format(-A1LeapCfg.rewards.scales.collision) if getattr(A1LeapCfg.rewards.scales, "collision", 0.) != 0. else ""),
|
||||
("_pContactForces" + np.format_float_scientific(-A1LeapCfg.rewards.scales.feet_contact_forces, precision=1, exp_digits=1) if getattr(A1LeapCfg.rewards.scales, "feet_contact_forces", 0.) != 0. else ""),
|
||||
("_leapHeight{:.1f}".format(A1LeapCfg.terrain.BarrierTrack_kwargs["leap"]["height"]) if A1LeapCfg.terrain.BarrierTrack_kwargs["leap"]["height"] != 0.2 else ""),
|
||||
# ("_kp{:d}".format(int(A1LeapCfg.control.stiffness["joint"])) if A1LeapCfg.control.stiffness["joint"] != 50 else ""),
|
||||
# ("_kd{:.1f}".format(A1LeapCfg.control.damping["joint"]) if A1LeapCfg.control.damping["joint"] != 0. else ""),
|
||||
("_noDelayActObs" if not A1LeapCfg.sensor.proprioception.delay_action_obs else ""),
|
||||
("_noTanh"),
|
||||
("_virtual" if A1LeapCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
|
||||
("_noResume" if not resume else "from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
resume = True
|
||||
load_run = "{Your traind walking model directory}"
|
||||
load_run = "{Your virtually trained leap model directory}"
|
||||
max_iterations = 20000
|
||||
save_interval = 500
|
||||
|
|
@ -1,452 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple
|
||||
|
||||
import rospy
|
||||
from unitree_legged_msgs.msg import LowState
|
||||
from unitree_legged_msgs.msg import LegsCmd
|
||||
from unitree_legged_msgs.msg import Float32MultiArrayStamped
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
from geometry_msgs.msg import Twist, Pose
|
||||
from nav_msgs.msg import Odometry
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
import ros_numpy
|
||||
|
||||
@torch.no_grad()
|
||||
def resize2d(img, size):
|
||||
return (F.adaptive_avg_pool2d(Variable(img), size)).data
|
||||
|
||||
@torch.jit.script
|
||||
def quat_rotate_inverse(q, v):
|
||||
shape = q.shape
|
||||
q_w = q[:, -1]
|
||||
q_vec = q[:, :3]
|
||||
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
|
||||
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
|
||||
c = q_vec * \
|
||||
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
|
||||
shape[0], 3, 1)).squeeze(-1) * 2.0
|
||||
return a - b + c
|
||||
|
||||
class UnitreeA1Real:
|
||||
""" This is the handler that works for ROS 1 on unitree. """
|
||||
def __init__(self,
|
||||
robot_namespace= "a112138",
|
||||
low_state_topic= "/low_state",
|
||||
legs_cmd_topic= "/legs_cmd",
|
||||
forward_depth_topic = "/camera/depth/image_rect_raw",
|
||||
forward_depth_embedding_dims = None,
|
||||
odom_topic= "/odom/filtered",
|
||||
lin_vel_deadband= 0.2,
|
||||
ang_vel_deadband= 0.1,
|
||||
move_by_wireless_remote= False, # if True, command will not listen to move_cmd_subscriber, but wireless remote.
|
||||
cfg= dict(),
|
||||
extra_cfg= dict(),
|
||||
model_device= torch.device("cpu"),
|
||||
):
|
||||
"""
|
||||
NOTE:
|
||||
* Must call start_ros() before using this class's get_obs() and send_action()
|
||||
* Joint order of simulation and of real A1 protocol are different, see dof_names
|
||||
* We store all joints values in the order of simulation in this class
|
||||
Args:
|
||||
forward_depth_embedding_dims: If a real number, the obs will not be built as a normal env.
|
||||
The segment of obs will be subsituted by the embedding of forward depth image from the
|
||||
ROS topic.
|
||||
cfg: same config from a1_config but a dict object.
|
||||
extra_cfg: some other configs that is hard to load from file.
|
||||
"""
|
||||
self.model_device = model_device
|
||||
self.num_envs = 1
|
||||
self.robot_namespace = robot_namespace
|
||||
self.low_state_topic = low_state_topic
|
||||
self.legs_cmd_topic = legs_cmd_topic
|
||||
self.forward_depth_topic = forward_depth_topic
|
||||
self.forward_depth_embedding_dims = forward_depth_embedding_dims
|
||||
self.odom_topic = odom_topic
|
||||
self.lin_vel_deadband = lin_vel_deadband
|
||||
self.ang_vel_deadband = ang_vel_deadband
|
||||
self.move_by_wireless_remote = move_by_wireless_remote
|
||||
self.cfg = cfg
|
||||
self.extra_cfg = dict(
|
||||
torque_limits= torch.tensor([33.5] * 12, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
|
||||
# torque_limits= torch.tensor([1, 5, 5] * 4, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
|
||||
dof_map= [ # from isaacgym simulation joint order to URDF order
|
||||
3, 4, 5,
|
||||
0, 1, 2,
|
||||
9, 10,11,
|
||||
6, 7, 8,
|
||||
], # real_joint_idx = dof_map[sim_joint_idx]
|
||||
dof_names= [ # NOTE: order matters
|
||||
"FL_hip_joint",
|
||||
"FL_thigh_joint",
|
||||
"FL_calf_joint",
|
||||
|
||||
"FR_hip_joint",
|
||||
"FR_thigh_joint",
|
||||
"FR_calf_joint",
|
||||
|
||||
"RL_hip_joint",
|
||||
"RL_thigh_joint",
|
||||
"RL_calf_joint",
|
||||
|
||||
"RR_hip_joint",
|
||||
"RR_thigh_joint",
|
||||
"RR_calf_joint",
|
||||
],
|
||||
# motor strength is multiplied directly to the action.
|
||||
motor_strength= torch.ones(12, dtype= torch.float32, device= self.model_device, requires_grad= False),
|
||||
); self.extra_cfg.update(extra_cfg)
|
||||
if "torque_limits" in self.cfg["control"]:
|
||||
self.extra_cfg["torque_limits"][:] = self.cfg["control"]["torque_limits"]
|
||||
self.command_buf = torch.zeros((self.num_envs, 3,), device= self.model_device, dtype= torch.float32) # zeros for initialization
|
||||
self.actions = torch.zeros((1, 12), device= model_device, dtype= torch.float32)
|
||||
|
||||
self.process_configs()
|
||||
|
||||
def start_ros(self):
|
||||
# initialze several buffers so that the system works even without message update.
|
||||
# self.low_state_buffer = LowState() # not initialized, let input message update it.
|
||||
self.base_position_buffer = torch.zeros((self.num_envs, 3), device= self.model_device, requires_grad= False)
|
||||
self.legs_cmd_publisher = rospy.Publisher(
|
||||
self.robot_namespace + self.legs_cmd_topic,
|
||||
LegsCmd,
|
||||
queue_size= 1,
|
||||
)
|
||||
# self.debug_publisher = rospy.Publisher(
|
||||
# "/DNNmodel_debug",
|
||||
# Float32MultiArray,
|
||||
# queue_size= 1,
|
||||
# )
|
||||
# NOTE: this launches the subscriber callback function
|
||||
self.low_state_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.low_state_topic,
|
||||
LowState,
|
||||
self.update_low_state,
|
||||
queue_size= 1,
|
||||
)
|
||||
self.odom_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.odom_topic,
|
||||
Odometry,
|
||||
self.update_base_pose,
|
||||
queue_size= 1,
|
||||
)
|
||||
if not self.move_by_wireless_remote:
|
||||
self.move_cmd_subscriber = rospy.Subscriber(
|
||||
"/cmd_vel",
|
||||
Twist,
|
||||
self.update_move_cmd,
|
||||
queue_size= 1,
|
||||
)
|
||||
if "forward_depth" in self.all_obs_components:
|
||||
if not self.forward_depth_embedding_dims:
|
||||
self.forward_depth_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.forward_depth_topic,
|
||||
Image,
|
||||
self.update_forward_depth,
|
||||
queue_size= 1,
|
||||
)
|
||||
else:
|
||||
self.forward_depth_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.forward_depth_topic,
|
||||
Float32MultiArrayStamped,
|
||||
self.update_forward_depth_embedding,
|
||||
queue_size= 1,
|
||||
)
|
||||
self.pose_cmd_subscriber = rospy.Subscriber(
|
||||
"/body_pose",
|
||||
Pose,
|
||||
self.dummy_handler,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
def wait_untill_ros_working(self):
|
||||
rate = rospy.Rate(100)
|
||||
while not hasattr(self, "low_state_buffer"):
|
||||
rate.sleep()
|
||||
rospy.loginfo("UnitreeA1Real.low_state_buffer acquired, stop waiting.")
|
||||
|
||||
def process_configs(self):
|
||||
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
|
||||
self.gravity_vec = torch.zeros((self.num_envs, 3), dtype= torch.float32)
|
||||
self.gravity_vec[:, self.up_axis_idx] = -1
|
||||
|
||||
self.obs_scales = self.cfg["normalization"]["obs_scales"]
|
||||
self.obs_scales["dof_pos"] = torch.tensor(self.obs_scales["dof_pos"], device= self.model_device, dtype= torch.float32)
|
||||
if not isinstance(self.cfg["control"]["damping"]["joint"], (list, tuple)):
|
||||
self.cfg["control"]["damping"]["joint"] = [self.cfg["control"]["damping"]["joint"]] * 12
|
||||
if not isinstance(self.cfg["control"]["stiffness"]["joint"], (list, tuple)):
|
||||
self.cfg["control"]["stiffness"]["joint"] = [self.cfg["control"]["stiffness"]["joint"]] * 12
|
||||
self.d_gains = torch.tensor(self.cfg["control"]["damping"]["joint"], device= self.model_device, dtype= torch.float32)
|
||||
self.p_gains = torch.tensor(self.cfg["control"]["stiffness"]["joint"], device= self.model_device, dtype= torch.float32)
|
||||
self.default_dof_pos = torch.zeros(12, device= self.model_device, dtype= torch.float32)
|
||||
for i in range(12):
|
||||
name = self.extra_cfg["dof_names"][i]
|
||||
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
|
||||
self.default_dof_pos[i] = default_joint_angle
|
||||
|
||||
self.torque_limits = self.extra_cfg["torque_limits"]
|
||||
|
||||
self.commands_scale = torch.tensor([
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["lin_vel"],
|
||||
], device= self.model_device, requires_grad= False)
|
||||
|
||||
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
|
||||
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
|
||||
components = self.cfg["env"].get("privileged_obs_components", None)
|
||||
self.privileged_obs_segments = None if components is None else self.get_num_obs_from_components(components)
|
||||
self.num_privileged_obs = None if components is None else self.get_num_obs_from_components(components)
|
||||
self.all_obs_components = self.cfg["env"]["obs_components"] + (self.cfg["env"].get("privileged_obs_components", []) if components is not None else [])
|
||||
|
||||
# store config values to attributes to improve speed
|
||||
self.clip_obs = self.cfg["normalization"]["clip_observations"]
|
||||
self.control_type = self.cfg["control"]["control_type"]
|
||||
self.action_scale = self.cfg["control"]["action_scale"]
|
||||
self.motor_strength = self.extra_cfg["motor_strength"]
|
||||
self.clip_actions = self.cfg["normalization"]["clip_actions"]
|
||||
self.dof_map = self.extra_cfg["dof_map"]
|
||||
|
||||
# get ROS params for hardware configs
|
||||
self.joint_limits_high = torch.tensor([
|
||||
rospy.get_param(self.robot_namespace + "/joint_limits/{}_max".format(s), 0.) \
|
||||
for s in ["hip", "thigh", "calf"] * 4
|
||||
])
|
||||
self.joint_limits_low = torch.tensor([
|
||||
rospy.get_param(self.robot_namespace + "/joint_limits/{}_min".format(s), 0.) \
|
||||
for s in ["hip", "thigh", "calf"] * 4
|
||||
])
|
||||
|
||||
if "forward_depth" in self.all_obs_components:
|
||||
resolution = self.cfg["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
self.cfg["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
if not self.forward_depth_embedding_dims:
|
||||
self.forward_depth_buf = torch.zeros(
|
||||
(self.num_envs, *resolution),
|
||||
device= self.model_device,
|
||||
dtype= torch.float32,
|
||||
)
|
||||
else:
|
||||
self.forward_depth_embedding_buf = torch.zeros(
|
||||
(1, self.forward_depth_embedding_dims),
|
||||
device= self.model_device,
|
||||
dtype= torch.float32,
|
||||
)
|
||||
|
||||
def _init_height_points(self):
|
||||
""" Returns points at which the height measurments are sampled (in base frame)
|
||||
|
||||
Returns:
|
||||
[torch.Tensor]: Tensor of shape (num_envs, self.num_height_points, 3)
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_heights(self):
|
||||
""" TODO: get estimated terrain heights around the robot base """
|
||||
# currently return a zero tensor with valid size
|
||||
return torch.zeros(self.num_envs, 187, device= self.model_device, requires_grad= False)
|
||||
|
||||
def clip_by_torque_limit(self, actions_scaled):
|
||||
""" Different from simulation, we reverse the process and clip the actions directly,
|
||||
so that the PD controller runs in robot but not our script.
|
||||
"""
|
||||
control_type = self.cfg["control"]["control_type"]
|
||||
if control_type == "P":
|
||||
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel
|
||||
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel
|
||||
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos
|
||||
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return torch.clip(actions_scaled, actions_low, actions_high)
|
||||
|
||||
""" Get obs components and cat to a single obs input """
|
||||
def _get_proprioception_obs(self):
|
||||
# base_ang_vel = quat_rotate_inverse(
|
||||
# torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
|
||||
# torch.tensor(self.low_state_buffer.imu.gyroscope).unsqueeze(0),
|
||||
# ).to(self.model_device)
|
||||
# NOTE: Different from the isaacgym.
|
||||
# The anglar velocity is already in base frame, no need to rotate
|
||||
base_ang_vel = torch.tensor(self.low_state_buffer.imu.gyroscope, device= self.model_device).unsqueeze(0)
|
||||
projected_gravity = quat_rotate_inverse(
|
||||
torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
|
||||
self.gravity_vec,
|
||||
).to(self.model_device)
|
||||
self.dof_pos = dof_pos = torch.tensor([
|
||||
self.low_state_buffer.motorState[self.dof_map[i]].q for i in range(12)
|
||||
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
|
||||
self.dof_vel = dof_vel = torch.tensor([
|
||||
self.low_state_buffer.motorState[self.dof_map[i]].dq for i in range(12)
|
||||
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
|
||||
|
||||
return torch.cat([
|
||||
torch.zeros((1, 3), device= self.model_device), # no linear velocity
|
||||
base_ang_vel * self.obs_scales["ang_vel"],
|
||||
projected_gravity,
|
||||
self.command_buf * self.commands_scale,
|
||||
(dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"],
|
||||
dof_vel * self.obs_scales["dof_vel"],
|
||||
self.actions
|
||||
], dim= -1)
|
||||
|
||||
def _get_forward_depth_obs(self):
|
||||
if not self.forward_depth_embedding_dims:
|
||||
return self.forward_depth_buf.flatten(start_dim= 1)
|
||||
else:
|
||||
return self.forward_depth_embedding_buf.flatten(start_dim= 1)
|
||||
|
||||
def compute_observation(self):
|
||||
""" use the updated low_state_buffer to compute observation vector """
|
||||
assert hasattr(self, "legs_cmd_publisher"), "start_ros() not called, ROS handlers are not initialized!"
|
||||
obs_segments = self.obs_segments
|
||||
obs = []
|
||||
for k, v in obs_segments.items():
|
||||
obs.append(
|
||||
getattr(self, "_get_" + k + "_obs")() * \
|
||||
self.obs_scales.get(k, 1.)
|
||||
)
|
||||
obs = torch.cat(obs, dim= 1)
|
||||
self.obs_buf = obs
|
||||
|
||||
""" The methods combined with outer model forms the step function
|
||||
NOTE: the outer user handles the loop frequency.
|
||||
"""
|
||||
def send_action(self, actions):
|
||||
""" The function that send commands to the real robot.
|
||||
"""
|
||||
self.actions = torch.clip(actions, -self.clip_actions, self.clip_actions)
|
||||
actions = self.actions * self.motor_strength
|
||||
robot_coordinates_action = self.clip_by_torque_limit(actions * self.action_scale) + self.default_dof_pos.unsqueeze(0)
|
||||
# robot_coordinates_action = self.actions * self.action_scale + self.default_dof_pos.unsqueeze(0)
|
||||
|
||||
# debugging and logging
|
||||
# transfered_action = torch.zeros_like(self.actions[0])
|
||||
# for i in range(12):
|
||||
# transfered_action[self.dof_map[i]] = self.actions[0, i] + self.default_dof_pos[i]
|
||||
# self.debug_publisher.publish(Float32MultiArray(data=
|
||||
# transfered_action\
|
||||
# .cpu().numpy().astype(np.float32).tolist()
|
||||
# ))
|
||||
|
||||
# restrict the target action delta in order to avoid robot shutdown (maybe there is another solution)
|
||||
# robot_coordinates_action = torch.clip(
|
||||
# robot_coordinates_action,
|
||||
# self.dof_pos - 0.3,
|
||||
# self.dof_pos + 0.3,
|
||||
# )
|
||||
|
||||
# wrap the message and publish
|
||||
self.publish_legs_cmd(robot_coordinates_action)
|
||||
|
||||
def publish_legs_cmd(self, robot_coordinates_action, kp= None, kd= None):
|
||||
""" publish the joint position directly to the robot. NOTE: The joint order from input should
|
||||
be in simulation order. The value should be absolute value rather than related to dof_pos.
|
||||
"""
|
||||
robot_coordinates_action = torch.clip(
|
||||
robot_coordinates_action.cpu(),
|
||||
self.joint_limits_low,
|
||||
self.joint_limits_high,
|
||||
)
|
||||
legs_cmd = LegsCmd()
|
||||
for sim_joint_idx in range(12):
|
||||
real_joint_idx = self.dof_map[sim_joint_idx]
|
||||
legs_cmd.cmd[real_joint_idx].mode = 10
|
||||
legs_cmd.cmd[real_joint_idx].q = robot_coordinates_action[0, sim_joint_idx] if self.control_type == "P" else rospy.get_param(self.robot_namespace + "/PosStopF", (2.146e+9))
|
||||
legs_cmd.cmd[real_joint_idx].dq = 0.
|
||||
legs_cmd.cmd[real_joint_idx].tau = 0.
|
||||
legs_cmd.cmd[real_joint_idx].Kp = self.p_gains[sim_joint_idx] if kp is None else kp
|
||||
legs_cmd.cmd[real_joint_idx].Kd = self.d_gains[sim_joint_idx] if kd is None else kd
|
||||
self.legs_cmd_publisher.publish(legs_cmd)
|
||||
|
||||
def get_obs(self):
|
||||
""" The function that refreshes the buffer and return the observation vector.
|
||||
"""
|
||||
self.compute_observation()
|
||||
self.obs_buf = torch.clip(self.obs_buf, -self.clip_obs, self.clip_obs)
|
||||
return self.obs_buf.to(self.model_device)
|
||||
|
||||
""" Copied from legged_robot_field. Please check whether these are consistent. """
|
||||
def get_obs_segment_from_components(self, components):
|
||||
segments = OrderedDict()
|
||||
if "proprioception" in components:
|
||||
segments["proprioception"] = (48,)
|
||||
if "height_measurements" in components:
|
||||
segments["height_measurements"] = (187,)
|
||||
if "forward_depth" in components:
|
||||
resolution = self.cfg["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
self.cfg["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
segments["forward_depth"] = (1, *resolution)
|
||||
# The following components are only for rebuilding the non-actor module.
|
||||
# DO NOT use these in actor network and check consistency with simulator implementation.
|
||||
if "base_pose" in components:
|
||||
segments["base_pose"] = (6,) # xyz + rpy
|
||||
if "robot_config" in components:
|
||||
segments["robot_config"] = (1 + 3 + 1 + 12,)
|
||||
if "engaging_block" in components:
|
||||
# This could be wrong, please check the implementation of BarrierTrack
|
||||
segments["engaging_block"] = (1 + (4 + 1) + 2,)
|
||||
if "sidewall_distance" in components:
|
||||
segments["sidewall_distance"] = (2,)
|
||||
return segments
|
||||
|
||||
def get_num_obs_from_components(self, components):
|
||||
obs_segments = self.get_obs_segment_from_components(components)
|
||||
num_obs = 0
|
||||
for k, v in obs_segments.items():
|
||||
num_obs += np.prod(v)
|
||||
return num_obs
|
||||
|
||||
""" ROS callbacks and handlers that update the buffer """
|
||||
def update_low_state(self, ros_msg):
|
||||
self.low_state_buffer = ros_msg
|
||||
if self.move_by_wireless_remote:
|
||||
self.command_buf[0, 0] = self.low_state_buffer.wirelessRemote.ly
|
||||
self.command_buf[0, 1] = -self.low_state_buffer.wirelessRemote.lx # right-moving stick is positive
|
||||
self.command_buf[0, 2] = -self.low_state_buffer.wirelessRemote.rx # right-moving stick is positive
|
||||
# set the command to zero if it is too small
|
||||
if np.linalg.norm(self.command_buf[0, :2]) < self.lin_vel_deadband:
|
||||
self.command_buf[0, :2] = 0.
|
||||
if np.abs(self.command_buf[0, 2]) < self.ang_vel_deadband:
|
||||
self.command_buf[0, 2] = 0.
|
||||
|
||||
def update_base_pose(self, ros_msg):
|
||||
""" update robot odometry for position """
|
||||
self.base_position_buffer[0, 0] = ros_msg.pose.pose.position.x
|
||||
self.base_position_buffer[0, 1] = ros_msg.pose.pose.position.y
|
||||
self.base_position_buffer[0, 2] = ros_msg.pose.pose.position.z
|
||||
|
||||
def update_move_cmd(self, ros_msg):
|
||||
self.command_buf[0, 0] = ros_msg.linear.x
|
||||
self.command_buf[0, 1] = ros_msg.linear.y
|
||||
self.command_buf[0, 2] = ros_msg.angular.z
|
||||
|
||||
def update_forward_depth(self, ros_msg):
|
||||
# TODO not checked.
|
||||
self.forward_depth_header = ros_msg.header
|
||||
buf = ros_numpy.numpify(ros_msg)
|
||||
self.forward_depth_buf = resize2d(
|
||||
torch.from_numpy(buf.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(self.model_device),
|
||||
self.forward_depth_buf.shape[-2:],
|
||||
)
|
||||
|
||||
def update_forward_depth_embedding(self, ros_msg):
|
||||
self.forward_depth_embedding_stamp = ros_msg.header.stamp
|
||||
self.forward_depth_embedding_buf[:] = torch.tensor(ros_msg.data).unsqueeze(0) # (1, d)
|
||||
|
||||
def dummy_handler(self, ros_msg):
|
||||
""" To meet the need of teleop-legged-robots requirements """
|
||||
pass
|
|
@ -1,377 +0,0 @@
|
|||
#!/home/unitree/agility_ziwenz_venv/bin/python
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import rospy
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
from sensor_msgs.msg import Image
|
||||
import ros_numpy
|
||||
|
||||
from a1_real import UnitreeA1Real, resize2d
|
||||
from rsl_rl import modules
|
||||
from rsl_rl.utils.utils import get_obs_slice
|
||||
|
||||
@torch.no_grad()
|
||||
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
|
||||
""" The callback function to handle the forward depth and send the embedding through ROS topic """
|
||||
buf = ros_numpy.numpify(ros_msg).astype(np.float32)
|
||||
forward_depth_buf = resize2d(
|
||||
torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
|
||||
output_resolution,
|
||||
)
|
||||
embedding = model(forward_depth_buf)
|
||||
ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
|
||||
publisher.publish(Float32MultiArray(data= ros_data.tolist()))
|
||||
|
||||
class StandOnlyModel(torch.nn.Module):
|
||||
def __init__(self, action_scale, dof_pos_scale, tolerance= 0.1, delta= 0.1):
|
||||
rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
|
||||
rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
|
||||
super().__init__()
|
||||
if isinstance(action_scale, (tuple, list)):
|
||||
self.register_buffer("action_scale", torch.tensor(action_scale))
|
||||
else:
|
||||
self.action_scale = action_scale
|
||||
if isinstance(dof_pos_scale, (tuple, list)):
|
||||
self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
|
||||
else:
|
||||
self.dof_pos_scale = dof_pos_scale
|
||||
self.tolerance = tolerance
|
||||
self.delta = delta
|
||||
|
||||
def forward(self, obs):
|
||||
joint_positions = obs[..., -36:-24] / self.dof_pos_scale
|
||||
diff_large_mask = torch.abs(joint_positions) > self.tolerance
|
||||
target_positions = torch.zeros_like(joint_positions)
|
||||
target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
|
||||
return torch.clip(
|
||||
target_positions / self.action_scale,
|
||||
-1.0, 1.0,
|
||||
)
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def load_walk_policy(env, model_dir):
|
||||
""" Load the walk policy from the model directory """
|
||||
if model_dir == None:
|
||||
model = StandOnlyModel(
|
||||
action_scale= env.action_scale,
|
||||
dof_pos_scale= env.obs_scales["dof_pos"],
|
||||
)
|
||||
policy = torch.jit.script(model)
|
||||
|
||||
else:
|
||||
with open(osp.join(model_dir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
obs_components = config_dict["env"]["obs_components"]
|
||||
privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= env.get_num_obs_from_components(obs_components),
|
||||
num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
|
||||
num_actions= 12,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
|
||||
if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
|
||||
action_rescale_ratio = model_action_scale / env.action_scale
|
||||
print("walk_policy action scaling:", action_rescale_ratio.tolist())
|
||||
else:
|
||||
action_rescale_ratio = 1.0
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy_run(obs):
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
|
||||
or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
|
||||
policy = policy_run
|
||||
else:
|
||||
policy = lambda x: policy_run(x) * action_rescale_ratio
|
||||
|
||||
return policy, model
|
||||
|
||||
def standup_procedure(env, ros_rate, angle_tolerance= 0.05, kp= None, kd= None, device= "cpu"):
|
||||
rospy.loginfo("Robot standing up, please wait ...")
|
||||
|
||||
target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
|
||||
while not rospy.is_shutdown():
|
||||
dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
|
||||
diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
|
||||
direction = [1 if i > 0 else -1 for i in diff]
|
||||
if all([abs(i) < angle_tolerance for i in diff]):
|
||||
break
|
||||
print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
|
||||
for i in range(12):
|
||||
target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
|
||||
env.publish_legs_cmd(target_pos,
|
||||
kp= kp,
|
||||
kd= kd,
|
||||
)
|
||||
ros_rate.sleep()
|
||||
|
||||
rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
|
||||
while not rospy.is_shutdown():
|
||||
if env.low_state_buffer.wirelessRemote.btn.components.R1:
|
||||
break
|
||||
if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
exit(0)
|
||||
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
|
||||
ros_rate.sleep()
|
||||
rospy.loginfo("Robot standing up procedure finished!")
|
||||
|
||||
class SkilledA1Real(UnitreeA1Real):
|
||||
""" Some additional methods to help the execution of skill policy """
|
||||
def __init__(self, *args,
|
||||
skill_mode_threhold= 0.1,
|
||||
skill_vel_range= [0.0, 1.0],
|
||||
**kwargs,
|
||||
):
|
||||
self.skill_mode_threhold = skill_mode_threhold
|
||||
self.skill_vel_range = skill_vel_range
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def is_skill_mode(self):
|
||||
if self.move_by_wireless_remote:
|
||||
return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
|
||||
else:
|
||||
# Not implemented yet
|
||||
return False
|
||||
|
||||
def update_low_state(self, ros_msg):
|
||||
self.low_state_buffer = ros_msg
|
||||
if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
|
||||
skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
|
||||
skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
|
||||
skill_vel += self.skill_vel_range[0]
|
||||
self.command_buf[0, 0] = skill_vel
|
||||
self.command_buf[0, 1] = 0.
|
||||
self.command_buf[0, 2] = 0.
|
||||
return
|
||||
return super().update_low_state(ros_msg)
|
||||
|
||||
def main(args):
|
||||
log_level = rospy.DEBUG if args.debug else rospy.INFO
|
||||
rospy.init_node("a1_legged_gym_" + args.mode, anonymous= True, log_level= log_level)
|
||||
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
|
||||
# config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp
|
||||
|
||||
model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")
|
||||
|
||||
unitree_real_env = SkilledA1Real(
|
||||
robot_namespace= args.namespace,
|
||||
cfg= config_dict,
|
||||
forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
|
||||
forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
|
||||
move_by_wireless_remote= True,
|
||||
skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
|
||||
model_device= model_device,
|
||||
# extra_cfg= dict(
|
||||
# motor_strength= torch.tensor([
|
||||
# 1., 1./0.9, 1./0.9,
|
||||
# 1., 1./0.9, 1./0.9,
|
||||
# 1., 1., 1.,
|
||||
# 1., 1., 1.,
|
||||
# ], dtype= torch.float32, device= model_device, requires_grad= False),
|
||||
# ),
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= unitree_real_env.num_obs,
|
||||
num_critic_obs= unitree_real_env.num_privileged_obs,
|
||||
num_actions= 12,
|
||||
obs_segments= unitree_real_env.obs_segments,
|
||||
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
config_dict["terrain"]["measure_heights"] = False
|
||||
# load the model with the latest checkpoint
|
||||
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(model_device)
|
||||
model.eval()
|
||||
|
||||
rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
|
||||
duration,
|
||||
config_dict["control"]["stiffness"]["joint"],
|
||||
config_dict["control"]["damping"]["joint"],
|
||||
))
|
||||
rospy.loginfo("[Env] torque limit: {:.1f}".format(unitree_real_env.torque_limits.mean().item()))
|
||||
rospy.loginfo("[Env] action scale: {:.1f}".format(unitree_real_env.action_scale))
|
||||
rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))
|
||||
|
||||
if args.mode == "jetson":
|
||||
embeding_publisher = rospy.Publisher(
|
||||
args.namespace + "/visual_embedding",
|
||||
Float32MultiArray,
|
||||
queue_size= 1,
|
||||
)
|
||||
# extract and build the torch ScriptFunction
|
||||
visual_encoder = model.visual_encoder
|
||||
visual_encoder = torch.jit.script(visual_encoder)
|
||||
|
||||
forward_depth_subscriber = rospy.Subscriber(
|
||||
args.namespace + "/camera/depth/image_rect_raw",
|
||||
Image,
|
||||
partial(handle_forward_depth,
|
||||
model= visual_encoder,
|
||||
publisher= embeding_publisher,
|
||||
output_resolution= config_dict["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
config_dict["sensor"]["forward_camera"]["resolution"],
|
||||
),
|
||||
device= model_device,
|
||||
),
|
||||
queue_size= 1,
|
||||
)
|
||||
rospy.spin()
|
||||
elif args.mode == "upboard":
|
||||
# extract and build the torch ScriptFunction
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy(obs):
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
|
||||
walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)
|
||||
|
||||
using_walk_policy = True # switch between skill policy and walk policy
|
||||
unitree_real_env.start_ros()
|
||||
unitree_real_env.wait_untill_ros_working()
|
||||
rate = rospy.Rate(1 / duration)
|
||||
with torch.no_grad():
|
||||
if not args.debug:
|
||||
standup_procedure(unitree_real_env, rate,
|
||||
angle_tolerance= 0.1,
|
||||
kp= 50,
|
||||
kd= 1.,
|
||||
device= model_device,
|
||||
)
|
||||
while not rospy.is_shutdown():
|
||||
# inference_start_time = rospy.get_time()
|
||||
# check remote controller and decide which policy to use
|
||||
if unitree_real_env.is_skill_mode():
|
||||
if using_walk_policy:
|
||||
rospy.loginfo_throttle(0.1, "switch to skill policy")
|
||||
using_walk_policy = False
|
||||
model.reset()
|
||||
else:
|
||||
if not using_walk_policy:
|
||||
rospy.loginfo_throttle(0.1, "switch to walk policy")
|
||||
using_walk_policy = True
|
||||
walk_model.reset()
|
||||
if not using_walk_policy:
|
||||
obs = unitree_real_env.get_obs()
|
||||
actions = policy(obs)
|
||||
else:
|
||||
walk_obs = unitree_real_env._get_proprioception_obs()
|
||||
actions = walk_policy(walk_obs)
|
||||
unitree_real_env.send_action(actions)
|
||||
# unitree_real_env.send_action(torch.zeros((1, 12)))
|
||||
# inference_duration = rospy.get_time() - inference_start_time
|
||||
# rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
|
||||
# rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
|
||||
# motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
|
||||
# rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
|
||||
rate.sleep()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
|
||||
rospy.loginfo_throttle(0.1, "model reset")
|
||||
model.reset()
|
||||
walk_model.reset()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
elif args.mode == "full":
|
||||
# extract and build the torch ScriptFunction
|
||||
visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
|
||||
visual_encoder = model.visual_encoder
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
|
||||
visual_latent = visual_encoder(
|
||||
observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
|
||||
).reshape(1, -1)
|
||||
obs = torch.cat([
|
||||
observations[..., :obs_start],
|
||||
visual_latent,
|
||||
observations[..., obs_stop:],
|
||||
], dim= -1)
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
|
||||
unitree_real_env.start_ros()
|
||||
unitree_real_env.wait_untill_ros_working()
|
||||
rate = rospy.Rate(1 / duration)
|
||||
with torch.no_grad():
|
||||
while not rospy.is_shutdown():
|
||||
# inference_start_time = rospy.get_time()
|
||||
obs = unitree_real_env.get_obs()
|
||||
actions = policy(obs,
|
||||
obs_start= visual_obs_slice[0].start.item(),
|
||||
obs_stop= visual_obs_slice[0].stop.item(),
|
||||
obs_shape= visual_obs_slice[1],
|
||||
)
|
||||
unitree_real_env.send_action(actions)
|
||||
# inference_duration = rospy.get_time() - inference_start_time
|
||||
motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
|
||||
rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
|
||||
rate.sleep()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
else:
|
||||
rospy.logfatal("Unknown mode, exiting")
|
||||
|
||||
if __name__ == "__main__":
|
||||
""" The script to run the A1 script in ROS.
|
||||
It's designed as a main function and not designed to be a scalable code.
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--namespace",
|
||||
type= str,
|
||||
default= "/a112138",
|
||||
)
|
||||
parser.add_argument("--logdir",
|
||||
type= str,
|
||||
help= "The log directory of the trained model",
|
||||
)
|
||||
parser.add_argument("--walkdir",
|
||||
type= str,
|
||||
help= "The log directory of the walking model, not for the skills.",
|
||||
default= None,
|
||||
)
|
||||
parser.add_argument("--mode",
|
||||
type= str,
|
||||
help= "The mode to determine which computer to run on.",
|
||||
choices= ["jetson", "upboard", "full"],
|
||||
)
|
||||
parser.add_argument("--debug",
|
||||
action= "store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -1,15 +1,15 @@
|
|||
import numpy as np
|
||||
from os import path as osp
|
||||
from legged_gym.envs.a1.a1_field_config import A1FieldCfg, A1FieldCfgPPO
|
||||
from legged_gym.utils.helpers import merge_dict
|
||||
|
||||
class A1TiltCfg( A1FieldCfg ):
|
||||
|
||||
#### uncomment this to train non-virtual terrain
|
||||
# class sensor( A1FieldCfg.sensor ):
|
||||
# class proprioception( A1FieldCfg.sensor.proprioception ):
|
||||
# delay_action_obs = True
|
||||
# latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
#### uncomment the above to train non-virtual terrain
|
||||
### uncomment this to train non-virtual terrain
|
||||
class sensor( A1FieldCfg.sensor ):
|
||||
class proprioception( A1FieldCfg.sensor.proprioception ):
|
||||
latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
### uncomment the above to train non-virtual terrain
|
||||
|
||||
class terrain( A1FieldCfg.terrain ):
|
||||
max_init_terrain_level = 2
|
||||
|
@ -41,6 +41,9 @@ class A1TiltCfg( A1FieldCfg ):
|
|||
lin_vel_y = [0.0, 0.0]
|
||||
ang_vel_yaw = [0., 0.]
|
||||
|
||||
class asset( A1FieldCfg.asset ):
|
||||
penalize_contacts_on = ["base"]
|
||||
|
||||
class termination( A1FieldCfg.termination ):
|
||||
# additional factors that determines whether to terminates the episode
|
||||
termination_terms = [
|
||||
|
@ -52,8 +55,12 @@ class A1TiltCfg( A1FieldCfg ):
|
|||
]
|
||||
|
||||
class domain_rand( A1FieldCfg.domain_rand ):
|
||||
# push_robots = True # use for virtual training
|
||||
push_robots = False # use for non-virtual training
|
||||
push_robots = True # use for virtual training
|
||||
# push_robots = False # use for non-virtual training
|
||||
init_base_rot_range = dict(
|
||||
roll= [-0.1, 0.1],
|
||||
pitch= [-0.1, 0.1],
|
||||
)
|
||||
|
||||
class rewards( A1FieldCfg.rewards ):
|
||||
class scales:
|
||||
|
@ -61,10 +68,12 @@ class A1TiltCfg( A1FieldCfg ):
|
|||
world_vel_l2norm = -1.
|
||||
legs_energy_substeps = -1e-5
|
||||
alive = 2.
|
||||
penetrate_depth = -3e-3
|
||||
penetrate_volume = -3e-3
|
||||
exceed_dof_pos_limits = -1e-1
|
||||
exceed_torque_limits_i = -2e-1
|
||||
penetrate_depth = -2e-3
|
||||
penetrate_volume = -2e-3
|
||||
exceed_dof_pos_limits = -8e-1
|
||||
exceed_torque_limits_l1norm = -8e-1
|
||||
tilt_cond = 8e-3
|
||||
collision = -0.1
|
||||
|
||||
class curriculum( A1FieldCfg.curriculum ):
|
||||
penetrate_volume_threshold_harder = 4000
|
||||
|
@ -73,6 +82,7 @@ class A1TiltCfg( A1FieldCfg ):
|
|||
penetrate_depth_threshold_easier = 300
|
||||
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class A1TiltCfgPPO( A1FieldCfgPPO ):
|
||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
|
@ -81,24 +91,38 @@ class A1TiltCfgPPO( A1FieldCfgPPO ):
|
|||
class runner( A1FieldCfgPPO.runner ):
|
||||
policy_class_name = "ActorCriticRecurrent"
|
||||
experiment_name = "field_a1"
|
||||
run_name = "".join(["Skill",
|
||||
("Multi" if len(A1TiltCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1TiltCfg.terrain.BarrierTrack_kwargs["options"][0] if A1TiltCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
|
||||
("_propDelay{:.2f}-{:.2f}".format(
|
||||
A1TiltCfg.sensor.proprioception.latency_range[0],
|
||||
A1TiltCfg.sensor.proprioception.latency_range[1],
|
||||
) if A1TiltCfg.sensor.proprioception.delay_action_obs else ""
|
||||
),
|
||||
("_pPenV" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_volume, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_volume", 0.) < 0. else ""),
|
||||
("_pPenD" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_depth, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_depth", 0.) < 0. else ""),
|
||||
("_noPush" if not A1TiltCfg.domain_rand.push_robots else ""),
|
||||
("_tiltMax{:.2f}".format(A1TiltCfg.terrain.BarrierTrack_kwargs["tilt"]["width"][1])),
|
||||
("_virtual" if A1TiltCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
|
||||
])
|
||||
resume = True
|
||||
load_run = "{Your traind walking model directory}"
|
||||
load_run = "{Your virtual terrain model directory}"
|
||||
load_run = "Aug17_11-13-14_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||
load_run = "Aug23_22-03-41_Skilltilt_pPenV3e-3_pPenD3e-3_tiltMax0.40_virtual"
|
||||
# load_run = osp.join(logs_root, "field_a1_oracle/Aug08_05-22-52_Skills_tilt_pEnergySubsteps1e-5_rAlive1_pPenV5e-3_pPenD5e-3_pPosY0.50_pYaw0.50_rTilt7e-1_pTorqueExceedIndicate1e-1_virtualTerrain_propDelay0.04-0.05_push")
|
||||
# load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Sep27_13-59-27_Skills_tilt_propDelay0.04-0.05_pEnergySubsteps1e-5_pPenD2e-3_pDofLimit8e-1_rTilt8e-03_pCollision0.1_noPush_kp40_kd0.5_tiltMax0.40fromAug08_05-22-52")
|
||||
load_run = osp.join(logs_root, "field_a1_noTanh_oracle", "Oct11_12-24-22_Skills_tilt_propDelay0.04-0.05_pEnergySubsteps1e-5_pPenD2e-3_pDofLimit8e-1_rTilt8e-03_pCollision0.1_PushRobot_kp40_kd0.5_tiltMax0.40fromSep27_13-59-27")
|
||||
|
||||
run_name = "".join(["Skills_",
|
||||
("Multi" if len(A1TiltCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (A1TiltCfg.terrain.BarrierTrack_kwargs["options"][0] if A1TiltCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
|
||||
("_comXRange{:.1f}-{:.1f}".format(A1TiltCfg.domain_rand.com_range.x[0], A1TiltCfg.domain_rand.com_range.x[1])),
|
||||
("_noLinVel" if not A1TiltCfg.env.use_lin_vel else ""),
|
||||
("_propDelay{:.2f}-{:.2f}".format(
|
||||
A1TiltCfg.sensor.proprioception.latency_range[0],
|
||||
A1TiltCfg.sensor.proprioception.latency_range[1],
|
||||
) if A1TiltCfg.sensor.proprioception.delay_action_obs else ""
|
||||
),
|
||||
# ("_pEnergySubsteps" + np.format_float_scientific(-A1TiltCfg.rewards.scales.legs_energy_substeps, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "legs_energy_substeps", 0.) < 0. else ""),
|
||||
# ("_pPenV" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_volume, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_volume", 0.) < 0. else ""),
|
||||
("_pPenD" + np.format_float_scientific(-A1TiltCfg.rewards.scales.penetrate_depth, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "penetrate_depth", 0.) < 0. else ""),
|
||||
("_pDofLimit" + np.format_float_scientific(-A1TiltCfg.rewards.scales.exceed_dof_pos_limits, trim= "-", exp_digits= 1) if getattr(A1TiltCfg.rewards.scales, "exceed_dof_pos_limits", 0.) < 0. else ""),
|
||||
# ("_rTilt{:.0e}".format(A1TiltCfg.rewards.scales.tilt_cond) if getattr(A1TiltCfg.rewards.scales, "tilt_cond", 0.) > 0. else ""),
|
||||
# ("_pCollision{:.1f}".format(-A1TiltCfg.rewards.scales.collision) if getattr(A1TiltCfg.rewards.scales, "collision", 0.) != 0. else ""),
|
||||
("_noPush" if not A1TiltCfg.domain_rand.push_robots else "_PushRobot"),
|
||||
# ("_kp{:d}".format(int(A1TiltCfg.control.stiffness["joint"])) if A1TiltCfg.control.stiffness["joint"] != 50 else ""),
|
||||
# ("_kd{:.1f}".format(A1TiltCfg.control.damping["joint"]) if A1TiltCfg.control.damping["joint"] != 0. else ""),
|
||||
("_noTanh"),
|
||||
("_tiltMax{:.2f}".format(A1TiltCfg.terrain.BarrierTrack_kwargs["tilt"]["width"][1])),
|
||||
("_virtual" if A1TiltCfg.terrain.BarrierTrack_kwargs["virtual_terrain"] else ""),
|
||||
("_noResume" if not resume else "from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
max_iterations = 20000
|
||||
save_interval = 500
|
||||
|
|
@ -1,309 +0,0 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
from a1_real import UnitreeA1Real, resize2d
|
||||
from rsl_rl import modules
|
||||
|
||||
import rospy
|
||||
from unitree_legged_msgs.msg import Float32MultiArrayStamped
|
||||
from sensor_msgs.msg import Image
|
||||
import ros_numpy
|
||||
|
||||
import pyrealsense2 as rs
|
||||
|
||||
def get_encoder_script(logdir):
|
||||
with open(osp.join(logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
model_device = torch.device("cuda")
|
||||
|
||||
unitree_real_env = UnitreeA1Real(
|
||||
robot_namespace= "DummyUnitreeA1Real",
|
||||
cfg= config_dict,
|
||||
forward_depth_topic= "", # this env only computes parameters to build the model
|
||||
forward_depth_embedding_dims= None,
|
||||
model_device= model_device,
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= unitree_real_env.num_obs,
|
||||
num_critic_obs= unitree_real_env.num_privileged_obs,
|
||||
num_actions= 12,
|
||||
obs_segments= unitree_real_env.obs_segments,
|
||||
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(model_device)
|
||||
model.eval()
|
||||
|
||||
visual_encoder = model.visual_encoder
|
||||
script = torch.jit.script(visual_encoder)
|
||||
|
||||
return script, model_device
|
||||
|
||||
def get_input_filter(args):
|
||||
""" This is the filter different from the simulator, but try to close the gap. """
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
image_resolution = config_dict["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
config_dict["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
depth_range = config_dict["sensor"]["forward_camera"].get(
|
||||
"depth_range",
|
||||
[0.0, 3.0],
|
||||
)
|
||||
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
|
||||
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
|
||||
crop_far = args.crop_far * 1000
|
||||
|
||||
def input_filter(depth_image: torch.Tensor,
|
||||
crop_top: int,
|
||||
crop_bottom: int,
|
||||
crop_left: int,
|
||||
crop_right: int,
|
||||
crop_far: float,
|
||||
depth_min: int,
|
||||
depth_max: int,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
):
|
||||
""" depth_image must have shape [1, 1, H, W] """
|
||||
depth_image = depth_image[:, :,
|
||||
crop_top: -crop_bottom-1,
|
||||
crop_left: -crop_right-1,
|
||||
]
|
||||
depth_image[depth_image > crop_far] = depth_max
|
||||
depth_image = torch.clip(
|
||||
depth_image,
|
||||
depth_min,
|
||||
depth_max,
|
||||
) / (depth_max - depth_min)
|
||||
depth_image = resize2d(depth_image, (output_height, output_width))
|
||||
return depth_image
|
||||
# input_filter = torch.jit.script(input_filter)
|
||||
|
||||
return partial(input_filter,
|
||||
crop_top= crop_top,
|
||||
crop_bottom= crop_bottom,
|
||||
crop_left= crop_left,
|
||||
crop_right= crop_right,
|
||||
crop_far= crop_far,
|
||||
depth_min= depth_range[0],
|
||||
depth_max= depth_range[1],
|
||||
output_height= image_resolution[0],
|
||||
output_width= image_resolution[1],
|
||||
), depth_range
|
||||
|
||||
def get_started_pipeline(
|
||||
height= 480,
|
||||
width= 640,
|
||||
fps= 30,
|
||||
enable_rgb= False,
|
||||
):
|
||||
# By default, rgb is not used.
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
|
||||
if enable_rgb:
|
||||
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
|
||||
profile = pipeline.start(config)
|
||||
|
||||
# build the sensor filter
|
||||
hole_filling_filter = rs.hole_filling_filter(2)
|
||||
spatial_filter = rs.spatial_filter()
|
||||
spatial_filter.set_option(rs.option.filter_magnitude, 5)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
spatial_filter.set_option(rs.option.holes_fill, 4)
|
||||
temporal_filter = rs.temporal_filter()
|
||||
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
# decimation_filter = rs.decimation_filter()
|
||||
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
|
||||
|
||||
def filter_func(frame):
|
||||
frame = hole_filling_filter.process(frame)
|
||||
frame = spatial_filter.process(frame)
|
||||
frame = temporal_filter.process(frame)
|
||||
# frame = decimation_filter.process(frame)
|
||||
return frame
|
||||
|
||||
return pipeline, filter_func
|
||||
|
||||
def main(args):
|
||||
rospy.init_node("a1_legged_gym_jetson", anonymous= True)
|
||||
|
||||
input_filter, depth_range = get_input_filter(args)
|
||||
model_script, model_device = get_encoder_script(args.logdir)
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
|
||||
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
|
||||
ros_rate = rospy.Rate(1.0 / refresh_duration)
|
||||
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
|
||||
else:
|
||||
ros_rate = rospy.Rate(args.fps)
|
||||
|
||||
rs_pipeline, rs_filters = get_started_pipeline(
|
||||
height= args.height,
|
||||
width= args.width,
|
||||
fps= args.fps,
|
||||
enable_rgb= args.enable_rgb,
|
||||
)
|
||||
|
||||
embedding_publisher = rospy.Publisher(
|
||||
args.namespace + "/visual_embedding",
|
||||
Float32MultiArrayStamped,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
if args.enable_vis:
|
||||
depth_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/image_rect_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
network_input_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/network_input_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
if args.enable_rgb:
|
||||
rgb_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/color/image_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
|
||||
rospy.loginfo("ROS, model, realsense have been initialized.")
|
||||
if args.enable_vis:
|
||||
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
|
||||
try:
|
||||
embedding_msg = Float32MultiArrayStamped()
|
||||
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
frame_got = False
|
||||
while not rospy.is_shutdown():
|
||||
# Wait for the depth image
|
||||
frames = rs_pipeline.wait_for_frames()
|
||||
embedding_msg.header.stamp = rospy.Time.now()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
if not depth_frame:
|
||||
continue
|
||||
if not frame_got:
|
||||
frame_got = True
|
||||
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
|
||||
if args.enable_rgb:
|
||||
color_frame = frames.get_color_frame()
|
||||
# Use this branch to log the time when image is acquired
|
||||
if args.enable_vis and not color_frame is None:
|
||||
color_frame = np.asanyarray(color_frame.get_data())
|
||||
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
|
||||
rgb_image_msg.header.stamp = rospy.Time.now()
|
||||
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
|
||||
rgb_image_publisher.publish(rgb_image_msg)
|
||||
|
||||
# Process the depth image and publish
|
||||
depth_frame = rs_filters(depth_frame)
|
||||
depth_image_ = np.asanyarray(depth_frame.get_data())
|
||||
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
|
||||
depth_image = input_filter(depth_image)
|
||||
with torch.no_grad():
|
||||
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
|
||||
embedding_msg.header.seq += 1
|
||||
embedding_msg.data = depth_embedding.tolist()
|
||||
embedding_publisher.publish(embedding_msg)
|
||||
|
||||
# Publish the acquired image if needed
|
||||
if args.enable_vis:
|
||||
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
|
||||
depth_image_msg.header.stamp = rospy.Time.now()
|
||||
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
depth_image_publisher.publish(depth_image_msg)
|
||||
network_input_np = (\
|
||||
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
|
||||
+ depth_range[0]
|
||||
).astype(np.uint16)
|
||||
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
|
||||
network_input_msg.header.stamp = rospy.Time.now()
|
||||
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
network_input_publisher.publish(network_input_msg)
|
||||
ros_rate.sleep()
|
||||
finally:
|
||||
rs_pipeline.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
""" This script is designed to load the model and process the realsense image directly
|
||||
from realsense SDK without realsense ROS wrapper
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--namespace",
|
||||
type= str,
|
||||
default= "/a112138",
|
||||
)
|
||||
parser.add_argument("--logdir",
|
||||
type= str,
|
||||
help= "The log directory of the trained model",
|
||||
)
|
||||
parser.add_argument("--height",
|
||||
type= int,
|
||||
default= 240,
|
||||
help= "The height of the realsense image",
|
||||
)
|
||||
parser.add_argument("--width",
|
||||
type= int,
|
||||
default= 424,
|
||||
help= "The width of the realsense image",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type= int,
|
||||
default= 30,
|
||||
help= "The fps of the realsense image",
|
||||
)
|
||||
parser.add_argument("--crop_left",
|
||||
type= int,
|
||||
default= 60,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_right",
|
||||
type= int,
|
||||
default= 46,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_top",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_bottom",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_far",
|
||||
type= float,
|
||||
default= 3.0,
|
||||
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
|
||||
)
|
||||
parser.add_argument("--enable_rgb",
|
||||
action= "store_true",
|
||||
help= "Whether to enable rgb image",
|
||||
)
|
||||
parser.add_argument("--enable_vis",
|
||||
action= "store_true",
|
||||
help= "Whether to publish realsense image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -34,6 +34,8 @@ from isaacgym import gymutil
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from legged_gym.utils.webviewer import WebViewer
|
||||
|
||||
# Base class for RL tasks
|
||||
class BaseTask():
|
||||
|
||||
|
@ -66,6 +68,12 @@ class BaseTask():
|
|||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
self.extras = {}
|
||||
|
||||
# create envs, sim and viewer
|
||||
self.create_sim()
|
||||
self.gym.prepare_sim(self.sim)
|
||||
|
||||
# allocate buffers
|
||||
self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
|
||||
self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
|
||||
|
@ -78,12 +86,6 @@ class BaseTask():
|
|||
self.privileged_obs_buf = None
|
||||
# self.num_privileged_obs = self.num_obs
|
||||
|
||||
self.extras = {}
|
||||
|
||||
# create envs, sim and viewer
|
||||
self.create_sim()
|
||||
self.gym.prepare_sim(self.sim)
|
||||
|
||||
# todo: read from config
|
||||
self.enable_viewer_sync = True
|
||||
self.viewer = None
|
||||
|
@ -98,6 +100,13 @@ class BaseTask():
|
|||
self.gym.subscribe_viewer_keyboard_event(
|
||||
self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
|
||||
|
||||
def start_webviewer(self, port= 5000):
|
||||
""" This method must be called after the env is fully initialized """
|
||||
print("Starting webviewer on port: ", port)
|
||||
print("env is passed as a parameter to the webviewer")
|
||||
self.webviewer = WebViewer(host= "127.0.0.1", port= port)
|
||||
self.webviewer.setup(self)
|
||||
|
||||
def get_observations(self):
|
||||
return self.obs_buf
|
||||
|
||||
|
@ -141,4 +150,9 @@ class BaseTask():
|
|||
if sync_frame_time:
|
||||
self.gym.sync_frame_time(self.sim)
|
||||
else:
|
||||
self.gym.poll_viewer_events(self.viewer)
|
||||
self.gym.poll_viewer_events(self.viewer)
|
||||
if hasattr(self, "webviewer"):
|
||||
self.webviewer.render(fetch_results=True,
|
||||
step_graphics=True,
|
||||
render_all_camera_sensors=True,
|
||||
wait_for_page_load=True)
|
File diff suppressed because it is too large
Load Diff
|
@ -129,6 +129,7 @@ class LeggedRobotCfg(BaseConfig):
|
|||
max_push_vel_xy = 1.
|
||||
max_push_vel_ang = 0.
|
||||
init_dof_pos_ratio_range = [0.5, 1.5]
|
||||
init_base_vel_range = [-1., 1.]
|
||||
|
||||
class rewards:
|
||||
class scales:
|
||||
|
@ -160,6 +161,7 @@ class LeggedRobotCfg(BaseConfig):
|
|||
class obs_scales:
|
||||
lin_vel = 2.0
|
||||
ang_vel = 0.25
|
||||
commands = [2., 2., 0.25] # matching lin_vel and ang_vel scales
|
||||
dof_pos = 1.0
|
||||
dof_vel = 0.05
|
||||
height_measurements = 5.0
|
||||
|
@ -182,6 +184,12 @@ class LeggedRobotCfg(BaseConfig):
|
|||
ref_env = 0
|
||||
pos = [10, 0, 6] # [m]
|
||||
lookat = [11., 5, 3.] # [m]
|
||||
stream_depth = False # for webviewer
|
||||
|
||||
draw_commands = True # for debugger
|
||||
class commands:
|
||||
color = [0.1, 0.8, 0.1] # rgb
|
||||
size = 0.5
|
||||
|
||||
class sim:
|
||||
dt = 0.005
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,14 +1,13 @@
|
|||
import random
|
||||
|
||||
from isaacgym.torch_utils import torch_rand_float, get_euler_xyz, quat_from_euler_xyz, tf_apply
|
||||
from isaacgym import gymtorch, gymapi, gymutil
|
||||
import numpy as np
|
||||
from isaacgym.torch_utils import torch_rand_float, tf_apply
|
||||
from isaacgym import gymutil, gymapi
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
|
||||
from legged_gym.envs.base.legged_robot_field import LeggedRobotField
|
||||
|
||||
class LeggedRobotNoisy(LeggedRobotField):
|
||||
class LeggedRobotNoisyMixin:
|
||||
""" This class should be independent from the terrain, but depend on the sensors of the parent
|
||||
class.
|
||||
"""
|
||||
|
@ -16,22 +15,23 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
def clip_position_action_by_torque_limit(self, actions_scaled):
|
||||
""" For position control, scaled actions should be in the coordinate of robot default dof pos
|
||||
"""
|
||||
if hasattr(self, "proprioception_output"):
|
||||
dof_vel = self.proprioception_output[:, -24:-12] / self.obs_scales.dof_vel
|
||||
dof_pos_ = self.proprioception_output[:, -36:-24] / self.obs_scales.dof_pos
|
||||
if hasattr(self, "dof_vel_obs_output_buffer"):
|
||||
dof_vel = self.dof_vel_obs_output_buffer / self.obs_scales.dof_vel
|
||||
else:
|
||||
dof_vel = self.dof_vel
|
||||
if hasattr(self, "dof_pos_obs_output_buffer"):
|
||||
dof_pos_ = self.dof_pos_obs_output_buffer / self.obs_scales.dof_pos
|
||||
else:
|
||||
dof_pos_ = self.dof_pos - self.default_dof_pos
|
||||
p_limits_low = (-self.torque_limits) + self.d_gains*dof_vel
|
||||
p_limits_high = (self.torque_limits) + self.d_gains*dof_vel
|
||||
actions_low = (p_limits_low/self.p_gains) + dof_pos_
|
||||
actions_high = (p_limits_high/self.p_gains) + dof_pos_
|
||||
actions_scaled_torque_clipped = torch.clip(actions_scaled, actions_low, actions_high)
|
||||
return actions_scaled_torque_clipped
|
||||
actions_scaled_clipped = torch.clip(actions_scaled, actions_low, actions_high)
|
||||
return actions_scaled_clipped
|
||||
|
||||
def pre_physics_step(self, actions):
|
||||
self.forward_depth_refreshed = False # incase _get_forward_depth_obs is called multiple times
|
||||
self.proprioception_refreshed = False
|
||||
self.set_buffers_refreshed_to_false()
|
||||
return_ = super().pre_physics_step(actions)
|
||||
|
||||
if isinstance(self.cfg.control.action_scale, (tuple, list)):
|
||||
|
@ -40,24 +40,24 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
self.actions_scaled = self.actions * self.cfg.control.action_scale
|
||||
control_type = self.cfg.control.control_type
|
||||
if control_type == "P":
|
||||
actions_scaled_torque_clipped = self.clip_position_action_by_torque_limit(self.actions_scaled)
|
||||
actions_scaled_clipped = self.clip_position_action_by_torque_limit(self.actions_scaled)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
actions_scaled_torque_clipped = self.actions * self.cfg.control.action_scale
|
||||
actions_scaled_clipped = self.actions * self.cfg.control.action_scale
|
||||
|
||||
if getattr(self.cfg.control, "action_delay", False):
|
||||
# always put the latest action at the end of the buffer
|
||||
self.actions_history_buffer = torch.roll(self.actions_history_buffer, shifts= -1, dims= 0)
|
||||
self.actions_history_buffer[-1] = actions_scaled_torque_clipped
|
||||
self.actions_history_buffer[-1] = actions_scaled_clipped
|
||||
# get the delayed action
|
||||
self.action_delayed_frames = ((self.current_action_delay / self.dt) + 1).to(int)
|
||||
self.actions_scaled_torque_clipped = self.actions_history_buffer[
|
||||
-self.action_delayed_frames,
|
||||
action_delayed_frames = ((self.action_delay_buffer / self.dt) + 1).to(int)
|
||||
self.actions_scaled_clipped = self.actions_history_buffer[
|
||||
-action_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
]
|
||||
else:
|
||||
self.actions_scaled_torque_clipped = actions_scaled_torque_clipped
|
||||
self.actions_scaled_clipped = actions_scaled_clipped
|
||||
|
||||
return return_
|
||||
|
||||
|
@ -69,12 +69,12 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
return super()._compute_torques(actions)
|
||||
else:
|
||||
if hasattr(self, "motor_strength"):
|
||||
actions_scaled_torque_clipped = self.motor_strength * self.actions_scaled_torque_clipped
|
||||
actions_scaled_clipped = self.motor_strength * self.actions_scaled_clipped
|
||||
else:
|
||||
actions_scaled_torque_clipped = self.actions_scaled_torque_clipped
|
||||
actions_scaled_clipped = self.actions_scaled_clipped
|
||||
control_type = self.cfg.control.control_type
|
||||
if control_type == "P":
|
||||
torques = self.p_gains * (actions_scaled_torque_clipped + self.default_dof_pos - self.dof_pos) \
|
||||
torques = self.p_gains * (actions_scaled_clipped + self.default_dof_pos - self.dof_pos) \
|
||||
- self.d_gains * self.dof_vel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -93,9 +93,7 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
self.max_torques,
|
||||
)
|
||||
### The set torque limit is usally smaller than the robot dataset
|
||||
self.torque_exceed_count_substep[(torch.abs(self.torques) > self.torque_limits).any(dim= -1)] += 1
|
||||
### Hack to check the torque limit exceeding by your own value.
|
||||
# self.torque_exceed_count_envstep[(torch.abs(self.torques) > 38.).any(dim= -1)] += 1
|
||||
self.torque_exceed_count_substep[(torch.abs(self.torques) > self.torque_limits * self.cfg.rewards.soft_torque_limit).any(dim= -1)] += 1
|
||||
|
||||
### count how many times in the episode the robot is out of dof pos limit (summing all dofs)
|
||||
self.out_of_dof_pos_limit_count_substep += self._reward_dof_pos_limits().int()
|
||||
|
@ -130,43 +128,11 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
if len(resample_env_ids) > 0:
|
||||
self._resample_action_delay(resample_env_ids)
|
||||
|
||||
if hasattr(self, "proprioception_buffer"):
|
||||
resampling_time = getattr(self.cfg.sensor.proprioception, "latency_resampling_time", self.dt)
|
||||
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
|
||||
if len(resample_env_ids) > 0:
|
||||
self._resample_proprioception_latency(resample_env_ids)
|
||||
|
||||
if hasattr(self, "forward_depth_buffer"):
|
||||
resampling_time = getattr(self.cfg.sensor.forward_camera, "latency_resampling_time", self.dt)
|
||||
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
|
||||
if len(resample_env_ids) > 0:
|
||||
self._resample_forward_camera_latency(resample_env_ids)
|
||||
for sensor_name in self.available_sensors:
|
||||
if hasattr(self, sensor_name + "_latency_buffer"):
|
||||
self._resample_sensor_latency_if_needed(sensor_name)
|
||||
|
||||
self.torque_exceed_count_envstep[(torch.abs(self.substep_torques) > self.torque_limits).any(dim= 1).any(dim= 1)] += 1
|
||||
|
||||
def _resample_action_delay(self, env_ids):
|
||||
self.current_action_delay[env_ids] = torch_rand_float(
|
||||
self.cfg.control.action_delay_range[0],
|
||||
self.cfg.control.action_delay_range[1],
|
||||
(len(env_ids), 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
|
||||
def _resample_proprioception_latency(self, env_ids):
|
||||
self.current_proprioception_latency[env_ids] = torch_rand_float(
|
||||
self.cfg.sensor.proprioception.latency_range[0],
|
||||
self.cfg.sensor.proprioception.latency_range[1],
|
||||
(len(env_ids), 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
|
||||
def _resample_forward_camera_latency(self, env_ids):
|
||||
self.current_forward_camera_latency[env_ids] = torch_rand_float(
|
||||
self.cfg.sensor.forward_camera.latency_range[0],
|
||||
self.cfg.sensor.forward_camera.latency_range[1],
|
||||
(len(env_ids), 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
self.torque_exceed_count_envstep[(torch.abs(self.substep_torques) > self.torque_limits * self.cfg.rewards.soft_torque_limit).any(dim= 1).any(dim= 1)] += 1
|
||||
|
||||
def _init_buffers(self):
|
||||
return_ = super()._init_buffers()
|
||||
|
@ -174,9 +140,29 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
|
||||
if getattr(self.cfg.control, "action_delay", False):
|
||||
assert hasattr(self.cfg.control, "action_delay_range") and hasattr(self.cfg.control, "action_delay_resample_time"), "Please specify action_delay_range and action_delay_resample_time in the config file."
|
||||
""" Used in pre-physics step """
|
||||
self.cfg.control.action_history_buffer_length = int((self.cfg.control.action_delay_range[1] + self.dt) / self.dt)
|
||||
self.actions_history_buffer = torch.zeros(
|
||||
self.build_action_delay_buffer()
|
||||
|
||||
self.component_governed_by_sensor = dict()
|
||||
for sensor_name in self.available_sensors:
|
||||
if hasattr(self.cfg.sensor, sensor_name):
|
||||
self.set_latency_buffer_for_sensor(sensor_name)
|
||||
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
|
||||
assert not hasattr(self, component + "_obs_buffer"), "The obs component {} already has a buffer and corresponding sensor. Should not be governed also by {}".format(component, sensor_name)
|
||||
self.set_obs_buffers_for_component(component, sensor_name)
|
||||
if component == "forward_depth":
|
||||
self.build_depth_image_processor_buffers(sensor_name)
|
||||
|
||||
self.max_torques = torch.zeros_like(self.torques[..., 0])
|
||||
self.torque_exceed_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the torque exceeds the limit
|
||||
self.torque_exceed_count_envstep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of envsteps that the torque exceeds the limit
|
||||
self.out_of_dof_pos_limit_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the dof pos exceeds the limit
|
||||
|
||||
return return_
|
||||
|
||||
def build_action_delay_buffer(self):
|
||||
""" Used in pre-physics step """
|
||||
self.cfg.control.action_history_buffer_length = int((self.cfg.control.action_delay_range[1] + self.dt) / self.dt)
|
||||
self.actions_history_buffer = torch.zeros(
|
||||
(
|
||||
self.cfg.control.action_history_buffer_length,
|
||||
self.num_envs,
|
||||
|
@ -185,66 +171,27 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
self.current_action_delay = torch_rand_float(
|
||||
self.action_delay_buffer = torch_rand_float(
|
||||
self.cfg.control.action_delay_range[0],
|
||||
self.cfg.control.action_delay_range[1],
|
||||
(self.num_envs, 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
self.action_delayed_frames = ((self.current_action_delay / self.dt) + 1).to(int)
|
||||
self.action_delayed_frames = ((self.action_delay_buffer / self.dt) + 1).to(int)
|
||||
|
||||
if "proprioception" in all_obs_components and hasattr(self.cfg.sensor, "proprioception"):
|
||||
""" Adding proprioception delay buffer """
|
||||
self.cfg.sensor.proprioception.buffer_length = int((self.cfg.sensor.proprioception.latency_range[1] + self.dt) / self.dt)
|
||||
self.proprioception_buffer = torch.zeros(
|
||||
(
|
||||
self.cfg.sensor.proprioception.buffer_length,
|
||||
self.num_envs,
|
||||
self.get_num_obs_from_components(["proprioception"]),
|
||||
),
|
||||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
self.current_proprioception_latency = torch_rand_float(
|
||||
self.cfg.sensor.proprioception.latency_range[0],
|
||||
self.cfg.sensor.proprioception.latency_range[1],
|
||||
(self.num_envs, 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
self.proprioception_delayed_frames = ((self.current_proprioception_latency / self.dt) + 1).to(int)
|
||||
|
||||
if "forward_depth" in all_obs_components and hasattr(self.cfg.sensor, "forward_camera"):
|
||||
output_resolution = getattr(self.cfg.sensor.forward_camera, "output_resolution", self.cfg.sensor.forward_camera.resolution)
|
||||
self.cfg.sensor.forward_camera.buffer_length = int((self.cfg.sensor.forward_camera.latency_range[1] + self.cfg.sensor.forward_camera.refresh_duration) / self.dt)
|
||||
self.forward_depth_buffer = torch.zeros(
|
||||
(
|
||||
self.cfg.sensor.forward_camera.buffer_length,
|
||||
self.num_envs,
|
||||
1,
|
||||
output_resolution[0],
|
||||
output_resolution[1],
|
||||
),
|
||||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
self.forward_depth_delayed_frames = torch.ones((self.num_envs,), device= self.device, dtype= int) * self.cfg.sensor.forward_camera.buffer_length
|
||||
self.current_forward_camera_latency = torch_rand_float(
|
||||
self.cfg.sensor.forward_camera.latency_range[0],
|
||||
self.cfg.sensor.forward_camera.latency_range[1],
|
||||
(self.num_envs, 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
if hasattr(self.cfg.sensor.forward_camera, "resized_resolution"):
|
||||
self.forward_depth_resize_transform = T.Resize(
|
||||
def build_depth_image_processor_buffers(self, sensor_name):
|
||||
assert sensor_name == "forward_camera", "Only forward_camera is supported for now."
|
||||
if hasattr(getattr(self.cfg.sensor, sensor_name), "resized_resolution"):
|
||||
self.forward_depth_resize_transform = T.Resize(
|
||||
self.cfg.sensor.forward_camera.resized_resolution,
|
||||
interpolation= T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
self.contour_detection_kernel = torch.zeros(
|
||||
(8, 1, 3, 3),
|
||||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
# emperical values to be more sensitive to vertical edges
|
||||
(8, 1, 3, 3),
|
||||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
# emperical values to be more sensitive to vertical edges
|
||||
self.contour_detection_kernel[0, :, 1, 1] = 0.5
|
||||
self.contour_detection_kernel[0, :, 0, 0] = -0.5
|
||||
self.contour_detection_kernel[1, :, 1, 1] = 0.1
|
||||
|
@ -261,40 +208,176 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
self.contour_detection_kernel[6, :, 2, 1] = -0.1
|
||||
self.contour_detection_kernel[7, :, 1, 1] = 0.5
|
||||
self.contour_detection_kernel[7, :, 2, 2] = -0.5
|
||||
|
||||
""" Considering sensors are not necessarily matching observation component,
|
||||
we need to set the buffers for each obs component and latency buffer for each sensor.
|
||||
"""
|
||||
def set_obs_buffers_for_component(self, component, sensor_name):
|
||||
buffer_length = int(getattr(self.cfg.sensor, sensor_name).latency_range[1] / self.dt) + 1
|
||||
# use super().get_obs_segment_from_components() to get the obs shape to prevent post processing
|
||||
# overrides the buffer shape
|
||||
obs_buffer = torch.zeros(
|
||||
(
|
||||
buffer_length,
|
||||
self.num_envs,
|
||||
*(self.get_obs_segment_from_components([component])[component]), # tuple(obs_shape)
|
||||
),
|
||||
dtype= torch.float32,
|
||||
device= self.device,
|
||||
)
|
||||
setattr(self, component + "_obs_buffer", obs_buffer)
|
||||
setattr(self, component + "_obs_refreshed", False)
|
||||
self.component_governed_by_sensor[component] = sensor_name
|
||||
|
||||
self.max_torques = torch.zeros_like(self.torques[..., 0])
|
||||
self.torque_exceed_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the torque exceeds the limit
|
||||
self.torque_exceed_count_envstep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of envsteps that the torque exceeds the limit
|
||||
self.out_of_dof_pos_limit_count_substep = torch.zeros_like(self.torques[..., 0], dtype= torch.int32) # The number of substeps that the dof pos exceeds the limit
|
||||
def set_latency_buffer_for_sensor(self, sensor_name):
|
||||
latency_buffer = torch_rand_float(
|
||||
getattr(self.cfg.sensor, sensor_name).latency_range[0],
|
||||
getattr(self.cfg.sensor, sensor_name).latency_range[1],
|
||||
(self.num_envs, 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
# using setattr to set the buffer
|
||||
setattr(self, sensor_name + "_latency_buffer", latency_buffer)
|
||||
if "camera" in sensor_name:
|
||||
setattr(
|
||||
self,
|
||||
sensor_name + "_delayed_frames",
|
||||
torch.zeros_like(latency_buffer, dtype= torch.long, device= self.device),
|
||||
)
|
||||
|
||||
def _resample_sensor_latency_if_needed(self, sensor_name):
|
||||
resampling_time = getattr(getattr(self.cfg.sensor, sensor_name), "latency_resampling_time", self.dt)
|
||||
resample_env_ids = (self.episode_length_buf % int(resampling_time / self.dt) == 0).nonzero(as_tuple= False).flatten()
|
||||
if len(resample_env_ids) > 0:
|
||||
getattr(self, sensor_name + "_latency_buffer")[resample_env_ids] = torch_rand_float(
|
||||
getattr(getattr(self.cfg.sensor, sensor_name), "latency_range")[0],
|
||||
getattr(getattr(self.cfg.sensor, sensor_name), "latency_range")[1],
|
||||
(len(resample_env_ids), 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
|
||||
return return_
|
||||
def _resample_action_delay(self, env_ids):
|
||||
self.action_delay_buffer[env_ids] = torch_rand_float(
|
||||
self.cfg.control.action_delay_range[0],
|
||||
self.cfg.control.action_delay_range[1],
|
||||
(len(env_ids), 1),
|
||||
device= self.device,
|
||||
).flatten()
|
||||
|
||||
def _reset_buffers(self, env_ids):
|
||||
return_ = super()._reset_buffers(env_ids)
|
||||
if hasattr(self, "actions_history_buffer"):
|
||||
self.actions_history_buffer[:, env_ids] = 0.
|
||||
self.action_delayed_frames[env_ids] = self.cfg.control.action_history_buffer_length
|
||||
if hasattr(self, "forward_depth_buffer"):
|
||||
self.forward_depth_buffer[:, env_ids] = 0.
|
||||
self.forward_depth_delayed_frames[env_ids] = self.cfg.sensor.forward_camera.buffer_length
|
||||
if hasattr(self, "proprioception_buffer"):
|
||||
self.proprioception_buffer[:, env_ids] = 0.
|
||||
self.proprioception_delayed_frames[env_ids] = self.cfg.sensor.proprioception.buffer_length
|
||||
for sensor_name in self.available_sensors:
|
||||
if not hasattr(self.cfg.sensor, sensor_name):
|
||||
continue
|
||||
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
|
||||
if hasattr(self, component + "_obs_buffer"):
|
||||
getattr(self, component + "_obs_buffer")[:, env_ids] = 0.
|
||||
setattr(self, component + "_obs_refreshed", False)
|
||||
if "camera" in sensor_name:
|
||||
getattr(self, sensor_name + "_delayed_frames")[env_ids] = 0
|
||||
return return_
|
||||
|
||||
def _draw_debug_vis(self):
|
||||
return_ = super()._draw_debug_vis()
|
||||
def _build_forward_depth_intrinsic(self):
|
||||
sim_raw_resolution = self.cfg.sensor.forward_camera.resolution
|
||||
sim_cropping_h = self.cfg.sensor.forward_camera.crop_top_bottom
|
||||
sim_cropping_w = self.cfg.sensor.forward_camera.crop_left_right
|
||||
cropped_resolution = [ # (H, W)
|
||||
sim_raw_resolution[0] - sum(sim_cropping_h),
|
||||
sim_raw_resolution[1] - sum(sim_cropping_w),
|
||||
]
|
||||
network_input_resolution = self.cfg.sensor.forward_camera.output_resolution
|
||||
x_fov = torch.mean(torch.tensor(self.cfg.sensor.forward_camera.horizontal_fov).float()) \
|
||||
/ 180 * np.pi
|
||||
fx = (sim_raw_resolution[1]) / (2 * torch.tan(x_fov / 2))
|
||||
fy = fx # * (sim_raw_resolution[0] / sim_raw_resolution[1])
|
||||
fx = fx * network_input_resolution[1] / cropped_resolution[1]
|
||||
fy = fy * network_input_resolution[0] / cropped_resolution[0]
|
||||
cx = (sim_raw_resolution[1] / 2) - sim_cropping_w[0]
|
||||
cy = (sim_raw_resolution[0] / 2) - sim_cropping_h[0]
|
||||
cx = cx * network_input_resolution[1] / cropped_resolution[1]
|
||||
cy = cy * network_input_resolution[0] / cropped_resolution[0]
|
||||
self.forward_depth_intrinsic = torch.tensor([
|
||||
[fx, 0., cx],
|
||||
[0., fy, cy],
|
||||
[0., 0., 1.],
|
||||
], device= self.device)
|
||||
x_arr = torch.linspace(
|
||||
0,
|
||||
network_input_resolution[1] - 1,
|
||||
network_input_resolution[1],
|
||||
)
|
||||
y_arr = torch.linspace(
|
||||
0,
|
||||
network_input_resolution[0] - 1,
|
||||
network_input_resolution[0],
|
||||
)
|
||||
x_mesh, y_mesh = torch.meshgrid(x_arr, y_arr, indexing= "xy")
|
||||
# (H, W, 2) -> (H * W, 3) row wise
|
||||
self.forward_depth_pixel_mesh = torch.stack([
|
||||
x_mesh,
|
||||
y_mesh,
|
||||
torch.ones_like(x_mesh),
|
||||
], dim= -1).view(-1, 3).float().to(self.device)
|
||||
|
||||
def _draw_pointcloud_from_depth_image(self,
|
||||
env_h, sensor_h,
|
||||
camera_intrinsic,
|
||||
depth_image, # torch tensor (H, W)
|
||||
offset = [-2, 0, 1],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
offset: drawing points directly based on the sensor will interfere with the
|
||||
depth image rendering. so we need to provide the offset to the pointcloud
|
||||
"""
|
||||
assert self.num_envs == 1, "LeggedRobotNoisy: Only implemented when num_envs == 1"
|
||||
camera_transform = self.gym.get_camera_transform(self.sim, env_h, sensor_h)
|
||||
camera_intrinsic_inv = torch.inverse(camera_intrinsic)
|
||||
|
||||
# (H * W, 3) -> (3, H * W) -> (H * W, 3)
|
||||
depth_image = depth_image * (self.cfg.sensor.forward_camera.depth_range[1] - self.cfg.sensor.forward_camera.depth_range[0]) + self.cfg.sensor.forward_camera.depth_range[0]
|
||||
points = camera_intrinsic_inv @ self.forward_depth_pixel_mesh.T * depth_image.view(-1)
|
||||
points = points.T
|
||||
|
||||
sphere_geom = gymutil.WireframeSphereGeometry(0.008, 8, 8, None, color= (0.9, 0.1, 0.9))
|
||||
for p in points:
|
||||
sphere_pose = gymapi.Transform(
|
||||
p= camera_transform.transform_point(gymapi.Vec3(
|
||||
p[2] + offset[0],
|
||||
-p[0] + offset[1],
|
||||
-p[1] + offset[2],
|
||||
)), # +z forward to +x forward
|
||||
r= None,
|
||||
)
|
||||
gymutil.draw_lines(sphere_geom, self.gym, self.viewer, env_h, sphere_pose)
|
||||
|
||||
def _draw_sensor_reading_vis(self, env_h, sensor_hd):
|
||||
return_ = super()._draw_sensor_reading_vis(env_h, sensor_hd)
|
||||
if hasattr(self, "forward_depth_output"):
|
||||
if self.num_envs == 1:
|
||||
import matplotlib.pyplot as plt
|
||||
forward_depth_np = self.forward_depth_output[0, 0].detach().cpu().numpy() # (H, W)
|
||||
plt.imshow(forward_depth_np, cmap= "gray", vmin= 0, vmax= 1)
|
||||
plt.pause(0.001)
|
||||
if getattr(self.cfg.viewer, "forward_depth_as_pointcloud", False):
|
||||
if not hasattr(self, "forward_depth_intrinsic"):
|
||||
self._build_forward_depth_intrinsic()
|
||||
for sensor_name, sensor_h in sensor_hd.items():
|
||||
if "forward_camera" in sensor_name:
|
||||
self._draw_pointcloud_from_depth_image(
|
||||
env_h, sensor_h,
|
||||
self.forward_depth_intrinsic,
|
||||
self.forward_depth_output[0, 0].detach(),
|
||||
)
|
||||
else:
|
||||
import matplotlib.pyplot as plt
|
||||
forward_depth_np = self.forward_depth_output[0, 0].detach().cpu().numpy() # (H, W)
|
||||
plt.imshow(forward_depth_np, cmap= "gray", vmin= 0, vmax= 1)
|
||||
plt.pause(0.001)
|
||||
else:
|
||||
print("LeggedRobotNoisy: More than one robot, stop showing camera image")
|
||||
return return_
|
||||
|
||||
""" Steps to simulate stereo camera depth image """
|
||||
""" ########## Steps to simulate stereo camera depth image ########## """
|
||||
def _add_depth_contour(self, depth_images):
|
||||
mask = F.max_pool2d(
|
||||
torch.abs(F.conv2d(depth_images, self.contour_detection_kernel, padding= 1)).max(dim= -3, keepdim= True)[0],
|
||||
|
@ -545,32 +628,133 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
depth_images_ = self._crop_depth_images(depth_images_)
|
||||
if hasattr(self, "forward_depth_resize_transform"):
|
||||
depth_images_ = self.forward_depth_resize_transform(depth_images_)
|
||||
depth_images_ = depth_images_.clip(0, 1)
|
||||
return depth_images_.unsqueeze(0) # (1, N, 1, H, W)
|
||||
|
||||
""" ########## Override the observation functions to add latencies and artifacts ########## """
|
||||
def set_buffers_refreshed_to_false(self):
|
||||
for sensor_name in self.available_sensors:
|
||||
if hasattr(self.cfg.sensor, sensor_name):
|
||||
for component in getattr(self.cfg.sensor, sensor_name).obs_components:
|
||||
setattr(self, component + "_obs_refreshed", False)
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_ang_vel_obs(self, privileged= False):
|
||||
if hasattr(self, "ang_vel_obs_buffer") and (not self.ang_vel_obs_refreshed) and (not privileged):
|
||||
self.ang_vel_obs_buffer = torch.cat([
|
||||
self.ang_vel_obs_buffer[1:],
|
||||
super()._get_ang_vel_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
component_governed_by = self.component_governed_by_sensor["ang_vel"]
|
||||
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
|
||||
self.ang_vel_obs_output_buffer = self.ang_vel_obs_buffer[
|
||||
-buffer_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.ang_vel_obs_refreshed = True
|
||||
if privileged or not hasattr(self, "ang_vel_obs_buffer"):
|
||||
return super()._get_ang_vel_obs(privileged)
|
||||
return self.ang_vel_obs_output_buffer
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_projected_gravity_obs(self, privileged= False):
|
||||
if hasattr(self, "projected_gravity_obs_buffer") and (not self.projected_gravity_obs_refreshed) and (not privileged):
|
||||
self.projected_gravity_obs_buffer = torch.cat([
|
||||
self.projected_gravity_obs_buffer[1:],
|
||||
super()._get_projected_gravity_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
component_governed_by = self.component_governed_by_sensor["projected_gravity"]
|
||||
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
|
||||
self.projected_gravity_obs_output_buffer = self.projected_gravity_obs_buffer[
|
||||
-buffer_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.projected_gravity_obs_refreshed = True
|
||||
if privileged or not hasattr(self, "projected_gravity_obs_buffer"):
|
||||
return super()._get_projected_gravity_obs(privileged)
|
||||
return self.projected_gravity_obs_output_buffer
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_commands_obs(self, privileged= False):
|
||||
if hasattr(self, "commands_obs_buffer") and (not self.commands_obs_refreshed) and (not privileged):
|
||||
self.commands_obs_buffer = torch.cat([
|
||||
self.commands_obs_buffer[1:],
|
||||
super()._get_commands_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
component_governed_by = self.component_governed_by_sensor["commands"]
|
||||
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
|
||||
self.commands_obs_output_buffer = self.commands_obs_buffer[
|
||||
-buffer_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.commands_obs_refreshed = True
|
||||
if privileged or not hasattr(self, "commands_obs_buffer"):
|
||||
return super()._get_commands_obs(privileged)
|
||||
return self.commands_obs_output_buffer
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_dof_pos_obs(self, privileged= False):
|
||||
if hasattr(self, "dof_pos_obs_buffer") and (not self.dof_pos_obs_refreshed) and (not privileged):
|
||||
self.dof_pos_obs_buffer = torch.cat([
|
||||
self.dof_pos_obs_buffer[1:],
|
||||
super()._get_dof_pos_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
component_governed_by = self.component_governed_by_sensor["dof_pos"]
|
||||
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
|
||||
self.dof_pos_obs_output_buffer = self.dof_pos_obs_buffer[
|
||||
-buffer_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.dof_pos_obs_refreshed = True
|
||||
if privileged or not hasattr(self, "dof_pos_obs_buffer"):
|
||||
return super()._get_dof_pos_obs(privileged)
|
||||
return self.dof_pos_obs_output_buffer
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_dof_vel_obs(self, privileged= False):
|
||||
if hasattr(self, "dof_vel_obs_buffer") and (not self.dof_vel_obs_refreshed) and (not privileged):
|
||||
self.dof_vel_obs_buffer = torch.cat([
|
||||
self.dof_vel_obs_buffer[1:],
|
||||
super()._get_dof_vel_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
component_governed_by = self.component_governed_by_sensor["dof_vel"]
|
||||
buffer_delayed_frames = ((getattr(self, component_governed_by + "_latency_buffer") / self.dt) + 1).to(int)
|
||||
self.dof_vel_obs_output_buffer = self.dof_vel_obs_buffer[
|
||||
-buffer_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.dof_vel_obs_refreshed = True
|
||||
if privileged or not hasattr(self, "dof_vel_obs_buffer"):
|
||||
return super()._get_dof_vel_obs(privileged)
|
||||
return self.dof_vel_obs_output_buffer
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_forward_depth_obs(self, privileged= False):
|
||||
if not self.forward_depth_refreshed and hasattr(self.cfg.sensor, "forward_camera") and (not privileged):
|
||||
self.forward_depth_buffer = torch.cat([
|
||||
self.forward_depth_buffer[1:],
|
||||
if hasattr(self, "forward_depth_obs_buffer") and (not self.forward_depth_obs_refreshed) and hasattr(self.cfg.sensor, "forward_camera") and (not privileged):
|
||||
# TODO: any variables named with "forward_camera" here should be rearanged
|
||||
# in a individual class method.
|
||||
self.forward_depth_obs_buffer = torch.cat([
|
||||
self.forward_depth_obs_buffer[1:],
|
||||
self._process_depth_image(self.sensor_tensor_dict["forward_depth"]),
|
||||
], dim= 0)
|
||||
delay_refresh_mask = (self.episode_length_buf % int(self.cfg.sensor.forward_camera.refresh_duration / self.dt)) == 0
|
||||
# NOTE: if the delayed frames is greater than the last frame, the last image should be used.
|
||||
frame_select = (self.current_forward_camera_latency / self.dt).to(int)
|
||||
self.forward_depth_delayed_frames = torch.where(
|
||||
frame_select = (self.forward_camera_latency_buffer / self.dt).to(int)
|
||||
self.forward_camera_delayed_frames = torch.where(
|
||||
delay_refresh_mask,
|
||||
torch.minimum(
|
||||
frame_select,
|
||||
self.forward_depth_delayed_frames + 1,
|
||||
self.forward_camera_delayed_frames + 1,
|
||||
),
|
||||
self.forward_depth_delayed_frames + 1,
|
||||
self.forward_camera_delayed_frames + 1,
|
||||
)
|
||||
self.forward_depth_delayed_frames = torch.clip(
|
||||
self.forward_depth_delayed_frames,
|
||||
self.forward_camera_delayed_frames = torch.clip(
|
||||
self.forward_camera_delayed_frames,
|
||||
0,
|
||||
self.cfg.sensor.forward_camera.buffer_length,
|
||||
self.forward_depth_obs_buffer.shape[0],
|
||||
)
|
||||
self.forward_depth_output = self.forward_depth_buffer[
|
||||
-self.forward_depth_delayed_frames,
|
||||
self.forward_depth_output = self.forward_depth_obs_buffer[
|
||||
-self.forward_camera_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
self.forward_depth_refreshed = True
|
||||
|
@ -579,29 +763,6 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
|
||||
return self.forward_depth_output.flatten(start_dim= 1)
|
||||
|
||||
def _get_proprioception_obs(self, privileged= False):
|
||||
if not self.proprioception_refreshed and hasattr(self.cfg.sensor, "proprioception") and (not privileged):
|
||||
self.proprioception_buffer = torch.cat([
|
||||
self.proprioception_buffer[1:],
|
||||
super()._get_proprioception_obs().unsqueeze(0),
|
||||
], dim= 0)
|
||||
# NOTE: if the delayed frames is greater than the last frame, the last image should be used. [0.04-0.0075, 0.04+0.0025]
|
||||
self.proprioception_delayed_frames = ((self.current_proprioception_latency / self.dt) + 1).to(int)
|
||||
self.proprioception_output = self.proprioception_buffer[
|
||||
-self.proprioception_delayed_frames,
|
||||
torch.arange(self.num_envs, device= self.device),
|
||||
].clone()
|
||||
### NOTE: WARN: ERROR: remove this code in final version, no action delay should be used.
|
||||
if getattr(self.cfg.sensor.proprioception, "delay_action_obs", False) or getattr(self.cfg.sensor.proprioception, "delay_privileged_action_obs", False):
|
||||
raise ValueError("LeggedRobotNoisy: No action delay should be used. Please remove these settings")
|
||||
# The last-action is not delayed.
|
||||
self.proprioception_output[:, -12:] = self.proprioception_buffer[-1, :, -12:]
|
||||
self.proprioception_refreshed = True
|
||||
if not hasattr(self.cfg.sensor, "proprioception") or privileged:
|
||||
return super()._get_proprioception_obs(privileged)
|
||||
|
||||
return self.proprioception_output.flatten(start_dim= 1)
|
||||
|
||||
def get_obs_segment_from_components(self, obs_components):
|
||||
obs_segments = super().get_obs_segment_from_components(obs_components)
|
||||
if "forward_depth" in obs_components:
|
||||
|
@ -611,27 +772,3 @@ class LeggedRobotNoisy(LeggedRobotField):
|
|||
self.cfg.sensor.forward_camera.resolution,
|
||||
))
|
||||
return obs_segments
|
||||
|
||||
def _reward_exceed_torque_limits_i(self):
|
||||
""" Indicator function """
|
||||
max_torques = torch.abs(self.substep_torques).max(dim= 1)[0]
|
||||
exceed_torque_each_dof = max_torques > self.torque_limits
|
||||
exceed_torque = exceed_torque_each_dof.any(dim= 1)
|
||||
return exceed_torque.to(torch.float32)
|
||||
|
||||
def _reward_exceed_torque_limits_square(self):
|
||||
""" square function for exceeding part """
|
||||
exceeded_torques = torch.abs(self.substep_torques) - self.torque_limits
|
||||
exceeded_torques[exceeded_torques < 0.] = 0.
|
||||
# sum along decimation axis and dof axis
|
||||
return torch.square(exceeded_torques).sum(dim= 1).sum(dim= 1)
|
||||
|
||||
def _reward_exceed_torque_limits_l1norm(self):
|
||||
""" square function for exceeding part """
|
||||
exceeded_torques = torch.abs(self.substep_torques) - self.torque_limits
|
||||
exceeded_torques[exceeded_torques < 0.] = 0.
|
||||
# sum along decimation axis and dof axis
|
||||
return torch.norm(exceeded_torques, p= 1, dim= -1).sum(dim= 1)
|
||||
|
||||
def _reward_exceed_dof_pos_limits(self):
|
||||
return self.substep_exceed_dof_pos_limits.to(torch.float32).sum(dim= -1).mean(dim= -1)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
|
||||
from legged_gym.envs.base.legged_robot_field import LeggedRobotField
|
||||
from legged_gym.envs.base.legged_robot_noisy import LeggedRobotNoisyMixin
|
||||
|
||||
class RobotFieldNoisy(LeggedRobotNoisyMixin, LeggedRobotField):
|
||||
""" Using inheritance to combine the two classes.
|
||||
Then, LeggedRobotNoisyMixin and LeggedRobot can also be used elsewhere.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,280 @@
|
|||
""" Basic model configs for Unitree Go2 """
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
|
||||
from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
|
||||
|
||||
go2_action_scale = 0.5
|
||||
go2_const_dof_range = dict(
|
||||
Hip_max= 1.0472,
|
||||
Hip_min= -1.0472,
|
||||
Front_Thigh_max= 3.4907,
|
||||
Front_Thigh_min= -1.5708,
|
||||
Rear_Thingh_max= 4.5379,
|
||||
Rear_Thingh_min= -0.5236,
|
||||
Calf_max= -0.83776,
|
||||
Calf_min= -2.7227,
|
||||
)
|
||||
|
||||
class Go2RoughCfg( LeggedRobotCfg ):
|
||||
class env:
|
||||
num_envs = 4096
|
||||
num_observations = None # No use, use obs_components
|
||||
num_privileged_obs = None # No use, use privileged_obs_components
|
||||
|
||||
use_lin_vel = False # to be decided
|
||||
num_actions = 12
|
||||
send_timeouts = True # send time out information to the algorithm
|
||||
episode_length_s = 20 # episode length in seconds
|
||||
|
||||
obs_components = [
|
||||
"lin_vel",
|
||||
"ang_vel",
|
||||
"projected_gravity",
|
||||
"commands",
|
||||
"dof_pos",
|
||||
"dof_vel",
|
||||
"last_actions",
|
||||
"height_measurements",
|
||||
]
|
||||
|
||||
class sensor:
|
||||
class proprioception:
|
||||
obs_components = ["ang_vel", "projected_gravity", "commands", "dof_pos", "dof_vel"]
|
||||
latency_range = [0.005, 0.045] # [s]
|
||||
latency_resample_time = 5.0 # [s]
|
||||
|
||||
class terrain:
|
||||
selected = "TerrainPerlin"
|
||||
mesh_type = None
|
||||
measure_heights = True
|
||||
# x: [-0.5, 1.5], y: [-0.5, 0.5] range for go2
|
||||
measured_points_x = [i for i in np.arange(-0.5, 1.51, 0.1)]
|
||||
measured_points_y = [i for i in np.arange(-0.5, 0.51, 0.1)]
|
||||
horizontal_scale = 0.025 # [m]
|
||||
vertical_scale = 0.005 # [m]
|
||||
border_size = 5 # [m]
|
||||
curriculum = False
|
||||
static_friction = 1.0
|
||||
dynamic_friction = 1.0
|
||||
restitution = 0.
|
||||
max_init_terrain_level = 5 # starting curriculum state
|
||||
terrain_length = 4.
|
||||
terrain_width = 4.
|
||||
num_rows= 16 # number of terrain rows (levels)
|
||||
num_cols = 16 # number of terrain cols (types)
|
||||
slope_treshold = 1.
|
||||
|
||||
TerrainPerlin_kwargs = dict(
|
||||
zScale= 0.07,
|
||||
frequency= 10,
|
||||
)
|
||||
|
||||
class commands( LeggedRobotCfg.commands ):
|
||||
heading_command = False
|
||||
resampling_time = 5 # [s]
|
||||
lin_cmd_cutoff = 0.2
|
||||
ang_cmd_cutoff = 0.2
|
||||
class ranges( LeggedRobotCfg.commands.ranges ):
|
||||
lin_vel_x = [-1.0, 1.5]
|
||||
lin_vel_y = [-1., 1.]
|
||||
ang_vel_yaw = [-2., 2.]
|
||||
|
||||
class init_state( LeggedRobotCfg.init_state ):
|
||||
pos = [0., 0., 0.5] # [m]
|
||||
default_joint_angles = { # 12 joints in the order of simulation
|
||||
"FL_hip_joint": 0.1,
|
||||
"FL_thigh_joint": 0.7,
|
||||
"FL_calf_joint": -1.5,
|
||||
"FR_hip_joint": -0.1,
|
||||
"FR_thigh_joint": 0.7,
|
||||
"FR_calf_joint": -1.5,
|
||||
"RL_hip_joint": 0.1,
|
||||
"RL_thigh_joint": 1.0,
|
||||
"RL_calf_joint": -1.5,
|
||||
"RR_hip_joint": -0.1,
|
||||
"RR_thigh_joint": 1.0,
|
||||
"RR_calf_joint": -1.5,
|
||||
}
|
||||
|
||||
class control( LeggedRobotCfg.control ):
|
||||
stiffness = {'joint': 40.}
|
||||
damping = {'joint': 1.}
|
||||
action_scale = go2_action_scale
|
||||
computer_clip_torque = False
|
||||
motor_clip_torque = True
|
||||
|
||||
class asset( LeggedRobotCfg.asset ):
|
||||
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/go2/urdf/go2.urdf'
|
||||
name = "go2"
|
||||
foot_name = "foot"
|
||||
front_hip_names = ["FL_hip_joint", "FR_hip_joint"]
|
||||
rear_hip_names = ["RL_hip_joint", "RR_hip_joint"]
|
||||
penalize_contacts_on = ["thigh", "calf"]
|
||||
terminate_after_contacts_on = ["base"]
|
||||
sdk_dof_range = go2_const_dof_range
|
||||
dof_velocity_override = 35.
|
||||
|
||||
class termination:
|
||||
termination_terms = [
|
||||
"roll",
|
||||
"pitch",
|
||||
]
|
||||
|
||||
roll_kwargs = dict(
|
||||
threshold= 3.0, # [rad]
|
||||
)
|
||||
pitch_kwargs = dict(
|
||||
threshold= 3.0, # [rad] # for leap, jump
|
||||
)
|
||||
|
||||
class domain_rand( LeggedRobotCfg.domain_rand ):
|
||||
randomize_com = True
|
||||
class com_range:
|
||||
x = [-0.2, 0.2]
|
||||
y = [-0.1, 0.1]
|
||||
z = [-0.05, 0.05]
|
||||
|
||||
randomize_motor = True
|
||||
leg_motor_strength_range = [0.8, 1.2]
|
||||
|
||||
randomize_base_mass = True
|
||||
added_mass_range = [1.0, 3.0]
|
||||
|
||||
randomize_friction = True
|
||||
friction_range = [0., 2.]
|
||||
|
||||
init_base_pos_range = dict(
|
||||
x= [0.05, 0.6],
|
||||
y= [-0.25, 0.25],
|
||||
)
|
||||
init_base_rot_range = dict(
|
||||
roll= [-0.75, 0.75],
|
||||
pitch= [-0.75, 0.75],
|
||||
)
|
||||
init_base_vel_range = dict(
|
||||
x= [-0.2, 1.5],
|
||||
y= [-0.2, 0.2],
|
||||
z= [-0.2, 0.2],
|
||||
roll= [-1., 1.],
|
||||
pitch= [-1., 1.],
|
||||
yaw= [-1., 1.],
|
||||
)
|
||||
init_dof_vel_range = [-5, 5]
|
||||
|
||||
push_robots = True
|
||||
max_push_vel_xy = 0.5 # [m/s]
|
||||
push_interval_s = 2
|
||||
|
||||
class rewards( LeggedRobotCfg.rewards ):
|
||||
class scales:
|
||||
tracking_lin_vel = 1.
|
||||
tracking_ang_vel = 1.
|
||||
energy_substeps = -2e-5
|
||||
stand_still = -2.
|
||||
dof_error_named = -1.
|
||||
dof_error = -0.01
|
||||
# penalty for hardware safety
|
||||
exceed_dof_pos_limits = -0.4
|
||||
exceed_torque_limits_l1norm = -0.4
|
||||
dof_vel_limits = -0.4
|
||||
dof_error_names = ["FL_hip_joint", "FR_hip_joint", "RL_hip_joint", "RR_hip_joint"]
|
||||
only_positive_rewards = False
|
||||
soft_dof_vel_limit = 0.9
|
||||
soft_dof_pos_limit = 0.9
|
||||
soft_torque_limit = 0.9
|
||||
|
||||
class normalization( LeggedRobotCfg.normalization ):
|
||||
class obs_scales( LeggedRobotCfg.normalization.obs_scales ):
|
||||
lin_vel = 1.
|
||||
height_measurements_offset = -0.2
|
||||
clip_actions_method = None # let the policy learns not to exceed the limits
|
||||
|
||||
class noise( LeggedRobotCfg.noise ):
|
||||
add_noise = False
|
||||
|
||||
class viewer( LeggedRobotCfg.viewer ):
|
||||
pos = [-1., -1., 0.4]
|
||||
lookat = [0., 0., 0.3]
|
||||
|
||||
class sim( LeggedRobotCfg.sim ):
|
||||
body_measure_points = { # transform are related to body frame
|
||||
"base": dict(
|
||||
x= [i for i in np.arange(-0.24, 0.41, 0.03)],
|
||||
y= [-0.08, -0.04, 0.0, 0.04, 0.08],
|
||||
z= [i for i in np.arange(-0.061, 0.071, 0.03)],
|
||||
transform= [0., 0., 0.005, 0., 0., 0.],
|
||||
),
|
||||
"thigh": dict(
|
||||
x= [
|
||||
-0.16, -0.158, -0.156, -0.154, -0.152,
|
||||
-0.15, -0.145, -0.14, -0.135, -0.13, -0.125, -0.12, -0.115, -0.11, -0.105, -0.1, -0.095, -0.09, -0.085, -0.08, -0.075, -0.07, -0.065, -0.05,
|
||||
0.0, 0.05, 0.1,
|
||||
],
|
||||
y= [-0.015, -0.01, 0.0, -0.01, 0.015],
|
||||
z= [-0.03, -0.015, 0.0, 0.015],
|
||||
transform= [0., 0., -0.1, 0., 1.57079632679, 0.],
|
||||
),
|
||||
"calf": dict(
|
||||
x= [i for i in np.arange(-0.13, 0.111, 0.03)],
|
||||
y= [-0.015, 0.0, 0.015],
|
||||
z= [-0.015, 0.0, 0.015],
|
||||
transform= [0., 0., -0.11, 0., 1.57079632679, 0.],
|
||||
),
|
||||
}
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class Go2RoughCfgPPO( LeggedRobotCfgPPO ):
|
||||
class algorithm( LeggedRobotCfgPPO.algorithm ):
|
||||
entropy_coef = 0.01
|
||||
clip_min_std = 0.2
|
||||
learning_rate = 1e-4
|
||||
optimizer_class_name = "AdamW"
|
||||
|
||||
class policy( LeggedRobotCfgPPO.policy ):
|
||||
# configs for estimator module
|
||||
estimator_obs_components = [
|
||||
"ang_vel",
|
||||
"projected_gravity",
|
||||
"commands",
|
||||
"dof_pos",
|
||||
"dof_vel",
|
||||
"last_actions",
|
||||
]
|
||||
estimator_target_components = ["lin_vel"]
|
||||
replace_state_prob = 1.0
|
||||
class estimator_kwargs:
|
||||
hidden_sizes = [128, 64]
|
||||
nonlinearity = "CELU"
|
||||
# configs for (critic) encoder
|
||||
encoder_component_names = ["height_measurements"]
|
||||
encoder_class_name = "MlpModel"
|
||||
class encoder_kwargs:
|
||||
hidden_sizes = [128, 64]
|
||||
nonlinearity = "CELU"
|
||||
encoder_output_size = 32
|
||||
critic_encoder_component_names = ["height_measurements"]
|
||||
init_noise_std = 0.5
|
||||
# configs for policy: using recurrent policy with GRU
|
||||
rnn_type = 'gru'
|
||||
mu_activation = None
|
||||
|
||||
class runner( LeggedRobotCfgPPO.runner ):
|
||||
policy_class_name = "EncoderStateAcRecurrent"
|
||||
algorithm_class_name = "EstimatorPPO"
|
||||
experiment_name = "rough_go2"
|
||||
|
||||
resume = False
|
||||
load_run = None
|
||||
|
||||
run_name = "".join(["Go2Rough",
|
||||
("_pEnergy" + np.format_float_scientific(Go2RoughCfg.rewards.scales.energy_substeps, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.energy_substeps != 0 else ""),
|
||||
("_pDofErr" + np.format_float_scientific(Go2RoughCfg.rewards.scales.dof_error, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.dof_error != 0 else ""),
|
||||
("_pDofErrN" + np.format_float_scientific(Go2RoughCfg.rewards.scales.dof_error_named, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.dof_error_named != 0 else ""),
|
||||
("_pStand" + np.format_float_scientific(Go2RoughCfg.rewards.scales.stand_still, precision= 1, trim= "-") if Go2RoughCfg.rewards.scales.stand_still != 0 else ""),
|
||||
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
|
||||
max_iterations = 2000
|
||||
save_interval = 2000
|
||||
log_interval = 100
|
|
@ -0,0 +1,233 @@
|
|||
""" Config to train the whole parkour oracle policy """
|
||||
import numpy as np
|
||||
from os import path as osp
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from legged_gym.utils.helpers import merge_dict
|
||||
from legged_gym.envs.go2.go2_field_config import Go2FieldCfg, Go2FieldCfgPPO, Go2RoughCfgPPO
|
||||
|
||||
multi_process_ = True
|
||||
class Go2DistillCfg( Go2FieldCfg ):
|
||||
class env( Go2FieldCfg.env ):
|
||||
num_envs = 256
|
||||
obs_components = [
|
||||
"lin_vel",
|
||||
"ang_vel",
|
||||
"projected_gravity",
|
||||
"commands",
|
||||
"dof_pos",
|
||||
"dof_vel",
|
||||
"last_actions",
|
||||
"forward_depth",
|
||||
]
|
||||
|
||||
privileged_obs_components = [
|
||||
"lin_vel",
|
||||
"ang_vel",
|
||||
"projected_gravity",
|
||||
"commands",
|
||||
"dof_pos",
|
||||
"dof_vel",
|
||||
"last_actions",
|
||||
"height_measurements",
|
||||
]
|
||||
|
||||
class terrain( Go2FieldCfg.terrain ):
|
||||
if multi_process_:
|
||||
num_rows = 4
|
||||
num_cols = 1
|
||||
curriculum = False
|
||||
|
||||
BarrierTrack_kwargs = merge_dict(Go2FieldCfg.terrain.BarrierTrack_kwargs, dict(
|
||||
leap= dict(
|
||||
length= [0.05, 0.8],
|
||||
depth= [0.5, 0.8],
|
||||
height= 0.15, # expected leap height over the gap
|
||||
fake_offset= 0.1,
|
||||
),
|
||||
))
|
||||
|
||||
class sensor( Go2FieldCfg.sensor ):
|
||||
class forward_camera:
|
||||
obs_components = ["forward_depth"]
|
||||
resolution = [int(480/4), int(640/4)]
|
||||
position = dict(
|
||||
mean= [0.24, -0.0175, 0.12],
|
||||
std= [0.01, 0.0025, 0.03],
|
||||
)
|
||||
rotation = dict(
|
||||
lower= [-0.1, 0.37, -0.1],
|
||||
upper= [0.1, 0.43, 0.1],
|
||||
)
|
||||
resized_resolution = [48, 64]
|
||||
output_resolution = [48, 64]
|
||||
horizontal_fov = [86, 90]
|
||||
crop_top_bottom = [int(48/4), 0]
|
||||
crop_left_right = [int(28/4), int(36/4)]
|
||||
near_plane = 0.05
|
||||
depth_range = [0.0, 3.0]
|
||||
|
||||
latency_range = [0.08, 0.142]
|
||||
latency_resample_time = 5.0
|
||||
refresh_duration = 1/10 # [s]
|
||||
|
||||
class commands( Go2FieldCfg.commands ):
|
||||
# a mixture of command sampling and goal_based command update allows only high speed range
|
||||
# in x-axis but no limits on y-axis and yaw-axis
|
||||
lin_cmd_cutoff = 0.2
|
||||
class ranges( Go2FieldCfg.commands.ranges ):
|
||||
# lin_vel_x = [0.6, 1.8]
|
||||
lin_vel_x = [-0.6, 2.0]
|
||||
|
||||
is_goal_based = True
|
||||
class goal_based:
|
||||
# the ratios are related to the goal position in robot frame
|
||||
x_ratio = None # sample from lin_vel_x range
|
||||
y_ratio = 1.2
|
||||
yaw_ratio = 0.8
|
||||
follow_cmd_cutoff = True
|
||||
x_stop_by_yaw_threshold = 1. # stop when yaw is over this threshold [rad]
|
||||
|
||||
class normalization( Go2FieldCfg.normalization ):
|
||||
class obs_scales( Go2FieldCfg.normalization.obs_scales ):
|
||||
forward_depth = 1.0
|
||||
|
||||
class noise( Go2FieldCfg.noise ):
|
||||
add_noise = False
|
||||
class noise_scales( Go2FieldCfg.noise.noise_scales ):
|
||||
forward_depth = 0.0
|
||||
### noise for simulating sensors
|
||||
commands = 0.1
|
||||
lin_vel = 0.1
|
||||
ang_vel = 0.1
|
||||
projected_gravity = 0.02
|
||||
dof_pos = 0.02
|
||||
dof_vel = 0.2
|
||||
last_actions = 0.
|
||||
### noise for simulating sensors
|
||||
class forward_depth:
|
||||
stereo_min_distance = 0.175 # when using (480, 640) resolution
|
||||
stereo_far_distance = 1.2
|
||||
stereo_far_noise_std = 0.08
|
||||
stereo_near_noise_std = 0.02
|
||||
stereo_full_block_artifacts_prob = 0.008
|
||||
stereo_full_block_values = [0.0, 0.25, 0.5, 1., 3.]
|
||||
stereo_full_block_height_mean_std = [62, 1.5]
|
||||
stereo_full_block_width_mean_std = [3, 0.01]
|
||||
stereo_half_block_spark_prob = 0.02
|
||||
stereo_half_block_value = 3000
|
||||
sky_artifacts_prob = 0.0001
|
||||
sky_artifacts_far_distance = 2.
|
||||
sky_artifacts_values = [0.6, 1., 1.2, 1.5, 1.8]
|
||||
sky_artifacts_height_mean_std = [2, 3.2]
|
||||
sky_artifacts_width_mean_std = [2, 3.2]
|
||||
|
||||
class curriculum:
|
||||
no_moveup_when_fall = False
|
||||
|
||||
class sim( Go2FieldCfg.sim ):
|
||||
no_camera = False
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class Go2DistillCfgPPO( Go2FieldCfgPPO ):
|
||||
class algorithm( Go2FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
using_ppo = False
|
||||
num_learning_epochs = 8
|
||||
num_mini_batches = 2
|
||||
distill_target = "l1"
|
||||
learning_rate = 3e-4
|
||||
optimizer_class_name = "AdamW"
|
||||
teacher_act_prob = 0.
|
||||
distillation_loss_coef = 1.0
|
||||
# update_times_scale = 100
|
||||
action_labels_from_sample = False
|
||||
|
||||
teacher_policy_class_name = "EncoderStateAcRecurrent"
|
||||
teacher_ac_path = osp.join(logs_root, "field_go2",
|
||||
"{Your trained oracle parkour model directory}",
|
||||
"{The latest model filename in the directory}"
|
||||
)
|
||||
|
||||
class teacher_policy( Go2FieldCfgPPO.policy ):
|
||||
num_actor_obs = 48 + 21 * 11
|
||||
num_critic_obs = 48 + 21 * 11
|
||||
num_actions = 12
|
||||
obs_segments = OrderedDict([
|
||||
("lin_vel", (3,)),
|
||||
("ang_vel", (3,)),
|
||||
("projected_gravity", (3,)),
|
||||
("commands", (3,)),
|
||||
("dof_pos", (12,)),
|
||||
("dof_vel", (12,)),
|
||||
("last_actions", (12,)), # till here: 3+3+3+3+12+12+12 = 48
|
||||
("height_measurements", (1, 21, 11)),
|
||||
])
|
||||
|
||||
class policy( Go2RoughCfgPPO.policy ):
|
||||
# configs for estimator module
|
||||
estimator_obs_components = [
|
||||
"ang_vel",
|
||||
"projected_gravity",
|
||||
"commands",
|
||||
"dof_pos",
|
||||
"dof_vel",
|
||||
"last_actions",
|
||||
]
|
||||
estimator_target_components = ["lin_vel"]
|
||||
replace_state_prob = 1.0
|
||||
class estimator_kwargs:
|
||||
hidden_sizes = [128, 64]
|
||||
nonlinearity = "CELU"
|
||||
# configs for visual encoder
|
||||
encoder_component_names = ["forward_depth"]
|
||||
encoder_class_name = "Conv2dHeadModel"
|
||||
class encoder_kwargs:
|
||||
channels = [16, 32, 32]
|
||||
kernel_sizes = [5, 4, 3]
|
||||
strides = [2, 2, 1]
|
||||
hidden_sizes = [128]
|
||||
use_maxpool = True
|
||||
nonlinearity = "LeakyReLU"
|
||||
# configs for critic encoder
|
||||
critic_encoder_component_names = ["height_measurements"]
|
||||
critic_encoder_class_name = "MlpModel"
|
||||
class critic_encoder_kwargs:
|
||||
hidden_sizes = [128, 64]
|
||||
nonlinearity = "CELU"
|
||||
encoder_output_size = 32
|
||||
|
||||
init_noise_std = 0.1
|
||||
|
||||
if multi_process_:
|
||||
runner_class_name = "TwoStageRunner"
|
||||
class runner( Go2FieldCfgPPO.runner ):
|
||||
policy_class_name = "EncoderStateAcRecurrent"
|
||||
algorithm_class_name = "EstimatorTPPO"
|
||||
experiment_name = "distill_go2"
|
||||
num_steps_per_env = 32
|
||||
|
||||
if multi_process_:
|
||||
pretrain_iterations = -1
|
||||
class pretrain_dataset:
|
||||
data_dir = "{A temporary directory to store collected trajectory}"
|
||||
dataset_loops = -1
|
||||
random_shuffle_traj_order = True
|
||||
keep_latest_n_trajs = 1500
|
||||
starting_frame_range = [0, 50]
|
||||
|
||||
resume = True
|
||||
load_run = osp.join(logs_root, "field_go2",
|
||||
"{Your trained oracle parkour model directory}",
|
||||
)
|
||||
ckpt_manipulator = "replace_encoder0" if "field_go2" in load_run else None
|
||||
|
||||
run_name = "".join(["Go2_",
|
||||
("{:d}skills".format(len(Go2DistillCfg.terrain.BarrierTrack_kwargs["options"]))),
|
||||
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
|
||||
max_iterations = 60000
|
||||
log_interval = 100
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
""" Config to train the whole parkour oracle policy """
|
||||
import numpy as np
|
||||
from os import path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
from legged_gym.envs.go2.go2_config import Go2RoughCfg, Go2RoughCfgPPO
|
||||
|
||||
class Go2FieldCfg( Go2RoughCfg ):
|
||||
class init_state( Go2RoughCfg.init_state ):
|
||||
pos = [0.0, 0.0, 0.7]
|
||||
zero_actions = False
|
||||
|
||||
class sensor( Go2RoughCfg.sensor):
|
||||
class proprioception( Go2RoughCfg.sensor.proprioception ):
|
||||
# latency_range = [0.0, 0.0]
|
||||
latency_range = [0.005, 0.045] # [s]
|
||||
|
||||
class terrain( Go2RoughCfg.terrain ):
|
||||
num_rows = 10
|
||||
num_cols = 40
|
||||
selected = "BarrierTrack"
|
||||
slope_treshold = 20.
|
||||
|
||||
max_init_terrain_level = 2
|
||||
curriculum = True
|
||||
|
||||
pad_unavailable_info = True
|
||||
BarrierTrack_kwargs = dict(
|
||||
options= [
|
||||
"jump",
|
||||
"leap",
|
||||
"hurdle",
|
||||
"down",
|
||||
"tilted_ramp",
|
||||
"stairsup",
|
||||
"stairsdown",
|
||||
"discrete_rect",
|
||||
"slope",
|
||||
"wave",
|
||||
], # each race track will permute all the options
|
||||
jump= dict(
|
||||
height= [0.05, 0.5],
|
||||
depth= [0.1, 0.3],
|
||||
# fake_offset= 0.1,
|
||||
),
|
||||
leap= dict(
|
||||
length= [0.05, 0.8],
|
||||
depth= [0.5, 0.8],
|
||||
height= 0.2, # expected leap height over the gap
|
||||
# fake_offset= 0.1,
|
||||
),
|
||||
hurdle= dict(
|
||||
height= [0.05, 0.5],
|
||||
depth= [0.2, 0.5],
|
||||
# fake_offset= 0.1,
|
||||
curved_top_rate= 0.1,
|
||||
),
|
||||
down= dict(
|
||||
height= [0.1, 0.6],
|
||||
depth= [0.3, 0.5],
|
||||
),
|
||||
tilted_ramp= dict(
|
||||
tilt_angle= [0.2, 0.5],
|
||||
switch_spacing= 0.,
|
||||
spacing_curriculum= False,
|
||||
overlap_size= 0.2,
|
||||
depth= [-0.1, 0.1],
|
||||
length= [0.6, 1.2],
|
||||
),
|
||||
slope= dict(
|
||||
slope_angle= [0.2, 0.42],
|
||||
length= [1.2, 2.2],
|
||||
use_mean_height_offset= True,
|
||||
face_angle= [-3.14, 0, 1.57, -1.57],
|
||||
no_perlin_rate= 0.2,
|
||||
length_curriculum= True,
|
||||
),
|
||||
slopeup= dict(
|
||||
slope_angle= [0.2, 0.42],
|
||||
length= [1.2, 2.2],
|
||||
use_mean_height_offset= True,
|
||||
face_angle= [-0.2, 0.2],
|
||||
no_perlin_rate= 0.2,
|
||||
length_curriculum= True,
|
||||
),
|
||||
slopedown= dict(
|
||||
slope_angle= [0.2, 0.42],
|
||||
length= [1.2, 2.2],
|
||||
use_mean_height_offset= True,
|
||||
face_angle= [-0.2, 0.2],
|
||||
no_perlin_rate= 0.2,
|
||||
length_curriculum= True,
|
||||
),
|
||||
stairsup= dict(
|
||||
height= [0.1, 0.3],
|
||||
length= [0.3, 0.5],
|
||||
residual_distance= 0.05,
|
||||
num_steps= [3, 19],
|
||||
num_steps_curriculum= True,
|
||||
),
|
||||
stairsdown= dict(
|
||||
height= [0.1, 0.3],
|
||||
length= [0.3, 0.5],
|
||||
num_steps= [3, 19],
|
||||
num_steps_curriculum= True,
|
||||
),
|
||||
discrete_rect= dict(
|
||||
max_height= [0.05, 0.2],
|
||||
max_size= 0.6,
|
||||
min_size= 0.2,
|
||||
num_rects= 10,
|
||||
),
|
||||
wave= dict(
|
||||
amplitude= [0.1, 0.15], # in meter
|
||||
frequency= [0.6, 1.0], # in 1/meter
|
||||
),
|
||||
track_width= 3.2,
|
||||
track_block_length= 2.4,
|
||||
wall_thickness= (0.01, 0.6),
|
||||
wall_height= [-0.5, 2.0],
|
||||
add_perlin_noise= True,
|
||||
border_perlin_noise= True,
|
||||
border_height= 0.,
|
||||
virtual_terrain= False,
|
||||
draw_virtual_terrain= True,
|
||||
engaging_next_threshold= 0.8,
|
||||
engaging_finish_threshold= 0.,
|
||||
curriculum_perlin= False,
|
||||
no_perlin_threshold= 0.1,
|
||||
randomize_obstacle_order= True,
|
||||
n_obstacles_per_track= 1,
|
||||
)
|
||||
|
||||
class commands( Go2RoughCfg.commands ):
|
||||
# a mixture of command sampling and goal_based command update allows only high speed range
|
||||
# in x-axis but no limits on y-axis and yaw-axis
|
||||
lin_cmd_cutoff = 0.2
|
||||
class ranges( Go2RoughCfg.commands.ranges ):
|
||||
# lin_vel_x = [0.6, 1.8]
|
||||
lin_vel_x = [-0.6, 2.0]
|
||||
|
||||
is_goal_based = True
|
||||
class goal_based:
|
||||
# the ratios are related to the goal position in robot frame
|
||||
x_ratio = None # sample from lin_vel_x range
|
||||
y_ratio = 1.2
|
||||
yaw_ratio = 1.
|
||||
follow_cmd_cutoff = True
|
||||
x_stop_by_yaw_threshold = 1. # stop when yaw is over this threshold [rad]
|
||||
|
||||
class asset( Go2RoughCfg.asset ):
|
||||
terminate_after_contacts_on = []
|
||||
penalize_contacts_on = ["thigh", "calf", "base"]
|
||||
|
||||
class termination( Go2RoughCfg.termination ):
|
||||
roll_kwargs = dict(
|
||||
threshold= 1.4, # [rad]
|
||||
)
|
||||
pitch_kwargs = dict(
|
||||
threshold= 1.6, # [rad]
|
||||
)
|
||||
timeout_at_border = True
|
||||
timeout_at_finished = False
|
||||
|
||||
class rewards( Go2RoughCfg.rewards ):
|
||||
class scales:
|
||||
tracking_lin_vel = 1.
|
||||
tracking_ang_vel = 1.
|
||||
energy_substeps = -2e-7
|
||||
torques = -1e-7
|
||||
stand_still = -1.
|
||||
dof_error_named = -1.
|
||||
dof_error = -0.005
|
||||
collision = -0.05
|
||||
lazy_stop = -3.
|
||||
# penalty for hardware safety
|
||||
exceed_dof_pos_limits = -0.1
|
||||
exceed_torque_limits_l1norm = -0.1
|
||||
# penetration penalty
|
||||
penetrate_depth = -0.05
|
||||
|
||||
class noise( Go2RoughCfg.noise ):
|
||||
add_noise = False
|
||||
|
||||
class curriculum:
|
||||
penetrate_depth_threshold_harder = 100
|
||||
penetrate_depth_threshold_easier = 200
|
||||
no_moveup_when_fall = True
|
||||
|
||||
logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
|
||||
class Go2FieldCfgPPO( Go2RoughCfgPPO ):
|
||||
class algorithm( Go2RoughCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
|
||||
class runner( Go2RoughCfgPPO.runner ):
|
||||
experiment_name = "field_go2"
|
||||
|
||||
resume = True
|
||||
load_run = osp.join(logs_root, "rough_go2",
|
||||
"{Your trained walking model directory}",
|
||||
)
|
||||
|
||||
run_name = "".join(["Go2_",
|
||||
("{:d}skills".format(len(Go2FieldCfg.terrain.BarrierTrack_kwargs["options"]))),
|
||||
("_pEnergy" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.energy_substeps, precision=2)),
|
||||
# ("_pDofErr" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.dof_error, precision=2) if getattr(Go2FieldCfg.rewards.scales, "dof_error", 0.) != 0. else ""),
|
||||
# ("_pHipDofErr" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.dof_error_named, precision=2) if getattr(Go2FieldCfg.rewards.scales, "dof_error_named", 0.) != 0. else ""),
|
||||
# ("_pStand" + np.format_float_scientific(Go2FieldCfg.rewards.scales.stand_still, precision=2)),
|
||||
# ("_pTerm" + np.format_float_scientific(Go2FieldCfg.rewards.scales.termination, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "termination") else ""),
|
||||
("_pTorques" + np.format_float_scientific(Go2FieldCfg.rewards.scales.torques, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "torques") else ""),
|
||||
# ("_pColl" + np.format_float_scientific(Go2FieldCfg.rewards.scales.collision, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "collision") else ""),
|
||||
("_pLazyStop" + np.format_float_scientific(Go2FieldCfg.rewards.scales.lazy_stop, precision=2) if hasattr(Go2FieldCfg.rewards.scales, "lazy_stop") else ""),
|
||||
# ("_trackSigma" + np.format_float_scientific(Go2FieldCfg.rewards.tracking_sigma, precision=2) if Go2FieldCfg.rewards.tracking_sigma != 0.25 else ""),
|
||||
# ("_pPenV" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.penetrate_volume, precision=2)),
|
||||
("_pPenD" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.penetrate_depth, precision=2)),
|
||||
# ("_pTorqueL1" + np.format_float_scientific(-Go2FieldCfg.rewards.scales.exceed_torque_limits_l1norm, precision=2)),
|
||||
("_penEasier{:d}".format(Go2FieldCfg.curriculum.penetrate_depth_threshold_easier)),
|
||||
("_penHarder{:d}".format(Go2FieldCfg.curriculum.penetrate_depth_threshold_harder)),
|
||||
# ("_leapMin" + np.format_float_scientific(Go2FieldCfg.terrain.BarrierTrack_kwargs["leap"]["length"][0], precision=2)),
|
||||
("_leapHeight" + np.format_float_scientific(Go2FieldCfg.terrain.BarrierTrack_kwargs["leap"]["height"], precision=2)),
|
||||
("_motorTorqueClip" if Go2FieldCfg.control.motor_clip_torque else ""),
|
||||
# ("_noMoveupWhenFall" if Go2FieldCfg.curriculum.no_moveup_when_fall else ""),
|
||||
("_noResume" if not resume else "_from" + "_".join(load_run.split("/")[-1].split("_")[:2])),
|
||||
])
|
||||
|
||||
max_iterations = 38000
|
||||
save_interval = 10000
|
||||
log_interval = 100
|
||||
|
|
@ -8,8 +8,10 @@ import os
|
|||
import json
|
||||
import os.path as osp
|
||||
|
||||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
from legged_gym.envs import *
|
||||
from legged_gym.utils import get_args, task_registry
|
||||
from legged_gym.utils import get_args
|
||||
from legged_gym.utils.task_registry import task_registry
|
||||
from legged_gym.utils.helpers import update_cfg_from_args, class_to_dict, update_class_from_dict
|
||||
from legged_gym.debugger import break_into_debugger
|
||||
|
||||
|
@ -17,55 +19,22 @@ from rsl_rl.modules import build_actor_critic
|
|||
from rsl_rl.runners.dagger_saver import DemonstrationSaver, DaggerSaver
|
||||
|
||||
def main(args):
|
||||
RunnerCls = DaggerSaver
|
||||
# RunnerCls = DemonstrationSaver
|
||||
RunnerCls = DaggerSaver if args.load_run else DemonstrationSaver
|
||||
success_traj_only = False
|
||||
teacher_act_prob = 0.1
|
||||
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
|
||||
if RunnerCls == DaggerSaver:
|
||||
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
|
||||
d = json.load(f, object_pairs_hook= OrderedDict)
|
||||
update_class_from_dict(env_cfg, d, strict= True)
|
||||
update_class_from_dict(train_cfg, d, strict= True)
|
||||
|
||||
# Some custom settings
|
||||
# ####### customized option to increase data distribution #######
|
||||
action_sample_std = 0.0
|
||||
|
||||
####### customized option to increase data distribution #######
|
||||
# env_cfg.env.num_envs = 6
|
||||
# env_cfg.terrain.curriculum = True
|
||||
# env_cfg.terrain.max_init_terrain_level = 0
|
||||
# env_cfg.terrain.border_size = 1.
|
||||
############# some predefined options #############
|
||||
if len(env_cfg.terrain.BarrierTrack_kwargs["options"]) == 1:
|
||||
env_cfg.terrain.num_rows = 20; env_cfg.terrain.num_cols = 30
|
||||
else: # for parkour env
|
||||
# >>> option 1
|
||||
env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 2.8
|
||||
env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 2.4
|
||||
env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 0.6)
|
||||
env_cfg.domain_rand.init_base_pos_range["x"] = (0.4, 1.8)
|
||||
env_cfg.terrain.num_rows = 12; env_cfg.terrain.num_cols = 10
|
||||
# >>> option 2
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 3.
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 4.0
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.5, -0.2)
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 1.4)
|
||||
# env_cfg.domain_rand.init_base_pos_range["x"] = (1.6, 2.0)
|
||||
# env_cfg.terrain.num_rows = 16; env_cfg.terrain.num_cols = 5
|
||||
# >>> option 3
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] = 1.6
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 2.2
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.5, 0.1)
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.0, 0.5)
|
||||
# env_cfg.domain_rand.init_base_pos_range["x"] = (0.2, 0.9)
|
||||
# env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 1
|
||||
# action_sample_std = 0.1
|
||||
# env_cfg.terrain.num_rows = 22; env_cfg.terrain.num_cols = 16
|
||||
pass
|
||||
if (env_cfg.terrain.BarrierTrack_kwargs["options"][0] == "leap") and all(i == env_cfg.terrain.BarrierTrack_kwargs["options"][0] for i in env_cfg.terrain.BarrierTrack_kwargs["options"]):
|
||||
######### For leap, because the platform is usually higher than the ground.
|
||||
env_cfg.terrain.num_rows = 80
|
||||
env_cfg.terrain.num_cols = 1
|
||||
env_cfg.terrain.BarrierTrack_kwargs["track_width"] = 1.6
|
||||
env_cfg.terrain.BarrierTrack_kwargs["wall_thickness"] = (0.01, 0.5)
|
||||
env_cfg.terrain.BarrierTrack_kwargs["wall_height"] = (-0.4, 0.2) # randomize incase of terrain that have side walls
|
||||
env_cfg.terrain.BarrierTrack_kwargs["border_height"] = -0.4
|
||||
env_cfg.terrain.num_rows = 8; env_cfg.terrain.num_cols = 40
|
||||
# Done custom settings
|
||||
|
||||
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
|
||||
|
@ -73,82 +42,81 @@ def main(args):
|
|||
|
||||
config = class_to_dict(train_cfg)
|
||||
config.update(class_to_dict(env_cfg))
|
||||
teacher_act_prob = config["algorithm"]["teacher_act_prob"] if args.teacher_prob is None else args.teacher_prob
|
||||
action_std = config["policy"]["init_noise_std"] if args.action_std is None else args.action_std
|
||||
|
||||
# create teacher policy
|
||||
policy = build_actor_critic(
|
||||
env,
|
||||
train_cfg.algorithm.teacher_policy_class_name,
|
||||
config["algorithm"]["teacher_policy_class_name"],
|
||||
config["algorithm"]["teacher_policy"],
|
||||
).to(env.device)
|
||||
# load the policy is possible
|
||||
if train_cfg.algorithm.teacher_ac_path is not None:
|
||||
state_dict = torch.load(train_cfg.algorithm.teacher_ac_path, map_location= "cpu")
|
||||
if config["algorithm"]["teacher_ac_path"] is not None:
|
||||
if "{LEGGED_GYM_ROOT_DIR}" in config["algorithm"]["teacher_ac_path"]:
|
||||
config["algorithm"]["teacher_ac_path"] = config["algorithm"]["teacher_ac_path"].format(LEGGED_GYM_ROOT_DIR= LEGGED_GYM_ROOT_DIR)
|
||||
state_dict = torch.load(config["algorithm"]["teacher_ac_path"], map_location= "cpu")
|
||||
teacher_actor_critic_state_dict = state_dict["model_state_dict"]
|
||||
policy.load_state_dict(teacher_actor_critic_state_dict)
|
||||
|
||||
# build runner
|
||||
track_header = "".join(env_cfg.terrain.BarrierTrack_kwargs["options"])
|
||||
if env_cfg.commands.ranges.lin_vel_x[1] > 0.0:
|
||||
cmd_vel = "_cmd{:.1f}-{:.1f}".format(env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1])
|
||||
elif env_cfg.commands.ranges.lin_vel_x[1] == 0. and len(env_cfg.terrain.BarrierTrack_kwargs["options"]) == 1 \
|
||||
or (env_cfg.terrain.BarrierTrack_kwargs["options"][0] == env_cfg.terrain.BarrierTrack_kwargs["options"][1]):
|
||||
obstacle_id = env.terrain.track_options_id_dict[env_cfg.terrain.BarrierTrack_kwargs["options"][0]]
|
||||
try:
|
||||
overrided_vel = train_cfg.algorithm.teacher_policy.cmd_vel_mapping[obstacle_id]
|
||||
except:
|
||||
overrided_vel = train_cfg.algorithm.teacher_policy.cmd_vel_mapping[str(obstacle_id)]
|
||||
cmd_vel = "_cmdOverride{:.1f}".format(overrided_vel)
|
||||
else:
|
||||
cmd_vel = "_cmdMutex"
|
||||
|
||||
datadir = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), "logs")
|
||||
runner_kwargs = dict(
|
||||
env= env,
|
||||
policy= policy,
|
||||
save_dir= osp.join(
|
||||
train_cfg.runner.pretrain_dataset.scan_dir if RunnerCls == DaggerSaver else osp.join(datadir, "distill_{}_dagger".format(args.task.split("_")[0])),
|
||||
config["runner"]["pretrain_dataset"]["data_dir"] if RunnerCls == DaggerSaver else osp.join("/localdata_ssd/zzw/athletic-isaac_tmp", "{}_dagger".format(config["runner"]["experiment_name"])),
|
||||
datetime.now().strftime('%b%d_%H-%M-%S') + "_" + "".join([
|
||||
track_header,
|
||||
cmd_vel,
|
||||
"_lowBorder" if env_cfg.terrain.BarrierTrack_kwargs["border_height"] < 0 else "",
|
||||
"_trackWidth{:.1f}".format(env_cfg.terrain.BarrierTrack_kwargs["track_width"]) if env_cfg.terrain.BarrierTrack_kwargs["track_width"] < 1.8 else "",
|
||||
"_blockLength{:.1f}".format(env_cfg.terrain.BarrierTrack_kwargs["track_block_length"]) if env_cfg.terrain.BarrierTrack_kwargs["track_block_length"] > 1.6 else "",
|
||||
"_addMassMin{:.1f}".format(env_cfg.domain_rand.added_mass_range[0]) if env_cfg.domain_rand.added_mass_range[0] > 1. else "",
|
||||
"_comMean{:.2f}".format((env_cfg.domain_rand.com_range.x[0] + env_cfg.domain_rand.com_range.x[1])/2),
|
||||
"_1cols" if env_cfg.terrain.num_cols == 1 else "",
|
||||
"_teacherProb{:.1f}".format(teacher_act_prob),
|
||||
"_randOrder" if env_cfg.terrain.BarrierTrack_kwargs.get("randomize_obstacle_order", False) else "",
|
||||
("_noPerlinRate{:.1f}".format(
|
||||
(env_cfg.terrain.BarrierTrack_kwargs["no_perlin_threshold"] - env_cfg.terrain.TerrainPerlin_kwargs["zScale"][0]) / \
|
||||
(env_cfg.terrain.TerrainPerlin_kwargs["zScale"][1] - env_cfg.terrain.TerrainPerlin_kwargs["zScale"][0])
|
||||
)),
|
||||
) if isinstance(env_cfg.terrain.TerrainPerlin_kwargs["zScale"], (tuple, list)) else ""),
|
||||
("_fric{:.1f}-{:.1f}".format(*env_cfg.domain_rand.friction_range)),
|
||||
"_successOnly" if success_traj_only else "",
|
||||
"_aStd{:.2f}".format(action_sample_std) if (action_sample_std > 0. and RunnerCls == DaggerSaver) else "",
|
||||
"_aStd{:.2f}".format(action_std) if (action_std > 0. and RunnerCls == DaggerSaver) else "",
|
||||
] + ([] if RunnerCls == DemonstrationSaver else ["_" + "_".join(args.load_run.split("_")[:2])])
|
||||
),
|
||||
),
|
||||
rollout_storage_length= 256,
|
||||
min_timesteps= 1e6, # 1e6,
|
||||
min_episodes= 2e4 if RunnerCls == DaggerSaver else 2e-3,
|
||||
min_timesteps= 1e9, # 1e6,
|
||||
min_episodes= 1e6 if RunnerCls == DaggerSaver else 2e5,
|
||||
use_critic_obs= True,
|
||||
success_traj_only= success_traj_only,
|
||||
obs_disassemble_mapping= dict(
|
||||
forward_depth= "normalized_image",
|
||||
),
|
||||
demo_by_sample= config["algorithm"].get("action_labels_from_sample", False),
|
||||
)
|
||||
if RunnerCls == DaggerSaver:
|
||||
# kwargs for dagger saver
|
||||
runner_kwargs.update(dict(
|
||||
training_policy_logdir= osp.join(
|
||||
"logs",
|
||||
train_cfg.runner.experiment_name,
|
||||
config["runner"]["experiment_name"],
|
||||
args.load_run,
|
||||
),
|
||||
teacher_act_prob= teacher_act_prob,
|
||||
update_times_scale= config["algorithm"]["update_times_scale"],
|
||||
action_sample_std= action_sample_std,
|
||||
update_times_scale= config["algorithm"].get("update_times_scale", 1e5),
|
||||
action_sample_std= action_std,
|
||||
log_to_tensorboard= args.log,
|
||||
))
|
||||
runner = RunnerCls(**runner_kwargs)
|
||||
runner.collect_and_save(config= config)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
args = get_args(
|
||||
custom_args= [
|
||||
{"name": "--teacher_prob", "type": float, "default": None, "help": "probability of using teacher's action"},
|
||||
{"name": "--action_std", "type": float, "default": None, "help": "override the action sample std during rollout. None for using model's std"},
|
||||
{"name": "--log", "action": "store_true", "help": "log the data to tensorboard"},
|
||||
],
|
||||
)
|
||||
main(args)
|
||||
|
|
|
@ -39,7 +39,8 @@ import isaacgym
|
|||
from isaacgym import gymtorch, gymapi
|
||||
from isaacgym.torch_utils import *
|
||||
from legged_gym.envs import *
|
||||
from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
|
||||
from legged_gym.utils import get_args, export_policy_as_jit, Logger
|
||||
from legged_gym.utils.task_registry import task_registry
|
||||
from legged_gym.utils.helpers import update_class_from_dict
|
||||
from legged_gym.utils.observation import get_obs_slice
|
||||
from legged_gym.debugger import break_into_debugger
|
||||
|
@ -78,18 +79,30 @@ def create_recording_camera(gym, env_handle,
|
|||
@torch.no_grad()
|
||||
def play(args):
|
||||
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
|
||||
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
|
||||
d = json.load(f, object_pairs_hook= OrderedDict)
|
||||
update_class_from_dict(env_cfg, d, strict= True)
|
||||
update_class_from_dict(train_cfg, d, strict= True)
|
||||
if args.load_cfg:
|
||||
with open(os.path.join("logs", train_cfg.runner.experiment_name, args.load_run, "config.json"), "r") as f:
|
||||
d = json.load(f, object_pairs_hook= OrderedDict)
|
||||
update_class_from_dict(env_cfg, d, strict= True)
|
||||
update_class_from_dict(train_cfg, d, strict= True)
|
||||
|
||||
# override some parameters for testing
|
||||
if env_cfg.terrain.selected == "BarrierTrack":
|
||||
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
|
||||
env_cfg.env.episode_length_s = 20
|
||||
env_cfg.terrain.max_init_terrain_level = 0
|
||||
env_cfg.terrain.num_rows = 1
|
||||
env_cfg.terrain.num_cols = 1
|
||||
env_cfg.terrain.num_rows = 4
|
||||
env_cfg.terrain.num_cols = 8
|
||||
env_cfg.terrain.BarrierTrack_kwargs["options"] = [
|
||||
"jump",
|
||||
"leap",
|
||||
"down",
|
||||
"hurdle",
|
||||
"tilted_ramp",
|
||||
"stairsup",
|
||||
"discrete_rect",
|
||||
"wave",
|
||||
]
|
||||
env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1
|
||||
else:
|
||||
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
|
||||
env_cfg.env.episode_length_s = 60
|
||||
|
@ -98,67 +111,43 @@ def play(args):
|
|||
env_cfg.terrain.max_init_terrain_level = 0
|
||||
env_cfg.terrain.num_rows = 1
|
||||
env_cfg.terrain.num_cols = 1
|
||||
env_cfg.terrain.curriculum = False
|
||||
env_cfg.terrain.BarrierTrack_kwargs["options"] = [
|
||||
# "crawl",
|
||||
"jump",
|
||||
# "leap",
|
||||
# "tilt",
|
||||
]
|
||||
if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys():
|
||||
env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track")
|
||||
env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 2
|
||||
# env_cfg.terrain.curriculum = False
|
||||
# env_cfg.asset.fix_base_link = True
|
||||
env_cfg.env.episode_length_s = 1000
|
||||
env_cfg.commands.resampling_time = int(1e16)
|
||||
env_cfg.commands.ranges.lin_vel_x = [1.2, 1.2]
|
||||
if "distill" in args.task:
|
||||
env_cfg.commands.ranges.lin_vel_x = [0.0, 0.0]
|
||||
env_cfg.commands.ranges.lin_vel_y = [-0., 0.]
|
||||
env_cfg.commands.ranges.ang_vel_yaw = [-0., 0.]
|
||||
env_cfg.domain_rand.push_robots = False
|
||||
env_cfg.domain_rand.init_base_pos_range = dict(
|
||||
x= [0.6, 0.6],
|
||||
y= [-0.05, 0.05],
|
||||
)
|
||||
env_cfg.termination.termination_terms = []
|
||||
# env_cfg.termination.termination_terms = []
|
||||
env_cfg.termination.timeout_at_border = False
|
||||
env_cfg.termination.timeout_at_finished = False
|
||||
env_cfg.viewer.debug_viz = False # in a1_distill, setting this to true will constantly showing the egocentric depth view.
|
||||
env_cfg.viewer.debug_viz = True
|
||||
env_cfg.viewer.draw_measure_heights = False
|
||||
env_cfg.viewer.draw_height_measurements = False
|
||||
env_cfg.viewer.draw_volume_sample_points = False
|
||||
train_cfg.runner.resume = True
|
||||
env_cfg.viewer.draw_sensors = False
|
||||
if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"):
|
||||
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
|
||||
# train_cfg.runner.resume = (args.load_run is not None)
|
||||
train_cfg.runner_class_name = "OnPolicyRunner"
|
||||
if "distill" in args.task: # to save the memory
|
||||
train_cfg.algorithm.teacher_policy.sub_policy_paths = []
|
||||
train_cfg.algorithm.teacher_policy_class_name = "ActorCritic"
|
||||
train_cfg.algorithm.teacher_policy = dict(
|
||||
num_actor_obs= 48,
|
||||
num_critic_obs= 48,
|
||||
num_actions= 12,
|
||||
)
|
||||
|
||||
if args.no_throw:
|
||||
env_cfg.init_state.pos[2] = 0.4
|
||||
env_cfg.domain_rand.init_base_pos_range["x"] = [0.4, 0.4]
|
||||
env_cfg.domain_rand.init_base_vel_range = [0., 0.]
|
||||
env_cfg.domain_rand.init_dof_vel_range = [0., 0.]
|
||||
env_cfg.domain_rand.init_base_rot_range["roll"] = [0., 0.]
|
||||
env_cfg.domain_rand.init_base_rot_range["pitch"] = [0., 0.]
|
||||
env_cfg.domain_rand.init_base_rot_range["yaw"] = [0., 0.]
|
||||
env_cfg.domain_rand.init_base_vel_range = [0., 0.]
|
||||
env_cfg.domain_rand.init_dof_pos_ratio_range = [1., 1.]
|
||||
|
||||
######### Some hacks to run ActorCriticMutex policy ##########
|
||||
if False: # for a1
|
||||
train_cfg.runner.policy_class_name = "ActorCriticClimbMutex"
|
||||
train_cfg.policy.sub_policy_class_name = "ActorCriticRecurrent"
|
||||
logs_root = "logs"
|
||||
train_cfg.policy.sub_policy_paths = [ # must in the order of obstacle ID
|
||||
logs_root + "/field_a1_oracle/Jun03_00-01-38_SkillsPlaneWalking_pEnergySubsteps1e-5_rAlive2_pTorqueExceedIndicate1e+1_noCurriculum_propDelay0.04-0.05_noPerlinRate-2.0_nSubsteps4_envFreq50.0_aScale244",
|
||||
logs_root + "/field_a1_oracle/Aug08_05-22-52_Skills_tilt_pEnergySubsteps1e-5_rAlive1_pPenV5e-3_pPenD5e-3_pPosY0.50_pYaw0.50_rTilt7e-1_pTorqueExceedIndicate1e-1_virtualTerrain_propDelay0.04-0.05_push/",
|
||||
logs_root + "/field_a1_oracle/May21_05-25-19_Skills_crawl_pEnergy2e-5_rAlive1_pPenV6e-2_pPenD6e-2_pPosY0.2_kp50_noContactTerminate_aScale0.5/",
|
||||
logs_root + "/field_a1_oracle/Jun03_00-33-08_Skills_climb_pEnergySubsteps2e-6_rAlive2_pTorqueExceedIndicate2e-1_propDelay0.04-0.05_noPerlinRate0.2_nSubsteps4_envFreq50.0_aScale0.5",
|
||||
logs_root + "/field_a1_oracle/Jun04_01-03-59_Skills_leap_pEnergySubsteps2e-6_rAlive2_pPenV4e-3_pPenD4e-3_pPosY0.20_pYaw0.20_pTorqueExceedSquare1e-3_leapH0.2_propDelay0.04-0.05_noPerlinRate0.2_aScale0.5",
|
||||
]
|
||||
train_cfg.policy.jump_down_policy_path = logs_root + "/field_a1_oracle/Aug30_16-12-14_Skills_climb_climbDownProb0.5_propDelay0.04-0.05"
|
||||
train_cfg.runner.resume = False
|
||||
env_cfg.env.use_lin_vel = True
|
||||
train_cfg.policy.cmd_vel_mapping = {
|
||||
0: 1.0,
|
||||
1: 0.5,
|
||||
2: 0.8,
|
||||
3: 1.2,
|
||||
4: 1.5,
|
||||
}
|
||||
if args.task == "a1_distill":
|
||||
env_cfg.env.obs_components = env_cfg.env.privileged_obs_components
|
||||
env_cfg.env.privileged_obs_gets_privilege = False
|
||||
# default camera position
|
||||
# env_cfg.viewer.lookat = [0.6, 1.2, 0.5]
|
||||
# env_cfg.viewer.pos = [0.6, 0., 0.5]
|
||||
|
||||
# prepare environment
|
||||
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
|
||||
|
@ -168,9 +157,8 @@ def play(args):
|
|||
critic_obs = env.get_privileged_observations()
|
||||
# register debugging options to manually trigger disruption
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_P, "push_robot")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_L, "press_robot")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_J, "action_jitter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_Q, "exit")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_ESCAPE, "exit")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_R, "agent_full_reset")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_U, "full_reset")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_C, "resample_commands")
|
||||
|
@ -180,7 +168,24 @@ def play(args):
|
|||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_D, "rightward")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_F, "leftturn")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_G, "rightturn")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_B, "leftdrag")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_M, "rightdrag")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_X, "stop")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_K, "mark")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_I, "more_plots")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_T, "switch_teacher")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_O, "lean_fronter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_L, "lean_backer")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_COMMA, "lean_lefter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_PERIOD, "lean_righter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_DOWN, "slower")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_UP, "faster")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_LEFT, "lefter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_RIGHT, "righter")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_LEFT_BRACKET, "terrain_left")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_RIGHT_BRACKET, "terrain_right")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_MINUS, "terrain_back")
|
||||
env.gym.subscribe_viewer_keyboard_event(env.viewer, isaacgym.gymapi.KEY_EQUAL, "terrain_forward")
|
||||
# load policy
|
||||
ppo_runner, train_cfg = task_registry.make_alg_runner(
|
||||
env=env,
|
||||
|
@ -200,6 +205,7 @@ def play(args):
|
|||
export_policy_as_jit(ppo_runner.alg.actor_critic, path)
|
||||
print('Exported policy as jit script to: ', path)
|
||||
if RECORD_FRAMES:
|
||||
os.mkdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images"), exist_ok= True)
|
||||
transform = gymapi.Transform()
|
||||
transform.p = gymapi.Vec3(*env_cfg.viewer.pos)
|
||||
transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2)
|
||||
|
@ -212,7 +218,7 @@ def play(args):
|
|||
logger = Logger(env.dt)
|
||||
robot_index = 0 # which robot is used for logging
|
||||
joint_index = 4 # which joint is used for logging
|
||||
stop_state_log = 512 # number of steps before plotting states
|
||||
stop_state_log = args.plot_time # number of steps before plotting states
|
||||
stop_rew_log = env.max_episode_length + 1 # number of steps before print average episode rewards
|
||||
camera_position = np.array(env_cfg.viewer.pos, dtype=np.float64)
|
||||
camera_vel = np.array([0.6, 0., 0.])
|
||||
|
@ -248,12 +254,9 @@ def play(args):
|
|||
if ui_event.action == "push_robot" and ui_event.value > 0:
|
||||
# manully trigger to push the robot
|
||||
env._push_robots()
|
||||
if ui_event.action == "press_robot" and ui_event.value > 0:
|
||||
env.root_states[:, 9] = torch_rand_float(-env.cfg.domain_rand.max_push_vel_xy, 0, (env.num_envs, 1), device=env.device).squeeze(1)
|
||||
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
|
||||
if ui_event.action == "action_jitter" and ui_event.value > 0:
|
||||
# assuming wrong action is taken
|
||||
obs, critic_obs, rews, dones, infos = env.step(torch.tanh(torch.randn_like(actions)))
|
||||
obs, critic_obs, rews, dones, infos = env.step(actions + torch.randn_like(actions) * 0.2)
|
||||
if ui_event.action == "exit" and ui_event.value > 0:
|
||||
print("exit")
|
||||
exit(0)
|
||||
|
@ -263,24 +266,169 @@ def play(args):
|
|||
if ui_event.action == "full_reset" and ui_event.value > 0:
|
||||
print("full_reset")
|
||||
agent_model.reset()
|
||||
if hasattr(ppo_runner.alg, "teacher_actor_critic"):
|
||||
ppo_runner.alg.teacher_actor_critic.reset()
|
||||
# print(env._get_terrain_curriculum_move([robot_index]))
|
||||
obs, _ = env.reset()
|
||||
if ui_event.action == "resample_commands" and ui_event.value > 0:
|
||||
print("resample_commands")
|
||||
env._resample_commands(torch.arange(env.num_envs, device= env.device))
|
||||
if ui_event.action == "stop" and ui_event.value > 0:
|
||||
if hasattr(env, "sampled_x_cmd_buffer"):
|
||||
env.sampled_x_cmd_buffer[:] = 0
|
||||
env.commands[:, :] = 0
|
||||
if hasattr(env, "orientation_cmds"):
|
||||
env.orientation_cmds[:] = env.gravity_vec
|
||||
# env.stop_position.copy_(env.root_states[:, :3])
|
||||
# env.command_ranges["lin_vel_x"] = [0, 0]
|
||||
# env.command_ranges["lin_vel_y"] = [0, 0]
|
||||
# env.command_ranges["ang_vel_yaw"] = [0, 0]
|
||||
if ui_event.action == "forward" and ui_event.value > 0:
|
||||
env.commands[:, 0] = env_cfg.commands.ranges.lin_vel_x[1]
|
||||
# env.command_ranges["lin_vel_x"] = [env_cfg.commands.ranges.lin_vel_x[1], env_cfg.commands.ranges.lin_vel_x[1]]
|
||||
if ui_event.action == "backward" and ui_event.value > 0:
|
||||
env.commands[:, 0] = env_cfg.commands.ranges.lin_vel_x[0]
|
||||
# env.command_ranges["lin_vel_x"] = [env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[0]]
|
||||
if ui_event.action == "leftward" and ui_event.value > 0:
|
||||
env.commands[:, 1] = env_cfg.commands.ranges.lin_vel_y[1]
|
||||
# env.command_ranges["lin_vel_y"] = [env_cfg.commands.ranges.lin_vel_y[1], env_cfg.commands.ranges.lin_vel_y[1]]
|
||||
if ui_event.action == "rightward" and ui_event.value > 0:
|
||||
env.commands[:, 1] = env_cfg.commands.ranges.lin_vel_y[0]
|
||||
# env.command_ranges["lin_vel_y"] = [env_cfg.commands.ranges.lin_vel_y[0], env_cfg.commands.ranges.lin_vel_y[0]]
|
||||
if ui_event.action == "leftturn" and ui_event.value > 0:
|
||||
env.commands[:, 2] = env_cfg.commands.ranges.ang_vel_yaw[1]
|
||||
# env.command_ranges["ang_vel_yaw"] = [env_cfg.commands.ranges.ang_vel_yaw[1], env_cfg.commands.ranges.ang_vel_yaw[1]]
|
||||
if ui_event.action == "rightturn" and ui_event.value > 0:
|
||||
env.commands[:, 2] = env_cfg.commands.ranges.ang_vel_yaw[0]
|
||||
# env.command_ranges["ang_vel_yaw"] = [env_cfg.commands.ranges.ang_vel_yaw[0], env_cfg.commands.ranges.ang_vel_yaw[0]]
|
||||
if ui_event.action == "leftdrag" and ui_event.value > 0:
|
||||
env.root_states[:, 7:10] += quat_rotate(env.base_quat, torch.tensor([[0., 0.5, 0.]], device= env.device))
|
||||
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
|
||||
if ui_event.action == "rightdrag" and ui_event.value > 0:
|
||||
env.root_states[:, 7:10] += quat_rotate(env.base_quat, torch.tensor([[0., -0.5, 0.]], device= env.device))
|
||||
env.gym.set_actor_root_state_tensor(env.sim, gymtorch.unwrap_tensor(env.all_root_states))
|
||||
if ui_event.action == "mark" and ui_event.value > 0:
|
||||
logger.plot_states()
|
||||
if ui_event.action == "more_plots" and ui_event.value > 0:
|
||||
logger.plot_additional_states()
|
||||
if ui_event.action == "switch_teacher" and ui_event.value > 0:
|
||||
args.show_teacher = not args.show_teacher
|
||||
print("show_teacher:", args.show_teacher)
|
||||
if ui_event.action == "lean_fronter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
|
||||
env.orientation_cmds[:, 0] += 0.1
|
||||
print("orientation_cmds:", env.orientation_cmds[:, 0])
|
||||
if ui_event.action == "lean_backer" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
|
||||
env.orientation_cmds[:, 0] -= 0.1
|
||||
print("orientation_cmds:", env.orientation_cmds[:, 0])
|
||||
if ui_event.action == "lean_lefter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
|
||||
env.orientation_cmds[:, 1] += 0.1
|
||||
print("orientation_cmds:", env.orientation_cmds[:, 1])
|
||||
if ui_event.action == "lean_righter" and ui_event.value > 0 and hasattr(env, "orientation_cmds"):
|
||||
env.orientation_cmds[:, 1] -= 0.1
|
||||
print("orientation_cmds:", env.orientation_cmds[:, 1])
|
||||
if ui_event.action == "slower" and ui_event.value > 0:
|
||||
if hasattr(env, "sampled_x_cmd_buffer"):
|
||||
env.sampled_x_cmd_buffer[:] -= 0.2
|
||||
env.commands[:, 0] -= 0.2
|
||||
print("command_x:", env.commands[:, 0])
|
||||
if ui_event.action == "faster" and ui_event.value > 0:
|
||||
if hasattr(env, "sampled_x_cmd_buffer"):
|
||||
env.sampled_x_cmd_buffer[:] += 0.2
|
||||
env.commands[:, 0] += 0.2
|
||||
print("command_x:", env.commands[:, 0])
|
||||
if ui_event.action == "lefter" and ui_event.value > 0:
|
||||
if env.commands[:, 2] < 0:
|
||||
env.commands[:, 2] = 0.
|
||||
else:
|
||||
env.commands[:, 2] += 0.4
|
||||
print("command_yaw:", env.commands[:, 2])
|
||||
if ui_event.action == "righter" and ui_event.value > 0:
|
||||
if env.commands[:, 2] > 0:
|
||||
env.commands[:, 2] = 0.
|
||||
else:
|
||||
env.commands[:, 2] -= 0.4
|
||||
print("command_yaw:", env.commands[:, 2])
|
||||
if ui_event.action == "terrain_forward" and ui_event.value > 0:
|
||||
# env.cfg.terrain.curriculum = False
|
||||
env.terrain_levels[:] += 1
|
||||
env.terrain_levels = torch.clip(
|
||||
env.terrain_levels,
|
||||
min= 0,
|
||||
max= env.cfg.terrain.num_rows - 1,
|
||||
)
|
||||
print("before", env.terrain_levels)
|
||||
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
|
||||
env.reset()
|
||||
agent_model.reset()
|
||||
print("after", env.terrain_levels)
|
||||
if ui_event.action == "terrain_back" and ui_event.value > 0:
|
||||
# env.cfg.terrain.curriculum = False
|
||||
env.terrain_levels[:] -= 1
|
||||
env.terrain_levels = torch.clip(
|
||||
env.terrain_levels,
|
||||
min= 0,
|
||||
max= env.cfg.terrain.num_rows - 1,
|
||||
)
|
||||
print("before", env.terrain_levels)
|
||||
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
|
||||
env.reset()
|
||||
agent_model.reset()
|
||||
print("after", env.terrain_levels)
|
||||
if ui_event.action == "terrain_right" and ui_event.value > 0:
|
||||
# env.cfg.terrain.curriculum = False
|
||||
env.terrain_types[:] -= 1
|
||||
env.terrain_types = torch.clip(
|
||||
env.terrain_types,
|
||||
min= 0,
|
||||
max= env.cfg.terrain.num_cols - 1,
|
||||
)
|
||||
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
|
||||
env.reset()
|
||||
agent_model.reset()
|
||||
if ui_event.action == "terrain_left" and ui_event.value > 0:
|
||||
# env.cfg.terrain.curriculum = False
|
||||
env.terrain_types[:] += 1
|
||||
env.terrain_types = torch.clip(
|
||||
env.terrain_types,
|
||||
min= 0,
|
||||
max= env.cfg.terrain.num_cols - 1,
|
||||
)
|
||||
env.env_origins[:] = env.terrain_origins[env.terrain_levels[:], env.terrain_types[:]]
|
||||
env.reset()
|
||||
agent_model.reset()
|
||||
# if (env.contact_forces[robot_index, env.feet_indices, 2] > 200).any():
|
||||
# print("contact_forces:", env.contact_forces[robot_index, env.feet_indices, 2])
|
||||
# if (abs(env.substep_torques[robot_index]) > 35.).any():
|
||||
# exceed_idxs = torch.where(abs(env.substep_torques[robot_index]) > 35.)
|
||||
# print("substep_torques:", exceed_idxs[1], env.substep_torques[robot_index][exceed_idxs[0], exceed_idxs[1]])
|
||||
if env.torque_exceed_count_envstep[robot_index].any():
|
||||
print("substep torque exceed limit ratio",
|
||||
(torch.abs(env.substep_torques[robot_index]) / (env.torque_limits.unsqueeze(0))).max(),
|
||||
"joint index",
|
||||
torch.where((torch.abs(env.substep_torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any(dim= 0))[0],
|
||||
"timestep", i,
|
||||
)
|
||||
env.torque_exceed_count_envstep[robot_index] = 0
|
||||
# if (torch.abs(env.torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any():
|
||||
# print("torque exceed limit ratio",
|
||||
# (torch.abs(env.torques[robot_index]) / (env.torque_limits.unsqueeze(0))).max(),
|
||||
# "joint index",
|
||||
# torch.where((torch.abs(env.torques[robot_index]) > env.torque_limits.unsqueeze(0) * env.cfg.rewards.soft_torque_limit).any(dim= 0))[0],
|
||||
# "timestep", i,
|
||||
# )
|
||||
# dof_exceed_mask = ((env.dof_pos[robot_index] > env.dof_pos_limits[:, 1]) | (env.dof_pos[robot_index] < env.dof_pos_limits[:, 0]))
|
||||
# if dof_exceed_mask.any():
|
||||
# print("dof pos exceed limit: joint index",
|
||||
# torch.where(dof_exceed_mask)[0],
|
||||
# "amount",
|
||||
# torch.maximum(
|
||||
# env.dof_pos[robot_index][dof_exceed_mask] - env.dof_pos_limits[dof_exceed_mask][:, 1],
|
||||
# env.dof_pos_limits[dof_exceed_mask][:, 0] - env.dof_pos[robot_index][dof_exceed_mask],
|
||||
# ),
|
||||
# "dof value:",
|
||||
# env.dof_pos[robot_index][dof_exceed_mask],
|
||||
# "timestep", i,
|
||||
# )
|
||||
|
||||
if i < stop_state_log:
|
||||
if torch.is_tensor(env.cfg.control.action_scale):
|
||||
|
@ -308,6 +456,9 @@ def play(args):
|
|||
"student_action": actions[robot_index, 2].item(),
|
||||
"teacher_action": teacher_actions[robot_index, 2].item(),
|
||||
"reward": rews[robot_index].item(),
|
||||
'all_dof_vel': env.substep_dof_vel[robot_index].mean(-2).cpu().numpy(),
|
||||
'all_dof_torque': env.substep_torques[robot_index].mean(-2).cpu().numpy(),
|
||||
"power": torch.max(torch.sum(env.substep_torques * env.substep_dof_vel, dim= -1), dim= -1)[0][robot_index].item(),
|
||||
}
|
||||
)
|
||||
elif i==stop_state_log:
|
||||
|
@ -323,8 +474,11 @@ def play(args):
|
|||
|
||||
if dones.any():
|
||||
agent_model.reset(dones)
|
||||
print("env dones,{} because has timeout".format("" if env.time_out_buf[dones].any() else " not"))
|
||||
print(infos)
|
||||
if env.time_out_buf[dones].any():
|
||||
print("env dones because of timeout")
|
||||
else:
|
||||
print("env dones because of failure")
|
||||
# print(infos)
|
||||
if i % 100 == 0:
|
||||
print("frame_rate:" , 100/(time.time_ns() - start_time) * 1e9,
|
||||
"command_x:", env.commands[robot_index, 0],
|
||||
|
@ -333,8 +487,45 @@ def play(args):
|
|||
|
||||
if __name__ == '__main__':
|
||||
EXPORT_POLICY = False
|
||||
RECORD_FRAMES = False
|
||||
MOVE_CAMERA = True
|
||||
CAMERA_FOLLOW = True
|
||||
args = get_args()
|
||||
play(args)
|
||||
args = get_args([
|
||||
dict(name= "--slow", type= float, default= 0., help= "slow down the simulation by sleep secs (float) every frame"),
|
||||
dict(name= "--show_teacher", action= "store_true", default= False, help= "show teacher actions"),
|
||||
dict(name= "--no_teacher", action= "store_true", default= False, help= "whether to disable teacher policy when running the script"),
|
||||
dict(name= "--zero_act_until", type= int, default= 0., help= "zero action until this step"),
|
||||
dict(name= "--sample", action= "store_true", default= False, help= "sample actions from policy"),
|
||||
dict(name= "--plot_time", type= int, default= -1, help= "plot states after this time"),
|
||||
dict(name= "--no_throw", action= "store_true", default= False),
|
||||
dict(name= "--load_cfg", action= "store_true", default= False, help= "use the config from the logdir"),
|
||||
dict(name= "--record", action= "store_true", default= False, help= "record frames"),
|
||||
dict(name= "--frames_dir", type= str, default= "images", help= "which folder to store intermediate recorded frames."),
|
||||
])
|
||||
MOVE_CAMERA = (args.num_envs is None)
|
||||
CAMERA_FOLLOW = MOVE_CAMERA
|
||||
RECORD_FRAMES = args.record
|
||||
try:
|
||||
play(args)
|
||||
except KeyboardInterrupt:
|
||||
print("KeyboardInterrupt")
|
||||
finally:
|
||||
if RECORD_FRAMES and args.load_run is not None:
|
||||
import subprocess
|
||||
print("converting frames to video")
|
||||
log_dir = args.load_run if os.path.isabs(args.load_run) \
|
||||
else os.path.join(
|
||||
LEGGED_GYM_ROOT_DIR,
|
||||
"logs",
|
||||
task_registry.get_cfgs(name=args.task)[1].runner.experiment_name,
|
||||
args.load_run,
|
||||
)
|
||||
subprocess.run(["ffmpeg",
|
||||
"-framerate", "50",
|
||||
"-r", "50",
|
||||
"-i", "logs/images/%04d.png",
|
||||
"-c:v", "libx264",
|
||||
"-hide_banner", "-loglevel", "error",
|
||||
os.path.join(log_dir, f"video_{args.checkpoint}.mp4")
|
||||
])
|
||||
print("done converting frames to video, deleting frame images")
|
||||
for f in os.listdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images")):
|
||||
os.remove(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir, f))
|
||||
print("done deleting frame images")
|
|
@ -29,12 +29,14 @@
|
|||
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
|
||||
|
||||
import numpy as np
|
||||
np.float = np.float32
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import isaacgym
|
||||
from legged_gym.envs import *
|
||||
from legged_gym.utils import get_args, task_registry
|
||||
from legged_gym.utils import get_args
|
||||
from legged_gym.utils.task_registry import task_registry
|
||||
import torch
|
||||
|
||||
from legged_gym.debugger import break_into_debugger
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
|
||||
|
||||
from .helpers import class_to_dict, get_load_path, get_args, export_policy_as_jit, set_seed, update_class_from_dict
|
||||
from .task_registry import task_registry
|
||||
from .logger import Logger
|
||||
from .math import *
|
||||
from .terrain.terrain import Terrain
|
|
@ -62,6 +62,12 @@ class Logger:
|
|||
self.plot_process = Process(target=self._plot)
|
||||
self.plot_process.start()
|
||||
|
||||
def plot_additional_states(self):
|
||||
self.plot_vel_process = Process(target=self._plot_vel)
|
||||
self.plot_vel_process.start()
|
||||
self.plot_torque_process = Process(target=self._plot_torque)
|
||||
self.plot_torque_process.start()
|
||||
|
||||
def _plot(self):
|
||||
nb_rows = 4
|
||||
nb_cols = 3
|
||||
|
@ -98,7 +104,6 @@ class Logger:
|
|||
a = axs[0, 2]
|
||||
if log["base_pitch"]:
|
||||
a.plot(time, log["base_pitch"], label='measured')
|
||||
a.plot(time, [-0.75] * len(time), label= 'thresh')
|
||||
# if log["command_yaw"]: a.plot(time, log["command_yaw"], label='commanded')
|
||||
a.set(xlabel='time [s]', ylabel='base ang [rad]', title='Base pitch')
|
||||
a.legend()
|
||||
|
@ -157,6 +162,58 @@ class Logger:
|
|||
a.legend(fontsize = 5)
|
||||
plt.show()
|
||||
|
||||
def _plot_vel(self):
|
||||
log= self.state_log
|
||||
nb_rows = int(np.sqrt(log['all_dof_vel'][0].shape[0]))
|
||||
nb_cols = int(np.ceil(log['all_dof_vel'][0].shape[0] / nb_rows))
|
||||
nb_rows, nb_cols = nb_cols, nb_rows
|
||||
fig, axs = plt.subplots(nb_rows, nb_cols)
|
||||
for key, value in self.state_log.items():
|
||||
time = np.linspace(0, len(value)*self.dt, len(value))
|
||||
break
|
||||
|
||||
# plot joint velocities
|
||||
for i in range(nb_rows):
|
||||
for j in range(nb_cols):
|
||||
if i*nb_cols+j < log['all_dof_vel'][0].shape[0]:
|
||||
a = axs[i][j]
|
||||
a.plot(
|
||||
time,
|
||||
[all_dof_vel[i*nb_cols+j] for all_dof_vel in log['all_dof_vel']],
|
||||
label='measured',
|
||||
)
|
||||
a.set(xlabel='time [s]', ylabel='Velocity [rad/s]', title=f'Joint Velocity {i*nb_cols+j}')
|
||||
a.legend()
|
||||
else:
|
||||
break
|
||||
plt.show()
|
||||
|
||||
def _plot_torque(self):
|
||||
log= self.state_log
|
||||
nb_rows = int(np.sqrt(log['all_dof_torque'][0].shape[0]))
|
||||
nb_cols = int(np.ceil(log['all_dof_torque'][0].shape[0] / nb_rows))
|
||||
nb_rows, nb_cols = nb_cols, nb_rows
|
||||
fig, axs = plt.subplots(nb_rows, nb_cols)
|
||||
for key, value in self.state_log.items():
|
||||
time = np.linspace(0, len(value)*self.dt, len(value))
|
||||
break
|
||||
|
||||
# plot joint torques
|
||||
for i in range(nb_rows):
|
||||
for j in range(nb_cols):
|
||||
if i*nb_cols+j < log['all_dof_torque'][0].shape[0]:
|
||||
a = axs[i][j]
|
||||
a.plot(
|
||||
time,
|
||||
[all_dof_torque[i*nb_cols+j] for all_dof_torque in log['all_dof_torque']],
|
||||
label='measured',
|
||||
)
|
||||
a.set(xlabel='time [s]', ylabel='Torque [Nm]', title=f'Joint Torque {i*nb_cols+j}')
|
||||
a.legend()
|
||||
else:
|
||||
break
|
||||
plt.show()
|
||||
|
||||
def print_rewards(self):
|
||||
print("Average rewards per second:")
|
||||
for key, values in self.rew_log.items():
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from numpy.random import choice
|
||||
from scipy import interpolate
|
||||
|
||||
|
@ -13,13 +14,14 @@ class TerrainPerlin:
|
|||
self.env_width = cfg.terrain_width
|
||||
self.xSize = cfg.terrain_length * cfg.num_rows # int(cfg.horizontal_scale * cfg.tot_cols)
|
||||
self.ySize = cfg.terrain_width * cfg.num_cols # int(cfg.horizontal_scale * cfg.tot_rows)
|
||||
self.tot_cols = int(self.xSize / cfg.horizontal_scale)
|
||||
self.tot_rows = int(self.ySize / cfg.horizontal_scale)
|
||||
self.tot_rows = int(self.xSize / cfg.horizontal_scale)
|
||||
self.tot_cols = int(self.ySize / cfg.horizontal_scale)
|
||||
assert(self.xSize == cfg.horizontal_scale * self.tot_rows and self.ySize == cfg.horizontal_scale * self.tot_cols)
|
||||
self.heightsamples_float = self.generate_fractal_noise_2d(self.xSize, self.ySize, self.tot_rows, self.tot_cols, **cfg.TerrainPerlin_kwargs)
|
||||
# self.heightsamples_float[self.tot_cols//2 - 100:, :] += 100000
|
||||
# self.heightsamples_float[self.tot_cols//2 - 40: self.tot_cols//2 + 40, :] = np.mean(self.heightsamples_float)
|
||||
self.heightsamples = (self.heightsamples_float * (1 / cfg.vertical_scale)).astype(np.int16)
|
||||
self.heightfield_raw_pyt = torch.tensor(self.heightsamples, device= "cpu")
|
||||
|
||||
|
||||
print("Terrain heightsamples shape: ", self.heightsamples.shape)
|
||||
|
@ -111,3 +113,31 @@ class TerrainPerlin:
|
|||
int(origin_y / self.cfg.horizontal_scale),
|
||||
] * self.cfg.vertical_scale,
|
||||
]
|
||||
self.heightfield_raw_pyt = torch.from_numpy(self.heightsamples).to(device= self.device).float()
|
||||
|
||||
def in_terrain_range(self, pos):
|
||||
""" Check if the given position still have terrain underneath. (same x/y, but z is different)
|
||||
pos: (batch_size, 3) torch.Tensor
|
||||
"""
|
||||
return torch.logical_and(
|
||||
pos[..., :2] >= 0,
|
||||
pos[..., :2] < torch.tensor([self.xSize, self.ySize], device= self.device),
|
||||
).all(dim= -1)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_terrain_heights(self, points):
|
||||
""" Get the terrain heights below the given points """
|
||||
points_shape = points.shape
|
||||
points = points.view(-1, 3)
|
||||
points_x_px = (points[:, 0] / self.cfg.horizontal_scale).to(int)
|
||||
points_y_px = (points[:, 1] / self.cfg.horizontal_scale).to(int)
|
||||
out_of_range_mask = torch.logical_or(
|
||||
torch.logical_or(points_x_px < 0, points_x_px >= self.heightfield_raw_pyt.shape[0]),
|
||||
torch.logical_or(points_y_px < 0, points_y_px >= self.heightfield_raw_pyt.shape[1]),
|
||||
)
|
||||
points_x_px = torch.clip(points_x_px, 0, self.heightfield_raw_pyt.shape[0] - 1)
|
||||
points_y_px = torch.clip(points_y_px, 0, self.heightfield_raw_pyt.shape[1] - 1)
|
||||
heights = self.heightfield_raw_pyt[points_x_px, points_y_px] * self.cfg.vertical_scale
|
||||
heights[out_of_range_mask] = - torch.inf
|
||||
heights = heights.view(points_shape[:-1])
|
||||
return heights
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.random import choice
|
||||
from scipy import interpolate
|
||||
|
||||
|
@ -162,6 +163,65 @@ class Terrain:
|
|||
y2 = int((self.env_width/2. + 1) / terrain.horizontal_scale)
|
||||
env_origin_z = np.max(terrain.height_field_raw[x1:x2, y1:y2])*terrain.vertical_scale
|
||||
self.env_origins[i, j] = [env_origin_x, env_origin_y, env_origin_z]
|
||||
|
||||
def _create_heightfield(self, gym, sim, device= "cpu"):
|
||||
""" Adds a heightfield terrain to the simulation, sets parameters based on the cfg.
|
||||
"""
|
||||
hf_params = gym.HeightFieldParams()
|
||||
hf_params.column_scale = self.cfg.horizontal_scale
|
||||
hf_params.row_scale = self.cfg.horizontal_scale
|
||||
hf_params.vertical_scale = self.cfg.vertical_scale
|
||||
hf_params.nbRows = self.tot_cols
|
||||
hf_params.nbColumns = self.tot_rows
|
||||
hf_params.transform.p.x = -self.cfg.border_size
|
||||
hf_params.transform.p.y = -self.cfg.border_size
|
||||
hf_params.transform.p.z = 0.0
|
||||
hf_params.static_friction = self.cfg.static_friction
|
||||
hf_params.dynamic_friction = self.cfg.dynamic_friction
|
||||
hf_params.restitution = self.cfg.restitution
|
||||
|
||||
self.gym.add_heightfield(sim, self.heightsamples, hf_params)
|
||||
|
||||
def _create_trimesh(self, gym, sim, device= "cpu"):
|
||||
""" Adds a triangle mesh terrain to the simulation, sets parameters based on the cfg.
|
||||
# """
|
||||
tm_params = gym.TriangleMeshParams()
|
||||
tm_params.nb_vertices = self.vertices.shape[0]
|
||||
tm_params.nb_triangles = self.triangles.shape[0]
|
||||
|
||||
tm_params.transform.p.x = -self.cfg.border_size
|
||||
tm_params.transform.p.y = -self.cfg.border_size
|
||||
tm_params.transform.p.z = 0.0
|
||||
tm_params.static_friction = self.cfg.static_friction
|
||||
tm_params.dynamic_friction = self.cfg.dynamic_friction
|
||||
tm_params.restitution = self.cfg.restitution
|
||||
self.gym.add_triangle_mesh(self.sim, self.vertices.flatten(order='C'), self.triangles.flatten(order='C'), tm_params)
|
||||
|
||||
def add_terrain_to_sim(self, gym, sim, device= "cpu"):
|
||||
if self.type == "heightfield":
|
||||
self._create_heightfield(gym, sim, device)
|
||||
elif self.type == "trimesh":
|
||||
self._create_trimesh(gym, sim, device)
|
||||
else:
|
||||
raise NotImplementedError("Terrain type {} not implemented".format(self.type))
|
||||
|
||||
def get_terrain_heights(self, points):
|
||||
""" Return the z coordinate of the terrain where just below the given points. """
|
||||
num_robots = points.shape[0]
|
||||
points += self.cfg.border_size
|
||||
points = (points/self.cfg.horizontal_scale).long()
|
||||
px = points[:, :, 0].view(-1)
|
||||
py = points[:, :, 1].view(-1)
|
||||
px = torch.clip(px, 0, self.heightsamples.shape[0]-2)
|
||||
py = torch.clip(py, 0, self.heightsamples.shape[1]-2)
|
||||
|
||||
heights1 = self.heightsamples[px, py]
|
||||
heights2 = self.heightsamples[px+1, py]
|
||||
heights3 = self.heightsamples[px, py+1]
|
||||
heights = torch.min(heights1, heights2)
|
||||
heights = torch.min(heights, heights3)
|
||||
|
||||
return heights.view(num_robots, -1) * self.cfg.vertical_scale
|
||||
|
||||
def gap_terrain(terrain, gap_size, platform_size=1.):
|
||||
gap_size = int(gap_size / terrain.horizontal_scale)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
# then you can import it like `from legged_gym.utils.webviewer import WebViewer``
|
||||
from .webviewer import WebViewer
|
|
@ -0,0 +1,80 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<style>
|
||||
html, body {
|
||||
width: 100%; height: 100%;
|
||||
margin: 0; overflow: hidden; display: block;
|
||||
background-color: #000;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<canvas id="canvas" tabindex='1'></canvas>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
var canvas, context, image;
|
||||
|
||||
function sendInputRequest(data){
|
||||
let xmlRequest = new XMLHttpRequest();
|
||||
xmlRequest.open("POST", "{{ url_for('_route_input_event') }}", true);
|
||||
xmlRequest.setRequestHeader("Content-Type", "application/json");
|
||||
xmlRequest.send(JSON.stringify(data));
|
||||
}
|
||||
|
||||
window.onload = function(){
|
||||
canvas = document.getElementById("canvas");
|
||||
context = canvas.getContext('2d');
|
||||
image = new Image();
|
||||
image.src = "{{ url_for('_route_stream') }}";
|
||||
image1 = new Image();
|
||||
image1.src = "{{ url_for('_route_stream_depth') }}";
|
||||
|
||||
canvas.width = window.innerWidth;
|
||||
canvas.height = window.innerHeight;
|
||||
|
||||
window.addEventListener('resize', function(){
|
||||
canvas.width = window.innerWidth;
|
||||
canvas.height = window.innerHeight;
|
||||
}, false);
|
||||
|
||||
window.setInterval(function(){
|
||||
let ratio = image.naturalWidth / image.naturalHeight;
|
||||
context.drawImage(image, 0, 0, canvas.width, canvas.width / ratio);
|
||||
let imageHeight = canvas.width / ratio;
|
||||
context.drawImage(image1, 0, imageHeight, canvas.width, canvas.width / image1.naturalWidth * image1.naturalHeight);
|
||||
}, 50);
|
||||
|
||||
canvas.addEventListener('keydown', function(event){
|
||||
if(event.keyCode != 18)
|
||||
sendInputRequest({key: event.keyCode});
|
||||
}, false);
|
||||
|
||||
canvas.addEventListener('mousemove', function(event){
|
||||
if(event.buttons){
|
||||
let data = {dx: event.movementX, dy: event.movementY};
|
||||
if(event.altKey && event.buttons == 1){
|
||||
data.key = 18;
|
||||
data.mouse = "left";
|
||||
}
|
||||
else if(event.buttons == 2)
|
||||
data.mouse = "right";
|
||||
else if(event.buttons == 4)
|
||||
data.mouse = "middle";
|
||||
else
|
||||
return;
|
||||
sendInputRequest(data);
|
||||
}
|
||||
}, false);
|
||||
|
||||
canvas.addEventListener('wheel', function(event){
|
||||
sendInputRequest({mouse: "wheel", dz: Math.sign(event.deltaY)});
|
||||
}, false);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,442 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
import os
|
||||
|
||||
try:
|
||||
import flask
|
||||
except ImportError:
|
||||
flask = None
|
||||
|
||||
try:
|
||||
import imageio
|
||||
import isaacgym
|
||||
import isaacgym.torch_utils as torch_utils
|
||||
from isaacgym import gymapi
|
||||
except ImportError:
|
||||
imageio = None
|
||||
isaacgym = None
|
||||
torch_utils = None
|
||||
gymapi = None
|
||||
|
||||
|
||||
def cartesian_to_spherical(x, y, z):
|
||||
r = np.sqrt(x**2 + y**2 + z**2)
|
||||
theta = np.arccos(z/r) if r != 0 else 0
|
||||
phi = np.arctan2(y, x)
|
||||
return r, theta, phi
|
||||
|
||||
def spherical_to_cartesian(r, theta, phi):
|
||||
x = r * np.sin(theta) * np.cos(phi)
|
||||
y = r * np.sin(theta) * np.sin(phi)
|
||||
z = r * np.cos(theta)
|
||||
return x, y, z
|
||||
|
||||
class WebViewer:
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 5000) -> None:
|
||||
"""
|
||||
Web viewer for Isaac Gym
|
||||
|
||||
:param host: Host address (default: "127.0.0.1")
|
||||
:type host: str
|
||||
:param port: Port number (default: 5000)
|
||||
:type port: int
|
||||
"""
|
||||
self._app = flask.Flask(__name__)
|
||||
self._app.add_url_rule("/", view_func=self._route_index)
|
||||
self._app.add_url_rule("/_route_stream", view_func=self._route_stream)
|
||||
self._app.add_url_rule("/_route_stream_depth", view_func=self._route_stream_depth)
|
||||
self._app.add_url_rule("/_route_input_event", view_func=self._route_input_event, methods=["POST"])
|
||||
|
||||
self._log = logging.getLogger('werkzeug')
|
||||
self._log.disabled = True
|
||||
self._app.logger.disabled = True
|
||||
|
||||
self._image = None
|
||||
self._image_depth = None
|
||||
self._camera_id = 0
|
||||
self._camera_type = gymapi.IMAGE_COLOR
|
||||
# get from self._env and stream to webviewer (expected shape: num_envs, 1, height, width)
|
||||
self._depth_image_buffer_name = "forward_depth_output"
|
||||
self._notified = False
|
||||
self._wait_for_page = True
|
||||
self._pause_stream = False
|
||||
self._event_load = threading.Event()
|
||||
self._event_stream = threading.Event()
|
||||
self._event_stream_depth = threading.Event()
|
||||
|
||||
# start server
|
||||
self._thread = threading.Thread(target=lambda: \
|
||||
self._app.run(host=host, port=port, debug=False, use_reloader=False), daemon=True)
|
||||
self._thread.start()
|
||||
print(f"\nStarting web viewer on http://{host}:{port}/\n")
|
||||
|
||||
def _route_index(self) -> 'flask.Response':
|
||||
"""Render the web page
|
||||
|
||||
:return: Flask response
|
||||
:rtype: flask.Response
|
||||
"""
|
||||
with open(os.path.join(os.path.dirname(__file__), "webviewer.html"), 'r', encoding='utf-8') as file:
|
||||
template = file.read()
|
||||
self._event_load.set()
|
||||
return flask.render_template_string(template)
|
||||
|
||||
def _route_stream(self) -> 'flask.Response':
|
||||
"""Stream the image to the web page
|
||||
|
||||
:return: Flask response
|
||||
:rtype: flask.Response
|
||||
"""
|
||||
return flask.Response(self._stream(), mimetype='multipart/x-mixed-replace; boundary=frame')
|
||||
|
||||
def _route_stream_depth(self) -> 'flask.Response':
|
||||
return flask.Response(self._stream_depth(), mimetype='multipart/x-mixed-replace; boundary=frame')
|
||||
|
||||
def _route_input_event(self) -> 'flask.Response':
|
||||
|
||||
# get keyboard and mouse inputs
|
||||
data = flask.request.get_json()
|
||||
key, mouse = data.get("key", None), data.get("mouse", None)
|
||||
dx, dy, dz = data.get("dx", None), data.get("dy", None), data.get("dz", None)
|
||||
|
||||
transform = self._gym.get_camera_transform(self._sim,
|
||||
self._envs[self._camera_id],
|
||||
self._cameras[self._camera_id])
|
||||
|
||||
# zoom in/out
|
||||
if mouse == "wheel":
|
||||
# compute zoom vector
|
||||
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
|
||||
r += 0.05 * dz
|
||||
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
|
||||
|
||||
# orbit camera
|
||||
elif mouse == "left":
|
||||
# convert mouse movement to angle
|
||||
dx *= 0.2 * math.pi / 180
|
||||
dy *= 0.2 * math.pi / 180
|
||||
|
||||
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
|
||||
theta -= dy
|
||||
phi -= dx
|
||||
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
|
||||
|
||||
# pan camera
|
||||
elif mouse == "right":
|
||||
# convert mouse movement to angle
|
||||
dx *= -0.2 * math.pi / 180
|
||||
dy *= -0.2 * math.pi / 180
|
||||
|
||||
r, theta, phi = cartesian_to_spherical(*self.cam_pos_rel)
|
||||
theta += dy
|
||||
phi += dx
|
||||
self.cam_pos_rel = spherical_to_cartesian(r, theta, phi)
|
||||
|
||||
elif key == 219: # prev
|
||||
self._camera_id = (self._camera_id-1) % self._env.num_envs
|
||||
return flask.Response(status=200)
|
||||
|
||||
elif key == 221: # next
|
||||
self._camera_id = (self._camera_id+1) % self._env.num_envs
|
||||
return flask.Response(status=200)
|
||||
|
||||
# pause stream (V: 86)
|
||||
elif key == 86:
|
||||
self._pause_stream = not self._pause_stream
|
||||
return flask.Response(status=200)
|
||||
|
||||
# change image type (T: 84)
|
||||
elif key == 84:
|
||||
if self._camera_type == gymapi.IMAGE_COLOR:
|
||||
self._camera_type = gymapi.IMAGE_DEPTH
|
||||
elif self._camera_type == gymapi.IMAGE_DEPTH:
|
||||
self._camera_type = gymapi.IMAGE_COLOR
|
||||
return flask.Response(status=200)
|
||||
|
||||
else:
|
||||
return flask.Response(status=200)
|
||||
|
||||
return flask.Response(status=200)
|
||||
|
||||
def _stream(self) -> bytes:
|
||||
"""Format the image to be streamed
|
||||
|
||||
:return: Image encoded as Content-Type
|
||||
:rtype: bytes
|
||||
"""
|
||||
while True:
|
||||
self._event_stream.wait()
|
||||
|
||||
# prepare image
|
||||
image = imageio.imwrite("<bytes>", self._image, format="JPEG")
|
||||
|
||||
# stream image
|
||||
yield (b'--frame\r\n'
|
||||
b'Content-Type: image/jpeg\r\n\r\n' + image + b'\r\n')
|
||||
|
||||
self._event_stream.clear()
|
||||
self._notified = False
|
||||
|
||||
def _stream_depth(self) -> bytes:
|
||||
while self._env.cfg.viewer.stream_depth:
|
||||
self._event_stream_depth.wait()
|
||||
|
||||
# prepare image
|
||||
image = imageio.imwrite("<bytes>", self._image_depth, format="JPEG")
|
||||
|
||||
# stream image
|
||||
yield (b'--frame\r\n'
|
||||
b'Content-Type: image/jpeg\r\n\r\n' + image + b'\r\n')
|
||||
|
||||
self._event_stream_depth.clear()
|
||||
|
||||
def attach_view_camera(self, i, env_handle, actor_handle, root_pos):
|
||||
if True:
|
||||
camera_props = gymapi.CameraProperties()
|
||||
camera_props.width = 960
|
||||
camera_props.height = 540
|
||||
# camera_props.enable_tensors = True
|
||||
# camera_props.horizontal_fov = camera_horizontal_fov
|
||||
|
||||
camera_handle = self._gym.create_camera_sensor(env_handle, camera_props)
|
||||
self._cameras.append(camera_handle)
|
||||
|
||||
cam_pos = root_pos + np.array([0, 1, 0.5])
|
||||
self._gym.set_camera_location(camera_handle, env_handle, gymapi.Vec3(*cam_pos), gymapi.Vec3(*root_pos))
|
||||
|
||||
def setup(self, env) -> None:
|
||||
"""Setup the web viewer
|
||||
|
||||
:param gym: The gym
|
||||
:type gym: isaacgym.gymapi.Gym
|
||||
:param sim: Simulation handle
|
||||
:type sim: isaacgym.gymapi.Sim
|
||||
:param envs: Environment handles
|
||||
:type envs: list of ints
|
||||
:param cameras: Camera handles
|
||||
:type cameras: list of ints
|
||||
"""
|
||||
self._gym = env.gym
|
||||
self._sim = env.sim
|
||||
self._envs = env.envs
|
||||
self._cameras = []
|
||||
self._env = env
|
||||
self.cam_pos_rel = np.array([0, 2, 1])
|
||||
for i in range(self._env.num_envs):
|
||||
root_pos = self._env.root_states[i, :3].cpu().numpy()
|
||||
self.attach_view_camera(i, self._envs[i], self._env.actor_handles[i], root_pos)
|
||||
|
||||
def render(self,
|
||||
fetch_results: bool = True,
|
||||
step_graphics: bool = True,
|
||||
render_all_camera_sensors: bool = True,
|
||||
wait_for_page_load: bool = True) -> None:
|
||||
"""Render and get the image from the current camera
|
||||
|
||||
This function must be called after the simulation is stepped (post_physics_step).
|
||||
The following Isaac Gym functions are called before get the image.
|
||||
Their calling can be skipped by setting the corresponding argument to False
|
||||
|
||||
- fetch_results
|
||||
- step_graphics
|
||||
- render_all_camera_sensors
|
||||
|
||||
:param fetch_results: Call Gym.fetch_results method (default: True)
|
||||
:type fetch_results: bool
|
||||
:param step_graphics: Call Gym.step_graphics method (default: True)
|
||||
:type step_graphics: bool
|
||||
:param render_all_camera_sensors: Call Gym.render_all_camera_sensors method (default: True)
|
||||
:type render_all_camera_sensors: bool
|
||||
:param wait_for_page_load: Wait for the page to load (default: True)
|
||||
:type wait_for_page_load: bool
|
||||
"""
|
||||
# wait for page to load
|
||||
if self._wait_for_page:
|
||||
if wait_for_page_load:
|
||||
if not self._event_load.is_set():
|
||||
print("Waiting for web page to begin loading...")
|
||||
self._event_load.wait()
|
||||
self._event_load.clear()
|
||||
self._wait_for_page = False
|
||||
|
||||
# pause stream
|
||||
if self._pause_stream:
|
||||
return
|
||||
|
||||
if self._notified:
|
||||
return
|
||||
|
||||
# isaac gym API
|
||||
if fetch_results:
|
||||
self._gym.fetch_results(self._sim, True)
|
||||
if step_graphics:
|
||||
self._gym.step_graphics(self._sim)
|
||||
if render_all_camera_sensors:
|
||||
self._gym.render_all_camera_sensors(self._sim)
|
||||
|
||||
# get image
|
||||
image = self._gym.get_camera_image(self._sim,
|
||||
self._envs[self._camera_id],
|
||||
self._cameras[self._camera_id],
|
||||
self._camera_type)
|
||||
if self._camera_type == gymapi.IMAGE_COLOR:
|
||||
self._image = image.reshape(image.shape[0], -1, 4)[..., :3]
|
||||
elif self._camera_type == gymapi.IMAGE_DEPTH:
|
||||
self._image = -image.reshape(image.shape[0], -1)
|
||||
minimum = 0 if np.isinf(np.min(self._image)) else np.min(self._image)
|
||||
maximum = 5 if np.isinf(np.max(self._image)) else np.max(self._image)
|
||||
self._image = np.clip(1 - (self._image - minimum) / (maximum - minimum), 0, 1)
|
||||
self._image = np.uint8(255 * self._image)
|
||||
else:
|
||||
raise ValueError("Unsupported camera type")
|
||||
|
||||
if self._env.cfg.viewer.stream_depth:
|
||||
self._image_depth = getattr(self._env, self._depth_image_buffer_name)[self._camera_id, -1].cpu().numpy() # expected value range: [0, 1]
|
||||
self._image_depth = np.uint8(255 * self._image_depth)
|
||||
|
||||
root_pos = self._env.root_states[self._camera_id, :3].cpu().numpy()
|
||||
cam_pos = root_pos + self.cam_pos_rel
|
||||
self._gym.set_camera_location(self._cameras[self._camera_id], self._envs[self._camera_id], gymapi.Vec3(*cam_pos), gymapi.Vec3(*root_pos))
|
||||
|
||||
# notify stream thread
|
||||
self._event_stream.set()
|
||||
if self._env.cfg.viewer.stream_depth:
|
||||
self._event_stream_depth.set()
|
||||
self._notified = True
|
||||
|
||||
|
||||
def ik(jacobian_end_effector: torch.Tensor,
|
||||
current_position: torch.Tensor,
|
||||
current_orientation: torch.Tensor,
|
||||
goal_position: torch.Tensor,
|
||||
goal_orientation: Optional[torch.Tensor] = None,
|
||||
damping_factor: float = 0.05,
|
||||
squeeze_output: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Inverse kinematics using damped least squares method
|
||||
|
||||
:param jacobian_end_effector: End effector's jacobian
|
||||
:type jacobian_end_effector: torch.Tensor
|
||||
:param current_position: End effector's current position
|
||||
:type current_position: torch.Tensor
|
||||
:param current_orientation: End effector's current orientation
|
||||
:type current_orientation: torch.Tensor
|
||||
:param goal_position: End effector's goal position
|
||||
:type goal_position: torch.Tensor
|
||||
:param goal_orientation: End effector's goal orientation (default: None)
|
||||
:type goal_orientation: torch.Tensor or None
|
||||
:param damping_factor: Damping factor (default: 0.05)
|
||||
:type damping_factor: float
|
||||
:param squeeze_output: Squeeze output (default: True)
|
||||
:type squeeze_output: bool
|
||||
|
||||
:return: Change in joint angles
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if goal_orientation is None:
|
||||
goal_orientation = current_orientation
|
||||
|
||||
# compute error
|
||||
q = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation))
|
||||
error = torch.cat([goal_position - current_position, # position error
|
||||
q[:, 0:3] * torch.sign(q[:, 3]).unsqueeze(-1)], # orientation error
|
||||
dim=-1).unsqueeze(-1)
|
||||
|
||||
# solve damped least squares (dO = J.T * V)
|
||||
transpose = torch.transpose(jacobian_end_effector, 1, 2)
|
||||
lmbda = torch.eye(6, device=jacobian_end_effector.device) * (damping_factor ** 2)
|
||||
if squeeze_output:
|
||||
return (transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error).squeeze(dim=2)
|
||||
else:
|
||||
return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error
|
||||
|
||||
def print_arguments(args):
|
||||
print("")
|
||||
print("Arguments")
|
||||
for a in args.__dict__:
|
||||
print(f" |-- {a}: {args.__getattribute__(a)}")
|
||||
|
||||
def print_asset_options(asset_options: 'isaacgym.gymapi.AssetOptions', asset_name: str = ""):
|
||||
attrs = ["angular_damping", "armature", "collapse_fixed_joints", "convex_decomposition_from_submeshes",
|
||||
"default_dof_drive_mode", "density", "disable_gravity", "fix_base_link", "flip_visual_attachments",
|
||||
"linear_damping", "max_angular_velocity", "max_linear_velocity", "mesh_normal_mode", "min_particle_mass",
|
||||
"override_com", "override_inertia", "replace_cylinder_with_capsule", "tendon_limit_stiffness", "thickness",
|
||||
"use_mesh_materials", "use_physx_armature", "vhacd_enabled"] # vhacd_params
|
||||
print("\nAsset options{}".format(f" ({asset_name})" if asset_name else ""))
|
||||
for attr in attrs:
|
||||
print(" |-- {}: {}".format(attr, getattr(asset_options, attr) if hasattr(asset_options, attr) else "--"))
|
||||
# vhacd attributes
|
||||
if attr == "vhacd_enabled" and hasattr(asset_options, attr) and getattr(asset_options, attr):
|
||||
vhacd_attrs = ["alpha", "beta", "concavity", "convex_hull_approximation", "convex_hull_downsampling",
|
||||
"max_convex_hulls", "max_num_vertices_per_ch", "min_volume_per_ch", "mode", "ocl_acceleration",
|
||||
"pca", "plane_downsampling", "project_hull_vertices", "resolution"]
|
||||
print(" |-- vhacd_params:")
|
||||
for vhacd_attr in vhacd_attrs:
|
||||
print(" | |-- {}: {}".format(vhacd_attr, getattr(asset_options.vhacd_params, vhacd_attr) \
|
||||
if hasattr(asset_options.vhacd_params, vhacd_attr) else "--"))
|
||||
|
||||
def print_sim_components(gym, sim):
|
||||
print("")
|
||||
print("Sim components")
|
||||
print(" |-- env count:", gym.get_env_count(sim))
|
||||
print(" |-- actor count:", gym.get_sim_actor_count(sim))
|
||||
print(" |-- rigid body count:", gym.get_sim_rigid_body_count(sim))
|
||||
print(" |-- joint count:", gym.get_sim_joint_count(sim))
|
||||
print(" |-- dof count:", gym.get_sim_dof_count(sim))
|
||||
print(" |-- force sensor count:", gym.get_sim_force_sensor_count(sim))
|
||||
|
||||
def print_env_components(gym, env):
|
||||
print("")
|
||||
print("Env components")
|
||||
print(" |-- actor count:", gym.get_actor_count(env))
|
||||
print(" |-- rigid body count:", gym.get_env_rigid_body_count(env))
|
||||
print(" |-- joint count:", gym.get_env_joint_count(env))
|
||||
print(" |-- dof count:", gym.get_env_dof_count(env))
|
||||
|
||||
def print_actor_components(gym, env, actor):
|
||||
print("")
|
||||
print("Actor components")
|
||||
print(" |-- rigid body count:", gym.get_actor_rigid_body_count(env, actor))
|
||||
print(" |-- joint count:", gym.get_actor_joint_count(env, actor))
|
||||
print(" |-- dof count:", gym.get_actor_dof_count(env, actor))
|
||||
print(" |-- actuator count:", gym.get_actor_actuator_count(env, actor))
|
||||
print(" |-- rigid shape count:", gym.get_actor_rigid_shape_count(env, actor))
|
||||
print(" |-- soft body count:", gym.get_actor_soft_body_count(env, actor))
|
||||
print(" |-- tendon count:", gym.get_actor_tendon_count(env, actor))
|
||||
|
||||
def print_dof_properties(gymapi, props):
|
||||
print("")
|
||||
print("DOF properties")
|
||||
print(" |-- hasLimits:", props["hasLimits"])
|
||||
print(" |-- lower:", props["lower"])
|
||||
print(" |-- upper:", props["upper"])
|
||||
print(" |-- driveMode:", props["driveMode"])
|
||||
print(" | |-- {}: gymapi.DOF_MODE_NONE".format(int(gymapi.DOF_MODE_NONE)))
|
||||
print(" | |-- {}: gymapi.DOF_MODE_POS".format(int(gymapi.DOF_MODE_POS)))
|
||||
print(" | |-- {}: gymapi.DOF_MODE_VEL".format(int(gymapi.DOF_MODE_VEL)))
|
||||
print(" | |-- {}: gymapi.DOF_MODE_EFFORT".format(int(gymapi.DOF_MODE_EFFORT)))
|
||||
print(" |-- stiffness:", props["stiffness"])
|
||||
print(" |-- damping:", props["damping"])
|
||||
print(" |-- velocity (max):", props["velocity"])
|
||||
print(" |-- effort (max):", props["effort"])
|
||||
print(" |-- friction:", props["friction"])
|
||||
print(" |-- armature:", props["armature"])
|
||||
|
||||
def print_links_and_dofs(gym, asset):
|
||||
link_dict = gym.get_asset_rigid_body_dict(asset)
|
||||
dof_dict = gym.get_asset_dof_dict(asset)
|
||||
|
||||
print("")
|
||||
print("Links")
|
||||
for k in link_dict:
|
||||
print(f" |-- {k}: {link_dict[k]}")
|
||||
print("DOFs")
|
||||
for k in dof_dict:
|
||||
print(f" |-- {k}: {dof_dict[k]}")
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904144240.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904144240.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904182446.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904182446.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193628.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193628.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193644.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193644.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193653.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193653.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193705.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193705.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193710.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193710.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193711.urdf
Executable file
1189
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230904193711.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112047.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112047.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112143.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112143.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112217.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112217.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112735.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112735.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112749.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112749.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112752.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112752.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112944.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905112944.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113000.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113000.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113059.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113059.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113633.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113633.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113754.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113754.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113755.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113755.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113839.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113839.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113840.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905113840.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905114759.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905114759.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905114819.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905114819.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115021.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115021.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115022.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115022.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115040.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115040.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115215.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115215.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115226.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115226.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115300.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115300.urdf
Executable file
File diff suppressed because it is too large
Load Diff
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115313.urdf
Executable file
1217
legged_gym/resources/robots/go2/urdf/.history/go2_description_20230905115313.urdf
Executable file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
After Width: | Height: | Size: 213 KiB |
Binary file not shown.
After Width: | Height: | Size: 200 KiB |
File diff suppressed because it is too large
Load Diff
|
@ -12,6 +12,7 @@ setup(
|
|||
install_requires=['isaacgym',
|
||||
'rsl-rl',
|
||||
'matplotlib',
|
||||
'tensorboard',
|
||||
'tensorboardX',
|
||||
'debugpy']
|
||||
)
|
|
@ -1,9 +1,9 @@
|
|||
# Deploy the model on your real Unitree robot
|
||||
# Deploy the model on your real Unitree Go1 robot
|
||||
|
||||
This version shows an example of how to deploy the model on the Unittree Go1 robot (with Nvidia Jetson NX).
|
||||
|
||||
## Install dependencies on Go1
|
||||
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script` to the `parkour` folder.
|
||||
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script/go1` to the `parkour` folder.
|
||||
|
||||
1. Install ROS and the [unitree ros package for Go1](https://github.com/Tsinghua-MARS-Lab/unitree_ros_real.git) and follow the instructions to set up the robot on branch `go1`
|
||||
|
||||
|
@ -29,7 +29,7 @@ To deploy the trained model on Go1, please set up a folder on your robot, e.g. `
|
|||
4. 3D print the camera mount for Go1 using the step file in `go1_ckpts/go1_camMount_30Down.step`. Use the two pairs of screw holes to mount the Intel Realsense D435i camera on the robot, as shown in the picture below.
|
||||
|
||||
<p align="center">
|
||||
<img src="images/go1_camMount_30Down.png" width="50%"/>
|
||||
<img src="../images/go1_camMount_30Down.png" width="50%"/>
|
||||
</p>
|
||||
|
||||
## Run the model on Go1
|
|
@ -0,0 +1,73 @@
|
|||
# Deploy the model on your real Unitree Go2 robot
|
||||
|
||||
This file shows an example of how to deploy the model on the Unittree Go2 robot (with Nvidia Jetson NX).
|
||||
|
||||
The code is a quick start of the deployment and fit the simulation as much as possible. You can modify the code to fit your own project.
|
||||
|
||||
## Install dependencies on Go2
|
||||
|
||||
1. Take Nvidia Jetson Orin as an exmaple, make sure your JetPack and related software are up-to-date.
|
||||
|
||||
2. Install ROS and the [unitree ros package for Go2](https://support.unitree.com/home/en/developer/ROS2_service)
|
||||
|
||||
3. Set up a folder on your robot for this project, e.g. `parkour`. Then `cd` into it.
|
||||
|
||||
4. Create a python virtual env and install the dependencies.
|
||||
|
||||
- Install pytorch on a Python 3 environment.
|
||||
```bash
|
||||
sudo apt-get install python3-pip python3-dev python3-venv
|
||||
python3 -m venv parkour_venv
|
||||
source parkour_venv/bin/activate
|
||||
```
|
||||
|
||||
- Download the pip wheel file from [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048) with v1.10.0. Then install it with
|
||||
```bash
|
||||
pip install torch-1.10.0-cp36-cp36m-linux_aarch64.whl
|
||||
```
|
||||
|
||||
- Install `ros2-numpy` from [here](https://github.com/nitesh-subedi/ros2_numpy) in a new colcon_ws, where you prefer.
|
||||
```bash
|
||||
pip install transformations pybase64
|
||||
mkdir -p ros2_numpy_ws/src
|
||||
cd ros2_numpy_ws/src
|
||||
git clone https://github.com/nitesh-subedi/ros2_numpy.git
|
||||
cd ../
|
||||
colcon build
|
||||
```
|
||||
|
||||
4. Copy folders of this project.
|
||||
- copy the `rsl_rl` folder to the `parkour` folder.
|
||||
- copy the distilled parkour log folder (e.g. **Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08**) to the `parkour` folder.
|
||||
- copy all the files in `onboard_script/go2` to the `parkour` folder.
|
||||
|
||||
3. Install rsl_rl and other dependencies.
|
||||
```bash
|
||||
pip install -e ./rsl_rl
|
||||
```
|
||||
|
||||
## Run the model on Go2
|
||||
|
||||
***Disclaimer:*** *Always put a safety belt on the robot when the robot moves. The robot may fall down and cause damage to itself or the environment.*
|
||||
|
||||
1. Put the robot on the ground, power on the robot, and **turn off the builtin sport service**. Make sure your Intel Realsense D435i camera is connected to the robot and the camera is installed where you calibrated the camera.
|
||||
|
||||
> To turn off the builtin sport service, please refer to the [official guide](https://support.unitree.com/home/zh/developer/Basic_motion_control) and [official example](https://github.com/unitreerobotics/unitree_sdk2/blob/main/example/low_level/stand_example_go2.cpp#L184)
|
||||
|
||||
2. Launch 2 terminals onboard (whether 2 ssh connections from your computer or something else), named T_visual, T_run. Source the Unitree ROS environment, `ros2_numpy_ws` and the python virtual environment in both terminals.
|
||||
|
||||
3. In T_visual, run
|
||||
```bash
|
||||
cd parkour
|
||||
python go2_visual.py --logdir Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08
|
||||
```
|
||||
where `Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08` is the logdir of the distilled model.
|
||||
|
||||
4. In T_run, run
|
||||
```bash
|
||||
cd parkour
|
||||
python go2_run.py --logdir Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08
|
||||
```
|
||||
where `Jul18_07-22-08_Go2_10skills_fromJul16_07-38-08` is the logdir of the distilled model.
|
||||
|
||||
Currently, the robot will not actually move its motors. You may see the ros topics. If you want to let the robot move, you can add argument `--nodryrun` in the command line, but **be careful**.
|
Binary file not shown.
|
@ -0,0 +1,183 @@
|
|||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from unitree_ros2_real import UnitreeRos2Real, get_euler_xyz
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from rsl_rl import modules
|
||||
|
||||
class ZeroActModel(torch.nn.Module):
|
||||
def __init__(self, angle_tolerance= 0.15, delta= 0.2):
|
||||
super().__init__()
|
||||
self.angle_tolerance = angle_tolerance
|
||||
self.delta = delta
|
||||
|
||||
def forward(self, dof_pos):
|
||||
target = torch.zeros_like(dof_pos)
|
||||
diff = dof_pos - target
|
||||
diff_large_mask = torch.abs(diff) > self.angle_tolerance
|
||||
target[diff_large_mask] = dof_pos[diff_large_mask] \
|
||||
- self.delta * torch.sign(diff[diff_large_mask])
|
||||
return target
|
||||
|
||||
class Go2Node(UnitreeRos2Real):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, robot_class_name= "Go2", **kwargs)
|
||||
|
||||
def register_models(self, stand_model, task_model, task_policy):
|
||||
self.stand_model = stand_model
|
||||
self.task_model = task_model
|
||||
|
||||
self.task_policy = task_policy
|
||||
self.use_stand_policy = True # Start with standing model
|
||||
|
||||
def start_main_loop_timer(self, duration):
|
||||
self.main_loop_timer = self.create_timer(
|
||||
duration, # in sec
|
||||
self.main_loop,
|
||||
)
|
||||
|
||||
def main_loop(self):
|
||||
if (self.joy_stick_buffer.keys & self.WirelessButtons.L1) and self.use_stand_policy:
|
||||
self.get_logger().info("L1 pressed, stop using stand policy")
|
||||
self.use_stand_policy = False
|
||||
if self.use_stand_policy:
|
||||
obs = self._get_dof_pos_obs() # do not multiply by obs_scales["dof_pos"]
|
||||
action = self.stand_model(obs)
|
||||
if (action == 0).all():
|
||||
self.get_logger().info("All actions are zero, it's time to switch to the policy", throttle_duration_sec= 1)
|
||||
# else:
|
||||
# print("maximum dof error: {:.3f}".format(action.abs().max().item(), end= "\r"))
|
||||
self.send_action(action / self.action_scale)
|
||||
else:
|
||||
# start_time = time.monotonic()
|
||||
obs = self.get_obs()
|
||||
# obs_time = time.monotonic()
|
||||
action = self.task_policy(obs)
|
||||
# policy_time = time.monotonic()
|
||||
self.send_action(action)
|
||||
# self.send_action(self._get_dof_pos_obs() / self.action_scale)
|
||||
# publish_time = time.monotonic()
|
||||
# print(
|
||||
# "obs_time: {:.5f}".format(obs_time - start_time),
|
||||
# "policy_time: {:.5f}".format(policy_time - obs_time),
|
||||
# "publish_time: {:.5f}".format(publish_time - policy_time),
|
||||
# )
|
||||
if (self.joy_stick_buffer.keys & self.WirelessButtons.Y):
|
||||
self.get_logger().info("Y pressed, reset the policy")
|
||||
self.task_model.reset()
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
rclpy.init()
|
||||
|
||||
assert args.logdir is not None, "Please provide a logdir"
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
# modify the config_dict if needed
|
||||
config_dict["control"]["computer_clip_torque"] = True
|
||||
|
||||
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
|
||||
device = "cuda"
|
||||
|
||||
env_node = Go2Node(
|
||||
"go2",
|
||||
# low_cmd_topic= "low_cmd_dryrun", # for the dryrun safety
|
||||
cfg= config_dict,
|
||||
replace_obs_with_embeddings= ["forward_depth"],
|
||||
model_device= device,
|
||||
dryrun= not args.nodryrun,
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs = env_node.num_obs,
|
||||
num_critic_obs = env_node.num_privileged_obs,
|
||||
num_actions= env_node.num_actions,
|
||||
obs_segments= env_node.obs_segments,
|
||||
privileged_obs_segments= env_node.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
# load the model with the latest checkpoint
|
||||
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
env_node.get_logger().info("Model loaded from: {}".format(osp.join(args.logdir, model_names[-1])))
|
||||
env_node.get_logger().info("Control Duration: {} sec".format(duration))
|
||||
env_node.get_logger().info("Motor Stiffness (kp): {}".format(env_node.p_gains))
|
||||
env_node.get_logger().info("Motor Damping (kd): {}".format(env_node.d_gains))
|
||||
|
||||
|
||||
# zero_act_model to start the safe standing
|
||||
zero_act_model = ZeroActModel()
|
||||
zero_act_model = torch.jit.script(zero_act_model)
|
||||
|
||||
# magically modify the model to use the components other than the forward depth encoders
|
||||
memory_a = model.memory_a
|
||||
mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy(obs: torch.Tensor):
|
||||
rnn_embedding = memory_a(obs)
|
||||
action = mlp(rnn_embedding)
|
||||
return action
|
||||
if hasattr(model, "replace_state_prob"):
|
||||
# the case where lin_vel is estimated by the state estimator
|
||||
memory_s = model.memory_s
|
||||
estimator = model.state_estimator
|
||||
rnn_policy = policy
|
||||
@torch.jit.script
|
||||
def policy(obs: torch.Tensor):
|
||||
estimator_input = obs[:, 3:48]
|
||||
memory_s_embedding = memory_s(estimator_input)
|
||||
estimated_state = estimator(memory_s_embedding)
|
||||
obs[:, :3] = estimated_state
|
||||
return rnn_policy(obs)
|
||||
|
||||
env_node.register_models(
|
||||
zero_act_model,
|
||||
model,
|
||||
policy,
|
||||
)
|
||||
env_node.start_ros_handlers()
|
||||
if args.loop_mode == "while":
|
||||
rclpy.spin_once(env_node, timeout_sec= 0.)
|
||||
env_node.get_logger().info("Model and Policy are ready")
|
||||
while rclpy.ok():
|
||||
main_loop_time = time.monotonic()
|
||||
env_node.main_loop()
|
||||
rclpy.spin_once(env_node, timeout_sec= 0.)
|
||||
# env_node.get_logger().info("loop time: {:f}".format((time.monotonic() - main_loop_time)))
|
||||
time.sleep(max(0, duration - (time.monotonic() - main_loop_time)))
|
||||
elif args.loop_mode == "timer":
|
||||
env_node.start_main_loop_timer(duration)
|
||||
rclpy.spin(env_node)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--logdir", type= str, default= None, help= "The directory which contains the config.json and model_*.pt files")
|
||||
parser.add_argument("--nodryrun", action= "store_true", default= False, help= "Disable dryrun mode")
|
||||
parser.add_argument("--loop_mode", type= str, default= "timer",
|
||||
choices= ["while", "timer"],
|
||||
help= "Select which mode to run the main policy control iteration",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,383 @@
|
|||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from unitree_ros2_real import UnitreeRos2Real
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
from sensor_msgs.msg import Image, CameraInfo
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from rsl_rl import modules
|
||||
|
||||
import pyrealsense2 as rs
|
||||
import ros2_numpy as rnp
|
||||
|
||||
@torch.no_grad()
|
||||
def resize2d(img, size):
|
||||
return (F.adaptive_avg_pool2d(Variable(img), size)).data
|
||||
|
||||
class VisualHandlerNode(Node):
|
||||
""" A wapper class for the realsense camera """
|
||||
def __init__(self,
|
||||
cfg: dict,
|
||||
cropping: list = [0, 0, 0, 0], # top, bottom, left, right
|
||||
rs_resolution: tuple = (480, 270), # width, height for the realsense camera)
|
||||
rs_fps: int= 30,
|
||||
depth_input_topic= "/camera/forward_depth",
|
||||
rgb_topic= "/camera/forward_rgb",
|
||||
camera_info_topic= "/camera/camera_info",
|
||||
enable_rgb= False,
|
||||
forward_depth_embedding_topic= "/forward_depth_embedding",
|
||||
):
|
||||
super().__init__("forward_depth_embedding")
|
||||
self.cfg = cfg
|
||||
self.cropping = cropping
|
||||
self.rs_resolution = rs_resolution
|
||||
self.rs_fps = rs_fps
|
||||
self.depth_input_topic = depth_input_topic
|
||||
self.rgb_topic= rgb_topic
|
||||
self.camera_info_topic = camera_info_topic
|
||||
self.enable_rgb= enable_rgb
|
||||
self.forward_depth_embedding_topic = forward_depth_embedding_topic
|
||||
|
||||
self.parse_args()
|
||||
self.start_pipeline()
|
||||
self.start_ros_handlers()
|
||||
|
||||
def parse_args(self):
|
||||
self.output_resolution = self.cfg["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
self.cfg["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
depth_range = self.cfg["sensor"]["forward_camera"].get(
|
||||
"depth_range",
|
||||
[0.0, 3.0],
|
||||
)
|
||||
self.depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
|
||||
|
||||
def start_pipeline(self):
|
||||
self.rs_pipeline = rs.pipeline()
|
||||
self.rs_config = rs.config()
|
||||
self.rs_config.enable_stream(
|
||||
rs.stream.depth,
|
||||
self.rs_resolution[0],
|
||||
self.rs_resolution[1],
|
||||
rs.format.z16,
|
||||
self.rs_fps,
|
||||
)
|
||||
if self.enable_rgb:
|
||||
self.rs_config.enable_stream(
|
||||
rs.stream.color,
|
||||
self.rs_resolution[0],
|
||||
self.rs_resolution[1],
|
||||
rs.format.rgb8,
|
||||
self.rs_fps,
|
||||
)
|
||||
self.rs_profile = self.rs_pipeline.start(self.rs_config)
|
||||
|
||||
self.rs_align = rs.align(rs.stream.depth)
|
||||
|
||||
# build rs builtin filters
|
||||
# self.rs_decimation_filter = rs.decimation_filter()
|
||||
# self.rs_decimation_filter.set_option(rs.option.filter_magnitude, 6)
|
||||
self.rs_hole_filling_filter = rs.hole_filling_filter()
|
||||
self.rs_spatial_filter = rs.spatial_filter()
|
||||
self.rs_spatial_filter.set_option(rs.option.filter_magnitude, 5)
|
||||
self.rs_spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
self.rs_spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
self.rs_spatial_filter.set_option(rs.option.holes_fill, 4)
|
||||
self.rs_temporal_filter = rs.temporal_filter()
|
||||
self.rs_temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
self.rs_temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
# using a list of filters to define the filtering order
|
||||
self.rs_filters = [
|
||||
# self.rs_decimation_filter,
|
||||
self.rs_hole_filling_filter,
|
||||
self.rs_spatial_filter,
|
||||
self.rs_temporal_filter,
|
||||
]
|
||||
|
||||
if self.enable_rgb:
|
||||
# get frame with longer waiting time to start the system
|
||||
# I know what's going on, but when enabling rgb, this solves the problem.
|
||||
rs_frame = self.rs_pipeline.wait_for_frames(int(
|
||||
self.cfg["sensor"]["forward_camera"]["latency_range"][1] * 10000 # ms * 10
|
||||
))
|
||||
|
||||
def start_ros_handlers(self):
|
||||
self.depth_input_pub = self.create_publisher(
|
||||
Image,
|
||||
self.depth_input_topic,
|
||||
1,
|
||||
)
|
||||
if self.enable_rgb:
|
||||
self.rgb_pub = self.create_publisher(
|
||||
Image,
|
||||
self.rgb_topic,
|
||||
1,
|
||||
)
|
||||
self.camera_info_pub = self.create_publisher(
|
||||
CameraInfo,
|
||||
self.camera_info_topic,
|
||||
1,
|
||||
)
|
||||
# fill in critical info of processed camera info based on simulated data
|
||||
# NOTE: simply because realsense's camera_info does not match our network input.
|
||||
# It is easier to compute this way.
|
||||
self.camera_info_msg = CameraInfo()
|
||||
self.camera_info_msg.header.frame_id = "d435_sim_depth_link"
|
||||
self.camera_info_msg.height = self.output_resolution[0]
|
||||
self.camera_info_msg.width = self.output_resolution[1]
|
||||
self.camera_info_msg.distortion_model = "plumb_bob"
|
||||
self.camera_info_msg.d = [0., 0., 0., 0., 0.]
|
||||
sim_raw_resolution = self.cfg["sensor"]["forward_camera"]["resolution"]
|
||||
sim_cropping_h = self.cfg["sensor"]["forward_camera"]["crop_top_bottom"]
|
||||
sim_cropping_w = self.cfg["sensor"]["forward_camera"]["crop_left_right"]
|
||||
cropped_resolution = [ # (H, W)
|
||||
sim_raw_resolution[0] - sum(sim_cropping_h),
|
||||
sim_raw_resolution[1] - sum(sim_cropping_w),
|
||||
]
|
||||
network_input_resolution = self.cfg["sensor"]["forward_camera"]["output_resolution"]
|
||||
x_fov = sum(self.cfg["sensor"]["forward_camera"]["horizontal_fov"]) / 2 / 180 * np.pi
|
||||
fx = (sim_raw_resolution[1]) / (2 * np.tan(x_fov / 2))
|
||||
fy = fx
|
||||
fx = fx * network_input_resolution[1] / cropped_resolution[1]
|
||||
fy = fy * network_input_resolution[0] / cropped_resolution[0]
|
||||
cx = (sim_raw_resolution[1] / 2) - sim_cropping_w[0]
|
||||
cy = (sim_raw_resolution[0] / 2) - sim_cropping_h[0]
|
||||
cx = cx * network_input_resolution[1] / cropped_resolution[1]
|
||||
cy = cy * network_input_resolution[0] / cropped_resolution[0]
|
||||
self.camera_info_msg.k = [
|
||||
fx, 0., cx,
|
||||
0., fy, cy,
|
||||
0., 0., 1.,
|
||||
]
|
||||
self.camera_info_msg.r = [1., 0., 0., 0., 1., 0., 0., 0., 1.]
|
||||
self.camera_info_msg.p = [
|
||||
fx, 0., cx, 0.,
|
||||
0., fy, cy, 0.,
|
||||
0., 0., 1., 0.,
|
||||
]
|
||||
self.camera_info_msg.binning_x = 0
|
||||
self.camera_info_msg.binning_y = 0
|
||||
self.camera_info_msg.roi.do_rectify = False
|
||||
self.create_timer(
|
||||
self.cfg["sensor"]["forward_camera"]["refresh_duration"],
|
||||
self.publish_camera_info_callback,
|
||||
)
|
||||
|
||||
self.forward_depth_embedding_pub = self.create_publisher(
|
||||
Float32MultiArray,
|
||||
self.forward_depth_embedding_topic,
|
||||
1,
|
||||
)
|
||||
self.get_logger().info("ros handlers started")
|
||||
|
||||
def publish_camera_info_callback(self):
|
||||
self.camera_info_msg.header.stamp = self.get_clock().now().to_msg()
|
||||
self.get_logger().info("camera info published", once= True)
|
||||
self.camera_info_pub.publish(self.camera_info_msg)
|
||||
|
||||
def get_depth_frame(self):
|
||||
# read from pyrealsense2, preprocess and write the model embedding to the buffer
|
||||
rs_frame = self.rs_pipeline.wait_for_frames(int(
|
||||
self.cfg["sensor"]["forward_camera"]["latency_range"][1] * 1000 # ms
|
||||
))
|
||||
if self.enable_rgb:
|
||||
rs_frame = self.rs_align.process(rs_frame)
|
||||
depth_frame = rs_frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
self.get_logger().error("No depth frame", throttle_duration_sec= 1)
|
||||
return
|
||||
color_frame = rs_frame.get_color_frame()
|
||||
if color_frame:
|
||||
rgb_image_np = np.asanyarray(color_frame.get_data())
|
||||
rgb_image_np = np.rot90(rgb_image_np, k= 2) # since the camera is inverted
|
||||
rgb_image_np = rgb_image_np[
|
||||
self.cropping[0]: -self.cropping[1]-1,
|
||||
self.cropping[2]: -self.cropping[3]-1,
|
||||
]
|
||||
rgb_image_msg = rnp.msgify(Image, rgb_image_np, encoding= "rgb8")
|
||||
rgb_image_msg.header.stamp = self.get_clock().now().to_msg()
|
||||
rgb_image_msg.header.frame_id = "d435_sim_depth_link"
|
||||
self.rgb_pub.publish(rgb_image_msg)
|
||||
self.get_logger().info("rgb image published", once= True)
|
||||
|
||||
# apply relsense filters
|
||||
for rs_filter in self.rs_filters:
|
||||
depth_frame = rs_filter.process(depth_frame)
|
||||
depth_image_np = np.asanyarray(depth_frame.get_data())
|
||||
# rotate 180 degree because d435i on h1 head is mounted inverted
|
||||
depth_image_np = np.rot90(depth_image_np, k= 2) # k = 2 for rotate 90 degree twice
|
||||
depth_image_pyt = torch.from_numpy(depth_image_np.astype(np.float32)).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# apply torch filters
|
||||
depth_image_pyt = depth_image_pyt[:, :,
|
||||
self.cropping[0]: -self.cropping[1]-1,
|
||||
self.cropping[2]: -self.cropping[3]-1,
|
||||
]
|
||||
depth_image_pyt = torch.clip(depth_image_pyt, self.depth_range[0], self.depth_range[1]) / (self.depth_range[1] - self.depth_range[0])
|
||||
depth_image_pyt = resize2d(depth_image_pyt, self.output_resolution)
|
||||
|
||||
# publish the depth image input to ros topic
|
||||
self.get_logger().info("depth range: {}-{}".format(*self.depth_range), once= True)
|
||||
depth_input_data = (
|
||||
depth_image_pyt.detach().cpu().numpy() * (self.depth_range[1] - self.depth_range[0]) + self.depth_range[0]
|
||||
).astype(np.uint16)[0, 0] # (h, w) unit [mm]
|
||||
# DEBUG: centering the depth image
|
||||
# depth_input_data = depth_input_data.copy()
|
||||
# depth_input_data[int(depth_input_data.shape[0] / 2), :] = 0
|
||||
# depth_input_data[:, int(depth_input_data.shape[1] / 2)] = 0
|
||||
|
||||
depth_input_msg = rnp.msgify(Image, depth_input_data, encoding= "16UC1")
|
||||
depth_input_msg.header.stamp = self.get_clock().now().to_msg()
|
||||
depth_input_msg.header.frame_id = "d435_sim_depth_link"
|
||||
self.depth_input_pub.publish(depth_input_msg)
|
||||
self.get_logger().info("depth input published", once= True)
|
||||
|
||||
return depth_image_pyt
|
||||
|
||||
def publish_depth_embedding(self, embedding):
|
||||
msg = Float32MultiArray()
|
||||
msg.data = embedding.squeeze().detach().cpu().numpy().tolist()
|
||||
self.forward_depth_embedding_pub.publish(msg)
|
||||
self.get_logger().info("depth embedding published", once= True)
|
||||
|
||||
def register_models(self, visual_encoder):
|
||||
self.visual_encoder = visual_encoder
|
||||
|
||||
def start_main_loop_timer(self, duration):
|
||||
self.create_timer(
|
||||
duration,
|
||||
self.main_loop,
|
||||
)
|
||||
|
||||
def main_loop(self):
|
||||
depth_image_pyt = self.get_depth_frame()
|
||||
if depth_image_pyt is not None:
|
||||
embedding = self.visual_encoder(depth_image_pyt)
|
||||
self.publish_depth_embedding(embedding)
|
||||
else:
|
||||
self.get_logger().warn("One frame of depth embedding if not acquired")
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
rclpy.init()
|
||||
|
||||
assert args.logdir is not None, "Please provide a logdir"
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
device = "cpu"
|
||||
duration = config_dict["sensor"]["forward_camera"]["refresh_duration"] # in sec
|
||||
|
||||
visual_node = VisualHandlerNode(
|
||||
cfg= json.load(open(osp.join(args.logdir, "config.json"), "r")),
|
||||
cropping= [args.crop_top, args.crop_bottom, args.crop_left, args.crop_right],
|
||||
rs_resolution= (args.width, args.height),
|
||||
rs_fps= args.fps,
|
||||
enable_rgb= args.rgb,
|
||||
)
|
||||
|
||||
env_node = UnitreeRos2Real(
|
||||
"visual_h1",
|
||||
low_cmd_topic= "low_cmd_dryrun", # This node should not publish any command at all
|
||||
cfg= config_dict,
|
||||
model_device= device,
|
||||
robot_class_name= "Go2",
|
||||
dryrun= True, # The robot node in this process should not run at all
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs = env_node.num_obs,
|
||||
num_critic_obs = env_node.num_privileged_obs,
|
||||
num_actions= env_node.num_actions,
|
||||
obs_segments= env_node.obs_segments,
|
||||
privileged_obs_segments= env_node.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
# load the model with the latest checkpoint
|
||||
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(device)
|
||||
model = model.encoders[0] # the first encoder is the visual encoder
|
||||
env_node.destroy_node()
|
||||
|
||||
visual_node.get_logger().info("Embedding send duration: {:.2f} sec".format(duration))
|
||||
visual_node.register_models(model)
|
||||
if args.loop_mode == "while":
|
||||
rclpy.spin_once(visual_node, timeout_sec= 0.)
|
||||
while rclpy.ok():
|
||||
main_loop_time = time.monotonic()
|
||||
visual_node.main_loop()
|
||||
rclpy.spin_once(visual_node, timeout_sec= 0.)
|
||||
time.sleep(max(0, duration - (time.monotonic() - main_loop_time)))
|
||||
elif args.loop_mode == "timer":
|
||||
visual_node.start_main_loop_timer(duration)
|
||||
rclpy.spin(visual_node)
|
||||
|
||||
visual_node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--logdir", type= str, default= None, help= "The directory which contains the config.json and model_*.pt files")
|
||||
|
||||
parser.add_argument("--height",
|
||||
type= int,
|
||||
default= 480,
|
||||
help= "The height of the realsense image",
|
||||
)
|
||||
parser.add_argument("--width",
|
||||
type= int,
|
||||
default= 640,
|
||||
help= "The width of the realsense image",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type= int,
|
||||
default= 30,
|
||||
help= "The fps request to the rs pipeline",
|
||||
)
|
||||
parser.add_argument("--crop_left",
|
||||
type= int,
|
||||
default= 28,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_right",
|
||||
type= int,
|
||||
default= 36,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_top",
|
||||
type= int,
|
||||
default= 48,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_bottom",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--rgb",
|
||||
action= "store_true",
|
||||
default= False,
|
||||
help= "Set to enable rgb visualization",
|
||||
)
|
||||
parser.add_argument("--loop_mode", type= str, default= "timer",
|
||||
choices= ["while", "timer"],
|
||||
help= "Select which mode to run the main policy control iteration",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,633 @@
|
|||
import os, sys
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from unitree_go.msg import (
|
||||
WirelessController,
|
||||
LowState,
|
||||
# MotorState,
|
||||
# IMUState,
|
||||
LowCmd,
|
||||
# MotorCmd,
|
||||
)
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
if os.uname().machine in ["x86_64", "amd64"]:
|
||||
sys.path.append(os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"x86",
|
||||
))
|
||||
elif os.uname().machine == "aarch64":
|
||||
sys.path.append(os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"aarch64",
|
||||
))
|
||||
from crc_module import get_crc
|
||||
|
||||
from multiprocessing import Process
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
def quat_from_euler_xyz(roll, pitch, yaw):
|
||||
cy = torch.cos(yaw * 0.5)
|
||||
sy = torch.sin(yaw * 0.5)
|
||||
cr = torch.cos(roll * 0.5)
|
||||
sr = torch.sin(roll * 0.5)
|
||||
cp = torch.cos(pitch * 0.5)
|
||||
sp = torch.sin(pitch * 0.5)
|
||||
|
||||
qw = cy * cr * cp + sy * sr * sp
|
||||
qx = cy * sr * cp - sy * cr * sp
|
||||
qy = cy * cr * sp + sy * sr * cp
|
||||
qz = sy * cr * cp - cy * sr * sp
|
||||
|
||||
return torch.stack([qx, qy, qz, qw], dim=-1)
|
||||
|
||||
@torch.jit.script
|
||||
def quat_rotate_inverse(q, v):
|
||||
""" q must be in x, y, z, w order """
|
||||
shape = q.shape
|
||||
q_w = q[:, -1]
|
||||
q_vec = q[:, :3]
|
||||
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
|
||||
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
|
||||
c = q_vec * \
|
||||
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
|
||||
shape[0], 3, 1)).squeeze(-1) * 2.0
|
||||
return a - b + c
|
||||
|
||||
@torch.jit.script
|
||||
def copysign(a, b):
|
||||
# type: (float, Tensor) -> Tensor
|
||||
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
|
||||
return torch.abs(a) * torch.sign(b)
|
||||
|
||||
@torch.jit.script
|
||||
def get_euler_xyz(q):
|
||||
qx, qy, qz, qw = 0, 1, 2, 3
|
||||
# roll (x-axis rotation)
|
||||
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
|
||||
cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * \
|
||||
q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz]
|
||||
roll = torch.atan2(sinr_cosp, cosr_cosp)
|
||||
|
||||
# pitch (y-axis rotation)
|
||||
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
|
||||
pitch = torch.where(torch.abs(sinp) >= 1, copysign(
|
||||
np.pi / 2.0, sinp), torch.asin(sinp))
|
||||
|
||||
# yaw (z-axis rotation)
|
||||
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
|
||||
cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * \
|
||||
q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz]
|
||||
yaw = torch.atan2(siny_cosp, cosy_cosp)
|
||||
|
||||
return roll % (2*np.pi), pitch % (2*np.pi), yaw % (2*np.pi)
|
||||
|
||||
class RobotCfgs:
|
||||
class H1:
|
||||
pass
|
||||
|
||||
class Go2:
|
||||
NUM_DOF = 12
|
||||
NUM_ACTIONS = 12
|
||||
dof_map = [ # from isaacgym simulation joint order to real robot joint order
|
||||
3, 4, 5,
|
||||
0, 1, 2,
|
||||
9, 10, 11,
|
||||
6, 7, 8,
|
||||
]
|
||||
dof_names = [ # NOTE: order matters. This list is the order in simulation.
|
||||
"FL_hip_joint",
|
||||
"FL_thigh_joint",
|
||||
"FL_calf_joint",
|
||||
"FR_hip_joint",
|
||||
"FR_thigh_joint",
|
||||
"FR_calf_joint",
|
||||
"RL_hip_joint",
|
||||
"RL_thigh_joint",
|
||||
"RL_calf_joint",
|
||||
"RR_hip_joint",
|
||||
"RR_thigh_joint",
|
||||
"RR_calf_joint",
|
||||
]
|
||||
dof_signs = [1.] * 12
|
||||
joint_limits_high = torch.tensor([
|
||||
1.0472, 3.4907, -0.83776,
|
||||
1.0472, 3.4907, -0.83776,
|
||||
1.0472, 4.5379, -0.83776,
|
||||
1.0472, 4.5379, -0.83776,
|
||||
], device= "cpu", dtype= torch.float32)
|
||||
joint_limits_low = torch.tensor([
|
||||
-1.0472, -1.5708, -2.7227,
|
||||
-1.0472, -1.5708, -2.7227,
|
||||
-1.0472, -0.5236, -2.7227,
|
||||
-1.0472, -0.5236, -2.7227,
|
||||
], device= "cpu", dtype= torch.float32)
|
||||
torque_limits = torch.tensor([ # from urdf and in simulation order
|
||||
25, 40, 40,
|
||||
25, 40, 40,
|
||||
25, 40, 40,
|
||||
25, 40, 40,
|
||||
], device= "cpu", dtype= torch.float32)
|
||||
turn_on_motor_mode = [0x01] * 12
|
||||
|
||||
|
||||
class UnitreeRos2Real(Node):
|
||||
""" A proxy implementation of the real H1 robot. """
|
||||
class WirelessButtons:
|
||||
R1 = 0b00000001 # 1
|
||||
L1 = 0b00000010 # 2
|
||||
start = 0b00000100 # 4
|
||||
select = 0b00001000 # 8
|
||||
R2 = 0b00010000 # 16
|
||||
L2 = 0b00100000 # 32
|
||||
F1 = 0b01000000 # 64
|
||||
F2 = 0b10000000 # 128
|
||||
A = 0b100000000 # 256
|
||||
B = 0b1000000000 # 512
|
||||
X = 0b10000000000 # 1024
|
||||
Y = 0b100000000000 # 2048
|
||||
up = 0b1000000000000 # 4096
|
||||
right = 0b10000000000000 # 8192
|
||||
down = 0b100000000000000 # 16384
|
||||
left = 0b1000000000000000 # 32768
|
||||
|
||||
def __init__(self,
|
||||
robot_namespace= None,
|
||||
low_state_topic= "/lowstate",
|
||||
low_cmd_topic= "/lowcmd",
|
||||
joy_stick_topic= "/wirelesscontroller",
|
||||
forward_depth_topic= None, # if None and still need access, set to str "pyrealsense"
|
||||
forward_depth_embedding_topic= "/forward_depth_embedding",
|
||||
cfg= dict(),
|
||||
lin_vel_deadband= 0.1,
|
||||
ang_vel_deadband= 0.1,
|
||||
cmd_px_range= [0.4, 1.0], # check joy_stick_callback (p for positive, n for negative)
|
||||
cmd_nx_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
|
||||
cmd_py_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
|
||||
cmd_ny_range= [0.4, 0.8], # check joy_stick_callback (p for positive, n for negative)
|
||||
cmd_pyaw_range= [0.4, 1.6], # check joy_stick_callback (p for positive, n for negative)
|
||||
cmd_nyaw_range= [0.4, 1.6], # check joy_stick_callback (p for positive, n for negative)
|
||||
replace_obs_with_embeddings= [], # a list of strings, e.g. ["forward_depth"] then the corrseponding obs will be processed by _get_forward_depth_embedding_obs()
|
||||
move_by_wireless_remote= True, # if True, the robot will be controlled by a wireless remote
|
||||
model_device= "cpu",
|
||||
dof_pos_protect_ratio= 1.1, # if the dof_pos is out of the range of this ratio, the process will shutdown.
|
||||
robot_class_name= "H1",
|
||||
dryrun= True, # if True, the robot will not send commands to the real robot
|
||||
):
|
||||
super().__init__("unitree_ros2_real")
|
||||
self.NUM_DOF = getattr(RobotCfgs, robot_class_name).NUM_DOF
|
||||
self.NUM_ACTIONS = getattr(RobotCfgs, robot_class_name).NUM_ACTIONS
|
||||
self.robot_namespace = robot_namespace
|
||||
self.low_state_topic = low_state_topic
|
||||
# Generate a unique cmd topic so that the low_cmd will not send to the robot's motor.
|
||||
self.low_cmd_topic = low_cmd_topic if not dryrun else low_cmd_topic + "_dryrun_" + str(np.random.randint(0, 65535))
|
||||
self.joy_stick_topic = joy_stick_topic
|
||||
self.forward_depth_topic = forward_depth_topic
|
||||
self.forward_depth_embedding_topic = forward_depth_embedding_topic
|
||||
self.cfg = cfg
|
||||
self.lin_vel_deadband = lin_vel_deadband
|
||||
self.ang_vel_deadband = ang_vel_deadband
|
||||
self.cmd_px_range = cmd_px_range
|
||||
self.cmd_nx_range = cmd_nx_range
|
||||
self.cmd_py_range = cmd_py_range
|
||||
self.cmd_ny_range = cmd_ny_range
|
||||
self.cmd_pyaw_range = cmd_pyaw_range
|
||||
self.cmd_nyaw_range = cmd_nyaw_range
|
||||
self.replace_obs_with_embeddings = replace_obs_with_embeddings
|
||||
self.move_by_wireless_remote = move_by_wireless_remote
|
||||
self.model_device = model_device
|
||||
self.dof_pos_protect_ratio = dof_pos_protect_ratio
|
||||
self.robot_class_name = robot_class_name
|
||||
self.dryrun = dryrun
|
||||
|
||||
self.dof_map = getattr(RobotCfgs, robot_class_name).dof_map
|
||||
self.dof_names = getattr(RobotCfgs, robot_class_name).dof_names
|
||||
self.dof_signs = getattr(RobotCfgs, robot_class_name).dof_signs
|
||||
self.turn_on_motor_mode = getattr(RobotCfgs, robot_class_name).turn_on_motor_mode
|
||||
|
||||
self.parse_config()
|
||||
|
||||
def parse_config(self):
|
||||
""" parse, set attributes from config dict, initialize buffers to speed up the computation """
|
||||
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
|
||||
self.gravity_vec = torch.zeros((1, 3), device= self.model_device, dtype= torch.float32)
|
||||
self.gravity_vec[:, self.up_axis_idx] = -1
|
||||
|
||||
# observations
|
||||
self.clip_obs = self.cfg["normalization"]["clip_observations"]
|
||||
self.obs_scales = self.cfg["normalization"]["obs_scales"]
|
||||
for k, v in self.obs_scales.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
self.obs_scales[k] = torch.tensor(v, device= self.model_device, dtype= torch.float32)
|
||||
# check whether there are embeddings in obs_components and launch encoder process later
|
||||
if len(self.replace_obs_with_embeddings) > 0:
|
||||
for comp in self.replace_obs_with_embeddings:
|
||||
self.get_logger().warn(f"{comp} will be replaced with its embedding when get_obs, don't forget to launch the corresponding process before running the policy.")
|
||||
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
|
||||
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
|
||||
if "privileged_obs_components" in self.cfg["env"].keys():
|
||||
self.privileged_obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["privileged_obs_components"])
|
||||
self.num_privileged_obs = self.get_num_obs_from_components(self.cfg["env"]["privileged_obs_components"])
|
||||
for obs_component in self.cfg["env"]["obs_components"]:
|
||||
if "orientation_cmds" in obs_component:
|
||||
self.roll_pitch_yaw_cmd = torch.zeros(1, 3, device= self.model_device, dtype= torch.float32)
|
||||
|
||||
# controls
|
||||
self.control_type = self.cfg["control"]["control_type"]
|
||||
if not (self.control_type == "P"):
|
||||
raise NotImplementedError("Only position control is supported for now.")
|
||||
self.p_gains = []
|
||||
for i in range(self.NUM_DOF):
|
||||
name = self.dof_names[i] # set p_gains in simulation order
|
||||
for k, v in self.cfg["control"]["stiffness"].items():
|
||||
if k in name:
|
||||
self.p_gains.append(v)
|
||||
break # only one match
|
||||
self.p_gains = torch.tensor(self.p_gains, device= self.model_device, dtype= torch.float32)
|
||||
self.d_gains = []
|
||||
for i in range(self.NUM_DOF):
|
||||
name = self.dof_names[i] # set d_gains in simulation order
|
||||
for k, v in self.cfg["control"]["damping"].items():
|
||||
if k in name:
|
||||
self.d_gains.append(v)
|
||||
break
|
||||
self.d_gains = torch.tensor(self.d_gains, device= self.model_device, dtype= torch.float32)
|
||||
self.default_dof_pos = torch.zeros(self.NUM_DOF, device= self.model_device, dtype= torch.float32)
|
||||
self.dof_pos_ = torch.empty(1, self.NUM_DOF, device= self.model_device, dtype= torch.float32)
|
||||
self.dof_vel_ = torch.empty(1, self.NUM_DOF, device= self.model_device, dtype= torch.float32)
|
||||
for i in range(self.NUM_DOF):
|
||||
name = self.dof_names[i]
|
||||
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
|
||||
# in simulation order.
|
||||
self.default_dof_pos[i] = default_joint_angle
|
||||
self.computer_clip_torque = self.cfg["control"].get("computer_clip_torque", True)
|
||||
self.get_logger().info("Computer Clip Torque (onboard) is " + str(self.computer_clip_torque))
|
||||
self.torque_limits = getattr(RobotCfgs, self.robot_class_name).torque_limits.to(self.model_device)
|
||||
if self.computer_clip_torque:
|
||||
assert hasattr(self, "torque_limits") and (len(self.torque_limits) == self.NUM_DOF), f"torque_limits must be set with the length of {self.NUM_DOF} if computer_clip_torque is True"
|
||||
self.get_logger().info("[Env] torque limit: " + ",".join("{:.1f}".format(x) for x in self.torque_limits))
|
||||
|
||||
# actions
|
||||
self.num_actions = self.NUM_ACTIONS
|
||||
self.action_scale = self.cfg["control"]["action_scale"]
|
||||
self.get_logger().info("[Env] action scale: {:.1f}".format(self.action_scale))
|
||||
self.clip_actions = self.cfg["normalization"]["clip_actions"]
|
||||
if self.cfg["normalization"].get("clip_actions_method", None) == "hard":
|
||||
self.get_logger().info("clip_actions_method with hard mode")
|
||||
self.get_logger().info("clip_actions_high: " + str(self.cfg["normalization"]["clip_actions_high"]))
|
||||
self.get_logger().info("clip_actions_low: " + str(self.cfg["normalization"]["clip_actions_low"]))
|
||||
self.clip_actions_method = "hard"
|
||||
self.clip_actions_low = torch.tensor(self.cfg["normalization"]["clip_actions_low"], device= self.model_device, dtype= torch.float32)
|
||||
self.clip_actions_high = torch.tensor(self.cfg["normalization"]["clip_actions_high"], device= self.model_device, dtype= torch.float32)
|
||||
else:
|
||||
self.get_logger().info("clip_actions_method is " + str(self.cfg["normalization"].get("clip_actions_method", None)))
|
||||
self.actions = torch.zeros(self.NUM_ACTIONS, device= self.model_device, dtype= torch.float32)
|
||||
|
||||
|
||||
# hardware related, in simulation order
|
||||
self.joint_limits_high = getattr(RobotCfgs, self.robot_class_name).joint_limits_high.to(self.model_device)
|
||||
self.joint_limits_low = getattr(RobotCfgs, self.robot_class_name).joint_limits_low.to(self.model_device)
|
||||
joint_pos_mid = (self.joint_limits_high + self.joint_limits_low) / 2
|
||||
joint_pos_range = (self.joint_limits_high - self.joint_limits_low) / 2
|
||||
self.joint_pos_protect_high = joint_pos_mid + joint_pos_range * self.dof_pos_protect_ratio
|
||||
self.joint_pos_protect_low = joint_pos_mid - joint_pos_range * self.dof_pos_protect_ratio
|
||||
|
||||
def start_ros_handlers(self):
|
||||
""" after initializing the env and policy, register ros related callbacks and topics
|
||||
"""
|
||||
|
||||
# ROS publishers
|
||||
self.low_cmd_pub = self.create_publisher(
|
||||
LowCmd,
|
||||
self.low_cmd_topic,
|
||||
1
|
||||
)
|
||||
self.low_cmd_buffer = LowCmd()
|
||||
|
||||
# ROS subscribers
|
||||
self.low_state_sub = self.create_subscription(
|
||||
LowState,
|
||||
self.low_state_topic,
|
||||
self._low_state_callback,
|
||||
1
|
||||
)
|
||||
self.joy_stick_sub = self.create_subscription(
|
||||
WirelessController,
|
||||
self.joy_stick_topic,
|
||||
self._joy_stick_callback,
|
||||
1
|
||||
)
|
||||
|
||||
if self.forward_depth_topic is not None:
|
||||
self.forward_camera_sub = self.create_subscription(
|
||||
Image,
|
||||
self.forward_depth_topic,
|
||||
self._forward_depth_callback,
|
||||
1
|
||||
)
|
||||
|
||||
if self.forward_depth_embedding_topic is not None and "forward_depth" in self.replace_obs_with_embeddings:
|
||||
self.forward_depth_embedding_sub = self.create_subscription(
|
||||
Float32MultiArray,
|
||||
self.forward_depth_embedding_topic,
|
||||
self._forward_depth_embedding_callback,
|
||||
1,
|
||||
)
|
||||
|
||||
self.get_logger().info("ROS handlers started, waiting to recieve critical low state and wireless controller messages.")
|
||||
if not self.dryrun:
|
||||
self.get_logger().warn(f"You are running the code in no-dryrun mode and publishing to '{self.low_cmd_topic}', Please keep safe.")
|
||||
else:
|
||||
self.get_logger().warn(f"You are publishing low cmd to '{self.low_cmd_topic}' because of dryrun mode, Please check and be safe.")
|
||||
while rclpy.ok():
|
||||
rclpy.spin_once(self)
|
||||
if hasattr(self, "low_state_buffer") and hasattr(self, "joy_stick_buffer"):
|
||||
break
|
||||
self.get_logger().info("Low state message received, the robot is ready to go.")
|
||||
|
||||
""" ROS callbacks and handlers that update the buffer """
|
||||
def _low_state_callback(self, msg):
|
||||
""" store and handle proprioception data """
|
||||
self.low_state_buffer = msg # keep the latest low state
|
||||
|
||||
# refresh dof_pos and dof_vel
|
||||
for sim_idx in range(self.NUM_DOF):
|
||||
real_idx = self.dof_map[sim_idx]
|
||||
self.dof_pos_[0, sim_idx] = self.low_state_buffer.motor_state[real_idx].q * self.dof_signs[sim_idx]
|
||||
for sim_idx in range(self.NUM_DOF):
|
||||
real_idx = self.dof_map[sim_idx]
|
||||
self.dof_vel_[0, sim_idx] = self.low_state_buffer.motor_state[real_idx].dq * self.dof_signs[sim_idx]
|
||||
# automatic safety check
|
||||
for sim_idx in range(self.NUM_DOF):
|
||||
real_idx = self.dof_map[sim_idx]
|
||||
if self.dof_pos_[0, sim_idx] > self.joint_pos_protect_high[sim_idx] or \
|
||||
self.dof_pos_[0, sim_idx] < self.joint_pos_protect_low[sim_idx]:
|
||||
self.get_logger().error(f"Joint {sim_idx}(sim), {real_idx}(real) position out of range at {self.low_state_buffer.motor_state[real_idx].q}")
|
||||
self.get_logger().error("The motors and this process shuts down.")
|
||||
self._turn_off_motors()
|
||||
raise SystemExit()
|
||||
|
||||
def _joy_stick_callback(self, msg):
|
||||
self.joy_stick_buffer = msg
|
||||
if self.move_by_wireless_remote:
|
||||
# left-y for forward/backward
|
||||
ly = msg.ly
|
||||
if ly > self.lin_vel_deadband:
|
||||
vx = (ly - self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (0, 1)
|
||||
vx = vx * (self.cmd_px_range[1] - self.cmd_px_range[0]) + self.cmd_px_range[0]
|
||||
elif ly < -self.lin_vel_deadband:
|
||||
vx = (ly + self.lin_vel_deadband) / (1 - self.lin_vel_deadband) # (-1, 0)
|
||||
vx = vx * (self.cmd_nx_range[1] - self.cmd_nx_range[0]) - self.cmd_nx_range[0]
|
||||
else:
|
||||
vx = 0
|
||||
# left-x for turning left/right
|
||||
lx = -msg.lx
|
||||
if lx > self.ang_vel_deadband:
|
||||
yaw = (lx - self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
|
||||
yaw = yaw * (self.cmd_pyaw_range[1] - self.cmd_pyaw_range[0]) + self.cmd_pyaw_range[0]
|
||||
elif lx < -self.ang_vel_deadband:
|
||||
yaw = (lx + self.ang_vel_deadband) / (1 - self.ang_vel_deadband)
|
||||
yaw = yaw * (self.cmd_nyaw_range[1] - self.cmd_nyaw_range[0]) - self.cmd_nyaw_range[0]
|
||||
else:
|
||||
yaw = 0
|
||||
# right-x for side moving left/right
|
||||
rx = -msg.rx
|
||||
if rx > self.lin_vel_deadband:
|
||||
vy = (rx - self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
|
||||
vy = vy * (self.cmd_py_range[1] - self.cmd_py_range[0]) + self.cmd_py_range[0]
|
||||
elif rx < -self.lin_vel_deadband:
|
||||
vy = (rx + self.lin_vel_deadband) / (1 - self.lin_vel_deadband)
|
||||
vy = vy * (self.cmd_ny_range[1] - self.cmd_ny_range[0]) - self.cmd_ny_range[0]
|
||||
else:
|
||||
vy = 0
|
||||
self.xyyaw_command = torch.tensor([vx, vy, yaw], device= self.model_device, dtype= torch.float32)
|
||||
|
||||
# refer to Unitree Remote Control data structure, msg.keys is a bit mask
|
||||
# 00000000 00000001 means pressing the 0-th button (R1)
|
||||
# 00000000 00000010 means pressing the 1-th button (L1)
|
||||
# 10000000 00000000 means pressing the 15-th button (left)
|
||||
if (msg.keys & self.WirelessButtons.R2) or (msg.keys & self.WirelessButtons.L2): # R2 or L2 is pressed
|
||||
self.get_logger().warn("R2 or L2 is pressed, the motors and this process shuts down.")
|
||||
self._turn_off_motors()
|
||||
raise SystemExit()
|
||||
|
||||
# roll-pitch target
|
||||
if hasattr(self, "roll_pitch_yaw_cmd"):
|
||||
if (msg.keys & self.WirelessButtons.up):
|
||||
self.roll_pitch_yaw_cmd[0, 1] += 0.1
|
||||
self.get_logger().info("Pitch Command: " + str(self.roll_pitch_yaw_cmd))
|
||||
if (msg.keys & self.WirelessButtons.down):
|
||||
self.roll_pitch_yaw_cmd[0, 1] -= 0.1
|
||||
self.get_logger().info("Pitch Command: " + str(self.roll_pitch_yaw_cmd))
|
||||
if (msg.keys & self.WirelessButtons.left):
|
||||
self.roll_pitch_yaw_cmd[0, 0] -= 0.1
|
||||
self.get_logger().info("Roll Command: " + str(self.roll_pitch_yaw_cmd))
|
||||
if (msg.keys & self.WirelessButtons.right):
|
||||
self.roll_pitch_yaw_cmd[0, 0] += 0.1
|
||||
self.get_logger().info("Roll Command: " + str(self.roll_pitch_yaw_cmd))
|
||||
|
||||
def _forward_depth_callback(self, msg):
|
||||
""" store and handle depth camera data """
|
||||
pass
|
||||
|
||||
def _forward_depth_embedding_callback(self, msg):
|
||||
self.forward_depth_embedding_buffer = torch.tensor(msg.data, device= self.model_device, dtype= torch.float32).view(1, -1)
|
||||
|
||||
""" Done: ROS callbacks and handlers that update the buffer """
|
||||
|
||||
""" refresh observation buffer and corresponding sub-functions """
|
||||
def _get_lin_vel_obs(self):
|
||||
return torch.zeros(1, 3, device= self.model_device, dtype= torch.float32)
|
||||
|
||||
def _get_ang_vel_obs(self):
|
||||
buffer = torch.from_numpy(self.low_state_buffer.imu_state.gyroscope).unsqueeze(0)
|
||||
return buffer
|
||||
|
||||
def _get_projected_gravity_obs(self):
|
||||
quat_xyzw = torch.tensor([
|
||||
self.low_state_buffer.imu_state.quaternion[1],
|
||||
self.low_state_buffer.imu_state.quaternion[2],
|
||||
self.low_state_buffer.imu_state.quaternion[3],
|
||||
self.low_state_buffer.imu_state.quaternion[0],
|
||||
], device= self.model_device, dtype= torch.float32).unsqueeze(0)
|
||||
return quat_rotate_inverse(
|
||||
quat_xyzw,
|
||||
self.gravity_vec,
|
||||
)
|
||||
|
||||
def _get_commands_obs(self):
|
||||
return self.xyyaw_command.unsqueeze(0) # (1, 3)
|
||||
|
||||
def _get_dof_pos_obs(self):
|
||||
return self.dof_pos_ - self.default_dof_pos.unsqueeze(0)
|
||||
|
||||
def _get_dof_vel_obs(self):
|
||||
return self.dof_vel_
|
||||
|
||||
def _get_last_actions_obs(self):
|
||||
return self.actions
|
||||
|
||||
def _get_forward_depth_embedding_obs(self):
|
||||
return self.forward_depth_embedding_buffer
|
||||
|
||||
def _get_forward_depth_obs(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_orientation_cmds_obs(self):
|
||||
return quat_rotate_inverse(
|
||||
quat_from_euler_xyz(self.roll_pitch_yaw_cmd[:, 0], self.roll_pitch_yaw_cmd[:, 1], self.roll_pitch_yaw_cmd[:, 2]),
|
||||
self.gravity_vec,
|
||||
)
|
||||
|
||||
def get_num_obs_from_components(self, components):
|
||||
obs_segments = self.get_obs_segment_from_components(components)
|
||||
num_obs = 0
|
||||
for k, v in obs_segments.items():
|
||||
num_obs += np.prod(v)
|
||||
return num_obs
|
||||
|
||||
def get_obs_segment_from_components(self, components):
|
||||
""" Observation segment is defined as a list of lists/ints defining the tensor shape with
|
||||
corresponding order.
|
||||
"""
|
||||
segments = OrderedDict()
|
||||
if "lin_vel" in components:
|
||||
print("Warning: lin_vel is not typically available or accurate enough on the real robot. Will return zeros.")
|
||||
segments["lin_vel"] = (3,)
|
||||
if "ang_vel" in components:
|
||||
segments["ang_vel"] = (3,)
|
||||
if "projected_gravity" in components:
|
||||
segments["projected_gravity"] = (3,)
|
||||
if "commands" in components:
|
||||
segments["commands"] = (3,)
|
||||
if "dof_pos" in components:
|
||||
segments["dof_pos"] = (self.NUM_DOF,)
|
||||
if "dof_vel" in components:
|
||||
segments["dof_vel"] = (self.NUM_DOF,)
|
||||
if "last_actions" in components:
|
||||
segments["last_actions"] = (self.NUM_ACTIONS,)
|
||||
if "height_measurements" in components:
|
||||
print("Warning: height_measurements is not typically available on the real robot.")
|
||||
segments["height_measurements"] = (1, len(self.cfg["terrain"]["measured_points_x"]), len(self.cfg["terrain"]["measured_points_y"]))
|
||||
if "forward_depth" in components:
|
||||
if "output_resolution" in self.cfg["sensor"]["forward_camera"]:
|
||||
segments["forward_depth"] = (1, *self.cfg["sensor"]["forward_camera"]["output_resolution"])
|
||||
else:
|
||||
segments["forward_depth"] = (1, *self.cfg["sensor"]["forward_camera"]["resolution"])
|
||||
if "base_pose" in components:
|
||||
segments["base_pose"] = (6,) # xyz + rpy
|
||||
if "robot_config" in components:
|
||||
""" Related to robot_config_buffer attribute, Be careful to change. """
|
||||
# robot shape friction
|
||||
# CoM (Center of Mass) x, y, z
|
||||
# base mass (payload)
|
||||
# motor strength for each joint
|
||||
print("Warning: height_measurements is not typically available on the real robot.")
|
||||
segments["robot_config"] = (1 + 3 + 1 + self.NUM_ACTIONS,)
|
||||
|
||||
""" NOTE: The following components are not directly set in legged_robot.py.
|
||||
Please check the order or extend the class implementation if needed.
|
||||
"""
|
||||
if "joints_target" in components:
|
||||
# o[0] for target value, 0[1] for wether the target should be tracked (1) or not (0)
|
||||
segments["joints_target"] = (2, self.NUM_DOF)
|
||||
if "projected_gravity_target" in components:
|
||||
# projected_gravity for which the robot should track the target
|
||||
# last value as a mask of whether to follow the target or not
|
||||
segments["projected_gravity_target"] = (3+1,)
|
||||
if "orientation_cmds" in components:
|
||||
segments["orientation_cmds"] = (3,)
|
||||
return segments
|
||||
|
||||
def get_obs(self, obs_segments= None):
|
||||
""" Extract from the buffers and build the 1d observation tensor
|
||||
Each get ... obs function does not do the obs_scale multiplication.
|
||||
NOTE: obs_buffer has the batch dimension, whose size is 1.
|
||||
"""
|
||||
if obs_segments is None:
|
||||
obs_segments = self.obs_segments
|
||||
obs_buffer = []
|
||||
for k, v in obs_segments.items():
|
||||
if k in self.replace_obs_with_embeddings:
|
||||
obs_component_value = getattr(self, "_get_" + k + "_embedding_obs")()
|
||||
else:
|
||||
obs_component_value = getattr(self, "_get_" + k + "_obs")() * self.obs_scales.get(k, 1.0)
|
||||
obs_buffer.append(obs_component_value)
|
||||
obs_buffer = torch.cat(obs_buffer, dim=1)
|
||||
obs_buffer = torch.clamp(obs_buffer, -self.clip_obs, self.clip_obs)
|
||||
return obs_buffer
|
||||
""" Done: refresh observation buffer and corresponding sub-functions """
|
||||
|
||||
""" Control related functions """
|
||||
def clip_action_before_scale(self, action):
|
||||
action = torch.clip(action, -self.clip_actions, self.clip_actions)
|
||||
if getattr(self, "clip_actions_method", None) == "hard":
|
||||
action = torch.clip(action, self.clip_actions_low, self.clip_actions_high)
|
||||
return action
|
||||
|
||||
def clip_by_torque_limit(self, actions_scaled):
|
||||
""" Different from simulation, we reverse the process and clip the actions directly,
|
||||
so that the PD controller runs in robot but not our script.
|
||||
"""
|
||||
control_type = self.cfg["control"]["control_type"]
|
||||
if control_type == "P":
|
||||
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel_
|
||||
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel_
|
||||
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos_
|
||||
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos_
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return torch.clip(actions_scaled, actions_low, actions_high)
|
||||
|
||||
def send_action(self, actions):
|
||||
""" Send the action to the robot motors, which does the preprocessing
|
||||
just like env.step in simulation.
|
||||
Thus, the actions has the batch dimension, whose size is 1.
|
||||
"""
|
||||
self.actions = self.clip_action_before_scale(actions)
|
||||
if self.computer_clip_torque:
|
||||
clipped_scaled_action = self.clip_by_torque_limit(actions * self.action_scale)
|
||||
else:
|
||||
self.get_logger().warn("Computer Clip Torque is False, the robot may be damaged.", throttle_duration_sec= 1)
|
||||
clipped_scaled_action = actions * self.action_scale
|
||||
robot_coordinates_action = clipped_scaled_action + self.default_dof_pos.unsqueeze(0)
|
||||
|
||||
self._publish_legs_cmd(robot_coordinates_action[0])
|
||||
|
||||
""" Done: Control related functions """
|
||||
|
||||
""" functions that actually publish the commands and take effect """
|
||||
def _publish_legs_cmd(self, robot_coordinates_action: torch.Tensor):
|
||||
""" Publish the joint commands to the robot legs in robot coordinates system.
|
||||
robot_coordinates_action: shape (NUM_DOF,), in simulation order.
|
||||
"""
|
||||
for sim_idx in range(self.NUM_DOF):
|
||||
real_idx = self.dof_map[sim_idx]
|
||||
if not self.dryrun:
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].mode = self.turn_on_motor_mode[sim_idx]
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].q = robot_coordinates_action[sim_idx].item() * self.dof_signs[sim_idx]
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].dq = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].tau = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].kp = self.p_gains[sim_idx].item()
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].kd = self.d_gains[sim_idx].item()
|
||||
|
||||
self.low_cmd_buffer.crc = get_crc(self.low_cmd_buffer)
|
||||
self.low_cmd_pub.publish(self.low_cmd_buffer)
|
||||
|
||||
def _turn_off_motors(self):
|
||||
""" Turn off the motors """
|
||||
for sim_idx in range(self.NUM_DOF):
|
||||
real_idx = self.dof_map[sim_idx]
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].mode = 0x00
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].q = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].dq = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].tau = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].kp = 0.
|
||||
self.low_cmd_buffer.motor_cmd[real_idx].kd = 0.
|
||||
self.low_cmd_buffer.crc = get_crc(self.low_cmd_buffer)
|
||||
self.low_cmd_pub.publish(self.low_cmd_buffer)
|
||||
""" Done: functions that actually publish the commands and take effect """
|
Binary file not shown.
|
@ -30,3 +30,4 @@
|
|||
|
||||
from .ppo import PPO
|
||||
from .tppo import TPPO
|
||||
from .estimator import EstimatorPPO, EstimatorTPPO
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from rsl_rl.algorithms.ppo import PPO
|
||||
from rsl_rl.algorithms.tppo import TPPO
|
||||
from rsl_rl.utils.utils import unpad_trajectories, get_subobs_by_components
|
||||
from rsl_rl.storage.rollout_storage import SarsaRolloutStorage
|
||||
|
||||
class EstimatorAlgoMixin:
|
||||
""" A supervised algorithm implementation that trains a state predictor in the policy model
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
estimator_loss_func= "mse_loss",
|
||||
estimator_loss_kwargs= dict(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.estimator_obs_components = self.actor_critic.estimator_obs_components
|
||||
self.estimator_target_obs_components = self.actor_critic.estimator_target_components
|
||||
self.estimator_loss_func = estimator_loss_func
|
||||
self.estimator_loss_kwargs = estimator_loss_kwargs
|
||||
|
||||
def compute_losses(self, minibatch):
|
||||
losses, inter_vars, stats = super().compute_losses(minibatch)
|
||||
|
||||
# Use the critic_obs from the same timestep for estimation target
|
||||
# Not considering predicting the next state for now.
|
||||
estimation_target = get_subobs_by_components(
|
||||
minibatch.critic_obs,
|
||||
component_names= self.estimator_target_obs_components,
|
||||
obs_segments= self.actor_critic.privileged_obs_segments,
|
||||
)
|
||||
if self.actor_critic.is_recurrent:
|
||||
estimation_target = unpad_trajectories(estimation_target, minibatch.masks)
|
||||
|
||||
# actor_critic must compute the estimated_state_ during act() as a intermediate variable
|
||||
estimation = unpad_trajectories(self.actor_critic.get_estimated_state(), minibatch.masks)
|
||||
|
||||
estimator_loss = getattr(F, self.estimator_loss_func)(
|
||||
estimation,
|
||||
estimation_target,
|
||||
**self.estimator_loss_kwargs,
|
||||
reduction= "none",
|
||||
).sum(dim= -1)
|
||||
|
||||
losses["estimator_loss"] = estimator_loss.mean()
|
||||
|
||||
return losses, inter_vars, stats
|
||||
|
||||
class EstimatorPPO(EstimatorAlgoMixin, PPO):
|
||||
pass
|
||||
|
||||
class EstimatorTPPO(EstimatorAlgoMixin, TPPO):
|
||||
pass
|
|
@ -68,7 +68,6 @@ class PPO:
|
|||
self.actor_critic.to(self.device)
|
||||
self.storage = None # initialized later
|
||||
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
|
||||
self.transition = RolloutStorage.Transition()
|
||||
|
||||
# PPO parameters
|
||||
self.clip_param = clip_param
|
||||
|
@ -86,6 +85,7 @@ class PPO:
|
|||
self.current_learning_iteration = 0
|
||||
|
||||
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
|
||||
self.transition = RolloutStorage.Transition()
|
||||
self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device)
|
||||
|
||||
def test_mode(self):
|
||||
|
@ -108,7 +108,7 @@ class PPO:
|
|||
self.transition.critic_observations = critic_obs
|
||||
return self.transition.actions
|
||||
|
||||
def process_env_step(self, rewards, dones, infos):
|
||||
def process_env_step(self, rewards, dones, infos, next_obs, next_critic_obs):
|
||||
self.transition.rewards = rewards.clone()
|
||||
self.transition.dones = dones
|
||||
# Bootstrapping on time outs
|
||||
|
@ -162,9 +162,9 @@ class PPO:
|
|||
return mean_losses, average_stats
|
||||
|
||||
def compute_losses(self, minibatch):
|
||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
|
||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.actor)
|
||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.critic)
|
||||
mu_batch = self.actor_critic.action_mean
|
||||
sigma_batch = self.actor_critic.action_std
|
||||
try:
|
||||
|
@ -222,3 +222,22 @@ class PPO:
|
|||
if self.use_clipped_value_loss:
|
||||
inter_vars["value_clipped"] = value_clipped
|
||||
return return_, inter_vars, dict()
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"model_state_dict": self.actor_critic.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}
|
||||
if hasattr(self, "lr_scheduler"):
|
||||
state_dict["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.actor_critic.load_state_dict(state_dict["model_state_dict"])
|
||||
if "optimizer_state_dict" in state_dict:
|
||||
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
|
||||
if hasattr(self, "lr_scheduler"):
|
||||
self.lr_scheduler.load_state_dict(state_dict["lr_scheduler_state_dict"])
|
||||
elif "lr_scheduler_state_dict" in state_dict:
|
||||
print("Warning: lr scheduler state dict loaded but no lr scheduler is initialized. Ignored.")
|
||||
|
|
|
@ -6,17 +6,19 @@ import torch.nn.functional as F
|
|||
import torch.optim as optim
|
||||
|
||||
import rsl_rl.modules as modules
|
||||
from rsl_rl.utils import unpad_trajectories
|
||||
from rsl_rl.storage.rollout_storage import ActionLabelRollout
|
||||
from rsl_rl.algorithms.ppo import PPO
|
||||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
|
||||
# assuming learning iteration is at an assumable iteration scale
|
||||
def GET_TEACHER_ACT_PROB_FUNC(option, iteration_scale):
|
||||
TEACHER_ACT_PROB_options = {
|
||||
def GET_PROB_FUNC(option, iteration_scale):
|
||||
PROB_options = {
|
||||
"linear": (lambda x: max(0, 1 - 1 / iteration_scale * x)),
|
||||
"exp": (lambda x: max(0, (1 - 1 / iteration_scale) ** x)),
|
||||
"tanh": (lambda x: max(0, 0.5 * (1 - torch.tanh(1 / iteration_scale * (x - iteration_scale))))),
|
||||
}
|
||||
return TEACHER_ACT_PROB_options[option]
|
||||
return PROB_options[option]
|
||||
|
||||
class TPPO(PPO):
|
||||
def __init__(self,
|
||||
|
@ -28,35 +30,53 @@ class TPPO(PPO):
|
|||
teacher_act_prob= "exp", # a number or a callable to (0 ~ 1) to the selection of act using teacher policy
|
||||
update_times_scale= 100, # a rough estimation of how many times the update will be called
|
||||
using_ppo= True, # If False, compute_losses will skip ppo loss computation and returns to DAGGR
|
||||
distillation_loss_coef= 1.,
|
||||
distillation_loss_coef= 1., # can also be string to select a prob function to scale the distillation loss
|
||||
distill_target= "real",
|
||||
distill_latent_coef= 1.,
|
||||
distill_latent_target= "real",
|
||||
distill_latent_obs_component_mapping= None,
|
||||
buffer_dilation_ratio= 1.,
|
||||
lr_scheduler_class_name= None,
|
||||
lr_scheduler= dict(),
|
||||
hidden_state_resample_prob= 0.0, # if > 0, Some hidden state in the minibatch will be resampled
|
||||
action_labels_from_sample= False, # if True, the action labels from teacher policy will be from policy.act instead of policy.act_inference
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
- distill_latent_obs_component_mapping: a dict of
|
||||
{student_obs_component_name: teacher_obs_component_name}
|
||||
only when both policy are the instance of EncoderActorCriticMixin
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.label_action_with_critic_obs = label_action_with_critic_obs
|
||||
self.teacher_act_prob = teacher_act_prob
|
||||
self.update_times_scale = update_times_scale
|
||||
if isinstance(self.teacher_act_prob, str):
|
||||
self.teacher_act_prob = GET_TEACHER_ACT_PROB_FUNC(self.teacher_act_prob, update_times_scale)
|
||||
self.teacher_act_prob = GET_PROB_FUNC(self.teacher_act_prob, update_times_scale)
|
||||
else:
|
||||
self.__teacher_act_prob = self.teacher_act_prob
|
||||
self.teacher_act_prob = lambda x: self.__teacher_act_prob
|
||||
self.using_ppo = using_ppo
|
||||
self.distillation_loss_coef = distillation_loss_coef
|
||||
self.__distillation_loss_coef = distillation_loss_coef
|
||||
if isinstance(self.__distillation_loss_coef, str):
|
||||
self.distillation_loss_coef_func = GET_PROB_FUNC(self.__distillation_loss_coef, update_times_scale)
|
||||
self.distill_target = distill_target
|
||||
self.distill_latent_coef = distill_latent_coef
|
||||
self.distill_latent_target = distill_latent_target
|
||||
self.distill_latent_obs_component_mapping = distill_latent_obs_component_mapping
|
||||
self.buffer_dilation_ratio = buffer_dilation_ratio
|
||||
self.lr_scheduler_class_name = lr_scheduler_class_name
|
||||
self.lr_scheduler_kwargs = lr_scheduler
|
||||
self.hidden_state_resample_prob = hidden_state_resample_prob
|
||||
self.action_labels_from_sample = action_labels_from_sample
|
||||
self.transition = ActionLabelRollout.Transition()
|
||||
|
||||
# build and load teacher network
|
||||
teacher_actor_critic = getattr(modules, teacher_policy_class_name)(**teacher_policy)
|
||||
if not teacher_ac_path is None:
|
||||
if "{LEGGED_GYM_ROOT_DIR}" in teacher_ac_path:
|
||||
teacher_ac_path = teacher_ac_path.format(LEGGED_GYM_ROOT_DIR= LEGGED_GYM_ROOT_DIR)
|
||||
state_dict = torch.load(teacher_ac_path, map_location= "cpu")
|
||||
teacher_actor_critic_state_dict = state_dict["model_state_dict"]
|
||||
teacher_actor_critic.load_state_dict(teacher_actor_critic_state_dict)
|
||||
|
@ -84,8 +104,12 @@ class TPPO(PPO):
|
|||
def act(self, obs, critic_obs):
|
||||
# get actions
|
||||
return_ = super().act(obs, critic_obs)
|
||||
if self.label_action_with_critic_obs:
|
||||
if self.label_action_with_critic_obs and self.action_labels_from_sample:
|
||||
self.transition.action_labels = self.teacher_actor_critic.act(critic_obs).detach()
|
||||
elif self.label_action_with_critic_obs:
|
||||
self.transition.action_labels = self.teacher_actor_critic.act_inference(critic_obs).detach()
|
||||
elif self.action_labels_from_sample:
|
||||
self.transition.action_labels = self.teacher_actor_critic.act(obs).detach()
|
||||
else:
|
||||
self.transition.action_labels = self.teacher_actor_critic.act_inference(obs).detach()
|
||||
|
||||
|
@ -96,8 +120,8 @@ class TPPO(PPO):
|
|||
|
||||
return return_
|
||||
|
||||
def process_env_step(self, rewards, dones, infos):
|
||||
return_ = super().process_env_step(rewards, dones, infos)
|
||||
def process_env_step(self, rewards, dones, infos, next_obs, next_critic_obs):
|
||||
return_ = super().process_env_step(rewards, dones, infos, next_obs, next_critic_obs)
|
||||
self.teacher_actor_critic.reset(dones)
|
||||
# resample teacher action mask for those dones env
|
||||
self.use_teacher_act_mask[dones] = torch.rand(dones.sum(), device= self.device) < self.teacher_act_prob(self.current_learning_iteration)
|
||||
|
@ -107,7 +131,7 @@ class TPPO(PPO):
|
|||
""" The interface to collect transition from dataset rather than env """
|
||||
super().act(transition.observation, transition.privileged_observation)
|
||||
self.transition.action_labels = transition.action
|
||||
super().process_env_step(transition.reward, transition.done, infos)
|
||||
super().process_env_step(transition.reward, transition.done, infos, transition.next_observation, transition.next_privileged_observation)
|
||||
|
||||
def compute_returns(self, last_critic_obs):
|
||||
if not self.using_ppo:
|
||||
|
@ -124,30 +148,30 @@ class TPPO(PPO):
|
|||
def compute_losses(self, minibatch):
|
||||
if self.hidden_state_resample_prob > 0.0:
|
||||
# assuming the hidden states are from LSTM or GRU, which are always betwein -1 and 1
|
||||
hidden_state_example = minibatch.hid_states[0][0] if isinstance(minibatch.hid_states[0], tuple) else minibatch.hid_states[0]
|
||||
hidden_state_example = minibatch.hidden_states[0][0] if isinstance(minibatch.hidden_states[0], tuple) else minibatch.hidden_states[0]
|
||||
resample_mask = torch.rand(hidden_state_example.shape[1], device= self.device) < self.hidden_state_resample_prob
|
||||
# for each hidden state, resample from -1 to 1
|
||||
if isinstance(minibatch.hid_states[0], tuple):
|
||||
if isinstance(minibatch.hidden_states[0], tuple):
|
||||
# for LSTM not tested
|
||||
# iterate through actor and critic hidden state
|
||||
minibatch = minibatch._replace(hid_states= tuple(
|
||||
minibatch = minibatch._replace(hidden_states= tuple(
|
||||
tuple(
|
||||
torch.where(
|
||||
resample_mask.unsqueeze(-1).unsqueeze(-1),
|
||||
torch.rand_like(minibatch.hid_states[i][j], device= self.device) * 2 - 1,
|
||||
minibatch.hid_states[i][j],
|
||||
) for j in range(len(minibatch.hid_states[i]))
|
||||
) for i in range(len(minibatch.hid_states))
|
||||
torch.rand_like(minibatch.hidden_states[i][j], device= self.device) * 2 - 1,
|
||||
minibatch.hidden_states[i][j],
|
||||
) for j in range(len(minibatch.hidden_states[i]))
|
||||
) for i in range(len(minibatch.hidden_states))
|
||||
))
|
||||
else:
|
||||
# for GRU
|
||||
# iterate through actor and critic hidden state
|
||||
minibatch = minibatch._replace(hid_states= tuple(
|
||||
minibatch = minibatch._replace(hidden_states= tuple(
|
||||
torch.where(
|
||||
resample_mask.unsqueeze(-1),
|
||||
torch.rand_like(minibatch.hid_states[i], device= self.device) * 2 - 1,
|
||||
minibatch.hid_states[i],
|
||||
) for i in range(len(minibatch.hid_states))
|
||||
torch.rand_like(minibatch.hidden_states[i], device= self.device) * 2 - 1,
|
||||
minibatch.hidden_states[i],
|
||||
) for i in range(len(minibatch.hidden_states))
|
||||
))
|
||||
|
||||
if self.using_ppo:
|
||||
|
@ -156,7 +180,7 @@ class TPPO(PPO):
|
|||
losses = dict()
|
||||
inter_vars = dict()
|
||||
stats = dict()
|
||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
|
||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hidden_states.actor)
|
||||
|
||||
# distillation loss (with teacher actor)
|
||||
if self.distill_target == "real":
|
||||
|
@ -164,6 +188,12 @@ class TPPO(PPO):
|
|||
self.actor_critic.action_mean - minibatch.action_labels,
|
||||
dim= -1
|
||||
)
|
||||
elif self.distill_target == "mse_sum":
|
||||
dist_loss = F.mse_loss(
|
||||
self.actor_critic.action_mean,
|
||||
minibatch.action_labels,
|
||||
reduction= "none",
|
||||
).sum(-1)
|
||||
elif self.distill_target == "l1":
|
||||
dist_loss = torch.norm(
|
||||
self.actor_critic.action_mean - minibatch.action_labels,
|
||||
|
@ -187,6 +217,11 @@ class TPPO(PPO):
|
|||
(minibatch.action_labels + 1) * 0.5, # (n, t, d)
|
||||
reduction= "none",
|
||||
).mean(-1) * 2 * l1 / self.actor_critic.action_mean.shape[-1] # (n, t)
|
||||
elif self.distill_target == "max_log_prob":
|
||||
action_labels_log_prob = self.actor_critic.get_actions_log_prob(minibatch.action_labels)
|
||||
dist_loss = -action_labels_log_prob
|
||||
elif self.distill_target == "kl":
|
||||
raise NotImplementedError()
|
||||
|
||||
if "tanh" in self.distill_target:
|
||||
stats["l1distance"] = torch.norm(
|
||||
|
@ -199,7 +234,54 @@ class TPPO(PPO):
|
|||
dim= -1,
|
||||
p= 1
|
||||
).mean().detach()
|
||||
# update distillation loss coef if applicable
|
||||
self.distillation_loss_coef = self.distillation_loss_coef_func(self.current_learning_iteration) if hasattr(self, "distillation_loss_coef_func") else self.__distillation_loss_coef
|
||||
losses["distillation_loss"] = dist_loss.mean()
|
||||
|
||||
# distill latent embedding
|
||||
if self.distill_latent_obs_component_mapping is not None:
|
||||
for k, v in self.distill_latent_obs_component_mapping.items():
|
||||
# get the latent embedding
|
||||
latent = self.actor_critic.get_encoder_latent(
|
||||
minibatch.obs,
|
||||
k,
|
||||
)
|
||||
with torch.no_grad():
|
||||
target_latent = self.teacher_actor_critic.get_encoder_latent(
|
||||
minibatch.critic_obs,
|
||||
v,
|
||||
)
|
||||
if self.actor_critic.is_recurrent:
|
||||
latent = unpad_trajectories(latent, minibatch.masks)
|
||||
target_latent = unpad_trajectories(target_latent, minibatch.masks)
|
||||
if self.distill_latent_target == "real":
|
||||
dist_loss = torch.norm(
|
||||
latent - target_latent,
|
||||
dim= -1,
|
||||
)
|
||||
elif self.distill_latent_target == "l1":
|
||||
dist_loss = torch.norm(
|
||||
latent - target_latent,
|
||||
dim= -1,
|
||||
p= 1,
|
||||
)
|
||||
elif self.distill_latent_target == "tanh":
|
||||
dist_loss = F.binary_cross_entropy(
|
||||
(latent + 1) * 0.5,
|
||||
(target_latent + 1) * 0.5,
|
||||
) * 2
|
||||
elif self.distill_latent_target == "scaled_tanh":
|
||||
l1 = torch.norm(
|
||||
latent - target_latent,
|
||||
dim= -1,
|
||||
p= 1,
|
||||
)
|
||||
dist_loss = F.binary_cross_entropy(
|
||||
(latent + 1) * 0.5,
|
||||
(target_latent + 1) * 0.5, # (n, t, d)
|
||||
reduction= "none",
|
||||
).mean(-1) * 2 * l1 / latent.shape[-1] # (n, t)
|
||||
setattr(self, f"distill_latent_{k}_coef", self.distill_latent_coef)
|
||||
losses[f"distill_latent_{k}"] = dist_loss.mean()
|
||||
|
||||
return losses, inter_vars, stats
|
||||
|
|
|
@ -33,6 +33,9 @@ from .actor_critic_recurrent import ActorCriticRecurrent
|
|||
from .visual_actor_critic import VisualDeterministicRecurrent, VisualDeterministicAC
|
||||
from .actor_critic_mutex import ActorCriticMutex
|
||||
from .actor_critic_field_mutex import ActorCriticFieldMutex, ActorCriticClimbMutex
|
||||
from .encoder_actor_critic import EncoderActorCriticMixin, EncoderActorCritic, EncoderActorCriticRecurrent
|
||||
from .state_estimator import EstimatorMixin, StateAc, StateAcRecurrent
|
||||
from .all_mixer import EncoderStateAc, EncoderStateAcRecurrent
|
||||
|
||||
def build_actor_critic(env, policy_class_name, policy_cfg):
|
||||
""" NOTE: This method allows to hack the policy kwargs by adding the env attributes to the policy_cfg. """
|
||||
|
|
|
@ -45,11 +45,20 @@ class ActorCritic(nn.Module):
|
|||
activation='elu',
|
||||
init_noise_std=1.0,
|
||||
mu_activation= None, # If set, the last layer will be added with a special activation layer.
|
||||
obs_segments= None,
|
||||
privileged_obs_segments= None, # No need
|
||||
**kwargs):
|
||||
if kwargs:
|
||||
print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
|
||||
super(ActorCritic, self).__init__()
|
||||
|
||||
# obs_segmnets is a ordered dict that contains the observation components (string) and its shape.
|
||||
# We use this to slice the observations into the correct components.
|
||||
if privileged_obs_segments is None:
|
||||
privileged_obs_segments = obs_segments
|
||||
self.obs_segments = obs_segments
|
||||
self.privileged_obs_segments = privileged_obs_segments
|
||||
|
||||
activation = get_activation(activation)
|
||||
|
||||
mlp_input_dim_a = num_actor_obs
|
||||
|
|
|
@ -36,6 +36,9 @@ from torch.distributions import Normal
|
|||
from torch.nn.modules import rnn
|
||||
from .actor_critic import ActorCritic, get_activation
|
||||
from rsl_rl.utils import unpad_trajectories
|
||||
from rsl_rl.utils.collections import namedarraytuple, is_namedarraytuple
|
||||
|
||||
ActorCriticHiddenState = namedarraytuple('ActorCriticHiddenState', ['actor', 'critic'])
|
||||
|
||||
class ActorCriticRecurrent(ActorCritic):
|
||||
is_recurrent = True
|
||||
|
@ -74,39 +77,47 @@ class ActorCriticRecurrent(ActorCritic):
|
|||
|
||||
def act(self, observations, masks=None, hidden_states=None):
|
||||
input_a = self.memory_a(observations, masks, hidden_states)
|
||||
return super().act(input_a.squeeze(0))
|
||||
return super().act(input_a)
|
||||
|
||||
def act_inference(self, observations):
|
||||
input_a = self.memory_a(observations)
|
||||
return super().act_inference(input_a.squeeze(0))
|
||||
return super().act_inference(input_a)
|
||||
|
||||
def evaluate(self, critic_observations, masks=None, hidden_states=None):
|
||||
input_c = self.memory_c(critic_observations, masks, hidden_states)
|
||||
return super().evaluate(input_c.squeeze(0))
|
||||
return super().evaluate(input_c)
|
||||
|
||||
def get_hidden_states(self):
|
||||
return self.memory_a.hidden_states, self.memory_c.hidden_states
|
||||
return ActorCriticHiddenState(self.memory_a.hidden_states, self.memory_c.hidden_states)
|
||||
|
||||
LstmHiddenState = namedarraytuple('LstmHiddenState', ['hidden', 'cell'])
|
||||
|
||||
class Memory(torch.nn.Module):
|
||||
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
|
||||
super().__init__()
|
||||
# RNN
|
||||
# RNN currently support only GRU and LSTM
|
||||
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
|
||||
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
|
||||
self.hidden_states = None
|
||||
|
||||
def forward(self, input, masks=None, hidden_states=None):
|
||||
batch_mode = masks is not None
|
||||
batch_mode = hidden_states is not None
|
||||
if batch_mode:
|
||||
# batch mode (policy update): need saved hidden states
|
||||
if hidden_states is None:
|
||||
raise ValueError("Hidden states not passed to memory module during policy update")
|
||||
if is_namedarraytuple(hidden_states):
|
||||
hidden_states = tuple(hidden_states)
|
||||
out, _ = self.rnn(input, hidden_states)
|
||||
out = unpad_trajectories(out, masks)
|
||||
if not masks is None:
|
||||
# in this case, user can choose whether to unpad the output or not
|
||||
out = unpad_trajectories(out, masks)
|
||||
else:
|
||||
# inference mode (collection): use hidden states of last step
|
||||
if is_namedarraytuple(self.hidden_states):
|
||||
self.hidden_states = tuple(self.hidden_states)
|
||||
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
|
||||
if isinstance(self.hidden_states, tuple):
|
||||
self.hidden_states = LstmHiddenState(*self.hidden_states)
|
||||
out = out.squeeze(0) # remove the time dimension
|
||||
return out
|
||||
|
||||
def reset(self, dones=None):
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
""" A file put all mixin class combinations """
|
||||
from .actor_critic import ActorCritic
|
||||
from .actor_critic_recurrent import ActorCriticRecurrent
|
||||
from .encoder_actor_critic import EncoderActorCriticMixin
|
||||
from .state_estimator import EstimatorMixin
|
||||
|
||||
class EncoderStateAc(EstimatorMixin, EncoderActorCriticMixin, ActorCritic):
|
||||
pass
|
||||
|
||||
class EncoderStateAcRecurrent(EstimatorMixin, EncoderActorCriticMixin, ActorCriticRecurrent):
|
||||
|
||||
def load_misaligned_state_dict(self, module, obs_segments, privileged_obs_segments=None):
|
||||
pass
|
|
@ -0,0 +1,188 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rsl_rl.modules.mlp import MlpModel
|
||||
from rsl_rl.modules.conv2d import Conv2dHeadModel
|
||||
from rsl_rl.utils.utils import get_obs_slice
|
||||
|
||||
class EncoderActorCriticMixin:
|
||||
""" A general implementation where a seperate encoder is used to embed the obs/privileged_obs """
|
||||
def __init__(self,
|
||||
num_actor_obs,
|
||||
num_critic_obs,
|
||||
num_actions,
|
||||
obs_segments= None,
|
||||
privileged_obs_segments= None,
|
||||
encoder_component_names= [], # allow multiple encoders
|
||||
encoder_class_name= "MlpModel", # accept list of names (in the same order as encoder_component_names)
|
||||
encoder_kwargs= dict(), # accept list of kwargs (in the same order as encoder_component_names),
|
||||
encoder_output_size= None,
|
||||
critic_encoder_component_names= None, # None, "shared", or a list of names (in the same order as encoder_component_names)
|
||||
critic_encoder_class_name= None, # accept list of names (in the same order as encoder_component_names)
|
||||
critic_encoder_kwargs= None, # accept list of kwargs (in the same order as encoder_component_names),
|
||||
**kwargs,
|
||||
):
|
||||
""" NOTE: recurrent encoder is not implemented and tested yet.
|
||||
"""
|
||||
self.num_actor_obs = num_actor_obs
|
||||
self.num_critic_obs = num_critic_obs
|
||||
self.num_actions = num_actions
|
||||
self.obs_segments = obs_segments
|
||||
self.privileged_obs_segments = privileged_obs_segments
|
||||
self.encoder_component_names = encoder_component_names
|
||||
self.encoder_class_name = encoder_class_name
|
||||
self.encoder_kwargs = encoder_kwargs
|
||||
self.encoder_output_size = encoder_output_size
|
||||
self.critic_encoder_component_names = critic_encoder_component_names
|
||||
self.critic_encoder_class_name = critic_encoder_class_name if not critic_encoder_class_name is None else encoder_class_name
|
||||
self.critic_encoder_kwargs = critic_encoder_kwargs if not critic_encoder_kwargs is None else encoder_kwargs
|
||||
self.obs_segments = obs_segments
|
||||
self.privileged_obs_segments = privileged_obs_segments
|
||||
|
||||
self.prepare_obs_slices()
|
||||
|
||||
super().__init__(
|
||||
num_actor_obs - sum([s[0].stop - s[0].start for s in self.encoder_obs_slices]) + len(self.encoder_obs_slices) * self.encoder_output_size,
|
||||
num_critic_obs if self.critic_encoder_component_names is None else num_critic_obs - sum([s[0].stop - s[0].start for s in self.critic_encoder_obs_slices]) + len(self.critic_encoder_obs_slices) * self.encoder_output_size,
|
||||
num_actions,
|
||||
obs_segments= obs_segments,
|
||||
privileged_obs_segments= privileged_obs_segments,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.encoders = self.build_encoders(
|
||||
self.encoder_component_names,
|
||||
self.encoder_class_name,
|
||||
self.encoder_obs_slices,
|
||||
self.encoder_kwargs,
|
||||
self.encoder_output_size,
|
||||
)
|
||||
if not (self.critic_encoder_component_names is None or self.critic_encoder_component_names == "shared"):
|
||||
self.critic_encoders = self.build_encoders(
|
||||
self.critic_encoder_component_names,
|
||||
self.critic_encoder_class_name,
|
||||
self.critic_encoder_obs_slices,
|
||||
self.critic_encoder_kwargs,
|
||||
self.encoder_output_size,
|
||||
)
|
||||
|
||||
def prepare_obs_slices(self):
|
||||
# NOTE: encoders are stored in the order of obs_component_names respectively.
|
||||
# latents_order stores the order of how each output latent should be concatenated with
|
||||
# the rest of the obs vector.
|
||||
self.encoder_obs_slices = [get_obs_slice(self.obs_segments, name) for name in self.encoder_component_names]
|
||||
self.latents_order = [i for i in range(len(self.encoder_obs_slices))]
|
||||
self.latents_order.sort(key= lambda i: self.encoder_obs_slices[i][0].start)
|
||||
if self.critic_encoder_component_names is not None:
|
||||
if self.critic_encoder_component_names == "shared":
|
||||
self.critic_encoder_obs_slices = self.encoder_obs_slices
|
||||
else:
|
||||
critic_obs_segments = self.obs_segments if self.privileged_obs_segments is None else self.privileged_obs_segments
|
||||
self.critic_encoder_obs_slices = [get_obs_slice(critic_obs_segments, name) for name in self.critic_encoder_component_names]
|
||||
self.critic_latents_order = [i for i in range(len(self.critic_encoder_obs_slices))]
|
||||
self.critic_latents_order.sort(key= lambda i: self.critic_encoder_obs_slices[i][0].start)
|
||||
|
||||
def build_encoders(self, component_names, class_name, obs_slices, kwargs, encoder_output_size):
|
||||
encoders = nn.ModuleList()
|
||||
for component_i, name in enumerate(component_names):
|
||||
model_class_name = class_name[component_i] if isinstance(class_name, (tuple, list)) else class_name
|
||||
obs_slice = obs_slices[component_i]
|
||||
model_kwargs = kwargs[component_i] if isinstance(kwargs, (tuple, list)) else kwargs
|
||||
model_kwargs = model_kwargs.copy() # 1-level shallow copy
|
||||
# This code is not clean enough, need to sort out later
|
||||
if model_class_name == "MlpModel":
|
||||
hidden_sizes = model_kwargs.pop("hidden_sizes") + [encoder_output_size,]
|
||||
encoders.append(MlpModel(
|
||||
np.prod(obs_slice[1]),
|
||||
hidden_sizes= hidden_sizes,
|
||||
output_size= None,
|
||||
**model_kwargs,
|
||||
))
|
||||
elif model_class_name == "Conv2dHeadModel":
|
||||
hidden_sizes = model_kwargs.pop("hidden_sizes") + [encoder_output_size,]
|
||||
encoders.append(Conv2dHeadModel(
|
||||
obs_slice[1],
|
||||
hidden_sizes= hidden_sizes,
|
||||
output_size= None,
|
||||
**model_kwargs,
|
||||
))
|
||||
else:
|
||||
raise NotImplementedError(f"Encoder for {model_class_name} on {name} not implemented")
|
||||
return encoders
|
||||
|
||||
def embed_encoders_latent(self, observations, obs_slices, encoders, latents_order):
|
||||
leading_dims = observations.shape[:-1]
|
||||
latents = []
|
||||
for encoder_i, encoder in enumerate(encoders):
|
||||
# This code is not clean enough, need to sort out later
|
||||
if isinstance(encoder, MlpModel):
|
||||
latents.append(encoder(
|
||||
observations[..., obs_slices[encoder_i][0]].reshape(-1, np.prod(obs_slices[encoder_i][1]))
|
||||
).reshape(*leading_dims, -1))
|
||||
elif isinstance(encoder, Conv2dHeadModel):
|
||||
latents.append(encoder(
|
||||
observations[..., obs_slices[encoder_i][0]].reshape(-1, *obs_slices[encoder_i][1])
|
||||
).reshape(*leading_dims, -1))
|
||||
else:
|
||||
raise NotImplementedError(f"Encoder for {type(encoder)} not implemented")
|
||||
# replace the obs vector with the latent vector in eace obs_slice[0] (the slice of obs)
|
||||
embedded_obs = []
|
||||
embedded_obs.append(observations[..., :obs_slices[latents_order[0]][0].start])
|
||||
for order_i in range(len(latents)- 1):
|
||||
current_idx = latents_order[order_i]
|
||||
next_idx = latents_order[order_i + 1]
|
||||
embedded_obs.append(latents[current_idx])
|
||||
embedded_obs.append(observations[..., obs_slices[current_idx][0].stop: obs_slices[next_idx][0].start])
|
||||
current_idx = latents_order[-1]
|
||||
next_idx = None
|
||||
embedded_obs.append(latents[current_idx])
|
||||
embedded_obs.append(observations[..., obs_slices[current_idx][0].stop:])
|
||||
|
||||
return torch.cat(embedded_obs, dim= -1)
|
||||
|
||||
def get_encoder_latent(self, observations, obs_component, critic= False):
|
||||
""" Get the latent vector from the encoder of the specified obs_component
|
||||
"""
|
||||
leading_dims = observations.shape[:-1]
|
||||
encoder_obs_components = self.critic_encoder_component_names if critic else self.encoder_component_names
|
||||
encoder_obs_slices = self.critic_encoder_obs_slices if critic else self.encoder_obs_slices
|
||||
encoders = self.critic_encoders if critic else self.encoders
|
||||
|
||||
for i, name in enumerate(encoder_obs_components):
|
||||
if name == obs_component:
|
||||
obs_component_var = observations[..., encoder_obs_slices[i][0]]
|
||||
if isinstance(encoders[i], MlpModel):
|
||||
obs_component_var = obs_component_var.reshape(-1, np.prod(encoder_obs_slices[i][1]))
|
||||
elif isinstance(encoders[i], Conv2dHeadModel):
|
||||
obs_component_var = obs_component_var.reshape(-1, *encoder_obs_slices[i][1])
|
||||
latent = encoders[i](obs_component_var).reshape(*leading_dims, -1)
|
||||
return latent
|
||||
raise ValueError(f"obs_component {obs_component} not found in encoder_obs_components")
|
||||
|
||||
def act(self, observations, **kwargs):
|
||||
obs = self.embed_encoders_latent(observations, self.encoder_obs_slices, self.encoders, self.latents_order)
|
||||
return super().act(obs, **kwargs)
|
||||
|
||||
def act_inference(self, observations):
|
||||
obs = self.embed_encoders_latent(observations, self.encoder_obs_slices, self.encoders, self.latents_order)
|
||||
return super().act_inference(obs)
|
||||
|
||||
def evaluate(self, critic_observations, masks=None, hidden_states=None):
|
||||
if self.critic_encoder_component_names == "shared":
|
||||
obs = self.embed_encoders_latent(critic_observations, self.encoder_obs_slices, self.encoders, self.latents_order)
|
||||
elif self.critic_encoder_component_names is None:
|
||||
obs = critic_observations
|
||||
else:
|
||||
obs = self.embed_encoders_latent(critic_observations, self.critic_encoder_obs_slices, self.critic_encoders, self.critic_latents_order)
|
||||
return super().evaluate(obs, masks, hidden_states)
|
||||
|
||||
from .actor_critic import ActorCritic
|
||||
class EncoderActorCritic(EncoderActorCriticMixin, ActorCritic):
|
||||
pass
|
||||
|
||||
from .actor_critic_recurrent import ActorCriticRecurrent
|
||||
class EncoderActorCriticRecurrent(EncoderActorCriticMixin, ActorCriticRecurrent):
|
||||
pass
|
||||
|
|
@ -24,6 +24,8 @@ class MlpModel(torch.nn.Module):
|
|||
hidden_sizes = [hidden_sizes]
|
||||
elif hidden_sizes is None:
|
||||
hidden_sizes = []
|
||||
if isinstance(nonlinearity, str):
|
||||
nonlinearity = getattr(torch.nn, nonlinearity)
|
||||
hidden_layers = [torch.nn.Linear(n_in, n_out) for n_in, n_out in
|
||||
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
|
||||
sequence = list()
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .actor_critic import ActorCritic
|
||||
from .actor_critic_recurrent import ActorCriticRecurrent, ActorCriticHiddenState
|
||||
from rsl_rl.modules.mlp import MlpModel
|
||||
from rsl_rl.modules.actor_critic_recurrent import Memory
|
||||
from rsl_rl.utils import unpad_trajectories
|
||||
from rsl_rl.utils.utils import get_subobs_size, get_subobs_by_components, substitute_estimated_state
|
||||
from rsl_rl.utils.collections import namedarraytuple
|
||||
|
||||
EstimatorActorHiddenState = namedarraytuple('EstimatorActorHiddenState', [
|
||||
'estimator',
|
||||
'actor',
|
||||
])
|
||||
|
||||
class EstimatorMixin:
|
||||
def __init__(self,
|
||||
*args,
|
||||
estimator_obs_components= None, # a list of strings used to get obs slices
|
||||
estimator_target_components= None, # a list of strings used to get obs slices
|
||||
estimator_kwargs= dict(),
|
||||
use_actor_rnn= False, # if set and self.is_recurrent, use actor-rnn output as the input of the state estimator directly
|
||||
replace_state_prob= 0., # if 0~1, replace the actor observation with the estimated state with this probability
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
self.estimator_obs_components = estimator_obs_components
|
||||
self.estimator_target_components = estimator_target_components
|
||||
self.estimator_kwargs = estimator_kwargs
|
||||
self.use_actor_rnn = use_actor_rnn
|
||||
self.replace_state_prob = replace_state_prob
|
||||
assert (self.replace_state_prob <= 0.) or (not self.use_actor_rnn), "You cannot replace the actor's observation (part) after the actor already used it's memory module. "
|
||||
self.build_estimator(**kwargs)
|
||||
|
||||
def build_estimator(self, **kwargs):
|
||||
""" This implementation is not flexible enough, but it is enough for now. """
|
||||
estimator_input_size = get_subobs_size(self.obs_segments, self.estimator_obs_components)
|
||||
estimator_output_size = get_subobs_size(self.obs_segments, self.estimator_target_components)
|
||||
if self.is_recurrent:
|
||||
# estimate required state using a recurrent network
|
||||
if self.use_actor_rnn:
|
||||
estimator_input_size = self.memory_a.rnn.hidden_size
|
||||
self.state_estimator = MlpModel(
|
||||
input_size= estimator_input_size,
|
||||
output_size= estimator_output_size,
|
||||
**self.estimator_kwargs,
|
||||
)
|
||||
else:
|
||||
self.memory_s = Memory(
|
||||
estimator_input_size,
|
||||
type= kwargs.get("rnn_type", "lstm"),
|
||||
num_layers= kwargs.get("rnn_num_layers", 1),
|
||||
hidden_size= kwargs.get("rnn_hidden_size", 256),
|
||||
)
|
||||
self.state_estimator = MlpModel(
|
||||
input_size= self.memory_s.rnn.hidden_size,
|
||||
output_size= estimator_output_size,
|
||||
**self.estimator_kwargs,
|
||||
)
|
||||
else:
|
||||
# estimate required state using a feedforward network
|
||||
self.state_estimator = MlpModel(
|
||||
input_size= estimator_input_size,
|
||||
output_size= estimator_output_size,
|
||||
**self.estimator_kwargs,
|
||||
)
|
||||
|
||||
def reset(self, dones=None):
|
||||
super().reset(dones)
|
||||
if self.is_recurrent and not self.use_actor_rnn:
|
||||
self.memory_s.reset(dones)
|
||||
|
||||
def act(self, observations, masks=None, hidden_states=None, inference= False):
|
||||
observations = observations.clone()
|
||||
if inference:
|
||||
assert masks is None and hidden_states is None, "Inference mode does not support masks and hidden_states. "
|
||||
if self.is_recurrent and self.use_actor_rnn:
|
||||
# NOTE:
|
||||
# In this branch, it may requires a redesign for the entire actor_critic.act interface.
|
||||
input_s = self.memory_a(observations, masks, hidden_states)
|
||||
action = ActorCritic.act(self, input_s)
|
||||
self.estimated_state_ = self.state_estimator(input_s)
|
||||
elif self.is_recurrent and not self.use_actor_rnn:
|
||||
# TODO: allows non-recurrent state estimator with recurrent actor
|
||||
subobs = get_subobs_by_components(
|
||||
observations,
|
||||
component_names= self.estimator_obs_components,
|
||||
obs_segments= self.obs_segments,
|
||||
)
|
||||
input_s = self.memory_s(
|
||||
subobs,
|
||||
None, # use None to prevent unpadding
|
||||
hidden_states if hidden_states is None else hidden_states.estimator,
|
||||
)
|
||||
# QUESTION: after memory_s, the estimated_state is already unpadded. How to get
|
||||
# the padded format and feed it into actor's observation?
|
||||
# SOLUTION: modify the code of Memory module, use masks= None to stop the unpadding.
|
||||
self.estimated_state_ = self.state_estimator(input_s)
|
||||
use_estimated_state_mask = torch.rand_like(observations[..., 0]) < self.replace_state_prob
|
||||
observations[use_estimated_state_mask] = substitute_estimated_state(
|
||||
observations[use_estimated_state_mask],
|
||||
self.estimator_target_components,
|
||||
self.estimated_state_[use_estimated_state_mask].detach(),
|
||||
self.obs_segments,
|
||||
)
|
||||
if inference:
|
||||
action = super().act_inference(observations)
|
||||
else:
|
||||
action = super().act(
|
||||
observations,
|
||||
masks= masks,
|
||||
hidden_states= hidden_states if hidden_states is None else hidden_states.actor,
|
||||
)
|
||||
else:
|
||||
# both state estimator and actor are feedforward (non-recurrent)
|
||||
subobs = get_subobs_by_components(
|
||||
observations,
|
||||
component_names= self.estimator_obs_components,
|
||||
obs_segments= self.obs_segments,
|
||||
)
|
||||
self.estimated_state_ = self.state_estimator(subobs)
|
||||
use_estimated_state_mask = torch.rand_like(observations[..., 0]) < self.replace_state_prob
|
||||
observations[use_estimated_state_mask] = substitute_estimated_state(
|
||||
observations[use_estimated_state_mask],
|
||||
self.estimator_target_components,
|
||||
self.estimated_state_[use_estimated_state_mask].detach(),
|
||||
self.obs_segments,
|
||||
)
|
||||
if inference:
|
||||
action = super().act_inference(observations)
|
||||
else:
|
||||
action = super().act(observations, masks= masks, hidden_states= hidden_states)
|
||||
return action
|
||||
|
||||
def act_inference(self, observations):
|
||||
return self.act(observations, inference= True)
|
||||
|
||||
""" No modification required for evaluate() """
|
||||
|
||||
def get_estimated_state(self):
|
||||
""" In order to maintain the same interface of ActorCritic(Recurrent),
|
||||
the user must call this function after calling act() to get the estimated state.
|
||||
"""
|
||||
return self.estimated_state_
|
||||
|
||||
def get_hidden_states(self):
|
||||
return_ = super().get_hidden_states()
|
||||
if self.is_recurrent and not self.use_actor_rnn:
|
||||
return_ = return_._replace(actor= EstimatorActorHiddenState(
|
||||
self.memory_s.hidden_states,
|
||||
return_.actor,
|
||||
))
|
||||
return return_
|
||||
|
||||
class StateAc(EstimatorMixin, ActorCritic):
|
||||
pass
|
||||
|
||||
class StateAcRecurrent(EstimatorMixin, ActorCriticRecurrent):
|
||||
pass
|
|
@ -6,10 +6,12 @@ import time
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from tabulate import tabulate
|
||||
|
||||
from rsl_rl.modules import build_actor_critic
|
||||
from rsl_rl.runners.demonstration import DemonstrationSaver
|
||||
from rsl_rl.algorithms.tppo import GET_TEACHER_ACT_PROB_FUNC
|
||||
from rsl_rl.algorithms.tppo import GET_PROB_FUNC
|
||||
|
||||
class DaggerSaver(DemonstrationSaver):
|
||||
""" This demonstration saver will rollout the trajectory by running the student policy
|
||||
|
@ -21,6 +23,7 @@ class DaggerSaver(DemonstrationSaver):
|
|||
teacher_act_prob= "exp",
|
||||
update_times_scale= 5000,
|
||||
action_sample_std= 0.0, # if > 0, add Gaussian noise to the action in effort.
|
||||
log_to_tensorboard= False, # if True, log the rollout episode info to tensorboard
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -32,9 +35,18 @@ class DaggerSaver(DemonstrationSaver):
|
|||
self.teacher_act_prob = teacher_act_prob
|
||||
self.update_times_scale = update_times_scale
|
||||
self.action_sample_std = action_sample_std
|
||||
self.log_to_tensorboard = log_to_tensorboard
|
||||
if self.log_to_tensorboard:
|
||||
self.tb_writer = SummaryWriter(
|
||||
log_dir= osp.join(
|
||||
self.training_policy_logdir,
|
||||
"_".join(["collector", *(osp.basename(self.save_dir).split("_")[:2])]),
|
||||
),
|
||||
flush_secs= 10,
|
||||
)
|
||||
|
||||
if isinstance(self.teacher_act_prob, str):
|
||||
self.teacher_act_prob = GET_TEACHER_ACT_PROB_FUNC(self.teacher_act_prob, update_times_scale)
|
||||
self.teacher_act_prob = GET_PROB_FUNC(self.teacher_act_prob, update_times_scale)
|
||||
else:
|
||||
self.__teacher_act_prob = self.teacher_act_prob
|
||||
self.teacher_act_prob = lambda x: self.__teacher_act_prob
|
||||
|
@ -47,6 +59,11 @@ class DaggerSaver(DemonstrationSaver):
|
|||
self.build_training_policy()
|
||||
return return_
|
||||
|
||||
def init_storage_buffer(self):
|
||||
return_ = super().init_storage_buffer()
|
||||
self.rollout_episode_infos = []
|
||||
return return_
|
||||
|
||||
def build_training_policy(self):
|
||||
""" Load the latest training policy model. """
|
||||
with open(osp.join(self.training_policy_logdir, "config.json"), "r") as f:
|
||||
|
@ -77,21 +94,26 @@ class DaggerSaver(DemonstrationSaver):
|
|||
time.sleep(0.1)
|
||||
self.training_policy.load_state_dict(loaded_dict["model_state_dict"])
|
||||
self.training_policy_iteration = loaded_dict["iter"]
|
||||
# override the action std in self.training_policy
|
||||
with torch.no_grad():
|
||||
if self.action_sample_std > 0:
|
||||
self.training_policy.std[:] = self.action_sample_std
|
||||
print("Training policy iteration: {}".format(self.training_policy_iteration))
|
||||
self.use_teacher_act_mask = torch.rand(self.env.num_envs) < self.teacher_act_prob(self.training_policy_iteration)
|
||||
|
||||
def get_transition(self):
|
||||
if self.use_critic_obs:
|
||||
teacher_actions = self.policy.act_inference(self.critic_obs)
|
||||
else:
|
||||
teacher_actions = self.policy.act_inference(self.obs)
|
||||
teacher_actions = self.get_policy_actions()
|
||||
actions = self.training_policy.act(self.obs)
|
||||
if self.action_sample_std > 0:
|
||||
actions += torch.randn_like(actions) * self.action_sample_std
|
||||
actions[self.use_teacher_act_mask] = teacher_actions[self.use_teacher_act_mask]
|
||||
n_obs, n_critic_obs, rewards, dones, infos = self.env.step(actions)
|
||||
# Use teacher actions to label the trajectory, no matter what the student policy does
|
||||
return teacher_actions, rewards, dones, infos, n_obs, n_critic_obs
|
||||
|
||||
def add_transition(self, step_i, infos):
|
||||
return_ = super().add_transition(step_i, infos)
|
||||
if "episode" in infos:
|
||||
self.rollout_episode_infos.append(infos["episode"])
|
||||
return return_
|
||||
|
||||
def policy_reset(self, dones):
|
||||
return_ = super().policy_reset(dones)
|
||||
|
@ -103,3 +125,33 @@ class DaggerSaver(DemonstrationSaver):
|
|||
""" Also check whether need to load the latest training policy model. """
|
||||
self.load_latest_training_policy()
|
||||
return super().check_stop()
|
||||
|
||||
def print_log(self):
|
||||
# Copy from runner logging mechanism. TODO: optimize these implementation into one.
|
||||
ep_table = []
|
||||
for key in self.rollout_episode_infos[0].keys():
|
||||
infotensor = torch.tensor([], device= self.env.device)
|
||||
for ep_info in self.rollout_episode_infos:
|
||||
if not isinstance(ep_info[key], torch.Tensor):
|
||||
ep_info[key] = torch.Tensor([ep_info[key]])
|
||||
if len(ep_info[key].shape) == 0:
|
||||
ep_info[key] = ep_info[key].unsqueeze(0)
|
||||
infotensor = torch.cat((infotensor, ep_info[key].to(self.env.device)))
|
||||
if "_max" in key:
|
||||
infotensor = infotensor[~infotensor.isnan()]
|
||||
value = torch.max(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
|
||||
elif "_min" in key:
|
||||
infotensor = infotensor[~infotensor.isnan()]
|
||||
value = torch.min(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
|
||||
else:
|
||||
value = torch.nanmean(infotensor)
|
||||
if self.log_to_tensorboard:
|
||||
self.tb_writer.add_scalar('Episode/' + key, value, self.training_policy_iteration)
|
||||
ep_table.append(("Episode/" + key, value.detach().cpu().item()))
|
||||
# NOTE: assuming dagger trainner's iteration is always faster than collector's iteration
|
||||
# Otherwise, the training_policy will not be updated.
|
||||
self.training_policy_iteration += 1
|
||||
print("Sampling saved for training policy iteration: {}".format(self.training_policy_iteration))
|
||||
print(tabulate(ep_table))
|
||||
self.rollout_episode_infos = []
|
||||
return super().print_log()
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import os.path as osp
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,12 +22,14 @@ class DemonstrationSaver:
|
|||
success_traj_only = False, # if true, the trajectory terminated no by timeout will be dumped.
|
||||
use_critic_obs= False,
|
||||
obs_disassemble_mapping= None,
|
||||
demo_by_sample= False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
obs_disassemble_mapping (dict):
|
||||
If set, the obs segment will be compressed using given type.
|
||||
example: {"forward_depth": "normalized_image", "forward_rgb": "normalized_image"}
|
||||
obs_disassemble_mapping (dict): If set, the obs segment will be compressed using given
|
||||
type. example: {"forward_depth": "normalized_image", "forward_rgb": "normalized_image"}
|
||||
demo_by_sample (bool): # if True, the action will be sampled (policy.act) from the
|
||||
policy instead of using the mean (policy.act_inference).
|
||||
"""
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
|
@ -38,6 +41,7 @@ class DemonstrationSaver:
|
|||
self.use_critic_obs = use_critic_obs
|
||||
self.success_traj_only = success_traj_only
|
||||
self.obs_disassemble_mapping = obs_disassemble_mapping
|
||||
self.demo_by_sample = demo_by_sample
|
||||
self.RolloutStorageCls = RolloutStorage
|
||||
|
||||
def init_traj_handlers(self):
|
||||
|
@ -106,11 +110,19 @@ class DemonstrationSaver:
|
|||
self.policy_reset(dones)
|
||||
self.obs, self.critic_obs = n_obs, n_critic_obs
|
||||
|
||||
def get_transition(self):
|
||||
if self.use_critic_obs:
|
||||
def get_policy_actions(self):
|
||||
if self.use_critic_obs and self.demo_by_sample:
|
||||
actions = self.policy.act(self.critic_obs)
|
||||
elif self.use_critic_obs:
|
||||
actions = self.policy.act_inference(self.critic_obs)
|
||||
elif self.demo_by_sample:
|
||||
actions = self.policy.act(self.obs)
|
||||
else:
|
||||
actions = self.policy.act_inference(self.obs)
|
||||
return actions
|
||||
|
||||
def get_transition(self):
|
||||
actions = self.get_policy_actions()
|
||||
n_obs, n_critic_obs, rewards, dones, infos = self.env.step(actions)
|
||||
return actions, rewards, dones, infos, n_obs, n_critic_obs
|
||||
|
||||
|
@ -153,10 +165,16 @@ class DemonstrationSaver:
|
|||
pickle.dump(trajectory, f)
|
||||
self.dumped_traj_lengths[env_i] += step_slice.stop - step_slice.start
|
||||
self.total_timesteps += step_slice.stop - step_slice.start
|
||||
with open(osp.join(self.save_dir, "metadata.json"), "w") as f:
|
||||
|
||||
def dump_metadata(self):
|
||||
self.metadata["total_timesteps"] = self.total_timesteps.item() if isinstance(self.total_timesteps, np.int64) else self.total_timesteps
|
||||
self.metadata["total_trajectories"] = self.total_traj_completed
|
||||
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
|
||||
json.dump(self.metadata, f, indent= 4)
|
||||
|
||||
def wrap_up_trajectory(self, env_i, step_slice):
|
||||
# wrap up from the rollout_storage based on `step_slice`. Thus, `step_slice` must include
|
||||
# the `done` step if exist.
|
||||
trajectory = dict(
|
||||
privileged_observations= self.rollout_storage.privileged_observations[step_slice, env_i].cpu().numpy(),
|
||||
actions= self.rollout_storage.actions[step_slice, env_i].cpu().numpy(),
|
||||
|
@ -224,8 +242,7 @@ class DemonstrationSaver:
|
|||
self.dump_to_file(rollout_env_i, slice(0, self.rollout_storage.num_transitions_per_env))
|
||||
else:
|
||||
start_idx = 0
|
||||
di = 0
|
||||
while di < done_idxs.shape[0]:
|
||||
for di in range(done_idxs.shape[0]):
|
||||
end_idx = done_idxs[di].item()
|
||||
|
||||
# dump and update the traj_idx for this env
|
||||
|
@ -233,7 +250,9 @@ class DemonstrationSaver:
|
|||
self.update_traj_handler(rollout_env_i, slice(start_idx, end_idx+1))
|
||||
|
||||
start_idx = end_idx + 1
|
||||
di += 1
|
||||
if start_idx < self.rollout_storage.num_transitions_per_env:
|
||||
self.dump_to_file(rollout_env_i, slice(start_idx, self.rollout_storage.num_transitions_per_env))
|
||||
self.dump_metadata()
|
||||
|
||||
def collect_and_save(self, config= None):
|
||||
""" Run the rolllout to collect the demonstration data and save it to the file """
|
||||
|
@ -276,6 +295,11 @@ class DemonstrationSaver:
|
|||
|
||||
def print_log(self):
|
||||
""" print the log """
|
||||
self.print_log_time = time.monotonic()
|
||||
if hasattr(self, "last_print_log_time"):
|
||||
print("time elapsed:", self.print_log_time - self.last_print_log_time)
|
||||
print("throughput:", self.total_timesteps / (self.print_log_time - self.last_print_log_time))
|
||||
self.last_print_log_time = self.print_log_time
|
||||
print("total_timesteps:", self.total_timesteps)
|
||||
print("total_trajectories", self.total_traj_completed)
|
||||
|
||||
|
@ -292,8 +316,5 @@ class DemonstrationSaver:
|
|||
os.rmdir(traj_dir)
|
||||
for timestep_count in self.dumped_traj_lengths:
|
||||
self.total_timesteps += timestep_count
|
||||
self.metadata["total_timesteps"] = self.total_timesteps.item() if isinstance(self.total_timesteps, np.int64) else self.total_timesteps
|
||||
self.metadata["total_trajectories"] = self.total_traj_completed
|
||||
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
|
||||
json.dump(self.metadata, f, indent= 4)
|
||||
self.dump_metadata()
|
||||
print(f"Saved dataset in {self.save_dir}")
|
||||
|
|
|
@ -33,12 +33,13 @@ import os
|
|||
from collections import deque
|
||||
import statistics
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tensorboardX import SummaryWriter
|
||||
import torch
|
||||
|
||||
import rsl_rl.algorithms as algorithms
|
||||
import rsl_rl.modules as modules
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.utils import ckpt_manipulator
|
||||
|
||||
|
||||
class OnPolicyRunner:
|
||||
|
@ -99,12 +100,15 @@ class OnPolicyRunner:
|
|||
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
|
||||
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
|
||||
|
||||
print("Initialization done, start learning.")
|
||||
print("NOTE: you may see a bunch of `NaN or Inf found in input tensor` once and appears in the log. Just ignore it if it does not affect the performance.")
|
||||
start_iter = self.current_learning_iteration
|
||||
tot_iter = self.current_learning_iteration + num_learning_iterations
|
||||
tot_start_time = time.time()
|
||||
start = time.time()
|
||||
while self.current_learning_iteration < tot_iter:
|
||||
start = time.time()
|
||||
# Rollout
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(self.cfg.get("inference_mode_rollout", True)):
|
||||
for i in range(self.num_steps_per_env):
|
||||
obs, critic_obs, rewards, dones, infos = self.rollout_step(obs, critic_obs)
|
||||
|
||||
|
@ -137,6 +141,7 @@ class OnPolicyRunner:
|
|||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||
ep_infos.clear()
|
||||
self.current_learning_iteration = self.current_learning_iteration + 1
|
||||
start = time.time()
|
||||
|
||||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||
|
||||
|
@ -145,12 +150,12 @@ class OnPolicyRunner:
|
|||
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
|
||||
critic_obs = privileged_obs if privileged_obs is not None else obs
|
||||
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
|
||||
self.alg.process_env_step(rewards, dones, infos)
|
||||
self.alg.process_env_step(rewards, dones, infos, obs, critic_obs)
|
||||
return obs, critic_obs, rewards, dones, infos
|
||||
|
||||
def log(self, locs, width=80, pad=35):
|
||||
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
|
||||
self.tot_time += locs['collection_time'] + locs['learn_time']
|
||||
self.tot_time = time.time() - locs['tot_start_time']
|
||||
iteration_time = locs['collection_time'] + locs['learn_time']
|
||||
|
||||
ep_string = f''
|
||||
|
@ -164,7 +169,14 @@ class OnPolicyRunner:
|
|||
if len(ep_info[key].shape) == 0:
|
||||
ep_info[key] = ep_info[key].unsqueeze(0)
|
||||
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
|
||||
value = torch.mean(infotensor)
|
||||
if "_max" in key:
|
||||
infotensor = infotensor[~infotensor.isnan()]
|
||||
value = torch.max(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
|
||||
elif "_min" in key:
|
||||
infotensor = infotensor[~infotensor.isnan()]
|
||||
value = torch.min(infotensor) if len(infotensor) > 0 else torch.tensor(float("nan"))
|
||||
else:
|
||||
value = torch.nanmean(infotensor)
|
||||
self.writer.add_scalar('Episode/' + key, value, self.current_learning_iteration)
|
||||
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
|
||||
mean_std = self.alg.actor_critic.action_std.mean()
|
||||
|
@ -181,10 +193,12 @@ class OnPolicyRunner:
|
|||
self.writer.add_scalar('Perf/collection time', locs['collection_time'], self.current_learning_iteration)
|
||||
self.writer.add_scalar('Perf/learning_time', locs['learn_time'], self.current_learning_iteration)
|
||||
self.writer.add_scalar('Perf/gpu_allocated', torch.cuda.memory_allocated(self.device) / 1024 ** 3, self.current_learning_iteration)
|
||||
self.writer.add_scalar('Perf/gpu_occupied', torch.cuda.mem_get_info(self.device)[1] / 1024 ** 3, self.current_learning_iteration)
|
||||
self.writer.add_scalar('Perf/gpu_global_free_mem', torch.cuda.mem_get_info(self.device)[0] / 1024 ** 3, self.current_learning_iteration)
|
||||
self.writer.add_scalar('Perf/gpu_total', torch.cuda.mem_get_info(self.device)[1] / 1024 ** 3, self.current_learning_iteration)
|
||||
self.writer.add_scalar('Train/mean_reward_each_timestep', statistics.mean(locs['rframebuffer']), self.current_learning_iteration)
|
||||
if len(locs['rewbuffer']) > 0:
|
||||
self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), self.current_learning_iteration)
|
||||
self.writer.add_scalar('Train/ratio_above_mean_reward', statistics.mean([(1. if rew > statistics.mean(locs['rewbuffer']) else 0) for rew in locs['rewbuffer']]), self.current_learning_iteration)
|
||||
self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), self.current_learning_iteration)
|
||||
self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
|
||||
self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
|
||||
|
@ -192,29 +206,37 @@ class OnPolicyRunner:
|
|||
str = f" \033[1m Learning iteration {self.current_learning_iteration}/{locs['tot_iter']} \033[0m "
|
||||
|
||||
if len(locs['rewbuffer']) > 0:
|
||||
log_string = (f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
|
||||
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
|
||||
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
||||
)
|
||||
log_string = (
|
||||
f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
)
|
||||
for k, v in locs["losses"].items():
|
||||
log_string += f"""{k:>{pad}} {v.item():.4f}\n"""
|
||||
log_string += (
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
|
||||
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
||||
)
|
||||
else:
|
||||
log_string = (f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
|
||||
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
||||
)
|
||||
log_string = (
|
||||
f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
)
|
||||
for k. v in locs["losses"].items():
|
||||
log_string += f"""{k:>{pad}} {v.item():.4f}\n"""
|
||||
log_string += (
|
||||
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
|
||||
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
||||
)
|
||||
|
||||
log_string += ep_string
|
||||
log_string += (f"""{'-' * width}\n"""
|
||||
|
@ -226,29 +248,30 @@ class OnPolicyRunner:
|
|||
print(log_string)
|
||||
|
||||
def save(self, path, infos=None):
|
||||
run_state_dict = {
|
||||
'model_state_dict': self.alg.actor_critic.state_dict(),
|
||||
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
||||
run_state_dict = self.alg.state_dict()
|
||||
run_state_dict.update({
|
||||
'iter': self.current_learning_iteration,
|
||||
'infos': infos,
|
||||
}
|
||||
if hasattr(self.alg, "lr_scheduler"):
|
||||
run_state_dict["lr_scheduler_state_dict"] = self.alg.lr_scheduler.state_dict()
|
||||
})
|
||||
torch.save(run_state_dict, path)
|
||||
|
||||
def load(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path)
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
|
||||
if "lr_scheduler_state_dict" in loaded_dict:
|
||||
if not hasattr(self.alg, "lr_scheduler"):
|
||||
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")
|
||||
else:
|
||||
self.alg.lr_scheduler.load_state_dict(loaded_dict["lr_scheduler_state_dict"])
|
||||
elif hasattr(self.alg, "lr_scheduler"):
|
||||
print("Warning: lr_scheduler_state_dict not found in checkpoint but lr_scheduler in algorithm. Ignoring.")
|
||||
if self.cfg.get("ckpt_manipulator", False):
|
||||
# suppose to be a string specifying which function to use
|
||||
print("\033[1;36m Warning: using a hacky way to load the model. \033[0m")
|
||||
loaded_dict = getattr(ckpt_manipulator, self.cfg["ckpt_manipulator"])(
|
||||
loaded_dict,
|
||||
self.alg.state_dict(),
|
||||
)
|
||||
print("\033[1;36m Done: using a hacky way to load the model. \033[0m")
|
||||
self.alg.load_state_dict(loaded_dict)
|
||||
self.current_learning_iteration = loaded_dict['iter']
|
||||
if self.cfg.get("ckpt_manipulator", False):
|
||||
try:
|
||||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||
except:
|
||||
print("\033[1;36m Save manipulated checkpoint failed, ignored... \033[0m")
|
||||
return loaded_dict['infos']
|
||||
|
||||
def get_inference_policy(self, device=None):
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import torch
|
||||
|
||||
from rsl_rl.runners.on_policy_runner import OnPolicyRunner
|
||||
from rsl_rl.storage.rollout_dataset import RolloutDataset
|
||||
from rsl_rl.storage.rollout_files.rollout_dataset import RolloutDataset
|
||||
|
||||
class TwoStageRunner(OnPolicyRunner):
|
||||
""" A runner that have a pretrain stage which is used to collect demonstration data """
|
||||
|
@ -17,7 +17,7 @@ class TwoStageRunner(OnPolicyRunner):
|
|||
self.rollout_dataset = RolloutDataset(
|
||||
**self.cfg["pretrain_dataset"],
|
||||
num_envs= self.env.num_envs,
|
||||
rl_device= self.alg.device,
|
||||
device= self.alg.device,
|
||||
)
|
||||
|
||||
def rollout_step(self, obs, critic_obs):
|
||||
|
@ -31,8 +31,8 @@ class TwoStageRunner(OnPolicyRunner):
|
|||
if not transition is None:
|
||||
self.alg.collect_transition_from_dataset(transition, infos)
|
||||
return (
|
||||
transition.observation,
|
||||
transition.privileged_observation,
|
||||
transition.next_observation,
|
||||
transition.next_privileged_observation,
|
||||
transition.reward,
|
||||
transition.done,
|
||||
infos,
|
||||
|
|
|
@ -1,359 +0,0 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
from collections import namedtuple, OrderedDict
|
||||
import json
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset, get_worker_info
|
||||
|
||||
import rsl_rl.utils.data_compresser as compresser
|
||||
|
||||
class RolloutDataset(IterableDataset):
|
||||
Transitions = namedtuple("Transitions", [
|
||||
"observation", "privileged_observation", "action", "reward", "done",
|
||||
])
|
||||
|
||||
def __init__(self,
|
||||
data_dir= None,
|
||||
scan_dir= None,
|
||||
num_envs= 1,
|
||||
dataset_loops: int= 1,
|
||||
subset_traj= None, # (start_idx, end_idx) as a slice
|
||||
random_shuffle_traj_order= False, # If True, the traj_data will be loaded directoy to rl_device in a random order
|
||||
keep_latest_ratio= 1.0, # If < 1., only keeps a certain ratio of the latest trajectories
|
||||
keep_latest_n_trajs= 0, # If > 0 and more than n_trajectories, ignores keep_latest_ratio and keeps the latest n trajectories.
|
||||
starting_frame_range= [0, 1], # if set, the starting timestep will be uniformly chose from this, when each new trajectory is loaded.
|
||||
# if sampled starting frame is bigger than the trajectory length, starting frame will be 0
|
||||
load_data_to_device= True, # If True, the traj_data will be loaded directoy to rl_device rather than np array
|
||||
rl_device= "cpu",
|
||||
):
|
||||
""" choose data_dir or scan_dir, but not both. If scan_dir is chosen, the dataset will scan the
|
||||
directory and treat each direct subdirectory as a dataset everytime it is initialized.
|
||||
"""
|
||||
self.data_dir = data_dir
|
||||
self.scan_dir = scan_dir
|
||||
self.num_envs = num_envs
|
||||
self.max_loops = dataset_loops
|
||||
self.subset_traj = subset_traj
|
||||
self.random_shuffle_traj_order = random_shuffle_traj_order
|
||||
self.keep_latest_ratio = keep_latest_ratio
|
||||
self.keep_latest_n_trajs = keep_latest_n_trajs
|
||||
self.starting_frame_range = starting_frame_range
|
||||
self.load_data_to_device = load_data_to_device
|
||||
self.rl_device = rl_device
|
||||
|
||||
# check arguments
|
||||
assert not (self.data_dir is None and self.scan_dir is None), "data_dir and scan_dir cannot be both None"
|
||||
|
||||
self.num_looped = 0
|
||||
|
||||
def initialize(self):
|
||||
self.load_dataset_directory()
|
||||
if self.subset_traj is not None:
|
||||
self.unused_traj_dirs = self.unused_traj_dirs[self.subset_traj[0]: self.subset_traj[1]]
|
||||
if self.keep_latest_ratio < 1. or self.keep_latest_n_trajs > 0:
|
||||
self.unused_traj_dirs = sorted(
|
||||
self.unused_traj_dirs,
|
||||
key= lambda x: os.stat(x).st_ctime,
|
||||
)
|
||||
if self.keep_latest_n_trajs > 0:
|
||||
self.unused_traj_dirs = self.unused_traj_dirs[-self.keep_latest_n_trajs:]
|
||||
else:
|
||||
self.unused_traj_dirs = self.unused_traj_dirs[int(len(self.unused_traj_dirs) * self.keep_latest_ratio):]
|
||||
print("Using a subset of trajectories, total number of trajectories: ", len(self.unused_traj_dirs))
|
||||
if self.random_shuffle_traj_order:
|
||||
random.shuffle(self.unused_traj_dirs)
|
||||
|
||||
# attributes that handles trajectory files for each env
|
||||
self.current_traj_dirs = [None for _ in range(self.num_envs)]
|
||||
self.trajectory_files = [[] for _ in range(self.num_envs)]
|
||||
self.traj_file_idxs = np.zeros(self.num_envs, dtype= np.int32)
|
||||
self.traj_step_idxs = np.zeros(self.num_envs, dtype= np.int32)
|
||||
self.traj_datas = [None for _ in range(self.num_envs)]
|
||||
|
||||
env_idx = 0
|
||||
while env_idx < self.num_envs:
|
||||
if len(self.unused_traj_dirs) == 0:
|
||||
print("Not enough trajectories, waiting to re-initialize. Press Enter to continue....")
|
||||
input()
|
||||
self.initialize()
|
||||
return
|
||||
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
|
||||
update_result = self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
|
||||
if update_result:
|
||||
env_idx += 1
|
||||
|
||||
self.dataset_drained = False
|
||||
|
||||
def update_traj_handle(self, env_idx, traj_dir, starting_step_idx= 0):
|
||||
""" Load and update the trajectory handle for a given env_idx.
|
||||
Also update traj_step_idxs.
|
||||
Return whether the trajectory is successfully loaded
|
||||
"""
|
||||
self.current_traj_dirs[env_idx] = traj_dir
|
||||
try:
|
||||
self.trajectory_files[env_idx] = sorted(
|
||||
os.listdir(self.current_traj_dirs[env_idx]),
|
||||
key= lambda x: int(x.split("_")[1]),
|
||||
)
|
||||
self.traj_file_idxs[env_idx] = 0
|
||||
except:
|
||||
self.nullify_traj_handles(env_idx)
|
||||
return False
|
||||
self.traj_datas[env_idx] = self.load_traj_data(
|
||||
env_idx,
|
||||
self.traj_file_idxs[env_idx],
|
||||
new_episode= True,
|
||||
)
|
||||
if self.traj_datas[env_idx] is None:
|
||||
self.nullify_traj_handles(env_idx)
|
||||
return False
|
||||
|
||||
# The number in the file name is the timestep slice
|
||||
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
|
||||
while current_file_max_timestep < starting_step_idx:
|
||||
self.traj_file_idxs[env_idx] += 1
|
||||
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
|
||||
# trajectory length is shorter than starting_step_idx, set starting_step_idx to 0
|
||||
starting_step_idx = 0
|
||||
self.traj_file_idxs[env_idx] = 0
|
||||
break
|
||||
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
|
||||
|
||||
current_file_min_step = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[1])
|
||||
self.traj_step_idxs[env_idx] = starting_step_idx - current_file_min_step
|
||||
if self.traj_file_idxs[env_idx] > 0:
|
||||
# reload the traj_data because traj_file_idxs is updated
|
||||
self.traj_datas[env_idx] = self.load_traj_data(
|
||||
env_idx,
|
||||
self.traj_file_idxs[env_idx],
|
||||
new_episode= True,
|
||||
)
|
||||
if self.traj_datas[env_idx] is None:
|
||||
self.nullify_traj_handles(env_idx)
|
||||
return False
|
||||
return True
|
||||
|
||||
def nullify_traj_handles(self, env_idx):
|
||||
self.current_traj_dirs[env_idx] = ""
|
||||
self.trajectory_files[env_idx] = []
|
||||
self.traj_file_idxs[env_idx] = 0
|
||||
self.traj_step_idxs[env_idx] = 0
|
||||
self.traj_datas[env_idx] = None
|
||||
|
||||
def load_dataset_directory(self):
|
||||
if self.scan_dir is not None:
|
||||
if not osp.isdir(self.scan_dir):
|
||||
print("RolloutDataset: scan_dir {} does not exist, creating...".format(self.scan_dir))
|
||||
os.makedirs(self.scan_dir)
|
||||
self.data_dir = sorted([
|
||||
osp.join(self.scan_dir, x) \
|
||||
for x in os.listdir(self.scan_dir) \
|
||||
if osp.isdir(osp.join(self.scan_dir, x)) and osp.isfile(osp.join(self.scan_dir, x, "metadata.json"))
|
||||
])
|
||||
if isinstance(self.data_dir, list):
|
||||
total_timesteps = 0
|
||||
self.unused_traj_dirs = []
|
||||
for data_dir in self.data_dir:
|
||||
try:
|
||||
new_trajectories = sorted([
|
||||
osp.join(data_dir, x) \
|
||||
for x in os.listdir(data_dir) \
|
||||
if x.startswith("trajectory_") and len(os.listdir(osp.join(data_dir, x))) > 0
|
||||
], key= lambda x: int(x.split("_")[-1]))
|
||||
except:
|
||||
continue
|
||||
self.unused_traj_dirs.extend(new_trajectories)
|
||||
try:
|
||||
with open(osp.join(data_dir, "metadata.json"), "r") as f:
|
||||
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
|
||||
total_timesteps += self.metadata["total_timesteps"]
|
||||
except:
|
||||
pass # skip
|
||||
print("RolloutDataset: Loaded data from multiple directories. The metadata is from the last directory.")
|
||||
print("RolloutDataset: Total number of timesteps: ", total_timesteps)
|
||||
print("RolloutDataset: Total number of trajectories: ", len(self.unused_traj_dirs))
|
||||
else:
|
||||
self.unused_traj_dirs = sorted([
|
||||
osp.join(self.data_dir, x) \
|
||||
for x in os.listdir(self.data_dir) \
|
||||
if x.startswith("trajectory_") and len(os.listdir(osp.join(self.data_dir, x))) > 0
|
||||
], key= lambda x: int(x.split("_")[-1]))
|
||||
with open(osp.join(self.data_dir, "metadata.json"), "r") as f:
|
||||
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
# check if this dataset is initialized in worker process
|
||||
worker_info = get_worker_info()
|
||||
if worker_info is not None:
|
||||
self.dataset_loops = 1 # Let the sampler handle the loops
|
||||
worker_id = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
trajs_per_worker = len(self.unused_traj_dirs) // num_workers
|
||||
self.unused_traj_dirs = self.unused_traj_dirs[worker_id * trajs_per_worker: (worker_id + 1) * trajs_per_worker]
|
||||
if worker_id == num_workers - 1:
|
||||
self.unused_traj_dirs.extend(self.unused_traj_dirs[:(len(self.unused_traj_dirs) % num_workers)])
|
||||
print("RolloutDataset: Worker {} of {} initialized with {} trajectories".format(
|
||||
worker_id, num_workers, len(self.unused_traj_dirs)
|
||||
))
|
||||
|
||||
def assmeble_obs_components(self, traj_data):
|
||||
assert "obs_segments" in self.metadata, "Corrupted metadata, obs_segments not found in metadata"
|
||||
observations = []
|
||||
for component_name in self.metadata["obs_segments"].keys():
|
||||
obs_component = traj_data.pop("obs_" + component_name)
|
||||
if component_name in self.metadata["obs_disassemble_mapping"]:
|
||||
obs_component = getattr(
|
||||
compresser,
|
||||
"decompress_" + self.metadata["obs_disassemble_mapping"][component_name],
|
||||
)(obs_component)
|
||||
observations.append(obs_component)
|
||||
traj_data["observations"] = np.concatenate(observations, axis= -1) # (n_steps, d_obs)
|
||||
return traj_data
|
||||
|
||||
def load_traj_data(self, env_idx, traj_file_idx, new_episode= False):
|
||||
""" If new_episode, set the 0-th frame to done, making sure the agent is reset.
|
||||
"""
|
||||
traj_dir = self.current_traj_dirs[env_idx]
|
||||
try:
|
||||
with open(osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx]), "rb") as f:
|
||||
traj_data = pickle.load(f)
|
||||
except:
|
||||
try:
|
||||
traj_file = osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx])
|
||||
print("Failed to load", traj_file)
|
||||
except:
|
||||
print("Failed to load file")
|
||||
# The caller will know that the file is abscent, then switch to a new trajectory
|
||||
return None
|
||||
# connect the observation components if they are disassambled in pickle files
|
||||
if "obs_disassemble_mapping" in self.metadata:
|
||||
traj_data = self.assmeble_obs_components(traj_data)
|
||||
if self.load_data_to_device:
|
||||
for data_key, data_val in traj_data.items():
|
||||
traj_data[data_key] = torch.from_numpy(data_val).to(self.rl_device)
|
||||
if new_episode:
|
||||
# add done flag to the 0-th step of newly loaded trajectory
|
||||
traj_data["dones"][0] = True
|
||||
return traj_data
|
||||
|
||||
def get_transition_batch(self):
|
||||
if not hasattr(self, "dataset_drained"):
|
||||
# initialize the dataset if it is not used as a iterator
|
||||
self.initialize()
|
||||
observations = []
|
||||
privileged_observations = []
|
||||
actions = []
|
||||
rewards = []
|
||||
dones = []
|
||||
time_outs = []
|
||||
if self.dataset_drained:
|
||||
return None, None
|
||||
for env_idx in range(self.num_envs):
|
||||
traj_data = self.traj_datas[env_idx]
|
||||
traj_step_idx = self.traj_step_idxs[env_idx]
|
||||
observations.append(traj_data["observations"][traj_step_idx])
|
||||
privileged_observations.append(traj_data["privileged_observations"][traj_step_idx])
|
||||
actions.append(traj_data["actions"][traj_step_idx])
|
||||
rewards.append(traj_data["rewards"][traj_step_idx])
|
||||
dones.append(traj_data["dones"][traj_step_idx])
|
||||
if "timeouts" in traj_data: time_outs.append(traj_data["timeouts"][traj_step_idx])
|
||||
self.traj_step_idxs[env_idx] += 1
|
||||
traj_update_result = self.update_traj_data_if_needed(env_idx)
|
||||
if traj_update_result == "drained":
|
||||
self.dataset_drained = True
|
||||
return None, None
|
||||
elif traj_update_result == "new_traj":
|
||||
dones[-1][:] = True
|
||||
if torch.is_tensor(observations[0]):
|
||||
observations = torch.stack(observations)
|
||||
else:
|
||||
observations = torch.from_numpy(np.stack(observations)).to(self.rl_device)
|
||||
if torch.is_tensor(privileged_observations[0]):
|
||||
privileged_observations = torch.stack(privileged_observations)
|
||||
else:
|
||||
privileged_observations = torch.from_numpy(np.stack(privileged_observations)).to(self.rl_device)
|
||||
if torch.is_tensor(actions[0]):
|
||||
actions = torch.stack(actions)
|
||||
else:
|
||||
actions = torch.from_numpy(np.stack(actions)).to(self.rl_device)
|
||||
if torch.is_tensor(rewards[0]):
|
||||
rewards = torch.stack(rewards).squeeze(-1) # to remove the last dimension as the simulator env
|
||||
else:
|
||||
rewards = torch.from_numpy(np.stack(rewards)).to(self.rl_device).squeeze(-1)
|
||||
if torch.is_tensor(dones[0]):
|
||||
dones = torch.stack(dones).to(bool).squeeze(-1)
|
||||
else:
|
||||
dones = torch.from_numpy(np.stack(dones)).to(self.rl_device).to(bool).squeeze(-1)
|
||||
infos = dict()
|
||||
if time_outs:
|
||||
if torch.is_tensor(time_outs[0]):
|
||||
infos["time_outs"] = torch.stack(time_outs)
|
||||
else:
|
||||
infos["time_outs"] = torch.from_numpy(np.stack(time_outs)).to(self.rl_device)
|
||||
infos["num_looped"] = self.num_looped
|
||||
return self.Transitions(
|
||||
observation= observations,
|
||||
privileged_observation= privileged_observations,
|
||||
action= actions,
|
||||
reward= rewards,
|
||||
done= dones,
|
||||
), infos
|
||||
|
||||
def update_traj_data_if_needed(self, env_idx):
|
||||
""" Return 'new_file', 'new_traj', 'drained', or None
|
||||
"""
|
||||
traj_data = self.traj_datas[env_idx]
|
||||
if self.traj_step_idxs[env_idx] >= len(traj_data["rewards"]):
|
||||
# to next file
|
||||
self.traj_file_idxs[env_idx] += 1
|
||||
self.traj_step_idxs[env_idx] = 0
|
||||
traj_data = None
|
||||
new_episode = False
|
||||
while traj_data is None:
|
||||
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
|
||||
# to next trajectory
|
||||
if len(self.unused_traj_dirs) == 0 or not osp.isdir(self.unused_traj_dirs[0]):
|
||||
if self.max_loops > 0 and self.num_looped >= self.max_loops:
|
||||
return 'drained'
|
||||
else:
|
||||
self.num_looped += 1
|
||||
self.initialize()
|
||||
return 'new_traj'
|
||||
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
|
||||
self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
|
||||
traj_data = self.traj_datas[env_idx]
|
||||
else:
|
||||
traj_data = self.load_traj_data(
|
||||
env_idx,
|
||||
self.traj_file_idxs[env_idx],
|
||||
new_episode= new_episode,
|
||||
)
|
||||
if traj_data is None:
|
||||
self.nullify_traj_handles(env_idx)
|
||||
else:
|
||||
self.traj_datas[env_idx] = traj_data
|
||||
return 'new_file'
|
||||
return None
|
||||
|
||||
def set_traj_idx(self, traj_idx, env_idx= 0):
|
||||
""" Allow users to select a specific trajectory to start from """
|
||||
self.current_traj_dirs[env_idx] = self.unused_traj_dirs[traj_idx]
|
||||
self.traj_file_idxs[env_idx] = 0
|
||||
self.traj_step_idxs[env_idx] = 0
|
||||
self.trajectory_files[env_idx] = sorted(
|
||||
os.listdir(self.current_traj_dirs[env_idx]),
|
||||
key= lambda x: int(x.split("_")[1]),
|
||||
)
|
||||
self.traj_datas[env_idx] = self.load_traj_data(env_idx, self.traj_file_idxs[env_idx])
|
||||
self.dataset_drained = False
|
||||
|
||||
##### Interfaces for the IterableDataset #####
|
||||
def __iter__(self):
|
||||
self.initialize()
|
||||
transition_batch, infos = self.get_transition_batch()
|
||||
while transition_batch is not None:
|
||||
yield transition_batch, infos
|
||||
transition_batch, infos = self.get_transition_batch()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue