backup wip
This commit is contained in:
parent
a48789d629
commit
df8f95f157
|
@ -331,16 +331,15 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
continue
|
||||
batch_ = torch.load("/tmp/batch.pth")
|
||||
print(f"STEP {step}")
|
||||
assert torch.equal(batch['index'], batch_['index'])
|
||||
assert torch.equal(batch['episode_index'], batch_['episode_index'])
|
||||
assert torch.equal(batch["index"], batch_["index"])
|
||||
assert torch.equal(batch["episode_index"], batch_["episode_index"])
|
||||
if not torch.equal(batch["observation.image"], batch_["observation.image"]):
|
||||
import cv2
|
||||
|
||||
for b, fn in [(batch, "outputs/img.png"), (batch_, "outputs/img_.png")]:
|
||||
cv2.imwrite(
|
||||
fn,
|
||||
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
),
|
||||
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8),
|
||||
)
|
||||
assert False
|
||||
assert torch.equal(batch["observation.state"], batch_["observation.state"])
|
||||
|
|
|
@ -103,7 +103,7 @@ def update_policy(
|
|||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
step: int = 0
|
||||
step: int = 0,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.perf_counter()
|
||||
|
@ -446,7 +446,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
shuffle=False, # TODO(now)
|
||||
shuffle=False, # TODO(now)
|
||||
)
|
||||
else:
|
||||
shuffle = False # TODO(now)
|
||||
|
@ -485,7 +485,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
step=step
|
||||
step=step,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
|
Loading…
Reference in New Issue