Skip to content

Commit

Permalink
Merge pull request #16 from JeanKossaifi/main
Browse files Browse the repository at this point in the history
Adds efficient factorized linear implementations
  • Loading branch information
JeanKossaifi authored Feb 20, 2022
2 parents 9b58f0f + 9e9902d commit 3b7138b
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 41 deletions.
83 changes: 58 additions & 25 deletions tltorch/functional/factorized_linear.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,61 @@
import torch
from .factorized_tensordot import tensor_dot_tucker, tensor_dot_cp
import tensorly as tl
from collections import Counter
from tensorly.tt_tensor import TTTensor

tl.set_backend('pytorch')

# Author: Taylor Lee Patti <taylorpatti@g.harvard.edu>

def tt_factorized_linear(tt_vec, ttm_weights):
"""Contracts a TT tensor with a TT matrix and returns a TT tensor.
Parameters
----------
tt_vec : tensor train tensor
ttm_weights : tensor train matrix
Returns
-------
The tensor train tensor obtained for contracting the TT tensor and the TT matrix.
"""
ncores = len(tt_vec)
contr_layer = []
for i in range(ncores):
dimW, dimX = ttm_weights[i].shape, tt_vec[i].shape
contr = tl.einsum('abc,debf->adecf', tt_vec[i], ttm_weights[i])
contr_layer.append(tl.reshape(contr, (dimW[0]*dimX[0], dimW[1], dimW[3]*dimX[2])))
return TTTensor(contr_layer)
# Author: Jean Kossaifi

def linear_tucker(tensor, tucker_matrix, transpose=True):
if transpose:
contraction_axis = 1
else:
contraction_axis = 0
n_rows = len(tucker_matrix.tensorized_shape[contraction_axis])
tensor = tensor.reshape(-1, *tucker_matrix.tensorized_shape[contraction_axis])

modes_tensor = list(range(tensor.ndim - n_rows, tensor.ndim))
if transpose:
modes_tucker = list(range(n_rows, tucker_matrix.order))
else:
modes_tucker = list(range(n_rows))

return tensor_dot_tucker(tensor, tucker_matrix, (modes_tensor, modes_tucker))

def linear_cp(tensor, cp_matrix, transpose=True):
if transpose:
out_features, in_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1])
in_shape = cp_matrix.tensorized_shape[1]
modes_cp = list(range(out_features, cp_matrix.order))
else:
in_features, out_features = len(cp_matrix.tensorized_shape[0]), len(cp_matrix.tensorized_shape[1])
in_shape = cp_matrix.tensorized_shape[0]
modes_cp = list(range(in_features))
tensor = tensor.reshape(-1, *in_shape)

modes_tensor = list(range(1, tensor.ndim))

return tensor_dot_cp(tensor, cp_matrix, (modes_tensor, modes_cp))


def linear_blocktt(tensor, tt_matrix, transpose=True):
if transpose:
contraction_axis = 1
else:
contraction_axis = 0
ndim = len(tt_matrix.tensorized_shape[contraction_axis])
tensor = tensor.reshape(-1, *tt_matrix.tensorized_shape[contraction_axis])

bs = 'a'
start = ord(bs) + 1
in_idx = bs + ''.join(chr(i) for i in [start+i for i in range(ndim)])
factors_idx = []
for i in range(ndim):
if transpose:
idx = [start+ndim*2+i, start+ndim+i, start+i, start+ndim*2+i+1]
else:
idx = [start+ndim*2+i, start+i, start+ndim+i, start+ndim*2+i+1]
factors_idx.append(''.join(chr(j) for j in idx))
out_idx = bs + ''.join(chr(i) for i in [start + ndim + i for i in range(ndim)])
eq = in_idx + ',' + ','.join(i for i in factors_idx) + '->' + out_idx
res = tl.einsum(eq, tensor, *tt_matrix.factors)
return tl.reshape(res, (tl.shape(res)[0], -1))

79 changes: 79 additions & 0 deletions tltorch/functional/factorized_tensordot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Author: Jean Kossaifi

import tensorly as tl
from tensorly.tenalg.tenalg_utils import _validate_contraction_modes
tl.set_backend('pytorch')

einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


def tensor_dot_tucker(tensor, tucker, modes):
modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
input_order = tensor.ndim
weight_order = tucker.order

sorted_modes_tucker = sorted(modes_tucker, reverse=True)
sorted_modes_tensor = sorted(modes_tensor, reverse=True)

# Symbol for dimensionality of the core
rank_sym = [einsum_symbols[i] for i in range(weight_order)]

# Symbols for tucker weight size
tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]

# Symbolds for input tensor
tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]

# Output: input + weights symbols after removing contraction symbols
output_sym = tensor_sym + tucker_sym
for m in sorted_modes_tucker:
output_sym.pop(m+input_order)
for m in sorted_modes_tensor:
output_sym.pop(m)
for i, e in enumerate(modes_tensor):
tensor_sym[e] = tucker_sym[modes_tucker[i]]

# Form the actual equation: tensor, core, factors -> output
eq = ''.join(tensor_sym)
eq += ',' + ''.join(rank_sym)
eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
eq += '->' + ''.join(output_sym)

return tl.einsum(eq, tensor, tucker.core, *tucker.factors)


def tensor_dot_cp(tensor, cp, modes):
"""Contracts a to CP tensors in factorized form
Returns
-------
tensor = tensor x cp_matrix.to_matrix().T
"""
try:
cp_shape = cp.tensor_shape
except AttributeError:
cp_shape = cp.shape
modes_tensor, modes_cp = _validate_contraction_modes(tl.shape(tensor), cp_shape, modes)

tensor_order = tl.ndim(tensor)
# CP rank = 'a', start at b
start = ord('b')
eq_in = ''.join(f'{chr(start+index)}' for index in range(tensor_order))
eq_factors = []
eq_res = ''.join(eq_in[i] if i not in modes_tensor else '' for i in range(tensor_order))
counter_joint = 0 # contraction modes, shared indices between tensor and CP
counter_free = 0 # new uncontracted modes from the CP
for i in range(len(cp.factors)):
if i in modes_cp:
eq_factors.append(f'{eq_in[modes_tensor[counter_joint]]}a')
counter_joint += 1
else:
eq_factors.append(f'{chr(start+tensor_order+counter_free)}a')
eq_res += f'{chr(start+tensor_order+counter_free)}'
counter_free += 1

eq_factors = ','.join(f for f in eq_factors)
eq = eq_in + ',a,' + eq_factors + '->' + eq_res
res = tl.einsum(eq, tensor, cp.weights, *cp.factors)

return res
10 changes: 9 additions & 1 deletion tltorch/functional/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from tkinter import W
import numpy as np
import torch
import torch.nn.functional as F
from ..factorized_tensors import TensorizedTensor
from .factorized_linear import linear_blocktt, linear_cp, linear_tucker

import tensorly as tl
tl.set_backend('pytorch')
Expand All @@ -20,9 +22,15 @@ def factorized_linear(x, weight, bias=None, in_features=None):
# Weights are in the form (out_features, in_features)
# PyTorch's linear returns dot(x, weight.T)!
if isinstance(weight, TensorizedTensor):
if weight._factorization == 'cp':
return linear_cp(x, weight) + bias
elif weight._factorization == 'tucker':
return linear_tucker(x, weight) + bias
elif weight._factorization == 'blocktt':
return linear_blocktt(x, weight) + bias
# if no efficient implementation available: use reconstruction
weight = weight.to_matrix()
else:
weight = weight.to_tensor()


return F.linear(x, torch.reshape(weight, (-1, in_features)), bias=bias)
15 changes: 0 additions & 15 deletions tltorch/functional/tests/test_factored_linear.py

This file was deleted.

33 changes: 33 additions & 0 deletions tltorch/functional/tests/test_factorized_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from ...factorized_tensors import TensorizedTensor
from ..factorized_linear import linear_tucker, linear_blocktt, linear_cp
import torch

import tensorly as tl
tl.set_backend('pytorch')
from tensorly import testing
from tensorly.utils import prod

import pytest

# Author: Jean Kossaifi


@pytest.mark.parametrize('factorization, factorized_linear',
[('tucker', linear_tucker), ('blocktt', linear_blocktt), ('cp', linear_cp)])
def test_linear_tensor_dot_tucker(factorization, factorized_linear):
in_shape = (4, 5)
in_dim = prod(in_shape)
out_shape = (6, 2)
rank = 3
batch_size = 2

tensor = tl.randn((batch_size, in_dim))
fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank,
factorization=factorization)
fact_weight.normal_()
full_weight = fact_weight.to_matrix()
true_res = torch.matmul(tensor, full_weight.T)
res = factorized_linear(tensor, fact_weight, transpose=True)
res = res.reshape(batch_size, -1)
testing.assert_array_almost_equal(true_res, res, decimal=5)

0 comments on commit 3b7138b

Please sign in to comment.