Python CPT Wrappers

CPTWrapper.py
class CPTWrapper(object):
	pass;
SimpleDiscreteCPTBN.py
class SimpleDiscreteCPTBN(CPTWrapper):
	def __init__(self, nodeOrder, map, stateSpace, jft):
		'''
		nodeOrder: [node1, node2, node3, ...]
		map: {node1:[], node2:[node1], node3:[node1, node2], ...}
		stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}
		jft: {node1:[0, 2, 3, 1, 1, ...], node2:[3, 3, 1, 1, 2, ...], node3:[1, 1, 0, 0, 0, ...], ...}
		'''
		self.__nodeMap = map;
		self.__orderedNodeList = nodeOrder;
		self.__stateSpace = stateSpace;
		self.__cfts = self.__JFTtoCFT(nodeOrder, map, jft);
		self.dataPoints = len(jft[nodeOrder[0]]);
	def __JFTtoCFT(self, nodeOrder, map, jft):
		'''
		return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]
		'''
		#return: {node1:[{node3:1, node1:0, node2:3}, {node3:1, node1:2, node2:3}, ...]
		retVal = defaultdict(int);
 
		# Cycle through each node in order
		for node in nodeOrder:
			parents = map[node];
			#retVal[node] = [];
 
			# Cycle through each data point
			for i in range(len(jft[node])):
				# Define the state space for this CFT entry
				state = {};
				state[node] = jft[node][i];
				for parent in parents:
					state[parent] = jft[parent][i];
 
				# Define the CFT entry
				#retVal[node].append(state);
				retVal[repr(state)] += 1;
 
		return retVal;
	def stateFreq(self, state, node):
		return (self.__cfts[repr(state)] + 1, self.dataPoints + len(self.__stateSpace[node]));
SQLDiscreteCPTBN.py
import MySQLdb
 
class SQLDiscreteCPTBN(CPTWrapper):
	class BNStateKey(object):
		def __init__(self, nodeOrder, state):
			self.nodeOrder = nodeOrder;
			self.state = state;
		def nodeOrder(self):
			return self.nodeOrder;
		def state(self):
			return self.state;
		def __repr__(self):
			retVal = "{";
			for node in self.nodeOrder:
				if self.state.has_key(node):
					retVal += "'" + str(node) + "':" + str(self.state[node]) + ", ";
			retVal += "}";
			return retVal;
 
	def __init__(self, nodeOrder, map, stateSpace, database, username, password, jftTable):
		'''
		nodeOrder: [node1, node2, node3, ...]
		map: {node1:[], node2:[node1], node3:[node1, node2], ...}
		stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}
		'''
		self.__database = database;
		self.__username = username;
		self.__password = password;
		self.__nodeMap = map;
		self.__orderedNodeList = nodeOrder;
		self.__stateSpace = stateSpace;
		self.__cfts = self.__JFTtoCFT(nodeOrder, map, jftTable);
	def getNodes(self):
		return self.__orderedNodeList;
	def getParents(self, node):
		return self.__nodeMap(node);
	def getStates(self, node):
		return self.__stateSpace(node);
	def __JFTtoCFT(self, nodeOrder, map, jftTable):
		'''
		return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]
		'''
		retVal = {};
 
		# Cycle through each node in order
		db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
		c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
		for node in nodeOrder:
			parents = map[node];
 
			# Check if table exists
			tableName = jftTable + "_" + node;
			for parent in parents:
				tableName += "_" + parent;
			checkSQL = "SHOW TABLES LIKE '" + tableName + "'";
			c.execute(checkSQL);
			tableExists = len(c.fetchall()) > 0;
 
			if(not tableExists):
				# Construct SQL statement
				createSQL = "CREATE TABLE IF NOT EXISTS ";
				createSQL += tableName;
				createSQL += " SELECT ";
				nodeListSQL = node;
				for parent in parents:
					nodeListSQL += ", " + parent;
				createSQL += nodeListSQL;
				createSQL += ", SUM(frequency) as frequency ";
				createSQL += "FROM " + jftTable + " ";
				createSQL += "GROUP BY " + nodeListSQL;
 
				# Gather results into DB table
				c.execute(createSQL);
				db.commit();
 
			retVal[node] = tableName;
 
		# Close DB connections`
		c.close();
		db.close();
 
		return retVal;
	def stateFreq(self, node, *args, **kwargs):
		'''
		stateFreq(node, [condNodeList], [evidence={}], [cft=True or False])
		-> state table: ({'A':0.3, 'B':8, 'frequency':4}, ...) OR
		-> conditional frequency table: {"{'A':0.3, 'B':8}":4, ...}
 
		node: The node in this BN to examine
		condNodeList: A conditional node list identifying each of the
		conditional nodes of interest.  This parameter is optional and is of
		the form [node1, node2, ...].  Nodes in this list must be a subset of
		the parent nodes of the given node.
		cft: True will cause this method to return a conditional frequency
		table.  False will cause this method to return a state table.
		evidence: A dictionary mapping conditional nodes to a value.  The nodes
		in the dictionary need not be in the conditional node list.  This
		parameter is optional.
		'''
		condNodeList = [];
		if(len(args) > 0):
			condNodeList = args[0];
 
		# Construct SQL statement
		db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
		c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
		selectSQL = "SELECT ";
		nodeListSQL = node + ", ";
		for condNode in condNodeList:
			nodeListSQL += condNode + ", ";
		selectSQL += nodeListSQL;
		selectSQL += "SUM(frequency) as frequency FROM " + self.__cfts[node];
		if(kwargs.has_key('evidence') and kwargs['evidence'] != {}):
			evidence = kwargs['evidence'];
			selectSQL += " WHERE ";
			for eNode in evidence.keys():
				if eNode in self.getParents(node):
					selectSQL += eNode + " = " + str(evidence[eNode]) + " AND ";
			selectSQL = selectSQL[:len(selectSQL) - 5];
		selectSQL += " GROUP BY " + nodeListSQL.rstrip(", ");
 
		# Gather results
		c.execute(selectSQL);
		retVal = c.fetchall();
		c.close();
 
		# Build CFT if asked to
		if(kwargs.has_key('cft') and kwargs['cft']):
			newRetVal = {};
			for item in retVal:
				val = item.pop('frequency');
				newRetVal[repr(self.BNStateKey(self.getNodes(), item))] = val;
			retVal = newRetVal;
 
		# Close DB connections`
		c.close();
		db.close();
 
		return retVal;

Export page to Open Document format









You could leave a comment if you were logged in.

goplayer/cptwrappers.txt · Last modified: 2023/02/24 23:05 (external edit)