Files
Atomizer/atomizer-dashboard/frontend/src/hooks/useCanvasStore.ts

382 lines
11 KiB
TypeScript
Raw Normal View History

import { create } from 'zustand';
import { Node, Edge, addEdge, applyNodeChanges, applyEdgeChanges, Connection, NodeChange, EdgeChange } from 'reactflow';
import { CanvasNodeData, NodeType } from '../lib/canvas/schema';
import { validateGraph, ValidationResult } from '../lib/canvas/validation';
import { serializeToIntent, OptimizationIntent } from '../lib/canvas/intent';
interface CanvasState {
nodes: Node<CanvasNodeData>[];
edges: Edge[];
selectedNode: string | null;
validation: ValidationResult;
// Actions
onNodesChange: (changes: NodeChange[]) => void;
onEdgesChange: (changes: EdgeChange[]) => void;
onConnect: (connection: Connection) => void;
addNode: (type: NodeType, position: { x: number; y: number }) => void;
updateNodeData: (nodeId: string, data: Partial<CanvasNodeData>) => void;
selectNode: (nodeId: string | null) => void;
deleteSelected: () => void;
validate: () => ValidationResult;
toIntent: () => OptimizationIntent;
clear: () => void;
loadFromIntent: (intent: OptimizationIntent) => void;
loadFromConfig: (config: OptimizationConfig) => void;
}
// Optimization config structure (from optimization_config.json)
export interface OptimizationConfig {
study_name?: string;
model?: {
path?: string;
type?: string;
};
solver?: {
type?: string;
solution?: number;
};
design_variables?: Array<{
name: string;
expression_name?: string;
lower: number;
upper: number;
type?: string;
}>;
objectives?: Array<{
name: string;
direction?: string;
weight?: number;
extractor?: string;
}>;
constraints?: Array<{
name: string;
type?: string;
value?: number;
extractor?: string;
}>;
method?: string;
max_trials?: number;
surrogate?: {
type?: string;
min_trials?: number;
} | null;
}
let nodeIdCounter = 0;
const getNodeId = () => `node_${++nodeIdCounter}`;
const getDefaultData = (type: NodeType): CanvasNodeData => {
const base = { label: type.charAt(0).toUpperCase() + type.slice(1), configured: false };
switch (type) {
case 'model': return { ...base, type: 'model' };
case 'solver': return { ...base, type: 'solver' };
case 'designVar': return { ...base, type: 'designVar', label: 'Design Variable' };
case 'extractor': return { ...base, type: 'extractor' };
case 'objective': return { ...base, type: 'objective' };
case 'constraint': return { ...base, type: 'constraint' };
case 'algorithm': return { ...base, type: 'algorithm' };
case 'surrogate': return { ...base, type: 'surrogate', enabled: false };
default: return { ...base, type } as CanvasNodeData;
}
};
// Layout constants for auto-arrangement
const LAYOUT = {
startX: 100,
startY: 100,
colWidth: 250,
rowHeight: 120,
};
export const useCanvasStore = create<CanvasState>((set, get) => ({
nodes: [],
edges: [],
selectedNode: null,
validation: { valid: false, errors: [], warnings: [] },
onNodesChange: (changes) => {
set({ nodes: applyNodeChanges(changes, get().nodes) });
},
onEdgesChange: (changes) => {
set({ edges: applyEdgeChanges(changes, get().edges) });
},
onConnect: (connection) => {
set({ edges: addEdge(connection, get().edges) });
},
addNode: (type, position) => {
const newNode: Node<CanvasNodeData> = {
id: getNodeId(),
type,
position,
data: getDefaultData(type),
};
set({ nodes: [...get().nodes, newNode] });
},
updateNodeData: (nodeId, data) => {
set({
nodes: get().nodes.map((node) =>
node.id === nodeId
? { ...node, data: { ...node.data, ...data } }
: node
),
});
},
selectNode: (nodeId) => {
set({ selectedNode: nodeId });
},
deleteSelected: () => {
const { selectedNode, nodes, edges } = get();
if (!selectedNode) return;
set({
nodes: nodes.filter((n) => n.id !== selectedNode),
edges: edges.filter((e) => e.source !== selectedNode && e.target !== selectedNode),
selectedNode: null,
});
},
validate: () => {
const { nodes, edges } = get();
const result = validateGraph(nodes, edges);
set({ validation: result });
return result;
},
toIntent: () => {
const { nodes, edges } = get();
return serializeToIntent(nodes, edges);
},
clear: () => {
set({
nodes: [],
edges: [],
selectedNode: null,
validation: { valid: false, errors: [], warnings: [] },
});
nodeIdCounter = 0;
},
loadFromIntent: (intent) => {
// Clear existing
nodeIdCounter = 0;
const nodes: Node<CanvasNodeData>[] = [];
const edges: Edge[] = [];
let col = 0;
let row = 0;
// Helper to create positioned node
const createNode = (type: NodeType, data: Partial<CanvasNodeData>, colOffset = 0): string => {
const id = getNodeId();
nodes.push({
id,
type,
position: {
x: LAYOUT.startX + (col + colOffset) * LAYOUT.colWidth,
y: LAYOUT.startY + row * LAYOUT.rowHeight,
},
data: { ...getDefaultData(type), ...data, configured: true } as CanvasNodeData,
});
return id;
};
// Model node (column 0)
col = 0;
const modelId = createNode('model', {
label: 'Model',
filePath: intent.model?.path,
fileType: intent.model?.type as 'prt' | 'fem' | 'sim' | undefined,
});
// Solver node (column 1)
col = 1;
const solverId = createNode('solver', {
label: 'Solver',
solverType: intent.solver?.type as any,
});
edges.push({ id: `e_${modelId}_${solverId}`, source: modelId, target: solverId });
// Design variables (column 0, multiple rows)
col = 0;
row = 1;
const dvIds: string[] = [];
for (const dv of intent.design_variables || []) {
const dvId = createNode('designVar', {
label: dv.name,
expressionName: dv.name,
minValue: dv.min,
maxValue: dv.max,
unit: dv.unit,
});
dvIds.push(dvId);
edges.push({ id: `e_${dvId}_${modelId}`, source: dvId, target: modelId });
row++;
}
// Extractors (column 2)
col = 2;
row = 0;
const extractorMap: Record<string, string> = {};
for (const ext of intent.extractors || []) {
const extId = createNode('extractor', {
label: ext.name,
extractorId: ext.id,
extractorName: ext.name,
config: ext.config,
});
extractorMap[ext.id] = extId;
edges.push({ id: `e_${solverId}_${extId}`, source: solverId, target: extId });
row++;
}
// Objectives (column 3)
col = 3;
row = 0;
const objIds: string[] = [];
for (const obj of intent.objectives || []) {
const objId = createNode('objective', {
label: obj.name,
name: obj.name,
direction: obj.direction,
weight: obj.weight,
});
objIds.push(objId);
// Connect to extractor if specified
if (obj.extractor && extractorMap[obj.extractor]) {
edges.push({ id: `e_${extractorMap[obj.extractor]}_${objId}`, source: extractorMap[obj.extractor], target: objId });
}
row++;
}
// Constraints (column 3, after objectives)
const conIds: string[] = [];
for (const con of intent.constraints || []) {
const conId = createNode('constraint', {
label: con.name,
name: con.name,
operator: con.operator as any,
value: con.value,
});
conIds.push(conId);
if (con.extractor && extractorMap[con.extractor]) {
edges.push({ id: `e_${extractorMap[con.extractor]}_${conId}`, source: extractorMap[con.extractor], target: conId });
}
row++;
}
// Algorithm (column 4)
col = 4;
row = 0;
const algoId = createNode('algorithm', {
label: 'Algorithm',
method: intent.optimization?.method as any,
maxTrials: intent.optimization?.max_trials,
});
// Connect all objectives and constraints to algorithm
for (const objId of objIds) {
edges.push({ id: `e_${objId}_${algoId}`, source: objId, target: algoId });
}
for (const conId of conIds) {
edges.push({ id: `e_${conId}_${algoId}`, source: conId, target: algoId });
}
// Surrogate (column 5, optional)
if (intent.surrogate?.enabled) {
col = 5;
const surId = createNode('surrogate', {
label: 'Surrogate',
enabled: true,
modelType: intent.surrogate.type as any,
minTrials: intent.surrogate.min_trials,
});
edges.push({ id: `e_${algoId}_${surId}`, source: algoId, target: surId });
}
set({
nodes,
edges,
selectedNode: null,
validation: { valid: false, errors: [], warnings: [] },
});
},
loadFromConfig: (config) => {
// Convert optimization_config.json format to intent format, then load
const intent: OptimizationIntent = {
version: '1.0',
source: 'canvas',
timestamp: new Date().toISOString(),
model: {
path: config.model?.path,
type: config.model?.type,
},
solver: {
type: config.solver?.solution ? `SOL${config.solver.solution}` : undefined,
},
design_variables: (config.design_variables || []).map(dv => ({
name: dv.expression_name || dv.name,
min: dv.lower,
max: dv.upper,
})),
extractors: [], // Will be inferred from objectives
objectives: (config.objectives || []).map(obj => ({
name: obj.name,
direction: (obj.direction as 'minimize' | 'maximize') || 'minimize',
weight: obj.weight || 1,
extractor: obj.extractor || '',
})),
constraints: (config.constraints || []).map(con => ({
name: con.name,
operator: con.type === 'upper' ? '<=' : '>=',
value: con.value || 0,
extractor: con.extractor || '',
})),
optimization: {
method: config.method,
max_trials: config.max_trials,
},
surrogate: config.surrogate ? {
enabled: true,
type: config.surrogate.type,
min_trials: config.surrogate.min_trials,
} : undefined,
};
// Infer extractors from objectives and constraints
const extractorIds = new Set<string>();
for (const obj of intent.objectives) {
if (obj.extractor) extractorIds.add(obj.extractor);
}
for (const con of intent.constraints) {
if (con.extractor) extractorIds.add(con.extractor);
}
const extractorNames: Record<string, string> = {
'E1': 'Displacement',
'E2': 'Frequency',
'E3': 'Solid Stress',
'E4': 'BDF Mass',
'E5': 'CAD Mass',
'E8': 'Zernike (OP2)',
'E9': 'Zernike (CSV)',
'E10': 'Zernike (RMS)',
};
intent.extractors = Array.from(extractorIds).map(id => ({
id,
name: extractorNames[id] || id,
}));
get().loadFromIntent(intent);
},
}));