summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--run_tests.m55
1 files changed, 25 insertions, 30 deletions
diff --git a/run_tests.m b/run_tests.m
index e89a295..4e6b1d5 100644
--- a/run_tests.m
+++ b/run_tests.m
@@ -6,44 +6,39 @@
train_data = q( 1:28, :);
real_data = q(29:42, :);
-mean_data = mean_pred(p, train_data);
-regress_data = regress_pred(p, train_data);
-quad_data = quad_regress_pred(p, train_data);
-log_data = log_regress_pred(p, train_data);
-sevenday_data = sevenday_pred(p, train_data);
-random_data = rand_pred(p, train_data);
-regress2_data = regress_frequency_removal(p, train_data);
+pred_methods = {
+ { 'mean', @mean_pred },
+ { 'regress', @regress_pred },
+ { 'quad', @quad_regress_pred },
+ { 'log reg.', @log_regress_pred },
+ { 'sevenday' @sevenday_pred },
+ { 'random' @rand_pred },
+ { 'regress2' @regress_frequency_removal }
+};
+num_methods = size(pred_methods, 1);
+
+pred_list = {};
+for i = 1:num_methods
+ pred_list{i} = pred_methods{i}{2}(p, train_data);
+end
-pred_list = {mean_data regress_data quad_data log_data sevenday_data random_data regress2_data};
+qerr = zeros(1, num_methods);
+terr = zeros(1, num_methods);
+err = zeros(num_methods, size(real_data,2));
+for i = 1:num_methods
+ [qerr(i), terr(i)] = calc_error(pred_methods{i}{1}, real_data, pred_list{i});
+ err(i, :) = sum(abs(real_data - pred_list{i}));
+end
opt_data = opt_pred(real_data, pred_list);
-
-% plot prediction quallity
-[meqerr, meterr] = calc_error('mean', real_data, mean_data);
-[reqerr, reterr] = calc_error('regress', real_data, regress_data);
-% quadratic just for reference, it sucks more than mean-predicition
-[quqerr, quterr] = calc_error('quad reg.',real_data, quad_data);
-[loqerr, loterr] = calc_error('log reg.', real_data, log_data);
-[seqerr, seterr] = calc_error('sevenday', real_data, sevenday_data);
-[raqerr, raterr] = calc_error('random', real_data, random_data);
[opqerr, opterr] = calc_error('optimize', real_data, opt_data);
-[re2qerr, re2terr] = calc_error('regress2', real_data, regress2_data);
+opt_err = sum(abs(real_data - opt_data));
+qerr = [ qerr opqerr ];
+terr = [ terr opterr ];
-qerr = [meqerr reqerr quqerr loqerr seqerr raqerr opqerr re2qerr];
-terr = [meterr reterr quterr loterr seterr raterr opterr re2terr];
bar(qerr);
bar(terr);
-mean_err = sum(abs(real_data - mean_data));
-regress_err = sum(abs(real_data - regress_data));
-quad_err = sum(abs(real_data - quad_data));
-log_err = sum(abs(real_data - log_data));
-sevenday_err = sum(abs(real_data - sevenday_data));
-random_err = sum(abs(real_data - random_data));
-opt_err = sum(abs(real_data - opt_data));
-regress2_err = sum(abs(real_data - regress2_data));
-
-err = [mean_err;regress_err;quad_err;log_err;sevenday_err;random_err;regress2_err];
[min_err, err_idx] = min(err);
printf('global min. error: %d\n', sum(min_err));
printf('local min count:');