Metadata-Version: 2.1
Name: stitchnet
Version: 0.2.2
Summary: 
Author: Surat Teerapittayanon
Author-email: steerapi@gmail.com
Requires-Python: >=3.9,<3.12
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.9
Requires-Dist: datasets (>=2.13.1,<3.0.0)
Requires-Dist: evaluate (>=0.4.0,<0.5.0)
Requires-Dist: graphviz (>=0.20.1,<0.21.0)
Requires-Dist: netron (>=7.0.4,<8.0.0)
Requires-Dist: numpy (>=1.23.1,<2.0.0)
Requires-Dist: onnx (==1.13.1)
Requires-Dist: onnx-tool (==0.7.3)
Requires-Dist: onnx2torch (>=1.4.1,<2.0.0)
Requires-Dist: onnxoptimizer (>=0.3.0,<0.4.0)
Requires-Dist: onnxruntime (>=1.12.1,<2.0.0)
Requires-Dist: pyppeteer (>=1.0.2,<2.0.0)
Requires-Dist: pyre-extensions (>=0.0.30,<0.0.31)
Requires-Dist: scipy (>=1.7.3,<2.0.0)
Requires-Dist: seaborn (>=0.11.2,<0.12.0)
Requires-Dist: skl2onnx (>=1.12.0,<1.13.0)
Requires-Dist: torch (>=1.12.1,<2.0.0)
Requires-Dist: torchvision (>=0.13.1,<0.14.0)
Requires-Dist: transformers (>=4.25.1,<5.0.0)
Description-Content-Type: text/markdown

StitchNet: Composing Neural Networks from Pre-Trained Fragments
=============

We propose StitchNet, a novel neural network creation paradigm that stitches together fragments (one or more consecutive network layers) from multiple pre-trained neural networks. StitchNet allows the creation of high-performing neural networks without the large compute and data requirements needed under traditional model creation processes via backpropagation training. We leverage Centered Kernel Alignment (CKA) as a compatibility measure to efficiently guide the selection of these fragments in composing a network for a given task tailored to specific accuracy needs and computing resource constraints. We then show that these fragments can be stitched together to create neural networks with comparable accuracy to traditionally trained networks at a fraction of computing resource and data requirements. Finally, we explore a novel on-the-fly personalized model creation and inference application enabled by this new paradigm.

Installation
=============

    pip install stitchnet
    
Usage
=============
    
    import stitchnet
    
    # load the beans dataset from huggingface
    from stitchnet import load_hf_dataset
    dataset_train, dataset_val = load_hf_dataset('beans', train_split='validation', val_split='test', label_column='labels', seed=47)

    # prepare stitching dataset
    import numpy as np
    from tqdm import tqdm
    stitching_dataset = np.vstack([x['pixel_values'] for x in tqdm(dataset_train.select(range(32)))])

    # generate stitchnets 1 sample
    score,net = generate(stitching_dataset, threshold=0, totalThreshold=0, maxDepth=10, K=2, sample=True)
    
    # generate multiple stitchnets
    generator = generate(stitching_dataset, threshold=0.8, totalThreshold=0.8, maxDepth=10, K=2, sample=False)
    for score,net in generator:
        print(score,net)
    
    # print macs and params
    net.get_macs_params() # {'macs': 4488343528.0, 'params': 25653096}
    
    # save onnx
    net.save_onnx('./_data/net') # saving to ./_results/net.onnx
        
    # draw the stitchnet
    net.draw_svg('./_data/net') # saving to ./_results/net.svg
    
    # train a classifier
    net.fit(dataset_train, label_column="labels")
    
    # use it for prediction
    net.predict_files(['./_results/test.jpg']) # [{'score': [0.8, 0.2, 0.0], 'label': 0}]
    
    # evaluate with validation dataset
    net.evaluate_dataset(dataset_val, label_column='labels') # {'accuracy': 0.7421875}

CUDA
=============
See https://pytorch.org/get-started/previous-versions/ to install appropriate version. For example

    # CUDA 11.6
    pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116


Experiment Notebooks
=============

1. Download dogs and cats dataset from https://www.kaggle.com/c/dogs-vs-cats/data and put train data in _data/dogs_cats/raw/train folder
2. See 00_prepare_data.ipynb to split the images into cats and dogs folder
3. See 01_download_networks.ipynb to download the pretrained networks from Torchvision
4. See 02_generate_fragments.ipynb to generate fragments from the pretrained networks
5. See 03_stitchnet.ipynb to generate stitchnets
6. See 04_render_graph.ipynb to create svg images of the network graphs using netron
7. See 05_eval_original_networks.ipynb for evaluating the original pretrained networks
8. See 06_finetuning.ipynb to generate the finetuning result
9. See 07_ensemble.ipynb to generate the ensemble result
10. See 08_number_of_samples_for_stitching.ipynb for experimenting with varying number of samples to use when stitching
11. See 09_plot_results.ipynb plot figures of the results for the paper


Installation using conda
=============

Create a new conda env

    conda create -n stitchnet python=3.10
    
Activate stitchnet conda env

    conda activate stitchnet

For conda and NVIDIA gpu, please also install for CUDA runtime on onnx

    conda install -c conda-forge cudnn
    
Install poetry

    curl -sSL https://install.python-poetry.org | python3 -

Install dependencies using poetry 

    poetry install


