Jean-Francois Leveque

Calcul de précision et rappel

......@@ -10,6 +10,7 @@ public class PostprocessingExpert {
List<PostprocessingSample> sampleList;
List<PostprocessingSample> recommendationList;
List<PostprocessingSample> annotatedList;
Logger logger = LoggerFactory.getLogger(getClass());
Set<Long> sampleItemIds;
......@@ -19,19 +20,30 @@ public class PostprocessingExpert {
int recommendableItemCount;
int recommendedItemCount;
int recommendableItemUserCount;
int annotatedItemUserCount;
int recommendedItemUserCount;
int validRecommendationCount;
public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList) {
public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList,
List<PostprocessingSample> annotatedList) {
this.sampleList = sampleList;
this.recommendationList = recommendationList;
this.annotatedList = annotatedList;
}
public PostprocessingCoverage getCoverage() {
public void analyze() {
analyzeSample();
analyzeRecommendations();
}
public PostprocessingCoverage getCoverage() {
return computeCoverage();
}
public PostprocessingPrecisionRecall getPrecisionRecall() {
return computePrecisionRecall();
}
protected void analyzeSample() {
sampleItemIds = new HashSet<>();
......@@ -55,8 +67,8 @@ public class PostprocessingExpert {
}
recommendableItemCount = sampleItemIds.size();
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
logger.trace("Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size());
logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount);
logger.trace("C: Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size());
int sampleCoupleCount = 0;
for (Long itemId : sampleItemIds) {
......@@ -64,13 +76,18 @@ public class PostprocessingExpert {
}
recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount;
logger.trace("Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
logger.trace("C: Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount);
logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount);
}
protected void analyzeRecommendations() {
recommendedItemUserCount = 0;
validRecommendationCount = 0;
recommendedItemIds = new HashSet<>();
for (PostprocessingSample annote : annotatedList) {
logger.trace("Annotated item {}, user {}", annote.getItemId(), annote.getUserId());
}
for (PostprocessingSample reco : recommendationList) {
Long itemId = reco.getItemId();
Long userId = reco.getUserId();
......@@ -81,10 +98,17 @@ public class PostprocessingExpert {
recommendedItemUserCount++;
}
}
logger.trace("Recommendation item {}, user {}", reco.getItemId(), reco.getUserId());
if (annotatedList.contains(reco)) {
validRecommendationCount++;
}
}
recommendedItemCount = recommendedItemIds.size();
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("C/PR: Nombre de couples item-user recommandés {}", recommendedItemUserCount);
annotatedItemUserCount = annotatedList.size();
logger.trace("PR: Nombre d'associations annotées {}", annotatedItemUserCount);
logger.trace("PR: Nombre de recommandations annotées {}", validRecommendationCount);
}
protected PostprocessingCoverage computeCoverage() {
......@@ -92,12 +116,12 @@ public class PostprocessingExpert {
float c2;
int c3;
logger.trace("Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("Nombre d'objets recommandables {}", recommendableItemCount);
logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount);
logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount);
c1 = (float) recommendedItemCount / recommendableItemCount;
logger.trace("c1 {}", String.format(Locale.FRENCH, "%.3f", c1));
logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount);
logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount);
logger.trace("C: Nombre de couples item-user recommandés {}", recommendedItemUserCount);
logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount);
c2 = (float) recommendedItemUserCount / recommendableItemUserCount;
logger.trace("c2 {}", String.format(Locale.FRENCH, "%.3f", c2));
c3 = recommendedItemCount;
......@@ -106,8 +130,19 @@ public class PostprocessingExpert {
return new PostprocessingCoverage(c1,c2, c3);
}
protected PostprocessingPrecisionRecall computePrecisionRecall() {
float precision;
float recall;
logger.trace("PR: nombre de recommandations annotées {}", validRecommendationCount);
logger.trace("PR: nombre de recommandations {}", recommendedItemUserCount);
precision = (float) validRecommendationCount / recommendedItemUserCount;
logger.trace("PR: précision {}", String.format(Locale.FRENCH, "%.3f", precision));
logger.trace("PR: nombre d'associations annotées {}", annotatedItemUserCount);
recall = (float) validRecommendationCount / annotatedItemUserCount;
logger.trace("PR: rappel {}", String.format(Locale.FRENCH, "%.3f", recall));
return new PostprocessingPrecisionRecall(precision, recall);
}
}
......
package org.legrog.recommendation.postprocess;
public class PostprocessingPrecisionRecall {
private float precision;
private float recall;
public PostprocessingPrecisionRecall(float precision, float recall) {
this.precision = precision;
this.recall = recall;
}
public float getPrecision() {
return precision;
}
public float getRecall() {
return recall;
}
}
......@@ -11,10 +11,7 @@ import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.io.*;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
......@@ -33,14 +30,24 @@ public class PostprocessingRunner implements ApplicationRunner {
@Value("${ratingSample.filename}")
private String ratingSampleFilename;
@Value("${collectionAnnotated.filename}")
private String collectionAnnotatedFilename;
@Value("${ratingAnnotated.filename}")
private String ratingAnnotatedFilename;
@Value("${recommandations.filename}")
private String recommandationsFilename;
@Value("${coverage.filename}")
private String coverageFilename;
@Value("${precisionRecall.filename}")
private String precisionRecallFilename;
private Logger logger = LoggerFactory.getLogger(getClass());
private String sampleFilename;
private String annotatedFilename;
@Override
......@@ -49,11 +56,15 @@ public class PostprocessingRunner implements ApplicationRunner {
loadSampleFilename();
List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename));
List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename));
List<PostprocessingSample> annotated = loadCsvSample(new File(dataDir, annotatedFilename));
PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations);
PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations, annotated);
expert.analyze();
PostprocessingCoverage coverage = expert.getCoverage();
PostprocessingPrecisionRecall precisionRecall = expert.getPrecisionRecall();
writeCsvCoverage(coverage, dataDir, coverageFilename);
writeCsvPrecisionRecall(precisionRecall, dataDir, precisionRecallFilename);
}
private void writeCsvCoverage(PostprocessingCoverage coverage, String dataDir, String coverageFilename) throws PostprocessingException {
......@@ -69,6 +80,19 @@ public class PostprocessingRunner implements ApplicationRunner {
}
private void writeCsvPrecisionRecall(PostprocessingPrecisionRecall precisionRecall, String dataDir, String precisionRecallFilename) throws PostprocessingException {
try {
CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(new File(dataDir, precisionRecallFilename)),
CSVFormat.TDF.withHeader("Precision", "Recall"));
csvPrinter.printRecord(String.format(Locale.FRENCH, "%.3f", precisionRecall.getPrecision()),
String.format(Locale.FRENCH, "%.3f", precisionRecall.getRecall()));
csvPrinter.close();
} catch (IOException e) {
throw new PostprocessingException("Can't write coverage file " + dataDir + precisionRecallFilename, e);
}
}
/**
* read csv (TDF) file and map it to a list of PostprocessingSample
*
......@@ -77,6 +101,10 @@ public class PostprocessingRunner implements ApplicationRunner {
* @throws PostprocessingException
*/
private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException {
if (!file.exists() || file.isDirectory()) {
return new LinkedList<>();
}
try (Reader in = new InputStreamReader(new FileInputStream(file))) {
Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
......@@ -109,12 +137,15 @@ public class PostprocessingRunner implements ApplicationRunner {
logger.trace("ratings {}", properties.getProperty("ratings"));
if (Boolean.parseBoolean(properties.getProperty("ratings"))) {
sampleFilename = ratingSampleFilename;
annotatedFilename = ratingAnnotatedFilename;
} else {
sampleFilename = collectionSampleFilename;
annotatedFilename = collectionAnnotatedFilename;
}
} else {
// by default, takes collection
sampleFilename = collectionSampleFilename;
annotatedFilename = collectionAnnotatedFilename;
}
} catch (IOException e) {
throw new PostprocessingException("Can't read properties file " + parametersFilename, e);
......
......@@ -9,6 +9,15 @@ public class PostprocessingSample {
this.itemId = itemId;
}
public boolean equals(Object obj) {
if (obj instanceof PostprocessingSample) {
PostprocessingSample postprocessingSample = (PostprocessingSample) obj;
return this.itemId == postprocessingSample.getItemId() && this.userId == postprocessingSample.getUserId();
} else {
return false;
}
}
public Long getUserId() {
return userId;
}
......