/* ============================================================================ Name : dbrv.c Author : Stephen Cannon Version : 0.1 Copyright : Copyright 2011 Stephen Cannon Description : ============================================================================ * * This file is part of LikelihoodWeighting. * * LikelihoodWeighting is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * LikelihoodWeighting is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public * License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with LikelihoodWeighting. If not, see . * */ #include #include #include #include #include "dbrv.h" #include "DBConnector.h" char dbrv__init(DBRV *self, const char *modelName, const char *nodeName, char enumNode, State *stateList, size_t numStates, ParentTypePair *parentTypeMap, size_t numParents, const char *connectionString) { size_t i = 0; size_t c = 0; char error = 0; // Initialize attibutes self->numParams = 0; self->strOutValue[0] = 0; self->enumNode = enumNode; // Define node states self->numStates = numStates; self->stateList = (State *)calloc(numStates, sizeof(State)); for(i = 0; i < numStates; i++) { self->stateList[i].type = stateList[i].type; switch(stateList[i].type) { case STATE_VALUE: self->stateList[i].value.scalarValue = stateList[i].value.scalarValue; break; case STATE_RANGE: self->stateList[i].value.rangeValue[0] = stateList[i].value.rangeValue[0]; self->stateList[i].value.rangeValue[1] = stateList[i].value.rangeValue[1]; break; case STATE_ENUM: for(c = 0 ; c < MAX_SQL_ID_SIZE && stateList[i].value.enumValue[c] != 0; c++) self->stateList[i].value.enumValue[c] = stateList[i].value.enumValue[c]; self->stateList[i].value.enumValue[c] = 0; break; default: return ERR_INVALID_STATE_TYPE; } } // Initialize query SQL string self->querySQL[0] = 0; self->querySQL[LARGE_BUFFER_SIZE - 1] = 0; // Get node name self->nodeName = nodeName; // Get model name self->modelName = modelName; // Generate parent class DBConnector if(error = DBConnector__init(&self->super, connectionString)) return error; // Build query SQL if(error = __dbrv__prepareQuerySQL(self, enumNode, parentTypeMap, numParents)) return error; return error; } /** * @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...} */ char dbrv__start(DBRV *self) { char error = 0; // Connect to DB if(error = DBConnector__connect(&self->super)) return error; // Prepare statement if(error = DBConnector__prepareStatement(&self->super, self->querySQL, self->paramBindings, self->numParams, self->columnBindings, 2)) return error; return error; } /** * @param parentTypeMap {parent1:(1, 0, &pVal1, 1024, &pValLen1), parent2:(0, 10, &pVal2, 0, &pValLen2), ...} */ char __dbrv__prepareQuerySQL(DBRV *self, char enumNode, ParentTypePair *parentTypeMap, size_t numParents) { char error = 0; size_t offset = 0; size_t i = 0; // Verify fewer than 32 parents if(numParents > 32) return ERR_TOO_MANY_PARENTS; // Start SELECT statement if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "SELECT `", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`, SUM(frequency) as frequency FROM ", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->modelName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "_", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1; // Setup WHERE clause if(numParents > 0) if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " WHERE `", &offset)) return -1; for(i = 0; i < numParents; i++) { const char *pNodeName = 0; // Get parent node name pNodeName = parentTypeMap[i].parent; // Get parent value if(parentTypeMap[i].enumeratedType) { if(i > 0) if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` = ?", &offset)) return -1; } else { if(i > 0) if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, ", ", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` >= ? AND `", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, pNodeName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "` < ?", &offset)) return -1; } // bind parameter if(__dbrv__setupParamBinding(self, pNodeName, &parentTypeMap[i])) return -1; } // Setup GROUP BY clause if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, " GROUP BY `", &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, self->nodeName, &offset)) return -1; if(overwriteString(self->querySQL, LARGE_BUFFER_SIZE, "`;", &offset)) return -1; // Define column bindings if(__dbrv__setupColumnBinding(self, enumNode)) return -1; return 0; } char __dbrv__setupParamBinding(DBRV *self, const char *parentName, ParentTypePair *parentType) { char *strParamValue = 0; double *dParamValue = 0; SQLLEN *paramLenPtr = 0; if(parentType->enumeratedType) { // Define parameter binding self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT; self->paramBindings[self->numParams].valueType = SQL_C_CHAR; self->paramBindings[self->numParams].paramType = SQL_CHAR; self->paramBindings[self->numParams].columnSize = 0; self->paramBindings[self->numParams].decimalDigits = parentType->precision; strParamValue = (char *)malloc(sizeof(char)*BUFFER_SIZE); self->paramBindings[self->numParams].paramValuePtr = strParamValue; self->paramBindings[self->numParams].bufferLength = BUFFER_SIZE; paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN)); self->paramBindings[self->numParams].indPtr = paramLenPtr; } else { // Define parameter binding self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT; self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE; self->paramBindings[self->numParams].paramType = SQL_DOUBLE; self->paramBindings[self->numParams].columnSize = 0; self->paramBindings[self->numParams].decimalDigits = parentType->precision; dParamValue = (double *)malloc(sizeof(double)); self->paramBindings[self->numParams].paramValuePtr = dParamValue; self->paramBindings[self->numParams].bufferLength = 0; paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN)); self->paramBindings[self->numParams].indPtr = paramLenPtr; self->numParams++; // Define parameter binding for second parameter self->paramBindings[self->numParams].ioType = SQL_PARAM_INPUT; self->paramBindings[self->numParams].valueType = SQL_C_DOUBLE; self->paramBindings[self->numParams].paramType = SQL_DOUBLE; self->paramBindings[self->numParams].columnSize = 0; self->paramBindings[self->numParams].decimalDigits = parentType->precision; dParamValue = (double *)malloc(sizeof(double)); self->paramBindings[self->numParams].paramValuePtr = dParamValue; self->paramBindings[self->numParams].bufferLength = 0; paramLenPtr = (SQLLEN *)malloc(sizeof(SQLLEN)); self->paramBindings[self->numParams].indPtr = paramLenPtr; } self->numParams++; return 0; } char __dbrv__setupColumnBinding(DBRV *self, char enumNode) { if(enumNode) { self->columnBindings[0].type = SQL_C_CHAR; } else { self->columnBindings[0].type = SQL_C_DOUBLE; } self->columnBindings[0].valuePtr = 0; self->columnBindings[0].bufferLength = 0; self->columnBindings[0].indPtr = 0; self->columnBindings[1].type = SQL_C_LONG; self->columnBindings[1].valuePtr = &self->lastFreq; self->columnBindings[1].bufferLength = 0; self->columnBindings[1].indPtr = &self->lastFreqID; return 0; } /** * @param ev Evidence nodes presented in the order they were presented in the * parentTypeMap parameter when the start() method was called. * @param numEvNodes * @param probs [OUT] */ char dbrv__logP(DBRV *self, Evidence *ev, size_t numEvNodes, double *probs) { size_t i = 0; size_t pBindIndex = 0; char error = 0; size_t totalFreq = 0; // Define parent values for(i = 0; i < numEvNodes; i++) { const State *binState; size_t index; // Get bin state if(error = dbrv__getState(ev[i].rv, ev[i].state, &binState, &index)) return error; // Bind parameter values switch(binState->type) { case STATE_RANGE: self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[0]; self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.rangeValue[1]; break; case STATE_ENUM: self->paramBindings[pBindIndex++].paramValuePtr = (SQLPOINTER)&binState->value.enumValue; break; default: return ERR_INVALID_STATE_TYPE; break; } } // Get CPT for(i = 0; i < self->numStates; i++) probs[i] = 0; if(error = DBConnector__executePreparedStatement(&self->super)) return error; while(DBConnector__fetchExecutedStatementResult(&self->super, &error)) { const State *binState; size_t index; if(error = dbrv__getStateFromScalar(self, self->dOutValue, &binState, &index)) return error; probs[index] += self->lastFreq; totalFreq += self->lastFreq; } if(error) return error; for(i = 0; i < self->numStates; i++) probs[i] /= totalFreq; return error; } char dbrv__randomSample(DBRV *self, Evidence *ev, size_t numEvNodes, State *sample) { return 1; } char dbrv__getState(DBRV *self, const State *valState, const State **binState, size_t *i) { *i = 0; switch(valState->type) { case STATE_VALUE: return dbrv__getStateFromScalar(self, valState->value.scalarValue, binState, i); break; case STATE_RANGE: for(*i = 0; *i < self->numStates; *i++) { const State presentBinState = self->stateList[*i]; if(presentBinState.type != STATE_RANGE) return ERR_INVALID_STATE_TYPE; if(valState == &presentBinState) { *binState = &presentBinState; return 0; } } return ERR_NO_BIN_STATE_FOR_VALUE; break; case STATE_ENUM: return dbrv__getStateFromEnum(self, valState->value.enumValue, binState, i); break; default: return ERR_INVALID_STATE_TYPE; break; } return ERR_GENERAL_ERROR; } char dbrv__getStateFromScalar(DBRV *self, double value, const State **binState, size_t *i) { *i = 0; for(*i = 0; *i < self->numStates; *i++) { const State presentBinState = self->stateList[*i]; if(presentBinState.type != STATE_RANGE) return ERR_INVALID_STATE_TYPE; if(value >= presentBinState.value.rangeValue[0] && value < presentBinState.value.rangeValue[1]) { *binState = &presentBinState; return 0; } } return ERR_NO_BIN_STATE_FOR_VALUE; } char dbrv__getStateFromRange(DBRV *self, double low, double high, const State **binState, size_t *i) { *i = 0; for(*i = 0; *i < self->numStates; *i++) { const State presentBinState = self->stateList[*i]; if(presentBinState.type != STATE_RANGE) return ERR_INVALID_STATE_TYPE; if(low == presentBinState.value.rangeValue[0] && high == presentBinState.value.rangeValue[1]) { *binState = &presentBinState; return 0; } } return ERR_NO_BIN_STATE_FOR_VALUE; } char dbrv__getStateFromEnum(DBRV *self, const char value[MAX_SQL_ID_SIZE], const State **binState, size_t *i) { *i = 0; for(*i = 0; *i < self->numStates; *i++) { const State presentBinState = self->stateList[*i]; if(presentBinState.type != STATE_ENUM) return ERR_INVALID_STATE_TYPE; if(strcmp(value, presentBinState.value.enumValue) == 0) { *binState = &presentBinState; return 0; } } return ERR_NO_BIN_STATE_FOR_VALUE; } char dbrv__stop(DBRV *self) { char error = 0; if(error = DBConnector__disconnect(&self->super)) return error; return error; } char dbrv__del(DBRV *self) { char error = 0; size_t i = 0; // Deallocate states free(self->stateList); // Deallocate parameter values and length pointers for(i = 0; i < self->numParams; i++) { free(self->paramBindings[i].paramValuePtr); free(self->paramBindings[i].indPtr); } // Clean up parent class if(error = DBConnector__del(&self->super)) return error; return error; } char overwriteString(char *strDest, size_t numElements, const char *strSrc, size_t *offset) { size_t srcBufferLen; size_t i; // Verify there won't be a buffer overrun srcBufferLen = strlen(strSrc); if(numElements <= srcBufferLen + *offset) return -1; // Overwrite destination string with source string for (i = 0; i < srcBufferLen; i++) strDest[i + *offset] = strSrc[i]; // Update offset *offset += srcBufferLen; strDest[*offset] = 0; return 0; }