Example 05 - Create a 2D Scatter plot for post-processing#

  1##########################################################################
  2# Copyright (c) 2025 - 2025 Altair Engineering Inc.  All Rights Reserved
  3# Contains trade secrets of Altair Engineering, Inc.  Copyright notice
  4# does not imply publication.  Decompilation or disassembly of this
  5# software is strictly prohibited.
  6##########################################################################
  7
  8from __future__ import annotations
  9
 10import alt.hst.gui.custom_plot as custom_plot
 11import alt.hst.utl.hstgettext as hsttr
 12import typing
 13import os
 14import string
 15
 16
 17# -----------------------------------------------------------------------------
 18# -----------------------------------------------------------------------------
 19class DemoPlotlyScatterSettings(custom_plot.CustomPlotSettings):
 20    MARKER_SIZE: typing.Final[str] = 'marker_size'
 21    TEXT_SIZE: typing.Final[str] = 'text_size'
 22    TRENDLINE: typing.Final[str] = 'trendline'
 23    TRENDLINE_CHOICE_OLS: typing.Final[str] = 'ols'
 24    TRENDLINE_CHOICE_KRR: typing.Final[str] = 'krr'
 25    TRENDLINE_CHOICE_ROLLING: typing.Final[str] = 'rolling'
 26    TRENDLINE_CHOICE_EWM: typing.Final[str] = 'ewm'
 27    TRENDLINE_CHOICE_NONE: typing.Final[str] = 'none'
 28    WINDOW: typing.Final[str] = 'window'
 29
 30    # ---------------------------------------------------------------------------
 31    def __init__(self) -> None:
 32        super().__init__()
 33        self._markerSizeValue = 20
 34        self._textSizeValue = 20
 35        self._trendline = self.TRENDLINE_CHOICE_NONE
 36        self._window = 5
 37
 38    # ---------------------------------------------------------------------------
 39    def getNames(self) -> typing.List[str]:
 40        return [self.MARKER_SIZE, self.TEXT_SIZE, self.TRENDLINE, self.WINDOW]
 41    
 42    # ---------------------------------------------------------------------------
 43    def getLabel(self, name: str) -> typing.Optional[str]:
 44        if name == self.MARKER_SIZE:
 45            return hsttr.gt('Marker Size')
 46        elif name == self.TEXT_SIZE:
 47            return hsttr.gt('Text Size')
 48        elif name == self.TRENDLINE:
 49            return hsttr.gt('Trendline')
 50        elif name == self.WINDOW:
 51            return hsttr.gt('Window Size')
 52        return None
 53    
 54    # ---------------------------------------------------------------------------
 55    def getValue(self, name: str) -> typing.Optional[typing.Any]:
 56        if name == self.MARKER_SIZE:
 57            return self._markerSizeValue
 58        elif name == self.TEXT_SIZE:
 59            return self._textSizeValue
 60        elif name == self.TRENDLINE:
 61            return self._trendline
 62        elif name == self.WINDOW:
 63            return self._window
 64        return None
 65    
 66    # ---------------------------------------------------------------------------
 67    def setValue(self, name: str, value: typing.Any) -> None:
 68        if name == self.MARKER_SIZE:
 69            self._markerSizeValue = int(value)
 70        elif name == self.TEXT_SIZE:
 71            self._textSizeValue = int(value)
 72        elif name == self.TRENDLINE:
 73            self._trendline = str(value)
 74        elif name == self.WINDOW:
 75            self._window = int(value)
 76    
 77    # ---------------------------------------------------------------------------
 78    def getDescription(self, name: str) -> typing.Optional[str]:
 79        if name == self.MARKER_SIZE:
 80            return hsttr.gt('The size of individual markers in the plot')
 81        elif name == self.TEXT_SIZE:
 82            return hsttr.gt('The size of text in the plot')
 83        elif name == self.TRENDLINE:
 84            return hsttr.gt('Show trendline in the plot')
 85        elif name == self.WINDOW:
 86            return hsttr.gt('Window size for rolling mean and EWM trendlines')
 87        return None
 88    
 89    # ---------------------------------------------------------------------------
 90    def getDefault(self, name: str) -> typing.Optional[typing.Any]:
 91        if name == self.MARKER_SIZE:
 92            return 20
 93        elif name == self.TEXT_SIZE:
 94            return 20
 95        elif name == self.TRENDLINE:
 96            return self.TRENDLINE_CHOICE_NONE
 97        elif name == self.WINDOW:
 98            return 5
 99        return None
100    
101    # -------------------------------------------------------------------------
102    def getChoices(self, name: str) -> typing.List[typing.Any]:
103        if name == self.TRENDLINE:
104            return [self.TRENDLINE_CHOICE_NONE,
105                    self.TRENDLINE_CHOICE_OLS, 
106                    self.TRENDLINE_CHOICE_KRR,
107                    self.TRENDLINE_CHOICE_ROLLING, 
108                    self.TRENDLINE_CHOICE_EWM]
109        return []
110    
111    # -------------------------------------------------------------------------
112    def getChoiceLabel(self, name: str, value: typing.Any) -> str:
113        if name == self.TRENDLINE:
114            if value == self.TRENDLINE_CHOICE_NONE:
115                return hsttr.gt('None')
116            elif value == self.TRENDLINE_CHOICE_OLS:
117                return 'OLS'
118            elif value == self.TRENDLINE_CHOICE_KRR:
119                return 'KRR'
120            elif value == self.TRENDLINE_CHOICE_ROLLING:
121                return hsttr.gt('Rolling')
122            elif value == self.TRENDLINE_CHOICE_EWM:
123                return 'EWMA'
124        return ''
125    
126    # -------------------------------------------------------------------------
127    def getChoiceDescription(self, name: str, value: typing.Any) -> str:
128        if name == self.TRENDLINE:
129            if value == self.TRENDLINE_CHOICE_NONE:
130                return hsttr.gt('None')
131            elif value == self.TRENDLINE_CHOICE_OLS:
132                return hsttr.gt('Ordinary Least Squares')
133            elif value == self.TRENDLINE_CHOICE_KRR:
134                return hsttr.gt('Kernel Ridge Regression')
135            elif value == self.TRENDLINE_CHOICE_ROLLING:
136                return hsttr.gt('Rolling Mean')
137            elif value == self.TRENDLINE_CHOICE_EWM:
138                return hsttr.gt('Exponentially Weighted Moving Average')
139        return ''
140    
141    # ---------------------------------------------------------------------------
142    def getValid(self, name: str) -> bool:
143        if name == self.MARKER_SIZE:
144            return self._markerSizeValue >= 0 and self._markerSizeValue < 100
145        elif name == self.TEXT_SIZE:
146            return self._textSizeValue >= 0 and self._textSizeValue < 100
147        elif name == self.WINDOW:
148            return self._window >= 0
149        return True
150    
151    # ---------------------------------------------------------------------------
152    def isSettingEnabled(self, name: str) -> bool:
153        if name == self.WINDOW:
154            return self._trendline in [self.TRENDLINE_CHOICE_ROLLING, self.TRENDLINE_CHOICE_EWM]
155        return True
156
157
158# -----------------------------------------------------------------------------
159# -----------------------------------------------------------------------------
160class DemoPlotlyScatter(custom_plot.CustomPlot):
161    """
162    A demo plot class that uses Plotly for plotting.
163    """
164
165    # -------------------------------------------------------------------------
166    def __init__(self) -> None:
167        super().__init__()
168        self.setUpdateControlsEnabled(True, True)
169
170
171    # -------------------------------------------------------------------------
172    @classmethod
173    def getDefaultSettings(cls) -> DemoPlotlyScatterSettings:
174        """
175        Returns the settings for the plot.
176        """
177        return DemoPlotlyScatterSettings()
178    
179    # -------------------------------------------------------------------------
180    @classmethod
181    def getIcon(cls) -> str:
182        if custom_plot.isDarkTheme():
183            return os.path.join(os.path.dirname(__file__), 'scatter_2D_dark.svg')
184        return os.path.join(os.path.dirname(__file__), 'scatter_2D.svg')
185    
186    # -------------------------------------------------------------------------
187    @classmethod
188    def getThumbnail(cls) -> str:
189        return cls.getIcon()
190
191    # -------------------------------------------------------------------------
192    def makeChannelConfig(self) -> custom_plot.ChannelConfig:
193        config = custom_plot.ChannelConfig()
194
195        # Add an Inputs tab
196        inputSection = config.addSection('Inputs')
197
198        # Add a Variables table to the Inputs section
199        inputSection.addTable('Variables', [custom_plot.ChannelTypes.VARIABLES],
200                              allowMultiSelect=False)
201
202        # Add an Outputs tab
203        outputSection = config.addSection('Outputs')
204
205        # Add a Responses table and Gradients table to the Outputs section
206        outputSection.addTable('Responses', [custom_plot.ChannelTypes.RESPONSES],
207                               allowMultiSelect=True)
208        outputSection.addTable('My Gradients Table', 
209                               [custom_plot.ChannelTypes.GRADIENTS],
210                               allowMultiSelect=True)
211        
212        return config
213    
214    # -------------------------------------------------------------------------
215    def getEventSetupScript(self) -> str:
216        # Setup the body of the JavaScript function that will handle the event
217        script = """
218            var pointNumber='', curveNumber='', colors=[];
219            for(var i=0; i < data.points.length; i++){
220                pointNumber = data.points[i].pointNumber;
221                curveNumber = data.points[i].curveNumber;
222            };
223            hst_plot.handlePointClickEvent(pointNumber, curveNumber);
224        """   
225          
226        # Create the script that the will connect the event handler to the Plotly graph
227        return """
228            var plot = document.querySelector(".plotly-graph-div");
229            if (plot === null) {
230                console.log("Plotly graph object not found");
231            } else {""" + self.createEventHandler('plot', 'plotly_click', script) + "}"
232
233    # -------------------------------------------------------------------------
234    def getPointSelectionScript(self, points: typing.List[typing.Union[str, int]], 
235                                _unused_curves: typing.Optional[typing.List[typing.Union[str, int]]] = None) -> str:
236        # Generate a JavaScript script that will change the color of the selected points
237        markerSize = self.getSettings().getValue(DemoPlotlyScatterSettings.MARKER_SIZE)
238        script = string.Template("""
239            var plot = document.querySelector(".plotly-graph-div"); 
240            if (plot.data.length > 0) { 
241                var resetUpdate = { marker: { color: null, size: null } }; 
242                var curves = [];
243                for (var i = 0; i < plot.data.length; i++) { curves.push(i); };
244                for (var i = 0; i < curves.length; i++) { 
245                    var curve = curves[i]; 
246                    var colors = [];
247                    if (plot._fullData[curve].marker === undefined) {
248                        continue;
249                    }
250                                 
251                    Plotly.restyle(plot, resetUpdate, [curve]); 
252                    for (var j = 0; j < plot.data[curve].x.length; j++) {
253                        colors.push(plot._fullData[curve].marker.color); 
254                    }; 
255                    ${POINT_COLORS}
256                    var update = { marker: { color: colors, size: ${MARKER_SIZE} } }; 
257                    Plotly.restyle(plot, update, [curve]); 
258                } 
259            }
260        """)
261        pointColors = ''
262        for point in points: 
263            pointColors += f'colors[{point}] = "#ffdb0f";'
264        return script.substitute(POINT_COLORS=pointColors, 
265                                 MARKER_SIZE=str(markerSize))
266    
267    # -------------------------------------------------------------------------
268    def getTrendline(self, xData: typing.Sequence[typing.Union[str, int, float]], 
269                     yData: typing.Sequence[typing.Union[str, int, float]]) -> typing.List[float]:
270        trendline = self.getSettings().getValue(DemoPlotlyScatterSettings.TRENDLINE)
271        trendWindow = typing.cast(int, self.getSettings().getValue(DemoPlotlyScatterSettings.WINDOW))
272        import pandas as pd
273        df = pd.DataFrame({'x': xData, 'y': yData})
274        df.sort_values('x', inplace=True)
275        result = yData
276
277        if trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_OLS:
278            from sklearn import linear_model # type: ignore[import-untyped]
279            import numpy as np
280            model = linear_model.LinearRegression()
281            xArray = np.array(df['x']).reshape(-1, 1)
282            model.fit(xArray, np.array(df['y']))
283            result = model.predict(xArray)
284        elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_KRR:
285            import numpy as np
286            from sklearn.kernel_ridge import KernelRidge # type: ignore[import-untyped]
287            from scipy.interpolate import interp1d # type: ignore[import-untyped]
288
289            x_array = np.array(df['x']).reshape(-1, 1)
290            y_array = np.array(df['y'])
291
292            sort_indices = np.argsort(x_array.flatten())
293            x_sorted = x_array[sort_indices]
294            y_sorted = y_array[sort_indices]
295            
296            x_smooth = np.linspace(min(df['x']), max(df['x']), len(df['x'])).reshape(-1, 1)
297            model = KernelRidge(alpha=0.1, kernel='rbf', gamma=0.1)
298            model.fit(x_sorted, y_sorted)
299            y_pred = model.predict(x_smooth)
300 
301            f = interp1d(x_smooth.flatten(), y_pred, bounds_error=False, fill_value="extrapolate")
302            result =  f(np.array(df['x'])).tolist()
303        elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_ROLLING:
304            result = df['y'].rolling(window=trendWindow).mean().tolist()
305        elif trendline == DemoPlotlyScatterSettings.TRENDLINE_CHOICE_EWM:
306            result = df['y'].ewm(span=trendWindow).mean().tolist()
307        return typing.cast(typing.List[float], result)
308
309
310    # -------------------------------------------------------------------------
311    def plot(self, channelSelection: custom_plot.ChannelSelection) -> typing.Tuple[custom_plot.PlotOutputType, str]:
312        import plotly.graph_objects as go # type: ignore[import-untyped]
313        import plotly
314        import numpy as np
315
316        # Get the selections from the Variables table of the Inputs section
317        variables = channelSelection.getItems('Inputs', 'Variables')
318        if not variables:
319            return custom_plot.PlotOutputType.HTML_STRING, ''
320        
321        # Make sure the item is actually a Variable type
322        if variables[0].getType() != custom_plot.ChannelTypes.VARIABLES:
323            raise ValueError(f'Invalid variable type: {variables[0].getType()}')
324
325        # Get the variable name and label, and the dataset to retrieve data
326        variableVarname = variables[0].getVarname()
327        variableLabel = variables[0].getLabel()
328        dataset = self.getDataSet()
329        xData = dataset.getStoredEvaluationValues(variableVarname)
330        fig = go.Figure()
331
332        # Add the scatter plot for each response in the Responses table of the Outputs section
333        responses = channelSelection.getItems('Outputs', 'Responses')
334        for response in responses:
335            self.checkIfStopped()
336            if response.getType() != custom_plot.ChannelTypes.RESPONSES:
337                raise ValueError(f'Invalid response type: {response.getType()}')
338            yData = dataset.getStoredEvaluationValues(response.getVarname())
339            xArray = np.array(xData)
340            yArray = np.array(yData)
341            fig.add_scatter(x=xArray, y=yArray, mode='markers', 
342                              name=response.getLabel(), showlegend=True)
343            
344            # Add a trendline if the setting is enabled
345            if self.getSettings().getValue(DemoPlotlyScatterSettings.TRENDLINE) != DemoPlotlyScatterSettings.TRENDLINE_CHOICE_NONE:
346                trendline = self.getTrendline(xData, yData)
347                fig.add_scatter(
348                    x=np.linspace(min(xData), max(xData), len(trendline)),
349                    y=np.array(trendline),
350                    mode='lines',
351                    name=f"{response.getLabel()} Trendline",
352                    line=dict(width=4),
353                    showlegend=True
354                    )
355        
356        # Update the plot layout with the custom plot settings
357        markerSize = self.getSettings().getValue(DemoPlotlyScatterSettings.MARKER_SIZE)
358        textSize = self.getSettings().getValue(DemoPlotlyScatterSettings.TEXT_SIZE)
359        fig.update_traces(marker_size=markerSize)
360        fig.update_layout(
361            font=dict(size=textSize),
362            title_text='Evaluation Data Scatter',
363            xaxis=dict(title=dict(text=variableLabel)),
364            legend=dict(title=dict(text='Responses')),
365            modebar_remove=['toImage'],
366        )
367
368        # Check if the plot has been stopped by the user before going any further.
369        # This will raise an exception.
370        self.checkIfStopped()
371    
372        # If you are returning a file path, make sure the file is saved to
373        # a writable location. Keep in mind that the file will contain
374        # data from your HyperStudy session and should be cleaned up
375        # after use.
376        return custom_plot.PlotOutputType.HTML_FILE, str(plotly.offline.plot(
377            fig, include_plotlyjs=True, output_type='file', 
378            filename=self.getPlotOutputFile(), auto_open=False))
379    
380        # If you are returning a string, plotly.js cannot be inlined in the HTML
381        # due to limitations of the renderer. You must use the CDN version or 
382        # have your working directory be a writable location. Note: the CDN version
383        # will make a network request to load the plotly.js library.
384        # return custom_plot.PlotOutputType.HTML_STRING, str(plotly.offline.plot(
385        #     fig, include_plotlyjs='cdn', output_type='div'))
386
387
388# -----------------------------------------------------------------------------
389def getPlotClass() -> typing.Type[custom_plot.CustomPlot]:
390    return DemoPlotlyScatter