Skip to content
Snippets Groups Projects
asso.cpp 7.54 KiB
#include "dataset.h"
#include <algorithm>
#include <stdio.h>
#include <getopt.h>
#include <math.h>


template<class InputIt1, class InputIt2>
uint32_t intersection_cnt(InputIt1 first1, InputIt1 last1,
                          InputIt2 first2, InputIt2 last2)
{
	uint32_t cnt = 0;
    while (first1 != last1 && first2 != last2) {
        if (*first1 < *first2) {
            ++first1;
        } else  {
            if (!(*first2 < *first1)) {
				cnt++;
                ++first1;
            }
            ++first2;
        }
    }
    return cnt;
}

void
index_union(const uintvector & a, const uintvector & b, uintvector & out)
{
	uint32_t size = a.size() + b.size() - intersection_cnt(a.begin(), a.end(), b.begin(), b.end());
	out.resize(0);
	out.reserve(size);
	std::set_union(a.begin(), a.end(), b.begin(), b.end(), std::back_inserter(out));
}

void
find_basis(const dataset & d, double threshold, dataset & basis)
{
	uint32_t dim = d.cols.size();

	basis.cols.resize(dim);
	basis.rowcnt = dim;

	for (uint32_t i = 0; i < dim; i++) {
		doublevector c(dim);
		for (uint32_t j = 0; j < dim; j++) {
			c[j] = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), d.cols[j].begin(), d.cols[j].end());
			c[j] /= d.cols[i].size();
		}

		uint32_t size = 0;
		for (uint32_t j = 0; j < dim; j++) {
			if (c[j] >= threshold) size++;
		}
		basis.cols[i].reserve(size);
		for (uint32_t j = 0; j < dim; j++) {
			if (c[j] >= threshold) {
				basis.cols[i].push_back(j);
			}
		}
	}
}

void
count_column(const uintvector & dcol, const uintvector & ccol, double weight, doublevector & score)
{
	uint32_t ind1 = 0;
	uint32_t ind2 = 0;
	while (ind1 < ccol.size() && ind2 < dcol.size()) {
		if (ccol[ind1] <= dcol[ind2]) {
			if (ccol[ind1] == dcol[ind2]) ind2++;
			score[ccol[ind1]] += 1;
			ind1++;
			//printf("foo\n");
		}
		else {
			score[dcol[ind2]] += weight + 1;
			ind2++;
		}
	}
	for (;ind1 < ccol.size(); ind1++) {
		score[ccol[ind1]] += 1;
	}
	for (;ind2 < dcol.size(); ind2++) {
		score[dcol[ind2]] += weight + 1;
	}
}

void
cover(const dataset & d, const dataset & basis, double weight, uint32_t k, uintvector & selected, dataset & decomp, uintvector & counts)
{
	decomp.cols.resize(k);
	decomp.rowcnt = d.rowcnt;
	selected.resize(k);

	dataset covered;
	covered.cols.resize(d.cols.size());

	counts.resize(4);


	for (uint32_t b = 0; b < k; b++) {
		doublevector best_score;
		double best_gain = -1;
		uint32_t best = -1;
		for (uint32_t i = 0; i < basis.cols.size(); i++) {
			doublevector cand_score(d.rowcnt, -double(basis.cols[i].size())); 
			//printf("%d\n", i);
			for (uint32_t j = 0; j < basis.cols[i].size(); j++) {
				uint32_t c = basis.cols[i][j];
				const uintvector & dcol = d.cols[c];
				const uintvector & ccol = covered.cols[c];
				count_column(dcol, ccol, weight, cand_score);
			}

			double cand_gain = 0;
			for (uint32_t j = 0; j < cand_score.size(); j++) {
				if (cand_score[j] > 0) cand_gain += cand_score[j];
			}

			if (cand_gain > best_gain) {
				best_score.swap(cand_score);
				best_gain = cand_gain;
				best = i;
			}
			//printf("%f %d\n", cand_gain, best);
		}
		//printf("BEST: %d\n", best);

		selected[b] = best;
		uint32_t size = 0;
		for (uint32_t i = 0; i < best_score.size(); i++) {
			//printf("%f ", best_score[i]);
			if (best_score[i] > 0) size++;
		}

		decomp.cols[b].resize(0);
		decomp.cols[b].reserve(size);
		for (uint32_t i = 0; i < best_score.size(); i++) {
			if (best_score[i] > 0) decomp.cols[b].push_back(i);
		}

		for (uint32_t i = 0; i < basis.cols[best].size(); i++) {
			uint32_t c = basis.cols[best][i];
			uintvector out;
			index_union(decomp.cols[b], covered.cols[c], out);
			covered.cols[c].swap(out);
			//printf("%d %d\n", c, covered.cols[c].size());
		}
	}

	// Compute stats
	for (uint32_t i = 0; i < d.cols.size(); i++) {
		uint32_t size =  intersection_cnt(d.cols[i].begin(), d.cols[i].end(), covered.cols[i].begin(), covered.cols[i].end());
		counts[3] += size;
		counts[2] += covered.cols[i].size() - size;
		counts[1] += d.cols[i].size() - size;
		counts[0] += d.rowcnt - covered.cols[i].size() - d.cols[i].size() + size;
	}
}


void
read_sparse(FILE *f, dataset & d)
{
	uint32_t dim;
	fscanf(f, "%d%d", &d.rowcnt, &dim);
	d.cols.resize(dim);

	uintvector counts(dim);

	for (uint32_t i = 0; i < d.rowcnt; i++) {
		uint32_t cnt;
		fscanf(f, "%d", &cnt);
		for (uint32_t j = 0; j < cnt; j++) {
			uint32_t ind;
			fscanf(f, "%d", &ind);
			counts[ind - 1]++;
		}
	}

	for (uint32_t i = 0; i < dim; i++)
		d.cols[i].reserve(counts[i]);
	
	rewind(f);
	fscanf(f, "%*d%*d");
	for (uint32_t i = 0; i < d.rowcnt; i++) {
		uint32_t cnt;
		fscanf(f, "%d", &cnt);
		for (uint32_t j = 0; j < cnt; j++) {
			uint32_t ind;
			fscanf(f, "%d", &ind);
			d.cols[ind - 1].push_back(i);
		}
	}
}

void
print_sparse(FILE *f, const dataset & d)
{
	fprintf(f, "%lu\n%d\n", d.cols.size(), d.rowcnt);

	for (uint32_t i = 0; i < d.cols.size(); i++) {
		fprintf(f, "%lu", d.cols[i].size());
		for (uint32_t j = 0; j < d.cols[i].size(); j++) {
			fprintf(f, " %d", 1 + d.cols[i][j]);
		}
		fprintf(f, "\n");
	}
}

void
print_sparse(FILE *f, const dataset & d, const uintvector & selected)
{
	fprintf(f, "%lu\n%d\n", selected.size(), d.rowcnt);

	for (uint32_t i = 0; i < selected.size(); i++) {
		fprintf(f, "%lu", d.cols[selected[i]].size());
		for (uint32_t j = 0; j < d.cols[selected[i]].size(); j++) {
			fprintf(f, " %d", 1 + d.cols[selected[i]][j]);
		}
		fprintf(f, "\n");
	}
}


int
main(int argc, char **argv)
{
    static struct option longopts[] = {
        {"threshold",       required_argument,  NULL, 't'},
        {"weight",          required_argument,  NULL, 'w'},
        {"out",             required_argument,  NULL, 'o'},
        {"in",              required_argument,  NULL, 'i'},
        {"order",           required_argument,  NULL, 'k'},
		{"help",            no_argument,        NULL, 'h'},
        { NULL,             0,                  NULL,  0 }
    };

	char *inname = NULL;
	char *outname = NULL;
	double threshold = 1;
	double weight = 1;
	uint32_t k = 1;

    int ch;
    while ((ch = getopt_long(argc, argv, "ho:i:w:t:k:", longopts, NULL)) != -1) {
        switch (ch) {
            case 'h':
                printf("Usage: %s -i <input file> -o <output file> [-k <order>] [-t <threshold>] [-w <weight>] [-h]\n", argv[0]);
                printf("  -h    print this help\n");
                printf("  -i    input file\n");
                printf("  -o    output file\n");
                printf("  -w    weight for covering 1 (default = 1)\n");
                printf("  -k    decomposition order(default = 1)\n");
                printf("  -t    threshold for combining columns when searching for candidate basis (default = 1)\n");
                return 0;
                break;
            case 'i':
                inname = optarg;
                break;
            case 'o':
                outname = optarg;
                break;
            case 'w':
				weight = atof(optarg);
                break;
            case 'k':
				k = atoi(optarg);
                break;
            case 't':
				threshold = atof(optarg);
                break;
        }
    }


	dataset d;

	FILE *f = stdin;
	if (inname) f = fopen(inname, "r");
	read_sparse(f, d);
	if (inname) fclose(f);

	dataset basis;
	find_basis(d, threshold, basis);


	dataset decomp;
	uintvector selected;
	uintvector counts;

	cover(d, basis, weight, k, selected, decomp, counts);
	//print_sparse(stdout, basis);

	FILE *out = stdout;
	if (outname) out = fopen(outname, "w");

	fprintf(out, "%d %d %d %d\n", counts[0], counts[1], counts[2], counts[3]);
	print_sparse(out, basis, selected);
	print_sparse(out, decomp);

	if (outname) fclose(out);

	return 0;
}