backup wip
This commit is contained in:
parent
53b36dcaab
commit
a48789d629
|
@ -29,6 +29,7 @@ TODO(alexander-soare): Use batch-first throughout.
|
|||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
|
@ -310,7 +311,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
def forward(self, batch: dict[str, Tensor], step) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
|
@ -328,7 +329,21 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if not os.path.exists("/tmp/mutex.txt"):
|
||||
sleep(0.01)
|
||||
continue
|
||||
batch.update(torch.load("/tmp/batch.pth"))
|
||||
batch_ = torch.load("/tmp/batch.pth")
|
||||
print(f"STEP {step}")
|
||||
assert torch.equal(batch['index'], batch_['index'])
|
||||
assert torch.equal(batch['episode_index'], batch_['episode_index'])
|
||||
if not torch.equal(batch["observation.image"], batch_["observation.image"]):
|
||||
import cv2
|
||||
for b, fn in [(batch, "outputs/img.png"), (batch_, "outputs/img_.png")]:
|
||||
cv2.imwrite(
|
||||
fn,
|
||||
(b["observation.image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
),
|
||||
)
|
||||
assert False
|
||||
assert torch.equal(batch["observation.state"], batch_["observation.state"])
|
||||
os.remove("/tmp/mutex.txt")
|
||||
break
|
||||
|
||||
|
@ -343,12 +358,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# # Apply random image augmentations.
|
||||
# if self.config.max_random_shift_ratio > 0:
|
||||
# observations["observation.image"] = flatten_forward_unflatten(
|
||||
# partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
# observations["observation.image"],
|
||||
# )
|
||||
# Apply random image augmentations.
|
||||
if self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||
# the latent consistency loss and TD loss.
|
||||
|
|
|
@ -103,13 +103,14 @@ def update_policy(
|
|||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
step: int = 0
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
output_dict = policy.forward(batch, step)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
@ -445,7 +446,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
shuffle=False, # TODO(now)
|
||||
)
|
||||
else:
|
||||
shuffle = False # TODO(now)
|
||||
|
@ -484,6 +485,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
step=step
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
|
Loading…
Reference in New Issue