diff --git a/lerobot/common/policies/dot/configuration_dot.py b/lerobot/common/policies/dot/configuration_dot.py index 5b5a7aba..33d23f73 100644 --- a/lerobot/common/policies/dot/configuration_dot.py +++ b/lerobot/common/policies/dot/configuration_dot.py @@ -35,16 +35,27 @@ class DOTConfig(PreTrainedConfig): default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX, } ) + # Not sure if there is a better way to do this with new config system. + override_dataset_stats: bool = False + new_dataset_stats: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "action": {"max": [512.0] * 2, "min": [0.0] * 2}, + "observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16}, + "observation.state": {"max": [512.0] * 2, "min": [0.0] * 2}, + } + ) + # Architecture. vision_backbone: str = "resnet18" pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" pre_norm: bool = True lora_rank: int = 20 - merge_lora: bool = True + merge_lora: bool = False dim_model: int = 128 n_heads: int = 8 diff --git a/lerobot/common/policies/dot/modeling_dot.py b/lerobot/common/policies/dot/modeling_dot.py index 477cfdf8..5f392cbc 100644 --- a/lerobot/common/policies/dot/modeling_dot.py +++ b/lerobot/common/policies/dot/modeling_dot.py @@ -45,7 +45,7 @@ class DOT(nn.Module): self.n_features += self.config.n_obs_steps if self.config.env_state_feature: - self.projections["env_state"] = nn.Linear( + self.projections["environment_state"] = nn.Linear( self.config.env_state_feature.shape[0], self.config.dim_model ) self.n_features += self.config.n_obs_steps @@ -54,7 +54,7 @@ class DOT(nn.Module): obs_mapping = { "images": "observation.images", "state": "observation.state", - "env_state": "observation.environment_state", + "environment_state": "observation.environment_state", } self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names} @@ -154,6 +154,15 @@ class DOTPolicy(PreTrainedPolicy): config.validate_features() self.config = config + if config.override_dataset_stats: + if dataset_stats is None: + dataset_stats = {} + for k, v in config.new_dataset_stats.items(): + if k not in dataset_stats: + dataset_stats[k] = {} + for k1, v1 in v.items(): + dataset_stats[k][k1] = torch.tensor(v1) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats @@ -164,11 +173,12 @@ class DOTPolicy(PreTrainedPolicy): self.model = DOT(self.config) - self.return_every_n = self.config.return_every_n self.state_noise = self.config.state_noise self.crop_scale = self.config.crop_scale self.alpha = self.config.alpha self.inference_horizon = self.config.inference_horizon + self.return_every_n = self.config.return_every_n + self.predict_every_n = self.config.predict_every_n # Inference action chunking and observation queues self._old_predictions = None @@ -196,12 +206,14 @@ class DOTPolicy(PreTrainedPolicy): config.rescale_shape, interpolation=InterpolationMode.NEAREST ) - self.predict_every_n = self.config.predict_every_n self.step = 0 + self.last_action = None def reset(self): self._old_predictions = None self._input_buffers = {} + self.last_action = None + self.step = 0 def get_optim_params(self) -> dict: return self.model.parameters() @@ -346,9 +358,9 @@ class DOTPolicy(PreTrainedPolicy): return loss_dict @classmethod - def from_pretrained(cls, *args, **kwargs): + def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs): """Load model from pretrained checkpoint and merge LoRA after loading""" - policy = super().from_pretrained(*args, **kwargs) + policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs) if getattr(policy.config, "merge_lora", False): print("Merging LoRA after loading pretrained model...") @@ -368,8 +380,9 @@ class LoRAConv2d(nn.Module): fan_in = in_channels * kh * kw # LoRA parameters - self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank))) - self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in))) + std = 1 / math.sqrt(fan_in) + self.lora_A = nn.Parameter(torch.normal(0, std, (out_channels, rank))) + self.lora_B = nn.Parameter(torch.normal(0, std, (rank, fan_in))) def forward(self, x): lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)