Source code for mtpy.modeling.modem.model_manipulator

"""
==================
ModEM
==================

# Generate files for ModEM

# revised by JP 2017
# revised by AK 2017 to bring across functionality from ak branch

"""

import os

import numpy as np
from matplotlib import cm as cm, pyplot as plt, colorbar as mcb, colors as colors, widgets as widgets

from mtpy.imaging import mtplottools as mtplottools
from .data import Data
from .model import Model

__all__ = ['ModelManipulator']


[docs]class ModelManipulator(Model): """ will plot a model from wsinv3d or init file so the user can manipulate the resistivity values relatively easily. At the moment only plotted in map view. :Example: :: >>> import mtpy.modeling.ws3dinv as ws >>> initial_fn = r"/home/MT/ws3dinv/Inv1/WSInitialFile" >>> mm = ws.WSModelManipulator(initial_fn=initial_fn) =================== ======================================================= Buttons Description =================== ======================================================= '=' increase depth to next vertical node (deeper) '-' decrease depth to next vertical node (shallower) 'q' quit the plot, rewrites initial file when pressed 'a' copies the above horizontal layer to the present layer 'b' copies the below horizonal layer to present layer 'u' undo previous change =================== ======================================================= =================== ======================================================= Attributes Description =================== ======================================================= ax1 matplotlib.axes instance for mesh plot of the model ax2 matplotlib.axes instance of colorbar cb matplotlib.colorbar instance for colorbar cid_depth matplotlib.canvas.connect for depth cmap matplotlib.colormap instance cmax maximum value of resistivity for colorbar. (linear) cmin minimum value of resistivity for colorbar (linear) data_fn full path fo data file depth_index integer value of depth slice for plotting dpi resolution of figure in dots-per-inch dscale depth scaling, computed internally east_line_xlist list of east mesh lines for faster plotting east_line_ylist list of east mesh lines for faster plotting fdict dictionary of font properties fig matplotlib.figure instance fig_num number of figure instance fig_size size of figure in inches font_size size of font in points grid_east location of east nodes in relative coordinates grid_north location of north nodes in relative coordinates grid_z location of vertical nodes in relative coordinates initial_fn full path to initial file m_height mean height of horizontal cells m_width mean width of horizontal cells map_scale [ 'm' | 'km' ] scale of map mesh_east np.meshgrid of east, north mesh_north np.meshgrid of east, north mesh_plot matplotlib.axes.pcolormesh instance model_fn full path to model file new_initial_fn full path to new initial file nodes_east spacing between east nodes nodes_north spacing between north nodes nodes_z spacing between vertical nodes north_line_xlist list of coordinates of north nodes for faster plotting north_line_ylist list of coordinates of north nodes for faster plotting plot_yn [ 'y' | 'n' ] plot on instantiation radio_res matplotlib.widget.radio instance for change resistivity rect_selector matplotlib.widget.rect_selector res np.ndarray(nx, ny, nz) for model in linear resistivity res_copy copy of res for undo res_dict dictionary of segmented resistivity values res_list list of resistivity values for model linear scale res_model np.ndarray(nx, ny, nz) of resistivity values from res_list (linear scale) res_model_int np.ndarray(nx, ny, nz) of integer values corresponding to res_list for initial model res_value current resistivty value of radio_res save_path path to save initial file to station_east station locations in east direction station_north station locations in north direction xlimits limits of plot in e-w direction ylimits limits of plot in n-s direction =================== ======================================================= """ def __init__(self, model_fn=None, data_fn=None, **kwargs): # be sure to initialize Model Model.__init__(self, model_fn=model_fn, **kwargs) self.data_fn = data_fn self.model_fn_basename = kwargs.pop('model_fn_basename', 'ModEM_Model_rw.ws') if self.model_fn is not None: self.save_path = os.path.dirname(self.model_fn) elif self.data_fn is not None: self.save_path = os.path.dirname(self.data_fn) else: self.save_path = os.getcwd() # station locations in relative coordinates read from data file self.station_east = None self.station_north = None # --> set map scale self.map_scale = kwargs.pop('map_scale', 'km') self.m_width = 100 self.m_height = 100 # --> scale the map coordinates if self.map_scale == 'km': self.dscale = 1000. if self.map_scale == 'm': self.dscale = 1. # figure attributes self.fig = None self.ax1 = None self.ax2 = None self.cb = None self.east_line_xlist = None self.east_line_ylist = None self.north_line_xlist = None self.north_line_ylist = None # make a default resistivity list to change values self._res_sea = 0.3 self._res_air = 1E12 self.res_dict = None self.res_list = kwargs.pop('res_list', None) if self.res_list is None: self.set_res_list(np.array([self._res_sea, 1, 10, 50, 100, 500, 1000, 5000], dtype=np.float)) # set initial resistivity value self.res_value = self.res_list[0] self.cov_arr = None # --> set map limits self.xlimits = kwargs.pop('xlimits', None) self.ylimits = kwargs.pop('ylimits', None) self.font_size = kwargs.pop('font_size', 7) self.fig_dpi = kwargs.pop('fig_dpi', 300) self.fig_num = kwargs.pop('fig_num', 1) self.fig_size = kwargs.pop('fig_size', [6, 6]) self.cmap = kwargs.pop('cmap', cm.jet_r) self.depth_index = kwargs.pop('depth_index', 0) self.fdict = {'size': self.font_size + 2, 'weight': 'bold'} self.subplot_wspace = kwargs.pop('subplot_wspace', .3) self.subplot_hspace = kwargs.pop('subplot_hspace', .0) self.subplot_right = kwargs.pop('subplot_right', .8) self.subplot_left = kwargs.pop('subplot_left', .01) self.subplot_top = kwargs.pop('subplot_top', .93) self.subplot_bottom = kwargs.pop('subplot_bottom', .1) # plot on initialization self.plot_yn = kwargs.pop('plot_yn', 'y') if self.plot_yn == 'y': self.get_model() self.plot()
[docs] def set_res_list(self, res_list): """ on setting res_list also set the res_dict to correspond """ self.res_list = res_list # make a dictionary of values to write to file. self.res_dict = dict([(res, ii) for ii, res in enumerate(self.res_list, 1)]) if self.fig is not None: plt.close() self.plot()
# ---read files-------------------------------------------------------------
[docs] def get_model(self): """ reads in initial file or model file and set attributes: -resmodel -northrid -eastrid -zgrid -res_list if initial file """ # --> read in model file self.read_model_file() self.cov_arr = np.ones_like(self.res_model) # --> read in data file if given if self.data_fn is not None: md_data = Data() md_data.read_data_file(self.data_fn) # get station locations self.station_east = md_data.station_locations.rel_east self.station_north = md_data.station_locations.rel_north # get cell block sizes self.m_height = np.median(self.nodes_north[5:-5]) / self.dscale self.m_width = np.median(self.nodes_east[5:-5]) / self.dscale # make a copy of original in case there are unwanted changes self.res_copy = self.res_model.copy()
# ---plot model-------------------------------------------------------------
[docs] def plot(self): """ plots the model with: -a radio dial for depth slice -radio dial for resistivity value """ # set plot properties plt.rcParams['font.size'] = self.font_size plt.rcParams['figure.subplot.left'] = self.subplot_left plt.rcParams['figure.subplot.right'] = self.subplot_right plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom plt.rcParams['figure.subplot.top'] = self.subplot_top font_dict = {'size': self.font_size + 2, 'weight': 'bold'} # make sure there is a model to plot if self.res_model is None: self.get_model() self.cmin = np.floor(np.log10(min(self.res_list))) self.cmax = np.ceil(np.log10(max(self.res_list))) # -->Plot properties plt.rcParams['font.size'] = self.font_size # need to add an extra row and column to east and north to make sure # all is plotted see pcolor for details. plot_east = self.grid_east / self.dscale plot_north = self.grid_north / self.dscale # make a mesh grid for plotting # the 'ij' makes sure the resulting grid is in east, north self.mesh_east, self.mesh_north = np.meshgrid(plot_east, plot_north, indexing='ij') self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi) plt.clf() self.ax1 = self.fig.add_subplot(1, 1, 1, aspect='equal') # transpose to make x--east and y--north plot_res = np.log10(self.res_model[:, :, self.depth_index].T) self.mesh_plot = self.ax1.pcolormesh(self.mesh_east, self.mesh_north, plot_res, cmap=self.cmap, vmin=self.cmin, vmax=self.cmax) # on plus or minus change depth slice self.cid_depth = \ self.mesh_plot.figure.canvas.mpl_connect('key_press_event', self._on_key_callback) # plot the stations if self.station_east is not None: for ee, nn in zip(self.station_east, self.station_north): self.ax1.text(ee / self.dscale, nn / self.dscale, '*', verticalalignment='center', horizontalalignment='center', fontdict={'size': self.font_size - 2, 'weight': 'bold'}) # set axis properties if self.xlimits is not None: self.ax1.set_xlim(self.xlimits) else: self.ax1.set_xlim(xmin=self.grid_east.min() / self.dscale, xmax=self.grid_east.max() / self.dscale) if self.ylimits is not None: self.ax1.set_ylim(self.ylimits) else: self.ax1.set_ylim(ymin=self.grid_north.min() / self.dscale, ymax=self.grid_north.max() / self.dscale) # self.ax1.xaxis.set_minor_locator(MultipleLocator(100*1./dscale)) # self.ax1.yaxis.set_minor_locator(MultipleLocator(100*1./dscale)) self.ax1.set_ylabel('Northing (' + self.map_scale + ')', fontdict=self.fdict) self.ax1.set_xlabel('Easting (' + self.map_scale + ')', fontdict=self.fdict) depth_title = self.grid_z[self.depth_index] / self.dscale self.ax1.set_title('Depth = {:.3f} '.format(depth_title) + \ '(' + self.map_scale + ')', fontdict=self.fdict) # plot the grid if desired self.east_line_xlist = [] self.east_line_ylist = [] for xx in self.grid_east: self.east_line_xlist.extend([xx / self.dscale, xx / self.dscale]) self.east_line_xlist.append(None) self.east_line_ylist.extend([self.grid_north.min() / self.dscale, self.grid_north.max() / self.dscale]) self.east_line_ylist.append(None) self.ax1.plot(self.east_line_xlist, self.east_line_ylist, lw=.25, color='k') self.north_line_xlist = [] self.north_line_ylist = [] for yy in self.grid_north: self.north_line_xlist.extend([self.grid_east.min() / self.dscale, self.grid_east.max() / self.dscale]) self.north_line_xlist.append(None) self.north_line_ylist.extend([yy / self.dscale, yy / self.dscale]) self.north_line_ylist.append(None) self.ax1.plot(self.north_line_xlist, self.north_line_ylist, lw=.25, color='k') # plot the colorbar # self.ax2 = mcb.make_axes(self.ax1, orientation='vertical', shrink=.35) self.ax2 = self.fig.add_axes([.81, .45, .16, .03]) self.ax2.xaxis.set_ticks_position('top') # seg_cmap = ws.cmap_discretize(self.cmap, len(self.res_list)) self.cb = mcb.ColorbarBase(self.ax2, cmap=self.cmap, norm=colors.Normalize(vmin=self.cmin, vmax=self.cmax), orientation='horizontal') self.cb.set_label('Resistivity ($\Omega \cdot$m)', fontdict={'size': self.font_size}) self.cb.set_ticks(np.arange(self.cmin, self.cmax + 1)) self.cb.set_ticklabels([mtplottools.labeldict[cc] for cc in np.arange(self.cmin, self.cmax + 1)]) # make a resistivity radio button # resrb = self.fig.add_axes([.85,.1,.1,.2]) # reslabels = ['{0:.4g}'.format(res) for res in self.res_list] # self.radio_res = widgets.RadioButtons(resrb, reslabels, # active=self.res_dict[self.res_value]) # slider_ax_bounds = list(self.cb.ax.get_position().bounds) # slider_ax_bounds[0] += .1 slider_ax = self.fig.add_axes([.81, .5, .16, .03]) self.slider_res = widgets.Slider(slider_ax, 'Resistivity', self.cmin, self.cmax, valinit=2) # make a rectangular selector self.rect_selector = widgets.RectangleSelector(self.ax1, self.rect_onselect, drawtype='box', useblit=True) plt.show() # needs to go after show() self.slider_res.on_changed(self.set_res_value)
# self.radio_res.on_clicked(self.set_res_value)
[docs] def redraw_plot(self): """ redraws the plot """ current_xlimits = self.ax1.get_xlim() current_ylimits = self.ax1.get_ylim() self.ax1.cla() plot_res = np.log10(self.res_model[:, :, self.depth_index].T) self.mesh_plot = self.ax1.pcolormesh(self.mesh_east, self.mesh_north, plot_res, cmap=self.cmap, vmin=self.cmin, vmax=self.cmax) # plot the stations if self.station_east is not None: for ee, nn in zip(self.station_east, self.station_north): self.ax1.text(ee / self.dscale, nn / self.dscale, '*', verticalalignment='center', horizontalalignment='center', fontdict={'size': self.font_size - 2, 'weight': 'bold'}) # set axis properties if self.xlimits is not None: self.ax1.set_xlim(self.xlimits) else: self.ax1.set_xlim(current_xlimits) if self.ylimits is not None: self.ax1.set_ylim(self.ylimits) else: self.ax1.set_ylim(current_ylimits) self.ax1.set_ylabel('Northing (' + self.map_scale + ')', fontdict=self.fdict) self.ax1.set_xlabel('Easting (' + self.map_scale + ')', fontdict=self.fdict) depth_title = self.grid_z[self.depth_index] / self.dscale self.ax1.set_title('Depth = {:.3f} '.format(depth_title) + \ '(' + self.map_scale + ')', fontdict=self.fdict) # plot finite element mesh self.ax1.plot(self.east_line_xlist, self.east_line_ylist, lw=.25, color='k') self.ax1.plot(self.north_line_xlist, self.north_line_ylist, lw=.25, color='k') # be sure to redraw the canvas self.fig.canvas.draw()
# def set_res_value(self, label): # self.res_value = float(label) # print 'set resistivity to ', label # print self.res_value def set_res_value(self, val): self.res_value = 10 ** val print('set resistivity to ', self.res_value) def _on_key_callback(self, event): """ on pressing a key do something """ self.event_change_depth = event # go down a layer on push of +/= keys if self.event_change_depth.key == '=': self.depth_index += 1 if self.depth_index > len(self.grid_z) - 1: self.depth_index = len(self.grid_z) - 1 print('already at deepest depth') print('Plotting Depth {0:.3f}'.format(self.grid_z[self.depth_index] / \ self.dscale) + '(' + self.map_scale + ')') self.redraw_plot() # go up a layer on push of - key elif self.event_change_depth.key == '-': self.depth_index -= 1 if self.depth_index < 0: self.depth_index = 0 print('Plotting Depth {0:.3f} '.format(self.grid_z[self.depth_index] / \ self.dscale) + '(' + self.map_scale + ')') self.redraw_plot() # exit plot on press of q elif self.event_change_depth.key == 'q': self.event_change_depth.canvas.mpl_disconnect(self.cid_depth) plt.close(self.event_change_depth.canvas.figure) self.rewrite_model_file() # copy the layer above elif self.event_change_depth.key == 'a': try: if self.depth_index == 0: print('No layers above') else: self.res_model[:, :, self.depth_index] = \ self.res_model[:, :, self.depth_index - 1] except IndexError: print('No layers above') self.redraw_plot() # copy the layer below elif self.event_change_depth.key == 'b': try: self.res_model[:, :, self.depth_index] = \ self.res_model[:, :, self.depth_index + 1] except IndexError: print('No more layers below') self.redraw_plot() # undo elif self.event_change_depth.key == 'u': if type(self.xchange) is int and type(self.ychange) is int: self.res_model[self.ychange, self.xchange, self.depth_index] = \ self.res_copy[self.ychange, self.xchange, self.depth_index] else: for xx in self.xchange: for yy in self.ychange: self.res_model[yy, xx, self.depth_index] = \ self.res_copy[yy, xx, self.depth_index] self.redraw_plot()
[docs] def change_model_res(self, xchange, ychange): """ change resistivity values of resistivity model """ if type(xchange) is int and type(ychange) is int: self.res_model[ychange, xchange, self.depth_index] = self.res_value else: for xx in xchange: for yy in ychange: self.res_model[yy, xx, self.depth_index] = self.res_value self.redraw_plot()
[docs] def rect_onselect(self, eclick, erelease): """ on selecting a rectangle change the colors to the resistivity values """ x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata self.xchange = self._get_east_index(x1, x2) self.ychange = self._get_north_index(y1, y2) # reset values of resistivity self.change_model_res(self.xchange, self.ychange)
def _get_east_index(self, x1, x2): """ get the index value of the points to be changed """ if x1 < x2: xchange = np.where((self.grid_east / self.dscale >= x1) & \ (self.grid_east / self.dscale <= x2))[0] if len(xchange) == 0: xchange = np.where(self.grid_east / self.dscale >= x1)[0][0] - 1 return [xchange] if x1 > x2: xchange = np.where((self.grid_east / self.dscale <= x1) & \ (self.grid_east / self.dscale >= x2))[0] if len(xchange) == 0: xchange = np.where(self.grid_east / self.dscale >= x2)[0][0] - 1 return [xchange] # check the edges to see if the selection should include the square xchange = np.append(xchange, xchange[0] - 1) xchange.sort() return xchange def _get_north_index(self, y1, y2): """ get the index value of the points to be changed in north direction need to flip the index because the plot is flipped """ if y1 < y2: ychange = np.where((self.grid_north / self.dscale > y1) & \ (self.grid_north / self.dscale < y2))[0] if len(ychange) == 0: ychange = np.where(self.grid_north / self.dscale >= y1)[0][0] - 1 return [ychange] elif y1 > y2: ychange = np.where((self.grid_north / self.dscale < y1) & \ (self.grid_north / self.dscale > y2))[0] if len(ychange) == 0: ychange = np.where(self.grid_north / self.dscale >= y2)[0][0] - 1 return [ychange] ychange -= 1 ychange = np.append(ychange, ychange[-1] + 1) return ychange
[docs] def rewrite_model_file(self, model_fn=None, save_path=None, model_fn_basename=None): """ write an initial file for wsinv3d from the model created. """ if save_path is not None: self.save_path = save_path self.model_fn = model_fn if model_fn_basename is not None: self.model_fn_basename = model_fn_basename self.write_model_file()