Add a custom plot to a benchmark#

Benchopt provides a set of default plots to visualize the results of a benchmark. These plots can be complemented with custom plots, defined in the benchmark, to visualize the results in a different way. These plots are defined in the plots directory, by adding python files with classes inheriting from benchopt.BasePlot. This page details the API to generate custom visualizations for your benchmark.

Structure of a custom plot#

A custom plot is defined by a class inheriting from benchopt.BasePlot and implementing:

  • name: The name of the plot title. This will be the name that appears in the plot selection menu of the HTML interface, or the name you can use to select this plot in config files for your benchmark.

  • type: The type of the plot, which defines how the output of plot will be rendered. Supported types are "scatter", "bar_chart", "boxplot", "table" and "image".

  • options: A dictionary defining the different options available for the plot. Typically, this can be used to have different plots depending on dataset’s or objective’s parameters, or to display customization options. The keys in the dictionary are the names of the options, associated to a list of their possible values. If a key objective/dataset/solver/objective_column is associated with the value ..., the options are automatically inferred from the results DataFrame, as all unique values associated with this key.

  • plot(self, df, **kwargs): give the data to produce one plot, that is rendered with the plotly or matplotlib backend. The method takes the results DataFrame and the options values as arguments, and returns the plot data. The output depends on the plot’s type, and are detailed below for each of them.

  • get_metadata(self, df, **kwargs): Gives global information about the plot, such as the title and axis labels. The method takes the results DataFrame and the options values as arguments, and returns the metadata of the plot, which is specific to each plot type.

The get_metadata method allow to change global properties of the resulting visualization, and the plot method outputs the data necessary to render it. The visualization is rendered using either the plotly or matplotly backend.

from benchopt import BasePlot

class Plot(BasePlot):
    name = "My Custom Plot"
    type = "scatter"  # or "bar_chart", "boxplot", "table" or "image"
    options = {
        "dataset": ...,         # Automatic options from DataFrame columns
        "objective": ...,
        "my_parameter": [1, 2], # custom options
    }

    # The inputs args of this method correspond to `df` and
    # the keys in the `options` dictionary.
    def plot(self, df, dataset, objective, my_parameter):
        # ... process df ...
        return plot_data

    def get_metadata(self, df, dataset, objective, my_parameter):
        return {
            "title": f"Plot for {dataset}",
            "xlabel": "X Label",
            "ylabel": "Y Label",
        }

Plot Options#

The options dictionary keys define the arguments passed to plot and get_metadata. Special keys like dataset, objective, solver will automatically try to match columns in the dataframe. Using ... as a value will populate the options with all unique values from the dataframe column {key}_name (e.g. dataset_name).

Scatter Plot#

For a scatter plot, the plot method should return a list of dictionaries, where each dictionary represents a trace in the plot. Each dictionary must contain:

  • x: A list of x values.

  • y: A list of y values.

  • label: The label of the trace

  • color (optional): The color of the trace.

  • marker (optional): The marker style of the trace.

  • y_low, y_high (optional): Lists of values to display uncertainty in the plot. They will be used to display shaded area around the plot.

  • x_low, x_high (optional): Lists of values to display uncertainty in the plot. They will be used to display shaded area around the plot. You can use either y_low/y_high or x_low/x_high, but not both.

The metadata dictionary returned by get_metadata should contain:

  • title: The title of the plot.

  • xlabel: The label of the x-axis.

  • ylabel: The label of the y-axis.

  • grid (optional, default=True): Whether to show grid lines in the plot. This only affects the matplotlib backend, not the html page.

  • scale (optional, default=”loglog”): The scale of the axes in the matplotlib backend, can be either “linear”, “semilog-x”, “semilog-y” or “loglog”.

def plot(self, df, dataset, objective, my_parameter):
    # Filter the dataframe
    df = df.query(
        "dataset_name == @dataset and objective_name == @objective"
    )

    plot_traces = []
    for solver, df_solver in df.groupby('solver_name'):
        # Compute the median over the repetitions
        curve = (
            df_solver.groupby("stop_val")[["time", "'objective_value"]]
            .median()
        )
        plot_traces.append({
            "x": curve['time'].tolist(),
            "y": curve['objective_value'].tolist(),
            "label": solver,
            **self.get_style(solver)
        })
    return plot_traces

def get_metadata(self, df, dataset, objective, my_parameter):
    return {
        "title": f"Convergence for {dataset}",
        "xlabel": "Time [sec]",
        "ylabel": "Objective value",
    }

Note

To help with consistent style accross figures, you can use the helper get_style, as described in Plotting Utilities.

Bar Chart#

For a bar chart, the plot method should return a list of dictionaries, where each dictionary represents a bar. For each bar, the median value will be used to determine its height, while the individual values will be displayed as scatter points. The dictionary should contain:

  • y: The list of values for the bar (the median will be the height of the bar).

  • label: The label of the bar.

  • color (optional): The color of the bar.

  • text (optional): The text to display on the bar.

The metadata dictionary returned by get_metadata should contain:

  • title: The title of the plot.

  • ylabel: The label of the y-axis.

  • grid (optional, default=True): Whether to show grid lines on the y-axis in the plot. This only affects the matplotlib backend, not the html page.

def plot(self, df, dataset, objective, **kwargs):
    df = df.query(
        "dataset_name == @dataset and objective_name == @objective"
    )
    bars = []
    for solver, df_solver in df.groupby('solver_name'):
        # Select the total runtime for each repetition
        runtimes = df_solver.groupby("idx_rep")["runtime"].last()
        bars.append({
            "y": runtimes.tolist(),
            "label": solver,
            "text": "",
            "color": self.get_style(solver)['color']
        })
    return bars

def get_metadata(self, df, dataset, objective, **kwargs):
    return {
        "title": f"Average times for {objective} on {dataset}",
        "ylabel": "Time [sec]",
    }

Box Plot#

For a box plot, the plot method should return a list of dictionaries, where each dictionary represents a box. Each dictionary should contain:

  • x: The x coordinate.

  • y: The values of the box for the corresponding x coordinate.

  • label: The label of the box.

  • color (optional): The color of the box.

The metadata dictionary returned by get_metadata should contain:

  • title: The title of the plot.

  • xlabel: The label of the x-axis.

  • ylabel: The label of the y-axis.

  • grid (optional, default=True): Whether to show grid lines on the y-axis in the plot. This only affects the matplotlib backend, not the html page.

  • box_width (optional, default=0.6): The width of the boxes, only affects the matplotlib backend, not the html page.

  • showfliers (optional, default=False): Whether to show fliers in the boxplot. Fliers are points that are outside the whiskers of the boxplot, which represent outliers in the data. This only affects the matplotlib backend, not the html page.

def plot(self, df, dataset, objective, **kwargs):
    df = df.query(
        "dataset_name == @dataset and objective_name == @objective"
    )
    plot_data = []
    for solver, df_solver in df.groupby('solver_name'):
        # Example: boxplot for the final objective values
        # for each solver
        final_objective_value = (
            df_solver.groupby("idx_rep")['objective_value'].last()
        )
        plot_data.append({
            "x": [solver],
            "y": [final_objective_value.tolist()],
            "label": solver,
            "color": self.get_style(solver)['color']
        })
    return plot_data

def get_metadata(self, df, dataset, objective, **kwargs):
    return {
        "title": f"Boxplot for {objective} on {dataset}",
        "xlabel": "Solver",
        "ylabel": "Objective value",
    }

Table Plot#

For a table plot, the plot method should return a list of lists, where each inner list represents a row in the table. The metadata dictionary returned by get_metadata should contain:

  • title: The title of the plot.

  • columns: A list of column names.

def plot(self, df, dataset, objective, **kwargs):
    df = df.query(
        "dataset_name == @dataset and objective_name == @objective"
    )
    rows = []
    for solver, df_solver in df.groupby('solver_name'):
        # Example: table with solver name and mean time
        # when using `sampling_strategy = 'run_once'`
        rows.append([solver, df_solver['time'].mean()])
    return rows

def get_metadata(self, df, dataset, objective, **kwargs):
    return {
        "title": f"Summary for {dataset}",
        "columns": ["Solver", "Mean Time [sec]"],
    }

Image Plot#

For an image plot, the plot method should return a list of dictionaries, where each dictionary represents one image card displayed in a grid. Each dictionary must contain:

  • image: Either an image-compatible array (rendered as a PNG) or a list of image-compatible arrays (rendered as an animated GIF showing per-iteration progress). A pre-computed base64 data URI or URL are also accepted. If set to None, this will create an empty image, which can be used for alignment purposes.

Optional keys:

  • label: Text displayed below the image card.

Arrays are expected to have values in [0, 1] and are converted automatically using Pillow, so no manual encoding is needed.

The metadata dictionary returned by get_metadata should contain:

  • title: The title displayed above the grid.

  • ncols: Number of columns in the grid (default: min(n_images, 3)).

Note

In the HTML result page, animated GIFs are rendered when a list of arrays is provided. In the matplotlib backend, each image card is shown as a static subplot using the last frame for animated sequences.

Plotting Utilities#

To ensure consistency across plots (e.g., using the same color and marker for a given solver), benchopt.BasePlot provides the helper method get_style(label). This method returns a dictionary with color and marker keys, which can be directly unpacked into the trace dictionary for scatter plots or used to select the color for other plot types. It automatically assigns a color from the default color palette based on the hash of the label, ensuring that the same solver always gets the same color.

# Usage in plot()
style = self.get_style(solver_name)
trace = {
    # ...
    "color": style["color"],
    "marker": style["marker"]
}
# Or simply:
trace = {
    # ...
    **self.get_style(solver_name)
}