Added normalization schemes and style checks
This commit is contained in:
parent
b0e2fcdba7
commit
04da4dd3e3
|
@ -25,13 +25,13 @@ from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
import wandb
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,6 @@ import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassifierConfig:
|
class ClassifierConfig:
|
||||||
|
|
|
@ -23,9 +23,11 @@ class ClassifierOutput:
|
||||||
self.hidden_states = hidden_states
|
self.hidden_states = hidden_states
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"ClassifierOutput(logits={self.logits}, "
|
return (
|
||||||
f"probabilities={self.probabilities}, "
|
f"ClassifierOutput(logits={self.logits}, "
|
||||||
f"hidden_states={self.hidden_states})")
|
f"probabilities={self.probabilities}, "
|
||||||
|
f"hidden_states={self.hidden_states})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Classifier(
|
class Classifier(
|
||||||
|
@ -74,7 +76,7 @@ class Classifier(
|
||||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported CNN architecture")
|
raise ValueError("Unsupported CNN architecture")
|
||||||
|
|
||||||
self.encoder = self.encoder.to(self.config.device)
|
self.encoder = self.encoder.to(self.config.device)
|
||||||
|
|
||||||
def _freeze_encoder(self) -> None:
|
def _freeze_encoder(self) -> None:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -26,4 +26,4 @@ class HILSerlPolicy(
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
repo_url="https://github.com/huggingface/lerobot",
|
||||||
tags=["robotics", "hilserl"],
|
tags=["robotics", "hilserl"],
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||||
for running training on the real robot.
|
for running training on the real robot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
|
||||||
|
|
||||||
|
|
||||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||||
"""Run a batched policy rollout on the real robot.
|
"""Run a batched policy rollout on the real robot.
|
||||||
|
|
||||||
The return dictionary contains:
|
The return dictionary contains:
|
||||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||||
|
@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||||
extraneous elements from the sequences above.
|
extraneous elements from the sequences above.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
robot: The robot class that defines the interface with the real robot.
|
robot: The robot class that defines the interface with the real robot.
|
||||||
policy: The policy. Must be a PyTorch nn module.
|
policy: The policy. Must be a PyTorch nn module.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||||
# policy.reset()
|
# policy.reset()
|
||||||
|
|
||||||
# Get observation from real robot
|
# Get observation from real robot
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
|
@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
elif policy.name == "sac":
|
elif policy.name == "sac":
|
||||||
optimizer = torch.optim.Adam([
|
optimizer = torch.optim.Adam(
|
||||||
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
|
[
|
||||||
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
|
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||||
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
|
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||||
])
|
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||||
lr_scheduler = None
|
]
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
elif cfg.policy.name == "vqbet":
|
elif cfg.policy.name == "vqbet":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||||
|
|
|
@ -22,7 +22,6 @@ from pprint import pformat
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import wandb
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
|
||||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import wandb
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
|
|
Loading…
Reference in New Issue