JR Utily

refactor post processing

package org.legrog.recommendation.postprocess;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.orm.jpa.HibernateJpaAutoConfiguration;
import org.springframework.context.annotation.Bean;
import java.io.*;
import java.util.*;
@SpringBootApplication
public class PostprocessingApplication {
Logger logger = LoggerFactory.getLogger(getClass());
@Value("${parameters.filename}")
String parametersFilename;
@Value("${data.dir}")
String dataDir;
@Value("${collectionSample.filename}")
String collectionSampleFilename;
@Value("${ratingSample.filename}")
String ratingSampleFilename;
@Value("${recommandations.filename}")
String recommandationsFilename;
String sampleFilename;
Properties properties;
Set<Long> sampleItemIds;
Set<Long> recommendedItemIds;
Set<Long> sampleUserIds;
Map<Long, Set<Long>> sampleItemUserIds;
int recommendableItemCount;
int recommendedItemCount;
int recommendableItemUserCount;
int recommendedItemUserCount;
public static void main(String[] args) {
SpringApplication.run(PostprocessingApplication.class, args);
}
@Bean
public CommandLineRunner postprocess() {
return (args) -> this.run();
}
public void run() {
loadParametersProperties();
loadSampleFilename();
analyzeSample();
analyzeRecommendations();
computeCoverage();
}
void computeCoverage() {
float c1;
float c2;
int c3;
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
c1 = (float) recommendedItemCount / recommendableItemCount;
logger.trace("c1 {}", String.format("%.3f", c1));
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
c2 = (float) recommendedItemUserCount / recommendableItemUserCount;
logger.trace("c2 {}", String.format("%.3f", c2));
c3 = recommendedItemCount;
logger.trace("c3 {}", c3);
}
void analyzeRecommendations() {
Reader in = null;
try {
recommendedItemUserCount = 0;
recommendedItemIds = new HashSet<>();
in = new InputStreamReader(new FileInputStream(dataDir + recommandationsFilename));
Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
for (CSVRecord record : records) {
Long itemId = Long.parseLong(record.get("itemId"));
Long userId = Long.parseLong(record.get("userId"));
recommendedItemIds.add(itemId);
if (sampleItemUserIds.containsKey(itemId)) {
Set<Long> set = sampleItemUserIds.get(itemId);
if (!set.contains(userId)) {
recommendedItemUserCount++;
}
}
}
recommendedItemCount = recommendedItemIds.size();
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
} catch (IOException e) {
logger.error("analyzeRecommendations IOException : {}", e.getStackTrace());
}
}
void analyzeSample() {
Reader in = null;
try {
sampleItemIds = new HashSet<>();
sampleUserIds = new HashSet<>();
sampleItemUserIds = new HashMap<>();
in = new InputStreamReader(new FileInputStream(dataDir + sampleFilename));
Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
for (CSVRecord record : records) {
Long itemId = Long.parseLong(record.get("itemId"));
Long userId = Long.parseLong(record.get("userId"));
sampleItemIds.add(itemId);
sampleUserIds.add(userId);
if (!sampleItemUserIds.containsKey(itemId)) {
Set<Long> set = new HashSet<>();
set.add(userId);
sampleItemUserIds.put(itemId, set);
} else {
Set<Long> set = sampleItemUserIds.get(itemId);
set.add(userId);
sampleItemUserIds.put(itemId, set);
}
}
recommendableItemCount = sampleItemIds.size();
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
logger.trace("Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size());
int sampleCoupleCount = 0;
for (Long itemId : sampleItemIds) {
sampleCoupleCount += sampleItemUserIds.get(itemId).size();
}
recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount;
logger.trace("Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
} catch (IOException e) {
logger.error("analyzeSample IOException : {}", e.getStackTrace());
}
}
void loadSampleFilename() {
if (!properties.containsKey("ratings")) {
return;
} else {
logger.trace("ratings {}", properties.getProperty("ratings"));
if (Boolean.parseBoolean(properties.getProperty("ratings"))) {
sampleFilename = ratingSampleFilename;
} else {
sampleFilename = collectionSampleFilename;
}
logger.trace("sampleFilename {}", sampleFilename);
}
}
void loadParametersProperties() {
Properties properties = new Properties();
InputStream in = null;
try {
in = new FileInputStream(dataDir + parametersFilename);
properties.load(in);
in.close();
} catch (IOException e) {
logger.error("loadParametersProperties IOException : {}", e.getStackTrace());
}
this.properties = properties;
}
}
......
package org.legrog.recommendation.postprocess;
public class PostprocessingCoverage {
private float c1;
private float c2;
private int c3;
public PostprocessingCoverage(float c1, float c2, int c3) {
this.c1 = c1;
this.c2 = c2;
this.c3 = c3;
}
public float getC1() {
return c1;
}
public float getC2() {
return c2;
}
public int getC3() {
return c3;
}
}
package org.legrog.recommendation.postprocess;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
public class PostprocessingExpert {
List<PostprocessingSample> sampleList;
List<PostprocessingSample> recommendationList;
Logger logger = LoggerFactory.getLogger(getClass());
Set<Long> sampleItemIds;
Set<Long> recommendedItemIds;
Set<Long> sampleUserIds;
Map<Long, Set<Long>> sampleItemUserIds;
int recommendableItemCount;
int recommendedItemCount;
int recommendableItemUserCount;
int recommendedItemUserCount;
public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList) {
this.sampleList = sampleList;
this.recommendationList = recommendationList;
}
public PostprocessingCoverage getCoverage() {
analyzeSample();
analyzeRecommendations();
return computeCoverage();
}
protected void analyzeSample() {
sampleItemIds = new HashSet<>();
sampleUserIds = new HashSet<>();
sampleItemUserIds = new HashMap<>();
for (PostprocessingSample sample : sampleList) {
Long itemId = sample.getItemId();
Long userId = sample.getUserId();
sampleItemIds.add(itemId);
sampleUserIds.add(userId);
if (!sampleItemUserIds.containsKey(itemId)) {
Set<Long> set = new HashSet<>();
set.add(userId);
sampleItemUserIds.put(itemId, set);
} else {
Set<Long> set = sampleItemUserIds.get(itemId);
set.add(userId);
sampleItemUserIds.put(itemId, set);
}
}
recommendableItemCount = sampleItemIds.size();
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
logger.trace("Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size());
int sampleCoupleCount = 0;
for (Long itemId : sampleItemIds) {
sampleCoupleCount += sampleItemUserIds.get(itemId).size();
}
recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount;
logger.trace("Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
}
protected void analyzeRecommendations() {
recommendedItemUserCount = 0;
recommendedItemIds = new HashSet<>();
for (PostprocessingSample reco : recommendationList) {
Long itemId = reco.getItemId();
Long userId = reco.getUserId();
recommendedItemIds.add(itemId);
if (sampleItemUserIds.containsKey(itemId)) {
Set<Long> set = sampleItemUserIds.get(itemId);
if (!set.contains(userId)) {
recommendedItemUserCount++;
}
}
}
recommendedItemCount = recommendedItemIds.size();
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
}
protected PostprocessingCoverage computeCoverage() {
float c1;
float c2;
int c3;
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
c1 = (float) recommendedItemCount / recommendableItemCount;
logger.trace("c1 {}", String.format("%.3f", c1));
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
c2 = (float) recommendedItemUserCount / recommendableItemUserCount;
logger.trace("c2 {}", String.format("%.3f", c2));
c3 = recommendedItemCount;
logger.trace("c3 {}", c3);
return new PostprocessingCoverage(c1,c2, c3);
}
}
package org.legrog.recommendation.postprocess;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.io.*;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@Component
public class PostprocessingRunner implements ApplicationRunner {
@Value("${parameters.filename}")
private String parametersFilename;
@Value("${data.dir}")
private String dataDir;
@Value("${collectionSample.filename}")
private String collectionSampleFilename;
@Value("${ratingSample.filename}")
private String ratingSampleFilename;
@Value("${recommandations.filename}")
private String recommandationsFilename;
private Logger logger = LoggerFactory.getLogger(getClass());
private String sampleFilename;
@Override
public void run(ApplicationArguments args) throws Exception {
loadSampleFilename();
List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename));
List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename));
PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations);
PostprocessingCoverage coverage = expert.getCoverage();
//todo write coverage in a file to be read by user
//...
}
/**
* read csv (TDF) file and map it to a list of PostprocessingSample
*
* @param file
* @return
* @throws PostprocessingException
*/
private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException {
try (Reader in = new InputStreamReader(new FileInputStream(file))) {
Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
return StreamSupport.stream(records.spliterator(), false)
.map((record) -> new PostprocessingSample(
Long.parseLong(record.get("userId")),
Long.parseLong(record.get("itemId")))
)
.collect(Collectors.toList());
} catch (IOException e) {
throw new PostprocessingException("Can't read CSV file " + sampleFilename, e);
}
}
/**
* read properties file from application.properties parameter.fileName then search for rating property inside
* depending of which, select rating or collection file as the sample file
*
* todo replace this by a command line switch ?
*
* @throws PostprocessingException
*/
private void loadSampleFilename() throws PostprocessingException {
try (InputStream in = new FileInputStream(new File(dataDir, parametersFilename))) {
Properties properties = new Properties();
properties.load(in);
if (properties.containsKey("ratings")) {
logger.trace("ratings {}", properties.getProperty("ratings"));
if (Boolean.parseBoolean(properties.getProperty("ratings"))) {
sampleFilename = ratingSampleFilename;
} else {
sampleFilename = collectionSampleFilename;
}
} else {
// by default, takes collection
sampleFilename = collectionSampleFilename;
}
} catch (IOException e) {
throw new PostprocessingException("Can't read properties file " + parametersFilename, e);
}
}
private class PostprocessingException extends Exception {
public PostprocessingException() {
super();
}
public PostprocessingException(String message) {
super(message);
}
public PostprocessingException(String message, Throwable cause) {
super(message, cause);
}
public PostprocessingException(Throwable cause) {
super(cause);
}
protected PostprocessingException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
}
package org.legrog.recommendation.postprocess;
public class PostprocessingSample {
private Long userId;
private Long itemId;
public PostprocessingSample(Long userId, Long itemId) {
this.userId = userId;
this.itemId = itemId;
}
public Long getUserId() {
return userId;
}
public Long getItemId() {
return itemId;
}
}