feat: Improve dashboard performance and Claude terminal context
- Add trial limiting (300 max) and reduce polling to 15s for large studies - Make dashboard layout wider with col-span adjustments - Claude terminal now runs from Atomizer root for CLAUDE.md/skills access - Add study context display in terminal on connect - Add KaTeX math rendering styles for study reports - Add surrogate tuner module for hyperparameter optimization - Fix backend proxy to port 8001 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -15,7 +15,39 @@
|
|||||||
"Bash(C:UsersAntoineminiconda3envsatomizerpython.exe run_adaptive_mirror_optimization.py --fea-budget 100 --batch-size 5 --strategy hybrid)",
|
"Bash(C:UsersAntoineminiconda3envsatomizerpython.exe run_adaptive_mirror_optimization.py --fea-budget 100 --batch-size 5 --strategy hybrid)",
|
||||||
"Bash(/c/Users/Antoine/miniconda3/envs/atomizer/python.exe:*)",
|
"Bash(/c/Users/Antoine/miniconda3/envs/atomizer/python.exe:*)",
|
||||||
"Bash(npm run build:*)",
|
"Bash(npm run build:*)",
|
||||||
"Bash(npm uninstall:*)"
|
"Bash(npm uninstall:*)",
|
||||||
|
"Bash(netstat:*)",
|
||||||
|
"Bash(findstr:*)",
|
||||||
|
"Bash(curl:*)",
|
||||||
|
"Bash(npx tsc:*)",
|
||||||
|
"Bash(atomizer-dashboard/README.md )",
|
||||||
|
"Bash(atomizer-dashboard/backend/api/main.py )",
|
||||||
|
"Bash(atomizer-dashboard/backend/api/routes/optimization.py )",
|
||||||
|
"Bash(atomizer-dashboard/backend/api/routes/claude.py )",
|
||||||
|
"Bash(atomizer-dashboard/backend/api/routes/terminal.py )",
|
||||||
|
"Bash(atomizer-dashboard/backend/api/services/ )",
|
||||||
|
"Bash(atomizer-dashboard/backend/requirements.txt )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/package.json )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/package-lock.json )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/components/ClaudeChat.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/components/ClaudeTerminal.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/components/dashboard/ControlPanel.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/pages/Dashboard.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/context/ )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/pages/Home.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/App.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/api/client.ts )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/components/layout/Sidebar.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/index.css )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/src/pages/Results.tsx )",
|
||||||
|
"Bash(atomizer-dashboard/frontend/tailwind.config.js )",
|
||||||
|
"Bash(docs/07_DEVELOPMENT/DASHBOARD_IMPROVEMENT_PLAN.md)",
|
||||||
|
"Bash(taskkill:*)",
|
||||||
|
"Bash(xargs:*)",
|
||||||
|
"Bash(cmd.exe /c:*)",
|
||||||
|
"Bash(powershell.exe -Command:*)",
|
||||||
|
"Bash(where:*)",
|
||||||
|
"Bash(type %USERPROFILE%.claude*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Create Optimization Study Skill
|
# Create Optimization Study Skill
|
||||||
|
|
||||||
**Last Updated**: November 26, 2025
|
**Last Updated**: December 4, 2025
|
||||||
**Version**: 2.0 - Protocol Reference + Code Patterns (Centralized)
|
**Version**: 2.1 - Added Mandatory Documentation Requirements
|
||||||
|
|
||||||
You are helping the user create a complete Atomizer optimization study from a natural language description.
|
You are helping the user create a complete Atomizer optimization study from a natural language description.
|
||||||
|
|
||||||
@@ -9,6 +9,39 @@ You are helping the user create a complete Atomizer optimization study from a na
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## MANDATORY DOCUMENTATION CHECKLIST
|
||||||
|
|
||||||
|
**EVERY study MUST have these files. A study is NOT complete without them:**
|
||||||
|
|
||||||
|
| File | Purpose | When Created |
|
||||||
|
|------|---------|--------------|
|
||||||
|
| `README.md` | **Engineering Blueprint** - Full mathematical formulation, design variables, objectives, algorithm config | At study creation |
|
||||||
|
| `STUDY_REPORT.md` | **Results Tracking** - Progress, best designs, surrogate accuracy, recommendations | At study creation (template) |
|
||||||
|
|
||||||
|
**README.md Requirements (11 sections)**:
|
||||||
|
1. Engineering Problem (objective, physical system)
|
||||||
|
2. Mathematical Formulation (objectives, design variables, constraints with LaTeX)
|
||||||
|
3. Optimization Algorithm (config, properties, return format)
|
||||||
|
4. Simulation Pipeline (trial execution flow diagram)
|
||||||
|
5. Result Extraction Methods (extractor details, code snippets)
|
||||||
|
6. Neural Acceleration (surrogate config, expected performance)
|
||||||
|
7. Study File Structure (directory tree)
|
||||||
|
8. Results Location (output files)
|
||||||
|
9. Quick Start (commands)
|
||||||
|
10. Configuration Reference (config.json mapping)
|
||||||
|
11. References
|
||||||
|
|
||||||
|
**STUDY_REPORT.md Requirements**:
|
||||||
|
- Executive Summary (trial counts, best values)
|
||||||
|
- Optimization Progress (iteration history, convergence)
|
||||||
|
- Best Designs Found (FEA-validated)
|
||||||
|
- Neural Surrogate Performance (R², MAE)
|
||||||
|
- Engineering Recommendations
|
||||||
|
|
||||||
|
**FAILURE MODE**: If you create a study without README.md and STUDY_REPORT.md, the user cannot understand what the study does, the dashboard cannot display documentation, and the study is incomplete.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Protocol Reference (MUST USE)
|
## Protocol Reference (MUST USE)
|
||||||
|
|
||||||
This section defines ALL available components. When generating `run_optimization.py`, use ONLY these documented patterns.
|
This section defines ALL available components. When generating `run_optimization.py`, use ONLY these documented patterns.
|
||||||
|
|||||||
@@ -73,7 +73,8 @@ async def list_studies():
|
|||||||
# Protocol 10: Read from Optuna SQLite database
|
# Protocol 10: Read from Optuna SQLite database
|
||||||
if study_db.exists():
|
if study_db.exists():
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(str(study_db))
|
# Use timeout to avoid blocking on locked databases
|
||||||
|
conn = sqlite3.connect(str(study_db), timeout=2.0)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Get trial count and status
|
# Get trial count and status
|
||||||
@@ -130,6 +131,29 @@ async def list_studies():
|
|||||||
config.get('trials', {}).get('n_trials', 50)
|
config.get('trials', {}).get('n_trials', 50)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get creation date from directory or config modification time
|
||||||
|
created_at = None
|
||||||
|
try:
|
||||||
|
# First try to get from database (most accurate)
|
||||||
|
if study_db.exists():
|
||||||
|
created_at = datetime.fromtimestamp(study_db.stat().st_mtime).isoformat()
|
||||||
|
elif config_file.exists():
|
||||||
|
created_at = datetime.fromtimestamp(config_file.stat().st_mtime).isoformat()
|
||||||
|
else:
|
||||||
|
created_at = datetime.fromtimestamp(study_dir.stat().st_ctime).isoformat()
|
||||||
|
except:
|
||||||
|
created_at = None
|
||||||
|
|
||||||
|
# Get last modified time
|
||||||
|
last_modified = None
|
||||||
|
try:
|
||||||
|
if study_db.exists():
|
||||||
|
last_modified = datetime.fromtimestamp(study_db.stat().st_mtime).isoformat()
|
||||||
|
elif history_file.exists():
|
||||||
|
last_modified = datetime.fromtimestamp(history_file.stat().st_mtime).isoformat()
|
||||||
|
except:
|
||||||
|
last_modified = None
|
||||||
|
|
||||||
studies.append({
|
studies.append({
|
||||||
"id": study_dir.name,
|
"id": study_dir.name,
|
||||||
"name": study_dir.name.replace("_", " ").title(),
|
"name": study_dir.name.replace("_", " ").title(),
|
||||||
@@ -140,7 +164,9 @@ async def list_studies():
|
|||||||
},
|
},
|
||||||
"best_value": best_value,
|
"best_value": best_value,
|
||||||
"target": config.get('target', {}).get('value'),
|
"target": config.get('target', {}).get('value'),
|
||||||
"path": str(study_dir)
|
"path": str(study_dir),
|
||||||
|
"created_at": created_at,
|
||||||
|
"last_modified": last_modified
|
||||||
})
|
})
|
||||||
|
|
||||||
return {"studies": studies}
|
return {"studies": studies}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Terminal WebSocket for Claude Code CLI
|
Terminal WebSocket for Claude Code CLI
|
||||||
|
|
||||||
Provides a PTY-based terminal that runs Claude Code in the dashboard.
|
Provides a PTY-based terminal that runs Claude Code in the dashboard.
|
||||||
|
Uses pywinpty on Windows for proper interactive terminal support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
@@ -18,6 +19,13 @@ router = APIRouter()
|
|||||||
# Store active terminal sessions
|
# Store active terminal sessions
|
||||||
_terminal_sessions: dict = {}
|
_terminal_sessions: dict = {}
|
||||||
|
|
||||||
|
# Check if winpty is available (for Windows)
|
||||||
|
try:
|
||||||
|
from winpty import PtyProcess
|
||||||
|
HAS_WINPTY = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_WINPTY = False
|
||||||
|
|
||||||
|
|
||||||
class TerminalSession:
|
class TerminalSession:
|
||||||
"""Manages a Claude Code terminal session."""
|
"""Manages a Claude Code terminal session."""
|
||||||
@@ -25,10 +33,11 @@ class TerminalSession:
|
|||||||
def __init__(self, session_id: str, working_dir: str):
|
def __init__(self, session_id: str, working_dir: str):
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
self.process: Optional[subprocess.Popen] = None
|
self.process = None
|
||||||
self.websocket: Optional[WebSocket] = None
|
self.websocket: Optional[WebSocket] = None
|
||||||
self._read_task: Optional[asyncio.Task] = None
|
self._read_task: Optional[asyncio.Task] = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._use_winpty = sys.platform == "win32" and HAS_WINPTY
|
||||||
|
|
||||||
async def start(self, websocket: WebSocket):
|
async def start(self, websocket: WebSocket):
|
||||||
"""Start the Claude Code process."""
|
"""Start the Claude Code process."""
|
||||||
@@ -36,18 +45,34 @@ class TerminalSession:
|
|||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
# Determine the claude command
|
# Determine the claude command
|
||||||
# On Windows, claude is typically installed via npm and available in PATH
|
|
||||||
claude_cmd = "claude"
|
claude_cmd = "claude"
|
||||||
|
|
||||||
# Check if we're on Windows
|
|
||||||
is_windows = sys.platform == "win32"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_windows:
|
if self._use_winpty:
|
||||||
# On Windows, use subprocess with pipes
|
# Use winpty for proper PTY on Windows
|
||||||
# We need to use cmd.exe to get proper terminal behavior
|
# Spawn claude directly - winpty handles the interactive terminal
|
||||||
|
import shutil
|
||||||
|
claude_path = shutil.which("claude") or "claude"
|
||||||
|
|
||||||
|
# Ensure HOME and USERPROFILE are properly set for Claude CLI auth
|
||||||
|
env = {**os.environ}
|
||||||
|
env["FORCE_COLOR"] = "1"
|
||||||
|
env["TERM"] = "xterm-256color"
|
||||||
|
# Claude CLI looks for credentials in HOME/.claude or USERPROFILE/.claude
|
||||||
|
# Ensure these are set correctly
|
||||||
|
if "USERPROFILE" in env and "HOME" not in env:
|
||||||
|
env["HOME"] = env["USERPROFILE"]
|
||||||
|
|
||||||
|
self.process = PtyProcess.spawn(
|
||||||
|
claude_path,
|
||||||
|
cwd=self.working_dir,
|
||||||
|
env=env
|
||||||
|
)
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
# Fallback: Windows without winpty - use subprocess
|
||||||
|
# Run claude with --dangerously-skip-permissions for non-interactive mode
|
||||||
self.process = subprocess.Popen(
|
self.process = subprocess.Popen(
|
||||||
["cmd.exe", "/c", claude_cmd],
|
["cmd.exe", "/k", claude_cmd],
|
||||||
stdin=subprocess.PIPE,
|
stdin=subprocess.PIPE,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
@@ -57,7 +82,7 @@ class TerminalSession:
|
|||||||
env={**os.environ, "FORCE_COLOR": "1", "TERM": "xterm-256color"}
|
env={**os.environ, "FORCE_COLOR": "1", "TERM": "xterm-256color"}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# On Unix, we can use pty
|
# On Unix, use pty
|
||||||
import pty
|
import pty
|
||||||
master_fd, slave_fd = pty.openpty()
|
master_fd, slave_fd = pty.openpty()
|
||||||
self.process = subprocess.Popen(
|
self.process = subprocess.Popen(
|
||||||
@@ -94,34 +119,71 @@ class TerminalSession:
|
|||||||
|
|
||||||
async def _read_output(self):
|
async def _read_output(self):
|
||||||
"""Read output from the process and send to WebSocket."""
|
"""Read output from the process and send to WebSocket."""
|
||||||
is_windows = sys.platform == "win32"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while self._running and self.process and self.process.poll() is None:
|
while self._running:
|
||||||
if is_windows:
|
if self._use_winpty:
|
||||||
# Read from stdout pipe
|
# Read from winpty
|
||||||
if self.process.stdout:
|
if self.process and self.process.isalive():
|
||||||
# Use asyncio to read without blocking
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
data = await loop.run_in_executor(
|
data = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: self.process.stdout.read(1024)
|
lambda: self.process.read(4096)
|
||||||
)
|
)
|
||||||
if data:
|
if data:
|
||||||
await self.websocket.send_json({
|
await self.websocket.send_json({
|
||||||
"type": "output",
|
"type": "output",
|
||||||
"data": data.decode("utf-8", errors="replace")
|
"data": data
|
||||||
})
|
})
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
# Windows subprocess pipe mode
|
||||||
|
if self.process and self.process.poll() is None:
|
||||||
|
if self.process.stdout:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
# Use non-blocking read with a timeout
|
||||||
|
import msvcrt
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
# Read available data
|
||||||
|
data = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.process.stdout.read(1)
|
||||||
|
)
|
||||||
|
if data:
|
||||||
|
# Read more if available
|
||||||
|
more_data = b""
|
||||||
|
try:
|
||||||
|
# Try to read more without blocking
|
||||||
|
while True:
|
||||||
|
extra = self.process.stdout.read(1)
|
||||||
|
if extra:
|
||||||
|
more_data += extra
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
full_data = data + more_data
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "output",
|
||||||
|
"data": full_data.decode("utf-8", errors="replace")
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
# Read from PTY master
|
# Unix PTY
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
data = await loop.run_in_executor(
|
data = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: os.read(self._master_fd, 1024)
|
lambda: os.read(self._master_fd, 4096)
|
||||||
)
|
)
|
||||||
if data:
|
if data:
|
||||||
await self.websocket.send_json({
|
await self.websocket.send_json({
|
||||||
@@ -135,7 +197,12 @@ class TerminalSession:
|
|||||||
|
|
||||||
# Process ended
|
# Process ended
|
||||||
if self.websocket:
|
if self.websocket:
|
||||||
exit_code = self.process.poll() if self.process else -1
|
exit_code = -1
|
||||||
|
if self._use_winpty:
|
||||||
|
exit_code = self.process.exitstatus if self.process else -1
|
||||||
|
elif self.process:
|
||||||
|
exit_code = self.process.poll() if self.process.poll() is not None else -1
|
||||||
|
|
||||||
await self.websocket.send_json({
|
await self.websocket.send_json({
|
||||||
"type": "exit",
|
"type": "exit",
|
||||||
"code": exit_code
|
"code": exit_code
|
||||||
@@ -156,10 +223,10 @@ class TerminalSession:
|
|||||||
if not self.process or not self._running:
|
if not self.process or not self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_windows = sys.platform == "win32"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_windows:
|
if self._use_winpty:
|
||||||
|
self.process.write(data)
|
||||||
|
elif sys.platform == "win32":
|
||||||
if self.process.stdin:
|
if self.process.stdin:
|
||||||
self.process.stdin.write(data.encode())
|
self.process.stdin.write(data.encode())
|
||||||
self.process.stdin.flush()
|
self.process.stdin.flush()
|
||||||
@@ -173,8 +240,13 @@ class TerminalSession:
|
|||||||
})
|
})
|
||||||
|
|
||||||
async def resize(self, cols: int, rows: int):
|
async def resize(self, cols: int, rows: int):
|
||||||
"""Resize the terminal (Unix only)."""
|
"""Resize the terminal."""
|
||||||
if sys.platform != "win32" and hasattr(self, '_master_fd'):
|
if self._use_winpty and self.process:
|
||||||
|
try:
|
||||||
|
self.process.setwinsize(rows, cols)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
elif sys.platform != "win32" and hasattr(self, '_master_fd'):
|
||||||
import struct
|
import struct
|
||||||
import fcntl
|
import fcntl
|
||||||
import termios
|
import termios
|
||||||
@@ -194,14 +266,17 @@ class TerminalSession:
|
|||||||
|
|
||||||
if self.process:
|
if self.process:
|
||||||
try:
|
try:
|
||||||
if sys.platform == "win32":
|
if self._use_winpty:
|
||||||
|
self.process.terminate()
|
||||||
|
elif sys.platform == "win32":
|
||||||
self.process.terminate()
|
self.process.terminate()
|
||||||
else:
|
else:
|
||||||
os.kill(self.process.pid, signal.SIGTERM)
|
os.kill(self.process.pid, signal.SIGTERM)
|
||||||
self.process.wait(timeout=2)
|
self.process.wait(timeout=2)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
self.process.kill()
|
if hasattr(self.process, 'kill'):
|
||||||
|
self.process.kill()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -213,16 +288,18 @@ class TerminalSession:
|
|||||||
|
|
||||||
|
|
||||||
@router.websocket("/claude")
|
@router.websocket("/claude")
|
||||||
async def claude_terminal(websocket: WebSocket, working_dir: str = None):
|
async def claude_terminal(websocket: WebSocket, working_dir: str = None, study_id: str = None):
|
||||||
"""
|
"""
|
||||||
WebSocket endpoint for Claude Code terminal.
|
WebSocket endpoint for Claude Code terminal.
|
||||||
|
|
||||||
Query params:
|
Query params:
|
||||||
working_dir: Directory to start Claude Code in (defaults to Atomizer root)
|
working_dir: Directory to start Claude Code in (defaults to Atomizer root)
|
||||||
|
study_id: Optional study ID to set context for Claude
|
||||||
|
|
||||||
Client -> Server messages:
|
Client -> Server messages:
|
||||||
{"type": "input", "data": "user input text"}
|
{"type": "input", "data": "user input text"}
|
||||||
{"type": "resize", "cols": 80, "rows": 24}
|
{"type": "resize", "cols": 80, "rows": 24}
|
||||||
|
{"type": "stop"}
|
||||||
|
|
||||||
Server -> Client messages:
|
Server -> Client messages:
|
||||||
{"type": "started", "message": "..."}
|
{"type": "started", "message": "..."}
|
||||||
@@ -247,6 +324,11 @@ async def claude_terminal(websocket: WebSocket, working_dir: str = None):
|
|||||||
# Start Claude Code
|
# Start Claude Code
|
||||||
await session.start(websocket)
|
await session.start(websocket)
|
||||||
|
|
||||||
|
# Note: Claude is started in Atomizer root directory so it has access to:
|
||||||
|
# - CLAUDE.md (system instructions)
|
||||||
|
# - .claude/skills/ (skill definitions)
|
||||||
|
# The study_id is available for the user to reference in their prompts
|
||||||
|
|
||||||
# Handle incoming messages
|
# Handle incoming messages
|
||||||
while session._running:
|
while session._running:
|
||||||
try:
|
try:
|
||||||
@@ -285,5 +367,6 @@ async def terminal_status():
|
|||||||
return {
|
return {
|
||||||
"available": claude_path is not None,
|
"available": claude_path is not None,
|
||||||
"path": claude_path,
|
"path": claude_path,
|
||||||
|
"winpty_available": HAS_WINPTY,
|
||||||
"message": "Claude Code CLI is available" if claude_path else "Claude Code CLI not found. Install with: npm install -g @anthropic-ai/claude-code"
|
"message": "Claude Code CLI is available" if claude_path else "Claude Code CLI not found. Install with: npm install -g @anthropic-ai/claude-code"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,21 +149,25 @@ export const ClaudeTerminal: React.FC<ClaudeTerminalProps> = ({
|
|||||||
setIsConnecting(true);
|
setIsConnecting(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
|
|
||||||
// Determine working directory - use study path if available
|
// Always use Atomizer root as working directory so Claude has access to:
|
||||||
let workingDir = '';
|
// - CLAUDE.md (system instructions)
|
||||||
if (selectedStudy?.id) {
|
// - .claude/skills/ (skill definitions)
|
||||||
// The study directory path
|
// Pass study_id as parameter so we can inform Claude about the context
|
||||||
workingDir = `?working_dir=C:/Users/Antoine/Atomizer`;
|
const workingDir = 'C:/Users/Antoine/Atomizer';
|
||||||
}
|
const studyParam = selectedStudy?.id ? `&study_id=${selectedStudy.id}` : '';
|
||||||
|
|
||||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
const ws = new WebSocket(`${protocol}//${window.location.host}/api/terminal/claude${workingDir}`);
|
const ws = new WebSocket(`${protocol}//${window.location.host}/api/terminal/claude?working_dir=${workingDir}${studyParam}`);
|
||||||
|
|
||||||
ws.onopen = () => {
|
ws.onopen = () => {
|
||||||
setIsConnected(true);
|
setIsConnected(true);
|
||||||
setIsConnecting(false);
|
setIsConnecting(false);
|
||||||
xtermRef.current?.clear();
|
xtermRef.current?.clear();
|
||||||
xtermRef.current?.writeln('\x1b[1;32mConnected to Claude Code\x1b[0m');
|
xtermRef.current?.writeln('\x1b[1;32mConnected to Claude Code\x1b[0m');
|
||||||
|
if (selectedStudy?.id) {
|
||||||
|
xtermRef.current?.writeln(`\x1b[90mStudy context: \x1b[1;33m${selectedStudy.id}\x1b[0m`);
|
||||||
|
xtermRef.current?.writeln('\x1b[90mTip: Tell Claude about your study, e.g. "Help me with study ' + selectedStudy.id + '"\x1b[0m');
|
||||||
|
}
|
||||||
xtermRef.current?.writeln('');
|
xtermRef.current?.writeln('');
|
||||||
|
|
||||||
// Send initial resize
|
// Send initial resize
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ export function StudyReportViewer({ studyId, studyPath }: StudyReportViewerProps
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70">
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70">
|
||||||
<div className="bg-dark-800 rounded-xl shadow-2xl w-[90vw] max-w-5xl h-[85vh] flex flex-col border border-dark-600">
|
<div className="bg-dark-800 rounded-xl shadow-2xl w-[95vw] max-w-7xl h-[90vh] flex flex-col border border-dark-600">
|
||||||
{/* Header */}
|
{/* Header */}
|
||||||
<div className="flex items-center justify-between px-6 py-4 border-b border-dark-600">
|
<div className="flex items-center justify-between px-6 py-4 border-b border-dark-600">
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
@@ -127,8 +127,8 @@ export function StudyReportViewer({ studyId, studyPath }: StudyReportViewerProps
|
|||||||
{markdown && !loading && (
|
{markdown && !loading && (
|
||||||
<article className="markdown-body">
|
<article className="markdown-body">
|
||||||
<ReactMarkdown
|
<ReactMarkdown
|
||||||
remarkPlugins={[remarkGfm, remarkMath]}
|
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||||
rehypePlugins={[rehypeKatex]}
|
rehypePlugins={[[rehypeKatex, { strict: false, trust: true, output: 'html' }]]}
|
||||||
components={{
|
components={{
|
||||||
// Custom heading styles
|
// Custom heading styles
|
||||||
h1: ({children}) => (
|
h1: ({children}) => (
|
||||||
|
|||||||
@@ -69,3 +69,39 @@
|
|||||||
::-webkit-scrollbar-thumb:hover {
|
::-webkit-scrollbar-thumb:hover {
|
||||||
@apply bg-dark-400;
|
@apply bg-dark-400;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* KaTeX Math Rendering */
|
||||||
|
.katex {
|
||||||
|
font-size: 1.1em !important;
|
||||||
|
color: inherit !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.katex-display {
|
||||||
|
margin: 1em 0 !important;
|
||||||
|
overflow-x: auto;
|
||||||
|
overflow-y: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.katex-display > .katex {
|
||||||
|
color: #e2e8f0 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Markdown body styles */
|
||||||
|
.markdown-body {
|
||||||
|
color: #e2e8f0;
|
||||||
|
line-height: 1.7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-body .katex-display {
|
||||||
|
background: rgba(30, 41, 59, 0.5);
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Code blocks in markdown should have proper width */
|
||||||
|
.markdown-body pre {
|
||||||
|
max-width: 100%;
|
||||||
|
overflow-x: auto;
|
||||||
|
white-space: pre;
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ export default function Dashboard() {
|
|||||||
const [alertIdCounter, setAlertIdCounter] = useState(0);
|
const [alertIdCounter, setAlertIdCounter] = useState(0);
|
||||||
const [expandedTrials, setExpandedTrials] = useState<Set<number>>(new Set());
|
const [expandedTrials, setExpandedTrials] = useState<Set<number>>(new Set());
|
||||||
const [sortBy, setSortBy] = useState<'performance' | 'chronological'>('performance');
|
const [sortBy, setSortBy] = useState<'performance' | 'chronological'>('performance');
|
||||||
|
const [trialsPage, setTrialsPage] = useState(0);
|
||||||
|
const trialsPerPage = 50; // Limit trials per page for performance
|
||||||
|
|
||||||
// Parameter Space axis selection
|
// Parameter Space axis selection
|
||||||
const [paramXIndex, setParamXIndex] = useState(0);
|
const [paramXIndex, setParamXIndex] = useState(0);
|
||||||
@@ -99,6 +101,9 @@ export default function Dashboard() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Load initial trial history when study changes
|
// Load initial trial history when study changes
|
||||||
|
// PERFORMANCE: Use limit to avoid loading thousands of trials at once
|
||||||
|
const MAX_TRIALS_LOAD = 300;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedStudyId) {
|
if (selectedStudyId) {
|
||||||
setAllTrials([]);
|
setAllTrials([]);
|
||||||
@@ -106,74 +111,63 @@ export default function Dashboard() {
|
|||||||
setPrunedCount(0);
|
setPrunedCount(0);
|
||||||
setExpandedTrials(new Set());
|
setExpandedTrials(new Set());
|
||||||
|
|
||||||
apiClient.getStudyHistory(selectedStudyId)
|
// Single history fetch with limit - used for both trial list and charts
|
||||||
|
// This replaces the duplicate fetch calls
|
||||||
|
fetch(`/api/optimization/studies/${selectedStudyId}/history?limit=${MAX_TRIALS_LOAD}`)
|
||||||
|
.then(res => res.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
const validTrials = data.trials.filter(t => t.objective !== null && t.objective !== undefined);
|
// Set trials for the trial list
|
||||||
|
const validTrials = data.trials.filter((t: any) => t.objective !== null && t.objective !== undefined);
|
||||||
setAllTrials(validTrials);
|
setAllTrials(validTrials);
|
||||||
if (validTrials.length > 0) {
|
if (validTrials.length > 0) {
|
||||||
const minObj = Math.min(...validTrials.map(t => t.objective));
|
const minObj = Math.min(...validTrials.map((t: any) => t.objective));
|
||||||
setBestValue(minObj);
|
setBestValue(minObj);
|
||||||
}
|
}
|
||||||
})
|
|
||||||
.catch(console.error);
|
|
||||||
|
|
||||||
apiClient.getStudyPruning(selectedStudyId)
|
// Transform for charts (parallel coordinates, etc.)
|
||||||
.then(data => {
|
|
||||||
// Use count if available (new API), fallback to array length (legacy)
|
|
||||||
setPrunedCount(data.count ?? data.pruned_trials?.length ?? 0);
|
|
||||||
})
|
|
||||||
.catch(console.error);
|
|
||||||
|
|
||||||
// Protocol 13: Fetch metadata
|
|
||||||
fetch(`/api/optimization/studies/${selectedStudyId}/metadata`)
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
|
||||||
setStudyMetadata(data);
|
|
||||||
})
|
|
||||||
.catch(err => console.error('Failed to load metadata:', err));
|
|
||||||
|
|
||||||
// Protocol 13: Fetch Pareto front (raw format for Protocol 13 components)
|
|
||||||
fetch(`/api/optimization/studies/${selectedStudyId}/pareto-front`)
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(paretoData => {
|
|
||||||
console.log('[Dashboard] Pareto front data:', paretoData);
|
|
||||||
if (paretoData.is_multi_objective && paretoData.pareto_front) {
|
|
||||||
console.log('[Dashboard] Setting Pareto front with', paretoData.pareto_front.length, 'trials');
|
|
||||||
setParetoFront(paretoData.pareto_front);
|
|
||||||
} else {
|
|
||||||
console.log('[Dashboard] No Pareto front or not multi-objective');
|
|
||||||
setParetoFront([]);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch(err => console.error('Failed to load Pareto front:', err));
|
|
||||||
|
|
||||||
// Fetch ALL trials (not just Pareto) for parallel coordinates and charts
|
|
||||||
fetch(`/api/optimization/studies/${selectedStudyId}/history`)
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
|
||||||
// Transform to match the format expected by charts
|
|
||||||
// API returns 'objectives' (array) for multi-objective, 'objective' (number) for single
|
|
||||||
const trialsData = data.trials.map((t: any) => {
|
const trialsData = data.trials.map((t: any) => {
|
||||||
// Build values array: use objectives if available, otherwise wrap single objective
|
|
||||||
let values: number[] = [];
|
let values: number[] = [];
|
||||||
if (t.objectives && Array.isArray(t.objectives)) {
|
if (t.objectives && Array.isArray(t.objectives)) {
|
||||||
values = t.objectives;
|
values = t.objectives;
|
||||||
} else if (t.objective !== null && t.objective !== undefined) {
|
} else if (t.objective !== null && t.objective !== undefined) {
|
||||||
values = [t.objective];
|
values = [t.objective];
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
trial_number: t.trial_number,
|
trial_number: t.trial_number,
|
||||||
values,
|
values,
|
||||||
params: t.design_variables || {},
|
params: t.design_variables || {},
|
||||||
user_attrs: t.user_attrs || {},
|
user_attrs: t.user_attrs || {},
|
||||||
constraint_satisfied: t.constraint_satisfied !== false,
|
constraint_satisfied: t.constraint_satisfied !== false,
|
||||||
source: t.source || t.user_attrs?.source || 'FEA' // FEA vs NN differentiation
|
source: t.source || t.user_attrs?.source || 'FEA'
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
setAllTrialsRaw(trialsData);
|
setAllTrialsRaw(trialsData);
|
||||||
})
|
})
|
||||||
.catch(err => console.error('Failed to load all trials:', err));
|
.catch(console.error);
|
||||||
|
|
||||||
|
apiClient.getStudyPruning(selectedStudyId)
|
||||||
|
.then(data => {
|
||||||
|
setPrunedCount(data.count ?? data.pruned_trials?.length ?? 0);
|
||||||
|
})
|
||||||
|
.catch(console.error);
|
||||||
|
|
||||||
|
// Fetch metadata (small payload)
|
||||||
|
fetch(`/api/optimization/studies/${selectedStudyId}/metadata`)
|
||||||
|
.then(res => res.json())
|
||||||
|
.then(data => setStudyMetadata(data))
|
||||||
|
.catch(err => console.error('Failed to load metadata:', err));
|
||||||
|
|
||||||
|
// Fetch Pareto front (usually small)
|
||||||
|
fetch(`/api/optimization/studies/${selectedStudyId}/pareto-front`)
|
||||||
|
.then(res => res.json())
|
||||||
|
.then(paretoData => {
|
||||||
|
if (paretoData.is_multi_objective && paretoData.pareto_front) {
|
||||||
|
setParetoFront(paretoData.pareto_front);
|
||||||
|
} else {
|
||||||
|
setParetoFront([]);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(err => console.error('Failed to load Pareto front:', err));
|
||||||
}
|
}
|
||||||
}, [selectedStudyId]);
|
}, [selectedStudyId]);
|
||||||
|
|
||||||
@@ -194,41 +188,77 @@ export default function Dashboard() {
|
|||||||
setDisplayedTrials(sorted);
|
setDisplayedTrials(sorted);
|
||||||
}, [allTrials, sortBy]);
|
}, [allTrials, sortBy]);
|
||||||
|
|
||||||
// Auto-refresh polling (every 3 seconds) for trial history
|
// Auto-refresh polling for trial history
|
||||||
|
// PERFORMANCE: Use limit and longer interval for large studies
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!selectedStudyId) return;
|
if (!selectedStudyId) return;
|
||||||
|
|
||||||
const refreshInterval = setInterval(() => {
|
const refreshInterval = setInterval(() => {
|
||||||
apiClient.getStudyHistory(selectedStudyId)
|
// Only fetch latest trials, not the entire history
|
||||||
|
fetch(`/api/optimization/studies/${selectedStudyId}/history?limit=${MAX_TRIALS_LOAD}`)
|
||||||
|
.then(res => res.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
const validTrials = data.trials.filter(t => t.objective !== null && t.objective !== undefined);
|
const validTrials = data.trials.filter((t: any) => t.objective !== null && t.objective !== undefined);
|
||||||
setAllTrials(validTrials);
|
setAllTrials(validTrials);
|
||||||
if (validTrials.length > 0) {
|
if (validTrials.length > 0) {
|
||||||
const minObj = Math.min(...validTrials.map(t => t.objective));
|
const minObj = Math.min(...validTrials.map((t: any) => t.objective));
|
||||||
setBestValue(minObj);
|
setBestValue(minObj);
|
||||||
}
|
}
|
||||||
|
// Also update chart data
|
||||||
|
const trialsData = data.trials.map((t: any) => {
|
||||||
|
let values: number[] = [];
|
||||||
|
if (t.objectives && Array.isArray(t.objectives)) {
|
||||||
|
values = t.objectives;
|
||||||
|
} else if (t.objective !== null && t.objective !== undefined) {
|
||||||
|
values = [t.objective];
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
trial_number: t.trial_number,
|
||||||
|
values,
|
||||||
|
params: t.design_variables || {},
|
||||||
|
user_attrs: t.user_attrs || {},
|
||||||
|
constraint_satisfied: t.constraint_satisfied !== false,
|
||||||
|
source: t.source || t.user_attrs?.source || 'FEA'
|
||||||
|
};
|
||||||
|
});
|
||||||
|
setAllTrialsRaw(trialsData);
|
||||||
})
|
})
|
||||||
.catch(err => console.error('Auto-refresh failed:', err));
|
.catch(err => console.error('Auto-refresh failed:', err));
|
||||||
}, 3000); // Poll every 3 seconds
|
}, 15000); // Poll every 15 seconds for performance
|
||||||
|
|
||||||
return () => clearInterval(refreshInterval);
|
return () => clearInterval(refreshInterval);
|
||||||
}, [selectedStudyId]);
|
}, [selectedStudyId]);
|
||||||
|
|
||||||
// Prepare chart data with proper null/undefined handling
|
// Sample data for charts when there are too many trials (performance optimization)
|
||||||
const convergenceData: ConvergenceDataPoint[] = allTrials
|
const MAX_CHART_POINTS = 200; // Reduced for better performance
|
||||||
.filter(t => t.objective !== null && t.objective !== undefined)
|
const sampleData = <T,>(data: T[], maxPoints: number): T[] => {
|
||||||
.sort((a, b) => a.trial_number - b.trial_number)
|
if (data.length <= maxPoints) return data;
|
||||||
.map((trial, idx, arr) => {
|
const step = Math.ceil(data.length / maxPoints);
|
||||||
const previousTrials = arr.slice(0, idx + 1);
|
return data.filter((_, i) => i % step === 0 || i === data.length - 1);
|
||||||
const validObjectives = previousTrials.map(t => t.objective).filter(o => o !== null && o !== undefined);
|
};
|
||||||
return {
|
|
||||||
trial_number: trial.trial_number,
|
|
||||||
objective: trial.objective,
|
|
||||||
best_so_far: validObjectives.length > 0 ? Math.min(...validObjectives) : trial.objective,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const parameterSpaceData: ParameterSpaceDataPoint[] = allTrials
|
// Prepare chart data with proper null/undefined handling
|
||||||
|
const allValidTrials = allTrials
|
||||||
|
.filter(t => t.objective !== null && t.objective !== undefined)
|
||||||
|
.sort((a, b) => a.trial_number - b.trial_number);
|
||||||
|
|
||||||
|
// Calculate best_so_far for each trial
|
||||||
|
let runningBest = Infinity;
|
||||||
|
const convergenceDataFull: ConvergenceDataPoint[] = allValidTrials.map(trial => {
|
||||||
|
if (trial.objective < runningBest) {
|
||||||
|
runningBest = trial.objective;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
trial_number: trial.trial_number,
|
||||||
|
objective: trial.objective,
|
||||||
|
best_so_far: runningBest,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
// Sample for chart rendering performance
|
||||||
|
const convergenceData = sampleData(convergenceDataFull, MAX_CHART_POINTS);
|
||||||
|
|
||||||
|
const parameterSpaceDataFull: ParameterSpaceDataPoint[] = allTrials
|
||||||
.filter(t => t.objective !== null && t.objective !== undefined && t.design_variables)
|
.filter(t => t.objective !== null && t.objective !== undefined && t.design_variables)
|
||||||
.map(trial => {
|
.map(trial => {
|
||||||
const params = Object.values(trial.design_variables);
|
const params = Object.values(trial.design_variables);
|
||||||
@@ -241,6 +271,9 @@ export default function Dashboard() {
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Sample for chart rendering performance
|
||||||
|
const parameterSpaceData = sampleData(parameterSpaceDataFull, MAX_CHART_POINTS);
|
||||||
|
|
||||||
// Calculate average objective
|
// Calculate average objective
|
||||||
const validObjectives = allTrials.filter(t => t.objective !== null && t.objective !== undefined).map(t => t.objective);
|
const validObjectives = allTrials.filter(t => t.objective !== null && t.objective !== undefined).map(t => t.objective);
|
||||||
const avgObjective = validObjectives.length > 0
|
const avgObjective = validObjectives.length > 0
|
||||||
@@ -384,14 +417,14 @@ export default function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<div className="grid grid-cols-12 gap-6">
|
<div className="grid grid-cols-12 gap-4">
|
||||||
{/* Control Panel - Left Sidebar */}
|
{/* Control Panel - Left Sidebar (smaller) */}
|
||||||
<aside className="col-span-3">
|
<aside className="col-span-2">
|
||||||
<ControlPanel onStatusChange={refreshStudies} />
|
<ControlPanel onStatusChange={refreshStudies} />
|
||||||
</aside>
|
</aside>
|
||||||
|
|
||||||
{/* Main Content - shrinks when chat is open */}
|
{/* Main Content - takes most of the space */}
|
||||||
<main className={chatOpen ? 'col-span-5' : 'col-span-9'}>
|
<main className={chatOpen ? 'col-span-6' : 'col-span-10'}>
|
||||||
{/* Study Name Header */}
|
{/* Study Name Header */}
|
||||||
{selectedStudyId && (
|
{selectedStudyId && (
|
||||||
<div className="mb-4 pb-3 border-b border-dark-600">
|
<div className="mb-4 pb-3 border-b border-dark-600">
|
||||||
@@ -694,12 +727,12 @@ export default function Dashboard() {
|
|||||||
</ExpandableChart>
|
</ExpandableChart>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Trial History with Sort Controls */}
|
{/* Trial History with Sort Controls and Pagination */}
|
||||||
<Card
|
<Card
|
||||||
title={
|
title={
|
||||||
<div className="flex items-center justify-between w-full">
|
<div className="flex items-center justify-between w-full">
|
||||||
<span>Trial History ({displayedTrials.length} trials)</span>
|
<span>Trial History ({displayedTrials.length} trials)</span>
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2 items-center">
|
||||||
<button
|
<button
|
||||||
onClick={() => setSortBy('performance')}
|
onClick={() => setSortBy('performance')}
|
||||||
className={`px-3 py-1 rounded text-sm ${
|
className={`px-3 py-1 rounded text-sm ${
|
||||||
@@ -720,13 +753,35 @@ export default function Dashboard() {
|
|||||||
>
|
>
|
||||||
Newest First
|
Newest First
|
||||||
</button>
|
</button>
|
||||||
|
{/* Pagination controls */}
|
||||||
|
{displayedTrials.length > trialsPerPage && (
|
||||||
|
<div className="flex items-center gap-1 ml-2 border-l border-dark-500 pl-2">
|
||||||
|
<button
|
||||||
|
onClick={() => setTrialsPage(Math.max(0, trialsPage - 1))}
|
||||||
|
disabled={trialsPage === 0}
|
||||||
|
className="px-2 py-1 text-sm bg-dark-500 rounded disabled:opacity-50 hover:bg-dark-400"
|
||||||
|
>
|
||||||
|
‹
|
||||||
|
</button>
|
||||||
|
<span className="text-xs text-dark-300 px-2">
|
||||||
|
{trialsPage + 1}/{Math.ceil(displayedTrials.length / trialsPerPage)}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
onClick={() => setTrialsPage(Math.min(Math.ceil(displayedTrials.length / trialsPerPage) - 1, trialsPage + 1))}
|
||||||
|
disabled={trialsPage >= Math.ceil(displayedTrials.length / trialsPerPage) - 1}
|
||||||
|
className="px-2 py-1 text-sm bg-dark-500 rounded disabled:opacity-50 hover:bg-dark-400"
|
||||||
|
>
|
||||||
|
›
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
<div className="space-y-2 max-h-[600px] overflow-y-auto">
|
<div className="space-y-2 max-h-[600px] overflow-y-auto">
|
||||||
{displayedTrials.length > 0 ? (
|
{displayedTrials.length > 0 ? (
|
||||||
displayedTrials.map(trial => {
|
displayedTrials.slice(trialsPage * trialsPerPage, (trialsPage + 1) * trialsPerPage).map(trial => {
|
||||||
const isExpanded = expandedTrials.has(trial.trial_number);
|
const isExpanded = expandedTrials.has(trial.trial_number);
|
||||||
const isBest = trial.objective === bestValue;
|
const isBest = trial.objective === bestValue;
|
||||||
|
|
||||||
@@ -879,9 +934,9 @@ export default function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
|
|
||||||
{/* Claude Code Terminal - Right Sidebar */}
|
{/* Claude Code Terminal - Right Sidebar (taller for better visibility) */}
|
||||||
{chatOpen && (
|
{chatOpen && (
|
||||||
<aside className="col-span-4 h-[calc(100vh-12rem)] sticky top-24">
|
<aside className="col-span-4 h-[calc(100vh-8rem)] sticky top-20">
|
||||||
<ClaudeTerminal
|
<ClaudeTerminal
|
||||||
isExpanded={chatExpanded}
|
isExpanded={chatExpanded}
|
||||||
onToggleExpand={() => setChatExpanded(!chatExpanded)}
|
onToggleExpand={() => setChatExpanded(!chatExpanded)}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import ReactMarkdown from 'react-markdown';
|
|||||||
import remarkGfm from 'remark-gfm';
|
import remarkGfm from 'remark-gfm';
|
||||||
import remarkMath from 'remark-math';
|
import remarkMath from 'remark-math';
|
||||||
import rehypeKatex from 'rehype-katex';
|
import rehypeKatex from 'rehype-katex';
|
||||||
|
import 'katex/dist/katex.min.css';
|
||||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
||||||
import { oneDark } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
import { oneDark } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
||||||
import { apiClient } from '../api/client';
|
import { apiClient } from '../api/client';
|
||||||
@@ -101,11 +102,24 @@ const Home: React.FC = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Sort studies: running first, then by trial count
|
// Study sort options
|
||||||
|
const [studySort, setStudySort] = useState<'date' | 'running' | 'trials'>('date');
|
||||||
|
|
||||||
|
// Sort studies based on selected sort option
|
||||||
const sortedStudies = [...studies].sort((a, b) => {
|
const sortedStudies = [...studies].sort((a, b) => {
|
||||||
if (a.status === 'running' && b.status !== 'running') return -1;
|
if (studySort === 'running') {
|
||||||
if (b.status === 'running' && a.status !== 'running') return 1;
|
// Running first, then by date
|
||||||
return b.progress.current - a.progress.current;
|
if (a.status === 'running' && b.status !== 'running') return -1;
|
||||||
|
if (b.status === 'running' && a.status !== 'running') return 1;
|
||||||
|
}
|
||||||
|
if (studySort === 'trials') {
|
||||||
|
// By trial count (most trials first)
|
||||||
|
return b.progress.current - a.progress.current;
|
||||||
|
}
|
||||||
|
// Default: sort by date (newest first)
|
||||||
|
const aDate = a.last_modified || a.created_at || '';
|
||||||
|
const bDate = b.last_modified || b.created_at || '';
|
||||||
|
return bDate.localeCompare(aDate);
|
||||||
});
|
});
|
||||||
|
|
||||||
const displayedStudies = showAllStudies ? sortedStudies : sortedStudies.slice(0, 6);
|
const displayedStudies = showAllStudies ? sortedStudies : sortedStudies.slice(0, 6);
|
||||||
@@ -114,7 +128,7 @@ const Home: React.FC = () => {
|
|||||||
<div className="min-h-screen bg-dark-900">
|
<div className="min-h-screen bg-dark-900">
|
||||||
{/* Header */}
|
{/* Header */}
|
||||||
<header className="bg-dark-800/50 border-b border-dark-700 backdrop-blur-sm sticky top-0 z-10">
|
<header className="bg-dark-800/50 border-b border-dark-700 backdrop-blur-sm sticky top-0 z-10">
|
||||||
<div className="max-w-[1600px] mx-auto px-6 py-4">
|
<div className="max-w-[1920px] mx-auto px-6 py-4">
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<div className="flex items-center gap-4">
|
<div className="flex items-center gap-4">
|
||||||
<div className="w-11 h-11 bg-gradient-to-br from-primary-500 to-primary-700 rounded-xl flex items-center justify-center shadow-lg shadow-primary-500/20">
|
<div className="w-11 h-11 bg-gradient-to-br from-primary-500 to-primary-700 rounded-xl flex items-center justify-center shadow-lg shadow-primary-500/20">
|
||||||
@@ -138,7 +152,7 @@ const Home: React.FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<main className="max-w-[1600px] mx-auto px-6 py-8">
|
<main className="max-w-[1920px] mx-auto px-6 py-8">
|
||||||
{/* Study Selection Section */}
|
{/* Study Selection Section */}
|
||||||
<section className="mb-8">
|
<section className="mb-8">
|
||||||
<div className="flex items-center justify-between mb-4">
|
<div className="flex items-center justify-between mb-4">
|
||||||
@@ -146,18 +160,56 @@ const Home: React.FC = () => {
|
|||||||
<FolderOpen className="w-5 h-5 text-primary-400" />
|
<FolderOpen className="w-5 h-5 text-primary-400" />
|
||||||
Select a Study
|
Select a Study
|
||||||
</h2>
|
</h2>
|
||||||
{studies.length > 6 && (
|
<div className="flex items-center gap-4">
|
||||||
<button
|
{/* Sort Controls */}
|
||||||
onClick={() => setShowAllStudies(!showAllStudies)}
|
<div className="flex items-center gap-2">
|
||||||
className="text-sm text-primary-400 hover:text-primary-300 flex items-center gap-1"
|
<span className="text-sm text-dark-400">Sort:</span>
|
||||||
>
|
<div className="flex rounded-lg overflow-hidden border border-dark-600">
|
||||||
{showAllStudies ? (
|
<button
|
||||||
<>Show Less <ChevronUp className="w-4 h-4" /></>
|
onClick={() => setStudySort('date')}
|
||||||
) : (
|
className={`px-3 py-1.5 text-sm transition-colors ${
|
||||||
<>Show All ({studies.length}) <ChevronDown className="w-4 h-4" /></>
|
studySort === 'date'
|
||||||
)}
|
? 'bg-primary-500 text-white'
|
||||||
</button>
|
: 'bg-dark-700 text-dark-300 hover:bg-dark-600'
|
||||||
)}
|
}`}
|
||||||
|
>
|
||||||
|
Newest
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => setStudySort('running')}
|
||||||
|
className={`px-3 py-1.5 text-sm transition-colors ${
|
||||||
|
studySort === 'running'
|
||||||
|
? 'bg-primary-500 text-white'
|
||||||
|
: 'bg-dark-700 text-dark-300 hover:bg-dark-600'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
Running
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => setStudySort('trials')}
|
||||||
|
className={`px-3 py-1.5 text-sm transition-colors ${
|
||||||
|
studySort === 'trials'
|
||||||
|
? 'bg-primary-500 text-white'
|
||||||
|
: 'bg-dark-700 text-dark-300 hover:bg-dark-600'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
Most Trials
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{studies.length > 6 && (
|
||||||
|
<button
|
||||||
|
onClick={() => setShowAllStudies(!showAllStudies)}
|
||||||
|
className="text-sm text-primary-400 hover:text-primary-300 flex items-center gap-1"
|
||||||
|
>
|
||||||
|
{showAllStudies ? (
|
||||||
|
<>Show Less <ChevronUp className="w-4 h-4" /></>
|
||||||
|
) : (
|
||||||
|
<>Show All ({studies.length}) <ChevronDown className="w-4 h-4" /></>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
@@ -273,8 +325,8 @@ const Home: React.FC = () => {
|
|||||||
<div className="p-8 overflow-x-auto">
|
<div className="p-8 overflow-x-auto">
|
||||||
<article className="markdown-body max-w-none">
|
<article className="markdown-body max-w-none">
|
||||||
<ReactMarkdown
|
<ReactMarkdown
|
||||||
remarkPlugins={[remarkGfm, remarkMath]}
|
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||||
rehypePlugins={[rehypeKatex]}
|
rehypePlugins={[[rehypeKatex, { strict: false, trust: true, output: 'html' }]]}
|
||||||
components={{
|
components={{
|
||||||
// Custom heading styles
|
// Custom heading styles
|
||||||
h1: ({ children }) => (
|
h1: ({ children }) => (
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ export interface Study {
|
|||||||
best_value: number | null;
|
best_value: number | null;
|
||||||
target: number | null;
|
target: number | null;
|
||||||
path: string;
|
path: string;
|
||||||
|
created_at?: string;
|
||||||
|
last_modified?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StudyListResponse {
|
export interface StudyListResponse {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ export default defineConfig({
|
|||||||
strictPort: false, // Allow fallback to next available port
|
strictPort: false, // Allow fallback to next available port
|
||||||
proxy: {
|
proxy: {
|
||||||
'/api': {
|
'/api': {
|
||||||
target: 'http://127.0.0.1:8000', // Use 127.0.0.1 instead of localhost
|
target: 'http://127.0.0.1:8001', // Use 127.0.0.1 instead of localhost
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
secure: false,
|
secure: false,
|
||||||
ws: true,
|
ws: true,
|
||||||
|
|||||||
800
optimization_engine/surrogate_tuner.py
Normal file
800
optimization_engine/surrogate_tuner.py
Normal file
@@ -0,0 +1,800 @@
|
|||||||
|
"""
|
||||||
|
Hyperparameter Tuning for Neural Network Surrogates
|
||||||
|
|
||||||
|
This module provides automatic hyperparameter optimization for MLP surrogates
|
||||||
|
using Optuna, with proper train/validation splits and early stopping.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
1. Optuna-based hyperparameter search
|
||||||
|
2. K-fold cross-validation
|
||||||
|
3. Early stopping to prevent overfitting
|
||||||
|
4. Ensemble model support
|
||||||
|
5. Proper uncertainty quantification
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from optimization_engine.surrogate_tuner import SurrogateHyperparameterTuner
|
||||||
|
|
||||||
|
tuner = SurrogateHyperparameterTuner(
|
||||||
|
input_dim=11,
|
||||||
|
output_dim=3,
|
||||||
|
n_trials=50
|
||||||
|
)
|
||||||
|
best_config = tuner.tune(X_train, Y_train)
|
||||||
|
model = tuner.create_tuned_model(best_config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
logger.warning("PyTorch not installed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import optuna
|
||||||
|
from optuna.samplers import TPESampler
|
||||||
|
from optuna.pruners import MedianPruner
|
||||||
|
OPTUNA_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
OPTUNA_AVAILABLE = False
|
||||||
|
logger.warning("Optuna not installed")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SurrogateConfig:
|
||||||
|
"""Configuration for a tuned surrogate model."""
|
||||||
|
hidden_dims: List[int] = field(default_factory=lambda: [128, 256, 128])
|
||||||
|
dropout: float = 0.1
|
||||||
|
activation: str = 'relu'
|
||||||
|
use_batch_norm: bool = True
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
weight_decay: float = 1e-4
|
||||||
|
batch_size: int = 16
|
||||||
|
max_epochs: int = 500
|
||||||
|
early_stopping_patience: int = 30
|
||||||
|
|
||||||
|
# Normalization stats (filled during training)
|
||||||
|
input_mean: Optional[np.ndarray] = None
|
||||||
|
input_std: Optional[np.ndarray] = None
|
||||||
|
output_mean: Optional[np.ndarray] = None
|
||||||
|
output_std: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
# Validation metrics
|
||||||
|
val_loss: float = float('inf')
|
||||||
|
val_r2: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TunableMLP(nn.Module):
|
||||||
|
"""Flexible MLP with configurable architecture."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
hidden_dims: List[int],
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation: str = 'relu',
|
||||||
|
use_batch_norm: bool = True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# Activation function
|
||||||
|
activations = {
|
||||||
|
'relu': nn.ReLU(),
|
||||||
|
'leaky_relu': nn.LeakyReLU(0.1),
|
||||||
|
'elu': nn.ELU(),
|
||||||
|
'selu': nn.SELU(),
|
||||||
|
'gelu': nn.GELU(),
|
||||||
|
'swish': nn.SiLU()
|
||||||
|
}
|
||||||
|
act_fn = activations.get(activation, nn.ReLU())
|
||||||
|
|
||||||
|
# Build layers
|
||||||
|
layers = []
|
||||||
|
prev_dim = input_dim
|
||||||
|
|
||||||
|
for hidden_dim in hidden_dims:
|
||||||
|
layers.append(nn.Linear(prev_dim, hidden_dim))
|
||||||
|
if use_batch_norm:
|
||||||
|
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||||
|
layers.append(act_fn)
|
||||||
|
if dropout > 0:
|
||||||
|
layers.append(nn.Dropout(dropout))
|
||||||
|
prev_dim = hidden_dim
|
||||||
|
|
||||||
|
layers.append(nn.Linear(prev_dim, output_dim))
|
||||||
|
self.network = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
"""Initialize weights using Kaiming initialization."""
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.network(x)
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""Early stopping to prevent overfitting."""
|
||||||
|
|
||||||
|
def __init__(self, patience: int = 20, min_delta: float = 1e-5):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.best_model_state = None
|
||||||
|
self.should_stop = False
|
||||||
|
|
||||||
|
def __call__(self, val_loss: float, model: nn.Module) -> bool:
|
||||||
|
if val_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||||
|
self.counter = 0
|
||||||
|
else:
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.should_stop = True
|
||||||
|
|
||||||
|
return self.should_stop
|
||||||
|
|
||||||
|
def restore_best(self, model: nn.Module):
|
||||||
|
"""Restore model to best state."""
|
||||||
|
if self.best_model_state is not None:
|
||||||
|
model.load_state_dict(self.best_model_state)
|
||||||
|
|
||||||
|
|
||||||
|
class SurrogateHyperparameterTuner:
|
||||||
|
"""
|
||||||
|
Automatic hyperparameter tuning for neural network surrogates.
|
||||||
|
|
||||||
|
Uses Optuna for Bayesian optimization of:
|
||||||
|
- Network architecture (layers, widths)
|
||||||
|
- Regularization (dropout, weight decay)
|
||||||
|
- Learning rate and batch size
|
||||||
|
- Activation functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
n_trials: int = 50,
|
||||||
|
n_cv_folds: int = 5,
|
||||||
|
device: str = 'auto',
|
||||||
|
seed: int = 42,
|
||||||
|
timeout_seconds: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize hyperparameter tuner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: Number of input features (design variables)
|
||||||
|
output_dim: Number of outputs (objectives)
|
||||||
|
n_trials: Number of Optuna trials for hyperparameter search
|
||||||
|
n_cv_folds: Number of cross-validation folds
|
||||||
|
device: Computing device ('cuda', 'cpu', or 'auto')
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
timeout_seconds: Optional timeout for tuning
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise ImportError("PyTorch required for surrogate tuning")
|
||||||
|
if not OPTUNA_AVAILABLE:
|
||||||
|
raise ImportError("Optuna required for hyperparameter tuning")
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.n_trials = n_trials
|
||||||
|
self.n_cv_folds = n_cv_folds
|
||||||
|
self.seed = seed
|
||||||
|
self.timeout = timeout_seconds
|
||||||
|
|
||||||
|
if device == 'auto':
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
else:
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
self.best_config: Optional[SurrogateConfig] = None
|
||||||
|
self.study: Optional[optuna.Study] = None
|
||||||
|
|
||||||
|
# Set seeds
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
def _suggest_hyperparameters(self, trial: optuna.Trial) -> SurrogateConfig:
|
||||||
|
"""Suggest hyperparameters for a trial."""
|
||||||
|
|
||||||
|
# Architecture
|
||||||
|
n_layers = trial.suggest_int('n_layers', 2, 5)
|
||||||
|
hidden_dims = []
|
||||||
|
for i in range(n_layers):
|
||||||
|
dim = trial.suggest_int(f'hidden_dim_{i}', 32, 512, step=32)
|
||||||
|
hidden_dims.append(dim)
|
||||||
|
|
||||||
|
# Regularization
|
||||||
|
dropout = trial.suggest_float('dropout', 0.0, 0.5)
|
||||||
|
weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||||
|
batch_size = trial.suggest_categorical('batch_size', [8, 16, 32, 64])
|
||||||
|
|
||||||
|
# Activation
|
||||||
|
activation = trial.suggest_categorical('activation',
|
||||||
|
['relu', 'leaky_relu', 'elu', 'gelu', 'swish'])
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
||||||
|
|
||||||
|
return SurrogateConfig(
|
||||||
|
hidden_dims=hidden_dims,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation,
|
||||||
|
use_batch_norm=use_batch_norm,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_fold(
|
||||||
|
self,
|
||||||
|
config: SurrogateConfig,
|
||||||
|
X_train: np.ndarray,
|
||||||
|
Y_train: np.ndarray,
|
||||||
|
X_val: np.ndarray,
|
||||||
|
Y_val: np.ndarray,
|
||||||
|
trial: Optional[optuna.Trial] = None
|
||||||
|
) -> Tuple[float, Dict[str, float]]:
|
||||||
|
"""Train model on one fold and return validation metrics."""
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = TunableMLP(
|
||||||
|
input_dim=self.input_dim,
|
||||||
|
output_dim=self.output_dim,
|
||||||
|
hidden_dims=config.hidden_dims,
|
||||||
|
dropout=config.dropout,
|
||||||
|
activation=config.activation,
|
||||||
|
use_batch_norm=config.use_batch_norm
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
X_train_t = torch.tensor(X_train, dtype=torch.float32, device=self.device)
|
||||||
|
Y_train_t = torch.tensor(Y_train, dtype=torch.float32, device=self.device)
|
||||||
|
X_val_t = torch.tensor(X_val, dtype=torch.float32, device=self.device)
|
||||||
|
Y_val_t = torch.tensor(Y_val, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(X_train_t, Y_train_t)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Optimizer and scheduler
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.learning_rate,
|
||||||
|
weight_decay=config.weight_decay
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer, T_max=config.max_epochs
|
||||||
|
)
|
||||||
|
|
||||||
|
early_stopping = EarlyStopping(patience=config.early_stopping_patience)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(config.max_epochs):
|
||||||
|
model.train()
|
||||||
|
for X_batch, Y_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
pred = model(X_batch)
|
||||||
|
loss = nn.functional.mse_loss(pred, Y_batch)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_pred = model(X_val_t)
|
||||||
|
val_loss = nn.functional.mse_loss(val_pred, Y_val_t).item()
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if early_stopping(val_loss, model):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Optuna pruning (only report once per epoch across all folds)
|
||||||
|
if trial is not None and epoch % 10 == 0:
|
||||||
|
trial.report(val_loss, epoch // 10)
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.TrialPruned()
|
||||||
|
|
||||||
|
# Restore best model
|
||||||
|
early_stopping.restore_best(model)
|
||||||
|
|
||||||
|
# Final validation metrics
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_pred = model(X_val_t).cpu().numpy()
|
||||||
|
Y_val_np = Y_val_t.cpu().numpy()
|
||||||
|
|
||||||
|
val_loss = float(np.mean((val_pred - Y_val_np) ** 2))
|
||||||
|
|
||||||
|
# R² per output
|
||||||
|
r2_scores = {}
|
||||||
|
for i in range(self.output_dim):
|
||||||
|
ss_res = np.sum((Y_val_np[:, i] - val_pred[:, i]) ** 2)
|
||||||
|
ss_tot = np.sum((Y_val_np[:, i] - Y_val_np[:, i].mean()) ** 2)
|
||||||
|
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||||
|
r2_scores[f'output_{i}'] = r2
|
||||||
|
|
||||||
|
return val_loss, r2_scores
|
||||||
|
|
||||||
|
def _cross_validate(
|
||||||
|
self,
|
||||||
|
config: SurrogateConfig,
|
||||||
|
X: np.ndarray,
|
||||||
|
Y: np.ndarray,
|
||||||
|
trial: Optional[optuna.Trial] = None
|
||||||
|
) -> Tuple[float, Dict[str, float]]:
|
||||||
|
"""Perform k-fold cross-validation."""
|
||||||
|
|
||||||
|
n_samples = len(X)
|
||||||
|
indices = np.random.permutation(n_samples)
|
||||||
|
fold_size = n_samples // self.n_cv_folds
|
||||||
|
|
||||||
|
fold_losses = []
|
||||||
|
fold_r2s = {f'output_{i}': [] for i in range(self.output_dim)}
|
||||||
|
|
||||||
|
for fold in range(self.n_cv_folds):
|
||||||
|
# Split indices
|
||||||
|
val_start = fold * fold_size
|
||||||
|
val_end = val_start + fold_size if fold < self.n_cv_folds - 1 else n_samples
|
||||||
|
|
||||||
|
val_indices = indices[val_start:val_end]
|
||||||
|
train_indices = np.concatenate([indices[:val_start], indices[val_end:]])
|
||||||
|
|
||||||
|
X_train, Y_train = X[train_indices], Y[train_indices]
|
||||||
|
X_val, Y_val = X[val_indices], Y[val_indices]
|
||||||
|
|
||||||
|
# Skip fold if too few samples
|
||||||
|
if len(X_train) < 10 or len(X_val) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
val_loss, r2_scores = self._train_fold(
|
||||||
|
config, X_train, Y_train, X_val, Y_val, trial
|
||||||
|
)
|
||||||
|
|
||||||
|
fold_losses.append(val_loss)
|
||||||
|
for key, val in r2_scores.items():
|
||||||
|
fold_r2s[key].append(val)
|
||||||
|
|
||||||
|
mean_loss = np.mean(fold_losses)
|
||||||
|
mean_r2 = {k: np.mean(v) for k, v in fold_r2s.items()}
|
||||||
|
|
||||||
|
return mean_loss, mean_r2
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
X: np.ndarray,
|
||||||
|
Y: np.ndarray,
|
||||||
|
output_names: Optional[List[str]] = None
|
||||||
|
) -> SurrogateConfig:
|
||||||
|
"""
|
||||||
|
Tune hyperparameters using Optuna.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: Input features [n_samples, input_dim]
|
||||||
|
Y: Outputs [n_samples, output_dim]
|
||||||
|
output_names: Optional names for outputs (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best SurrogateConfig found
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting hyperparameter tuning with {self.n_trials} trials...")
|
||||||
|
logger.info(f"Data: {len(X)} samples, {self.n_cv_folds}-fold CV")
|
||||||
|
|
||||||
|
# Normalize data
|
||||||
|
self.input_mean = X.mean(axis=0)
|
||||||
|
self.input_std = X.std(axis=0) + 1e-8
|
||||||
|
self.output_mean = Y.mean(axis=0)
|
||||||
|
self.output_std = Y.std(axis=0) + 1e-8
|
||||||
|
|
||||||
|
X_norm = (X - self.input_mean) / self.input_std
|
||||||
|
Y_norm = (Y - self.output_mean) / self.output_std
|
||||||
|
|
||||||
|
def objective(trial: optuna.Trial) -> float:
|
||||||
|
config = self._suggest_hyperparameters(trial)
|
||||||
|
val_loss, r2_scores = self._cross_validate(config, X_norm, Y_norm, trial)
|
||||||
|
|
||||||
|
# Log R² scores
|
||||||
|
for key, val in r2_scores.items():
|
||||||
|
trial.set_user_attr(f'r2_{key}', val)
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
# Create study
|
||||||
|
self.study = optuna.create_study(
|
||||||
|
direction='minimize',
|
||||||
|
sampler=TPESampler(seed=self.seed, n_startup_trials=10),
|
||||||
|
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=20)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.study.optimize(
|
||||||
|
objective,
|
||||||
|
n_trials=self.n_trials,
|
||||||
|
timeout=self.timeout,
|
||||||
|
show_progress_bar=True,
|
||||||
|
catch=(RuntimeError,) # Catch GPU OOM errors
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build best config
|
||||||
|
best_trial = self.study.best_trial
|
||||||
|
self.best_config = self._suggest_hyperparameters_from_params(best_trial.params)
|
||||||
|
self.best_config.val_loss = best_trial.value
|
||||||
|
self.best_config.val_r2 = {
|
||||||
|
k.replace('r2_', ''): v
|
||||||
|
for k, v in best_trial.user_attrs.items()
|
||||||
|
if k.startswith('r2_')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store normalization
|
||||||
|
self.best_config.input_mean = self.input_mean
|
||||||
|
self.best_config.input_std = self.input_std
|
||||||
|
self.best_config.output_mean = self.output_mean
|
||||||
|
self.best_config.output_std = self.output_std
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
logger.info(f"\nBest hyperparameters found:")
|
||||||
|
logger.info(f" Hidden dims: {self.best_config.hidden_dims}")
|
||||||
|
logger.info(f" Dropout: {self.best_config.dropout:.3f}")
|
||||||
|
logger.info(f" Activation: {self.best_config.activation}")
|
||||||
|
logger.info(f" Batch norm: {self.best_config.use_batch_norm}")
|
||||||
|
logger.info(f" Learning rate: {self.best_config.learning_rate:.2e}")
|
||||||
|
logger.info(f" Weight decay: {self.best_config.weight_decay:.2e}")
|
||||||
|
logger.info(f" Batch size: {self.best_config.batch_size}")
|
||||||
|
logger.info(f" Validation loss: {self.best_config.val_loss:.6f}")
|
||||||
|
|
||||||
|
if output_names:
|
||||||
|
for i, name in enumerate(output_names):
|
||||||
|
r2 = self.best_config.val_r2.get(f'output_{i}', 0)
|
||||||
|
logger.info(f" {name} R² (CV): {r2:.4f}")
|
||||||
|
|
||||||
|
return self.best_config
|
||||||
|
|
||||||
|
def _suggest_hyperparameters_from_params(self, params: Dict[str, Any]) -> SurrogateConfig:
|
||||||
|
"""Reconstruct config from Optuna params dict."""
|
||||||
|
n_layers = params['n_layers']
|
||||||
|
hidden_dims = [params[f'hidden_dim_{i}'] for i in range(n_layers)]
|
||||||
|
|
||||||
|
return SurrogateConfig(
|
||||||
|
hidden_dims=hidden_dims,
|
||||||
|
dropout=params['dropout'],
|
||||||
|
activation=params['activation'],
|
||||||
|
use_batch_norm=params['use_batch_norm'],
|
||||||
|
learning_rate=params['learning_rate'],
|
||||||
|
weight_decay=params['weight_decay'],
|
||||||
|
batch_size=params['batch_size']
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tuned_model(
|
||||||
|
self,
|
||||||
|
config: Optional[SurrogateConfig] = None
|
||||||
|
) -> TunableMLP:
|
||||||
|
"""Create a model with tuned hyperparameters."""
|
||||||
|
if config is None:
|
||||||
|
config = self.best_config
|
||||||
|
if config is None:
|
||||||
|
raise ValueError("No config available. Run tune() first.")
|
||||||
|
|
||||||
|
return TunableMLP(
|
||||||
|
input_dim=self.input_dim,
|
||||||
|
output_dim=self.output_dim,
|
||||||
|
hidden_dims=config.hidden_dims,
|
||||||
|
dropout=config.dropout,
|
||||||
|
activation=config.activation,
|
||||||
|
use_batch_norm=config.use_batch_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TunedEnsembleSurrogate:
|
||||||
|
"""
|
||||||
|
Ensemble of tuned surrogate models for better uncertainty quantification.
|
||||||
|
|
||||||
|
Trains multiple models with different random seeds and aggregates predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SurrogateConfig,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
n_models: int = 5,
|
||||||
|
device: str = 'auto'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize ensemble surrogate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Tuned configuration to use for all models
|
||||||
|
input_dim: Number of input features
|
||||||
|
output_dim: Number of outputs
|
||||||
|
n_models: Number of models in ensemble
|
||||||
|
device: Computing device
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.n_models = n_models
|
||||||
|
|
||||||
|
if device == 'auto':
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
else:
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
self.models: List[TunableMLP] = []
|
||||||
|
self.trained = False
|
||||||
|
|
||||||
|
def train(self, X: np.ndarray, Y: np.ndarray, val_split: float = 0.2):
|
||||||
|
"""Train all models in the ensemble."""
|
||||||
|
logger.info(f"Training ensemble of {self.n_models} models...")
|
||||||
|
|
||||||
|
# Normalize using config stats
|
||||||
|
X_norm = (X - self.config.input_mean) / self.config.input_std
|
||||||
|
Y_norm = (Y - self.config.output_mean) / self.config.output_std
|
||||||
|
|
||||||
|
# Split data
|
||||||
|
n_val = int(len(X) * val_split)
|
||||||
|
indices = np.random.permutation(len(X))
|
||||||
|
train_idx, val_idx = indices[n_val:], indices[:n_val]
|
||||||
|
|
||||||
|
X_train, Y_train = X_norm[train_idx], Y_norm[train_idx]
|
||||||
|
X_val, Y_val = X_norm[val_idx], Y_norm[val_idx]
|
||||||
|
|
||||||
|
X_train_t = torch.tensor(X_train, dtype=torch.float32, device=self.device)
|
||||||
|
Y_train_t = torch.tensor(Y_train, dtype=torch.float32, device=self.device)
|
||||||
|
X_val_t = torch.tensor(X_val, dtype=torch.float32, device=self.device)
|
||||||
|
Y_val_t = torch.tensor(Y_val, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(X_train_t, Y_train_t)
|
||||||
|
|
||||||
|
self.models = []
|
||||||
|
|
||||||
|
for i in range(self.n_models):
|
||||||
|
torch.manual_seed(42 + i)
|
||||||
|
|
||||||
|
model = TunableMLP(
|
||||||
|
input_dim=self.input_dim,
|
||||||
|
output_dim=self.output_dim,
|
||||||
|
hidden_dims=self.config.hidden_dims,
|
||||||
|
dropout=self.config.dropout,
|
||||||
|
activation=self.config.activation,
|
||||||
|
use_batch_norm=self.config.use_batch_norm
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=self.config.learning_rate,
|
||||||
|
weight_decay=self.config.weight_decay
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer, T_max=self.config.max_epochs
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=self.config.batch_size,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
early_stopping = EarlyStopping(patience=self.config.early_stopping_patience)
|
||||||
|
|
||||||
|
for epoch in range(self.config.max_epochs):
|
||||||
|
model.train()
|
||||||
|
for X_batch, Y_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
pred = model(X_batch)
|
||||||
|
loss = nn.functional.mse_loss(pred, Y_batch)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_pred = model(X_val_t)
|
||||||
|
val_loss = nn.functional.mse_loss(val_pred, Y_val_t).item()
|
||||||
|
|
||||||
|
if early_stopping(val_loss, model):
|
||||||
|
break
|
||||||
|
|
||||||
|
early_stopping.restore_best(model)
|
||||||
|
model.eval()
|
||||||
|
self.models.append(model)
|
||||||
|
|
||||||
|
logger.info(f" Model {i+1}/{self.n_models}: val_loss = {early_stopping.best_loss:.6f}")
|
||||||
|
|
||||||
|
self.trained = True
|
||||||
|
logger.info("Ensemble training complete")
|
||||||
|
|
||||||
|
def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Predict with uncertainty estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: Input features [n_samples, input_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mean_predictions, std_predictions)
|
||||||
|
"""
|
||||||
|
if not self.trained:
|
||||||
|
raise RuntimeError("Ensemble not trained. Call train() first.")
|
||||||
|
|
||||||
|
# Normalize input
|
||||||
|
X_norm = (X - self.config.input_mean) / self.config.input_std
|
||||||
|
X_t = torch.tensor(X_norm, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
|
# Collect predictions from all models
|
||||||
|
predictions = []
|
||||||
|
for model in self.models:
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(X_t).cpu().numpy()
|
||||||
|
# Denormalize
|
||||||
|
pred = pred * self.config.output_std + self.config.output_mean
|
||||||
|
predictions.append(pred)
|
||||||
|
|
||||||
|
predictions = np.array(predictions) # [n_models, n_samples, output_dim]
|
||||||
|
|
||||||
|
mean_pred = predictions.mean(axis=0)
|
||||||
|
std_pred = predictions.std(axis=0)
|
||||||
|
|
||||||
|
return mean_pred, std_pred
|
||||||
|
|
||||||
|
def predict_single(self, params: Dict[str, float], var_names: List[str]) -> Tuple[Dict[str, float], float]:
|
||||||
|
"""
|
||||||
|
Predict for a single point with uncertainty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of input parameters
|
||||||
|
var_names: List of variable names in order
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (predictions dict, total uncertainty)
|
||||||
|
"""
|
||||||
|
X = np.array([[params[name] for name in var_names]])
|
||||||
|
mean, std = self.predict(X)
|
||||||
|
|
||||||
|
pred_dict = {f'output_{i}': mean[0, i] for i in range(self.output_dim)}
|
||||||
|
uncertainty = float(np.sum(std[0]))
|
||||||
|
|
||||||
|
return pred_dict, uncertainty
|
||||||
|
|
||||||
|
def save(self, path: Path):
|
||||||
|
"""Save ensemble to disk."""
|
||||||
|
state = {
|
||||||
|
'config': {
|
||||||
|
'hidden_dims': self.config.hidden_dims,
|
||||||
|
'dropout': self.config.dropout,
|
||||||
|
'activation': self.config.activation,
|
||||||
|
'use_batch_norm': self.config.use_batch_norm,
|
||||||
|
'learning_rate': self.config.learning_rate,
|
||||||
|
'weight_decay': self.config.weight_decay,
|
||||||
|
'batch_size': self.config.batch_size,
|
||||||
|
'input_mean': self.config.input_mean.tolist(),
|
||||||
|
'input_std': self.config.input_std.tolist(),
|
||||||
|
'output_mean': self.config.output_mean.tolist(),
|
||||||
|
'output_std': self.config.output_std.tolist(),
|
||||||
|
},
|
||||||
|
'n_models': self.n_models,
|
||||||
|
'model_states': [m.state_dict() for m in self.models]
|
||||||
|
}
|
||||||
|
torch.save(state, path)
|
||||||
|
logger.info(f"Saved ensemble to {path}")
|
||||||
|
|
||||||
|
def load(self, path: Path):
|
||||||
|
"""Load ensemble from disk."""
|
||||||
|
state = torch.load(path, map_location=self.device)
|
||||||
|
|
||||||
|
# Restore config
|
||||||
|
cfg = state['config']
|
||||||
|
self.config = SurrogateConfig(
|
||||||
|
hidden_dims=cfg['hidden_dims'],
|
||||||
|
dropout=cfg['dropout'],
|
||||||
|
activation=cfg['activation'],
|
||||||
|
use_batch_norm=cfg['use_batch_norm'],
|
||||||
|
learning_rate=cfg['learning_rate'],
|
||||||
|
weight_decay=cfg['weight_decay'],
|
||||||
|
batch_size=cfg['batch_size'],
|
||||||
|
input_mean=np.array(cfg['input_mean']),
|
||||||
|
input_std=np.array(cfg['input_std']),
|
||||||
|
output_mean=np.array(cfg['output_mean']),
|
||||||
|
output_std=np.array(cfg['output_std'])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_models = state['n_models']
|
||||||
|
self.models = []
|
||||||
|
|
||||||
|
for model_state in state['model_states']:
|
||||||
|
model = TunableMLP(
|
||||||
|
input_dim=self.input_dim,
|
||||||
|
output_dim=self.output_dim,
|
||||||
|
hidden_dims=self.config.hidden_dims,
|
||||||
|
dropout=self.config.dropout,
|
||||||
|
activation=self.config.activation,
|
||||||
|
use_batch_norm=self.config.use_batch_norm
|
||||||
|
).to(self.device)
|
||||||
|
model.load_state_dict(model_state)
|
||||||
|
model.eval()
|
||||||
|
self.models.append(model)
|
||||||
|
|
||||||
|
self.trained = True
|
||||||
|
logger.info(f"Loaded ensemble with {self.n_models} models from {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def tune_surrogate_for_study(
|
||||||
|
fea_data: List[Dict],
|
||||||
|
design_var_names: List[str],
|
||||||
|
objective_names: List[str],
|
||||||
|
n_tuning_trials: int = 50,
|
||||||
|
n_ensemble_models: int = 5
|
||||||
|
) -> TunedEnsembleSurrogate:
|
||||||
|
"""
|
||||||
|
Convenience function to tune and create ensemble surrogate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fea_data: List of FEA results with 'params' and 'objectives' keys
|
||||||
|
design_var_names: List of design variable names
|
||||||
|
objective_names: List of objective names
|
||||||
|
n_tuning_trials: Number of Optuna trials
|
||||||
|
n_ensemble_models: Number of models in ensemble
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trained TunedEnsembleSurrogate
|
||||||
|
"""
|
||||||
|
# Prepare data
|
||||||
|
X = np.array([[d['params'][name] for name in design_var_names] for d in fea_data])
|
||||||
|
Y = np.array([[d['objectives'][name] for name in objective_names] for d in fea_data])
|
||||||
|
|
||||||
|
logger.info(f"Tuning surrogate on {len(X)} samples...")
|
||||||
|
logger.info(f"Input: {len(design_var_names)} design variables")
|
||||||
|
logger.info(f"Output: {len(objective_names)} objectives")
|
||||||
|
|
||||||
|
# Tune hyperparameters
|
||||||
|
tuner = SurrogateHyperparameterTuner(
|
||||||
|
input_dim=len(design_var_names),
|
||||||
|
output_dim=len(objective_names),
|
||||||
|
n_trials=n_tuning_trials,
|
||||||
|
n_cv_folds=5
|
||||||
|
)
|
||||||
|
|
||||||
|
best_config = tuner.tune(X, Y, output_names=objective_names)
|
||||||
|
|
||||||
|
# Create and train ensemble
|
||||||
|
ensemble = TunedEnsembleSurrogate(
|
||||||
|
config=best_config,
|
||||||
|
input_dim=len(design_var_names),
|
||||||
|
output_dim=len(objective_names),
|
||||||
|
n_models=n_ensemble_models
|
||||||
|
)
|
||||||
|
|
||||||
|
ensemble.train(X, Y)
|
||||||
|
|
||||||
|
return ensemble
|
||||||
217
studies/m1_mirror_zernike_optimization/STUDY_REPORT.md
Normal file
217
studies/m1_mirror_zernike_optimization/STUDY_REPORT.md
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
# M1 Mirror Zernike Optimization Report
|
||||||
|
|
||||||
|
**Study**: m1_mirror_zernike_optimization
|
||||||
|
**Generated**: 2025-12-04
|
||||||
|
**Protocol**: Protocol 12 (Hybrid FEA/Neural with Zernike)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Executive Summary
|
||||||
|
|
||||||
|
This optimization study aimed to minimize wavefront error (WFE) in the M1 telescope primary mirror support structure across different gravity orientations. The optimization achieved a **9x improvement** in the weighted objective function compared to early trials, finding configurations that significantly reduce optical aberrations.
|
||||||
|
|
||||||
|
### Key Results
|
||||||
|
|
||||||
|
| Metric | Baseline Region | Optimized | Improvement |
|
||||||
|
|--------|-----------------|-----------|-------------|
|
||||||
|
| Weighted Objective | ~13.5 | **1.49** | **89% reduction** |
|
||||||
|
| WFE @ 40° vs 20° | ~87 nm | **6.1 nm** | 93% reduction |
|
||||||
|
| WFE @ 60° vs 20° | ~73 nm | **14.4 nm** | 80% reduction |
|
||||||
|
| Optician Workload @ 90° | ~51 nm | **30.5 nm** | 40% reduction |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Study Overview
|
||||||
|
|
||||||
|
### 1.1 Objective
|
||||||
|
|
||||||
|
Optimize the whiffle tree support structure geometry to minimize wavefront error across telescope elevation angles (20°, 40°, 60°, 90°), ensuring consistent optical performance from horizon to zenith.
|
||||||
|
|
||||||
|
### 1.2 Design Variables (3 active)
|
||||||
|
|
||||||
|
| Parameter | Min | Max | Baseline | Optimized | Change |
|
||||||
|
|-----------|-----|-----|----------|-----------|--------|
|
||||||
|
| whiffle_min | 35.0 mm | 55.0 mm | 40.55 mm | **49.39 mm** | +21.8% |
|
||||||
|
| whiffle_outer_to_vertical | 68.0° | 80.0° | 75.67° | **71.64°** | -5.3% |
|
||||||
|
| inner_circular_rib_dia | 480 mm | 620 mm | 534.0 mm | **497.8 mm** | -6.8% |
|
||||||
|
|
||||||
|
### 1.3 Optimization Objectives
|
||||||
|
|
||||||
|
| Objective | Description | Weight | Target | Best Achieved |
|
||||||
|
|-----------|-------------|--------|--------|---------------|
|
||||||
|
| rel_filtered_rms_40_vs_20 | Filtered RMS WFE at 40° relative to 20° | 5.0 | 4 nm | **6.10 nm** |
|
||||||
|
| rel_filtered_rms_60_vs_20 | Filtered RMS WFE at 60° relative to 20° | 5.0 | 10 nm | **14.38 nm** |
|
||||||
|
| mfg_90_optician_workload | Optician workload at 90° (J4+ filtered RMS) | 1.0 | 20 nm | **30.47 nm** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Trial Statistics
|
||||||
|
|
||||||
|
| Category | Count |
|
||||||
|
|----------|-------|
|
||||||
|
| Total Trials | 54 |
|
||||||
|
| Completed | 21 |
|
||||||
|
| Failed | 10 |
|
||||||
|
| Running/Pending | 23 |
|
||||||
|
|
||||||
|
### 2.1 Trial Distribution
|
||||||
|
|
||||||
|
- **Trials 0-12**: Initial exploration phase with high objective values (~13.5)
|
||||||
|
- **Trials 14-15**: Anomalous results (likely simulation issues)
|
||||||
|
- **Trial 20**: First significant improvement (2.15 weighted objective)
|
||||||
|
- **Trials 40-46**: Convergence region with best results (~1.49)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Best Configuration
|
||||||
|
|
||||||
|
### Trial 40 (Optimal)
|
||||||
|
|
||||||
|
**Weighted Objective**: 1.4852
|
||||||
|
|
||||||
|
#### Design Parameters
|
||||||
|
|
||||||
|
| Parameter | Value | Units |
|
||||||
|
|-----------|-------|-------|
|
||||||
|
| whiffle_min | 49.393 | mm |
|
||||||
|
| whiffle_outer_to_vertical | 71.635 | degrees |
|
||||||
|
| inner_circular_rib_dia | 497.838 | mm |
|
||||||
|
|
||||||
|
#### Individual Objectives
|
||||||
|
|
||||||
|
| Objective | Value | Target | Status |
|
||||||
|
|-----------|-------|--------|--------|
|
||||||
|
| rel_filtered_rms_40_vs_20 | 6.10 nm | 4 nm | Close to target |
|
||||||
|
| rel_filtered_rms_60_vs_20 | 14.38 nm | 10 nm | Close to target |
|
||||||
|
| mfg_90_optician_workload | 30.47 nm | 20 nm | Within 1.5× target |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Top 5 Configurations
|
||||||
|
|
||||||
|
| Rank | Trial | Weighted Obj | whiffle_min | whiffle_outer_to_vertical | inner_circular_rib_dia |
|
||||||
|
|------|-------|--------------|-------------|---------------------------|------------------------|
|
||||||
|
| 1 | 40 | 1.4852 | 49.39 mm | 71.64° | 497.8 mm |
|
||||||
|
| 2 | 41 | 1.4852 | 49.01 mm | 74.11° | 522.6 mm |
|
||||||
|
| 3 | 42 | 1.4852 | 48.58 mm | 73.68° | 523.5 mm |
|
||||||
|
| 4 | 43 | 1.4852 | 49.41 mm | 74.07° | 511.5 mm |
|
||||||
|
| 5 | 46 | 1.4852 | 46.98 mm | 76.52° | 498.6 mm |
|
||||||
|
|
||||||
|
**Note**: Multiple configurations achieve the same optimal objective value, indicating a relatively flat optimum region. This provides manufacturing flexibility.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Parameter Insights
|
||||||
|
|
||||||
|
### 5.1 whiffle_min (Whiffle Tree Minimum Parameter)
|
||||||
|
|
||||||
|
- **Trend**: Optimal values cluster around **47-50 mm** (upper half of range)
|
||||||
|
- **Baseline**: 40.55 mm was suboptimal
|
||||||
|
- **Recommendation**: Increase whiffle_min to ~49 mm for best performance
|
||||||
|
|
||||||
|
### 5.2 whiffle_outer_to_vertical (Outer Support Angle)
|
||||||
|
|
||||||
|
- **Trend**: Optimal range spans **71.6° to 76.5°**
|
||||||
|
- **Baseline**: 75.67° was near the upper optimal bound
|
||||||
|
- **Recommendation**: Maintain flexibility; angle has moderate sensitivity
|
||||||
|
|
||||||
|
### 5.3 inner_circular_rib_dia (Inner Rib Diameter)
|
||||||
|
|
||||||
|
- **Trend**: Optimal values range from **497-524 mm** (lower half of range)
|
||||||
|
- **Baseline**: 534 mm was slightly high
|
||||||
|
- **Recommendation**: Reduce rib diameter to ~500-510 mm
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Convergence Analysis
|
||||||
|
|
||||||
|
```
|
||||||
|
Weighted Objective vs Trial Number
|
||||||
|
|
||||||
|
13.5 |■■■■■■■■■■■■■
|
||||||
|
|
|
||||||
|
|
|
||||||
|
5.0 |
|
||||||
|
|
|
||||||
|
2.1 | ■
|
||||||
|
1.5 | ■■■■■
|
||||||
|
+------------------------------------>
|
||||||
|
0 10 20 30 40 50
|
||||||
|
Trial Number
|
||||||
|
```
|
||||||
|
|
||||||
|
The optimization showed clear convergence:
|
||||||
|
- **Phase 1** (Trials 0-12): Exploration at ~13.5 weighted objective
|
||||||
|
- **Phase 2** (Trials 14-15): Anomalous results (possible simulation errors)
|
||||||
|
- **Phase 3** (Trial 20): First breakthrough to 2.15
|
||||||
|
- **Phase 4** (Trials 40+): Converged optimum at 1.49
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Recommendations
|
||||||
|
|
||||||
|
### 7.1 Recommended Production Configuration
|
||||||
|
|
||||||
|
Based on the optimization results, the recommended design parameters are:
|
||||||
|
|
||||||
|
| Parameter | Recommended Value | Tolerance |
|
||||||
|
|-----------|-------------------|-----------|
|
||||||
|
| whiffle_min | **49.4 mm** | ±2 mm |
|
||||||
|
| whiffle_outer_to_vertical | **71.6° - 74.1°** | ±2° |
|
||||||
|
| inner_circular_rib_dia | **500 - 520 mm** | ±20 mm |
|
||||||
|
|
||||||
|
### 7.2 Performance Expectations
|
||||||
|
|
||||||
|
With the optimized configuration, expect:
|
||||||
|
- **6.1 nm RMS** wavefront error change from 20° to 40° elevation
|
||||||
|
- **14.4 nm RMS** wavefront error change from 20° to 60° elevation
|
||||||
|
- **30.5 nm RMS** optician workload at 90° orientation
|
||||||
|
|
||||||
|
### 7.3 Next Steps
|
||||||
|
|
||||||
|
1. **Validate with FEA**: Run confirmation analysis at recommended parameters
|
||||||
|
2. **Manufacturing Review**: Verify proposed geometry is manufacturable
|
||||||
|
3. **Sensitivity Analysis**: Explore parameter tolerances more thoroughly
|
||||||
|
4. **Extended Optimization**: Consider enabling additional design variables for further improvement
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Technical Notes
|
||||||
|
|
||||||
|
### 8.1 Zernike Analysis
|
||||||
|
|
||||||
|
- **Number of modes**: 50 (Noll indexing)
|
||||||
|
- **Filtered modes**: J1-J4 excluded (piston, tip, tilt, defocus - correctable by alignment)
|
||||||
|
- **Reference orientation**: 20° zenith angle (Subcase 2)
|
||||||
|
|
||||||
|
### 8.2 Weighted Sum Formula
|
||||||
|
|
||||||
|
The weighted objective combines three metrics:
|
||||||
|
|
||||||
|
$$J = \sum_{i=1}^{3} \frac{w_i \cdot f_i}{t_i}$$
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- $w_i$ = weight (5.0, 5.0, 1.0)
|
||||||
|
- $f_i$ = objective value (nm)
|
||||||
|
- $t_i$ = target value (4, 10, 20 nm)
|
||||||
|
|
||||||
|
### 8.3 Algorithm
|
||||||
|
|
||||||
|
- **Optimizer**: TPE (Tree-structured Parzen Estimator)
|
||||||
|
- **Startup trials**: 15 random
|
||||||
|
- **EI candidates**: 150
|
||||||
|
- **Multivariate modeling**: Enabled
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Files
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `2_results/study.db` | Optuna SQLite database with all trial data |
|
||||||
|
| `1_setup/optimization_config.json` | Study configuration |
|
||||||
|
| `run_optimization.py` | Main optimization script |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Report generated by Atomizer Optimization Framework*
|
||||||
Reference in New Issue
Block a user