Python Snippets
Package abbreviations
Unless otherwise specified, the snippets use the following conventions for package abbreviations
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
A few short bits
Ternary operator
If you also miss the ?:
operator, than have no fear!
x = 5 if (y < 6) else (y+3)
Note that the argument order is different (first two arguments are swapped). As with the C++ case, be careful with your brackets.
Colour order
If you want a list of the default colour ordering in matplotlib, you can use
colours = plt.rcParams['axes.prop_cycle'].by_key()['color']
Enumerate through a list
If you want to loop through a list and keep track of both position and entry in the list, enumerate
is your friend!
for index, entry in enumerate(a_list_of_stuff):
Improving exponential notation in colour bars
The default behaviour for handling exponents in colour bars isn't the prettiest, and can sometimes overlap the plot window itself. While there are options for shifting it exponent around to fit better, I find it looks a lot better to simply place it in the ylabel for the colour bar. This also provides a nice opportunity to add units!
Function Definition
First, we'll define a function, because it keeps things tidy.
To use it, simply pass the handle for a colour bar to the function ScientificCbar
.
It has some option arguments that may be useful.
units
(string) Simply set this argument to the desired string to have it included in the colour bar. The default is the empty string.orientation
(string) Specifies whether or not the colour bar should be horizontal or vertical. Default is'vertical'
.centre
(boolean) IfTrue
(the default), will forcibly centre the colour bar aroundcentre_val
.centre_val
(double) The centre value about which the colour will be centred, ifcentre=True
. Default is0
.labelpad
(double) Extra space to pad between label and colour bar. The larger it is, the further the label from the colour bar.label
(string) Identification label to be prepended to the colour bar label. E.g. the name of the field being plotted.
Notes:
- If you need more significant digits, modify the format string
{0:.2g}
(line 31) by changing the2
to be the desired number of digits - This setting only permits 5 ticks along the colour bar: this is to avoid clutter while giving basic information on scaling. If you need / want more, simply modify line 15.
import matplotlib as mpl
import numpy as np
def ScientificCbar(cbar, units='',
orientation='vertical', centre=True, centre_val=0.,
labelpad = 20, label=''):
# If requested, centre the colour bar
if centre:
cb_vals = cbar.mappable.get_clim()
cv = np.max(np.abs(np.array([val - centre_val for val in cb_vals])))
cbar.mappable.set_clim(centre_val-cv, centre_val+cv)
# Limit the number of ticks on the colour bar
tick_locator = mpl.ticker.MaxNLocator(nbins=5)
cbar.locator = tick_locator
cbar.update_ticks()
ticks = cbar.get_ticks()
# Re-scale the values to avoid the poorly-placed exponent
# above the colour bar
scale = np.log10(np.max(np.abs(ticks)))
scale = np.floor(scale)
# Label
cb_label = '$\\times10^{' + '{0:d}'.format(int(scale)) + '}$ ' + units
if len(label) > 0:
cb_label = label + '\n' + cb_label
# Tick labels
tick_labels = ["{0:.2g}".format(tick/(10**scale)) for tick in ticks]
# Instead, simply add a ylabel to the colour bar giving the scale.
if orientation == 'vertical':
if scale != 0.:
cbar.ax.set_yticklabels(tick_labels)
cbar.ax.set_ylabel(cb_label, rotation = '-90', labelpad = labelpad)
elif orientation == 'horizontal':
if scale != 0.:
cbar.ax.set_xticklabels(tick_labels)
cbar.ax.set_xlabel(cb_label, rotation = '0', labelpad = labelpad)
Sample Usage / Organizing Multiple Colour Bars
This is just a snippet, not a full working example.
One could certainly just throw the three loop segments into one part, but they're kept separate here because in practice, the contents of the first loop is likely to be a lot messier.
fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(7,7))
# Create an array to store the pcolormesh objects
qs = np.zeros(axes.shape, dtype='object')
# Also create an array to store the colour bars
cbars = np.zeros(axes.shape, dtype='object')
# Suppose arrays x and y already exist.
# Suppose data is 4D, with the first two corresponding to y,x
# and the last two being some parameter index
for ii in range(axes.shape[0]):
for jj in range(axes.shape[1]):
qs[ii,jj] = axes[ii,jj].pcolormesh(x, y, data[:,:,ii,jj])
for ii in range(axes.shape[0]):
for jj in range(axes.shape[1]):
cbars[ii,jj] = plt.colorbar(qs[ii,jj], ax=axes[ii,jj])
for ii in range(axes.shape[0]):
for jj in range(axes.shape[1]):
ScientificCbar(cbars[ii,jj])
Sample Usage / Common Colour Bar for Multiple Plots
In this snippet, suppose that we have a 2-by-N axis array 'axes'. Further suppose that we have already created plots and stored the handles in 'qs' (as above).
We would now like to create a common colour bar for each row, apply our ScientificCbar function, and ensure that all of the plots in each row have the same colour bar limits.
The following does exactly this.
# Create array to store colour bars
cbars = np.zeros((2,), dtype='object')
# Initialize colour bars to take values from the left-most column, but to take space equally from the entire row.
for ii in range(len(cbars)):
cbars[ii] = plt.colorbar(qs[ii,0], ax=axes[ii,:])
# Scientific-ify the colour bars
ScientificCbar(cbars[0])
ScientificCbar(cbars[1])
# For the plots in the other columns to have the same clim as the first column.
for ii in range(len(cbars)):
for jj in range(1, qs.shape[1]):
qs[ii,jj].set_clim(qs[ii,0].get_clim())
Creating Movie Files
Matplotlib has some built-in animation tools, but I've generally found that they only work well for very simple figures. For more complex things, I find it can be cumbersome or crash-prone. [Note: this may have changed since I've last used them, but I still like using the manual approach, since it can be parallelized trivially]
The general idea is just to produce a directory full of png files (such as frame_0014.png
) and then combine them into a single movie file using ffmpeg.
A basic example for frame generation is given below. The for loop can (generally) be trivially parallelized with MPI4py.
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2 * np.pi, 50)
t = np.linspace(0, 2 * np.pi, 50)
fig, axes = plt.subplots(1, 1, figsize=(6,6))
for It in range(len(t)):
axes.clear()
y = np.sin( x - t[It] )
axes.plot( x, y, line width = 1 + t[It] )
plt.savefig('frame_{0:04d}.png'.format(It), dpi=200)
Merging Images into a Movie
The following function provides a clean wrapper for merging a collection of images (such as pngs) into a movie (such as an mp4).
Notes:
- This requires that ffmpeg be installed and accessible at command-line.
- The first argument,
frame_filenames
, is a string that indicates the files. For example, if the files are img_000.png, img_001.png, ..., then you would pass'img_%03d.png'
- The second argument,
movie_name
, is the desired output filename. This should include the extension. For example,'my_mov.mp4'
- The (optional) argument
fps
specifies the frame rate out the output. Default is 12 frames per second. - The first two arguments use relative paths, so you can merge files in one directory and save the video in another.
- The log files
ffmpeg.log
andffmpeg.err
contain the outputs fromstdout
andstderr
, respectively
In the above example, we would simply call merge_to_mp4('frame_%04d.png', 'my_movie.mp4')
in your python script.
[Note: if parallelized, then only one process should call the merge function!]
import subprocess
def merge_to_mp4(frame_filenames, movie_name, fps=12):
f_log = open("ffmpeg.log", "w")
f_err = open("ffmpeg.err", "w")
cmd = ['ffmpeg', '-framerate', str(fps), '-i', frame_filenames, '-y',
'-q', '1', '-threads', '0', '-pix_fmt', 'yuv420p', movie_name]
subprocess.call(cmd, stdout=f_log, stderr=f_err)
f_log.close()
f_err.close()
Plotting Multiple Lines, Coloured by a Colour Bar
A common desire, particularly for plotting a sweep across a parameter (such as a time evolution), is to plot multiple line plots, but to have them be coloured sequentially, following a colour map, and to have a colour bar showing the evolution.
The below example illustrates exactly how to do this.
Lines 8-15 are purely to generate a sample data-set: a right-wards propagating Gaussian, the steadily widens.
Lines 17-20 are used to sub-sample the data so that we don't have too many lines.
Line 25 creates a list of the desired lines sets.
The colour map is specified in line 28.
Lines 37 - 45 add dots indicating the data points (much like line-style '-o'), using the same colour map as the lines. A function definition is included, since it can be convenient for tidying code when repeatedly using this method.
import numpy as np
import matplotlib
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1)
Nx, Nt = 100, 100
x = np.linspace(-1, 1, Nx)
t = np.linspace( 0, 5, Nt)
X, T = np.meshgrid(x, t)
h = np.exp(-( (X - T/5 + 0.5) / (0.2 + T/10))**2)
## Determine which times we want to show
start = 0 # start at the beginning
stop = Nt//2 # only plot first half of simulation
skip = (stop-start) // 10 # only use 10 lines to plot evolution
plot_inds = np.arange(start, stop, skip)
# Select out the data that we want to use for creating lines
h_lines = [np.column_stack([x, h[ii,:]]) for ii in plot_inds]
# Create a LineCollection object which holds them together
cmap = matplotlib.cm.viridis
h_segments = LineCollection(h_lines, cmap = cmap)
# Set which variable is used with the colour map (determines colours)
h_segments.set_array(t[plot_inds])
# Plot the lines
ax.add_collection(h_segments)
# Overlay dots - shows where the data points are
def points_from_LineCollection( h_segments ):
x = np.array([i[0] for j in h_segments.get_segments() for i in j])
y = np.array([i[1] for j in h_segments.get_segments() for i in j])
c = np.array([col for col in h_segments.get_array() for _ in range(len(h_segments.get_segments()[0]))])
return x, y, c
x,y,c = points_from_LineCollection( h_segments )
ax.scatter(x, y, c = c, cmap = cmap, s = 2)
# Need to manually update the bounds
ax.axis('tight')
# Add a colour bar showing time
cbar = plt.colorbar(h_segments, ax = ax)
Speeding-up Repeated Interpolations
Suppose you have a two grids x_1,x_2,...,x_n
and y_1,y_2,...,y_n
and a bunch of functions that are sampled at the first grid (i.e. f_1(x_1,x_2,...,x_n),f_2(x_1,x_2,...,x_n),...,f_m(x_1,x_2,...,x_n)
) that you want to have at the second grid.
You can use the built-in interpolation functions (such as griddata
from scipy
), but that can be very slow / expensive since it has to re-create the grid-to-grid information (built simplices, compute interpolation weights, etc).
The problem (and solution) is well described in this StackOverflow post. The solution is reproduced here, along with a few modifications to account for some Python3 changes etc.
Notes:
- This only uses linear interpolation
- This does not use 'fill values'. You can identify cells that are outside the complex hull by searching for weights (wts) that are negative.
import numpy as np
import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools
def interp_weights(xyz, uvw):
d = np.min( xyz.shape )
tri = qhull.Delaunay(xyz)
simplex = tri.find_simplex(uvw)
vertices = np.take(tri.simplices, simplex, axis=0)
temp = np.take(tri.transform, simplex, axis=0)
delta = uvw - temp[:, d]
bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))
def interpolate(values, vtx, wts):
return np.einsum('nj,nj->n', np.take(values, vtx), wts)
x_1 = np.linspace(0, 1, 100)
x_2 = np.linspace(0, 2, 150)
y_1 = np.linspace(0, 1, 500)
y_2 = np.linspace(0, 2, 700)
X_1, X_2 = np.meshgrid(x_1, x_2)
Y_1, Y_2 = np.meshgrid(y_1, y_2)
f_x = np.sin( 2*np.pi*X_1 ) + np.exp(-X_2)
g_x = 2*X_1 / (1 + X_2)
source_points = np.vstack( ( X_1.ravel(), X_2.ravel() ) ).T
target_points = np.vstack( ( Y_1.ravel(), Y_2.ravel() ) ).T
# The bulk of run-time happens here, which only needs to be done once
vtx, wts = interp_weights( source_points, target_points )
# The individual interpolations are now much faster, and adding additional interpolations is very cheap.
f_y = interpolate( f_x.ravel(), vtx, wts ).reshape(Y_1.shape)
g_y = interpolate( g_x.ravel(), vtx, wts ).reshape(Y_1.shape)
Histogram with Percent on Vertical Axis
Would you like to have a histogram that shows percentage on the vertical axis? Well here you are! (thanks to this StackOverflow post)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
fig, ax = plt.subplots(1, 1, gridspec_kw = dict(left = 0.15, right = 0.99, bottom = 0.15, top = 0.98 ) )
sample_data = np.random.randn(500)
counts, edges = np.histogram( sample_data, bins = 'auto' )
ax.hist( sample_data, bins = edges, weights = np.ones(len(sample_data)) / len(sample_data) )
ax.grid(True, linestyle = '--', alpha = 0.7)
ax.set_xlabel('Data value')
ax.yaxis.set_major_formatter(PercentFormatter(1, decimals = 0))