Differences

This shows you the differences between two versions of the page.

Link to this comparison view

goplayer:dbrv [2011/09/08 06:00] (current)
aiartificer created
Line 1: Line 1:
 +====== DB RV ======
  
 +<code C DBRV.h>
 +/*
 + ============================================================================
 + Name        : dbrv.h
 + Author      : Stephen Cannon
 + Version     : 0.1
 + Copyright   : Copyright 2011 Stephen Cannon
 + Description :
 + ============================================================================
 + *
 + * This file is part of LikelihoodWeighting.
 + *
 + * LikelihoodWeighting is free software: you can redistribute it and/or modify
 + * it under the terms of the GNU Lesser General Public License as published by
 + * the Free Software Foundation, either version 3 of the License, or (at your
 + * option) any later version.
 + *
 + * LikelihoodWeighting is distributed in the hope that it will be useful, but
 + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 + * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
 + * License for more details.
 + *
 + * You should have received a copy of the GNU Lesser General Public License
 + * along with LikelihoodWeighting.  If not, see <http://www.gnu.org/licenses/>.
 + *
 + */
 +
 +#ifndef DBRV_H_
 +#define DBRV_H_
 +
 +#include <windows.h>
 +#include <stdio.h>
 +#include <sql.h>
 +#include <sqlext.h>
 +#include "DBConnector.h"
 +
 +#ifdef __cplusplus
 +extern "C" {
 +#endif
 +
 +
 +enum {MAX_SQL_ID_SIZE = 65};
 +
 +enum
 +{
 + ERR_TOO_MANY_PARENTS = -13,
 + ERR_INVALID_STATE_TYPE = -14,
 + ERR_NO_BIN_STATE_FOR_VALUE = -15
 +};
 +
 +typedef struct __ParentTypePair__
 +{
 + const char *parent;
 + char enumeratedType;
 + SQLSMALLINT precision;
 +} ParentTypePair;
 +
 +typedef enum __StateType__
 +{
 + STATE_VALUE,
 + STATE_RANGE,
 + STATE_ENUM
 +} StateType;
 +
 +// TODO For enums at least, maybe an ID table and make enumValue a char *?
 +typedef struct __State__
 +{
 + StateType type;
 + union
 + {
 + char enumValue[MAX_SQL_ID_SIZE];
 + double scalarValue;
 + double rangeValue[2];
 + } value;
 +} State;
 +
 +typedef struct __DBRV__
 +{
 + const char *modelName;
 + const char *nodeName;
 + char enumNode;
 +
 + State *stateList;
 + size_t numStates;
 +
 + ParamBinding paramBindings[64]; // No more than ~32 parents per node
 + size_t numParams;
 +
 + double dOutValue;
 + char strOutValue[BUFFER_SIZE];
 + size_t lastFreq;
 + SQLINTEGER lastFreqID, outValueID;
 + ColBinding columnBindings[2];
 +
 + char querySQL[LARGE_BUFFER_SIZE];
 +
 + DBConnector super;
 +} DBRV;
 +
 +typedef struct __Evidence__
 +{
 + DBRV *rv;
 + State *state;
 +} Evidence;
 +
 +
 +char dbrv__init(DBRV *self,
 + const char *modelName,
 + const char *nodename,
 + char enumNode,
 + State *stateList,
 + size_t numStates,
 + ParentTypePair *parentTypeMap,
 + size_t numParents,
 + const char *connectionString);
 +char dbrv__start(DBRV *self);
 +char __dbrv__prepareQuerySQL(DBRV *self, char enumNode, ParentTypePair *parentTypeMap, size_t numParents);
 +char __dbrv__setupParamBinding(DBRV *self, const char *parentName, ParentTypePair *parentType);
 +char __dbrv__setupColumnBinding(DBRV *self, char enumNode);
 +char dbrv__logP(DBRV *self, Evidence *ev, size_t numEvNodes, double *probs);
 +char dbrv__randomSample(DBRV *self, Evidence *ev, size_t numEvNodes, State *sample);
 +char dbrv__getState(DBRV *self, const State *valState, const State **binState, size_t *i);
 +char dbrv__getStateFromScalar(DBRV *self, double value, const State **binState, size_t *i);
 +char dbrv__getStateFromRange(DBRV *self, double low, double high, const State **binState, size_t *i);
 +char dbrv__getStateFromEnum(DBRV *self, const char value[MAX_SQL_ID_SIZE], const State **binState, size_t *i);
 +char dbrv__stop(DBRV *self);
 +char dbrv__del(DBRV *self);
 +
 +char overwriteString(char *strDest, size_t numElements, const char *strSrc, size_t *offset);
 +
 +#ifdef __cplusplus
 +}
 +#endif
 +
 +#endif /* DBRV_H_ */
 +</code>
 +
 +<code C DBRV.c>
 +/*
 + ============================================================================
 + Name        : dbrv.c
 + Author      : Stephen Cannon
 + Version     : 0.1
 + Copyright   : Copyright 2011 Stephen Cannon
 + Description :
 + ============================================================================
 + *
 + * This file is part of LikelihoodWeighting.
 + *
 + * LikelihoodWeighting is free software: you can redistribute it and/or modify
 + * it under the terms of the GNU Lesser General Public License as published by
 + * the Free Software Foundation, either version 3 of the License, or (at your
 + * option) any later version.
 + *
 + * LikelihoodWeighting is distributed in the hope that it will be useful, but
 + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 + * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
 + * License for more details.
 + *
 + * You should have received a copy of the GNU Lesser General Public License
 + * along with LikelihoodWeighting.  If not, see <http://www.gnu.org/licenses/>.
 + *
 + */
 +#include <stdio.h>
 +#include <stdlib.h>
 +#include <string.h>
 +#include <errno.h>
 +
 +#include "dbrv.h"
 +#include "DBConnector.h"
 +
 +
 +char dbrv__init(DBRV *self,
 + const char *modelName,
 + const char *nodeName,
 + char enumNode,
 + State *stateList,
 + size_t numStates,
 + ParentTypePair *parentTypeMap,
 + size_t numParents,
 + const char *connectionString)
 +{
 + size_t i = 0;
 + size_t c = 0;
 + char error = 0;
 +
 + // Initialize attibutes
 + self->numParams = 0;
 + self->strOutValue[0] = 0;
 + self->enumNode = enumNode;
 +
 + // Define node states
 + self->numStates = numStates;
 + self->stateList = (State *)calloc(numStates, sizeof(State));
 + for(i = 0; i < numStates; i++)
 + {
 + self->stateList[i].type = stateList[i].type;
 + switch(stateList[i].type)
 + {
 + case STATE_VALUE:
 + self->stateList[i].value.scalarValue = stateList[i].value.scalarValue;
 + break;
 + case STATE_RANGE:
 + self->stateList[i].value.rangeValue[0] = stateList[i].value.rangeValue[0];
 + self->stateList[i].value.rangeValue[1] = stateList[i].value.rangeValue[1];
 + break;
 + case STATE_ENUM:
 + for(c = 0 ; c < MAX_SQL_ID_SIZE && stateList[i].value.enumValue[c] != 0; c++)
 + self->stateList[i].value.enumValue[c] = stateList[i].value.enumValue[c];
 + self->stateList[i].value.enumValue[c] = 0;
 + break;
 + default:
 + return ERR_INVALID_STATE_TYPE;
 + }
 + }
 +
 + // Initialize query SQL string
 + self->querySQL[0] = 0;
 + self->querySQL[LARGE_BUFFER_SIZE - 1] = 0;
 +
 + // Get node name
 + self->nodeName = nodeName;
 +
 + // Get model name
 + self->modelName = modelName;
 +
 + // Generate parent class DBConnector
 + if(error = DBConnector__init(&self->super, connectionString))
 + return error;
 +
 + // Build query SQL
 + if(error = __dbrv__prepareQuerySQL(self, enumNode, parentTypeMap, numParents))
 + return error;
 +
 + return error;
 +}
 +
 +/**
 + * @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...}
 + */
 +char dbrv__start(DBRV *self)
 +{
 + char error = 0;
 +
 + // Connect to DB
 + if(error = DBConnector__connect(&self->super))
 + return error;
 +
 + // Prepare statement
 + if(error = DBConnector__prepareStatement(&self->super, self->querySQL, self->paramBindings, self->numParams, self->columnBindings, 2))
 + return error;
 +
 + return error;
 +}
 +
 +/**
 + * @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...}
 + */
 +char __dbrv__prepareQuerySQL(DBRV *self, char enumNode, ParentTypePair *parentTypeMap, size_t numParents)
 +{
 + char error = 0;
 + size_t offset = 0;
 + size_t i = 0;
 +
 + // Verify fewer than 32 parents
 + if(numParents > 32)
 + return ERR_TOO_MANY_PARENTS;
 +
 + // Start SELECT statement
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "SELECT `", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`, SUM(frequency) as frequency FROM ", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->modelName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "_", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
 +
 + // Setup WHERE clause
 + if(numParents > 0)
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " WHERE `", &offset)) return -1;
 + for(i = 0; i < numParents; i++)
 + {
 + const char *pNodeName = 0;
 +
 + // Get parent node name
 + pNodeName = parentTypeMap[i].parent;
 +
 + // Get parent value
 + if(parentTypeMap[i].enumeratedType)
 + {
 + if(i > 0)
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` = ?", &offset)) return -1;
 + }
 + else
 + {
 + if(i > 0)
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` >= ? AND `", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` < ?", &offset)) return -1;
 + }
 +
 + // bind parameter
 + if(__dbrv__setupParamBinding(self, pNodeName, &parentTypeMap[i])) return -1;
 + }
 +
 + // Setup GROUP BY clause
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " GROUP BY `", &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
 + if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`;", &offset)) return -1;
 +
 + // Define column bindings
 + if(__dbrv__setupColumnBinding(self, enumNode)) return -1;
 +
 + return 0;
 +}
 +
 +char __dbrv__setupParamBinding(DBRV *self, const char *parentName, ParentTypePair *parentType)
 +{
 + char *strParamValue = 0;
 + double *dParamValue = 0;
 + SQLLEN *paramLenPtr = 0;
 +
 + if(parentType->enumeratedType)
 + {
 + // Define parameter binding
 + self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
 + self->paramBindings[self->numParams].valueType = SQL_C_CHAR;
 + self->paramBindings[self->numParams].paramType = SQL_CHAR;
 + self->paramBindings[self->numParams].columnSize = 0;
 + self->paramBindings[self->numParams].decimalDigits = parentType->precision;
 + strParamValue = (char *)malloc(sizeof(char)*BUFFER_SIZE);
 + self->paramBindings[self->numParams].paramValuePtr = strParamValue;
 + self->paramBindings[self->numParams].bufferLength = BUFFER_SIZE;
 + paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
 + self->paramBindings[self->numParams].indPtr = paramLenPtr;
 +
 + }
 + else
 + {
 + // Define parameter binding
 + self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
 + self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE;
 + self->paramBindings[self->numParams].paramType = SQL_DOUBLE;
 + self->paramBindings[self->numParams].columnSize = 0;
 + self->paramBindings[self->numParams].decimalDigits = parentType->precision;
 + dParamValue = (double *)malloc(sizeof(double));
 + self->paramBindings[self->numParams].paramValuePtr = dParamValue;
 + self->paramBindings[self->numParams].bufferLength = 0;
 + paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
 + self->paramBindings[self->numParams].indPtr = paramLenPtr;
 + self->numParams++;
 +
 + // Define parameter binding for second parameter
 + self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
 + self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE;
 + self->paramBindings[self->numParams].paramType = SQL_DOUBLE;
 + self->paramBindings[self->numParams].columnSize = 0;
 + self->paramBindings[self->numParams].decimalDigits = parentType->precision;
 + dParamValue = (double *)malloc(sizeof(double));
 + self->paramBindings[self->numParams].paramValuePtr = dParamValue;
 + self->paramBindings[self->numParams].bufferLength = 0;
 + paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
 + self->paramBindings[self->numParams].indPtr = paramLenPtr;
 + }
 + self->numParams++;
 +
 + return 0;
 +}
 +
 +char __dbrv__setupColumnBinding(DBRV *self, char enumNode)
 +{
 + if(enumNode)
 + {
 + self->columnBindings[0].type = SQL_C_CHAR;
 + }
 + else
 + {
 + self->columnBindings[0].type = SQL_C_DOUBLE;
 + }
 + self->columnBindings[0].valuePtr = 0;
 + self->columnBindings[0].bufferLength = 0;
 + self->columnBindings[0].indPtr = 0;
 +
 + self->columnBindings[1].type = SQL_C_LONG;
 + self->columnBindings[1].valuePtr = &self->lastFreq;
 + self->columnBindings[1].bufferLength = 0;
 + self->columnBindings[1].indPtr = &self->lastFreqID;
 +
 + return 0;
 +}
 +
 +/**
 + * @param ev Evidence nodes presented in the order they were presented in the
 + * parentTypeMap parameter when the start() method was called.
 + * @param numEvNodes
 + * @param probs [OUT] 
 + */
 +char dbrv__logP(DBRV *self, Evidence *ev, size_t numEvNodes, double *probs)
 +{
 + size_t i = 0;
 + size_t pBindIndex = 0;
 + char error = 0;
 + size_t totalFreq = 0;
 +
 + // Define parent values
 + for(i = 0; i < numEvNodes; i++)
 + {
 + const State *binState;
 + size_t index;
 +
 + // Get bin state
 + if(error = dbrv__getState(ev[i].rv, ev[i].state, &binState, &index)) return error;
 +
 + // Bind parameter values
 + switch(binState->type)
 + {
 + case STATE_RANGE:
 + self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[0];
 + self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[1];
 + break;
 + case STATE_ENUM:
 + self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.enumValue;
 + break;
 + default:
 + return ERR_INVALID_STATE_TYPE;
 + break;
 + }
 + }
 +
 + // Get CPT
 + for(i = 0; i < self->numStates; i++)
 + probs[i] = 0;
 + if(error = DBConnector__executePreparedStatement(&self->super)) return error;
 + while(DBConnector__fetchExecutedStatementResult(&self->super, &error))
 + {
 + const State *binState;
 + size_t index;
 +
 + if(error = dbrv__getStateFromScalar(self, self->dOutValue, &binState, &index)) return error;
 + probs[index] += self->lastFreq;
 + totalFreq += self->lastFreq;
 + }
 + if(error)
 + return error;
 + for(i = 0; i < self->numStates; i++)
 + probs[i] /= totalFreq;
 +
 + return error;
 +}
 +
 +char dbrv__randomSample(DBRV *self, Evidence *ev, size_t numEvNodes, State *sample)
 +{
 + return 1;
 +}
 +
 +char dbrv__getState(DBRV *self, const State *valState, const State **binState, size_t *i)
 +{
 + *i = 0;
 +
 + switch(valState->type)
 + {
 + case STATE_VALUE:
 + return dbrv__getStateFromScalar(self, valState->value.scalarValue, binState, i);
 + break;
 + case STATE_RANGE:
 + for(*i = 0; *i < self->numStates; *i++)
 + {
 + const State presentBinState = self->stateList[*i];
 +
 + if(presentBinState.type != STATE_RANGE)
 + return ERR_INVALID_STATE_TYPE;
 + if(valState == &presentBinState)
 + {
 + *binState = &presentBinState;
 + return 0;
 + }
 + }
 +
 + return ERR_NO_BIN_STATE_FOR_VALUE;
 + break;
 + case STATE_ENUM:
 + return dbrv__getStateFromEnum(self, valState->value.enumValue, binState, i);
 + break;
 + default:
 + return ERR_INVALID_STATE_TYPE;
 + break;
 + }
 +
 + return ERR_GENERAL_ERROR;
 +}
 +char dbrv__getStateFromScalar(DBRV *self, double value, const State **binState, size_t *i)
 +{
 + *i = 0;
 +
 + for(*i = 0; *i < self->numStates; *i++)
 + {
 + const State presentBinState = self->stateList[*i];
 +
 + if(presentBinState.type != STATE_RANGE)
 + return ERR_INVALID_STATE_TYPE;
 + if(value >= presentBinState.value.rangeValue[0] &&
 + value < presentBinState.value.rangeValue[1])
 + {
 + *binState = &presentBinState;
 + return 0;
 + }
 + }
 +
 + return ERR_NO_BIN_STATE_FOR_VALUE;
 +}
 +char dbrv__getStateFromRange(DBRV *self, double low, double high, const State **binState, size_t *i)
 +{
 + *i = 0;
 +
 + for(*i = 0; *i < self->numStates; *i++)
 + {
 + const State presentBinState = self->stateList[*i];
 +
 + if(presentBinState.type != STATE_RANGE)
 + return ERR_INVALID_STATE_TYPE;
 + if(low == presentBinState.value.rangeValue[0] &&
 + high == presentBinState.value.rangeValue[1])
 + {
 + *binState = &presentBinState;
 + return 0;
 + }
 + }
 +
 + return ERR_NO_BIN_STATE_FOR_VALUE;
 +}
 +char dbrv__getStateFromEnum(DBRV *self, const char value[MAX_SQL_ID_SIZE], const State **binState, size_t *i)
 +{
 + *i = 0;
 +
 + for(*i = 0; *i < self->numStates; *i++)
 + {
 + const State presentBinState = self->stateList[*i];
 +
 +
 + if(presentBinState.type != STATE_ENUM)
 + return ERR_INVALID_STATE_TYPE;
 + if(strcmp(value, presentBinState.value.enumValue) == 0)
 + {
 + *binState = &presentBinState;
 + return 0;
 + }
 + }
 +
 + return ERR_NO_BIN_STATE_FOR_VALUE;
 +}
 +
 +char dbrv__stop(DBRV *self)
 +{
 + char error = 0;
 +
 + if(error = DBConnector__disconnect(&self->super))
 + return error;
 +
 + return error;
 +}
 +char dbrv__del(DBRV *self)
 +{
 + char error = 0;
 + size_t i = 0;
 +
 + // Deallocate states
 + free(self->stateList);
 +
 + // Deallocate parameter values and length pointers
 + for(i = 0; i < self->numParams; i++)
 + {
 + free(self->paramBindings[i].paramValuePtr);
 + free(self->paramBindings[i].indPtr);
 + }
 +
 + // Clean up parent class
 + if(error = DBConnector__del(&self->super))
 + return error;
 +
 + return error;
 +}
 +
 +char overwriteString(char *strDest, size_t numElements, const char *strSrc, size_t *offset)
 +{
 + size_t srcBufferLen;
 + size_t i;
 +
 + // Verify there won't be a buffer overrun
 + srcBufferLen = strlen(strSrc);
 + if(numElements <= srcBufferLen + *offset)
 + return -1;
 +
 + // Overwrite destination string with source string
 + for (i = 0; i < srcBufferLen; i++)
 + strDest[i + *offset] = strSrc[i];
 +
 + // Update offset
 + *offset += srcBufferLen;
 + strDest[*offset] = 0;
 +
 + return 0;
 +}
 +</code>
 +
 +~~ODT~~
 +\\
 +\\
 +\\
 +\\
 +\\
 +\\
 +\\
 +\\
 +\\
 +\\
 +~~DISCUSSION~~

goplayer/dbrv.txt · Last modified: 2011/09/08 06:00 by aiartificer