Skip to content
Snippets Groups Projects
Commit 1a527826 authored by Nikolaj Tatti's avatar Nikolaj Tatti
Browse files

initial version

parents
No related branches found
No related tags found
No related merge requests found
SRCS = asso.cpp dataset.cpp
OBJS = $(SRCS:.cpp=.o)
CC = g++
LDFLAGS= -lm -O3
CFLAGS= -g -O3 -Wall -Wextra #-Werror
all : asso
asso: $(OBJS)
$(CC) $(LDFLAGS) -o $@ $(OBJS)
%.o : %.cpp
$(CC) $(CFLAGS) -c $*.cpp
clean:
-rm *.o asso
zip:
zip asso.zip *.cpp *.h Makefile
depend:
makedepend -Y -- $(CFLAGS) -- $(SRCS)
# DO NOT DELETE
asso.o: defines.h
dataset.o: dataset.h defines.h
#include "dataset.h"
#include <algorithm>
#include <stdio.h>
#include <getopt.h>
#include <math.h>
template<class InputIt1, class InputIt2>
uint32_t intersection_cnt(InputIt1 first1, InputIt1 last1,
InputIt2 first2, InputIt2 last2)
{
uint32_t cnt = 0;
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
++first1;
} else {
if (!(*first2 < *first1)) {
cnt++;
++first1;
}
++first2;
}
}
return cnt;
}
void
index_union(const uintvector & a, const uintvector & b, uintvector & out)
{
uint32_t size = a.size() + b.size() - intersection_cnt(a.begin(), a.end(), b.begin(), b.end());
out.resize(0);
out.reserve(size);
std::set_union(a.begin(), a.end(), b.begin(), b.end(), std::back_inserter(out));
}
void
find_basis(const dataset & d, double threshold, dataset & basis)
{
uint32_t dim = d.cols.size();
basis.cols.resize(dim);
basis.rowcnt = dim;
for (uint32_t i = 0; i < dim; i++) {
doublevector c(dim);
for (uint32_t j = 0; j < dim; j++) {
c[j] = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), d.cols[j].begin(), d.cols[j].end());
c[j] /= d.cols[i].size();
}
uint32_t size = 0;
for (uint32_t j = 0; j < dim; j++) {
if (c[j] >= threshold) size++;
}
basis.cols[i].reserve(size);
for (uint32_t j = 0; j < dim; j++) {
if (c[j] >= threshold) {
basis.cols[i].push_back(j);
}
}
}
}
void
count_column(const uintvector & dcol, const uintvector & ccol, double weight, doublevector & score)
{
uint32_t ind1 = 0;
uint32_t ind2 = 0;
while (ind1 < ccol.size() && ind2 < dcol.size()) {
if (ccol[ind1] <= dcol[ind2]) {
if (ccol[ind1] == dcol[ind2]) ind2++;
score[ccol[ind1]] += 1;
ind1++;
//printf("foo\n");
}
else {
score[dcol[ind2]] += weight + 1;
ind2++;
}
}
for (;ind1 < ccol.size(); ind1++) {
score[ccol[ind1]] += 1;
}
for (;ind2 < dcol.size(); ind2++) {
score[dcol[ind2]] += weight + 1;
}
}
void
cover(const dataset & d, const dataset & basis, double weight, uint32_t k, uintvector & selected, dataset & decomp, uintvector & counts)
{
decomp.cols.resize(k);
decomp.rowcnt = d.rowcnt;
selected.resize(k);
dataset covered;
covered.cols.resize(d.cols.size());
counts.resize(4);
for (uint32_t b = 0; b < k; b++) {
doublevector best_score;
double best_gain = -1;
uint32_t best = -1;
for (uint32_t i = 0; i < basis.cols.size(); i++) {
doublevector cand_score(d.rowcnt, -double(basis.cols[i].size()));
//printf("%d\n", i);
for (uint32_t j = 0; j < basis.cols[i].size(); j++) {
uint32_t c = basis.cols[i][j];
const uintvector & dcol = d.cols[c];
const uintvector & ccol = covered.cols[c];
count_column(dcol, ccol, weight, cand_score);
}
double cand_gain = 0;
for (uint32_t j = 0; j < cand_score.size(); j++) {
if (cand_score[j] > 0) cand_gain += cand_score[j];
}
if (cand_gain > best_gain) {
best_score.swap(cand_score);
best_gain = cand_gain;
best = i;
}
//printf("%f %d\n", cand_gain, best);
}
//printf("BEST: %d\n", best);
selected[b] = best;
uint32_t size = 0;
for (uint32_t i = 0; i < best_score.size(); i++) {
//printf("%f ", best_score[i]);
if (best_score[i] > 0) size++;
}
decomp.cols[b].resize(0);
decomp.cols[b].reserve(size);
for (uint32_t i = 0; i < best_score.size(); i++) {
if (best_score[i] > 0) decomp.cols[b].push_back(i);
}
for (uint32_t i = 0; i < basis.cols[best].size(); i++) {
uint32_t c = basis.cols[best][i];
uintvector out;
index_union(decomp.cols[b], covered.cols[c], out);
covered.cols[c].swap(out);
//printf("%d %d\n", c, covered.cols[c].size());
}
}
// Compute stats
for (uint32_t i = 0; i < d.cols.size(); i++) {
uint32_t size = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), covered.cols[i].begin(), covered.cols[i].end());
counts[3] += size;
counts[2] += covered.cols[i].size() - size;
counts[1] += d.cols[i].size() - size;
counts[0] += d.rowcnt - covered.cols[i].size() - d.cols[i].size() + size;
}
}
void
read_sparse(FILE *f, dataset & d)
{
uint32_t dim;
fscanf(f, "%d%d", &d.rowcnt, &dim);
d.cols.resize(dim);
uintvector counts(dim);
for (uint32_t i = 0; i < d.rowcnt; i++) {
uint32_t cnt;
fscanf(f, "%d", &cnt);
for (uint32_t j = 0; j < cnt; j++) {
uint32_t ind;
fscanf(f, "%d", &ind);
counts[ind - 1]++;
}
}
for (uint32_t i = 0; i < dim; i++)
d.cols[i].reserve(counts[i]);
rewind(f);
fscanf(f, "%*d%*d");
for (uint32_t i = 0; i < d.rowcnt; i++) {
uint32_t cnt;
fscanf(f, "%d", &cnt);
for (uint32_t j = 0; j < cnt; j++) {
uint32_t ind;
fscanf(f, "%d", &ind);
d.cols[ind - 1].push_back(i);
}
}
}
void
print_sparse(FILE *f, const dataset & d)
{
fprintf(f, "%lu\n%d\n", d.cols.size(), d.rowcnt);
for (uint32_t i = 0; i < d.cols.size(); i++) {
fprintf(f, "%lu", d.cols[i].size());
for (uint32_t j = 0; j < d.cols[i].size(); j++) {
fprintf(f, " %d", 1 + d.cols[i][j]);
}
fprintf(f, "\n");
}
}
void
print_sparse(FILE *f, const dataset & d, const uintvector & selected)
{
fprintf(f, "%lu\n%d\n", selected.size(), d.rowcnt);
for (uint32_t i = 0; i < selected.size(); i++) {
fprintf(f, "%lu", d.cols[selected[i]].size());
for (uint32_t j = 0; j < d.cols[selected[i]].size(); j++) {
fprintf(f, " %d", 1 + d.cols[selected[i]][j]);
}
fprintf(f, "\n");
}
}
int
main(int argc, char **argv)
{
static struct option longopts[] = {
{"threshold", required_argument, NULL, 't'},
{"weight", required_argument, NULL, 'w'},
{"out", required_argument, NULL, 'o'},
{"in", required_argument, NULL, 'i'},
{"order", required_argument, NULL, 'k'},
{"help", no_argument, NULL, 'h'},
{ NULL, 0, NULL, 0 }
};
char *inname = NULL;
char *outname = NULL;
double threshold = 1;
double weight = 1;
uint32_t k = 1;
int ch;
while ((ch = getopt_long(argc, argv, "ho:i:w:t:k:", longopts, NULL)) != -1) {
switch (ch) {
case 'h':
printf("Usage: %s -i <input file> -o <output file> [-k <order>] [-t <threshold>] [-w <weight>] [-h]\n", argv[0]);
printf(" -h print this help\n");
printf(" -i input file\n");
printf(" -o output file\n");
printf(" -w weight for covering 1 (default = 1)\n");
printf(" -k decomposition order(default = 1)\n");
printf(" -t threshold for combining columns when searching for candidate basis (default = 1)\n");
return 0;
break;
case 'i':
inname = optarg;
break;
case 'o':
outname = optarg;
break;
case 'w':
weight = atof(optarg);
break;
case 'k':
k = atoi(optarg);
break;
case 't':
threshold = atof(optarg);
break;
}
}
dataset d;
FILE *f = stdin;
if (inname) f = fopen(inname, "r");
read_sparse(f, d);
if (inname) fclose(f);
dataset basis;
find_basis(d, threshold, basis);
dataset decomp;
uintvector selected;
uintvector counts;
cover(d, basis, weight, k, selected, decomp, counts);
//print_sparse(stdout, basis);
FILE *out = stdout;
if (outname) out = fopen(outname, "w");
fprintf(out, "%d %d %d %d\n", counts[0], counts[1], counts[2], counts[3]);
print_sparse(out, basis, selected);
print_sparse(out, decomp);
if (outname) fclose(out);
return 0;
}
#include "dataset.h"
#ifndef DATASET_H
#define DATASET_H
#include "defines.h"
struct dataset {
uintmatrix cols;
uint32_t rowcnt;
};
#endif
#ifndef DEFINES_H
#define DEFINES_H
#include <stdint.h>
#include <vector>
typedef std::vector<uint32_t> uintvector;
typedef std::vector<double> doublevector;
typedef std::vector<uintvector> uintmatrix;
#endif
import sys
import subprocess
import argparse
from math import log, sqrt
def bernllh(c1, c2):
if c1 == 0 or c2 == 0:
return 0.0
p = c2 / float(c1 + c2)
return -c1*log(1 - p) - c2*log(p)
def cost(counts):
return bernllh(counts[0], counts[1]) + bernllh(counts[2], counts[3])
querycnt = 0
def query(inname, k, weight):
global querycnt
querycnt += 1
threshold = 1 / (weight + 1.0)
out = subprocess.check_output(['../asso/asso', '-i', inname, '-k', str(k), '-w', str(weight), '-t', str(threshold)])
counts = [float(x) for x in out.split(None, 4)[:4]]
return cost(counts), counts, out
def datasize(inname):
f = open(inname)
rowcnt = int(f.readline().strip())
dim = int(f.readline().strip())
onecnt = sum([int(line.split(None, 1)[0]) for line in f])
return rowcnt, dim, onecnt
def alpha(q, s):
if q == 1:
return 0
if s == 0:
return float('inf')
return (log(q) - log(s)) / (log(1 - s) - log(1 - q))
def alpha_from_counts(counts):
q = 1
s = 0
if counts[3] + counts[2] > 0:
q = counts[3] / float(counts[3] + counts[2])
if counts[1] + counts[0] > 0:
s = counts[1] / float(counts[1] + counts[0])
return alpha(q, s)
def sgrid(inname, approx, k):
rowcnt, dim, onecnt = datasize(inname)
cellcnt = float(rowcnt*dim)
density = onecnt / cellcnt
best = (float('inf'), None, None)
cand = query(inname, k, 1 / cellcnt)
best = min(best, cand)
cand = query(inname, k, cellcnt)
best = min(best, cand)
par = alpha(density, 1 / cellcnt)
while par >= alpha(1 - 1 / cellcnt, density):
cand = query(inname, k, par)
best = min(best, cand)
par /= (1 + approx)
return best
def sgrid_with_pruning(inname, approx, k):
rowcnt, dim, onecnt = datasize(inname)
cellcnt = float(rowcnt*dim)
density = onecnt / cellcnt
best = None
best_counts = None
best_cost = float('inf')
best = (float('inf'), None, None)
lbound = alpha(1 - 1 / cellcnt, density)*(1 + approx)
ubound = alpha(density, 1 / cellcnt)
windows = [(lbound, ubound)]
minpar = float('inf')
maxpar = 0
while len(windows) > 0:
lower, upper = windows.pop()
if lower*(1 + approx) < upper:
mid = sqrt(lower * upper)
#print lower, upper, mid
cand = query(inname, k, mid)
best = min(best, cand)
par = alpha_from_counts(cand[1])
minpar = min(par, minpar)
maxpar = max(par, maxpar)
windows.append([lower, min(mid, par)])
windows.append([max(mid, par), upper])
if minpar > lbound:
cand = query(inname, k, lbound)
best = min(best, cand)
if minpar > 0:
cand = query(inname, k, 1 / cellcnt)
best = min(best, cand)
if minpar < ubound:
cand = query(inname, k, ubound)
best = min(best, cand)
if minpar < float('inf'):
cand = query(inname, k, cellcnt)
best = min(best, cand)
return best
parser = argparse.ArgumentParser(description='Grid search for BMF')
parser.add_argument('-i', help='input file')
parser.add_argument('-k', help='number of components')
parser.add_argument('-a', type=float, help='approximation')
parser.add_argument('-w', type=float, help='fixed weight parameter, no grid search')
parser.add_argument('-p', help='use pruning', action='store_true')
parser.add_argument('-o', help='output file')
args = parser.parse_args()
if args.w:
best_cost, counts, out = query(args.i, args.k, args.w)
elif args.p:
best_cost, counts, out = sgrid_with_pruning(args.i, args.a, args.k)
else:
best_cost, counts, out = sgrid(args.i, args.a, args.k)
print best_cost, querycnt, alpha_from_counts(counts)
of = open(args.o, 'w')
print >>of, out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment