Fixed small issues

This commit is contained in:
Ilia 2025-02-09 19:09:41 +07:00
parent 78d3ba8db2
commit 489cdc2ace
2 changed files with 33 additions and 9 deletions

View File

@ -35,16 +35,27 @@ class DOTConfig(PreTrainedConfig):
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Not sure if there is a better way to do this with new config system.
override_dataset_stats: bool = False
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
}
)
# Architecture.
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
pre_norm: bool = True
lora_rank: int = 20
merge_lora: bool = True
merge_lora: bool = False
dim_model: int = 128
n_heads: int = 8

View File

@ -45,7 +45,7 @@ class DOT(nn.Module):
self.n_features += self.config.n_obs_steps
if self.config.env_state_feature:
self.projections["env_state"] = nn.Linear(
self.projections["environment_state"] = nn.Linear(
self.config.env_state_feature.shape[0], self.config.dim_model
)
self.n_features += self.config.n_obs_steps
@ -54,7 +54,7 @@ class DOT(nn.Module):
obs_mapping = {
"images": "observation.images",
"state": "observation.state",
"env_state": "observation.environment_state",
"environment_state": "observation.environment_state",
}
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
@ -154,6 +154,15 @@ class DOTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
if config.override_dataset_stats:
if dataset_stats is None:
dataset_stats = {}
for k, v in config.new_dataset_stats.items():
if k not in dataset_stats:
dataset_stats[k] = {}
for k1, v1 in v.items():
dataset_stats[k][k1] = torch.tensor(v1)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
@ -164,11 +173,12 @@ class DOTPolicy(PreTrainedPolicy):
self.model = DOT(self.config)
self.return_every_n = self.config.return_every_n
self.state_noise = self.config.state_noise
self.crop_scale = self.config.crop_scale
self.alpha = self.config.alpha
self.inference_horizon = self.config.inference_horizon
self.return_every_n = self.config.return_every_n
self.predict_every_n = self.config.predict_every_n
# Inference action chunking and observation queues
self._old_predictions = None
@ -196,12 +206,14 @@ class DOTPolicy(PreTrainedPolicy):
config.rescale_shape, interpolation=InterpolationMode.NEAREST
)
self.predict_every_n = self.config.predict_every_n
self.step = 0
self.last_action = None
def reset(self):
self._old_predictions = None
self._input_buffers = {}
self.last_action = None
self.step = 0
def get_optim_params(self) -> dict:
return self.model.parameters()
@ -346,9 +358,9 @@ class DOTPolicy(PreTrainedPolicy):
return loss_dict
@classmethod
def from_pretrained(cls, *args, **kwargs):
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
"""Load model from pretrained checkpoint and merge LoRA after loading"""
policy = super().from_pretrained(*args, **kwargs)
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
if getattr(policy.config, "merge_lora", False):
print("Merging LoRA after loading pretrained model...")
@ -368,8 +380,9 @@ class LoRAConv2d(nn.Module):
fan_in = in_channels * kh * kw
# LoRA parameters
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
std = 1 / math.sqrt(fan_in)
self.lora_A = nn.Parameter(torch.normal(0, std, (out_channels, rank)))
self.lora_B = nn.Parameter(torch.normal(0, std, (rank, fan_in)))
def forward(self, x):
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)