From f38c08a5b555a846a0d4c224f617b08ffd86a28d Mon Sep 17 00:00:00 2001 From: Benjamin Franzke Date: Fri, 11 May 2012 11:26:30 +0200 Subject: Store all non-emta prediction methods in a cell-array --- run_tests.m | 55 +++++++++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/run_tests.m b/run_tests.m index e89a295..4e6b1d5 100644 --- a/run_tests.m +++ b/run_tests.m @@ -6,44 +6,39 @@ train_data = q( 1:28, :); real_data = q(29:42, :); -mean_data = mean_pred(p, train_data); -regress_data = regress_pred(p, train_data); -quad_data = quad_regress_pred(p, train_data); -log_data = log_regress_pred(p, train_data); -sevenday_data = sevenday_pred(p, train_data); -random_data = rand_pred(p, train_data); -regress2_data = regress_frequency_removal(p, train_data); +pred_methods = { + { 'mean', @mean_pred }, + { 'regress', @regress_pred }, + { 'quad', @quad_regress_pred }, + { 'log reg.', @log_regress_pred }, + { 'sevenday' @sevenday_pred }, + { 'random' @rand_pred }, + { 'regress2' @regress_frequency_removal } +}; +num_methods = size(pred_methods, 1); + +pred_list = {}; +for i = 1:num_methods + pred_list{i} = pred_methods{i}{2}(p, train_data); +end -pred_list = {mean_data regress_data quad_data log_data sevenday_data random_data regress2_data}; +qerr = zeros(1, num_methods); +terr = zeros(1, num_methods); +err = zeros(num_methods, size(real_data,2)); +for i = 1:num_methods + [qerr(i), terr(i)] = calc_error(pred_methods{i}{1}, real_data, pred_list{i}); + err(i, :) = sum(abs(real_data - pred_list{i})); +end opt_data = opt_pred(real_data, pred_list); - -% plot prediction quallity -[meqerr, meterr] = calc_error('mean', real_data, mean_data); -[reqerr, reterr] = calc_error('regress', real_data, regress_data); -% quadratic just for reference, it sucks more than mean-predicition -[quqerr, quterr] = calc_error('quad reg.',real_data, quad_data); -[loqerr, loterr] = calc_error('log reg.', real_data, log_data); -[seqerr, seterr] = calc_error('sevenday', real_data, sevenday_data); -[raqerr, raterr] = calc_error('random', real_data, random_data); [opqerr, opterr] = calc_error('optimize', real_data, opt_data); -[re2qerr, re2terr] = calc_error('regress2', real_data, regress2_data); +opt_err = sum(abs(real_data - opt_data)); +qerr = [ qerr opqerr ]; +terr = [ terr opterr ]; -qerr = [meqerr reqerr quqerr loqerr seqerr raqerr opqerr re2qerr]; -terr = [meterr reterr quterr loterr seterr raterr opterr re2terr]; bar(qerr); bar(terr); -mean_err = sum(abs(real_data - mean_data)); -regress_err = sum(abs(real_data - regress_data)); -quad_err = sum(abs(real_data - quad_data)); -log_err = sum(abs(real_data - log_data)); -sevenday_err = sum(abs(real_data - sevenday_data)); -random_err = sum(abs(real_data - random_data)); -opt_err = sum(abs(real_data - opt_data)); -regress2_err = sum(abs(real_data - regress2_data)); - -err = [mean_err;regress_err;quad_err;log_err;sevenday_err;random_err;regress2_err]; [min_err, err_idx] = min(err); printf('global min. error: %d\n', sum(min_err)); printf('local min count:'); -- cgit