69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
|
from loguru import logger
|
||
|
|
||
|
import sys
|
||
|
|
||
|
sys.path.append("pyrat39/")
|
||
|
from analysis_param import AnalysisParam
|
||
|
from sound_pyrat.analyzer_split import analyze_splits
|
||
|
from pyrat import load_img
|
||
|
from sound_pyrat.domains.domain_factory import create_box_maker
|
||
|
import torch
|
||
|
|
||
|
from sound_pyrat.utils_analyzer import analyze, read_inputs, read_model, analyze_adv
|
||
|
import warnings
|
||
|
warnings.filterwarnings("ignore", message="invalid value encountered in *")
|
||
|
warnings.filterwarnings("ignore", message="divide by zero encountered in reciprocal")
|
||
|
|
||
|
def launch_pyrat(model_path, property_path, domains=[], split_timeout=0):
|
||
|
"""
|
||
|
Args:
|
||
|
model_path: path to the model
|
||
|
property_path: path to the property file
|
||
|
domains: a python list of domains to propagate in the network from "zonotopes" or "poly"
|
||
|
nb_of_splits: maximum number of splits to do on the input intervals (-1 for no split)
|
||
|
"""
|
||
|
logger.remove()
|
||
|
|
||
|
params = AnalysisParam(domains=domains, squeeze=True, verbose=False, check="skip", by_layer=True)
|
||
|
inputs = read_inputs(model_path=model_path, prop_path=property_path, params=params)
|
||
|
|
||
|
pyrat_model, bounds, to_verify, to_counter = inputs
|
||
|
if pyrat_model.root[0].input[0].input_shape == (1, 3):
|
||
|
pyrat_model.root[0].input[0].input_shape = (3,)
|
||
|
bounds = params.box_maker(bounds[0], bounds[1]).to_type(pyrat_model.dtype)
|
||
|
if pyrat_model.root[0].input[0].input_shape == (1, 3):
|
||
|
bounds = bounds.reshape((3,))
|
||
|
|
||
|
if split_timeout == 0:
|
||
|
result = analyze(bounds, pyrat_model, to_verify, to_counter, params)
|
||
|
print("Output bounds:\n", result.single_res.output_bounds)
|
||
|
print("Result = {}, Time = {:.2f} s".format(result.result, result.time))
|
||
|
else:
|
||
|
params.timeout = split_timeout
|
||
|
result = analyze_splits(bounds, pyrat_model, to_verify, to_counter, params)
|
||
|
result.print_res()
|
||
|
|
||
|
return result.result, result.time
|
||
|
|
||
|
def local_robustness(model_path, image, label, pert, domains=[]):
|
||
|
box_maker = create_box_maker("torch", sound=False, dtype=torch.float32)
|
||
|
params = AnalysisParam(total_labels=10, epsilon=pert, box_maker=box_maker, domains=domains,
|
||
|
by_layer=True, force_analysis=True, true_label=label)
|
||
|
|
||
|
pyrat_model = read_model(model_path, params=params)
|
||
|
res = analyze_adv(image, model_path, pyrat_model, params)
|
||
|
|
||
|
return res.result, res.time
|
||
|
|
||
|
def read_images():
|
||
|
images = []
|
||
|
logger.remove()
|
||
|
with open("fmnist/img_labels.csv") as f:
|
||
|
lines = f.readlines()
|
||
|
for line in lines[:1]:
|
||
|
name, label = line.split(",")
|
||
|
image = load_img(f"fmnist/images/{name}", grayscale=False)
|
||
|
label = int(label.replace("\n", "").strip())
|
||
|
images.append((image, label))
|
||
|
|
||
|
return images
|