diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..d1e5314 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,79 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What This Project Does + +RSIPI enables real-time control of KUKA industrial robots from Python via the RSI (Robot Sensor Interface) protocol. The robot sends its position ~250 times/second over UDP, and this library lets you send back position corrections to control the robot externally. + +## Build & Development Commands + +```bash +# Install dependencies +pip install -e . + +# Or install from requirements (if present) +pip install pandas>=2.0 numpy>=1.22 matplotlib>=3.5 lxml>=4.9 scipy>=1.8 + +# Run the CLI +python -m RSIPI.rsi_cli --config RSI_EthernetConfig.xml + +# Run the echo server (for offline testing without a real robot) +python -m RSIPI.rsi_echo_server +``` + +**No test suite exists** - testing is done via the echo server simulation and example scripts in `examples/`. + +## Architecture + +### Core Communication Flow + +``` +KUKA Robot Controller <--UDP/XML--> NetworkProcess <--multiprocessing.Manager--> RSIClient <-- RSIAPI/CLI +``` + +1. **NetworkProcess** (`network_handler.py`) - Runs in separate process via `multiprocessing.Process`. Binds to UDP socket, receives XML from robot, parses into `receive_variables`, sends XML from `send_variables` back to robot. Uses `start_event` to wait for explicit start signal. + +2. **RSIClient** (`rsi_client.py`) - Orchestrates the system. Initializes ConfigParser, SafetyManager, and NetworkProcess. Uses `multiprocessing.Manager` dicts for thread-safe variable sharing between processes. + +3. **RSIAPI** (`rsi_api.py`) - High-level API wrapping RSIClient. Runs RSIClient in a daemon thread. Provides trajectory planning, logging, plotting, and safety controls. + +4. **RSICommandLineInterface** (`rsi_cli.py`) - Interactive CLI that wraps RSIAPI. + +### Key Shared State + +Variables are shared between processes using `multiprocessing.Manager().dict()`: +- `send_variables` - Values to send to robot (RKorr corrections, digital outputs, etc.) +- `receive_variables` - Values received from robot (RIst position, ASPos joints, IPOC timestamp) + +### Configuration + +`RSI_EthernetConfig.xml` defines: +- Network settings (IP, port) in `` section +- Send variables in `` - what the robot receives from us +- Receive variables in `` - what we receive from robot + +Variable tags like `DEF_RIst` get the `DEF_` prefix stripped and are expanded using `internal_structure` in ConfigParser to full dicts (e.g., `RIst: {X, Y, Z, A, B, C}`). + +### Safety Layer + +**SafetyManager** (`safety_manager.py`) validates all outgoing values against configurable limits. Can load limits from `.rsi.xml` files. Supports emergency stop and safety override modes. + +### Trajectory Execution + +`TrajectoryPlanner` generates interpolated waypoints. `execute_trajectory()` in RSIAPI uses asyncio to send points at specified rate (default 12ms for Cartesian, 400ms for joints). + +## Important Patterns + +- **IPOC synchronization**: The robot sends an IPOC (timestamp) value. The response must include `IPOC + 4` to maintain sync. This is handled automatically in `NetworkProcess.process_received_data()`. + +- **Lazy client initialization**: RSIAPI uses `_ensure_client()` pattern - RSIClient is created on first use, not at RSIAPI instantiation. + +- **Non-blocking start**: `start_rsi()` runs the client loop in a daemon thread. The NetworkProcess waits on `start_event` before binding the socket. + +## File Locations + +- Source code: `src/RSIPI/` +- Example scripts: `examples/` +- Config template: `RSI_EthernetConfig.xml` +- Logs written to: `logs/` (created at runtime) diff --git a/pyproject.toml b/pyproject.toml index 7ce9b06..4833b5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,17 @@ classifiers = [ "Operating System :: OS Independent", ] +[project.optional-dependencies] +dev = [ + "pytest>=7.0", +] + [tool.setuptools] package-dir = {"" = "src"} [tool.setuptools.packages.find] -where = ["src"] \ No newline at end of file +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] \ No newline at end of file diff --git a/src/RSIPI/config_parser.py b/src/RSIPI/config_parser.py index 8afa5cb..8ffa6c9 100644 --- a/src/RSIPI/config_parser.py +++ b/src/RSIPI/config_parser.py @@ -17,7 +17,7 @@ class ConfigParser: config_file (str): Path to the RSI_EthernetConfig.xml file. rsi_limits_file (str, optional): Path to .rsi.xml file containing safety limits. """ - from src.RSIPI.rsi_limit_parser import parse_rsi_limits + from .rsi_limit_parser import parse_rsi_limits self.config_file = config_file self.rsi_limits_file = rsi_limits_file diff --git a/src/RSIPI/echo_server_gui.py b/src/RSIPI/echo_server_gui.py index a977c75..1255cc5 100644 --- a/src/RSIPI/echo_server_gui.py +++ b/src/RSIPI/echo_server_gui.py @@ -2,7 +2,7 @@ import tkinter as tk from tkinter import ttk, filedialog import threading import time -from src.RSIPI.rsi_echo_server import EchoServer +from .rsi_echo_server import EchoServer import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from mpl_toolkits.mplot3d import Axes3D diff --git a/src/RSIPI/kuka_visualiser.py b/src/RSIPI/kuka_visualiser.py index 0c3fb84..d02ecf5 100644 --- a/src/RSIPI/kuka_visualiser.py +++ b/src/RSIPI/kuka_visualiser.py @@ -150,7 +150,7 @@ if __name__ == "__main__": args = parser.parse_args() if args.limits: - from src.RSIPI.rsi_limit_parser import parse_rsi_limits + from .rsi_limit_parser import parse_rsi_limits limits = parse_rsi_limits(args.limits) visualiser = KukaRSIVisualiser(args.csv_file, safety_limits=limits) else: diff --git a/src/RSIPI/network_handler.py b/src/RSIPI/network_handler.py index f42e2b9..c9d6cb3 100644 --- a/src/RSIPI/network_handler.py +++ b/src/RSIPI/network_handler.py @@ -2,47 +2,116 @@ import multiprocessing import socket import logging import xml.etree.ElementTree as ET +import os +import datetime +from queue import Empty from .xml_handler import XMLGenerator from .safety_manager import SafetyManager + +class CSVLogger(multiprocessing.Process): + """Separate process for writing CSV logs without blocking the network loop.""" + + def __init__(self, log_queue, stop_event, filename): + super().__init__() + self.log_queue = log_queue + self.stop_event = stop_event + self.filename = filename + self.daemon = True + + def run(self): + """Write log entries from queue to CSV file.""" + # Ensure logs directory exists + log_dir = os.path.dirname(self.filename) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + header_written = False + + try: + with open(self.filename, 'w', newline='') as f: + while not self.stop_event.is_set(): + try: + entry = self.log_queue.get(timeout=0.5) + if entry is None: # Poison pill + break + + # Write header on first entry + if not header_written: + headers = ['Timestamp'] + list(entry.keys()) + f.write(','.join(headers) + '\n') + header_written = True + + # Write data row + timestamp = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S.%f")[:-3] + values = [timestamp] + [str(v) for v in entry.values()] + f.write(','.join(values) + '\n') + f.flush() + + except Empty: + continue + except Exception as e: + logging.error(f"CSV logging error: {e}") + + except Exception as e: + logging.error(f"Failed to open log file {self.filename}: {e}") + + class NetworkProcess(multiprocessing.Process): """Handles UDP communication and optional CSV logging in a separate process.""" - def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event): + def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event, command_queue): super().__init__() self.send_variables = send_variables self.receive_variables = receive_variables self.stop_event = stop_event - self.start_event = start_event # ✅ NEW + self.start_event = start_event self.config_parser = config_parser - self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.command_queue = command_queue self.safety_manager = SafetyManager(config_parser.safety_limits) self.client_address = (ip, port) self.logging_active = multiprocessing.Value('b', False) - self.log_filename = multiprocessing.Array('c', 256) - self.csv_process = None self.controller_ip_and_port = None + self.udp_socket = None + + # Logging infrastructure (created when logging starts) + self.log_queue = None + self.log_stop_event = None + self.csv_logger = None def run(self): """Start the network loop.""" - self.start_event.wait() # ✅ Wait until RSIClient sends start signal + # Wait for start signal, but check stop_event periodically to allow clean shutdown + while not self.start_event.wait(timeout=0.5): + if self.stop_event.is_set(): + logging.info("Network process stopped before starting") + return try: - if not self.is_valid_ip(self.client_address[0]): - logging.warning(f"Invalid IP address '{self.client_address[0]}'. Falling back to '0.0.0.0'.") - self.client_address = ('0.0.0.0', self.client_address[1]) + self._setup_socket() + self._run_loop() + finally: + self._cleanup() - self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.udp_socket.bind(self.client_address) - logging.info(f"✅ Network process bound on {self.client_address}") + def _setup_socket(self): + """Create and bind the UDP socket.""" + if not self.is_valid_ip(self.client_address[0]): + logging.warning(f"Invalid IP address '{self.client_address[0]}'. Falling back to '0.0.0.0'.") + self.client_address = ('0.0.0.0', self.client_address[1]) - except OSError as e: - logging.error(f"❌ Failed to bind to {self.client_address}: {e}") - raise + self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.udp_socket.bind(self.client_address) + logging.info(f"Network process bound on {self.client_address}") + def _run_loop(self): + """Main communication loop.""" while not self.stop_event.is_set(): + # Check for commands (non-blocking) + self._process_commands() + try: self.udp_socket.settimeout(5) data_received, self.controller_ip_and_port = self.udp_socket.recvfrom(1024) @@ -51,13 +120,115 @@ class NetworkProcess(multiprocessing.Process): send_xml = XMLGenerator.generate_send_xml(self.send_variables, self.config_parser.network_settings) self.udp_socket.sendto(send_xml.encode(), self.controller_ip_and_port) - if self.logging_active.value: - self.log_to_csv() + if self.logging_active.value and self.log_queue: + self._queue_log_entry() except socket.timeout: - logging.error("[WARNING] No message received within timeout period.") + logging.warning("No message received within timeout period") except Exception as e: - logging.error(f"[ERROR] Network process error: {e}") + logging.error(f"Network process error: {e}") + + def _process_commands(self): + """Process any pending commands from the parent process.""" + try: + while True: + cmd = self.command_queue.get_nowait() + if cmd is None: + continue + + action = cmd.get('action') + if action == 'start_logging': + self._start_logging(cmd.get('filename')) + elif action == 'stop_logging': + self._stop_logging() + + except Empty: + pass + except Exception as e: + logging.error(f"Error processing command: {e}") + + def _queue_log_entry(self): + """Queue current state for CSV logging (non-blocking).""" + try: + entry = {} + # Flatten send variables + for key, value in dict(self.send_variables).items(): + if isinstance(value, dict): + for subkey, subval in value.items(): + entry[f"Send.{key}.{subkey}"] = subval + else: + entry[f"Send.{key}"] = value + + # Flatten receive variables + for key, value in dict(self.receive_variables).items(): + if isinstance(value, dict): + for subkey, subval in value.items(): + entry[f"Receive.{key}.{subkey}"] = subval + else: + entry[f"Receive.{key}"] = value + + # Non-blocking put - drop entry if queue is full + try: + self.log_queue.put_nowait(entry) + except: + pass # Queue full, skip this entry rather than block + + except Exception as e: + logging.debug(f"Failed to queue log entry: {e}") + + def _start_logging(self, filename): + """Start CSV logging to the specified file.""" + if self.logging_active.value: + logging.warning("Logging already active") + return + + self.log_queue = multiprocessing.Queue(maxsize=1000) + self.log_stop_event = multiprocessing.Event() + + self.csv_logger = CSVLogger(self.log_queue, self.log_stop_event, filename) + self.csv_logger.start() + + self.logging_active.value = True + logging.info(f"CSV logging started: {filename}") + + def _stop_logging(self): + """Stop CSV logging.""" + if not self.logging_active.value: + return + + self.logging_active.value = False + + if self.log_queue: + try: + self.log_queue.put_nowait(None) # Poison pill + except: + pass + + if self.log_stop_event: + self.log_stop_event.set() + + if self.csv_logger and self.csv_logger.is_alive(): + self.csv_logger.join(timeout=2) + if self.csv_logger.is_alive(): + self.csv_logger.terminate() + + self.csv_logger = None + self.log_queue = None + self.log_stop_event = None + logging.info("CSV logging stopped") + + def _cleanup(self): + """Clean up resources.""" + # Stop logging first + self._stop_logging() + + if self.udp_socket: + try: + self.udp_socket.close() + logging.info("Network socket closed") + except Exception as e: + logging.error(f"Error closing socket: {e}") + self.udp_socket = None @staticmethod def is_valid_ip(ip): @@ -83,4 +254,4 @@ class NetworkProcess(multiprocessing.Process): self.receive_variables["IPOC"] = received_ipoc self.send_variables["IPOC"] = received_ipoc + 4 except Exception as e: - logging.error(f"[ERROR] Error parsing received message: {e}") + logging.error(f"Error parsing received message: {e}") diff --git a/src/RSIPI/rsi_api.py b/src/RSIPI/rsi_api.py index 6893ca5..6fee0d7 100644 --- a/src/RSIPI/rsi_api.py +++ b/src/RSIPI/rsi_api.py @@ -9,9 +9,9 @@ from .inject_rsi_to_krl import inject_rsi_to_krl import threading from .trajectory_planner import generate_trajectory, execute_trajectory import datetime -from src.RSIPI.static_plotter import StaticPlotter # Make sure this file exists as described +from .static_plotter import StaticPlotter import os -from src.RSIPI.live_plotter import LivePlotter +from .live_plotter import LivePlotter from threading import Thread import asyncio @@ -49,6 +49,7 @@ class RSIAPI: self.client.stop() return "RSI stopped." + @staticmethod def generate_report(filename, format_type): """ Generate a statistical report from a CSV log file. diff --git a/src/RSIPI/rsi_client.py b/src/RSIPI/rsi_client.py index c4eafaf..c7664da 100644 --- a/src/RSIPI/rsi_client.py +++ b/src/RSIPI/rsi_client.py @@ -1,17 +1,43 @@ import logging import multiprocessing import time +from enum import Enum, auto +from threading import Lock from .config_parser import ConfigParser from .network_handler import NetworkProcess from .safety_manager import SafetyManager import threading + +class ClientState(Enum): + """Connection states for RSIClient.""" + INITIALIZED = auto() # After __init__, network process spawned but not started + STARTING = auto() # Start signal sent, waiting for network to be ready + RUNNING = auto() # Actively communicating with robot + STOPPING = auto() # Shutdown in progress + STOPPED = auto() # Fully stopped, cannot be restarted (use reconnect) + ERROR = auto() # Error state + + class RSIClient: """Main RSI API class that integrates network, config handling, and message processing.""" + # Valid state transitions + _VALID_TRANSITIONS = { + ClientState.INITIALIZED: {ClientState.STARTING, ClientState.STOPPING}, + ClientState.STARTING: {ClientState.RUNNING, ClientState.STOPPING, ClientState.ERROR}, + ClientState.RUNNING: {ClientState.STOPPING, ClientState.ERROR}, + ClientState.STOPPING: {ClientState.STOPPED, ClientState.ERROR}, + ClientState.STOPPED: {ClientState.INITIALIZED}, # Via reconnect + ClientState.ERROR: {ClientState.STOPPING, ClientState.INITIALIZED}, # Via reconnect + } + def __init__(self, config_file, rsi_limits_file=None): logging.info(f"Loading RSI configuration from {config_file}...") + self._state = ClientState.INITIALIZED + self._state_lock = Lock() + self.config_parser = ConfigParser(config_file, rsi_limits_file) network_settings = self.config_parser.get_network_settings() @@ -19,11 +45,15 @@ class RSIClient: self.send_variables = self.manager.dict(self.config_parser.send_variables) self.receive_variables = self.manager.dict(self.config_parser.receive_variables) self.stop_event = multiprocessing.Event() - self.start_event = multiprocessing.Event() # ✅ NEW + self.start_event = multiprocessing.Event() + self.command_queue = multiprocessing.Queue() self.safety_manager = SafetyManager(self.config_parser.safety_limits) - # ✅ Create NetworkProcess but don't start communication yet + # Shared logging state (readable from parent process) + self._logging_active = multiprocessing.Value('b', False) + + # Create NetworkProcess but don't start communication yet self.network_process = NetworkProcess( network_settings["ip"], network_settings["port"], @@ -31,17 +61,56 @@ class RSIClient: self.receive_variables, self.stop_event, self.config_parser, - self.start_event + self.start_event, + self.command_queue ) + # Share the logging_active flag + self.network_process.logging_active = self._logging_active self.network_process.start() + self.logger = None + self.running = False + self.thread = None + + @property + def state(self) -> ClientState: + """Get current client state (thread-safe).""" + with self._state_lock: + return self._state + + def _transition_to(self, new_state: ClientState) -> bool: + """ + Attempt to transition to a new state. + + Returns: + True if transition was valid and completed, False otherwise. + """ + with self._state_lock: + if new_state in self._VALID_TRANSITIONS.get(self._state, set()): + old_state = self._state + self._state = new_state + logging.debug(f"State transition: {old_state.name} -> {new_state.name}") + return True + else: + logging.warning( + f"Invalid state transition attempted: {self._state.name} -> {new_state.name}" + ) + return False def start(self): """Send start signal to NetworkProcess and run control loop.""" + if not self._transition_to(ClientState.STARTING): + logging.error("Cannot start: invalid state") + return + logging.info("RSIClient sending start signal to NetworkProcess...") self.start_event.set() - self.running = True + if not self._transition_to(ClientState.RUNNING): + logging.error("Failed to transition to RUNNING state") + return + + self.running = True logging.info("RSI Client Started") try: @@ -51,39 +120,58 @@ class RSIClient: self.stop() except Exception as e: logging.error(f"RSI Client encountered an error: {e}") + self._transition_to(ClientState.ERROR) def stop(self): """Stop the network process and the client thread safely.""" - logging.info("🛑 Stopping RSI Client...") + if self.state in (ClientState.STOPPED, ClientState.STOPPING): + logging.debug("Already stopped or stopping") + return + + if not self._transition_to(ClientState.STOPPING): + logging.warning("Could not transition to STOPPING state") + # Continue anyway to ensure cleanup + + logging.info("Stopping RSI Client...") self.running = False - self.stop_event.set() # ✅ Tell network process to exit nicely + self.stop_event.set() if self.network_process and self.network_process.is_alive(): - self.network_process.join(timeout=3) # ✅ Give it time to shutdown + self.network_process.join(timeout=3) if self.network_process.is_alive(): - logging.warning("⚠️ Forcing network process termination...") + logging.warning("Forcing network process termination...") self.network_process.terminate() self.network_process.join() - if hasattr(self, "thread") and self.thread and self.thread.is_alive(): - self.thread.join() + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=2) self.thread = None - logging.info("✅ RSI Client Stopped") + self._transition_to(ClientState.STOPPED) + logging.info("RSI Client Stopped") def reconnect(self): """Reconnects the network process safely.""" logging.info("Reconnecting RSI Client network...") + # Stop if currently running + if self.state in (ClientState.RUNNING, ClientState.STARTING): + self.stop() + if self.network_process and self.network_process.is_alive(): self.stop_event.set() self.network_process.terminate() self.network_process.join() - # Fresh new events + # Reset to initialized state + with self._state_lock: + self._state = ClientState.INITIALIZED + + # Fresh new events and queue self.stop_event = multiprocessing.Event() self.start_event = multiprocessing.Event() + self.command_queue = multiprocessing.Queue() # Create new network process network_settings = self.config_parser.get_network_settings() @@ -94,10 +182,32 @@ class RSIClient: self.receive_variables, self.stop_event, self.config_parser, - self.start_event + self.start_event, + self.command_queue ) + self.network_process.logging_active = self._logging_active self.network_process.start() # Fresh control thread self.thread = threading.Thread(target=self.start, daemon=True) self.thread.start() + + def is_running(self) -> bool: + """Check if client is in running state.""" + return self.state == ClientState.RUNNING + + def is_stopped(self) -> bool: + """Check if client is fully stopped.""" + return self.state == ClientState.STOPPED + + def start_logging(self, filename): + """Start CSV logging to the specified file.""" + self.command_queue.put({'action': 'start_logging', 'filename': filename}) + + def stop_logging(self): + """Stop CSV logging.""" + self.command_queue.put({'action': 'stop_logging'}) + + def is_logging_active(self) -> bool: + """Check if CSV logging is currently active.""" + return self._logging_active.value diff --git a/src/RSIPI/rsi_config.py b/src/RSIPI/rsi_config.py index 7de83aa..7031090 100644 --- a/src/RSIPI/rsi_config.py +++ b/src/RSIPI/rsi_config.py @@ -1,6 +1,6 @@ import xml.etree.ElementTree as ET import logging -from src.RSIPI.rsi_limit_parser import parse_rsi_limits +from .rsi_limit_parser import parse_rsi_limits # ✅ Configure Logging (toggleable) LOGGING_ENABLED = False # Change too False to silence logging output diff --git a/src/RSIPI/rsi_echo_server.py b/src/RSIPI/rsi_echo_server.py index 6d07478..6340d55 100644 --- a/src/RSIPI/rsi_echo_server.py +++ b/src/RSIPI/rsi_echo_server.py @@ -3,7 +3,7 @@ import time import xml.etree.ElementTree as ET import logging import threading -from src.RSIPI.rsi_config import RSIConfig +from .rsi_config import RSIConfig # ✅ Toggle logging for debugging purposes LOGGING_ENABLED = True diff --git a/src/RSIPI/trajectory_planner.py b/src/RSIPI/trajectory_planner.py index a735033..a225abe 100644 --- a/src/RSIPI/trajectory_planner.py +++ b/src/RSIPI/trajectory_planner.py @@ -1,4 +1,4 @@ -from RSIPI.safety_manager import SafetyManager +from .safety_manager import SafetyManager import time def generate_trajectory(start, end, steps=100, space="cartesian", mode="absolute", include_resets=False): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..693e8ec --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# RSIPI Tests diff --git a/tests/__pycache__/__init__.cpython-313.pyc b/tests/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..233849b Binary files /dev/null and b/tests/__pycache__/__init__.cpython-313.pyc differ diff --git a/tests/__pycache__/test_safety_manager.cpython-313-pytest-8.4.1.pyc b/tests/__pycache__/test_safety_manager.cpython-313-pytest-8.4.1.pyc new file mode 100644 index 0000000..c9b8378 Binary files /dev/null and b/tests/__pycache__/test_safety_manager.cpython-313-pytest-8.4.1.pyc differ diff --git a/tests/__pycache__/test_trajectory_planner.cpython-313-pytest-8.4.1.pyc b/tests/__pycache__/test_trajectory_planner.cpython-313-pytest-8.4.1.pyc new file mode 100644 index 0000000..060de21 Binary files /dev/null and b/tests/__pycache__/test_trajectory_planner.cpython-313-pytest-8.4.1.pyc differ diff --git a/tests/__pycache__/test_xml_handler.cpython-313-pytest-8.4.1.pyc b/tests/__pycache__/test_xml_handler.cpython-313-pytest-8.4.1.pyc new file mode 100644 index 0000000..f7b8286 Binary files /dev/null and b/tests/__pycache__/test_xml_handler.cpython-313-pytest-8.4.1.pyc differ diff --git a/tests/test_safety_manager.py b/tests/test_safety_manager.py new file mode 100644 index 0000000..e6cc04e --- /dev/null +++ b/tests/test_safety_manager.py @@ -0,0 +1,168 @@ +"""Tests for SafetyManager.""" +import pytest +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from RSIPI.safety_manager import SafetyManager + + +class TestValidate: + """Tests for SafetyManager.validate()""" + + def test_validate_within_limits(self): + """Test that values within limits pass through unchanged.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + result = sm.validate("RKorr.X", 3.0) + assert result == 3.0 + + def test_validate_at_boundary(self): + """Test that values at exact boundaries pass.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + assert sm.validate("RKorr.X", -5.0) == -5.0 + assert sm.validate("RKorr.X", 5.0) == 5.0 + + def test_validate_exceeds_max(self): + """Test that values exceeding max raise ValueError.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + with pytest.raises(ValueError, match="out of bounds"): + sm.validate("RKorr.X", 5.1) + + def test_validate_below_min(self): + """Test that values below min raise ValueError.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + with pytest.raises(ValueError, match="out of bounds"): + sm.validate("RKorr.X", -5.1) + + def test_validate_unlisted_path(self): + """Test that paths without defined limits pass through.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + # RKorr.Y has no limit defined - should pass + result = sm.validate("RKorr.Y", 1000.0) + assert result == 1000.0 + + def test_validate_with_override(self): + """Test that override bypasses all limit checks.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + sm.override_safety(True) + + # Should pass despite being out of bounds + result = sm.validate("RKorr.X", 100.0) + assert result == 100.0 + + +class TestEmergencyStop: + """Tests for emergency stop functionality.""" + + def test_estop_blocks_validation(self): + """Test that e-stop blocks all validation.""" + sm = SafetyManager() + sm.emergency_stop() + + with pytest.raises(RuntimeError, match="E-STOP"): + sm.validate("RKorr.X", 0.0) + + def test_estop_reset(self): + """Test that e-stop can be reset.""" + sm = SafetyManager() + sm.emergency_stop() + + assert sm.is_stopped() is True + + sm.reset_stop() + + assert sm.is_stopped() is False + # Should work again + result = sm.validate("RKorr.X", 1.0) + assert result == 1.0 + + +class TestSetLimit: + """Tests for runtime limit modification.""" + + def test_set_new_limit(self): + """Test adding a new limit at runtime.""" + sm = SafetyManager() + + sm.set_limit("RKorr.Y", -10.0, 10.0) + + assert sm.validate("RKorr.Y", 5.0) == 5.0 + with pytest.raises(ValueError): + sm.validate("RKorr.Y", 15.0) + + def test_override_existing_limit(self): + """Test overriding an existing limit.""" + limits = {"RKorr.X": (-5.0, 5.0)} + sm = SafetyManager(limits) + + # Original limit blocks this + with pytest.raises(ValueError): + sm.validate("RKorr.X", 8.0) + + # Override limit + sm.set_limit("RKorr.X", -10.0, 10.0) + + # Now it should pass + assert sm.validate("RKorr.X", 8.0) == 8.0 + + def test_get_limits(self): + """Test retrieving all limits.""" + limits = {"RKorr.X": (-5.0, 5.0), "AKorr.A1": (-6.0, 6.0)} + sm = SafetyManager(limits) + + retrieved = sm.get_limits() + assert retrieved == limits + # Should be a copy + retrieved["new"] = (0, 1) + assert "new" not in sm.get_limits() + + +class TestStaticChecks: + """Tests for static limit checking methods.""" + + def test_cartesian_limits_valid(self): + """Test valid Cartesian pose passes check.""" + pose = {"X": 500, "Y": -200, "Z": 1000, "A": 0, "B": 0, "C": 0} + assert SafetyManager.check_cartesian_limits(pose) is True + + def test_cartesian_limits_z_too_low(self): + """Test Z below zero fails.""" + pose = {"X": 0, "Y": 0, "Z": -100} + assert SafetyManager.check_cartesian_limits(pose) is False + + def test_cartesian_limits_x_too_high(self): + """Test X exceeding max fails.""" + pose = {"X": 2000, "Y": 0, "Z": 500} + assert SafetyManager.check_cartesian_limits(pose) is False + + def test_cartesian_limits_partial_pose(self): + """Test partial pose (missing keys) passes if present keys are valid.""" + pose = {"X": 100, "Z": 500} # Missing Y, A, B, C + assert SafetyManager.check_cartesian_limits(pose) is True + + def test_joint_limits_valid(self): + """Test valid joint pose passes check.""" + pose = {"A1": 0, "A2": -45, "A3": 90, "A4": 0, "A5": 0, "A6": 180} + assert SafetyManager.check_joint_limits(pose) is True + + def test_joint_limits_a1_exceeded(self): + """Test A1 exceeding limit fails.""" + pose = {"A1": 200} # Limit is -185 to 185 + assert SafetyManager.check_joint_limits(pose) is False + + def test_joint_limits_a5_exceeded(self): + """Test A5 exceeding its tighter limit fails.""" + pose = {"A5": 150} # Limit is -130 to 130 + assert SafetyManager.check_joint_limits(pose) is False diff --git a/tests/test_trajectory_planner.py b/tests/test_trajectory_planner.py new file mode 100644 index 0000000..7f51125 --- /dev/null +++ b/tests/test_trajectory_planner.py @@ -0,0 +1,136 @@ +"""Tests for trajectory planner.""" +import pytest +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from RSIPI.trajectory_planner import generate_trajectory + + +class TestGenerateTrajectory: + """Tests for generate_trajectory()""" + + def test_basic_linear_interpolation(self): + """Test basic linear interpolation between two points.""" + start = {"X": 0, "Y": 0, "Z": 0} + end = {"X": 100, "Y": 200, "Z": 300} + + traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute") + + assert len(traj) == 10 + + # First point should be 1/10 of the way + assert traj[0]["X"] == pytest.approx(10.0) + assert traj[0]["Y"] == pytest.approx(20.0) + assert traj[0]["Z"] == pytest.approx(30.0) + + # Last point should be at end + assert traj[-1]["X"] == pytest.approx(100.0) + assert traj[-1]["Y"] == pytest.approx(200.0) + assert traj[-1]["Z"] == pytest.approx(300.0) + + def test_relative_mode(self): + """Test relative mode generates incremental deltas.""" + start = {"X": 0, "Y": 0} + end = {"X": 100, "Y": 50} + + traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="relative") + + # Each step should be the same delta + for point in traj: + assert point["X"] == pytest.approx(10.0) # 100/10 + assert point["Y"] == pytest.approx(5.0) # 50/10 + + def test_relative_mode_with_resets(self): + """Test relative mode with reset steps.""" + start = {"X": 0} + end = {"X": 30} + + traj = generate_trajectory(start, end, steps=3, space="cartesian", mode="relative", include_resets=True) + + # Should have 6 points: delta, reset, delta, reset, delta, reset + assert len(traj) == 6 + + # Odd indices should be zero (reset points) + assert traj[1]["X"] == 0.0 + assert traj[3]["X"] == 0.0 + assert traj[5]["X"] == 0.0 + + # Even indices should be deltas + assert traj[0]["X"] == pytest.approx(10.0) + assert traj[2]["X"] == pytest.approx(10.0) + assert traj[4]["X"] == pytest.approx(10.0) + + def test_joint_space(self): + """Test trajectory generation in joint space.""" + start = {"A1": 0, "A2": 0} + end = {"A1": 90, "A2": -45} + + traj = generate_trajectory(start, end, steps=9, space="joint", mode="absolute") + + assert len(traj) == 9 + + # Final point should match end + assert traj[-1]["A1"] == pytest.approx(90.0) + assert traj[-1]["A2"] == pytest.approx(-45.0) + + def test_negative_movement(self): + """Test trajectory with negative direction.""" + start = {"X": 100, "Y": 100} + end = {"X": 0, "Y": -100} + + traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute") + + # Should decrease linearly + assert traj[0]["X"] == pytest.approx(90.0) + assert traj[0]["Y"] == pytest.approx(80.0) + assert traj[-1]["X"] == pytest.approx(0.0) + assert traj[-1]["Y"] == pytest.approx(-100.0) + + def test_single_step(self): + """Test trajectory with single step goes directly to end.""" + start = {"X": 0} + end = {"X": 100} + + traj = generate_trajectory(start, end, steps=1, space="cartesian", mode="absolute") + + assert len(traj) == 1 + assert traj[0]["X"] == pytest.approx(100.0) + + def test_invalid_mode_raises(self): + """Test that invalid mode raises ValueError.""" + start = {"X": 0} + end = {"X": 100} + + with pytest.raises(ValueError, match="mode must be"): + generate_trajectory(start, end, mode="invalid") + + def test_invalid_space_raises(self): + """Test that invalid space raises ValueError.""" + start = {"X": 0} + end = {"X": 100} + + with pytest.raises(ValueError, match="space must be"): + generate_trajectory(start, end, space="invalid") + + def test_absolute_mode_ignores_include_resets(self): + """Test that absolute mode ignores include_resets flag.""" + start = {"X": 0} + end = {"X": 100} + + # Even with include_resets=True, absolute mode should not add resets + traj = generate_trajectory(start, end, steps=5, space="cartesian", mode="absolute", include_resets=True) + + assert len(traj) == 5 # No extra reset points + + def test_preserves_all_axes(self): + """Test that all axes from start are preserved in trajectory.""" + start = {"X": 0, "Y": 0, "Z": 0, "A": 0, "B": 0, "C": 0} + end = {"X": 10, "Y": 20, "Z": 30, "A": 5, "B": 10, "C": 15} + + traj = generate_trajectory(start, end, steps=2, space="cartesian", mode="absolute") + + # Each point should have all axes + for point in traj: + assert set(point.keys()) == {"X", "Y", "Z", "A", "B", "C"} diff --git a/tests/test_xml_handler.py b/tests/test_xml_handler.py new file mode 100644 index 0000000..714fa6c --- /dev/null +++ b/tests/test_xml_handler.py @@ -0,0 +1,112 @@ +"""Tests for XMLGenerator.""" +import pytest +import xml.etree.ElementTree as ET +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from RSIPI.xml_handler import XMLGenerator + + +class TestGenerateSendXML: + """Tests for XMLGenerator.generate_send_xml()""" + + def test_simple_values(self): + """Test XML generation with simple scalar values.""" + send_vars = { + "IPOC": 12345, + "DiL": 1 + } + network_settings = {"sentype": "ImFree"} + + xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings) + + # Parse and verify + root = ET.fromstring(xml_str) + assert root.tag == "Sen" + assert root.get("Type") == "ImFree" + assert root.find("IPOC").text == "12345" + assert root.find("DiL").text == "1" + + def test_nested_dict_values(self): + """Test XML generation with nested dictionary values (attributes).""" + send_vars = { + "RKorr": {"X": 100.5, "Y": -50.25, "Z": 0.0} + } + network_settings = {"sentype": "TestType"} + + xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings) + + root = ET.fromstring(xml_str) + rkorr = root.find("RKorr") + assert rkorr is not None + assert float(rkorr.get("X")) == 100.50 + assert float(rkorr.get("Y")) == -50.25 + assert float(rkorr.get("Z")) == 0.0 + + def test_free_field_skipped(self): + """Test that FREE field is skipped in XML output.""" + send_vars = { + "IPOC": 1, + "FREE": 999 + } + network_settings = {"sentype": "ImFree"} + + xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings) + + root = ET.fromstring(xml_str) + assert root.find("FREE") is None + assert root.find("IPOC") is not None + + def test_mixed_values(self): + """Test XML generation with mixed scalar and nested values.""" + send_vars = { + "IPOC": 5000, + "RKorr": {"X": 10.0, "Y": 20.0}, + "Digout": {"o1": 1, "o2": 0} + } + network_settings = {"sentype": "Mixed"} + + xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings) + + root = ET.fromstring(xml_str) + assert root.find("IPOC").text == "5000" + assert float(root.find("RKorr").get("X")) == 10.0 + assert float(root.find("Digout").get("o1")) == 1.0 + + +class TestGenerateReceiveXML: + """Tests for XMLGenerator.generate_receive_xml()""" + + def test_simple_receive(self): + """Test receive XML generation with scalar values.""" + receive_vars = { + "IPOC": 12345, + "BMode": "Active" + } + + xml_str = XMLGenerator.generate_receive_xml(receive_vars) + + root = ET.fromstring(xml_str) + assert root.tag == "Rob" + assert root.get("Type") == "KUKA" + assert root.find("IPOC").text == "12345" + assert root.find("BMode").text == "Active" + + def test_nested_receive(self): + """Test receive XML generation with position data.""" + receive_vars = { + "RIst": {"X": 500.0, "Y": 100.0, "Z": 800.0}, + "ASPos": {"A1": 0.0, "A2": -45.0, "A3": 90.0} + } + + xml_str = XMLGenerator.generate_receive_xml(receive_vars) + + root = ET.fromstring(xml_str) + rist = root.find("RIst") + assert float(rist.get("X")) == 500.0 + assert float(rist.get("Z")) == 800.0 + + aspos = root.find("ASPos") + assert float(aspos.get("A2")) == -45.0