Fixed small issues
This commit is contained in:
parent
78d3ba8db2
commit
489cdc2ace
|
@ -35,16 +35,27 @@ class DOTConfig(PreTrainedConfig):
|
|||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Not sure if there is a better way to do this with new config system.
|
||||
override_dataset_stats: bool = False
|
||||
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
|
||||
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture.
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
pre_norm: bool = True
|
||||
lora_rank: int = 20
|
||||
merge_lora: bool = True
|
||||
merge_lora: bool = False
|
||||
|
||||
dim_model: int = 128
|
||||
n_heads: int = 8
|
||||
|
|
|
@ -45,7 +45,7 @@ class DOT(nn.Module):
|
|||
self.n_features += self.config.n_obs_steps
|
||||
|
||||
if self.config.env_state_feature:
|
||||
self.projections["env_state"] = nn.Linear(
|
||||
self.projections["environment_state"] = nn.Linear(
|
||||
self.config.env_state_feature.shape[0], self.config.dim_model
|
||||
)
|
||||
self.n_features += self.config.n_obs_steps
|
||||
|
@ -54,7 +54,7 @@ class DOT(nn.Module):
|
|||
obs_mapping = {
|
||||
"images": "observation.images",
|
||||
"state": "observation.state",
|
||||
"env_state": "observation.environment_state",
|
||||
"environment_state": "observation.environment_state",
|
||||
}
|
||||
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
|
||||
|
||||
|
@ -154,6 +154,15 @@ class DOTPolicy(PreTrainedPolicy):
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.override_dataset_stats:
|
||||
if dataset_stats is None:
|
||||
dataset_stats = {}
|
||||
for k, v in config.new_dataset_stats.items():
|
||||
if k not in dataset_stats:
|
||||
dataset_stats[k] = {}
|
||||
for k1, v1 in v.items():
|
||||
dataset_stats[k][k1] = torch.tensor(v1)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
|
@ -164,11 +173,12 @@ class DOTPolicy(PreTrainedPolicy):
|
|||
|
||||
self.model = DOT(self.config)
|
||||
|
||||
self.return_every_n = self.config.return_every_n
|
||||
self.state_noise = self.config.state_noise
|
||||
self.crop_scale = self.config.crop_scale
|
||||
self.alpha = self.config.alpha
|
||||
self.inference_horizon = self.config.inference_horizon
|
||||
self.return_every_n = self.config.return_every_n
|
||||
self.predict_every_n = self.config.predict_every_n
|
||||
|
||||
# Inference action chunking and observation queues
|
||||
self._old_predictions = None
|
||||
|
@ -196,12 +206,14 @@ class DOTPolicy(PreTrainedPolicy):
|
|||
config.rescale_shape, interpolation=InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
self.predict_every_n = self.config.predict_every_n
|
||||
self.step = 0
|
||||
self.last_action = None
|
||||
|
||||
def reset(self):
|
||||
self._old_predictions = None
|
||||
self._input_buffers = {}
|
||||
self.last_action = None
|
||||
self.step = 0
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
@ -346,9 +358,9 @@ class DOTPolicy(PreTrainedPolicy):
|
|||
return loss_dict
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
|
||||
"""Load model from pretrained checkpoint and merge LoRA after loading"""
|
||||
policy = super().from_pretrained(*args, **kwargs)
|
||||
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
|
||||
|
||||
if getattr(policy.config, "merge_lora", False):
|
||||
print("Merging LoRA after loading pretrained model...")
|
||||
|
@ -368,8 +380,9 @@ class LoRAConv2d(nn.Module):
|
|||
fan_in = in_channels * kh * kw
|
||||
|
||||
# LoRA parameters
|
||||
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
|
||||
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
|
||||
std = 1 / math.sqrt(fan_in)
|
||||
self.lora_A = nn.Parameter(torch.normal(0, std, (out_channels, rank)))
|
||||
self.lora_B = nn.Parameter(torch.normal(0, std, (rank, fan_in)))
|
||||
|
||||
def forward(self, x):
|
||||
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
|
||||
|
|
Loading…
Reference in New Issue