Source code for luckyrobots.client

"""
LuckyEngine gRPC client.

Uses checked-in Python stubs generated from the `.proto` files under
`src/luckyrobots/grpc/proto/`.
"""

from __future__ import annotations

import logging
import math
import statistics
import time
from types import SimpleNamespace
from typing import Any, Optional

import grpc  # type: ignore

logger = logging.getLogger("luckyrobots.client")

try:
    from .grpc.generated import agent_pb2  # type: ignore
    from .grpc.generated import agent_pb2_grpc  # type: ignore
    from .grpc.generated import common_pb2  # type: ignore
    from .grpc.generated import debug_pb2  # type: ignore
    from .grpc.generated import debug_pb2_grpc  # type: ignore
    from .grpc.generated import mujoco_pb2  # type: ignore
    from .grpc.generated import mujoco_pb2_grpc  # type: ignore
    from .grpc.generated import scene_pb2  # type: ignore
    from .grpc.generated import scene_pb2_grpc  # type: ignore
except Exception as e:  # pragma: no cover
    raise ImportError(
        "Missing generated gRPC stubs. Regenerate them from the protos in "
        "src/luckyrobots/grpc/proto into src/luckyrobots/grpc/generated."
    ) from e

from .models import ObservationResponse
from .models.observation import CameraFrame
from .models.benchmark import BenchmarkResult
from . import sim_contract


[docs] class GrpcConnectionError(Exception): """Raised when gRPC connection fails."""
[docs] def __init__(self, message: str): super().__init__(message) logger.warning("gRPC connection error: %s", message)
[docs] class LuckyEngineClient: """ Client for connecting to the LuckyEngine gRPC server. Provides access to gRPC services for RL training: - AgentService: stepping, resets - SceneService: simulation mode control - MujocoService: health checks, joint state Usage: client = LuckyEngineClient(host="127.0.0.1", port=50051) client.connect() client.wait_for_server() schema = client.get_agent_schema() obs = client.step(actions=[0.0] * 12) client.close() """
[docs] def __init__( self, host: str = "127.0.0.1", port: int = 50051, timeout: float = 5.0, *, robot_name: Optional[str] = None, ) -> None: """ Initialize the LuckyEngine gRPC client. Args: host: gRPC server host address. port: gRPC server port. timeout: Default timeout for RPC calls in seconds. robot_name: Default robot name for calls that require it. """ self.host = host self.port = port self.timeout = timeout self._robot_name = robot_name self._channel = None # Service stubs (populated after connect) self._scene = None self._mujoco = None self._agent = None self._debug = None # Cached agent schemas: agent_name -> (observation_names, action_names) self._schema_cache: dict[str, tuple[list[str], list[str]]] = {} # Camera requests included on every Step RPC (configured via configure_cameras). self._camera_requests: list = [] # Protobuf modules (for discoverability + explicit imports). self._pb = SimpleNamespace( common=common_pb2, scene=scene_pb2, mujoco=mujoco_pb2, agent=agent_pb2, debug=debug_pb2, )
[docs] def connect(self) -> None: """ Connect to the LuckyEngine gRPC server. Raises: GrpcConnectionError: If connection fails. """ target = f"{self.host}:{self.port}" logger.info(f"Connecting to LuckyEngine gRPC server at {target}") self._channel = grpc.insecure_channel(target) # Create service stubs self._scene = scene_pb2_grpc.SceneServiceStub(self._channel) self._mujoco = mujoco_pb2_grpc.MujocoServiceStub(self._channel) self._agent = agent_pb2_grpc.AgentServiceStub(self._channel) self._debug = debug_pb2_grpc.DebugServiceStub(self._channel) logger.info(f"Channel opened to {target} (server not verified yet)")
[docs] def close(self) -> None: """Close the gRPC channel.""" if self._channel is not None: try: self._channel.close() except Exception as e: logger.debug(f"Error closing gRPC channel: {e}") self._channel = None self._scene = None self._mujoco = None self._agent = None self._debug = None logger.info("gRPC channel closed")
[docs] def is_connected(self) -> bool: """Check if the client is connected.""" return self._channel is not None
[docs] def health_check(self, timeout: Optional[float] = None) -> bool: """ Perform a health check by calling GetMujocoInfo. Args: timeout: Timeout in seconds (uses default if None). Returns: True if server responds, False otherwise. """ if not self.is_connected(): return False timeout = timeout or self.timeout try: self._mujoco.GetMujocoInfo( self.pb.mujoco.GetMujocoInfoRequest(robot_name=self._robot_name or ""), timeout=timeout, ) return True except Exception as e: logger.debug(f"Health check failed: {e}") return False
[docs] def wait_for_server( self, timeout: float = 30.0, poll_interval: float = 0.5 ) -> bool: """ Wait for the gRPC server to become available. Args: timeout: Maximum time to wait in seconds. poll_interval: Time between connection attempts. Returns: True if server became available, False if timeout. """ start = time.perf_counter() while time.perf_counter() - start < timeout: if not self.is_connected(): try: self.connect() except Exception: pass if self.health_check(timeout=min(poll_interval, timeout - (time.perf_counter() - start))): logger.info(f"Connected to LuckyEngine gRPC server at {self.host}:{self.port}") return True time.sleep(poll_interval) return False
@property def pb(self) -> Any: """Access protobuf modules grouped by domain (e.g., `client.pb.scene`).""" return self._pb @property def robot_name(self) -> Optional[str]: """Default robot name used by calls that accept an optional robot_name.""" return self._robot_name
[docs] def set_robot_name(self, robot_name: str) -> None: """Set the default robot name used by calls that accept an optional robot_name.""" self._robot_name = robot_name
@property def scene(self) -> Any: """SceneService stub.""" if self._scene is None: raise GrpcConnectionError("Not connected. Call connect() first.") return self._scene @property def mujoco(self) -> Any: """MujocoService stub.""" if self._mujoco is None: raise GrpcConnectionError("Not connected. Call connect() first.") return self._mujoco @property def agent(self) -> Any: """AgentService stub.""" if self._agent is None: raise GrpcConnectionError("Not connected. Call connect() first.") return self._agent @property def debug(self) -> Any: """DebugService stub.""" if self._debug is None: raise GrpcConnectionError("Not connected. Call connect() first.") return self._debug # ── Camera configuration ──
[docs] def configure_cameras(self, cameras: list[dict]) -> None: """Configure cameras to capture on every Step RPC. Args: cameras: List of camera configs. Each dict has keys: name: Camera entity name in the scene. width: Desired image width (0 = native resolution). height: Desired image height (0 = native resolution). """ self._camera_requests = [ self.pb.agent.GetCameraFrameRequest( name=c["name"], width=c.get("width", 0), height=c.get("height", 0), ) for c in cameras ]
# ── MujocoService RPCs ──
[docs] def get_joint_state(self, robot_name: str = "", timeout: Optional[float] = None): """Get current joint state (positions and velocities). Args: robot_name: Robot entity name (uses default if empty). timeout: RPC timeout in seconds. Returns: GetJointStateResponse with state.positions (qpos) and state.velocities (qvel). """ timeout = timeout or self.timeout robot_name = robot_name or self._robot_name if not robot_name: raise ValueError("robot_name is required") return self.mujoco.GetJointState( self.pb.mujoco.GetJointStateRequest(robot_name=robot_name), timeout=timeout, )
[docs] def get_mujoco_info(self, robot_name: str = "", timeout: Optional[float] = None): """Get MuJoCo model information (joint names, limits, etc.).""" timeout = timeout or self.timeout robot_name = robot_name or self._robot_name if not robot_name: raise ValueError("robot_name is required") return self.mujoco.GetMujocoInfo( self.pb.mujoco.GetMujocoInfoRequest(robot_name=robot_name), timeout=timeout, )
# ── AgentService RPCs ──
[docs] def get_agent_schema(self, agent_name: str = "", timeout: Optional[float] = None): """Get agent schema (observation/action sizes and names). The schema is cached for subsequent step() calls to enable named access to observation values. Args: agent_name: Agent name (empty = default agent). timeout: RPC timeout. Returns: GetAgentSchemaResponse with schema containing observation_names, action_names, observation_size, and action_size. """ timeout = timeout or self.timeout resp = self.agent.GetAgentSchema( self.pb.agent.GetAgentSchemaRequest(agent_name=agent_name), timeout=timeout, ) # Cache the schema for named observation access schema = getattr(resp, "schema", None) if schema is not None: cache_key = agent_name or "agent_0" obs_names = list(schema.observation_names) if schema.observation_names else [] action_names = list(schema.action_names) if schema.action_names else [] self._schema_cache[cache_key] = (obs_names, action_names) logger.debug( "Cached schema for %s: %d obs names, %d action names", cache_key, len(obs_names), len(action_names), ) return resp
[docs] def reset_agent( self, agent_name: str = "", randomization_cfg: Optional[Any] = None, timeout: Optional[float] = None, ): """ Reset a specific agent. Args: agent_name: Agent logical name. Empty string means default agent. randomization_cfg: Optional simulation contract config for this reset. timeout: Timeout in seconds (uses default if None). Returns: ResetAgentResponse with success and message fields. """ timeout = timeout or self.timeout request_kwargs = {"agent_name": agent_name} if randomization_cfg is not None: contract = sim_contract.to_proto(self.pb.agent, randomization_cfg) request_kwargs["simulation_contract"] = contract return self.agent.ResetAgent( self.pb.agent.ResetAgentRequest(**request_kwargs), timeout=timeout, )
[docs] def step( self, actions: list[float], agent_name: str = "", step_timeout_s: float = 0.0, timeout: Optional[float] = None, ) -> ObservationResponse: """ Synchronous RL step: apply action, wait for physics, return observation. Args: actions: Action vector to apply for this step. agent_name: Agent name (empty = default agent). step_timeout_s: Server-side timeout for waiting for the physics step (seconds). 0 means use server default. timeout: RPC timeout in seconds. Returns: ObservationResponse with observation after physics step. """ timeout = timeout or self.timeout try: resp = self.agent.Step( self.pb.agent.StepRequest( agent_name=agent_name, actions=actions, timeout_s=step_timeout_s, camera_requests=self._camera_requests, ), timeout=timeout, ) except grpc.RpcError as e: if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: raise RuntimeError( f"Client-side gRPC timeout ({timeout}s): the server did not respond in time. " "This usually means the engine is frozen or the network is unreachable." ) from e raise if not resp.success: raise RuntimeError( f"Server-side physics timeout: {resp.message} " f"(server waited up to its configured timeout for the physics step to complete)" ) agent_frame = resp.observation observations = list(agent_frame.observations) if agent_frame.observations else [] actions_out = list(agent_frame.actions) if agent_frame.actions else [] timestamp_ms = getattr(agent_frame, "timestamp_ms", 0) frame_number = getattr(agent_frame, "frame_number", 0) cache_key = agent_name or "agent_0" obs_names, action_names = self._schema_cache.get(cache_key, (None, None)) camera_frames = [ CameraFrame( name=nf.name, data=bytes(nf.frame.data), width=nf.frame.width, height=nf.frame.height, channels=nf.frame.channels, frame_number=nf.frame.frame_number, ) for nf in resp.camera_frames ] return ObservationResponse( observation=observations, actions=actions_out, timestamp_ms=timestamp_ms, frame_number=frame_number, agent_name=cache_key, observation_names=obs_names, action_names=action_names, camera_frames=camera_frames, )
# ── Progress reporting ──
[docs] def report_progress( self, *, run_id: str = "", task_name: str = "", policy_name: str = "", phase: str = "", current_episode: int = 0, total_episodes: int = 0, current_step: int = 0, max_steps: int = 0, elapsed_s: float = 0.0, status_text: str = "", finished: bool = False, ) -> None: """Report evaluation/training progress to the engine for UI display. Fire-and-forget: errors are logged but never raised. """ try: self.agent.ReportProgress( self.pb.agent.ProgressReport( run_id=run_id, task_name=task_name, policy_name=policy_name, phase=phase, current_episode=current_episode, total_episodes=total_episodes, current_step=current_step, max_steps=max_steps, elapsed_s=elapsed_s, status_text=status_text, finished=finished, ), timeout=1.0, ) except Exception as e: logger.debug("report_progress failed (non-fatal): %s", e)
# ── SceneService RPCs ──
[docs] def set_simulation_mode( self, mode: str = "fast", timeout: Optional[float] = None, ): """ Set simulation timing mode. Args: mode: "realtime", "deterministic", or "fast" - realtime: Physics runs at 1x wall-clock speed - deterministic: Physics runs at fixed rate - fast: Physics runs as fast as possible (for RL training) timeout: RPC timeout in seconds. Returns: SetSimulationModeResponse with success and current mode. """ timeout = timeout or self.timeout mode_map = { "realtime": 0, "deterministic": 1, "fast": 2, } mode_value = mode_map.get(mode.lower(), 2) return self.scene.SetSimulationMode( self.pb.scene.SetSimulationModeRequest(mode=mode_value), timeout=timeout, )
# ── Benchmarking ──
[docs] def benchmark( self, duration_seconds: float = 5.0, method: str = "step", print_results: bool = False, ) -> BenchmarkResult: """Benchmark a client method by calling it in a tight loop. Args: duration_seconds: How long to run the benchmark. method: Method to benchmark. Currently supports "step". print_results: Print results to stdout. Returns: BenchmarkResult with timing statistics. Raises: ValueError: If method is not recognized. """ if method == "step": # Use zero actions for benchmarking call_fn = lambda: self.step(actions=[0.0] * 12) else: raise ValueError( f"Unknown method '{method}'. Supported: 'step'" ) latencies: list[float] = [] start = time.perf_counter() deadline = start + duration_seconds while time.perf_counter() < deadline: t0 = time.perf_counter() call_fn() t1 = time.perf_counter() latencies.append((t1 - t0) * 1000.0) # ms elapsed = time.perf_counter() - start count = len(latencies) if count == 0: result = BenchmarkResult( method=method, duration_seconds=elapsed, frame_count=0, actual_fps=0.0, avg_latency_ms=0.0, min_latency_ms=0.0, max_latency_ms=0.0, std_latency_ms=0.0, p50_latency_ms=0.0, p99_latency_ms=0.0, ) else: sorted_lat = sorted(latencies) p50_idx = int(math.floor(0.50 * (count - 1))) p99_idx = int(math.floor(0.99 * (count - 1))) result = BenchmarkResult( method=method, duration_seconds=elapsed, frame_count=count, actual_fps=count / elapsed if elapsed > 0 else 0.0, avg_latency_ms=statistics.mean(latencies), min_latency_ms=sorted_lat[0], max_latency_ms=sorted_lat[-1], std_latency_ms=statistics.stdev(latencies) if count > 1 else 0.0, p50_latency_ms=sorted_lat[p50_idx], p99_latency_ms=sorted_lat[p99_idx], ) if print_results: print(f"\n--- Benchmark: {method} ({elapsed:.1f}s) ---") print(f" Frames: {result.frame_count}") print(f" FPS: {result.actual_fps:.1f}") print(f" Avg: {result.avg_latency_ms:.2f} ms") print(f" Min: {result.min_latency_ms:.2f} ms") print(f" Max: {result.max_latency_ms:.2f} ms") print(f" Std: {result.std_latency_ms:.2f} ms") print(f" P50: {result.p50_latency_ms:.2f} ms") print(f" P99: {result.p99_latency_ms:.2f} ms") return result