pyvene.models.modeling_utils.gather_neurons

Contents

pyvene.models.modeling_utils.gather_neurons#

gather_neurons(tensor_input, unit, unit_locations_as_list, device=None)[source]#

Gather intervening neurons.

Parameters:

tensor_input – tensors of shape (batch_size, sequence_length, …) if

unit is “pos” or “h”, tensors of shape (batch_size, num_heads, sequence_length, …) if unit is “h.pos” :param unit: the intervention units to gather. Units could be “h” - head number, “pos” - position in the sequence, or “dim” - a particular dimension in the embedding space. If intervening multiple units, they are ordered and separated by .. Currently only support “pos”, “h”, and “h.pos” units. :param unit_locations_as_list: tuple of lists of lists of positions to gather in tensor_input, according to the unit. :return the gathered tensor as tensor_output