CorpusWord2Vec.java
- /*******************************************************************************
- * Copyright (C) 2020 Ram Sadasiv
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- ******************************************************************************/
- package io.outofprintmagazine.corpus.batch.impl;
- import java.io.File;
- import java.io.FileInputStream;
- import java.util.ArrayList;
- import java.util.Collection;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.List;
- import java.util.Map;
- import org.apache.logging.log4j.LogManager;
- import org.apache.logging.log4j.Logger;
- import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
- import org.deeplearning4j.models.word2vec.Word2Vec;
- import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
- import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
- import com.fasterxml.jackson.databind.JsonNode;
- import com.fasterxml.jackson.databind.node.ArrayNode;
- import com.fasterxml.jackson.databind.node.ObjectNode;
- import io.outofprintmagazine.corpus.batch.CorpusBatchStep;
- import io.outofprintmagazine.corpus.batch.ICorpusBatchStep;
- public class CorpusWord2Vec extends CorpusBatchStep implements ICorpusBatchStep {
-
- private static final Logger logger = LogManager.getLogger(CorpusWord2Vec.class);
- @SuppressWarnings("unused")
- private Logger getLogger() {
- return logger;
- }
-
- private Map<String, Collection<String>> sentenceCollections = new HashMap<String, Collection<String>>();
-
- protected Map<String, Collection<String>> getSentenceCollections() {
- return sentenceCollections;
- }
-
- protected Collection<String> getSentenceCollection(String name) {
- return getSentenceCollections().get(name);
- }
-
- public CorpusWord2Vec() {
- super();
- getSentenceCollections().put("Tokens", new ArrayList<String>());
- getSentenceCollections().put("Lemmas", new ArrayList<String>());
- getSentenceCollections().put("Lemmas_POS", new ArrayList<String>());
- }
-
- @Override
- public ObjectNode getDefaultProperties() {
- ObjectNode properties = getMapper().createObjectNode();
- properties.put("minWordFrequency", 2);
- properties.put("iterations", 100);
- properties.put("layerSize", 50);
- properties.put("windowSize", 5);
- return properties;
- }
-
- @Override
- public ArrayNode run(ArrayNode input) {
- ArrayNode retval = super.run(input);
- try {
- storeModel("Tokens");
- storeModel("Lemmas");
- storeModel("Lemmas_POS");
- }
- catch (Exception e) {
- e.printStackTrace();
- getLogger().error(e);
- }
- return retval;
- }
-
-
- @Override
- public ArrayNode runOne(ObjectNode inputStepItem) throws Exception {
- ArrayNode retval = getMapper().createArrayNode();
- ObjectNode outputStepItem = copyInputToOutput(inputStepItem);
- ObjectNode oop = (ObjectNode) getJsonNodeFromStorage(inputStepItem, "oopNLPStorage");
- getSentenceCollection("Tokens").addAll(docToTokenStrings(oop, new NodeToToken()));
- getSentenceCollection("Lemmas").addAll(docToTokenStrings(oop, new NodeToLemma()));
- getSentenceCollection("Lemmas_POS").addAll(docToTokenStrings(oop, new NodeToLemma_POS()));
- outputStepItem.put(
- "oopNLPCorpusAggregatesWord2Vec_TokensStorage",
- getStorage().getScratchFilePath(
- getData().getCorpusBatchId(),
- getData().getCorpusBatchStepId(),
- "CORPUS_AGGREGATES_WORD2VEC_Tokens.word2vec"
- )
- );
- outputStepItem.put(
- "oopNLPCorpusAggregatesWord2Vec_LemmasStorage",
- getStorage().getScratchFilePath(
- getData().getCorpusBatchId(),
- getData().getCorpusBatchStepId(),
- "CORPUS_AGGREGATES_WORD2VEC_Lemmas.word2vec"
- )
- );
- outputStepItem.put(
- "oopNLPCorpusAggregatesWord2Vec_Lemmas_POSStorage",
- getStorage().getScratchFilePath(
- getData().getCorpusBatchId(),
- getData().getCorpusBatchStepId(),
- "CORPUS_AGGREGATES_WORD2VEC_Lemmas_POS.word2vec"
- )
- );
- retval.add(outputStepItem);
- return retval;
- }
-
- protected void storeModel(String prefix) throws Exception {
- Word2Vec vec = new Word2Vec.Builder()
- .minWordFrequency(getData().getProperties().get("minWordFrequency").asInt())
- .iterations(getData().getProperties().get("iterations").asInt())
- .layerSize(getData().getProperties().get("layerSize").asInt())
- .seed(42)
- .windowSize(getData().getProperties().get("windowSize").asInt())
- .iterate(
- new CollectionSentenceIterator(
- getSentenceCollection(prefix)
- )
- )
- .tokenizerFactory(new DefaultTokenizerFactory())
- .build();
- vec.fit();
- File f = File.createTempFile(prefix, "word2vec");
- WordVectorSerializer.writeWord2VecModel(vec, f);
- FileInputStream fin = null;
- try {
- fin = new FileInputStream(f);
- getStorage().storeScratchFileStream(
- getData().getCorpusId(),
- getOutputScratchFilePath("CORPUS_AGGREGATES_WORD2VEC_" + prefix, "word2vec"),
- fin
- );
- }
- finally {
- if (fin != null) {
- fin.close();
- }
- }
- }
- abstract class NodeToString {
- abstract String nodeToString(ObjectNode node);
- }
-
- class NodeToToken extends NodeToString {
- public String nodeToString(ObjectNode node) {
- return node.get("TokensAnnotation").get("word").asText();
- }
- }
-
- class NodeToLemma extends NodeToString {
- public String nodeToString(ObjectNode node) {
- return node.get("TokensAnnotation").get("lemma").asText();
- }
- }
-
- class NodeToLemma_POS extends NodeToString {
- public String nodeToString(ObjectNode node) {
- return node.get("TokensAnnotation").get("lemma").asText() + "_" + node.get("TokensAnnotation").get("pos").asText();
- }
- }
-
- protected List<String> docToTokenStrings(ObjectNode doc, NodeToString serializer) {
- ArrayNode sentences = (ArrayNode) doc.get("sentences");
- List<String> cleanedSentences = new ArrayList<String>();
- Iterator<JsonNode> sentencesIter = sentences.iterator();
- while (sentencesIter.hasNext()) {
- JsonNode sentenceNode = sentencesIter.next();
- StringBuffer buf = new StringBuffer();
- ArrayNode tokensNode = (ArrayNode) sentenceNode.get("tokens");
- Iterator<JsonNode> tokensIter = tokensNode.iterator();
- while (tokensIter.hasNext()) {
- ObjectNode tokenNode = (ObjectNode) tokensIter.next();
- //keep all the verbs
- if (
- tokenNode.has("OOPActionlessVerbsAnnotation")
- || tokenNode.has("OOPVerbsAnnotation")
- ) {
- buf.append(serializer.nodeToString(tokenNode));
- buf.append(" ");
- }
- else if (
- !tokenNode.get("TokensAnnotation").get("pos").asText().equals("POS")
- && !tokenNode.get("TokensAnnotation").get("pos").asText().equals("NFP")
- && !tokenNode.has("OOPPunctuationMarkAnnotation")
- //&& !stopWords.contains(tokenNode.get("TokensAnnotation").get("lemma").asText().toLowerCase())
- //&& !tokenNode.has("OOPCommonWordsAnnotation")
- && !tokenNode.has("OOPFunctionWordsAnnotation")
- ) {
- buf.append(serializer.nodeToString(tokenNode));
- buf.append(" ");
- }
- }
-
- cleanedSentences.add(buf.toString());
- }
- return cleanedSentences;
- }
- }