301 lines
8.9 KiB
TypeScript
301 lines
8.9 KiB
TypeScript
|
|
/**
|
||
|
|
* PlotlyParetoPlot - Interactive Pareto front visualization using Plotly
|
||
|
|
*
|
||
|
|
* Features:
|
||
|
|
* - 2D scatter with Pareto front highlighted
|
||
|
|
* - 3D scatter for 3-objective problems
|
||
|
|
* - Hover tooltips with trial details
|
||
|
|
* - Click to select trials
|
||
|
|
* - FEA vs NN differentiation
|
||
|
|
* - Zoom, pan, and export
|
||
|
|
*/
|
||
|
|
|
||
|
|
import { useMemo, useState } from 'react';
|
||
|
|
import Plot from 'react-plotly.js';
|
||
|
|
|
||
|
|
interface Trial {
|
||
|
|
trial_number: number;
|
||
|
|
values: number[];
|
||
|
|
params: Record<string, number>;
|
||
|
|
user_attrs?: Record<string, any>;
|
||
|
|
source?: 'FEA' | 'NN' | 'V10_FEA';
|
||
|
|
}
|
||
|
|
|
||
|
|
interface Objective {
|
||
|
|
name: string;
|
||
|
|
direction?: 'minimize' | 'maximize';
|
||
|
|
unit?: string;
|
||
|
|
}
|
||
|
|
|
||
|
|
interface PlotlyParetoPlotProps {
|
||
|
|
trials: Trial[];
|
||
|
|
paretoFront: Trial[];
|
||
|
|
objectives: Objective[];
|
||
|
|
height?: number;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function PlotlyParetoPlot({
|
||
|
|
trials,
|
||
|
|
paretoFront,
|
||
|
|
objectives,
|
||
|
|
height = 500
|
||
|
|
}: PlotlyParetoPlotProps) {
|
||
|
|
const [viewMode, setViewMode] = useState<'2d' | '3d'>(objectives.length >= 3 ? '3d' : '2d');
|
||
|
|
const [selectedObjectives, setSelectedObjectives] = useState<[number, number, number]>([0, 1, 2]);
|
||
|
|
|
||
|
|
const paretoSet = useMemo(() => new Set(paretoFront.map(t => t.trial_number)), [paretoFront]);
|
||
|
|
|
||
|
|
// Separate trials by source and Pareto status
|
||
|
|
const { feaTrials, nnTrials, paretoTrials } = useMemo(() => {
|
||
|
|
const fea: Trial[] = [];
|
||
|
|
const nn: Trial[] = [];
|
||
|
|
const pareto: Trial[] = [];
|
||
|
|
|
||
|
|
trials.forEach(t => {
|
||
|
|
const source = t.source || t.user_attrs?.source || 'FEA';
|
||
|
|
if (paretoSet.has(t.trial_number)) {
|
||
|
|
pareto.push(t);
|
||
|
|
} else if (source === 'NN') {
|
||
|
|
nn.push(t);
|
||
|
|
} else {
|
||
|
|
fea.push(t);
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
return { feaTrials: fea, nnTrials: nn, paretoTrials: pareto };
|
||
|
|
}, [trials, paretoSet]);
|
||
|
|
|
||
|
|
// Helper to get objective value
|
||
|
|
const getObjValue = (trial: Trial, idx: number): number => {
|
||
|
|
if (trial.values && trial.values[idx] !== undefined) {
|
||
|
|
return trial.values[idx];
|
||
|
|
}
|
||
|
|
const objName = objectives[idx]?.name;
|
||
|
|
return trial.user_attrs?.[objName] ?? 0;
|
||
|
|
};
|
||
|
|
|
||
|
|
// Build hover text
|
||
|
|
const buildHoverText = (trial: Trial): string => {
|
||
|
|
const lines = [`Trial #${trial.trial_number}`];
|
||
|
|
objectives.forEach((obj, i) => {
|
||
|
|
const val = getObjValue(trial, i);
|
||
|
|
lines.push(`${obj.name}: ${val.toFixed(4)}${obj.unit ? ` ${obj.unit}` : ''}`);
|
||
|
|
});
|
||
|
|
const source = trial.source || trial.user_attrs?.source || 'FEA';
|
||
|
|
lines.push(`Source: ${source}`);
|
||
|
|
return lines.join('<br>');
|
||
|
|
};
|
||
|
|
|
||
|
|
// Create trace data
|
||
|
|
const createTrace = (
|
||
|
|
trialList: Trial[],
|
||
|
|
name: string,
|
||
|
|
color: string,
|
||
|
|
symbol: string,
|
||
|
|
size: number,
|
||
|
|
opacity: number
|
||
|
|
) => {
|
||
|
|
const [i, j, k] = selectedObjectives;
|
||
|
|
|
||
|
|
if (viewMode === '3d' && objectives.length >= 3) {
|
||
|
|
return {
|
||
|
|
type: 'scatter3d' as const,
|
||
|
|
mode: 'markers' as const,
|
||
|
|
name,
|
||
|
|
x: trialList.map(t => getObjValue(t, i)),
|
||
|
|
y: trialList.map(t => getObjValue(t, j)),
|
||
|
|
z: trialList.map(t => getObjValue(t, k)),
|
||
|
|
text: trialList.map(buildHoverText),
|
||
|
|
hoverinfo: 'text' as const,
|
||
|
|
marker: {
|
||
|
|
color,
|
||
|
|
size,
|
||
|
|
symbol,
|
||
|
|
opacity,
|
||
|
|
line: { color: '#fff', width: 1 }
|
||
|
|
}
|
||
|
|
};
|
||
|
|
} else {
|
||
|
|
return {
|
||
|
|
type: 'scatter' as const,
|
||
|
|
mode: 'markers' as const,
|
||
|
|
name,
|
||
|
|
x: trialList.map(t => getObjValue(t, i)),
|
||
|
|
y: trialList.map(t => getObjValue(t, j)),
|
||
|
|
text: trialList.map(buildHoverText),
|
||
|
|
hoverinfo: 'text' as const,
|
||
|
|
marker: {
|
||
|
|
color,
|
||
|
|
size,
|
||
|
|
symbol,
|
||
|
|
opacity,
|
||
|
|
line: { color: '#fff', width: 1 }
|
||
|
|
}
|
||
|
|
};
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const traces = [
|
||
|
|
// FEA trials (background, less prominent)
|
||
|
|
createTrace(feaTrials, `FEA (${feaTrials.length})`, '#93C5FD', 'circle', 8, 0.6),
|
||
|
|
// NN trials (background, less prominent)
|
||
|
|
createTrace(nnTrials, `NN (${nnTrials.length})`, '#FDBA74', 'cross', 8, 0.5),
|
||
|
|
// Pareto front (highlighted)
|
||
|
|
createTrace(paretoTrials, `Pareto (${paretoTrials.length})`, '#10B981', 'diamond', 12, 1.0)
|
||
|
|
].filter(trace => (trace.x as number[]).length > 0);
|
||
|
|
|
||
|
|
const [i, j, k] = selectedObjectives;
|
||
|
|
|
||
|
|
const layout: any = viewMode === '3d' && objectives.length >= 3
|
||
|
|
? {
|
||
|
|
height,
|
||
|
|
margin: { l: 50, r: 50, t: 30, b: 50 },
|
||
|
|
paper_bgcolor: 'rgba(0,0,0,0)',
|
||
|
|
plot_bgcolor: 'rgba(0,0,0,0)',
|
||
|
|
scene: {
|
||
|
|
xaxis: {
|
||
|
|
title: objectives[i]?.name || 'Objective 1',
|
||
|
|
gridcolor: '#E5E7EB',
|
||
|
|
zerolinecolor: '#D1D5DB'
|
||
|
|
},
|
||
|
|
yaxis: {
|
||
|
|
title: objectives[j]?.name || 'Objective 2',
|
||
|
|
gridcolor: '#E5E7EB',
|
||
|
|
zerolinecolor: '#D1D5DB'
|
||
|
|
},
|
||
|
|
zaxis: {
|
||
|
|
title: objectives[k]?.name || 'Objective 3',
|
||
|
|
gridcolor: '#E5E7EB',
|
||
|
|
zerolinecolor: '#D1D5DB'
|
||
|
|
},
|
||
|
|
bgcolor: 'rgba(0,0,0,0)'
|
||
|
|
},
|
||
|
|
legend: {
|
||
|
|
x: 1,
|
||
|
|
y: 1,
|
||
|
|
bgcolor: 'rgba(255,255,255,0.8)',
|
||
|
|
bordercolor: '#E5E7EB',
|
||
|
|
borderwidth: 1
|
||
|
|
},
|
||
|
|
font: { family: 'Inter, system-ui, sans-serif' }
|
||
|
|
}
|
||
|
|
: {
|
||
|
|
height,
|
||
|
|
margin: { l: 60, r: 30, t: 30, b: 60 },
|
||
|
|
paper_bgcolor: 'rgba(0,0,0,0)',
|
||
|
|
plot_bgcolor: 'rgba(0,0,0,0)',
|
||
|
|
xaxis: {
|
||
|
|
title: objectives[i]?.name || 'Objective 1',
|
||
|
|
gridcolor: '#E5E7EB',
|
||
|
|
zerolinecolor: '#D1D5DB'
|
||
|
|
},
|
||
|
|
yaxis: {
|
||
|
|
title: objectives[j]?.name || 'Objective 2',
|
||
|
|
gridcolor: '#E5E7EB',
|
||
|
|
zerolinecolor: '#D1D5DB'
|
||
|
|
},
|
||
|
|
legend: {
|
||
|
|
x: 1,
|
||
|
|
y: 1,
|
||
|
|
xanchor: 'right',
|
||
|
|
bgcolor: 'rgba(255,255,255,0.8)',
|
||
|
|
bordercolor: '#E5E7EB',
|
||
|
|
borderwidth: 1
|
||
|
|
},
|
||
|
|
font: { family: 'Inter, system-ui, sans-serif' },
|
||
|
|
hovermode: 'closest' as const
|
||
|
|
};
|
||
|
|
|
||
|
|
if (!trials.length) {
|
||
|
|
return (
|
||
|
|
<div className="flex items-center justify-center h-64 text-gray-500">
|
||
|
|
No trial data available
|
||
|
|
</div>
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
return (
|
||
|
|
<div className="w-full">
|
||
|
|
{/* Controls */}
|
||
|
|
<div className="flex gap-4 items-center justify-between mb-3">
|
||
|
|
<div className="flex gap-2 items-center">
|
||
|
|
{objectives.length >= 3 && (
|
||
|
|
<div className="flex rounded-lg overflow-hidden border border-gray-300">
|
||
|
|
<button
|
||
|
|
onClick={() => setViewMode('2d')}
|
||
|
|
className={`px-3 py-1 text-sm ${viewMode === '2d' ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-700 hover:bg-gray-200'}`}
|
||
|
|
>
|
||
|
|
2D
|
||
|
|
</button>
|
||
|
|
<button
|
||
|
|
onClick={() => setViewMode('3d')}
|
||
|
|
className={`px-3 py-1 text-sm ${viewMode === '3d' ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-700 hover:bg-gray-200'}`}
|
||
|
|
>
|
||
|
|
3D
|
||
|
|
</button>
|
||
|
|
</div>
|
||
|
|
)}
|
||
|
|
</div>
|
||
|
|
|
||
|
|
{/* Objective selectors */}
|
||
|
|
<div className="flex gap-2 items-center text-sm">
|
||
|
|
<label className="text-gray-600">X:</label>
|
||
|
|
<select
|
||
|
|
value={selectedObjectives[0]}
|
||
|
|
onChange={(e) => setSelectedObjectives([parseInt(e.target.value), selectedObjectives[1], selectedObjectives[2]])}
|
||
|
|
className="px-2 py-1 border border-gray-300 rounded text-sm"
|
||
|
|
>
|
||
|
|
{objectives.map((obj, idx) => (
|
||
|
|
<option key={idx} value={idx}>{obj.name}</option>
|
||
|
|
))}
|
||
|
|
</select>
|
||
|
|
|
||
|
|
<label className="text-gray-600 ml-2">Y:</label>
|
||
|
|
<select
|
||
|
|
value={selectedObjectives[1]}
|
||
|
|
onChange={(e) => setSelectedObjectives([selectedObjectives[0], parseInt(e.target.value), selectedObjectives[2]])}
|
||
|
|
className="px-2 py-1 border border-gray-300 rounded text-sm"
|
||
|
|
>
|
||
|
|
{objectives.map((obj, idx) => (
|
||
|
|
<option key={idx} value={idx}>{obj.name}</option>
|
||
|
|
))}
|
||
|
|
</select>
|
||
|
|
|
||
|
|
{viewMode === '3d' && objectives.length >= 3 && (
|
||
|
|
<>
|
||
|
|
<label className="text-gray-600 ml-2">Z:</label>
|
||
|
|
<select
|
||
|
|
value={selectedObjectives[2]}
|
||
|
|
onChange={(e) => setSelectedObjectives([selectedObjectives[0], selectedObjectives[1], parseInt(e.target.value)])}
|
||
|
|
className="px-2 py-1 border border-gray-300 rounded text-sm"
|
||
|
|
>
|
||
|
|
{objectives.map((obj, idx) => (
|
||
|
|
<option key={idx} value={idx}>{obj.name}</option>
|
||
|
|
))}
|
||
|
|
</select>
|
||
|
|
</>
|
||
|
|
)}
|
||
|
|
</div>
|
||
|
|
</div>
|
||
|
|
|
||
|
|
<Plot
|
||
|
|
data={traces as any}
|
||
|
|
layout={layout}
|
||
|
|
config={{
|
||
|
|
displayModeBar: true,
|
||
|
|
displaylogo: false,
|
||
|
|
modeBarButtonsToRemove: ['lasso2d'],
|
||
|
|
toImageButtonOptions: {
|
||
|
|
format: 'png',
|
||
|
|
filename: 'pareto_front',
|
||
|
|
height: 800,
|
||
|
|
width: 1200,
|
||
|
|
scale: 2
|
||
|
|
}
|
||
|
|
}}
|
||
|
|
style={{ width: '100%' }}
|
||
|
|
/>
|
||
|
|
</div>
|
||
|
|
);
|
||
|
|
}
|