diff --git a/svm.c b/svm.c index 983084a..b67cbe8 100644 --- a/svm.c +++ b/svm.c @@ -970,21 +970,23 @@ PHP_METHOD(svmmodel, predict) } /* }}} */ -/** {{{ SvmModel::predict_probability(array data) +/** {{{ SvmModel::predict_probability(array data, array probabilities) Predicts based on the model */ PHP_METHOD(svmmodel, predict_probability) { php_svm_model_object *intern; double predict_probability; - int nr_classes; + int nr_classes, i; double *estimates; struct svm_node *x; int max_nr_attr = 64; - zval *arr; + int *labels; + zval *arr; + zval *retarr = NULL; /* we want an array of data to be passed in */ - if (zend_parse_parameters(ZEND_NUM_ARGS() TSRMLS_CC, "a", &arr) == FAILURE) { + if (zend_parse_parameters(ZEND_NUM_ARGS() TSRMLS_CC, "az", &arr, &retarr) == FAILURE) { return; } @@ -996,9 +998,20 @@ PHP_METHOD(svmmodel, predict_probability) x = php_svm_get_data_from_array(arr TSRMLS_CC); nr_classes = svm_get_nr_class(intern->model); estimates = safe_emalloc(nr_classes, sizeof(double), 0); + labels = safe_emalloc(nr_classes, sizeof(int), 0); predict_probability = svm_predict_probability(intern->model, x, estimates); - predict_probability = estimates[(int)predict_probability]; + + if (retarr != NULL) { + zval_dtor(retarr); + array_init(retarr); + svm_get_labels(intern->model, labels); + for (i = 0; i < nr_classes; ++i) { + add_index_double(retarr, labels[i], estimates[i]); + } + } + efree(estimates); + efree(labels); efree(x); RETURN_DOUBLE(predict_probability); diff --git a/tests/014_predict_probability.phpt b/tests/014_predict_probability.phpt index 9837679..7630bdf 100644 --- a/tests/014_predict_probability.phpt +++ b/tests/014_predict_probability.phpt @@ -11,7 +11,7 @@ $svm->setOptions(array( SVM::OPT_TYPE => SVM::C_SVC, SVM::OPT_KERNEL_TYPE => SVM::KERNEL_LINEAR, SVM::OPT_P => 0.1, // epsilon 0.1 - SVM::OPT_PROBABILITY => true + SVM::OPT_PROBABILITY => 1 )); $model = $svm->train(dirname(__FILE__) . '/abalone.scale'); @@ -27,8 +27,12 @@ if($model) { 8 => -0.704036 ); $class = $model->predict($data); - $result = $model->predict_probability($data); - if($class == 9 && $result > 0) { + $return = array(); + $result = $model->predict_probability($data, $return); + arsort($return); + reset($return); + $key = key($return); + if($class == 9 && $key == 9) { echo "ok"; } else { echo "predict failed: $class $result";