diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index dec8b465..4015492d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch +import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 911a265b..8a6bcfbd 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -122,12 +122,12 @@ def predict_action(observation, policy, device, use_amp): def init_keyboard_listener(assign_rewards=False): """ - Initializes a keyboard listener to enable early termination of an episode - or environment reset by pressing the right arrow key ('->'). This may require + Initializes a keyboard listener to enable early termination of an episode + or environment reset by pressing the right arrow key ('->'). This may require sudo permissions to allow the terminal to monitor keyboard events. Args: - assign_rewards (bool): If True, allows annotating the collected trajectory + assign_rewards (bool): If True, allows annotating the collected trajectory with a binary reward at the end of the episode to indicate success. """ events = {} diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 8dea68c6..86fa90f2 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,6 +22,7 @@ from pprint import pformat import hydra import torch import torch.nn as nn +import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm -import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger