Fixing predict probabilities

This commit is contained in:
Ian Barber
2011-12-18 00:53:01 +00:00
parent 677e2f827b
commit 23091cd976
2 changed files with 25 additions and 8 deletions

23
svm.c
View File

@@ -970,21 +970,23 @@ PHP_METHOD(svmmodel, predict)
}
/* }}} */
/** {{{ SvmModel::predict_probability(array data)
/** {{{ SvmModel::predict_probability(array data, array probabilities)
Predicts based on the model
*/
PHP_METHOD(svmmodel, predict_probability)
{
php_svm_model_object *intern;
double predict_probability;
int nr_classes;
int nr_classes, i;
double *estimates;
struct svm_node *x;
int max_nr_attr = 64;
zval *arr;
int *labels;
zval *arr;
zval *retarr = NULL;
/* we want an array of data to be passed in */
if (zend_parse_parameters(ZEND_NUM_ARGS() TSRMLS_CC, "a", &arr) == FAILURE) {
if (zend_parse_parameters(ZEND_NUM_ARGS() TSRMLS_CC, "az", &arr, &retarr) == FAILURE) {
return;
}
@@ -996,9 +998,20 @@ PHP_METHOD(svmmodel, predict_probability)
x = php_svm_get_data_from_array(arr TSRMLS_CC);
nr_classes = svm_get_nr_class(intern->model);
estimates = safe_emalloc(nr_classes, sizeof(double), 0);
labels = safe_emalloc(nr_classes, sizeof(int), 0);
predict_probability = svm_predict_probability(intern->model, x, estimates);
predict_probability = estimates[(int)predict_probability];
if (retarr != NULL) {
zval_dtor(retarr);
array_init(retarr);
svm_get_labels(intern->model, labels);
for (i = 0; i < nr_classes; ++i) {
add_index_double(retarr, labels[i], estimates[i]);
}
}
efree(estimates);
efree(labels);
efree(x);
RETURN_DOUBLE(predict_probability);

View File

@@ -11,7 +11,7 @@ $svm->setOptions(array(
SVM::OPT_TYPE => SVM::C_SVC,
SVM::OPT_KERNEL_TYPE => SVM::KERNEL_LINEAR,
SVM::OPT_P => 0.1, // epsilon 0.1
SVM::OPT_PROBABILITY => true
SVM::OPT_PROBABILITY => 1
));
$model = $svm->train(dirname(__FILE__) . '/abalone.scale');
@@ -27,8 +27,12 @@ if($model) {
8 => -0.704036
);
$class = $model->predict($data);
$result = $model->predict_probability($data);
if($class == 9 && $result > 0) {
$return = array();
$result = $model->predict_probability($data, $return);
arsort($return);
reset($return);
$key = key($return);
if($class == 9 && $key == 9) {
echo "ok";
} else {
echo "predict failed: $class $result";