/*
============================================================================
Name : dbrv.c
Author : Stephen Cannon
Version : 0.1
Copyright : Copyright 2011 Stephen Cannon
Description :
============================================================================
*
* This file is part of LikelihoodWeighting.
*
* LikelihoodWeighting is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at your
* option) any later version.
*
* LikelihoodWeighting is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
* License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with LikelihoodWeighting. If not, see .
*
*/
#include
#include
#include
#include
#include "dbrv.h"
#include "DBConnector.h"
char dbrv__init(DBRV *self,
const char *modelName,
const char *nodeName,
char enumNode,
State *stateList,
size_t numStates,
ParentTypePair *parentTypeMap,
size_t numParents,
const char *connectionString)
{
size_t i = 0;
size_t c = 0;
char error = 0;
// Initialize attibutes
self->numParams = 0;
self->strOutValue[0] = 0;
self->enumNode = enumNode;
// Define node states
self->numStates = numStates;
self->stateList = (State *)calloc(numStates, sizeof(State));
for(i = 0; i < numStates; i++)
{
self->stateList[i].type = stateList[i].type;
switch(stateList[i].type)
{
case STATE_VALUE:
self->stateList[i].value.scalarValue = stateList[i].value.scalarValue;
break;
case STATE_RANGE:
self->stateList[i].value.rangeValue[0] = stateList[i].value.rangeValue[0];
self->stateList[i].value.rangeValue[1] = stateList[i].value.rangeValue[1];
break;
case STATE_ENUM:
for(c = 0 ; c < MAX_SQL_ID_SIZE && stateList[i].value.enumValue[c] != 0; c++)
self->stateList[i].value.enumValue[c] = stateList[i].value.enumValue[c];
self->stateList[i].value.enumValue[c] = 0;
break;
default:
return ERR_INVALID_STATE_TYPE;
}
}
// Initialize query SQL string
self->querySQL[0] = 0;
self->querySQL[LARGE_BUFFER_SIZE - 1] = 0;
// Get node name
self->nodeName = nodeName;
// Get model name
self->modelName = modelName;
// Generate parent class DBConnector
if(error = DBConnector__init(&self->super, connectionString))
return error;
// Build query SQL
if(error = __dbrv__prepareQuerySQL(self, enumNode, parentTypeMap, numParents))
return error;
return error;
}
/**
* @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...}
*/
char dbrv__start(DBRV *self)
{
char error = 0;
// Connect to DB
if(error = DBConnector__connect(&self->super))
return error;
// Prepare statement
if(error = DBConnector__prepareStatement(&self->super, self->querySQL, self->paramBindings, self->numParams, self->columnBindings, 2))
return error;
return error;
}
/**
* @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...}
*/
char __dbrv__prepareQuerySQL(DBRV *self, char enumNode, ParentTypePair *parentTypeMap, size_t numParents)
{
char error = 0;
size_t offset = 0;
size_t i = 0;
// Verify fewer than 32 parents
if(numParents > 32)
return ERR_TOO_MANY_PARENTS;
// Start SELECT statement
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "SELECT `", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`, SUM(frequency) as frequency FROM ", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->modelName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "_", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
// Setup WHERE clause
if(numParents > 0)
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " WHERE `", &offset)) return -1;
for(i = 0; i < numParents; i++)
{
const char *pNodeName = 0;
// Get parent node name
pNodeName = parentTypeMap[i].parent;
// Get parent value
if(parentTypeMap[i].enumeratedType)
{
if(i > 0)
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` = ?", &offset)) return -1;
}
else
{
if(i > 0)
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` >= ? AND `", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` < ?", &offset)) return -1;
}
// bind parameter
if(__dbrv__setupParamBinding(self, pNodeName, &parentTypeMap[i])) return -1;
}
// Setup GROUP BY clause
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " GROUP BY `", &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1;
if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`;", &offset)) return -1;
// Define column bindings
if(__dbrv__setupColumnBinding(self, enumNode)) return -1;
return 0;
}
char __dbrv__setupParamBinding(DBRV *self, const char *parentName, ParentTypePair *parentType)
{
char *strParamValue = 0;
double *dParamValue = 0;
SQLLEN *paramLenPtr = 0;
if(parentType->enumeratedType)
{
// Define parameter binding
self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
self->paramBindings[self->numParams].valueType = SQL_C_CHAR;
self->paramBindings[self->numParams].paramType = SQL_CHAR;
self->paramBindings[self->numParams].columnSize = 0;
self->paramBindings[self->numParams].decimalDigits = parentType->precision;
strParamValue = (char *)malloc(sizeof(char)*BUFFER_SIZE);
self->paramBindings[self->numParams].paramValuePtr = strParamValue;
self->paramBindings[self->numParams].bufferLength = BUFFER_SIZE;
paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
self->paramBindings[self->numParams].indPtr = paramLenPtr;
}
else
{
// Define parameter binding
self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE;
self->paramBindings[self->numParams].paramType = SQL_DOUBLE;
self->paramBindings[self->numParams].columnSize = 0;
self->paramBindings[self->numParams].decimalDigits = parentType->precision;
dParamValue = (double *)malloc(sizeof(double));
self->paramBindings[self->numParams].paramValuePtr = dParamValue;
self->paramBindings[self->numParams].bufferLength = 0;
paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
self->paramBindings[self->numParams].indPtr = paramLenPtr;
self->numParams++;
// Define parameter binding for second parameter
self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT;
self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE;
self->paramBindings[self->numParams].paramType = SQL_DOUBLE;
self->paramBindings[self->numParams].columnSize = 0;
self->paramBindings[self->numParams].decimalDigits = parentType->precision;
dParamValue = (double *)malloc(sizeof(double));
self->paramBindings[self->numParams].paramValuePtr = dParamValue;
self->paramBindings[self->numParams].bufferLength = 0;
paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN));
self->paramBindings[self->numParams].indPtr = paramLenPtr;
}
self->numParams++;
return 0;
}
char __dbrv__setupColumnBinding(DBRV *self, char enumNode)
{
if(enumNode)
{
self->columnBindings[0].type = SQL_C_CHAR;
}
else
{
self->columnBindings[0].type = SQL_C_DOUBLE;
}
self->columnBindings[0].valuePtr = 0;
self->columnBindings[0].bufferLength = 0;
self->columnBindings[0].indPtr = 0;
self->columnBindings[1].type = SQL_C_LONG;
self->columnBindings[1].valuePtr = &self->lastFreq;
self->columnBindings[1].bufferLength = 0;
self->columnBindings[1].indPtr = &self->lastFreqID;
return 0;
}
/**
* @param ev Evidence nodes presented in the order they were presented in the
* parentTypeMap parameter when the start() method was called.
* @param numEvNodes
* @param probs [OUT]
*/
char dbrv__logP(DBRV *self, Evidence *ev, size_t numEvNodes, double *probs)
{
size_t i = 0;
size_t pBindIndex = 0;
char error = 0;
size_t totalFreq = 0;
// Define parent values
for(i = 0; i < numEvNodes; i++)
{
const State *binState;
size_t index;
// Get bin state
if(error = dbrv__getState(ev[i].rv, ev[i].state, &binState, &index)) return error;
// Bind parameter values
switch(binState->type)
{
case STATE_RANGE:
self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[0];
self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[1];
break;
case STATE_ENUM:
self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.enumValue;
break;
default:
return ERR_INVALID_STATE_TYPE;
break;
}
}
// Get CPT
for(i = 0; i < self->numStates; i++)
probs[i] = 0;
if(error = DBConnector__executePreparedStatement(&self->super)) return error;
while(DBConnector__fetchExecutedStatementResult(&self->super, &error))
{
const State *binState;
size_t index;
if(error = dbrv__getStateFromScalar(self, self->dOutValue, &binState, &index)) return error;
probs[index] += self->lastFreq;
totalFreq += self->lastFreq;
}
if(error)
return error;
for(i = 0; i < self->numStates; i++)
probs[i] /= totalFreq;
return error;
}
char dbrv__randomSample(DBRV *self, Evidence *ev, size_t numEvNodes, State *sample)
{
return 1;
}
char dbrv__getState(DBRV *self, const State *valState, const State **binState, size_t *i)
{
*i = 0;
switch(valState->type)
{
case STATE_VALUE:
return dbrv__getStateFromScalar(self, valState->value.scalarValue, binState, i);
break;
case STATE_RANGE:
for(*i = 0; *i < self->numStates; *i++)
{
const State presentBinState = self->stateList[*i];
if(presentBinState.type != STATE_RANGE)
return ERR_INVALID_STATE_TYPE;
if(valState == &presentBinState)
{
*binState = &presentBinState;
return 0;
}
}
return ERR_NO_BIN_STATE_FOR_VALUE;
break;
case STATE_ENUM:
return dbrv__getStateFromEnum(self, valState->value.enumValue, binState, i);
break;
default:
return ERR_INVALID_STATE_TYPE;
break;
}
return ERR_GENERAL_ERROR;
}
char dbrv__getStateFromScalar(DBRV *self, double value, const State **binState, size_t *i)
{
*i = 0;
for(*i = 0; *i < self->numStates; *i++)
{
const State presentBinState = self->stateList[*i];
if(presentBinState.type != STATE_RANGE)
return ERR_INVALID_STATE_TYPE;
if(value >= presentBinState.value.rangeValue[0] &&
value < presentBinState.value.rangeValue[1])
{
*binState = &presentBinState;
return 0;
}
}
return ERR_NO_BIN_STATE_FOR_VALUE;
}
char dbrv__getStateFromRange(DBRV *self, double low, double high, const State **binState, size_t *i)
{
*i = 0;
for(*i = 0; *i < self->numStates; *i++)
{
const State presentBinState = self->stateList[*i];
if(presentBinState.type != STATE_RANGE)
return ERR_INVALID_STATE_TYPE;
if(low == presentBinState.value.rangeValue[0] &&
high == presentBinState.value.rangeValue[1])
{
*binState = &presentBinState;
return 0;
}
}
return ERR_NO_BIN_STATE_FOR_VALUE;
}
char dbrv__getStateFromEnum(DBRV *self, const char value[MAX_SQL_ID_SIZE], const State **binState, size_t *i)
{
*i = 0;
for(*i = 0; *i < self->numStates; *i++)
{
const State presentBinState = self->stateList[*i];
if(presentBinState.type != STATE_ENUM)
return ERR_INVALID_STATE_TYPE;
if(strcmp(value, presentBinState.value.enumValue) == 0)
{
*binState = &presentBinState;
return 0;
}
}
return ERR_NO_BIN_STATE_FOR_VALUE;
}
char dbrv__stop(DBRV *self)
{
char error = 0;
if(error = DBConnector__disconnect(&self->super))
return error;
return error;
}
char dbrv__del(DBRV *self)
{
char error = 0;
size_t i = 0;
// Deallocate states
free(self->stateList);
// Deallocate parameter values and length pointers
for(i = 0; i < self->numParams; i++)
{
free(self->paramBindings[i].paramValuePtr);
free(self->paramBindings[i].indPtr);
}
// Clean up parent class
if(error = DBConnector__del(&self->super))
return error;
return error;
}
char overwriteString(char *strDest, size_t numElements, const char *strSrc, size_t *offset)
{
size_t srcBufferLen;
size_t i;
// Verify there won't be a buffer overrun
srcBufferLen = strlen(strSrc);
if(numElements <= srcBufferLen + *offset)
return -1;
// Overwrite destination string with source string
for (i = 0; i < srcBufferLen; i++)
strDest[i + *offset] = strSrc[i];
// Update offset
*offset += srcBufferLen;
strDest[*offset] = 0;
return 0;
}