Differences

This shows you the differences between two versions of the page.

Link to this comparison view

goplayer:cptwrappers [2011/02/16 22:28]
aiartificer Update to condNodeList parameter description in SQLDiscreteCPTBN
goplayer:cptwrappers [2011/02/17 23:10] (current)
aiartificer
Line 50: Line 50:
  
 <code python SQLDiscreteCPTBN.py> <code python SQLDiscreteCPTBN.py>
 +import MySQLdb
 +
 class SQLDiscreteCPTBN(CPTWrapper): class SQLDiscreteCPTBN(CPTWrapper):
  class BNStateKey(object):  class BNStateKey(object):
Line 67: Line 69:
  return retVal;  return retVal;
   
- def __init__(self, nodeOrder, map, stateSpace, db, jftTable):+ def __init__(self, nodeOrder, map, stateSpace, database, username, password, jftTable):
  '''  '''
  nodeOrder: [node1, node2, node3, ...]  nodeOrder: [node1, node2, node3, ...]
  map: {node1:[], node2:[node1], node3:[node1, node2], ...}  map: {node1:[], node2:[node1], node3:[node1, node2], ...}
  stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}  stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}
- db: DB connection 
  '''  '''
 + self.__database = database;
 + self.__username = username;
 + self.__password = password;
  self.__nodeMap = map;  self.__nodeMap = map;
  self.__orderedNodeList = nodeOrder;  self.__orderedNodeList = nodeOrder;
  self.__stateSpace = stateSpace;  self.__stateSpace = stateSpace;
- self.__cfts = self.__JFTtoCFT(nodeOrder, map, db, jftTable);+ self.__cfts = self.__JFTtoCFT(nodeOrder, map, jftTable);
  def getNodes(self):  def getNodes(self):
  return self.__orderedNodeList;  return self.__orderedNodeList;
Line 84: Line 88:
  def getStates(self, node):  def getStates(self, node):
  return self.__stateSpace(node);  return self.__stateSpace(node);
- def __JFTtoCFT(self, nodeOrder, map, db, jftTable):+ def __JFTtoCFT(self, nodeOrder, map, jftTable):
  '''  '''
  return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]  return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]
Line 91: Line 95:
   
  # Cycle through each node in order  # Cycle through each node in order
 + db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
  c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);  c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
  for node in nodeOrder:  for node in nodeOrder:
Line 121: Line 126:
   
  retVal[node] = tableName;  retVal[node] = tableName;
- +  
 + # Close DB connections`
  c.close();  c.close();
 + db.close();
   
  return retVal;  return retVal;
- def stateFreq(self, db, node, *args, **kwargs):+ def stateFreq(self, node, *args, **kwargs):
  '''  '''
- stateFreq(db, node, [condNodeList], [evidence={}], [cft=True or False])+ stateFreq(node, [condNodeList], [evidence={}], [cft=True or False])
  -> state table: ({'A':0.3, 'B':8, 'frequency':4}, ...) OR  -> state table: ({'A':0.3, 'B':8, 'frequency':4}, ...) OR
  -> conditional frequency table: {"{'A':0.3, 'B':8}":4, ...}  -> conditional frequency table: {"{'A':0.3, 'B':8}":4, ...}
   
- db: DB connection of type MySQLdb.connections.Connection 
  node: The node in this BN to examine  node: The node in this BN to examine
  condNodeList: A conditional node list identifying each of the  condNodeList: A conditional node list identifying each of the
Line 148: Line 154:
   
  # Construct SQL statement  # Construct SQL statement
 + db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
  c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);  c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
  selectSQL = "SELECT ";  selectSQL = "SELECT ";
Line 159: Line 166:
  selectSQL += " WHERE ";  selectSQL += " WHERE ";
  for eNode in evidence.keys():  for eNode in evidence.keys():
- selectSQL += eNode + " = " + str(evidence[eNode]) + " AND ";+ if eNode in self.getParents(node): 
 + selectSQL += eNode + " = " + str(evidence[eNode]) + " AND ";
  selectSQL = selectSQL[:len(selectSQL) - 5];  selectSQL = selectSQL[:len(selectSQL) - 5];
  selectSQL += " GROUP BY " + nodeListSQL.rstrip(", ");  selectSQL += " GROUP BY " + nodeListSQL.rstrip(", ");
Line 176: Line 184:
  retVal = newRetVal;  retVal = newRetVal;
   
 + # Close DB connections`
 + c.close();
 + db.close();
 +
  return retVal;  return retVal;
 </code> </code>

goplayer/cptwrappers.txt · Last modified: 2011/02/17 23:10 by aiartificer