Fix done in pusht, Fix --time in sbatch
This commit is contained in:
parent
664cfb2023
commit
591985c67d
|
@ -88,10 +88,6 @@ def add_tee(
|
||||||
|
|
||||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
# available_datasets = [
|
|
||||||
# "xarm_lift_medium",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -233,6 +229,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
block_angle = state[:, 4]
|
block_angle = state[:, 4]
|
||||||
|
|
||||||
reward = torch.zeros(num_frames, 1)
|
reward = torch.zeros(num_frames, 1)
|
||||||
|
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
for i in range(num_frames):
|
for i in range(num_frames):
|
||||||
space = pymunk.Space()
|
space = pymunk.Space()
|
||||||
|
@ -257,7 +254,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
goal_area = goal_geom.area
|
goal_area = goal_geom.area
|
||||||
coverage = intersection_area / goal_area
|
coverage = intersection_area / goal_area
|
||||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||||
done[i] = coverage > SUCCESS_THRESHOLD
|
success[i] = coverage > SUCCESS_THRESHOLD
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
episode = TensorDict(
|
episode = TensorDict(
|
||||||
{
|
{
|
||||||
|
@ -271,6 +271,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||||
("next", "reward"): reward[1:],
|
("next", "reward"): reward[1:],
|
||||||
("next", "done"): done[1:],
|
("next", "done"): done[1:],
|
||||||
|
("next", "success"): success[1:],
|
||||||
},
|
},
|
||||||
batch_size=num_frames - 1,
|
batch_size=num_frames - 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
|
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
|
||||||
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
|
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
|
||||||
#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
|
#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
|
||||||
#SBATCH --time=02:00:00
|
#SBATCH --time=2-00:00:00
|
||||||
#SBATCH --output=/home/rcadene/slurm/%j.out
|
#SBATCH --output=/home/rcadene/slurm/%j.out
|
||||||
#SBATCH --error=/home/rcadene/slurm/%j.err
|
#SBATCH --error=/home/rcadene/slurm/%j.err
|
||||||
#SBATCH --qos=low
|
#SBATCH --qos=low
|
||||||
|
|
Loading…
Reference in New Issue