Refactor core architecture and add test coverage

- Fix socket lifecycle: create in child process, add cleanup with try/finally
- Add ClientState enum with validated state transitions to prevent invalid operations
- Decouple CSV logging from network loop using queue-based CSVLogger process
- Fix broken imports: change absolute (src.RSIPI.x) to relative (.x) across 7 files
- Add missing @staticmethod decorator to generate_report()
- Add command queue for inter-process communication (logging control)
- Add 34 unit tests for XMLGenerator, SafetyManager, and trajectory_planner
- Add pytest configuration to pyproject.toml
- Add CLAUDE.md with architecture documentation
This commit is contained in:
Adam 2026-01-16 20:09:56 +00:00
parent 219ffc531c
commit 7bfe5cccf1
19 changed files with 829 additions and 42 deletions

79
CLAUDE.md Normal file
View File

@ -0,0 +1,79 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## What This Project Does
RSIPI enables real-time control of KUKA industrial robots from Python via the RSI (Robot Sensor Interface) protocol. The robot sends its position ~250 times/second over UDP, and this library lets you send back position corrections to control the robot externally.
## Build & Development Commands
```bash
# Install dependencies
pip install -e .
# Or install from requirements (if present)
pip install pandas>=2.0 numpy>=1.22 matplotlib>=3.5 lxml>=4.9 scipy>=1.8
# Run the CLI
python -m RSIPI.rsi_cli --config RSI_EthernetConfig.xml
# Run the echo server (for offline testing without a real robot)
python -m RSIPI.rsi_echo_server
```
**No test suite exists** - testing is done via the echo server simulation and example scripts in `examples/`.
## Architecture
### Core Communication Flow
```
KUKA Robot Controller <--UDP/XML--> NetworkProcess <--multiprocessing.Manager--> RSIClient <-- RSIAPI/CLI
```
1. **NetworkProcess** (`network_handler.py`) - Runs in separate process via `multiprocessing.Process`. Binds to UDP socket, receives XML from robot, parses into `receive_variables`, sends XML from `send_variables` back to robot. Uses `start_event` to wait for explicit start signal.
2. **RSIClient** (`rsi_client.py`) - Orchestrates the system. Initializes ConfigParser, SafetyManager, and NetworkProcess. Uses `multiprocessing.Manager` dicts for thread-safe variable sharing between processes.
3. **RSIAPI** (`rsi_api.py`) - High-level API wrapping RSIClient. Runs RSIClient in a daemon thread. Provides trajectory planning, logging, plotting, and safety controls.
4. **RSICommandLineInterface** (`rsi_cli.py`) - Interactive CLI that wraps RSIAPI.
### Key Shared State
Variables are shared between processes using `multiprocessing.Manager().dict()`:
- `send_variables` - Values to send to robot (RKorr corrections, digital outputs, etc.)
- `receive_variables` - Values received from robot (RIst position, ASPos joints, IPOC timestamp)
### Configuration
`RSI_EthernetConfig.xml` defines:
- Network settings (IP, port) in `<CONFIG>` section
- Send variables in `<SEND><ELEMENTS>` - what the robot receives from us
- Receive variables in `<RECEIVE><ELEMENTS>` - what we receive from robot
Variable tags like `DEF_RIst` get the `DEF_` prefix stripped and are expanded using `internal_structure` in ConfigParser to full dicts (e.g., `RIst: {X, Y, Z, A, B, C}`).
### Safety Layer
**SafetyManager** (`safety_manager.py`) validates all outgoing values against configurable limits. Can load limits from `.rsi.xml` files. Supports emergency stop and safety override modes.
### Trajectory Execution
`TrajectoryPlanner` generates interpolated waypoints. `execute_trajectory()` in RSIAPI uses asyncio to send points at specified rate (default 12ms for Cartesian, 400ms for joints).
## Important Patterns
- **IPOC synchronization**: The robot sends an IPOC (timestamp) value. The response must include `IPOC + 4` to maintain sync. This is handled automatically in `NetworkProcess.process_received_data()`.
- **Lazy client initialization**: RSIAPI uses `_ensure_client()` pattern - RSIClient is created on first use, not at RSIAPI instantiation.
- **Non-blocking start**: `start_rsi()` runs the client loop in a daemon thread. The NetworkProcess waits on `start_event` before binding the socket.
## File Locations
- Source code: `src/RSIPI/`
- Example scripts: `examples/`
- Config template: `RSI_EthernetConfig.xml`
- Logs written to: `logs/` (created at runtime)

View File

@ -25,8 +25,17 @@ classifiers = [
"Operating System :: OS Independent",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0",
]
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["src"]

View File

@ -17,7 +17,7 @@ class ConfigParser:
config_file (str): Path to the RSI_EthernetConfig.xml file.
rsi_limits_file (str, optional): Path to .rsi.xml file containing safety limits.
"""
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
from .rsi_limit_parser import parse_rsi_limits
self.config_file = config_file
self.rsi_limits_file = rsi_limits_file

View File

@ -2,7 +2,7 @@ import tkinter as tk
from tkinter import ttk, filedialog
import threading
import time
from src.RSIPI.rsi_echo_server import EchoServer
from .rsi_echo_server import EchoServer
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from mpl_toolkits.mplot3d import Axes3D

View File

@ -150,7 +150,7 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.limits:
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
from .rsi_limit_parser import parse_rsi_limits
limits = parse_rsi_limits(args.limits)
visualiser = KukaRSIVisualiser(args.csv_file, safety_limits=limits)
else:

View File

@ -2,47 +2,116 @@ import multiprocessing
import socket
import logging
import xml.etree.ElementTree as ET
import os
import datetime
from queue import Empty
from .xml_handler import XMLGenerator
from .safety_manager import SafetyManager
class CSVLogger(multiprocessing.Process):
"""Separate process for writing CSV logs without blocking the network loop."""
def __init__(self, log_queue, stop_event, filename):
super().__init__()
self.log_queue = log_queue
self.stop_event = stop_event
self.filename = filename
self.daemon = True
def run(self):
"""Write log entries from queue to CSV file."""
# Ensure logs directory exists
log_dir = os.path.dirname(self.filename)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
header_written = False
try:
with open(self.filename, 'w', newline='') as f:
while not self.stop_event.is_set():
try:
entry = self.log_queue.get(timeout=0.5)
if entry is None: # Poison pill
break
# Write header on first entry
if not header_written:
headers = ['Timestamp'] + list(entry.keys())
f.write(','.join(headers) + '\n')
header_written = True
# Write data row
timestamp = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S.%f")[:-3]
values = [timestamp] + [str(v) for v in entry.values()]
f.write(','.join(values) + '\n')
f.flush()
except Empty:
continue
except Exception as e:
logging.error(f"CSV logging error: {e}")
except Exception as e:
logging.error(f"Failed to open log file {self.filename}: {e}")
class NetworkProcess(multiprocessing.Process):
"""Handles UDP communication and optional CSV logging in a separate process."""
def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event):
def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event, command_queue):
super().__init__()
self.send_variables = send_variables
self.receive_variables = receive_variables
self.stop_event = stop_event
self.start_event = start_event # ✅ NEW
self.start_event = start_event
self.config_parser = config_parser
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.command_queue = command_queue
self.safety_manager = SafetyManager(config_parser.safety_limits)
self.client_address = (ip, port)
self.logging_active = multiprocessing.Value('b', False)
self.log_filename = multiprocessing.Array('c', 256)
self.csv_process = None
self.controller_ip_and_port = None
self.udp_socket = None
# Logging infrastructure (created when logging starts)
self.log_queue = None
self.log_stop_event = None
self.csv_logger = None
def run(self):
"""Start the network loop."""
self.start_event.wait() # ✅ Wait until RSIClient sends start signal
# Wait for start signal, but check stop_event periodically to allow clean shutdown
while not self.start_event.wait(timeout=0.5):
if self.stop_event.is_set():
logging.info("Network process stopped before starting")
return
try:
self._setup_socket()
self._run_loop()
finally:
self._cleanup()
def _setup_socket(self):
"""Create and bind the UDP socket."""
if not self.is_valid_ip(self.client_address[0]):
logging.warning(f"Invalid IP address '{self.client_address[0]}'. Falling back to '0.0.0.0'.")
self.client_address = ('0.0.0.0', self.client_address[1])
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.udp_socket.bind(self.client_address)
logging.info(f"✅ Network process bound on {self.client_address}")
except OSError as e:
logging.error(f"❌ Failed to bind to {self.client_address}: {e}")
raise
logging.info(f"Network process bound on {self.client_address}")
def _run_loop(self):
"""Main communication loop."""
while not self.stop_event.is_set():
# Check for commands (non-blocking)
self._process_commands()
try:
self.udp_socket.settimeout(5)
data_received, self.controller_ip_and_port = self.udp_socket.recvfrom(1024)
@ -51,13 +120,115 @@ class NetworkProcess(multiprocessing.Process):
send_xml = XMLGenerator.generate_send_xml(self.send_variables, self.config_parser.network_settings)
self.udp_socket.sendto(send_xml.encode(), self.controller_ip_and_port)
if self.logging_active.value:
self.log_to_csv()
if self.logging_active.value and self.log_queue:
self._queue_log_entry()
except socket.timeout:
logging.error("[WARNING] No message received within timeout period.")
logging.warning("No message received within timeout period")
except Exception as e:
logging.error(f"[ERROR] Network process error: {e}")
logging.error(f"Network process error: {e}")
def _process_commands(self):
"""Process any pending commands from the parent process."""
try:
while True:
cmd = self.command_queue.get_nowait()
if cmd is None:
continue
action = cmd.get('action')
if action == 'start_logging':
self._start_logging(cmd.get('filename'))
elif action == 'stop_logging':
self._stop_logging()
except Empty:
pass
except Exception as e:
logging.error(f"Error processing command: {e}")
def _queue_log_entry(self):
"""Queue current state for CSV logging (non-blocking)."""
try:
entry = {}
# Flatten send variables
for key, value in dict(self.send_variables).items():
if isinstance(value, dict):
for subkey, subval in value.items():
entry[f"Send.{key}.{subkey}"] = subval
else:
entry[f"Send.{key}"] = value
# Flatten receive variables
for key, value in dict(self.receive_variables).items():
if isinstance(value, dict):
for subkey, subval in value.items():
entry[f"Receive.{key}.{subkey}"] = subval
else:
entry[f"Receive.{key}"] = value
# Non-blocking put - drop entry if queue is full
try:
self.log_queue.put_nowait(entry)
except:
pass # Queue full, skip this entry rather than block
except Exception as e:
logging.debug(f"Failed to queue log entry: {e}")
def _start_logging(self, filename):
"""Start CSV logging to the specified file."""
if self.logging_active.value:
logging.warning("Logging already active")
return
self.log_queue = multiprocessing.Queue(maxsize=1000)
self.log_stop_event = multiprocessing.Event()
self.csv_logger = CSVLogger(self.log_queue, self.log_stop_event, filename)
self.csv_logger.start()
self.logging_active.value = True
logging.info(f"CSV logging started: {filename}")
def _stop_logging(self):
"""Stop CSV logging."""
if not self.logging_active.value:
return
self.logging_active.value = False
if self.log_queue:
try:
self.log_queue.put_nowait(None) # Poison pill
except:
pass
if self.log_stop_event:
self.log_stop_event.set()
if self.csv_logger and self.csv_logger.is_alive():
self.csv_logger.join(timeout=2)
if self.csv_logger.is_alive():
self.csv_logger.terminate()
self.csv_logger = None
self.log_queue = None
self.log_stop_event = None
logging.info("CSV logging stopped")
def _cleanup(self):
"""Clean up resources."""
# Stop logging first
self._stop_logging()
if self.udp_socket:
try:
self.udp_socket.close()
logging.info("Network socket closed")
except Exception as e:
logging.error(f"Error closing socket: {e}")
self.udp_socket = None
@staticmethod
def is_valid_ip(ip):
@ -83,4 +254,4 @@ class NetworkProcess(multiprocessing.Process):
self.receive_variables["IPOC"] = received_ipoc
self.send_variables["IPOC"] = received_ipoc + 4
except Exception as e:
logging.error(f"[ERROR] Error parsing received message: {e}")
logging.error(f"Error parsing received message: {e}")

View File

@ -9,9 +9,9 @@ from .inject_rsi_to_krl import inject_rsi_to_krl
import threading
from .trajectory_planner import generate_trajectory, execute_trajectory
import datetime
from src.RSIPI.static_plotter import StaticPlotter # Make sure this file exists as described
from .static_plotter import StaticPlotter
import os
from src.RSIPI.live_plotter import LivePlotter
from .live_plotter import LivePlotter
from threading import Thread
import asyncio
@ -49,6 +49,7 @@ class RSIAPI:
self.client.stop()
return "RSI stopped."
@staticmethod
def generate_report(filename, format_type):
"""
Generate a statistical report from a CSV log file.

View File

@ -1,17 +1,43 @@
import logging
import multiprocessing
import time
from enum import Enum, auto
from threading import Lock
from .config_parser import ConfigParser
from .network_handler import NetworkProcess
from .safety_manager import SafetyManager
import threading
class ClientState(Enum):
"""Connection states for RSIClient."""
INITIALIZED = auto() # After __init__, network process spawned but not started
STARTING = auto() # Start signal sent, waiting for network to be ready
RUNNING = auto() # Actively communicating with robot
STOPPING = auto() # Shutdown in progress
STOPPED = auto() # Fully stopped, cannot be restarted (use reconnect)
ERROR = auto() # Error state
class RSIClient:
"""Main RSI API class that integrates network, config handling, and message processing."""
# Valid state transitions
_VALID_TRANSITIONS = {
ClientState.INITIALIZED: {ClientState.STARTING, ClientState.STOPPING},
ClientState.STARTING: {ClientState.RUNNING, ClientState.STOPPING, ClientState.ERROR},
ClientState.RUNNING: {ClientState.STOPPING, ClientState.ERROR},
ClientState.STOPPING: {ClientState.STOPPED, ClientState.ERROR},
ClientState.STOPPED: {ClientState.INITIALIZED}, # Via reconnect
ClientState.ERROR: {ClientState.STOPPING, ClientState.INITIALIZED}, # Via reconnect
}
def __init__(self, config_file, rsi_limits_file=None):
logging.info(f"Loading RSI configuration from {config_file}...")
self._state = ClientState.INITIALIZED
self._state_lock = Lock()
self.config_parser = ConfigParser(config_file, rsi_limits_file)
network_settings = self.config_parser.get_network_settings()
@ -19,11 +45,15 @@ class RSIClient:
self.send_variables = self.manager.dict(self.config_parser.send_variables)
self.receive_variables = self.manager.dict(self.config_parser.receive_variables)
self.stop_event = multiprocessing.Event()
self.start_event = multiprocessing.Event() # ✅ NEW
self.start_event = multiprocessing.Event()
self.command_queue = multiprocessing.Queue()
self.safety_manager = SafetyManager(self.config_parser.safety_limits)
# ✅ Create NetworkProcess but don't start communication yet
# Shared logging state (readable from parent process)
self._logging_active = multiprocessing.Value('b', False)
# Create NetworkProcess but don't start communication yet
self.network_process = NetworkProcess(
network_settings["ip"],
network_settings["port"],
@ -31,17 +61,56 @@ class RSIClient:
self.receive_variables,
self.stop_event,
self.config_parser,
self.start_event
self.start_event,
self.command_queue
)
# Share the logging_active flag
self.network_process.logging_active = self._logging_active
self.network_process.start()
self.logger = None
self.running = False
self.thread = None
@property
def state(self) -> ClientState:
"""Get current client state (thread-safe)."""
with self._state_lock:
return self._state
def _transition_to(self, new_state: ClientState) -> bool:
"""
Attempt to transition to a new state.
Returns:
True if transition was valid and completed, False otherwise.
"""
with self._state_lock:
if new_state in self._VALID_TRANSITIONS.get(self._state, set()):
old_state = self._state
self._state = new_state
logging.debug(f"State transition: {old_state.name} -> {new_state.name}")
return True
else:
logging.warning(
f"Invalid state transition attempted: {self._state.name} -> {new_state.name}"
)
return False
def start(self):
"""Send start signal to NetworkProcess and run control loop."""
if not self._transition_to(ClientState.STARTING):
logging.error("Cannot start: invalid state")
return
logging.info("RSIClient sending start signal to NetworkProcess...")
self.start_event.set()
self.running = True
if not self._transition_to(ClientState.RUNNING):
logging.error("Failed to transition to RUNNING state")
return
self.running = True
logging.info("RSI Client Started")
try:
@ -51,39 +120,58 @@ class RSIClient:
self.stop()
except Exception as e:
logging.error(f"RSI Client encountered an error: {e}")
self._transition_to(ClientState.ERROR)
def stop(self):
"""Stop the network process and the client thread safely."""
logging.info("🛑 Stopping RSI Client...")
if self.state in (ClientState.STOPPED, ClientState.STOPPING):
logging.debug("Already stopped or stopping")
return
if not self._transition_to(ClientState.STOPPING):
logging.warning("Could not transition to STOPPING state")
# Continue anyway to ensure cleanup
logging.info("Stopping RSI Client...")
self.running = False
self.stop_event.set() # ✅ Tell network process to exit nicely
self.stop_event.set()
if self.network_process and self.network_process.is_alive():
self.network_process.join(timeout=3) # ✅ Give it time to shutdown
self.network_process.join(timeout=3)
if self.network_process.is_alive():
logging.warning("⚠️ Forcing network process termination...")
logging.warning("Forcing network process termination...")
self.network_process.terminate()
self.network_process.join()
if hasattr(self, "thread") and self.thread and self.thread.is_alive():
self.thread.join()
if self.thread and self.thread.is_alive():
self.thread.join(timeout=2)
self.thread = None
logging.info("✅ RSI Client Stopped")
self._transition_to(ClientState.STOPPED)
logging.info("RSI Client Stopped")
def reconnect(self):
"""Reconnects the network process safely."""
logging.info("Reconnecting RSI Client network...")
# Stop if currently running
if self.state in (ClientState.RUNNING, ClientState.STARTING):
self.stop()
if self.network_process and self.network_process.is_alive():
self.stop_event.set()
self.network_process.terminate()
self.network_process.join()
# Fresh new events
# Reset to initialized state
with self._state_lock:
self._state = ClientState.INITIALIZED
# Fresh new events and queue
self.stop_event = multiprocessing.Event()
self.start_event = multiprocessing.Event()
self.command_queue = multiprocessing.Queue()
# Create new network process
network_settings = self.config_parser.get_network_settings()
@ -94,10 +182,32 @@ class RSIClient:
self.receive_variables,
self.stop_event,
self.config_parser,
self.start_event
self.start_event,
self.command_queue
)
self.network_process.logging_active = self._logging_active
self.network_process.start()
# Fresh control thread
self.thread = threading.Thread(target=self.start, daemon=True)
self.thread.start()
def is_running(self) -> bool:
"""Check if client is in running state."""
return self.state == ClientState.RUNNING
def is_stopped(self) -> bool:
"""Check if client is fully stopped."""
return self.state == ClientState.STOPPED
def start_logging(self, filename):
"""Start CSV logging to the specified file."""
self.command_queue.put({'action': 'start_logging', 'filename': filename})
def stop_logging(self):
"""Stop CSV logging."""
self.command_queue.put({'action': 'stop_logging'})
def is_logging_active(self) -> bool:
"""Check if CSV logging is currently active."""
return self._logging_active.value

View File

@ -1,6 +1,6 @@
import xml.etree.ElementTree as ET
import logging
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
from .rsi_limit_parser import parse_rsi_limits
# ✅ Configure Logging (toggleable)
LOGGING_ENABLED = False # Change too False to silence logging output

View File

@ -3,7 +3,7 @@ import time
import xml.etree.ElementTree as ET
import logging
import threading
from src.RSIPI.rsi_config import RSIConfig
from .rsi_config import RSIConfig
# ✅ Toggle logging for debugging purposes
LOGGING_ENABLED = True

View File

@ -1,4 +1,4 @@
from RSIPI.safety_manager import SafetyManager
from .safety_manager import SafetyManager
import time
def generate_trajectory(start, end, steps=100, space="cartesian", mode="absolute", include_resets=False):

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# RSIPI Tests

Binary file not shown.

View File

@ -0,0 +1,168 @@
"""Tests for SafetyManager."""
import pytest
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from RSIPI.safety_manager import SafetyManager
class TestValidate:
"""Tests for SafetyManager.validate()"""
def test_validate_within_limits(self):
"""Test that values within limits pass through unchanged."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
result = sm.validate("RKorr.X", 3.0)
assert result == 3.0
def test_validate_at_boundary(self):
"""Test that values at exact boundaries pass."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
assert sm.validate("RKorr.X", -5.0) == -5.0
assert sm.validate("RKorr.X", 5.0) == 5.0
def test_validate_exceeds_max(self):
"""Test that values exceeding max raise ValueError."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
with pytest.raises(ValueError, match="out of bounds"):
sm.validate("RKorr.X", 5.1)
def test_validate_below_min(self):
"""Test that values below min raise ValueError."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
with pytest.raises(ValueError, match="out of bounds"):
sm.validate("RKorr.X", -5.1)
def test_validate_unlisted_path(self):
"""Test that paths without defined limits pass through."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
# RKorr.Y has no limit defined - should pass
result = sm.validate("RKorr.Y", 1000.0)
assert result == 1000.0
def test_validate_with_override(self):
"""Test that override bypasses all limit checks."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
sm.override_safety(True)
# Should pass despite being out of bounds
result = sm.validate("RKorr.X", 100.0)
assert result == 100.0
class TestEmergencyStop:
"""Tests for emergency stop functionality."""
def test_estop_blocks_validation(self):
"""Test that e-stop blocks all validation."""
sm = SafetyManager()
sm.emergency_stop()
with pytest.raises(RuntimeError, match="E-STOP"):
sm.validate("RKorr.X", 0.0)
def test_estop_reset(self):
"""Test that e-stop can be reset."""
sm = SafetyManager()
sm.emergency_stop()
assert sm.is_stopped() is True
sm.reset_stop()
assert sm.is_stopped() is False
# Should work again
result = sm.validate("RKorr.X", 1.0)
assert result == 1.0
class TestSetLimit:
"""Tests for runtime limit modification."""
def test_set_new_limit(self):
"""Test adding a new limit at runtime."""
sm = SafetyManager()
sm.set_limit("RKorr.Y", -10.0, 10.0)
assert sm.validate("RKorr.Y", 5.0) == 5.0
with pytest.raises(ValueError):
sm.validate("RKorr.Y", 15.0)
def test_override_existing_limit(self):
"""Test overriding an existing limit."""
limits = {"RKorr.X": (-5.0, 5.0)}
sm = SafetyManager(limits)
# Original limit blocks this
with pytest.raises(ValueError):
sm.validate("RKorr.X", 8.0)
# Override limit
sm.set_limit("RKorr.X", -10.0, 10.0)
# Now it should pass
assert sm.validate("RKorr.X", 8.0) == 8.0
def test_get_limits(self):
"""Test retrieving all limits."""
limits = {"RKorr.X": (-5.0, 5.0), "AKorr.A1": (-6.0, 6.0)}
sm = SafetyManager(limits)
retrieved = sm.get_limits()
assert retrieved == limits
# Should be a copy
retrieved["new"] = (0, 1)
assert "new" not in sm.get_limits()
class TestStaticChecks:
"""Tests for static limit checking methods."""
def test_cartesian_limits_valid(self):
"""Test valid Cartesian pose passes check."""
pose = {"X": 500, "Y": -200, "Z": 1000, "A": 0, "B": 0, "C": 0}
assert SafetyManager.check_cartesian_limits(pose) is True
def test_cartesian_limits_z_too_low(self):
"""Test Z below zero fails."""
pose = {"X": 0, "Y": 0, "Z": -100}
assert SafetyManager.check_cartesian_limits(pose) is False
def test_cartesian_limits_x_too_high(self):
"""Test X exceeding max fails."""
pose = {"X": 2000, "Y": 0, "Z": 500}
assert SafetyManager.check_cartesian_limits(pose) is False
def test_cartesian_limits_partial_pose(self):
"""Test partial pose (missing keys) passes if present keys are valid."""
pose = {"X": 100, "Z": 500} # Missing Y, A, B, C
assert SafetyManager.check_cartesian_limits(pose) is True
def test_joint_limits_valid(self):
"""Test valid joint pose passes check."""
pose = {"A1": 0, "A2": -45, "A3": 90, "A4": 0, "A5": 0, "A6": 180}
assert SafetyManager.check_joint_limits(pose) is True
def test_joint_limits_a1_exceeded(self):
"""Test A1 exceeding limit fails."""
pose = {"A1": 200} # Limit is -185 to 185
assert SafetyManager.check_joint_limits(pose) is False
def test_joint_limits_a5_exceeded(self):
"""Test A5 exceeding its tighter limit fails."""
pose = {"A5": 150} # Limit is -130 to 130
assert SafetyManager.check_joint_limits(pose) is False

View File

@ -0,0 +1,136 @@
"""Tests for trajectory planner."""
import pytest
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from RSIPI.trajectory_planner import generate_trajectory
class TestGenerateTrajectory:
"""Tests for generate_trajectory()"""
def test_basic_linear_interpolation(self):
"""Test basic linear interpolation between two points."""
start = {"X": 0, "Y": 0, "Z": 0}
end = {"X": 100, "Y": 200, "Z": 300}
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute")
assert len(traj) == 10
# First point should be 1/10 of the way
assert traj[0]["X"] == pytest.approx(10.0)
assert traj[0]["Y"] == pytest.approx(20.0)
assert traj[0]["Z"] == pytest.approx(30.0)
# Last point should be at end
assert traj[-1]["X"] == pytest.approx(100.0)
assert traj[-1]["Y"] == pytest.approx(200.0)
assert traj[-1]["Z"] == pytest.approx(300.0)
def test_relative_mode(self):
"""Test relative mode generates incremental deltas."""
start = {"X": 0, "Y": 0}
end = {"X": 100, "Y": 50}
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="relative")
# Each step should be the same delta
for point in traj:
assert point["X"] == pytest.approx(10.0) # 100/10
assert point["Y"] == pytest.approx(5.0) # 50/10
def test_relative_mode_with_resets(self):
"""Test relative mode with reset steps."""
start = {"X": 0}
end = {"X": 30}
traj = generate_trajectory(start, end, steps=3, space="cartesian", mode="relative", include_resets=True)
# Should have 6 points: delta, reset, delta, reset, delta, reset
assert len(traj) == 6
# Odd indices should be zero (reset points)
assert traj[1]["X"] == 0.0
assert traj[3]["X"] == 0.0
assert traj[5]["X"] == 0.0
# Even indices should be deltas
assert traj[0]["X"] == pytest.approx(10.0)
assert traj[2]["X"] == pytest.approx(10.0)
assert traj[4]["X"] == pytest.approx(10.0)
def test_joint_space(self):
"""Test trajectory generation in joint space."""
start = {"A1": 0, "A2": 0}
end = {"A1": 90, "A2": -45}
traj = generate_trajectory(start, end, steps=9, space="joint", mode="absolute")
assert len(traj) == 9
# Final point should match end
assert traj[-1]["A1"] == pytest.approx(90.0)
assert traj[-1]["A2"] == pytest.approx(-45.0)
def test_negative_movement(self):
"""Test trajectory with negative direction."""
start = {"X": 100, "Y": 100}
end = {"X": 0, "Y": -100}
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute")
# Should decrease linearly
assert traj[0]["X"] == pytest.approx(90.0)
assert traj[0]["Y"] == pytest.approx(80.0)
assert traj[-1]["X"] == pytest.approx(0.0)
assert traj[-1]["Y"] == pytest.approx(-100.0)
def test_single_step(self):
"""Test trajectory with single step goes directly to end."""
start = {"X": 0}
end = {"X": 100}
traj = generate_trajectory(start, end, steps=1, space="cartesian", mode="absolute")
assert len(traj) == 1
assert traj[0]["X"] == pytest.approx(100.0)
def test_invalid_mode_raises(self):
"""Test that invalid mode raises ValueError."""
start = {"X": 0}
end = {"X": 100}
with pytest.raises(ValueError, match="mode must be"):
generate_trajectory(start, end, mode="invalid")
def test_invalid_space_raises(self):
"""Test that invalid space raises ValueError."""
start = {"X": 0}
end = {"X": 100}
with pytest.raises(ValueError, match="space must be"):
generate_trajectory(start, end, space="invalid")
def test_absolute_mode_ignores_include_resets(self):
"""Test that absolute mode ignores include_resets flag."""
start = {"X": 0}
end = {"X": 100}
# Even with include_resets=True, absolute mode should not add resets
traj = generate_trajectory(start, end, steps=5, space="cartesian", mode="absolute", include_resets=True)
assert len(traj) == 5 # No extra reset points
def test_preserves_all_axes(self):
"""Test that all axes from start are preserved in trajectory."""
start = {"X": 0, "Y": 0, "Z": 0, "A": 0, "B": 0, "C": 0}
end = {"X": 10, "Y": 20, "Z": 30, "A": 5, "B": 10, "C": 15}
traj = generate_trajectory(start, end, steps=2, space="cartesian", mode="absolute")
# Each point should have all axes
for point in traj:
assert set(point.keys()) == {"X", "Y", "Z", "A", "B", "C"}

112
tests/test_xml_handler.py Normal file
View File

@ -0,0 +1,112 @@
"""Tests for XMLGenerator."""
import pytest
import xml.etree.ElementTree as ET
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from RSIPI.xml_handler import XMLGenerator
class TestGenerateSendXML:
"""Tests for XMLGenerator.generate_send_xml()"""
def test_simple_values(self):
"""Test XML generation with simple scalar values."""
send_vars = {
"IPOC": 12345,
"DiL": 1
}
network_settings = {"sentype": "ImFree"}
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
# Parse and verify
root = ET.fromstring(xml_str)
assert root.tag == "Sen"
assert root.get("Type") == "ImFree"
assert root.find("IPOC").text == "12345"
assert root.find("DiL").text == "1"
def test_nested_dict_values(self):
"""Test XML generation with nested dictionary values (attributes)."""
send_vars = {
"RKorr": {"X": 100.5, "Y": -50.25, "Z": 0.0}
}
network_settings = {"sentype": "TestType"}
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
root = ET.fromstring(xml_str)
rkorr = root.find("RKorr")
assert rkorr is not None
assert float(rkorr.get("X")) == 100.50
assert float(rkorr.get("Y")) == -50.25
assert float(rkorr.get("Z")) == 0.0
def test_free_field_skipped(self):
"""Test that FREE field is skipped in XML output."""
send_vars = {
"IPOC": 1,
"FREE": 999
}
network_settings = {"sentype": "ImFree"}
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
root = ET.fromstring(xml_str)
assert root.find("FREE") is None
assert root.find("IPOC") is not None
def test_mixed_values(self):
"""Test XML generation with mixed scalar and nested values."""
send_vars = {
"IPOC": 5000,
"RKorr": {"X": 10.0, "Y": 20.0},
"Digout": {"o1": 1, "o2": 0}
}
network_settings = {"sentype": "Mixed"}
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
root = ET.fromstring(xml_str)
assert root.find("IPOC").text == "5000"
assert float(root.find("RKorr").get("X")) == 10.0
assert float(root.find("Digout").get("o1")) == 1.0
class TestGenerateReceiveXML:
"""Tests for XMLGenerator.generate_receive_xml()"""
def test_simple_receive(self):
"""Test receive XML generation with scalar values."""
receive_vars = {
"IPOC": 12345,
"BMode": "Active"
}
xml_str = XMLGenerator.generate_receive_xml(receive_vars)
root = ET.fromstring(xml_str)
assert root.tag == "Rob"
assert root.get("Type") == "KUKA"
assert root.find("IPOC").text == "12345"
assert root.find("BMode").text == "Active"
def test_nested_receive(self):
"""Test receive XML generation with position data."""
receive_vars = {
"RIst": {"X": 500.0, "Y": 100.0, "Z": 800.0},
"ASPos": {"A1": 0.0, "A2": -45.0, "A3": 90.0}
}
xml_str = XMLGenerator.generate_receive_xml(receive_vars)
root = ET.fromstring(xml_str)
rist = root.find("RIst")
assert float(rist.get("X")) == 500.0
assert float(rist.get("Z")) == 800.0
aspos = root.find("ASPos")
assert float(aspos.get("A2")) == -45.0