fixes and updated comments

This commit is contained in:
Michel Aractingi 2024-11-26 09:46:59 +00:00
parent 15090c2544
commit 16edbbdeee
3 changed files with 223 additions and 202 deletions

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class TDMPC2Config: class TDMPC2Config:
"""Configuration class for TDMPCPolicy. """Configuration class for TDMPC2Policy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations. camera observations.
@ -77,18 +77,9 @@ class TDMPC2Config:
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation. is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss. reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss. value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss. consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time- temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step. current time step.
@ -126,9 +117,12 @@ class TDMPC2Config:
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
latent_dim: int = 512 latent_dim: int = 512
q_ensemble_size: int = 5 q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512 mlp_dim: int = 512
# Reinforcement learning. # Reinforcement learning.
discount: float = 0.9 discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor # actor
log_std_min: float = -10 log_std_min: float = -10

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, # Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
# and The HuggingFace Inc. team. All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,11 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World. """Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
The comments in this code may sometimes refer to these references: We refer to the main paper and codebase:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
""" """
# ruff: noqa: N806 # ruff: noqa: N806
@ -38,8 +38,16 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy( class TDMPC2Policy(
nn.Module, nn.Module,
@ -48,22 +56,7 @@ class TDMPC2Policy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc2"], tags=["robotics", "tdmpc2"],
): ):
"""Implementation of TD-MPC2 learning + inference. """Implementation of TD-MPC2 learning + inference."""
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
`lerobot/configs/policy/tdmpc2_pusht_keypoints.yaml`.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
name = "tdmpc2" name = "tdmpc2"
@ -83,6 +76,8 @@ class TDMPC2Policy(
config = TDMPC2Config() config = TDMPC2Config()
self.config = config self.config = config
self.model = TDMPC2WorldModel(config) self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
@ -109,7 +104,9 @@ class TDMPC2Policy(
self._use_env_state = True self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum) self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset() self.reset()
@ -204,7 +201,7 @@ class TDMPC2Policy(
for t in range(self.config.horizon): for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be # Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM. # helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0] pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t]) _z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
@ -249,13 +246,16 @@ class TDMPC2Policy(
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True) score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim) # (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt( std = torch.sqrt(
torch.sum( torch.sum(
einops.rearrange(score, "n b -> n b 1") einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2, * (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1, dim=1,
) / (score.sum(0) + 1e-9) )
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std) ).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps. # Keep track of the mean for warm-starting subsequent steps.
@ -287,10 +287,10 @@ class TDMPC2Policy(
G += running_discount * reward G += running_discount * reward
running_discount *= self.config.discount running_discount *= self.config.discount
#next_action = self.model.pi(z)[0] # (batch, action_dim) # next_action = self.model.pi(z)[0] # (batch, action_dim)
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) # terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg') return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -358,7 +358,8 @@ class TDMPC2Policy(
pi = self.model.pi(z_targets)[0] pi = self.model.pi(z_targets)[0]
td_targets = ( td_targets = (
reward reward
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() + self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
) )
# Compute losses. # Compute losses.
@ -429,10 +430,12 @@ class TDMPC2Policy(
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = ( pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2)) (self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
* rho * rho
# * temporal_loss_coeffs # * temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions. # `action_preds` depends on the first observation and the actions.
@ -470,6 +473,7 @@ class TDMPC2Policy(
"""Update the target model's using polyak averaging.""" """Update the target model's using polyak averaging."""
self.model.update_target_Q() self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module): class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2.""" """Latent dynamics model used in TD-MPC2."""
@ -480,28 +484,39 @@ class TDMPC2WorldModel(nn.Module):
self._encoder = TDMPC2ObservationEncoder(config) self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head # Define latent dynamics head
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._dynamics = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head # Define reward head
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), self._reward = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1))) nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head # Define policy head
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim), self._pi = nn.Sequential(
NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0])) nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions # Define ensemble of Q functions
self._Qs = nn.ModuleList( self._Qs = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout), NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim), NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)) nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
) for _ in range(config.q_ensemble_size) )
for _ in range(config.q_ensemble_size)
] ]
) )
@ -520,6 +535,7 @@ class TDMPC2WorldModel(nn.Module):
Custom weight initializations proposed in TD-MPC2. Custom weight initializations proposed in TD-MPC2.
""" """
def _apply_fn(m): def _apply_fn(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
@ -529,7 +545,7 @@ class TDMPC2WorldModel(nn.Module):
for i, p in enumerate(m): for i, p in enumerate(m):
if p.dim() == 3: # Linear if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i+1], 0) # Bias nn.init.constant_(m[i + 1], 0) # Bias
self.apply(_apply_fn) self.apply(_apply_fn)
@ -641,7 +657,7 @@ class TDMPC2WorldModel(nn.Module):
Soft-update target Q-networks using Polyak averaging. Soft-update target Q-networks using Polyak averaging.
""" """
with torch.no_grad(): with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum) p_target.data.lerp_(p.data, self.config.target_model_momentum)
@ -672,7 +688,7 @@ class TDMPC2ObservationEncoder(nn.Module):
) )
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key]) dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode(): with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:] out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend( encoder_module.extend(
nn.Sequential( nn.Sequential(
nn.Flatten(), nn.Flatten(),
@ -680,28 +696,30 @@ class TDMPC2ObservationEncoder(nn.Module):
) )
) )
elif "observation.state" in config.input_shapes: elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList() encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0 assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1): for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) encoder_module.append(
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
encoder_module = nn.Sequential(*encoder_module) )
encoder_module.append(
elif "observation.environment_state" in config.input_shapes: NormedLinear(
encoder_module = nn.ModuleList() config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) )
assert config.num_enc_layers > 0 )
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module) encoder_module = nn.Sequential(*encoder_module)
else: else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.") raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key] = encoder_module encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict) self.encoder = nn.ModuleDict(encoder_dict)
@ -714,9 +732,11 @@ class TDMPC2ObservationEncoder(nn.Module):
feat = [] feat = []
for obs_key in self.config.input_shapes: for obs_key in self.config.input_shapes:
if "observation.image" in obs_key: if "observation.image" in obs_key:
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key])) feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else: else:
feat.append(self.encoder[obs_key](obs_dict[obs_key])) feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)

View File

@ -13,7 +13,7 @@ class Ensemble(nn.Module):
super().__init__() super().__init__()
modules = nn.ModuleList(modules) modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules) fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules) self._repr = str(modules)
@ -21,7 +21,8 @@ class Ensemble(nn.Module):
return self.vmap([p for p in self.params], (), *args, **kwargs) return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self): def __repr__(self):
return 'Vectorized ' + self._repr return "Vectorized " + self._repr
class SimNorm(nn.Module): class SimNorm(nn.Module):
""" """
@ -48,7 +49,7 @@ class NormedLinear(nn.Linear):
Linear layer with LayerNorm, activation, and optionally dropout. Linear layer with LayerNorm, activation, and optionally dropout.
""" """
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features) self.ln = nn.LayerNorm(self.out_features)
self.act = act self.act = act
@ -62,16 +63,21 @@ class NormedLinear(nn.Linear):
def __repr__(self): def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\ return (
f"out_features={self.out_features}, "\ f"NormedLinear(in_features={self.in_features}, "
f"bias={self.bias is not None}{repr_dropout}, "\ f"out_features={self.out_features}, "
f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})" f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg): def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets.""" """Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1) pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg) target = two_hot(target, cfg)
import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True) return -(target * pred).sum(-1, keepdim=True)
@ -135,14 +141,15 @@ def two_hot(x, cfg):
return x return x
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symlog(x) return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
return soft_two_hot return soft_two_hot
def two_hot_inv(x, bins): def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars.""" """Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0] num_bins = bins.shape[0]