Source code for parmed.amber.mask

"""
Module for evaluating Amber Mask strings and translating them into lists in
which a selected atom is 1 and one that's not is 0.
"""
from __future__ import division, print_function, absolute_import

from ..exceptions import MaskError
from ..periodic_table import AtomicNum
from ..utils.six.moves import range

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

[docs]class AmberMask(object): """ What is hopefully a fully-fledged Amber mask parser implemented in Python. Parameters ---------- parm : Structure The topology structure for which to select atoms mask : str The mask string that selects a subset of atoms """ #======================================================
[docs] def __init__(self, parm, mask): self.parm = parm self.mask = mask.strip()
#====================================================== def __str__(self): return self.mask #======================================================
[docs] def Selected(self, invert=False): """ Generator that returns the indexes of selected atoms Parameters ---------- invert : bool, optional If True, all atoms *not* selected by the mask will be returned Returns ------- generator of int Each iteration will yield the index of the next atom that has been selected by the mask. Atom indices are 0-based """ for i, v in enumerate(self.Selection(invert=invert)): if v: yield i
#======================================================
[docs] def Selection(self, prnlev=0, invert=False): """ Parses the mask and analyzes the result to return an atom selection array Parameters ---------- prnlev : int, optional Print debug information on the processing of the Amber mask string. This is mainly useful if you are modifying the mask parser. Default value is 0 (no printout), values between 1 and 8 control the level of output (larger values produce more output). Default 0 invert : bool, optional If True, the returned array will invert the selection of the mask (i.e., selected atoms will not be selected and vice-versa) Returns ------- mask : list of int A list with length equal to the number of atoms in the assigned :class:`Structure <parmed.structure.Structure>` instance. Selected atoms will have a value of 1 placed in the corresponding slot in the return list while atoms not selected will be assigned 0. """ from sys import stderr, stdout if prnlev > 2: stderr.write('In AmberMask.Selection(), debug active!\n') if prnlev > 5: stdout.write('original mask: ==%s==\n' % self.mask) # 0) See if we got the default "all" mask(*) and return accordingly if self.mask.strip() == '*': return [1 for atom in self.parm.atoms] # 1) preprocess input expression infix = self._tokenize(prnlev) if prnlev > 5: stdout.write('tokenized mask: ==%s==\n' % infix) # 2) construct postfix (RPN) notation postfix = self._torpn(infix, prnlev) if prnlev > 5: stdout.write('postfix mask: ==%s==\n' % postfix) # 3) evaluate the postfix notation if invert: return [1-i for i in self._evaluate(postfix, prnlev)] return self._evaluate(postfix, prnlev)
#====================================================== def _tokenize(self, prnlev): """ Tokenizes the mask string into individual selections: 1. remove spaces 2. isolate 'operands' into brackets [...] 3. split expressions of the type :1-10@CA,CB into 2 parts; the 2 parts are joned with & operator and (for the sake of preserving precedence of other operators) enclosed by (...); i.e. :1-10@CA,CB is split into (:1-10 & @CA,CB) 4. do basic error checking """ buffer = '' # keeping track of a single operand infix = '' # value that is returned at the end # flag == 0: means new operand or operand was completed & ended with ] # flag == 1: means operand with ":" read # flag == 2: means operand with "@" read # flag == 3: means '<' or '>' read, waiting for numbers flag = 0 i = 0 while i < len(self.mask): p = self.mask[i] # skip whitespace if p.isspace(): i += 1 continue # If p is an operator, is the last character, or is a ()... elif self._isOperator(p) or i == len(self.mask) - 1 or p in ['(',')']: # Deal with the last character being a wildcard that we have to # convert if p == '=' and i == len(self.mask) - 1: # wildcard if flag > 0: p = '*' else: raise MaskError("AmberMask: '=' not in name list syntax") # If this is the end of an operand, terminate the buffer, flush # it to infix, and reset flag to 0 and empty the buffer if flag > 0: if i == len(self.mask) - 1 and p != ')': buffer += p buffer += '])' flag = 0 infix += buffer buffer = '' if i != len(self.mask) - 1 or p == ')': infix += p # else if p is >,< if p in ['<','>']: buffer = '([%s' % p i += 1 try: p = self.mask[i] except IndexError: raise MaskError('Bad distance syntax [%s]' % self.mask) buffer += p flag = 3 if self.parm.coordinates is None: raise MaskError('<,> operators require coordinates') if not p in [':','@']: raise MaskError('Bad syntax [%s]' % self.mask) elif self._isOperand(p): if flag == 0: buffer = '([' flag = 1 if p != '*': raise MaskError('Bad syntax [%s]' % self.mask) if p == '=': # wildcard if flag > 0: p = '*' else: raise MaskError("'=' not in name list syntax") buffer += p elif p == ':': if flag == 0: buffer = '([:' flag = 1 else: buffer += '])|([:' flag = 1 elif p == '@': if flag == 0: buffer = '([@' flag = 2 elif flag == 1: buffer += ']&[@' flag = 2 elif flag == 2: buffer += '])|([@' flag = 2 else: raise MaskError('Unknown symbol (%s) expression' % p) i += 1 # end while i < len(self.mask): # Check that each operand has at least 4 characters: [:1] and [@C], etc. i = 0 n = 1 # number of characters in current operand flag = 0 while i < len(infix): p = infix[i] if p == '[': n += 1 flag = 1 elif p == ']': if n < 4 and infix[i-1] != '*': raise MaskError('empty token in infix') n = 1 else: if flag == 1: n += 1 i += 1 return infix + '\n' # terminating \n for next step #====================================================== def _isOperator(self, char): """ Determines if a character is an operator """ return len(char) == 1 and char in '!&|<>' #====================================================== def _isOperand(self, char): """ Determines if a character is an operand """ return len(char) == 1 and (char in "\\*/%-?,'.=+_" or char.isalnum()) #====================================================== def _torpn(self, infix, prnlev): """ Converts the infix to an RPN array """ postfix = '' stack = ['\n'] # use a list as a stack. Then pop() works as expected flag = 0 i = 0 while i < len(infix): p = infix[i] if p == '[': postfix += p flag = 1 elif p == ']': postfix += p flag = 0 elif flag: postfix += p elif p == '(': stack.append(p) elif p == ')': pp = stack.pop() while pp != '(': if pp == '\n': raise MaskError('Unbalanced parentheses in Mask.') postfix += pp pp = stack.pop() # At this point both ()s are discarded elif p == '\n': pp = stack.pop() while pp != '\n': if pp == '(': raise MaskError('Unbalanced parentheses in Mask.') postfix += pp pp = stack.pop() elif self._isOperator(p): P1 = self._priority(p) P2 = self._priority(stack[len(stack)-1]) if P1 > P2: stack.append(p) else: while P1 <= P2: pp = stack.pop() postfix += pp P1 = self._priority(p) P2 = self._priority(stack[len(stack)-1]) stack.append(p) else: raise MaskError('Unknown symbol %s' % p) # should not reach here i += 1 # end while i < len(infix): return postfix #====================================================== def _evaluate(self, postfix, prnlev): """ Evaluates a postfix in RPN format and returns a selection array """ from sys import stderr buffer = '' stack = [] pos = 0 # position in postfix while pos < len(postfix): p = postfix[pos] if p == '[': buffer = '' elif p == ']': # end of the token ptoken = buffer pmask = self._selectElemMask(ptoken) stack.append(pmask) elif self._isOperand(p) or p in [':','@']: buffer += p elif p in ['&','|']: pmask1 = None pmask2 = None try: pmask1 = stack.pop() pmask2 = stack.pop() pmask = self._binop(p, pmask1, pmask2) except IndexError: raise MaskError('Illegal binary operation') stack.append(pmask) elif p in ['<','>']: if pos < len(postfix)-1 and postfix[pos+1] in [':','@']: buffer += p else: try: pmask1 = stack.pop() # distance criteria pmask2 = stack.pop() pmask = self._selectDistd(pmask1, pmask2) except IndexError: return [0 for a in self.parm.atoms] stack.append(pmask) elif p == '!': try: pmask1 = stack.pop() except IndexError: raise MaskError('Illegal ! operation') pmask = self._neg(pmask1) stack.append(pmask) else: raise MaskError('Unknown symbol evaluating RPN: %s' % p) pos += 1 # end while i < len(postfix) try: pmask = stack.pop() except IndexError: raise MaskError('Empty stack -- no available operands') if stack: raise MaskError('There may be missing operands in the mask') if prnlev > 7: stderr.write('%d atoms selected by %s' % (sum(pmask), self.mask)) return pmask #====================================================== def _neg(self, pmask1): """ Negates a given mask """ return pmask1.Not() #====================================================== def _selectDistd(self, pmask1, pmask2): """ Selects atoms based on a distance criteria """ # pmask1 is either @<number> or :<number>, and represents the distance # criteria. pmask2 is the selection of atoms from which the distance is # evaluated. pmask = _mask(len(self.parm.atoms)) # Determine if we want > or < if pmask1[0] == '<': cmp = lambda x, y: x < y elif pmask1[0] == '>': cmp = lambda x, y: x > y else: # Should never execute this raise MaskError('Unknown comparison criteria for distance mask: %s' % pmask1[0]) pmask1 = pmask1[1:] if pmask1[0] not in ':@': # Should never execute this raise MaskError('Bad distance criteria for mask: %s' % pmask1) try: distance = float(pmask1[1:]) except (TypeError, ValueError): raise MaskError('Distance must be a number: %s' % pmask1[1:]) distance *= distance # Faster to compare square of distance # First select all atoms that satisfy the distance. If we ended up # choosing residues, then we will go back through afterwards and select # entire residues when one of the atoms in that residue is selected. idxlist = [i for i, val in enumerate(pmask2) if val == 1] for i, atomi in enumerate(self.parm.atoms): for j in idxlist: atomj = self.parm.atoms[j] dx = atomi.xx - atomj.xx dy = atomi.xy - atomj.xy dz = atomi.xz - atomj.xz d2 = dx*dx + dy*dy + dz*dz if cmp(d2, distance): pmask[i] = 1 break # Now see if we have to select all atoms in residues with any selected # atoms if pmask1[0] == ':': for res in self.parm.residues: for atom in res.atoms: if pmask[atom.idx] == 1: for atom in res.atoms: pmask[atom.idx] = 1 break return pmask #====================================================== def _selectElemMask(self, ptoken): """ Selects an element mask """ # some constants ALL = 0 NUMLIST = 1 NAMELIST = 2 TYPELIST = 3 ELEMLIST = 4 # define the mask object and empty buffer pmask = _mask(len(self.parm.atoms)) buffer = '' buffer_p = 0 # This is a residue NUMber LIST if ptoken.startswith(':'): reslist = NUMLIST pos = 1 while pos < len(ptoken): p = ptoken[pos] buffer += p buffer_p += 1 if p == '*' and ptoken[pos-1] != '\\': if buffer_p == 1 and (pos == len(ptoken) - 1 or ptoken[pos+1] == ','): reslist = ALL elif reslist == NUMLIST: reslist = NAMELIST elif p.isalpha() or p in '_?*': reslist = NAMELIST if pos == len(ptoken) - 1: buffer_p = 0 if len(buffer) != 0 and buffer_p == 0: if reslist == ALL: pmask.select_all() elif reslist == NUMLIST: self._residue_numlist(buffer, pmask) elif reslist == NAMELIST: self._residue_namelist(buffer, pmask) reslist = NUMLIST pos += 1 elif ptoken.startswith('@'): atomlist = NUMLIST pos = 1 while pos < len(ptoken): p = ptoken[pos] buffer += p buffer_p += 1 if p == '*' and ptoken[pos-1] != "\\": if atomlist == NUMLIST: atomlist = NAMELIST elif p.isalpha() or p in '?*_': if atomlist == NUMLIST: atomlist = NAMELIST elif p == '%': atomlist = TYPELIST elif p == '/': atomlist = ELEMLIST if pos == len(ptoken) - 1: buffer_p = 0 if len(buffer) != 0 and buffer_p == 0: if atomlist == ALL: pmask.select_all() elif atomlist == NUMLIST: self._atom_numlist(buffer, pmask) elif atomlist == NAMELIST: self._atom_namelist(buffer, pmask) elif atomlist == TYPELIST: self._atom_typelist(buffer[1:], pmask) elif atomlist == ELEMLIST: self._atom_elemlist(buffer[1:], pmask) pos += 1 elif ptoken.strip() == '*': pmask.select_all() elif ptoken[0] in ['<','>']: return ptoken else: # Should never reach here raise MaskError('Mask is missing : and @') # end if ':' in ptoken: return pmask #====================================================== def _atom_numlist(self, instring, mask): """ Fills a _mask based on atom numbers """ buffer = '' pos = 0 at1 = at2 = dash = 0 while pos < len(instring): p = instring[pos] if p.isdigit(): buffer += p if p == ',' or pos == len(instring) - 1: if dash == 0: at1 = int(buffer) self._atnum_select(at1, at1, mask) else: at2 = int(buffer) self._atnum_select(at1, at2, mask) dash = 0 buffer = '' elif p == '-': at1 = int(buffer) dash = 1 buffer = '' if not (p.isdigit() or p in [',','-']): raise MaskError('Unknown symbol in atom number parsing [%s]'%p) pos += 1 #====================================================== def _atom_namelist(self, instring, mask, key='name'): """ Fills a _mask based on atom names/types """ buffer = '' pos = 0 while pos < len(instring): p = instring[pos] if p.isalnum() or p in "\\*?+'-_": buffer += p if p == ',' or pos == len(instring) - 1: if '-' in buffer and buffer[0].isdigit(): self._atom_numlist(buffer, mask) else: self._atname_select(buffer, mask, key) buffer = '' if not (p.isalnum() or p in "\\,?*'+-_"): raise MaskError('Unrecognized symbol in atom name parsing [%s]' % p) pos += 1 #====================================================== def _atom_typelist(self, buffer, mask): """ Fills a _mask based on atom types """ self._atom_namelist(buffer, mask, key='type') #====================================================== def _atom_elemlist(self, buffer, mask): """ Fills a _mask based on atom elements. For now it will just be Atom names, since elements are not stored in the prmtop anywhere. """ self._atom_namelist(buffer, mask, key='element') #====================================================== def _residue_numlist(self, instring, mask): """ Fills a _mask based on residue numbers """ buffer = '' pos = 0 at1 = at2 = dash = 0 while pos < len(instring): p = instring[pos] if p.isdigit(): buffer += p if p == ',' or pos == len(instring) - 1: if dash == 0: at1 = int(buffer) self._resnum_select(at1, at1, mask) else: try: at2 = int(buffer) except ValueError: raise MaskError('Bad mask: error in integer conversion') self._resnum_select(at1, at2, mask) dash = 0 buffer = '' elif p == '-': at1 = int(buffer) dash = 1 buffer = '' pos += 1 #====================================================== def _residue_namelist(self, instring, mask): """ Fills a _mask based on residue names """ buffer = '' pos = 0 while pos < len(instring): p = instring[pos] if p.isalnum() or p in ['*','?','+',"'",'-']: buffer += p if p == ',' or pos == len(instring) - 1: if '-' in buffer and buffer[0].isdigit(): self._residue_numlist(buffer, mask) else: self._resname_select(buffer, mask) buffer = '' if not (p.isalnum() or p in ",?*'+-"): raise MaskError('Unknown symbol in residue name parsing [%s]' % p) pos += 1 #====================================================== def _atnum_select(self, at1, at2, mask): """ Fills a _mask array between atom numbers at1 and at2 """ for i in range(at1-1, at2): mask[i] = 1 #====================================================== def _resnum_select(self, res1, res2, mask): """ Fills a _mask array between residues res1 and res2 """ for i, atom in enumerate(self.parm.atoms): res = atom.residue.idx + 1 if res >= res1 and res <= res2: mask[i] = 1 #====================================================== def _atname_select(self, atname, mask, key='name'): """ Fills a _mask array with all atom names of a given name """ if atname.isdigit(): atname = int(atname) - 1 for i, atom in enumerate(self.parm.atoms): mask[i] = mask[i] | int(atname == i) elif key == 'element': try: for i, atom in enumerate(self.parm.atoms): mask[i] = mask[i] | int(AtomicNum[atname] == atom.atomic_number) except KeyError: raise MaskError('Unknown element %s' % atname) else: for i, atom in enumerate(self.parm.atoms): mask[i] = mask[i] | int(_nameMatch(atname, getattr(atom, key))) #====================================================== def _resname_select(self, resname, mask): """ Fills a _mask array with all residue names of a given name """ for i, atm in enumerate(self.parm.atoms): if _nameMatch(resname, atm.residue.name): mask[i] = 1 elif resname.isdigit(): mask[i] = mask[i] | int(int(resname) == atm.residue.idx + 1) #====================================================== def _binop(self, op, pmask1, pmask2): """ Does a binary operation on a pair of masks """ if op == '&': return pmask1.And(pmask2) if op == '|': return pmask1.Or(pmask2) raise MaskError('Unknown operator [%s]' % op) #====================================================== def _priority(self, op): if op in ['>','<']: return 6 if op in ['!']: return 5 if op in ['&']: return 4 if op in ['|']: return 3 if op in ['(']: return 2 if op in ['\n']: return 1 raise MaskError('Unknown operator [%s] in Mask ==%s==' % (op, self.mask))
#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ def _nameMatch(atnam1, atnam2): """ Determines if atnam1 matches atnam2, where atnam1 can have * as a wildcard and spaces are ignored. atnam2 should come from the prmtop. We'll use regex to do this. We will replace * with a regex that will match any alphanumeric character 0 or more times: * --> \\w* We will replace ? with a regex that will match exactly 1 alphanumeric character: ? --> \\w Then, we will substitute all instances of atnam2 in atnam1 with ''. If it's a complete match, then our result will be a blank string (and will evaluate to False for boolean conditions). If it's not blank, then it's not a full match and should return False """ import re atnam1 = str(atnam1).replace(' ','') atnam2 = str(atnam2).replace(' ','') # Replace amber mask wildcards with appropriate regex wildcards and protect # the + (but protect backslashes) R = '<!PROTECT!>' atnam1 = atnam1.replace('\\*', R).replace('*',r'\S*').replace(R, '*') atnam1 = atnam1.replace('\\?', R).replace('?',r'\S').replace(R, '?') atnam1 = atnam1.replace('\\+', R).replace('+',r'\+').replace(R, '+') # Now replace just the first instance of atnam2 in atnam2 with '', and # return *not* that return atnam1 == atnam2 or not bool(re.sub(atnam1, '', atnam2, 1)) #+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ class _mask(list): """ Mask array; only used by AmberMask """ def __init__(self, natom): self.natom = natom list.__init__(self, [0 for i in range(natom)]) def append(self, *args, **kwargs): raise MaskError('_mask is a fixed-length array!') def extend(self, *args, **kwargs): raise MaskError('_mask is a fixed-length array!') def pop(self, *args, **kwargs): return self[-1] def remove(self, *args, **kwargs): raise MaskError('_mask is a fixed-length array!') def And(self, other): if self.natom != other.natom: raise MaskError("_mask: and() requires another mask of equal size!") new_mask = _mask(self.natom) for i in range(len(self)): new_mask[i] = int(self[i] and other[i]) return new_mask def Or(self, other): if self.natom != other.natom: raise MaskError('_mask: or() requires another mask of equal size!') new_mask = _mask(self.natom) for i in range(len(self)): new_mask[i] = int(self[i] or other[i]) return new_mask def Not(self): new_mask = _mask(self.natom) for i in range(self.natom): new_mask[i] = 1 - self[i] return new_mask def select_all(self): for i in range(self.natom): self[i] = 1 #+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+