1   /*
2    *  Copyright (c) 1998-2001, The University of Sheffield.
3    *
4    *  This file is part of GATE (see http://gate.ac.uk/), and is free
5    *  software, licenced under the GNU Library General Public License,
6    *  Version 2, June 1991 (in the distribution as file licence.html,
7    *  and also available at http://gate.ac.uk/gate/licence.html).
8    *
9    *  Valentin Tablan 21/11/2002
10   *
11   *  $Id: Wrapper.java,v 1.3 2002/11/28 09:41:18 valyt Exp $
12   *
13   */
14  package gate.creole.ml.weka;
15  
16  import java.util.*;
17  import java.io.*;
18  import javax.swing.*;
19  import java.util.zip.*;
20  
21  import org.jdom.Element;
22  
23  import weka.core.*;
24  import weka.classifiers.*;
25  
26  import gate.creole.ml.*;
27  import gate.*;
28  import gate.creole.*;
29  import gate.util.*;
30  import gate.event.*;
31  import gate.gui.*;
32  
33  /**
34   * Wrapper class for the WEKA Machine Learning Engine.
35   * {@ see http://www.cs.waikato.ac.nz/ml/weka/}
36   */
37  
38  public class Wrapper implements MLEngine, ActionsPublisher {
39  
40    public Wrapper() {
41      actionsList = new ArrayList();
42      actionsList.add(new LoadModelAction());
43      actionsList.add(new SaveModelAction());
44      actionsList.add(new SaveDatasetAsArffAction());
45    }
46  
47    public void setOptions(Element optionsElem) {
48      this.optionsElement = optionsElem;
49    }
50  
51    public void addTrainingInstance(List attributeValues)
52                throws ExecutionException{
53      Instance instance = buildInstance(attributeValues);
54      dataset.add(instance);
55      if(classifier != null){
56        if(classifier instanceof UpdateableClassifier){
57          //the classifier can learn on the fly; we need to update it
58          try{
59            ((UpdateableClassifier)classifier).updateClassifier(instance);
60          }catch(Exception e){
61            throw new GateRuntimeException(
62              "Could not update updateable classifier! Problem was:\n" +
63              e.toString());
64          }
65        }else{
66          //the classifier is not updatebale; we need to mark the dataset as changed
67          datasetChanged = true;
68        }
69      }
70    }
71  
72    /**
73     * Constructs an instance valid for the current dataset from a list of
74     * attribute values.
75     * @param attributeValues the values for the attributes.
76     * @return an {@link weka.core.Instance} value.
77     */
78    protected Instance buildInstance(List attributeValues)
79              throws ExecutionException{
80      //sanity check
81      if(attributeValues.size() != datasetDefinition.getAttributes().size()){
82        throw new ExecutionException(
83          "The number of attributes provided is wrong for this dataset!");
84      }
85  
86      double[] values = new double[datasetDefinition.getAttributes().size()];
87      int index = 0;
88      Iterator attrIter = datasetDefinition.getAttributes().iterator();
89      Iterator valuesIter = attributeValues.iterator();
90  
91      Instance instance = new Instance(attributeValues.size());
92      instance.setDataset(dataset);
93  
94      while(attrIter.hasNext()){
95        gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next();
96        String value = (String)valuesIter.next();
97        if(value == null){
98          instance.setMissing(index);
99        }else{
100         if(attr.getFeature() == null){
101           //boolean attribute ->the value should already be true/false
102           instance.setValue(index, value);
103         }else{
104           //nominal or numeric attribute
105           if(attr.getValues() != null && !attr.getValues().isEmpty()){
106             //nominal attribute
107             if(attr.getValues().contains(value)){
108               instance.setValue(index, value);
109             }else{
110               Out.prln("Warning: invalid value: \"" + value +
111                        "\" for attribute " + attr.getName() + " was ignored!");
112               instance.setMissing(index);
113             }
114           }else{
115             //numeric attribute
116             try{
117               double db = Double.parseDouble(value);
118               instance.setValue(index, db);
119             }catch(Exception e){
120               Out.prln("Warning: invalid numeric value: \"" + value +
121                        "\" for attribute " + attr.getName() + " was ignored!");
122               instance.setMissing(index);
123             }
124           }
125         }
126       }
127       index ++;
128     }
129     return instance;
130   }
131 
132   public void setDatasetDefinition(DatasetDefintion definition) {
133     this.datasetDefinition = definition;
134   }
135 
136   public Object classifyInstance(List attributeValues)
137          throws ExecutionException {
138     Instance instance = buildInstance(attributeValues);
139 //    double result;
140 
141     try{
142       if(classifier instanceof UpdateableClassifier){
143         return convertAttributeValue(classifier.classifyInstance(instance));
144       }else{
145         if(datasetChanged){
146           if(sListener != null) sListener.statusChanged("[Re]building model...");
147           classifier.buildClassifier(dataset);
148           datasetChanged = false;
149           if(sListener != null) sListener.statusChanged("");
150         }
151 
152         if(confidenceThreshold > 0 &&
153            dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){
154           //confidence set; use probability distribution
155 
156           double[] distribution = null;
157           distribution = ((DistributionClassifier)classifier).
158                                   distributionForInstance(instance);
159 
160           List res = new ArrayList();
161           for(int i = 0; i < distribution.length; i++){
162             if(distribution[i] >= confidenceThreshold){
163               res.add(dataset.classAttribute().value(i));
164             }
165           }
166           return res;
167 
168         }else{
169           //confidence not set; use simple classification
170           return convertAttributeValue(classifier.classifyInstance(instance));
171         }
172       }
173     }catch(Exception e){
174       throw new ExecutionException(e);
175     }
176   }
177 
178   protected Object convertAttributeValue(double value){
179     gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute();
180     List classValues = classAttr.getValues();
181     if(classValues != null && !classValues.isEmpty()){
182       //nominal attribute
183       return dataset.attribute(datasetDefinition.getClassIndex()).
184                      value((int)value);
185     }else{
186       if(classAttr.getFeature() == null){
187         //boolean attribute
188         return dataset.attribute(datasetDefinition.getClassIndex()).
189                        value((int)value);
190       }else{
191         //numeric attribute
192         return new Double(value);
193       }
194     }
195   }
196   /**
197    * Initialises the classifier and prepares for running.
198    * @throws GateException
199    */
200   public void init() throws GateException{
201     //see if we can shout about what we're doing
202     sListener = null;
203     Map listeners = MainFrame.getListeners();
204     if(listeners != null){
205       sListener = (StatusListener)listeners.get("gate.event.StatusListener");
206     }
207 
208     //find the classifier to be used
209     if(sListener != null) sListener.statusChanged("Initialising classifier...");
210     Element classifierElem = optionsElement.getChild("CLASSIFIER");
211     if(classifierElem == null){
212       Out.prln("Warning (WEKA ML engine): no classifier selected;" +
213                " dataset collection only!");
214       classifier = null;
215     }else{
216       String classifierClassName = classifierElem.getTextTrim();
217 
218 
219       //get the options for the classiffier
220       if(sListener != null) sListener.statusChanged("Setting classifier options...");
221       String[] options;
222       Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS");
223       if(classifierOptionsElem == null){
224         options = new String[]{};
225       }else{
226         List optionsList = new ArrayList();
227         StringTokenizer strTok =
228           new StringTokenizer(classifierOptionsElem.getTextTrim() , " ", false);
229         while(strTok.hasMoreTokens()){
230           optionsList.add(strTok.nextToken());
231         }
232         options = (String[])optionsList.toArray(new String[optionsList.size()]);
233       }
234 
235       try{
236         classifier = Classifier.forName(classifierClassName, options);
237       }catch(Exception e){
238         throw new GateException(e);
239       }
240       Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD");
241       if(anElement != null){
242         try{
243           confidenceThreshold = Double.parseDouble(anElement.getTextTrim());
244         }catch(Exception e){
245           throw new GateException(
246             "Could not parse confidence threshold value: " +
247             anElement.getTextTrim() + "!");
248         }
249         if(!(classifier instanceof DistributionClassifier)){
250           throw new GateException(
251             "Cannot use confidence threshold with classifier: " +
252             classifier.getClass().getName() + "!");
253         }
254       }
255 
256     }
257 
258     //initialise the dataset
259     if(sListener != null) sListener.statusChanged("Initialising dataset...");
260     FastVector attributes = new FastVector();
261     weka.core.Attribute classAttribute;
262     Iterator attIter = datasetDefinition.getAttributes().iterator();
263     while(attIter.hasNext()){
264       gate.creole.ml.Attribute aGateAttr =
265         (gate.creole.ml.Attribute)attIter.next();
266       weka.core.Attribute aWekaAttribute = null;
267       if(aGateAttr.getValues() != null && !aGateAttr.getValues().isEmpty()){
268         //nominal attribute
269         FastVector attrValues = new FastVector(aGateAttr.getValues().size());
270         Iterator valIter = aGateAttr.getValues().iterator();
271         while(valIter.hasNext()){
272           attrValues.addElement(valIter.next());
273         }
274         aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
275                                                  attrValues);
276       }else{
277         if(aGateAttr.getFeature() == null){
278           //boolean attribute ([lack of] presence of an annotation)
279           FastVector attrValues = new FastVector(2);
280           attrValues.addElement("true");
281           attrValues.addElement("false");
282           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
283                                                    attrValues);
284         }else{
285           //feature is not null but no values provided -> numeric attribute
286           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName());
287         }
288       }
289       if(aGateAttr.isClass()) classAttribute = aWekaAttribute;
290       attributes.addElement(aWekaAttribute);
291     }
292 
293     dataset = new Instances("Weka ML Engine Dataset", attributes, 0);
294     dataset.setClassIndex(datasetDefinition.getClassIndex());
295 
296     if(classifier != null && classifier instanceof UpdateableClassifier){
297       try{
298         classifier.buildClassifier(dataset);
299       }catch(Exception e){
300         throw new ResourceInstantiationException(e);
301       }
302     }
303     if(sListener != null) sListener.statusChanged("");
304   }
305 
306 
307   /**
308    * Loads the state of this engine from previously saved data.
309    * @param is
310    */
311   protected void load(InputStream is) throws IOException{
312     if(sListener != null) sListener.statusChanged("Loading model...");
313     ObjectInputStream ois = new ObjectInputStream(is);
314     try{
315       classifier = (Classifier)ois.readObject();
316       dataset = (Instances)ois.readObject();
317       datasetDefinition = (DatasetDefintion)ois.readObject();
318       datasetChanged = ois.readBoolean();
319       confidenceThreshold = ois.readDouble();
320     }catch(ClassNotFoundException cnfe){
321       throw new GateRuntimeException(cnfe.toString());
322     }
323     ois.close();
324     if(sListener != null) sListener.statusChanged("");
325   }
326 
327   /**
328    * Saves the state of the engine for reuse at a later time.
329    * @param os
330    */
331   protected void save(OutputStream os) throws IOException{
332     if(sListener != null) sListener.statusChanged("Saving model...");
333     ObjectOutputStream oos = new ObjectOutputStream(os);
334     oos.writeObject(classifier);
335     oos.writeObject(dataset);
336     oos.writeObject(datasetDefinition);
337     oos.writeBoolean(datasetChanged);
338     oos.writeDouble(confidenceThreshold);
339     oos.flush();
340     oos.close();
341     if(sListener != null) sListener.statusChanged("");
342   }
343 
344   /**
345    * Gets the list of actions that can be performed on this resource.
346    * @return a List of Action objects (or null values)
347    */
348   public List getActions(){
349     return actionsList;
350   }
351 
352   /**
353    * Registers the PR using the engine with the engine itself.
354    * @param pr the processing resource that owns this engine.
355    */
356   public void setOwnerPR(ProcessingResource pr){
357     this.owner = pr;
358   }
359 
360 
361   protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction{
362     public SaveDatasetAsArffAction(){
363       super("Save dataset as ARFF");
364       putValue(SHORT_DESCRIPTION, "Saves the ML model to a file in ARFF format");
365     }
366 
367     public void actionPerformed(java.awt.event.ActionEvent evt){
368       Runnable runnable = new Runnable(){
369         public void run(){
370           JFileChooser fileChooser = MainFrame.getFileChooser();
371           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
372           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
373           fileChooser.setMultiSelectionEnabled(false);
374           if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){
375             File file = fileChooser.getSelectedFile();
376             try{
377               MainFrame.lockGUI("Saving dataset...");
378               FileWriter fw = new FileWriter(file.getCanonicalPath(), false);
379               fw.write(dataset.toString());
380               fw.flush();
381               fw.close();
382             }catch(IOException ioe){
383               JOptionPane.showMessageDialog(null,
384                               "Error!\n"+
385                                ioe.toString(),
386                                "Gate", JOptionPane.ERROR_MESSAGE);
387               ioe.printStackTrace(Err.getPrintWriter());
388             }finally{
389               MainFrame.unlockGUI();
390             }
391           }
392         }
393       };
394 
395       Thread thread = new Thread(runnable, "DatasetSaver(ARFF)");
396       thread.setPriority(Thread.MIN_PRIORITY);
397       thread.start();
398     }
399   }
400 
401 
402   protected class SaveModelAction extends javax.swing.AbstractAction{
403     public SaveModelAction(){
404       super("Save model");
405       putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
406     }
407 
408     public void actionPerformed(java.awt.event.ActionEvent evt){
409       Runnable runnable = new Runnable(){
410         public void run(){
411           JFileChooser fileChooser = MainFrame.getFileChooser();
412           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
413           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
414           fileChooser.setMultiSelectionEnabled(false);
415           if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){
416             File file = fileChooser.getSelectedFile();
417             try{
418               MainFrame.lockGUI("Saving ML model...");
419               save(new GZIPOutputStream(
420                    new FileOutputStream(file.getCanonicalPath(), false)));
421             }catch(IOException ioe){
422               JOptionPane.showMessageDialog(null,
423                               "Error!\n"+
424                                ioe.toString(),
425                                "Gate", JOptionPane.ERROR_MESSAGE);
426               ioe.printStackTrace(Err.getPrintWriter());
427             }finally{
428               MainFrame.unlockGUI();
429             }
430           }
431         }
432       };
433       Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
434       thread.setPriority(Thread.MIN_PRIORITY);
435       thread.start();
436     }
437   }
438 
439   protected class LoadModelAction extends javax.swing.AbstractAction{
440     public LoadModelAction(){
441       super("Load model");
442       putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
443     }
444 
445     public void actionPerformed(java.awt.event.ActionEvent evt){
446       Runnable runnable = new Runnable(){
447         public void run(){
448           JFileChooser fileChooser = MainFrame.getFileChooser();
449           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
450           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
451           fileChooser.setMultiSelectionEnabled(false);
452           if(fileChooser.showOpenDialog(null) == fileChooser.APPROVE_OPTION){
453             File file = fileChooser.getSelectedFile();
454             try{
455               MainFrame.lockGUI("Loading model...");
456               load(new GZIPInputStream(new FileInputStream(file)));
457             }catch(IOException ioe){
458               JOptionPane.showMessageDialog(null,
459                               "Error!\n"+
460                                ioe.toString(),
461                                "Gate", JOptionPane.ERROR_MESSAGE);
462               ioe.printStackTrace(Err.getPrintWriter());
463             }finally{
464               MainFrame.unlockGUI();
465             }
466           }
467         }
468       };
469       Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
470       thread.setPriority(Thread.MIN_PRIORITY);
471       thread.start();
472     }
473   }
474 
475 
476 
477   protected DatasetDefintion datasetDefinition;
478 
479   double confidenceThreshold = 0;
480 
481   /**
482    * The WEKA classifier used by this wrapper
483    */
484   protected Classifier classifier;
485 
486   /**
487    * The dataset used for training
488    */
489   protected Instances dataset;
490 
491   /**
492    * The JDom element contaning the options fro this wrapper.
493    */
494   protected Element optionsElement;
495 
496   /**
497    * Marks whether the dataset was changed since the last time the classifier
498    * was built.
499    */
500   protected boolean datasetChanged = false;
501 
502   protected List actionsList;
503 
504   protected ProcessingResource owner;
505 
506   protected StatusListener sListener;
507 }