-
Nikolaj Tatti authoredNikolaj Tatti authored
asso.cpp 7.54 KiB
#include "dataset.h"
#include <algorithm>
#include <stdio.h>
#include <getopt.h>
#include <math.h>
template<class InputIt1, class InputIt2>
uint32_t intersection_cnt(InputIt1 first1, InputIt1 last1,
InputIt2 first2, InputIt2 last2)
{
uint32_t cnt = 0;
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
++first1;
} else {
if (!(*first2 < *first1)) {
cnt++;
++first1;
}
++first2;
}
}
return cnt;
}
void
index_union(const uintvector & a, const uintvector & b, uintvector & out)
{
uint32_t size = a.size() + b.size() - intersection_cnt(a.begin(), a.end(), b.begin(), b.end());
out.resize(0);
out.reserve(size);
std::set_union(a.begin(), a.end(), b.begin(), b.end(), std::back_inserter(out));
}
void
find_basis(const dataset & d, double threshold, dataset & basis)
{
uint32_t dim = d.cols.size();
basis.cols.resize(dim);
basis.rowcnt = dim;
for (uint32_t i = 0; i < dim; i++) {
doublevector c(dim);
for (uint32_t j = 0; j < dim; j++) {
c[j] = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), d.cols[j].begin(), d.cols[j].end());
c[j] /= d.cols[i].size();
}
uint32_t size = 0;
for (uint32_t j = 0; j < dim; j++) {
if (c[j] >= threshold) size++;
}
basis.cols[i].reserve(size);
for (uint32_t j = 0; j < dim; j++) {
if (c[j] >= threshold) {
basis.cols[i].push_back(j);
}
}
}
}
void
count_column(const uintvector & dcol, const uintvector & ccol, double weight, doublevector & score)
{
uint32_t ind1 = 0;
uint32_t ind2 = 0;
while (ind1 < ccol.size() && ind2 < dcol.size()) {
if (ccol[ind1] <= dcol[ind2]) {
if (ccol[ind1] == dcol[ind2]) ind2++;
score[ccol[ind1]] += 1;
ind1++;
//printf("foo\n");
}
else {
score[dcol[ind2]] += weight + 1;
ind2++;
}
}
for (;ind1 < ccol.size(); ind1++) {
score[ccol[ind1]] += 1;
}
for (;ind2 < dcol.size(); ind2++) {
score[dcol[ind2]] += weight + 1;
}
}
void
cover(const dataset & d, const dataset & basis, double weight, uint32_t k, uintvector & selected, dataset & decomp, uintvector & counts)
{
decomp.cols.resize(k);
decomp.rowcnt = d.rowcnt;
selected.resize(k);
dataset covered;
covered.cols.resize(d.cols.size());
counts.resize(4);
for (uint32_t b = 0; b < k; b++) {
doublevector best_score;
double best_gain = -1;
uint32_t best = -1;
for (uint32_t i = 0; i < basis.cols.size(); i++) {
doublevector cand_score(d.rowcnt, -double(basis.cols[i].size()));
//printf("%d\n", i);
for (uint32_t j = 0; j < basis.cols[i].size(); j++) {
uint32_t c = basis.cols[i][j];
const uintvector & dcol = d.cols[c];
const uintvector & ccol = covered.cols[c];
count_column(dcol, ccol, weight, cand_score);
}
double cand_gain = 0;
for (uint32_t j = 0; j < cand_score.size(); j++) {
if (cand_score[j] > 0) cand_gain += cand_score[j];
}
if (cand_gain > best_gain) {
best_score.swap(cand_score);
best_gain = cand_gain;
best = i;
}
//printf("%f %d\n", cand_gain, best);
}
//printf("BEST: %d\n", best);
selected[b] = best;
uint32_t size = 0;
for (uint32_t i = 0; i < best_score.size(); i++) {
//printf("%f ", best_score[i]);
if (best_score[i] > 0) size++;
}
decomp.cols[b].resize(0);
decomp.cols[b].reserve(size);
for (uint32_t i = 0; i < best_score.size(); i++) {
if (best_score[i] > 0) decomp.cols[b].push_back(i);
}
for (uint32_t i = 0; i < basis.cols[best].size(); i++) {
uint32_t c = basis.cols[best][i];
uintvector out;
index_union(decomp.cols[b], covered.cols[c], out);
covered.cols[c].swap(out);
//printf("%d %d\n", c, covered.cols[c].size());
}
}
// Compute stats
for (uint32_t i = 0; i < d.cols.size(); i++) {
uint32_t size = intersection_cnt(d.cols[i].begin(), d.cols[i].end(), covered.cols[i].begin(), covered.cols[i].end());
counts[3] += size;
counts[2] += covered.cols[i].size() - size;
counts[1] += d.cols[i].size() - size;
counts[0] += d.rowcnt - covered.cols[i].size() - d.cols[i].size() + size;
}
}
void
read_sparse(FILE *f, dataset & d)
{
uint32_t dim;
fscanf(f, "%d%d", &d.rowcnt, &dim);
d.cols.resize(dim);
uintvector counts(dim);
for (uint32_t i = 0; i < d.rowcnt; i++) {
uint32_t cnt;
fscanf(f, "%d", &cnt);
for (uint32_t j = 0; j < cnt; j++) {
uint32_t ind;
fscanf(f, "%d", &ind);
counts[ind - 1]++;
}
}
for (uint32_t i = 0; i < dim; i++)
d.cols[i].reserve(counts[i]);
rewind(f);
fscanf(f, "%*d%*d");
for (uint32_t i = 0; i < d.rowcnt; i++) {
uint32_t cnt;
fscanf(f, "%d", &cnt);
for (uint32_t j = 0; j < cnt; j++) {
uint32_t ind;
fscanf(f, "%d", &ind);
d.cols[ind - 1].push_back(i);
}
}
}
void
print_sparse(FILE *f, const dataset & d)
{
fprintf(f, "%lu\n%d\n", d.cols.size(), d.rowcnt);
for (uint32_t i = 0; i < d.cols.size(); i++) {
fprintf(f, "%lu", d.cols[i].size());
for (uint32_t j = 0; j < d.cols[i].size(); j++) {
fprintf(f, " %d", 1 + d.cols[i][j]);
}
fprintf(f, "\n");
}
}
void
print_sparse(FILE *f, const dataset & d, const uintvector & selected)
{
fprintf(f, "%lu\n%d\n", selected.size(), d.rowcnt);
for (uint32_t i = 0; i < selected.size(); i++) {
fprintf(f, "%lu", d.cols[selected[i]].size());
for (uint32_t j = 0; j < d.cols[selected[i]].size(); j++) {
fprintf(f, " %d", 1 + d.cols[selected[i]][j]);
}
fprintf(f, "\n");
}
}
int
main(int argc, char **argv)
{
static struct option longopts[] = {
{"threshold", required_argument, NULL, 't'},
{"weight", required_argument, NULL, 'w'},
{"out", required_argument, NULL, 'o'},
{"in", required_argument, NULL, 'i'},
{"order", required_argument, NULL, 'k'},
{"help", no_argument, NULL, 'h'},
{ NULL, 0, NULL, 0 }
};
char *inname = NULL;
char *outname = NULL;
double threshold = 1;
double weight = 1;
uint32_t k = 1;
int ch;
while ((ch = getopt_long(argc, argv, "ho:i:w:t:k:", longopts, NULL)) != -1) {
switch (ch) {
case 'h':
printf("Usage: %s -i <input file> -o <output file> [-k <order>] [-t <threshold>] [-w <weight>] [-h]\n", argv[0]);
printf(" -h print this help\n");
printf(" -i input file\n");
printf(" -o output file\n");
printf(" -w weight for covering 1 (default = 1)\n");
printf(" -k decomposition order(default = 1)\n");
printf(" -t threshold for combining columns when searching for candidate basis (default = 1)\n");
return 0;
break;
case 'i':
inname = optarg;
break;
case 'o':
outname = optarg;
break;
case 'w':
weight = atof(optarg);
break;
case 'k':
k = atoi(optarg);
break;
case 't':
threshold = atof(optarg);
break;
}
}
dataset d;
FILE *f = stdin;
if (inname) f = fopen(inname, "r");
read_sparse(f, d);
if (inname) fclose(f);
dataset basis;
find_basis(d, threshold, basis);
dataset decomp;
uintvector selected;
uintvector counts;
cover(d, basis, weight, k, selected, decomp, counts);
//print_sparse(stdout, basis);
FILE *out = stdout;
if (outname) out = fopen(outname, "w");
fprintf(out, "%d %d %d %d\n", counts[0], counts[1], counts[2], counts[3]);
print_sparse(out, basis, selected);
print_sparse(out, decomp);
if (outname) fclose(out);
return 0;
}