import { create } from 'zustand'; import { Node, Edge, addEdge, applyNodeChanges, applyEdgeChanges, Connection, NodeChange, EdgeChange } from 'reactflow'; import { CanvasNodeData, NodeType } from '../lib/canvas/schema'; import { validateGraph, ValidationResult } from '../lib/canvas/validation'; import { serializeToIntent, OptimizationIntent } from '../lib/canvas/intent'; interface CanvasState { nodes: Node[]; edges: Edge[]; selectedNode: string | null; selectedEdge: string | null; validation: ValidationResult; // Actions onNodesChange: (changes: NodeChange[]) => void; onEdgesChange: (changes: EdgeChange[]) => void; onConnect: (connection: Connection) => void; addNode: (type: NodeType, position: { x: number; y: number }) => void; updateNodeData: (nodeId: string, data: Partial) => void; selectNode: (nodeId: string | null) => void; selectEdge: (edgeId: string | null) => void; deleteSelected: () => void; deleteEdge: (edgeId: string) => void; validate: () => ValidationResult; toIntent: () => OptimizationIntent; clear: () => void; loadFromIntent: (intent: OptimizationIntent) => void; loadFromConfig: (config: OptimizationConfig) => void; } // Optimization config structure (from optimization_config.json) export interface OptimizationConfig { study_name?: string; model?: { path?: string; type?: string; }; solver?: { type?: string; solution?: number; }; design_variables?: Array<{ name: string; expression_name?: string; lower: number; upper: number; type?: string; }>; objectives?: Array<{ name: string; direction?: string; weight?: number; extractor?: string; }>; constraints?: Array<{ name: string; type?: string; value?: number; extractor?: string; }>; method?: string; max_trials?: number; surrogate?: { type?: string; min_trials?: number; } | null; } let nodeIdCounter = 0; const getNodeId = () => `node_${++nodeIdCounter}`; const getDefaultData = (type: NodeType): CanvasNodeData => { const base = { label: type.charAt(0).toUpperCase() + type.slice(1), configured: false }; switch (type) { case 'model': return { ...base, type: 'model' }; case 'solver': return { ...base, type: 'solver' }; case 'designVar': return { ...base, type: 'designVar', label: 'Design Variable' }; case 'extractor': return { ...base, type: 'extractor' }; case 'objective': return { ...base, type: 'objective' }; case 'constraint': return { ...base, type: 'constraint' }; case 'algorithm': return { ...base, type: 'algorithm' }; case 'surrogate': return { ...base, type: 'surrogate', enabled: false }; default: return { ...base, type } as CanvasNodeData; } }; // Layout constants for auto-arrangement const LAYOUT = { startX: 100, startY: 100, colWidth: 250, rowHeight: 120, }; export const useCanvasStore = create((set, get) => ({ nodes: [], edges: [], selectedNode: null, selectedEdge: null, validation: { valid: false, errors: [], warnings: [] }, onNodesChange: (changes) => { set({ nodes: applyNodeChanges(changes, get().nodes) }); }, onEdgesChange: (changes) => { set({ edges: applyEdgeChanges(changes, get().edges) }); }, onConnect: (connection) => { set({ edges: addEdge(connection, get().edges) }); }, addNode: (type, position) => { const newNode: Node = { id: getNodeId(), type, position, data: getDefaultData(type), }; set({ nodes: [...get().nodes, newNode] }); }, updateNodeData: (nodeId, data) => { set({ nodes: get().nodes.map((node) => node.id === nodeId ? { ...node, data: { ...node.data, ...data } } : node ), }); }, selectNode: (nodeId) => { set({ selectedNode: nodeId, selectedEdge: null }); }, selectEdge: (edgeId) => { set({ selectedEdge: edgeId, selectedNode: null }); }, deleteSelected: () => { const { selectedNode, selectedEdge, nodes, edges } = get(); // Delete selected edge if (selectedEdge) { set({ edges: edges.filter((e) => e.id !== selectedEdge), selectedEdge: null, }); return; } // Delete selected node if (!selectedNode) return; set({ nodes: nodes.filter((n) => n.id !== selectedNode), edges: edges.filter((e) => e.source !== selectedNode && e.target !== selectedNode), selectedNode: null, }); }, deleteEdge: (edgeId) => { set({ edges: get().edges.filter((e) => e.id !== edgeId), selectedEdge: null, }); }, validate: () => { const { nodes, edges } = get(); const result = validateGraph(nodes, edges); set({ validation: result }); return result; }, toIntent: () => { const { nodes, edges } = get(); return serializeToIntent(nodes, edges); }, clear: () => { set({ nodes: [], edges: [], selectedNode: null, selectedEdge: null, validation: { valid: false, errors: [], warnings: [] }, }); nodeIdCounter = 0; }, loadFromIntent: (intent) => { // Clear existing nodeIdCounter = 0; const nodes: Node[] = []; const edges: Edge[] = []; let col = 0; let row = 0; // Helper to create positioned node const createNode = (type: NodeType, data: Partial, colOffset = 0): string => { const id = getNodeId(); nodes.push({ id, type, position: { x: LAYOUT.startX + (col + colOffset) * LAYOUT.colWidth, y: LAYOUT.startY + row * LAYOUT.rowHeight, }, data: { ...getDefaultData(type), ...data, configured: true } as CanvasNodeData, }); return id; }; // Model node (column 0) col = 0; const modelId = createNode('model', { label: 'Model', filePath: intent.model?.path, fileType: intent.model?.type as 'prt' | 'fem' | 'sim' | undefined, }); // Solver node (column 1) col = 1; const solverId = createNode('solver', { label: 'Solver', solverType: intent.solver?.type as any, }); edges.push({ id: `e_${modelId}_${solverId}`, source: modelId, target: solverId }); // Design variables (column 0, multiple rows) col = 0; row = 1; const dvIds: string[] = []; for (const dv of intent.design_variables || []) { const dvId = createNode('designVar', { label: dv.name, expressionName: dv.name, minValue: dv.min, maxValue: dv.max, unit: dv.unit, }); dvIds.push(dvId); edges.push({ id: `e_${dvId}_${modelId}`, source: dvId, target: modelId }); row++; } // Extractors (column 2) col = 2; row = 0; const extractorMap: Record = {}; for (const ext of intent.extractors || []) { const extId = createNode('extractor', { label: ext.name, extractorId: ext.id, extractorName: ext.name, config: ext.config, }); extractorMap[ext.id] = extId; edges.push({ id: `e_${solverId}_${extId}`, source: solverId, target: extId }); row++; } // Objectives (column 3) col = 3; row = 0; const objIds: string[] = []; for (const obj of intent.objectives || []) { const objId = createNode('objective', { label: obj.name, name: obj.name, direction: obj.direction, weight: obj.weight, }); objIds.push(objId); // Connect to extractor if specified if (obj.extractor && extractorMap[obj.extractor]) { edges.push({ id: `e_${extractorMap[obj.extractor]}_${objId}`, source: extractorMap[obj.extractor], target: objId }); } row++; } // Constraints (column 3, after objectives) const conIds: string[] = []; for (const con of intent.constraints || []) { const conId = createNode('constraint', { label: con.name, name: con.name, operator: con.operator as any, value: con.value, }); conIds.push(conId); if (con.extractor && extractorMap[con.extractor]) { edges.push({ id: `e_${extractorMap[con.extractor]}_${conId}`, source: extractorMap[con.extractor], target: conId }); } row++; } // Algorithm (column 4) col = 4; row = 0; const algoId = createNode('algorithm', { label: 'Algorithm', method: intent.optimization?.method as any, maxTrials: intent.optimization?.max_trials, }); // Connect all objectives and constraints to algorithm for (const objId of objIds) { edges.push({ id: `e_${objId}_${algoId}`, source: objId, target: algoId }); } for (const conId of conIds) { edges.push({ id: `e_${conId}_${algoId}`, source: conId, target: algoId }); } // Surrogate (column 5, optional) if (intent.surrogate?.enabled) { col = 5; const surId = createNode('surrogate', { label: 'Surrogate', enabled: true, modelType: intent.surrogate.type as any, minTrials: intent.surrogate.min_trials, }); edges.push({ id: `e_${algoId}_${surId}`, source: algoId, target: surId }); } set({ nodes, edges, selectedNode: null, selectedEdge: null, validation: { valid: false, errors: [], warnings: [] }, }); }, loadFromConfig: (config) => { // Complete rewrite: Create all nodes and edges directly from config nodeIdCounter = 0; const nodes: Node[] = []; const edges: Edge[] = []; // Column positions for proper layout const COLS = { modelDvar: 50, solver: 280, extractor: 510, objCon: 740, algo: 970, surrogate: 1200, }; const ROW_HEIGHT = 100; const START_Y = 50; // Helper to create node const createNode = (type: NodeType, x: number, y: number, data: Partial): string => { const id = getNodeId(); nodes.push({ id, type, position: { x, y }, data: { ...getDefaultData(type), ...data, configured: true } as CanvasNodeData, }); return id; }; // 1. Model node const modelId = createNode('model', COLS.modelDvar, START_Y, { label: config.study_name || 'Model', filePath: config.model?.path, fileType: config.model?.type as 'prt' | 'fem' | 'sim' | undefined, }); // 2. Solver node const solverType = config.solver?.solution ? `SOL${config.solver.solution}` : undefined; const solverId = createNode('solver', COLS.solver, START_Y, { label: 'Solver', solverType: solverType as any, }); edges.push({ id: `e_model_solver`, source: modelId, target: solverId }); // 3. Design variables (column 0, below model) let dvRow = 1; for (const dv of config.design_variables || []) { const dvId = createNode('designVar', COLS.modelDvar, START_Y + dvRow * ROW_HEIGHT, { label: dv.expression_name || dv.name, expressionName: dv.expression_name || dv.name, minValue: dv.lower, maxValue: dv.upper, }); edges.push({ id: `e_dv_${dvRow}_model`, source: dvId, target: modelId }); dvRow++; } // 4. Extractors - infer from objectives and constraints const extractorNames: Record = { 'E1': 'Displacement', 'E2': 'Frequency', 'E3': 'Solid Stress', 'E4': 'BDF Mass', 'E5': 'CAD Mass', 'E8': 'Zernike (OP2)', 'E9': 'Zernike (CSV)', 'E10': 'Zernike (RMS)', }; const extractorIds = new Set(); for (const obj of config.objectives || []) { if (obj.extractor) extractorIds.add(obj.extractor); } for (const con of config.constraints || []) { if (con.extractor) extractorIds.add(con.extractor); } // If no extractors found, add a default based on objectives if (extractorIds.size === 0 && (config.objectives?.length || 0) > 0) { extractorIds.add('E5'); // Default to CAD Mass } let extRow = 0; const extractorMap: Record = {}; for (const extId of extractorIds) { const nodeId = createNode('extractor', COLS.extractor, START_Y + extRow * ROW_HEIGHT, { label: extractorNames[extId] || extId, extractorId: extId, extractorName: extractorNames[extId] || extId, }); extractorMap[extId] = nodeId; edges.push({ id: `e_solver_ext_${extRow}`, source: solverId, target: nodeId }); extRow++; } // 5. Objectives let objRow = 0; const objIds: string[] = []; for (const obj of config.objectives || []) { const objId = createNode('objective', COLS.objCon, START_Y + objRow * ROW_HEIGHT, { label: obj.name, name: obj.name, direction: (obj.direction as 'minimize' | 'maximize') || 'minimize', weight: obj.weight || 1, }); objIds.push(objId); // Connect to extractor const extNodeId = obj.extractor ? extractorMap[obj.extractor] : Object.values(extractorMap)[0]; if (extNodeId) { edges.push({ id: `e_ext_obj_${objRow}`, source: extNodeId, target: objId }); } objRow++; } // 6. Constraints let conRow = objRow; const conIds: string[] = []; for (const con of config.constraints || []) { const conId = createNode('constraint', COLS.objCon, START_Y + conRow * ROW_HEIGHT, { label: con.name, name: con.name, operator: (con.type === 'upper' ? '<=' : '>=') as any, value: con.value || 0, }); conIds.push(conId); // Connect to extractor const extNodeId = con.extractor ? extractorMap[con.extractor] : Object.values(extractorMap)[0]; if (extNodeId) { edges.push({ id: `e_ext_con_${conRow}`, source: extNodeId, target: conId }); } conRow++; } // 7. Algorithm node const method = config.method || (config as any).optimization?.sampler || 'TPE'; const maxTrials = config.max_trials || (config as any).optimization?.n_trials || 100; const algoId = createNode('algorithm', COLS.algo, START_Y, { label: 'Algorithm', method: method as any, maxTrials: maxTrials, }); // Connect objectives to algorithm for (let i = 0; i < objIds.length; i++) { edges.push({ id: `e_obj_${i}_algo`, source: objIds[i], target: algoId }); } // Connect constraints to algorithm for (let i = 0; i < conIds.length; i++) { edges.push({ id: `e_con_${i}_algo`, source: conIds[i], target: algoId }); } // 8. Surrogate node (if enabled) if (config.surrogate) { const surId = createNode('surrogate', COLS.surrogate, START_Y, { label: 'Surrogate', enabled: true, modelType: config.surrogate.type as any, minTrials: config.surrogate.min_trials, }); edges.push({ id: `e_algo_sur`, source: algoId, target: surId }); } set({ nodes, edges, selectedNode: null, selectedEdge: null, validation: { valid: false, errors: [], warnings: [] }, }); }, }));