From 47f8b501121b08f63f92cdf557c6adc532bf38f1 Mon Sep 17 00:00:00 2001 From: Anto01 Date: Tue, 20 Jan 2026 14:14:14 -0500 Subject: [PATCH] fix(canvas): Bug fixes for node movement, drag-drop, config panel, and introspection - SpecRenderer: Add localNodes state with applyNodeChanges for smooth node dragging - SpecRenderer: Fix getDefaultNodeData() - extractor uses 'custom_function' type with function definition - SpecRenderer: Fix constraint default - use constraint_type instead of type - CanvasView: Show config panel INSTEAD of chat when node selected (not blocked) - NodeConfigPanelV2: Enable showHeader for code editor toolbar (Generate/Snippets/Validate/Test buttons) - NodeConfigPanelV2: Pass studyId to IntrospectionPanel - IntrospectionPanel: Accept studyId prop and use correct API endpoint - optimization.py: Search multiple directories for model files including 1_setup/model/ --- .../backend/api/routes/optimization.py | 1995 ++++++++++------- .../src/components/canvas/SpecRenderer.tsx | 43 +- .../canvas/panels/IntrospectionPanel.tsx | 39 +- .../canvas/panels/NodeConfigPanelV2.tsx | 49 +- .../frontend/src/pages/CanvasView.tsx | 10 +- 5 files changed, 1214 insertions(+), 922 deletions(-) diff --git a/atomizer-dashboard/backend/api/routes/optimization.py b/atomizer-dashboard/backend/api/routes/optimization.py index f4d10568..af870ee8 100644 --- a/atomizer-dashboard/backend/api/routes/optimization.py +++ b/atomizer-dashboard/backend/api/routes/optimization.py @@ -66,7 +66,7 @@ def resolve_study_path(study_id: str) -> Path: # Scan topic folders for nested structure for topic_dir in STUDIES_DIR.iterdir(): - if topic_dir.is_dir() and not topic_dir.name.startswith('.'): + if topic_dir.is_dir() and not topic_dir.name.startswith("."): study_dir = topic_dir / study_id if study_dir.exists() and study_dir.is_dir(): if _is_valid_study_dir(study_dir): @@ -78,9 +78,9 @@ def resolve_study_path(study_id: str) -> Path: def _is_valid_study_dir(study_dir: Path) -> bool: """Check if a directory is a valid study directory.""" return ( - (study_dir / "1_setup").exists() or - (study_dir / "optimization_config.json").exists() or - (study_dir / "atomizer_spec.json").exists() + (study_dir / "1_setup").exists() + or (study_dir / "optimization_config.json").exists() + or (study_dir / "atomizer_spec.json").exists() ) @@ -103,13 +103,13 @@ def is_optimization_running(study_id: str) -> bool: except HTTPException: return False - for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'cwd']): + for proc in psutil.process_iter(["pid", "name", "cmdline", "cwd"]): try: - cmdline = proc.info.get('cmdline') or [] - cmdline_str = ' '.join(cmdline) if cmdline else '' + cmdline = proc.info.get("cmdline") or [] + cmdline_str = " ".join(cmdline) if cmdline else "" # Check if this is a Python process running run_optimization.py for this study - if 'python' in cmdline_str.lower() and 'run_optimization' in cmdline_str: + if "python" in cmdline_str.lower() and "run_optimization" in cmdline_str: if study_id in cmdline_str or str(study_dir) in cmdline_str: return True except (psutil.NoSuchProcess, psutil.AccessDenied): @@ -118,7 +118,9 @@ def is_optimization_running(study_id: str) -> bool: return False -def get_accurate_study_status(study_id: str, trial_count: int, total_trials: int, has_db: bool) -> str: +def get_accurate_study_status( + study_id: str, trial_count: int, total_trials: int, has_db: bool +) -> str: """Determine accurate study status based on multiple factors. Status can be: @@ -218,14 +220,14 @@ def _load_study_info(study_dir: Path, topic: Optional[str] = None) -> Optional[d trial_count = len(history) if history: # Find best trial - best_trial = min(history, key=lambda x: x['objective']) - best_value = best_trial['objective'] + best_trial = min(history, key=lambda x: x["objective"]) + best_value = best_trial["objective"] # Get total trials from config (supports both formats) total_trials = ( - config.get('optimization_settings', {}).get('n_trials') or - config.get('optimization', {}).get('n_trials') or - config.get('trials', {}).get('n_trials', 50) + config.get("optimization_settings", {}).get("n_trials") + or config.get("optimization", {}).get("n_trials") + or config.get("trials", {}).get("n_trials", 50) ) # Get accurate status using process detection @@ -259,15 +261,12 @@ def _load_study_info(study_dir: Path, topic: Optional[str] = None) -> Optional[d "name": study_dir.name.replace("_", " ").title(), "topic": topic, # NEW: topic field for grouping "status": status, - "progress": { - "current": trial_count, - "total": total_trials - }, + "progress": {"current": trial_count, "total": total_trials}, "best_value": best_value, - "target": config.get('target', {}).get('value'), + "target": config.get("target", {}).get("value"), "path": str(study_dir), "created_at": created_at, - "last_modified": last_modified + "last_modified": last_modified, } @@ -289,7 +288,7 @@ async def list_studies(): for item in STUDIES_DIR.iterdir(): if not item.is_dir(): continue - if item.name.startswith('.'): + if item.name.startswith("."): continue # Check if this is a study (flat structure) or a topic folder (nested structure) @@ -306,11 +305,13 @@ async def list_studies(): for sub_item in item.iterdir(): if not sub_item.is_dir(): continue - if sub_item.name.startswith('.'): + if sub_item.name.startswith("."): continue # Check if this subdirectory is a study - sub_is_study = (sub_item / "1_setup").exists() or (sub_item / "optimization_config.json").exists() + sub_is_study = (sub_item / "1_setup").exists() or ( + sub_item / "optimization_config.json" + ).exists() if sub_is_study: study_info = _load_study_info(sub_item, topic=item.name) if study_info: @@ -321,6 +322,7 @@ async def list_studies(): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to list studies: {str(e)}") + @router.get("/studies/{study_id}/status") async def get_study_status(study_id: str): """Get detailed status of a specific study""" @@ -370,23 +372,26 @@ async def get_study_status(study_id: str): trial_id, trial_number, best_value = result # Get parameters for this trial - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) params = {row[0]: row[1] for row in cursor.fetchall()} best_trial = { "trial_number": trial_number, "objective": best_value, "design_variables": params, - "results": {"first_frequency": best_value} + "results": {"first_frequency": best_value}, } conn.close() - total_trials = config.get('optimization_settings', {}).get('n_trials', 50) + total_trials = config.get("optimization_settings", {}).get("n_trials", 50) status = get_accurate_study_status(study_id, trial_count, total_trials, True) return { @@ -395,11 +400,11 @@ async def get_study_status(study_id: str): "progress": { "current": trial_count, "total": total_trials, - "percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0 + "percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0, }, "best_trial": best_trial, "pruned_trials": pruned_count, - "config": config + "config": config, } # Legacy: Read from JSON history @@ -407,20 +412,20 @@ async def get_study_status(study_id: str): return { "study_id": study_id, "status": "not_started", - "progress": {"current": 0, "total": config.get('trials', {}).get('n_trials', 50)}, - "config": config + "progress": {"current": 0, "total": config.get("trials", {}).get("n_trials", 50)}, + "config": config, } with open(history_file) as f: history = json.load(f) trial_count = len(history) - total_trials = config.get('trials', {}).get('n_trials', 50) + total_trials = config.get("trials", {}).get("n_trials", 50) # Find best trial best_trial = None if history: - best_trial = min(history, key=lambda x: x['objective']) + best_trial = min(history, key=lambda x: x["objective"]) # Check for pruning data pruning_file = results_dir / "pruning_history.json" @@ -438,11 +443,11 @@ async def get_study_status(study_id: str): "progress": { "current": trial_count, "total": total_trials, - "percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0 + "percentage": (trial_count / total_trials * 100) if total_trials > 0 else 0, }, "best_trial": best_trial, "pruned_trials": pruned_count, - "config": config + "config": config, } except FileNotFoundError: @@ -450,6 +455,7 @@ async def get_study_status(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get study status: {str(e)}") + @router.get("/studies/{study_id}/history") async def get_optimization_history(study_id: str, limit: Optional[int] = None): """Get optimization history (all trials)""" @@ -467,33 +473,42 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): # Get all completed trials FROM ALL STUDIES in the database # This handles adaptive optimizations that create multiple Optuna studies # (e.g., v11_fea for FEA trials, v11_iter1_nn for NN trials, etc.) - cursor.execute(""" + cursor.execute( + """ SELECT t.trial_id, t.number, t.datetime_start, t.datetime_complete, s.study_name FROM trials t JOIN studies s ON t.study_id = s.study_id WHERE t.state = 'COMPLETE' ORDER BY t.datetime_start DESC - """ + (f" LIMIT {limit}" if limit else "")) + """ + + (f" LIMIT {limit}" if limit else "") + ) trial_rows = cursor.fetchall() trials = [] for trial_id, trial_num, start_time, end_time, study_name in trial_rows: # Get objectives for this trial - cursor.execute(""" + cursor.execute( + """ SELECT value FROM trial_values WHERE trial_id = ? ORDER BY objective - """, (trial_id,)) + """, + (trial_id,), + ) values = [row[0] for row in cursor.fetchall()] # Get parameters for this trial - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) params = {} for param_name, param_value in cursor.fetchall(): try: @@ -502,11 +517,14 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): params[param_name] = param_value # Get user attributes (extracted results: mass, frequency, stress, displacement, etc.) - cursor.execute(""" + cursor.execute( + """ SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) user_attrs = {} for key, value_json in cursor.fetchall(): try: @@ -524,7 +542,9 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): # Include numeric values and lists of numbers if isinstance(val, (int, float)): results[key] = val - elif isinstance(val, list) and len(val) > 0 and isinstance(val[0], (int, float)): + elif ( + isinstance(val, list) and len(val) > 0 and isinstance(val[0], (int, float)) + ): # For lists, store as-is (e.g., Zernike coefficients) results[key] = val elif key == "objectives" and isinstance(val, dict): @@ -541,7 +561,9 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): design_vars_from_attrs = user_attrs.get("design_vars", {}) # Merge with params (prefer user_attrs design_vars if available) - final_design_vars = {**params, **design_vars_from_attrs} if design_vars_from_attrs else params + final_design_vars = ( + {**params, **design_vars_from_attrs} if design_vars_from_attrs else params + ) # Extract source for FEA vs NN differentiation source = user_attrs.get("source", "FEA") # Default to FEA for legacy studies @@ -553,20 +575,24 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): # trial_id is unique across all studies in the database unique_trial_num = iter_num if iter_num is not None else trial_id - trials.append({ - "trial_number": unique_trial_num, - "trial_id": trial_id, # Keep original for debugging - "optuna_trial_num": trial_num, # Keep original Optuna trial number - "objective": values[0] if len(values) > 0 else None, # Primary objective - "objectives": values if len(values) > 1 else None, # All objectives for multi-objective - "design_variables": final_design_vars, # Use merged design vars - "results": results, - "user_attrs": user_attrs, # Include all user attributes - "source": source, # FEA or NN - "start_time": start_time, - "end_time": end_time, - "study_name": study_name # Include for debugging - }) + trials.append( + { + "trial_number": unique_trial_num, + "trial_id": trial_id, # Keep original for debugging + "optuna_trial_num": trial_num, # Keep original Optuna trial number + "objective": values[0] if len(values) > 0 else None, # Primary objective + "objectives": values + if len(values) > 1 + else None, # All objectives for multi-objective + "design_variables": final_design_vars, # Use merged design vars + "results": results, + "user_attrs": user_attrs, # Include all user attributes + "source": source, # FEA or NN + "start_time": start_time, + "end_time": end_time, + "study_name": study_name, # Include for debugging + } + ) conn.close() return {"trials": trials} @@ -589,6 +615,7 @@ async def get_optimization_history(study_id: str, limit: Optional[int] = None): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get history: {str(e)}") + @router.get("/studies/{study_id}/pruning") async def get_pruning_history(study_id: str): """Get pruning diagnostics from Optuna database or legacy JSON file""" @@ -615,19 +642,25 @@ async def get_pruning_history(study_id: str): pruned_trials = [] for trial_id, trial_num, start_time, end_time in pruned_rows: # Get parameters for this trial - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) params = {row[0]: row[1] for row in cursor.fetchall()} # Get user attributes (may contain pruning cause) - cursor.execute(""" + cursor.execute( + """ SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) user_attrs = {} for key, value_json in cursor.fetchall(): try: @@ -635,13 +668,15 @@ async def get_pruning_history(study_id: str): except (ValueError, TypeError): user_attrs[key] = value_json - pruned_trials.append({ - "trial_number": trial_num, - "params": params, - "pruning_cause": user_attrs.get("pruning_cause", "Unknown"), - "start_time": start_time, - "end_time": end_time - }) + pruned_trials.append( + { + "trial_number": trial_num, + "params": params, + "pruning_cause": user_attrs.get("pruning_cause", "Unknown"), + "start_time": start_time, + "end_time": end_time, + } + ) conn.close() return {"pruned_trials": pruned_trials, "count": len(pruned_trials)} @@ -660,6 +695,7 @@ async def get_pruning_history(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get pruning history: {str(e)}") + def _infer_objective_unit(objective: Dict) -> str: """Infer unit from objective name and description""" name = objective.get("name", "").lower() @@ -683,12 +719,14 @@ def _infer_objective_unit(objective: Dict) -> str: # Check if unit is explicitly mentioned in description (e.g., "(N/mm)") import re - unit_match = re.search(r'\(([^)]+)\)', desc) + + unit_match = re.search(r"\(([^)]+)\)", desc) if unit_match: return unit_match.group(1) return "" # No unit found + @router.get("/studies/{study_id}/metadata") async def get_study_metadata(study_id: str): """Read optimization_config.json for objectives, design vars, units (Protocol 13)""" @@ -703,7 +741,9 @@ async def get_study_metadata(study_id: str): config_file = study_dir / "1_setup" / "optimization_config.json" if not config_file.exists(): - raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"Config file not found for study {study_id}" + ) with open(config_file) as f: config = json.load(f) @@ -740,7 +780,7 @@ async def get_study_metadata(study_id: str): "description": config.get("description", ""), "sampler": sampler, "algorithm": algorithm, - "n_trials": optimization.get("n_trials", 100) + "n_trials": optimization.get("n_trials", 100), } except FileNotFoundError: @@ -748,6 +788,7 @@ async def get_study_metadata(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get study metadata: {str(e)}") + @router.get("/studies/{study_id}/optimizer-state") async def get_optimizer_state(study_id: str): """ @@ -845,12 +886,14 @@ async def get_optimizer_state(study_id: str): objectives = [] if "objectives" in config: for obj in config["objectives"]: - objectives.append({ - "name": obj.get("name", "objective"), - "direction": obj.get("direction", "minimize"), - "current_best": best_value if len(objectives) == 0 else None, - "unit": obj.get("units", "") - }) + objectives.append( + { + "name": obj.get("name", "objective"), + "direction": obj.get("direction", "minimize"), + "current_best": best_value if len(objectives) == 0 else None, + "unit": obj.get("units", ""), + } + ) # Sampler descriptions sampler_descriptions = { @@ -859,7 +902,7 @@ async def get_optimizer_state(study_id: str): "NSGAIIISampler": "NSGA-III - Reference-point based multi-objective optimization", "CmaEsSampler": "CMA-ES - Covariance Matrix Adaptation Evolution Strategy", "RandomSampler": "Random sampling - Uniform random search across parameter space", - "QMCSampler": "Quasi-Monte Carlo - Low-discrepancy sequence sampling" + "QMCSampler": "Quasi-Monte Carlo - Low-discrepancy sequence sampling", } return { @@ -871,16 +914,24 @@ async def get_optimizer_state(study_id: str): "current_strategy": f"{sampler_name} sampling", "sampler": { "name": sampler_name, - "description": sampler_descriptions.get(sampler_name, f"{sampler_name} optimization algorithm") + "description": sampler_descriptions.get( + sampler_name, f"{sampler_name} optimization algorithm" + ), }, "objectives": objectives, "plan": { "total_phases": 4, - "current_phase": {"initializing": 0, "exploration": 1, "exploitation": 2, "refinement": 3, "convergence": 4}.get(phase, 0), - "phases": ["Exploration", "Exploitation", "Refinement", "Convergence"] + "current_phase": { + "initializing": 0, + "exploration": 1, + "exploitation": 2, + "refinement": 3, + "convergence": 4, + }.get(phase, 0), + "phases": ["Exploration", "Exploitation", "Refinement", "Convergence"], }, "completed_trials": completed_trials, - "total_trials": n_trials + "total_trials": n_trials, } return {"available": False} @@ -890,6 +941,7 @@ async def get_optimizer_state(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get optimizer state: {str(e)}") + @router.get("/studies/{study_id}/pareto-front") async def get_pareto_front(study_id: str): """Get Pareto-optimal solutions for multi-objective studies (Protocol 13)""" @@ -922,10 +974,10 @@ async def get_pareto_front(study_id: str): "values": t.values, "params": t.params, "user_attrs": dict(t.user_attrs), - "constraint_satisfied": t.user_attrs.get("constraint_satisfied", True) + "constraint_satisfied": t.user_attrs.get("constraint_satisfied", True), } for t in pareto_trials - ] + ], } except FileNotFoundError: @@ -951,27 +1003,25 @@ async def get_nn_pareto_front(study_id: str): # Transform to match Trial interface format transformed = [] for trial in nn_pareto: - transformed.append({ - "trial_number": trial.get("trial_number"), - "values": [trial.get("mass"), trial.get("frequency")], - "params": trial.get("params", {}), - "user_attrs": { + transformed.append( + { + "trial_number": trial.get("trial_number"), + "values": [trial.get("mass"), trial.get("frequency")], + "params": trial.get("params", {}), + "user_attrs": { + "source": "NN", + "feasible": trial.get("feasible", False), + "predicted_stress": trial.get("predicted_stress"), + "predicted_displacement": trial.get("predicted_displacement"), + "mass": trial.get("mass"), + "frequency": trial.get("frequency"), + }, + "constraint_satisfied": trial.get("feasible", False), "source": "NN", - "feasible": trial.get("feasible", False), - "predicted_stress": trial.get("predicted_stress"), - "predicted_displacement": trial.get("predicted_displacement"), - "mass": trial.get("mass"), - "frequency": trial.get("frequency") - }, - "constraint_satisfied": trial.get("feasible", False), - "source": "NN" - }) + } + ) - return { - "has_nn_results": True, - "pareto_front": transformed, - "count": len(transformed) - } + return {"has_nn_results": True, "pareto_front": transformed, "count": len(transformed)} except FileNotFoundError: raise HTTPException(status_code=404, detail=f"Study {study_id} not found") @@ -1000,7 +1050,7 @@ async def get_nn_optimization_state(study_id: str): "pareto_front_size": state.get("pareto_front_size", 0), "best_mass": state.get("best_mass"), "best_frequency": state.get("best_frequency"), - "timestamp": state.get("timestamp") + "timestamp": state.get("timestamp"), } except FileNotFoundError: @@ -1014,7 +1064,7 @@ async def create_study( config: str = Form(...), prt_file: Optional[UploadFile] = File(None), sim_file: Optional[UploadFile] = File(None), - fem_file: Optional[UploadFile] = File(None) + fem_file: Optional[UploadFile] = File(None), ): """ Create a new optimization study @@ -1027,7 +1077,7 @@ async def create_study( try: # Parse config config_data = json.loads(config) - study_name = config_data.get("name") # Changed from study_name to name to match frontend + study_name = config_data.get("name") # Changed from study_name to name to match frontend if not study_name: raise HTTPException(status_code=400, detail="name is required in config") @@ -1047,31 +1097,31 @@ async def create_study( # Save config file config_file = setup_dir / "optimization_config.json" - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(config_data, f, indent=2) # Save uploaded files files_saved = {} if prt_file: prt_path = model_dir / prt_file.filename - with open(prt_path, 'wb') as f: + with open(prt_path, "wb") as f: content = await prt_file.read() f.write(content) - files_saved['prt_file'] = str(prt_path) + files_saved["prt_file"] = str(prt_path) if sim_file: sim_path = model_dir / sim_file.filename - with open(sim_path, 'wb') as f: + with open(sim_path, "wb") as f: content = await sim_file.read() f.write(content) - files_saved['sim_file'] = str(sim_path) + files_saved["sim_file"] = str(sim_path) if fem_file: fem_path = model_dir / fem_file.filename - with open(fem_path, 'wb') as f: + with open(fem_path, "wb") as f: content = await fem_file.read() f.write(content) - files_saved['fem_file'] = str(fem_path) + files_saved["fem_file"] = str(fem_path) return JSONResponse( status_code=201, @@ -1081,18 +1131,19 @@ async def create_study( "study_path": str(study_dir), "config_path": str(config_file), "files_saved": files_saved, - "message": f"Study {study_name} created successfully. Ready to run optimization." - } + "message": f"Study {study_name} created successfully. Ready to run optimization.", + }, ) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON in config: {str(e)}") except Exception as e: # Clean up on error - if 'study_dir' in locals() and study_dir.exists(): + if "study_dir" in locals() and study_dir.exists(): shutil.rmtree(study_dir) raise HTTPException(status_code=500, detail=f"Failed to create study: {str(e)}") + @router.post("/studies/{study_id}/convert-mesh") async def convert_study_mesh(study_id: str): """ @@ -1117,7 +1168,7 @@ async def convert_study_mesh(study_id: str): "gltf_path": str(output_path), "gltf_url": f"/api/optimization/studies/{study_id}/mesh/model.gltf", "metadata_url": f"/api/optimization/studies/{study_id}/mesh/model.json", - "message": "Mesh converted successfully" + "message": "Mesh converted successfully", } else: raise HTTPException(status_code=500, detail="Mesh conversion failed") @@ -1127,6 +1178,7 @@ async def convert_study_mesh(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to convert mesh: {str(e)}") + @router.get("/studies/{study_id}/mesh/{filename}") async def get_mesh_file(study_id: str, filename: str): """ @@ -1135,7 +1187,7 @@ async def get_mesh_file(study_id: str, filename: str): """ try: # Validate filename to prevent directory traversal - if '..' in filename or '/' in filename or '\\' in filename: + if ".." in filename or "/" in filename or "\\" in filename: raise HTTPException(status_code=400, detail="Invalid filename") study_dir = resolve_study_path(study_id) @@ -1149,25 +1201,22 @@ async def get_mesh_file(study_id: str, filename: str): # Determine content type suffix = file_path.suffix.lower() content_types = { - '.gltf': 'model/gltf+json', - '.bin': 'application/octet-stream', - '.json': 'application/json', - '.glb': 'model/gltf-binary' + ".gltf": "model/gltf+json", + ".bin": "application/octet-stream", + ".json": "application/json", + ".glb": "model/gltf-binary", } - content_type = content_types.get(suffix, 'application/octet-stream') + content_type = content_types.get(suffix, "application/octet-stream") - return FileResponse( - path=str(file_path), - media_type=content_type, - filename=filename - ) + return FileResponse(path=str(file_path), media_type=content_type, filename=filename) except FileNotFoundError: raise HTTPException(status_code=404, detail=f"File not found") except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to serve mesh file: {str(e)}") + @router.get("/studies/{study_id}/optuna-url") async def get_optuna_dashboard_url(study_id: str): """ @@ -1186,15 +1235,20 @@ async def get_optuna_dashboard_url(study_id: str): study_db = results_dir / "study.db" if not study_db.exists(): - raise HTTPException(status_code=404, detail=f"No Optuna database found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"No Optuna database found for study {study_id}" + ) # Get the study name from the database (may differ from folder name) import optuna + storage = optuna.storages.RDBStorage(f"sqlite:///{study_db}") studies = storage.get_all_studies() if not studies: - raise HTTPException(status_code=404, detail=f"No Optuna study found in database for {study_id}") + raise HTTPException( + status_code=404, detail=f"No Optuna study found in database for {study_id}" + ) # Use the actual study name from the database optuna_study_name = studies[0].study_name @@ -1207,7 +1261,7 @@ async def get_optuna_dashboard_url(study_id: str): "database_path": f"studies/{study_id}/2_results/study.db", "dashboard_url": f"http://localhost:8081/dashboard/studies/{studies[0]._study_id}", "dashboard_base": "http://localhost:8081", - "note": "Optuna dashboard must be started with: sqlite:///studies/{study_id}/2_results/study.db" + "note": "Optuna dashboard must be started with: sqlite:///studies/{study_id}/2_results/study.db", } except FileNotFoundError: @@ -1215,11 +1269,10 @@ async def get_optuna_dashboard_url(study_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get Optuna URL: {str(e)}") + @router.post("/studies/{study_id}/generate-report") async def generate_report( - study_id: str, - format: str = "markdown", - include_llm_summary: bool = False + study_id: str, format: str = "markdown", include_llm_summary: bool = False ): """ Generate an optimization report in the specified format @@ -1238,9 +1291,12 @@ async def generate_report( raise HTTPException(status_code=404, detail=f"Study {study_id} not found") # Validate format - valid_formats = ['markdown', 'md', 'html', 'pdf'] + valid_formats = ["markdown", "md", "html", "pdf"] if format.lower() not in valid_formats: - raise HTTPException(status_code=400, detail=f"Invalid format. Must be one of: {', '.join(valid_formats)}") + raise HTTPException( + status_code=400, + detail=f"Invalid format. Must be one of: {', '.join(valid_formats)}", + ) # Import report generator sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) @@ -1250,7 +1306,7 @@ async def generate_report( output_path = generate_study_report( study_dir=study_dir, output_format=format.lower(), - include_llm_summary=include_llm_summary + include_llm_summary=include_llm_summary, ) if output_path and output_path.exists(): @@ -1263,7 +1319,7 @@ async def generate_report( "file_path": str(output_path), "download_url": f"/api/optimization/studies/{study_id}/reports/{output_path.name}", "file_size": output_path.stat().st_size, - "message": f"Report generated successfully in {format} format" + "message": f"Report generated successfully in {format} format", } else: raise HTTPException(status_code=500, detail="Report generation failed") @@ -1273,6 +1329,7 @@ async def generate_report( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}") + @router.get("/studies/{study_id}/reports/{filename}") async def download_report(study_id: str, filename: str): """ @@ -1287,7 +1344,7 @@ async def download_report(study_id: str, filename: str): """ try: # Validate filename to prevent directory traversal - if '..' in filename or '/' in filename or '\\' in filename: + if ".." in filename or "/" in filename or "\\" in filename: raise HTTPException(status_code=400, detail="Invalid filename") study_dir = resolve_study_path(study_id) @@ -1301,19 +1358,19 @@ async def download_report(study_id: str, filename: str): # Determine content type suffix = file_path.suffix.lower() content_types = { - '.md': 'text/markdown', - '.html': 'text/html', - '.pdf': 'application/pdf', - '.json': 'application/json' + ".md": "text/markdown", + ".html": "text/html", + ".pdf": "application/pdf", + ".json": "application/json", } - content_type = content_types.get(suffix, 'application/octet-stream') + content_type = content_types.get(suffix, "application/octet-stream") return FileResponse( path=str(file_path), media_type=content_type, filename=filename, - headers={"Content-Disposition": f"attachment; filename={filename}"} + headers={"Content-Disposition": f"attachment; filename={filename}"}, ) except FileNotFoundError: @@ -1361,25 +1418,25 @@ async def get_console_output(study_id: str, lines: int = 200): "lines": [], "total_lines": 0, "log_file": None, - "message": "No log file found. Optimization may not have started yet." + "message": "No log file found. Optimization may not have started yet.", } # Read the last N lines efficiently - with open(log_path_used, 'r', encoding='utf-8', errors='replace') as f: + with open(log_path_used, "r", encoding="utf-8", errors="replace") as f: all_lines = f.readlines() # Get last N lines last_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines # Clean up lines (remove trailing newlines) - last_lines = [line.rstrip('\n\r') for line in last_lines] + last_lines = [line.rstrip("\n\r") for line in last_lines] return { "lines": last_lines, "total_lines": len(all_lines), "displayed_lines": len(last_lines), "log_file": str(log_path_used), - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } except HTTPException: @@ -1411,14 +1468,10 @@ async def get_study_report(study_id: str): if not report_path.exists(): raise HTTPException(status_code=404, detail="No STUDY_REPORT.md found for this study") - with open(report_path, 'r', encoding='utf-8') as f: + with open(report_path, "r", encoding="utf-8") as f: content = f.read() - return { - "content": content, - "path": str(report_path), - "study_id": study_id - } + return {"content": content, "path": str(report_path), "study_id": study_id} except HTTPException: raise @@ -1430,6 +1483,7 @@ async def get_study_report(study_id: str): # Study README and Config Endpoints # ============================================================================ + @router.get("/studies/{study_id}/readme") async def get_study_readme(study_id: str): """ @@ -1460,7 +1514,7 @@ async def get_study_readme(study_id: str): for path in readme_paths: if path.exists(): readme_path = path - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: readme_content = f.read() break @@ -1474,15 +1528,15 @@ async def get_study_readme(study_id: str): with open(config_file) as f: config = json.load(f) - readme_content = f"""# {config.get('study_name', study_id)} + readme_content = f"""# {config.get("study_name", study_id)} -{config.get('description', 'No description available.')} +{config.get("description", "No description available.")} ## Design Variables -{chr(10).join([f"- **{dv['name']}**: {dv.get('min', '?')} - {dv.get('max', '?')} {dv.get('units', '')}" for dv in config.get('design_variables', [])])} +{chr(10).join([f"- **{dv['name']}**: {dv.get('min', '?')} - {dv.get('max', '?')} {dv.get('units', '')}" for dv in config.get("design_variables", [])])} ## Objectives -{chr(10).join([f"- **{obj['name']}**: {obj.get('description', '')} ({obj.get('direction', 'minimize')})" for obj in config.get('objectives', [])])} +{chr(10).join([f"- **{obj['name']}**: {obj.get('description', '')} ({obj.get('direction', 'minimize')})" for obj in config.get("objectives", [])])} """ else: readme_content = f"# {study_id}\n\nNo README or configuration found for this study." @@ -1490,7 +1544,7 @@ async def get_study_readme(study_id: str): return { "content": readme_content, "path": str(readme_path) if readme_path else None, - "study_id": study_id + "study_id": study_id, } except HTTPException: @@ -1524,7 +1578,7 @@ async def get_study_image(study_id: str, image_path: str): raise HTTPException(status_code=404, detail=f"Study {study_id} not found") # Sanitize path to prevent directory traversal - image_path = image_path.replace('..', '').lstrip('/') + image_path = image_path.replace("..", "").lstrip("/") # Try multiple locations for the image possible_paths = [ @@ -1547,14 +1601,14 @@ async def get_study_image(study_id: str, image_path: str): # Determine media type suffix = image_file.suffix.lower() media_types = { - '.png': 'image/png', - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.gif': 'image/gif', - '.svg': 'image/svg+xml', - '.webp': 'image/webp', + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".svg": "image/svg+xml", + ".webp": "image/webp", } - media_type = media_types.get(suffix, 'application/octet-stream') + media_type = media_types.get(suffix, "application/octet-stream") return FileResponse(image_file, media_type=media_type) @@ -1594,7 +1648,7 @@ async def get_study_config(study_id: str): "config": config, "path": str(spec_file), "study_id": study_id, - "source": "atomizer_spec" + "source": "atomizer_spec", } # Priority 2: Legacy optimization_config.json @@ -1603,7 +1657,9 @@ async def get_study_config(study_id: str): config_file = study_dir / "optimization_config.json" if not config_file.exists(): - raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"Config file not found for study {study_id}" + ) with open(config_file) as f: config = json.load(f) @@ -1612,7 +1668,7 @@ async def get_study_config(study_id: str): "config": config, "path": str(config_file), "study_id": study_id, - "source": "legacy_config" + "source": "legacy_config", } except HTTPException: @@ -1631,59 +1687,67 @@ def _transform_spec_to_config(spec: dict, study_id: str) -> dict: design_variables = [] for dv in spec.get("design_variables", []): bounds = dv.get("bounds", {}) - design_variables.append({ - "name": dv.get("name"), - "expression_name": dv.get("expression_name"), - "type": "float" if dv.get("type") == "continuous" else dv.get("type", "float"), - "min": bounds.get("min"), - "max": bounds.get("max"), - "low": bounds.get("min"), # Alias for compatibility - "high": bounds.get("max"), # Alias for compatibility - "baseline": dv.get("baseline"), - "unit": dv.get("units"), - "units": dv.get("units"), - "enabled": dv.get("enabled", True) - }) + design_variables.append( + { + "name": dv.get("name"), + "expression_name": dv.get("expression_name"), + "type": "float" if dv.get("type") == "continuous" else dv.get("type", "float"), + "min": bounds.get("min"), + "max": bounds.get("max"), + "low": bounds.get("min"), # Alias for compatibility + "high": bounds.get("max"), # Alias for compatibility + "baseline": dv.get("baseline"), + "unit": dv.get("units"), + "units": dv.get("units"), + "enabled": dv.get("enabled", True), + } + ) # Transform objectives objectives = [] for obj in spec.get("objectives", []): source = obj.get("source", {}) - objectives.append({ - "name": obj.get("name"), - "direction": obj.get("direction", "minimize"), - "weight": obj.get("weight", 1.0), - "target": obj.get("target"), - "unit": obj.get("units"), - "units": obj.get("units"), - "extractor_id": source.get("extractor_id"), - "output_key": source.get("output_key") - }) + objectives.append( + { + "name": obj.get("name"), + "direction": obj.get("direction", "minimize"), + "weight": obj.get("weight", 1.0), + "target": obj.get("target"), + "unit": obj.get("units"), + "units": obj.get("units"), + "extractor_id": source.get("extractor_id"), + "output_key": source.get("output_key"), + } + ) # Transform constraints constraints = [] for con in spec.get("constraints", []): - constraints.append({ - "name": con.get("name"), - "type": _operator_to_type(con.get("operator", "<=")), - "operator": con.get("operator"), - "max_value": con.get("threshold") if con.get("operator") in ["<=", "<"] else None, - "min_value": con.get("threshold") if con.get("operator") in [">=", ">"] else None, - "bound": con.get("threshold"), - "unit": con.get("units"), - "units": con.get("units") - }) + constraints.append( + { + "name": con.get("name"), + "type": _operator_to_type(con.get("operator", "<=")), + "operator": con.get("operator"), + "max_value": con.get("threshold") if con.get("operator") in ["<=", "<"] else None, + "min_value": con.get("threshold") if con.get("operator") in [">=", ">"] else None, + "bound": con.get("threshold"), + "unit": con.get("units"), + "units": con.get("units"), + } + ) # Transform extractors extractors = [] for ext in spec.get("extractors", []): - extractors.append({ - "name": ext.get("name"), - "type": ext.get("type"), - "builtin": ext.get("builtin", True), - "config": ext.get("config", {}), - "outputs": ext.get("outputs", []) - }) + extractors.append( + { + "name": ext.get("name"), + "type": ext.get("type"), + "builtin": ext.get("builtin", True), + "config": ext.get("config", {}), + "outputs": ext.get("outputs", []), + } + ) # Get algorithm info algorithm = optimization.get("algorithm", {}) @@ -1702,19 +1766,21 @@ def _transform_spec_to_config(spec: dict, study_id: str) -> dict: "algorithm": algorithm.get("type", "TPE"), "n_trials": budget.get("max_trials", 100), "max_time_hours": budget.get("max_time_hours"), - "convergence_patience": budget.get("convergence_patience") + "convergence_patience": budget.get("convergence_patience"), }, "optimization_settings": { "sampler": algorithm.get("type", "TPE"), - "n_trials": budget.get("max_trials", 100) + "n_trials": budget.get("max_trials", 100), }, "algorithm": { "name": "Optuna", "sampler": algorithm.get("type", "TPE"), - "n_trials": budget.get("max_trials", 100) + "n_trials": budget.get("max_trials", 100), }, "model": model, - "sim_file": model.get("sim", {}).get("path") if isinstance(model.get("sim"), dict) else None + "sim_file": model.get("sim", {}).get("path") + if isinstance(model.get("sim"), dict) + else None, } return config @@ -1722,14 +1788,7 @@ def _transform_spec_to_config(spec: dict, study_id: str) -> dict: def _operator_to_type(operator: str) -> str: """Convert constraint operator to legacy type string.""" - mapping = { - "<=": "le", - "<": "le", - ">=": "ge", - ">": "ge", - "==": "eq", - "=": "eq" - } + mapping = {"<=": "le", "<": "le", ">=": "ge", ">": "ge", "==": "eq", "=": "eq"} return mapping.get(operator, "le") @@ -1754,7 +1813,7 @@ def _find_optimization_process(study_id: str) -> Optional[psutil.Process]: """ study_dir = resolve_study_path(study_id) study_dir_str = str(study_dir).lower() - study_dir_normalized = study_dir_str.replace('/', '\\') + study_dir_normalized = study_dir_str.replace("/", "\\") # Strategy 1: Check tracked processes first (most reliable) if study_id in _running_processes: @@ -1770,28 +1829,30 @@ def _find_optimization_process(study_id: str) -> Optional[psutil.Process]: del _running_processes[study_id] # Strategy 2 & 3: Scan all Python processes - for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'cwd']): + for proc in psutil.process_iter(["pid", "name", "cmdline", "cwd"]): try: # Only check Python processes - proc_name = (proc.info.get('name') or '').lower() - if 'python' not in proc_name: + proc_name = (proc.info.get("name") or "").lower() + if "python" not in proc_name: continue - cmdline = proc.info.get('cmdline') or [] - cmdline_str = ' '.join(cmdline) if cmdline else '' + cmdline = proc.info.get("cmdline") or [] + cmdline_str = " ".join(cmdline) if cmdline else "" cmdline_lower = cmdline_str.lower() # Check if this is an optimization script (run_optimization.py or run_sat_optimization.py) - is_opt_script = ('run_optimization' in cmdline_lower or - 'run_sat_optimization' in cmdline_lower or - 'run_imso' in cmdline_lower) + is_opt_script = ( + "run_optimization" in cmdline_lower + or "run_sat_optimization" in cmdline_lower + or "run_imso" in cmdline_lower + ) if not is_opt_script: continue # Strategy 2: Check CWD (most reliable on Windows) try: - proc_cwd = proc.cwd().lower().replace('/', '\\') + proc_cwd = proc.cwd().lower().replace("/", "\\") if proc_cwd == study_dir_normalized or study_dir_normalized in proc_cwd: # Track this process for future lookups _running_processes[study_id] = proc.pid @@ -1800,7 +1861,7 @@ def _find_optimization_process(study_id: str) -> Optional[psutil.Process]: pass # Strategy 3: Check cmdline for study path or study_id - cmdline_normalized = cmdline_lower.replace('/', '\\') + cmdline_normalized = cmdline_lower.replace("/", "\\") if study_dir_normalized in cmdline_normalized: _running_processes[study_id] = proc.pid return proc @@ -1898,7 +1959,8 @@ async def get_process_status(study_id: str): result = cursor.fetchone() if result: import re - match = re.search(r'iter(\d+)', result[0]) + + match = re.search(r"iter(\d+)", result[0]) if match: iteration = int(match.group(1)) @@ -1914,6 +1976,7 @@ async def get_process_status(study_id: str): if len(timestamps) >= 2: from datetime import datetime as dt + try: # Parse timestamps and calculate average time between trials times = [] @@ -1922,7 +1985,7 @@ async def get_process_status(study_id: str): # Handle different timestamp formats ts_str = ts[0] try: - times.append(dt.fromisoformat(ts_str.replace('Z', '+00:00'))) + times.append(dt.fromisoformat(ts_str.replace("Z", "+00:00"))) except ValueError: try: times.append(dt.strptime(ts_str, "%Y-%m-%d %H:%M:%S.%f")) @@ -1932,9 +1995,14 @@ async def get_process_status(study_id: str): if len(times) >= 2: # Sort ascending and calculate time differences times.sort() - diffs = [(times[i+1] - times[i]).total_seconds() for i in range(len(times)-1)] + diffs = [ + (times[i + 1] - times[i]).total_seconds() + for i in range(len(times) - 1) + ] # Filter out outliers (e.g., breaks in optimization) - diffs = [d for d in diffs if d > 0 and d < 7200] # Max 2 hours between trials + diffs = [ + d for d in diffs if d > 0 and d < 7200 + ] # Max 2 hours between trials if diffs: time_per_trial = sum(diffs) / len(diffs) rate_per_hour = 3600 / time_per_trial if time_per_trial > 0 else 0 @@ -1982,7 +2050,7 @@ async def get_process_status(study_id: str): "eta_formatted": eta_formatted, "rate_per_hour": round(rate_per_hour, 2) if rate_per_hour else None, "start_time": start_time, - "study_id": study_id + "study_id": study_id, } except HTTPException: @@ -2004,15 +2072,15 @@ class StartOptimizationRequest(BaseModel): def _detect_script_type(script_path: Path) -> str: """Detect the type of optimization script (SAT, IMSO, etc.)""" try: - content = script_path.read_text(encoding='utf-8', errors='ignore') - if 'run_sat_optimization' in content or 'SAT' in content or 'Self-Aware Turbo' in content: - return 'sat' - elif 'IMSO' in content or 'intelligent_optimizer' in content: - return 'imso' + content = script_path.read_text(encoding="utf-8", errors="ignore") + if "run_sat_optimization" in content or "SAT" in content or "Self-Aware Turbo" in content: + return "sat" + elif "IMSO" in content or "intelligent_optimizer" in content: + return "imso" else: - return 'generic' + return "generic" except: - return 'generic' + return "generic" @router.post("/studies/{study_id}/start") @@ -2039,7 +2107,7 @@ async def start_optimization(study_id: str, request: StartOptimizationRequest = return { "success": False, "message": f"Optimization already running (PID: {existing_proc.pid})", - "pid": existing_proc.pid + "pid": existing_proc.pid, } # Find run_optimization.py or run_sat_optimization.py @@ -2047,7 +2115,9 @@ async def start_optimization(study_id: str, request: StartOptimizationRequest = if not run_script.exists(): run_script = study_dir / "run_optimization.py" if not run_script.exists(): - raise HTTPException(status_code=404, detail=f"No optimization script found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"No optimization script found for study {study_id}" + ) # Detect script type and build appropriate command script_type = _detect_script_type(run_script) @@ -2055,7 +2125,7 @@ async def start_optimization(study_id: str, request: StartOptimizationRequest = cmd = [python_exe, str(run_script)] if request: - if script_type == 'sat': + if script_type == "sat": # SAT scripts use --trials cmd.extend(["--trials", str(request.trials)]) if request.freshStart: @@ -2077,7 +2147,7 @@ async def start_optimization(study_id: str, request: StartOptimizationRequest = cwd=str(study_dir), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - start_new_session=True + start_new_session=True, ) _running_processes[study_id] = proc.pid @@ -2086,8 +2156,8 @@ async def start_optimization(study_id: str, request: StartOptimizationRequest = "success": True, "message": f"Optimization started successfully ({script_type} script)", "pid": proc.pid, - "command": ' '.join(cmd), - "script_type": script_type + "command": " ".join(cmd), + "script_type": script_type, } except HTTPException: @@ -2125,10 +2195,7 @@ async def stop_optimization(study_id: str, request: StopRequest = None): proc = _find_optimization_process(study_id) if not proc: - return { - "success": False, - "message": "No running optimization process found" - } + return {"success": False, "message": "No running optimization process found"} pid = proc.pid killed_pids = [] @@ -2183,16 +2250,13 @@ async def stop_optimization(study_id: str, request: StopRequest = None): "success": True, "message": f"Optimization killed (PID: {pid}, +{len(children)} children)", "pid": pid, - "killed_pids": killed_pids + "killed_pids": killed_pids, } except psutil.NoSuchProcess: if study_id in _running_processes: del _running_processes[study_id] - return { - "success": True, - "message": "Process already terminated" - } + return {"success": True, "message": "Process already terminated"} except HTTPException: raise @@ -2223,17 +2287,11 @@ async def pause_optimization(study_id: str): proc = _find_optimization_process(study_id) if not proc: - return { - "success": False, - "message": "No running optimization process found" - } + return {"success": False, "message": "No running optimization process found"} # Check if already paused if _paused_processes.get(study_id, False): - return { - "success": False, - "message": "Optimization is already paused" - } + return {"success": False, "message": "Optimization is already paused"} pid = proc.pid @@ -2261,14 +2319,11 @@ async def pause_optimization(study_id: str): return { "success": True, "message": f"Optimization paused (PID: {pid}, +{len(children)} children)", - "pid": pid + "pid": pid, } except psutil.NoSuchProcess: - return { - "success": False, - "message": "Process no longer exists" - } + return {"success": False, "message": "Process no longer exists"} except HTTPException: raise @@ -2299,17 +2354,11 @@ async def resume_optimization(study_id: str): proc = _find_optimization_process(study_id) if not proc: - return { - "success": False, - "message": "No optimization process found" - } + return {"success": False, "message": "No optimization process found"} # Check if actually paused if not _paused_processes.get(study_id, False): - return { - "success": False, - "message": "Optimization is not paused" - } + return {"success": False, "message": "Optimization is not paused"} pid = proc.pid @@ -2337,15 +2386,12 @@ async def resume_optimization(study_id: str): return { "success": True, "message": f"Optimization resumed (PID: {pid}, +{len(children)} children)", - "pid": pid + "pid": pid, } except psutil.NoSuchProcess: _paused_processes[study_id] = False - return { - "success": False, - "message": "Process no longer exists" - } + return {"success": False, "message": "Process no longer exists"} except HTTPException: raise @@ -2380,7 +2426,7 @@ async def validate_optimization(study_id: str, request: ValidateRequest = None): if existing_proc: return { "success": False, - "message": "Cannot validate while optimization is running. Stop optimization first." + "message": "Cannot validate while optimization is running. Stop optimization first.", } # Look for final_validation.py script @@ -2406,14 +2452,14 @@ async def validate_optimization(study_id: str, request: ValidateRequest = None): cwd=str(study_dir), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - start_new_session=True + start_new_session=True, ) return { "success": True, "message": f"Validation started for top {top_n} NN predictions", "pid": proc.pid, - "command": ' '.join(cmd) + "command": " ".join(cmd), } except HTTPException: @@ -2428,6 +2474,7 @@ async def validate_optimization(study_id: str, request: ValidateRequest = None): _optuna_processes: Dict[str, subprocess.Popen] = {} + @router.post("/studies/{study_id}/optuna-dashboard") async def launch_optuna_dashboard(study_id: str): """ @@ -2445,7 +2492,7 @@ async def launch_optuna_dashboard(study_id: str): def is_port_in_use(port: int) -> bool: """Check if a port is already in use""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 try: study_dir = resolve_study_path(study_id) @@ -2457,7 +2504,9 @@ async def launch_optuna_dashboard(study_id: str): study_db = results_dir / "study.db" if not study_db.exists(): - raise HTTPException(status_code=404, detail=f"No Optuna database found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"No Optuna database found for study {study_id}" + ) port = 8081 @@ -2471,14 +2520,14 @@ async def launch_optuna_dashboard(study_id: str): "success": True, "url": f"http://localhost:{port}", "pid": proc.pid, - "message": "Optuna dashboard already running" + "message": "Optuna dashboard already running", } # Port in use but not by us - still return success since dashboard is available return { "success": True, "url": f"http://localhost:{port}", "pid": None, - "message": "Optuna dashboard already running on port 8081" + "message": "Optuna dashboard already running on port 8081", } # Launch optuna-dashboard using CLI command (more robust than Python import) @@ -2492,7 +2541,7 @@ async def launch_optuna_dashboard(study_id: str): optuna_dashboard_cmd = "optuna-dashboard" - if platform.system() == 'Windows': + if platform.system() == "Windows": # Try to find optuna-dashboard in the conda environment conda_paths = [ Path(r"C:\Users\antoi\anaconda3\envs\atomizer\Scripts\optuna-dashboard.exe"), @@ -2511,13 +2560,13 @@ async def launch_optuna_dashboard(study_id: str): # Return helpful error message raise HTTPException( status_code=500, - detail="optuna-dashboard not found. Install with: pip install optuna-dashboard" + detail="optuna-dashboard not found. Install with: pip install optuna-dashboard", ) cmd = [optuna_dashboard_cmd, storage_url, "--port", str(port), "--host", "0.0.0.0"] # On Windows, use CREATE_NEW_PROCESS_GROUP and DETACHED_PROCESS flags - if platform.system() == 'Windows': + if platform.system() == "Windows": # Windows-specific: create detached process DETACHED_PROCESS = 0x00000008 CREATE_NEW_PROCESS_GROUP = 0x00000200 @@ -2525,14 +2574,11 @@ async def launch_optuna_dashboard(study_id: str): cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP + creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP, ) else: proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - start_new_session=True + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, start_new_session=True ) _optuna_processes[study_id] = proc @@ -2546,7 +2592,7 @@ async def launch_optuna_dashboard(study_id: str): "success": True, "url": f"http://localhost:{port}", "pid": proc.pid, - "message": "Optuna dashboard launched successfully" + "message": "Optuna dashboard launched successfully", } # Check if process died if proc.poll() is not None: @@ -2555,10 +2601,7 @@ async def launch_optuna_dashboard(study_id: str): stderr = proc.stderr.read().decode() if proc.stderr else "" except: pass - return { - "success": False, - "message": f"Failed to start Optuna dashboard: {stderr}" - } + return {"success": False, "message": f"Failed to start Optuna dashboard: {stderr}"} time.sleep(0.5) # Timeout - process might still be starting @@ -2567,7 +2610,7 @@ async def launch_optuna_dashboard(study_id: str): "success": True, "url": f"http://localhost:{port}", "pid": proc.pid, - "message": "Optuna dashboard starting (may take a moment)" + "message": "Optuna dashboard starting (may take a moment)", } else: stderr = "" @@ -2575,10 +2618,7 @@ async def launch_optuna_dashboard(study_id: str): stderr = proc.stderr.read().decode() if proc.stderr else "" except: pass - return { - "success": False, - "message": f"Failed to start Optuna dashboard: {stderr}" - } + return {"success": False, "message": f"Failed to start Optuna dashboard: {stderr}"} except HTTPException: raise @@ -2597,7 +2637,7 @@ async def check_optuna_dashboard_available(): import platform import shutil as sh - if platform.system() == 'Windows': + if platform.system() == "Windows": # Check conda environment paths conda_paths = [ Path(r"C:\Users\antoi\anaconda3\envs\atomizer\Scripts\optuna-dashboard.exe"), @@ -2608,23 +2648,19 @@ async def check_optuna_dashboard_available(): return { "available": True, "path": str(conda_path), - "message": "optuna-dashboard is installed" + "message": "optuna-dashboard is installed", } # Fallback: check system PATH found_path = sh.which("optuna-dashboard") if found_path: - return { - "available": True, - "path": found_path, - "message": "optuna-dashboard is installed" - } + return {"available": True, "path": found_path, "message": "optuna-dashboard is installed"} return { "available": False, "path": None, "message": "optuna-dashboard not found", - "install_instructions": "pip install optuna-dashboard" + "install_instructions": "pip install optuna-dashboard", } @@ -2632,6 +2668,7 @@ async def check_optuna_dashboard_available(): # Model Files Endpoint # ============================================================================ + @router.get("/studies/{study_id}/model-files") async def get_model_files(study_id: str): """ @@ -2654,39 +2691,52 @@ async def get_model_files(study_id: str): study_dir / "1_setup" / "model", study_dir / "model", study_dir / "1_setup", - study_dir + study_dir, ] model_files = [] model_dir_path = None # NX and FEA file extensions to look for - nx_extensions = {'.prt', '.sim', '.fem', '.bdf', '.dat', '.op2', '.f06', '.inp'} + nx_extensions = {".prt", ".sim", ".fem", ".bdf", ".dat", ".op2", ".f06", ".inp"} for model_dir in model_dirs: if model_dir.exists() and model_dir.is_dir(): for file_path in model_dir.iterdir(): if file_path.is_file() and file_path.suffix.lower() in nx_extensions: - model_files.append({ - "name": file_path.name, - "path": str(file_path), - "extension": file_path.suffix.lower(), - "size_bytes": file_path.stat().st_size, - "size_display": _format_file_size(file_path.stat().st_size), - "modified": datetime.fromtimestamp(file_path.stat().st_mtime).isoformat() - }) + model_files.append( + { + "name": file_path.name, + "path": str(file_path), + "extension": file_path.suffix.lower(), + "size_bytes": file_path.stat().st_size, + "size_display": _format_file_size(file_path.stat().st_size), + "modified": datetime.fromtimestamp( + file_path.stat().st_mtime + ).isoformat(), + } + ) if model_dir_path is None: model_dir_path = str(model_dir) # Sort by extension for better display (prt first, then sim, fem, etc.) - extension_order = {'.prt': 0, '.sim': 1, '.fem': 2, '.bdf': 3, '.dat': 4, '.op2': 5, '.f06': 6, '.inp': 7} - model_files.sort(key=lambda x: (extension_order.get(x['extension'], 99), x['name'])) + extension_order = { + ".prt": 0, + ".sim": 1, + ".fem": 2, + ".bdf": 3, + ".dat": 4, + ".op2": 5, + ".f06": 6, + ".inp": 7, + } + model_files.sort(key=lambda x: (extension_order.get(x["extension"], 99), x["name"])) return { "study_id": study_id, "model_dir": model_dir_path or str(study_dir / "1_setup" / "model"), "files": model_files, - "count": len(model_files) + "count": len(model_files), } except HTTPException: @@ -2753,16 +2803,12 @@ async def open_model_folder(study_id: str, folder_type: str = "model"): else: # Linux subprocess.Popen(["xdg-open", str(target_dir)]) - return { - "success": True, - "message": f"Opened {target_dir}", - "path": str(target_dir) - } + return {"success": True, "message": f"Opened {target_dir}", "path": str(target_dir)} except Exception as e: return { "success": False, "message": f"Failed to open folder: {str(e)}", - "path": str(target_dir) + "path": str(target_dir), } except HTTPException: @@ -2788,7 +2834,7 @@ async def get_best_solution(study_id: str): "best_trial": None, "first_trial": None, "improvements": {}, - "total_trials": 0 + "total_trials": 0, } conn = sqlite3.connect(str(db_path)) @@ -2827,61 +2873,72 @@ async def get_best_solution(study_id: str): improvements = {} if best_row: - best_trial_id = best_row['trial_id'] + best_trial_id = best_row["trial_id"] # Get design variables - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (best_trial_id,)) - params = {row['param_name']: row['param_value'] for row in cursor.fetchall()} + """, + (best_trial_id,), + ) + params = {row["param_name"]: row["param_value"] for row in cursor.fetchall()} # Get user attributes (including results) - cursor.execute(""" + cursor.execute( + """ SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ? - """, (best_trial_id,)) + """, + (best_trial_id,), + ) user_attrs = {} for row in cursor.fetchall(): try: - user_attrs[row['key']] = json.loads(row['value_json']) + user_attrs[row["key"]] = json.loads(row["value_json"]) except: - user_attrs[row['key']] = row['value_json'] + user_attrs[row["key"]] = row["value_json"] best_trial = { - "trial_number": best_row['number'], - "objective": best_row['objective'], + "trial_number": best_row["number"], + "objective": best_row["objective"], "design_variables": params, "user_attrs": user_attrs, - "timestamp": best_row['timestamp'] + "timestamp": best_row["timestamp"], } if first_row: - first_trial_id = first_row['trial_id'] + first_trial_id = first_row["trial_id"] - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (first_trial_id,)) - first_params = {row['param_name']: row['param_value'] for row in cursor.fetchall()} + """, + (first_trial_id,), + ) + first_params = {row["param_name"]: row["param_value"] for row in cursor.fetchall()} first_trial = { - "trial_number": first_row['number'], - "objective": first_row['objective'], - "design_variables": first_params + "trial_number": first_row["number"], + "objective": first_row["objective"], + "design_variables": first_params, } # Calculate improvement - if best_row and first_row['objective'] != 0: - improvement_pct = ((first_row['objective'] - best_row['objective']) / abs(first_row['objective'])) * 100 + if best_row and first_row["objective"] != 0: + improvement_pct = ( + (first_row["objective"] - best_row["objective"]) / abs(first_row["objective"]) + ) * 100 improvements["objective"] = { - "initial": first_row['objective'], - "final": best_row['objective'], + "initial": first_row["objective"], + "final": best_row["objective"], "improvement_pct": round(improvement_pct, 2), - "absolute_change": round(first_row['objective'] - best_row['objective'], 6) + "absolute_change": round(first_row["objective"] - best_row["objective"], 6), } conn.close() @@ -2891,7 +2948,7 @@ async def get_best_solution(study_id: str): "best_trial": best_trial, "first_trial": first_trial, "improvements": improvements, - "total_trials": total_trials + "total_trials": total_trials, } except HTTPException: @@ -2932,68 +2989,78 @@ async def get_study_runs(study_id: str): runs = [] for study_row in studies: - optuna_study_id = study_row['study_id'] - study_name = study_row['study_name'] + optuna_study_id = study_row["study_id"] + study_name = study_row["study_name"] # Get trial count - cursor.execute(""" + cursor.execute( + """ SELECT COUNT(*) FROM trials WHERE study_id = ? AND state = 'COMPLETE' - """, (optuna_study_id,)) + """, + (optuna_study_id,), + ) trial_count = cursor.fetchone()[0] if trial_count == 0: continue # Get best value (first objective) - cursor.execute(""" + cursor.execute( + """ SELECT MIN(tv.value) as best_value FROM trial_values tv JOIN trials t ON tv.trial_id = t.trial_id WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0 - """, (optuna_study_id,)) + """, + (optuna_study_id,), + ) best_result = cursor.fetchone() - best_value = best_result['best_value'] if best_result else None + best_value = best_result["best_value"] if best_result else None # Get average value - cursor.execute(""" + cursor.execute( + """ SELECT AVG(tv.value) as avg_value FROM trial_values tv JOIN trials t ON tv.trial_id = t.trial_id WHERE t.study_id = ? AND t.state = 'COMPLETE' AND tv.objective = 0 - """, (optuna_study_id,)) + """, + (optuna_study_id,), + ) avg_result = cursor.fetchone() - avg_value = avg_result['avg_value'] if avg_result else None + avg_value = avg_result["avg_value"] if avg_result else None # Get time range - cursor.execute(""" + cursor.execute( + """ SELECT MIN(datetime_start) as first_trial, MAX(datetime_complete) as last_trial FROM trials WHERE study_id = ? AND state = 'COMPLETE' - """, (optuna_study_id,)) + """, + (optuna_study_id,), + ) time_result = cursor.fetchone() # Determine source type (FEA or NN) source = "NN" if "_nn" in study_name.lower() else "FEA" - runs.append({ - "run_id": optuna_study_id, - "name": study_name, - "source": source, - "trial_count": trial_count, - "best_value": best_value, - "avg_value": avg_value, - "first_trial": time_result['first_trial'] if time_result else None, - "last_trial": time_result['last_trial'] if time_result else None - }) + runs.append( + { + "run_id": optuna_study_id, + "name": study_name, + "source": source, + "trial_count": trial_count, + "best_value": best_value, + "avg_value": avg_value, + "first_trial": time_result["first_trial"] if time_result else None, + "last_trial": time_result["last_trial"] if time_result else None, + } + ) conn.close() - return { - "runs": runs, - "total_runs": len(runs), - "study_id": study_id - } + return {"runs": runs, "total_runs": len(runs), "study_id": study_id} except HTTPException: raise @@ -3016,122 +3083,126 @@ def intent_to_config(intent: dict, existing_config: Optional[dict] = None) -> di config = existing_config.copy() if existing_config else {} # Metadata - if intent.get('model', {}).get('path'): - model_path = Path(intent['model']['path']).name - if 'simulation' not in config: - config['simulation'] = {} - config['simulation']['model_file'] = model_path + if intent.get("model", {}).get("path"): + model_path = Path(intent["model"]["path"]).name + if "simulation" not in config: + config["simulation"] = {} + config["simulation"]["model_file"] = model_path # Try to infer other files from model name - base_name = model_path.replace('.prt', '') - if not config['simulation'].get('fem_file'): - config['simulation']['fem_file'] = f"{base_name}_fem1.fem" - if not config['simulation'].get('sim_file'): - config['simulation']['sim_file'] = f"{base_name}_sim1.sim" + base_name = model_path.replace(".prt", "") + if not config["simulation"].get("fem_file"): + config["simulation"]["fem_file"] = f"{base_name}_fem1.fem" + if not config["simulation"].get("sim_file"): + config["simulation"]["sim_file"] = f"{base_name}_sim1.sim" # Solver - if intent.get('solver', {}).get('type'): - solver_type = intent['solver']['type'] - if 'simulation' not in config: - config['simulation'] = {} - config['simulation']['solver'] = 'nastran' + if intent.get("solver", {}).get("type"): + solver_type = intent["solver"]["type"] + if "simulation" not in config: + config["simulation"] = {} + config["simulation"]["solver"] = "nastran" # Map SOL types to analysis_types sol_to_analysis = { - 'SOL101': ['static'], - 'SOL103': ['modal'], - 'SOL105': ['buckling'], - 'SOL106': ['nonlinear'], - 'SOL111': ['modal', 'frequency_response'], - 'SOL112': ['modal', 'transient'], + "SOL101": ["static"], + "SOL103": ["modal"], + "SOL105": ["buckling"], + "SOL106": ["nonlinear"], + "SOL111": ["modal", "frequency_response"], + "SOL112": ["modal", "transient"], } - config['simulation']['analysis_types'] = sol_to_analysis.get(solver_type, ['static']) + config["simulation"]["analysis_types"] = sol_to_analysis.get(solver_type, ["static"]) # Design Variables - if intent.get('design_variables'): - config['design_variables'] = [] - for dv in intent['design_variables']: - config['design_variables'].append({ - 'parameter': dv.get('name', dv.get('expression_name', '')), - 'bounds': [dv.get('min', 0), dv.get('max', 100)], - 'description': dv.get('description', f"Design variable: {dv.get('name', '')}"), - }) + if intent.get("design_variables"): + config["design_variables"] = [] + for dv in intent["design_variables"]: + config["design_variables"].append( + { + "parameter": dv.get("name", dv.get("expression_name", "")), + "bounds": [dv.get("min", 0), dv.get("max", 100)], + "description": dv.get("description", f"Design variable: {dv.get('name', '')}"), + } + ) # Extractors → used for objectives/constraints extraction extractor_map = {} - if intent.get('extractors'): - for ext in intent['extractors']: - ext_id = ext.get('id', '') - ext_name = ext.get('name', '') + if intent.get("extractors"): + for ext in intent["extractors"]: + ext_id = ext.get("id", "") + ext_name = ext.get("name", "") extractor_map[ext_name] = ext # Objectives - if intent.get('objectives'): - config['objectives'] = [] - for obj in intent['objectives']: + if intent.get("objectives"): + config["objectives"] = [] + for obj in intent["objectives"]: obj_config = { - 'name': obj.get('name', 'objective'), - 'goal': obj.get('direction', 'minimize'), - 'weight': obj.get('weight', 1.0), - 'description': obj.get('description', f"Objective: {obj.get('name', '')}"), + "name": obj.get("name", "objective"), + "goal": obj.get("direction", "minimize"), + "weight": obj.get("weight", 1.0), + "description": obj.get("description", f"Objective: {obj.get('name', '')}"), } # Add extraction config if extractor referenced - extractor_name = obj.get('extractor') + extractor_name = obj.get("extractor") if extractor_name and extractor_name in extractor_map: ext = extractor_map[extractor_name] - ext_config = ext.get('config', {}) - obj_config['extraction'] = { - 'action': _extractor_id_to_action(ext.get('id', '')), - 'domain': 'result_extraction', - 'params': ext_config, + ext_config = ext.get("config", {}) + obj_config["extraction"] = { + "action": _extractor_id_to_action(ext.get("id", "")), + "domain": "result_extraction", + "params": ext_config, } - config['objectives'].append(obj_config) + config["objectives"].append(obj_config) # Constraints - if intent.get('constraints'): - config['constraints'] = [] - for con in intent['constraints']: - op = con.get('operator', '<=') - con_type = 'less_than' if '<' in op else 'greater_than' if '>' in op else 'equal_to' + if intent.get("constraints"): + config["constraints"] = [] + for con in intent["constraints"]: + op = con.get("operator", "<=") + con_type = "less_than" if "<" in op else "greater_than" if ">" in op else "equal_to" con_config = { - 'name': con.get('name', 'constraint'), - 'type': con_type, - 'threshold': con.get('value', 0), - 'description': con.get('description', f"Constraint: {con.get('name', '')}"), + "name": con.get("name", "constraint"), + "type": con_type, + "threshold": con.get("value", 0), + "description": con.get("description", f"Constraint: {con.get('name', '')}"), } # Add extraction config if extractor referenced - extractor_name = con.get('extractor') + extractor_name = con.get("extractor") if extractor_name and extractor_name in extractor_map: ext = extractor_map[extractor_name] - ext_config = ext.get('config', {}) - con_config['extraction'] = { - 'action': _extractor_id_to_action(ext.get('id', '')), - 'domain': 'result_extraction', - 'params': ext_config, + ext_config = ext.get("config", {}) + con_config["extraction"] = { + "action": _extractor_id_to_action(ext.get("id", "")), + "domain": "result_extraction", + "params": ext_config, } - config['constraints'].append(con_config) + config["constraints"].append(con_config) # Optimization settings - if intent.get('optimization'): - opt = intent['optimization'] - if 'optimization_settings' not in config: - config['optimization_settings'] = {} - if opt.get('max_trials'): - config['optimization_settings']['n_trials'] = opt['max_trials'] - if opt.get('method'): + if intent.get("optimization"): + opt = intent["optimization"] + if "optimization_settings" not in config: + config["optimization_settings"] = {} + if opt.get("max_trials"): + config["optimization_settings"]["n_trials"] = opt["max_trials"] + if opt.get("method"): # Map method names to Optuna sampler names method_map = { - 'TPE': 'TPESampler', - 'CMA-ES': 'CmaEsSampler', - 'NSGA-II': 'NSGAIISampler', - 'RandomSearch': 'RandomSampler', - 'GP-BO': 'GPSampler', + "TPE": "TPESampler", + "CMA-ES": "CmaEsSampler", + "NSGA-II": "NSGAIISampler", + "RandomSearch": "RandomSampler", + "GP-BO": "GPSampler", } - config['optimization_settings']['sampler'] = method_map.get(opt['method'], opt['method']) + config["optimization_settings"]["sampler"] = method_map.get( + opt["method"], opt["method"] + ) # Surrogate - if intent.get('surrogate', {}).get('enabled'): - config['surrogate'] = { - 'type': intent['surrogate'].get('type', 'MLP'), - 'min_trials': intent['surrogate'].get('min_trials', 20), + if intent.get("surrogate", {}).get("enabled"): + config["surrogate"] = { + "type": intent["surrogate"].get("type", "MLP"), + "min_trials": intent["surrogate"].get("min_trials", 20), } return config @@ -3140,24 +3211,24 @@ def intent_to_config(intent: dict, existing_config: Optional[dict] = None) -> di def _extractor_id_to_action(ext_id: str) -> str: """Map extractor IDs (E1, E2, etc.) to extraction action names.""" action_map = { - 'E1': 'extract_displacement', - 'E2': 'extract_frequency', - 'E3': 'extract_stress', - 'E4': 'extract_mass', - 'E5': 'extract_mass', - 'E8': 'extract_zernike', - 'E9': 'extract_zernike', - 'E10': 'extract_zernike', - 'displacement': 'extract_displacement', - 'frequency': 'extract_frequency', - 'stress': 'extract_stress', - 'mass': 'extract_mass', - 'mass_bdf': 'extract_mass', - 'mass_cad': 'extract_mass', - 'zernike': 'extract_zernike', - 'zernike_opd': 'extract_zernike', + "E1": "extract_displacement", + "E2": "extract_frequency", + "E3": "extract_stress", + "E4": "extract_mass", + "E5": "extract_mass", + "E8": "extract_zernike", + "E9": "extract_zernike", + "E10": "extract_zernike", + "displacement": "extract_displacement", + "frequency": "extract_frequency", + "stress": "extract_stress", + "mass": "extract_mass", + "mass_bdf": "extract_mass", + "mass_cad": "extract_mass", + "zernike": "extract_zernike", + "zernike_opd": "extract_zernike", } - return action_map.get(ext_id, 'extract_displacement') + return action_map.get(ext_id, "extract_displacement") @router.put("/studies/{study_id}/config") @@ -3186,7 +3257,7 @@ async def update_study_config(study_id: str, request: UpdateConfigRequest): if is_optimization_running(study_id): raise HTTPException( status_code=409, - detail="Cannot modify config while optimization is running. Stop the optimization first." + detail="Cannot modify config while optimization is running. Stop the optimization first.", ) # Find config file location @@ -3195,10 +3266,12 @@ async def update_study_config(study_id: str, request: UpdateConfigRequest): config_file = study_dir / "optimization_config.json" if not config_file.exists(): - raise HTTPException(status_code=404, detail=f"Config file not found for study {study_id}") + raise HTTPException( + status_code=404, detail=f"Config file not found for study {study_id}" + ) # Backup existing config - backup_file = config_file.with_suffix('.json.backup') + backup_file = config_file.with_suffix(".json.backup") shutil.copy(config_file, backup_file) # Determine which format was provided @@ -3207,24 +3280,23 @@ async def update_study_config(study_id: str, request: UpdateConfigRequest): new_config = request.config elif request.intent is not None: # Convert intent to config, merging with existing - with open(config_file, 'r') as f: + with open(config_file, "r") as f: existing_config = json.load(f) new_config = intent_to_config(request.intent, existing_config) else: raise HTTPException( - status_code=400, - detail="Request must include either 'config' or 'intent' field" + status_code=400, detail="Request must include either 'config' or 'intent' field" ) # Write new config - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(new_config, f, indent=2) return { "success": True, "message": "Configuration updated successfully", "path": str(config_file), - "backup_path": str(backup_file) + "backup_path": str(backup_file), } except HTTPException: @@ -3237,6 +3309,7 @@ async def update_study_config(study_id: str, request: UpdateConfigRequest): # Zernike Analysis Endpoints # ============================================================================ + @router.get("/studies/{study_id}/zernike-available") async def get_zernike_available_trials(study_id: str): """ @@ -3256,11 +3329,11 @@ async def get_zernike_available_trials(study_id: str): available_trials = [] for d in iter_base.iterdir(): - if d.is_dir() and d.name.startswith('iter'): + if d.is_dir() and d.name.startswith("iter"): # Check for OP2 file op2_files = list(d.glob("*.op2")) if op2_files: - iter_num_str = d.name.replace('iter', '') + iter_num_str = d.name.replace("iter", "") try: iter_num = int(iter_num_str) # Map iter number to trial number (iter1 -> trial 0, etc.) @@ -3276,7 +3349,7 @@ async def get_zernike_available_trials(study_id: str): return { "study_id": study_id, "available_trials": available_trials, - "count": len(available_trials) + "count": len(available_trials), } except HTTPException: @@ -3325,7 +3398,7 @@ async def get_trial_zernike(study_id: str, trial_number: int): if iter_dir is None: raise HTTPException( status_code=404, - detail=f"No FEA results for trial {trial_number}. This trial may have used surrogate model (NN) prediction instead of full FEA simulation. Zernike analysis requires OP2 results from actual FEA runs." + detail=f"No FEA results for trial {trial_number}. This trial may have used surrogate model (NN) prediction instead of full FEA simulation. Zernike analysis requires OP2 results from actual FEA runs.", ) # Check for OP2 file BEFORE doing expensive imports @@ -3333,7 +3406,7 @@ async def get_trial_zernike(study_id: str, trial_number: int): if not op2_files: raise HTTPException( status_code=404, - detail=f"No OP2 results file found in {iter_dir.name}. FEA may not have completed." + detail=f"No OP2 results file found in {iter_dir.name}. FEA may not have completed.", ) # Only import heavy dependencies after we know we have an OP2 file @@ -3356,12 +3429,15 @@ async def get_trial_zernike(study_id: str, trial_number: int): PANCAKE = 3.0 # Z-axis range multiplier PLOT_DOWNSAMPLE = 5000 # Reduced for faster loading FILTER_LOW_ORDERS = 4 - COLORSCALE = 'Turbo' # Colorscale: 'RdBu_r', 'Viridis', 'Plasma', 'Turbo' + COLORSCALE = "Turbo" # Colorscale: 'RdBu_r', 'Viridis', 'Plasma', 'Turbo' SUBCASE_MAP = { - '1': '90', '2': '20', '3': '40', '4': '60', + "1": "90", + "2": "20", + "3": "40", + "4": "60", } - REF_SUBCASE = '2' + REF_SUBCASE = "2" def noll_indices(j: int): if j < 1: @@ -3372,9 +3448,9 @@ async def get_trial_zernike(study_id: str, trial_number: int): if n == 0: ms = [0] elif n % 2 == 0: - ms = [0] + [m for k in range(1, n//2 + 1) for m in (-2*k, 2*k)] + ms = [0] + [m for k in range(1, n // 2 + 1) for m in (-2 * k, 2 * k)] else: - ms = [m for k in range(0, (n+1)//2) for m in (-(2*k+1), (2*k+1))] + ms = [m for k in range(0, (n + 1) // 2) for m in (-(2 * k + 1), (2 * k + 1))] for m in ms: count += 1 if count == j: @@ -3384,26 +3460,44 @@ async def get_trial_zernike(study_id: str, trial_number: int): def zernike_noll(j: int, r: np.ndarray, th: np.ndarray) -> np.ndarray: n, m = noll_indices(j) R = np.zeros_like(r) - for s in range((n-abs(m))//2 + 1): - c = ((-1)**s * factorial(n-s) / - (factorial(s) * - factorial((n+abs(m))//2 - s) * - factorial((n-abs(m))//2 - s))) - R += c * r**(n-2*s) + for s in range((n - abs(m)) // 2 + 1): + c = ( + (-1) ** s + * factorial(n - s) + / ( + factorial(s) + * factorial((n + abs(m)) // 2 - s) + * factorial((n - abs(m)) // 2 - s) + ) + ) + R += c * r ** (n - 2 * s) if m == 0: return R - return R * (np.cos(m*th) if m > 0 else np.sin(-m*th)) + return R * (np.cos(m * th) if m > 0 else np.sin(-m * th)) def zernike_common_name(n: int, m: int) -> str: names = { - (0, 0): "Piston", (1, -1): "Tilt X", (1, 1): "Tilt Y", - (2, 0): "Defocus", (2, -2): "Astig 45°", (2, 2): "Astig 0°", - (3, -1): "Coma X", (3, 1): "Coma Y", (3, -3): "Trefoil X", (3, 3): "Trefoil Y", - (4, 0): "Primary Spherical", (4, -2): "Sec Astig X", (4, 2): "Sec Astig Y", - (4, -4): "Quadrafoil X", (4, 4): "Quadrafoil Y", - (5, -1): "Sec Coma X", (5, 1): "Sec Coma Y", - (5, -3): "Sec Trefoil X", (5, 3): "Sec Trefoil Y", - (5, -5): "Pentafoil X", (5, 5): "Pentafoil Y", + (0, 0): "Piston", + (1, -1): "Tilt X", + (1, 1): "Tilt Y", + (2, 0): "Defocus", + (2, -2): "Astig 45°", + (2, 2): "Astig 0°", + (3, -1): "Coma X", + (3, 1): "Coma Y", + (3, -3): "Trefoil X", + (3, 3): "Trefoil Y", + (4, 0): "Primary Spherical", + (4, -2): "Sec Astig X", + (4, 2): "Sec Astig Y", + (4, -4): "Quadrafoil X", + (4, 4): "Quadrafoil Y", + (5, -1): "Sec Coma X", + (5, 1): "Sec Coma Y", + (5, -3): "Sec Trefoil X", + (5, 3): "Sec Trefoil Y", + (5, -5): "Pentafoil X", + (5, 5): "Pentafoil Y", (6, 0): "Sec Spherical", } return names.get((n, m), f"Z(n={n}, m={m})") @@ -3415,18 +3509,24 @@ async def get_trial_zernike(study_id: str, trial_number: int): def compute_manufacturing_metrics(coefficients: np.ndarray) -> dict: """Compute manufacturing-related aberration metrics.""" return { - 'defocus_nm': float(abs(coefficients[3])), # J4 - 'astigmatism_rms': float(np.sqrt(coefficients[4]**2 + coefficients[5]**2)), # J5+J6 - 'coma_rms': float(np.sqrt(coefficients[6]**2 + coefficients[7]**2)), # J7+J8 - 'trefoil_rms': float(np.sqrt(coefficients[8]**2 + coefficients[9]**2)), # J9+J10 - 'spherical_nm': float(abs(coefficients[10])) if len(coefficients) > 10 else 0.0, # J11 + "defocus_nm": float(abs(coefficients[3])), # J4 + "astigmatism_rms": float( + np.sqrt(coefficients[4] ** 2 + coefficients[5] ** 2) + ), # J5+J6 + "coma_rms": float(np.sqrt(coefficients[6] ** 2 + coefficients[7] ** 2)), # J7+J8 + "trefoil_rms": float( + np.sqrt(coefficients[8] ** 2 + coefficients[9] ** 2) + ), # J9+J10 + "spherical_nm": float(abs(coefficients[10])) + if len(coefficients) > 10 + else 0.0, # J11 } def compute_rms_filter_j1to3(X, Y, W_nm, coefficients, R): """Compute RMS with J1-J3 filtered (keeping defocus for optician workload).""" Xc = X - np.mean(X) Yc = Y - np.mean(Y) - r = np.hypot(Xc/R, Yc/R) + r = np.hypot(Xc / R, Yc / R) th = np.arctan2(Yc, Xc) Z_j1to3 = np.column_stack([zernike_noll(j, r, th) for j in range(1, 4)]) W_filter_j1to3 = W_nm - Z_j1to3 @ coefficients[:3] @@ -3441,20 +3541,20 @@ async def get_trial_zernike(study_id: str, trial_number: int): rms_global: float, rms_filtered: float, ref_title: str = "20 deg", - abs_pair = None, + abs_pair=None, is_manufacturing: bool = False, mfg_metrics: dict = None, - correction_metrics: dict = None + correction_metrics: dict = None, ) -> str: """Generate HTML string for Zernike visualization with full tables.""" # Compute residual surface (filtered) Xc = X - np.mean(X) Yc = Y - np.mean(Y) R = float(np.max(np.hypot(Xc, Yc))) - r = np.hypot(Xc/R, Yc/R) + r = np.hypot(Xc / R, Yc / R) th = np.arctan2(Yc, Xc) - Z = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)]) + Z = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES + 1)]) W_res_filt = W_nm - Z[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS]) # Compute J1-J3 filtered RMS (optician workload metric) @@ -3479,33 +3579,33 @@ async def get_trial_zernike(study_id: str, trial_number: int): if tri.triangles is not None and len(tri.triangles) > 0: i_idx, j_idx, k_idx = tri.triangles.T surface_trace = go.Mesh3d( - x=Xp.tolist(), y=Yp.tolist(), z=res_amp.tolist(), - i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(), + x=Xp.tolist(), + y=Yp.tolist(), + z=res_amp.tolist(), + i=i_idx.tolist(), + j=j_idx.tolist(), + k=k_idx.tolist(), intensity=res_amp.tolist(), colorscale=COLORSCALE, opacity=1.0, flatshading=False, # Smooth shading lighting=dict( - ambient=0.4, - diffuse=0.8, - specular=0.3, - roughness=0.5, - fresnel=0.2 + ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2 ), lightposition=dict(x=100, y=200, z=300), showscale=True, colorbar=dict( - title=dict(text="Residual (nm)", side='right'), + title=dict(text="Residual (nm)", side="right"), thickness=15, len=0.6, - tickformat=".1f" + tickformat=".1f", ), - hovertemplate="X: %{x:.1f}
Y: %{y:.1f}
Residual: %{z:.2f} nm" + hovertemplate="X: %{x:.1f}
Y: %{y:.1f}
Residual: %{z:.2f} nm", ) except Exception as e: print(f"Triangulation failed: {e}") - labels = [zernike_label(j) for j in range(1, N_MODES+1)] + labels = [zernike_label(j) for j in range(1, N_MODES + 1)] coeff_abs = np.abs(coefficients) mfg = compute_manufacturing_metrics(coefficients) @@ -3513,12 +3613,15 @@ async def get_trial_zernike(study_id: str, trial_number: int): if is_manufacturing and mfg_metrics and correction_metrics: # Manufacturing view: 5 rows fig = make_subplots( - rows=5, cols=1, - specs=[[{"type": "scene"}], - [{"type": "table"}], - [{"type": "table"}], - [{"type": "table"}], - [{"type": "xy"}]], + rows=5, + cols=1, + specs=[ + [{"type": "scene"}], + [{"type": "table"}], + [{"type": "table"}], + [{"type": "table"}], + [{"type": "xy"}], + ], row_heights=[0.35, 0.10, 0.15, 0.15, 0.25], vertical_spacing=0.025, subplot_titles=[ @@ -3526,25 +3629,28 @@ async def get_trial_zernike(study_id: str, trial_number: int): "RMS Metrics", "Mode Magnitudes (Absolute 90 deg)", "Pre-Correction (90 deg - 20 deg)", - f"Zernike Coefficients ({N_MODES} modes)" - ] + f"Zernike Coefficients ({N_MODES} modes)", + ], ) else: # Standard relative view: 4 rows with full coefficient table fig = make_subplots( - rows=4, cols=1, - specs=[[{"type": "scene"}], - [{"type": "table"}], - [{"type": "table"}], - [{"type": "xy"}]], + rows=4, + cols=1, + specs=[ + [{"type": "scene"}], + [{"type": "table"}], + [{"type": "table"}], + [{"type": "xy"}], + ], row_heights=[0.40, 0.12, 0.28, 0.20], vertical_spacing=0.03, subplot_titles=[ f"Surface Residual (relative to {ref_title})", "RMS Metrics", f"Zernike Coefficients ({N_MODES} modes)", - "Top 20 |Zernike Coefficients| (nm)" - ] + "Top 20 |Zernike Coefficients| (nm)", + ], ) # Add surface mesh (or fallback to scatter) @@ -3552,123 +3658,181 @@ async def get_trial_zernike(study_id: str, trial_number: int): fig.add_trace(surface_trace, row=1, col=1) else: # Fallback to scatter if triangulation failed - fig.add_trace(go.Scatter3d( - x=Xp.tolist(), y=Yp.tolist(), z=res_amp.tolist(), - mode='markers', - marker=dict(size=2, color=res_amp.tolist(), colorscale=COLORSCALE, showscale=True), - showlegend=False - ), row=1, col=1) + fig.add_trace( + go.Scatter3d( + x=Xp.tolist(), + y=Yp.tolist(), + z=res_amp.tolist(), + mode="markers", + marker=dict( + size=2, color=res_amp.tolist(), colorscale=COLORSCALE, showscale=True + ), + showlegend=False, + ), + row=1, + col=1, + ) fig.update_scenes( - camera=dict( - eye=dict(x=1.2, y=1.2, z=0.8), - up=dict(x=0, y=0, z=1) - ), + camera=dict(eye=dict(x=1.2, y=1.2, z=0.8), up=dict(x=0, y=0, z=1)), xaxis=dict( title="X (mm)", showgrid=True, - gridcolor='rgba(128,128,128,0.3)', + gridcolor="rgba(128,128,128,0.3)", showbackground=True, - backgroundcolor='rgba(240,240,240,0.9)' + backgroundcolor="rgba(240,240,240,0.9)", ), yaxis=dict( title="Y (mm)", showgrid=True, - gridcolor='rgba(128,128,128,0.3)', + gridcolor="rgba(128,128,128,0.3)", showbackground=True, - backgroundcolor='rgba(240,240,240,0.9)' + backgroundcolor="rgba(240,240,240,0.9)", ), zaxis=dict( title="Residual (nm)", range=[-max_amp * PANCAKE, max_amp * PANCAKE], showgrid=True, - gridcolor='rgba(128,128,128,0.3)', + gridcolor="rgba(128,128,128,0.3)", showbackground=True, - backgroundcolor='rgba(230,230,250,0.9)' + backgroundcolor="rgba(230,230,250,0.9)", ), - aspectmode='manual', - aspectratio=dict(x=1, y=1, z=0.4) + aspectmode="manual", + aspectratio=dict(x=1, y=1, z=0.4), ) # Row 2: RMS table with all metrics if abs_pair is not None: abs_global, abs_filtered = abs_pair - fig.add_trace(go.Table( - header=dict( - values=["Metric", "Relative (nm)", "Absolute (nm)"], - align="left", - fill_color='rgb(55, 83, 109)', - font=dict(color='white', size=12) + fig.add_trace( + go.Table( + header=dict( + values=[ + "Metric", + "Relative (nm)", + "Absolute (nm)", + ], + align="left", + fill_color="rgb(55, 83, 109)", + font=dict(color="white", size=12), + ), + cells=dict( + values=[ + [ + "Global RMS", + "Filtered RMS (J1-J4)", + "Filtered RMS (J1-J3, w/ defocus)", + ], + [ + f"{rms_global:.2f}", + f"{rms_filtered:.2f}", + f"{rms_filter_j1to3:.2f}", + ], + [f"{abs_global:.2f}", f"{abs_filtered:.2f}", "-"], + ], + align="left", + fill_color="rgb(243, 243, 243)", + ), ), - cells=dict( - values=[ - ["Global RMS", "Filtered RMS (J1-J4)", "Filtered RMS (J1-J3, w/ defocus)"], - [f"{rms_global:.2f}", f"{rms_filtered:.2f}", f"{rms_filter_j1to3:.2f}"], - [f"{abs_global:.2f}", f"{abs_filtered:.2f}", "-"], - ], - align="left", - fill_color='rgb(243, 243, 243)' - ) - ), row=2, col=1) + row=2, + col=1, + ) else: - fig.add_trace(go.Table( - header=dict( - values=["Metric", "Value (nm)"], - align="left", - fill_color='rgb(55, 83, 109)', - font=dict(color='white', size=12) + fig.add_trace( + go.Table( + header=dict( + values=["Metric", "Value (nm)"], + align="left", + fill_color="rgb(55, 83, 109)", + font=dict(color="white", size=12), + ), + cells=dict( + values=[ + [ + "Global RMS", + "Filtered RMS (J1-J4)", + "Filtered RMS (J1-J3, w/ defocus)", + ], + [ + f"{rms_global:.2f}", + f"{rms_filtered:.2f}", + f"{rms_filter_j1to3:.2f}", + ], + ], + align="left", + fill_color="rgb(243, 243, 243)", + ), ), - cells=dict( - values=[ - ["Global RMS", "Filtered RMS (J1-J4)", "Filtered RMS (J1-J3, w/ defocus)"], - [f"{rms_global:.2f}", f"{rms_filtered:.2f}", f"{rms_filter_j1to3:.2f}"] - ], - align="left", - fill_color='rgb(243, 243, 243)' - ) - ), row=2, col=1) + row=2, + col=1, + ) if is_manufacturing and mfg_metrics and correction_metrics: # Row 3: Mode magnitudes at 90 deg (absolute) - fig.add_trace(go.Table( - header=dict( - values=["Mode", "Value (nm)"], - align="left", - fill_color='rgb(55, 83, 109)', - font=dict(color='white', size=11) + fig.add_trace( + go.Table( + header=dict( + values=["Mode", "Value (nm)"], + align="left", + fill_color="rgb(55, 83, 109)", + font=dict(color="white", size=11), + ), + cells=dict( + values=[ + [ + "Defocus (J4)", + "Astigmatism (J5+J6)", + "Coma (J7+J8)", + "Trefoil (J9+J10)", + "Spherical (J11)", + ], + [ + f"{mfg_metrics['defocus_nm']:.2f}", + f"{mfg_metrics['astigmatism_rms']:.2f}", + f"{mfg_metrics['coma_rms']:.2f}", + f"{mfg_metrics['trefoil_rms']:.2f}", + f"{mfg_metrics['spherical_nm']:.2f}", + ], + ], + align="left", + fill_color="rgb(243, 243, 243)", + ), ), - cells=dict( - values=[ - ["Defocus (J4)", "Astigmatism (J5+J6)", "Coma (J7+J8)", "Trefoil (J9+J10)", "Spherical (J11)"], - [f"{mfg_metrics['defocus_nm']:.2f}", f"{mfg_metrics['astigmatism_rms']:.2f}", - f"{mfg_metrics['coma_rms']:.2f}", f"{mfg_metrics['trefoil_rms']:.2f}", - f"{mfg_metrics['spherical_nm']:.2f}"] - ], - align="left", - fill_color='rgb(243, 243, 243)' - ) - ), row=3, col=1) + row=3, + col=1, + ) # Row 4: Pre-correction (90 deg - 20 deg) - fig.add_trace(go.Table( - header=dict( - values=["Correction Mode", "Value (nm)"], - align="left", - fill_color='rgb(55, 83, 109)', - font=dict(color='white', size=11) + fig.add_trace( + go.Table( + header=dict( + values=["Correction Mode", "Value (nm)"], + align="left", + fill_color="rgb(55, 83, 109)", + font=dict(color="white", size=11), + ), + cells=dict( + values=[ + [ + "Total RMS (J1-J3 filter)", + "Defocus (J4)", + "Astigmatism (J5+J6)", + "Coma (J7+J8)", + ], + [ + f"{correction_metrics.get('rms_filter_j1to3', 0):.2f}", + f"{correction_metrics['defocus_nm']:.2f}", + f"{correction_metrics['astigmatism_rms']:.2f}", + f"{correction_metrics['coma_rms']:.2f}", + ], + ], + align="left", + fill_color="rgb(243, 243, 243)", + ), ), - cells=dict( - values=[ - ["Total RMS (J1-J3 filter)", "Defocus (J4)", "Astigmatism (J5+J6)", "Coma (J7+J8)"], - [f"{correction_metrics.get('rms_filter_j1to3', 0):.2f}", - f"{correction_metrics['defocus_nm']:.2f}", - f"{correction_metrics['astigmatism_rms']:.2f}", - f"{correction_metrics['coma_rms']:.2f}"] - ], - align="left", - fill_color='rgb(243, 243, 243)' - ) - ), row=4, col=1) + row=4, + col=1, + ) # Row 5: Bar chart sorted_idx = np.argsort(coeff_abs)[::-1][:20] @@ -3676,35 +3840,45 @@ async def get_trial_zernike(study_id: str, trial_number: int): go.Bar( x=[float(coeff_abs[i]) for i in sorted_idx], y=[labels[i] for i in sorted_idx], - orientation='h', - marker_color='rgb(55, 83, 109)', + orientation="h", + marker_color="rgb(55, 83, 109)", hovertemplate="%{y}
|Coeff| = %{x:.3f} nm", - showlegend=False + showlegend=False, ), - row=5, col=1 + row=5, + col=1, ) else: # Row 3: Full coefficient table - fig.add_trace(go.Table( - header=dict( - values=["Noll j", "Mode Name", "Coeff (nm)", "|Coeff| (nm)"], - align="left", - fill_color='rgb(55, 83, 109)', - font=dict(color='white', size=11) + fig.add_trace( + go.Table( + header=dict( + values=[ + "Noll j", + "Mode Name", + "Coeff (nm)", + "|Coeff| (nm)", + ], + align="left", + fill_color="rgb(55, 83, 109)", + font=dict(color="white", size=11), + ), + cells=dict( + values=[ + list(range(1, N_MODES + 1)), + labels, + [f"{c:+.3f}" for c in coefficients], + [f"{abs(c):.3f}" for c in coefficients], + ], + align="left", + fill_color="rgb(243, 243, 243)", + font=dict(size=10), + height=22, + ), ), - cells=dict( - values=[ - list(range(1, N_MODES+1)), - labels, - [f"{c:+.3f}" for c in coefficients], - [f"{abs(c):.3f}" for c in coefficients] - ], - align="left", - fill_color='rgb(243, 243, 243)', - font=dict(size=10), - height=22 - ) - ), row=3, col=1) + row=3, + col=1, + ) # Row 4: Bar chart - top 20 modes by magnitude sorted_idx = np.argsort(coeff_abs)[::-1][:20] @@ -3712,28 +3886,25 @@ async def get_trial_zernike(study_id: str, trial_number: int): go.Bar( x=[float(coeff_abs[i]) for i in sorted_idx], y=[labels[i] for i in sorted_idx], - orientation='h', - marker_color='rgb(55, 83, 109)', + orientation="h", + marker_color="rgb(55, 83, 109)", hovertemplate="%{y}
|Coeff| = %{x:.3f} nm", - showlegend=False + showlegend=False, ), - row=4, col=1 + row=4, + col=1, ) fig.update_layout( width=1400, height=1800 if is_manufacturing else 1600, margin=dict(t=80, b=20, l=20, r=20), - title=dict( - text=f"{title}", - font=dict(size=20), - x=0.5 - ), - paper_bgcolor='white', - plot_bgcolor='white' + title=dict(text=f"{title}", font=dict(size=20), x=0.5), + paper_bgcolor="white", + plot_bgcolor="white", ) - return fig.to_html(include_plotlyjs='cdn', full_html=True) + return fig.to_html(include_plotlyjs="cdn", full_html=True) # ===================================================================== # NEW: Use OPD method (accounts for lateral X,Y displacement) @@ -3748,20 +3919,17 @@ async def get_trial_zernike(study_id: str, trial_number: int): str(op2_path), bdf_path=str(bdf_path), n_modes=N_MODES, - filter_orders=FILTER_LOW_ORDERS + filter_orders=FILTER_LOW_ORDERS, ) except Exception as e: print(f"OPD extractor failed, falling back to Standard: {e}") use_opd = False # Also create Standard extractor for comparison - std_extractor = ZernikeExtractor(str(op2_path), displacement_unit='mm', n_modes=N_MODES) + std_extractor = ZernikeExtractor(str(op2_path), displacement_unit="mm", n_modes=N_MODES) def generate_dual_method_html( - title: str, - target_sc: str, - ref_sc: str, - is_manufacturing: bool = False + title: str, target_sc: str, ref_sc: str, is_manufacturing: bool = False ) -> tuple: """Generate HTML with OPD method and displacement component views. @@ -3788,25 +3956,25 @@ async def get_trial_zernike(study_id: str, trial_number: int): opd_ref_data = opd_extractor._build_figure_opd_data(ref_sc) # Build relative displacement arrays (node-by-node) - ref_map = {int(nid): i for i, nid in enumerate(opd_ref_data['node_ids'])} + ref_map = {int(nid): i for i, nid in enumerate(opd_ref_data["node_ids"])} X_list, Y_list, WFE_list = [], [], [] dx_list, dy_list, dz_list = [], [], [] - for i, nid in enumerate(opd_data['node_ids']): + for i, nid in enumerate(opd_data["node_ids"]): nid = int(nid) if nid not in ref_map: continue ref_idx = ref_map[nid] # Use deformed coordinates from OPD - X_list.append(opd_data['x_deformed'][i]) - Y_list.append(opd_data['y_deformed'][i]) - WFE_list.append(opd_data['wfe_nm'][i] - opd_ref_data['wfe_nm'][ref_idx]) + X_list.append(opd_data["x_deformed"][i]) + Y_list.append(opd_data["y_deformed"][i]) + WFE_list.append(opd_data["wfe_nm"][i] - opd_ref_data["wfe_nm"][ref_idx]) # Relative displacements (target - reference) - dx_list.append(opd_data['dx'][i] - opd_ref_data['dx'][ref_idx]) - dy_list.append(opd_data['dy'][i] - opd_ref_data['dy'][ref_idx]) - dz_list.append(opd_data['dz'][i] - opd_ref_data['dz'][ref_idx]) + dx_list.append(opd_data["dx"][i] - opd_ref_data["dx"][ref_idx]) + dy_list.append(opd_data["dy"][i] - opd_ref_data["dy"][ref_idx]) + dz_list.append(opd_data["dz"][i] - opd_ref_data["dz"][ref_idx]) X = np.array(X_list) Y = np.array(Y_list) @@ -3820,19 +3988,19 @@ async def get_trial_zernike(study_id: str, trial_number: int): max_lateral = float(np.max(np.abs(lateral_um))) rms_lateral = float(np.sqrt(np.mean(lateral_um**2))) - rms_global_opd = opd_rel['relative_global_rms_nm'] - rms_filtered_opd = opd_rel['relative_filtered_rms_nm'] - coefficients = np.array(opd_rel.get('delta_coefficients', std_rel['coefficients'])) + rms_global_opd = opd_rel["relative_global_rms_nm"] + rms_filtered_opd = opd_rel["relative_filtered_rms_nm"] + coefficients = np.array(opd_rel.get("delta_coefficients", std_rel["coefficients"])) else: # Fallback to Standard method arrays target_disp = std_extractor.displacements[target_sc] ref_disp = std_extractor.displacements[ref_sc] - ref_map = {int(nid): i for i, nid in enumerate(ref_disp['node_ids'])} + ref_map = {int(nid): i for i, nid in enumerate(ref_disp["node_ids"])} X_list, Y_list, W_list = [], [], [] dx_list, dy_list, dz_list = [], [], [] - for i, nid in enumerate(target_disp['node_ids']): + for i, nid in enumerate(target_disp["node_ids"]): nid = int(nid) if nid not in ref_map: continue @@ -3844,14 +4012,20 @@ async def get_trial_zernike(study_id: str, trial_number: int): X_list.append(geo[0]) Y_list.append(geo[1]) - target_wfe = target_disp['disp'][i, 2] * std_extractor.wfe_factor - ref_wfe = ref_disp['disp'][ref_idx, 2] * std_extractor.wfe_factor + target_wfe = target_disp["disp"][i, 2] * std_extractor.wfe_factor + ref_wfe = ref_disp["disp"][ref_idx, 2] * std_extractor.wfe_factor W_list.append(target_wfe - ref_wfe) # Relative displacements (mm to µm) - dx_list.append((target_disp['disp'][i, 0] - ref_disp['disp'][ref_idx, 0]) * 1000.0) - dy_list.append((target_disp['disp'][i, 1] - ref_disp['disp'][ref_idx, 1]) * 1000.0) - dz_list.append((target_disp['disp'][i, 2] - ref_disp['disp'][ref_idx, 2]) * 1000.0) + dx_list.append( + (target_disp["disp"][i, 0] - ref_disp["disp"][ref_idx, 0]) * 1000.0 + ) + dy_list.append( + (target_disp["disp"][i, 1] - ref_disp["disp"][ref_idx, 1]) * 1000.0 + ) + dz_list.append( + (target_disp["disp"][i, 2] - ref_disp["disp"][ref_idx, 2]) * 1000.0 + ) X = np.array(X_list) Y = np.array(Y_list) @@ -3864,22 +4038,22 @@ async def get_trial_zernike(study_id: str, trial_number: int): max_lateral = float(np.max(np.abs(lateral_um))) rms_lateral = float(np.sqrt(np.mean(lateral_um**2))) - rms_global_opd = std_rel['relative_global_rms_nm'] - rms_filtered_opd = std_rel['relative_filtered_rms_nm'] - coefficients = np.array(std_rel['coefficients']) + rms_global_opd = std_rel["relative_global_rms_nm"] + rms_filtered_opd = std_rel["relative_filtered_rms_nm"] + coefficients = np.array(std_rel["coefficients"]) # Standard method RMS values - rms_global_std = std_rel['relative_global_rms_nm'] - rms_filtered_std = std_rel['relative_filtered_rms_nm'] + rms_global_std = std_rel["relative_global_rms_nm"] + rms_filtered_std = std_rel["relative_filtered_rms_nm"] # Compute residual surface Xc = X - np.mean(X) Yc = Y - np.mean(Y) R = float(np.max(np.hypot(Xc, Yc))) - r = np.hypot(Xc/R, Yc/R) + r = np.hypot(Xc / R, Yc / R) th = np.arctan2(Yc, Xc) - Z_basis = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES+1)]) + Z_basis = np.column_stack([zernike_noll(j, r, th) for j in range(1, N_MODES + 1)]) W_res_filt = W - Z_basis[:, :FILTER_LOW_ORDERS].dot(coefficients[:FILTER_LOW_ORDERS]) # Downsample for display @@ -3904,30 +4078,41 @@ async def get_trial_zernike(study_id: str, trial_number: int): if tri.triangles is not None and len(tri.triangles) > 0: i_idx, j_idx, k_idx = tri.triangles.T return go.Mesh3d( - x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(), - i=i_idx.tolist(), j=j_idx.tolist(), k=k_idx.tolist(), + x=Xp.tolist(), + y=Yp.tolist(), + z=Zp.tolist(), + i=i_idx.tolist(), + j=j_idx.tolist(), + k=k_idx.tolist(), intensity=Zp.tolist(), colorscale=colorscale, opacity=1.0, flatshading=False, - lighting=dict(ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2), + lighting=dict( + ambient=0.4, diffuse=0.8, specular=0.3, roughness=0.5, fresnel=0.2 + ), lightposition=dict(x=100, y=200, z=300), showscale=True, - colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15, len=0.5), - hovertemplate=f"X: %{{x:.1f}}
Y: %{{y:.1f}}
{unit}: %{{z:.3f}}" + colorbar=dict( + title=dict(text=colorbar_title, side="right"), thickness=15, len=0.5 + ), + hovertemplate=f"X: %{{x:.1f}}
Y: %{{y:.1f}}
{unit}: %{{z:.3f}}", ) except Exception: pass return go.Scatter3d( - x=Xp.tolist(), y=Yp.tolist(), z=Zp.tolist(), - mode='markers', marker=dict(size=2, color=Zp.tolist(), colorscale=colorscale, showscale=True) + x=Xp.tolist(), + y=Yp.tolist(), + z=Zp.tolist(), + mode="markers", + marker=dict(size=2, color=Zp.tolist(), colorscale=colorscale, showscale=True), ) # Build traces for each view trace_wfe = build_mesh_trace(res_amp, COLORSCALE, "WFE (nm)", "WFE nm") - trace_dx = build_mesh_trace(dxp, 'RdBu_r', "ΔX (µm)", "ΔX µm") - trace_dy = build_mesh_trace(dyp, 'RdBu_r', "ΔY (µm)", "ΔY µm") - trace_dz = build_mesh_trace(dzp, 'RdBu_r', "ΔZ (µm)", "ΔZ µm") + trace_dx = build_mesh_trace(dxp, "RdBu_r", "ΔX (µm)", "ΔX µm") + trace_dy = build_mesh_trace(dyp, "RdBu_r", "ΔY (µm)", "ΔY µm") + trace_dz = build_mesh_trace(dzp, "RdBu_r", "ΔZ (µm)", "ΔZ µm") # Create figure with dropdown to switch views fig = go.Figure() @@ -3949,18 +4134,31 @@ async def get_trial_zernike(study_id: str, trial_number: int): dict( type="buttons", direction="right", - x=0.0, y=1.12, + x=0.0, + y=1.12, xanchor="left", showactive=True, buttons=[ - dict(label="WFE (nm)", method="update", - args=[{"visible": [True, False, False, False]}]), - dict(label="ΔX (µm)", method="update", - args=[{"visible": [False, True, False, False]}]), - dict(label="ΔY (µm)", method="update", - args=[{"visible": [False, False, True, False]}]), - dict(label="ΔZ (µm)", method="update", - args=[{"visible": [False, False, False, True]}]), + dict( + label="WFE (nm)", + method="update", + args=[{"visible": [True, False, False, False]}], + ), + dict( + label="ΔX (µm)", + method="update", + args=[{"visible": [False, True, False, False]}], + ), + dict( + label="ΔY (µm)", + method="update", + args=[{"visible": [False, False, True, False]}], + ), + dict( + label="ΔZ (µm)", + method="update", + args=[{"visible": [False, False, False, True]}], + ), ], font=dict(size=12), pad=dict(r=10, t=10), @@ -3969,12 +4167,16 @@ async def get_trial_zernike(study_id: str, trial_number: int): ) # Compute method difference - pct_diff = 100.0 * (rms_filtered_opd - rms_filtered_std) / rms_filtered_std if rms_filtered_std > 0 else 0.0 + pct_diff = ( + 100.0 * (rms_filtered_opd - rms_filtered_std) / rms_filtered_std + if rms_filtered_std > 0 + else 0.0 + ) # Annotations for metrics method_label = "OPD (X,Y,Z)" if use_opd else "Standard (Z-only)" annotations_text = f""" -Method: {method_label} {'← More Accurate' if use_opd else '(BDF not found)'} +Method: {method_label} {"← More Accurate" if use_opd else "(BDF not found)"} RMS Metrics (Filtered J1-J4): • OPD: {rms_filtered_opd:.2f} nm • Standard: {rms_filtered_std:.2f} nm @@ -3989,19 +4191,39 @@ async def get_trial_zernike(study_id: str, trial_number: int): """ # Z-axis range for different views - max_disp = max(float(np.max(np.abs(dxp))), float(np.max(np.abs(dyp))), float(np.max(np.abs(dzp))), 0.1) + max_disp = max( + float(np.max(np.abs(dxp))), + float(np.max(np.abs(dyp))), + float(np.max(np.abs(dzp))), + 0.1, + ) fig.update_layout( scene=dict( camera=dict(eye=dict(x=1.2, y=1.2, z=0.8), up=dict(x=0, y=0, z=1)), - xaxis=dict(title="X (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)', - showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'), - yaxis=dict(title="Y (mm)", showgrid=True, gridcolor='rgba(128,128,128,0.3)', - showbackground=True, backgroundcolor='rgba(240,240,240,0.9)'), - zaxis=dict(title="Value", showgrid=True, gridcolor='rgba(128,128,128,0.3)', - showbackground=True, backgroundcolor='rgba(230,230,250,0.9)'), - aspectmode='manual', - aspectratio=dict(x=1, y=1, z=0.4) + xaxis=dict( + title="X (mm)", + showgrid=True, + gridcolor="rgba(128,128,128,0.3)", + showbackground=True, + backgroundcolor="rgba(240,240,240,0.9)", + ), + yaxis=dict( + title="Y (mm)", + showgrid=True, + gridcolor="rgba(128,128,128,0.3)", + showbackground=True, + backgroundcolor="rgba(240,240,240,0.9)", + ), + zaxis=dict( + title="Value", + showgrid=True, + gridcolor="rgba(128,128,128,0.3)", + showbackground=True, + backgroundcolor="rgba(230,230,250,0.9)", + ), + aspectmode="manual", + aspectratio=dict(x=1, y=1, z=0.4), ), width=1400, height=900, @@ -4009,51 +4231,62 @@ async def get_trial_zernike(study_id: str, trial_number: int): title=dict( text=f"{title}
Click buttons to switch: WFE, ΔX, ΔY, ΔZ", font=dict(size=18), - x=0.5 + x=0.5, ), - paper_bgcolor='white', - plot_bgcolor='white', + paper_bgcolor="white", + plot_bgcolor="white", annotations=[ dict( - text=annotations_text.replace('\n', '
'), - xref="paper", yref="paper", - x=1.02, y=0.98, - xanchor="left", yanchor="top", + text=annotations_text.replace("\n", "
"), + xref="paper", + yref="paper", + x=1.02, + y=0.98, + xanchor="left", + yanchor="top", showarrow=False, font=dict(family="monospace", size=11), align="left", bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.3)", borderwidth=1, - borderpad=8 + borderpad=8, ), dict( text="View:", - xref="paper", yref="paper", - x=0.0, y=1.15, - xanchor="left", yanchor="top", + xref="paper", + yref="paper", + x=0.0, + y=1.15, + xanchor="left", + yanchor="top", showarrow=False, - font=dict(size=12) - ) - ] + font=dict(size=12), + ), + ], ) - html_content = fig.to_html(include_plotlyjs='cdn', full_html=True) + html_content = fig.to_html(include_plotlyjs="cdn", full_html=True) - return (html_content, rms_global_opd, rms_filtered_opd, { - 'max_lateral_um': max_lateral, - 'rms_lateral_um': rms_lateral, - 'method': 'opd' if use_opd else 'standard', - 'rms_std': rms_filtered_std, - 'pct_diff': pct_diff - }) + return ( + html_content, + rms_global_opd, + rms_filtered_opd, + { + "max_lateral_um": max_lateral, + "rms_lateral_um": rms_lateral, + "method": "opd" if use_opd else "standard", + "rms_std": rms_filtered_std, + "pct_diff": pct_diff, + }, + ) # Generate results for each comparison results = {} comparisons = [ - ('3', '2', '40_vs_20', '40 deg vs 20 deg'), - ('4', '2', '60_vs_20', '60 deg vs 20 deg'), - ('1', '2', '90_vs_20', '90 deg vs 20 deg (manufacturing)'), + ("3", "2", "40_vs_20", "40 deg vs 20 deg"), + ("4", "2", "60_vs_20", "60 deg vs 20 deg"), + ("1", "2", "90_vs_20", "90 deg vs 20 deg (manufacturing)"), ] for target_sc, ref_sc, key, title_suffix in comparisons: @@ -4062,13 +4295,13 @@ async def get_trial_zernike(study_id: str, trial_number: int): target_angle = SUBCASE_MAP.get(target_sc, target_sc) ref_angle = SUBCASE_MAP.get(ref_sc, ref_sc) - is_mfg = (key == '90_vs_20') + is_mfg = key == "90_vs_20" html_content, rms_global, rms_filtered, lateral_stats = generate_dual_method_html( title=f"iter{trial_number}: {target_angle}° vs {ref_angle}°", target_sc=target_sc, ref_sc=ref_sc, - is_manufacturing=is_mfg + is_manufacturing=is_mfg, ) results[key] = { @@ -4076,17 +4309,17 @@ async def get_trial_zernike(study_id: str, trial_number: int): "rms_global": rms_global, "rms_filtered": rms_filtered, "title": f"{target_angle}° vs {ref_angle}°", - "method": lateral_stats['method'], - "rms_std": lateral_stats['rms_std'], - "pct_diff": lateral_stats['pct_diff'], - "max_lateral_um": lateral_stats['max_lateral_um'], - "rms_lateral_um": lateral_stats['rms_lateral_um'] + "method": lateral_stats["method"], + "rms_std": lateral_stats["rms_std"], + "pct_diff": lateral_stats["pct_diff"], + "max_lateral_um": lateral_stats["max_lateral_um"], + "rms_lateral_um": lateral_stats["rms_lateral_um"], } if not results: raise HTTPException( status_code=500, - detail="Failed to generate Zernike analysis. Check if subcases are available." + detail="Failed to generate Zernike analysis. Check if subcases are available.", ) return { @@ -4094,15 +4327,18 @@ async def get_trial_zernike(study_id: str, trial_number: int): "trial_number": trial_number, "comparisons": results, "available_comparisons": list(results.keys()), - "method": "opd" if use_opd else "standard" + "method": "opd" if use_opd else "standard", } except HTTPException: raise except Exception as e: import traceback + traceback.print_exc() - raise HTTPException(status_code=500, detail=f"Failed to generate Zernike analysis: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to generate Zernike analysis: {str(e)}" + ) @router.get("/studies/{study_id}/export/{format}") @@ -4134,44 +4370,54 @@ async def export_study_data(study_id: str, format: str): trials_data = [] for row in cursor.fetchall(): - trial_id = row['trial_id'] + trial_id = row["trial_id"] # Get params - cursor.execute(""" + cursor.execute( + """ SELECT param_name, param_value FROM trial_params WHERE trial_id = ? - """, (trial_id,)) - params = {r['param_name']: r['param_value'] for r in cursor.fetchall()} + """, + (trial_id,), + ) + params = {r["param_name"]: r["param_value"] for r in cursor.fetchall()} # Get user attrs - cursor.execute(""" + cursor.execute( + """ SELECT key, value_json FROM trial_user_attributes WHERE trial_id = ? - """, (trial_id,)) + """, + (trial_id,), + ) user_attrs = {} for r in cursor.fetchall(): try: - user_attrs[r['key']] = json.loads(r['value_json']) + user_attrs[r["key"]] = json.loads(r["value_json"]) except: - user_attrs[r['key']] = r['value_json'] + user_attrs[r["key"]] = r["value_json"] - trials_data.append({ - "trial_number": row['number'], - "objective": row['objective'], - "params": params, - "user_attrs": user_attrs - }) + trials_data.append( + { + "trial_number": row["number"], + "objective": row["objective"], + "params": params, + "user_attrs": user_attrs, + } + ) conn.close() if format.lower() == "json": - return JSONResponse(content={ - "study_id": study_id, - "total_trials": len(trials_data), - "trials": trials_data - }) + return JSONResponse( + content={ + "study_id": study_id, + "total_trials": len(trials_data), + "trials": trials_data, + } + ) elif format.lower() == "csv": import io @@ -4184,30 +4430,28 @@ async def export_study_data(study_id: str, format: str): output = io.StringIO() # Get all param names - param_names = sorted(set( - key for trial in trials_data - for key in trial['params'].keys() - )) + param_names = sorted( + set(key for trial in trials_data for key in trial["params"].keys()) + ) - fieldnames = ['trial_number', 'objective'] + param_names + fieldnames = ["trial_number", "objective"] + param_names writer = csv.DictWriter(output, fieldnames=fieldnames) writer.writeheader() for trial in trials_data: - row_data = { - 'trial_number': trial['trial_number'], - 'objective': trial['objective'] - } - row_data.update(trial['params']) + row_data = {"trial_number": trial["trial_number"], "objective": trial["objective"]} + row_data.update(trial["params"]) writer.writerow(row_data) csv_content = output.getvalue() - return JSONResponse(content={ - "filename": f"{study_id}_data.csv", - "content": csv_content, - "content_type": "text/csv" - }) + return JSONResponse( + content={ + "filename": f"{study_id}_data.csv", + "content": csv_content, + "content_type": "text/csv", + } + ) elif format.lower() == "config": # Export optimization config @@ -4215,13 +4459,15 @@ async def export_study_data(study_id: str, format: str): config_path = setup_dir / "optimization_config.json" if config_path.exists(): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) - return JSONResponse(content={ - "filename": f"{study_id}_config.json", - "content": json.dumps(config, indent=2), - "content_type": "application/json" - }) + return JSONResponse( + content={ + "filename": f"{study_id}_config.json", + "content": json.dumps(config, indent=2), + "content_type": "application/json", + } + ) else: raise HTTPException(status_code=404, detail="Config file not found") @@ -4238,6 +4484,7 @@ async def export_study_data(study_id: str, format: str): # NX Model Introspection Endpoints # ============================================================================ + @router.get("/studies/{study_id}/nx/introspect") async def introspect_nx_model(study_id: str, force: bool = False): """ @@ -4255,21 +4502,40 @@ async def introspect_nx_model(study_id: str, force: bool = False): """ try: study_dir = resolve_study_path(study_id) + print(f"[introspect] study_id={study_id}, study_dir={study_dir}") if not study_dir.exists(): raise HTTPException(status_code=404, detail=f"Study {study_id} not found") - # Find model directory - model_dir = study_dir / "1_model" - if not model_dir.exists(): - model_dir = study_dir / "0_model" - if not model_dir.exists(): - raise HTTPException(status_code=404, detail=f"Model directory not found for {study_id}") + # Find model directory - check multiple possible locations + model_dir = None + possible_dirs = [ + study_dir / "1_setup" / "model", # Standard Atomizer structure + study_dir / "1_model", + study_dir / "0_model", + study_dir / "model", + study_dir / "1_setup", # Model files directly in 1_setup + ] + for possible_dir in possible_dirs: + if possible_dir.exists(): + prt_files = list(possible_dir.glob("*.prt")) + print( + f"[introspect] checking {possible_dir}: exists=True, prt_count={len(prt_files)}" + ) + if prt_files: + model_dir = possible_dir + break + else: + print(f"[introspect] checking {possible_dir}: exists=False") + + if model_dir is None: + detail = f"[V2] No model dir for {study_id} in {study_dir}" + raise HTTPException(status_code=404, detail=detail) # Find .prt file prt_files = list(model_dir.glob("*.prt")) # Exclude idealized parts - prt_files = [f for f in prt_files if '_i.prt' not in f.name.lower()] + prt_files = [f for f in prt_files if "_i.prt" not in f.name.lower()] if not prt_files: raise HTTPException(status_code=404, detail=f"No .prt files found in {model_dir}") @@ -4280,43 +4546,30 @@ async def introspect_nx_model(study_id: str, force: bool = False): cache_file = model_dir / "_introspection_cache.json" if cache_file.exists() and not force: try: - with open(cache_file, 'r') as f: + with open(cache_file, "r") as f: cached = json.load(f) # Check if cache is for same file - if cached.get('part_file') == str(prt_file): - return { - "study_id": study_id, - "cached": True, - "introspection": cached - } + if cached.get("part_file") == str(prt_file): + return {"study_id": study_id, "cached": True, "introspection": cached} except: pass # Invalid cache, re-run # Run introspection try: from optimization_engine.extractors.introspect_part import introspect_part + result = introspect_part(str(prt_file), str(model_dir), verbose=False) # Cache results - with open(cache_file, 'w') as f: + with open(cache_file, "w") as f: json.dump(result, f, indent=2) - return { - "study_id": study_id, - "cached": False, - "introspection": result - } + return {"study_id": study_id, "cached": False, "introspection": result} except ImportError: - raise HTTPException( - status_code=500, - detail="introspect_part module not available" - ) + raise HTTPException(status_code=500, detail="introspect_part module not available") except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Introspection failed: {str(e)}" - ) + raise HTTPException(status_code=500, detail=f"Introspection failed: {str(e)}") except HTTPException: raise @@ -4343,31 +4596,40 @@ async def get_nx_expressions(study_id: str): if not study_dir.exists(): raise HTTPException(status_code=404, detail=f"Study {study_id} not found") + # Find model directory - check multiple possible locations + model_dir = None + for possible_dir in [ + study_dir / "1_setup" / "model", # Standard Atomizer structure + study_dir / "1_model", + study_dir / "0_model", + study_dir / "model", + study_dir / "1_setup", # Model files directly in 1_setup + ]: + if possible_dir.exists(): + # Check if it has .prt files + if list(possible_dir.glob("*.prt")): + model_dir = possible_dir + break + + if model_dir is None: + raise HTTPException( + status_code=404, detail=f"[expr] No model dir for {study_id} in {study_dir}" + ) + # Check cache first - model_dir = study_dir / "1_model" - if not model_dir.exists(): - model_dir = study_dir / "0_model" + cache_file = model_dir / "_introspection_cache.json" - cache_file = model_dir / "_introspection_cache.json" if model_dir.exists() else None - - if cache_file and cache_file.exists(): - with open(cache_file, 'r') as f: + if cache_file.exists(): + with open(cache_file, "r") as f: cached = json.load(f) - expressions = cached.get('expressions', {}).get('user', []) - return { - "study_id": study_id, - "expressions": expressions, - "count": len(expressions) - } + expressions = cached.get("expressions", {}).get("user", []) + return {"study_id": study_id, "expressions": expressions, "count": len(expressions)} # No cache, need to run introspection - # Find .prt file - if not model_dir or not model_dir.exists(): - raise HTTPException(status_code=404, detail=f"Model directory not found for {study_id}") prt_files = list(model_dir.glob("*.prt")) - prt_files = [f for f in prt_files if '_i.prt' not in f.name.lower()] + prt_files = [f for f in prt_files if "_i.prt" not in f.name.lower()] if not prt_files: raise HTTPException(status_code=404, detail=f"No .prt files found") @@ -4376,18 +4638,15 @@ async def get_nx_expressions(study_id: str): try: from optimization_engine.extractors.introspect_part import introspect_part + result = introspect_part(str(prt_file), str(model_dir), verbose=False) # Cache for future - with open(cache_file, 'w') as f: + with open(cache_file, "w") as f: json.dump(result, f, indent=2) - expressions = result.get('expressions', {}).get('user', []) - return { - "study_id": study_id, - "expressions": expressions, - "count": len(expressions) - } + expressions = result.get("expressions", {}).get("user", []) + return {"study_id": study_id, "expressions": expressions, "count": len(expressions)} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get expressions: {str(e)}") diff --git a/atomizer-dashboard/frontend/src/components/canvas/SpecRenderer.tsx b/atomizer-dashboard/frontend/src/components/canvas/SpecRenderer.tsx index 0c3cfa34..5951240d 100644 --- a/atomizer-dashboard/frontend/src/components/canvas/SpecRenderer.tsx +++ b/atomizer-dashboard/frontend/src/components/canvas/SpecRenderer.tsx @@ -10,7 +10,7 @@ * P2.7-P2.10: SpecRenderer component with node/edge/selection handling */ -import { useCallback, useRef, useEffect, useMemo, DragEvent } from 'react'; +import { useCallback, useRef, useEffect, useMemo, useState, DragEvent } from 'react'; import ReactFlow, { Background, Controls, @@ -22,6 +22,7 @@ import ReactFlow, { NodeChange, EdgeChange, Connection, + applyNodeChanges, } from 'reactflow'; import 'reactflow/dist/style.css'; @@ -74,8 +75,28 @@ function getDefaultNodeData(type: AddableNodeType, position: { x: number; y: num case 'extractor': return { name: `extractor_${timestamp}`, - type: 'custom', + type: 'custom_function', // Must be valid ExtractorType + builtin: false, enabled: true, + // Custom function extractors need a function definition + function: { + name: 'extract', + source_code: `def extract(op2_path: str, config: dict = None) -> dict: + """ + Custom extractor function. + + Args: + op2_path: Path to the OP2 results file + config: Optional configuration dict + + Returns: + Dictionary with extracted values + """ + # TODO: Implement extraction logic + return {'value': 0.0} +`, + }, + outputs: [{ name: 'value', metric: 'custom' }], canvas_position: position, }; case 'objective': @@ -90,7 +111,8 @@ function getDefaultNodeData(type: AddableNodeType, position: { x: number; y: num case 'constraint': return { name: `constraint_${timestamp}`, - type: 'upper', + constraint_type: 'hard', // Must be 'hard' or 'soft' + operator: '<=', limit: 1.0, source_extractor_id: null, source_output: null, @@ -208,12 +230,23 @@ function SpecRendererInner({ nodesRef.current = nodes; }, [nodes]); + // Track local node state for smooth dragging + const [localNodes, setLocalNodes] = useState(nodes); + + // Sync local nodes with spec-derived nodes when spec changes + useEffect(() => { + setLocalNodes(nodes); + }, [nodes]); + // Handle node position changes const onNodesChange = useCallback( (changes: NodeChange[]) => { if (!editable) return; - // Handle position changes + // Apply changes to local state for smooth dragging + setLocalNodes((nds) => applyNodeChanges(changes, nds)); + + // Handle position changes - save to spec when drag ends for (const change of changes) { if (change.type === 'position' && change.position && change.dragging === false) { // Dragging ended - update spec @@ -458,7 +491,7 @@ function SpecRendererInner({ )} void; } @@ -56,7 +57,7 @@ interface IntrospectionResult { warnings: string[]; } -export function IntrospectionPanel({ filePath, onClose }: IntrospectionPanelProps) { +export function IntrospectionPanel({ filePath, studyId, onClose }: IntrospectionPanelProps) { const [result, setResult] = useState(null); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); @@ -73,21 +74,37 @@ export function IntrospectionPanel({ filePath, onClose }: IntrospectionPanelProp setIsLoading(true); setError(null); try { - const res = await fetch('/api/nx/introspect', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ file_path: filePath }), - }); - if (!res.ok) throw new Error('Introspection failed'); + let res; + + // If we have a studyId, use the study-aware introspection endpoint + if (studyId) { + // Don't encode studyId - it may contain slashes for nested paths (e.g., M1_Mirror/study_name) + res = await fetch(`/api/optimization/studies/${studyId}/nx/introspect`); + } else { + // Fallback to direct path introspection + res = await fetch('/api/nx/introspect', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ file_path: filePath }), + }); + } + + if (!res.ok) { + const errData = await res.json().catch(() => ({})); + throw new Error(errData.detail || 'Introspection failed'); + } const data = await res.json(); - setResult(data); + + // Handle different response formats + setResult(data.introspection || data); } catch (e) { - setError('Failed to introspect model'); - console.error(e); + const msg = e instanceof Error ? e.message : 'Failed to introspect model'; + setError(msg); + console.error('Introspection error:', e); } finally { setIsLoading(false); } - }, [filePath]); + }, [filePath, studyId]); useEffect(() => { runIntrospection(); diff --git a/atomizer-dashboard/frontend/src/components/canvas/panels/NodeConfigPanelV2.tsx b/atomizer-dashboard/frontend/src/components/canvas/panels/NodeConfigPanelV2.tsx index bff7291a..943082eb 100644 --- a/atomizer-dashboard/frontend/src/components/canvas/panels/NodeConfigPanelV2.tsx +++ b/atomizer-dashboard/frontend/src/components/canvas/panels/NodeConfigPanelV2.tsx @@ -254,6 +254,7 @@ export function NodeConfigPanelV2({ onClose }: NodeConfigPanelV2Props) {
setShowIntrospection(false)} />
@@ -313,6 +314,7 @@ function ModelNodeConfig({ spec }: SpecConfigProps) {
setShowIntrospection(false)} />
@@ -694,38 +696,21 @@ function ExtractorNodeConfig({ node, onChange }: ExtractorNodeConfigProps) { {showCodeEditor && (
- {/* Modal Header */} -
-
- - Custom Extractor: {node.name} - .py -
- -
- - {/* Code Editor */} -
- o.name) || []} - onChange={handleCodeChange} - onRequestGeneration={handleRequestGeneration} - onRequestStreamingGeneration={handleStreamingGeneration} - onRun={handleValidateCode} - onTest={handleTestCode} - onClose={() => setShowCodeEditor(false)} - showHeader={false} - height="100%" - studyId={studyId || undefined} - /> -
+ {/* Code Editor with built-in header containing toolbar buttons */} + o.name) || []} + onChange={handleCodeChange} + onRequestGeneration={handleRequestGeneration} + onRequestStreamingGeneration={handleStreamingGeneration} + onRun={handleValidateCode} + onTest={handleTestCode} + onClose={() => setShowCodeEditor(false)} + showHeader={true} + height="100%" + studyId={studyId || undefined} + />
)} diff --git a/atomizer-dashboard/frontend/src/pages/CanvasView.tsx b/atomizer-dashboard/frontend/src/pages/CanvasView.tsx index 45670d19..b5784eb0 100644 --- a/atomizer-dashboard/frontend/src/pages/CanvasView.tsx +++ b/atomizer-dashboard/frontend/src/pages/CanvasView.tsx @@ -472,7 +472,8 @@ export function CanvasView() { {/* Config Panel - use V2 for spec mode, legacy for AtomizerCanvas */} - {selectedNodeId && !showChat && ( + {/* Shows INSTEAD of chat when a node is selected */} + {selectedNodeId ? ( useSpecMode ? ( useSpecStore.getState().clearSelection()} /> ) : ( @@ -480,10 +481,7 @@ export function CanvasView() { ) - )} - - {/* Chat/Assistant Panel */} - {showChat && ( + ) : showChat ? (
{/* Chat Header */}
@@ -524,7 +522,7 @@ export function CanvasView() { isConnected={isConnected} />
- )} + ) : null} {/* Template Selector Modal */}