import { create } from 'zustand'; import { Node, Edge, addEdge, applyNodeChanges, applyEdgeChanges, Connection, NodeChange, EdgeChange } from 'reactflow'; import { CanvasNodeData, NodeType } from '../lib/canvas/schema'; import { validateGraph, ValidationResult } from '../lib/canvas/validation'; import { serializeToIntent, OptimizationIntent } from '../lib/canvas/intent'; interface CanvasState { nodes: Node[]; edges: Edge[]; selectedNode: string | null; validation: ValidationResult; // Actions onNodesChange: (changes: NodeChange[]) => void; onEdgesChange: (changes: EdgeChange[]) => void; onConnect: (connection: Connection) => void; addNode: (type: NodeType, position: { x: number; y: number }) => void; updateNodeData: (nodeId: string, data: Partial) => void; selectNode: (nodeId: string | null) => void; deleteSelected: () => void; validate: () => ValidationResult; toIntent: () => OptimizationIntent; clear: () => void; loadFromIntent: (intent: OptimizationIntent) => void; loadFromConfig: (config: OptimizationConfig) => void; } // Optimization config structure (from optimization_config.json) export interface OptimizationConfig { study_name?: string; model?: { path?: string; type?: string; }; solver?: { type?: string; solution?: number; }; design_variables?: Array<{ name: string; expression_name?: string; lower: number; upper: number; type?: string; }>; objectives?: Array<{ name: string; direction?: string; weight?: number; extractor?: string; }>; constraints?: Array<{ name: string; type?: string; value?: number; extractor?: string; }>; method?: string; max_trials?: number; surrogate?: { type?: string; min_trials?: number; } | null; } let nodeIdCounter = 0; const getNodeId = () => `node_${++nodeIdCounter}`; const getDefaultData = (type: NodeType): CanvasNodeData => { const base = { label: type.charAt(0).toUpperCase() + type.slice(1), configured: false }; switch (type) { case 'model': return { ...base, type: 'model' }; case 'solver': return { ...base, type: 'solver' }; case 'designVar': return { ...base, type: 'designVar', label: 'Design Variable' }; case 'extractor': return { ...base, type: 'extractor' }; case 'objective': return { ...base, type: 'objective' }; case 'constraint': return { ...base, type: 'constraint' }; case 'algorithm': return { ...base, type: 'algorithm' }; case 'surrogate': return { ...base, type: 'surrogate', enabled: false }; default: return { ...base, type } as CanvasNodeData; } }; // Layout constants for auto-arrangement const LAYOUT = { startX: 100, startY: 100, colWidth: 250, rowHeight: 120, }; export const useCanvasStore = create((set, get) => ({ nodes: [], edges: [], selectedNode: null, validation: { valid: false, errors: [], warnings: [] }, onNodesChange: (changes) => { set({ nodes: applyNodeChanges(changes, get().nodes) }); }, onEdgesChange: (changes) => { set({ edges: applyEdgeChanges(changes, get().edges) }); }, onConnect: (connection) => { set({ edges: addEdge(connection, get().edges) }); }, addNode: (type, position) => { const newNode: Node = { id: getNodeId(), type, position, data: getDefaultData(type), }; set({ nodes: [...get().nodes, newNode] }); }, updateNodeData: (nodeId, data) => { set({ nodes: get().nodes.map((node) => node.id === nodeId ? { ...node, data: { ...node.data, ...data } } : node ), }); }, selectNode: (nodeId) => { set({ selectedNode: nodeId }); }, deleteSelected: () => { const { selectedNode, nodes, edges } = get(); if (!selectedNode) return; set({ nodes: nodes.filter((n) => n.id !== selectedNode), edges: edges.filter((e) => e.source !== selectedNode && e.target !== selectedNode), selectedNode: null, }); }, validate: () => { const { nodes, edges } = get(); const result = validateGraph(nodes, edges); set({ validation: result }); return result; }, toIntent: () => { const { nodes, edges } = get(); return serializeToIntent(nodes, edges); }, clear: () => { set({ nodes: [], edges: [], selectedNode: null, validation: { valid: false, errors: [], warnings: [] }, }); nodeIdCounter = 0; }, loadFromIntent: (intent) => { // Clear existing nodeIdCounter = 0; const nodes: Node[] = []; const edges: Edge[] = []; let col = 0; let row = 0; // Helper to create positioned node const createNode = (type: NodeType, data: Partial, colOffset = 0): string => { const id = getNodeId(); nodes.push({ id, type, position: { x: LAYOUT.startX + (col + colOffset) * LAYOUT.colWidth, y: LAYOUT.startY + row * LAYOUT.rowHeight, }, data: { ...getDefaultData(type), ...data, configured: true } as CanvasNodeData, }); return id; }; // Model node (column 0) col = 0; const modelId = createNode('model', { label: 'Model', filePath: intent.model?.path, fileType: intent.model?.type as 'prt' | 'fem' | 'sim' | undefined, }); // Solver node (column 1) col = 1; const solverId = createNode('solver', { label: 'Solver', solverType: intent.solver?.type as any, }); edges.push({ id: `e_${modelId}_${solverId}`, source: modelId, target: solverId }); // Design variables (column 0, multiple rows) col = 0; row = 1; const dvIds: string[] = []; for (const dv of intent.design_variables || []) { const dvId = createNode('designVar', { label: dv.name, expressionName: dv.name, minValue: dv.min, maxValue: dv.max, unit: dv.unit, }); dvIds.push(dvId); edges.push({ id: `e_${dvId}_${modelId}`, source: dvId, target: modelId }); row++; } // Extractors (column 2) col = 2; row = 0; const extractorMap: Record = {}; for (const ext of intent.extractors || []) { const extId = createNode('extractor', { label: ext.name, extractorId: ext.id, extractorName: ext.name, config: ext.config, }); extractorMap[ext.id] = extId; edges.push({ id: `e_${solverId}_${extId}`, source: solverId, target: extId }); row++; } // Objectives (column 3) col = 3; row = 0; const objIds: string[] = []; for (const obj of intent.objectives || []) { const objId = createNode('objective', { label: obj.name, name: obj.name, direction: obj.direction, weight: obj.weight, }); objIds.push(objId); // Connect to extractor if specified if (obj.extractor && extractorMap[obj.extractor]) { edges.push({ id: `e_${extractorMap[obj.extractor]}_${objId}`, source: extractorMap[obj.extractor], target: objId }); } row++; } // Constraints (column 3, after objectives) const conIds: string[] = []; for (const con of intent.constraints || []) { const conId = createNode('constraint', { label: con.name, name: con.name, operator: con.operator as any, value: con.value, }); conIds.push(conId); if (con.extractor && extractorMap[con.extractor]) { edges.push({ id: `e_${extractorMap[con.extractor]}_${conId}`, source: extractorMap[con.extractor], target: conId }); } row++; } // Algorithm (column 4) col = 4; row = 0; const algoId = createNode('algorithm', { label: 'Algorithm', method: intent.optimization?.method as any, maxTrials: intent.optimization?.max_trials, }); // Connect all objectives and constraints to algorithm for (const objId of objIds) { edges.push({ id: `e_${objId}_${algoId}`, source: objId, target: algoId }); } for (const conId of conIds) { edges.push({ id: `e_${conId}_${algoId}`, source: conId, target: algoId }); } // Surrogate (column 5, optional) if (intent.surrogate?.enabled) { col = 5; const surId = createNode('surrogate', { label: 'Surrogate', enabled: true, modelType: intent.surrogate.type as any, minTrials: intent.surrogate.min_trials, }); edges.push({ id: `e_${algoId}_${surId}`, source: algoId, target: surId }); } set({ nodes, edges, selectedNode: null, validation: { valid: false, errors: [], warnings: [] }, }); }, loadFromConfig: (config) => { // Convert optimization_config.json format to intent format, then load const intent: OptimizationIntent = { version: '1.0', source: 'canvas', timestamp: new Date().toISOString(), model: { path: config.model?.path, type: config.model?.type, }, solver: { type: config.solver?.solution ? `SOL${config.solver.solution}` : undefined, }, design_variables: (config.design_variables || []).map(dv => ({ name: dv.expression_name || dv.name, min: dv.lower, max: dv.upper, })), extractors: [], // Will be inferred from objectives objectives: (config.objectives || []).map(obj => ({ name: obj.name, direction: (obj.direction as 'minimize' | 'maximize') || 'minimize', weight: obj.weight || 1, extractor: obj.extractor || '', })), constraints: (config.constraints || []).map(con => ({ name: con.name, operator: con.type === 'upper' ? '<=' : '>=', value: con.value || 0, extractor: con.extractor || '', })), optimization: { method: config.method, max_trials: config.max_trials, }, surrogate: config.surrogate ? { enabled: true, type: config.surrogate.type, min_trials: config.surrogate.min_trials, } : undefined, }; // Infer extractors from objectives and constraints const extractorIds = new Set(); for (const obj of intent.objectives) { if (obj.extractor) extractorIds.add(obj.extractor); } for (const con of intent.constraints) { if (con.extractor) extractorIds.add(con.extractor); } const extractorNames: Record = { 'E1': 'Displacement', 'E2': 'Frequency', 'E3': 'Solid Stress', 'E4': 'BDF Mass', 'E5': 'CAD Mass', 'E8': 'Zernike (OP2)', 'E9': 'Zernike (CSV)', 'E10': 'Zernike (RMS)', }; intent.extractors = Array.from(extractorIds).map(id => ({ id, name: extractorNames[id] || id, })); get().loadFromIntent(intent); }, }));