fix: server predicts multiple actions for a given observation, VLA-like
This commit is contained in:
parent
485d64c8f4
commit
b9b7492132
|
@ -17,14 +17,10 @@ def get_device():
|
|||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
def __init__(self, policy: PreTrainedPolicy = None):
|
||||
# TODO: Add code for loading and using policy for inference
|
||||
self.policy = policy
|
||||
|
||||
# TODO: Add device specification for policy inference
|
||||
# self.observation = None
|
||||
self.observation = async_inference_pb2.Observation(
|
||||
transfer_state=2,
|
||||
data=np.array([1], dtype=np.float32).tobytes()
|
||||
)
|
||||
# TODO: Add device specification for policy inference at init
|
||||
self.observation = None
|
||||
|
||||
self.lock = threading.Lock()
|
||||
# keeping a list of all observations received from the robot client
|
||||
|
@ -62,7 +58,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
def StreamActions(self, request, context):
|
||||
"""Stream actions to the robot client"""
|
||||
client_id = context.peer()
|
||||
print(f"Client {client_id} connected for action streaming")
|
||||
# print(f"Client {client_id} connected for action streaming")
|
||||
|
||||
with self.lock:
|
||||
yield self._generate_and_queue_action(self.observation)
|
||||
|
@ -80,17 +76,20 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def _generate_and_queue_action(self, observation):
|
||||
"""Generate an action based on the observation (dummy logic).
|
||||
"""Generate a buffer of actions based on the observation (dummy logic).
|
||||
Mainly used for testing purposes"""
|
||||
# Debinarize the observation data
|
||||
time.sleep(2)
|
||||
# Debinarize observation data
|
||||
data = np.frombuffer(
|
||||
observation.data,
|
||||
dtype=np.float32
|
||||
)
|
||||
# dummy transform on the observation data
|
||||
action = (data * 1.4).sum()
|
||||
# map action to bytes
|
||||
action_data = np.array([action], dtype=np.float32).tobytes()
|
||||
action_content = (data * 2).sum().item()
|
||||
action_data = (action_content * np.ones(
|
||||
shape=(10, 5), # 10 5-dimensional actions
|
||||
dtype=np.float32
|
||||
)).tobytes()
|
||||
|
||||
action = async_inference_pb2.Action(
|
||||
transfer_state=observation.transfer_state,
|
||||
|
|
Loading…
Reference in New Issue