Skip to content

Commit

Permalink
Merge pull request #870 from IGS/705-resolution-issue-for-umaptsne-in…
Browse files Browse the repository at this point in the history
…-gear-20

705 resolution issue for umaptsne in gear 20
  • Loading branch information
adkinsrs authored Aug 16, 2024
2 parents 8f8be49 + 75b3fbe commit a643743
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 25 deletions.
48 changes: 32 additions & 16 deletions www/api/resources/tsne_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import os
import re
from math import ceil
from math import ceil, log2

import geardb
import matplotlib as mpl
Expand Down Expand Up @@ -32,13 +32,14 @@
NUM_LEGENDS_PER_COL = 12 # Max number of legend items per column allowed in vertical legend
NUM_HORIZONTAL_COLS = 8 # Number of columns in horizontal legend

def calculate_figure_height(num_plots):
def calculate_figure_height(num_plots, span=1):
"""Determine height of tsne plot based on number of group elements."""
return (num_plots * 2) + (num_plots -1)
return ((num_plots * 4) * span) + (num_plots - 1)

def calculate_figure_width(num_plots):
def calculate_figure_width(num_plots, span=1):
"""Determine width of tsne plot based on number of group elements."""
return (num_plots * 6) + (num_plots -1)
# The + (num_plots - 1) is to account for the space between plots
return ((num_plots * 2) * span) + (num_plots - 1)

def calculate_num_legend_cols(group_len):
"""Determine number of columns legend should have in tSNE plot."""
Expand Down Expand Up @@ -195,8 +196,13 @@ def post(self, dataset_id):
projection_id = req.get('projection_id', None) # projection id of csv output
colorblind_mode = req.get('colorblind_mode', False)
high_dpi = req.get('high_dpi', False)
grid_spec = req.get('grid_spec', "1/1/2/2") # start_row/start_col/end_row/end_col (end not inclusive)
sc.settings.figdir = '/tmp/'

# convert max_columns to int
if max_columns:
max_columns = int(max_columns)

if not dataset_id:
return {
"success": -1,
Expand Down Expand Up @@ -444,7 +450,7 @@ def post(self, dataset_id):
elif color_idx_name in selected.obs:
# Alternative method. Associate with hexcodes already stored in the dataframe
# Making the assumption that these values are hexcodes
grouped = selected.obs.groupby([colorize_by, color_idx_name])
grouped = selected.obs.groupby([colorize_by, color_idx_name], observed=False)
# Ensure one-to-one mapping between category and hexcodes
if len(selected.obs[colorize_by].unique()) == len(grouped):
# Test if names are color hexcodes and use those if applicable (if first is good, assume all are)
Expand Down Expand Up @@ -473,7 +479,7 @@ def post(self, dataset_id):

max_cols = num_plots
if max_columns:
max_cols = min(int(max_columns), num_plots)
max_cols = min(max_columns, num_plots)

selected.obs["gene_expression"] = [float(x) for x in selected[:,selected.var.index.isin([selected_gene])].X]
max_expression = max(selected.obs["gene_expression"].tolist())
Expand All @@ -482,7 +488,6 @@ def post(self, dataset_id):
if order and plot_by_group in order:
column_order = order[plot_by_group]


for _,name in enumerate(column_order):
# Copy gene expression dataseries to observation
# Filter only expression values for a particular group.
Expand All @@ -509,9 +514,22 @@ def post(self, dataset_id):
io_fig = sc.pl.embedding(selected, **kwargs)
ax = io_fig.get_axes()

# break grid_spec into spans
grid_spec = grid_spec.split('/')
grid_spec = [int(x) for x in grid_spec]
row_span = grid_spec[2] - grid_spec[0]
col_span = ceil((grid_spec[3] - grid_spec[1]) / 3) # Generally these plots span columns in multiples of 4.

# Set the figsize (in inches)
dpi = io_fig.dpi # default dpi is 100, but will be saved as 150 later on
# With 2 plots as a default (gene expression and colorize_by), we want to grow the figure size slowly based on the number of plots

num_plots_wide = max_columns if max_columns else num_plots
num_plots_high = ceil(num_plots / num_plots_wide)

# set the figsize based on the number of plots
io_fig.set_figheight(calculate_figure_height(num_plots))
io_fig.set_figwidth(calculate_figure_width(num_plots))
io_fig.set_figwidth(calculate_figure_width(num_plots_wide, col_span))
io_fig.set_figheight(calculate_figure_height(num_plots_high, row_span))

# rename axes labels
if type(ax) == list:
Expand Down Expand Up @@ -542,21 +560,19 @@ def post(self, dataset_id):
else:
rename_axes_labels(ax, x_axis, y_axis)


# Close adata so that we do not have a stale opened object
if selected.isbacked:
selected.file.close()

with io.BytesIO() as io_pic:
# ? From what I'm reading and seeing, this line does not seem to make a difference if bbox_inches is set to "tight"
io_fig.tight_layout() # This crops out much of the whitespace around the plot. The "savefig" line does this with the legend too

# Set the saved figure dpi based on the number of observations in the dataset after filtering
if high_dpi:
dpi = max(150, int(selected.shape[0] / 100))
sc.settings.set_figure_params(dpi_save=dpi)
# if high_dpi, double the figsize height
io_fig.set_figheight(calculate_figure_height(num_plots) * 2)
# Double the height and width of the figure to maintain the same size
io_fig.set_figwidth(num_plots_wide * 10)
io_fig.set_figheight(num_plots_high * 10)

io_fig.savefig(io_pic, format='png', bbox_inches="tight")
else:
# Moved this to the end to prevent any issues with the dpi setting
Expand Down
2 changes: 1 addition & 1 deletion www/css/expression.css
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ ul#go-terms li {
width: 1080px; /* grid width + 2*border width */
}

/* tSNE plots */
.js-tile .card-image img {
object-fit: fill;
width: fit-content;
height: 100%;
margin: auto;
Expand Down
2 changes: 1 addition & 1 deletion www/css/projection.css
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ ul#go-terms li {
width: 1080px; /* grid width + 2*border width */
}

/* tSNE plots */
.js-tile .card-image img {
object-fit: fill;
width: fit-content;
height: 100%;
margin: auto;
Expand Down
24 changes: 17 additions & 7 deletions www/js/classes/tilegrid.js
Original file line number Diff line number Diff line change
Expand Up @@ -1203,15 +1203,16 @@ class DatasetTile {
// Determine how "download_png" is handled for scanpy plots
const downloadPNG = document.querySelector(`#tile-${this.tile.tileId} .dropdown-item[data-tool="download-png"]`);
if (downloadPNG) {
downloadPNG.classList.remove("is-hidden");

// Remove any existing event listeners
downloadPNG.removeEventListener("click", async (event) => {
await this.getScanpyPNG(display);
});
// If I use the existing "download image" button after switching displays, all previous tsne-static displays will
// also be downloaded becuase event listeners are not removed. So, I will remove the button and re-add it.
// Source -> https://stackoverflow.com/a/9251864

// Once = true so that the event listener is only called once
downloadPNG.addEventListener("click", async (event) => {
const newDownloadPNG = downloadPNG.cloneNode(true);
downloadPNG.parentNode.replaceChild(newDownloadPNG, downloadPNG);

newDownloadPNG.classList.remove("is-hidden");
newDownloadPNG.addEventListener("click", async (event) => {
// get the download URL
await this.getScanpyPNG(display);
});
Expand Down Expand Up @@ -1470,6 +1471,11 @@ class DatasetTile {
const plotType = display.plot_type;
const plotConfig = display.plotly_config;

const tileElement = document.getElementById(`tile-${this.tile.tileId}`);
if (!this.isZoomed) {
plotConfig.grid_spec = tileElement.style.gridArea // add grid spec to plot config
}

const plotContainer = document.querySelector(`#tile-${this.tile.tileId} .card-image`);
if (!plotContainer) return; // tile was removed before data was returned
plotContainer.replaceChildren(); // erase plot
Expand Down Expand Up @@ -1523,6 +1529,10 @@ class DatasetTile {
const plotConfig = JSON.parse(JSON.stringify(display.plotly_config));
plotConfig.high_dpi = true;

const tileElement = document.getElementById(`tile-${this.tile.tileId}`);
if (!this.isZoomed) {
plotConfig.grid_spec = tileElement.style.gridArea // add grid spec to plot config
}

const data = await apiCallsMixin.fetchTsneImage(datasetId, analysisObj, plotType, plotConfig);
if (data?.success < 1) {
Expand Down

0 comments on commit a643743

Please sign in to comment.