001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.io.Serializable; 016import java.lang.reflect.Array; 017import java.util.ArrayList; 018import java.util.Arrays; 019import java.util.List; 020 021import org.apache.commons.math3.util.MathArrays; 022import org.eclipse.january.DatasetException; 023import org.slf4j.Logger; 024import org.slf4j.LoggerFactory; 025 026/** 027 * Utilities for manipulating datasets 028 */ 029@SuppressWarnings("unchecked") 030public class DatasetUtils { 031 032 /** 033 * Setup the logging facilities 034 */ 035 transient protected static final Logger utilsLogger = LoggerFactory.getLogger(DatasetUtils.class); 036 037 /** 038 * Append copy of dataset with another dataset along n-th axis 039 * 040 * @param a 041 * @param b 042 * @param axis 043 * number of axis (negative number counts from last) 044 * @return appended dataset 045 */ 046 public static Dataset append(IDataset a, IDataset b, int axis) { 047 final int[] ashape = a.getShape(); 048 final int rank = ashape.length; 049 final int[] bshape = b.getShape(); 050 if (rank != bshape.length) { 051 throw new IllegalArgumentException("Incompatible number of dimensions"); 052 } 053 if (axis >= rank) { 054 throw new IllegalArgumentException("Axis specified exceeds array dimensions"); 055 } else if (axis > -rank) { 056 if (axis < 0) 057 axis += rank; 058 } else { 059 throw new IllegalArgumentException("Axis specified is less than " + (-rank)); 060 } 061 062 for (int i = 0; i < rank; i++) { 063 if (i != axis && ashape[i] != bshape[i]) { 064 throw new IllegalArgumentException("Incompatible dimensions"); 065 } 066 } 067 final int[] nshape = new int[rank]; 068 for (int i = 0; i < rank; i++) { 069 nshape[i] = ashape[i]; 070 } 071 nshape[axis] += bshape[axis]; 072 final int ot = DTypeUtils.getDType(b); 073 final int dt = DTypeUtils.getDType(a); 074 @SuppressWarnings("deprecation") 075 Dataset ds = DatasetFactory.zeros(a.getElementsPerItem(), nshape, dt > ot ? dt : ot); 076 IndexIterator iter = ds.getIterator(true); 077 int[] pos = iter.getPos(); 078 while (iter.hasNext()) { 079 int d = ashape[axis]; 080 if (pos[axis] < d) { 081 ds.setObjectAbs(iter.index, a.getObject(pos)); 082 } else { 083 pos[axis] -= d; 084 ds.setObjectAbs(iter.index, b.getObject(pos)); 085 pos[axis] += d; 086 } 087 } 088 089 return ds; 090 } 091 092 /** 093 * Changes specific items of dataset by replacing them with other array 094 * @param a 095 * @param indices dataset interpreted as integers 096 * @param values 097 * @return changed dataset 098 */ 099 public static <T extends Dataset> T put(final T a, final Dataset indices, Object values) { 100 IndexIterator it = indices.getIterator(); 101 Dataset vd = DatasetFactory.createFromObject(values).flatten(); 102 int vlen = vd.getSize(); 103 int v = 0; 104 while (it.hasNext()) { 105 if (v >= vlen) v -= vlen; 106 107 a.setObjectAbs((int) indices.getElementLongAbs(it.index), vd.getObjectAbs(v++)); 108 } 109 return a; 110 } 111 112 /** 113 * Changes specific items of dataset by replacing them with other array 114 * @param a 115 * @param indices 116 * @param values 117 * @return changed dataset 118 */ 119 public static <T extends Dataset> T put(final T a, final int[] indices, Object values) { 120 int ilen = indices.length; 121 Dataset vd = DatasetFactory.createFromObject(values).flatten(); 122 int vlen = vd.getSize(); 123 for (int i = 0, v= 0; i < ilen; i++) { 124 if (v >= vlen) v -= vlen; 125 126 a.setObjectAbs(indices[i], vd.getObjectAbs(v++)); 127 } 128 return a; 129 } 130 131 /** 132 * Take items from dataset along an axis 133 * @param indices dataset interpreted as integers 134 * @param axis if null, then use flattened view 135 * @return a sub-array 136 */ 137 public static <T extends Dataset> T take(final T a, final Dataset indices, Integer axis) { 138 IntegerDataset indexes = (IntegerDataset) indices.flatten().cast(Dataset.INT32); 139 return take(a, indexes.getData(), axis); 140 } 141 142 /** 143 * Take items from dataset along an axis 144 * @param indices 145 * @param axis if null, then use flattened view 146 * @return a sub-array 147 */ 148 @SuppressWarnings("deprecation") 149 public static <T extends Dataset> T take(final T a, final int[] indices, Integer axis) { 150 if (indices == null || indices.length == 0) { 151 utilsLogger.error("No indices given"); 152 throw new IllegalArgumentException("No indices given"); 153 } 154 int[] ashape = a.getShape(); 155 final int rank = ashape.length; 156 final int at = a.getDType(); 157 final int ilen = indices.length; 158 final int is = a.getElementsPerItem(); 159 160 Dataset result; 161 if (axis == null) { 162 ashape = new int[1]; 163 ashape[0] = ilen; 164 result = DatasetFactory.zeros(is, ashape, at); 165 Serializable src = a.getBuffer(); 166 for (int i = 0; i < ilen; i++) { 167 ((AbstractDataset) result).setItemDirect(i, indices[i], src); 168 } 169 } else { 170 axis = a.checkAxis(axis); 171 ashape[axis] = ilen; 172 result = DatasetFactory.zeros(is, ashape, at); 173 174 int[] dpos = new int[rank]; 175 int[] spos = new int[rank]; 176 boolean[] axes = new boolean[rank]; 177 Arrays.fill(axes, true); 178 axes[axis] = false; 179 Serializable src = a.getBuffer(); 180 for (int i = 0; i < ilen; i++) { 181 spos[axis] = indices[i]; 182 dpos[axis] = i; 183 SliceIterator siter = a.getSliceIteratorFromAxes(spos, axes); 184 SliceIterator diter = result.getSliceIteratorFromAxes(dpos, axes); 185 186 while (siter.hasNext() && diter.hasNext()) { 187 ((AbstractDataset) result).setItemDirect(diter.index, siter.index, src); 188 } 189 } 190 } 191 result.setDirty(); 192 return (T) result; 193 } 194 195 /** 196 * Construct a dataset that contains the original dataset repeated the number 197 * of times in each axis given by corresponding entries in the reps array 198 * 199 * @param a 200 * @param reps 201 * @return tiled dataset 202 */ 203 public static Dataset tile(final IDataset a, int... reps) { 204 int[] shape = a.getShape(); 205 int rank = shape.length; 206 final int rlen = reps.length; 207 208 // expand shape 209 if (rank < rlen) { 210 int[] newShape = new int[rlen]; 211 int extraRank = rlen - rank; 212 for (int i = 0; i < extraRank; i++) { 213 newShape[i] = 1; 214 } 215 for (int i = 0; i < rank; i++) { 216 newShape[i+extraRank] = shape[i]; 217 } 218 219 shape = newShape; 220 rank = rlen; 221 } else if (rank > rlen) { 222 int[] newReps = new int[rank]; 223 int extraRank = rank - rlen; 224 for (int i = 0; i < extraRank; i++) { 225 newReps[i] = 1; 226 } 227 for (int i = 0; i < rlen; i++) { 228 newReps[i+extraRank] = reps[i]; 229 } 230 reps = newReps; 231 } 232 233 // calculate new shape 234 int[] newShape = new int[rank]; 235 for (int i = 0; i < rank; i++) { 236 newShape[i] = shape[i]*reps[i]; 237 } 238 239 @SuppressWarnings("deprecation") 240 Dataset tdata = DatasetFactory.zeros(a.getElementsPerItem(), newShape, DTypeUtils.getDType(a)); 241 242 // decide which way to put slices 243 boolean manyColumns; 244 if (rank == 1) 245 manyColumns = true; 246 else 247 manyColumns = shape[rank-1] > 64; 248 249 if (manyColumns) { 250 // generate each start point and put a slice in 251 IndexIterator iter = tdata.getSliceIterator(null, null, shape); 252 SliceIterator siter = (SliceIterator) tdata.getSliceIterator(null, shape, null); 253 final int[] pos = iter.getPos(); 254 while (iter.hasNext()) { 255 siter.setStart(pos); 256 tdata.setSlice(a, siter); 257 } 258 259 } else { 260 // for each value, set slice given by repeats 261 final int[] skip = new int[rank]; 262 for (int i = 0; i < rank; i++) { 263 if (reps[i] == 1) { 264 skip[i] = newShape[i]; 265 } else { 266 skip[i] = shape[i]; 267 } 268 } 269 270 Dataset aa = convertToDataset(a); 271 IndexIterator ita = aa.getIterator(true); 272 final int[] pos = ita.getPos(); 273 274 final int[] sstart = new int[rank]; 275 final int extra = rank - pos.length; 276 for (int i = 0; i < extra; i++) { 277 sstart[i] = 0; 278 } 279 SliceIterator siter = (SliceIterator) tdata.getSliceIterator(sstart, null, skip); 280 while (ita.hasNext()) { 281 for (int i = 0; i < pos.length; i++) { 282 sstart[i + extra] = pos[i]; 283 } 284 siter.setStart(sstart); 285 tdata.setSlice(aa.getObjectAbs(ita.index), siter); 286 } 287 } 288 289 return tdata; 290 } 291 292 /** 293 * Permute copy of dataset's axes so that given order is old order: 294 * <pre> 295 * axisPerm = (p(0), p(1),...) => newdata(n(0), n(1),...) = olddata(o(0), o(1), ...) 296 * such that n(i) = o(p(i)) for all i 297 * </pre> 298 * I.e. for a 3D dataset (1,0,2) implies the new dataset has its 1st dimension 299 * running along the old dataset's 2nd dimension and the new 2nd is the old 1st. 300 * The 3rd dimension is left unchanged. 301 * 302 * @param a 303 * @param axes if null or zero length then axes order reversed 304 * @return remapped copy of data 305 */ 306 public static Dataset transpose(final IDataset a, int... axes) { 307 return convertToDataset(a).transpose(axes); 308 } 309 310 /** 311 * Swap two axes in dataset 312 * @param a 313 * @param axis1 314 * @param axis2 315 * @return swapped dataset 316 */ 317 public static Dataset swapAxes(final IDataset a, int axis1, int axis2) { 318 return convertToDataset(a).swapAxes(axis1, axis2); 319 } 320 321 /** 322 * @param a 323 * @return sorted flattened copy of dataset 324 */ 325 public static <T extends Dataset> T sort(final T a) { 326 return sort(a, (Integer) null); 327 } 328 329 /** 330 * @param a 331 * @param axis to sort along 332 * @return dataset sorted along axis 333 */ 334 public static <T extends Dataset> T sort(final T a, final Integer axis) { 335 Dataset s = a.clone(); 336 return (T) s.sort(axis); 337 } 338 339 /** 340 * Sort in place given dataset and reorder ancillary datasets too 341 * @param a dataset to be sorted 342 * @param b ancillary datasets 343 */ 344 public static void sort(Dataset a, Dataset... b) { 345 if (!DTypeUtils.isDTypeNumerical(a.getDType())) { 346 throw new UnsupportedOperationException("Sorting non-numerical datasets not supported yet"); 347 } 348 349 // gather all datasets as double dataset copies 350 DoubleDataset s = copy(DoubleDataset.class, a); 351 int l = b == null ? 0 : b.length; 352 DoubleDataset[] t = new DoubleDataset[l]; 353 int n = 0; 354 for (int i = 0; i < l; i++) { 355 if (b[i] != null) { 356 if (!DTypeUtils.isDTypeNumerical(b[i].getDType())) { 357 throw new UnsupportedOperationException("Sorting non-numerical datasets not supported yet"); 358 } 359 t[i] = copy(DoubleDataset.class, b[i]); 360 n++; 361 } 362 } 363 364 double[][] y = new double[n][]; 365 for (int i = 0, j = 0; i < l; i++) { 366 if (t[i] != null) { 367 y[j++] = t[i].getData(); 368 } 369 } 370 371 MathArrays.sortInPlace(s.getData(), y); 372 373 a.setSlice(s); 374 for (int i = 0; i < l; i++) { 375 if (b[i] != null) { 376 b[i].setSlice(t[i]); 377 } 378 } 379 } 380 381 /** 382 * Concatenate the set of datasets along given axis 383 * @param as 384 * @param axis 385 * @return concatenated dataset 386 */ 387 public static Dataset concatenate(final IDataset[] as, final int axis) { 388 if (as == null || as.length == 0) { 389 utilsLogger.error("No datasets given"); 390 throw new IllegalArgumentException("No datasets given"); 391 } 392 IDataset a = as[0]; 393 if (as.length == 1) { 394 return convertToDataset(a.clone()); 395 } 396 int[] ashape = a.getShape(); 397 int at = DTypeUtils.getDType(a); 398 int anum = as.length; 399 int isize = a.getElementsPerItem(); 400 401 int i = 1; 402 for (; i < anum; i++) { 403 if (at != DTypeUtils.getDType(as[i])) { 404 utilsLogger.error("Datasets are not of same type"); 405 break; 406 } 407 if (!ShapeUtils.areShapesCompatible(ashape, as[i].getShape(), axis)) { 408 utilsLogger.error("Datasets' shapes are not equal"); 409 break; 410 } 411 final int is = as[i].getElementsPerItem(); 412 if (isize < is) 413 isize = is; 414 } 415 if (i < anum) { 416 utilsLogger.error("Dataset are not compatible"); 417 throw new IllegalArgumentException("Datasets are not compatible"); 418 } 419 420 for (i = 1; i < anum; i++) { 421 ashape[axis] += as[i].getShape()[axis]; 422 } 423 424 @SuppressWarnings("deprecation") 425 Dataset result = DatasetFactory.zeros(isize, ashape, at); 426 427 int[] start = new int[ashape.length]; 428 int[] stop = ashape; 429 stop[axis] = 0; 430 for (i = 0; i < anum; i++) { 431 IDataset b = as[i]; 432 int[] bshape = b.getShape(); 433 stop[axis] += bshape[axis]; 434 result.setSlice(b, start, stop, null); 435 start[axis] += bshape[axis]; 436 } 437 438 return result; 439 } 440 441 /** 442 * Split a dataset into equal sections along given axis 443 * @param a 444 * @param sections 445 * @param axis 446 * @param checkEqual makes sure the division is into equal parts 447 * @return list of split datasets 448 */ 449 public static List<Dataset> split(final Dataset a, int sections, final int axis, final boolean checkEqual) { 450 int[] ashape = a.getShapeRef(); 451 int rank = ashape.length; 452 if (axis > rank) { 453 utilsLogger.error("Axis exceeds rank of dataset"); 454 throw new IllegalArgumentException("Axis exceeds rank of dataset"); 455 } 456 int imax = ashape[axis]; 457 if (checkEqual && (imax%sections) != 0) { 458 utilsLogger.error("Number of sections does not divide axis into equal parts"); 459 throw new IllegalArgumentException("Number of sections does not divide axis into equal parts"); 460 } 461 int n = (imax + sections - 1) / sections; 462 int[] indices = new int[sections-1]; 463 for (int i = 1; i < sections; i++) 464 indices[i-1] = n*i; 465 return split(a, indices, axis); 466 } 467 468 /** 469 * Split a dataset into parts along given axis 470 * @param a 471 * @param indices 472 * @param axis 473 * @return list of split datasets 474 */ 475 @SuppressWarnings("deprecation") 476 public static List<Dataset> split(final Dataset a, int[] indices, final int axis) { 477 final int[] ashape = a.getShapeRef(); 478 final int rank = ashape.length; 479 if (axis > rank) { 480 utilsLogger.error("Axis exceeds rank of dataset"); 481 throw new IllegalArgumentException("Axis exceeds rank of dataset"); 482 } 483 final int imax = ashape[axis]; 484 485 final List<Dataset> result = new ArrayList<Dataset>(); 486 487 final int[] nshape = ashape.clone(); 488 final int is = a.getElementsPerItem(); 489 490 int oind = 0; 491 final int[] start = new int[rank]; 492 final int[] stop = new int[rank]; 493 final int[] step = new int[rank]; 494 for (int i = 0; i < rank; i++) { 495 start[i] = 0; 496 stop[i] = ashape[i]; 497 step[i] = 1; 498 } 499 for (int ind : indices) { 500 if (ind > imax) { 501 result.add(DatasetFactory.zeros(is, new int[] {0}, a.getDType())); 502 } else { 503 nshape[axis] = ind - oind; 504 start[axis] = oind; 505 stop[axis] = ind; 506 Dataset n = DatasetFactory.zeros(is, nshape, a.getDType()); 507 IndexIterator iter = a.getSliceIterator(start, stop, step); 508 509 a.fillDataset(n, iter); 510 result.add(n); 511 oind = ind; 512 } 513 } 514 515 if (imax > oind) { 516 nshape[axis] = imax - oind; 517 start[axis] = oind; 518 stop[axis] = imax; 519 Dataset n = DatasetFactory.zeros(is, nshape, a.getDType()); 520 IndexIterator iter = a.getSliceIterator(start, stop, step); 521 522 a.fillDataset(n, iter); 523 result.add(n); 524 } 525 526 return result; 527 } 528 529 /** 530 * Constructs a dataset which has its elements along an axis replicated from 531 * the original dataset by the number of times given in the repeats array. 532 * 533 * By default, axis=-1 implies using a flattened version of the input dataset 534 * 535 * @param a 536 * @param repeats 537 * @param axis 538 * @return dataset 539 */ 540 public static <T extends Dataset> T repeat(T a, int[] repeats, int axis) { 541 Serializable buf = a.getBuffer(); 542 int[] shape = a.getShape(); 543 int rank = shape.length; 544 final int is = a.getElementsPerItem(); 545 546 if (axis >= rank) { 547 utilsLogger.warn("Axis value is out of bounds"); 548 throw new IllegalArgumentException("Axis value is out of bounds"); 549 } 550 551 int alen; 552 if (axis < 0) { 553 alen = a.getSize(); 554 axis = 0; 555 rank = 1; 556 shape[0] = alen; 557 } else { 558 alen = shape[axis]; 559 } 560 int rlen = repeats.length; 561 if (rlen != 1 && rlen != alen) { 562 utilsLogger.warn("Repeats array should have length of 1 or match chosen axis"); 563 throw new IllegalArgumentException("Repeats array should have length of 1 or match chosen axis"); 564 } 565 566 for (int i = 0; i < rlen; i++) { 567 if (repeats[i] < 0) { 568 utilsLogger.warn("Negative repeat value is not allowed"); 569 throw new IllegalArgumentException("Negative repeat value is not allowed"); 570 } 571 } 572 573 int[] newShape = new int[rank]; 574 for (int i = 0; i < rank; i ++) 575 newShape[i] = shape[i]; 576 577 // do single repeat separately 578 if (repeats.length == 1) { 579 newShape[axis] *= repeats[0]; 580 } else { 581 int nlen = 0; 582 for (int i = 0; i < alen; i++) { 583 nlen += repeats[i]; 584 } 585 newShape[axis] = nlen; 586 } 587 588 @SuppressWarnings("deprecation") 589 Dataset rdata = DatasetFactory.zeros(is, newShape, a.getDType()); 590 Serializable nbuf = rdata.getBuffer(); 591 592 int csize = is; // chunk size 593 for (int i = axis+1; i < rank; i++) { 594 csize *= newShape[i]; 595 } 596 int nout = 1; 597 for (int i = 0; i < axis; i++) { 598 nout *= newShape[i]; 599 } 600 601 int oi = 0; 602 int ni = 0; 603 if (rlen == 1) { // do single repeat separately 604 for (int i = 0; i < nout; i++) { 605 for (int j = 0; j < shape[axis]; j++) { 606 for (int k = 0; k < repeats[0]; k++) { 607 System.arraycopy(buf, oi, nbuf, ni, csize); 608 ni += csize; 609 } 610 oi += csize; 611 } 612 } 613 } else { 614 for (int i = 0; i < nout; i++) { 615 for (int j = 0; j < shape[axis]; j++) { 616 for (int k = 0; k < repeats[j]; k++) { 617 System.arraycopy(buf, oi, nbuf, ni, csize); 618 ni += csize; 619 } 620 oi += csize; 621 } 622 } 623 } 624 625 return (T) rdata; 626 } 627 628 /** 629 * Resize a dataset 630 * @param a 631 * @param shape 632 * @return new dataset with new shape and items that are truncated or repeated, as necessary 633 */ 634 public static <T extends Dataset> T resize(final T a, final int... shape) { 635 int size = a.getSize(); 636 @SuppressWarnings("deprecation") 637 Dataset rdata = DatasetFactory.zeros(a.getElementsPerItem(), shape, a.getDType()); 638 IndexIterator it = rdata.getIterator(); 639 while (it.hasNext()) { 640 rdata.setObjectAbs(it.index, a.getObjectAbs(it.index % size)); 641 } 642 643 return (T) rdata; 644 } 645 646 /** 647 * Copy and cast a dataset 648 * 649 * @param d 650 * The dataset to be copied 651 * @param dtype dataset type 652 * @return copied dataset of given type 653 */ 654 public static Dataset copy(final IDataset d, final int dtype) { 655 Dataset a = convertToDataset(d); 656 657 Dataset c = null; 658 try { 659 // copy across the data 660 switch (dtype) { 661 case Dataset.STRING: 662 c = new StringDataset(a); 663 break; 664 case Dataset.BOOL: 665 c = new BooleanDataset(a); 666 break; 667 case Dataset.INT8: 668 if (a instanceof CompoundDataset) 669 c = new CompoundByteDataset(a); 670 else 671 c = new ByteDataset(a); 672 break; 673 case Dataset.INT16: 674 if (a instanceof CompoundDataset) 675 c = new CompoundShortDataset(a); 676 else 677 c = new ShortDataset(a); 678 break; 679 case Dataset.INT32: 680 if (a instanceof CompoundDataset) 681 c = new CompoundIntegerDataset(a); 682 else 683 c = new IntegerDataset(a); 684 break; 685 case Dataset.INT64: 686 if (a instanceof CompoundDataset) 687 c = new CompoundLongDataset(a); 688 else 689 c = new LongDataset(a); 690 break; 691 case Dataset.ARRAYINT8: 692 if (a instanceof CompoundDataset) 693 c = new CompoundByteDataset((CompoundDataset) a); 694 else 695 c = new CompoundByteDataset(a); 696 break; 697 case Dataset.ARRAYINT16: 698 if (a instanceof CompoundDataset) 699 c = new CompoundShortDataset((CompoundDataset) a); 700 else 701 c = new CompoundShortDataset(a); 702 break; 703 case Dataset.ARRAYINT32: 704 if (a instanceof CompoundDataset) 705 c = new CompoundIntegerDataset((CompoundDataset) a); 706 else 707 c = new CompoundIntegerDataset(a); 708 break; 709 case Dataset.ARRAYINT64: 710 if (a instanceof CompoundDataset) 711 c = new CompoundLongDataset((CompoundDataset) a); 712 else 713 c = new CompoundLongDataset(a); 714 break; 715 case Dataset.FLOAT32: 716 c = new FloatDataset(a); 717 break; 718 case Dataset.FLOAT64: 719 c = new DoubleDataset(a); 720 break; 721 case Dataset.ARRAYFLOAT32: 722 if (a instanceof CompoundDataset) 723 c = new CompoundFloatDataset((CompoundDataset) a); 724 else 725 c = new CompoundFloatDataset(a); 726 break; 727 case Dataset.ARRAYFLOAT64: 728 if (a instanceof CompoundDataset) 729 c = new CompoundDoubleDataset((CompoundDataset) a); 730 else 731 c = new CompoundDoubleDataset(a); 732 break; 733 case Dataset.COMPLEX64: 734 c = new ComplexFloatDataset(a); 735 break; 736 case Dataset.COMPLEX128: 737 c = new ComplexDoubleDataset(a); 738 break; 739 case Dataset.RGB: 740 if (a instanceof CompoundDataset) 741 c = RGBDataset.createFromCompoundDataset((CompoundDataset) a); 742 else 743 c = new RGBDataset(a); 744 break; 745 default: 746 utilsLogger.error("Dataset of unknown type!"); 747 break; 748 } 749 } catch (OutOfMemoryError e) { 750 utilsLogger.error("Not enough memory available to create dataset"); 751 throw new OutOfMemoryError("Not enough memory available to create dataset"); 752 } 753 754 return c; 755 } 756 757 /** 758 * Copy and cast a dataset 759 * 760 * @param clazz dataset class 761 * @param d 762 * The dataset to be copied 763 * @return copied dataset of given type 764 */ 765 public static <T extends Dataset> T copy(Class<T> clazz, final IDataset d) { 766 return (T) copy(d, DTypeUtils.getDType(clazz)); 767 } 768 769 770 /** 771 * Cast a dataset 772 * 773 * @param d 774 * The dataset to be cast. 775 * @param dtype dataset type 776 * @return dataset of given type (or same dataset if already of the right type) 777 */ 778 public static Dataset cast(final IDataset d, final int dtype) { 779 Dataset a = convertToDataset(d); 780 781 if (a.getDType() == dtype) { 782 return a; 783 } 784 return copy(d, dtype); 785 } 786 787 /** 788 * Cast a dataset 789 * 790 * @param clazz dataset class 791 * @param d 792 * The dataset to be cast. 793 * @return dataset of given type (or same dataset if already of the right type) 794 */ 795 public static <T extends Dataset> T cast(Class<T> clazz, final IDataset d) { 796 return (T) cast(d, DTypeUtils.getDType(clazz)); 797 } 798 799 /** 800 * Cast a dataset 801 * 802 * @param d 803 * The dataset to be cast. 804 * @param repeat repeat elements over item 805 * @param dtype dataset type 806 * @param isize item size 807 */ 808 public static Dataset cast(final IDataset d, final boolean repeat, final int dtype, final int isize) { 809 Dataset a = convertToDataset(d); 810 811 if (a.getDType() == dtype && a.getElementsPerItem() == isize) { 812 return a; 813 } 814 if (isize <= 0) { 815 utilsLogger.error("Item size is invalid (>0)"); 816 throw new IllegalArgumentException("Item size is invalid (>0)"); 817 } 818 if (isize > 1 && dtype <= Dataset.FLOAT64) { 819 utilsLogger.error("Item size is inconsistent with dataset type"); 820 throw new IllegalArgumentException("Item size is inconsistent with dataset type"); 821 } 822 823 Dataset c = null; 824 825 try { 826 // copy across the data 827 switch (dtype) { 828 case Dataset.BOOL: 829 c = new BooleanDataset(a); 830 break; 831 case Dataset.INT8: 832 c = new ByteDataset(a); 833 break; 834 case Dataset.INT16: 835 c = new ShortDataset(a); 836 break; 837 case Dataset.INT32: 838 c = new IntegerDataset(a); 839 break; 840 case Dataset.INT64: 841 c = new LongDataset(a); 842 break; 843 case Dataset.ARRAYINT8: 844 c = new CompoundByteDataset(isize, repeat, a); 845 break; 846 case Dataset.ARRAYINT16: 847 c = new CompoundShortDataset(isize, repeat, a); 848 break; 849 case Dataset.ARRAYINT32: 850 c = new CompoundIntegerDataset(isize, repeat, a); 851 break; 852 case Dataset.ARRAYINT64: 853 c = new CompoundLongDataset(isize, repeat, a); 854 break; 855 case Dataset.FLOAT32: 856 c = new FloatDataset(a); 857 break; 858 case Dataset.FLOAT64: 859 c = new DoubleDataset(a); 860 break; 861 case Dataset.ARRAYFLOAT32: 862 c = new CompoundFloatDataset(isize, repeat, a); 863 break; 864 case Dataset.ARRAYFLOAT64: 865 c = new CompoundDoubleDataset(isize, repeat, a); 866 break; 867 case Dataset.COMPLEX64: 868 c = new ComplexFloatDataset(a); 869 break; 870 case Dataset.COMPLEX128: 871 c = new ComplexDoubleDataset(a); 872 break; 873 default: 874 utilsLogger.error("Dataset of unknown type!"); 875 break; 876 } 877 } catch (OutOfMemoryError e) { 878 utilsLogger.error("Not enough memory available to create dataset"); 879 throw new OutOfMemoryError("Not enough memory available to create dataset"); 880 } 881 882 return c; 883 } 884 885 /** 886 * Cast array of datasets to a compound dataset 887 * 888 * @param a 889 * The datasets to be cast. 890 */ 891 public static CompoundDataset cast(final Dataset[] a, final int dtype) { 892 CompoundDataset c = null; 893 894 switch (dtype) { 895 case Dataset.INT8: 896 case Dataset.ARRAYINT8: 897 c = new CompoundByteDataset(a); 898 break; 899 case Dataset.INT16: 900 case Dataset.ARRAYINT16: 901 c = new CompoundShortDataset(a); 902 break; 903 case Dataset.INT32: 904 case Dataset.ARRAYINT32: 905 c = new CompoundIntegerDataset(a); 906 break; 907 case Dataset.INT64: 908 case Dataset.ARRAYINT64: 909 c = new CompoundLongDataset(a); 910 break; 911 case Dataset.FLOAT32: 912 case Dataset.ARRAYFLOAT32: 913 c = new CompoundFloatDataset(a); 914 break; 915 case Dataset.FLOAT64: 916 case Dataset.ARRAYFLOAT64: 917 c = new CompoundDoubleDataset(a); 918 break; 919 case Dataset.COMPLEX64: 920 if (a.length != 2) { 921 throw new IllegalArgumentException("Need two datasets for complex dataset type"); 922 } 923 c = new ComplexFloatDataset(a[0], a[1]); 924 break; 925 case Dataset.COMPLEX128: 926 if (a.length != 2) { 927 throw new IllegalArgumentException("Need two datasets for complex dataset type"); 928 } 929 c = new ComplexDoubleDataset(a[0], a[1]); 930 break; 931 default: 932 utilsLogger.error("Dataset of unsupported type!"); 933 break; 934 } 935 936 return c; 937 } 938 939 /** 940 * Make a dataset unsigned by promoting it to a wider dataset type and unwrapping the signs 941 * of its content 942 * @param a 943 * @return unsigned dataset or original if it is not an integer dataset 944 */ 945 public static Dataset makeUnsigned(IDataset a) { 946 Dataset d = convertToDataset(a); 947 int dtype = d.getDType(); 948 switch (dtype) { 949 case Dataset.INT32: 950 d = new LongDataset(d); 951 unwrapUnsigned(d, 32); 952 break; 953 case Dataset.INT16: 954 d = new IntegerDataset(d); 955 unwrapUnsigned(d, 16); 956 break; 957 case Dataset.INT8: 958 d = new ShortDataset(d); 959 unwrapUnsigned(d, 8); 960 break; 961 case Dataset.ARRAYINT32: 962 d = new CompoundLongDataset(d); 963 unwrapUnsigned(d, 32); 964 break; 965 case Dataset.ARRAYINT16: 966 d = new CompoundIntegerDataset(d); 967 unwrapUnsigned(d, 16); 968 break; 969 case Dataset.ARRAYINT8: 970 d = new CompoundShortDataset(d); 971 unwrapUnsigned(d, 8); 972 break; 973 } 974 return d; 975 } 976 977 /** 978 * Unwrap dataset elements so that all elements are unsigned 979 * @param a dataset 980 * @param bitWidth width of original primitive in bits 981 */ 982 public static void unwrapUnsigned(Dataset a, final int bitWidth) { 983 final int dtype = a.getDType(); 984 final double dv = 1L << bitWidth; 985 final int isize = a.getElementsPerItem(); 986 IndexIterator it = a.getIterator(); 987 988 switch (dtype) { 989 case Dataset.BOOL: 990 break; 991 case Dataset.INT8: 992 break; 993 case Dataset.INT16: 994 ShortDataset sds = (ShortDataset) a; 995 final short soffset = (short) dv; 996 while (it.hasNext()) { 997 final short x = sds.getAbs(it.index); 998 if (x < 0) 999 sds.setAbs(it.index, (short) (x + soffset)); 1000 } 1001 break; 1002 case Dataset.INT32: 1003 IntegerDataset ids = (IntegerDataset) a; 1004 final int ioffset = (int) dv; 1005 while (it.hasNext()) { 1006 final int x = ids.getAbs(it.index); 1007 if (x < 0) 1008 ids.setAbs(it.index, x + ioffset); 1009 } 1010 break; 1011 case Dataset.INT64: 1012 LongDataset lds = (LongDataset) a; 1013 final long loffset = (long) dv; 1014 while (it.hasNext()) { 1015 final long x = lds.getAbs(it.index); 1016 if (x < 0) 1017 lds.setAbs(it.index, x + loffset); 1018 } 1019 break; 1020 case Dataset.FLOAT32: 1021 FloatDataset fds = (FloatDataset) a; 1022 final float foffset = (float) dv; 1023 while (it.hasNext()) { 1024 final float x = fds.getAbs(it.index); 1025 if (x < 0) 1026 fds.setAbs(it.index, x + foffset); 1027 } 1028 break; 1029 case Dataset.FLOAT64: 1030 DoubleDataset dds = (DoubleDataset) a; 1031 final double doffset = dv; 1032 while (it.hasNext()) { 1033 final double x = dds.getAbs(it.index); 1034 if (x < 0) 1035 dds.setAbs(it.index, x + doffset); 1036 } 1037 break; 1038 case Dataset.ARRAYINT8: 1039 break; 1040 case Dataset.ARRAYINT16: 1041 CompoundShortDataset csds = (CompoundShortDataset) a; 1042 final short csoffset = (short) dv; 1043 final short[] csa = new short[isize]; 1044 while (it.hasNext()) { 1045 csds.getAbs(it.index, csa); 1046 boolean dirty = false; 1047 for (int i = 0; i < isize; i++) { 1048 short x = csa[i]; 1049 if (x < 0) { 1050 csa[i] = (short) (x + csoffset); 1051 dirty = true; 1052 } 1053 } 1054 if (dirty) 1055 csds.setAbs(it.index, csa); 1056 } 1057 break; 1058 case Dataset.ARRAYINT32: 1059 CompoundIntegerDataset cids = (CompoundIntegerDataset) a; 1060 final int cioffset = (int) dv; 1061 final int[] cia = new int[isize]; 1062 while (it.hasNext()) { 1063 cids.getAbs(it.index, cia); 1064 boolean dirty = false; 1065 for (int i = 0; i < isize; i++) { 1066 int x = cia[i]; 1067 if (x < 0) { 1068 cia[i] = x + cioffset; 1069 dirty = true; 1070 } 1071 } 1072 if (dirty) 1073 cids.setAbs(it.index, cia); 1074 } 1075 break; 1076 case Dataset.ARRAYINT64: 1077 CompoundLongDataset clds = (CompoundLongDataset) a; 1078 final long cloffset = (long) dv; 1079 final long[] cla = new long[isize]; 1080 while (it.hasNext()) { 1081 clds.getAbs(it.index, cla); 1082 boolean dirty = false; 1083 for (int i = 0; i < isize; i++) { 1084 long x = cla[i]; 1085 if (x < 0) { 1086 cla[i] = x + cloffset; 1087 dirty = true; 1088 } 1089 } 1090 if (dirty) 1091 clds.setAbs(it.index, cla); 1092 } 1093 break; 1094 default: 1095 utilsLogger.error("Dataset of unsupported type for this method"); 1096 break; 1097 } 1098 } 1099 1100 /** 1101 * @param rows 1102 * @param cols 1103 * @param offset 1104 * @param dtype 1105 * @return a new 2d dataset of given shape and type, filled with ones on the (offset) diagonal 1106 */ 1107 public static Dataset eye(final int rows, final int cols, final int offset, final int dtype) { 1108 int[] shape = new int[] {rows, cols}; 1109 @SuppressWarnings("deprecation") 1110 Dataset a = DatasetFactory.zeros(shape, dtype); 1111 1112 int[] pos = new int[] {0, offset}; 1113 while (pos[1] < 0) { 1114 pos[0]++; 1115 pos[1]++; 1116 } 1117 while (pos[0] < rows && pos[1] < cols) { 1118 a.set(1, pos); 1119 pos[0]++; 1120 pos[1]++; 1121 } 1122 1123 return a; 1124 } 1125 1126 /** 1127 * Create a (off-)diagonal matrix from items in dataset 1128 * @param a 1129 * @param offset 1130 * @return diagonal matrix 1131 */ 1132 @SuppressWarnings("deprecation") 1133 public static <T extends Dataset> T diag(final T a, final int offset) { 1134 final int dtype = a.getDType(); 1135 final int rank = a.getRank(); 1136 final int is = a.getElementsPerItem(); 1137 1138 if (rank == 0 || rank > 2) { 1139 utilsLogger.error("Rank of dataset should be one or two"); 1140 throw new IllegalArgumentException("Rank of dataset should be one or two"); 1141 } 1142 1143 Dataset result; 1144 final int[] shape = a.getShapeRef(); 1145 if (rank == 1) { 1146 int side = shape[0] + Math.abs(offset); 1147 int[] pos = new int[] {side, side}; 1148 result = DatasetFactory.zeros(is, pos, dtype); 1149 if (offset >= 0) { 1150 pos[0] = 0; 1151 pos[1] = offset; 1152 } else { 1153 pos[0] = -offset; 1154 pos[1] = 0; 1155 } 1156 int i = 0; 1157 while (pos[0] < side && pos[1] < side) { 1158 result.set(a.getObject(i++), pos); 1159 pos[0]++; 1160 pos[1]++; 1161 } 1162 } else { 1163 int side = offset >= 0 ? Math.min(shape[0], shape[1]-offset) : Math.min(shape[0]+offset, shape[1]); 1164 if (side < 0) 1165 side = 0; 1166 result = DatasetFactory.zeros(is, new int[] {side}, dtype); 1167 1168 if (side > 0) { 1169 int[] pos = offset >= 0 ? new int[] { 0, offset } : new int[] { -offset, 0 }; 1170 int i = 0; 1171 while (pos[0] < shape[0] && pos[1] < shape[1]) { 1172 result.set(a.getObject(pos), i++); 1173 pos[0]++; 1174 pos[1]++; 1175 } 1176 } 1177 } 1178 1179 return (T) result; 1180 } 1181 1182 /** 1183 * Slice (or fully load), if necessary, a lazy dataset, otherwise take a slice view and 1184 * convert to our dataset implementation. If a slice is necessary, this may cause resource 1185 * problems when used on large datasets and throw runtime exceptions 1186 * @param lazy can be null 1187 * @return Converted dataset or null 1188 * @throws DatasetException 1189 */ 1190 public static Dataset sliceAndConvertLazyDataset(ILazyDataset lazy) throws DatasetException { 1191 if (lazy == null) 1192 return null; 1193 1194 IDataset data = lazy instanceof IDataset ? (IDataset) lazy.getSliceView() : lazy.getSlice(); 1195 1196 return convertToDataset(data); 1197 } 1198 1199 /** 1200 * Convert (if necessary) a dataset obeying the interface to our implementation 1201 * @param data can be null 1202 * @return Converted dataset or null 1203 */ 1204 public static Dataset convertToDataset(IDataset data) { 1205 if (data == null) 1206 return null; 1207 1208 if (data instanceof Dataset) { 1209 return (Dataset) data; 1210 } 1211 1212 int dtype = DTypeUtils.getDType(data); 1213 1214 final int isize = data.getElementsPerItem(); 1215 if (isize <= 0) { 1216 throw new IllegalArgumentException("Datasets with " + isize + " elements per item not supported"); 1217 } 1218 1219 @SuppressWarnings("deprecation") 1220 final Dataset result = DatasetFactory.zeros(isize, data.getShape(), dtype); 1221 result.setName(data.getName()); 1222 1223 final IndexIterator it = result.getIterator(true); 1224 final int[] pos = it.getPos(); 1225 switch (dtype) { 1226 case Dataset.BOOL: 1227 while (it.hasNext()) { 1228 result.setObjectAbs(it.index, data.getBoolean(pos)); 1229 } 1230 break; 1231 case Dataset.INT8: 1232 while (it.hasNext()) { 1233 result.setObjectAbs(it.index, data.getByte(pos)); 1234 } 1235 break; 1236 case Dataset.INT16: 1237 while (it.hasNext()) { 1238 result.setObjectAbs(it.index, data.getShort(pos)); 1239 } 1240 break; 1241 case Dataset.INT32: 1242 while (it.hasNext()) { 1243 result.setObjectAbs(it.index, data.getInt(pos)); 1244 } 1245 break; 1246 case Dataset.INT64: 1247 while (it.hasNext()) { 1248 result.setObjectAbs(it.index, data.getLong(pos)); 1249 } 1250 break; 1251 case Dataset.FLOAT32: 1252 while (it.hasNext()) { 1253 result.setObjectAbs(it.index, data.getFloat(pos)); 1254 } 1255 break; 1256 case Dataset.FLOAT64: 1257 while (it.hasNext()) { 1258 result.setObjectAbs(it.index, data.getDouble(pos)); 1259 } 1260 break; 1261 default: 1262 while (it.hasNext()) { 1263 result.setObjectAbs(it.index, data.getObject(pos)); 1264 } 1265 break; 1266 } 1267 1268 result.setErrors(data.getErrors()); 1269 return result; 1270 } 1271 1272 /** 1273 * Create a compound dataset from given datasets 1274 * @param datasets 1275 * @return compound dataset or null if none given 1276 */ 1277 public static CompoundDataset createCompoundDataset(final Dataset... datasets) { 1278 if (datasets == null || datasets.length == 0) 1279 return null; 1280 1281 return createCompoundDataset(datasets[0].getDType(), datasets); 1282 } 1283 1284 /** 1285 * Create a compound dataset from given datasets 1286 * @param dtype 1287 * @param datasets 1288 * @return compound dataset or null if none given 1289 */ 1290 public static CompoundDataset createCompoundDataset(final int dtype, final Dataset... datasets) { 1291 if (datasets == null || datasets.length == 0) 1292 return null; 1293 1294 switch (dtype) { 1295 case Dataset.INT8: 1296 case Dataset.ARRAYINT8: 1297 return new CompoundByteDataset(datasets); 1298 case Dataset.INT16: 1299 case Dataset.ARRAYINT16: 1300 return new CompoundShortDataset(datasets); 1301 case Dataset.INT32: 1302 case Dataset.ARRAYINT32: 1303 return new CompoundIntegerDataset(datasets); 1304 case Dataset.INT64: 1305 case Dataset.ARRAYINT64: 1306 return new CompoundLongDataset(datasets); 1307 case Dataset.FLOAT32: 1308 case Dataset.ARRAYFLOAT32: 1309 return new CompoundFloatDataset(datasets); 1310 case Dataset.FLOAT64: 1311 case Dataset.ARRAYFLOAT64: 1312 return new CompoundDoubleDataset(datasets); 1313 case Dataset.COMPLEX64: 1314 case Dataset.COMPLEX128: 1315 if (datasets.length > 2) { 1316 utilsLogger.error("At most two datasets are allowed"); 1317 throw new IllegalArgumentException("At most two datasets are allowed"); 1318 } else if (datasets.length == 2) { 1319 return dtype == Dataset.COMPLEX64 ? new ComplexFloatDataset(datasets[0], datasets[1]) : new ComplexDoubleDataset(datasets[0], datasets[1]); 1320 } 1321 return dtype == Dataset.COMPLEX64 ? new ComplexFloatDataset(datasets[0]) : new ComplexDoubleDataset(datasets[0]); 1322 case Dataset.RGB: 1323 if (datasets.length == 1) { 1324 return new RGBDataset(datasets[0]); 1325 } else if (datasets.length == 3) { 1326 return new RGBDataset(datasets[0], datasets[1], datasets[2]); 1327 } else { 1328 utilsLogger.error("Only one or three datasets are allowed to create a RGB dataset"); 1329 throw new IllegalArgumentException("Only one or three datasets are allowed to create a RGB dataset"); 1330 } 1331 default: 1332 utilsLogger.error("Dataset type not supported for this operation"); 1333 throw new UnsupportedOperationException("Dataset type not supported"); 1334 } 1335 } 1336 1337 /** 1338 * Create a compound dataset from given datasets 1339 * @param clazz dataset class 1340 * @param datasets 1341 * @return compound dataset or null if none given 1342 */ 1343 public static <T extends CompoundDataset> T createCompoundDataset(Class<T> clazz, final Dataset... datasets) { 1344 return (T) createCompoundDataset(DTypeUtils.getDType(clazz), datasets); 1345 } 1346 1347 /** 1348 * Create a compound dataset from given dataset 1349 * @param dataset 1350 * @param itemSize 1351 * @return compound dataset 1352 */ 1353 public static CompoundDataset createCompoundDataset(final Dataset dataset, final int itemSize) { 1354 int[] shape = dataset.getShapeRef(); 1355 int[] nshape = shape; 1356 if (shape != null && itemSize > 1) { 1357 int size = ShapeUtils.calcSize(shape); 1358 if (size % itemSize != 0) { 1359 throw new IllegalArgumentException("Input dataset has number of items that is not a multiple of itemSize"); 1360 } 1361 int d = shape.length; 1362 int l = 1; 1363 while (--d >= 0) { 1364 l *= shape[d]; 1365 if (l % itemSize == 0) { 1366 break; 1367 } 1368 } 1369 assert d >= 0; 1370 nshape = new int[d + 1]; 1371 for (int i = 0; i < d; i++) { 1372 nshape[i] = shape[i]; 1373 } 1374 nshape[d] = l / itemSize; 1375 } 1376 switch (dataset.getDType()) { 1377 case Dataset.INT8: 1378 return new CompoundByteDataset(itemSize, (byte[]) dataset.getBuffer(), nshape); 1379 case Dataset.INT16: 1380 return new CompoundShortDataset(itemSize, (short[]) dataset.getBuffer(), nshape); 1381 case Dataset.INT32: 1382 return new CompoundIntegerDataset(itemSize, (int[]) dataset.getBuffer(), nshape); 1383 case Dataset.INT64: 1384 return new CompoundLongDataset(itemSize, (long[]) dataset.getBuffer(), nshape); 1385 case Dataset.FLOAT32: 1386 return new CompoundFloatDataset(itemSize, (float[]) dataset.getBuffer(), nshape); 1387 case Dataset.FLOAT64: 1388 return new CompoundDoubleDataset(itemSize, (double[]) dataset.getBuffer(), nshape); 1389 default: 1390 utilsLogger.error("Dataset type not supported for this operation"); 1391 throw new UnsupportedOperationException("Dataset type not supported"); 1392 } 1393 } 1394 1395 1396 /** 1397 * Create a compound dataset by using last axis as elements of an item 1398 * @param a 1399 * @param shareData if true, then share data 1400 * @return compound dataset 1401 */ 1402 public static CompoundDataset createCompoundDatasetFromLastAxis(final Dataset a, final boolean shareData) { 1403 switch (a.getDType()) { 1404 case Dataset.INT8: 1405 return CompoundByteDataset.createCompoundDatasetWithLastDimension(a, shareData); 1406 case Dataset.INT16: 1407 return CompoundShortDataset.createCompoundDatasetWithLastDimension(a, shareData); 1408 case Dataset.INT32: 1409 return CompoundIntegerDataset.createCompoundDatasetWithLastDimension(a, shareData); 1410 case Dataset.INT64: 1411 return CompoundLongDataset.createCompoundDatasetWithLastDimension(a, shareData); 1412 case Dataset.FLOAT32: 1413 return CompoundFloatDataset.createCompoundDatasetWithLastDimension(a, shareData); 1414 case Dataset.FLOAT64: 1415 return CompoundDoubleDataset.createCompoundDatasetWithLastDimension(a, shareData); 1416 default: 1417 utilsLogger.error("Dataset type not supported for this operation"); 1418 throw new UnsupportedOperationException("Dataset type not supported"); 1419 } 1420 } 1421 1422 /** 1423 * Create a dataset from a compound dataset by using elements of an item as last axis 1424 * <p> 1425 * In the case where the number of elements is one, the last axis is squeezed out. 1426 * @param a 1427 * @param shareData if true, then share data 1428 * @return non-compound dataset 1429 */ 1430 public static Dataset createDatasetFromCompoundDataset(final CompoundDataset a, final boolean shareData) { 1431 return a.asNonCompoundDataset(shareData); 1432 } 1433 1434 /** 1435 * Create a copy that has been coerced to an appropriate dataset type 1436 * depending on the input object's class 1437 * 1438 * @param a 1439 * @param obj 1440 * @return coerced copy of dataset 1441 */ 1442 public static Dataset coerce(Dataset a, Object obj) { 1443 final int dt = a.getDType(); 1444 final int ot = DTypeUtils.getDTypeFromClass(obj.getClass()); 1445 1446 return cast(a.clone(), DTypeUtils.getBestDType(dt, ot)); 1447 } 1448 1449 /** 1450 * Function that returns a normalised dataset which is bounded between 0 and 1 1451 * @param a dataset 1452 * @return normalised dataset 1453 */ 1454 public static Dataset norm(Dataset a) { 1455 double amin = a.min().doubleValue(); 1456 double aptp = a.max().doubleValue() - amin; 1457 Dataset temp = Maths.subtract(a, amin); 1458 temp.idivide(aptp); 1459 return temp; 1460 } 1461 1462 /** 1463 * Function that returns a normalised compound dataset which is bounded between 0 and 1. There 1464 * are (at least) two ways to normalise a compound dataset: per element - extrema for each element 1465 * in a compound item is used, i.e. many min/max pairs; over all elements - extrema for all elements 1466 * is used, i.e. one min/max pair. 1467 * @param a dataset 1468 * @param overAllElements if true, then normalise over all elements in each item 1469 * @return normalised dataset 1470 */ 1471 public static CompoundDataset norm(CompoundDataset a, boolean overAllElements) { 1472 double[] amin = a.minItem(); 1473 double[] amax = a.maxItem(); 1474 final int is = a.getElementsPerItem(); 1475 Dataset result; 1476 1477 if (overAllElements) { 1478 Arrays.sort(amin); 1479 Arrays.sort(amax); 1480 double aptp = amax[0] - amin[0]; 1481 1482 result = Maths.subtract(a, amin[0]); 1483 result.idivide(aptp); 1484 } else { 1485 double[] aptp = new double[is]; 1486 for (int j = 0; j < is; j++) { 1487 aptp[j] = amax[j] - amin[j]; 1488 } 1489 1490 result = Maths.subtract(a, amin); 1491 result.idivide(aptp); 1492 } 1493 return (CompoundDataset) result; 1494 } 1495 1496 /** 1497 * Function that returns a normalised dataset which is bounded between 0 and 1 1498 * and has been distributed on a log10 scale 1499 * @param a dataset 1500 * @return normalised dataset 1501 */ 1502 public static Dataset lognorm(Dataset a) { 1503 double amin = a.min().doubleValue(); 1504 double aptp = Math.log10(a.max().doubleValue() - amin + 1.); 1505 Dataset temp = Maths.subtract(a, amin - 1.); 1506 temp = Maths.log10(temp); 1507 temp = Maths.divide(temp, aptp); 1508 return temp; 1509 } 1510 1511 /** 1512 * Function that returns a normalised dataset which is bounded between 0 and 1 1513 * and has been distributed on a natural log scale 1514 * @param a dataset 1515 * @return normalised dataset 1516 */ 1517 public static Dataset lnnorm(Dataset a) { 1518 double amin = a.min().doubleValue(); 1519 double aptp = Math.log(a.max().doubleValue() - amin + 1.); 1520 Dataset temp = Maths.subtract(a, amin - 1.); 1521 temp = Maths.log(temp); 1522 temp = Maths.divide(temp, aptp); 1523 return temp; 1524 } 1525 1526 /** 1527 * Construct a list of datasets where each represents a coordinate varying over the hypergrid 1528 * formed by the input list of axes 1529 * 1530 * @param axes an array of 1D datasets representing axes 1531 * @return a list of coordinate datasets 1532 */ 1533 public static List<Dataset> meshGrid(final Dataset... axes) { 1534 List<Dataset> result = new ArrayList<Dataset>(); 1535 int rank = axes.length; 1536 1537 if (rank < 2) { 1538 utilsLogger.error("Two or more axes datasets are required"); 1539 throw new IllegalArgumentException("Two or more axes datasets are required"); 1540 } 1541 1542 int[] nshape = new int[rank]; 1543 1544 for (int i = 0; i < rank; i++) { 1545 Dataset axis = axes[i]; 1546 if (axis.getRank() != 1) { 1547 utilsLogger.error("Given axis is not 1D"); 1548 throw new IllegalArgumentException("Given axis is not 1D"); 1549 } 1550 nshape[i] = axis.getSize(); 1551 } 1552 1553 for (int i = 0; i < rank; i++) { 1554 Dataset axis = axes[i]; 1555 @SuppressWarnings("deprecation") 1556 Dataset coord = DatasetFactory.zeros(nshape, axis.getDType()); 1557 result.add(coord); 1558 1559 final int alen = axis.getSize(); 1560 for (int j = 0; j < alen; j++) { 1561 final Object obj = axis.getObjectAbs(j); 1562 PositionIterator pi = coord.getPositionIterator(i); 1563 final int[] pos = pi.getPos(); 1564 1565 pos[i] = j; 1566 while (pi.hasNext()) { 1567 coord.set(obj, pos); 1568 } 1569 } 1570 } 1571 1572 return result; 1573 } 1574 1575 /** 1576 * Generate an index dataset for given dataset where sub-datasets contain index values 1577 * 1578 * @return an index dataset 1579 */ 1580 public static IntegerDataset indices(int... shape) { 1581 // now create another dataset to plot against 1582 final int rank = shape.length; 1583 int[] nshape = new int[rank+1]; 1584 nshape[0] = rank; 1585 for (int i = 0; i < rank; i++) { 1586 nshape[i+1] = shape[i]; 1587 } 1588 1589 IntegerDataset index = new IntegerDataset(nshape); 1590 1591 if (rank == 1) { 1592 final int alen = shape[0]; 1593 int[] pos = new int[2]; 1594 for (int j = 0; j < alen; j++) { 1595 pos[1] = j; 1596 index.set(j, pos); 1597 } 1598 } else { 1599 for (int i = 1; i <= rank; i++) { 1600 final int alen = nshape[i]; 1601 for (int j = 0; j < alen; j++) { 1602 PositionIterator pi = index.getPositionIterator(0, i); 1603 final int[] pos = pi.getPos(); 1604 1605 pos[0] = i-1; 1606 pos[i] = j; 1607 while (pi.hasNext()) { 1608 index.set(j, pos); 1609 } 1610 } 1611 } 1612 } 1613 return index; 1614 } 1615 1616 /** 1617 * Get the centroid value of a dataset, this function works out the centroid in every direction 1618 * 1619 * @param a 1620 * the dataset to be analysed 1621 * @param bases the optional array of base coordinates to use as weights. 1622 * This defaults to the mid-point of indices 1623 * @return a double array containing the centroid for each dimension 1624 */ 1625 public static double[] centroid(Dataset a, Dataset... bases) { 1626 int rank = a.getRank(); 1627 if (bases.length > 0 && bases.length != rank) { 1628 throw new IllegalArgumentException("Number of bases must be zero or match rank of dataset"); 1629 } 1630 1631 int[] shape = a.getShapeRef(); 1632 if (bases.length == rank) { 1633 for (int i = 0; i < rank; i++) { 1634 Dataset b = bases[i]; 1635 if (b.getRank() != 1 && b.getSize() != shape[i]) { 1636 throw new IllegalArgumentException("A base does not have shape to match given dataset"); 1637 } 1638 } 1639 } 1640 1641 double[] dc = new double[rank]; 1642 if (rank == 0) 1643 return dc; 1644 1645 final PositionIterator iter = new PositionIterator(shape); 1646 final int[] pos = iter.getPos(); 1647 1648 double tsum = 0.0; 1649 while (iter.hasNext()) { 1650 double val = a.getDouble(pos); 1651 tsum += val; 1652 for (int d = 0; d < rank; d++) { 1653 Dataset b = bases.length == 0 ? null : bases[d]; 1654 if (b == null) { 1655 dc[d] += (pos[d] + 0.5) * val; 1656 } else { 1657 dc[d] += b.getElementDoubleAbs(pos[d]) * val; 1658 } 1659 } 1660 } 1661 1662 for (int d = 0; d < rank; d++) { 1663 dc[d] /= tsum; 1664 } 1665 return dc; 1666 } 1667 1668 /** 1669 * Find linearly-interpolated crossing points where the given dataset crosses the given value 1670 * 1671 * @param d 1672 * @param value 1673 * @return list of interpolated indices 1674 */ 1675 public static List<Double> crossings(Dataset d, double value) { 1676 if (d.getRank() != 1) { 1677 utilsLogger.error("Only 1d datasets supported"); 1678 throw new UnsupportedOperationException("Only 1d datasets supported"); 1679 } 1680 List<Double> results = new ArrayList<Double>(); 1681 1682 // run through all pairs of points on the line and see if value lies within 1683 IndexIterator it = d.getIterator(); 1684 double y1, y2; 1685 1686 y2 = it.hasNext() ? y2 = d.getElementDoubleAbs(it.index) : 0; 1687 double x = 1; 1688 while (it.hasNext()) { 1689 y1 = y2; 1690 y2 = d.getElementDoubleAbs(it.index); 1691 // check if value lies within pair [y1, y2] 1692 if ((y1 <= value && value < y2) || (y1 > value && y2 <= value)) { 1693 final double f = (value - y2)/(y2 - y1); // negative distance from right to left 1694 results.add(x + f); 1695 } 1696 x++; 1697 } 1698 if (y2 == value) { // add end point of it intersects 1699 results.add(x); 1700 } 1701 1702 return results; 1703 } 1704 1705 /** 1706 * Find x values of all the crossing points of the dataset with the given y value 1707 * 1708 * @param xAxis 1709 * Dataset of the X axis that needs to be looked at 1710 * @param yAxis 1711 * Dataset of the Y axis that needs to be looked at 1712 * @param yValue 1713 * The y value the X values are required for 1714 * @return An list of doubles containing all the X coordinates of where the line crosses 1715 */ 1716 public static List<Double> crossings(Dataset xAxis, Dataset yAxis, double yValue) { 1717 List<Double> results = new ArrayList<Double>(); 1718 1719 List<Double> indices = crossings(yAxis, yValue); 1720 1721 for (double xi : indices) { 1722 results.add(Maths.interpolate(xAxis, xi)); 1723 } 1724 return results; 1725 } 1726 1727 /** 1728 * Function that uses the crossings function but prunes the result, so that multiple crossings within a 1729 * certain proportion of the overall range of the x values 1730 * 1731 * @param xAxis 1732 * Dataset of the X axis 1733 * @param yAxis 1734 * Dataset of the Y axis 1735 * @param yValue 1736 * The y value the x values are required for 1737 * @param xRangeProportion 1738 * The proportion of the overall x spread used to prune result 1739 * @return A list containing all the unique crossing points 1740 */ 1741 public static List<Double> crossings(Dataset xAxis, Dataset yAxis, double yValue, double xRangeProportion) { 1742 // get the values found 1743 List<Double> vals = crossings(xAxis, yAxis, yValue); 1744 1745 // use the proportion to calculate the error spacing 1746 double error = xRangeProportion * xAxis.peakToPeak().doubleValue(); 1747 1748 int i = 0; 1749 // now go through and check for groups of three crossings which are all 1750 // within the boundaries 1751 while (i < vals.size() - 3) { 1752 double v1 = Math.abs(vals.get(i) - vals.get(i + 2)); 1753 if (v1 < error) { 1754 // these 3 points should be treated as one 1755 // make the first point equal to the average of them all 1756 vals.set(i + 2, ((vals.get(i) + vals.get(i + 1) + vals.get(i + 2)) / 3.0)); 1757 // remove the other offending points 1758 vals.remove(i); 1759 vals.remove(i); 1760 } else { 1761 i++; 1762 } 1763 } 1764 1765 // once the thinning process has been completed, return the pruned list 1766 return vals; 1767 } 1768 1769 // recursive function 1770 private static void setRow(Object row, Dataset a, int... pos) { 1771 final int l = Array.getLength(row); 1772 final int rank = pos.length; 1773 final int[] npos = Arrays.copyOf(pos, rank+1); 1774 Object r; 1775 if (rank+1 < a.getRank()) { 1776 for (int i = 0; i < l; i++) { 1777 npos[rank] = i; 1778 r = Array.get(row, i); 1779 setRow(r, a, npos); 1780 } 1781 } else { 1782 for (int i = 0; i < l; i++) { 1783 npos[rank] = i; 1784 r = a.getObject(npos); 1785 Array.set(row, i, r); 1786 } 1787 } 1788 } 1789 1790 /** 1791 * Create Java array (of arrays) from dataset 1792 * @param a dataset 1793 * @return Java array (of arrays...) 1794 */ 1795 public static Object createJavaArray(Dataset a) { 1796 if (a.getElementsPerItem() > 1) { 1797 a = createDatasetFromCompoundDataset((CompoundDataset) a, true); 1798 } 1799 Object matrix; 1800 1801 switch (a.getDType()) { 1802 case Dataset.BOOL: 1803 matrix = Array.newInstance(boolean.class, a.getShape()); 1804 break; 1805 case Dataset.INT8: 1806 matrix = Array.newInstance(byte.class, a.getShape()); 1807 break; 1808 case Dataset.INT16: 1809 matrix = Array.newInstance(short.class, a.getShape()); 1810 break; 1811 case Dataset.INT32: 1812 matrix = Array.newInstance(int.class, a.getShape()); 1813 break; 1814 case Dataset.INT64: 1815 matrix = Array.newInstance(long.class, a.getShape()); 1816 break; 1817 case Dataset.FLOAT32: 1818 matrix = Array.newInstance(float.class, a.getShape()); 1819 break; 1820 case Dataset.FLOAT64: 1821 matrix = Array.newInstance(double.class, a.getShape()); 1822 break; 1823 default: 1824 utilsLogger.error("Dataset type not supported"); 1825 throw new IllegalArgumentException("Dataset type not supported"); 1826 } 1827 1828 // populate matrix 1829 setRow(matrix, a); 1830 return matrix; 1831 } 1832 1833 /** 1834 * Removes NaNs and infinities from floating point datasets. 1835 * All other dataset types are ignored. 1836 * 1837 * @param a dataset 1838 * @param value replacement value 1839 */ 1840 public static void removeNansAndInfinities(Dataset a, final Number value) { 1841 if (a instanceof DoubleDataset) { 1842 final double dvalue = DTypeUtils.toReal(value); 1843 final DoubleDataset set = (DoubleDataset) a; 1844 final IndexIterator it = set.getIterator(); 1845 final double[] data = set.getData(); 1846 while (it.hasNext()) { 1847 double x = data[it.index]; 1848 if (Double.isNaN(x) || Double.isInfinite(x)) 1849 data[it.index] = dvalue; 1850 } 1851 } else if (a instanceof FloatDataset) { 1852 final float fvalue = (float) DTypeUtils.toReal(value); 1853 final FloatDataset set = (FloatDataset) a; 1854 final IndexIterator it = set.getIterator(); 1855 final float[] data = set.getData(); 1856 while (it.hasNext()) { 1857 float x = data[it.index]; 1858 if (Float.isNaN(x) || Float.isInfinite(x)) 1859 data[it.index] = fvalue; 1860 } 1861 } else if (a instanceof CompoundDoubleDataset) { 1862 final double dvalue = DTypeUtils.toReal(value); 1863 final CompoundDoubleDataset set = (CompoundDoubleDataset) a; 1864 final int is = set.getElementsPerItem(); 1865 final IndexIterator it = set.getIterator(); 1866 final double[] data = set.getData(); 1867 while (it.hasNext()) { 1868 for (int j = 0; j < is; j++) { 1869 double x = data[it.index + j]; 1870 if (Double.isNaN(x) || Double.isInfinite(x)) 1871 data[it.index + j] = dvalue; 1872 } 1873 } 1874 } else if (a instanceof CompoundFloatDataset) { 1875 final float fvalue = (float) DTypeUtils.toReal(value); 1876 final CompoundFloatDataset set = (CompoundFloatDataset) a; 1877 final int is = set.getElementsPerItem(); 1878 final IndexIterator it = set.getIterator(); 1879 final float[] data = set.getData(); 1880 while (it.hasNext()) { 1881 for (int j = 0; j < is; j++) { 1882 float x = data[it.index + j]; 1883 if (Float.isNaN(x) || Float.isInfinite(x)) 1884 data[it.index + j] = fvalue; 1885 } 1886 } 1887 } 1888 } 1889 1890 /** 1891 * Make floating point datasets contain only finite values. Infinities and NaNs are replaced 1892 * by +/- MAX_VALUE and 0, respectively. 1893 * All other dataset types are ignored. 1894 * 1895 * @param a dataset 1896 */ 1897 public static void makeFinite(Dataset a) { 1898 if (a instanceof DoubleDataset) { 1899 final DoubleDataset set = (DoubleDataset) a; 1900 final IndexIterator it = set.getIterator(); 1901 final double[] data = set.getData(); 1902 while (it.hasNext()) { 1903 final double x = data[it.index]; 1904 if (Double.isNaN(x)) 1905 data[it.index] = 0; 1906 else if (Double.isInfinite(x)) 1907 data[it.index] = x > 0 ? Double.MAX_VALUE : -Double.MAX_VALUE; 1908 } 1909 } else if (a instanceof FloatDataset) { 1910 final FloatDataset set = (FloatDataset) a; 1911 final IndexIterator it = set.getIterator(); 1912 final float[] data = set.getData(); 1913 while (it.hasNext()) { 1914 final float x = data[it.index]; 1915 if (Float.isNaN(x)) 1916 data[it.index] = 0; 1917 else if (Float.isInfinite(x)) 1918 data[it.index] = x > 0 ? Float.MAX_VALUE : -Float.MAX_VALUE; 1919 } 1920 } else if (a instanceof CompoundDoubleDataset) { 1921 final CompoundDoubleDataset set = (CompoundDoubleDataset) a; 1922 final int is = set.getElementsPerItem(); 1923 final IndexIterator it = set.getIterator(); 1924 final double[] data = set.getData(); 1925 while (it.hasNext()) { 1926 for (int j = 0; j < is; j++) { 1927 final double x = data[it.index + j]; 1928 if (Double.isNaN(x)) 1929 data[it.index + j] = 0; 1930 else if (Double.isInfinite(x)) 1931 data[it.index + j] = x > 0 ? Double.MAX_VALUE : -Double.MAX_VALUE; 1932 } 1933 } 1934 } else if (a instanceof CompoundFloatDataset) { 1935 final CompoundFloatDataset set = (CompoundFloatDataset) a; 1936 final int is = set.getElementsPerItem(); 1937 final IndexIterator it = set.getIterator(); 1938 final float[] data = set.getData(); 1939 while (it.hasNext()) { 1940 for (int j = 0; j < is; j++) { 1941 final float x = data[it.index + j]; 1942 if (Float.isNaN(x)) 1943 data[it.index + j] = 0; 1944 else if (Float.isInfinite(x)) 1945 data[it.index + j] = x > 0 ? Float.MAX_VALUE : -Float.MAX_VALUE; 1946 } 1947 } 1948 } 1949 } 1950 1951 /** 1952 * Find absolute index of first value in dataset that is equal to given number 1953 * @param a 1954 * @param n 1955 * @return absolute index (if greater than a.getSize() then no value found) 1956 */ 1957 public static int findIndexEqualTo(final Dataset a, final double n) { 1958 IndexIterator iter = a.getIterator(); 1959 while (iter.hasNext()) { 1960 if (a.getElementDoubleAbs(iter.index) == n) 1961 break; 1962 } 1963 1964 return iter.index; 1965 } 1966 1967 /** 1968 * Find absolute index of first value in dataset that is greater than given number 1969 * @param a 1970 * @param n 1971 * @return absolute index (if greater than a.getSize() then no value found) 1972 */ 1973 public static int findIndexGreaterThan(final Dataset a, final double n) { 1974 IndexIterator iter = a.getIterator(); 1975 while (iter.hasNext()) { 1976 if (a.getElementDoubleAbs(iter.index) > n) 1977 break; 1978 } 1979 1980 return iter.index; 1981 } 1982 1983 /** 1984 * Find absolute index of first value in dataset that is greater than or equal to given number 1985 * @param a 1986 * @param n 1987 * @return absolute index (if greater than a.getSize() then no value found) 1988 */ 1989 public static int findIndexGreaterThanOrEqualTo(final Dataset a, final double n) { 1990 IndexIterator iter = a.getIterator(); 1991 while (iter.hasNext()) { 1992 if (a.getElementDoubleAbs(iter.index) >= n) 1993 break; 1994 } 1995 1996 return iter.index; 1997 } 1998 1999 /** 2000 * Find absolute index of first value in dataset that is less than given number 2001 * @param a 2002 * @param n 2003 * @return absolute index (if greater than a.getSize() then no value found) 2004 */ 2005 public static int findIndexLessThan(final Dataset a, final double n) { 2006 IndexIterator iter = a.getIterator(); 2007 while (iter.hasNext()) { 2008 if (a.getElementDoubleAbs(iter.index) < n) 2009 break; 2010 } 2011 2012 return iter.index; 2013 } 2014 2015 /** 2016 * Find absolute index of first value in dataset that is less than or equal to given number 2017 * @param a 2018 * @param n 2019 * @return absolute index (if greater than a.getSize() then no value found) 2020 */ 2021 public static int findIndexLessThanOrEqualTo(final Dataset a, final double n) { 2022 IndexIterator iter = a.getIterator(); 2023 while (iter.hasNext()) { 2024 if (a.getElementDoubleAbs(iter.index) <= n) 2025 break; 2026 } 2027 2028 return iter.index; 2029 } 2030 2031 /** 2032 * Find first occurrences in one dataset of values given in another sorted dataset 2033 * @param a 2034 * @param values sorted 1D dataset of values to find 2035 * @return absolute indexes of those first occurrences (-1 is used to indicate value not found) 2036 */ 2037 public static IntegerDataset findFirstOccurrences(final Dataset a, final Dataset values) { 2038 if (values.getRank() != 1) { 2039 throw new IllegalArgumentException("Values dataset must be 1D"); 2040 } 2041 IntegerDataset indexes = new IntegerDataset(values.getSize()); 2042 indexes.fill(-1); 2043 2044 IndexIterator it = a.getIterator(); 2045 final int n = values.getSize(); 2046 if (values.getDType() == Dataset.INT64) { 2047 while (it.hasNext()) { 2048 long x = a.getElementLongAbs(it.index); 2049 2050 int l = 0; // binary search to find value in sorted dataset 2051 long vl = values.getLong(l); 2052 if (x <= vl) { 2053 if (x == vl && indexes.getAbs(l) < 0) 2054 indexes.setAbs(l, it.index); 2055 continue; 2056 } 2057 int h = n - 1; 2058 long vh = values.getLong(h); 2059 if (x >= vh) { 2060 if (x == vh && indexes.getAbs(h) < 0) 2061 indexes.setAbs(h, it.index); 2062 continue; 2063 } 2064 while (h - l > 1) { 2065 int m = (l + h) / 2; 2066 long vm = values.getLong(m); 2067 if (x < vm) { 2068 h = m; 2069 } else if (x > vm) { 2070 l = m; 2071 } else { 2072 if (indexes.getAbs(m) < 0) 2073 indexes.setAbs(m, it.index); 2074 break; 2075 } 2076 } 2077 } 2078 } else { 2079 while (it.hasNext()) { 2080 double x = a.getElementDoubleAbs(it.index); 2081 2082 int l = 0; // binary search to find value in sorted dataset 2083 double vl = values.getDouble(l); 2084 if (x <= vl) { 2085 if (x == vl && indexes.getAbs(l) < 0) 2086 indexes.setAbs(l, it.index); 2087 continue; 2088 } 2089 int h = n - 1; 2090 double vh = values.getDouble(h); 2091 if (x >= vh) { 2092 if (x == vh && indexes.getAbs(h) < 0) 2093 indexes.setAbs(h, it.index); 2094 continue; 2095 } 2096 while (h - l > 1) { 2097 int m = (l + h) / 2; 2098 double vm = values.getDouble(m); 2099 if (x < vm) { 2100 h = m; 2101 } else if (x > vm) { 2102 l = m; 2103 } else { 2104 if (indexes.getAbs(m) < 0) 2105 indexes.setAbs(m, it.index); 2106 break; 2107 } 2108 } 2109 } 2110 } 2111 return indexes; 2112 } 2113 2114 /** 2115 * Find indexes in sorted dataset of values for each value in other dataset 2116 * @param a 2117 * @param values sorted 1D dataset of values to find 2118 * @return absolute indexes of values (-1 is used to indicate value not found) 2119 */ 2120 public static IntegerDataset findIndexesForValues(final Dataset a, final Dataset values) { 2121 if (values.getRank() != 1) { 2122 throw new IllegalArgumentException("Values dataset must be 1D"); 2123 } 2124 IntegerDataset indexes = new IntegerDataset(a.getSize()); 2125 indexes.fill(-1); 2126 2127 IndexIterator it = a.getIterator(); 2128 int i = -1; 2129 final int n = values.getSize(); 2130 if (values.getDType() == Dataset.INT64) { 2131 while (it.hasNext()) { 2132 i++; 2133 long x = a.getElementLongAbs(it.index); 2134 2135 int l = 0; // binary search to find value in sorted dataset 2136 long vl = values.getLong(l); 2137 if (x <= vl) { 2138 if (x == vl) 2139 indexes.setAbs(i, l); 2140 continue; 2141 } 2142 int h = n - 1; 2143 long vh = values.getLong(h); 2144 if (x >= vh) { 2145 if (x == vh) 2146 indexes.setAbs(i, h); 2147 continue; 2148 } 2149 while (h - l > 1) { 2150 int m = (l + h) / 2; 2151 long vm = values.getLong(m); 2152 if (x < vm) { 2153 h = m; 2154 } else if (x > vm) { 2155 l = m; 2156 } else { 2157 indexes.setAbs(i, m); 2158 break; 2159 } 2160 } 2161 } 2162 } else { 2163 while (it.hasNext()) { 2164 i++; 2165 double x = a.getElementDoubleAbs(it.index); 2166 2167 int l = 0; // binary search to find value in sorted dataset 2168 double vl = values.getDouble(l); 2169 if (x <= vl) { 2170 if (x == vl) 2171 indexes.setAbs(i, l); 2172 continue; 2173 } 2174 int h = n - 1; 2175 double vh = values.getDouble(h); 2176 if (x >= vh) { 2177 if (x == vh) 2178 indexes.setAbs(i, h); 2179 continue; 2180 } 2181 while (h - l > 1) { 2182 int m = (l + h) / 2; 2183 double vm = values.getDouble(m); 2184 if (x < vm) { 2185 h = m; 2186 } else if (x > vm) { 2187 l = m; 2188 } else { 2189 indexes.setAbs(i, m); 2190 break; 2191 } 2192 } 2193 } 2194 } 2195 2196 return indexes; 2197 } 2198 2199 /** 2200 * Roll items over given axis by given amount 2201 * @param a 2202 * @param shift 2203 * @param axis if null, then roll flattened dataset 2204 * @return rolled dataset 2205 */ 2206 public static <T extends Dataset> T roll(final T a, final int shift, final Integer axis) { 2207 Dataset r = DatasetFactory.zeros(a); 2208 int is = a.getElementsPerItem(); 2209 if (axis == null) { 2210 IndexIterator it = a.getIterator(); 2211 int s = r.getSize(); 2212 int i = shift % s; 2213 if (i < 0) 2214 i += s; 2215 while (it.hasNext()) { 2216 r.setObjectAbs(i, a.getObjectAbs(it.index)); 2217 i += is; 2218 if (i >= s) { 2219 i %= s; 2220 } 2221 } 2222 } else { 2223 PositionIterator pi = a.getPositionIterator(axis); 2224 int s = a.getShapeRef()[axis]; 2225 @SuppressWarnings("deprecation") 2226 Dataset u = DatasetFactory.zeros(is, new int[] {s}, a.getDType()); 2227 Dataset v = DatasetFactory.zeros(u); 2228 int[] pos = pi.getPos(); 2229 boolean[] hit = pi.getOmit(); 2230 while (pi.hasNext()) { 2231 a.copyItemsFromAxes(pos, hit, u); 2232 int i = shift % s; 2233 if (i < 0) 2234 i += s; 2235 for (int j = 0; j < s; j++) { 2236 v.setObjectAbs(i, u.getObjectAbs(j*is)); 2237 i += is; 2238 if (i >= s) { 2239 i %= s; 2240 } 2241 } 2242 r.setItemsOnAxes(pos, hit, v.getBuffer()); 2243 } 2244 } 2245 return (T) r; 2246 } 2247 2248 /** 2249 * Roll the specified axis backwards until it lies in given position 2250 * @param a 2251 * @param axis The rolled axis (index in shape array). Other axes are left unchanged in relative positions 2252 * @param start The position with it right of the destination of the rolled axis 2253 * @return dataset with rolled axis 2254 */ 2255 public static <T extends Dataset> T rollAxis(final T a, int axis, int start) { 2256 int r = a.getRank(); 2257 if (axis < 0) 2258 axis += r; 2259 if (axis < 0 || axis >= r) { 2260 throw new IllegalArgumentException("Axis is out of range: it should be >= 0 and < " + r); 2261 } 2262 if (start < 0) 2263 start += r; 2264 if (start < 0 || start > r) { 2265 throw new IllegalArgumentException("Start is out of range: it should be >= 0 and <= " + r); 2266 } 2267 if (axis < start) 2268 start--; 2269 2270 if (axis == start) 2271 return a; 2272 2273 ArrayList<Integer> axes = new ArrayList<Integer>(); 2274 for (int i = 0; i < r; i++) { 2275 if (i != axis) { 2276 axes.add(i); 2277 } 2278 } 2279 axes.add(start, axis); 2280 int[] aa = new int[r]; 2281 for (int i = 0; i < r; i++) { 2282 aa[i] = axes.get(i); 2283 } 2284 return (T) a.getTransposedView(aa); 2285 } 2286 2287 private static SliceND createFlippedSlice(final Dataset a, int axis) { 2288 int[] shape = a.getShapeRef(); 2289 SliceND slice = new SliceND(shape); 2290 slice.flip(axis); 2291 return slice; 2292 } 2293 2294 /** 2295 * Flip items in left/right direction, column-wise, or along second axis 2296 * @param a dataset must be at least 2D 2297 * @return view of flipped dataset 2298 */ 2299 public static <T extends Dataset> T flipLeftRight(final T a) { 2300 if (a.getRank() < 2) { 2301 throw new IllegalArgumentException("Dataset must be at least 2D"); 2302 } 2303 return (T) a.getSliceView(createFlippedSlice(a, 1)); 2304 } 2305 2306 /** 2307 * Flip items in up/down direction, row-wise, or along first axis 2308 * @param a dataset 2309 * @return view of flipped dataset 2310 */ 2311 public static <T extends Dataset> T flipUpDown(final T a) { 2312 return (T) a.getSliceView(createFlippedSlice(a, 0)); 2313 } 2314 2315 /** 2316 * Rotate items in first two dimension by 90 degrees anti-clockwise 2317 * @param a dataset must be at least 2D 2318 * @return view of flipped dataset 2319 */ 2320 public static <T extends Dataset> T rotate90(final T a) { 2321 return rotate90(a, 1); 2322 } 2323 2324 /** 2325 * Rotate items in first two dimension by 90 degrees anti-clockwise 2326 * @param a dataset must be at least 2D 2327 * @param k number of 90-degree rotations 2328 * @return view of flipped dataset 2329 */ 2330 public static <T extends Dataset> T rotate90(final T a, int k) { 2331 k = k % 4; 2332 while (k < 0) { 2333 k += 4; 2334 } 2335 int r = a.getRank(); 2336 if (r < 2) { 2337 throw new IllegalArgumentException("Dataset must be at least 2D"); 2338 } 2339 switch (k) { 2340 case 1: case 3: 2341 int[] axes = new int[r]; 2342 axes[0] = 1; 2343 axes[1] = 0; 2344 for (int i = 2; i < r; i++) { 2345 axes[i] = i; 2346 } 2347 Dataset t = a.getTransposedView(axes); 2348 return (T) t.getSliceView(createFlippedSlice(t, k == 1 ? 0 : 1)); 2349 case 2: 2350 SliceND s = createFlippedSlice(a, 0); 2351 s.flip(1); 2352 return (T) a.getSliceView(s); 2353 default: 2354 case 0: 2355 return a; 2356 } 2357 } 2358 2359 /** 2360 * Select content according where condition is true. All inputs are broadcasted to a maximum shape 2361 * @param condition boolean dataset 2362 * @param x 2363 * @param y 2364 * @return dataset where content is x or y depending on whether condition is true or otherwise 2365 */ 2366 public static Dataset select(BooleanDataset condition, Object x, Object y) { 2367 Object[] all = new Object[] {condition, x, y}; 2368 Dataset[] dAll = BroadcastUtils.convertAndBroadcast(all); 2369 condition = (BooleanDataset) dAll[0]; 2370 Dataset dx = dAll[1]; 2371 Dataset dy = dAll[2]; 2372 int dt = DTypeUtils.getBestDType(dx.getDType(),dy.getDType()); 2373 int ds = Math.max(dx.getElementsPerItem(), dy.getElementsPerItem()); 2374 2375 @SuppressWarnings("deprecation") 2376 Dataset r = DatasetFactory.zeros(ds, condition.getShapeRef(), dt); 2377 IndexIterator iter = condition.getIterator(true); 2378 final int[] pos = iter.getPos(); 2379 int i = 0; 2380 while (iter.hasNext()) { 2381 r.setObjectAbs(i++, condition.getElementBooleanAbs(iter.index) ? dx.getObject(pos) : dy.getObject(pos)); 2382 } 2383 return r; 2384 } 2385 2386 /** 2387 * Select content from choices where condition is true, otherwise use default. All inputs are broadcasted to a maximum shape 2388 * @param conditions array of boolean datasets 2389 * @param choices array of datasets or objects 2390 * @param def default value (can be a dataset) 2391 * @return dataset 2392 */ 2393 public static Dataset select(BooleanDataset[] conditions, Object[] choices, Object def) { 2394 final int n = conditions.length; 2395 if (choices.length != n) { 2396 throw new IllegalArgumentException("Choices list is not same length as conditions list"); 2397 } 2398 Object[] all = new Object[2*n]; 2399 System.arraycopy(conditions, 0, all, 0, n); 2400 System.arraycopy(choices, 0, all, n, n); 2401 Dataset[] dAll = BroadcastUtils.convertAndBroadcast(all); 2402 conditions = new BooleanDataset[n]; 2403 Dataset[] dChoices = new Dataset[n]; 2404 System.arraycopy(dAll, 0, conditions, 0, n); 2405 System.arraycopy(dAll, n, dChoices, 0, n); 2406 int dt = -1; 2407 int ds = -1; 2408 for (int i = 0; i < n; i++) { 2409 Dataset a = dChoices[i]; 2410 int t = a.getDType(); 2411 if (t > dt) 2412 dt = t; 2413 int s = a.getElementsPerItem(); 2414 if (s > ds) 2415 ds = s; 2416 } 2417 if (dt < 0 || ds < 1) { 2418 throw new IllegalArgumentException("Dataset types of choices are invalid"); 2419 } 2420 2421 @SuppressWarnings("deprecation") 2422 Dataset r = DatasetFactory.zeros(ds, conditions[0].getShapeRef(), dt); 2423 Dataset d = DatasetFactory.createFromObject(def).getBroadcastView(r.getShapeRef()); 2424 PositionIterator iter = new PositionIterator(r.getShapeRef()); 2425 final int[] pos = iter.getPos(); 2426 int i = 0; 2427 while (iter.hasNext()) { 2428 int j = 0; 2429 for (; j < n; j++) { 2430 if (conditions[j].get(pos)) { 2431 r.setObjectAbs(i++, dChoices[j].getObject(pos)); 2432 break; 2433 } 2434 } 2435 if (j == n) { 2436 r.setObjectAbs(i++, d.getObject(pos)); 2437 } 2438 } 2439 return r; 2440 } 2441 2442 /** 2443 * Choose content from choices where condition is true, otherwise use default. All inputs are broadcasted to a maximum shape 2444 * @param index integer dataset (ideally, items should be in [0, n) range, if there are n choices) 2445 * @param choices array of datasets or objects 2446 * @param throwAIOOBE if true, throw array index out of bound exception 2447 * @param clip true to clip else wrap indices out of bounds; only used when throwAOOBE is false 2448 * @return dataset 2449 */ 2450 public static Dataset choose(IntegerDataset index, Object[] choices, boolean throwAIOOBE, boolean clip) { 2451 final int n = choices.length; 2452 Object[] all = new Object[n + 1]; 2453 System.arraycopy(choices, 0, all, 0, n); 2454 all[n] = index; 2455 Dataset[] dChoices = BroadcastUtils.convertAndBroadcast(all); 2456 int dt = -1; 2457 int ds = -1; 2458 int mr = -1; 2459 for (int i = 0; i < n; i++) { 2460 Dataset a = dChoices[i]; 2461 int r = a.getRank(); 2462 if (r > mr) 2463 mr = r; 2464 int t = a.getDType(); 2465 if (t > dt) 2466 dt = t; 2467 int s = a.getElementsPerItem(); 2468 if (s > ds) 2469 ds = s; 2470 } 2471 if (dt < 0 || ds < 1) { 2472 throw new IllegalArgumentException("Dataset types of choices are invalid"); 2473 } 2474 index = (IntegerDataset) dChoices[n]; 2475 dChoices[n] = null; 2476 2477 @SuppressWarnings("deprecation") 2478 Dataset r = DatasetFactory.zeros(ds, index.getShape(), dt); 2479 IndexIterator iter = index.getIterator(true); 2480 final int[] pos = iter.getPos(); 2481 int i = 0; 2482 while (iter.hasNext()) { 2483 int j = index.getAbs(iter.index); 2484 if (j < 0) { 2485 if (throwAIOOBE) 2486 throw new ArrayIndexOutOfBoundsException(j); 2487 if (clip) { 2488 j = 0; 2489 } else { 2490 j %= n; 2491 j += n; // as remainder still negative 2492 } 2493 } 2494 if (j >= n) { 2495 if (throwAIOOBE) 2496 throw new ArrayIndexOutOfBoundsException(j); 2497 if (clip) { 2498 j = n - 1; 2499 } else { 2500 j %= n; 2501 } 2502 } 2503 Dataset c = dChoices[j]; 2504 r.setObjectAbs(i++, c.getObject(pos)); 2505 } 2506 return r; 2507 } 2508 2509 /** 2510 * Calculate positions in given shape from a dataset of 1-D indexes 2511 * @param indices 2512 * @param shape 2513 * @return list of positions as integer datasets 2514 */ 2515 public static List<IntegerDataset> calcPositionsFromIndexes(Dataset indices, int[] shape) { 2516 int rank = shape.length; 2517 List<IntegerDataset> posns = new ArrayList<IntegerDataset>(); 2518 int[] iShape = indices.getShapeRef(); 2519 for (int i = 0; i < rank; i++) { 2520 posns.add(new IntegerDataset(iShape)); 2521 } 2522 IndexIterator it = indices.getIterator(true); 2523 int[] pos = it.getPos(); 2524 while (it.hasNext()) { 2525 int n = indices.getInt(pos); 2526 int[] p = ShapeUtils.getNDPositionFromShape(n, shape); 2527 for (int i = 0; i < rank; i++) { 2528 posns.get(i).setItem(p[i], pos); 2529 } 2530 } 2531 return posns; 2532 } 2533 2534 2535 /** 2536 * Calculate indexes in given shape from datasets of position 2537 * @param positions as a list of datasets where each holds the position in a dimension 2538 * @param shape 2539 * @param mode either null, zero-length, unit length or length of rank of shape where 2540 * 0 = raise exception, 1 = wrap, 2 = clip 2541 * @return indexes as an integer dataset 2542 */ 2543 public static IntegerDataset calcIndexesFromPositions(List<? extends Dataset> positions, int[] shape, int... mode) { 2544 int rank = shape.length; 2545 if (positions.size() != rank) { 2546 throw new IllegalArgumentException("Number of position datasets must be equal to rank of shape"); 2547 } 2548 2549 if (mode == null || mode.length == 0) { 2550 mode = new int[rank]; 2551 } else if (mode.length == 1) { 2552 int m = mode[0]; 2553 mode = new int[rank]; 2554 Arrays.fill(mode, m); 2555 } else if (mode.length != rank) { 2556 throw new IllegalArgumentException("Mode length greater than one must match rank of shape"); 2557 } 2558 for (int i = 0; i < rank; i++) { 2559 int m = mode[i]; 2560 if (m < 0 || m > 2) { 2561 throw new IllegalArgumentException("Unknown mode value - it must be 0, 1, or 2"); 2562 } 2563 } 2564 2565 Dataset p = positions.get(0); 2566 IntegerDataset indexes = new IntegerDataset(p.getShapeRef()); 2567 IndexIterator it = p.getIterator(true); 2568 int[] iPos = it.getPos(); 2569 int[] tPos = new int[rank]; 2570 while (it.hasNext()) { 2571 for (int i = 0; i < rank; i++) { 2572 p = positions.get(i); 2573 int j = p.getInt(iPos); 2574 int d = shape[i]; 2575 if (mode[i] == 0) { 2576 if (j < 0 || j >= d) { 2577 throw new ArrayIndexOutOfBoundsException("Position value exceeds dimension in shape"); 2578 } 2579 } else if (mode[i] == 1) { 2580 while (j < 0) 2581 j += d; 2582 while (j >= d) 2583 j -= d; 2584 } else { 2585 if (j < 0) 2586 j = 0; 2587 if (j >= d) 2588 j = d - 1; 2589 } 2590 tPos[i] = j; 2591 } 2592 indexes.set(ShapeUtils.getFlat1DIndex(shape, tPos), iPos); 2593 } 2594 2595 return indexes; 2596 } 2597 2598 /** 2599 * Serialize dataset by flattening it. Discards metadata 2600 * @param data 2601 * @return some java array 2602 */ 2603 public static Serializable serializeDataset(final IDataset data) { 2604 Dataset d = convertToDataset(data.getSliceView()); 2605 d.clearMetadata(null); 2606 return d.flatten().getBuffer(); 2607 } 2608 2609 /** 2610 * Extract values where condition is non-zero. This is similar to Dataset#getByBoolean but supports broadcasting 2611 * @param data 2612 * @param condition should be broadcastable to data 2613 * @return 1-D dataset of values 2614 */ 2615 @SuppressWarnings("deprecation") 2616 public static Dataset extract(final IDataset data, final IDataset condition) { 2617 Dataset a = convertToDataset(data.getSliceView()); 2618 Dataset b = cast(condition.getSliceView(), Dataset.BOOL); 2619 2620 try { 2621 return a.getByBoolean(b); 2622 } catch (IllegalArgumentException e) { 2623 final int length = ((Number) b.sum()).intValue(); 2624 2625 BroadcastPairIterator it = new BroadcastPairIterator(a, b, null, false); 2626 int size = ShapeUtils.calcSize(it.getShape()); 2627 Dataset c; 2628 if (length < size) { 2629 int[] ashape = it.getFirstShape(); 2630 int[] bshape = it.getSecondShape(); 2631 int r = ashape.length; 2632 size = length; 2633 for (int i = 0; i < r; i++) { 2634 int s = ashape[i]; 2635 if (s > 1 && bshape[i] == 1) { 2636 size *= s; 2637 } 2638 } 2639 } 2640 c = DatasetFactory.zeros(new int[] {size}, a.getDType()); 2641 2642 int i = 0; 2643 if (it.isOutputDouble()) { 2644 while (it.hasNext()) { 2645 if (it.bLong != 0) { 2646 c.setObjectAbs(i++, it.aDouble); 2647 } 2648 } 2649 } else { 2650 while (it.hasNext()) { 2651 if (it.bLong != 0) { 2652 c.setObjectAbs(i++, it.aLong); 2653 } 2654 } 2655 } 2656 2657 return c; 2658 } 2659 } 2660}