DNC Memory Modules¶
This is a conversion of https://github.com/deepmind/dnc to TensorFlow v2.x.
DNC Module¶
-
class
ftw.tf.networks.dnc.dnc.DNC(access_config, controller_config, output_size, clip_value=None, name='dnc')¶ Bases:
sonnet.src.recurrent.RNNCoreDNC core module.
Contains controller and memory access module.
-
__init__(access_config, controller_config, output_size, clip_value=None, name='dnc')¶ Initializes the DNC core.
- Args:
access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between
[-clip_value, clip_value] if specified.name: module name (default ‘dnc’).
-
initial_state(batch_size: int, **unused_kwargs)¶ Constructs an initial state for this core.
- Args:
- batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
- Returns:
- Arbitrarily nested initial state for this core.
-
output_size¶
-
DNC State NamedTuple¶
Memory Access Module¶
-
class
ftw.tf.networks.dnc.access.MemoryAccess(memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access')¶ Bases:
sonnet.src.recurrent.RNNCoreAccess module of the Differentiable Neural Computer.
This memory module supports multiple read and write heads. It makes use of:
- addressing.TemporalLinkage to track the temporal ordering of writes in memory for each write head.
- addressing.FreenessAllocator for keeping track of memory usage, where usage increase when a memory location is written to, and decreases when memory is read from that the controller says can be freed.
Write-address selection is done by an interpolation between content-based lookup and using unused memory.
Read-address selection is done by an interpolation of content-based lookup and following the link graph in the forward or backwards read direction.
-
__init__(memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access')¶ Creates a MemoryAccess module.
- Args:
- memory_size: The number of memory slots (N in the DNC paper). word_size: The width of each memory slot (W in the DNC paper) num_reads: The number of read heads (R in the DNC paper). num_writes: The number of write heads (fixed at 1 in the paper). name: The name of the module.
-
initial_state(batch_size: int, **unused_kwargs)¶ Constructs an initial state for this core.
- Args:
- batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
- Returns:
- Arbitrarily nested initial state for this core.
-
output_size¶ Returns the output shape.
Memory Access State NamedTuple¶
-
class
ftw.tf.networks.dnc.access.AccessState(memory, read_weights, write_weights, linkage, usage)¶ Bases:
tuple-
linkage¶ Alias for field number 3
-
memory¶ Alias for field number 0
-
read_weights¶ Alias for field number 1
-
usage¶ Alias for field number 4
-
write_weights¶ Alias for field number 2
-
Addressing module¶
DNC addressing modules.
-
class
ftw.tf.networks.dnc.addressing.CosineWeights(num_heads, word_size, strength_op=<function softplus>, name='cosine_weights')¶ Bases:
sonnet.src.base.ModuleCosine-weighted attention.
Calculates the cosine similarity between a query and each word in memory, then applies a weighted softmax to return a sharp distribution.
-
class
ftw.tf.networks.dnc.addressing.Freeness(memory_size, name='freeness')¶ Bases:
sonnet.src.recurrent.RNNCoreMemory usage that is increased by writing and decreased by reading.
This module is a pseudo-RNNCore whose state is a tensor with values in the range [0, 1] indicating the usage of each of memory_size memory slots.
The usage is:
- Increased by writing, where usage is increased towards 1 at the write addresses.
- Decreased by reading, where usage is decreased after reading from a location when free_gate is close to 1.
The function write_allocation_weights can be invoked to get free locations to write to for a number of write heads.
-
initial_state(batch_size: int, **unused_kwargs)¶ Constructs an initial state for this core.
- Args:
- batch_size: An int or an integral scalar tensor representing batch size. **kwargs: Optional keyword arguments.
- Returns:
- Arbitrarily nested initial state for this core.
-
write_allocation_weights(usage, write_gates, num_writes)¶ Calculates freeness-based locations for writing to.
This finds unused memory by ranking the memory locations by usage, for each write head. (For more than one write head, we use a “simulated new usage” which takes into account the fact that the previous write head will increase the usage in that area of the memory.)
- Args:
- usage: A tensor of shape [batch_size, memory_size] representing
- current memory usage.
- write_gates: A tensor of shape [batch_size, num_writes] with values in
- the range [0, 1] indicating how much each write head does writing based on the address returned here (and hence how much usage increases).
num_writes: The number of write heads to calculate write weights for.
- Returns:
- tensor of shape [batch_size, num_writes, memory_size] containing the
- freeness-based write locations. Note that this isn’t scaled by write_gate; this scaling must be applied externally.
-
class
ftw.tf.networks.dnc.addressing.TemporalLinkage(memory_size, num_writes, name='temporal_linkage')¶ Bases:
sonnet.src.recurrent.RNNCoreKeeps track of write order for forward and backward addressing.
This is a pseudo-RNNCore module, whose state is a pair (link, precedence_weights), where link is a (collection of) graphs for (possibly multiple) write heads (represented by a tensor with values in the range [0, 1]), and precedence_weights records the “previous write locations” used to build the link graphs.
The function directional_read_weights computes addresses following the forward and backward directions in the link graphs.
-
directional_read_weights(link, prev_read_weights, forward)¶ Calculates the forward or the backward read weights.
For each read head (at a given address), there are num_writes link graphs to follow. Thus this function computes a read address for each of the num_reads * num_writes pairs of read and write heads.
- Args:
- link: tensor of shape `[batch_size, num_writes, memory_size,
- memory_size]` representing the link graphs L_t.
- prev_read_weights: tensor of shape `[batch_size, num_reads,
- memory_size]` containing the previous read weights w_{t-1}^r.
- forward: Boolean indicating whether to follow the “future” direction in
- the link graph (True) or the “past” direction (False).
- Returns:
- tensor of shape [batch_size, num_reads, num_writes, memory_size]
-
-
class
ftw.tf.networks.dnc.addressing.TemporalLinkageState(link, precedence_weights)¶ Bases:
tuple-
link¶ Alias for field number 0
-
precedence_weights¶ Alias for field number 1
-
-
ftw.tf.networks.dnc.addressing.weighted_softmax(activations, strengths, strengths_op)¶ Returns softmax over activations multiplied by positive strengths.
- Args:
- activations: A tensor of shape [batch_size, num_heads, memory_size], of
- activations to be transformed. Softmax is taken over the last dimension.
- strengths: A tensor of shape [batch_size, num_heads] containing strengths to
- multiply by the activations prior to the softmax.
strengths_op: An operation to transform strengths before softmax.
- Returns:
- A tensor of same shape as activations with weighted softmax applied.