DNC Memory Modules

This is a conversion of https://github.com/deepmind/dnc to TensorFlow v2.x.

DNC Module

class ftw.tf.networks.dnc.dnc.DNC(access_config, controller_config, output_size, clip_value=None, name='dnc')

Bases: sonnet.src.recurrent.RNNCore

DNC core module.

Contains controller and memory access module.

__init__(access_config, controller_config, output_size, clip_value=None, name='dnc')

Initializes the DNC core.

Args:

access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between

[-clip_value, clip_value] if specified.

name: module name (default ‘dnc’).

initial_state(batch_size: int, **unused_kwargs)

Constructs an initial state for this core.

Args:
batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
Returns:
Arbitrarily nested initial state for this core.
output_size

DNC State NamedTuple

class ftw.tf.networks.dnc.dnc.DNCState(access_output, access_state, controller_state)

Bases: tuple

access_output

Alias for field number 0

access_state

Alias for field number 1

controller_state

Alias for field number 2

Memory Access Module

class ftw.tf.networks.dnc.access.MemoryAccess(memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access')

Bases: sonnet.src.recurrent.RNNCore

Access module of the Differentiable Neural Computer.

This memory module supports multiple read and write heads. It makes use of:

  • addressing.TemporalLinkage to track the temporal ordering of writes in memory for each write head.
  • addressing.FreenessAllocator for keeping track of memory usage, where usage increase when a memory location is written to, and decreases when memory is read from that the controller says can be freed.

Write-address selection is done by an interpolation between content-based lookup and using unused memory.

Read-address selection is done by an interpolation of content-based lookup and following the link graph in the forward or backwards read direction.

__init__(memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access')

Creates a MemoryAccess module.

Args:
memory_size: The number of memory slots (N in the DNC paper). word_size: The width of each memory slot (W in the DNC paper) num_reads: The number of read heads (R in the DNC paper). num_writes: The number of write heads (fixed at 1 in the paper). name: The name of the module.
initial_state(batch_size: int, **unused_kwargs)

Constructs an initial state for this core.

Args:
batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
Returns:
Arbitrarily nested initial state for this core.
output_size

Returns the output shape.

Memory Access State NamedTuple

class ftw.tf.networks.dnc.access.AccessState(memory, read_weights, write_weights, linkage, usage)

Bases: tuple

linkage

Alias for field number 3

memory

Alias for field number 0

read_weights

Alias for field number 1

usage

Alias for field number 4

write_weights

Alias for field number 2

Addressing module

DNC addressing modules.

class ftw.tf.networks.dnc.addressing.CosineWeights(num_heads, word_size, strength_op=<function softplus>, name='cosine_weights')

Bases: sonnet.src.base.Module

Cosine-weighted attention.

Calculates the cosine similarity between a query and each word in memory, then applies a weighted softmax to return a sharp distribution.

class ftw.tf.networks.dnc.addressing.Freeness(memory_size, name='freeness')

Bases: sonnet.src.recurrent.RNNCore

Memory usage that is increased by writing and decreased by reading.

This module is a pseudo-RNNCore whose state is a tensor with values in the range [0, 1] indicating the usage of each of memory_size memory slots.

The usage is:

  • Increased by writing, where usage is increased towards 1 at the write addresses.
  • Decreased by reading, where usage is decreased after reading from a location when free_gate is close to 1.

The function write_allocation_weights can be invoked to get free locations to write to for a number of write heads.

initial_state(batch_size: int, **unused_kwargs)

Constructs an initial state for this core.

Args:
batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
Returns:
Arbitrarily nested initial state for this core.
write_allocation_weights(usage, write_gates, num_writes)

Calculates freeness-based locations for writing to.

This finds unused memory by ranking the memory locations by usage, for each write head. (For more than one write head, we use a “simulated new usage” which takes into account the fact that the previous write head will increase the usage in that area of the memory.)

Args:
usage: A tensor of shape [batch_size, memory_size] representing
current memory usage.
write_gates: A tensor of shape [batch_size, num_writes] with values in
the range [0, 1] indicating how much each write head does writing based on the address returned here (and hence how much usage increases).

num_writes: The number of write heads to calculate write weights for.

Returns:
tensor of shape [batch_size, num_writes, memory_size] containing the
freeness-based write locations. Note that this isn’t scaled by write_gate; this scaling must be applied externally.
class ftw.tf.networks.dnc.addressing.TemporalLinkage(memory_size, num_writes, name='temporal_linkage')

Bases: sonnet.src.recurrent.RNNCore

Keeps track of write order for forward and backward addressing.

This is a pseudo-RNNCore module, whose state is a pair (link, precedence_weights), where link is a (collection of) graphs for (possibly multiple) write heads (represented by a tensor with values in the range [0, 1]), and precedence_weights records the “previous write locations” used to build the link graphs.

The function directional_read_weights computes addresses following the forward and backward directions in the link graphs.

directional_read_weights(link, prev_read_weights, forward)

Calculates the forward or the backward read weights.

For each read head (at a given address), there are num_writes link graphs to follow. Thus this function computes a read address for each of the num_reads * num_writes pairs of read and write heads.

Args:
link: tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` representing the link graphs L_t.
prev_read_weights: tensor of shape `[batch_size, num_reads,
memory_size]` containing the previous read weights w_{t-1}^r.
forward: Boolean indicating whether to follow the “future” direction in
the link graph (True) or the “past” direction (False).
Returns:
tensor of shape [batch_size, num_reads, num_writes, memory_size]
initial_state(batch_size: int, **unused_kwargs)

Constructs an initial state for this core.

Args:
batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
Returns:
Arbitrarily nested initial state for this core.
class ftw.tf.networks.dnc.addressing.TemporalLinkageState(link, precedence_weights)

Bases: tuple

Alias for field number 0

precedence_weights

Alias for field number 1

ftw.tf.networks.dnc.addressing.weighted_softmax(activations, strengths, strengths_op)

Returns softmax over activations multiplied by positive strengths.

Args:
activations: A tensor of shape [batch_size, num_heads, memory_size], of
activations to be transformed. Softmax is taken over the last dimension.
strengths: A tensor of shape [batch_size, num_heads] containing strengths to
multiply by the activations prior to the softmax.

strengths_op: An operation to transform strengths before softmax.

Returns:
A tensor of same shape as activations with weighted softmax applied.