Refactor core architecture and add test coverage
- Fix socket lifecycle: create in child process, add cleanup with try/finally - Add ClientState enum with validated state transitions to prevent invalid operations - Decouple CSV logging from network loop using queue-based CSVLogger process - Fix broken imports: change absolute (src.RSIPI.x) to relative (.x) across 7 files - Add missing @staticmethod decorator to generate_report() - Add command queue for inter-process communication (logging control) - Add 34 unit tests for XMLGenerator, SafetyManager, and trajectory_planner - Add pytest configuration to pyproject.toml - Add CLAUDE.md with architecture documentation
This commit is contained in:
parent
219ffc531c
commit
7bfe5cccf1
79
CLAUDE.md
Normal file
79
CLAUDE.md
Normal file
@ -0,0 +1,79 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## What This Project Does
|
||||
|
||||
RSIPI enables real-time control of KUKA industrial robots from Python via the RSI (Robot Sensor Interface) protocol. The robot sends its position ~250 times/second over UDP, and this library lets you send back position corrections to control the robot externally.
|
||||
|
||||
## Build & Development Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -e .
|
||||
|
||||
# Or install from requirements (if present)
|
||||
pip install pandas>=2.0 numpy>=1.22 matplotlib>=3.5 lxml>=4.9 scipy>=1.8
|
||||
|
||||
# Run the CLI
|
||||
python -m RSIPI.rsi_cli --config RSI_EthernetConfig.xml
|
||||
|
||||
# Run the echo server (for offline testing without a real robot)
|
||||
python -m RSIPI.rsi_echo_server
|
||||
```
|
||||
|
||||
**No test suite exists** - testing is done via the echo server simulation and example scripts in `examples/`.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Communication Flow
|
||||
|
||||
```
|
||||
KUKA Robot Controller <--UDP/XML--> NetworkProcess <--multiprocessing.Manager--> RSIClient <-- RSIAPI/CLI
|
||||
```
|
||||
|
||||
1. **NetworkProcess** (`network_handler.py`) - Runs in separate process via `multiprocessing.Process`. Binds to UDP socket, receives XML from robot, parses into `receive_variables`, sends XML from `send_variables` back to robot. Uses `start_event` to wait for explicit start signal.
|
||||
|
||||
2. **RSIClient** (`rsi_client.py`) - Orchestrates the system. Initializes ConfigParser, SafetyManager, and NetworkProcess. Uses `multiprocessing.Manager` dicts for thread-safe variable sharing between processes.
|
||||
|
||||
3. **RSIAPI** (`rsi_api.py`) - High-level API wrapping RSIClient. Runs RSIClient in a daemon thread. Provides trajectory planning, logging, plotting, and safety controls.
|
||||
|
||||
4. **RSICommandLineInterface** (`rsi_cli.py`) - Interactive CLI that wraps RSIAPI.
|
||||
|
||||
### Key Shared State
|
||||
|
||||
Variables are shared between processes using `multiprocessing.Manager().dict()`:
|
||||
- `send_variables` - Values to send to robot (RKorr corrections, digital outputs, etc.)
|
||||
- `receive_variables` - Values received from robot (RIst position, ASPos joints, IPOC timestamp)
|
||||
|
||||
### Configuration
|
||||
|
||||
`RSI_EthernetConfig.xml` defines:
|
||||
- Network settings (IP, port) in `<CONFIG>` section
|
||||
- Send variables in `<SEND><ELEMENTS>` - what the robot receives from us
|
||||
- Receive variables in `<RECEIVE><ELEMENTS>` - what we receive from robot
|
||||
|
||||
Variable tags like `DEF_RIst` get the `DEF_` prefix stripped and are expanded using `internal_structure` in ConfigParser to full dicts (e.g., `RIst: {X, Y, Z, A, B, C}`).
|
||||
|
||||
### Safety Layer
|
||||
|
||||
**SafetyManager** (`safety_manager.py`) validates all outgoing values against configurable limits. Can load limits from `.rsi.xml` files. Supports emergency stop and safety override modes.
|
||||
|
||||
### Trajectory Execution
|
||||
|
||||
`TrajectoryPlanner` generates interpolated waypoints. `execute_trajectory()` in RSIAPI uses asyncio to send points at specified rate (default 12ms for Cartesian, 400ms for joints).
|
||||
|
||||
## Important Patterns
|
||||
|
||||
- **IPOC synchronization**: The robot sends an IPOC (timestamp) value. The response must include `IPOC + 4` to maintain sync. This is handled automatically in `NetworkProcess.process_received_data()`.
|
||||
|
||||
- **Lazy client initialization**: RSIAPI uses `_ensure_client()` pattern - RSIClient is created on first use, not at RSIAPI instantiation.
|
||||
|
||||
- **Non-blocking start**: `start_rsi()` runs the client loop in a daemon thread. The NetworkProcess waits on `start_event` before binding the socket.
|
||||
|
||||
## File Locations
|
||||
|
||||
- Source code: `src/RSIPI/`
|
||||
- Example scripts: `examples/`
|
||||
- Config template: `RSI_EthernetConfig.xml`
|
||||
- Logs written to: `logs/` (created at runtime)
|
||||
@ -25,8 +25,17 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
where = ["src"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
@ -17,7 +17,7 @@ class ConfigParser:
|
||||
config_file (str): Path to the RSI_EthernetConfig.xml file.
|
||||
rsi_limits_file (str, optional): Path to .rsi.xml file containing safety limits.
|
||||
"""
|
||||
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
|
||||
from .rsi_limit_parser import parse_rsi_limits
|
||||
|
||||
self.config_file = config_file
|
||||
self.rsi_limits_file = rsi_limits_file
|
||||
|
||||
@ -2,7 +2,7 @@ import tkinter as tk
|
||||
from tkinter import ttk, filedialog
|
||||
import threading
|
||||
import time
|
||||
from src.RSIPI.rsi_echo_server import EchoServer
|
||||
from .rsi_echo_server import EchoServer
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
@ -150,7 +150,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.limits:
|
||||
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
|
||||
from .rsi_limit_parser import parse_rsi_limits
|
||||
limits = parse_rsi_limits(args.limits)
|
||||
visualiser = KukaRSIVisualiser(args.csv_file, safety_limits=limits)
|
||||
else:
|
||||
|
||||
@ -2,47 +2,116 @@ import multiprocessing
|
||||
import socket
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
import os
|
||||
import datetime
|
||||
from queue import Empty
|
||||
from .xml_handler import XMLGenerator
|
||||
from .safety_manager import SafetyManager
|
||||
|
||||
|
||||
class CSVLogger(multiprocessing.Process):
|
||||
"""Separate process for writing CSV logs without blocking the network loop."""
|
||||
|
||||
def __init__(self, log_queue, stop_event, filename):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
self.stop_event = stop_event
|
||||
self.filename = filename
|
||||
self.daemon = True
|
||||
|
||||
def run(self):
|
||||
"""Write log entries from queue to CSV file."""
|
||||
# Ensure logs directory exists
|
||||
log_dir = os.path.dirname(self.filename)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
header_written = False
|
||||
|
||||
try:
|
||||
with open(self.filename, 'w', newline='') as f:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
entry = self.log_queue.get(timeout=0.5)
|
||||
if entry is None: # Poison pill
|
||||
break
|
||||
|
||||
# Write header on first entry
|
||||
if not header_written:
|
||||
headers = ['Timestamp'] + list(entry.keys())
|
||||
f.write(','.join(headers) + '\n')
|
||||
header_written = True
|
||||
|
||||
# Write data row
|
||||
timestamp = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S.%f")[:-3]
|
||||
values = [timestamp] + [str(v) for v in entry.values()]
|
||||
f.write(','.join(values) + '\n')
|
||||
f.flush()
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.error(f"CSV logging error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to open log file {self.filename}: {e}")
|
||||
|
||||
|
||||
class NetworkProcess(multiprocessing.Process):
|
||||
"""Handles UDP communication and optional CSV logging in a separate process."""
|
||||
|
||||
def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event):
|
||||
def __init__(self, ip, port, send_variables, receive_variables, stop_event, config_parser, start_event, command_queue):
|
||||
super().__init__()
|
||||
self.send_variables = send_variables
|
||||
self.receive_variables = receive_variables
|
||||
self.stop_event = stop_event
|
||||
self.start_event = start_event # ✅ NEW
|
||||
self.start_event = start_event
|
||||
self.config_parser = config_parser
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.command_queue = command_queue
|
||||
self.safety_manager = SafetyManager(config_parser.safety_limits)
|
||||
|
||||
self.client_address = (ip, port)
|
||||
self.logging_active = multiprocessing.Value('b', False)
|
||||
self.log_filename = multiprocessing.Array('c', 256)
|
||||
self.csv_process = None
|
||||
|
||||
self.controller_ip_and_port = None
|
||||
self.udp_socket = None
|
||||
|
||||
# Logging infrastructure (created when logging starts)
|
||||
self.log_queue = None
|
||||
self.log_stop_event = None
|
||||
self.csv_logger = None
|
||||
|
||||
def run(self):
|
||||
"""Start the network loop."""
|
||||
self.start_event.wait() # ✅ Wait until RSIClient sends start signal
|
||||
# Wait for start signal, but check stop_event periodically to allow clean shutdown
|
||||
while not self.start_event.wait(timeout=0.5):
|
||||
if self.stop_event.is_set():
|
||||
logging.info("Network process stopped before starting")
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.is_valid_ip(self.client_address[0]):
|
||||
logging.warning(f"Invalid IP address '{self.client_address[0]}'. Falling back to '0.0.0.0'.")
|
||||
self.client_address = ('0.0.0.0', self.client_address[1])
|
||||
self._setup_socket()
|
||||
self._run_loop()
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.udp_socket.bind(self.client_address)
|
||||
logging.info(f"✅ Network process bound on {self.client_address}")
|
||||
def _setup_socket(self):
|
||||
"""Create and bind the UDP socket."""
|
||||
if not self.is_valid_ip(self.client_address[0]):
|
||||
logging.warning(f"Invalid IP address '{self.client_address[0]}'. Falling back to '0.0.0.0'.")
|
||||
self.client_address = ('0.0.0.0', self.client_address[1])
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"❌ Failed to bind to {self.client_address}: {e}")
|
||||
raise
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.udp_socket.bind(self.client_address)
|
||||
logging.info(f"Network process bound on {self.client_address}")
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main communication loop."""
|
||||
while not self.stop_event.is_set():
|
||||
# Check for commands (non-blocking)
|
||||
self._process_commands()
|
||||
|
||||
try:
|
||||
self.udp_socket.settimeout(5)
|
||||
data_received, self.controller_ip_and_port = self.udp_socket.recvfrom(1024)
|
||||
@ -51,13 +120,115 @@ class NetworkProcess(multiprocessing.Process):
|
||||
send_xml = XMLGenerator.generate_send_xml(self.send_variables, self.config_parser.network_settings)
|
||||
self.udp_socket.sendto(send_xml.encode(), self.controller_ip_and_port)
|
||||
|
||||
if self.logging_active.value:
|
||||
self.log_to_csv()
|
||||
if self.logging_active.value and self.log_queue:
|
||||
self._queue_log_entry()
|
||||
|
||||
except socket.timeout:
|
||||
logging.error("[WARNING] No message received within timeout period.")
|
||||
logging.warning("No message received within timeout period")
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] Network process error: {e}")
|
||||
logging.error(f"Network process error: {e}")
|
||||
|
||||
def _process_commands(self):
|
||||
"""Process any pending commands from the parent process."""
|
||||
try:
|
||||
while True:
|
||||
cmd = self.command_queue.get_nowait()
|
||||
if cmd is None:
|
||||
continue
|
||||
|
||||
action = cmd.get('action')
|
||||
if action == 'start_logging':
|
||||
self._start_logging(cmd.get('filename'))
|
||||
elif action == 'stop_logging':
|
||||
self._stop_logging()
|
||||
|
||||
except Empty:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing command: {e}")
|
||||
|
||||
def _queue_log_entry(self):
|
||||
"""Queue current state for CSV logging (non-blocking)."""
|
||||
try:
|
||||
entry = {}
|
||||
# Flatten send variables
|
||||
for key, value in dict(self.send_variables).items():
|
||||
if isinstance(value, dict):
|
||||
for subkey, subval in value.items():
|
||||
entry[f"Send.{key}.{subkey}"] = subval
|
||||
else:
|
||||
entry[f"Send.{key}"] = value
|
||||
|
||||
# Flatten receive variables
|
||||
for key, value in dict(self.receive_variables).items():
|
||||
if isinstance(value, dict):
|
||||
for subkey, subval in value.items():
|
||||
entry[f"Receive.{key}.{subkey}"] = subval
|
||||
else:
|
||||
entry[f"Receive.{key}"] = value
|
||||
|
||||
# Non-blocking put - drop entry if queue is full
|
||||
try:
|
||||
self.log_queue.put_nowait(entry)
|
||||
except:
|
||||
pass # Queue full, skip this entry rather than block
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to queue log entry: {e}")
|
||||
|
||||
def _start_logging(self, filename):
|
||||
"""Start CSV logging to the specified file."""
|
||||
if self.logging_active.value:
|
||||
logging.warning("Logging already active")
|
||||
return
|
||||
|
||||
self.log_queue = multiprocessing.Queue(maxsize=1000)
|
||||
self.log_stop_event = multiprocessing.Event()
|
||||
|
||||
self.csv_logger = CSVLogger(self.log_queue, self.log_stop_event, filename)
|
||||
self.csv_logger.start()
|
||||
|
||||
self.logging_active.value = True
|
||||
logging.info(f"CSV logging started: {filename}")
|
||||
|
||||
def _stop_logging(self):
|
||||
"""Stop CSV logging."""
|
||||
if not self.logging_active.value:
|
||||
return
|
||||
|
||||
self.logging_active.value = False
|
||||
|
||||
if self.log_queue:
|
||||
try:
|
||||
self.log_queue.put_nowait(None) # Poison pill
|
||||
except:
|
||||
pass
|
||||
|
||||
if self.log_stop_event:
|
||||
self.log_stop_event.set()
|
||||
|
||||
if self.csv_logger and self.csv_logger.is_alive():
|
||||
self.csv_logger.join(timeout=2)
|
||||
if self.csv_logger.is_alive():
|
||||
self.csv_logger.terminate()
|
||||
|
||||
self.csv_logger = None
|
||||
self.log_queue = None
|
||||
self.log_stop_event = None
|
||||
logging.info("CSV logging stopped")
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up resources."""
|
||||
# Stop logging first
|
||||
self._stop_logging()
|
||||
|
||||
if self.udp_socket:
|
||||
try:
|
||||
self.udp_socket.close()
|
||||
logging.info("Network socket closed")
|
||||
except Exception as e:
|
||||
logging.error(f"Error closing socket: {e}")
|
||||
self.udp_socket = None
|
||||
|
||||
@staticmethod
|
||||
def is_valid_ip(ip):
|
||||
@ -83,4 +254,4 @@ class NetworkProcess(multiprocessing.Process):
|
||||
self.receive_variables["IPOC"] = received_ipoc
|
||||
self.send_variables["IPOC"] = received_ipoc + 4
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] Error parsing received message: {e}")
|
||||
logging.error(f"Error parsing received message: {e}")
|
||||
|
||||
@ -9,9 +9,9 @@ from .inject_rsi_to_krl import inject_rsi_to_krl
|
||||
import threading
|
||||
from .trajectory_planner import generate_trajectory, execute_trajectory
|
||||
import datetime
|
||||
from src.RSIPI.static_plotter import StaticPlotter # Make sure this file exists as described
|
||||
from .static_plotter import StaticPlotter
|
||||
import os
|
||||
from src.RSIPI.live_plotter import LivePlotter
|
||||
from .live_plotter import LivePlotter
|
||||
from threading import Thread
|
||||
import asyncio
|
||||
|
||||
@ -49,6 +49,7 @@ class RSIAPI:
|
||||
self.client.stop()
|
||||
return "RSI stopped."
|
||||
|
||||
@staticmethod
|
||||
def generate_report(filename, format_type):
|
||||
"""
|
||||
Generate a statistical report from a CSV log file.
|
||||
|
||||
@ -1,17 +1,43 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
from threading import Lock
|
||||
from .config_parser import ConfigParser
|
||||
from .network_handler import NetworkProcess
|
||||
from .safety_manager import SafetyManager
|
||||
import threading
|
||||
|
||||
|
||||
class ClientState(Enum):
|
||||
"""Connection states for RSIClient."""
|
||||
INITIALIZED = auto() # After __init__, network process spawned but not started
|
||||
STARTING = auto() # Start signal sent, waiting for network to be ready
|
||||
RUNNING = auto() # Actively communicating with robot
|
||||
STOPPING = auto() # Shutdown in progress
|
||||
STOPPED = auto() # Fully stopped, cannot be restarted (use reconnect)
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class RSIClient:
|
||||
"""Main RSI API class that integrates network, config handling, and message processing."""
|
||||
|
||||
# Valid state transitions
|
||||
_VALID_TRANSITIONS = {
|
||||
ClientState.INITIALIZED: {ClientState.STARTING, ClientState.STOPPING},
|
||||
ClientState.STARTING: {ClientState.RUNNING, ClientState.STOPPING, ClientState.ERROR},
|
||||
ClientState.RUNNING: {ClientState.STOPPING, ClientState.ERROR},
|
||||
ClientState.STOPPING: {ClientState.STOPPED, ClientState.ERROR},
|
||||
ClientState.STOPPED: {ClientState.INITIALIZED}, # Via reconnect
|
||||
ClientState.ERROR: {ClientState.STOPPING, ClientState.INITIALIZED}, # Via reconnect
|
||||
}
|
||||
|
||||
def __init__(self, config_file, rsi_limits_file=None):
|
||||
logging.info(f"Loading RSI configuration from {config_file}...")
|
||||
|
||||
self._state = ClientState.INITIALIZED
|
||||
self._state_lock = Lock()
|
||||
|
||||
self.config_parser = ConfigParser(config_file, rsi_limits_file)
|
||||
network_settings = self.config_parser.get_network_settings()
|
||||
|
||||
@ -19,11 +45,15 @@ class RSIClient:
|
||||
self.send_variables = self.manager.dict(self.config_parser.send_variables)
|
||||
self.receive_variables = self.manager.dict(self.config_parser.receive_variables)
|
||||
self.stop_event = multiprocessing.Event()
|
||||
self.start_event = multiprocessing.Event() # ✅ NEW
|
||||
self.start_event = multiprocessing.Event()
|
||||
self.command_queue = multiprocessing.Queue()
|
||||
|
||||
self.safety_manager = SafetyManager(self.config_parser.safety_limits)
|
||||
|
||||
# ✅ Create NetworkProcess but don't start communication yet
|
||||
# Shared logging state (readable from parent process)
|
||||
self._logging_active = multiprocessing.Value('b', False)
|
||||
|
||||
# Create NetworkProcess but don't start communication yet
|
||||
self.network_process = NetworkProcess(
|
||||
network_settings["ip"],
|
||||
network_settings["port"],
|
||||
@ -31,17 +61,56 @@ class RSIClient:
|
||||
self.receive_variables,
|
||||
self.stop_event,
|
||||
self.config_parser,
|
||||
self.start_event
|
||||
self.start_event,
|
||||
self.command_queue
|
||||
)
|
||||
# Share the logging_active flag
|
||||
self.network_process.logging_active = self._logging_active
|
||||
self.network_process.start()
|
||||
|
||||
self.logger = None
|
||||
self.running = False
|
||||
self.thread = None
|
||||
|
||||
@property
|
||||
def state(self) -> ClientState:
|
||||
"""Get current client state (thread-safe)."""
|
||||
with self._state_lock:
|
||||
return self._state
|
||||
|
||||
def _transition_to(self, new_state: ClientState) -> bool:
|
||||
"""
|
||||
Attempt to transition to a new state.
|
||||
|
||||
Returns:
|
||||
True if transition was valid and completed, False otherwise.
|
||||
"""
|
||||
with self._state_lock:
|
||||
if new_state in self._VALID_TRANSITIONS.get(self._state, set()):
|
||||
old_state = self._state
|
||||
self._state = new_state
|
||||
logging.debug(f"State transition: {old_state.name} -> {new_state.name}")
|
||||
return True
|
||||
else:
|
||||
logging.warning(
|
||||
f"Invalid state transition attempted: {self._state.name} -> {new_state.name}"
|
||||
)
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""Send start signal to NetworkProcess and run control loop."""
|
||||
if not self._transition_to(ClientState.STARTING):
|
||||
logging.error("Cannot start: invalid state")
|
||||
return
|
||||
|
||||
logging.info("RSIClient sending start signal to NetworkProcess...")
|
||||
self.start_event.set()
|
||||
self.running = True
|
||||
|
||||
if not self._transition_to(ClientState.RUNNING):
|
||||
logging.error("Failed to transition to RUNNING state")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logging.info("RSI Client Started")
|
||||
|
||||
try:
|
||||
@ -51,39 +120,58 @@ class RSIClient:
|
||||
self.stop()
|
||||
except Exception as e:
|
||||
logging.error(f"RSI Client encountered an error: {e}")
|
||||
self._transition_to(ClientState.ERROR)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the network process and the client thread safely."""
|
||||
logging.info("🛑 Stopping RSI Client...")
|
||||
if self.state in (ClientState.STOPPED, ClientState.STOPPING):
|
||||
logging.debug("Already stopped or stopping")
|
||||
return
|
||||
|
||||
if not self._transition_to(ClientState.STOPPING):
|
||||
logging.warning("Could not transition to STOPPING state")
|
||||
# Continue anyway to ensure cleanup
|
||||
|
||||
logging.info("Stopping RSI Client...")
|
||||
|
||||
self.running = False
|
||||
self.stop_event.set() # ✅ Tell network process to exit nicely
|
||||
self.stop_event.set()
|
||||
|
||||
if self.network_process and self.network_process.is_alive():
|
||||
self.network_process.join(timeout=3) # ✅ Give it time to shutdown
|
||||
self.network_process.join(timeout=3)
|
||||
if self.network_process.is_alive():
|
||||
logging.warning("⚠️ Forcing network process termination...")
|
||||
logging.warning("Forcing network process termination...")
|
||||
self.network_process.terminate()
|
||||
self.network_process.join()
|
||||
|
||||
if hasattr(self, "thread") and self.thread and self.thread.is_alive():
|
||||
self.thread.join()
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2)
|
||||
self.thread = None
|
||||
|
||||
logging.info("✅ RSI Client Stopped")
|
||||
self._transition_to(ClientState.STOPPED)
|
||||
logging.info("RSI Client Stopped")
|
||||
|
||||
def reconnect(self):
|
||||
"""Reconnects the network process safely."""
|
||||
logging.info("Reconnecting RSI Client network...")
|
||||
|
||||
# Stop if currently running
|
||||
if self.state in (ClientState.RUNNING, ClientState.STARTING):
|
||||
self.stop()
|
||||
|
||||
if self.network_process and self.network_process.is_alive():
|
||||
self.stop_event.set()
|
||||
self.network_process.terminate()
|
||||
self.network_process.join()
|
||||
|
||||
# Fresh new events
|
||||
# Reset to initialized state
|
||||
with self._state_lock:
|
||||
self._state = ClientState.INITIALIZED
|
||||
|
||||
# Fresh new events and queue
|
||||
self.stop_event = multiprocessing.Event()
|
||||
self.start_event = multiprocessing.Event()
|
||||
self.command_queue = multiprocessing.Queue()
|
||||
|
||||
# Create new network process
|
||||
network_settings = self.config_parser.get_network_settings()
|
||||
@ -94,10 +182,32 @@ class RSIClient:
|
||||
self.receive_variables,
|
||||
self.stop_event,
|
||||
self.config_parser,
|
||||
self.start_event
|
||||
self.start_event,
|
||||
self.command_queue
|
||||
)
|
||||
self.network_process.logging_active = self._logging_active
|
||||
self.network_process.start()
|
||||
|
||||
# Fresh control thread
|
||||
self.thread = threading.Thread(target=self.start, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if client is in running state."""
|
||||
return self.state == ClientState.RUNNING
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
"""Check if client is fully stopped."""
|
||||
return self.state == ClientState.STOPPED
|
||||
|
||||
def start_logging(self, filename):
|
||||
"""Start CSV logging to the specified file."""
|
||||
self.command_queue.put({'action': 'start_logging', 'filename': filename})
|
||||
|
||||
def stop_logging(self):
|
||||
"""Stop CSV logging."""
|
||||
self.command_queue.put({'action': 'stop_logging'})
|
||||
|
||||
def is_logging_active(self) -> bool:
|
||||
"""Check if CSV logging is currently active."""
|
||||
return self._logging_active.value
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
import logging
|
||||
from src.RSIPI.rsi_limit_parser import parse_rsi_limits
|
||||
from .rsi_limit_parser import parse_rsi_limits
|
||||
|
||||
# ✅ Configure Logging (toggleable)
|
||||
LOGGING_ENABLED = False # Change too False to silence logging output
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
import xml.etree.ElementTree as ET
|
||||
import logging
|
||||
import threading
|
||||
from src.RSIPI.rsi_config import RSIConfig
|
||||
from .rsi_config import RSIConfig
|
||||
|
||||
# ✅ Toggle logging for debugging purposes
|
||||
LOGGING_ENABLED = True
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from RSIPI.safety_manager import SafetyManager
|
||||
from .safety_manager import SafetyManager
|
||||
import time
|
||||
|
||||
def generate_trajectory(start, end, steps=100, space="cartesian", mode="absolute", include_resets=False):
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# RSIPI Tests
|
||||
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/__pycache__/test_xml_handler.cpython-313-pytest-8.4.1.pyc
Normal file
BIN
tests/__pycache__/test_xml_handler.cpython-313-pytest-8.4.1.pyc
Normal file
Binary file not shown.
168
tests/test_safety_manager.py
Normal file
168
tests/test_safety_manager.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Tests for SafetyManager."""
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from RSIPI.safety_manager import SafetyManager
|
||||
|
||||
|
||||
class TestValidate:
|
||||
"""Tests for SafetyManager.validate()"""
|
||||
|
||||
def test_validate_within_limits(self):
|
||||
"""Test that values within limits pass through unchanged."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
result = sm.validate("RKorr.X", 3.0)
|
||||
assert result == 3.0
|
||||
|
||||
def test_validate_at_boundary(self):
|
||||
"""Test that values at exact boundaries pass."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
assert sm.validate("RKorr.X", -5.0) == -5.0
|
||||
assert sm.validate("RKorr.X", 5.0) == 5.0
|
||||
|
||||
def test_validate_exceeds_max(self):
|
||||
"""Test that values exceeding max raise ValueError."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
sm.validate("RKorr.X", 5.1)
|
||||
|
||||
def test_validate_below_min(self):
|
||||
"""Test that values below min raise ValueError."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
sm.validate("RKorr.X", -5.1)
|
||||
|
||||
def test_validate_unlisted_path(self):
|
||||
"""Test that paths without defined limits pass through."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
# RKorr.Y has no limit defined - should pass
|
||||
result = sm.validate("RKorr.Y", 1000.0)
|
||||
assert result == 1000.0
|
||||
|
||||
def test_validate_with_override(self):
|
||||
"""Test that override bypasses all limit checks."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
sm.override_safety(True)
|
||||
|
||||
# Should pass despite being out of bounds
|
||||
result = sm.validate("RKorr.X", 100.0)
|
||||
assert result == 100.0
|
||||
|
||||
|
||||
class TestEmergencyStop:
|
||||
"""Tests for emergency stop functionality."""
|
||||
|
||||
def test_estop_blocks_validation(self):
|
||||
"""Test that e-stop blocks all validation."""
|
||||
sm = SafetyManager()
|
||||
sm.emergency_stop()
|
||||
|
||||
with pytest.raises(RuntimeError, match="E-STOP"):
|
||||
sm.validate("RKorr.X", 0.0)
|
||||
|
||||
def test_estop_reset(self):
|
||||
"""Test that e-stop can be reset."""
|
||||
sm = SafetyManager()
|
||||
sm.emergency_stop()
|
||||
|
||||
assert sm.is_stopped() is True
|
||||
|
||||
sm.reset_stop()
|
||||
|
||||
assert sm.is_stopped() is False
|
||||
# Should work again
|
||||
result = sm.validate("RKorr.X", 1.0)
|
||||
assert result == 1.0
|
||||
|
||||
|
||||
class TestSetLimit:
|
||||
"""Tests for runtime limit modification."""
|
||||
|
||||
def test_set_new_limit(self):
|
||||
"""Test adding a new limit at runtime."""
|
||||
sm = SafetyManager()
|
||||
|
||||
sm.set_limit("RKorr.Y", -10.0, 10.0)
|
||||
|
||||
assert sm.validate("RKorr.Y", 5.0) == 5.0
|
||||
with pytest.raises(ValueError):
|
||||
sm.validate("RKorr.Y", 15.0)
|
||||
|
||||
def test_override_existing_limit(self):
|
||||
"""Test overriding an existing limit."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
# Original limit blocks this
|
||||
with pytest.raises(ValueError):
|
||||
sm.validate("RKorr.X", 8.0)
|
||||
|
||||
# Override limit
|
||||
sm.set_limit("RKorr.X", -10.0, 10.0)
|
||||
|
||||
# Now it should pass
|
||||
assert sm.validate("RKorr.X", 8.0) == 8.0
|
||||
|
||||
def test_get_limits(self):
|
||||
"""Test retrieving all limits."""
|
||||
limits = {"RKorr.X": (-5.0, 5.0), "AKorr.A1": (-6.0, 6.0)}
|
||||
sm = SafetyManager(limits)
|
||||
|
||||
retrieved = sm.get_limits()
|
||||
assert retrieved == limits
|
||||
# Should be a copy
|
||||
retrieved["new"] = (0, 1)
|
||||
assert "new" not in sm.get_limits()
|
||||
|
||||
|
||||
class TestStaticChecks:
|
||||
"""Tests for static limit checking methods."""
|
||||
|
||||
def test_cartesian_limits_valid(self):
|
||||
"""Test valid Cartesian pose passes check."""
|
||||
pose = {"X": 500, "Y": -200, "Z": 1000, "A": 0, "B": 0, "C": 0}
|
||||
assert SafetyManager.check_cartesian_limits(pose) is True
|
||||
|
||||
def test_cartesian_limits_z_too_low(self):
|
||||
"""Test Z below zero fails."""
|
||||
pose = {"X": 0, "Y": 0, "Z": -100}
|
||||
assert SafetyManager.check_cartesian_limits(pose) is False
|
||||
|
||||
def test_cartesian_limits_x_too_high(self):
|
||||
"""Test X exceeding max fails."""
|
||||
pose = {"X": 2000, "Y": 0, "Z": 500}
|
||||
assert SafetyManager.check_cartesian_limits(pose) is False
|
||||
|
||||
def test_cartesian_limits_partial_pose(self):
|
||||
"""Test partial pose (missing keys) passes if present keys are valid."""
|
||||
pose = {"X": 100, "Z": 500} # Missing Y, A, B, C
|
||||
assert SafetyManager.check_cartesian_limits(pose) is True
|
||||
|
||||
def test_joint_limits_valid(self):
|
||||
"""Test valid joint pose passes check."""
|
||||
pose = {"A1": 0, "A2": -45, "A3": 90, "A4": 0, "A5": 0, "A6": 180}
|
||||
assert SafetyManager.check_joint_limits(pose) is True
|
||||
|
||||
def test_joint_limits_a1_exceeded(self):
|
||||
"""Test A1 exceeding limit fails."""
|
||||
pose = {"A1": 200} # Limit is -185 to 185
|
||||
assert SafetyManager.check_joint_limits(pose) is False
|
||||
|
||||
def test_joint_limits_a5_exceeded(self):
|
||||
"""Test A5 exceeding its tighter limit fails."""
|
||||
pose = {"A5": 150} # Limit is -130 to 130
|
||||
assert SafetyManager.check_joint_limits(pose) is False
|
||||
136
tests/test_trajectory_planner.py
Normal file
136
tests/test_trajectory_planner.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Tests for trajectory planner."""
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from RSIPI.trajectory_planner import generate_trajectory
|
||||
|
||||
|
||||
class TestGenerateTrajectory:
|
||||
"""Tests for generate_trajectory()"""
|
||||
|
||||
def test_basic_linear_interpolation(self):
|
||||
"""Test basic linear interpolation between two points."""
|
||||
start = {"X": 0, "Y": 0, "Z": 0}
|
||||
end = {"X": 100, "Y": 200, "Z": 300}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute")
|
||||
|
||||
assert len(traj) == 10
|
||||
|
||||
# First point should be 1/10 of the way
|
||||
assert traj[0]["X"] == pytest.approx(10.0)
|
||||
assert traj[0]["Y"] == pytest.approx(20.0)
|
||||
assert traj[0]["Z"] == pytest.approx(30.0)
|
||||
|
||||
# Last point should be at end
|
||||
assert traj[-1]["X"] == pytest.approx(100.0)
|
||||
assert traj[-1]["Y"] == pytest.approx(200.0)
|
||||
assert traj[-1]["Z"] == pytest.approx(300.0)
|
||||
|
||||
def test_relative_mode(self):
|
||||
"""Test relative mode generates incremental deltas."""
|
||||
start = {"X": 0, "Y": 0}
|
||||
end = {"X": 100, "Y": 50}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="relative")
|
||||
|
||||
# Each step should be the same delta
|
||||
for point in traj:
|
||||
assert point["X"] == pytest.approx(10.0) # 100/10
|
||||
assert point["Y"] == pytest.approx(5.0) # 50/10
|
||||
|
||||
def test_relative_mode_with_resets(self):
|
||||
"""Test relative mode with reset steps."""
|
||||
start = {"X": 0}
|
||||
end = {"X": 30}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=3, space="cartesian", mode="relative", include_resets=True)
|
||||
|
||||
# Should have 6 points: delta, reset, delta, reset, delta, reset
|
||||
assert len(traj) == 6
|
||||
|
||||
# Odd indices should be zero (reset points)
|
||||
assert traj[1]["X"] == 0.0
|
||||
assert traj[3]["X"] == 0.0
|
||||
assert traj[5]["X"] == 0.0
|
||||
|
||||
# Even indices should be deltas
|
||||
assert traj[0]["X"] == pytest.approx(10.0)
|
||||
assert traj[2]["X"] == pytest.approx(10.0)
|
||||
assert traj[4]["X"] == pytest.approx(10.0)
|
||||
|
||||
def test_joint_space(self):
|
||||
"""Test trajectory generation in joint space."""
|
||||
start = {"A1": 0, "A2": 0}
|
||||
end = {"A1": 90, "A2": -45}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=9, space="joint", mode="absolute")
|
||||
|
||||
assert len(traj) == 9
|
||||
|
||||
# Final point should match end
|
||||
assert traj[-1]["A1"] == pytest.approx(90.0)
|
||||
assert traj[-1]["A2"] == pytest.approx(-45.0)
|
||||
|
||||
def test_negative_movement(self):
|
||||
"""Test trajectory with negative direction."""
|
||||
start = {"X": 100, "Y": 100}
|
||||
end = {"X": 0, "Y": -100}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=10, space="cartesian", mode="absolute")
|
||||
|
||||
# Should decrease linearly
|
||||
assert traj[0]["X"] == pytest.approx(90.0)
|
||||
assert traj[0]["Y"] == pytest.approx(80.0)
|
||||
assert traj[-1]["X"] == pytest.approx(0.0)
|
||||
assert traj[-1]["Y"] == pytest.approx(-100.0)
|
||||
|
||||
def test_single_step(self):
|
||||
"""Test trajectory with single step goes directly to end."""
|
||||
start = {"X": 0}
|
||||
end = {"X": 100}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=1, space="cartesian", mode="absolute")
|
||||
|
||||
assert len(traj) == 1
|
||||
assert traj[0]["X"] == pytest.approx(100.0)
|
||||
|
||||
def test_invalid_mode_raises(self):
|
||||
"""Test that invalid mode raises ValueError."""
|
||||
start = {"X": 0}
|
||||
end = {"X": 100}
|
||||
|
||||
with pytest.raises(ValueError, match="mode must be"):
|
||||
generate_trajectory(start, end, mode="invalid")
|
||||
|
||||
def test_invalid_space_raises(self):
|
||||
"""Test that invalid space raises ValueError."""
|
||||
start = {"X": 0}
|
||||
end = {"X": 100}
|
||||
|
||||
with pytest.raises(ValueError, match="space must be"):
|
||||
generate_trajectory(start, end, space="invalid")
|
||||
|
||||
def test_absolute_mode_ignores_include_resets(self):
|
||||
"""Test that absolute mode ignores include_resets flag."""
|
||||
start = {"X": 0}
|
||||
end = {"X": 100}
|
||||
|
||||
# Even with include_resets=True, absolute mode should not add resets
|
||||
traj = generate_trajectory(start, end, steps=5, space="cartesian", mode="absolute", include_resets=True)
|
||||
|
||||
assert len(traj) == 5 # No extra reset points
|
||||
|
||||
def test_preserves_all_axes(self):
|
||||
"""Test that all axes from start are preserved in trajectory."""
|
||||
start = {"X": 0, "Y": 0, "Z": 0, "A": 0, "B": 0, "C": 0}
|
||||
end = {"X": 10, "Y": 20, "Z": 30, "A": 5, "B": 10, "C": 15}
|
||||
|
||||
traj = generate_trajectory(start, end, steps=2, space="cartesian", mode="absolute")
|
||||
|
||||
# Each point should have all axes
|
||||
for point in traj:
|
||||
assert set(point.keys()) == {"X", "Y", "Z", "A", "B", "C"}
|
||||
112
tests/test_xml_handler.py
Normal file
112
tests/test_xml_handler.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Tests for XMLGenerator."""
|
||||
import pytest
|
||||
import xml.etree.ElementTree as ET
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from RSIPI.xml_handler import XMLGenerator
|
||||
|
||||
|
||||
class TestGenerateSendXML:
|
||||
"""Tests for XMLGenerator.generate_send_xml()"""
|
||||
|
||||
def test_simple_values(self):
|
||||
"""Test XML generation with simple scalar values."""
|
||||
send_vars = {
|
||||
"IPOC": 12345,
|
||||
"DiL": 1
|
||||
}
|
||||
network_settings = {"sentype": "ImFree"}
|
||||
|
||||
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
|
||||
|
||||
# Parse and verify
|
||||
root = ET.fromstring(xml_str)
|
||||
assert root.tag == "Sen"
|
||||
assert root.get("Type") == "ImFree"
|
||||
assert root.find("IPOC").text == "12345"
|
||||
assert root.find("DiL").text == "1"
|
||||
|
||||
def test_nested_dict_values(self):
|
||||
"""Test XML generation with nested dictionary values (attributes)."""
|
||||
send_vars = {
|
||||
"RKorr": {"X": 100.5, "Y": -50.25, "Z": 0.0}
|
||||
}
|
||||
network_settings = {"sentype": "TestType"}
|
||||
|
||||
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
|
||||
|
||||
root = ET.fromstring(xml_str)
|
||||
rkorr = root.find("RKorr")
|
||||
assert rkorr is not None
|
||||
assert float(rkorr.get("X")) == 100.50
|
||||
assert float(rkorr.get("Y")) == -50.25
|
||||
assert float(rkorr.get("Z")) == 0.0
|
||||
|
||||
def test_free_field_skipped(self):
|
||||
"""Test that FREE field is skipped in XML output."""
|
||||
send_vars = {
|
||||
"IPOC": 1,
|
||||
"FREE": 999
|
||||
}
|
||||
network_settings = {"sentype": "ImFree"}
|
||||
|
||||
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
|
||||
|
||||
root = ET.fromstring(xml_str)
|
||||
assert root.find("FREE") is None
|
||||
assert root.find("IPOC") is not None
|
||||
|
||||
def test_mixed_values(self):
|
||||
"""Test XML generation with mixed scalar and nested values."""
|
||||
send_vars = {
|
||||
"IPOC": 5000,
|
||||
"RKorr": {"X": 10.0, "Y": 20.0},
|
||||
"Digout": {"o1": 1, "o2": 0}
|
||||
}
|
||||
network_settings = {"sentype": "Mixed"}
|
||||
|
||||
xml_str = XMLGenerator.generate_send_xml(send_vars, network_settings)
|
||||
|
||||
root = ET.fromstring(xml_str)
|
||||
assert root.find("IPOC").text == "5000"
|
||||
assert float(root.find("RKorr").get("X")) == 10.0
|
||||
assert float(root.find("Digout").get("o1")) == 1.0
|
||||
|
||||
|
||||
class TestGenerateReceiveXML:
|
||||
"""Tests for XMLGenerator.generate_receive_xml()"""
|
||||
|
||||
def test_simple_receive(self):
|
||||
"""Test receive XML generation with scalar values."""
|
||||
receive_vars = {
|
||||
"IPOC": 12345,
|
||||
"BMode": "Active"
|
||||
}
|
||||
|
||||
xml_str = XMLGenerator.generate_receive_xml(receive_vars)
|
||||
|
||||
root = ET.fromstring(xml_str)
|
||||
assert root.tag == "Rob"
|
||||
assert root.get("Type") == "KUKA"
|
||||
assert root.find("IPOC").text == "12345"
|
||||
assert root.find("BMode").text == "Active"
|
||||
|
||||
def test_nested_receive(self):
|
||||
"""Test receive XML generation with position data."""
|
||||
receive_vars = {
|
||||
"RIst": {"X": 500.0, "Y": 100.0, "Z": 800.0},
|
||||
"ASPos": {"A1": 0.0, "A2": -45.0, "A3": 90.0}
|
||||
}
|
||||
|
||||
xml_str = XMLGenerator.generate_receive_xml(receive_vars)
|
||||
|
||||
root = ET.fromstring(xml_str)
|
||||
rist = root.find("RIst")
|
||||
assert float(rist.get("X")) == 500.0
|
||||
assert float(rist.get("Z")) == 800.0
|
||||
|
||||
aspos = root.find("ASPos")
|
||||
assert float(aspos.get("A2")) == -45.0
|
||||
Loading…
Reference in New Issue
Block a user