Pré-traitement : tirage aléatoire d'un pourcentage indiqué d'éléments annotés.
Showing
7 changed files
with
321 additions
and
0 deletions
... | @@ -14,4 +14,37 @@ | ... | @@ -14,4 +14,37 @@ |
14 | <artifactId>grog-recommendation-preprocess</artifactId> | 14 | <artifactId>grog-recommendation-preprocess</artifactId> |
15 | <version>3.0-SNAPSHOT</version> | 15 | <version>3.0-SNAPSHOT</version> |
16 | <packaging>jar</packaging> | 16 | <packaging>jar</packaging> |
17 | + <dependencies> | ||
18 | + <dependency> | ||
19 | + <groupId>org.springframework.boot</groupId> | ||
20 | + <artifactId>spring-boot-starter</artifactId> | ||
21 | + </dependency> | ||
22 | + <dependency> | ||
23 | + <groupId>org.springframework.boot</groupId> | ||
24 | + <artifactId>spring-boot-starter-test</artifactId> | ||
25 | + <scope>test</scope> | ||
26 | + </dependency> | ||
27 | + <dependency> | ||
28 | + <groupId>org.apache.commons</groupId> | ||
29 | + <artifactId>commons-csv</artifactId> | ||
30 | + <version>1.3</version> | ||
31 | + </dependency> | ||
32 | + </dependencies> | ||
33 | + | ||
34 | + <build> | ||
35 | + <plugins> | ||
36 | + <plugin> | ||
37 | + <groupId>org.springframework.boot</groupId> | ||
38 | + <artifactId>spring-boot-maven-plugin</artifactId> | ||
39 | + <version>1.5.2.RELEASE </version> | ||
40 | + <executions> | ||
41 | + <execution> | ||
42 | + <goals> | ||
43 | + <goal>repackage</goal> | ||
44 | + </goals> | ||
45 | + </execution> | ||
46 | + </executions> | ||
47 | + </plugin> | ||
48 | + </plugins> | ||
49 | + </build> | ||
17 | </project> | 50 | </project> |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
1 | +package org.legrog.recommendation.preprocess; | ||
2 | + | ||
3 | +public class AssociationElement { | ||
4 | + private Long userId; | ||
5 | + private Long itemId; | ||
6 | + | ||
7 | + public AssociationElement(Long userId, Long itemId) { | ||
8 | + this.userId = userId; | ||
9 | + this.itemId = itemId; | ||
10 | + } | ||
11 | + | ||
12 | + public Long getUserId() { | ||
13 | + return userId; | ||
14 | + } | ||
15 | + | ||
16 | + public Long getItemId() { | ||
17 | + return itemId; | ||
18 | + } | ||
19 | +} |
1 | +package org.legrog.recommendation.preprocess; | ||
2 | + | ||
3 | +import org.springframework.boot.SpringApplication; | ||
4 | +import org.springframework.boot.autoconfigure.SpringBootApplication; | ||
5 | + | ||
6 | +@SpringBootApplication | ||
7 | +public class PreprocessingApplication { | ||
8 | + | ||
9 | + public static void main(String[] args) { | ||
10 | + SpringApplication.run(PreprocessingApplication.class, args); | ||
11 | + } | ||
12 | + | ||
13 | +} |
1 | +package org.legrog.recommendation.preprocess; | ||
2 | + | ||
3 | +import org.apache.commons.csv.CSVFormat; | ||
4 | +import org.apache.commons.csv.CSVPrinter; | ||
5 | +import org.apache.commons.csv.CSVRecord; | ||
6 | +import org.slf4j.Logger; | ||
7 | +import org.slf4j.LoggerFactory; | ||
8 | +import org.springframework.beans.factory.annotation.Value; | ||
9 | +import org.springframework.boot.ApplicationArguments; | ||
10 | +import org.springframework.boot.ApplicationRunner; | ||
11 | +import org.springframework.stereotype.Component; | ||
12 | + | ||
13 | +import java.io.*; | ||
14 | +import java.util.ArrayList; | ||
15 | +import java.util.List; | ||
16 | +import java.util.Properties; | ||
17 | +import java.util.Random; | ||
18 | +import java.util.stream.Collectors; | ||
19 | +import java.util.stream.StreamSupport; | ||
20 | + | ||
21 | +@Component | ||
22 | +public class PreprocessingRunner implements ApplicationRunner { | ||
23 | + | ||
24 | + Logger logger = LoggerFactory.getLogger(getClass()); | ||
25 | + | ||
26 | + @Value("${parameters.filename}") | ||
27 | + private String parametersFilename; | ||
28 | + | ||
29 | + @Value("${data.dir}") | ||
30 | + private String dataDir; | ||
31 | + | ||
32 | + @Value("${collectionComplete.filename}") | ||
33 | + private String collectionCompleteFilename; | ||
34 | + | ||
35 | + @Value("${ratingComplete.filename}") | ||
36 | + private String ratingCompleteFilename; | ||
37 | + | ||
38 | + @Value("${collectionSample.filename}") | ||
39 | + private String collectionSampleFilename; | ||
40 | + | ||
41 | + @Value("${ratingSample.filename}") | ||
42 | + private String ratingSampleFilename; | ||
43 | + | ||
44 | + @Value("${collectionAnnotated.filename}") | ||
45 | + private String collectionAnnotatedFilename; | ||
46 | + | ||
47 | + @Value("${ratingAnnotated.filename}") | ||
48 | + private String ratingAnnotatedFilename; | ||
49 | + | ||
50 | + private String completeFilename; | ||
51 | + private String sampleFilename; | ||
52 | + private String annontatedFilename; | ||
53 | + | ||
54 | + private Boolean ratings; | ||
55 | + | ||
56 | + private int annotatePercent; | ||
57 | + | ||
58 | + @Override | ||
59 | + public void run(ApplicationArguments applicationArguments) throws Exception { | ||
60 | + loadParameters(); | ||
61 | + setFilenames(); | ||
62 | + List<AssociationElement> associationElements = loadAssociationElements(new File(dataDir, completeFilename)); | ||
63 | + List<Integer> annotateIndexes = chooseAnnotated(associationElements.size()); | ||
64 | + writeSampleAndAnnotated(new File(dataDir, sampleFilename), new File(dataDir, annontatedFilename), annotateIndexes, associationElements); | ||
65 | + } | ||
66 | + | ||
67 | + private List<Integer> chooseAnnotated(int size) { | ||
68 | + List<Integer> annotatedChosen = new ArrayList<>(); | ||
69 | + | ||
70 | + Random random = new Random(); | ||
71 | + Integer randomInteger; | ||
72 | + | ||
73 | + while (annotatedChosen.size() <= size * annotatePercent / 100.0) { | ||
74 | + randomInteger = new Integer(random.nextInt(size)); | ||
75 | + if (!annotatedChosen.contains(randomInteger)) { | ||
76 | + annotatedChosen.add(randomInteger); | ||
77 | + } | ||
78 | + } | ||
79 | + | ||
80 | + return annotatedChosen; | ||
81 | + } | ||
82 | + | ||
83 | + private void writeSampleAndAnnotated(File sampleFile, File annotatedFile, List<Integer> annotateIndexes, List<AssociationElement> associationElements) throws PreprocessingException { | ||
84 | + try { | ||
85 | + AssociationElement associationElement; | ||
86 | + if (ratings) { | ||
87 | + RatingElement ratingElement; | ||
88 | + CSVFormat ratingsFormat = CSVFormat.TDF.withHeader("itemId", "userId", "rating"); | ||
89 | + CSVPrinter samplePrinter = new CSVPrinter(new FileWriter(sampleFile), ratingsFormat); | ||
90 | + CSVPrinter annotatedPrinter = new CSVPrinter(new FileWriter(annotatedFile), ratingsFormat); | ||
91 | + | ||
92 | + for (int i = 0; i < associationElements.size(); i++) { | ||
93 | + ratingElement = (RatingElement) associationElements.get(i); | ||
94 | + Integer index = new Integer(i); | ||
95 | + if (annotateIndexes.contains(index)) { | ||
96 | + annotatedPrinter.printRecord(ratingElement.getItemId(), ratingElement.getUserId(), ratingElement.getRating()); | ||
97 | + } else { | ||
98 | + samplePrinter.printRecord(ratingElement.getItemId(), ratingElement.getUserId(), ratingElement.getRating()); | ||
99 | + } | ||
100 | + } | ||
101 | + samplePrinter.close(); | ||
102 | + annotatedPrinter.close(); | ||
103 | + | ||
104 | + } else { | ||
105 | + CSVFormat collectionsFormat = CSVFormat.TDF.withHeader("itemId", "userId"); | ||
106 | + CSVPrinter samplePrinter = new CSVPrinter(new FileWriter(sampleFile), collectionsFormat); | ||
107 | + CSVPrinter annotatedPrinter = new CSVPrinter(new FileWriter(annotatedFile), collectionsFormat); | ||
108 | + | ||
109 | + for (int i = 0; i < associationElements.size(); i++) { | ||
110 | + associationElement = associationElements.get(i); | ||
111 | + Integer index = new Integer(i); | ||
112 | + if (annotateIndexes.contains(index)) { | ||
113 | + annotatedPrinter.printRecord(associationElement.getItemId(), associationElement.getUserId()); | ||
114 | + } else { | ||
115 | + samplePrinter.printRecord(associationElement.getItemId(), associationElement.getUserId()); | ||
116 | + } | ||
117 | + } | ||
118 | + samplePrinter.close(); | ||
119 | + annotatedPrinter.close(); | ||
120 | + | ||
121 | + } | ||
122 | + } catch (IOException e) { | ||
123 | + throw new PreprocessingException("Can't write sample or annotated file " + dataDir + sampleFilename + " / " + annontatedFilename, e); | ||
124 | + } | ||
125 | + } | ||
126 | + private List<AssociationElement> loadAssociationElements(File file) throws PreprocessingException { | ||
127 | + try (Reader in = new InputStreamReader(new FileInputStream(file))) { | ||
128 | + Iterable<CSVRecord> records = CSVFormat.TDF.withFirstRecordAsHeader().parse(in); | ||
129 | + | ||
130 | + if (ratings) { | ||
131 | + return StreamSupport.stream(records.spliterator(), false) | ||
132 | + .map((record) -> new RatingElement( | ||
133 | + Long.parseLong(record.get("userId")), | ||
134 | + Long.parseLong(record.get("itemId")), | ||
135 | + Integer.parseInt(record.get("rating"))) | ||
136 | + ) | ||
137 | + .collect(Collectors.toList()); | ||
138 | + } else { | ||
139 | + return StreamSupport.stream(records.spliterator(), false) | ||
140 | + .map((record) -> new AssociationElement( | ||
141 | + Long.parseLong(record.get("userId")), | ||
142 | + Long.parseLong(record.get("itemId"))) | ||
143 | + ) | ||
144 | + .collect(Collectors.toList()); | ||
145 | + } | ||
146 | + | ||
147 | + } catch (IOException e) { | ||
148 | + throw new PreprocessingException("Can't read CSV file " + file, e); | ||
149 | + } | ||
150 | + } | ||
151 | + | ||
152 | + private void setFilenames() { | ||
153 | + if (ratings) { | ||
154 | + completeFilename = ratingCompleteFilename; | ||
155 | + sampleFilename = ratingSampleFilename; | ||
156 | + annontatedFilename = ratingAnnotatedFilename; | ||
157 | + } else { | ||
158 | + completeFilename = collectionCompleteFilename; | ||
159 | + sampleFilename = collectionSampleFilename; | ||
160 | + annontatedFilename = collectionAnnotatedFilename; | ||
161 | + } | ||
162 | + } | ||
163 | + | ||
164 | + private void loadParameters() throws PreprocessingException { | ||
165 | + try (InputStream in = new FileInputStream(new File(dataDir, parametersFilename))) { | ||
166 | + Properties properties = new Properties(); | ||
167 | + properties.load(in); | ||
168 | + | ||
169 | + if (properties.containsKey("ratings")) { | ||
170 | + logger.trace("ratings {}", properties.getProperty("ratings")); | ||
171 | + if (Boolean.parseBoolean(properties.getProperty("ratings"))) { | ||
172 | + ratings = Boolean.TRUE; | ||
173 | + } else { | ||
174 | + ratings = Boolean.FALSE; | ||
175 | + } | ||
176 | + } else { | ||
177 | + // by default, takes collection | ||
178 | + ratings = Boolean.FALSE; | ||
179 | + } | ||
180 | + | ||
181 | + if (properties.containsKey("annotatePercent")) { | ||
182 | + annotatePercent = Integer.parseInt(properties.getProperty("annotatePercent")); | ||
183 | + } else { | ||
184 | + // default top size is 10 | ||
185 | + annotatePercent = 1; | ||
186 | + } | ||
187 | + | ||
188 | + } catch (IOException e) { | ||
189 | + throw new PreprocessingException("Can't read parameters properties file " + dataDir + parametersFilename, e); | ||
190 | + } | ||
191 | + | ||
192 | + } | ||
193 | + private class PreprocessingException extends Exception { | ||
194 | + public PreprocessingException() { | ||
195 | + super(); | ||
196 | + } | ||
197 | + | ||
198 | + public PreprocessingException(String message) { | ||
199 | + super(message); | ||
200 | + } | ||
201 | + | ||
202 | + public PreprocessingException(String message, Throwable cause) { | ||
203 | + super(message, cause); | ||
204 | + } | ||
205 | + | ||
206 | + public PreprocessingException(Throwable cause) { | ||
207 | + super(cause); | ||
208 | + } | ||
209 | + | ||
210 | + protected PreprocessingException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { | ||
211 | + super(message, cause, enableSuppression, writableStackTrace); | ||
212 | + } | ||
213 | + } | ||
214 | +} |
1 | +package org.legrog.recommendation.preprocess; | ||
2 | + | ||
3 | +public class RatingElement extends AssociationElement { | ||
4 | + private Integer rating; | ||
5 | + | ||
6 | + public RatingElement(Long userId, Long itemId, Integer rating) { | ||
7 | + super(userId, itemId); | ||
8 | + this.rating = rating; | ||
9 | + } | ||
10 | + | ||
11 | + public Integer getRating() { | ||
12 | + return rating; | ||
13 | + } | ||
14 | +} |
grog-recommendation/grog-recommendation-preprocess/src/main/resources/application.properties
0 → 100644
1 | +parameters.filename=${parameters.filename} | ||
2 | +collectionSample.filename=${collectionSample.filename} | ||
3 | +ratingSample.filename=${ratingSample.filename} | ||
4 | +recommandations.filename=${recommandations.filename} | ||
5 | +coverage.filename=${coverage.filename} | ||
6 | +data.dir=dumb/ | ||
7 | +collectionComplete.filename=${collectionComplete.filename} | ||
8 | +ratingComplete.filename=${ratingComplete.filename} | ||
9 | +collectionAnnotated.filename=${collectionAnnotated.filename} | ||
10 | +ratingAnnotated.filename=${ratingAnnotated.filename} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<configuration> | ||
3 | + | ||
4 | + <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> | ||
5 | + <!-- encoders are assigned the type | ||
6 | + ch.qos.logback.classic.encoder.PatternLayoutEncoder by default --> | ||
7 | + <encoder> | ||
8 | + <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern> | ||
9 | + </encoder> | ||
10 | + </appender> | ||
11 | + | ||
12 | + <logger name="org.legrog" level="DEBUG"/> | ||
13 | + <logger name="org.legrog.recommendation.preprocess" level="TRACE"/> | ||
14 | + | ||
15 | + <root level="warn"> | ||
16 | + <appender-ref ref="STDOUT" /> | ||
17 | + </root> | ||
18 | +</configuration> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment