Skip to content
Snippets Groups Projects
Commit 22b130b8 authored by Nikolaj Tatti's avatar Nikolaj Tatti
Browse files

updated version of asso: faster for large tiles

parent eb122ab5
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ OBJS = $(SRCS:.cpp=.o)
CC = g++
LDFLAGS= -lm -O3
CFLAGS= -g -O3 -Wall -Wextra #-Werror
CFLAGS= -g -O3 -MMD -Wall -Wextra #-Werror
all : asso
......@@ -21,9 +21,4 @@ clean:
zip:
zip asso.zip *.cpp *.h Makefile
depend:
makedepend -Y -- $(CFLAGS) -- $(SRCS)
# DO NOT DELETE
asso.o: dataset.h defines.h
dataset.o: dataset.h defines.h
-include $(SRCS:%.cpp=%.d)
......@@ -4,6 +4,13 @@
#include <getopt.h>
#include <math.h>
struct tilecover {
tilecover(uint32_t dim) : rows(dim), map(dim) {counts.reserve(dim); counts.push_back(dim);}
uintmatrix rows;
uintvector counts;
uintvector map;
};
template<class InputIt1, class InputIt2>
uint32_t intersection_cnt(InputIt1 first1, InputIt1 last1,
......@@ -24,6 +31,7 @@ uint32_t intersection_cnt(InputIt1 first1, InputIt1 last1,
return cnt;
}
void
index_union(const uintvector & a, const uintvector & b, uintvector & out)
{
......@@ -31,8 +39,50 @@ index_union(const uintvector & a, const uintvector & b, uintvector & out)
out.resize(0);
out.reserve(size);
std::set_union(a.begin(), a.end(), b.begin(), b.end(), std::back_inserter(out));
assert(size == out.size());
}
void
extend(tilecover & c, const uintvector & cols, const uintvector & rows)
{
uintvector counts(c.counts.size());
for (uint32_t i = 0; i < cols.size(); i++) {
counts[c.map[cols[i]]]++;
}
boolvector move(counts.size());
for (uint32_t i = 0; i < counts.size(); i++)
move[i] = counts[i] == c.counts[i];
uintvector remap(counts.size());
for (uint32_t i = 0; i < cols.size(); i++) {
uint32_t ind = cols[i];
uint32_t id = c.map[ind];
uintvector & r = c.rows[id];
if (move[id]) {
uintvector out;
index_union(rows, r, out);
r.swap(out);
}
else {
if (remap[id] == 0) {
remap[id] = c.counts.size();
index_union(rows, r, c.rows[remap[id]]);
c.counts.push_back(0);
}
c.map[ind] = remap[id];
c.counts[remap[id]]++;
c.counts[id]--;
}
}
}
void
find_basis(const dataset & d, double threshold, dataset & basis)
{
......@@ -69,25 +119,191 @@ count_column(const uintvector & dcol, const uintvector & ccol, double weight, do
while (ind1 < ccol.size() && ind2 < dcol.size()) {
if (ccol[ind1] <= dcol[ind2]) {
if (ccol[ind1] == dcol[ind2]) ind2++;
//if (score[ccol[ind1]] == 0) inds.push_back(ccol[ind1]);
score[ccol[ind1]] += 1;
ind1++;
//printf("foo\n");
}
else {
//if (score[dcol[ind2]] == 0) inds.push_back(dcol[ind2]);
score[dcol[ind2]] += weight + 1;
ind2++;
}
}
for (;ind1 < ccol.size(); ind1++) {
//if (score[ccol[ind1]] == 0) inds.push_back(ccol[ind1]);
score[ccol[ind1]] += 1;
}
for (;ind2 < dcol.size(); ind2++) {
//if (score[dcol[ind2]] == 0) inds.push_back(dcol[ind2]);
score[dcol[ind2]] += weight + 1;
}
}
void
cover(const dataset & d, const dataset & basis, double weight, uint32_t k, uintvector & selected, dataset & decomp, uintvector & counts)
transpose(const dataset & src, dataset & dst)
{
dst.rowcnt = src.cols.size();
dst.cols.resize(src.rowcnt);
uintvector counts(src.rowcnt);
for (uint32_t i = 0; i < src.cols.size(); i++) {
for (uint32_t j = 0; j < src.cols[i].size(); j++) {
counts[src.cols[i][j]]++;
}
}
for (uint32_t i = 0; i < src.rowcnt; i++)
dst.cols[i].reserve(counts[i]);
for (uint32_t i = 0; i < src.cols.size(); i++) {
for (uint32_t j = 0; j < src.cols[i].size(); j++) {
dst.cols[src.cols[i][j]].push_back(i);
}
}
}
double
find_rows(const uintvector & tile, const dataset & d, const dataset & covered, double weight, uintvector & rows)
{
doublevector cand_score(d.rowcnt, 0);
uint32_t cnt = tile.size();
for (uint32_t j = 0; j < cnt; j++) {
uint32_t c = tile[j];
const uintvector & dcol = d.cols[c];
const uintvector & ccol = covered.cols[c];
count_column(dcol, ccol, weight, cand_score);
}
double cand_gain = 0;
uint32_t size = 0;
for (uint32_t j = 0; j < d.rowcnt; j++) {
if (cand_score[j] > cnt) {
size++;
cand_gain += cand_score[j] - cnt;
}
}
rows.resize(0);
rows.reserve(size);
for (uint32_t j = 0; j < d.rowcnt; j++) {
if (cand_score[j] > cnt) {
rows.push_back(j);
}
}
return cand_gain;
}
double
find_rows(const uintvector & tile, const dataset & d, const tilecover & tc, double weight, uintvector & rows)
{
doublevector cand_score(d.rowcnt, 0);
uint32_t cnt = tile.size();
uintvector counts(tc.counts.size());
for (uint32_t j = 0; j < cnt; j++) {
uint32_t c = tile[j];
const uintvector & dcol = d.cols[c];
for (uint32_t i = 0; i < dcol.size(); i++)
cand_score[dcol[i]] += weight + 1;
counts[tc.map[c]]++;
}
for (uint32_t j = 0; j < counts.size(); j++) {
if (counts[j] == 0) continue;
const uintvector & dcol = tc.rows[j];
for (uint32_t i = 0; i < dcol.size(); i++)
cand_score[dcol[i]] += counts[j];
}
double cand_gain = 0;
uint32_t size = 0;
for (uint32_t j = 0; j < d.rowcnt; j++) {
if (cand_score[j] > cnt) {
size++;
cand_gain += cand_score[j] - cnt;
}
}
rows.resize(0);
rows.reserve(size);
for (uint32_t j = 0; j < d.rowcnt; j++) {
if (cand_score[j] > cnt) {
rows.push_back(j);
}
}
return cand_gain;
}
void
test(const uintvector & tile, const dataset & d, const tilecover & tc, double weight, uintvector & rows, const dataset & orig, const dataset & covered)
{
doublevector cand_score(d.rowcnt, 0);
uint32_t cnt = tile.size();
uintvector counts(tc.counts.size());
for (uint32_t j = 0; j < cnt; j++) {
uint32_t c = tile[j];
const uintvector & dcol = d.cols[c];
for (uint32_t i = 0; i < dcol.size(); i++)
cand_score[dcol[i]] += weight + 1;
counts[tc.map[c]]++;
}
for (uint32_t j = 0; j < counts.size(); j++) {
const uintvector & dcol = tc.rows[j];
printf("count: %d %d\n", counts[j], dcol.size());
if (counts[j] == 0) continue;
for (uint32_t i = 0; i < dcol.size(); i++)
cand_score[dcol[i]] += counts[j];
}
doublevector cand_score1(d.rowcnt, 0);
for (uint32_t j = 0; j < cnt; j++) {
uint32_t c = tile[j];
printf("%d ", c);
const uintvector & dcol = d.cols[c];
const uintvector & ccol = covered.cols[c];
count_column(dcol, ccol, weight, cand_score1);
}
printf("\n");
for (uint32_t i = 0; i < d.rowcnt; i++) {
if (cand_score[i] != cand_score1[i]) {
printf("%d %f %f\n", i, cand_score[i], cand_score1[i]);
}
}
assert(cand_score == cand_score1);
}
void
cover(dataset & d, const dataset & basis, double weight, uint32_t k, uintvector & selected, dataset & decomp, doublevector & counts)
{
decomp.cols.resize(k);
decomp.rowcnt = d.rowcnt;
......@@ -98,68 +314,92 @@ cover(const dataset & d, const dataset & basis, double weight, uint32_t k, uintv
counts.resize(4);
uint32_t tilecnt = basis.cols.size();
doublevector scores(tilecnt);
uintmatrix cands(tilecnt);
dataset dep;
transpose(basis, dep);
tilecover tc(d.cols.size());
for (uint32_t i = 0; i < tilecnt; i++) {
scores[i] = find_rows(basis.cols[i], d, tc, weight, cands[i]);
if (i % 100 == 0) fprintf(stderr, "%d \r", i);
}
for (uint32_t b = 0; b < k; b++) {
doublevector best_score;
double best_gain = -1;
uint32_t best = -1;
for (uint32_t i = 0; i < basis.cols.size(); i++) {
doublevector cand_score(d.rowcnt, -double(basis.cols[i].size()));
//printf("%d\n", i);
for (uint32_t j = 0; j < basis.cols[i].size(); j++) {
uint32_t c = basis.cols[i][j];
const uintvector & dcol = d.cols[c];
const uintvector & ccol = covered.cols[c];
count_column(dcol, ccol, weight, cand_score);
}
uint32_t best = std::distance(scores.begin(), std::max_element(scores.begin(), scores.end()));
double cand_gain = 0;
for (uint32_t j = 0; j < cand_score.size(); j++) {
if (cand_score[j] > 0) cand_gain += cand_score[j];
}
selected[b] = best;
decomp.cols[b].swap(cands[best]);
if (cand_gain > best_gain) {
best_score.swap(cand_score);
best_gain = cand_gain;
best = i;
}
//printf("%f %d\n", cand_gain, best);
const uintvector & cols = basis.cols[best];
const uintvector & rows = decomp.cols[b];
boolvector active(d.rowcnt);
for (uint32_t i = 0; i < rows.size(); i++) {
active[rows[i]] = true;
}
//printf("BEST: %d\n", best);
selected[b] = best;
uint32_t size = 0;
for (uint32_t i = 0; i < best_score.size(); i++) {
//printf("%f ", best_score[i]);
if (best_score[i] > 0) size++;
for (uint32_t i = 0; i < cols.size(); i++) {
uintvector & r = d.cols[cols[i]];
uint32_t shift = 0;
for (uint32_t i = 0; i < r.size(); i++) {
if (active[r[i]])
shift++;
else
r[i - shift] = r[i];
}
r.resize(r.size() - shift);
counts[3] += shift;
}
decomp.cols[b].resize(0);
decomp.cols[b].reserve(size);
for (uint32_t i = 0; i < best_score.size(); i++) {
if (best_score[i] > 0) decomp.cols[b].push_back(i);
extend(tc, cols, rows);
boolvector update(tilecnt);
//fprintf(stderr, "%d %d\n", cols.size(), rows.size());
for (uint32_t i = 0; i < cols.size(); i++) {
//printf("%d %d\n", basis.cols[b][i], dep.cols.size());
uintvector & c = dep.cols[cols[i]];
for (uint32_t j = 0; j < c.size(); j++) {
update[c[j]] = true;
}
}
for (uint32_t i = 0; i < basis.cols[best].size(); i++) {
uint32_t c = basis.cols[best][i];
uintvector out;
index_union(decomp.cols[b], covered.cols[c], out);
covered.cols[c].swap(out);
//printf("%d %d\n", c, covered.cols[c].size());
for (uint32_t i = 0; i < tilecnt; i++) {
if (i % 100 == 0) fprintf(stderr, "%d %d \r", b, i);
if (update[i]) {
scores[i] = find_rows(basis.cols[i], d, tc, weight, cands[i]);
}
}
scores[best] = 0;
}
// Compute stats
for (uint32_t i = 0; i < d.cols.size(); i++) {
uint32_t size = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), covered.cols[i].begin(), covered.cols[i].end());
counts[3] += size;
counts[2] += covered.cols[i].size() - size;
counts[1] += d.cols[i].size() - size;
counts[0] += d.rowcnt - covered.cols[i].size() - d.cols[i].size() + size;
counts[1] += d.cols[i].size();
}
for (uint32_t i = 0; i < tc.counts.size(); i++) {
counts[2] += tc.counts[i]*double(tc.rows[i].size());
//printf("%d %d\n", tc.counts[i], tc.rows[i].size());
}
counts[2] -= counts[3];
counts[0] = double(d.rowcnt) * d.cols.size() - counts[1] - counts[2] - counts[3];
}
void
read_sparse(FILE *f, dataset & d)
{
......@@ -232,6 +472,7 @@ main(int argc, char **argv)
{"weight", required_argument, NULL, 'w'},
{"out", required_argument, NULL, 'o'},
{"in", required_argument, NULL, 'i'},
{"basis", required_argument, NULL, 'b'},
{"order", required_argument, NULL, 'k'},
{"help", no_argument, NULL, 'h'},
{ NULL, 0, NULL, 0 }
......@@ -239,18 +480,20 @@ main(int argc, char **argv)
char *inname = NULL;
char *outname = NULL;
char *basisname = NULL;
double threshold = 1;
double weight = 1;
uint32_t k = 1;
int ch;
while ((ch = getopt_long(argc, argv, "ho:i:w:t:k:", longopts, NULL)) != -1) {
while ((ch = getopt_long(argc, argv, "ho:i:w:t:k:b:", longopts, NULL)) != -1) {
switch (ch) {
case 'h':
printf("Usage: %s -i <input file> -o <output file> [-k <order>] [-t <threshold>] [-w <weight>] [-h]\n", argv[0]);
printf(" -h print this help\n");
printf(" -i input file\n");
printf(" -o output file\n");
printf(" -b basis file\n");
printf(" -w weight for covering 1 (default = 1)\n");
printf(" -k decomposition order(default = 1)\n");
printf(" -t threshold for combining columns when searching for candidate basis (default = 1)\n");
......@@ -262,6 +505,9 @@ main(int argc, char **argv)
case 'o':
outname = optarg;
break;
case 'b':
basisname = optarg;
break;
case 'w':
weight = atof(optarg);
break;
......@@ -283,12 +529,18 @@ main(int argc, char **argv)
if (inname) fclose(f);
dataset basis;
find_basis(d, threshold, basis);
if (basisname) {
f = fopen(basisname, "r");
read_sparse(f, basis);
}
else
find_basis(d, threshold, basis);
dataset decomp;
uintvector selected;
uintvector counts;
doublevector counts;
cover(d, basis, weight, k, selected, decomp, counts);
//print_sparse(stdout, basis);
......@@ -296,7 +548,7 @@ main(int argc, char **argv)
FILE *out = stdout;
if (outname) out = fopen(outname, "w");
fprintf(out, "%d %d %d %d\n", counts[0], counts[1], counts[2], counts[3]);
fprintf(out, "%.0f %.0f %.0f %.0f\n", counts[0], counts[1], counts[2], counts[3]);
print_sparse(out, basis, selected);
print_sparse(out, decomp);
......
......@@ -6,6 +6,7 @@
typedef std::vector<uint32_t> uintvector;
typedef std::vector<double> doublevector;
typedef std::vector<bool> boolvector;
typedef std::vector<uintvector> uintmatrix;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment