fix: server predicts multiple actions for a given observation, VLA-like

This commit is contained in:
Francesco Capuano 2025-04-15 11:59:59 +02:00
parent 485d64c8f4
commit b9b7492132
1 changed files with 12 additions and 13 deletions

View File

@ -17,14 +17,10 @@ def get_device():
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def __init__(self, policy: PreTrainedPolicy = None):
# TODO: Add code for loading and using policy for inference
self.policy = policy
# TODO: Add device specification for policy inference
# self.observation = None
self.observation = async_inference_pb2.Observation(
transfer_state=2,
data=np.array([1], dtype=np.float32).tobytes()
)
# TODO: Add device specification for policy inference at init
self.observation = None
self.lock = threading.Lock()
# keeping a list of all observations received from the robot client
@ -62,7 +58,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def StreamActions(self, request, context):
"""Stream actions to the robot client"""
client_id = context.peer()
print(f"Client {client_id} connected for action streaming")
# print(f"Client {client_id} connected for action streaming")
with self.lock:
yield self._generate_and_queue_action(self.observation)
@ -80,17 +76,20 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
raise NotImplementedError("Not implemented")
def _generate_and_queue_action(self, observation):
"""Generate an action based on the observation (dummy logic).
"""Generate a buffer of actions based on the observation (dummy logic).
Mainly used for testing purposes"""
# Debinarize the observation data
time.sleep(2)
# Debinarize observation data
data = np.frombuffer(
observation.data,
dtype=np.float32
)
# dummy transform on the observation data
action = (data * 1.4).sum()
# map action to bytes
action_data = np.array([action], dtype=np.float32).tobytes()
action_content = (data * 2).sum().item()
action_data = (action_content * np.ones(
shape=(10, 5), # 10 5-dimensional actions
dtype=np.float32
)).tobytes()
action = async_inference_pb2.Action(
transfer_state=observation.transfer_state,