diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py index b2eb00fa..841273bf 100644 --- a/lerobot/scripts/server/policy_server.py +++ b/lerobot/scripts/server/policy_server.py @@ -17,15 +17,11 @@ def get_device(): class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def __init__(self, policy: PreTrainedPolicy = None): + # TODO: Add code for loading and using policy for inference self.policy = policy + # TODO: Add device specification for policy inference at init + self.observation = None - # TODO: Add device specification for policy inference - # self.observation = None - self.observation = async_inference_pb2.Observation( - transfer_state=2, - data=np.array([1], dtype=np.float32).tobytes() - ) - self.lock = threading.Lock() # keeping a list of all observations received from the robot client self.observations = [] @@ -62,7 +58,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def StreamActions(self, request, context): """Stream actions to the robot client""" client_id = context.peer() - print(f"Client {client_id} connected for action streaming") + # print(f"Client {client_id} connected for action streaming") with self.lock: yield self._generate_and_queue_action(self.observation) @@ -80,17 +76,20 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): raise NotImplementedError("Not implemented") def _generate_and_queue_action(self, observation): - """Generate an action based on the observation (dummy logic). + """Generate a buffer of actions based on the observation (dummy logic). Mainly used for testing purposes""" - # Debinarize the observation data + time.sleep(2) + # Debinarize observation data data = np.frombuffer( observation.data, dtype=np.float32 ) # dummy transform on the observation data - action = (data * 1.4).sum() - # map action to bytes - action_data = np.array([action], dtype=np.float32).tobytes() + action_content = (data * 2).sum().item() + action_data = (action_content * np.ones( + shape=(10, 5), # 10 5-dimensional actions + dtype=np.float32 + )).tobytes() action = async_inference_pb2.Action( transfer_state=observation.transfer_state,