|
Wrapper |
|
1 /* 2 * Copyright (c) 1998-2001, The University of Sheffield. 3 * 4 * This file is part of GATE (see http://gate.ac.uk/), and is free 5 * software, licenced under the GNU Library General Public License, 6 * Version 2, June 1991 (in the distribution as file licence.html, 7 * and also available at http://gate.ac.uk/gate/licence.html). 8 * 9 * Valentin Tablan 21/11/2002 10 * 11 * $Id: Wrapper.java,v 1.3 2002/11/28 09:41:18 valyt Exp $ 12 * 13 */ 14 package gate.creole.ml.weka; 15 16 import java.util.*; 17 import java.io.*; 18 import javax.swing.*; 19 import java.util.zip.*; 20 21 import org.jdom.Element; 22 23 import weka.core.*; 24 import weka.classifiers.*; 25 26 import gate.creole.ml.*; 27 import gate.*; 28 import gate.creole.*; 29 import gate.util.*; 30 import gate.event.*; 31 import gate.gui.*; 32 33 /** 34 * Wrapper class for the WEKA Machine Learning Engine. 35 * {@ see http://www.cs.waikato.ac.nz/ml/weka/} 36 */ 37 38 public class Wrapper implements MLEngine, ActionsPublisher { 39 40 public Wrapper() { 41 actionsList = new ArrayList(); 42 actionsList.add(new LoadModelAction()); 43 actionsList.add(new SaveModelAction()); 44 actionsList.add(new SaveDatasetAsArffAction()); 45 } 46 47 public void setOptions(Element optionsElem) { 48 this.optionsElement = optionsElem; 49 } 50 51 public void addTrainingInstance(List attributeValues) 52 throws ExecutionException{ 53 Instance instance = buildInstance(attributeValues); 54 dataset.add(instance); 55 if(classifier != null){ 56 if(classifier instanceof UpdateableClassifier){ 57 //the classifier can learn on the fly; we need to update it 58 try{ 59 ((UpdateableClassifier)classifier).updateClassifier(instance); 60 }catch(Exception e){ 61 throw new GateRuntimeException( 62 "Could not update updateable classifier! Problem was:\n" + 63 e.toString()); 64 } 65 }else{ 66 //the classifier is not updatebale; we need to mark the dataset as changed 67 datasetChanged = true; 68 } 69 } 70 } 71 72 /** 73 * Constructs an instance valid for the current dataset from a list of 74 * attribute values. 75 * @param attributeValues the values for the attributes. 76 * @return an {@link weka.core.Instance} value. 77 */ 78 protected Instance buildInstance(List attributeValues) 79 throws ExecutionException{ 80 //sanity check 81 if(attributeValues.size() != datasetDefinition.getAttributes().size()){ 82 throw new ExecutionException( 83 "The number of attributes provided is wrong for this dataset!"); 84 } 85 86 double[] values = new double[datasetDefinition.getAttributes().size()]; 87 int index = 0; 88 Iterator attrIter = datasetDefinition.getAttributes().iterator(); 89 Iterator valuesIter = attributeValues.iterator(); 90 91 Instance instance = new Instance(attributeValues.size()); 92 instance.setDataset(dataset); 93 94 while(attrIter.hasNext()){ 95 gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next(); 96 String value = (String)valuesIter.next(); 97 if(value == null){ 98 instance.setMissing(index); 99 }else{ 100 if(attr.getFeature() == null){ 101 //boolean attribute ->the value should already be true/false 102 instance.setValue(index, value); 103 }else{ 104 //nominal or numeric attribute 105 if(attr.getValues() != null && !attr.getValues().isEmpty()){ 106 //nominal attribute 107 if(attr.getValues().contains(value)){ 108 instance.setValue(index, value); 109 }else{ 110 Out.prln("Warning: invalid value: \"" + value + 111 "\" for attribute " + attr.getName() + " was ignored!"); 112 instance.setMissing(index); 113 } 114 }else{ 115 //numeric attribute 116 try{ 117 double db = Double.parseDouble(value); 118 instance.setValue(index, db); 119 }catch(Exception e){ 120 Out.prln("Warning: invalid numeric value: \"" + value + 121 "\" for attribute " + attr.getName() + " was ignored!"); 122 instance.setMissing(index); 123 } 124 } 125 } 126 } 127 index ++; 128 } 129 return instance; 130 } 131 132 public void setDatasetDefinition(DatasetDefintion definition) { 133 this.datasetDefinition = definition; 134 } 135 136 public Object classifyInstance(List attributeValues) 137 throws ExecutionException { 138 Instance instance = buildInstance(attributeValues); 139 // double result; 140 141 try{ 142 if(classifier instanceof UpdateableClassifier){ 143 return convertAttributeValue(classifier.classifyInstance(instance)); 144 }else{ 145 if(datasetChanged){ 146 if(sListener != null) sListener.statusChanged("[Re]building model..."); 147 classifier.buildClassifier(dataset); 148 datasetChanged = false; 149 if(sListener != null) sListener.statusChanged(""); 150 } 151 152 if(confidenceThreshold > 0 && 153 dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){ 154 //confidence set; use probability distribution 155 156 double[] distribution = null; 157 distribution = ((DistributionClassifier)classifier). 158 distributionForInstance(instance); 159 160 List res = new ArrayList(); 161 for(int i = 0; i < distribution.length; i++){ 162 if(distribution[i] >= confidenceThreshold){ 163 res.add(dataset.classAttribute().value(i)); 164 } 165 } 166 return res; 167 168 }else{ 169 //confidence not set; use simple classification 170 return convertAttributeValue(classifier.classifyInstance(instance)); 171 } 172 } 173 }catch(Exception e){ 174 throw new ExecutionException(e); 175 } 176 } 177 178 protected Object convertAttributeValue(double value){ 179 gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute(); 180 List classValues = classAttr.getValues(); 181 if(classValues != null && !classValues.isEmpty()){ 182 //nominal attribute 183 return dataset.attribute(datasetDefinition.getClassIndex()). 184 value((int)value); 185 }else{ 186 if(classAttr.getFeature() == null){ 187 //boolean attribute 188 return dataset.attribute(datasetDefinition.getClassIndex()). 189 value((int)value); 190 }else{ 191 //numeric attribute 192 return new Double(value); 193 } 194 } 195 } 196 /** 197 * Initialises the classifier and prepares for running. 198 * @throws GateException 199 */ 200 public void init() throws GateException{ 201 //see if we can shout about what we're doing 202 sListener = null; 203 Map listeners = MainFrame.getListeners(); 204 if(listeners != null){ 205 sListener = (StatusListener)listeners.get("gate.event.StatusListener"); 206 } 207 208 //find the classifier to be used 209 if(sListener != null) sListener.statusChanged("Initialising classifier..."); 210 Element classifierElem = optionsElement.getChild("CLASSIFIER"); 211 if(classifierElem == null){ 212 Out.prln("Warning (WEKA ML engine): no classifier selected;" + 213 " dataset collection only!"); 214 classifier = null; 215 }else{ 216 String classifierClassName = classifierElem.getTextTrim(); 217 218 219 //get the options for the classiffier 220 if(sListener != null) sListener.statusChanged("Setting classifier options..."); 221 String[] options; 222 Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS"); 223 if(classifierOptionsElem == null){ 224 options = new String[]{}; 225 }else{ 226 List optionsList = new ArrayList(); 227 StringTokenizer strTok = 228 new StringTokenizer(classifierOptionsElem.getTextTrim() , " ", false); 229 while(strTok.hasMoreTokens()){ 230 optionsList.add(strTok.nextToken()); 231 } 232 options = (String[])optionsList.toArray(new String[optionsList.size()]); 233 } 234 235 try{ 236 classifier = Classifier.forName(classifierClassName, options); 237 }catch(Exception e){ 238 throw new GateException(e); 239 } 240 Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD"); 241 if(anElement != null){ 242 try{ 243 confidenceThreshold = Double.parseDouble(anElement.getTextTrim()); 244 }catch(Exception e){ 245 throw new GateException( 246 "Could not parse confidence threshold value: " + 247 anElement.getTextTrim() + "!"); 248 } 249 if(!(classifier instanceof DistributionClassifier)){ 250 throw new GateException( 251 "Cannot use confidence threshold with classifier: " + 252 classifier.getClass().getName() + "!"); 253 } 254 } 255 256 } 257 258 //initialise the dataset 259 if(sListener != null) sListener.statusChanged("Initialising dataset..."); 260 FastVector attributes = new FastVector(); 261 weka.core.Attribute classAttribute; 262 Iterator attIter = datasetDefinition.getAttributes().iterator(); 263 while(attIter.hasNext()){ 264 gate.creole.ml.Attribute aGateAttr = 265 (gate.creole.ml.Attribute)attIter.next(); 266 weka.core.Attribute aWekaAttribute = null; 267 if(aGateAttr.getValues() != null && !aGateAttr.getValues().isEmpty()){ 268 //nominal attribute 269 FastVector attrValues = new FastVector(aGateAttr.getValues().size()); 270 Iterator valIter = aGateAttr.getValues().iterator(); 271 while(valIter.hasNext()){ 272 attrValues.addElement(valIter.next()); 273 } 274 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), 275 attrValues); 276 }else{ 277 if(aGateAttr.getFeature() == null){ 278 //boolean attribute ([lack of] presence of an annotation) 279 FastVector attrValues = new FastVector(2); 280 attrValues.addElement("true"); 281 attrValues.addElement("false"); 282 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), 283 attrValues); 284 }else{ 285 //feature is not null but no values provided -> numeric attribute 286 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName()); 287 } 288 } 289 if(aGateAttr.isClass()) classAttribute = aWekaAttribute; 290 attributes.addElement(aWekaAttribute); 291 } 292 293 dataset = new Instances("Weka ML Engine Dataset", attributes, 0); 294 dataset.setClassIndex(datasetDefinition.getClassIndex()); 295 296 if(classifier != null && classifier instanceof UpdateableClassifier){ 297 try{ 298 classifier.buildClassifier(dataset); 299 }catch(Exception e){ 300 throw new ResourceInstantiationException(e); 301 } 302 } 303 if(sListener != null) sListener.statusChanged(""); 304 } 305 306 307 /** 308 * Loads the state of this engine from previously saved data. 309 * @param is 310 */ 311 protected void load(InputStream is) throws IOException{ 312 if(sListener != null) sListener.statusChanged("Loading model..."); 313 ObjectInputStream ois = new ObjectInputStream(is); 314 try{ 315 classifier = (Classifier)ois.readObject(); 316 dataset = (Instances)ois.readObject(); 317 datasetDefinition = (DatasetDefintion)ois.readObject(); 318 datasetChanged = ois.readBoolean(); 319 confidenceThreshold = ois.readDouble(); 320 }catch(ClassNotFoundException cnfe){ 321 throw new GateRuntimeException(cnfe.toString()); 322 } 323 ois.close(); 324 if(sListener != null) sListener.statusChanged(""); 325 } 326 327 /** 328 * Saves the state of the engine for reuse at a later time. 329 * @param os 330 */ 331 protected void save(OutputStream os) throws IOException{ 332 if(sListener != null) sListener.statusChanged("Saving model..."); 333 ObjectOutputStream oos = new ObjectOutputStream(os); 334 oos.writeObject(classifier); 335 oos.writeObject(dataset); 336 oos.writeObject(datasetDefinition); 337 oos.writeBoolean(datasetChanged); 338 oos.writeDouble(confidenceThreshold); 339 oos.flush(); 340 oos.close(); 341 if(sListener != null) sListener.statusChanged(""); 342 } 343 344 /** 345 * Gets the list of actions that can be performed on this resource. 346 * @return a List of Action objects (or null values) 347 */ 348 public List getActions(){ 349 return actionsList; 350 } 351 352 /** 353 * Registers the PR using the engine with the engine itself. 354 * @param pr the processing resource that owns this engine. 355 */ 356 public void setOwnerPR(ProcessingResource pr){ 357 this.owner = pr; 358 } 359 360 361 protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction{ 362 public SaveDatasetAsArffAction(){ 363 super("Save dataset as ARFF"); 364 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file in ARFF format"); 365 } 366 367 public void actionPerformed(java.awt.event.ActionEvent evt){ 368 Runnable runnable = new Runnable(){ 369 public void run(){ 370 JFileChooser fileChooser = MainFrame.getFileChooser(); 371 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 372 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 373 fileChooser.setMultiSelectionEnabled(false); 374 if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){ 375 File file = fileChooser.getSelectedFile(); 376 try{ 377 MainFrame.lockGUI("Saving dataset..."); 378 FileWriter fw = new FileWriter(file.getCanonicalPath(), false); 379 fw.write(dataset.toString()); 380 fw.flush(); 381 fw.close(); 382 }catch(IOException ioe){ 383 JOptionPane.showMessageDialog(null, 384 "Error!\n"+ 385 ioe.toString(), 386 "Gate", JOptionPane.ERROR_MESSAGE); 387 ioe.printStackTrace(Err.getPrintWriter()); 388 }finally{ 389 MainFrame.unlockGUI(); 390 } 391 } 392 } 393 }; 394 395 Thread thread = new Thread(runnable, "DatasetSaver(ARFF)"); 396 thread.setPriority(Thread.MIN_PRIORITY); 397 thread.start(); 398 } 399 } 400 401 402 protected class SaveModelAction extends javax.swing.AbstractAction{ 403 public SaveModelAction(){ 404 super("Save model"); 405 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file"); 406 } 407 408 public void actionPerformed(java.awt.event.ActionEvent evt){ 409 Runnable runnable = new Runnable(){ 410 public void run(){ 411 JFileChooser fileChooser = MainFrame.getFileChooser(); 412 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 413 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 414 fileChooser.setMultiSelectionEnabled(false); 415 if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){ 416 File file = fileChooser.getSelectedFile(); 417 try{ 418 MainFrame.lockGUI("Saving ML model..."); 419 save(new GZIPOutputStream( 420 new FileOutputStream(file.getCanonicalPath(), false))); 421 }catch(IOException ioe){ 422 JOptionPane.showMessageDialog(null, 423 "Error!\n"+ 424 ioe.toString(), 425 "Gate", JOptionPane.ERROR_MESSAGE); 426 ioe.printStackTrace(Err.getPrintWriter()); 427 }finally{ 428 MainFrame.unlockGUI(); 429 } 430 } 431 } 432 }; 433 Thread thread = new Thread(runnable, "ModelSaver(serialisation)"); 434 thread.setPriority(Thread.MIN_PRIORITY); 435 thread.start(); 436 } 437 } 438 439 protected class LoadModelAction extends javax.swing.AbstractAction{ 440 public LoadModelAction(){ 441 super("Load model"); 442 putValue(SHORT_DESCRIPTION, "Loads a ML model from a file"); 443 } 444 445 public void actionPerformed(java.awt.event.ActionEvent evt){ 446 Runnable runnable = new Runnable(){ 447 public void run(){ 448 JFileChooser fileChooser = MainFrame.getFileChooser(); 449 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 450 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 451 fileChooser.setMultiSelectionEnabled(false); 452 if(fileChooser.showOpenDialog(null) == fileChooser.APPROVE_OPTION){ 453 File file = fileChooser.getSelectedFile(); 454 try{ 455 MainFrame.lockGUI("Loading model..."); 456 load(new GZIPInputStream(new FileInputStream(file))); 457 }catch(IOException ioe){ 458 JOptionPane.showMessageDialog(null, 459 "Error!\n"+ 460 ioe.toString(), 461 "Gate", JOptionPane.ERROR_MESSAGE); 462 ioe.printStackTrace(Err.getPrintWriter()); 463 }finally{ 464 MainFrame.unlockGUI(); 465 } 466 } 467 } 468 }; 469 Thread thread = new Thread(runnable, "ModelLoader(serialisation)"); 470 thread.setPriority(Thread.MIN_PRIORITY); 471 thread.start(); 472 } 473 } 474 475 476 477 protected DatasetDefintion datasetDefinition; 478 479 double confidenceThreshold = 0; 480 481 /** 482 * The WEKA classifier used by this wrapper 483 */ 484 protected Classifier classifier; 485 486 /** 487 * The dataset used for training 488 */ 489 protected Instances dataset; 490 491 /** 492 * The JDom element contaning the options fro this wrapper. 493 */ 494 protected Element optionsElement; 495 496 /** 497 * Marks whether the dataset was changed since the last time the classifier 498 * was built. 499 */ 500 protected boolean datasetChanged = false; 501 502 protected List actionsList; 503 504 protected ProcessingResource owner; 505 506 protected StatusListener sListener; 507 }
|
Wrapper |
|