data {
  int<lower=1> D; // Dimensions
  int<lower=0> N_obs;  // Number of "observed observations" (with T = 1)
  int<lower=0> N_cens; // Number of "censored observations" (with T = 0)
  int<lower=1> M; // Number of judges
  int<lower=1> K; // Number of groups
  real<lower=0> sigma_tau;
  int<lower=1, upper=M> jj_obs[N_obs];   // judge_ID
  int<lower=1, upper=M> jj_cens[N_cens]; // judge_ID
  int<lower=1, upper=K> kk_obs[N_obs];   // Grouping
  int<lower=1, upper=K> kk_cens[N_cens]; // Grouping
  int<lower=0, upper=1> dec_obs[N_obs];   // Positive decisions
  int<lower=0, upper=1> dec_cens[N_cens]; // Negative decisions
  row_vector[D] X_obs[N_obs];   // Private features of "observed observations" (with T = 1)
  row_vector[D] X_cens[N_cens]; // Private features of "censored observations" (with T = 1)
  int<lower=0, upper=1> y_obs[N_obs]; // Observed outcomes
}

parameters {
  vector[N_obs] Z_obs;
  vector[N_cens] Z_cens;
  real alpha_T[M];  // Judge-specific intercepts
  
  vector[D] beta_XT[K]; 
  vector[D] beta_XY[K];
  
  vector<lower=0>[K] beta_ZT_raw; // Coefficient for the latent variable.
  vector<lower=0>[K] beta_ZY_raw; // Coefficient for the latent variable.
  
  real<lower=0> tau_XT;
  real<lower=0> tau_XY;
  real<lower=0> tau_ZT;
  real<lower=0> tau_ZY;
}

transformed parameters{
  vector<lower=0>[K] beta_ZT_cumulative;
  vector<lower=0>[K] beta_ZY_cumulative;
  
  vector<lower=0>[K] beta_ZT;
  vector<lower=0>[K] beta_ZY;
  
  beta_ZT[1] = beta_ZT_raw[1];
  beta_ZY[1] = beta_ZY_raw[1];
  
  if(K >= 2){
    beta_ZT[2:] = tau_ZT * beta_ZT_raw[2:];
    beta_ZY[2:] = tau_ZY * beta_ZY_raw[2:];
  }
  
  beta_ZT_cumulative = cumulative_sum(beta_ZT);
  beta_ZY_cumulative = cumulative_sum(beta_ZY);
}

model {
  Z_obs  ~ std_normal();
  Z_cens ~ std_normal();
  
  tau_XY ~ normal(0, sigma_tau);
  tau_XT ~ normal(0, sigma_tau);
  tau_ZY ~ normal(0, sigma_tau);
  tau_ZT ~ normal(0, sigma_tau);
  
  beta_XT[1]  ~ std_normal(); // first group
  beta_XY[1]  ~ std_normal();
  beta_ZT_raw ~ std_normal();
  beta_ZY_raw ~ std_normal();
  
  if(K >= 2){
    for(i in 2:K){ // random walk prior here
      beta_XT[i] ~ normal(beta_XT[i-1], tau_XT); // ith group
      beta_XY[i] ~ normal(beta_XY[i-1], tau_XY);
    }
  }

  for(i in 1:N_obs){
    dec_obs[i] ~ bernoulli_logit(alpha_T[jj_obs[i]] + X_obs[i] * beta_XT[kk_obs[i]] + beta_ZT_cumulative[kk_obs[i]] * Z_obs[i]);
    y_obs[i] ~ bernoulli_logit(X_obs[i] * beta_XY[kk_obs[i]] + beta_ZY_cumulative[kk_obs[i]] * Z_obs[i]);
  }
  
  for(i in 1:N_cens)
    dec_cens[i] ~ bernoulli_logit(alpha_T[jj_cens[i]] + X_cens[i] * beta_XT[kk_cens[i]] + beta_ZT_cumulative[kk_cens[i]] * Z_cens[i]);
}

generated quantities {
  real<lower=0, upper=1> y_est[N_cens];
  
  for(i in 1:N_cens){
    y_est[i] = inv_logit(X_cens[i] * beta_XY[kk_cens[i]] + beta_ZY_cumulative[kk_cens[i]] * Z_cens[i]);
  }
}