Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Javadoc framework fixes #250

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public Constraint(Ops tf) {
*
* @param weights the weights
* @return the constrained weights
* @param <T> the date type for the weights and return value
*/
public abstract <T extends TNumber> Operand<T> call(Operand<T> weights);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,42 @@
import org.tensorflow.framework.data.impl.TakeDataset;
import org.tensorflow.framework.data.impl.TensorSliceDataset;
import org.tensorflow.framework.data.impl.TextLineDataset;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.family.TType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.tensorflow.types.family.TType;

/**
* Represents a potentially large list of independent elements (samples), and allows iteration and
* transformations to be performed across these elements.
*/
public abstract class Dataset implements Iterable<List<Operand<?>>> {
protected Ops tf;
private Operand<?> variant;
private List<Class<? extends TType>> outputTypes;
private List<Shape> outputShapes;
private final Operand<?> variant;
private final List<Class<? extends TType>> outputTypes;
private final List<Shape> outputShapes;

/**
* Creates a Dataset
*
* @param tf The TensorFlow Ops
* @param variant the Operand that represents the dataset.
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
* dataset element.
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
* a dataset element.
*/
public Dataset(
Ops tf, Operand<?> variant, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
Ops tf,
Operand<?> variant,
List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes) {
if (tf == null) {
throw new IllegalArgumentException("Ops accessor cannot be null.");
}
Expand All @@ -61,13 +74,65 @@ public Dataset(
this.outputShapes = outputShapes;
}

/**
* Creates a dataset from another dataset
*
* @param other the other dataset
*/
protected Dataset(Dataset other) {
this.tf = other.tf;
this.variant = other.variant;
this.outputTypes = other.outputTypes;
this.outputShapes = other.outputShapes;
}

/**
* Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
* this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
* provided tensors.
*
* @param tf Ops Accessor
* @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
* features, labels)
* @param outputTypes A list of tensor type classes representing the data type of each component
* of this dataset.
* @return A new `Dataset`
*/
public static Dataset fromTensorSlices(
Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes) {
return new TensorSliceDataset(tf, tensors, outputTypes);
}

/**
* Creates a Dataset comprising records from one or more TFRecord files.
*
* @param tf the TensorFlow Ops
* @param filename the name of the file containing the TFRecords
* @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
* @param bufferSize the number of bytes in the read buffer
* @return A Dataset comprising records from a TFRecord file.
*/
public static Dataset tfRecordDataset(
Ops tf, String filename, String compressionType, long bufferSize) {
return new TFRecordDataset(
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
}

/**
* Creates a Dataset comprising lines from one or more text files.
*
* @param tf the TensorFlow Ops
* @param filename the name of the file containing the text linea
* @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
* @param bufferSize the number of bytes in the read buffer
* @return A Dataset comprising lines from a text file.
*/
public static Dataset textLineDataset(
Ops tf, String filename, String compressionType, long bufferSize) {
return new TextLineDataset(
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
}

/**
* Groups elements of this dataset into batches.
*
Expand Down Expand Up @@ -127,11 +192,12 @@ public final Dataset take(long count) {
* Returns a new Dataset which maps a function across all elements from this dataset, on a single
* component of each element.
*
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
* labels).
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
* (features, labels).
*
* <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2)))} will
* map the function over the `features` component of each element, multiplying each by 2.
* <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features,
* tf.constant(2)))} will map the function over the `features` component of each element,
* multiplying each by 2.
*
* @param index The index of the component to transform.
* @param mapper The function to apply to the target component.
Expand All @@ -150,8 +216,8 @@ public Dataset mapOneComponent(int index, Function<Operand<?>, Operand<?>> mappe
* Returns a new Dataset which maps a function across all elements from this dataset, on all
* components of each element.
*
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
* labels).
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
* (features, labels).
*
* <p>Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component,
* tf.constant(2)))} will map the function over the both the `features` and `labels` components of
Expand All @@ -172,8 +238,8 @@ public Dataset mapAllComponents(Function<Operand<?>, Operand<?>> mapper) {
/**
* Returns a new Dataset which maps a function over all elements returned by this dataset.
*
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
* labels).
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
* (features, labels).
*
* <p>Calling
*
Expand Down Expand Up @@ -254,53 +320,42 @@ public DatasetIterator makeOneShotIterator() {
}

/**
* Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
* this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
* provided tensors.
* Gets the variant tensor representing this dataset.
*
* @param tf Ops Accessor
* @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
* features, labels)
* @param outputTypes A list of tensor type classes representing the data type of each component of
* this dataset.
* @return A new `Dataset`
* @return the variant tensor representing this dataset.
*/
public static Dataset fromTensorSlices(
Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes) {
return new TensorSliceDataset(tf, tensors, outputTypes);
}

public static Dataset tfRecordDataset(
Ops tf, String filename, String compressionType, long bufferSize) {
return new TFRecordDataset(
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
}

public static Dataset textLineDataset(
Ops tf, String filename, String compressionType, long bufferSize) {
return new TextLineDataset(
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
}

/** Get the variant tensor representing this dataset. */
public Operand<?> getVariant() {
return variant;
}

/** Get a list of output types for each component of this dataset. */
/**
* Gets a list of output types for each component of this dataset.
*
* @return the list of output types for each component of this dataset.
*/
public List<Class<? extends TType>> getOutputTypes() {
return this.outputTypes;
}

/** Get a list of shapes for each component of this dataset. */
/**
* Gets a list of shapes for each component of this dataset.
*
* @return the list of shapes for each component of this dataset.
*/
public List<Shape> getOutputShapes() {
return this.outputShapes;
}

/**
* Gets the TensorFlow Ops Instance
*
* @return the TensorFlow Ops Instance
*/
public Ops getOpsInstance() {
return this.tf;
}

/** {@inheritDoc} */
@Override
public String toString() {
return "Dataset{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.family.TType;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.types.family.TType;

/**
* Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a
Expand Down Expand Up @@ -102,21 +102,21 @@ public class DatasetIterator implements Iterable<List<Operand<?>>> {
public static final String EMPTY_SHARED_NAME = "";

protected Ops tf;

private Operand<?> iteratorResource;
private Op initializer;

protected List<Class<? extends TType>> outputTypes;
protected List<Shape> outputShapes;
private final Operand<?> iteratorResource;
private Op initializer;

/**
* Creates a DatasetIterator
*
* @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
* `iteratorResource`.
* @param iteratorResource An Operand representing the iterator (e.g. constructed from
* `tf.data.iterator` or `tf.data.anonymousIterator`)
* @param initializer An `Op` that should be run to initialize this iterator
* @param outputTypes A list of classes corresponding to the tensor type of each component of
* a dataset element.
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
* dataset element.
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
* a dataset element.
*/
Expand All @@ -134,6 +134,18 @@ public DatasetIterator(
this.outputShapes = outputShapes;
}

/**
* Creates a DatasetIterator
*
* @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
* `iteratorResource`.
* @param iteratorResource An Operand representing the iterator (e.g. constructed from
* `tf.data.iterator` or `tf.data.anonymousIterator`)
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
* dataset element.
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
* a dataset element.
*/
public DatasetIterator(
Ops tf,
Operand<?> iteratorResource,
Expand All @@ -145,6 +157,11 @@ public DatasetIterator(
this.outputShapes = outputShapes;
}

/**
* Creates a DatasetIterator from another DatasetIterator
*
* @param other the other DatasetIterator
*/
protected DatasetIterator(DatasetIterator other) {
this.tf = other.tf;
this.iteratorResource = other.iteratorResource;
Expand All @@ -153,6 +170,26 @@ protected DatasetIterator(DatasetIterator other) {
this.outputShapes = other.outputShapes;
}

/**
* Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`.
*
* @param tf Ops accessor
* @param outputTypes A list of classes repesenting the tensor type of each component of a dataset
* element.
* @param outputShapes A list of Shape objects representing the shape of each component of a
* dataset element.
* @return A new DatasetIterator
*/
public static DatasetIterator fromStructure(
Ops tf, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
Operand<?> iteratorResource =
tf.scope().env() instanceof Graph
? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes)
: tf.data.anonymousIterator(outputTypes, outputShapes).handle();

return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes);
}

/**
* Returns a list of {@code Operand<?>} representing the components of the next dataset element.
*
Expand Down Expand Up @@ -226,37 +263,33 @@ public Op makeInitializer(Dataset dataset) {
}

/**
* Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`.
* Gets the iteratorResource
*
* @param tf Ops accessor
* @param outputTypes A list of classes repesenting the tensor type of each component of a
* dataset element.
* @param outputShapes A list of Shape objects representing the shape of each component of a
* dataset element.
* @return A new DatasetIterator
* @return the iteratorResource
*/
public static DatasetIterator fromStructure(
Ops tf, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
Operand<?> iteratorResource =
tf.scope().env() instanceof Graph
? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes)
: tf.data.anonymousIterator(outputTypes, outputShapes).handle();

return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes);
}

public Operand<?> getIteratorResource() {
return iteratorResource;
}

/**
* Gets the initializer
*
* @return the initializer
*/
public Op getInitializer() {
return initializer;
}

/**
* Gets the TensorFlow Ops Instance
*
* @return the TensorFlow Ops Instance
*/
public Ops getOpsInstance() {
return tf;
}

/** {@inheritDoc} */
@Override
public Iterator<List<Operand<?>>> iterator() {

Expand Down
Loading