1##########################################################################
2# Copyright (c) 2025 - 2025 Altair Engineering Inc. All Rights Reserved
3# Contains trade secrets of Altair Engineering, Inc. Copyright notice
4# does not imply publication. Decompilation or disassembly of this
5# software is strictly prohibited.
6##########################################################################
7
8from __future__ import annotations
9
10import alt.hst.gui.custom_plot as custom_plot
11import alt.hst.utl.hstgettext as hsttr
12import typing
13import os
14import string
15
16
17# -----------------------------------------------------------------------------
18# -----------------------------------------------------------------------------
19class DemoPlotlyScatterSettings(custom_plot.CustomPlotSettings):
20 MARKER_SIZE: typing.Final[str] = 'marker_size'
21 TEXT_SIZE: typing.Final[str] = 'text_size'
22 TRENDLINE: typing.Final[str] = 'trendline'
23 TRENDLINE_CHOICE_OLS: typing.Final[str] = 'ols'
24 TRENDLINE_CHOICE_KRR: typing.Final[str] = 'krr'
25 TRENDLINE_CHOICE_ROLLING: typing.Final[str] = 'rolling'
26 TRENDLINE_CHOICE_EWM: typing.Final[str] = 'ewm'
27 TRENDLINE_CHOICE_NONE: typing.Final[str] = 'none'
28 WINDOW: typing.Final[str] = 'window'
29
30 # ---------------------------------------------------------------------------
31 def __init__(self) -> None:
32 super().__init__()
33 self._markerSizeValue = 20
34 self._textSizeValue = 20
35 self._trendline = self.TRENDLINE_CHOICE_NONE
36 self._window = 5
37
38 # ---------------------------------------------------------------------------
39 def getNames(self) -> typing.List[str]:
40 return [self.MARKER_SIZE, self.TEXT_SIZE, self.TRENDLINE, self.WINDOW]
41
42 # ---------------------------------------------------------------------------
43 def getLabel(self, name: str) -> typing.Optional[str]:
44 if name == self.MARKER_SIZE:
45 return hsttr.gt('Marker Size')
46 elif name == self.TEXT_SIZE:
47 return hsttr.gt('Text Size')
48 elif name == self.TRENDLINE:
49 return hsttr.gt('Trendline')
50 elif name == self.WINDOW:
51 return hsttr.gt('Window Size')
52 return None
53
54 # ---------------------------------------------------------------------------
55 def getValue(self, name: str) -> typing.Optional[typing.Any]:
56 if name == self.MARKER_SIZE:
57 return self._markerSizeValue
58 elif name == self.TEXT_SIZE:
59 return self._textSizeValue
60 elif name == self.TRENDLINE:
61 return self._trendline
62 elif name == self.WINDOW:
63 return self._window
64 return None
65
66 # ---------------------------------------------------------------------------
67 def setValue(self, name: str, value: typing.Any) -> None:
68 if name == self.MARKER_SIZE:
69 self._markerSizeValue = int(value)
70 elif name == self.TEXT_SIZE:
71 self._textSizeValue = int(value)
72 elif name == self.TRENDLINE:
73 self._trendline = str(value)
74 elif name == self.WINDOW:
75 self._window = int(value)
76
77 # ---------------------------------------------------------------------------
78 def getDescription(self, name: str) -> typing.Optional[str]:
79 if name == self.MARKER_SIZE:
80 return hsttr.gt('The size of individual markers in the plot')
81 elif name == self.TEXT_SIZE:
82 return hsttr.gt('The size of text in the plot')
83 elif name == self.TRENDLINE:
84 return hsttr.gt('Show trendline in the plot')
85 elif name == self.WINDOW:
86 return hsttr.gt('Window size for rolling mean and EWM trendlines')
87 return None
88
89 # ---------------------------------------------------------------------------
90 def getDefault(self, name: str) -> typing.Optional[typing.Any]:
91 if name == self.MARKER_SIZE:
92 return 20
93 elif name == self.TEXT_SIZE:
94 return 20
95 elif name == self.TRENDLINE:
96 return self.TRENDLINE_CHOICE_NONE
97 elif name == self.WINDOW:
98 return 5
99 return None
100
101 # -------------------------------------------------------------------------
102 def getChoices(self, name: str) -> typing.List[typing.Any]:
103 if name == self.TRENDLINE:
104 return [self.TRENDLINE_CHOICE_NONE,
105 self.TRENDLINE_CHOICE_OLS,
106 self.TRENDLINE_CHOICE_KRR,
107 self.TRENDLINE_CHOICE_ROLLING,
108 self.TRENDLINE_CHOICE_EWM]
109 return []
110
111 # -------------------------------------------------------------------------
112 def getChoiceLabel(self, name: str, value: typing.Any) -> str:
113 if name == self.TRENDLINE:
114 if value == self.TRENDLINE_CHOICE_NONE:
115 return hsttr.gt('None')
116 elif value == self.TRENDLINE_CHOICE_OLS:
117 return 'OLS'
118 elif value == self.TRENDLINE_CHOICE_KRR:
119 return 'KRR'
120 elif value == self.TRENDLINE_CHOICE_ROLLING:
121 return hsttr.gt('Rolling')
122 elif value == self.TRENDLINE_CHOICE_EWM:
123 return 'EWMA'
124 return ''
125
126 # -------------------------------------------------------------------------
127 def getChoiceDescription(self, name: str, value: typing.Any) -> str:
128 if name == self.TRENDLINE:
129 if value == self.TRENDLINE_CHOICE_NONE:
130 return hsttr.gt('None')
131 elif value == self.TRENDLINE_CHOICE_OLS:
132 return hsttr.gt('Ordinary Least Squares')
133 elif value == self.TRENDLINE_CHOICE_KRR:
134 return hsttr.gt('Kernel Ridge Regression')
135 elif value == self.TRENDLINE_CHOICE_ROLLING:
136 return hsttr.gt('Rolling Mean')
137 elif value == self.TRENDLINE_CHOICE_EWM:
138 return hsttr.gt('Exponentially Weighted Moving Average')
139 return ''
140
141 # ---------------------------------------------------------------------------
142 def getValid(self, name: str) -> bool:
143 if name == self.MARKER_SIZE:
144 return self._markerSizeValue >= 0 and self._markerSizeValue < 100
145 elif name == self.TEXT_SIZE:
146 return self._textSizeValue >= 0 and self._textSizeValue < 100
147 elif name == self.WINDOW:
148 return self._window >= 0
149 return True
150
151 # ---------------------------------------------------------------------------
152 def isSettingEnabled(self, name: str) -> bool:
153 if name == self.WINDOW:
154 return self._trendline in [self.TRENDLINE_CHOICE_ROLLING, self.TRENDLINE_CHOICE_EWM]
155 return True
156
157
158# -----------------------------------------------------------------------------
159# -----------------------------------------------------------------------------
160class DemoPlotlyScatter(custom_plot.CustomPlot):
161 """
162 A demo plot class that uses Plotly for plotting.
163 """
164
165 # -------------------------------------------------------------------------
166 def __init__(self) -> None:
167 super().__init__()
168 self.setUpdateControlsEnabled(True, True)
169
170
171 # -------------------------------------------------------------------------
172 @classmethod
173 def getDefaultSettings(cls) -> DemoPlotlyScatterSettings:
174 """
175 Returns the settings for the plot.
176 """
177 return DemoPlotlyScatterSettings()
178
179 # -------------------------------------------------------------------------
180 @classmethod
181 def getIcon(cls) -> str:
182 if custom_plot.isDarkTheme():
183 return os.path.join(os.path.dirname(__file__), 'scatter_2D_dark.svg')
184 return os.path.join(os.path.dirname(__file__), 'scatter_2D.svg')
185
186 # -------------------------------------------------------------------------
187 @classmethod
188 def getThumbnail(cls) -> str:
189 return cls.getIcon()
190
191 # -------------------------------------------------------------------------
192 def makeChannelConfig(self) -> custom_plot.ChannelConfig:
193 config = custom_plot.ChannelConfig()
194
195 # Add an Inputs tab
196 inputSection = config.addSection('Inputs')
197
198 # Add a Variables table to the Inputs section
199 inputSection.addTable('Variables', [custom_plot.ChannelTypes.VARIABLES],
200 allowMultiSelect=False)
201
202 # Add an Outputs tab
203 outputSection = config.addSection('Outputs')
204
205 # Add a Responses table and Gradients table to the Outputs section
206 outputSection.addTable('Responses', [custom_plot.ChannelTypes.RESPONSES],
207 allowMultiSelect=True)
208 outputSection.addTable('My Gradients Table',
209 [custom_plot.ChannelTypes.GRADIENTS],
210 allowMultiSelect=True)
211
212 return config
213
214 # -------------------------------------------------------------------------
215 def getEventSetupScript(self) -> str:
216 # Setup the body of the JavaScript function that will handle the event
217 script = """
218 var pointNumber='', curveNumber='', colors=[];
219 for(var i=0; i < data.points.length; i++){
220 pointNumber = data.points[i].pointNumber;
221 curveNumber = data.points[i].curveNumber;
222 };
223 hst_plot.handlePointClickEvent(pointNumber, curveNumber);
224 """
225
226 # Create the script that the will connect the event handler to the Plotly graph
227 return """
228 var plot = document.querySelector(".plotly-graph-div");
229 if (plot === null) {
230 console.log("Plotly graph object not found");
231 } else {""" + self.createEventHandler('plot', 'plotly_click', script) + "}"
232
233 # -------------------------------------------------------------------------
234 def getPointSelectionScript(self, points: typing.List[typing.Union[str, int]],
235 _unused_curves: typing.Optional[typing.List[typing.Union[str, int]]] = None) -> str:
236 # Generate a JavaScript script that will change the color of the selected points
237 markerSize = self.getSettings().getValue(DemoPlotlyScatterSettings.MARKER_SIZE)
238 script = string.Template("""
239 var plot = document.querySelector(".plotly-graph-div");
240 if (plot.data.length > 0) {
241 var resetUpdate = { marker: { color: null, size: null } };
242 var curves = [];
243 for (var i = 0; i < plot.data.length; i++) { curves.push(i); };
244 for (var i = 0; i < curves.length; i++) {
245 var curve = curves[i];
246 var colors = [];
247 if (plot._fullData[curve].marker === undefined) {
248 continue;
249 }
250
251 Plotly.restyle(plot, resetUpdate, [curve]);
252 for (var j = 0; j < plot.data[curve].x.length; j++) {
253 colors.push(plot._fullData[curve].marker.color);
254 };
255 ${POINT_COLORS}
256 var update = { marker: { color: colors, size: ${MARKER_SIZE} } };
257 Plotly.restyle(plot, update, [curve]);
258 }
259 }
260 """)
261 pointColors = ''
262 for point in points:
263 pointColors += f'colors[{point}] = "#ffdb0f";'
264 return script.substitute(POINT_COLORS=pointColors,
265 MARKER_SIZE=str(markerSize))
266
267 # -------------------------------------------------------------------------
268 def getTrendline(self, xData: typing.Sequence[typing.Union[str, int, float]],
269 yData: typing.Sequence[typing.Union[str, int, float]]) -> typing.List[float]:
270 trendline = self.getSettings().getValue(DemoPlotlyScatterSettings.TRENDLINE)
271 trendWindow = typing.cast(int, self.getSettings().getValue(DemoPlotlyScatterSettings.WINDOW))
272 import pandas as pd
273 df = pd.DataFrame({'x': xData, 'y': yData})
274 df.sort_values('x', inplace=True)
275 result = yData
276
277 if trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_OLS:
278 from sklearn import linear_model # type: ignore[import-untyped]
279 import numpy as np
280 model = linear_model.LinearRegression()
281 xArray = np.array(df['x']).reshape(-1, 1)
282 model.fit(xArray, np.array(df['y']))
283 result = model.predict(xArray)
284 elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_KRR:
285 import numpy as np
286 from sklearn.kernel_ridge import KernelRidge # type: ignore[import-untyped]
287 from scipy.interpolate import interp1d # type: ignore[import-untyped]
288
289 x_array = np.array(df['x']).reshape(-1, 1)
290 y_array = np.array(df['y'])
291
292 sort_indices = np.argsort(x_array.flatten())
293 x_sorted = x_array[sort_indices]
294 y_sorted = y_array[sort_indices]
295
296 x_smooth = np.linspace(min(df['x']), max(df['x']), len(df['x'])).reshape(-1, 1)
297 model = KernelRidge(alpha=0.1, kernel='rbf', gamma=0.1)
298 model.fit(x_sorted, y_sorted)
299 y_pred = model.predict(x_smooth)
300
301 f = interp1d(x_smooth.flatten(), y_pred, bounds_error=False, fill_value="extrapolate")
302 result = f(np.array(df['x'])).tolist()
303 elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_ROLLING:
304 result = df['y'].rolling(window=trendWindow).mean().tolist()
305 elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_EWM:
306 result = df['y'].ewm(span=trendWindow).mean().tolist()
307 return typing.cast(typing.List[float], result)
308
309
310 # -------------------------------------------------------------------------
311 def plot(self, channelSelection: custom_plot.ChannelSelection) -> typing.Tuple[custom_plot.PlotOutputType, str]:
312 import plotly.graph_objects as go # type: ignore[import-untyped]
313 import plotly
314 import numpy as np
315
316 # Get the selections from the Variables table of the Inputs section
317 variables = channelSelection.getItems('Inputs', 'Variables')
318 if not variables:
319 return custom_plot.PlotOutputType.HTML_STRING, ''
320
321 # Make sure the item is actually a Variable type
322 if variables[0].getType() != custom_plot.ChannelTypes.VARIABLES:
323 raise ValueError(f'Invalid variable type: {variables[0].getType()}')
324
325 # Get the variable name and label, and the dataset to retrieve data
326 variableVarname = variables[0].getVarname()
327 variableLabel = variables[0].getLabel()
328 dataset = self.getDataSet()
329 xData = dataset.getStoredEvaluationValues(variableVarname)
330 fig = go.Figure()
331
332 # Add the scatter plot for each response in the Responses table of the Outputs section
333 responses = channelSelection.getItems('Outputs', 'Responses')
334 for response in responses:
335 self.checkIfStopped()
336 if response.getType() != custom_plot.ChannelTypes.RESPONSES:
337 raise ValueError(f'Invalid response type: {response.getType()}')
338 yData = dataset.getStoredEvaluationValues(response.getVarname())
339 xArray = np.array(xData)
340 yArray = np.array(yData)
341 fig.add_scatter(x=xArray, y=yArray, mode='markers',
342 name=response.getLabel(), showlegend=True)
343
344 # Add a trendline if the setting is enabled
345 if self.getSettings().getValue(DemoPlotlyScatterSettings.TRENDLINE) != DemoPlotlyScatterSettings.TRENDLINE_CHOICE_NONE:
346 trendline = self.getTrendline(xData, yData)
347 fig.add_scatter(
348 x=np.linspace(min(xData), max(xData), len(trendline)),
349 y=np.array(trendline),
350 mode='lines',
351 name=f"{response.getLabel()} Trendline",
352 line=dict(width=4),
353 showlegend=True
354 )
355
356 # Update the plot layout with the custom plot settings
357 markerSize = self.getSettings().getValue(DemoPlotlyScatterSettings.MARKER_SIZE)
358 textSize = self.getSettings().getValue(DemoPlotlyScatterSettings.TEXT_SIZE)
359 fig.update_traces(marker_size=markerSize)
360 fig.update_layout(
361 font=dict(size=textSize),
362 title_text='Evaluation Data Scatter',
363 xaxis=dict(title=dict(text=variableLabel)),
364 legend=dict(title=dict(text='Responses')),
365 modebar_remove=['toImage'],
366 )
367
368 # Check if the plot has been stopped by the user before going any further.
369 # This will raise an exception.
370 self.checkIfStopped()
371
372 # If you are returning a file path, make sure the file is saved to
373 # a writable location. Keep in mind that the file will contain
374 # data from your HyperStudy session and should be cleaned up
375 # after use.
376 return custom_plot.PlotOutputType.HTML_FILE, str(plotly.offline.plot(
377 fig, include_plotlyjs=True, output_type='file',
378 filename=self.getPlotOutputFile(), auto_open=False))
379
380 # If you are returning a string, plotly.js cannot be inlined in the HTML
381 # due to limitations of the renderer. You must use the CDN version or
382 # have your working directory be a writable location. Note: the CDN version
383 # will make a network request to load the plotly.js library.
384 # return custom_plot.PlotOutputType.HTML_STRING, str(plotly.offline.plot(
385 # fig, include_plotlyjs='cdn', output_type='div'))
386
387
388# -----------------------------------------------------------------------------
389def getPlotClass() -> typing.Type[custom_plot.CustomPlot]:
390 return DemoPlotlyScatter