Jean-Francois Leveque

Calcul de précision et rappel

...@@ -10,6 +10,7 @@ public class PostprocessingExpert { ...@@ -10,6 +10,7 @@ public class PostprocessingExpert {
10 10
11 List<PostprocessingSample> sampleList; 11 List<PostprocessingSample> sampleList;
12 List<PostprocessingSample> recommendationList; 12 List<PostprocessingSample> recommendationList;
13 + List<PostprocessingSample> annotatedList;
13 Logger logger = LoggerFactory.getLogger(getClass()); 14 Logger logger = LoggerFactory.getLogger(getClass());
14 15
15 Set<Long> sampleItemIds; 16 Set<Long> sampleItemIds;
...@@ -19,19 +20,30 @@ public class PostprocessingExpert { ...@@ -19,19 +20,30 @@ public class PostprocessingExpert {
19 int recommendableItemCount; 20 int recommendableItemCount;
20 int recommendedItemCount; 21 int recommendedItemCount;
21 int recommendableItemUserCount; 22 int recommendableItemUserCount;
23 + int annotatedItemUserCount;
22 int recommendedItemUserCount; 24 int recommendedItemUserCount;
25 + int validRecommendationCount;
23 26
24 - public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList) { 27 + public PostprocessingExpert(List<PostprocessingSample> sampleList, List<PostprocessingSample> recommendationList,
28 + List<PostprocessingSample> annotatedList) {
25 this.sampleList = sampleList; 29 this.sampleList = sampleList;
26 this.recommendationList = recommendationList; 30 this.recommendationList = recommendationList;
31 + this.annotatedList = annotatedList;
27 } 32 }
28 33
29 - public PostprocessingCoverage getCoverage() { 34 + public void analyze() {
30 analyzeSample(); 35 analyzeSample();
31 analyzeRecommendations(); 36 analyzeRecommendations();
37 + }
38 +
39 + public PostprocessingCoverage getCoverage() {
32 return computeCoverage(); 40 return computeCoverage();
33 } 41 }
34 42
43 + public PostprocessingPrecisionRecall getPrecisionRecall() {
44 + return computePrecisionRecall();
45 + }
46 +
35 protected void analyzeSample() { 47 protected void analyzeSample() {
36 48
37 sampleItemIds = new HashSet<>(); 49 sampleItemIds = new HashSet<>();
...@@ -55,8 +67,8 @@ public class PostprocessingExpert { ...@@ -55,8 +67,8 @@ public class PostprocessingExpert {
55 } 67 }
56 68
57 recommendableItemCount = sampleItemIds.size(); 69 recommendableItemCount = sampleItemIds.size();
58 - logger.trace("Nombre d'objets recommandables {}", recommendableItemCount); 70 + logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount);
59 - logger.trace("Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size()); 71 + logger.trace("C: Taille de la matrice item-user {}", sampleItemIds.size() * sampleUserIds.size());
60 72
61 int sampleCoupleCount = 0; 73 int sampleCoupleCount = 0;
62 for (Long itemId : sampleItemIds) { 74 for (Long itemId : sampleItemIds) {
...@@ -64,13 +76,18 @@ public class PostprocessingExpert { ...@@ -64,13 +76,18 @@ public class PostprocessingExpert {
64 } 76 }
65 77
66 recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount; 78 recommendableItemUserCount = sampleItemIds.size() * sampleUserIds.size() - sampleCoupleCount;
67 - logger.trace("Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount); 79 + logger.trace("C: Nombre de couples item-user dans l'échantillon {}", sampleCoupleCount);
68 - logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount); 80 + logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount);
69 } 81 }
70 82
71 protected void analyzeRecommendations() { 83 protected void analyzeRecommendations() {
72 recommendedItemUserCount = 0; 84 recommendedItemUserCount = 0;
85 + validRecommendationCount = 0;
73 recommendedItemIds = new HashSet<>(); 86 recommendedItemIds = new HashSet<>();
87 + for (PostprocessingSample annote : annotatedList) {
88 + logger.trace("Annotated item {}, user {}", annote.getItemId(), annote.getUserId());
89 + }
90 +
74 for (PostprocessingSample reco : recommendationList) { 91 for (PostprocessingSample reco : recommendationList) {
75 Long itemId = reco.getItemId(); 92 Long itemId = reco.getItemId();
76 Long userId = reco.getUserId(); 93 Long userId = reco.getUserId();
...@@ -81,10 +98,17 @@ public class PostprocessingExpert { ...@@ -81,10 +98,17 @@ public class PostprocessingExpert {
81 recommendedItemUserCount++; 98 recommendedItemUserCount++;
82 } 99 }
83 } 100 }
101 + logger.trace("Recommendation item {}, user {}", reco.getItemId(), reco.getUserId());
102 + if (annotatedList.contains(reco)) {
103 + validRecommendationCount++;
104 + }
84 } 105 }
85 recommendedItemCount = recommendedItemIds.size(); 106 recommendedItemCount = recommendedItemIds.size();
86 - logger.trace("Nombre d'objets recommandés {}", recommendedItemCount); 107 + logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount);
87 - logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount); 108 + logger.trace("C/PR: Nombre de couples item-user recommandés {}", recommendedItemUserCount);
109 + annotatedItemUserCount = annotatedList.size();
110 + logger.trace("PR: Nombre d'associations annotées {}", annotatedItemUserCount);
111 + logger.trace("PR: Nombre de recommandations annotées {}", validRecommendationCount);
88 } 112 }
89 113
90 protected PostprocessingCoverage computeCoverage() { 114 protected PostprocessingCoverage computeCoverage() {
...@@ -92,12 +116,12 @@ public class PostprocessingExpert { ...@@ -92,12 +116,12 @@ public class PostprocessingExpert {
92 float c2; 116 float c2;
93 int c3; 117 int c3;
94 118
95 - logger.trace("Nombre d'objets recommandés {}", recommendedItemCount); 119 + logger.trace("C: Nombre d'objets recommandés {}", recommendedItemCount);
96 - logger.trace("Nombre d'objets recommandables {}", recommendableItemCount); 120 + logger.trace("C: Nombre d'objets recommandables {}", recommendableItemCount);
97 c1 = (float) recommendedItemCount / recommendableItemCount; 121 c1 = (float) recommendedItemCount / recommendableItemCount;
98 logger.trace("c1 {}", String.format(Locale.FRENCH, "%.3f", c1)); 122 logger.trace("c1 {}", String.format(Locale.FRENCH, "%.3f", c1));
99 - logger.trace("Nombre de couples item-user recommandés {}", recommendedItemUserCount); 123 + logger.trace("C: Nombre de couples item-user recommandés {}", recommendedItemUserCount);
100 - logger.trace("Nombre de couples item-user recommandables {}", recommendableItemUserCount); 124 + logger.trace("C: Nombre de couples item-user recommandables {}", recommendableItemUserCount);
101 c2 = (float) recommendedItemUserCount / recommendableItemUserCount; 125 c2 = (float) recommendedItemUserCount / recommendableItemUserCount;
102 logger.trace("c2 {}", String.format(Locale.FRENCH, "%.3f", c2)); 126 logger.trace("c2 {}", String.format(Locale.FRENCH, "%.3f", c2));
103 c3 = recommendedItemCount; 127 c3 = recommendedItemCount;
...@@ -106,8 +130,19 @@ public class PostprocessingExpert { ...@@ -106,8 +130,19 @@ public class PostprocessingExpert {
106 return new PostprocessingCoverage(c1,c2, c3); 130 return new PostprocessingCoverage(c1,c2, c3);
107 } 131 }
108 132
133 + protected PostprocessingPrecisionRecall computePrecisionRecall() {
134 + float precision;
135 + float recall;
109 136
137 + logger.trace("PR: nombre de recommandations annotées {}", validRecommendationCount);
138 + logger.trace("PR: nombre de recommandations {}", recommendedItemUserCount);
139 + precision = (float) validRecommendationCount / recommendedItemUserCount;
140 + logger.trace("PR: précision {}", String.format(Locale.FRENCH, "%.3f", precision));
141 + logger.trace("PR: nombre d'associations annotées {}", annotatedItemUserCount);
142 + recall = (float) validRecommendationCount / annotatedItemUserCount;
143 + logger.trace("PR: rappel {}", String.format(Locale.FRENCH, "%.3f", recall));
110 144
111 - 145 + return new PostprocessingPrecisionRecall(precision, recall);
146 + }
112 147
113 } 148 }
......
1 +package org.legrog.recommendation.postprocess;
2 +
3 +
4 +public class PostprocessingPrecisionRecall {
5 + private float precision;
6 + private float recall;
7 +
8 + public PostprocessingPrecisionRecall(float precision, float recall) {
9 + this.precision = precision;
10 + this.recall = recall;
11 + }
12 +
13 + public float getPrecision() {
14 + return precision;
15 + }
16 +
17 + public float getRecall() {
18 + return recall;
19 + }
20 +}
...@@ -11,10 +11,7 @@ import org.springframework.boot.ApplicationRunner; ...@@ -11,10 +11,7 @@ import org.springframework.boot.ApplicationRunner;
11 import org.springframework.stereotype.Component; 11 import org.springframework.stereotype.Component;
12 12
13 import java.io.*; 13 import java.io.*;
14 -import java.util.List; 14 +import java.util.*;
15 -import java.util.Locale;
16 -import java.util.Properties;
17 -import java.util.Set;
18 import java.util.stream.Collectors; 15 import java.util.stream.Collectors;
19 import java.util.stream.StreamSupport; 16 import java.util.stream.StreamSupport;
20 17
...@@ -33,14 +30,24 @@ public class PostprocessingRunner implements ApplicationRunner { ...@@ -33,14 +30,24 @@ public class PostprocessingRunner implements ApplicationRunner {
33 @Value("${ratingSample.filename}") 30 @Value("${ratingSample.filename}")
34 private String ratingSampleFilename; 31 private String ratingSampleFilename;
35 32
33 + @Value("${collectionAnnotated.filename}")
34 + private String collectionAnnotatedFilename;
35 +
36 + @Value("${ratingAnnotated.filename}")
37 + private String ratingAnnotatedFilename;
38 +
36 @Value("${recommandations.filename}") 39 @Value("${recommandations.filename}")
37 private String recommandationsFilename; 40 private String recommandationsFilename;
38 41
39 @Value("${coverage.filename}") 42 @Value("${coverage.filename}")
40 private String coverageFilename; 43 private String coverageFilename;
41 44
45 + @Value("${precisionRecall.filename}")
46 + private String precisionRecallFilename;
47 +
42 private Logger logger = LoggerFactory.getLogger(getClass()); 48 private Logger logger = LoggerFactory.getLogger(getClass());
43 private String sampleFilename; 49 private String sampleFilename;
50 + private String annotatedFilename;
44 51
45 52
46 @Override 53 @Override
...@@ -49,11 +56,15 @@ public class PostprocessingRunner implements ApplicationRunner { ...@@ -49,11 +56,15 @@ public class PostprocessingRunner implements ApplicationRunner {
49 loadSampleFilename(); 56 loadSampleFilename();
50 List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename)); 57 List<PostprocessingSample> samples = loadCsvSample(new File(dataDir, sampleFilename));
51 List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename)); 58 List<PostprocessingSample> recommendations = loadCsvSample(new File(dataDir, recommandationsFilename));
59 + List<PostprocessingSample> annotated = loadCsvSample(new File(dataDir, annotatedFilename));
52 60
53 - PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations); 61 + PostprocessingExpert expert = new PostprocessingExpert(samples, recommendations, annotated);
62 + expert.analyze();
54 PostprocessingCoverage coverage = expert.getCoverage(); 63 PostprocessingCoverage coverage = expert.getCoverage();
64 + PostprocessingPrecisionRecall precisionRecall = expert.getPrecisionRecall();
55 65
56 writeCsvCoverage(coverage, dataDir, coverageFilename); 66 writeCsvCoverage(coverage, dataDir, coverageFilename);
67 + writeCsvPrecisionRecall(precisionRecall, dataDir, precisionRecallFilename);
57 } 68 }
58 69
59 private void writeCsvCoverage(PostprocessingCoverage coverage, String dataDir, String coverageFilename) throws PostprocessingException { 70 private void writeCsvCoverage(PostprocessingCoverage coverage, String dataDir, String coverageFilename) throws PostprocessingException {
...@@ -69,6 +80,19 @@ public class PostprocessingRunner implements ApplicationRunner { ...@@ -69,6 +80,19 @@ public class PostprocessingRunner implements ApplicationRunner {
69 80
70 } 81 }
71 82
83 + private void writeCsvPrecisionRecall(PostprocessingPrecisionRecall precisionRecall, String dataDir, String precisionRecallFilename) throws PostprocessingException {
84 + try {
85 + CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(new File(dataDir, precisionRecallFilename)),
86 + CSVFormat.TDF.withHeader("Precision", "Recall"));
87 + csvPrinter.printRecord(String.format(Locale.FRENCH, "%.3f", precisionRecall.getPrecision()),
88 + String.format(Locale.FRENCH, "%.3f", precisionRecall.getRecall()));
89 + csvPrinter.close();
90 + } catch (IOException e) {
91 + throw new PostprocessingException("Can't write coverage file " + dataDir + precisionRecallFilename, e);
92 + }
93 +
94 + }
95 +
72 /** 96 /**
73 * read csv (TDF) file and map it to a list of PostprocessingSample 97 * read csv (TDF) file and map it to a list of PostprocessingSample
74 * 98 *
...@@ -77,6 +101,10 @@ public class PostprocessingRunner implements ApplicationRunner { ...@@ -77,6 +101,10 @@ public class PostprocessingRunner implements ApplicationRunner {
77 * @throws PostprocessingException 101 * @throws PostprocessingException
78 */ 102 */
79 private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException { 103 private List<PostprocessingSample> loadCsvSample(File file) throws PostprocessingException {
104 + if (!file.exists() || file.isDirectory()) {
105 + return new LinkedList<>();
106 + }
107 +
80 try (Reader in = new InputStreamReader(new FileInputStream(file))) { 108 try (Reader in = new InputStreamReader(new FileInputStream(file))) {
81 Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in); 109 Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in);
82 110
...@@ -109,12 +137,15 @@ public class PostprocessingRunner implements ApplicationRunner { ...@@ -109,12 +137,15 @@ public class PostprocessingRunner implements ApplicationRunner {
109 logger.trace("ratings {}", properties.getProperty("ratings")); 137 logger.trace("ratings {}", properties.getProperty("ratings"));
110 if (Boolean.parseBoolean(properties.getProperty("ratings"))) { 138 if (Boolean.parseBoolean(properties.getProperty("ratings"))) {
111 sampleFilename = ratingSampleFilename; 139 sampleFilename = ratingSampleFilename;
140 + annotatedFilename = ratingAnnotatedFilename;
112 } else { 141 } else {
113 sampleFilename = collectionSampleFilename; 142 sampleFilename = collectionSampleFilename;
143 + annotatedFilename = collectionAnnotatedFilename;
114 } 144 }
115 } else { 145 } else {
116 // by default, takes collection 146 // by default, takes collection
117 sampleFilename = collectionSampleFilename; 147 sampleFilename = collectionSampleFilename;
148 + annotatedFilename = collectionAnnotatedFilename;
118 } 149 }
119 } catch (IOException e) { 150 } catch (IOException e) {
120 throw new PostprocessingException("Can't read properties file " + parametersFilename, e); 151 throw new PostprocessingException("Can't read properties file " + parametersFilename, e);
......
...@@ -9,6 +9,15 @@ public class PostprocessingSample { ...@@ -9,6 +9,15 @@ public class PostprocessingSample {
9 this.itemId = itemId; 9 this.itemId = itemId;
10 } 10 }
11 11
12 + public boolean equals(Object obj) {
13 + if (obj instanceof PostprocessingSample) {
14 + PostprocessingSample postprocessingSample = (PostprocessingSample) obj;
15 + return this.itemId == postprocessingSample.getItemId() && this.userId == postprocessingSample.getUserId();
16 + } else {
17 + return false;
18 + }
19 + }
20 +
12 public Long getUserId() { 21 public Long getUserId() {
13 return userId; 22 return userId;
14 } 23 }
......