Python CPT Wrappers
- CPTWrapper.py
class CPTWrapper(object): pass;
- SimpleDiscreteCPTBN.py
class SimpleDiscreteCPTBN(CPTWrapper): def __init__(self, nodeOrder, map, stateSpace, jft): ''' nodeOrder: [node1, node2, node3, ...] map: {node1:[], node2:[node1], node3:[node1, node2], ...} stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...} jft: {node1:[0, 2, 3, 1, 1, ...], node2:[3, 3, 1, 1, 2, ...], node3:[1, 1, 0, 0, 0, ...], ...} ''' self.__nodeMap = map; self.__orderedNodeList = nodeOrder; self.__stateSpace = stateSpace; self.__cfts = self.__JFTtoCFT(nodeOrder, map, jft); self.dataPoints = len(jft[nodeOrder[0]]); def __JFTtoCFT(self, nodeOrder, map, jft): ''' return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...] ''' #return: {node1:[{node3:1, node1:0, node2:3}, {node3:1, node1:2, node2:3}, ...] retVal = defaultdict(int); # Cycle through each node in order for node in nodeOrder: parents = map[node]; #retVal[node] = []; # Cycle through each data point for i in range(len(jft[node])): # Define the state space for this CFT entry state = {}; state[node] = jft[node][i]; for parent in parents: state[parent] = jft[parent][i]; # Define the CFT entry #retVal[node].append(state); retVal[repr(state)] += 1; return retVal; def stateFreq(self, state, node): return (self.__cfts[repr(state)] + 1, self.dataPoints + len(self.__stateSpace[node]));
- SQLDiscreteCPTBN.py
import MySQLdb class SQLDiscreteCPTBN(CPTWrapper): class BNStateKey(object): def __init__(self, nodeOrder, state): self.nodeOrder = nodeOrder; self.state = state; def nodeOrder(self): return self.nodeOrder; def state(self): return self.state; def __repr__(self): retVal = "{"; for node in self.nodeOrder: if self.state.has_key(node): retVal += "'" + str(node) + "':" + str(self.state[node]) + ", "; retVal += "}"; return retVal; def __init__(self, nodeOrder, map, stateSpace, database, username, password, jftTable): ''' nodeOrder: [node1, node2, node3, ...] map: {node1:[], node2:[node1], node3:[node1, node2], ...} stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...} ''' self.__database = database; self.__username = username; self.__password = password; self.__nodeMap = map; self.__orderedNodeList = nodeOrder; self.__stateSpace = stateSpace; self.__cfts = self.__JFTtoCFT(nodeOrder, map, jftTable); def getNodes(self): return self.__orderedNodeList; def getParents(self, node): return self.__nodeMap(node); def getStates(self, node): return self.__stateSpace(node); def __JFTtoCFT(self, nodeOrder, map, jftTable): ''' return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...] ''' retVal = {}; # Cycle through each node in order db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username); c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor); for node in nodeOrder: parents = map[node]; # Check if table exists tableName = jftTable + "_" + node; for parent in parents: tableName += "_" + parent; checkSQL = "SHOW TABLES LIKE '" + tableName + "'"; c.execute(checkSQL); tableExists = len(c.fetchall()) > 0; if(not tableExists): # Construct SQL statement createSQL = "CREATE TABLE IF NOT EXISTS "; createSQL += tableName; createSQL += " SELECT "; nodeListSQL = node; for parent in parents: nodeListSQL += ", " + parent; createSQL += nodeListSQL; createSQL += ", SUM(frequency) as frequency "; createSQL += "FROM " + jftTable + " "; createSQL += "GROUP BY " + nodeListSQL; # Gather results into DB table c.execute(createSQL); db.commit(); retVal[node] = tableName; # Close DB connections` c.close(); db.close(); return retVal; def stateFreq(self, node, *args, **kwargs): ''' stateFreq(node, [condNodeList], [evidence={}], [cft=True or False]) -> state table: ({'A':0.3, 'B':8, 'frequency':4}, ...) OR -> conditional frequency table: {"{'A':0.3, 'B':8}":4, ...} node: The node in this BN to examine condNodeList: A conditional node list identifying each of the conditional nodes of interest. This parameter is optional and is of the form [node1, node2, ...]. Nodes in this list must be a subset of the parent nodes of the given node. cft: True will cause this method to return a conditional frequency table. False will cause this method to return a state table. evidence: A dictionary mapping conditional nodes to a value. The nodes in the dictionary need not be in the conditional node list. This parameter is optional. ''' condNodeList = []; if(len(args) > 0): condNodeList = args[0]; # Construct SQL statement db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username); c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor); selectSQL = "SELECT "; nodeListSQL = node + ", "; for condNode in condNodeList: nodeListSQL += condNode + ", "; selectSQL += nodeListSQL; selectSQL += "SUM(frequency) as frequency FROM " + self.__cfts[node]; if(kwargs.has_key('evidence') and kwargs['evidence'] != {}): evidence = kwargs['evidence']; selectSQL += " WHERE "; for eNode in evidence.keys(): if eNode in self.getParents(node): selectSQL += eNode + " = " + str(evidence[eNode]) + " AND "; selectSQL = selectSQL[:len(selectSQL) - 5]; selectSQL += " GROUP BY " + nodeListSQL.rstrip(", "); # Gather results c.execute(selectSQL); retVal = c.fetchall(); c.close(); # Build CFT if asked to if(kwargs.has_key('cft') and kwargs['cft']): newRetVal = {}; for item in retVal: val = item.pop('frequency'); newRetVal[repr(self.BNStateKey(self.getNodes(), item))] = val; retVal = newRetVal; # Close DB connections` c.close(); db.close(); return retVal;
You could leave a comment if you were logged in.